AstrBot 4.9.2__py3-none-any.whl → 4.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/cli/__init__.py +1 -1
- astrbot/core/agent/message.py +6 -4
- astrbot/core/agent/response.py +22 -1
- astrbot/core/agent/run_context.py +1 -1
- astrbot/core/agent/runners/tool_loop_agent_runner.py +99 -20
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astr_agent_run_util.py +42 -3
- astrbot/core/astr_agent_tool_exec.py +34 -4
- astrbot/core/config/default.py +127 -184
- astrbot/core/core_lifecycle.py +3 -0
- astrbot/core/db/__init__.py +72 -0
- astrbot/core/db/po.py +59 -0
- astrbot/core/db/sqlite.py +240 -0
- astrbot/core/message/components.py +4 -5
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +6 -1
- astrbot/core/pipeline/respond/stage.py +1 -1
- astrbot/core/platform/sources/telegram/tg_event.py +9 -0
- astrbot/core/platform/sources/webchat/webchat_event.py +22 -18
- astrbot/core/provider/entities.py +41 -0
- astrbot/core/provider/manager.py +203 -93
- astrbot/core/provider/sources/anthropic_source.py +55 -11
- astrbot/core/provider/sources/gemini_source.py +84 -33
- astrbot/core/provider/sources/openai_source.py +21 -6
- astrbot/core/star/command_management.py +449 -0
- astrbot/core/star/context.py +4 -0
- astrbot/core/star/filter/command.py +1 -0
- astrbot/core/star/filter/command_group.py +1 -0
- astrbot/core/star/star_handler.py +4 -0
- astrbot/core/star/star_manager.py +2 -0
- astrbot/core/utils/llm_metadata.py +63 -0
- astrbot/core/utils/migra_helper.py +93 -0
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/chat.py +56 -13
- astrbot/dashboard/routes/command.py +82 -0
- astrbot/dashboard/routes/config.py +291 -33
- astrbot/dashboard/routes/stat.py +96 -0
- astrbot/dashboard/routes/tools.py +20 -4
- astrbot/dashboard/server.py +1 -0
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/METADATA +2 -2
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/RECORD +43 -40
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/WHEEL +0 -0
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/provider/manager.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import copy
|
|
2
3
|
import traceback
|
|
3
4
|
from typing import Protocol, runtime_checkable
|
|
4
5
|
|
|
@@ -32,10 +33,12 @@ class ProviderManager:
|
|
|
32
33
|
persona_mgr: PersonaManager,
|
|
33
34
|
):
|
|
34
35
|
self.reload_lock = asyncio.Lock()
|
|
36
|
+
self.resource_lock = asyncio.Lock()
|
|
35
37
|
self.persona_mgr = persona_mgr
|
|
36
38
|
self.acm = acm
|
|
37
39
|
config = acm.confs["default"]
|
|
38
40
|
self.providers_config: list = config["provider"]
|
|
41
|
+
self.provider_sources_config: list = config.get("provider_sources", [])
|
|
39
42
|
self.provider_settings: dict = config["provider_settings"]
|
|
40
43
|
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
|
41
44
|
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
|
@@ -148,6 +151,7 @@ class ProviderManager:
|
|
|
148
151
|
|
|
149
152
|
"""
|
|
150
153
|
provider = None
|
|
154
|
+
provider_id = None
|
|
151
155
|
if umo:
|
|
152
156
|
provider_id = sp.get(
|
|
153
157
|
f"provider_perf_{provider_type.value}",
|
|
@@ -185,6 +189,12 @@ class ProviderManager:
|
|
|
185
189
|
)
|
|
186
190
|
else:
|
|
187
191
|
raise ValueError(f"Unknown provider type: {provider_type}")
|
|
192
|
+
|
|
193
|
+
if not provider and provider_id:
|
|
194
|
+
logger.warning(
|
|
195
|
+
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
|
|
196
|
+
)
|
|
197
|
+
|
|
188
198
|
return provider
|
|
189
199
|
|
|
190
200
|
async def initialize(self):
|
|
@@ -251,7 +261,136 @@ class ProviderManager:
|
|
|
251
261
|
# 初始化 MCP Client 连接
|
|
252
262
|
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
|
253
263
|
|
|
264
|
+
def dynamic_import_provider(self, type: str):
|
|
265
|
+
"""动态导入提供商适配器模块
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
type (str): 提供商请求类型。
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
ImportError: 如果提供商类型未知或无法导入对应模块,则抛出异常。
|
|
272
|
+
"""
|
|
273
|
+
match type:
|
|
274
|
+
case "openai_chat_completion":
|
|
275
|
+
from .sources.openai_source import (
|
|
276
|
+
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
|
277
|
+
)
|
|
278
|
+
case "zhipu_chat_completion":
|
|
279
|
+
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
|
280
|
+
case "groq_chat_completion":
|
|
281
|
+
from .sources.groq_source import ProviderGroq as ProviderGroq
|
|
282
|
+
case "anthropic_chat_completion":
|
|
283
|
+
from .sources.anthropic_source import (
|
|
284
|
+
ProviderAnthropic as ProviderAnthropic,
|
|
285
|
+
)
|
|
286
|
+
case "googlegenai_chat_completion":
|
|
287
|
+
from .sources.gemini_source import (
|
|
288
|
+
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
|
289
|
+
)
|
|
290
|
+
case "sensevoice_stt_selfhost":
|
|
291
|
+
from .sources.sensevoice_selfhosted_source import (
|
|
292
|
+
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
|
293
|
+
)
|
|
294
|
+
case "openai_whisper_api":
|
|
295
|
+
from .sources.whisper_api_source import (
|
|
296
|
+
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
|
297
|
+
)
|
|
298
|
+
case "openai_whisper_selfhost":
|
|
299
|
+
from .sources.whisper_selfhosted_source import (
|
|
300
|
+
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
|
301
|
+
)
|
|
302
|
+
case "xinference_stt":
|
|
303
|
+
from .sources.xinference_stt_provider import (
|
|
304
|
+
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
|
305
|
+
)
|
|
306
|
+
case "openai_tts_api":
|
|
307
|
+
from .sources.openai_tts_api_source import (
|
|
308
|
+
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
|
309
|
+
)
|
|
310
|
+
case "edge_tts":
|
|
311
|
+
from .sources.edge_tts_source import (
|
|
312
|
+
ProviderEdgeTTS as ProviderEdgeTTS,
|
|
313
|
+
)
|
|
314
|
+
case "gsv_tts_selfhost":
|
|
315
|
+
from .sources.gsv_selfhosted_source import (
|
|
316
|
+
ProviderGSVTTS as ProviderGSVTTS,
|
|
317
|
+
)
|
|
318
|
+
case "gsvi_tts_api":
|
|
319
|
+
from .sources.gsvi_tts_source import (
|
|
320
|
+
ProviderGSVITTS as ProviderGSVITTS,
|
|
321
|
+
)
|
|
322
|
+
case "fishaudio_tts_api":
|
|
323
|
+
from .sources.fishaudio_tts_api_source import (
|
|
324
|
+
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
|
325
|
+
)
|
|
326
|
+
case "dashscope_tts":
|
|
327
|
+
from .sources.dashscope_tts import (
|
|
328
|
+
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
|
329
|
+
)
|
|
330
|
+
case "azure_tts":
|
|
331
|
+
from .sources.azure_tts_source import (
|
|
332
|
+
AzureTTSProvider as AzureTTSProvider,
|
|
333
|
+
)
|
|
334
|
+
case "minimax_tts_api":
|
|
335
|
+
from .sources.minimax_tts_api_source import (
|
|
336
|
+
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
|
337
|
+
)
|
|
338
|
+
case "volcengine_tts":
|
|
339
|
+
from .sources.volcengine_tts import (
|
|
340
|
+
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
|
341
|
+
)
|
|
342
|
+
case "gemini_tts":
|
|
343
|
+
from .sources.gemini_tts_source import (
|
|
344
|
+
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
|
345
|
+
)
|
|
346
|
+
case "openai_embedding":
|
|
347
|
+
from .sources.openai_embedding_source import (
|
|
348
|
+
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
|
349
|
+
)
|
|
350
|
+
case "gemini_embedding":
|
|
351
|
+
from .sources.gemini_embedding_source import (
|
|
352
|
+
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
|
353
|
+
)
|
|
354
|
+
case "vllm_rerank":
|
|
355
|
+
from .sources.vllm_rerank_source import (
|
|
356
|
+
VLLMRerankProvider as VLLMRerankProvider,
|
|
357
|
+
)
|
|
358
|
+
case "xinference_rerank":
|
|
359
|
+
from .sources.xinference_rerank_source import (
|
|
360
|
+
XinferenceRerankProvider as XinferenceRerankProvider,
|
|
361
|
+
)
|
|
362
|
+
case "bailian_rerank":
|
|
363
|
+
from .sources.bailian_rerank_source import (
|
|
364
|
+
BailianRerankProvider as BailianRerankProvider,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
def get_merged_provider_config(self, provider_config: dict) -> dict:
|
|
368
|
+
"""获取 provider 配置和 provider_source 配置合并后的结果
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
dict: 合并后的 provider 配置,key 为 provider id,value 为合并后的配置字典
|
|
372
|
+
"""
|
|
373
|
+
pc = copy.deepcopy(provider_config)
|
|
374
|
+
provider_source_id = pc.get("provider_source_id", "")
|
|
375
|
+
if provider_source_id:
|
|
376
|
+
provider_source = None
|
|
377
|
+
for ps in self.provider_sources_config:
|
|
378
|
+
if ps.get("id") == provider_source_id:
|
|
379
|
+
provider_source = ps
|
|
380
|
+
break
|
|
381
|
+
|
|
382
|
+
if provider_source:
|
|
383
|
+
# 合并配置,provider 的配置优先级更高
|
|
384
|
+
merged_config = {**provider_source, **pc}
|
|
385
|
+
# 保持 id 为 provider 的 id,而不是 source 的 id
|
|
386
|
+
merged_config["id"] = pc["id"]
|
|
387
|
+
pc = merged_config
|
|
388
|
+
return pc
|
|
389
|
+
|
|
254
390
|
async def load_provider(self, provider_config: dict):
|
|
391
|
+
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
|
|
392
|
+
provider_config = self.get_merged_provider_config(provider_config)
|
|
393
|
+
|
|
255
394
|
if not provider_config["enable"]:
|
|
256
395
|
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
|
257
396
|
return
|
|
@@ -264,99 +403,7 @@ class ProviderManager:
|
|
|
264
403
|
|
|
265
404
|
# 动态导入
|
|
266
405
|
try:
|
|
267
|
-
|
|
268
|
-
case "openai_chat_completion":
|
|
269
|
-
from .sources.openai_source import (
|
|
270
|
-
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
|
271
|
-
)
|
|
272
|
-
case "zhipu_chat_completion":
|
|
273
|
-
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
|
274
|
-
case "groq_chat_completion":
|
|
275
|
-
from .sources.groq_source import ProviderGroq as ProviderGroq
|
|
276
|
-
case "anthropic_chat_completion":
|
|
277
|
-
from .sources.anthropic_source import (
|
|
278
|
-
ProviderAnthropic as ProviderAnthropic,
|
|
279
|
-
)
|
|
280
|
-
case "googlegenai_chat_completion":
|
|
281
|
-
from .sources.gemini_source import (
|
|
282
|
-
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
|
283
|
-
)
|
|
284
|
-
case "sensevoice_stt_selfhost":
|
|
285
|
-
from .sources.sensevoice_selfhosted_source import (
|
|
286
|
-
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
|
287
|
-
)
|
|
288
|
-
case "openai_whisper_api":
|
|
289
|
-
from .sources.whisper_api_source import (
|
|
290
|
-
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
|
291
|
-
)
|
|
292
|
-
case "openai_whisper_selfhost":
|
|
293
|
-
from .sources.whisper_selfhosted_source import (
|
|
294
|
-
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
|
295
|
-
)
|
|
296
|
-
case "xinference_stt":
|
|
297
|
-
from .sources.xinference_stt_provider import (
|
|
298
|
-
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
|
299
|
-
)
|
|
300
|
-
case "openai_tts_api":
|
|
301
|
-
from .sources.openai_tts_api_source import (
|
|
302
|
-
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
|
303
|
-
)
|
|
304
|
-
case "edge_tts":
|
|
305
|
-
from .sources.edge_tts_source import (
|
|
306
|
-
ProviderEdgeTTS as ProviderEdgeTTS,
|
|
307
|
-
)
|
|
308
|
-
case "gsv_tts_selfhost":
|
|
309
|
-
from .sources.gsv_selfhosted_source import (
|
|
310
|
-
ProviderGSVTTS as ProviderGSVTTS,
|
|
311
|
-
)
|
|
312
|
-
case "gsvi_tts_api":
|
|
313
|
-
from .sources.gsvi_tts_source import (
|
|
314
|
-
ProviderGSVITTS as ProviderGSVITTS,
|
|
315
|
-
)
|
|
316
|
-
case "fishaudio_tts_api":
|
|
317
|
-
from .sources.fishaudio_tts_api_source import (
|
|
318
|
-
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
|
319
|
-
)
|
|
320
|
-
case "dashscope_tts":
|
|
321
|
-
from .sources.dashscope_tts import (
|
|
322
|
-
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
|
323
|
-
)
|
|
324
|
-
case "azure_tts":
|
|
325
|
-
from .sources.azure_tts_source import (
|
|
326
|
-
AzureTTSProvider as AzureTTSProvider,
|
|
327
|
-
)
|
|
328
|
-
case "minimax_tts_api":
|
|
329
|
-
from .sources.minimax_tts_api_source import (
|
|
330
|
-
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
|
331
|
-
)
|
|
332
|
-
case "volcengine_tts":
|
|
333
|
-
from .sources.volcengine_tts import (
|
|
334
|
-
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
|
335
|
-
)
|
|
336
|
-
case "gemini_tts":
|
|
337
|
-
from .sources.gemini_tts_source import (
|
|
338
|
-
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
|
339
|
-
)
|
|
340
|
-
case "openai_embedding":
|
|
341
|
-
from .sources.openai_embedding_source import (
|
|
342
|
-
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
|
343
|
-
)
|
|
344
|
-
case "gemini_embedding":
|
|
345
|
-
from .sources.gemini_embedding_source import (
|
|
346
|
-
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
|
347
|
-
)
|
|
348
|
-
case "vllm_rerank":
|
|
349
|
-
from .sources.vllm_rerank_source import (
|
|
350
|
-
VLLMRerankProvider as VLLMRerankProvider,
|
|
351
|
-
)
|
|
352
|
-
case "xinference_rerank":
|
|
353
|
-
from .sources.xinference_rerank_source import (
|
|
354
|
-
XinferenceRerankProvider as XinferenceRerankProvider,
|
|
355
|
-
)
|
|
356
|
-
case "bailian_rerank":
|
|
357
|
-
from .sources.bailian_rerank_source import (
|
|
358
|
-
BailianRerankProvider as BailianRerankProvider,
|
|
359
|
-
)
|
|
406
|
+
self.dynamic_import_provider(provider_config["type"])
|
|
360
407
|
except (ImportError, ModuleNotFoundError) as e:
|
|
361
408
|
logger.critical(
|
|
362
409
|
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
|
@@ -499,6 +546,7 @@ class ProviderManager:
|
|
|
499
546
|
|
|
500
547
|
# 和配置文件保持同步
|
|
501
548
|
self.providers_config = astrbot_config["provider"]
|
|
549
|
+
self.provider_sources_config = astrbot_config.get("provider_sources", [])
|
|
502
550
|
config_ids = [provider["id"] for provider in self.providers_config]
|
|
503
551
|
logger.info(f"providers in user's config: {config_ids}")
|
|
504
552
|
for key in list(self.inst_map.keys()):
|
|
@@ -570,6 +618,68 @@ class ProviderManager:
|
|
|
570
618
|
)
|
|
571
619
|
del self.inst_map[provider_id]
|
|
572
620
|
|
|
621
|
+
async def delete_provider(
|
|
622
|
+
self, provider_id: str | None = None, provider_source_id: str | None = None
|
|
623
|
+
):
|
|
624
|
+
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
|
|
625
|
+
async with self.resource_lock:
|
|
626
|
+
# delete from config
|
|
627
|
+
target_prov_ids = []
|
|
628
|
+
if provider_id:
|
|
629
|
+
target_prov_ids.append(provider_id)
|
|
630
|
+
else:
|
|
631
|
+
for prov in self.providers_config:
|
|
632
|
+
if prov.get("provider_source_id") == provider_source_id:
|
|
633
|
+
target_prov_ids.append(prov.get("id"))
|
|
634
|
+
config = self.acm.default_conf
|
|
635
|
+
for tpid in target_prov_ids:
|
|
636
|
+
await self.terminate_provider(tpid)
|
|
637
|
+
config["provider"] = [
|
|
638
|
+
prov for prov in config["provider"] if prov.get("id") != tpid
|
|
639
|
+
]
|
|
640
|
+
config.save_config()
|
|
641
|
+
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
|
|
642
|
+
|
|
643
|
+
async def update_provider(self, origin_provider_id: str, new_config: dict):
|
|
644
|
+
"""Update provider config and reload the instance. Config will be saved after update."""
|
|
645
|
+
async with self.resource_lock:
|
|
646
|
+
npid = new_config.get("id", None)
|
|
647
|
+
if not npid:
|
|
648
|
+
raise ValueError("New provider config must have an 'id' field")
|
|
649
|
+
config = self.acm.default_conf
|
|
650
|
+
for provider in config["provider"]:
|
|
651
|
+
if (
|
|
652
|
+
provider.get("id", None) == npid
|
|
653
|
+
and provider.get("id", None) != origin_provider_id
|
|
654
|
+
):
|
|
655
|
+
raise ValueError(f"Provider ID {npid} already exists")
|
|
656
|
+
# update config
|
|
657
|
+
for idx, provider in enumerate(config["provider"]):
|
|
658
|
+
if provider.get("id", None) == origin_provider_id:
|
|
659
|
+
config["provider"][idx] = new_config
|
|
660
|
+
break
|
|
661
|
+
else:
|
|
662
|
+
raise ValueError(f"Provider ID {origin_provider_id} not found")
|
|
663
|
+
config.save_config()
|
|
664
|
+
# reload instance
|
|
665
|
+
await self.reload(new_config)
|
|
666
|
+
|
|
667
|
+
async def create_provider(self, new_config: dict):
|
|
668
|
+
"""Add new provider config and load the instance. Config will be saved after addition."""
|
|
669
|
+
async with self.resource_lock:
|
|
670
|
+
npid = new_config.get("id", None)
|
|
671
|
+
if not npid:
|
|
672
|
+
raise ValueError("New provider config must have an 'id' field")
|
|
673
|
+
config = self.acm.default_conf
|
|
674
|
+
for provider in config["provider"]:
|
|
675
|
+
if provider.get("id", None) == npid:
|
|
676
|
+
raise ValueError(f"Provider ID {npid} already exists")
|
|
677
|
+
# add to config
|
|
678
|
+
config["provider"].append(new_config)
|
|
679
|
+
config.save_config()
|
|
680
|
+
# load instance
|
|
681
|
+
await self.load_provider(new_config)
|
|
682
|
+
|
|
573
683
|
async def terminate(self):
|
|
574
684
|
for provider_inst in self.provider_insts:
|
|
575
685
|
if hasattr(provider_inst, "terminate"):
|
|
@@ -6,10 +6,12 @@ from mimetypes import guess_type
|
|
|
6
6
|
import anthropic
|
|
7
7
|
from anthropic import AsyncAnthropic
|
|
8
8
|
from anthropic.types import Message
|
|
9
|
+
from anthropic.types.message_delta_usage import MessageDeltaUsage
|
|
10
|
+
from anthropic.types.usage import Usage
|
|
9
11
|
|
|
10
12
|
from astrbot import logger
|
|
11
13
|
from astrbot.api.provider import Provider
|
|
12
|
-
from astrbot.core.provider.entities import LLMResponse
|
|
14
|
+
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
|
13
15
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
|
14
16
|
from astrbot.core.utils.io import download_image_by_url
|
|
15
17
|
|
|
@@ -45,7 +47,7 @@ class ProviderAnthropic(Provider):
|
|
|
45
47
|
base_url=self.base_url,
|
|
46
48
|
)
|
|
47
49
|
|
|
48
|
-
self.set_model(provider_config
|
|
50
|
+
self.set_model(provider_config.get("model", "unknown"))
|
|
49
51
|
|
|
50
52
|
def _prepare_payload(self, messages: list[dict]):
|
|
51
53
|
"""准备 Anthropic API 的请求 payload
|
|
@@ -107,12 +109,32 @@ class ProviderAnthropic(Provider):
|
|
|
107
109
|
|
|
108
110
|
return system_prompt, new_messages
|
|
109
111
|
|
|
112
|
+
def _extract_usage(self, usage: Usage) -> TokenUsage:
|
|
113
|
+
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
|
|
114
|
+
return TokenUsage(
|
|
115
|
+
input_other=usage.input_tokens or 0,
|
|
116
|
+
input_cached=usage.cache_read_input_tokens or 0,
|
|
117
|
+
output=usage.output_tokens,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
|
|
121
|
+
if usage.input_tokens is not None:
|
|
122
|
+
token_usage.input_other = usage.input_tokens
|
|
123
|
+
if usage.cache_read_input_tokens is not None:
|
|
124
|
+
token_usage.input_cached = usage.cache_read_input_tokens
|
|
125
|
+
if usage.output_tokens is not None:
|
|
126
|
+
token_usage.output = usage.output_tokens
|
|
127
|
+
|
|
110
128
|
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
|
111
129
|
if tools:
|
|
112
130
|
if tool_list := tools.get_func_desc_anthropic_style():
|
|
113
131
|
payloads["tools"] = tool_list
|
|
114
132
|
|
|
115
|
-
|
|
133
|
+
extra_body = self.provider_config.get("custom_extra_body", {})
|
|
134
|
+
|
|
135
|
+
completion = await self.client.messages.create(
|
|
136
|
+
**payloads, stream=False, extra_body=extra_body
|
|
137
|
+
)
|
|
116
138
|
|
|
117
139
|
assert isinstance(completion, Message)
|
|
118
140
|
logger.debug(f"completion: {completion}")
|
|
@@ -131,6 +153,10 @@ class ProviderAnthropic(Provider):
|
|
|
131
153
|
llm_response.tools_call_args.append(content_block.input)
|
|
132
154
|
llm_response.tools_call_name.append(content_block.name)
|
|
133
155
|
llm_response.tools_call_ids.append(content_block.id)
|
|
156
|
+
|
|
157
|
+
llm_response.id = completion.id
|
|
158
|
+
llm_response.usage = self._extract_usage(completion.usage)
|
|
159
|
+
|
|
134
160
|
# TODO(Soulter): 处理 end_turn 情况
|
|
135
161
|
if not llm_response.completion_text and not llm_response.tools_call_args:
|
|
136
162
|
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
|
@@ -151,10 +177,19 @@ class ProviderAnthropic(Provider):
|
|
|
151
177
|
# 用于累积最终结果
|
|
152
178
|
final_text = ""
|
|
153
179
|
final_tool_calls = []
|
|
180
|
+
id = None
|
|
181
|
+
usage = TokenUsage()
|
|
182
|
+
extra_body = self.provider_config.get("custom_extra_body", {})
|
|
154
183
|
|
|
155
|
-
async with self.client.messages.stream(
|
|
184
|
+
async with self.client.messages.stream(
|
|
185
|
+
**payloads, extra_body=extra_body
|
|
186
|
+
) as stream:
|
|
156
187
|
assert isinstance(stream, anthropic.AsyncMessageStream)
|
|
157
188
|
async for event in stream:
|
|
189
|
+
if event.type == "message_start":
|
|
190
|
+
# the usage contains input token usage
|
|
191
|
+
id = event.message.id
|
|
192
|
+
usage = self._extract_usage(event.message.usage)
|
|
158
193
|
if event.type == "content_block_start":
|
|
159
194
|
if event.content_block.type == "text":
|
|
160
195
|
# 文本块开始
|
|
@@ -162,6 +197,8 @@ class ProviderAnthropic(Provider):
|
|
|
162
197
|
role="assistant",
|
|
163
198
|
completion_text="",
|
|
164
199
|
is_chunk=True,
|
|
200
|
+
usage=usage,
|
|
201
|
+
id=id,
|
|
165
202
|
)
|
|
166
203
|
elif event.content_block.type == "tool_use":
|
|
167
204
|
# 工具使用块开始,初始化缓冲区
|
|
@@ -179,6 +216,8 @@ class ProviderAnthropic(Provider):
|
|
|
179
216
|
role="assistant",
|
|
180
217
|
completion_text=event.delta.text,
|
|
181
218
|
is_chunk=True,
|
|
219
|
+
usage=usage,
|
|
220
|
+
id=id,
|
|
182
221
|
)
|
|
183
222
|
elif event.delta.type == "input_json_delta":
|
|
184
223
|
# 工具调用参数增量
|
|
@@ -215,6 +254,8 @@ class ProviderAnthropic(Provider):
|
|
|
215
254
|
tools_call_name=[tool_info["name"]],
|
|
216
255
|
tools_call_ids=[tool_info["id"]],
|
|
217
256
|
is_chunk=True,
|
|
257
|
+
usage=usage,
|
|
258
|
+
id=id,
|
|
218
259
|
)
|
|
219
260
|
except json.JSONDecodeError:
|
|
220
261
|
# JSON 解析失败,跳过这个工具调用
|
|
@@ -223,11 +264,17 @@ class ProviderAnthropic(Provider):
|
|
|
223
264
|
# 清理缓冲区
|
|
224
265
|
del tool_use_buffer[event.index]
|
|
225
266
|
|
|
267
|
+
elif event.type == "message_delta":
|
|
268
|
+
if event.usage:
|
|
269
|
+
self._update_usage(usage, event.usage)
|
|
270
|
+
|
|
226
271
|
# 返回最终的完整结果
|
|
227
272
|
final_response = LLMResponse(
|
|
228
273
|
role="assistant",
|
|
229
274
|
completion_text=final_text,
|
|
230
275
|
is_chunk=False,
|
|
276
|
+
usage=usage,
|
|
277
|
+
id=id,
|
|
231
278
|
)
|
|
232
279
|
|
|
233
280
|
if final_tool_calls:
|
|
@@ -277,10 +324,9 @@ class ProviderAnthropic(Provider):
|
|
|
277
324
|
|
|
278
325
|
system_prompt, new_messages = self._prepare_payload(context_query)
|
|
279
326
|
|
|
280
|
-
|
|
281
|
-
model_config["model"] = model or self.get_model()
|
|
327
|
+
model = model or self.get_model()
|
|
282
328
|
|
|
283
|
-
payloads = {"messages": new_messages,
|
|
329
|
+
payloads = {"messages": new_messages, "model": model}
|
|
284
330
|
|
|
285
331
|
# Anthropic has a different way of handling system prompts
|
|
286
332
|
if system_prompt:
|
|
@@ -290,7 +336,6 @@ class ProviderAnthropic(Provider):
|
|
|
290
336
|
try:
|
|
291
337
|
llm_response = await self._query(payloads, func_tool)
|
|
292
338
|
except Exception as e:
|
|
293
|
-
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
|
294
339
|
raise e
|
|
295
340
|
|
|
296
341
|
return llm_response
|
|
@@ -332,10 +377,9 @@ class ProviderAnthropic(Provider):
|
|
|
332
377
|
|
|
333
378
|
system_prompt, new_messages = self._prepare_payload(context_query)
|
|
334
379
|
|
|
335
|
-
|
|
336
|
-
model_config["model"] = model or self.get_model()
|
|
380
|
+
model = model or self.get_model()
|
|
337
381
|
|
|
338
|
-
payloads = {"messages": new_messages,
|
|
382
|
+
payloads = {"messages": new_messages, "model": model}
|
|
339
383
|
|
|
340
384
|
# Anthropic has a different way of handling system prompts
|
|
341
385
|
if system_prompt:
|