Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions src/sandbox/anti_hack.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,28 @@
class HackDetector(ast.NodeVisitor):
"""AST visitor that detects hack patterns in generated code.

Two checks:
Three checks:
1. Hard blacklist: forbidden imports/attribute access (ctypes, vllm, etc.)
2. Torch API whitelist: any torch.X() call NOT in ALLOWED_TORCH_API -> hack
3. getattr(torch, "xxx") / import alias detection: catch obfuscation attempts
"""

def __init__(self, blacklist: List[str] = None):
self.violations: List[str] = []
self.blacklist = blacklist or []
# Track import aliases: {"tr": "torch", "ts": "torch.sum", ...}
self._aliases: dict = {}

# ---- Hard blacklist: imports ----
# ---- Hard blacklist + alias tracking: imports ----
def visit_Import(self, node: ast.Import):
for alias in node.names:
if self._is_blacklisted(alias.name):
self.violations.append(
f"Forbidden import: 'import {alias.name}' (line {node.lineno})"
)
# Track alias: import torch as tr -> "tr" -> "torch"
if alias.asname:
self._aliases[alias.asname] = alias.name
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom):
Expand All @@ -96,12 +102,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
self.violations.append(
f"Forbidden import: 'from {node.module} import {names}' (line {node.lineno})"
)
# Track aliases: from torch import sum as ts -> "ts" -> "torch.sum"
if node.module:
for alias in node.names:
key = alias.asname or alias.name
self._aliases[key] = f"{node.module}.{alias.name}"
self.generic_visit(node)

# ---- Torch API whitelist: call detection ----
# ---- Torch API whitelist + getattr detection: call detection ----
def visit_Call(self, node: ast.Call):
# Get the full attribute chain of the call, e.g. "torch.sum"
call_chain = self._get_attr_chain(node.func)
call_chain = self._resolve_alias(call_chain) if call_chain else call_chain

if call_chain:
# Check hard blacklist first
Expand All @@ -115,6 +127,23 @@ def visit_Call(self, node: ast.Call):
f"Forbidden torch API: '{call_chain}()' not in allowed whitelist (line {node.lineno})"
)

# Detect getattr(torch, "sum") — dynamic attribute access
if (
isinstance(node.func, ast.Name) and node.func.id == "getattr"
and len(node.args) >= 2
and isinstance(node.args[1], ast.Constant)
and isinstance(node.args[1].value, str)
):
# Reconstruct: getattr(torch, "sum") -> "torch.sum"
base = self._get_attr_chain(node.args[0])
if base:
full = f"{base}.{node.args[1].value}"
full = self._resolve_alias(full)
if self._is_torch_api(full) and not self._is_allowed(full):
self.violations.append(
f"Forbidden torch API via getattr: 'getattr({base}, \"{node.args[1].value}\")' (line {node.lineno})"
)

# Detect __import__("vllm...")
if isinstance(node.func, ast.Name) and node.func.id == "__import__":
if node.args and isinstance(node.args[0], ast.Constant):
Expand Down Expand Up @@ -168,10 +197,19 @@ def _is_allowed(self, call_chain: str) -> bool:
for prefix in _TRITON_ALLOWED_PREFIXES:
if call_chain.startswith(prefix):
return True
# Allow torch.Tensor attributes accessed as calls (e.g. x.to(), x.clone())
# These are typically tensor method calls, not torch API calls
return False

def _resolve_alias(self, call_chain: str) -> str:
"""Resolve import aliases: 'tr.sum' -> 'torch.sum'."""
if not call_chain:
return call_chain
parts = call_chain.split(".")
first = parts[0]
if first in self._aliases:
parts[0] = self._aliases[first]
return ".".join(parts)
return call_chain

def _is_importlib_call(self, node: ast.Call) -> bool:
func = node.func
if isinstance(func, ast.Attribute) and func.attr == "import_module":
Expand All @@ -188,10 +226,7 @@ def _get_attr_chain(self, node: ast.AST) -> str:
if isinstance(node, ast.Name):
parts.append(node.id)
elif isinstance(node, ast.Call):
# chained call like foo().bar — get the inner call chain
inner = self._get_attr_chain(node.func)
if inner:
return None # too complex, skip
# chained call like foo().bar — skip
return None
else:
return None
Expand Down
Loading