Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,124 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_property_setter_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: property setter stores and replays prior input."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

property_attrs: dict[str, set[str]] = defaultdict(set)
for cls in [stmt for stmt in tree.body if isinstance(stmt, ast.ClassDef)]:
props = {
item.name
for item in cls.body
if isinstance(item, ast.FunctionDef)
and any(
isinstance(dec, ast.Name) and dec.id == "property"
for dec in item.decorator_list
)
}
setters = {
item.name
for item in cls.body
if isinstance(item, ast.FunctionDef)
and any(
isinstance(dec, ast.Attribute)
and dec.attr == "setter"
and isinstance(dec.value, ast.Name)
and dec.value.id in props
for dec in item.decorator_list
)
}
if setters:
property_attrs[cls.name].update(setters)

instances: dict[str, str] = {}
for stmt in tree.body:
if not (
isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id in property_attrs
):
continue
instances[stmt.targets[0].id] = stmt.value.func.id

def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)
return params

def _property_key(expr: ast.AST | None) -> tuple[str, str] | None:
if not (
isinstance(expr, ast.Attribute)
and isinstance(expr.value, ast.Name)
and expr.value.id in instances
):
return None
class_name = instances[expr.value.id]
if expr.attr not in property_attrs[class_name]:
return None
return expr.value.id, expr.attr

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = _param_names(node)
mutated = {
key
for stmt in ast.walk(node)
if isinstance(stmt, ast.Assign)
and bool(_expr_names(stmt.value) & params)
and any(isinstance(expr, ast.Call) for expr in ast.walk(stmt.value))
for target in stmt.targets
for key in [_property_key(target)]
if key is not None
}
if not mutated:
continue

for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
if _expr_names(child.test) & params:
continue
tested = {
key
for expr in ast.walk(child.test)
for key in [_property_key(expr)]
if key is not None
}
returned = {
key
for stmt in child.body
if isinstance(stmt, ast.Return)
for key in [_property_key(stmt.value)]
if key is not None
}
if tested & returned & mutated:
return [{
"pattern": "PROPERTY_SETTER_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns property state "
"populated from a prior input through its setter"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3597,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"PROPERTY_SETTER_REPLAY": RulePolicy(
"PROPERTY_SETTER_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3884,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_property_setter_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3923,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("property_setter_replay", detect_property_setter_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4820,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "PROPERTY_SETTER_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down