Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Skip LFS object download when uv clones ltx-core. The only LFS-tracked
# files in ltx-2-internal are test fixtures under packages/ltx-core/tests/assets/
# which are never imported by the package source, so smudging them is wasted
# bandwidth (and currently fails for unauthenticated LFS clients anyway).
GIT_LFS_SKIP_SMUDGE=1
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
# Clone and setup
git clone https://github.com/RightNow-AI/autokernel.git
cd autokernel

# Load project env (sets GIT_LFS_SKIP_SMUDGE=1 so uv can clone ltx-core
# without needing GitHub LFS auth — the LFS objects are test fixtures we
# don't import). Or use direnv / mise if you prefer.
set -a && . ./.env && set +a

uv sync

# One-time setup: test data + baselines
Expand Down
23 changes: 19 additions & 4 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,11 +1029,13 @@ def run_performance(kernel_fn: Callable, config: dict, gpu: GPUSpec,
if not bench_sizes:
bench_sizes = [sizes[-1]]

# Find the primary benchmark size (large or biggest)
# Find the primary benchmark size: "model_primary" (extract.py convention,
# the actual model shape) wins; otherwise "large" (upstream default);
# otherwise the last entry as a fallback.
primary_label = None
primary_size = None
for label, sz in sizes:
if label == "large":
if label in ("model_primary", "large"):
primary_label = label
primary_size = sz
break
Expand Down Expand Up @@ -1262,6 +1264,17 @@ def main():

config = KERNEL_CONFIGS[kernel_type]

# If kernel.py declares TEST_SIZES (extract.py emits these for kernels
# extracted from a real model profile), use them instead of bench.py's
# hardcoded defaults. Otherwise the kernel would be benched against a
# generic shape that has nothing to do with the user's model.
custom_sizes = getattr(kernel_module, "TEST_SIZES", None)
if custom_sizes:
config = dict(config) # shallow-copy so we don't mutate the global
config["test_sizes"] = custom_sizes
print(f" using model-specific TEST_SIZES from kernel.py "
f"({len(custom_sizes)} sizes; primary: '{custom_sizes[0][0]}')")

# ------------------------------------------------------------------
# GPU Detection
# ------------------------------------------------------------------
Expand Down Expand Up @@ -1306,7 +1319,7 @@ def main():
_perf_primary_label = None
_perf_primary_size = None
for _pl, _ps in _perf_sizes:
if _pl == "large":
if _pl in ("model_primary", "large"):
_perf_primary_label = _pl
_perf_primary_size = _ps
break
Expand All @@ -1321,7 +1334,9 @@ def main():
try:
sizes_filter = args.sizes
if args.quick:
sizes_filter = "large"
# In quick mode, bench only the primary size (the model shape if
# custom TEST_SIZES, else "large").
sizes_filter = _perf_primary_label or "large"
torch.cuda.reset_peak_memory_stats()
perf_results = run_performance(kernel_fn, config, gpu, sizes_filter=sizes_filter)
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
Expand Down
114 changes: 103 additions & 11 deletions extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
from __future__ import annotations

import argparse
import ast
import json
import operator
import os
import re
import sys
from functools import reduce
from typing import Any, Dict, List, Optional, Tuple


def _prod(xs):
return reduce(operator.mul, xs, 1)

# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -173,22 +180,105 @@
# Shape parsing
# ---------------------------------------------------------------------------

def parse_shape_info(shape_info_str: str, op_type: str) -> Optional[Dict[str, int]]:
def _parse_bracket_shapes(
shape_info_str: str, op_type: str, name: str
) -> Optional[Dict[str, int]]:
"""Parse PyTorch's `record_shapes=True` bracket format -- e.g.
``"[[4096], [32640, 4096], [4096, 4096], [], []]"`` -- into a canonical
dim dict. Returns None on any parsing/dispatch failure so the caller
can fall through to the legacy key=value parser or default shapes.

Layout knowledge per ATen op (input-tensor order matters):
aten::mm(a, b) -> [[M, K], [K, N]]
aten::bmm(a, b) -> [[B, M, K], [B, K, N]]
aten::addmm(c, a, b) -> [[bias], [M, K], [K, N], ...]
aten::linear(x, w, b) -> [[..., K], [N, K], [N]] (weight is transposed!)
flash attention -> [[B, H, T, D], [B, H, T, D], [B, H, T, D], ...]
norms -> [[..., dim], [dim], [dim], ...]

`nvjet_*` / `cudnn_*` rows that survived dedup inherit their ATen
partner's shapes, so we use a name-prefix-then-shape-shape fallback.
"""
try:
parsed = ast.literal_eval(shape_info_str)
except (ValueError, SyntaxError):
return None
if not isinstance(parsed, (list, tuple)) or not parsed:
return None
tensors = [list(s) for s in parsed if isinstance(s, (list, tuple)) and s]
if not tensors:
return None

if op_type in ("matmul", "fused_mlp"):
if name.startswith("aten::linear") and len(tensors) >= 2 and len(tensors[1]) == 2:
x, w = tensors[0], tensors[1]
return {"M": _prod(x[:-1]), "K": x[-1], "N": w[0]}
if name.startswith("aten::addmm") and len(tensors) >= 3 and len(tensors[1]) == 2:
return {"M": tensors[1][0], "K": tensors[1][1], "N": tensors[2][1]}
if name.startswith("aten::bmm") and len(tensors) >= 2 and len(tensors[0]) == 3:
return {"M": tensors[0][1], "K": tensors[0][2], "N": tensors[1][2]}
if name.startswith("aten::mm") and len(tensors) >= 2 and len(tensors[0]) == 2:
return {"M": tensors[0][0], "K": tensors[0][1], "N": tensors[1][1]}
# nvjet_* / cudnn_* / generic kernel symbols: scan for a contraction-shaped
# 2-D tensor pair (A.shape[1] == B.shape[0]).
for i in range(len(tensors) - 1):
a, b = tensors[i], tensors[i + 1]
if len(a) == 2 and len(b) == 2 and a[1] == b[0]:
return {"M": a[0], "K": a[1], "N": b[1]}

if op_type in ("flash_attention", "rotary_embedding"):
q = tensors[0]
if len(q) == 4:
return {"batch": q[0], "heads": q[1], "seq_len": q[2], "head_dim": q[3]}
if len(q) == 5: # LTX RoPE has [B, H, T, 1, D/2]
return {"batch": q[0], "heads": q[1], "seq_len": q[2], "head_dim": q[4]}

if op_type in ("layernorm", "rmsnorm"):
x = tensors[0]
if len(x) >= 1:
return {"batch": _prod(x[:-1]), "dim": x[-1]}

if op_type == "softmax":
x = tensors[0]
if len(x) >= 1:
return {"rows": _prod(x[:-1]), "cols": x[-1]}

if op_type == "cross_entropy":
x = tensors[0]
if len(x) >= 2:
return {"batch": x[0], "vocab": x[1]}

if op_type == "reduce":
x = tensors[0]
if len(x) >= 2:
return {"M": _prod(x[:-1]), "N": x[-1]}

return None


def parse_shape_info(
shape_info_str: str, op_type: str, name: str = ""
) -> Optional[Dict[str, int]]:
"""
Parse a shape_info string like "M=4096, N=4096, K=4096" into a dict.
Parse a shape_info string into a canonical dim dict.

Handles various formats:
- "M=4096, N=4096, K=4096"
- "B=1, H=32, N=4096, D=128"
- "batch=4096, vocab=32000"
- "rows=4096, cols=4096"
Two input formats are supported:
1. PyTorch bracket format from torch.profiler -- e.g.
"[[4096], [32640, 4096], [4096, 4096], [], []]"
2. Hand-written / legacy key=value -- e.g. "M=4096, N=4096, K=4096"

Returns None if parsing fails.
Returns None on failure so callers fall through to default shapes.
"""
if not shape_info_str or not isinstance(shape_info_str, str):
return None

# Match key=value pairs
s = shape_info_str.lstrip()
if s.startswith("[") or s.startswith("("):
parsed = _parse_bracket_shapes(s, op_type, name)
if parsed:
return parsed

# Match key=value pairs (legacy / hand-written format)
pairs = re.findall(r"([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(\d+)", shape_info_str)
if not pairs:
return None
Expand Down Expand Up @@ -517,8 +607,10 @@ def extract_kernels(
gpu_time_ms = kernel_info.get("gpu_time_ms", kernel_info.get("total_gpu_time_ms", 0.0))
shape_info_str = kernel_info.get("shape_info", kernel_info.get("shape", ""))

# Parse model shape
model_shape = parse_shape_info(shape_info_str, op_type)
# Parse model shape (kernel name is needed to disambiguate matmul flavours)
model_shape = parse_shape_info(
shape_info_str, op_type, kernel_info.get("name", "")
)
if model_shape is None:
# Try to use a "shapes" dict directly if provided
if isinstance(kernel_info.get("shapes"), dict):
Expand Down
124 changes: 124 additions & 0 deletions models/ltx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""LTX-2 transformer wrapper for AutoKernel profiling.

Wraps the 20B `LTXModel` (Lightricks LTX-2 diffusion transformer) loaded from a
Comfy-format `.safetensors` checkpoint. Exposes a `forward(self, x)` shim that
ignores `x` and feeds production-shape AV `Modality` inputs into the transformer.

Scope is the transformer only. NOT included:
- Video VAE encode/decode
- Audio VAE encode + vocoder
- Text encoder (Gemma) and the embeddings connector
- The denoising sampling loop
Random tensors at the dimensions `LTXModel` expects stand in for those upstream
components.

Input size: 1080p, 121 frames @ 24fps (~5s) — matches `PIPELINE_SIZE_1080P_121F`
in the upstream ltx-bench harness.
"""
from __future__ import annotations

import torch
from torch import nn

from ltx_core.components.patchifiers import (
AudioPatchifier,
VideoLatentPatchifier,
)
from ltx_core.guidance.perturbations import BatchedPerturbationConfig
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
from ltx_core.model.transformer import (
LTXV_MODEL_COMFY_RENAMING_MAP,
LTXModelConfigurator,
)
from ltx_core.model.transformer.modality import Modality
from ltx_core.tools import AudioLatentTools, VideoLatentTools
from ltx_core.types import (
VIDEO_SCALE_FACTORS,
AudioLatentShape,
VideoLatentShape,
VideoPixelShape,
)

CHECKPOINT_PATH = "/models/comfyui_models/checkpoints/ltx2.3-20b-20k-step-2358400-v4.safetensors"

HEIGHT = 1088
WIDTH = 1920
NUM_FRAMES = 121
FPS = 24.0

CONTEXT_TOKENS = 64
VIDEO_LATENT_CHANNELS = 128


def _timesteps_from_mask(denoise_mask: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
if sigma.dim() == 1:
sigma = sigma.view(-1, *([1] * (denoise_mask.dim() - 1)))
return denoise_mask * sigma


class LTXModelWrapper(nn.Module):
"""No-arg wrapper around the LTX-2 20B transformer for AutoKernel.

`__init__` loads real checkpoint weights and pre-builds the dummy AV inputs.
`forward(x)` ignores `x` and runs one transformer step.
"""

def __init__(self) -> None:
super().__init__()
device = torch.device("cuda")
dtype = torch.bfloat16

builder = SingleGPUModelBuilder(
model_path=CHECKPOINT_PATH,
model_class_configurator=LTXModelConfigurator,
model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
)
self.ltx = builder.build(device=device, dtype=dtype).eval()

pixel_shape = VideoPixelShape(
batch=1, frames=NUM_FRAMES, height=HEIGHT, width=WIDTH, fps=FPS,
)
v_shape = VideoLatentShape.from_pixel_shape(
pixel_shape,
latent_channels=VIDEO_LATENT_CHANNELS,
scale_factors=VIDEO_SCALE_FACTORS,
)
a_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)

v_tools = VideoLatentTools(VideoLatentPatchifier(patch_size=1), v_shape, FPS)
a_tools = AudioLatentTools(AudioPatchifier(patch_size=1), a_shape)

v_state = v_tools.create_initial_state(device=device, dtype=dtype)
a_state = a_tools.create_initial_state(device=device, dtype=dtype)

v_ctx_dim = self.ltx.transformer_blocks[0].attn2.to_k.in_features
a_ctx_dim = self.ltx.transformer_blocks[0].audio_attn2.to_k.in_features

sigma = torch.full((1,), 0.5, device=device, dtype=torch.float32)
v_context = torch.randn(1, CONTEXT_TOKENS, v_ctx_dim, device=device, dtype=dtype)
a_context = torch.randn(1, CONTEXT_TOKENS, a_ctx_dim, device=device, dtype=dtype)

self._video = Modality(
enabled=True,
latent=v_state.latent,
sigma=sigma,
timesteps=_timesteps_from_mask(v_state.denoise_mask, sigma),
positions=v_state.positions,
context=v_context,
attention_mask=v_state.attention_mask,
)
self._audio = Modality(
enabled=True,
latent=a_state.latent,
sigma=sigma,
timesteps=_timesteps_from_mask(a_state.denoise_mask, sigma),
positions=a_state.positions,
context=a_context,
attention_mask=a_state.attention_mask,
)
self._perturbations = BatchedPerturbationConfig.empty(1)

@torch.inference_mode()
def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: ARG002
vx, _ax = self.ltx(self._video, self._audio, self._perturbations)
return vx
Loading