Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ ENV/
scripts/
experiment_configs/
logs/
datasets/
24 changes: 14 additions & 10 deletions experiment_config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
llm:
model_name: Qwen3-8B
api_key: not-needed
api_url: http://0.0.0.0:8000/v1
model_name: Qwen3-8B-SelfCoTFull
api_key: EMPTY
api_url: None
temperature: 0.6
max_tokens: 32768
max_tokens: 8192
max_workers: 64
enable_thinking: true

judge:
model_name: gemma-3-4b-it
api_key: not-needed
api_url: http://127.0.0.1:8006/v1
model_name: Qwen3-8B-SelfCoTFull
api_key: EMPTY
api_url: None
temperature: 0.0
max_tokens: 4096
enable_llm_judge: true

repeats: 1
max_samples: 2
badcase:
enabled: true

repeats: 3
max_samples: 0
datasets_path: datasets
results_path: results
results_path: results/Qwen3-8B-SelfCoTFull
99 changes: 83 additions & 16 deletions run_all.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,94 @@
"""统一评测入口
"""统一评测入口 & 可复用评测工具

运行所有数据集的评测脚本。
所有配置通过 experiment_config.yaml 管理,无需命令行参数。
提供三层使用方式:
1. python run_all.py → 运行当前 experiment_config.yaml 下的所有数据集
2. scripts/run_single_model.py → 指定 model yaml,运行所有数据集
3. scripts/run_single_dataset.py → 指定 dataset yaml,在所有 model 上运行
"""
import os
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Optional

import yaml

# 数据集列表
DATASETS = [
"ToMBench",
"Tomato",
"ToMQA",
"ToMi",
]

EXPERIMENT_CONFIG = Path("experiment_config.yaml")
MODEL_CONFIGS_DIR = Path("experiment_configs")

def run_dataset(dataset: str) -> bool:
"""运行指定数据集的评测

Args:
dataset: 数据集名称
# ---------------------------------------------------------------------------
# 可复用工具函数
# ---------------------------------------------------------------------------


def apply_config(yaml_path: str) -> None:
"""将指定 yaml 复制为 experiment_config.yaml 供各 task run.py 读取。"""
src = Path(yaml_path)
if not src.exists():
raise FileNotFoundError(f"配置文件不存在: {src}")
shutil.copy2(src, EXPERIMENT_CONFIG)
print(f"[config] {src} → {EXPERIMENT_CONFIG}")


def discover_model_configs(configs_dir: str = str(MODEL_CONFIGS_DIR)) -> List[Path]:
"""发现 experiment_configs/ 下所有 model yaml(按文件名排序)。"""
d = Path(configs_dir)
if not d.is_dir():
raise FileNotFoundError(f"模型配置目录不存在: {d}")
yamls = sorted(p for p in d.iterdir() if p.suffix in (".yaml", ".yml") and p.is_file())
if not yamls:
raise RuntimeError(f"{d} 中没有找到任何 yaml 配置文件")
return yamls


def get_model_name(yaml_path: Path) -> str:
"""从 experiment config yaml 中提取 model_name。"""
with open(yaml_path, encoding="utf-8") as f:
cfg = yaml.safe_load(f)
return cfg.get("llm", {}).get("model_name", yaml_path.stem)


