Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion bench_loop/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ def info() -> None:
default=None,
help="API key for the endpoint. Also reads OPENAI_API_KEY env var. Required for vLLM, sglang, and most cloud providers.",
)
@click.option(
"--max-tokens",
"max_tokens",
default=None,
type=int,
help=(
"Override max_tokens for every task (fixtures default to 2048). Thinking models "
"can burn through 2048 tokens of <think> before reaching an answer, get cut off "
"mid-thought, and have the raw reasoning dumped into content -- which then fails "
"coding's code-block extraction and dataextract's JSON parse. Raise this (e.g. 8192) "
"when benchmarking a reasoning model to rule that out before trusting a low score."
),
)
def run(
model: str,
endpoint: str,
Expand All @@ -132,6 +145,7 @@ def run(
command_used: str | None,
remote: bool,
api_key: str | None,
max_tokens: int | None,
) -> None:
"""Run a benchmark."""
# Set API key from CLI flag if provided (takes precedence over env var)
Expand Down Expand Up @@ -244,6 +258,7 @@ def on_progress(event: dict) -> None:
suites=selected_suites,
harness=harness,
remote=remote,
max_tokens=max_tokens,
on_progress=on_progress,
)
)
Expand Down Expand Up @@ -288,13 +303,18 @@ def on_progress(event: dict) -> None:
elif "timeout" in msg.lower() or type_name == "ReadTimeout":
click.echo("Timeout. Try a smaller model, fewer suites, or check network stability.", err=True)
raise SystemExit(1)
print_run_report(benchmark)
# Save before printing -- a console-rendering crash (e.g. legacy Windows
# cp1252 terminals choking on an emoji) must not cost the whole run's data.
save_run(
benchmark,
endpoint=endpoint,
publish_profile=publish_profile,
command_used=command_used,
)
try:
print_run_report(benchmark)
except Exception as exc: # noqa: BLE001
click.echo(f"(report rendering failed: {exc}; run was already saved)", err=True)


