diff --git a/.gitignore b/.gitignore index 67404c9..b40fcca 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ build/* **/cache/** *.pdf *.csv -site/* \ No newline at end of file +site/* +tool_cache/* \ No newline at end of file diff --git a/extract_ths_tool_cache/compress.sh b/extract_ths_tool_cache/compress.sh new file mode 100644 index 0000000..537b3b9 --- /dev/null +++ b/extract_ths_tool_cache/compress.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# 定义目标文件夹和输出文件名前缀 +TARGET_DIR="tool_cache" +OUTPUT_PREFIX="extract_ths_tool_cache/tool_cache.tar.gz.part_" +SIZE_LIMIT="50m" + +echo "正在压缩并切分 $TARGET_DIR ..." + +# tar -c: 创建 +# -z: gzip压缩 +# -f -: 将结果输出到标准输出流 +# split -b: 按大小切分 +# -: 从标准输入流读取 +tar -czf - "$TARGET_DIR" --exclude='.DS_Store' | split -b $SIZE_LIMIT - "$OUTPUT_PREFIX" + +echo "压缩完成!生成的切分文件如下:" +ls -lh ${OUTPUT_PREFIX}* \ No newline at end of file diff --git a/extract_ths_tool_cache/extract.sh b/extract_ths_tool_cache/extract.sh new file mode 100644 index 0000000..cad8de5 --- /dev/null +++ b/extract_ths_tool_cache/extract.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# 定义切分文件的前缀 +TARGET_DIR="tool_cache" +INPUT_PREFIX="extract_ths_tool_cache/tool_cache.tar.gz.part_" + +# 检查是否有分卷文件 +if ! ls ${INPUT_PREFIX}* 1> /dev/null 2>&1; then + echo "错误:未发现分卷文件 ${INPUT_PREFIX}*" + exit 1 +fi + +echo "正在合并并解压文件..." + +# cat: 合并所有分卷 +# tar -x: 解压 +cat ${INPUT_PREFIX}* | tar -xzf - + +echo "解压完成!文件夹 '$TARGET_DIR' 已还原。" \ No newline at end of file diff --git a/extract_ths_tool_cache/merge.py b/extract_ths_tool_cache/merge.py new file mode 100644 index 0000000..cb5871a --- /dev/null +++ b/extract_ths_tool_cache/merge.py @@ -0,0 +1,103 @@ +import os +import json +import shutil + +# 配置路径(Python 字符串不需要反斜杠转义空格) +SRC_ROOT_BASE = "/Users/tsc/研究工作/金融ASIO/代码/finance-mcp_3" +DST_ROOT_BASE = "/Users/tsc/研究工作/金融ASIO/代码/finance-mcp" + +# 需要合并的根文件夹 +TARGET_FOLDERS = ["tool_cache", "cache"] + +def load_json_data(file_path): + """加载 JSON 数据,支持列表或单个对象""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + return data if isinstance(data, list) else [data] + except Exception as e: + print(f" [跳过] 无法解析 JSON: {file_path}. 错误: {e}") + return None + +def save_json_data(file_path, data): + """保存 JSON 数据""" + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + +def get_code(sample): + """根据你提供的结构提取 code""" + try: + # 路径: sample -> tool_args -> code + return sample.get("tool_args", {}).get("code") + except: + return None + +def merge_json_files(src_file, dst_file): + """核心逻辑:合并两个文件中的 samples,按 code 去重""" + src_data = load_json_data(src_file) + dst_data = load_json_data(dst_file) + + if src_data is None or dst_data is None: + return False # 读取出错,不执行合并 + + # 获取目标文件中已有的所有 code + existing_codes = {get_code(s) for s in dst_data if get_code(s) is not None} + + initial_count = len(dst_data) + for sample in src_data: + code = get_code(sample) + if code and code not in existing_codes: + dst_data.append(sample) + existing_codes.add(code) + + new_added = len(dst_data) - initial_count + if new_added > 0: + save_json_data(dst_file, dst_data) + print(f" [合并完成] {os.path.basename(dst_file)}: 新增了 {new_added} 条 code 样本") + else: + print(f" [无需合并] {os.path.basename(dst_file)}: 未发现新 code") + return True + +def start_recursive_merge(): + for target in TARGET_FOLDERS: + src_folder = os.path.join(SRC_ROOT_BASE, target) + dst_folder = os.path.join(DST_ROOT_BASE, target) + + if not os.path.exists(src_folder): + print(f"源文件夹不存在,跳过: {src_folder}") + continue + + print(f"\n>>> 正在扫描目录: {target}") + + # 使用 os.walk 递归遍历所有子文件夹 + for root, dirs, files in os.walk(src_folder): + # 计算当前子目录相对于源根目录的路径 + rel_path = os.path.relpath(root, src_folder) + # 对应的目标子目录路径 + target_dst_dir = os.path.join(dst_folder, rel_path) + + # 1. 如果目标子目录不存在,直接创建 + if not os.path.exists(target_dst_dir): + os.makedirs(target_dst_dir) + print(f"创建新目录: {target_dst_dir}") + + # 2. 处理当前目录下的所有文件 + for filename in files: + # 隐藏文件跳过 (如 .DS_Store) + if filename.startswith('.'): continue + + src_file_path = os.path.join(root, filename) + dst_file_path = os.path.join(target_dst_dir, filename) + + if not os.path.exists(dst_file_path): + # 如果目标位置没有这个文件,直接整体复制 + shutil.copy2(src_file_path, dst_file_path) + print(f" [新文件] 已复制: {rel_path}/{filename}") + else: + # 如果目标位置有重名文件,执行深度合并 + merge_json_files(src_file_path, dst_file_path) + +if __name__ == "__main__": + print("开始深度合并任务...") + start_recursive_merge() + print("\n所有任务已结束。") \ No newline at end of file diff --git a/extract_ths_tool_cache/tool_cache.tar.gz.part_aa b/extract_ths_tool_cache/tool_cache.tar.gz.part_aa new file mode 100644 index 0000000..492e9b7 Binary files /dev/null and b/extract_ths_tool_cache/tool_cache.tar.gz.part_aa differ diff --git a/extract_ths_tool_cache/tool_cache.tar.gz.part_ab b/extract_ths_tool_cache/tool_cache.tar.gz.part_ab new file mode 100644 index 0000000..4afbc88 Binary files /dev/null and b/extract_ths_tool_cache/tool_cache.tar.gz.part_ab differ diff --git a/extract_ths_tool_cache/tool_cache.tar.gz.part_ac b/extract_ths_tool_cache/tool_cache.tar.gz.part_ac new file mode 100644 index 0000000..77287ff Binary files /dev/null and b/extract_ths_tool_cache/tool_cache.tar.gz.part_ac differ diff --git a/finance_mcp/config/ths_local.yaml b/finance_mcp/config/ths_local.yaml new file mode 100644 index 0000000..c5c779a --- /dev/null +++ b/finance_mcp/config/ths_local.yaml @@ -0,0 +1,156 @@ +flow: + crawl_ths_company: + flow_content: | + ReadLocalThsOp(tag="company") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取公司资料信息,例如:详细情况,高管介绍,发行相关,参控股公司,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_holder: + flow_content: | + ReadLocalThsOp(tag="holder") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取股东研究信息,例如:股东人数、十大流通股东、十大股东、十大债券持有人、控股层级关系,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_operate: + flow_content: | + ReadLocalThsOp(tag="operate") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取经营分析信息,例如:主营介绍、运营业务数据、主营构成分析、主要客户及供应商、董事会经营评述、产品价格,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_equity: + flow_content: | + ReadLocalThsOp(tag="equity") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取股本结构信息,例如:解禁时间表、总股本构成、A股结构图、历次股本变动,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_capital: + flow_content: | + ReadLocalThsOp(tag="capital") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取资本运作信息,例如:募集资金来源、项目投资、收购兼并、股权投资、参股IPO、股权转让、关联交易、质押解冻,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_worth: + flow_content: | + ReadLocalThsOp(tag="worth") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取盈利预测信息,例如:业绩预测、业绩预测详表、研报评级,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_news: + flow_content: | + ReadLocalThsOp(tag="news") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取新闻公告信息,例如:新闻与股价联动、公告列表、热点新闻列表、研报列表,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_concept: + flow_content: | + ReadLocalThsOp(tag="concept") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取概念题材信息,例如:常规概念、其他概念、题材要点、概念对比,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_position: + flow_content: | + ReadLocalThsOp(tag="position") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取主力持仓信息,例如:机构持股汇总、机构持股明细、被举牌情况、IPO获配机构,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_finance: + flow_content: | + ReadLocalThsOp(tag="finance") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取财务分析信息,例如:财务诊断、财务指标、指标变动说明、资产负债构成、财务报告、杜邦分析,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_bonus: + flow_content: | + ReadLocalThsOp(tag="bonus") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取分红融资信息,例如:分红诊断、分红情况、增发机构获配明细、增发概况、配股概况,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_event: + flow_content: | + ReadLocalThsOp(tag="event") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取公司大事信息,例如:高管持股变动、股东持股变动、担保明细、违规处理、机构调研、投资者互动,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true + + crawl_ths_field: + flow_content: | + ReadLocalThsOp(tag="field") + enable_cache: false + cache_expire_hours: 1 + description: "通过A股股票代码获取行业对比信息,例如:行业地位、行业新闻,最后返回和query相关的信息。" + input_schema: + code: + type: string + description: "stock code" + required: true diff --git a/finance_mcp/core/agent/conduct_research_op.py b/finance_mcp/core/agent/conduct_research_op.py index ba4d40a..ecf2899 100644 --- a/finance_mcp/core/agent/conduct_research_op.py +++ b/finance_mcp/core/agent/conduct_research_op.py @@ -74,7 +74,7 @@ def build_tool_call(self) -> ToolCall: ) async def async_execute(self): - """Run the multi-step research loop and produce a final answer. + """Run the multistep research loop and produce a final answer. The method performs the following high-level steps: diff --git a/finance_mcp/core/crawl/__init__.py b/finance_mcp/core/crawl/__init__.py index cb728ef..7e4c595 100644 --- a/finance_mcp/core/crawl/__init__.py +++ b/finance_mcp/core/crawl/__init__.py @@ -13,10 +13,12 @@ """ from .crawl4ai_op import Crawl4aiOp, Crawl4aiLongTextOp +from .read_local_ths_op import ReadLocalThsOp from .ths_url_op import ThsUrlOp __all__ = [ "Crawl4aiOp", "Crawl4aiLongTextOp", "ThsUrlOp", + "ReadLocalThsOp", ] diff --git a/finance_mcp/core/crawl/read_local_ths_op.py b/finance_mcp/core/crawl/read_local_ths_op.py new file mode 100644 index 0000000..338dd69 --- /dev/null +++ b/finance_mcp/core/crawl/read_local_ths_op.py @@ -0,0 +1,64 @@ +import json +from pathlib import Path +from typing import Dict + +from flowllm.core.context import C +from flowllm.core.op import BaseAsyncOp +from loguru import logger + + +@C.register_op() +class ReadLocalThsOp(BaseAsyncOp): + # Class-level cache: {tag: {code: tool_result}} + _cache: Dict[str, Dict[str, str]] = None + + def __init__(self, tag: str = "", **kwargs): + super().__init__(**kwargs) + self.tag: str = tag + # Initialize class-level cache if not exists + if ReadLocalThsOp._cache is None: + ReadLocalThsOp._cache = {} + + def _load_cache(self) -> Dict[str, str]: + """Load all crawl_ths_{tag}*.json files and build code->tool_result mapping.""" + tool_cache_dir = Path("tool_cache") + pattern = f"crawl_ths_{self.tag}*.json" + matching_files = list(tool_cache_dir.glob(pattern)) + + total_files = len(matching_files) + logger.info(f"Found {total_files} files matching pattern '{pattern}'") + + result_dict = {} + for idx, file_path in enumerate(matching_files, 1): + logger.info(f"Loading file [{idx}/{total_files}]: {file_path.name}") + + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + file_records = 0 + for item in data: + code = item['tool_args']['code'] + result_dict[code] = item['tool_result'] + file_records += 1 + + logger.info(f" → Processed {file_records} records from {file_path.name}") + + logger.info(f"✓ Successfully loaded {len(result_dict)} total records for tag={self.tag}") + return result_dict + + async def async_execute(self): + """Read tool_result for self.context.code from cached data.""" + # Load cache if not exists + if self.tag not in ReadLocalThsOp._cache: + ReadLocalThsOp._cache[self.tag] = self._load_cache() + + # Get code from context + code = self.context.code + if not code: + self.context.response.answer = f"No code={code} found in context." + logger.info(self.context.response.answer) + return + + # Get tool_result from cache + tool_result = ReadLocalThsOp._cache[self.tag].get(code) + self.context.response.answer = tool_result diff --git a/test_op/clean_empty_results.py b/test_op/clean_empty_results.py new file mode 100644 index 0000000..3c85d88 --- /dev/null +++ b/test_op/clean_empty_results.py @@ -0,0 +1,167 @@ +"""清理无效爬取结果脚本 + +该脚本用于清理 tool_cache 目录中 tool_result 为 "No relevant content found matching the query." 的条目, +并同步删除 progress 目录中相应的进度记录。 +""" + +import json +import os +from loguru import logger + +# 配置 +BASE_CACHE_DIR = "tool_cache" +PROGRESS_DIR = os.path.join(BASE_CACHE_DIR, "progress") +INVALID_RESULT = "No relevant content found matching the query." + + +def clean_cache_files(): + """清理 tool_cache 目录中的无效条目""" + + # 统计信息 + total_removed = 0 + removed_codes_by_tool = {} # {tool_name: set(codes)} + + # 遍历 tool_cache 目录下的所有 JSON 文件(不包括 progress 子目录) + for filename in os.listdir(BASE_CACHE_DIR): + file_path = os.path.join(BASE_CACHE_DIR, filename) + + # 跳过目录和非 JSON 文件 + if os.path.isdir(file_path) or not filename.endswith(".json"): + continue + + logger.info(f"处理文件: {filename}") + + try: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception as e: + logger.error(f"读取文件失败 {filename}: {e}") + continue + + if not isinstance(data, list): + logger.warning(f"文件格式不正确,跳过: {filename}") + continue + + # 过滤无效条目 + original_count = len(data) + valid_records = [] + removed_codes = set() + + for record in data: + tool_result = record.get("tool_result", "") + if tool_result == INVALID_RESULT: + # 记录被移除的 code + tool_args = record.get("tool_args", {}) + code = tool_args.get("code", "") + tool_name = record.get("tool_name", "") + if code: + removed_codes.add(code) + if tool_name not in removed_codes_by_tool: + removed_codes_by_tool[tool_name] = set() + removed_codes_by_tool[tool_name].add(code) + logger.debug(f" 移除无效条目: tool={tool_name}, code={code}") + else: + valid_records.append(record) + + removed_count = original_count - len(valid_records) + total_removed += removed_count + + if removed_count > 0: + # 保存清理后的文件 + with open(file_path, "w", encoding="utf-8") as f: + json.dump(valid_records, f, ensure_ascii=False, indent=2) + logger.info(f" 清理完成: 移除 {removed_count} 条,剩余 {len(valid_records)} 条") + else: + logger.info(f" 无需清理: 共 {original_count} 条,均有效") + + logger.info(f"\n{'='*60}") + logger.info(f"缓存文件清理完成,共移除 {total_removed} 条无效记录") + + return removed_codes_by_tool + + +def clean_progress_files(removed_codes_by_tool: dict): + """清理 progress 目录中的相应条目""" + + if not os.path.exists(PROGRESS_DIR): + logger.warning(f"进度目录不存在: {PROGRESS_DIR}") + return + + total_removed = 0 + + for tool_name, codes_to_remove in removed_codes_by_tool.items(): + progress_file = os.path.join(PROGRESS_DIR, f"{tool_name}_progress.json") + + if not os.path.exists(progress_file): + logger.warning(f"进度文件不存在: {progress_file}") + continue + + logger.info(f"处理进度文件: {tool_name}_progress.json") + + try: + with open(progress_file, "r", encoding="utf-8") as f: + progress_data = json.load(f) + except Exception as e: + logger.error(f"读取进度文件失败 {progress_file}: {e}") + continue + + completed_codes = set(progress_data.get("completed_codes", [])) + time_records = progress_data.get("time_records", {}) + + original_count = len(completed_codes) + + # 移除无效的 codes + for code in codes_to_remove: + if code in completed_codes: + completed_codes.remove(code) + if code in time_records: + del time_records[code] + + removed_count = original_count - len(completed_codes) + total_removed += removed_count + + if removed_count > 0: + # 保存更新后的进度文件 + progress_data["completed_codes"] = list(completed_codes) + progress_data["time_records"] = time_records + + with open(progress_file, "w", encoding="utf-8") as f: + json.dump(progress_data, f, ensure_ascii=False, indent=2) + logger.info(f" 清理完成: 移除 {removed_count} 条,剩余 {len(completed_codes)} 条") + else: + logger.info(f" 无需清理") + + logger.info(f"\n{'='*60}") + logger.info(f"进度文件清理完成,共移除 {total_removed} 条记录") + + +def main(): + """主函数""" + logger.info("="*60) + logger.info("开始清理无效爬取结果...") + logger.info(f"缓存目录: {BASE_CACHE_DIR}") + logger.info(f"进度目录: {PROGRESS_DIR}") + logger.info(f"无效结果标识: {INVALID_RESULT}") + logger.info("="*60 + "\n") + + # 第一步:清理缓存文件 + logger.info("【第一步】清理缓存文件中的无效条目...") + removed_codes_by_tool = clean_cache_files() + + # 第二步:清理进度文件 + logger.info("\n【第二步】清理进度文件中的相应条目...") + clean_progress_files(removed_codes_by_tool) + + logger.info("\n" + "="*60) + logger.info("✓ 所有清理任务完成!") + logger.info("="*60) + + # 打印汇总 + if removed_codes_by_tool: + logger.info("\n清理汇总:") + for tool_name, codes in removed_codes_by_tool.items(): + logger.info(f" {tool_name}: 移除 {len(codes)} 个股票代码") + + +if __name__ == "__main__": + main() diff --git a/test_op/test_crawl4ai_op.py b/test_op/test_crawl4ai_op.py index ae59dfd..7b6b8b1 100644 --- a/test_op/test_crawl4ai_op.py +++ b/test_op/test_crawl4ai_op.py @@ -6,6 +6,7 @@ """ import asyncio +import time from finance_mcp import FinanceMcpApp from finance_mcp.core.crawl import Crawl4aiOp @@ -13,13 +14,15 @@ async def main() -> None: """Execute the crawl operation for a sample stock information page.""" - + t1 = time.time() async with FinanceMcpApp(): # Instantiate and run the crawling operator against a THS stock page. op = Crawl4aiOp() - await op.async_call(url="https://stockpage.10jqka.com.cn/601899/") + # await op.async_call(url="https://stockpage.10jqka.com.cn/601899/") + await op.async_call(url="https://basic.10jqka.com.cn/601899/equity.html#stockpage") print(op.output) + print(f"Total time: {time.time() - t1:.2f}s") if __name__ == "__main__": asyncio.run(main()) diff --git a/test_op/test_project_http.py b/test_op/test_project_http.py index a2d7bb1..8512813 100644 --- a/test_op/test_project_http.py +++ b/test_op/test_project_http.py @@ -16,7 +16,9 @@ # Service configuration service_args = [ "finance-mcp", - "config=default,stream_agent", + "backend=http", + # "config=default,stream_agent", + "config=default,ths_local", "llm.default.model_name=qwen3-30b-a3b-thinking-2507", ] @@ -113,7 +115,11 @@ def main() -> None: for endpoint, data in [ # ("conduct_research", {"research_topic": "茅台怎么样?"}), # ("dashscope_deep_research", {"query": "茅台怎么样?"}), - ("langchain_deep_research", {"query": "茅台怎么样?"}), + # ("langchain_deep_research", {"query": "茅台怎么样?"}), + ("crawl_ths_company", {"code": "000001"}), + ("crawl_ths_company", {"code": "000002"}), + ("crawl_ths_company", {"code": "000004"}), + ("crawl_ths_holder", {"code": "000004"}), ]: test_http_service( endpoint=endpoint, diff --git a/test_op/test_ths_sse.py b/test_op/test_ths_sse.py new file mode 100644 index 0000000..137b9ec --- /dev/null +++ b/test_op/test_ths_sse.py @@ -0,0 +1,236 @@ +import asyncio +import json +import os +import random +import uuid + +# 禁用代理,避免 httpx 读取系统代理配置导致连接失败 +os.environ.setdefault("NO_PROXY", "*") +import pandas as pd +from datetime import datetime +from fastmcp.client.client import CallToolResult +from loguru import logger + +from finance_mcp.core.utils.fastmcp_client import FastMcpClient + +# --- 配置区 --- +HOST = "127.0.0.1" # 使用 IPv4 地址,避免 IPv6 连接问题 +PORT = 8050 +CSV_PATH = "tushare_stock_basic_20251226104714.csv" +BASE_CACHE_DIR = "tool_cache" +PROGRESS_DIR = os.path.join(BASE_CACHE_DIR, "progress") +MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB +SAVE_BATCH_SIZE = 1 # 每1个保存一次 +MAX_CONCURRENCY = 5 # 最大并发数(信号量控制,设为1即串行) +MIN_WAIT_ON_EMPTY = 60 # 无输出时最小等待秒数 +MAX_WAIT_ON_EMPTY = 120 # 无输出时最大等待秒数 +NORMAL_WAIT_SECONDS = 2 # 正常请求间隔秒数 + +# 针对每个页面结构设计的全量提取 Query +TOOLS_CONFIG = [ + # ("crawl_ths_company", "提取公司的完整资料:1.基本信息(行业、产品、主营、办公地址);2.高管介绍(所有高管的姓名、职务、薪资、详细个人简历);3.发行相关(上市日期、首日表现、募资额);4.所有参控股公司的名称、持股比例、业务、盈亏情况。"), + # ("crawl_ths_holder", "提取股东研究全量数据:1.历年股东人数及户均持股数;2.前十大股东及流通股东名单(含持股数、性质、变动情况);3.实际控制人详情及控股层级关系描述;4.股权质押、冻结的详细明细表。"), + # ("crawl_ths_operate", "提取经营分析数据:1.主营构成分析表(按行业、产品、区域划分的营业收入、利润、毛利率及同比变化);2.经营评述(公司对业务、核心竞争力的详细自我评估)。"), + # ("crawl_ths_equity", "提取股本结构信息:1.历次股本变动原因、日期及变动后的总股本;2.限售股份解禁的时间表、解禁数量及占总股本比例。"), + # ("crawl_ths_capital", "提取资本运作详情:1.资产重组、收购、合并的详细历史记录;2.对外投资明细及进展情况。"), + # ("crawl_ths_worth", "提取盈利预测信息:1.各机构最新评级汇总(买入/增持次数);2.未来三年的营收预测、净利润预测及EPS预测均值。"), + # ("crawl_ths_news", "提取最新新闻公告:1.公司最新重要公告标题及日期;2.媒体报道的新闻摘要及舆情评价。"), + # ("crawl_ths_concept", "提取所有概念题材:列出公司所属的所有概念板块,并详细提取每个概念对应的具体入选理由和业务关联性。"), + # ("crawl_ths_position", "提取主力持仓情况:1.各类机构(基金、保险、QFII等)持仓总数及占比;2.前十大具体机构持仓名单及变动。"), + # ("crawl_ths_finance", "提取财务分析详情:1.主要财务指标(盈利、成长、偿债等);2.资产负债表、利润表、现金流量表的核心科目及审计意见。"), + # ("crawl_ths_bonus", "提取分红融资记录:1.历年现金分红、送转股份方案及实施日期;2.历次增发、配股等融资详情。"), + # ("crawl_ths_event", "提取公司大事记录:1.股东及高管持股变动明细;2.对外担保记录、违规处理、机构调研及投资者互动记录。"), + ("crawl_ths_field", "提取行业对比数据:1.公司在所属行业内的规模、成长、盈利各项排名;2.与行业均值及同类竞品的关键财务指标对比。") +] + +if not os.path.exists(BASE_CACHE_DIR): + os.makedirs(BASE_CACHE_DIR, exist_ok=True) +os.makedirs(PROGRESS_DIR, exist_ok=True) + +class BatchResultSaver: + def __init__(self, tool_name): + self.tool_name = tool_name + self.buffer = [] + self.file_index = 1 + + def _get_file_path(self): + return os.path.join(BASE_CACHE_DIR, f"{self.tool_name}_{self.file_index:02d}.json") + + def add_record(self, tool_args, result_text): + now = datetime.now() + record = { + "_id": str(uuid.uuid4()), + "cache_key": f"{self.tool_name}::{json.dumps(tool_args, ensure_ascii=False)}", + "created_at": now.isoformat(), + "metadata": {"task_id": "comprehensive_crawl", "timestamp": now.isoformat()}, + "tool_args": tool_args, + "tool_name": self.tool_name, + "tool_result": result_text, + "updated_at": now.isoformat() + } + self.buffer.append(record) + if len(self.buffer) >= SAVE_BATCH_SIZE: + self.flush() + + def flush(self): + if not self.buffer: return + file_path = self._get_file_path() + + # 自动切分 50MB + if os.path.exists(file_path) and os.path.getsize(file_path) > MAX_FILE_SIZE: + self.file_index += 1 + file_path = self._get_file_path() + + data = [] + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + try: data = json.load(f) + except: data = [] + + data.extend(self.buffer) + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + logger.info(f"Save: {file_path} | Count: {len(data)}") + self.buffer = [] + +class ProgressTracker: + """进度跟踪器,支持断点继续""" + def __init__(self, tool_name): + self.tool_name = tool_name + self.progress_file = os.path.join(PROGRESS_DIR, f"{tool_name}_progress.json") + self.completed_codes = set() + self.time_records = {} # 记录每个code的处理时间 + self.load_progress() + + def load_progress(self): + if os.path.exists(self.progress_file): + try: + with open(self.progress_file, "r", encoding="utf-8") as f: + data = json.load(f) + self.completed_codes = set(data.get("completed_codes", [])) + self.time_records = data.get("time_records", {}) + except Exception as e: + logger.warning(f"加载进度文件失败 {self.progress_file}: {e}") + self.completed_codes = set() + self.time_records = {} + + def save_progress(self): + try: + data = { + "completed_codes": list(self.completed_codes), + "time_records": self.time_records + } + with open(self.progress_file, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + except Exception as e: + logger.error(f"保存进度文件失败 {self.progress_file}: {e}") + + def mark_completed(self, code, elapsed_time=None): + self.completed_codes.add(code) + if elapsed_time is not None: + self.time_records[code] = { + "elapsed_seconds": round(elapsed_time, 2), + "completed_at": datetime.now().isoformat() + } + + def is_completed(self, code): + return code in self.completed_codes + + def get_remaining_codes(self, all_codes): + return [code for code in all_codes if not self.is_completed(code)] + +async def process_single_stock(client, tool_name, code, deep_query, saver, progress_tracker, semaphore, index, total): + """处理单个股票的异步任务""" + async with semaphore: # 信号量控制并发数 + args = {"code": code, "query": deep_query} + start_time = datetime.now() # 记录开始时间 + try: + logger.info(f"Calling {tool_name} for {code} ({index}/{total})") + result: CallToolResult = await client.call_tool(tool_name, args) + elapsed_time = (datetime.now() - start_time).total_seconds() # 计算耗时 + + if not result.is_error: + content = result.content[0].text if result.content else "" + + # 检查内容是否包含不当内容警告,直接跳过不保存 + if "Input data may contain inappropriate content" in content: + logger.warning(f"工具 {tool_name} 处理 {code} 返回不当内容警告, 跳过该记录") + return + + # 检查内容是否为空或太短,如果是则等待50-100秒 + if not content or len(content.strip()) < 10: + wait_seconds = random.randint(MIN_WAIT_ON_EMPTY, MAX_WAIT_ON_EMPTY) + logger.warning( + f"工具 {tool_name} 处理 {code} 无输出或输出过短, " + f"耗时 {elapsed_time:.2f}秒, 等待 {wait_seconds} 秒..." + ) + await asyncio.sleep(wait_seconds) + else: + # 有正常输出,保存记录 + saver.add_record(args, content) + # 标记该代码已完成,记录耗时 + progress_tracker.mark_completed(code, elapsed_time) + # 定期保存进度 + if index % SAVE_BATCH_SIZE == 0: + progress_tracker.save_progress() + # 正常等待间隔 + logger.info(f"成功处理 {code}, 耗时 {elapsed_time:.2f}秒, 等待 {NORMAL_WAIT_SECONDS} 秒...") + await asyncio.sleep(NORMAL_WAIT_SECONDS) + else: + error_msg = result.content[0].text[:100] if result.content else "No error message" + logger.error(f"Error {code}: {error_msg}, 耗时 {elapsed_time:.2f}秒") + except Exception as e: + import traceback + elapsed_time = (datetime.now() - start_time).total_seconds() + logger.error(f"Failed {code}: {e}, 耗时 {elapsed_time:.2f}秒") + logger.error(f"Traceback: {traceback.format_exc()}") + +async def test_mcp_service(): + df = pd.read_csv(CSV_PATH, dtype={'symbol': str}) + stock_codes = df['symbol'].dropna().apply(lambda x: str(x).zfill(6)).tolist() + + mcp_config = {"type": "sse", "url": f"http://{HOST}:{PORT}/sse"} + + # 信号量控制并发数 + semaphore = asyncio.Semaphore(MAX_CONCURRENCY) + + async with FastMcpClient(name="full-info-crawler", config=mcp_config) as client: + # 外层循环:工具 (及设计好的深度 Query) + for tool_name, deep_query in TOOLS_CONFIG: + logger.info(f"### 开始爬取工具: {tool_name}") + saver = BatchResultSaver(tool_name) + progress_tracker = ProgressTracker(tool_name) + + # 获取剩余需要处理的股票代码 + remaining_codes = progress_tracker.get_remaining_codes(stock_codes) + total_remaining = len(remaining_codes) + logger.info(f"### 工具 {tool_name} 剩余 {total_remaining} 个股票代码待处理") + + # 创建所有并发任务 + tasks = [] + for i, code in enumerate(remaining_codes, start=1): + task = process_single_stock( + client=client, + tool_name=tool_name, + code=code, + deep_query=deep_query, + saver=saver, + progress_tracker=progress_tracker, + semaphore=semaphore, + index=i, + total=total_remaining + ) + tasks.append(task) + + # 并发执行所有任务,信号量控制最多10个同时运行 + logger.info(f"启动 {len(tasks)} 个并发任务,信号量限制为 {MAX_CONCURRENCY}") + await asyncio.gather(*tasks) + + # 保存最后的进度 + progress_tracker.save_progress() + saver.flush() # 一个工具所有代码跑完,冲刷最后的数据 + logger.info(f"### 工具 {tool_name} 完成!") + +if __name__ == "__main__": + asyncio.run(test_mcp_service()) \ No newline at end of file diff --git a/test_op/ths_sse.py b/test_op/ths_sse.py new file mode 100644 index 0000000..a047d92 --- /dev/null +++ b/test_op/ths_sse.py @@ -0,0 +1,412 @@ +"""同花顺数据全量爬取脚本(集成服务启动) + +该脚本整合了服务启动和数据爬取功能,无需手动在两个终端中分别启动服务和脚本。 +使用 FinanceMcpServiceRunner 自动启动和管理 finance-mcp 服务,并在同一进程中执行批量爬取任务。 + +功能特性: +1. 自动启动 finance-mcp 服务(SSE模式) +2. 支持断点续传(基于进度文件) +3. 并发控制(信号量) +4. 批量保存结果到 JSON 文件 +5. 自动文件切分(50MB) +6. 智能等待策略(无输出时长等待,正常时短等待) +""" + +import asyncio +import json +import os +import random +import uuid +import pandas as pd +from datetime import datetime +from fastmcp.client.client import CallToolResult +from loguru import logger + +from finance_mcp.core.utils.fastmcp_client import FastMcpClient +from finance_mcp.core.utils.service_runner import FinanceMcpServiceRunner + +# --- 服务配置 --- +SERVICE_ARGS = [ + "finance-mcp", + "config=default,ths", + 'disabled_flows=["tavily_search","mock_search","react_agent"]', + "mcp.transport=sse", +] +HOST = "localhost" +PORT = 8050 +os.environ.setdefault("NO_PROXY", "*") +# --- 数据配置 --- +CSV_PATH = "tushare_stock_basic_20251226104714.csv" +BASE_CACHE_DIR = "tool_cache" +PROGRESS_DIR = os.path.join(BASE_CACHE_DIR, "progress") +MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB +SAVE_BATCH_SIZE = 1 # 每1个保存一次 +MAX_CONCURRENCY = 5 # 最大并发数(信号量控制) +MIN_WAIT_ON_EMPTY = 60 # 无输出时最小等待秒数 +MAX_WAIT_ON_EMPTY = 120 # 无输出时最大等待秒数 +NORMAL_WAIT_SECONDS = 3 # 正常请求间隔秒数 + +# 无效结果标识(这些结果需要重新爬取) +INVALID_RESULTS = [ + "No relevant content found matching the query.", + "未找到与查询匹配的相关内容", +] + +# 针对每个页面结构设计的全量提取 Query +TOOLS_CONFIG = [ + ("crawl_ths_company", "提取公司的完整资料:1.基本信息(行业、产品、主营、办公地址);2.高管介绍(所有高管的姓名、职务、薪资、详细个人简历);3.发行相关(上市日期、首日表现、募资额);4.所有参控股公司的名称、持股比例、业务、盈亏情况。"), + ("crawl_ths_holder", "提取股东研究全量数据:1.历年股东人数及户均持股数;2.前十大股东及流通股东名单(含持股数、性质、变动情况);3.实际控制人详情及控股层级关系描述;4.股权质押、冻结的详细明细表。"), + ("crawl_ths_operate", "提取经营分析数据:1.主营构成分析表(按行业、产品、区域划分的营业收入、利润、毛利率及同比变化);2.经营评述(公司对业务、核心竞争力的详细自我评估)。"), + ("crawl_ths_equity", "提取股本结构信息:1.历次股本变动原因、日期及变动后的总股本;2.限售股份解禁的时间表、解禁数量及占总股本比例。"), + ("crawl_ths_capital", "提取资本运作详情:1.资产重组、收购、合并的详细历史记录;2.对外投资明细及进展情况。"), + ("crawl_ths_worth", "提取盈利预测信息:1.各机构最新评级汇总(买入/增持次数);2.未来三年的营收预测、净利润预测及EPS预测均值。"), + ("crawl_ths_news", "提取最新新闻公告:1.公司最新重要公告标题及日期;2.媒体报道的新闻摘要及舆情评价。"), + ("crawl_ths_concept", "提取所有概念题材:列出公司所属的所有概念板块,并详细提取每个概念对应的具体入选理由和业务关联性。"), + ("crawl_ths_position", "提取主力持仓情况:1.各类机构(基金、保险、QFII等)持仓总数及占比;2.前十大具体机构持仓名单及变动。"), + ("crawl_ths_finance", "提取财务分析详情:1.主要财务指标(盈利、成长、偿债等);2.资产负债表、利润表、现金流量表的核心科目及审计意见。"), + ("crawl_ths_bonus", "提取分红融资记录:1.历年现金分红、送转股份方案及实施日期;2.历次增发、配股等融资详情。"), + ("crawl_ths_event", "提取公司大事记录:1.股东及高管持股变动明细;2.对外担保记录、违规处理、机构调研及投资者互动记录。"), + ("crawl_ths_field", "提取行业对比数据:1.公司在所属行业内的规模、成长、盈利各项排名;2.与行业均值及同类竞品的关键财务指标对比。") +] + +# 创建必要的目录 +if not os.path.exists(BASE_CACHE_DIR): + os.makedirs(BASE_CACHE_DIR, exist_ok=True) +os.makedirs(PROGRESS_DIR, exist_ok=True) + + +class BatchResultSaver: + """批量结果保存器,支持自动文件切分""" + + def __init__(self, tool_name: str): + self.tool_name = tool_name + self.buffer = [] + self.file_index = 1 + + def _get_file_path(self) -> str: + """获取当前保存文件路径""" + return os.path.join(BASE_CACHE_DIR, f"{self.tool_name}_{self.file_index:02d}.json") + + def add_record(self, tool_args: dict, result_text: str): + """添加一条记录到缓冲区""" + now = datetime.now() + record = { + "_id": str(uuid.uuid4()), + "cache_key": f"{self.tool_name}::{json.dumps(tool_args, ensure_ascii=False)}", + "created_at": now.isoformat(), + "metadata": {"task_id": "comprehensive_crawl", "timestamp": now.isoformat()}, + "tool_args": tool_args, + "tool_name": self.tool_name, + "tool_result": result_text, + "updated_at": now.isoformat() + } + self.buffer.append(record) + if len(self.buffer) >= SAVE_BATCH_SIZE: + self.flush() + + def flush(self): + """将缓冲区数据写入文件""" + if not self.buffer: + return + + file_path = self._get_file_path() + + # 自动切分:如果文件超过 50MB,切换到新文件 + if os.path.exists(file_path) and os.path.getsize(file_path) > MAX_FILE_SIZE: + self.file_index += 1 + file_path = self._get_file_path() + + # 读取现有数据 + data = [] + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + try: + data = json.load(f) + except: + data = [] + + # 追加新数据并保存 + data.extend(self.buffer) + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + logger.info(f"保存文件: {file_path} | 总记录数: {len(data)}") + self.buffer = [] + + +class ProgressTracker: + """进度跟踪器,支持断点继续""" + + def __init__(self, tool_name: str): + self.tool_name = tool_name + self.progress_file = os.path.join(PROGRESS_DIR, f"{tool_name}_progress.json") + self.completed_codes = set() + self.time_records = {} # 记录每个code的处理时间 + self.load_progress() + + def load_progress(self): + """加载已有的进度数据""" + if os.path.exists(self.progress_file): + try: + with open(self.progress_file, "r", encoding="utf-8") as f: + data = json.load(f) + self.completed_codes = set(data.get("completed_codes", [])) + self.time_records = data.get("time_records", {}) + logger.info(f"加载进度文件: {self.progress_file},已完成 {len(self.completed_codes)} 个") + except Exception as e: + logger.warning(f"加载进度文件失败 {self.progress_file}: {e}") + self.completed_codes = set() + self.time_records = {} + + def save_progress(self): + """保存当前进度""" + try: + data = { + "completed_codes": list(self.completed_codes), + "time_records": self.time_records + } + with open(self.progress_file, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + except Exception as e: + logger.error(f"保存进度文件失败 {self.progress_file}: {e}") + + def mark_completed(self, code: str, elapsed_time: float | None = None): + """标记某个代码为已完成""" + self.completed_codes.add(code) + if elapsed_time is not None: + self.time_records[code] = { + "elapsed_seconds": round(elapsed_time, 2), + "completed_at": datetime.now().isoformat() + } + + def is_completed(self, code: str) -> bool: + """检查某个代码是否已完成""" + return code in self.completed_codes + + def get_remaining_codes(self, all_codes: list) -> list: + """获取剩余需要处理的代码列表""" + return [code for code in all_codes if not self.is_completed(code)] + + +MAX_RETRIES = 3 # 最大重试次数 + + +async def process_single_stock( + client: FastMcpClient, + tool_name: str, + code: str, + deep_query: str, + saver: BatchResultSaver, + progress_tracker: ProgressTracker, + semaphore: asyncio.Semaphore, + index: int, + total: int +): + """处理单个股票的异步任务""" + async with semaphore: # 信号量控制并发数 + args = {"code": code, "query": deep_query} + + for attempt in range(1, MAX_RETRIES + 1): + start_time = datetime.now() + + try: + logger.info(f"[{index}/{total}] 调用 {tool_name} 处理股票 {code} (尝试 {attempt}/{MAX_RETRIES})") + result: CallToolResult = await client.call_tool(tool_name, args) + elapsed_time = (datetime.now() - start_time).total_seconds() + + if not result.is_error: + content = result.content[0].text if result.content else "" + + # 检查内容是否为空、太短或是无效结果 + is_invalid = ( + not content + or len(content.strip()) < 10 + or content.strip() in INVALID_RESULTS + ) + + if is_invalid: + if attempt < MAX_RETRIES: + wait_seconds = random.randint(MIN_WAIT_ON_EMPTY, MAX_WAIT_ON_EMPTY) + logger.warning( + f"工具 {tool_name} 处理 {code} 无效结果: '{content[:50]}...', " + f"耗时 {elapsed_time:.2f}秒, 等待 {wait_seconds} 秒后重试 ({attempt}/{MAX_RETRIES})..." + ) + await asyncio.sleep(wait_seconds) + continue # 重试 + else: + logger.error( + f"工具 {tool_name} 处理 {code} 达到最大重试次数 {MAX_RETRIES}, " + f"仍为无效结果: '{content[:50]}...', 跳过" + ) + return # 达到最大重试次数,放弃 + else: + # 有正常输出,保存记录 + saver.add_record(args, content) + # 标记该代码已完成,记录耗时 + progress_tracker.mark_completed(code, elapsed_time) + # 定期保存进度 + if index % SAVE_BATCH_SIZE == 0: + progress_tracker.save_progress() + # 正常等待间隔 + logger.info(f"✓ 成功处理 {code}, 耗时 {elapsed_time:.2f}秒, 等待 {NORMAL_WAIT_SECONDS} 秒...") + await asyncio.sleep(NORMAL_WAIT_SECONDS) + return # 成功,退出重试循环 + else: + error_msg = result.content[0].text[:100] if result.content else "No error message" + if attempt < MAX_RETRIES: + wait_seconds = random.randint(MIN_WAIT_ON_EMPTY, MAX_WAIT_ON_EMPTY) + logger.warning(f"✗ 错误 {code}: {error_msg}, 等待 {wait_seconds} 秒后重试 ({attempt}/{MAX_RETRIES})...") + await asyncio.sleep(wait_seconds) + continue # 重试 + else: + logger.error(f"✗ 错误 {code}: {error_msg}, 达到最大重试次数 {MAX_RETRIES}, 跳过") + return + + except Exception as e: + import traceback + elapsed_time = (datetime.now() - start_time).total_seconds() + error_str = str(e) + + # 检查是否是“不当内容”错误,这种情况不需要重试 + if "inappropriate content" in error_str: + logger.warning(f"⚠ {code}: 返回不当内容错误,跳过不重试") + return + + if attempt < MAX_RETRIES: + wait_seconds = random.randint(MIN_WAIT_ON_EMPTY, MAX_WAIT_ON_EMPTY) + logger.warning(f"✗ 失败 {code}: {e}, 等待 {wait_seconds} 秒后重试 ({attempt}/{MAX_RETRIES})...") + await asyncio.sleep(wait_seconds) + continue # 重试 + else: + logger.error(f"✗ 失败 {code}: {e}, 达到最大重试次数 {MAX_RETRIES}, 跳过") + logger.error(f"Traceback: {traceback.format_exc()}") + return + + +async def run_crawl_task(): + """执行爬取任务的主函数""" + # 读取股票代码列表 + logger.info(f"读取股票代码列表: {CSV_PATH}") + df = pd.read_csv(CSV_PATH, dtype={'symbol': str}) + stock_codes = df['symbol'].dropna().apply(lambda x: str(x).zfill(6)).tolist() + logger.info(f"共加载 {len(stock_codes)} 个股票代码") + + # MCP 客户端配置 + mcp_config = {"type": "sse", "url": f"http://{HOST}:{PORT}/sse"} + + # 信号量控制并发数 + semaphore = asyncio.Semaphore(MAX_CONCURRENCY) + + # 【第二步】汇总统计各工具待爬取数量 + logger.info(f"\n{'='*60}") + logger.info("【第二步】统计各工具待爬取数量...") + logger.info(f"{'='*60}") + + total_tasks = 0 + tool_stats = [] + for tool_name, deep_query in TOOLS_CONFIG: + progress_tracker = ProgressTracker(tool_name) + remaining_codes = progress_tracker.get_remaining_codes(stock_codes) + total_remaining = len(remaining_codes) + completed_count = len(stock_codes) - total_remaining + total_tasks += total_remaining + tool_stats.append((tool_name, completed_count, total_remaining)) + # 只列出还需要爬取的工具 + if total_remaining > 0: + logger.info( + f" {tool_name}: 已完成 {completed_count}, 待爬取 {total_remaining}" + ) + + logger.info(f"{'='*60}") + logger.info(f"汇总: 共 {len(TOOLS_CONFIG)} 个工具, 总计待爬取 {total_tasks} 条记录") + logger.info(f"{'='*60}\n") + + if total_tasks == 0: + logger.info("所有工具已完成爬取,无需继续") + return + + # 【第三步】开始爬取任务 + logger.info(f"\n{'='*60}") + logger.info("【第三步】开始爬取任务...") + logger.info(f"{'='*60}\n") + + async with FastMcpClient(name="full-info-crawler", config=mcp_config) as client: + # 外层循环:遍历所有工具 + for tool_name, deep_query in TOOLS_CONFIG: + logger.info(f"\n{'-'*60}") + logger.info(f"开始爬取工具: {tool_name}") + logger.info(f"查询内容: {deep_query}") + logger.info(f"{'-'*60}") + + saver = BatchResultSaver(tool_name) + progress_tracker = ProgressTracker(tool_name) + + # 获取剩余需要处理的股票代码 + remaining_codes = progress_tracker.get_remaining_codes(stock_codes) + total_remaining = len(remaining_codes) + completed_count = len(stock_codes) - total_remaining + + logger.info( + f"工具 {tool_name}: 总计 {len(stock_codes)} 个股票, " + f"已完成 {completed_count} 个, 剩余 {total_remaining} 个" + ) + + if total_remaining == 0: + logger.info(f"工具 {tool_name} 所有股票已处理完成,跳过") + continue + + # 创建所有并发任务 + tasks = [] + for i, code in enumerate(remaining_codes, start=1): + task = process_single_stock( + client=client, + tool_name=tool_name, + code=code, + deep_query=deep_query, + saver=saver, + progress_tracker=progress_tracker, + semaphore=semaphore, + index=i, + total=total_remaining + ) + tasks.append(task) + + # 并发执行所有任务,信号量控制最多 MAX_CONCURRENCY 个同时运行 + logger.info(f"启动 {len(tasks)} 个并发任务,最大并发数: {MAX_CONCURRENCY}") + await asyncio.gather(*tasks) + + # 保存最后的进度 + progress_tracker.save_progress() + saver.flush() + logger.info(f"\n{'='*80}") + logger.info(f"✓ 工具 {tool_name} 完成!") + logger.info(f"{'='*80}\n") + + +def main(): + """主函数:启动服务并运行爬取任务""" + logger.info("="*80) + logger.info("开始启动 finance-mcp 服务...") + logger.info(f"服务参数: {SERVICE_ARGS}") + logger.info(f"监听地址: {HOST}:{PORT}") + logger.info("="*80) + + # 使用 FinanceMcpServiceRunner 启动服务 + with FinanceMcpServiceRunner( + SERVICE_ARGS, + host=HOST, + port=PORT, + ) as service: + logger.info(f"✓ 服务已启动,监听端口: {service.port}") + logger.info("开始执行爬取任务...\n") + + # 运行爬取任务 + asyncio.run(run_crawl_task()) + + logger.info("\n" + "="*80) + logger.info("✓ 所有任务已完成") + logger.info("="*80) + + +if __name__ == "__main__": + main()