Returns:
是否成功
"""
def get_dataset_name(yaml_path: str) -> str:
"""从 dataset config yaml 中提取 dataset 名称。"""
with open(yaml_path, encoding="utf-8") as f:
cfg = yaml.safe_load(f)
name = cfg.get("dataset")
if not name:
raise ValueError(f"yaml 中缺少 'dataset' 字段: {yaml_path}")
return name


def run_dataset(dataset: str) -> bool:
"""运行指定数据集的评测脚本。"""
project_root = Path(__file__).resolve().parent
run_script = Path(f"tasks/{dataset}/run.py")
if not run_script.exists():
if not (project_root / run_script).exists():
print(f"[{dataset}] run.py not found, skipping.")
return False

print(f"\n{'='*60}")
print(f"Running: {dataset}")
print(f"{'='*60}")

env = os.environ.copy()
env["PYTHONPATH"] = str(project_root) + os.pathsep + env.get("PYTHONPATH", "")

try:
result = subprocess.run(
subprocess.run(
[sys.executable, str(run_script)],
check=True,
capture_output=False,
cwd=str(project_root),
env=env,
)
return True
except subprocess.CalledProcessError as e:
Expand All @@ -52,14 +102,31 @@ def run_dataset(dataset: str) -> bool:
return False


def run_datasets(datasets: Optional[List[str]] = None) -> Dict[str, bool]:
"""依次运行指定数据集列表(默认全部),返回 {dataset: success}。"""
if datasets is None:
datasets = DATASETS
results = {}
for ds in datasets:
results[ds] = run_dataset(ds)
return results


# ---------------------------------------------------------------------------
# 默认入口:运行所有数据集
# ---------------------------------------------------------------------------


def main():
for dataset in DATASETS:
run_dataset(dataset)
results = run_datasets()

print(f"\n{'='*60}")
print("All datasets completed.")
for ds, ok in results.items():
status = "OK" if ok else "FAILED"
print(f" {ds}: {status}")
print(f"{'='*60}")


if __name__ == "__main__":
main()
main()
118 changes: 118 additions & 0 deletions src/judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""LLM 语义判断模块

当结构化输出提取失败(max_retry 耗尽)时,使用 LLM 判断模型的原始回答
与标准答案在语义上是否一致。
"""

import logging
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional

from pydantic import BaseModel
from tqdm import tqdm

from src.llm import LLMClient

logger = logging.getLogger(__name__)

JUDGE_PROMPT_TEMPLATE = """\
You are an impartial judge. Given a model's response and a gold (correct) answer, \
determine whether the model's response contains an answer that is semantically \
equivalent to the gold answer.

Focus ONLY on whether the final answer matches in meaning. Ignore formatting, \
extra explanation, or reasoning traces.

## Model Response
{raw_response}

## Gold Answer
{gold_answer}

## Question (for context)
{question}

Does the model's response contain an answer semantically equivalent to the gold answer?\
"""


class JudgeVerdict(BaseModel):
"""LLM judge 输出 schema — 布尔值约束"""
is_correct: bool


def _strip_think_tags(text: str) -> str:
"""Remove <think>...</think> blocks so the judge sees only the final answer."""
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip()


def judge_single(
client: LLMClient,
raw_response: str,
gold_answer: str,
question: str = "",
) -> bool:
"""判断单条原始回答是否语义等价于标准答案。

Args:
client: 用于 judge 的 LLMClient(通常低温度)
raw_response: 被评测模型的完整原始输出
gold_answer: 标准答案
question: 原始问题(提供上下文,可为空)

Returns:
True 表示语义一致
"""
cleaned = _strip_think_tags(raw_response)
prompt = JUDGE_PROMPT_TEMPLATE.format(
raw_response=cleaned,
gold_answer=gold_answer,
question=question,
)
result = client.generate_structure(prompt, JudgeVerdict, max_retry=3)
return getattr(result, "is_correct", False)


def batch_judge(
client: LLMClient,
items: List[Dict[str, Any]],
) -> List[bool]:
"""批量语义判断。

Args:
client: judge LLMClient
items: 每个元素是 {"raw_response": str, "gold_answer": str, "question": str}

Returns:
与 items 等长的布尔列表
"""
if not items:
return []

with ThreadPoolExecutor(client.max_workers) as executor:
futures = [
executor.submit(
judge_single,
client,
item["raw_response"],
item["gold_answer"],
item.get("question", ""),
)
for item in items
]

results = []
for future in tqdm(
futures,
total=len(futures),
desc="LLM Judge",
miniters=100,
):
try:
results.append(future.result())
except Exception:
logger.warning("[Judge] single judge call failed, treating as incorrect")
results.append(False)

return results
3 changes: 2 additions & 1 deletion src/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
- 结构化输出: generate_structure(), batch_generate_structure()
"""

from .client import LLMClient, LLMUsage
from .client import LLMClient, LLMUsage, StructuredResult

__all__ = [
"LLMClient",
"LLMUsage",
"StructuredResult",
]
Loading