@main.command()
Expand Down
8 changes: 8 additions & 0 deletions bench_loop/runner/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async def run_benchmark(
runs: int | None = None, # accepted but currently unused (single-run)
timeout_sec: float | None = None, # accepted but unused
remote: bool = False, # mark as remote/cloud benchmark
max_tokens: int | None = None, # override every task's fixture max_tokens
) -> BenchmarkRun:
# API back-compat: allow `run_benchmark(config)` where config has
# the same attributes (model/endpoint/provider/suite_names/harness/...).
Expand All @@ -51,6 +52,7 @@ async def run_benchmark(
cfg_suites = getattr(cfg, "suite_names", None) or getattr(cfg, "suites", None)
suites = cfg_suites or suites
harness = getattr(cfg, "harness", None) or harness
max_tokens = getattr(cfg, "max_tokens", None) or max_tokens
elif suite_names and not suites:
suites = suite_names

Expand Down Expand Up @@ -151,6 +153,7 @@ async def run_benchmark(
task_results: list[TaskResult] = []
for task in tasks:
if suite_name == "speed":
# Speed fixtures use deliberately tiny caps to measure raw throughput — never override.
result = await _run_speed_task(
provider_module,
endpoint,
Expand All @@ -160,6 +163,7 @@ async def run_benchmark(
harness=harness_adapter,
provider_name=provider,
remote=remote,
max_tokens_override=None,
)
else:
result = await suite.run_task(
Expand All @@ -169,6 +173,7 @@ async def run_benchmark(
task,
harness=harness_adapter,
provider_name=provider,
max_tokens_override=max_tokens,
)
task_results.append(result)
speed_meta = result.metadata.get("speed_metrics") if isinstance(result.metadata, dict) else None
Expand Down Expand Up @@ -252,13 +257,16 @@ async def _run_speed_task(
harness: Any | None = None,
provider_name: str = "ollama",
remote: bool = False,
max_tokens_override: int | None = None,
) -> TaskResult:
trial_results: list[TaskResult] = []
request = (
harness.prepare(task, provider_name=provider_name)
if harness is not None
else {"messages": task.messages, **task.config}
)
if max_tokens_override is not None:
request["max_tokens"] = max_tokens_override

# Use streaming for remote/cloud to get real TTFT + tok/s
use_streaming = remote and hasattr(provider_module, "chat_streaming")
Expand Down
6 changes: 4 additions & 2 deletions bench_loop/suites/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,13 @@ async def run_task(
task: BenchmarkTask,
harness: Any | None = None,
provider_name: str = "ollama",
max_tokens_override: int | None = None,
) -> TaskResult:
"""Run a multi-turn agent conversation against the model, executing tools
between turns. This OVERRIDES the default single-shot run_task and is
the heart of the agent suite.
"""
per_turn_max_tokens = max_tokens_override or 512
validation = task.validation or {}
max_turns = int(validation.get("max_turns", self.DEFAULT_MAX_TURNS))
allowed_tools = list(validation.get("tools", list(TOOL_SCHEMAS)))
Expand Down Expand Up @@ -362,14 +364,14 @@ async def run_task(
id=task.id,
suite=self.name,
messages=messages,
config={**task.config, "tools": tool_schemas, "max_tokens": 512},
config={**task.config, "tools": tool_schemas, "max_tokens": per_turn_max_tokens},
validation=validation,
metadata=task.metadata,
)
request = (
harness.prepare(synthetic_task, provider_name=provider_name)
if harness is not None
else {"messages": messages, "tools": tool_schemas, "max_tokens": 512}
else {"messages": messages, "tools": tool_schemas, "max_tokens": per_turn_max_tokens}
)

response = await provider_module.chat(
Expand Down
3 changes: 3 additions & 0 deletions bench_loop/suites/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ async def run_task(
task: BenchmarkTask,
harness: Any | None = None,
provider_name: str = "ollama",
max_tokens_override: int | None = None,
) -> TaskResult:
request = (
harness.prepare(task, provider_name=provider_name)
if harness is not None
else {"messages": task.messages, **task.config}
)
if max_tokens_override is not None:
request["max_tokens"] = max_tokens_override
response = await provider_module.chat(
endpoint=endpoint,
model=model,
Expand Down
91 changes: 85 additions & 6 deletions bench_loop/suites/dataextract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Any

Expand All @@ -18,6 +19,83 @@
"DE-13.discounts": "description",
}

_JSON_FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)```", re.DOTALL | re.IGNORECASE)


def _find_matching_bracket(text: str, start: int) -> int | None:
"""Return the index of the bracket that closes `text[start]`, string-aware."""
depth = 0
in_string = False
escape = False
for index in range(start, len(text)):
ch = text[index]
if in_string:
if escape:
escape = False
elif ch == "\\":
escape = True
elif ch == '"':
in_string = False
continue
if ch == '"':
in_string = True
elif ch in "{[":
depth += 1
elif ch in "}]":
depth -= 1
if depth == 0:
return index
return None


def extract_json(text: str) -> tuple[Any, str]:
"""Best-effort JSON extraction from a model response.

Returns (parsed_value, method). method == "none" means nothing parseable
was found -- callers should treat that (not a parsed `None`) as failure,
since these fixtures never expect a bare top-level `null`.

A strict `json.loads` on the whole response fails the moment a thinking
model's response has so much as a stray "Sure, here's the JSON:" in
front of it (or unclosed reasoning dumped in ahead of the real answer --
see openai_compat.py's reasoning-to-content fallback). This recovers the
JSON object/array from inside that surrounding text instead of scoring 0
outright, the same way coding.py already tolerates prose around a code
fence.
"""
if not text:
return None, "none"
stripped = text.strip()
if not stripped:
return None, "none"

try:
return json.loads(stripped), "direct"
except Exception:
pass

fence_match = _JSON_FENCE_RE.search(stripped)
if fence_match:
candidate = fence_match.group(1).strip()
try:
return json.loads(candidate), "fenced"
except Exception:
pass

for start, ch in enumerate(stripped):
if ch not in "{[":
continue
end = _find_matching_bracket(stripped, start)
if end is None:
continue
candidate = stripped[start : end + 1]
try:
return json.loads(candidate), "bracket_scan"
except Exception:
continue

return None, "none"


class DataExtractSuite(BenchmarkSuite):
name = "dataextract"
Expand Down Expand Up @@ -172,23 +250,23 @@ def evaluate(self, task: BenchmarkTask, response: dict[str, Any]) -> TaskResult:
response_text = self.response_text(response)
expected = task.validation.get("expected")
scenario_id = str(task.validation.get("scenario_id") or task.id.upper())
try:
parsed = json.loads(response_text)
except Exception as exc:
parsed, extraction_method = extract_json(response_text)
if extraction_method == "none":
return self.build_result(
task=task,
passed=False,
score=0.0,
response=response,
output=response_text,
error=f"Invalid JSON: {exc}",
error="Invalid JSON: no parseable JSON object/array found in response",
metadata={
"scenario_id": scenario_id,
"evaluation_status": "invalid_json",
"summary": f"Invalid JSON: {exc}",
"note": "Official score is 0 when the response is not valid JSON.",
"summary": "Invalid JSON: no parseable JSON object/array found in response",
"note": "Official score is 0 when no JSON could be extracted from the response.",
"category": task.validation.get("category"),
"title": task.validation.get("title"),
"json_extraction_method": extraction_method,
},
)
exact_shape, fields_only, no_missing, compliance_notes = self._evaluate_compliance(expected, parsed)
Expand Down Expand Up @@ -216,6 +294,7 @@ def evaluate(self, task: BenchmarkTask, response: dict[str, Any]) -> TaskResult:
"note": note,
"category": task.validation.get("category"),
"title": task.validation.get("title"),
"json_extraction_method": extraction_method,
},
)

Expand Down