diff --git a/.gitignore b/.gitignore
index 0d99c03..d6f53cd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -60,3 +60,4 @@ ENV/
scripts/
experiment_configs/
logs/
+datasets/
diff --git a/experiment_config.yaml b/experiment_config.yaml
index 92c1d02..9ebeefc 100644
--- a/experiment_config.yaml
+++ b/experiment_config.yaml
@@ -1,20 +1,24 @@
llm:
- model_name: Qwen3-8B
- api_key: not-needed
- api_url: http://0.0.0.0:8000/v1
+ model_name: Qwen3-8B-SelfCoTFull
+ api_key: EMPTY
+ api_url: None
temperature: 0.6
- max_tokens: 32768
+ max_tokens: 8192
max_workers: 64
enable_thinking: true
judge:
- model_name: gemma-3-4b-it
- api_key: not-needed
- api_url: http://127.0.0.1:8006/v1
+ model_name: Qwen3-8B-SelfCoTFull
+ api_key: EMPTY
+ api_url: None
temperature: 0.0
max_tokens: 4096
+ enable_llm_judge: true
-repeats: 1
-max_samples: 2
+badcase:
+ enabled: true
+
+repeats: 3
+max_samples: 0
datasets_path: datasets
-results_path: results
+results_path: results/Qwen3-8B-SelfCoTFull
diff --git a/run_all.py b/run_all.py
index 66d67ae..6b15e74 100644
--- a/run_all.py
+++ b/run_all.py
@@ -1,32 +1,77 @@
-"""统一评测入口
+"""统一评测入口 & 可复用评测工具
-运行所有数据集的评测脚本。
-所有配置通过 experiment_config.yaml 管理,无需命令行参数。
+提供三层使用方式:
+1. python run_all.py → 运行当前 experiment_config.yaml 下的所有数据集
+2. scripts/run_single_model.py → 指定 model yaml,运行所有数据集
+3. scripts/run_single_dataset.py → 指定 dataset yaml,在所有 model 上运行
"""
+import os
+import shutil
import subprocess
import sys
from pathlib import Path
+from typing import Dict, List, Optional
+import yaml
-# 数据集列表
DATASETS = [
"ToMBench",
"Tomato",
"ToMQA",
+ "ToMi",
]
+EXPERIMENT_CONFIG = Path("experiment_config.yaml")
+MODEL_CONFIGS_DIR = Path("experiment_configs")
-def run_dataset(dataset: str) -> bool:
- """运行指定数据集的评测
- Args:
- dataset: 数据集名称
+# ---------------------------------------------------------------------------
+# 可复用工具函数
+# ---------------------------------------------------------------------------
+
+
+def apply_config(yaml_path: str) -> None:
+ """将指定 yaml 复制为 experiment_config.yaml 供各 task run.py 读取。"""
+ src = Path(yaml_path)
+ if not src.exists():
+ raise FileNotFoundError(f"配置文件不存在: {src}")
+ shutil.copy2(src, EXPERIMENT_CONFIG)
+ print(f"[config] {src} → {EXPERIMENT_CONFIG}")
+
+
+def discover_model_configs(configs_dir: str = str(MODEL_CONFIGS_DIR)) -> List[Path]:
+ """发现 experiment_configs/ 下所有 model yaml(按文件名排序)。"""
+ d = Path(configs_dir)
+ if not d.is_dir():
+ raise FileNotFoundError(f"模型配置目录不存在: {d}")
+ yamls = sorted(p for p in d.iterdir() if p.suffix in (".yaml", ".yml") and p.is_file())
+ if not yamls:
+ raise RuntimeError(f"{d} 中没有找到任何 yaml 配置文件")
+ return yamls
+
+
+def get_model_name(yaml_path: Path) -> str:
+ """从 experiment config yaml 中提取 model_name。"""
+ with open(yaml_path, encoding="utf-8") as f:
+ cfg = yaml.safe_load(f)
+ return cfg.get("llm", {}).get("model_name", yaml_path.stem)
+
- Returns:
- 是否成功
- """
+def get_dataset_name(yaml_path: str) -> str:
+ """从 dataset config yaml 中提取 dataset 名称。"""
+ with open(yaml_path, encoding="utf-8") as f:
+ cfg = yaml.safe_load(f)
+ name = cfg.get("dataset")
+ if not name:
+ raise ValueError(f"yaml 中缺少 'dataset' 字段: {yaml_path}")
+ return name
+
+
+def run_dataset(dataset: str) -> bool:
+ """运行指定数据集的评测脚本。"""
+ project_root = Path(__file__).resolve().parent
run_script = Path(f"tasks/{dataset}/run.py")
- if not run_script.exists():
+ if not (project_root / run_script).exists():
print(f"[{dataset}] run.py not found, skipping.")
return False
@@ -34,11 +79,16 @@ def run_dataset(dataset: str) -> bool:
print(f"Running: {dataset}")
print(f"{'='*60}")
+ env = os.environ.copy()
+ env["PYTHONPATH"] = str(project_root) + os.pathsep + env.get("PYTHONPATH", "")
+
try:
- result = subprocess.run(
+ subprocess.run(
[sys.executable, str(run_script)],
check=True,
capture_output=False,
+ cwd=str(project_root),
+ env=env,
)
return True
except subprocess.CalledProcessError as e:
@@ -52,14 +102,31 @@ def run_dataset(dataset: str) -> bool:
return False
+def run_datasets(datasets: Optional[List[str]] = None) -> Dict[str, bool]:
+ """依次运行指定数据集列表(默认全部),返回 {dataset: success}。"""
+ if datasets is None:
+ datasets = DATASETS
+ results = {}
+ for ds in datasets:
+ results[ds] = run_dataset(ds)
+ return results
+
+
+# ---------------------------------------------------------------------------
+# 默认入口:运行所有数据集
+# ---------------------------------------------------------------------------
+
+
def main():
- for dataset in DATASETS:
- run_dataset(dataset)
+ results = run_datasets()
print(f"\n{'='*60}")
print("All datasets completed.")
+ for ds, ok in results.items():
+ status = "OK" if ok else "FAILED"
+ print(f" {ds}: {status}")
print(f"{'='*60}")
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/src/judge.py b/src/judge.py
new file mode 100644
index 0000000..ad8c9bd
--- /dev/null
+++ b/src/judge.py
@@ -0,0 +1,118 @@
+"""LLM 语义判断模块
+
+当结构化输出提取失败(max_retry 耗尽)时,使用 LLM 判断模型的原始回答
+与标准答案在语义上是否一致。
+"""
+
+import logging
+import re
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Dict, List, Optional
+
+from pydantic import BaseModel
+from tqdm import tqdm
+
+from src.llm import LLMClient
+
+logger = logging.getLogger(__name__)
+
+JUDGE_PROMPT_TEMPLATE = """\
+You are an impartial judge. Given a model's response and a gold (correct) answer, \
+determine whether the model's response contains an answer that is semantically \
+equivalent to the gold answer.
+
+Focus ONLY on whether the final answer matches in meaning. Ignore formatting, \
+extra explanation, or reasoning traces.
+
+## Model Response
+{raw_response}
+
+## Gold Answer
+{gold_answer}
+
+## Question (for context)
+{question}
+
+Does the model's response contain an answer semantically equivalent to the gold answer?\
+"""
+
+
+class JudgeVerdict(BaseModel):
+ """LLM judge 输出 schema — 布尔值约束"""
+ is_correct: bool
+
+
+def _strip_think_tags(text: str) -> str:
+ """Remove ... blocks so the judge sees only the final answer."""
+ return re.sub(r"[\s\S]*?", "", text).strip()
+
+
+def judge_single(
+ client: LLMClient,
+ raw_response: str,
+ gold_answer: str,
+ question: str = "",
+) -> bool:
+ """判断单条原始回答是否语义等价于标准答案。
+
+ Args:
+ client: 用于 judge 的 LLMClient(通常低温度)
+ raw_response: 被评测模型的完整原始输出
+ gold_answer: 标准答案
+ question: 原始问题(提供上下文,可为空)
+
+ Returns:
+ True 表示语义一致
+ """
+ cleaned = _strip_think_tags(raw_response)
+ prompt = JUDGE_PROMPT_TEMPLATE.format(
+ raw_response=cleaned,
+ gold_answer=gold_answer,
+ question=question,
+ )
+ result = client.generate_structure(prompt, JudgeVerdict, max_retry=3)
+ return getattr(result, "is_correct", False)
+
+
+def batch_judge(
+ client: LLMClient,
+ items: List[Dict[str, Any]],
+) -> List[bool]:
+ """批量语义判断。
+
+ Args:
+ client: judge LLMClient
+ items: 每个元素是 {"raw_response": str, "gold_answer": str, "question": str}
+
+ Returns:
+ 与 items 等长的布尔列表
+ """
+ if not items:
+ return []
+
+ with ThreadPoolExecutor(client.max_workers) as executor:
+ futures = [
+ executor.submit(
+ judge_single,
+ client,
+ item["raw_response"],
+ item["gold_answer"],
+ item.get("question", ""),
+ )
+ for item in items
+ ]
+
+ results = []
+ for future in tqdm(
+ futures,
+ total=len(futures),
+ desc="LLM Judge",
+ miniters=100,
+ ):
+ try:
+ results.append(future.result())
+ except Exception:
+ logger.warning("[Judge] single judge call failed, treating as incorrect")
+ results.append(False)
+
+ return results
diff --git a/src/llm/__init__.py b/src/llm/__init__.py
index b1fdb82..0693070 100644
--- a/src/llm/__init__.py
+++ b/src/llm/__init__.py
@@ -6,9 +6,10 @@
- 结构化输出: generate_structure(), batch_generate_structure()
"""
-from .client import LLMClient, LLMUsage
+from .client import LLMClient, LLMUsage, StructuredResult
__all__ = [
"LLMClient",
"LLMUsage",
+ "StructuredResult",
]
diff --git a/src/llm/client.py b/src/llm/client.py
index c8cdbfc..fe2cbf3 100644
--- a/src/llm/client.py
+++ b/src/llm/client.py
@@ -39,6 +39,27 @@ class LLMUsage:
latency: float = 0.0
+@dataclass
+class StructuredResult:
+ """Wraps a structured generation result with raw output metadata.
+
+ Provides backward-compatible access to the parsed Pydantic fields
+ (e.g. ``r.answer``) while also exposing the raw LLM response for
+ bad-case analysis and LLM-judge fallback.
+ """
+ parsed: BaseModel
+ raw_response: str = ""
+ reasoning_content: str = ""
+ extraction_success: bool = True
+
+ @property
+ def answer(self):
+ return getattr(self.parsed, "answer", "")
+
+ def __getattr__(self, name: str):
+ return getattr(self.parsed, name)
+
+
# ---------------------------------------------------------------------------
# LLM Client
# ---------------------------------------------------------------------------
@@ -271,8 +292,8 @@ def generate_structure(
prompt: str,
response_object: Type[BaseModel],
max_retry: int = 5,
- ) -> BaseModel:
- """调用 LLM,返回 Pydantic 对象(自动适配不同模型)。
+ ) -> "StructuredResult":
+ """调用 LLM,返回 StructuredResult(自动适配不同模型)。
两阶段降级策略:
1. 首选:chat.completions.parse() - 直接返回 Pydantic 对象,最佳体验
@@ -284,13 +305,12 @@ def generate_structure(
max_retry: 最大重试次数
Returns:
- response_object 的实例,失败时返回空实例
+ StructuredResult, extraction_success=False when all retries exhausted
"""
# 首次检测:尝试使用 parse API
if self._parse_supported is None:
with self._parse_lock:
if self._parse_supported is None:
- # 尝试一次,成功则标记支持,失败则不支持
try:
result = self._generate_with_parse(prompt, response_object, max_retry=1)
self._parse_supported = True
@@ -298,9 +318,7 @@ def generate_structure(
except Exception:
self._parse_supported = False
logging.warning(f"[LLM] Model {self.model} parse API failed, switching to JSON object mode")
- # 继续使用降级模式
- # 根据检测结果选择模式
if self._parse_supported:
return self._generate_with_parse(prompt, response_object, max_retry)
else:
@@ -311,12 +329,15 @@ def _generate_with_parse(
prompt: str,
response_object: Type[BaseModel],
max_retry: int = 5,
- ) -> BaseModel:
+ ) -> "StructuredResult":
"""使用 parse API 的原生结构化输出。"""
extra_body: Dict[str, Any] = {"top_k": self.top_k}
if not self.enable_thinking:
extra_body["chat_template_kwargs"] = {"enable_thinking": False}
+ last_raw = ""
+ last_reasoning = ""
+
for attempt in range(max_retry):
try:
start = time.time()
@@ -336,30 +357,45 @@ def _generate_with_parse(
usage.completion_tokens = response.usage.completion_tokens
usage.total_tokens = response.usage.total_tokens
- result = response.choices[0].message.parsed
+ msg = response.choices[0].message
+ last_raw = msg.content or ""
+ last_reasoning = (
+ getattr(msg, "reasoning_content", "")
+ or getattr(msg, "reasoning", "")
+ or ""
+ )
+
+ result = msg.parsed
self._track_usage(usage, success=True)
- return result
+ return StructuredResult(
+ parsed=result,
+ raw_response=last_raw,
+ reasoning_content=last_reasoning,
+ extraction_success=True,
+ )
except Exception as e:
- import traceback
logging.warning(f"[LLM] parse mode attempt {attempt + 1}")
logging.error(f"[LLM] parse mode all {max_retry} attempts exhausted")
self._track_usage(LLMUsage(), success=False)
- return response_object.model_construct()
+ return StructuredResult(
+ parsed=response_object.model_construct(),
+ raw_response=last_raw,
+ reasoning_content=last_reasoning,
+ extraction_success=False,
+ )
def _generate_with_json_object(
self,
prompt: str,
response_object: Type[BaseModel],
max_retry: int = 5,
- ) -> BaseModel:
+ ) -> "StructuredResult":
"""降级模式:使用 json_object response_format + prompt 引导 + 解析验证。"""
import json
- # 构建 schema 描述
schema_desc = self._format_schema_for_prompt(response_object)
- # 增强提示词
enhanced_prompt = f"""{prompt}
---
@@ -373,6 +409,9 @@ def _generate_with_json_object(
if not self.enable_thinking:
extra_body["chat_template_kwargs"] = {"enable_thinking": False}
+ last_raw = ""
+ last_reasoning = ""
+
for attempt in range(max_retry):
try:
start = time.time()
@@ -393,24 +432,40 @@ def _generate_with_json_object(
usage.completion_tokens = response.usage.completion_tokens
usage.total_tokens = response.usage.total_tokens
- content = response.choices[0].message.content or ""
- # 提取 JSON
+ msg = response.choices[0].message
+ content = msg.content or ""
+ reasoning = (
+ getattr(msg, "reasoning_content", "")
+ or getattr(msg, "reasoning", "")
+ or ""
+ )
+ last_raw = content
+ last_reasoning = reasoning
+
json_data = self._extract_json(content)
if json_data is None:
raise ValueError(f"Failed to extract valid JSON: {content[:200]}")
- # 用 Pydantic 验证(不符合就重试)
result = response_object.model_validate(json_data)
self._track_usage(usage, success=True)
- return result
+ return StructuredResult(
+ parsed=result,
+ raw_response=content,
+ reasoning_content=reasoning,
+ extraction_success=True,
+ )
except Exception as e:
- import traceback
logging.warning(f"[LLM] json_object mode attempt {attempt + 1}")
logging.error(f"[LLM] json_object mode all {max_retry} attempts exhausted")
self._track_usage(LLMUsage(), success=False)
- return response_object.model_construct()
+ return StructuredResult(
+ parsed=response_object.model_construct(),
+ raw_response=last_raw,
+ reasoning_content=last_reasoning,
+ extraction_success=False,
+ )
def _extract_json(self, text: str) -> Optional[Dict[str, Any]]:
"""从文本中提取 JSON。"""
@@ -471,9 +526,9 @@ def batch_generate_structure(
self,
prompts: List[str],
response_object: Type[BaseModel],
- ) -> List[BaseModel]:
+ ) -> List["StructuredResult"]:
"""
- 批量调用 LLM,返回 Pydantic 对象列表。
+ 批量调用 LLM,返回 StructuredResult 列表。
支持并行调用。
@@ -482,7 +537,7 @@ def batch_generate_structure(
response_object: 必须提供的 Pydantic 模型类
Returns:
- response_object 实例的列表
+ StructuredResult 实例的列表
"""
with ThreadPoolExecutor(self.max_workers) as executor:
futures = [
diff --git a/src/runner.py b/src/runner.py
index c392dc2..cf743bf 100644
--- a/src/runner.py
+++ b/src/runner.py
@@ -6,12 +6,12 @@
import os
import sys
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import yaml
from src.dataloader import load_dataset
-from src.llm import LLMClient
+from src.llm import LLMClient, StructuredResult
import logging
# 将日志级别设置为 WARNING 或更高
logging.getLogger("urllib3").setLevel(logging.WARNING)
@@ -65,13 +65,17 @@ def load_experiment_config(config_path: str) -> Dict[str, Any]:
"""
with open(config_path, encoding="utf-8") as f:
config = yaml.safe_load(f)
+ judge_block = config.get("judge", {})
+ badcase_block = config.get("badcase", {})
return {
"llm_config": config.get("llm", {}),
"repeats": config.get("repeats", 1),
"max_samples": config.get("max_samples", 0),
"datasets_path": config.get("datasets_path", "datasets"),
"results_path": config.get("results_path", "results"),
- "judge_config": config.get("judge", {}), # 覆盖数据集的 judge 配置
+ "judge_config": judge_block,
+ "enable_llm_judge": judge_block.get("enable_llm_judge", False),
+ "badcase_enabled": badcase_block.get("enabled", False),
}
@@ -141,6 +145,8 @@ def save_common_results(
metadata: Optional[Dict[str, Any]] = None,
dataset_config: Optional[Dict[str, Any]] = None,
experiment_config: Optional[Dict[str, Any]] = None,
+ badcases: Optional[List[Dict[str, Any]]] = None,
+ all_metrics_with_judge: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[Path, Path, Path]:
"""保存评测结果
@@ -157,6 +163,8 @@ def save_common_results(
metadata: 额外元数据(如 judge_model)
dataset_config: 数据集配置字典(保存到 config.json)
experiment_config: 实验配置字典(保存到 config.json,会过滤 api_key 和 api_url)
+ badcases: bad case 记录列表(保存到 badcases.jsonl)
+ all_metrics_with_judge: LLM judge 兜底后的 metrics 列表(保存到 metrics.json)
Returns:
(config_path, metrics_path, prediction_path) 元组
@@ -176,17 +184,14 @@ def save_common_results(
"repeats": len(all_metrics),
}
- # 添加 dataset_config 内容(排除 schemas_module 等非 JSON 可序列化对象)
if dataset_config:
dataset_config_copy = dict(dataset_config)
dataset_config_copy.pop("schemas_module", None)
- dataset_config_copy.pop("schema", None) # schema 是类对象,不可序列化
+ dataset_config_copy.pop("schema", None)
config_data["dataset_config"] = dataset_config_copy
- # 添加 experiment_config 内容(排除敏感信息)
if experiment_config:
experiment_config_copy = dict(experiment_config)
- # 过滤敏感信息
if "llm_config" in experiment_config_copy:
llm_config_copy = dict(experiment_config_copy["llm_config"])
llm_config_copy.pop("api_key", None)
@@ -208,13 +213,18 @@ def save_common_results(
encoding="utf-8",
)
- # 2. 保存 metrics.json
+ # 2. 保存 metrics.json(含 strict 和可选的 judge 两套指标)
avg_metrics = _compute_average_metrics(all_metrics)
- metrics_data = {
+ metrics_data: Dict[str, Any] = {
"avg_metrics": avg_metrics,
"all_metrics": all_metrics,
}
+ if all_metrics_with_judge:
+ avg_judge = _compute_average_metrics(all_metrics_with_judge)
+ metrics_data["avg_metrics_with_judge"] = avg_judge
+ metrics_data["all_metrics_with_judge"] = all_metrics_with_judge
+
metrics_path = output_dir / "metrics.json"
metrics_path.write_text(
json.dumps(metrics_data, ensure_ascii=False, indent=2),
@@ -235,6 +245,14 @@ def save_common_results(
}
f.write(json.dumps(record, ensure_ascii=False) + "\n")
+ # 4. 保存 badcases.jsonl(可选)
+ if badcases:
+ badcases_path = output_dir / "badcases.jsonl"
+ with open(badcases_path, "w", encoding="utf-8") as f:
+ for bc in badcases:
+ f.write(json.dumps(bc, ensure_ascii=False) + "\n")
+ print(f" - badcases.jsonl ({len(badcases)} bad cases)")
+
print(f"Results saved to: {output_dir}")
print(f" - config.json")
print(f" - metrics.json")
@@ -290,3 +308,95 @@ def load_and_limit_data(
random.seed(seed)
data = random.sample(data, min(max_samples, len(data)))
return data
+
+
+# ---------------------------------------------------------------------------
+# Bad-case collection & LLM-judge helpers
+# ---------------------------------------------------------------------------
+
+
+def collect_badcases(
+ results: List[StructuredResult],
+ predictions: List[str],
+ gold_answers: List[str],
+ prompts: List[str],
+ dataset_name: str,
+ is_correct_fn: Callable[[str, str], bool],
+ repeat_idx: int = 0,
+ judge_verdicts: Optional[List[Optional[bool]]] = None,
+) -> List[Dict[str, Any]]:
+ """收集 bad case 记录。
+
+ Args:
+ results: StructuredResult 列表
+ predictions: 提取到的预测答案列表
+ gold_answers: 标准答案列表
+ prompts: 输入 prompt 列表
+ dataset_name: 数据集名称
+ is_correct_fn: (prediction, gold) -> bool 的判定函数
+ repeat_idx: 当前 repeat 索引
+ judge_verdicts: LLM judge 判定结果(与 results 等长,未启用时为 None)
+
+ Returns:
+ bad case 字典列表
+ """
+ badcases: List[Dict[str, Any]] = []
+
+ for i, (r, pred, gold, prompt) in enumerate(
+ zip(results, predictions, gold_answers, prompts)
+ ):
+ if not r.extraction_success:
+ error_type = "extraction_failed"
+ elif not is_correct_fn(pred, gold):
+ error_type = "wrong_answer"
+ else:
+ continue
+
+ jv = judge_verdicts[i] if judge_verdicts is not None else None
+ badcases.append({
+ "repeat": repeat_idx,
+ "sample_idx": i,
+ "dataset": dataset_name,
+ "error_type": error_type,
+ "prompt": prompt,
+ "raw_response": r.raw_response,
+ "reasoning_content": r.reasoning_content,
+ "prediction": pred,
+ "gold_answer": gold,
+ "judge_result": jv,
+ })
+
+ return badcases
+
+
+def build_corrected_predictions(
+ predictions: List[str],
+ results: List[StructuredResult],
+ judge_verdicts: List[bool],
+ gold_answers: List[str],
+) -> List[str]:
+ """构建 LLM judge 兜底后的预测列表。
+
+ 对于 extraction 成功的样本,保持原始预测不变;
+ 对于 extraction 失败但 judge 判定语义正确的样本,替换为 gold_answer;
+ 其余保持原值(空字符串)。
+
+ Args:
+ predictions: 原始预测列表
+ results: StructuredResult 列表
+ judge_verdicts: 与 extraction_failed 样本对应的 judge 结果
+ gold_answers: 标准答案列表
+
+ Returns:
+ 修正后的预测列表(与 predictions 等长)
+ """
+ corrected = list(predictions)
+ judge_idx = 0
+
+ for i, r in enumerate(results):
+ if not r.extraction_success:
+ if judge_idx < len(judge_verdicts) and judge_verdicts[judge_idx]:
+ corrected[i] = gold_answers[i]
+ judge_idx += 1
+
+ return corrected
diff --git a/tables/SUMMARY.md b/tables/SUMMARY.md
index 684fb10..62680c6 100644
--- a/tables/SUMMARY.md
+++ b/tables/SUMMARY.md
@@ -1,7 +1,8 @@
## 总览表格:Accuracy
-| 数据集 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it |
-|---|-:|-:|-:|-:|-:|
-| ToMBench | 0.6337 | 0.4104 | 0.6242 | 0.6688 | 0.6014 |
-| ToMQA | - | 0.3549 | 0.5825 | 0.5611 | - |
-| Tomato | 0.6844 | 0.4034 | 0.6453 | 0.6769 | 0.5632 |
+| 数据集 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart | deepseek-chat | deepseek-r1 | Qwen3-8B-SelfCoT | Qwen3-8B-SelfCoTFull |
+| --- | -: | -: | -: | -: | -: | -: | -: | -: | -: | -: |
+| ToMBench | 0.6337 | 0.4104 | 0.6242 | 0.6688 | 0.6014 | 0.6198 | 0.7698 | 0.8097 | 0.6262 | 0.6667 |
+| ToMQA | - | 0.3549 | 0.5825 | 0.5611 | - | - | - | - | 0.5484 | 0.5547 |
+| Tomato | 0.6844 | 0.4034 | 0.6453 | 0.6769 | 0.5632 | 0.5885 | 0.7938 | 0.7894 | 0.6280 | 0.5816 |
+| ToMi | - | - | - | - | - | - | - | - | 0.7508 | 0.7561 |
diff --git "a/tables/ToMBench/\345\205\266\344\273\226\346\214\207\346\240\207.md" "b/tables/ToMBench/\345\205\266\344\273\226\346\214\207\346\240\207.md"
index 81f426b..46c2a32 100644
--- "a/tables/ToMBench/\345\205\266\344\273\226\346\214\207\346\240\207.md"
+++ "b/tables/ToMBench/\345\205\266\344\273\226\346\214\207\346\240\207.md"
@@ -1,37 +1,37 @@
# ToMBench - 其他指标
-| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it |
-|---|---|---|---|---|---|
-| by_ability.Belief: Beliefs based action/emotions | 0.6620 | 0.4390 | 0.6385 | 0.6761 | 0.6150 |
-| by_ability.Belief: Content false beliefs | 0.7017 | 0.4417 | 0.5267 | 0.5850 | 0.5817 |
-| by_ability.Belief: Content false beliefs Belief: Second-order beliefs | 0.8133 | 0.3333 | 0.7933 | 0.8467 | 0.6667 |
-| by_ability.Belief: Identity false beliefs | 0.8000 | 0.4333 | 0.6667 | 0.7833 | 0.7167 |
-| by_ability.Belief: Location false beliefs | 0.7533 | 0.3933 | 0.7667 | 0.8000 | 0.8083 |
-| by_ability.Belief: Location false beliefs Belief: Second-order beliefs | 0.4967 | 0.1933 | 0.1500 | 0.3167 | 0.3567 |
-| by_ability.Belief: Sequence false beliefs | 0.5300 | 0.4433 | 0.5633 | 0.6100 | 0.5133 |
-| by_ability.Desire: Desire-action contradiction | 0.6667 | 0.4833 | 0.6917 | 0.7500 | 0.6250 |
-| by_ability.Desire: Desires influence on actions | 0.4912 | 0.3596 | 0.5175 | 0.4912 | 0.4781 |
-| by_ability.Desire: Desires influence on emotions (beliefs) | 0.4722 | 0.3611 | 0.5278 | 0.5139 | 0.5417 |
-| by_ability.Desire: Discrepant desires | 0.4667 | 0.2500 | 0.3333 | 0.4167 | 0.4000 |
-| by_ability.Desire: Multiple desires | 0.6500 | 0.3833 | 0.6000 | 0.7167 | 0.6500 |
-| by_ability.Emotion: Atypical emotional reactions | 0.4567 | 0.2433 | 0.5033 | 0.5733 | 0.5667 |
-| by_ability.Emotion: Discrepant emotions | 0.5667 | 0.3833 | 0.6250 | 0.7000 | 0.5750 |
-| by_ability.Emotion: Emotion regulation | 0.4333 | 0.3000 | 0.3667 | 0.4500 | 0.5000 |
-| by_ability.Emotion: Hidden emotions | 0.5792 | 0.3250 | 0.6500 | 0.7250 | 0.5333 |
-| by_ability.Emotion: Mixed emotions | 0.3417 | 0.5000 | 0.4500 | 0.4083 | 0.3250 |
-| by_ability.Emotion: Moral emotions | 0.7333 | 0.4750 | 0.7000 | 0.7250 | 0.7500 |
-| by_ability.Emotion: Typical emotional reactions | 0.8400 | 0.6533 | 0.8567 | 0.8333 | 0.8933 |
-| by_ability.Intention: Completion of failed actions | 0.3500 | 0.3167 | 0.4333 | 0.4167 | 0.4500 |
-| by_ability.Intention: Discrepant intentions | 0.8250 | 0.5250 | 0.7250 | 0.8000 | 0.6750 |
-| by_ability.Intention: Intentions explanations | 0.7423 | 0.3795 | 0.6974 | 0.7449 | 0.6256 |
-| by_ability.Intention: Prediction of actions | 0.6667 | 0.2667 | 0.5833 | 0.4667 | 0.5000 |
-| by_ability.Knowledge: Information-knowledge links | 0.3333 | 0.2350 | 0.3917 | 0.4117 | 0.3083 |
-| by_ability.Knowledge: Knowledge-attention links | 0.4000 | 0.2333 | 0.1667 | 0.3667 | 0.3000 |
-| by_ability.Knowledge: Knowledge-pretend play links | 0.1111 | 0.2778 | 0.2444 | 0.2444 | 0.1000 |
-| by_ability.Knowledge: Percepts-knowledge links | 0.7000 | 0.4083 | 0.6000 | 0.6917 | 0.6333 |
-| by_ability.Non-Literal Communication: Faux pas | 0.6310 | 0.5435 | 0.7095 | 0.7393 | 0.6601 |
-| by_ability.Non-Literal Communication: Involuntary lies | 0.7698 | 0.4127 | 0.7619 | 0.8889 | 0.6905 |
-| by_ability.Non-Literal Communication: Irony/Sarcasm | 0.5897 | 0.2821 | 0.7436 | 0.7436 | 0.6154 |
-| by_ability.Non-literal communication: Egocentric lies | 0.9167 | 0.4750 | 0.9083 | 0.9333 | 0.8500 |
-| by_ability.Non-literal communication: Humor | 0.9250 | 0.4250 | 0.8167 | 0.9250 | 0.7500 |
-| by_ability.Non-literal communication: White lies | 0.9083 | 0.3333 | 0.8167 | 0.8667 | 0.7250 |
+| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart |
+| --- | --- | --- | --- | --- | --- | --- |
+| by_ability.Belief: Beliefs based action/emotions | 0.6620 | 0.4390 | 0.6385 | 0.6761 | 0.6150 | 0.6197 |
+| by_ability.Belief: Content false beliefs | 0.7017 | 0.4417 | 0.5267 | 0.5850 | 0.5817 | 0.6417 |
+| by_ability.Belief: Content false beliefs Belief: Second-order beliefs | 0.8133 | 0.3333 | 0.7933 | 0.8467 | 0.6667 | 0.3667 |
+| by_ability.Belief: Identity false beliefs | 0.8000 | 0.4333 | 0.6667 | 0.7833 | 0.7167 | 0.7167 |
+| by_ability.Belief: Location false beliefs | 0.7533 | 0.3933 | 0.7667 | 0.8000 | 0.8083 | 0.9333 |
+| by_ability.Belief: Location false beliefs Belief: Second-order beliefs | 0.4967 | 0.1933 | 0.1500 | 0.3167 | 0.3567 | 0.3733 |
+| by_ability.Belief: Sequence false beliefs | 0.5300 | 0.4433 | 0.5633 | 0.6100 | 0.5133 | 0.5100 |
+| by_ability.Desire: Desire-action contradiction | 0.6667 | 0.4833 | 0.6917 | 0.7500 | 0.6250 | 0.6833 |
+| by_ability.Desire: Desires influence on actions | 0.4912 | 0.3596 | 0.5175 | 0.4912 | 0.4781 | 0.4254 |
+| by_ability.Desire: Desires influence on emotions (beliefs) | 0.4722 | 0.3611 | 0.5278 | 0.5139 | 0.5417 | 0.4583 |
+| by_ability.Desire: Discrepant desires | 0.4667 | 0.2500 | 0.3333 | 0.4167 | 0.4000 | 0.3833 |
+| by_ability.Desire: Multiple desires | 0.6500 | 0.3833 | 0.6000 | 0.7167 | 0.6500 | 0.6833 |
+| by_ability.Emotion: Atypical emotional reactions | 0.4567 | 0.2433 | 0.5033 | 0.5733 | 0.5667 | 0.5000 |
+| by_ability.Emotion: Discrepant emotions | 0.5667 | 0.3833 | 0.6250 | 0.7000 | 0.5750 | 0.5167 |
+| by_ability.Emotion: Emotion regulation | 0.4333 | 0.3000 | 0.3667 | 0.4500 | 0.5000 | 0.3167 |
+| by_ability.Emotion: Hidden emotions | 0.5792 | 0.3250 | 0.6500 | 0.7250 | 0.5333 | 0.6458 |
+| by_ability.Emotion: Mixed emotions | 0.3417 | 0.5000 | 0.4500 | 0.4083 | 0.3250 | 0.4583 |
+| by_ability.Emotion: Moral emotions | 0.7333 | 0.4750 | 0.7000 | 0.7250 | 0.7500 | 0.6833 |
+| by_ability.Emotion: Typical emotional reactions | 0.8400 | 0.6533 | 0.8567 | 0.8333 | 0.8933 | 0.8367 |
+| by_ability.Intention: Completion of failed actions | 0.3500 | 0.3167 | 0.4333 | 0.4167 | 0.4500 | 0.3500 |
+| by_ability.Intention: Discrepant intentions | 0.8250 | 0.5250 | 0.7250 | 0.8000 | 0.6750 | 0.7167 |
+| by_ability.Intention: Intentions explanations | 0.7423 | 0.3795 | 0.6974 | 0.7449 | 0.6256 | 0.6667 |
+| by_ability.Intention: Prediction of actions | 0.6667 | 0.2667 | 0.5833 | 0.4667 | 0.5000 | 0.5667 |
+| by_ability.Knowledge: Information-knowledge links | 0.3333 | 0.2350 | 0.3917 | 0.4117 | 0.3083 | 0.3550 |
+| by_ability.Knowledge: Knowledge-attention links | 0.4000 | 0.2333 | 0.1667 | 0.3667 | 0.3000 | 0.2667 |
+| by_ability.Knowledge: Knowledge-pretend play links | 0.1111 | 0.2778 | 0.2444 | 0.2444 | 0.1000 | 0.0556 |
+| by_ability.Knowledge: Percepts-knowledge links | 0.7000 | 0.4083 | 0.6000 | 0.6917 | 0.6333 | 0.7667 |
+| by_ability.Non-Literal Communication: Faux pas | 0.6310 | 0.5435 | 0.7095 | 0.7393 | 0.6601 | 0.6911 |
+| by_ability.Non-Literal Communication: Involuntary lies | 0.7698 | 0.4127 | 0.7619 | 0.8889 | 0.6905 | 0.7143 |
+| by_ability.Non-Literal Communication: Irony/Sarcasm | 0.5897 | 0.2821 | 0.7436 | 0.7436 | 0.6154 | 0.8077 |
+| by_ability.Non-literal communication: Egocentric lies | 0.9167 | 0.4750 | 0.9083 | 0.9333 | 0.8500 | 0.8417 |
+| by_ability.Non-literal communication: Humor | 0.9250 | 0.4250 | 0.8167 | 0.9250 | 0.7500 | 0.8583 |
+| by_ability.Non-literal communication: White lies | 0.9083 | 0.3333 | 0.8167 | 0.8667 | 0.7250 | 0.7750 |
diff --git "a/tables/ToMBench/\345\237\272\347\241\200\346\214\207\346\240\207.md" "b/tables/ToMBench/\345\237\272\347\241\200\346\214\207\346\240\207.md"
index ed2e5d2..b107b2d 100644
--- "a/tables/ToMBench/\345\237\272\347\241\200\346\214\207\346\240\207.md"
+++ "b/tables/ToMBench/\345\237\272\347\241\200\346\214\207\346\240\207.md"
@@ -1,7 +1,7 @@
# ToMBench - 基础指标
-| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it |
-|---|---|---|---|---|---|
-| accuracy | 0.6337 | 0.4104 | 0.6242 | 0.6688 | 0.6014 |
-| correct | 1812.3333 | 1173.6667 | 1785.3333 | 1912.6667 | 1720 |
-| total | 2860 | 2860 | 2860 | 2860 | 2860 |
+| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart | Qwen3-8B-SelfCoT | Qwen3-8B-SelfCoTFull |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| accuracy | 0.6337 | 0.4104 | 0.6242 | 0.6688 | 0.6014 | 0.6198 | 0.6262 | 0.6667 |
+| correct | 1812.3333 | 1173.6667 | 1785.3333 | 1912.6667 | 1720 | 1772.6667 | 1791 | 1906.762 |
+| total | 2860 | 2860 | 2860 | 2860 | 2860 | 2860 | 2860 | 2860 |
diff --git "a/tables/ToMQA/\345\205\266\344\273\226\346\214\207\346\240\207.md" "b/tables/ToMQA/\345\205\266\344\273\226\346\214\207\346\240\207.md"
index ac0c22c..6874294 100644
--- "a/tables/ToMQA/\345\205\266\344\273\226\346\214\207\346\240\207.md"
+++ "b/tables/ToMQA/\345\205\266\344\273\226\346\214\207\346\240\207.md"
@@ -1,10 +1,10 @@
# ToMQA - 其他指标
-| 指标 \ 模型 | Qwen3-0.6B | Qwen3-4B | Qwen3-8B |
-|---|---|---|---|
-| by_dimension.first_order_belief | 0.4174 | 0.6253 | 0.6226 |
-| by_dimension.second_order_belief | 0.2923 | 0.5397 | 0.4997 |
-| by_difficulty.easy | 0.3549 | 0.5825 | 0.5611 |
-| by_order.1 | 0.4174 | 0.6253 | 0.6226 |
-| by_order.2 | 0.2923 | 0.5397 | 0.4997 |
-| by_task_type.qa | 0.3549 | 0.5825 | 0.5611 |
+| 指标 \ 模型 | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | Qwen3-8B-SIPColdStart |
+| --- | --- | --- | --- | --- |
+| by_dimension.first_order_belief | 0.4174 | 0.6253 | 0.6226 | - |
+| by_dimension.second_order_belief | 0.2923 | 0.5397 | 0.4997 | - |
+| by_difficulty.easy | 0.3549 | 0.5825 | 0.5611 | - |
+| by_order.1 | 0.4174 | 0.6253 | 0.6226 | - |
+| by_order.2 | 0.2923 | 0.5397 | 0.4997 | - |
+| by_task_type.qa | 0.3549 | 0.5825 | 0.5611 | - |
diff --git "a/tables/ToMQA/\345\237\272\347\241\200\346\214\207\346\240\207.md" "b/tables/ToMQA/\345\237\272\347\241\200\346\214\207\346\240\207.md"
index 1e67d53..e27db2d 100644
--- "a/tables/ToMQA/\345\237\272\347\241\200\346\214\207\346\240\207.md"
+++ "b/tables/ToMQA/\345\237\272\347\241\200\346\214\207\346\240\207.md"
@@ -1,7 +1,7 @@
# ToMQA - 基础指标
-| 指标 \ 模型 | Qwen3-0.6B | Qwen3-4B | Qwen3-8B |
-|---|---|---|---|
-| accuracy | 0.3549 | 0.5825 | 0.5611 |
-| correct | 4258.3333 | 6990.3333 | 6733.3333 |
-| total | 12000 | 12000 | 12000 |
+| 指标 \ 模型 | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | Qwen3-8B-SIPColdStart | Qwen3-8B-SelfCoT | Qwen3-8B-SelfCoTFull |
+| --- | --- | --- | --- | --- | --- | --- |
+| accuracy | 0.3549 | 0.5825 | 0.5611 | - | 0.5484 | 0.5547 |
+| correct | 4258.3333 | 6990.3333 | 6733.3333 | - | 6581 | 6657 |
+| total | 12000 | 12000 | 12000 | - | 12000 | 12000 |
diff --git "a/tables/ToMi/\345\237\272\347\241\200\346\214\207\346\240\207.md" "b/tables/ToMi/\345\237\272\347\241\200\346\214\207\346\240\207.md"
new file mode 100644
index 0000000..f59d06d
--- /dev/null
+++ "b/tables/ToMi/\345\237\272\347\241\200\346\214\207\346\240\207.md"
@@ -0,0 +1,7 @@
+# ToMi - 基础指标
+
+| 指标 \ 模型 | Qwen3-8B-SelfCoT | Qwen3-8B-SelfCoTFull |
+| --- | --- | --- |
+| accuracy | 0.7508 | 0.7561 |
+| correct | 4500.3333 | 4532.3333 |
+| total | 5994 | 5994 |
diff --git "a/tables/Tomato/\345\205\266\344\273\226\346\214\207\346\240\207.md" "b/tables/Tomato/\345\205\266\344\273\226\346\214\207\346\240\207.md"
index 11e181f..2b23db1 100644
--- "a/tables/Tomato/\345\205\266\344\273\226\346\214\207\346\240\207.md"
+++ "b/tables/Tomato/\345\205\266\344\273\226\346\214\207\346\240\207.md"
@@ -1,13 +1,13 @@
# Tomato - 其他指标
-| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it |
-|---|---|---|---|---|---|
-| by_dimension_1.belief | 0.6564 | 0.3883 | 0.6032 | 0.6240 | 0.5362 |
-| by_dimension_1.desire | 0.7431 | 0.4309 | 0.7047 | 0.7379 | 0.5997 |
-| by_dimension_1.emotion | 0.7001 | 0.3939 | 0.6754 | 0.6995 | 0.5699 |
-| by_dimension_1.intention | 0.6892 | 0.4093 | 0.6649 | 0.6958 | 0.5556 |
-| by_dimension_1.knowledge | 0.6411 | 0.3959 | 0.5883 | 0.6362 | 0.5581 |
-| by_dimension_2.first_order | 0.7473 | 0.4155 | 0.7013 | 0.7322 | 0.5904 |
-| by_dimension_2.second_order | 0.6226 | 0.3914 | 0.5902 | 0.6226 | 0.5364 |
-| by_dimension_3.__none__ | 0.7106 | 0.4207 | 0.6693 | 0.7030 | 0.5854 |
-| by_dimension_3.false_belief | 0.5347 | 0.3044 | 0.5083 | 0.5281 | 0.4363 |
+| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart |
+| --- | --- | --- | --- | --- | --- | --- |
+| by_dimension_1.belief | 0.6564 | 0.3883 | 0.6032 | 0.6240 | 0.5362 | 0.5407 |
+| by_dimension_1.desire | 0.7431 | 0.4309 | 0.7047 | 0.7379 | 0.5997 | 0.6263 |
+| by_dimension_1.emotion | 0.7001 | 0.3939 | 0.6754 | 0.6995 | 0.5699 | 0.6195 |
+| by_dimension_1.intention | 0.6892 | 0.4093 | 0.6649 | 0.6958 | 0.5556 | 0.6168 |
+| by_dimension_1.knowledge | 0.6411 | 0.3959 | 0.5883 | 0.6362 | 0.5581 | 0.5469 |
+| by_dimension_2.first_order | 0.7473 | 0.4155 | 0.7013 | 0.7322 | 0.5904 | 0.6486 |
+| by_dimension_2.second_order | 0.6226 | 0.3914 | 0.5902 | 0.6226 | 0.5364 | 0.5294 |
+| by_dimension_3.__none__ | 0.7106 | 0.4207 | 0.6693 | 0.7030 | 0.5854 | 0.6146 |
+| by_dimension_3.false_belief | 0.5347 | 0.3044 | 0.5083 | 0.5281 | 0.4363 | 0.4396 |
diff --git "a/tables/Tomato/\345\237\272\347\241\200\346\214\207\346\240\207.md" "b/tables/Tomato/\345\237\272\347\241\200\346\214\207\346\240\207.md"
index 6579601..2d650da 100644
--- "a/tables/Tomato/\345\237\272\347\241\200\346\214\207\346\240\207.md"
+++ "b/tables/Tomato/\345\237\272\347\241\200\346\214\207\346\240\207.md"
@@ -1,7 +1,7 @@
# Tomato - 基础指标
-| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it |
-|---|---|---|---|---|---|
-| accuracy | 0.6844 | 0.4034 | 0.6453 | 0.6769 | 0.5632 |
-| correct | 3696.3333 | 2178.6667 | 3485 | 3656 | 3041.6667 |
-| total | 5401 | 5401 | 5401 | 5401 | 5401 |
+| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart | Qwen3-8B-SelfCoT | Qwen3-8B-SelfCoTFull |
+| --- | --- | --- | --- | --- | --- | --- | --- | --- |
+| accuracy | 0.6844 | 0.4034 | 0.6453 | 0.6769 | 0.5632 | 0.5885 | 0.6280 | 0.5816 |
+| correct | 3696.3333 | 2178.6667 | 3485 | 3656 | 3041.6667 | 3178.3333 | 3391.6667 | 3141 |
+| total | 5401 | 5401 | 5401 | 5401 | 5401 | 5401 | 5401 | 5401 |
diff --git a/tasks/ToMBench/run.py b/tasks/ToMBench/run.py
index 625dacd..32e7f8f 100644
--- a/tasks/ToMBench/run.py
+++ b/tasks/ToMBench/run.py
@@ -3,39 +3,40 @@
from pathlib import Path
from typing import Any, Dict, List
-# 添加父目录到路径以导入 src
sys.path.insert(0, str(Path(__file__).parent.parent))
-from src.dataloader import load_dataset
-from src.llm import LLMClient
from src import runner
+from src import judge as judge_module
from ToMBench.prompts import get_template, build_prompt
from ToMBench.metrics import compute_metrics
import logging
-# 彻底关闭 httpx 和 httpcore 的请求日志
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
+
+
+def _is_correct(pred: str, gold: str) -> bool:
+ return pred == gold
+
+
def main():
- # 加载数据集配置
dataset_config = runner.load_dataset_config("tasks/ToMBench/config.yaml")
-
- # 加载实验配置
experiment_config = runner.load_experiment_config("experiment_config.yaml")
schema = dataset_config["schema"]
prompt_method = dataset_config["default_prompt"]
-
- # 获取 prompt 模板
template = get_template(prompt_method)
-
- # 创建 LLM 客户端
client = runner.create_llm_client(experiment_config["llm_config"])
- # 加载数据
+ badcase_enabled = experiment_config["badcase_enabled"]
+ enable_judge = experiment_config["enable_llm_judge"]
+ judge_client = None
+ if enable_judge:
+ judge_client = runner.create_llm_client(experiment_config["judge_config"])
+
data = runner.load_and_limit_data(
subset=dataset_config["subset"],
datasets_path=experiment_config["datasets_path"],
@@ -46,17 +47,18 @@ def main():
print(f"Prompt method: {prompt_method}")
print(f"Repeats: {experiment_config['repeats']}")
- # 构建 prompts(每个 repeat 构建相同的 prompts)
prompts = [build_prompt(template, row) for row in data]
all_prompts = prompts * experiment_config["repeats"]
- # 批量结构化推理
print(f"Running inference ({len(all_prompts)} prompts)...")
results = client.batch_generate_structure(all_prompts, schema)
- # 使用数据集的 metrics 函数计算
- all_predictions = []
- all_metrics = []
+ gold_answers = [row['Answer']['Correct Answer'][0] for row in data]
+
+ all_predictions: List[List[str]] = []
+ all_metrics: List[Dict[str, Any]] = []
+ all_metrics_with_judge: List[Dict[str, Any]] = []
+ all_badcases: List[Dict[str, Any]] = []
for i in range(experiment_config["repeats"]):
start = i * len(data)
@@ -65,13 +67,55 @@ def main():
predictions = [r.answer for r in repeat_results]
all_predictions.append(predictions)
- # 调用数据集的 metrics 函数
metrics = compute_metrics(predictions, data)
all_metrics.append(metrics)
print(f"Run {i+1}: Accuracy={metrics['accuracy']:.4f}, Correct={metrics['correct']}/{metrics['total']}")
- # 保存结果
- gold_answers = [row['Answer']['Correct Answer'][0] for row in data]
+ # --- LLM Judge 兜底 ---
+ judge_verdicts = None
+ if judge_client:
+ failed_items = [
+ {
+ "raw_response": r.raw_response,
+ "gold_answer": gold,
+ "question": prompt,
+ }
+ for r, gold, prompt in zip(repeat_results, gold_answers, prompts)
+ if not r.extraction_success
+ ]
+ if failed_items:
+ judge_results = judge_module.batch_judge(judge_client, failed_items)
+ judge_verdicts_full: List[bool] = []
+ ji = 0
+ for r in repeat_results:
+ if not r.extraction_success:
+ judge_verdicts_full.append(judge_results[ji])
+ ji += 1
+ else:
+ judge_verdicts_full.append(False)
+ judge_verdicts = judge_verdicts_full
+
+ corrected = runner.build_corrected_predictions(
+ predictions, repeat_results, judge_results, gold_answers,
+ )
+ metrics_j = compute_metrics(corrected, data)
+ all_metrics_with_judge.append(metrics_j)
+ print(
+ f" [Judge] Accuracy={metrics_j['accuracy']:.4f}, "
+ f"Recovered={metrics_j['correct'] - metrics['correct']}"
+ )
+ else:
+ all_metrics_with_judge.append(metrics)
+
+ # --- Bad case 收集 ---
+ if badcase_enabled:
+ bcs = runner.collect_badcases(
+ repeat_results, predictions, gold_answers, prompts,
+ dataset_config["dataset"], _is_correct,
+ repeat_idx=i, judge_verdicts=judge_verdicts,
+ )
+ all_badcases.extend(bcs)
+
runner.save_common_results(
dataset_name=dataset_config["dataset"],
model=experiment_config["llm_config"]["model_name"],
@@ -82,11 +126,12 @@ def main():
results_path=experiment_config["results_path"],
dataset_config=dataset_config,
experiment_config=experiment_config,
+ badcases=all_badcases if badcase_enabled else None,
+ all_metrics_with_judge=all_metrics_with_judge if enable_judge else None,
)
- # 打印统计摘要
runner.print_summary_stats(all_metrics, experiment_config["repeats"], len(gold_answers))
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/tasks/ToMQA/run.py b/tasks/ToMQA/run.py
index 4161b17..ef680f3 100644
--- a/tasks/ToMQA/run.py
+++ b/tasks/ToMQA/run.py
@@ -1,17 +1,18 @@
"""ToMQA 评测脚本(基于结构化输出)"""
import sys
from pathlib import Path
+from typing import Any, Dict, List
-# 添加父目录到路径以导入 src
sys.path.insert(0, str(Path(__file__).parent.parent))
from src import runner
+from src import judge as judge_module
+
from ToMQA.prompts import get_template, build_prompt
-from ToMQA.metrics import compute_metrics
+from ToMQA.metrics import compute_metrics, normalize_answer
import logging
-# 关闭不必要日志
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
@@ -34,23 +35,27 @@ def extract_gold_answers(data):
return golds
+def _is_correct(pred: str, gold: str) -> bool:
+ p = normalize_answer(pred)
+ g = normalize_answer(gold)
+ return bool(p) and p == g
+
+
def main():
- # 加载数据集配置
dataset_config = runner.load_dataset_config("tasks/ToMQA/config.yaml")
-
- # 加载实验配置
experiment_config = runner.load_experiment_config("experiment_config.yaml")
schema = dataset_config["schema"]
prompt_method = dataset_config["default_prompt"]
-
- # 获取 prompt 模板
template = get_template(prompt_method)
-
- # 创建 LLM 客户端
client = runner.create_llm_client(experiment_config["llm_config"])
- # 加载数据
+ badcase_enabled = experiment_config["badcase_enabled"]
+ enable_judge = experiment_config["enable_llm_judge"]
+ judge_client = None
+ if enable_judge:
+ judge_client = runner.create_llm_client(experiment_config["judge_config"])
+
data = runner.load_and_limit_data(
subset=dataset_config["subset"],
datasets_path=experiment_config["datasets_path"],
@@ -61,17 +66,18 @@ def main():
print(f"Prompt method: {prompt_method}")
print(f"Repeats: {experiment_config['repeats']}")
- # 构建 prompts
prompts = [build_prompt(template, row) for row in data]
all_prompts = prompts * experiment_config["repeats"]
- # 批量结构化推理
print(f"Running inference ({len(all_prompts)} prompts)...")
results = client.batch_generate_structure(all_prompts, schema)
- # 计算 metrics
- all_predictions = []
- all_metrics = []
+ gold_answers = extract_gold_answers(data)
+
+ all_predictions: List[List[str]] = []
+ all_metrics: List[Dict[str, Any]] = []
+ all_metrics_with_judge: List[Dict[str, Any]] = []
+ all_badcases: List[Dict[str, Any]] = []
for i in range(experiment_config["repeats"]):
start = i * len(data)
@@ -87,8 +93,51 @@ def main():
f"Correct={metrics['correct']}/{metrics['total']}"
)
- # 保存结果
- gold_answers = extract_gold_answers(data)
+ # --- LLM Judge 兜底 ---
+ judge_verdicts = None
+ if judge_client:
+ failed_items = [
+ {
+ "raw_response": r.raw_response,
+ "gold_answer": gold,
+ "question": prompt,
+ }
+ for r, gold, prompt in zip(repeat_results, gold_answers, prompts)
+ if not r.extraction_success
+ ]
+ if failed_items:
+ judge_results = judge_module.batch_judge(judge_client, failed_items)
+ judge_verdicts_full: List[bool] = []
+ ji = 0
+ for r in repeat_results:
+ if not r.extraction_success:
+ judge_verdicts_full.append(judge_results[ji])
+ ji += 1
+ else:
+ judge_verdicts_full.append(False)
+ judge_verdicts = judge_verdicts_full
+
+ corrected = runner.build_corrected_predictions(
+ predictions, repeat_results, judge_results, gold_answers,
+ )
+ metrics_j = compute_metrics(corrected, data)
+ all_metrics_with_judge.append(metrics_j)
+ print(
+ f" [Judge] Accuracy={metrics_j['accuracy']:.4f}, "
+ f"Recovered={metrics_j['correct'] - metrics['correct']}"
+ )
+ else:
+ all_metrics_with_judge.append(metrics)
+
+ # --- Bad case 收集 ---
+ if badcase_enabled:
+ bcs = runner.collect_badcases(
+ repeat_results, predictions, gold_answers, prompts,
+ dataset_config["dataset"], _is_correct,
+ repeat_idx=i, judge_verdicts=judge_verdicts,
+ )
+ all_badcases.extend(bcs)
+
runner.save_common_results(
dataset_name=dataset_config["dataset"],
model=experiment_config["llm_config"]["model_name"],
@@ -99,9 +148,10 @@ def main():
results_path=experiment_config["results_path"],
dataset_config=dataset_config,
experiment_config=experiment_config,
+ badcases=all_badcases if badcase_enabled else None,
+ all_metrics_with_judge=all_metrics_with_judge if enable_judge else None,
)
- # 打印统计摘要
runner.print_summary_stats(all_metrics, experiment_config["repeats"], len(gold_answers))
diff --git a/tasks/ToMi/metrics.py b/tasks/ToMi/metrics.py
index 9a7f9da..aa469dd 100644
--- a/tasks/ToMi/metrics.py
+++ b/tasks/ToMi/metrics.py
@@ -11,14 +11,18 @@ def _normalize_word(text: Any) -> str:
def compute_metrics(predictions: List[str], data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""计算 ToMi 的 metrics(单词答案精确匹配)"""
- gold_answers = [_normalize_word(row.get("output", "")) for row in data]
+ gold_answers = []
+ for row in data:
+ correct = row.get("Answer", {}).get("Correct_Answer", [])
+ gold_answers.append(_normalize_word(correct[0]) if correct else "")
+
pred_answers = [_normalize_word(p) for p in predictions]
- correct = sum(1 for p, g in zip(pred_answers, gold_answers) if p == g)
- accuracy = correct / len(pred_answers) if pred_answers else 0
+ correct_count = sum(1 for p, g in zip(pred_answers, gold_answers) if p == g)
+ accuracy = correct_count / len(pred_answers) if pred_answers else 0
return {
"accuracy": accuracy,
- "correct": correct,
+ "correct": correct_count,
"total": len(pred_answers),
}
diff --git a/tasks/ToMi/prompts.py b/tasks/ToMi/prompts.py
index 7b6c9fe..c3e1184 100644
--- a/tasks/ToMi/prompts.py
+++ b/tasks/ToMi/prompts.py
@@ -23,8 +23,9 @@
def build_prompt(template: str, row: Dict[str, Any]) -> str:
"""构建 prompt"""
- story = row.get("instruction", "")
- question = row.get("input", "")
+ story_info = row.get("Story", {}) if isinstance(row.get("Story"), dict) else {}
+ story = story_info.get("full_story", "") or ""
+ question = row.get("Question", "") or ""
return template.format(story=story, question=question)
diff --git a/tasks/ToMi/run.py b/tasks/ToMi/run.py
index 5d96c26..f0c8f0d 100644
--- a/tasks/ToMi/run.py
+++ b/tasks/ToMi/run.py
@@ -3,44 +3,49 @@
from pathlib import Path
from typing import Any, Dict, List
-# 添加父目录到路径以导入 src
sys.path.insert(0, str(Path(__file__).parent.parent))
from src import runner
+from src import judge as judge_module
from ToMi.prompts import get_template, build_prompt
from ToMi.metrics import compute_metrics
import logging
-# 关闭不必要日志
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
def extract_gold_answers(data: List[Dict[str, Any]]) -> List[str]:
- """提取标准答案。"""
- return [str(row.get("output", "")).strip().lower() for row in data]
+ """提取标准答案(取 Answer.Correct_Answer 列表的第一个元素)。"""
+ answers = []
+ for row in data:
+ correct = row.get("Answer", {}).get("Correct_Answer", [])
+ answers.append(str(correct[0]).strip().lower() if correct else "")
+ return answers
+
+
+def _is_correct(pred: str, gold: str) -> bool:
+ return str(pred).strip().lower() == str(gold).strip().lower()
def main():
- # 加载数据集配置
dataset_config = runner.load_dataset_config("tasks/ToMi/config.yaml")
-
- # 加载实验配置
experiment_config = runner.load_experiment_config("experiment_config.yaml")
schema = dataset_config["schema"]
prompt_method = dataset_config["default_prompt"]
-
- # 获取 prompt 模板
template = get_template(prompt_method)
-
- # 创建 LLM 客户端
client = runner.create_llm_client(experiment_config["llm_config"])
- # 加载数据
+ badcase_enabled = experiment_config["badcase_enabled"]
+ enable_judge = experiment_config["enable_llm_judge"]
+ judge_client = None
+ if enable_judge:
+ judge_client = runner.create_llm_client(experiment_config["judge_config"])
+
data = runner.load_and_limit_data(
subset=dataset_config["subset"],
datasets_path=experiment_config["datasets_path"],
@@ -51,17 +56,19 @@ def main():
print(f"Prompt method: {prompt_method}")
print(f"Repeats: {experiment_config['repeats']}")
- # 构建 prompts(每个 repeat 构建相同的 prompts)
prompts = [build_prompt(template, row) for row in data]
all_prompts = prompts * experiment_config["repeats"]
- # 批量结构化推理
print(f"Running inference ({len(all_prompts)} prompts)...")
results = client.batch_generate_structure(all_prompts, schema)
- # 计算 metrics
- all_predictions = []
- all_metrics = []
+ gold_answers = extract_gold_answers(data)
+
+ all_predictions: List[List[str]] = []
+ all_metrics: List[Dict[str, Any]] = []
+ all_metrics_with_judge: List[Dict[str, Any]] = []
+ all_badcases: List[Dict[str, Any]] = []
+
for i in range(experiment_config["repeats"]):
start = i * len(data)
end = start + len(data)
@@ -73,8 +80,51 @@ def main():
all_metrics.append(metrics)
print(f"Run {i+1}: Accuracy={metrics['accuracy']:.4f}, Correct={metrics['correct']}/{metrics['total']}")
- # 保存结果
- gold_answers = extract_gold_answers(data)
+ # --- LLM Judge 兜底 ---
+ judge_verdicts = None
+ if judge_client:
+ failed_items = [
+ {
+ "raw_response": r.raw_response,
+ "gold_answer": gold,
+ "question": prompt,
+ }
+ for r, gold, prompt in zip(repeat_results, gold_answers, prompts)
+ if not r.extraction_success
+ ]
+ if failed_items:
+ judge_results = judge_module.batch_judge(judge_client, failed_items)
+ judge_verdicts_full: List[bool] = []
+ ji = 0
+ for r in repeat_results:
+ if not r.extraction_success:
+ judge_verdicts_full.append(judge_results[ji])
+ ji += 1
+ else:
+ judge_verdicts_full.append(False)
+ judge_verdicts = judge_verdicts_full
+
+ corrected = runner.build_corrected_predictions(
+ predictions, repeat_results, judge_results, gold_answers,
+ )
+ metrics_j = compute_metrics(corrected, data)
+ all_metrics_with_judge.append(metrics_j)
+ print(
+ f" [Judge] Accuracy={metrics_j['accuracy']:.4f}, "
+ f"Recovered={metrics_j['correct'] - metrics['correct']}"
+ )
+ else:
+ all_metrics_with_judge.append(metrics)
+
+ # --- Bad case 收集 ---
+ if badcase_enabled:
+ bcs = runner.collect_badcases(
+ repeat_results, predictions, gold_answers, prompts,
+ dataset_config["dataset"], _is_correct,
+ repeat_idx=i, judge_verdicts=judge_verdicts,
+ )
+ all_badcases.extend(bcs)
+
runner.save_common_results(
dataset_name=dataset_config["dataset"],
model=experiment_config["llm_config"]["model_name"],
@@ -85,9 +135,10 @@ def main():
results_path=experiment_config["results_path"],
dataset_config=dataset_config,
experiment_config=experiment_config,
+ badcases=all_badcases if badcase_enabled else None,
+ all_metrics_with_judge=all_metrics_with_judge if enable_judge else None,
)
- # 打印统计摘要
runner.print_summary_stats(all_metrics, experiment_config["repeats"], len(gold_answers))
diff --git a/tasks/Tomato/run.py b/tasks/Tomato/run.py
index 6194152..45cb2b1 100644
--- a/tasks/Tomato/run.py
+++ b/tasks/Tomato/run.py
@@ -11,8 +11,11 @@
sys.path.insert(0, str(Path(__file__).parent.parent))
from src import runner
+from src import judge as judge_module
+
from Tomato.prompts import get_template, build_prompt
from Tomato.metrics import compute_metrics
+
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
@@ -97,6 +100,10 @@ def shuffle_mcq_options(mcq: Dict[str, Any], seed: int) -> Dict[str, Any]:
return {**mcq, "original_choices": new_choices, "gold_letter": new_gold}
+def _is_correct(pred: str, gold: str) -> bool:
+ return bool(pred) and pred == gold
+
+
def main() -> None:
dataset_config = runner.load_dataset_config("tasks/Tomato/config.yaml")
experiment_config = runner.load_experiment_config("experiment_config.yaml")
@@ -106,6 +113,12 @@ def main() -> None:
template = get_template(prompt_method)
client = runner.create_llm_client(experiment_config["llm_config"])
+ badcase_enabled = experiment_config["badcase_enabled"]
+ enable_judge = experiment_config["enable_llm_judge"]
+ judge_client = None
+ if enable_judge:
+ judge_client = runner.create_llm_client(experiment_config["judge_config"])
+
data = runner.load_and_limit_data(
subset=dataset_config["subset"],
datasets_path=experiment_config["datasets_path"],
@@ -122,16 +135,21 @@ def main() -> None:
all_prompts: List[str] = []
repeat_data: List[List[Dict[str, Any]]] = []
+ repeat_prompts_list: List[List[str]] = []
for i in range(repeats):
shuffled_rows: List[Dict[str, Any]] = []
+ cur_prompts: List[str] = []
for j, row in enumerate(data):
shuffled_mcq = shuffle_mcq_options(row["_mcq"], seed=42 * (i + 1) + j)
shuffled_row = dict(row)
shuffled_row["_mcq"] = shuffled_mcq
shuffled_rows.append(shuffled_row)
- all_prompts.append(build_prompt(template, shuffled_row))
+ p = build_prompt(template, shuffled_row)
+ all_prompts.append(p)
+ cur_prompts.append(p)
repeat_data.append(shuffled_rows)
+ repeat_prompts_list.append(cur_prompts)
print(f"Running inference ({len(all_prompts)} prompts)...")
results = client.batch_generate_structure(all_prompts, schema)
@@ -139,7 +157,9 @@ def main() -> None:
n = len(data)
all_predictions: List[List[str]] = []
all_metrics: List[Dict[str, Any]] = []
+ all_metrics_with_judge: List[Dict[str, Any]] = []
all_gold: List[List[str]] = []
+ all_badcases: List[Dict[str, Any]] = []
for i in range(repeats):
start = i * n
@@ -149,11 +169,59 @@ def main() -> None:
predictions = [r.answer for r in repeat_results]
all_predictions.append(predictions)
+ repeat_gold = [row["_mcq"]["gold_letter"] for row in rows]
+ repeat_pr = repeat_prompts_list[i]
+
metrics = compute_metrics(predictions, rows)
all_metrics.append(metrics)
- all_gold.append([row["_mcq"]["gold_letter"] for row in rows])
+ all_gold.append(repeat_gold)
print(f"Run {i+1}: Accuracy={metrics['accuracy']:.4f}, Correct={metrics['correct']}/{metrics['total']}")
+ # --- LLM Judge 兜底 ---
+ judge_verdicts = None
+ if judge_client:
+ failed_items = [
+ {
+ "raw_response": r.raw_response,
+ "gold_answer": gold,
+ "question": prompt,
+ }
+ for r, gold, prompt in zip(repeat_results, repeat_gold, repeat_pr)
+ if not r.extraction_success
+ ]
+ if failed_items:
+ judge_results = judge_module.batch_judge(judge_client, failed_items)
+ judge_verdicts_full: List[bool] = []
+ ji = 0
+ for r in repeat_results:
+ if not r.extraction_success:
+ judge_verdicts_full.append(judge_results[ji])
+ ji += 1
+ else:
+ judge_verdicts_full.append(False)
+ judge_verdicts = judge_verdicts_full
+
+ corrected = runner.build_corrected_predictions(
+ predictions, repeat_results, judge_results, repeat_gold,
+ )
+ metrics_j = compute_metrics(corrected, rows)
+ all_metrics_with_judge.append(metrics_j)
+ print(
+ f" [Judge] Accuracy={metrics_j['accuracy']:.4f}, "
+ f"Recovered={metrics_j['correct'] - metrics['correct']}"
+ )
+ else:
+ all_metrics_with_judge.append(metrics)
+
+ # --- Bad case 收集 ---
+ if badcase_enabled:
+ bcs = runner.collect_badcases(
+ repeat_results, predictions, repeat_gold, repeat_pr,
+ dataset_config["dataset"], _is_correct,
+ repeat_idx=i, judge_verdicts=judge_verdicts,
+ )
+ all_badcases.extend(bcs)
+
runner.save_common_results(
dataset_name=dataset_config["dataset"],
model=experiment_config["llm_config"]["model_name"],
@@ -162,6 +230,10 @@ def main() -> None:
gold_answers=all_gold,
all_metrics=all_metrics,
results_path=experiment_config["results_path"],
+ dataset_config=dataset_config,
+ experiment_config=experiment_config,
+ badcases=all_badcases if badcase_enabled else None,
+ all_metrics_with_judge=all_metrics_with_judge if enable_judge else None,
)
runner.print_summary_stats(all_metrics, repeats, n)