diff --git a/src/sandbox/anti_hack.py b/src/sandbox/anti_hack.py index 7130ad2..fe697a6 100644 --- a/src/sandbox/anti_hack.py +++ b/src/sandbox/anti_hack.py @@ -72,22 +72,28 @@ class HackDetector(ast.NodeVisitor): """AST visitor that detects hack patterns in generated code. - Two checks: + Three checks: 1. Hard blacklist: forbidden imports/attribute access (ctypes, vllm, etc.) 2. Torch API whitelist: any torch.X() call NOT in ALLOWED_TORCH_API -> hack + 3. getattr(torch, "xxx") / import alias detection: catch obfuscation attempts """ def __init__(self, blacklist: List[str] = None): self.violations: List[str] = [] self.blacklist = blacklist or [] + # Track import aliases: {"tr": "torch", "ts": "torch.sum", ...} + self._aliases: dict = {} - # ---- Hard blacklist: imports ---- + # ---- Hard blacklist + alias tracking: imports ---- def visit_Import(self, node: ast.Import): for alias in node.names: if self._is_blacklisted(alias.name): self.violations.append( f"Forbidden import: 'import {alias.name}' (line {node.lineno})" ) + # Track alias: import torch as tr -> "tr" -> "torch" + if alias.asname: + self._aliases[alias.asname] = alias.name self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom): @@ -96,12 +102,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom): self.violations.append( f"Forbidden import: 'from {node.module} import {names}' (line {node.lineno})" ) + # Track aliases: from torch import sum as ts -> "ts" -> "torch.sum" + if node.module: + for alias in node.names: + key = alias.asname or alias.name + self._aliases[key] = f"{node.module}.{alias.name}" self.generic_visit(node) - # ---- Torch API whitelist: call detection ---- + # ---- Torch API whitelist + getattr detection: call detection ---- def visit_Call(self, node: ast.Call): # Get the full attribute chain of the call, e.g. "torch.sum" call_chain = self._get_attr_chain(node.func) + call_chain = self._resolve_alias(call_chain) if call_chain else call_chain if call_chain: # Check hard blacklist first @@ -115,6 +127,23 @@ def visit_Call(self, node: ast.Call): f"Forbidden torch API: '{call_chain}()' not in allowed whitelist (line {node.lineno})" ) + # Detect getattr(torch, "sum") — dynamic attribute access + if ( + isinstance(node.func, ast.Name) and node.func.id == "getattr" + and len(node.args) >= 2 + and isinstance(node.args[1], ast.Constant) + and isinstance(node.args[1].value, str) + ): + # Reconstruct: getattr(torch, "sum") -> "torch.sum" + base = self._get_attr_chain(node.args[0]) + if base: + full = f"{base}.{node.args[1].value}" + full = self._resolve_alias(full) + if self._is_torch_api(full) and not self._is_allowed(full): + self.violations.append( + f"Forbidden torch API via getattr: 'getattr({base}, \"{node.args[1].value}\")' (line {node.lineno})" + ) + # Detect __import__("vllm...") if isinstance(node.func, ast.Name) and node.func.id == "__import__": if node.args and isinstance(node.args[0], ast.Constant): @@ -168,10 +197,19 @@ def _is_allowed(self, call_chain: str) -> bool: for prefix in _TRITON_ALLOWED_PREFIXES: if call_chain.startswith(prefix): return True - # Allow torch.Tensor attributes accessed as calls (e.g. x.to(), x.clone()) - # These are typically tensor method calls, not torch API calls return False + def _resolve_alias(self, call_chain: str) -> str: + """Resolve import aliases: 'tr.sum' -> 'torch.sum'.""" + if not call_chain: + return call_chain + parts = call_chain.split(".") + first = parts[0] + if first in self._aliases: + parts[0] = self._aliases[first] + return ".".join(parts) + return call_chain + def _is_importlib_call(self, node: ast.Call) -> bool: func = node.func if isinstance(func, ast.Attribute) and func.attr == "import_module": @@ -188,10 +226,7 @@ def _get_attr_chain(self, node: ast.AST) -> str: if isinstance(node, ast.Name): parts.append(node.id) elif isinstance(node, ast.Call): - # chained call like foo().bar — get the inner call chain - inner = self._get_attr_chain(node.func) - if inner: - return None # too complex, skip + # chained call like foo().bar — skip return None else: return None