diff --git a/bench_loop/cli.py b/bench_loop/cli.py index 0788c58..589d362 100644 --- a/bench_loop/cli.py +++ b/bench_loop/cli.py @@ -117,6 +117,19 @@ def info() -> None: default=None, help="API key for the endpoint. Also reads OPENAI_API_KEY env var. Required for vLLM, sglang, and most cloud providers.", ) +@click.option( + "--max-tokens", + "max_tokens", + default=None, + type=int, + help=( + "Override max_tokens for every task (fixtures default to 2048). Thinking models " + "can burn through 2048 tokens of before reaching an answer, get cut off " + "mid-thought, and have the raw reasoning dumped into content -- which then fails " + "coding's code-block extraction and dataextract's JSON parse. Raise this (e.g. 8192) " + "when benchmarking a reasoning model to rule that out before trusting a low score." + ), +) def run( model: str, endpoint: str, @@ -132,6 +145,7 @@ def run( command_used: str | None, remote: bool, api_key: str | None, + max_tokens: int | None, ) -> None: """Run a benchmark.""" # Set API key from CLI flag if provided (takes precedence over env var) @@ -244,6 +258,7 @@ def on_progress(event: dict) -> None: suites=selected_suites, harness=harness, remote=remote, + max_tokens=max_tokens, on_progress=on_progress, ) ) @@ -288,13 +303,18 @@ def on_progress(event: dict) -> None: elif "timeout" in msg.lower() or type_name == "ReadTimeout": click.echo("Timeout. Try a smaller model, fewer suites, or check network stability.", err=True) raise SystemExit(1) - print_run_report(benchmark) + # Save before printing -- a console-rendering crash (e.g. legacy Windows + # cp1252 terminals choking on an emoji) must not cost the whole run's data. save_run( benchmark, endpoint=endpoint, publish_profile=publish_profile, command_used=command_used, ) + try: + print_run_report(benchmark) + except Exception as exc: # noqa: BLE001 + click.echo(f"(report rendering failed: {exc}; run was already saved)", err=True) @main.command() diff --git a/bench_loop/runner/orchestrator.py b/bench_loop/runner/orchestrator.py index d70bd33..cb9f1b6 100644 --- a/bench_loop/runner/orchestrator.py +++ b/bench_loop/runner/orchestrator.py @@ -40,6 +40,7 @@ async def run_benchmark( runs: int | None = None, # accepted but currently unused (single-run) timeout_sec: float | None = None, # accepted but unused remote: bool = False, # mark as remote/cloud benchmark + max_tokens: int | None = None, # override every task's fixture max_tokens ) -> BenchmarkRun: # API back-compat: allow `run_benchmark(config)` where config has # the same attributes (model/endpoint/provider/suite_names/harness/...). @@ -51,6 +52,7 @@ async def run_benchmark( cfg_suites = getattr(cfg, "suite_names", None) or getattr(cfg, "suites", None) suites = cfg_suites or suites harness = getattr(cfg, "harness", None) or harness + max_tokens = getattr(cfg, "max_tokens", None) or max_tokens elif suite_names and not suites: suites = suite_names @@ -151,6 +153,7 @@ async def run_benchmark( task_results: list[TaskResult] = [] for task in tasks: if suite_name == "speed": + # Speed fixtures use deliberately tiny caps to measure raw throughput — never override. result = await _run_speed_task( provider_module, endpoint, @@ -160,6 +163,7 @@ async def run_benchmark( harness=harness_adapter, provider_name=provider, remote=remote, + max_tokens_override=None, ) else: result = await suite.run_task( @@ -169,6 +173,7 @@ async def run_benchmark( task, harness=harness_adapter, provider_name=provider, + max_tokens_override=max_tokens, ) task_results.append(result) speed_meta = result.metadata.get("speed_metrics") if isinstance(result.metadata, dict) else None @@ -252,6 +257,7 @@ async def _run_speed_task( harness: Any | None = None, provider_name: str = "ollama", remote: bool = False, + max_tokens_override: int | None = None, ) -> TaskResult: trial_results: list[TaskResult] = [] request = ( @@ -259,6 +265,8 @@ async def _run_speed_task( if harness is not None else {"messages": task.messages, **task.config} ) + if max_tokens_override is not None: + request["max_tokens"] = max_tokens_override # Use streaming for remote/cloud to get real TTFT + tok/s use_streaming = remote and hasattr(provider_module, "chat_streaming") diff --git a/bench_loop/suites/agent.py b/bench_loop/suites/agent.py index 2be2702..87d026d 100644 --- a/bench_loop/suites/agent.py +++ b/bench_loop/suites/agent.py @@ -328,11 +328,13 @@ async def run_task( task: BenchmarkTask, harness: Any | None = None, provider_name: str = "ollama", + max_tokens_override: int | None = None, ) -> TaskResult: """Run a multi-turn agent conversation against the model, executing tools between turns. This OVERRIDES the default single-shot run_task and is the heart of the agent suite. """ + per_turn_max_tokens = max_tokens_override or 512 validation = task.validation or {} max_turns = int(validation.get("max_turns", self.DEFAULT_MAX_TURNS)) allowed_tools = list(validation.get("tools", list(TOOL_SCHEMAS))) @@ -362,14 +364,14 @@ async def run_task( id=task.id, suite=self.name, messages=messages, - config={**task.config, "tools": tool_schemas, "max_tokens": 512}, + config={**task.config, "tools": tool_schemas, "max_tokens": per_turn_max_tokens}, validation=validation, metadata=task.metadata, ) request = ( harness.prepare(synthetic_task, provider_name=provider_name) if harness is not None - else {"messages": messages, "tools": tool_schemas, "max_tokens": 512} + else {"messages": messages, "tools": tool_schemas, "max_tokens": per_turn_max_tokens} ) response = await provider_module.chat( diff --git a/bench_loop/suites/base.py b/bench_loop/suites/base.py index 30af55f..065e613 100644 --- a/bench_loop/suites/base.py +++ b/bench_loop/suites/base.py @@ -75,12 +75,15 @@ async def run_task( task: BenchmarkTask, harness: Any | None = None, provider_name: str = "ollama", + max_tokens_override: int | None = None, ) -> TaskResult: request = ( harness.prepare(task, provider_name=provider_name) if harness is not None else {"messages": task.messages, **task.config} ) + if max_tokens_override is not None: + request["max_tokens"] = max_tokens_override response = await provider_module.chat( endpoint=endpoint, model=model, diff --git a/bench_loop/suites/dataextract.py b/bench_loop/suites/dataextract.py index a0a53b2..6a612a4 100644 --- a/bench_loop/suites/dataextract.py +++ b/bench_loop/suites/dataextract.py @@ -2,6 +2,7 @@ from __future__ import annotations import json +import re from pathlib import Path from typing import Any @@ -18,6 +19,83 @@ "DE-13.discounts": "description", } +_JSON_FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)```", re.DOTALL | re.IGNORECASE) + + +def _find_matching_bracket(text: str, start: int) -> int | None: + """Return the index of the bracket that closes `text[start]`, string-aware.""" + depth = 0 + in_string = False + escape = False + for index in range(start, len(text)): + ch = text[index] + if in_string: + if escape: + escape = False + elif ch == "\\": + escape = True + elif ch == '"': + in_string = False + continue + if ch == '"': + in_string = True + elif ch in "{[": + depth += 1 + elif ch in "}]": + depth -= 1 + if depth == 0: + return index + return None + + +def extract_json(text: str) -> tuple[Any, str]: + """Best-effort JSON extraction from a model response. + + Returns (parsed_value, method). method == "none" means nothing parseable + was found -- callers should treat that (not a parsed `None`) as failure, + since these fixtures never expect a bare top-level `null`. + + A strict `json.loads` on the whole response fails the moment a thinking + model's response has so much as a stray "Sure, here's the JSON:" in + front of it (or unclosed reasoning dumped in ahead of the real answer -- + see openai_compat.py's reasoning-to-content fallback). This recovers the + JSON object/array from inside that surrounding text instead of scoring 0 + outright, the same way coding.py already tolerates prose around a code + fence. + """ + if not text: + return None, "none" + stripped = text.strip() + if not stripped: + return None, "none" + + try: + return json.loads(stripped), "direct" + except Exception: + pass + + fence_match = _JSON_FENCE_RE.search(stripped) + if fence_match: + candidate = fence_match.group(1).strip() + try: + return json.loads(candidate), "fenced" + except Exception: + pass + + for start, ch in enumerate(stripped): + if ch not in "{[": + continue + end = _find_matching_bracket(stripped, start) + if end is None: + continue + candidate = stripped[start : end + 1] + try: + return json.loads(candidate), "bracket_scan" + except Exception: + continue + + return None, "none" + class DataExtractSuite(BenchmarkSuite): name = "dataextract" @@ -172,23 +250,23 @@ def evaluate(self, task: BenchmarkTask, response: dict[str, Any]) -> TaskResult: response_text = self.response_text(response) expected = task.validation.get("expected") scenario_id = str(task.validation.get("scenario_id") or task.id.upper()) - try: - parsed = json.loads(response_text) - except Exception as exc: + parsed, extraction_method = extract_json(response_text) + if extraction_method == "none": return self.build_result( task=task, passed=False, score=0.0, response=response, output=response_text, - error=f"Invalid JSON: {exc}", + error="Invalid JSON: no parseable JSON object/array found in response", metadata={ "scenario_id": scenario_id, "evaluation_status": "invalid_json", - "summary": f"Invalid JSON: {exc}", - "note": "Official score is 0 when the response is not valid JSON.", + "summary": "Invalid JSON: no parseable JSON object/array found in response", + "note": "Official score is 0 when no JSON could be extracted from the response.", "category": task.validation.get("category"), "title": task.validation.get("title"), + "json_extraction_method": extraction_method, }, ) exact_shape, fields_only, no_missing, compliance_notes = self._evaluate_compliance(expected, parsed) @@ -216,6 +294,7 @@ def evaluate(self, task: BenchmarkTask, response: dict[str, Any]) -> TaskResult: "note": note, "category": task.validation.get("category"), "title": task.validation.get("title"), + "json_extraction_method": extraction_method, }, )