Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 170 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,169 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_dict_view_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: replay through values/items views of a mutated dict."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

dict_names = {
stmt.targets[0].id
for stmt in tree.body
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and (
isinstance(stmt.value, ast.Dict)
or (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id == "dict"
)
)
}

def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)
return params

def _input_derived(expr: ast.AST | None, params: set[str]) -> bool:
return bool(_expr_names(expr) & params)

def _mutated_dict(expr: ast.AST, params: set[str]) -> str | None:
if isinstance(expr, ast.Assign) and _input_derived(expr.value, params):
for target in expr.targets:
if isinstance(target, ast.Subscript):
root = _ast_root_name(target)
if root in dict_names:
return root
if not isinstance(expr, ast.Call):
return None
if not (
isinstance(expr.func, ast.Attribute)
and expr.func.attr in {"update", "setdefault", "__setitem__"}
):
return None
root = _ast_root_name(expr.func.value)
if root not in dict_names:
return None
if any(_input_derived(arg, params) for arg in expr.args):
return root
if any(_input_derived(keyword.value, params) for keyword in expr.keywords):
return root
return None

def _view_call(expr: ast.AST | None) -> tuple[str, str] | None:
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr in {"values", "items"}
):
return None
root = _ast_root_name(expr.func.value)
if root in dict_names:
return root, expr.func.attr
return None

def _return_view(expr: ast.AST | None, aliases: dict[str, tuple[str, str]]) -> str | None:
if isinstance(expr, ast.Subscript):
return _return_view(expr.value, aliases)
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "next"
and len(expr.args) == 1
):
return None
iter_call = expr.args[0]
if not (
isinstance(iter_call, ast.Call)
and isinstance(iter_call.func, ast.Name)
and iter_call.func.id == "iter"
and len(iter_call.args) == 1
):
return None
source = iter_call.args[0]
view = _view_call(source)
if view:
return view[0]
if isinstance(source, ast.Name) and source.id in aliases:
return aliases[source.id][0]
return None

def _body_has_only_view_calls(body: list[ast.stmt], aliases: dict[str, tuple[str, str]]) -> bool:
allowed_names = {"next", "iter"}
allowed_attrs = {"values", "items"}
for stmt in body:
for expr in ast.walk(stmt):
if not isinstance(expr, ast.Call):
continue
if isinstance(expr.func, ast.Name) and expr.func.id in allowed_names:
continue
if isinstance(expr.func, ast.Attribute) and expr.func.attr in allowed_attrs:
continue
if isinstance(expr.func, ast.Name) and expr.func.id in aliases:
continue
return False
return True

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = _param_names(node)
mutated = {
root
for expr in ast.walk(node)
for root in [_mutated_dict(expr, params)]
if root is not None
}
if not mutated:
continue

aliases = {
stmt.targets[0].id: view
for stmt in ast.walk(node)
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
for view in [_view_call(stmt.value)]
if view is not None
}

for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _expr_names(child.test) & params:
continue
if not _body_has_only_view_calls(child.body, aliases):
continue
returned = {
root
for stmt in child.body
if isinstance(stmt, ast.Return)
for root in [_return_view(stmt.value, aliases)]
if root is not None
}
if returned & mutated:
return [{
"pattern": "DICT_VIEW_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns a values/items view "
"from a dict mutated with a prior input"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3642,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"DICT_VIEW_REPLAY": RulePolicy(
"DICT_VIEW_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3929,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_dict_view_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3968,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("dict_view_replay", detect_dict_view_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4865,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "DICT_VIEW_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down