flexllm 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexllm/__init__.py +224 -0
- flexllm/__main__.py +1096 -0
- flexllm/async_api/__init__.py +9 -0
- flexllm/async_api/concurrent_call.py +100 -0
- flexllm/async_api/concurrent_executor.py +1036 -0
- flexllm/async_api/core.py +373 -0
- flexllm/async_api/interface.py +12 -0
- flexllm/async_api/progress.py +277 -0
- flexllm/base_client.py +988 -0
- flexllm/batch_tools/__init__.py +16 -0
- flexllm/batch_tools/folder_processor.py +317 -0
- flexllm/batch_tools/table_processor.py +363 -0
- flexllm/cache/__init__.py +10 -0
- flexllm/cache/response_cache.py +293 -0
- flexllm/chain_of_thought_client.py +1120 -0
- flexllm/claudeclient.py +402 -0
- flexllm/client_pool.py +698 -0
- flexllm/geminiclient.py +563 -0
- flexllm/llm_client.py +523 -0
- flexllm/llm_parser.py +60 -0
- flexllm/mllm_client.py +559 -0
- flexllm/msg_processors/__init__.py +174 -0
- flexllm/msg_processors/image_processor.py +729 -0
- flexllm/msg_processors/image_processor_helper.py +485 -0
- flexllm/msg_processors/messages_processor.py +341 -0
- flexllm/msg_processors/unified_processor.py +1404 -0
- flexllm/openaiclient.py +256 -0
- flexllm/pricing/__init__.py +104 -0
- flexllm/pricing/data.json +1201 -0
- flexllm/pricing/updater.py +223 -0
- flexllm/provider_router.py +213 -0
- flexllm/token_counter.py +270 -0
- flexllm/utils/__init__.py +1 -0
- flexllm/utils/core.py +41 -0
- flexllm-0.3.3.dist-info/METADATA +573 -0
- flexllm-0.3.3.dist-info/RECORD +39 -0
- flexllm-0.3.3.dist-info/WHEEL +4 -0
- flexllm-0.3.3.dist-info/entry_points.txt +3 -0
- flexllm-0.3.3.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
模型定价自动更新脚本
|
|
6
|
+
|
|
7
|
+
从 OpenRouter API 获取所有模型定价,更新到 pricing.json
|
|
8
|
+
|
|
9
|
+
使用方法:
|
|
10
|
+
# 预览更新内容
|
|
11
|
+
python -m flexllm.pricing.updater
|
|
12
|
+
|
|
13
|
+
# 直接更新 pricing.json
|
|
14
|
+
python -m flexllm.pricing.updater --apply
|
|
15
|
+
|
|
16
|
+
# 输出 JSON 格式
|
|
17
|
+
python -m flexllm.pricing.updater --json
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import argparse
|
|
21
|
+
import json
|
|
22
|
+
import re
|
|
23
|
+
import urllib.request
|
|
24
|
+
from datetime import datetime
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Dict, List, Optional, Tuple
|
|
27
|
+
|
|
28
|
+
# OpenRouter API 端点
|
|
29
|
+
OPENROUTER_API = "https://openrouter.ai/api/v1/models"
|
|
30
|
+
|
|
31
|
+
# 定价文件路径
|
|
32
|
+
PRICING_FILE = Path(__file__).parent / "data.json"
|
|
33
|
+
|
|
34
|
+
# 排除的模型模式
|
|
35
|
+
EXCLUDE_PATTERNS = [
|
|
36
|
+
r":free$", # 免费模型
|
|
37
|
+
r":floor$", # floor 模型
|
|
38
|
+
r":extended$", # extended 模型
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def fetch_models() -> List[Dict]:
|
|
43
|
+
"""从 OpenRouter API 获取模型列表"""
|
|
44
|
+
req = urllib.request.Request(
|
|
45
|
+
OPENROUTER_API,
|
|
46
|
+
headers={"User-Agent": "flexllm-pricing-updater/1.0"}
|
|
47
|
+
)
|
|
48
|
+
with urllib.request.urlopen(req, timeout=30) as resp:
|
|
49
|
+
data = json.loads(resp.read().decode("utf-8"))
|
|
50
|
+
return data.get("data", [])
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def should_exclude(model_id: str) -> bool:
|
|
54
|
+
"""检查模型是否应该排除"""
|
|
55
|
+
for pattern in EXCLUDE_PATTERNS:
|
|
56
|
+
if re.search(pattern, model_id):
|
|
57
|
+
return True
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def normalize_model_id(model_id: str) -> Optional[str]:
|
|
62
|
+
"""
|
|
63
|
+
将 OpenRouter 模型 ID 规范化
|
|
64
|
+
|
|
65
|
+
规则: 直接取斜杠后面的部分
|
|
66
|
+
例如: openai/gpt-4o -> gpt-4o
|
|
67
|
+
anthropic/claude-3.5-sonnet -> claude-3.5-sonnet
|
|
68
|
+
"""
|
|
69
|
+
if should_exclude(model_id):
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
# 移除 :thinking 等后缀
|
|
73
|
+
clean_id = re.sub(r":\w+$", "", model_id)
|
|
74
|
+
|
|
75
|
+
# 取斜杠后面的部分
|
|
76
|
+
if "/" in clean_id:
|
|
77
|
+
return clean_id.split("/", 1)[1]
|
|
78
|
+
|
|
79
|
+
return clean_id
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def parse_pricing(model: Dict) -> Optional[Tuple[float, float]]:
|
|
83
|
+
"""
|
|
84
|
+
解析模型定价信息
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
(input_price, output_price) 单位: $/1M tokens
|
|
88
|
+
"""
|
|
89
|
+
pricing = model.get("pricing", {})
|
|
90
|
+
|
|
91
|
+
prompt_price = pricing.get("prompt", "0")
|
|
92
|
+
completion_price = pricing.get("completion", "0")
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
# OpenRouter 返回的是 $/token,转换为 $/1M tokens
|
|
96
|
+
input_price = float(prompt_price) * 1e6
|
|
97
|
+
output_price = float(completion_price) * 1e6
|
|
98
|
+
|
|
99
|
+
# 过滤免费模型和价格异常的模型
|
|
100
|
+
if input_price == 0 and output_price == 0:
|
|
101
|
+
return None
|
|
102
|
+
if input_price < 0 or output_price < 0:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
return (round(input_price, 4), round(output_price, 4))
|
|
106
|
+
except (ValueError, TypeError):
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def collect_pricing() -> Dict[str, Dict[str, float]]:
|
|
111
|
+
"""
|
|
112
|
+
收集所有模型的定价信息
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
{model_name: {"input": price_per_1m, "output": price_per_1m}}
|
|
116
|
+
"""
|
|
117
|
+
models = fetch_models()
|
|
118
|
+
pricing_map = {}
|
|
119
|
+
|
|
120
|
+
for model in models:
|
|
121
|
+
model_id = model.get("id", "")
|
|
122
|
+
name = normalize_model_id(model_id)
|
|
123
|
+
|
|
124
|
+
if not name:
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
pricing = parse_pricing(model)
|
|
128
|
+
if not pricing:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
input_price, output_price = pricing
|
|
132
|
+
|
|
133
|
+
# 如果同一模型有多个版本,保留价格最低的
|
|
134
|
+
if name in pricing_map:
|
|
135
|
+
existing = pricing_map[name]
|
|
136
|
+
if input_price >= existing["input"]:
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
pricing_map[name] = {
|
|
140
|
+
"input": input_price,
|
|
141
|
+
"output": output_price,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
return pricing_map
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def update_pricing_file(pricing_map: Dict[str, Dict[str, float]]) -> bool:
|
|
148
|
+
"""
|
|
149
|
+
更新 data.json 文件
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
pricing_map: {model_name: {"input": price, "output": price}}
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
是否更新成功
|
|
156
|
+
"""
|
|
157
|
+
# 按模型名排序
|
|
158
|
+
sorted_models = dict(sorted(pricing_map.items()))
|
|
159
|
+
|
|
160
|
+
data = {
|
|
161
|
+
"_meta": {
|
|
162
|
+
"updated_at": datetime.now().strftime("%Y-%m-%d"),
|
|
163
|
+
"source": "OpenRouter API (https://openrouter.ai/api/v1/models)",
|
|
164
|
+
"unit": "$/1M tokens",
|
|
165
|
+
},
|
|
166
|
+
"models": sorted_models,
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
with open(PRICING_FILE, "w", encoding="utf-8") as f:
|
|
171
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
172
|
+
return True
|
|
173
|
+
except Exception as e:
|
|
174
|
+
print(f"Error writing {PRICING_FILE}: {e}")
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def main():
|
|
179
|
+
parser = argparse.ArgumentParser(
|
|
180
|
+
description="从 OpenRouter API 更新模型定价表"
|
|
181
|
+
)
|
|
182
|
+
parser.add_argument(
|
|
183
|
+
"--apply",
|
|
184
|
+
action="store_true",
|
|
185
|
+
help="直接更新 data.json(默认只预览)",
|
|
186
|
+
)
|
|
187
|
+
parser.add_argument(
|
|
188
|
+
"--json",
|
|
189
|
+
action="store_true",
|
|
190
|
+
help="输出 JSON 格式",
|
|
191
|
+
)
|
|
192
|
+
args = parser.parse_args()
|
|
193
|
+
|
|
194
|
+
print("Fetching models from OpenRouter API...")
|
|
195
|
+
pricing_map = collect_pricing()
|
|
196
|
+
print(f"Found {len(pricing_map)} models with pricing info")
|
|
197
|
+
|
|
198
|
+
if args.json:
|
|
199
|
+
print(json.dumps(pricing_map, indent=2, ensure_ascii=False))
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
if args.apply:
|
|
203
|
+
print(f"\nUpdating {PRICING_FILE}...")
|
|
204
|
+
if update_pricing_file(pricing_map):
|
|
205
|
+
print(f"✓ Successfully updated data.json ({len(pricing_map)} models)")
|
|
206
|
+
else:
|
|
207
|
+
print("✗ Failed to update data.json")
|
|
208
|
+
exit(1)
|
|
209
|
+
else:
|
|
210
|
+
print("\n" + "=" * 60)
|
|
211
|
+
print("Preview (use --apply to update data.json):")
|
|
212
|
+
print("=" * 60 + "\n")
|
|
213
|
+
|
|
214
|
+
for name in sorted(pricing_map.keys())[:30]:
|
|
215
|
+
p = pricing_map[name]
|
|
216
|
+
print(f" {name:<40} ${p['input']:<10.4f} / ${p['output']:<10.4f}")
|
|
217
|
+
|
|
218
|
+
if len(pricing_map) > 30:
|
|
219
|
+
print(f" ... and {len(pricing_map) - 30} more models")
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
if __name__ == "__main__":
|
|
223
|
+
main()
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
多 Provider 负载均衡和故障转移
|
|
6
|
+
|
|
7
|
+
支持多个 API endpoint 的轮询、加权分配和自动 fallback。
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import time
|
|
11
|
+
import random
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import List, Optional, Literal
|
|
14
|
+
from threading import Lock
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
Strategy = Literal["round_robin", "weighted", "random", "fallback"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ProviderConfig:
|
|
22
|
+
"""
|
|
23
|
+
单个 Provider 配置
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
base_url: API 基础 URL
|
|
27
|
+
api_key: API 密钥
|
|
28
|
+
weight: 权重 (用于 weighted 策略)
|
|
29
|
+
model: 可选的模型覆盖
|
|
30
|
+
enabled: 是否启用
|
|
31
|
+
"""
|
|
32
|
+
base_url: str
|
|
33
|
+
api_key: str = "EMPTY"
|
|
34
|
+
weight: float = 1.0
|
|
35
|
+
model: Optional[str] = None
|
|
36
|
+
enabled: bool = True
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ProviderStatus:
|
|
41
|
+
"""Provider 运行时状态"""
|
|
42
|
+
config: ProviderConfig
|
|
43
|
+
failures: int = 0
|
|
44
|
+
last_failure: float = 0
|
|
45
|
+
is_healthy: bool = True
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ProviderRouter:
|
|
49
|
+
"""
|
|
50
|
+
Provider 路由器
|
|
51
|
+
|
|
52
|
+
支持策略:
|
|
53
|
+
- round_robin: 轮询
|
|
54
|
+
- weighted: 加权随机
|
|
55
|
+
- random: 随机
|
|
56
|
+
- fallback: 主备模式 (只有主挂了才用备)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
providers: List[ProviderConfig],
|
|
62
|
+
strategy: Strategy = "round_robin",
|
|
63
|
+
failure_threshold: int = 3,
|
|
64
|
+
recovery_time: float = 60.0,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
初始化路由器
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
providers: Provider 配置列表
|
|
71
|
+
strategy: 路由策略
|
|
72
|
+
failure_threshold: 连续失败多少次后标记为不健康
|
|
73
|
+
recovery_time: 不健康后多久尝试恢复 (秒)
|
|
74
|
+
"""
|
|
75
|
+
if not providers:
|
|
76
|
+
raise ValueError("至少需要一个 provider")
|
|
77
|
+
|
|
78
|
+
self.strategy = strategy
|
|
79
|
+
self.failure_threshold = failure_threshold
|
|
80
|
+
self.recovery_time = recovery_time
|
|
81
|
+
|
|
82
|
+
self._providers = [
|
|
83
|
+
ProviderStatus(config=p) for p in providers if p.enabled
|
|
84
|
+
]
|
|
85
|
+
self._index = 0
|
|
86
|
+
self._lock = Lock()
|
|
87
|
+
|
|
88
|
+
if not self._providers:
|
|
89
|
+
raise ValueError("没有可用的 provider")
|
|
90
|
+
|
|
91
|
+
def _get_healthy_providers(self) -> List[ProviderStatus]:
|
|
92
|
+
"""获取健康的 provider 列表"""
|
|
93
|
+
now = time.time()
|
|
94
|
+
healthy = []
|
|
95
|
+
|
|
96
|
+
for p in self._providers:
|
|
97
|
+
# 尝试恢复
|
|
98
|
+
if not p.is_healthy and (now - p.last_failure) > self.recovery_time:
|
|
99
|
+
p.is_healthy = True
|
|
100
|
+
p.failures = 0
|
|
101
|
+
|
|
102
|
+
if p.is_healthy:
|
|
103
|
+
healthy.append(p)
|
|
104
|
+
|
|
105
|
+
return healthy if healthy else self._providers # 全挂时返回所有
|
|
106
|
+
|
|
107
|
+
def get_next(self) -> ProviderConfig:
|
|
108
|
+
"""
|
|
109
|
+
获取下一个可用的 provider
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
ProviderConfig
|
|
113
|
+
"""
|
|
114
|
+
with self._lock:
|
|
115
|
+
healthy = self._get_healthy_providers()
|
|
116
|
+
|
|
117
|
+
if self.strategy == "round_robin":
|
|
118
|
+
self._index = (self._index + 1) % len(healthy)
|
|
119
|
+
return healthy[self._index].config
|
|
120
|
+
|
|
121
|
+
elif self.strategy == "weighted":
|
|
122
|
+
weights = [p.config.weight for p in healthy]
|
|
123
|
+
total = sum(weights)
|
|
124
|
+
r = random.uniform(0, total)
|
|
125
|
+
cumsum = 0
|
|
126
|
+
for p in healthy:
|
|
127
|
+
cumsum += p.config.weight
|
|
128
|
+
if r <= cumsum:
|
|
129
|
+
return p.config
|
|
130
|
+
return healthy[-1].config
|
|
131
|
+
|
|
132
|
+
elif self.strategy == "random":
|
|
133
|
+
return random.choice(healthy).config
|
|
134
|
+
|
|
135
|
+
elif self.strategy == "fallback":
|
|
136
|
+
# 优先使用第一个健康的
|
|
137
|
+
return healthy[0].config
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
return healthy[0].config
|
|
141
|
+
|
|
142
|
+
def mark_failed(self, provider: ProviderConfig) -> None:
|
|
143
|
+
"""
|
|
144
|
+
标记 provider 失败
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
provider: 失败的 provider 配置
|
|
148
|
+
"""
|
|
149
|
+
with self._lock:
|
|
150
|
+
for p in self._providers:
|
|
151
|
+
if p.config.base_url == provider.base_url:
|
|
152
|
+
p.failures += 1
|
|
153
|
+
p.last_failure = time.time()
|
|
154
|
+
if p.failures >= self.failure_threshold:
|
|
155
|
+
p.is_healthy = False
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
def mark_success(self, provider: ProviderConfig) -> None:
|
|
159
|
+
"""
|
|
160
|
+
标记 provider 成功,重置失败计数
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
provider: 成功的 provider 配置
|
|
164
|
+
"""
|
|
165
|
+
with self._lock:
|
|
166
|
+
for p in self._providers:
|
|
167
|
+
if p.config.base_url == provider.base_url:
|
|
168
|
+
p.failures = 0
|
|
169
|
+
p.is_healthy = True
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
def get_all_healthy(self) -> List[ProviderConfig]:
|
|
173
|
+
"""获取所有健康的 provider"""
|
|
174
|
+
with self._lock:
|
|
175
|
+
return [p.config for p in self._get_healthy_providers()]
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def stats(self) -> dict:
|
|
179
|
+
"""返回路由器统计信息"""
|
|
180
|
+
with self._lock:
|
|
181
|
+
return {
|
|
182
|
+
"total": len(self._providers),
|
|
183
|
+
"healthy": sum(1 for p in self._providers if p.is_healthy),
|
|
184
|
+
"strategy": self.strategy,
|
|
185
|
+
"providers": [
|
|
186
|
+
{
|
|
187
|
+
"base_url": p.config.base_url,
|
|
188
|
+
"healthy": p.is_healthy,
|
|
189
|
+
"failures": p.failures,
|
|
190
|
+
}
|
|
191
|
+
for p in self._providers
|
|
192
|
+
],
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def create_router_from_urls(
|
|
197
|
+
urls: List[str],
|
|
198
|
+
api_key: str = "EMPTY",
|
|
199
|
+
strategy: Strategy = "round_robin",
|
|
200
|
+
) -> ProviderRouter:
|
|
201
|
+
"""
|
|
202
|
+
便捷函数:从 URL 列表创建路由器
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
urls: API URL 列表
|
|
206
|
+
api_key: 统一的 API 密钥
|
|
207
|
+
strategy: 路由策略
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
ProviderRouter 实例
|
|
211
|
+
"""
|
|
212
|
+
providers = [ProviderConfig(base_url=url, api_key=api_key) for url in urls]
|
|
213
|
+
return ProviderRouter(providers, strategy=strategy)
|
flexllm/token_counter.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Token 计数和成本估算模块
|
|
6
|
+
|
|
7
|
+
支持使用 tiktoken 精确计算,或在缺失时使用估算方法。
|
|
8
|
+
定价数据从 pricing/data.json 加载,可通过 `flexllm pricing --update` 更新。
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import hashlib
|
|
12
|
+
import json
|
|
13
|
+
from typing import Union, List, Dict, Any, Optional
|
|
14
|
+
|
|
15
|
+
# tiktoken 是可选依赖
|
|
16
|
+
try:
|
|
17
|
+
import tiktoken
|
|
18
|
+
TIKTOKEN_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
TIKTOKEN_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
# 从 pricing 模块导入定价功能
|
|
23
|
+
from .pricing import get_pricing, reload_pricing, estimate_cost as _estimate_cost
|
|
24
|
+
|
|
25
|
+
# 兼容旧 API:MODEL_PRICING 现在是动态获取的
|
|
26
|
+
def _get_model_pricing():
|
|
27
|
+
return get_pricing()
|
|
28
|
+
|
|
29
|
+
# 为了向后兼容,保留 MODEL_PRICING 变量
|
|
30
|
+
MODEL_PRICING = _get_model_pricing()
|
|
31
|
+
|
|
32
|
+
# 模型到 tiktoken 编码器的映射
|
|
33
|
+
MODEL_TO_ENCODING = {
|
|
34
|
+
# GPT-5 系列
|
|
35
|
+
"gpt-5": "o200k_base",
|
|
36
|
+
"gpt-5.1": "o200k_base",
|
|
37
|
+
"gpt-5.1-codex": "o200k_base",
|
|
38
|
+
"gpt-5.2": "o200k_base",
|
|
39
|
+
"gpt-5.2-pro": "o200k_base",
|
|
40
|
+
# GPT-4o 系列
|
|
41
|
+
"gpt-4o": "o200k_base",
|
|
42
|
+
"gpt-4o-mini": "o200k_base",
|
|
43
|
+
# GPT-4.1 系列
|
|
44
|
+
"gpt-4.1": "o200k_base",
|
|
45
|
+
"gpt-4.1-mini": "o200k_base",
|
|
46
|
+
"gpt-4.1-nano": "o200k_base",
|
|
47
|
+
# GPT-4 系列
|
|
48
|
+
"gpt-4-turbo": "cl100k_base",
|
|
49
|
+
"gpt-4": "cl100k_base",
|
|
50
|
+
"gpt-3.5-turbo": "cl100k_base",
|
|
51
|
+
# o 系列推理模型
|
|
52
|
+
"o1": "o200k_base",
|
|
53
|
+
"o1-mini": "o200k_base",
|
|
54
|
+
"o1-pro": "o200k_base",
|
|
55
|
+
"o3": "o200k_base",
|
|
56
|
+
"o3-mini": "o200k_base",
|
|
57
|
+
"o4-mini": "o200k_base",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# 编码器缓存
|
|
61
|
+
_encoding_cache: Dict[str, Any] = {}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_encoding(model: str):
|
|
65
|
+
"""获取模型对应的 tiktoken 编码器"""
|
|
66
|
+
if not TIKTOKEN_AVAILABLE:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
encoding_name = MODEL_TO_ENCODING.get(model, "cl100k_base")
|
|
70
|
+
if encoding_name not in _encoding_cache:
|
|
71
|
+
_encoding_cache[encoding_name] = tiktoken.get_encoding(encoding_name)
|
|
72
|
+
return _encoding_cache[encoding_name]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _estimate_tokens_simple(text: str) -> int:
|
|
76
|
+
"""简单估算:中文约 2 字符/token,英文约 4 字符/token"""
|
|
77
|
+
if not text:
|
|
78
|
+
return 0
|
|
79
|
+
# 粗略统计中文字符
|
|
80
|
+
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
|
81
|
+
other_chars = len(text) - chinese_chars
|
|
82
|
+
return chinese_chars // 2 + other_chars // 4 + 1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def count_tokens(
|
|
86
|
+
content: Union[str, List[Dict], Dict],
|
|
87
|
+
model: str = "gpt-4o"
|
|
88
|
+
) -> int:
|
|
89
|
+
"""
|
|
90
|
+
计算 token 数量
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
content: 文本字符串或 messages 列表
|
|
94
|
+
model: 模型名称,用于选择正确的 tokenizer
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
token 数量
|
|
98
|
+
"""
|
|
99
|
+
# 处理 messages 格式
|
|
100
|
+
if isinstance(content, list):
|
|
101
|
+
total = 0
|
|
102
|
+
for msg in content:
|
|
103
|
+
if isinstance(msg, dict):
|
|
104
|
+
# 每条消息有固定开销
|
|
105
|
+
total += 4 # role + content 标记
|
|
106
|
+
for key, value in msg.items():
|
|
107
|
+
if isinstance(value, str):
|
|
108
|
+
total += count_tokens(value, model)
|
|
109
|
+
elif isinstance(value, list):
|
|
110
|
+
# 处理多模态内容
|
|
111
|
+
for item in value:
|
|
112
|
+
if isinstance(item, dict) and "text" in item:
|
|
113
|
+
total += count_tokens(item["text"], model)
|
|
114
|
+
elif isinstance(item, dict) and "image_url" in item:
|
|
115
|
+
# 图像 token 估算 (低分辨率约 85,高分辨率约 170*tiles)
|
|
116
|
+
total += 85
|
|
117
|
+
return total + 2 # 结束标记
|
|
118
|
+
|
|
119
|
+
if isinstance(content, dict):
|
|
120
|
+
return count_tokens(json.dumps(content, ensure_ascii=False), model)
|
|
121
|
+
|
|
122
|
+
# 文本处理
|
|
123
|
+
text = str(content)
|
|
124
|
+
encoding = _get_encoding(model)
|
|
125
|
+
|
|
126
|
+
if encoding:
|
|
127
|
+
return len(encoding.encode(text))
|
|
128
|
+
else:
|
|
129
|
+
return _estimate_tokens_simple(text)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def count_messages_tokens(
|
|
133
|
+
messages_list: List[List[Dict]],
|
|
134
|
+
model: str = "gpt-4o"
|
|
135
|
+
) -> int:
|
|
136
|
+
"""
|
|
137
|
+
批量计算 messages 的 token 总数
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
messages_list: messages 列表的列表
|
|
141
|
+
model: 模型名称
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
总 token 数量
|
|
145
|
+
"""
|
|
146
|
+
return sum(count_tokens(msgs, model) for msgs in messages_list)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def estimate_cost(
|
|
150
|
+
input_tokens: int,
|
|
151
|
+
output_tokens: int = 0,
|
|
152
|
+
model: str = "gpt-4o"
|
|
153
|
+
) -> float:
|
|
154
|
+
"""
|
|
155
|
+
估算 API 调用成本
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
input_tokens: 输入 token 数
|
|
159
|
+
output_tokens: 输出 token 数 (如果未知可传 0)
|
|
160
|
+
model: 模型名称
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
估算成本 (美元)
|
|
164
|
+
"""
|
|
165
|
+
return _estimate_cost(input_tokens, output_tokens, model)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def estimate_batch_cost(
|
|
169
|
+
messages_list: List[List[Dict]],
|
|
170
|
+
model: str = "gpt-4o",
|
|
171
|
+
avg_output_tokens: int = 500
|
|
172
|
+
) -> Dict[str, Any]:
|
|
173
|
+
"""
|
|
174
|
+
估算批量处理的成本
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
messages_list: messages 列表的列表
|
|
178
|
+
model: 模型名称
|
|
179
|
+
avg_output_tokens: 预估每条请求的平均输出 token 数
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
包含详细估算信息的字典
|
|
183
|
+
"""
|
|
184
|
+
input_tokens = count_messages_tokens(messages_list, model)
|
|
185
|
+
output_tokens = len(messages_list) * avg_output_tokens
|
|
186
|
+
cost = estimate_cost(input_tokens, output_tokens, model)
|
|
187
|
+
|
|
188
|
+
return {
|
|
189
|
+
"count": len(messages_list),
|
|
190
|
+
"input_tokens": input_tokens,
|
|
191
|
+
"estimated_output_tokens": output_tokens,
|
|
192
|
+
"total_tokens": input_tokens + output_tokens,
|
|
193
|
+
"estimated_cost_usd": round(cost, 4),
|
|
194
|
+
"model": model,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _normalize_message_for_hash(message: Dict) -> Dict:
|
|
199
|
+
"""
|
|
200
|
+
规范化消息用于 hash 计算,将 base64 图片替换为其内容 hash
|
|
201
|
+
|
|
202
|
+
这样做的好处:
|
|
203
|
+
1. 减少 hash 计算的数据量(base64 可能有几 MB)
|
|
204
|
+
2. 同一张图片即使重新编码也会产生相同的缓存键
|
|
205
|
+
"""
|
|
206
|
+
if not isinstance(message, dict):
|
|
207
|
+
return message
|
|
208
|
+
|
|
209
|
+
result = {}
|
|
210
|
+
for key, value in message.items():
|
|
211
|
+
if key == "content" and isinstance(value, list):
|
|
212
|
+
# 处理多模态内容(OpenAI 格式)
|
|
213
|
+
normalized_content = []
|
|
214
|
+
for item in value:
|
|
215
|
+
if isinstance(item, dict) and "image_url" in item:
|
|
216
|
+
image_url = item["image_url"]
|
|
217
|
+
if isinstance(image_url, dict):
|
|
218
|
+
url = image_url.get("url", "")
|
|
219
|
+
else:
|
|
220
|
+
url = str(image_url)
|
|
221
|
+
|
|
222
|
+
# 检查是否是 base64 数据
|
|
223
|
+
if url.startswith("data:image"):
|
|
224
|
+
# 提取 base64 部分并计算 hash
|
|
225
|
+
base64_data = url.split(",", 1)[-1] if "," in url else url
|
|
226
|
+
img_hash = hashlib.md5(base64_data.encode()).hexdigest()[:16]
|
|
227
|
+
# 用短 hash 替代完整 base64
|
|
228
|
+
normalized_item = {
|
|
229
|
+
"type": item.get("type", "image_url"),
|
|
230
|
+
"image_url": {"url": f"img_hash:{img_hash}"}
|
|
231
|
+
}
|
|
232
|
+
normalized_content.append(normalized_item)
|
|
233
|
+
else:
|
|
234
|
+
# URL 类型保持不变
|
|
235
|
+
normalized_content.append(item)
|
|
236
|
+
else:
|
|
237
|
+
normalized_content.append(item)
|
|
238
|
+
result[key] = normalized_content
|
|
239
|
+
else:
|
|
240
|
+
result[key] = value
|
|
241
|
+
return result
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def messages_hash(
|
|
245
|
+
messages: List[Dict],
|
|
246
|
+
model: str = "",
|
|
247
|
+
**kwargs
|
|
248
|
+
) -> str:
|
|
249
|
+
"""
|
|
250
|
+
生成 messages 的唯一哈希值,用于缓存键
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
messages: 消息列表
|
|
254
|
+
model: 模型名称
|
|
255
|
+
**kwargs: 其他影响结果的参数 (temperature, max_tokens 等)
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
MD5 哈希字符串
|
|
259
|
+
"""
|
|
260
|
+
# 规范化消息(优化 base64 图片的处理)
|
|
261
|
+
normalized_messages = [_normalize_message_for_hash(m) for m in messages]
|
|
262
|
+
|
|
263
|
+
# 构建要哈希的内容
|
|
264
|
+
cache_key_data = {
|
|
265
|
+
"messages": normalized_messages,
|
|
266
|
+
"model": model,
|
|
267
|
+
**{k: v for k, v in kwargs.items() if v is not None}
|
|
268
|
+
}
|
|
269
|
+
content = json.dumps(cache_key_data, sort_keys=True, ensure_ascii=False)
|
|
270
|
+
return hashlib.md5(content.encode()).hexdigest()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .core import async_retry
|