api-key-manager 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api_key_manager-2.1.0.dist-info/METADATA +709 -0
- api_key_manager-2.1.0.dist-info/RECORD +73 -0
- api_key_manager-2.1.0.dist-info/WHEEL +5 -0
- api_key_manager-2.1.0.dist-info/entry_points.txt +2 -0
- api_key_manager-2.1.0.dist-info/top_level.txt +1 -0
- key_manager/__init__.py +16 -0
- key_manager/__main__.py +5 -0
- key_manager/api_models.py +358 -0
- key_manager/checker.py +51 -0
- key_manager/cli.py +270 -0
- key_manager/config.py +61 -0
- key_manager/core.py +205 -0
- key_manager/detector.py +335 -0
- key_manager/errors.py +179 -0
- key_manager/i18n.py +142 -0
- key_manager/logger.py +207 -0
- key_manager/model_capabilities.py +412 -0
- key_manager/parser.py +153 -0
- key_manager/providers/__init__.py +283 -0
- key_manager/providers/ai302.py +109 -0
- key_manager/providers/anthropic.py +109 -0
- key_manager/providers/baichuan.py +97 -0
- key_manager/providers/base.py +312 -0
- key_manager/providers/cerebras.py +109 -0
- key_manager/providers/cohere.py +90 -0
- key_manager/providers/cstcloud.py +122 -0
- key_manager/providers/dashscope.py +120 -0
- key_manager/providers/dashscope_coding.py +122 -0
- key_manager/providers/deepseek.py +166 -0
- key_manager/providers/dmxapi.py +109 -0
- key_manager/providers/doubao.py +109 -0
- key_manager/providers/fireworks.py +109 -0
- key_manager/providers/google.py +99 -0
- key_manager/providers/grok.py +109 -0
- key_manager/providers/groq.py +109 -0
- key_manager/providers/huggingface.py +54 -0
- key_manager/providers/hyperbolic.py +109 -0
- key_manager/providers/infini.py +135 -0
- key_manager/providers/infini_coding.py +124 -0
- key_manager/providers/kimi.py +121 -0
- key_manager/providers/kimi_coding.py +124 -0
- key_manager/providers/longcat.py +123 -0
- key_manager/providers/mimo.py +109 -0
- key_manager/providers/mimo_plan.py +140 -0
- key_manager/providers/minimax.py +97 -0
- key_manager/providers/minimax_plan.py +122 -0
- key_manager/providers/mistral.py +109 -0
- key_manager/providers/models_registry.py +2901 -0
- key_manager/providers/modelscope.py +134 -0
- key_manager/providers/nvidia.py +109 -0
- key_manager/providers/ocoolai.py +109 -0
- key_manager/providers/openai.py +140 -0
- key_manager/providers/openrouter.py +119 -0
- key_manager/providers/perplexity.py +109 -0
- key_manager/providers/poe.py +109 -0
- key_manager/providers/ppio.py +109 -0
- key_manager/providers/replicate.py +54 -0
- key_manager/providers/siliconflow.py +121 -0
- key_manager/providers/stepfun.py +132 -0
- key_manager/providers/tencent_hunyuan.py +122 -0
- key_manager/providers/together.py +134 -0
- key_manager/providers/yi.py +97 -0
- key_manager/providers/zai.py +109 -0
- key_manager/providers/zhipu.py +127 -0
- key_manager/providers/zhipu_coding.py +124 -0
- key_manager/proxy.py +70 -0
- key_manager/ssrf.py +68 -0
- key_manager/storage.py +134 -0
- key_manager/tester.py +137 -0
- key_manager/url_override.py +5 -0
- key_manager/validator.py +185 -0
- key_manager/web.py +1512 -0
- key_manager/webhook.py +257 -0
key_manager/logger.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import json
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class KeyLogger:
|
|
9
|
+
"""Structured logger with date-based file rotation."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, logs_dir: str, category: str):
|
|
12
|
+
self.logs_dir = Path(logs_dir)
|
|
13
|
+
self.category = category
|
|
14
|
+
self.logger = logging.getLogger(f"keymanager.{category}")
|
|
15
|
+
self.logger.setLevel(logging.DEBUG)
|
|
16
|
+
if not self.logger.handlers:
|
|
17
|
+
self._setup_file_handler()
|
|
18
|
+
|
|
19
|
+
def _setup_file_handler(self):
|
|
20
|
+
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
|
21
|
+
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
22
|
+
log_file = self.logs_dir / f"{self.category}_{date_str}.log"
|
|
23
|
+
|
|
24
|
+
handler = logging.FileHandler(log_file, encoding="utf-8")
|
|
25
|
+
handler.setLevel(logging.DEBUG)
|
|
26
|
+
formatter = logging.Formatter("%(message)s")
|
|
27
|
+
handler.setFormatter(formatter)
|
|
28
|
+
self.logger.addHandler(handler)
|
|
29
|
+
|
|
30
|
+
def log(self, action: str, provider: str, key_masked: str,
|
|
31
|
+
status: str, detail: str = "", latency: float = 0.0):
|
|
32
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
33
|
+
provider_padded = provider.ljust(10)
|
|
34
|
+
line = f"[{timestamp}] [{action}] [{provider_padded}] {key_masked} -> {status}"
|
|
35
|
+
if detail:
|
|
36
|
+
line += f" ({detail}, {latency:.2f}s)"
|
|
37
|
+
self.logger.info(line)
|
|
38
|
+
|
|
39
|
+
def flush(self):
|
|
40
|
+
for handler in self.logger.handlers:
|
|
41
|
+
handler.flush()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ProjectLogger:
|
|
45
|
+
"""Main project logger that logs all operations."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, logs_dir: str = "./data/logs"):
|
|
48
|
+
self.logs_dir = Path(logs_dir)
|
|
49
|
+
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
self._setup_loggers()
|
|
51
|
+
|
|
52
|
+
def _setup_loggers(self):
|
|
53
|
+
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
54
|
+
|
|
55
|
+
# Main log - all operations
|
|
56
|
+
self.main_logger = self._create_logger(
|
|
57
|
+
"main",
|
|
58
|
+
self.logs_dir / f"main_{date_str}.log"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# JSON log for structured data
|
|
62
|
+
self.json_log_file = self.logs_dir / f"operations_{date_str}.jsonl"
|
|
63
|
+
|
|
64
|
+
def _create_logger(self, name: str, log_file: Path) -> logging.Logger:
|
|
65
|
+
logger = logging.getLogger(f"project.{name}")
|
|
66
|
+
logger.setLevel(logging.DEBUG)
|
|
67
|
+
if not logger.handlers:
|
|
68
|
+
handler = logging.FileHandler(log_file, encoding="utf-8")
|
|
69
|
+
handler.setLevel(logging.DEBUG)
|
|
70
|
+
formatter = logging.Formatter("%(message)s")
|
|
71
|
+
handler.setFormatter(formatter)
|
|
72
|
+
logger.addHandler(handler)
|
|
73
|
+
return logger
|
|
74
|
+
|
|
75
|
+
def _write_json_log(self, operation: str, data: dict):
|
|
76
|
+
"""Write structured JSON log entry."""
|
|
77
|
+
entry = {
|
|
78
|
+
"timestamp": datetime.now().isoformat(),
|
|
79
|
+
"operation": operation,
|
|
80
|
+
**data
|
|
81
|
+
}
|
|
82
|
+
with open(self.json_log_file, "a", encoding="utf-8") as f:
|
|
83
|
+
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
84
|
+
|
|
85
|
+
def log_import(self, filename: str, new_keys: int, duplicates: int, errors: list):
|
|
86
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
87
|
+
line = f"[{timestamp}] [IMPORT] {filename} -> new={new_keys}, dupes={duplicates}"
|
|
88
|
+
if errors:
|
|
89
|
+
line += f", errors={len(errors)}"
|
|
90
|
+
self.main_logger.info(line)
|
|
91
|
+
self._write_json_log("import", {
|
|
92
|
+
"filename": filename,
|
|
93
|
+
"new_keys": new_keys,
|
|
94
|
+
"duplicates": duplicates,
|
|
95
|
+
"errors": errors
|
|
96
|
+
})
|
|
97
|
+
|
|
98
|
+
def log_check(self, provider: str, key_masked: str, status: str,
|
|
99
|
+
status_code: Optional[int] = None, latency_ms: float = 0, error: str = None):
|
|
100
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
101
|
+
provider_padded = provider.ljust(10)
|
|
102
|
+
line = f"[{timestamp}] [CHECK ] [{provider_padded}] {key_masked} -> {status}"
|
|
103
|
+
if status_code:
|
|
104
|
+
line += f" (code={status_code}, {latency_ms:.0f}ms)"
|
|
105
|
+
if error:
|
|
106
|
+
line += f" error={error}"
|
|
107
|
+
self.main_logger.info(line)
|
|
108
|
+
self._write_json_log("check", {
|
|
109
|
+
"provider": provider,
|
|
110
|
+
"key_masked": key_masked,
|
|
111
|
+
"status": status,
|
|
112
|
+
"status_code": status_code,
|
|
113
|
+
"latency_ms": latency_ms,
|
|
114
|
+
"error": error
|
|
115
|
+
})
|
|
116
|
+
|
|
117
|
+
def log_test(self, provider: str, key_masked: str, max_tokens: Optional[int],
|
|
118
|
+
max_concurrency: Optional[int], models_count: int = 0):
|
|
119
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
120
|
+
provider_padded = provider.ljust(10)
|
|
121
|
+
line = f"[{timestamp}] [TEST ] [{provider_padded}] {key_masked} -> tokens={max_tokens}, conc={max_concurrency}, models={models_count}"
|
|
122
|
+
self.main_logger.info(line)
|
|
123
|
+
self._write_json_log("test", {
|
|
124
|
+
"provider": provider,
|
|
125
|
+
"key_masked": key_masked,
|
|
126
|
+
"max_tokens": max_tokens,
|
|
127
|
+
"max_concurrency": max_concurrency,
|
|
128
|
+
"models_count": models_count
|
|
129
|
+
})
|
|
130
|
+
|
|
131
|
+
def log_manual_check(self, key_masked: str, provider: str, status: str,
|
|
132
|
+
check_type: str = "fast", error: str = None):
|
|
133
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
134
|
+
line = f"[{timestamp}] [MANUAL ] [{provider.ljust(10)}] {key_masked} -> {status} ({check_type})"
|
|
135
|
+
if error:
|
|
136
|
+
line += f" error={error}"
|
|
137
|
+
self.main_logger.info(line)
|
|
138
|
+
self._write_json_log("manual_check", {
|
|
139
|
+
"key_masked": key_masked,
|
|
140
|
+
"provider": provider,
|
|
141
|
+
"status": status,
|
|
142
|
+
"check_type": check_type,
|
|
143
|
+
"error": error
|
|
144
|
+
})
|
|
145
|
+
|
|
146
|
+
def log_export(self, count: int, format: str = "txt"):
|
|
147
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
148
|
+
line = f"[{timestamp}] [EXPORT ] Exported {count} keys as {format}"
|
|
149
|
+
self.main_logger.info(line)
|
|
150
|
+
self._write_json_log("export", {"count": count, "format": format})
|
|
151
|
+
|
|
152
|
+
def log_web_action(self, action: str, detail: str = ""):
|
|
153
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
154
|
+
line = f"[{timestamp}] [WEB ] {action}"
|
|
155
|
+
if detail:
|
|
156
|
+
line += f" - {detail}"
|
|
157
|
+
self.main_logger.info(line)
|
|
158
|
+
|
|
159
|
+
def get_recent_logs(self, lines: int = 100) -> list:
|
|
160
|
+
"""Get recent log entries from main log file."""
|
|
161
|
+
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
162
|
+
log_file = self.logs_dir / f"main_{date_str}.log"
|
|
163
|
+
|
|
164
|
+
if not log_file.exists():
|
|
165
|
+
return []
|
|
166
|
+
|
|
167
|
+
with open(log_file, "r", encoding="utf-8") as f:
|
|
168
|
+
all_lines = f.readlines()
|
|
169
|
+
return [line.strip() for line in all_lines[-lines:]]
|
|
170
|
+
|
|
171
|
+
def get_operations_log(self, limit: int = 50) -> list:
|
|
172
|
+
"""Get structured operations log."""
|
|
173
|
+
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
174
|
+
json_file = self.logs_dir / f"operations_{date_str}.jsonl"
|
|
175
|
+
|
|
176
|
+
if not json_file.exists():
|
|
177
|
+
return []
|
|
178
|
+
|
|
179
|
+
entries = []
|
|
180
|
+
with open(json_file, "r", encoding="utf-8") as f:
|
|
181
|
+
for line in f:
|
|
182
|
+
try:
|
|
183
|
+
entries.append(json.loads(line.strip()))
|
|
184
|
+
except:
|
|
185
|
+
pass
|
|
186
|
+
return entries[-limit:]
|
|
187
|
+
|
|
188
|
+
def get_log_files(self) -> list:
|
|
189
|
+
"""Get list of all log files."""
|
|
190
|
+
files = []
|
|
191
|
+
for f in sorted(self.logs_dir.glob("*.log"), reverse=True):
|
|
192
|
+
files.append({
|
|
193
|
+
"name": f.name,
|
|
194
|
+
"size": f.stat().st_size,
|
|
195
|
+
"modified": datetime.fromtimestamp(f.stat().st_mtime).isoformat()
|
|
196
|
+
})
|
|
197
|
+
for f in sorted(self.logs_dir.glob("*.jsonl"), reverse=True):
|
|
198
|
+
files.append({
|
|
199
|
+
"name": f.name,
|
|
200
|
+
"size": f.stat().st_size,
|
|
201
|
+
"modified": datetime.fromtimestamp(f.stat().st_mtime).isoformat()
|
|
202
|
+
})
|
|
203
|
+
return files
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# Global logger instance
|
|
207
|
+
project_logger = ProjectLogger()
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
"""
|
|
2
|
+
模型能力检测模块
|
|
3
|
+
|
|
4
|
+
从 GitHub 同步 Cherry Studio 的模型能力规则,支持三层降级:
|
|
5
|
+
1. 从 GitHub 拉取最新版
|
|
6
|
+
2. 使用本地缓存(7天 TTL)
|
|
7
|
+
3. 使用硬编码兜底数据
|
|
8
|
+
|
|
9
|
+
用法:
|
|
10
|
+
from src.model_capabilities import detector
|
|
11
|
+
|
|
12
|
+
# 初始化(在应用启动时调用)
|
|
13
|
+
await detector.load()
|
|
14
|
+
|
|
15
|
+
# 检测模型能力
|
|
16
|
+
if detector.is_vision_model("gpt-4o"):
|
|
17
|
+
print("支持视觉")
|
|
18
|
+
|
|
19
|
+
if detector.is_tool_model("claude-sonnet-4"):
|
|
20
|
+
print("支持工具调用")
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import json
|
|
24
|
+
import re
|
|
25
|
+
import logging
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from datetime import datetime, timedelta
|
|
28
|
+
|
|
29
|
+
import httpx
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
# 配置
|
|
34
|
+
CAPS_URL = "https://raw.githubusercontent.com/Townrain/API-Key-Manager/main/data/model_capabilities.json"
|
|
35
|
+
CACHE_FILE = Path("data/cache/model_capabilities.json")
|
|
36
|
+
FALLBACK_FILE = Path("data/model_capabilities_fallback.json")
|
|
37
|
+
CACHE_TTL = timedelta(days=7)
|
|
38
|
+
SCHEMA_VERSION = 1
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ModelCapabilityDetector:
|
|
42
|
+
"""模型能力检测器(三层降级)"""
|
|
43
|
+
|
|
44
|
+
def __init__(self):
|
|
45
|
+
self._caps: dict | None = None
|
|
46
|
+
self._loaded_from: str | None = None
|
|
47
|
+
self._loaded_at: datetime | None = None
|
|
48
|
+
|
|
49
|
+
async def load(self, force: bool = False) -> str:
|
|
50
|
+
"""
|
|
51
|
+
加载模型能力配置(三层降级)
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
str: 数据来源 ("github", "cache", "fallback")
|
|
55
|
+
|
|
56
|
+
降级顺序:
|
|
57
|
+
1. 从 GitHub 拉取最新版
|
|
58
|
+
2. 使用本地缓存(7天 TTL)
|
|
59
|
+
3. 使用硬编码兜底数据
|
|
60
|
+
"""
|
|
61
|
+
if not force and self._caps and self._loaded_at:
|
|
62
|
+
# 检查是否需要重新加载(1小时内的请求直接返回缓存)
|
|
63
|
+
if datetime.utcnow() - self._loaded_at < timedelta(hours=1):
|
|
64
|
+
return self._loaded_from
|
|
65
|
+
|
|
66
|
+
# 第 1 层:尝试从 GitHub 拉取
|
|
67
|
+
try:
|
|
68
|
+
self._caps = await self._fetch_from_github()
|
|
69
|
+
if self._validate_schema(self._caps):
|
|
70
|
+
self._save_to_cache(self._caps)
|
|
71
|
+
self._loaded_from = "github"
|
|
72
|
+
self._loaded_at = datetime.utcnow()
|
|
73
|
+
logger.info("从 GitHub 加载模型能力配置成功")
|
|
74
|
+
return self._loaded_from
|
|
75
|
+
else:
|
|
76
|
+
logger.warning("GitHub 返回的配置 schema 版本不兼容")
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.warning(f"GitHub 拉取失败: {e}")
|
|
79
|
+
|
|
80
|
+
# 第 2 层:使用本地缓存
|
|
81
|
+
try:
|
|
82
|
+
self._caps = self._load_from_cache()
|
|
83
|
+
self._loaded_from = "cache"
|
|
84
|
+
self._loaded_at = datetime.utcnow()
|
|
85
|
+
logger.info("从本地缓存加载模型能力配置成功")
|
|
86
|
+
return self._loaded_from
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.warning(f"本地缓存加载失败: {e}")
|
|
89
|
+
|
|
90
|
+
# 第 3 层:使用硬编码兜底
|
|
91
|
+
self._caps = self._load_fallback()
|
|
92
|
+
self._loaded_from = "fallback"
|
|
93
|
+
self._loaded_at = datetime.utcnow()
|
|
94
|
+
logger.info("使用硬编码兜底数据")
|
|
95
|
+
return self._loaded_from
|
|
96
|
+
|
|
97
|
+
async def _fetch_from_github(self) -> dict:
|
|
98
|
+
"""从 GitHub 拉取最新配置"""
|
|
99
|
+
async with httpx.AsyncClient() as client:
|
|
100
|
+
resp = await client.get(CAPS_URL, timeout=10)
|
|
101
|
+
resp.raise_for_status()
|
|
102
|
+
return resp.json()
|
|
103
|
+
|
|
104
|
+
def _validate_schema(self, data: dict) -> bool:
|
|
105
|
+
"""验证 schema 版本"""
|
|
106
|
+
return data.get("schema_version") == SCHEMA_VERSION
|
|
107
|
+
|
|
108
|
+
def _save_to_cache(self, data: dict):
|
|
109
|
+
"""保存到本地缓存"""
|
|
110
|
+
CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
111
|
+
CACHE_FILE.write_text(
|
|
112
|
+
json.dumps(data, indent=2, ensure_ascii=False),
|
|
113
|
+
encoding="utf-8"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def _load_from_cache(self) -> dict:
|
|
117
|
+
"""从本地缓存加载"""
|
|
118
|
+
if not CACHE_FILE.exists():
|
|
119
|
+
raise FileNotFoundError("缓存文件不存在")
|
|
120
|
+
|
|
121
|
+
data = json.loads(CACHE_FILE.read_text(encoding="utf-8"))
|
|
122
|
+
|
|
123
|
+
# 检查 schema 版本
|
|
124
|
+
if not self._validate_schema(data):
|
|
125
|
+
raise ValueError(f"Schema 版本不兼容: {data.get('schema_version')}")
|
|
126
|
+
|
|
127
|
+
# 检查是否过期
|
|
128
|
+
updated_str = data.get("updated_at", "")
|
|
129
|
+
if updated_str:
|
|
130
|
+
updated = datetime.fromisoformat(updated_str.rstrip("Z"))
|
|
131
|
+
if datetime.utcnow() - updated > CACHE_TTL:
|
|
132
|
+
raise ValueError("缓存已过期")
|
|
133
|
+
|
|
134
|
+
return data
|
|
135
|
+
|
|
136
|
+
def _load_fallback(self) -> dict:
|
|
137
|
+
"""加载硬编码兜底数据"""
|
|
138
|
+
if FALLBACK_FILE.exists():
|
|
139
|
+
return json.loads(FALLBACK_FILE.read_text(encoding="utf-8"))
|
|
140
|
+
|
|
141
|
+
# 如果兜底文件也不存在,返回最小可用数据
|
|
142
|
+
return {
|
|
143
|
+
"schema_version": 1,
|
|
144
|
+
"updated_at": "2026-01-01T00:00:00Z",
|
|
145
|
+
"source": "hardcoded-minimal",
|
|
146
|
+
"capabilities": {
|
|
147
|
+
"vision": {
|
|
148
|
+
"allowed_patterns": ["gpt-4o", "claude-3", "gemini", "o1", "o3"],
|
|
149
|
+
"excluded_patterns": ["o1-mini", "o3-mini"]
|
|
150
|
+
},
|
|
151
|
+
"tooluse": {
|
|
152
|
+
"allowed_patterns": ["gpt-4o", "gpt-4", "claude", "deepseek", "gemini"],
|
|
153
|
+
"excluded_patterns": ["o1-mini"]
|
|
154
|
+
},
|
|
155
|
+
"embedding": {
|
|
156
|
+
"embedding_regex": "(?i)(?:text-embedding|embed|bge-|e5-|gte-|voyage-)",
|
|
157
|
+
"rerank_regex": "(?i)(?:rerank|re-rank)"
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
def is_vision_model(self, model_id: str) -> bool:
|
|
163
|
+
"""
|
|
164
|
+
判断是否为视觉模型
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
model_id: 模型 ID (如 "gpt-4o", "claude-sonnet-4")
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
bool: 是否支持视觉输入
|
|
171
|
+
"""
|
|
172
|
+
if not self._caps:
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
vision = self._caps["capabilities"].get("vision", {})
|
|
176
|
+
allowed = vision.get("allowed_patterns", [])
|
|
177
|
+
excluded = vision.get("excluded_patterns", [])
|
|
178
|
+
|
|
179
|
+
# 检查排除列表
|
|
180
|
+
for pattern in excluded:
|
|
181
|
+
try:
|
|
182
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
183
|
+
return False
|
|
184
|
+
except re.error:
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
# 检查允许列表
|
|
188
|
+
for pattern in allowed:
|
|
189
|
+
try:
|
|
190
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
191
|
+
return True
|
|
192
|
+
except re.error:
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
return False
|
|
196
|
+
|
|
197
|
+
def is_tool_model(self, model_id: str) -> bool:
|
|
198
|
+
"""
|
|
199
|
+
判断是否支持工具调用
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
model_id: 模型 ID
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
bool: 是否支持工具调用
|
|
206
|
+
"""
|
|
207
|
+
if not self._caps:
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
tooluse = self._caps["capabilities"].get("tooluse", {})
|
|
211
|
+
allowed = tooluse.get("allowed_patterns", [])
|
|
212
|
+
excluded = tooluse.get("excluded_patterns", [])
|
|
213
|
+
|
|
214
|
+
# 检查排除列表
|
|
215
|
+
for pattern in excluded:
|
|
216
|
+
try:
|
|
217
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
218
|
+
return False
|
|
219
|
+
except re.error:
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
# 检查允许列表
|
|
223
|
+
for pattern in allowed:
|
|
224
|
+
try:
|
|
225
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
226
|
+
return True
|
|
227
|
+
except re.error:
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
def is_embedding_model(self, model_id: str) -> bool:
|
|
233
|
+
"""
|
|
234
|
+
判断是否为嵌入模型
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
model_id: 模型 ID
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
bool: 是否为嵌入模型
|
|
241
|
+
"""
|
|
242
|
+
if not self._caps:
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
regex = self._caps["capabilities"]["embedding"].get("embedding_regex")
|
|
246
|
+
if regex:
|
|
247
|
+
try:
|
|
248
|
+
return bool(re.search(regex, model_id, re.IGNORECASE))
|
|
249
|
+
except re.error:
|
|
250
|
+
return False
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
def is_rerank_model(self, model_id: str) -> bool:
|
|
254
|
+
"""
|
|
255
|
+
判断是否为重排模型
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
model_id: 模型 ID
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
bool: 是否为重排模型
|
|
262
|
+
"""
|
|
263
|
+
if not self._caps:
|
|
264
|
+
return False
|
|
265
|
+
|
|
266
|
+
regex = self._caps["capabilities"]["embedding"].get("rerank_regex")
|
|
267
|
+
if regex:
|
|
268
|
+
try:
|
|
269
|
+
return bool(re.search(regex, model_id, re.IGNORECASE))
|
|
270
|
+
except re.error:
|
|
271
|
+
return False
|
|
272
|
+
return False
|
|
273
|
+
|
|
274
|
+
def is_reasoning_model(self, model_id: str) -> bool:
|
|
275
|
+
"""
|
|
276
|
+
判断是否为推理模型
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
model_id: 模型 ID
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
bool: 是否为推理模型
|
|
283
|
+
"""
|
|
284
|
+
if not self._caps:
|
|
285
|
+
return False
|
|
286
|
+
|
|
287
|
+
# 从配置中获取推理模型列表
|
|
288
|
+
reasoning = self._caps["capabilities"].get('reasoning', {})
|
|
289
|
+
allowed = reasoning.get('allowed_patterns', [])
|
|
290
|
+
excluded = reasoning.get('excluded_patterns', [])
|
|
291
|
+
|
|
292
|
+
# 检查排除列表
|
|
293
|
+
for pattern in excluded:
|
|
294
|
+
try:
|
|
295
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
296
|
+
return False
|
|
297
|
+
except re.error:
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
# 检查允许列表
|
|
301
|
+
for pattern in allowed:
|
|
302
|
+
try:
|
|
303
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
304
|
+
return True
|
|
305
|
+
except re.error:
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
return False
|
|
309
|
+
|
|
310
|
+
def is_websearch_model(self, model_id: str) -> bool:
|
|
311
|
+
"""
|
|
312
|
+
判断是否为联网搜索模型
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
model_id: 模型 ID
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
bool: 是否支持联网搜索
|
|
319
|
+
"""
|
|
320
|
+
if not self._caps:
|
|
321
|
+
return False
|
|
322
|
+
|
|
323
|
+
# 从配置中获取联网模型列表
|
|
324
|
+
websearch = self._caps["capabilities"].get('websearch', {})
|
|
325
|
+
allowed = websearch.get('allowed_patterns', [])
|
|
326
|
+
excluded = websearch.get('excluded_patterns', [])
|
|
327
|
+
|
|
328
|
+
# 检查排除列表
|
|
329
|
+
for pattern in excluded:
|
|
330
|
+
try:
|
|
331
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
332
|
+
return False
|
|
333
|
+
except re.error:
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
# 检查允许列表
|
|
337
|
+
for pattern in allowed:
|
|
338
|
+
try:
|
|
339
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
340
|
+
return True
|
|
341
|
+
except re.error:
|
|
342
|
+
continue
|
|
343
|
+
|
|
344
|
+
return False
|
|
345
|
+
|
|
346
|
+
def is_free_model(self, model_id: str) -> bool:
|
|
347
|
+
"""
|
|
348
|
+
判断是否为免费模型
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
model_id: 模型 ID
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
bool: 是否为免费模型
|
|
355
|
+
"""
|
|
356
|
+
if not self._caps:
|
|
357
|
+
return False
|
|
358
|
+
|
|
359
|
+
# 从配置中获取免费模型列表
|
|
360
|
+
free = self._caps["capabilities"].get('free', {})
|
|
361
|
+
allowed = free.get('allowed_patterns', [])
|
|
362
|
+
|
|
363
|
+
# 检查允许列表
|
|
364
|
+
for pattern in allowed:
|
|
365
|
+
try:
|
|
366
|
+
if re.search(pattern, model_id, re.IGNORECASE):
|
|
367
|
+
return True
|
|
368
|
+
except re.error:
|
|
369
|
+
continue
|
|
370
|
+
|
|
371
|
+
return False
|
|
372
|
+
|
|
373
|
+
def get_model_capabilities(self, model_id: str) -> dict[str, bool]:
|
|
374
|
+
"""
|
|
375
|
+
获取模型的所有能力
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
model_id: 模型 ID
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
dict: 能力字典
|
|
382
|
+
"""
|
|
383
|
+
return {
|
|
384
|
+
"vision": self.is_vision_model(model_id),
|
|
385
|
+
"tooluse": self.is_tool_model(model_id),
|
|
386
|
+
"embedding": self.is_embedding_model(model_id),
|
|
387
|
+
"rerank": self.is_rerank_model(model_id),
|
|
388
|
+
"reasoning": self.is_reasoning_model(model_id),
|
|
389
|
+
"websearch": self.is_websearch_model(model_id),
|
|
390
|
+
"free": self.is_free_model(model_id),
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
@property
|
|
394
|
+
def loaded_from(self) -> str | None:
|
|
395
|
+
"""数据来源: github / cache / fallback"""
|
|
396
|
+
return self._loaded_from
|
|
397
|
+
|
|
398
|
+
@property
|
|
399
|
+
def is_loaded(self) -> bool:
|
|
400
|
+
"""是否已加载"""
|
|
401
|
+
return self._caps is not None
|
|
402
|
+
|
|
403
|
+
@property
|
|
404
|
+
def updated_at(self) -> str | None:
|
|
405
|
+
"""数据更新时间"""
|
|
406
|
+
if self._caps:
|
|
407
|
+
return self._caps.get("updated_at")
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
# 全局单例
|
|
412
|
+
detector = ModelCapabilityDetector()
|