fast-agent-mcp 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_agent/__init__.py +183 -0
- fast_agent/acp/__init__.py +19 -0
- fast_agent/acp/acp_aware_mixin.py +304 -0
- fast_agent/acp/acp_context.py +437 -0
- fast_agent/acp/content_conversion.py +136 -0
- fast_agent/acp/filesystem_runtime.py +427 -0
- fast_agent/acp/permission_store.py +269 -0
- fast_agent/acp/server/__init__.py +5 -0
- fast_agent/acp/server/agent_acp_server.py +1472 -0
- fast_agent/acp/slash_commands.py +1050 -0
- fast_agent/acp/terminal_runtime.py +408 -0
- fast_agent/acp/tool_permission_adapter.py +125 -0
- fast_agent/acp/tool_permissions.py +474 -0
- fast_agent/acp/tool_progress.py +814 -0
- fast_agent/agents/__init__.py +85 -0
- fast_agent/agents/agent_types.py +64 -0
- fast_agent/agents/llm_agent.py +350 -0
- fast_agent/agents/llm_decorator.py +1139 -0
- fast_agent/agents/mcp_agent.py +1337 -0
- fast_agent/agents/tool_agent.py +271 -0
- fast_agent/agents/workflow/agents_as_tools_agent.py +849 -0
- fast_agent/agents/workflow/chain_agent.py +212 -0
- fast_agent/agents/workflow/evaluator_optimizer.py +380 -0
- fast_agent/agents/workflow/iterative_planner.py +652 -0
- fast_agent/agents/workflow/maker_agent.py +379 -0
- fast_agent/agents/workflow/orchestrator_models.py +218 -0
- fast_agent/agents/workflow/orchestrator_prompts.py +248 -0
- fast_agent/agents/workflow/parallel_agent.py +250 -0
- fast_agent/agents/workflow/router_agent.py +353 -0
- fast_agent/cli/__init__.py +0 -0
- fast_agent/cli/__main__.py +73 -0
- fast_agent/cli/commands/acp.py +159 -0
- fast_agent/cli/commands/auth.py +404 -0
- fast_agent/cli/commands/check_config.py +783 -0
- fast_agent/cli/commands/go.py +514 -0
- fast_agent/cli/commands/quickstart.py +557 -0
- fast_agent/cli/commands/serve.py +143 -0
- fast_agent/cli/commands/server_helpers.py +114 -0
- fast_agent/cli/commands/setup.py +174 -0
- fast_agent/cli/commands/url_parser.py +190 -0
- fast_agent/cli/constants.py +40 -0
- fast_agent/cli/main.py +115 -0
- fast_agent/cli/terminal.py +24 -0
- fast_agent/config.py +798 -0
- fast_agent/constants.py +41 -0
- fast_agent/context.py +279 -0
- fast_agent/context_dependent.py +50 -0
- fast_agent/core/__init__.py +92 -0
- fast_agent/core/agent_app.py +448 -0
- fast_agent/core/core_app.py +137 -0
- fast_agent/core/direct_decorators.py +784 -0
- fast_agent/core/direct_factory.py +620 -0
- fast_agent/core/error_handling.py +27 -0
- fast_agent/core/exceptions.py +90 -0
- fast_agent/core/executor/__init__.py +0 -0
- fast_agent/core/executor/executor.py +280 -0
- fast_agent/core/executor/task_registry.py +32 -0
- fast_agent/core/executor/workflow_signal.py +324 -0
- fast_agent/core/fastagent.py +1186 -0
- fast_agent/core/logging/__init__.py +5 -0
- fast_agent/core/logging/events.py +138 -0
- fast_agent/core/logging/json_serializer.py +164 -0
- fast_agent/core/logging/listeners.py +309 -0
- fast_agent/core/logging/logger.py +278 -0
- fast_agent/core/logging/transport.py +481 -0
- fast_agent/core/prompt.py +9 -0
- fast_agent/core/prompt_templates.py +183 -0
- fast_agent/core/validation.py +326 -0
- fast_agent/event_progress.py +62 -0
- fast_agent/history/history_exporter.py +49 -0
- fast_agent/human_input/__init__.py +47 -0
- fast_agent/human_input/elicitation_handler.py +123 -0
- fast_agent/human_input/elicitation_state.py +33 -0
- fast_agent/human_input/form_elements.py +59 -0
- fast_agent/human_input/form_fields.py +256 -0
- fast_agent/human_input/simple_form.py +113 -0
- fast_agent/human_input/types.py +40 -0
- fast_agent/interfaces.py +310 -0
- fast_agent/llm/__init__.py +9 -0
- fast_agent/llm/cancellation.py +22 -0
- fast_agent/llm/fastagent_llm.py +931 -0
- fast_agent/llm/internal/passthrough.py +161 -0
- fast_agent/llm/internal/playback.py +129 -0
- fast_agent/llm/internal/silent.py +41 -0
- fast_agent/llm/internal/slow.py +38 -0
- fast_agent/llm/memory.py +275 -0
- fast_agent/llm/model_database.py +490 -0
- fast_agent/llm/model_factory.py +388 -0
- fast_agent/llm/model_info.py +102 -0
- fast_agent/llm/prompt_utils.py +155 -0
- fast_agent/llm/provider/anthropic/anthropic_utils.py +84 -0
- fast_agent/llm/provider/anthropic/cache_planner.py +56 -0
- fast_agent/llm/provider/anthropic/llm_anthropic.py +796 -0
- fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py +462 -0
- fast_agent/llm/provider/bedrock/bedrock_utils.py +218 -0
- fast_agent/llm/provider/bedrock/llm_bedrock.py +2207 -0
- fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py +84 -0
- fast_agent/llm/provider/google/google_converter.py +466 -0
- fast_agent/llm/provider/google/llm_google_native.py +681 -0
- fast_agent/llm/provider/openai/llm_aliyun.py +31 -0
- fast_agent/llm/provider/openai/llm_azure.py +143 -0
- fast_agent/llm/provider/openai/llm_deepseek.py +76 -0
- fast_agent/llm/provider/openai/llm_generic.py +35 -0
- fast_agent/llm/provider/openai/llm_google_oai.py +32 -0
- fast_agent/llm/provider/openai/llm_groq.py +42 -0
- fast_agent/llm/provider/openai/llm_huggingface.py +85 -0
- fast_agent/llm/provider/openai/llm_openai.py +1195 -0
- fast_agent/llm/provider/openai/llm_openai_compatible.py +138 -0
- fast_agent/llm/provider/openai/llm_openrouter.py +45 -0
- fast_agent/llm/provider/openai/llm_tensorzero_openai.py +128 -0
- fast_agent/llm/provider/openai/llm_xai.py +38 -0
- fast_agent/llm/provider/openai/multipart_converter_openai.py +561 -0
- fast_agent/llm/provider/openai/openai_multipart.py +169 -0
- fast_agent/llm/provider/openai/openai_utils.py +67 -0
- fast_agent/llm/provider/openai/responses.py +133 -0
- fast_agent/llm/provider_key_manager.py +139 -0
- fast_agent/llm/provider_types.py +34 -0
- fast_agent/llm/request_params.py +61 -0
- fast_agent/llm/sampling_converter.py +98 -0
- fast_agent/llm/stream_types.py +9 -0
- fast_agent/llm/usage_tracking.py +445 -0
- fast_agent/mcp/__init__.py +56 -0
- fast_agent/mcp/common.py +26 -0
- fast_agent/mcp/elicitation_factory.py +84 -0
- fast_agent/mcp/elicitation_handlers.py +164 -0
- fast_agent/mcp/gen_client.py +83 -0
- fast_agent/mcp/helpers/__init__.py +36 -0
- fast_agent/mcp/helpers/content_helpers.py +352 -0
- fast_agent/mcp/helpers/server_config_helpers.py +25 -0
- fast_agent/mcp/hf_auth.py +147 -0
- fast_agent/mcp/interfaces.py +92 -0
- fast_agent/mcp/logger_textio.py +108 -0
- fast_agent/mcp/mcp_agent_client_session.py +411 -0
- fast_agent/mcp/mcp_aggregator.py +2175 -0
- fast_agent/mcp/mcp_connection_manager.py +723 -0
- fast_agent/mcp/mcp_content.py +262 -0
- fast_agent/mcp/mime_utils.py +108 -0
- fast_agent/mcp/oauth_client.py +509 -0
- fast_agent/mcp/prompt.py +159 -0
- fast_agent/mcp/prompt_message_extended.py +155 -0
- fast_agent/mcp/prompt_render.py +84 -0
- fast_agent/mcp/prompt_serialization.py +580 -0
- fast_agent/mcp/prompts/__init__.py +0 -0
- fast_agent/mcp/prompts/__main__.py +7 -0
- fast_agent/mcp/prompts/prompt_constants.py +18 -0
- fast_agent/mcp/prompts/prompt_helpers.py +238 -0
- fast_agent/mcp/prompts/prompt_load.py +186 -0
- fast_agent/mcp/prompts/prompt_server.py +552 -0
- fast_agent/mcp/prompts/prompt_template.py +438 -0
- fast_agent/mcp/resource_utils.py +215 -0
- fast_agent/mcp/sampling.py +200 -0
- fast_agent/mcp/server/__init__.py +4 -0
- fast_agent/mcp/server/agent_server.py +613 -0
- fast_agent/mcp/skybridge.py +44 -0
- fast_agent/mcp/sse_tracking.py +287 -0
- fast_agent/mcp/stdio_tracking_simple.py +59 -0
- fast_agent/mcp/streamable_http_tracking.py +309 -0
- fast_agent/mcp/tool_execution_handler.py +137 -0
- fast_agent/mcp/tool_permission_handler.py +88 -0
- fast_agent/mcp/transport_tracking.py +634 -0
- fast_agent/mcp/types.py +24 -0
- fast_agent/mcp/ui_agent.py +48 -0
- fast_agent/mcp/ui_mixin.py +209 -0
- fast_agent/mcp_server_registry.py +89 -0
- fast_agent/py.typed +0 -0
- fast_agent/resources/examples/data-analysis/analysis-campaign.py +189 -0
- fast_agent/resources/examples/data-analysis/analysis.py +68 -0
- fast_agent/resources/examples/data-analysis/fastagent.config.yaml +41 -0
- fast_agent/resources/examples/data-analysis/mount-point/WA_Fn-UseC_-HR-Employee-Attrition.csv +1471 -0
- fast_agent/resources/examples/mcp/elicitations/elicitation_account_server.py +88 -0
- fast_agent/resources/examples/mcp/elicitations/elicitation_forms_server.py +297 -0
- fast_agent/resources/examples/mcp/elicitations/elicitation_game_server.py +164 -0
- fast_agent/resources/examples/mcp/elicitations/fastagent.config.yaml +35 -0
- fast_agent/resources/examples/mcp/elicitations/fastagent.secrets.yaml.example +17 -0
- fast_agent/resources/examples/mcp/elicitations/forms_demo.py +107 -0
- fast_agent/resources/examples/mcp/elicitations/game_character.py +65 -0
- fast_agent/resources/examples/mcp/elicitations/game_character_handler.py +256 -0
- fast_agent/resources/examples/mcp/elicitations/tool_call.py +21 -0
- fast_agent/resources/examples/mcp/state-transfer/agent_one.py +18 -0
- fast_agent/resources/examples/mcp/state-transfer/agent_two.py +18 -0
- fast_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml +27 -0
- fast_agent/resources/examples/mcp/state-transfer/fastagent.secrets.yaml.example +15 -0
- fast_agent/resources/examples/researcher/fastagent.config.yaml +61 -0
- fast_agent/resources/examples/researcher/researcher-eval.py +53 -0
- fast_agent/resources/examples/researcher/researcher-imp.py +189 -0
- fast_agent/resources/examples/researcher/researcher.py +36 -0
- fast_agent/resources/examples/tensorzero/.env.sample +2 -0
- fast_agent/resources/examples/tensorzero/Makefile +31 -0
- fast_agent/resources/examples/tensorzero/README.md +56 -0
- fast_agent/resources/examples/tensorzero/agent.py +35 -0
- fast_agent/resources/examples/tensorzero/demo_images/clam.jpg +0 -0
- fast_agent/resources/examples/tensorzero/demo_images/crab.png +0 -0
- fast_agent/resources/examples/tensorzero/demo_images/shrimp.png +0 -0
- fast_agent/resources/examples/tensorzero/docker-compose.yml +105 -0
- fast_agent/resources/examples/tensorzero/fastagent.config.yaml +19 -0
- fast_agent/resources/examples/tensorzero/image_demo.py +67 -0
- fast_agent/resources/examples/tensorzero/mcp_server/Dockerfile +25 -0
- fast_agent/resources/examples/tensorzero/mcp_server/entrypoint.sh +35 -0
- fast_agent/resources/examples/tensorzero/mcp_server/mcp_server.py +31 -0
- fast_agent/resources/examples/tensorzero/mcp_server/pyproject.toml +11 -0
- fast_agent/resources/examples/tensorzero/simple_agent.py +25 -0
- fast_agent/resources/examples/tensorzero/tensorzero_config/system_schema.json +29 -0
- fast_agent/resources/examples/tensorzero/tensorzero_config/system_template.minijinja +11 -0
- fast_agent/resources/examples/tensorzero/tensorzero_config/tensorzero.toml +35 -0
- fast_agent/resources/examples/workflows/agents_as_tools_extended.py +73 -0
- fast_agent/resources/examples/workflows/agents_as_tools_simple.py +50 -0
- fast_agent/resources/examples/workflows/chaining.py +37 -0
- fast_agent/resources/examples/workflows/evaluator.py +77 -0
- fast_agent/resources/examples/workflows/fastagent.config.yaml +26 -0
- fast_agent/resources/examples/workflows/graded_report.md +89 -0
- fast_agent/resources/examples/workflows/human_input.py +28 -0
- fast_agent/resources/examples/workflows/maker.py +156 -0
- fast_agent/resources/examples/workflows/orchestrator.py +70 -0
- fast_agent/resources/examples/workflows/parallel.py +56 -0
- fast_agent/resources/examples/workflows/router.py +69 -0
- fast_agent/resources/examples/workflows/short_story.md +13 -0
- fast_agent/resources/examples/workflows/short_story.txt +19 -0
- fast_agent/resources/setup/.gitignore +30 -0
- fast_agent/resources/setup/agent.py +28 -0
- fast_agent/resources/setup/fastagent.config.yaml +65 -0
- fast_agent/resources/setup/fastagent.secrets.yaml.example +38 -0
- fast_agent/resources/setup/pyproject.toml.tmpl +23 -0
- fast_agent/skills/__init__.py +9 -0
- fast_agent/skills/registry.py +235 -0
- fast_agent/tools/elicitation.py +369 -0
- fast_agent/tools/shell_runtime.py +402 -0
- fast_agent/types/__init__.py +59 -0
- fast_agent/types/conversation_summary.py +294 -0
- fast_agent/types/llm_stop_reason.py +78 -0
- fast_agent/types/message_search.py +249 -0
- fast_agent/ui/__init__.py +38 -0
- fast_agent/ui/console.py +59 -0
- fast_agent/ui/console_display.py +1080 -0
- fast_agent/ui/elicitation_form.py +946 -0
- fast_agent/ui/elicitation_style.py +59 -0
- fast_agent/ui/enhanced_prompt.py +1400 -0
- fast_agent/ui/history_display.py +734 -0
- fast_agent/ui/interactive_prompt.py +1199 -0
- fast_agent/ui/markdown_helpers.py +104 -0
- fast_agent/ui/markdown_truncator.py +1004 -0
- fast_agent/ui/mcp_display.py +857 -0
- fast_agent/ui/mcp_ui_utils.py +235 -0
- fast_agent/ui/mermaid_utils.py +169 -0
- fast_agent/ui/message_primitives.py +50 -0
- fast_agent/ui/notification_tracker.py +205 -0
- fast_agent/ui/plain_text_truncator.py +68 -0
- fast_agent/ui/progress_display.py +10 -0
- fast_agent/ui/rich_progress.py +195 -0
- fast_agent/ui/streaming.py +774 -0
- fast_agent/ui/streaming_buffer.py +449 -0
- fast_agent/ui/tool_display.py +422 -0
- fast_agent/ui/usage_display.py +204 -0
- fast_agent/utils/__init__.py +5 -0
- fast_agent/utils/reasoning_stream_parser.py +77 -0
- fast_agent/utils/time.py +22 -0
- fast_agent/workflow_telemetry.py +261 -0
- fast_agent_mcp-0.4.7.dist-info/METADATA +788 -0
- fast_agent_mcp-0.4.7.dist-info/RECORD +261 -0
- fast_agent_mcp-0.4.7.dist-info/WHEEL +4 -0
- fast_agent_mcp-0.4.7.dist-info/entry_points.txt +7 -0
- fast_agent_mcp-0.4.7.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1195 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from mcp import Tool
|
|
9
|
+
from mcp.types import (
|
|
10
|
+
CallToolRequest,
|
|
11
|
+
CallToolRequestParams,
|
|
12
|
+
ContentBlock,
|
|
13
|
+
TextContent,
|
|
14
|
+
)
|
|
15
|
+
from openai import APIError, AsyncOpenAI, AuthenticationError, DefaultAioHttpClient
|
|
16
|
+
from openai.lib.streaming.chat import ChatCompletionStreamState
|
|
17
|
+
|
|
18
|
+
# from openai.types.beta.chat import
|
|
19
|
+
from openai.types.chat import (
|
|
20
|
+
ChatCompletionMessage,
|
|
21
|
+
ChatCompletionMessageParam,
|
|
22
|
+
ChatCompletionSystemMessageParam,
|
|
23
|
+
ChatCompletionToolParam,
|
|
24
|
+
)
|
|
25
|
+
from pydantic_core import from_json
|
|
26
|
+
|
|
27
|
+
from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL, REASONING
|
|
28
|
+
from fast_agent.core.exceptions import ProviderKeyError
|
|
29
|
+
from fast_agent.core.logging.logger import get_logger
|
|
30
|
+
from fast_agent.core.prompt import Prompt
|
|
31
|
+
from fast_agent.event_progress import ProgressAction
|
|
32
|
+
from fast_agent.llm.fastagent_llm import FastAgentLLM, RequestParams
|
|
33
|
+
from fast_agent.llm.model_database import ModelDatabase
|
|
34
|
+
from fast_agent.llm.provider.openai.multipart_converter_openai import OpenAIConverter, OpenAIMessage
|
|
35
|
+
from fast_agent.llm.provider_types import Provider
|
|
36
|
+
from fast_agent.llm.stream_types import StreamChunk
|
|
37
|
+
from fast_agent.llm.usage_tracking import TurnUsage
|
|
38
|
+
from fast_agent.mcp.helpers.content_helpers import get_text, text_content
|
|
39
|
+
from fast_agent.types import LlmStopReason, PromptMessageExtended
|
|
40
|
+
|
|
41
|
+
_logger = get_logger(__name__)
|
|
42
|
+
|
|
43
|
+
DEFAULT_OPENAI_MODEL = "gpt-5-mini"
|
|
44
|
+
DEFAULT_REASONING_EFFORT = "low"
|
|
45
|
+
|
|
46
|
+
# Stream capture mode - when enabled, saves all streaming chunks to files for debugging
|
|
47
|
+
# Set FAST_AGENT_LLM_TRACE=1 (or any non-empty value) to enable
|
|
48
|
+
STREAM_CAPTURE_ENABLED = bool(os.environ.get("FAST_AGENT_LLM_TRACE"))
|
|
49
|
+
STREAM_CAPTURE_DIR = Path("stream-debug")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _stream_capture_filename(turn: int) -> Path | None:
|
|
53
|
+
"""Generate filename for stream capture. Returns None if capture is disabled."""
|
|
54
|
+
if not STREAM_CAPTURE_ENABLED:
|
|
55
|
+
return None
|
|
56
|
+
STREAM_CAPTURE_DIR.mkdir(parents=True, exist_ok=True)
|
|
57
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
58
|
+
return STREAM_CAPTURE_DIR / f"{timestamp}_turn{turn}"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _save_stream_request(filename_base: Path | None, arguments: dict[str, Any]) -> None:
|
|
62
|
+
"""Save the request arguments to a _request.json file."""
|
|
63
|
+
if not filename_base:
|
|
64
|
+
return
|
|
65
|
+
try:
|
|
66
|
+
request_file = filename_base.with_name(f"{filename_base.name}_request.json")
|
|
67
|
+
with open(request_file, "w") as f:
|
|
68
|
+
json.dump(arguments, f, indent=2, default=str)
|
|
69
|
+
except Exception as e:
|
|
70
|
+
_logger.debug(f"Failed to save stream request: {e}")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _save_stream_chunk(filename_base: Path | None, chunk: Any) -> None:
|
|
74
|
+
"""Save a streaming chunk to file when capture mode is enabled."""
|
|
75
|
+
if not filename_base:
|
|
76
|
+
return
|
|
77
|
+
try:
|
|
78
|
+
chunk_file = filename_base.with_name(f"{filename_base.name}.jsonl")
|
|
79
|
+
try:
|
|
80
|
+
payload: Any = chunk.model_dump()
|
|
81
|
+
except Exception:
|
|
82
|
+
payload = str(chunk)
|
|
83
|
+
|
|
84
|
+
with open(chunk_file, "a") as f:
|
|
85
|
+
f.write(json.dumps(payload) + "\n")
|
|
86
|
+
except Exception as e:
|
|
87
|
+
_logger.debug(f"Failed to save stream chunk: {e}")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class OpenAILLM(FastAgentLLM[ChatCompletionMessageParam, ChatCompletionMessage]):
|
|
91
|
+
# Config section name override (falls back to provider value)
|
|
92
|
+
config_section: str | None = None
|
|
93
|
+
# OpenAI-specific parameter exclusions
|
|
94
|
+
OPENAI_EXCLUDE_FIELDS = {
|
|
95
|
+
FastAgentLLM.PARAM_MESSAGES,
|
|
96
|
+
FastAgentLLM.PARAM_MODEL,
|
|
97
|
+
FastAgentLLM.PARAM_MAX_TOKENS,
|
|
98
|
+
FastAgentLLM.PARAM_SYSTEM_PROMPT,
|
|
99
|
+
FastAgentLLM.PARAM_PARALLEL_TOOL_CALLS,
|
|
100
|
+
FastAgentLLM.PARAM_USE_HISTORY,
|
|
101
|
+
FastAgentLLM.PARAM_MAX_ITERATIONS,
|
|
102
|
+
FastAgentLLM.PARAM_TEMPLATE_VARS,
|
|
103
|
+
FastAgentLLM.PARAM_MCP_METADATA,
|
|
104
|
+
FastAgentLLM.PARAM_STOP_SEQUENCES,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
def __init__(self, provider: Provider = Provider.OPENAI, *args, **kwargs) -> None:
|
|
108
|
+
super().__init__(*args, provider=provider, **kwargs)
|
|
109
|
+
|
|
110
|
+
# Initialize logger with name if available
|
|
111
|
+
self.logger = get_logger(f"{__name__}.{self.name}" if self.name else __name__)
|
|
112
|
+
|
|
113
|
+
# Set up reasoning-related attributes
|
|
114
|
+
self._reasoning_effort = kwargs.get("reasoning_effort", None)
|
|
115
|
+
if self.context and self.context.config and self.context.config.openai:
|
|
116
|
+
if self._reasoning_effort is None and hasattr(
|
|
117
|
+
self.context.config.openai, "reasoning_effort"
|
|
118
|
+
):
|
|
119
|
+
self._reasoning_effort = self.context.config.openai.reasoning_effort
|
|
120
|
+
|
|
121
|
+
# Determine reasoning mode for the selected model
|
|
122
|
+
chosen_model = self.default_request_params.model if self.default_request_params else None
|
|
123
|
+
self._reasoning_mode = ModelDatabase.get_reasoning(chosen_model)
|
|
124
|
+
self._reasoning = self._reasoning_mode == "openai"
|
|
125
|
+
if self._reasoning_mode:
|
|
126
|
+
self.logger.info(
|
|
127
|
+
f"Using reasoning model '{chosen_model}' (mode='{self._reasoning_mode}') with "
|
|
128
|
+
f"'{self._reasoning_effort}' reasoning effort"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
|
132
|
+
"""Initialize OpenAI-specific default parameters"""
|
|
133
|
+
# Get base defaults from parent (includes ModelDatabase lookup)
|
|
134
|
+
base_params = super()._initialize_default_params(kwargs)
|
|
135
|
+
|
|
136
|
+
# Override with OpenAI-specific settings
|
|
137
|
+
chosen_model = kwargs.get("model", DEFAULT_OPENAI_MODEL)
|
|
138
|
+
base_params.model = chosen_model
|
|
139
|
+
|
|
140
|
+
return base_params
|
|
141
|
+
|
|
142
|
+
def _base_url(self) -> str:
|
|
143
|
+
return self.context.config.openai.base_url if self.context.config.openai else None
|
|
144
|
+
|
|
145
|
+
def _default_headers(self) -> dict[str, str] | None:
|
|
146
|
+
"""
|
|
147
|
+
Get custom headers from configuration.
|
|
148
|
+
Subclasses can override this to provide provider-specific headers.
|
|
149
|
+
"""
|
|
150
|
+
provider_config = self._get_provider_config()
|
|
151
|
+
return getattr(provider_config, "default_headers", None) if provider_config else None
|
|
152
|
+
|
|
153
|
+
def _get_provider_config(self):
|
|
154
|
+
"""Return the config section for this provider, if available."""
|
|
155
|
+
context_config = getattr(self.context, "config", None)
|
|
156
|
+
if not context_config:
|
|
157
|
+
return None
|
|
158
|
+
section_name = self.config_section or getattr(self.provider, "value", None)
|
|
159
|
+
if not section_name:
|
|
160
|
+
return None
|
|
161
|
+
return getattr(context_config, section_name, None)
|
|
162
|
+
|
|
163
|
+
def _openai_client(self) -> AsyncOpenAI:
|
|
164
|
+
"""
|
|
165
|
+
Create an OpenAI client instance.
|
|
166
|
+
Subclasses can override this to provide different client types (e.g., AzureOpenAI).
|
|
167
|
+
|
|
168
|
+
Note: The returned client should be used within an async context manager
|
|
169
|
+
to ensure proper cleanup of aiohttp sessions.
|
|
170
|
+
"""
|
|
171
|
+
try:
|
|
172
|
+
kwargs: dict[str, Any] = {
|
|
173
|
+
"api_key": self._api_key(),
|
|
174
|
+
"base_url": self._base_url(),
|
|
175
|
+
"http_client": DefaultAioHttpClient(),
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
# Add custom headers if configured
|
|
179
|
+
default_headers = self._default_headers()
|
|
180
|
+
if default_headers:
|
|
181
|
+
kwargs["default_headers"] = default_headers
|
|
182
|
+
|
|
183
|
+
return AsyncOpenAI(**kwargs)
|
|
184
|
+
except AuthenticationError as e:
|
|
185
|
+
raise ProviderKeyError(
|
|
186
|
+
"Invalid OpenAI API key",
|
|
187
|
+
"The configured OpenAI API key was rejected.\n"
|
|
188
|
+
"Please check that your API key is valid and not expired.",
|
|
189
|
+
) from e
|
|
190
|
+
|
|
191
|
+
def _streams_tool_arguments(self) -> bool:
|
|
192
|
+
"""
|
|
193
|
+
Determine whether the current provider streams tool call arguments incrementally.
|
|
194
|
+
|
|
195
|
+
Official OpenAI and Azure OpenAI endpoints stream arguments. Most third-party
|
|
196
|
+
OpenAI-compatible gateways (e.g. OpenRouter, Moonshot) deliver the full arguments
|
|
197
|
+
once, so we should treat them as non-streaming to restore the legacy \"Calling Tool\"
|
|
198
|
+
display experience.
|
|
199
|
+
"""
|
|
200
|
+
if self.provider in (Provider.AZURE, Provider.HUGGINGFACE):
|
|
201
|
+
return True
|
|
202
|
+
|
|
203
|
+
if self.provider == Provider.OPENAI:
|
|
204
|
+
base_url = self._base_url()
|
|
205
|
+
if not base_url:
|
|
206
|
+
return True
|
|
207
|
+
lowered = base_url.lower()
|
|
208
|
+
return "api.openai" in lowered or "openai.azure" in lowered or "azure.com" in lowered
|
|
209
|
+
|
|
210
|
+
return False
|
|
211
|
+
|
|
212
|
+
def _emit_tool_notification_fallback(
|
|
213
|
+
self,
|
|
214
|
+
tool_calls: Any,
|
|
215
|
+
notified_indices: set[int],
|
|
216
|
+
*,
|
|
217
|
+
streams_arguments: bool,
|
|
218
|
+
model: str,
|
|
219
|
+
) -> None:
|
|
220
|
+
"""Emit start/stop notifications when streaming metadata was missing."""
|
|
221
|
+
if not tool_calls:
|
|
222
|
+
return
|
|
223
|
+
|
|
224
|
+
for index, tool_call in enumerate(tool_calls):
|
|
225
|
+
if index in notified_indices:
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
tool_name = None
|
|
229
|
+
tool_use_id = None
|
|
230
|
+
|
|
231
|
+
try:
|
|
232
|
+
tool_use_id = getattr(tool_call, "id", None)
|
|
233
|
+
function = getattr(tool_call, "function", None)
|
|
234
|
+
if function:
|
|
235
|
+
tool_name = getattr(function, "name", None)
|
|
236
|
+
except Exception:
|
|
237
|
+
tool_use_id = None
|
|
238
|
+
tool_name = None
|
|
239
|
+
|
|
240
|
+
if not tool_name:
|
|
241
|
+
tool_name = "tool"
|
|
242
|
+
if not tool_use_id:
|
|
243
|
+
tool_use_id = f"tool-{index}"
|
|
244
|
+
|
|
245
|
+
payload = {
|
|
246
|
+
"tool_name": tool_name,
|
|
247
|
+
"tool_use_id": tool_use_id,
|
|
248
|
+
"index": index,
|
|
249
|
+
"streams_arguments": streams_arguments,
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
self._notify_tool_stream_listeners("start", payload)
|
|
253
|
+
self.logger.info(
|
|
254
|
+
"Model emitted fallback tool notification",
|
|
255
|
+
data={
|
|
256
|
+
"progress_action": ProgressAction.CALLING_TOOL,
|
|
257
|
+
"agent_name": self.name,
|
|
258
|
+
"model": model,
|
|
259
|
+
"tool_name": tool_name,
|
|
260
|
+
"tool_use_id": tool_use_id,
|
|
261
|
+
"tool_event": "start",
|
|
262
|
+
"streams_arguments": streams_arguments,
|
|
263
|
+
"fallback": True,
|
|
264
|
+
},
|
|
265
|
+
)
|
|
266
|
+
self._notify_tool_stream_listeners("stop", payload)
|
|
267
|
+
self.logger.info(
|
|
268
|
+
"Model emitted fallback tool notification",
|
|
269
|
+
data={
|
|
270
|
+
"progress_action": ProgressAction.CALLING_TOOL,
|
|
271
|
+
"agent_name": self.name,
|
|
272
|
+
"model": model,
|
|
273
|
+
"tool_name": tool_name,
|
|
274
|
+
"tool_use_id": tool_use_id,
|
|
275
|
+
"tool_event": "stop",
|
|
276
|
+
"streams_arguments": streams_arguments,
|
|
277
|
+
"fallback": True,
|
|
278
|
+
},
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def _handle_reasoning_delta(
|
|
282
|
+
self,
|
|
283
|
+
*,
|
|
284
|
+
reasoning_mode: str | None,
|
|
285
|
+
reasoning_text: str,
|
|
286
|
+
reasoning_active: bool,
|
|
287
|
+
reasoning_segments: list[str],
|
|
288
|
+
) -> bool:
|
|
289
|
+
"""Stream reasoning text and track whether a thinking block is open."""
|
|
290
|
+
if not reasoning_text:
|
|
291
|
+
return reasoning_active
|
|
292
|
+
|
|
293
|
+
if reasoning_mode == "tags":
|
|
294
|
+
if not reasoning_active:
|
|
295
|
+
reasoning_active = True
|
|
296
|
+
self._notify_stream_listeners(StreamChunk(text=reasoning_text, is_reasoning=True))
|
|
297
|
+
reasoning_segments.append(reasoning_text)
|
|
298
|
+
return reasoning_active
|
|
299
|
+
|
|
300
|
+
if reasoning_mode in {"stream", "reasoning_content", "gpt_oss"}:
|
|
301
|
+
# Emit reasoning as-is
|
|
302
|
+
self._notify_stream_listeners(StreamChunk(text=reasoning_text, is_reasoning=True))
|
|
303
|
+
reasoning_segments.append(reasoning_text)
|
|
304
|
+
return reasoning_active
|
|
305
|
+
|
|
306
|
+
return reasoning_active
|
|
307
|
+
|
|
308
|
+
def _handle_tool_delta(
|
|
309
|
+
self,
|
|
310
|
+
*,
|
|
311
|
+
delta_tool_calls: Any,
|
|
312
|
+
tool_call_started: dict[int, dict[str, Any]],
|
|
313
|
+
streams_arguments: bool,
|
|
314
|
+
model: str,
|
|
315
|
+
notified_tool_indices: set[int],
|
|
316
|
+
) -> None:
|
|
317
|
+
"""Emit tool call start/delta events and keep state in sync."""
|
|
318
|
+
for tool_call in delta_tool_calls:
|
|
319
|
+
index = tool_call.index
|
|
320
|
+
if index is None:
|
|
321
|
+
continue
|
|
322
|
+
|
|
323
|
+
existing_info = tool_call_started.get(index)
|
|
324
|
+
tool_use_id = tool_call.id or (
|
|
325
|
+
existing_info.get("tool_use_id") if existing_info else None
|
|
326
|
+
)
|
|
327
|
+
function_name = (
|
|
328
|
+
tool_call.function.name
|
|
329
|
+
if tool_call.function and tool_call.function.name
|
|
330
|
+
else (existing_info.get("tool_name") if existing_info else None)
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
if existing_info is None and tool_use_id and function_name:
|
|
334
|
+
tool_call_started[index] = {
|
|
335
|
+
"tool_name": function_name,
|
|
336
|
+
"tool_use_id": tool_use_id,
|
|
337
|
+
"streams_arguments": streams_arguments,
|
|
338
|
+
}
|
|
339
|
+
self._notify_tool_stream_listeners(
|
|
340
|
+
"start",
|
|
341
|
+
{
|
|
342
|
+
"tool_name": function_name,
|
|
343
|
+
"tool_use_id": tool_use_id,
|
|
344
|
+
"index": index,
|
|
345
|
+
"streams_arguments": streams_arguments,
|
|
346
|
+
},
|
|
347
|
+
)
|
|
348
|
+
self.logger.info(
|
|
349
|
+
"Model started streaming tool call",
|
|
350
|
+
data={
|
|
351
|
+
"progress_action": ProgressAction.CALLING_TOOL,
|
|
352
|
+
"agent_name": self.name,
|
|
353
|
+
"model": model,
|
|
354
|
+
"tool_name": function_name,
|
|
355
|
+
"tool_use_id": tool_use_id,
|
|
356
|
+
"tool_event": "start",
|
|
357
|
+
"streams_arguments": streams_arguments,
|
|
358
|
+
},
|
|
359
|
+
)
|
|
360
|
+
notified_tool_indices.add(index)
|
|
361
|
+
elif existing_info:
|
|
362
|
+
if tool_use_id:
|
|
363
|
+
existing_info["tool_use_id"] = tool_use_id
|
|
364
|
+
if function_name:
|
|
365
|
+
existing_info["tool_name"] = function_name
|
|
366
|
+
|
|
367
|
+
if tool_call.function and tool_call.function.arguments:
|
|
368
|
+
info = tool_call_started.setdefault(
|
|
369
|
+
index,
|
|
370
|
+
{
|
|
371
|
+
"tool_name": function_name,
|
|
372
|
+
"tool_use_id": tool_use_id,
|
|
373
|
+
"streams_arguments": streams_arguments,
|
|
374
|
+
},
|
|
375
|
+
)
|
|
376
|
+
self._notify_tool_stream_listeners(
|
|
377
|
+
"delta",
|
|
378
|
+
{
|
|
379
|
+
"tool_name": info.get("tool_name"),
|
|
380
|
+
"tool_use_id": info.get("tool_use_id"),
|
|
381
|
+
"index": index,
|
|
382
|
+
"chunk": tool_call.function.arguments,
|
|
383
|
+
"streams_arguments": info.get("streams_arguments", False),
|
|
384
|
+
},
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def _finalize_tool_calls_on_stop(
|
|
388
|
+
self,
|
|
389
|
+
*,
|
|
390
|
+
tool_call_started: dict[int, dict[str, Any]],
|
|
391
|
+
streams_arguments: bool,
|
|
392
|
+
model: str,
|
|
393
|
+
notified_tool_indices: set[int],
|
|
394
|
+
) -> None:
|
|
395
|
+
"""Emit stop events for any in-flight tool calls and clear state."""
|
|
396
|
+
for index, info in list(tool_call_started.items()):
|
|
397
|
+
self._notify_tool_stream_listeners(
|
|
398
|
+
"stop",
|
|
399
|
+
{
|
|
400
|
+
"tool_name": info.get("tool_name"),
|
|
401
|
+
"tool_use_id": info.get("tool_use_id"),
|
|
402
|
+
"index": index,
|
|
403
|
+
"streams_arguments": info.get("streams_arguments", False),
|
|
404
|
+
},
|
|
405
|
+
)
|
|
406
|
+
self.logger.info(
|
|
407
|
+
"Model finished streaming tool call",
|
|
408
|
+
data={
|
|
409
|
+
"progress_action": ProgressAction.CALLING_TOOL,
|
|
410
|
+
"agent_name": self.name,
|
|
411
|
+
"model": model,
|
|
412
|
+
"tool_name": info.get("tool_name"),
|
|
413
|
+
"tool_use_id": info.get("tool_use_id"),
|
|
414
|
+
"tool_event": "stop",
|
|
415
|
+
"streams_arguments": info.get("streams_arguments", False),
|
|
416
|
+
},
|
|
417
|
+
)
|
|
418
|
+
notified_tool_indices.add(index)
|
|
419
|
+
tool_call_started.clear()
|
|
420
|
+
|
|
421
|
+
def _emit_text_delta(
|
|
422
|
+
self,
|
|
423
|
+
*,
|
|
424
|
+
content: str,
|
|
425
|
+
model: str,
|
|
426
|
+
estimated_tokens: int,
|
|
427
|
+
streams_arguments: bool,
|
|
428
|
+
reasoning_active: bool,
|
|
429
|
+
) -> tuple[int, bool]:
|
|
430
|
+
"""Emit text deltas and close any active reasoning block."""
|
|
431
|
+
if reasoning_active:
|
|
432
|
+
reasoning_active = False
|
|
433
|
+
|
|
434
|
+
self._notify_stream_listeners(StreamChunk(text=content, is_reasoning=False))
|
|
435
|
+
estimated_tokens = self._update_streaming_progress(content, model, estimated_tokens)
|
|
436
|
+
self._notify_tool_stream_listeners(
|
|
437
|
+
"text",
|
|
438
|
+
{
|
|
439
|
+
"chunk": content,
|
|
440
|
+
"streams_arguments": streams_arguments,
|
|
441
|
+
},
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
return estimated_tokens, reasoning_active
|
|
445
|
+
|
|
446
|
+
def _close_reasoning_if_active(self, reasoning_active: bool) -> bool:
|
|
447
|
+
"""Return reasoning state; kept for symmetry."""
|
|
448
|
+
return False if reasoning_active else reasoning_active
|
|
449
|
+
|
|
450
|
+
async def _process_stream(
|
|
451
|
+
self,
|
|
452
|
+
stream,
|
|
453
|
+
model: str,
|
|
454
|
+
capture_filename: Path | None = None,
|
|
455
|
+
) -> tuple[Any, list[str]]:
|
|
456
|
+
"""Process the streaming response and display real-time token usage."""
|
|
457
|
+
# Track estimated output tokens by counting text chunks
|
|
458
|
+
estimated_tokens = 0
|
|
459
|
+
reasoning_active = False
|
|
460
|
+
reasoning_segments: list[str] = []
|
|
461
|
+
reasoning_mode = ModelDatabase.get_reasoning(model)
|
|
462
|
+
|
|
463
|
+
# For providers/models that emit non-OpenAI deltas, fall back to manual accumulation
|
|
464
|
+
stream_mode = ModelDatabase.get_stream_mode(model)
|
|
465
|
+
provider_requires_manual = self.provider in [
|
|
466
|
+
Provider.GENERIC,
|
|
467
|
+
Provider.OPENROUTER,
|
|
468
|
+
Provider.GOOGLE_OAI,
|
|
469
|
+
]
|
|
470
|
+
if stream_mode == "manual" or provider_requires_manual:
|
|
471
|
+
return await self._process_stream_manual(stream, model, capture_filename)
|
|
472
|
+
|
|
473
|
+
# Use ChatCompletionStreamState helper for accumulation (OpenAI only)
|
|
474
|
+
state = ChatCompletionStreamState()
|
|
475
|
+
|
|
476
|
+
# Track tool call state for stream events
|
|
477
|
+
tool_call_started: dict[int, dict[str, Any]] = {}
|
|
478
|
+
streams_arguments = self._streams_tool_arguments()
|
|
479
|
+
notified_tool_indices: set[int] = set()
|
|
480
|
+
|
|
481
|
+
# Process the stream chunks
|
|
482
|
+
# Cancellation is handled via asyncio.Task.cancel() which raises CancelledError
|
|
483
|
+
async for chunk in stream:
|
|
484
|
+
# Save chunk if stream capture is enabled
|
|
485
|
+
_save_stream_chunk(capture_filename, chunk)
|
|
486
|
+
# Handle chunk accumulation
|
|
487
|
+
state.handle_chunk(chunk)
|
|
488
|
+
# Process streaming events for tool calls
|
|
489
|
+
if chunk.choices:
|
|
490
|
+
choice = chunk.choices[0]
|
|
491
|
+
delta = choice.delta
|
|
492
|
+
reasoning_text = self._extract_reasoning_text(
|
|
493
|
+
reasoning=getattr(delta, "reasoning", None),
|
|
494
|
+
reasoning_content=getattr(delta, "reasoning_content", None),
|
|
495
|
+
)
|
|
496
|
+
reasoning_active = self._handle_reasoning_delta(
|
|
497
|
+
reasoning_mode=reasoning_mode,
|
|
498
|
+
reasoning_text=reasoning_text,
|
|
499
|
+
reasoning_active=reasoning_active,
|
|
500
|
+
reasoning_segments=reasoning_segments,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# Handle tool call streaming
|
|
504
|
+
if delta.tool_calls:
|
|
505
|
+
self._handle_tool_delta(
|
|
506
|
+
delta_tool_calls=delta.tool_calls,
|
|
507
|
+
tool_call_started=tool_call_started,
|
|
508
|
+
streams_arguments=streams_arguments,
|
|
509
|
+
model=model,
|
|
510
|
+
notified_tool_indices=notified_tool_indices,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Handle text content streaming
|
|
514
|
+
if delta.content:
|
|
515
|
+
estimated_tokens, reasoning_active = self._emit_text_delta(
|
|
516
|
+
content=delta.content,
|
|
517
|
+
model=model,
|
|
518
|
+
estimated_tokens=estimated_tokens,
|
|
519
|
+
streams_arguments=streams_arguments,
|
|
520
|
+
reasoning_active=reasoning_active,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
# Fire "stop" event when tool calls complete
|
|
524
|
+
if choice.finish_reason == "tool_calls":
|
|
525
|
+
self._finalize_tool_calls_on_stop(
|
|
526
|
+
tool_call_started=tool_call_started,
|
|
527
|
+
streams_arguments=streams_arguments,
|
|
528
|
+
model=model,
|
|
529
|
+
notified_tool_indices=notified_tool_indices,
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# Check if we hit the length limit to avoid LengthFinishReasonError
|
|
533
|
+
current_snapshot = state.current_completion_snapshot
|
|
534
|
+
if current_snapshot.choices and current_snapshot.choices[0].finish_reason == "length":
|
|
535
|
+
# Return the current snapshot directly to avoid exception
|
|
536
|
+
final_completion = current_snapshot
|
|
537
|
+
else:
|
|
538
|
+
# Get the final completion with usage data (may include structured output parsing)
|
|
539
|
+
final_completion = state.get_final_completion()
|
|
540
|
+
|
|
541
|
+
reasoning_active = self._close_reasoning_if_active(reasoning_active)
|
|
542
|
+
|
|
543
|
+
# Log final usage information
|
|
544
|
+
if hasattr(final_completion, "usage") and final_completion.usage:
|
|
545
|
+
actual_tokens = final_completion.usage.completion_tokens
|
|
546
|
+
# Emit final progress with actual token count
|
|
547
|
+
token_str = str(actual_tokens).rjust(5)
|
|
548
|
+
data = {
|
|
549
|
+
"progress_action": ProgressAction.STREAMING,
|
|
550
|
+
"model": model,
|
|
551
|
+
"agent_name": self.name,
|
|
552
|
+
"chat_turn": self.chat_turn(),
|
|
553
|
+
"details": token_str.strip(),
|
|
554
|
+
}
|
|
555
|
+
self.logger.info("Streaming progress", data=data)
|
|
556
|
+
|
|
557
|
+
self.logger.info(
|
|
558
|
+
f"Streaming complete - Model: {model}, Input tokens: {final_completion.usage.prompt_tokens}, Output tokens: {final_completion.usage.completion_tokens}"
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
final_message = None
|
|
562
|
+
if hasattr(final_completion, "choices") and final_completion.choices:
|
|
563
|
+
final_message = getattr(final_completion.choices[0], "message", None)
|
|
564
|
+
tool_calls = getattr(final_message, "tool_calls", None) if final_message else None
|
|
565
|
+
self._emit_tool_notification_fallback(
|
|
566
|
+
tool_calls,
|
|
567
|
+
notified_tool_indices,
|
|
568
|
+
streams_arguments=streams_arguments,
|
|
569
|
+
model=model,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
return final_completion, reasoning_segments
|
|
573
|
+
|
|
574
|
+
def _normalize_role(self, role: str | None) -> str:
|
|
575
|
+
"""Ensure the role string matches MCP expectations."""
|
|
576
|
+
default_role = "assistant"
|
|
577
|
+
if not role:
|
|
578
|
+
return default_role
|
|
579
|
+
|
|
580
|
+
lowered = role.lower()
|
|
581
|
+
allowed_roles = {"assistant", "user", "system", "tool"}
|
|
582
|
+
if lowered in allowed_roles:
|
|
583
|
+
return lowered
|
|
584
|
+
|
|
585
|
+
for candidate in allowed_roles:
|
|
586
|
+
if len(lowered) % len(candidate) == 0:
|
|
587
|
+
repetitions = len(lowered) // len(candidate)
|
|
588
|
+
if candidate * repetitions == lowered:
|
|
589
|
+
self.logger.info(
|
|
590
|
+
"Collapsing repeated role value from provider",
|
|
591
|
+
data={
|
|
592
|
+
"original_role": role,
|
|
593
|
+
"normalized_role": candidate,
|
|
594
|
+
},
|
|
595
|
+
)
|
|
596
|
+
return candidate
|
|
597
|
+
|
|
598
|
+
self.logger.warning(
|
|
599
|
+
"Model emitted unsupported role; defaulting to assistant",
|
|
600
|
+
data={"original_role": role},
|
|
601
|
+
)
|
|
602
|
+
return default_role
|
|
603
|
+
|
|
604
|
+
# TODO - as per other comment this needs to go in another class. There are a number of "special" cases dealt with
|
|
605
|
+
# here to deal with OpenRouter idiosyncrasies between e.g. Anthropic and Gemini models.
|
|
606
|
+
async def _process_stream_manual(
|
|
607
|
+
self,
|
|
608
|
+
stream,
|
|
609
|
+
model: str,
|
|
610
|
+
capture_filename: Path | None = None,
|
|
611
|
+
) -> tuple[Any, list[str]]:
|
|
612
|
+
"""Manual stream processing for providers like Ollama that may not work with ChatCompletionStreamState."""
|
|
613
|
+
|
|
614
|
+
from openai.types.chat import ChatCompletionMessageToolCall
|
|
615
|
+
|
|
616
|
+
# Track estimated output tokens by counting text chunks
|
|
617
|
+
estimated_tokens = 0
|
|
618
|
+
reasoning_active = False
|
|
619
|
+
reasoning_segments: list[str] = []
|
|
620
|
+
reasoning_mode = ModelDatabase.get_reasoning(model)
|
|
621
|
+
|
|
622
|
+
# Manual accumulation of response data
|
|
623
|
+
accumulated_content = ""
|
|
624
|
+
role = "assistant"
|
|
625
|
+
tool_calls_map = {} # Use a map to accumulate tool calls by index
|
|
626
|
+
function_call = None
|
|
627
|
+
finish_reason = None
|
|
628
|
+
usage_data = None
|
|
629
|
+
|
|
630
|
+
# Track tool call state for stream events
|
|
631
|
+
tool_call_started: dict[int, dict[str, Any]] = {}
|
|
632
|
+
streams_arguments = self._streams_tool_arguments()
|
|
633
|
+
notified_tool_indices: set[int] = set()
|
|
634
|
+
|
|
635
|
+
# Process the stream chunks manually
|
|
636
|
+
# Cancellation is handled via asyncio.Task.cancel() which raises CancelledError
|
|
637
|
+
async for chunk in stream:
|
|
638
|
+
# Save chunk if stream capture is enabled
|
|
639
|
+
_save_stream_chunk(capture_filename, chunk)
|
|
640
|
+
# Process streaming events for tool calls
|
|
641
|
+
if chunk.choices:
|
|
642
|
+
choice = chunk.choices[0]
|
|
643
|
+
delta = choice.delta
|
|
644
|
+
|
|
645
|
+
reasoning_text = self._extract_reasoning_text(
|
|
646
|
+
reasoning=getattr(delta, "reasoning", None),
|
|
647
|
+
reasoning_content=getattr(delta, "reasoning_content", None),
|
|
648
|
+
)
|
|
649
|
+
reasoning_active = self._handle_reasoning_delta(
|
|
650
|
+
reasoning_mode=reasoning_mode,
|
|
651
|
+
reasoning_text=reasoning_text,
|
|
652
|
+
reasoning_active=reasoning_active,
|
|
653
|
+
reasoning_segments=reasoning_segments,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# Handle tool call streaming
|
|
657
|
+
if delta.tool_calls:
|
|
658
|
+
self._handle_tool_delta(
|
|
659
|
+
delta_tool_calls=delta.tool_calls,
|
|
660
|
+
tool_call_started=tool_call_started,
|
|
661
|
+
streams_arguments=streams_arguments,
|
|
662
|
+
model=model,
|
|
663
|
+
notified_tool_indices=notified_tool_indices,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
# Handle text content streaming
|
|
667
|
+
if delta.content:
|
|
668
|
+
estimated_tokens, reasoning_active = self._emit_text_delta(
|
|
669
|
+
content=delta.content,
|
|
670
|
+
model=model,
|
|
671
|
+
estimated_tokens=estimated_tokens,
|
|
672
|
+
streams_arguments=streams_arguments,
|
|
673
|
+
reasoning_active=reasoning_active,
|
|
674
|
+
)
|
|
675
|
+
accumulated_content += delta.content
|
|
676
|
+
|
|
677
|
+
# Fire "stop" event when tool calls complete
|
|
678
|
+
if choice.finish_reason == "tool_calls":
|
|
679
|
+
self._finalize_tool_calls_on_stop(
|
|
680
|
+
tool_call_started=tool_call_started,
|
|
681
|
+
streams_arguments=streams_arguments,
|
|
682
|
+
model=model,
|
|
683
|
+
notified_tool_indices=notified_tool_indices,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
# Extract other fields from the chunk
|
|
687
|
+
if chunk.choices:
|
|
688
|
+
choice = chunk.choices[0]
|
|
689
|
+
if choice.delta.role:
|
|
690
|
+
role = choice.delta.role
|
|
691
|
+
if choice.delta.tool_calls:
|
|
692
|
+
# Accumulate tool call deltas
|
|
693
|
+
for delta_tool_call in choice.delta.tool_calls:
|
|
694
|
+
if delta_tool_call.index is not None:
|
|
695
|
+
if delta_tool_call.index not in tool_calls_map:
|
|
696
|
+
tool_calls_map[delta_tool_call.index] = {
|
|
697
|
+
"id": delta_tool_call.id,
|
|
698
|
+
"type": delta_tool_call.type or "function",
|
|
699
|
+
"function": {
|
|
700
|
+
"name": delta_tool_call.function.name
|
|
701
|
+
if delta_tool_call.function
|
|
702
|
+
else None,
|
|
703
|
+
"arguments": "",
|
|
704
|
+
},
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
# Always update if we have new data (needed for OpenRouter Gemini)
|
|
708
|
+
if delta_tool_call.id:
|
|
709
|
+
tool_calls_map[delta_tool_call.index]["id"] = delta_tool_call.id
|
|
710
|
+
if delta_tool_call.function:
|
|
711
|
+
if delta_tool_call.function.name:
|
|
712
|
+
tool_calls_map[delta_tool_call.index]["function"]["name"] = (
|
|
713
|
+
delta_tool_call.function.name
|
|
714
|
+
)
|
|
715
|
+
# Handle arguments - they might come as None, empty string, or actual content
|
|
716
|
+
if delta_tool_call.function.arguments is not None:
|
|
717
|
+
tool_calls_map[delta_tool_call.index]["function"][
|
|
718
|
+
"arguments"
|
|
719
|
+
] += delta_tool_call.function.arguments
|
|
720
|
+
|
|
721
|
+
if choice.delta.function_call:
|
|
722
|
+
function_call = choice.delta.function_call
|
|
723
|
+
if choice.finish_reason:
|
|
724
|
+
finish_reason = choice.finish_reason
|
|
725
|
+
|
|
726
|
+
# Extract usage data if available
|
|
727
|
+
if hasattr(chunk, "usage") and chunk.usage:
|
|
728
|
+
usage_data = chunk.usage
|
|
729
|
+
|
|
730
|
+
# Convert accumulated tool calls to proper format.
|
|
731
|
+
tool_calls = None
|
|
732
|
+
if tool_calls_map:
|
|
733
|
+
tool_calls = []
|
|
734
|
+
for idx in sorted(tool_calls_map.keys()):
|
|
735
|
+
tool_call_data = tool_calls_map[idx]
|
|
736
|
+
# Only add tool calls that have valid data
|
|
737
|
+
if tool_call_data["id"] and tool_call_data["function"]["name"]:
|
|
738
|
+
tool_calls.append(
|
|
739
|
+
ChatCompletionMessageToolCall(
|
|
740
|
+
id=tool_call_data["id"],
|
|
741
|
+
type=tool_call_data["type"],
|
|
742
|
+
function={
|
|
743
|
+
"name": tool_call_data["function"]["name"],
|
|
744
|
+
"arguments": tool_call_data["function"]["arguments"],
|
|
745
|
+
},
|
|
746
|
+
)
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
# Create a ChatCompletionMessage manually
|
|
750
|
+
message = ChatCompletionMessage(
|
|
751
|
+
content=accumulated_content,
|
|
752
|
+
role=role,
|
|
753
|
+
tool_calls=tool_calls if tool_calls else None,
|
|
754
|
+
function_call=function_call,
|
|
755
|
+
refusal=None,
|
|
756
|
+
annotations=None,
|
|
757
|
+
audio=None,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
reasoning_active = False
|
|
761
|
+
|
|
762
|
+
from types import SimpleNamespace
|
|
763
|
+
|
|
764
|
+
final_completion = SimpleNamespace()
|
|
765
|
+
final_completion.choices = [SimpleNamespace()]
|
|
766
|
+
final_completion.choices[0].message = message
|
|
767
|
+
final_completion.choices[0].finish_reason = finish_reason
|
|
768
|
+
final_completion.usage = usage_data
|
|
769
|
+
|
|
770
|
+
# Log final usage information
|
|
771
|
+
if usage_data:
|
|
772
|
+
actual_tokens = getattr(usage_data, "completion_tokens", estimated_tokens)
|
|
773
|
+
token_str = str(actual_tokens).rjust(5)
|
|
774
|
+
data = {
|
|
775
|
+
"progress_action": ProgressAction.STREAMING,
|
|
776
|
+
"model": model,
|
|
777
|
+
"agent_name": self.name,
|
|
778
|
+
"chat_turn": self.chat_turn(),
|
|
779
|
+
"details": token_str.strip(),
|
|
780
|
+
}
|
|
781
|
+
self.logger.info("Streaming progress", data=data)
|
|
782
|
+
|
|
783
|
+
self.logger.info(
|
|
784
|
+
f"Streaming complete - Model: {model}, Input tokens: {getattr(usage_data, 'prompt_tokens', 0)}, Output tokens: {actual_tokens}"
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
final_message = final_completion.choices[0].message if final_completion.choices else None
|
|
788
|
+
tool_calls = getattr(final_message, "tool_calls", None) if final_message else None
|
|
789
|
+
self._emit_tool_notification_fallback(
|
|
790
|
+
tool_calls,
|
|
791
|
+
notified_tool_indices,
|
|
792
|
+
streams_arguments=streams_arguments,
|
|
793
|
+
model=model,
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
return final_completion, reasoning_segments
|
|
797
|
+
|
|
798
|
+
async def _openai_completion(
|
|
799
|
+
self,
|
|
800
|
+
message: list[OpenAIMessage] | None,
|
|
801
|
+
request_params: RequestParams | None = None,
|
|
802
|
+
tools: list[Tool] | None = None,
|
|
803
|
+
) -> PromptMessageExtended:
|
|
804
|
+
"""
|
|
805
|
+
Process a query using an LLM and available tools.
|
|
806
|
+
The default implementation uses OpenAI's ChatCompletion as the LLM.
|
|
807
|
+
Override this method to use a different LLM.
|
|
808
|
+
"""
|
|
809
|
+
|
|
810
|
+
request_params = self.get_request_params(request_params=request_params)
|
|
811
|
+
|
|
812
|
+
response_content_blocks: list[ContentBlock] = []
|
|
813
|
+
model_name = self.default_request_params.model or DEFAULT_OPENAI_MODEL
|
|
814
|
+
|
|
815
|
+
# TODO -- move this in to agent context management / agent group handling
|
|
816
|
+
messages: list[ChatCompletionMessageParam] = []
|
|
817
|
+
system_prompt = self.instruction or request_params.systemPrompt
|
|
818
|
+
if system_prompt:
|
|
819
|
+
messages.append(ChatCompletionSystemMessageParam(role="system", content=system_prompt))
|
|
820
|
+
|
|
821
|
+
# The caller supplies the full history; convert it directly
|
|
822
|
+
if message:
|
|
823
|
+
messages.extend(message)
|
|
824
|
+
|
|
825
|
+
available_tools: list[ChatCompletionToolParam] | None = [
|
|
826
|
+
{
|
|
827
|
+
"type": "function",
|
|
828
|
+
"function": {
|
|
829
|
+
"name": tool.name,
|
|
830
|
+
"description": tool.description if tool.description else "",
|
|
831
|
+
"parameters": self.adjust_schema(tool.inputSchema),
|
|
832
|
+
},
|
|
833
|
+
}
|
|
834
|
+
for tool in tools or []
|
|
835
|
+
]
|
|
836
|
+
|
|
837
|
+
if not available_tools:
|
|
838
|
+
if self.provider in [Provider.DEEPSEEK, Provider.ALIYUN]:
|
|
839
|
+
available_tools = None # deepseek/aliyun does not allow empty array
|
|
840
|
+
else:
|
|
841
|
+
available_tools = []
|
|
842
|
+
|
|
843
|
+
# we do NOT send "stop sequences" as this causes errors with mutlimodal processing
|
|
844
|
+
arguments: dict[str, Any] = self._prepare_api_request(
|
|
845
|
+
messages, available_tools, request_params
|
|
846
|
+
)
|
|
847
|
+
if not self._reasoning and request_params.stopSequences:
|
|
848
|
+
arguments["stop"] = request_params.stopSequences
|
|
849
|
+
|
|
850
|
+
self.logger.debug(f"OpenAI completion requested for: {arguments}")
|
|
851
|
+
|
|
852
|
+
self._log_chat_progress(self.chat_turn(), model=self.default_request_params.model)
|
|
853
|
+
model_name = self.default_request_params.model or DEFAULT_OPENAI_MODEL
|
|
854
|
+
|
|
855
|
+
# Generate stream capture filename once (before streaming starts)
|
|
856
|
+
capture_filename = _stream_capture_filename(self.chat_turn())
|
|
857
|
+
_save_stream_request(capture_filename, arguments)
|
|
858
|
+
|
|
859
|
+
# Use basic streaming API with context manager to properly close aiohttp session
|
|
860
|
+
try:
|
|
861
|
+
async with self._openai_client() as client:
|
|
862
|
+
stream = await client.chat.completions.create(**arguments)
|
|
863
|
+
# Process the stream
|
|
864
|
+
response, streamed_reasoning = await self._process_stream(
|
|
865
|
+
stream, model_name, capture_filename
|
|
866
|
+
)
|
|
867
|
+
except asyncio.CancelledError as e:
|
|
868
|
+
reason = str(e) if e.args else "cancelled"
|
|
869
|
+
self.logger.info(f"OpenAI completion cancelled: {reason}")
|
|
870
|
+
# Return a response indicating cancellation
|
|
871
|
+
return Prompt.assistant(
|
|
872
|
+
TextContent(type="text", text=""),
|
|
873
|
+
stop_reason=LlmStopReason.CANCELLED,
|
|
874
|
+
)
|
|
875
|
+
except APIError as error:
|
|
876
|
+
self.logger.error("APIError during OpenAI completion", exc_info=error)
|
|
877
|
+
raise error
|
|
878
|
+
except Exception:
|
|
879
|
+
streamed_reasoning = []
|
|
880
|
+
raise
|
|
881
|
+
# Track usage if response is valid and has usage data
|
|
882
|
+
if (
|
|
883
|
+
hasattr(response, "usage")
|
|
884
|
+
and response.usage
|
|
885
|
+
and not isinstance(response, BaseException)
|
|
886
|
+
):
|
|
887
|
+
try:
|
|
888
|
+
turn_usage = TurnUsage.from_openai(response.usage, model_name)
|
|
889
|
+
self._finalize_turn_usage(turn_usage)
|
|
890
|
+
except Exception as e:
|
|
891
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
|
892
|
+
|
|
893
|
+
self.logger.debug(
|
|
894
|
+
"OpenAI completion response:",
|
|
895
|
+
data=response,
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
if isinstance(response, AuthenticationError):
|
|
899
|
+
raise ProviderKeyError(
|
|
900
|
+
"Rejected OpenAI API key",
|
|
901
|
+
"The configured OpenAI API key was rejected.\n"
|
|
902
|
+
"Please check that your API key is valid and not expired.",
|
|
903
|
+
) from response
|
|
904
|
+
elif isinstance(response, BaseException):
|
|
905
|
+
self.logger.error(f"Error: {response}")
|
|
906
|
+
|
|
907
|
+
choice = response.choices[0]
|
|
908
|
+
message = choice.message
|
|
909
|
+
normalized_role = self._normalize_role(getattr(message, "role", None))
|
|
910
|
+
# prep for image/audio gen models
|
|
911
|
+
if message.content:
|
|
912
|
+
response_content_blocks.append(TextContent(type="text", text=message.content))
|
|
913
|
+
|
|
914
|
+
# ParsedChatCompletionMessage is compatible with ChatCompletionMessage
|
|
915
|
+
# since it inherits from it, so we can use it directly
|
|
916
|
+
# Convert to dict and remove None values
|
|
917
|
+
message_dict = message.model_dump()
|
|
918
|
+
message_dict = {k: v for k, v in message_dict.items() if v is not None}
|
|
919
|
+
if normalized_role:
|
|
920
|
+
try:
|
|
921
|
+
message.role = normalized_role
|
|
922
|
+
except Exception:
|
|
923
|
+
pass
|
|
924
|
+
|
|
925
|
+
if model_name in (
|
|
926
|
+
"deepseek-r1-distill-llama-70b",
|
|
927
|
+
"openai/gpt-oss-120b",
|
|
928
|
+
"openai/gpt-oss-20b",
|
|
929
|
+
):
|
|
930
|
+
message_dict.pop("reasoning", None)
|
|
931
|
+
message_dict.pop("channel", None)
|
|
932
|
+
|
|
933
|
+
message_dict["role"] = normalized_role or message_dict.get("role", "assistant")
|
|
934
|
+
|
|
935
|
+
messages.append(message_dict)
|
|
936
|
+
stop_reason = LlmStopReason.END_TURN
|
|
937
|
+
requested_tool_calls: dict[str, CallToolRequest] | None = None
|
|
938
|
+
if await self._is_tool_stop_reason(choice.finish_reason) and message.tool_calls:
|
|
939
|
+
requested_tool_calls = {}
|
|
940
|
+
stop_reason = LlmStopReason.TOOL_USE
|
|
941
|
+
for tool_call in message.tool_calls:
|
|
942
|
+
tool_call_request = CallToolRequest(
|
|
943
|
+
method="tools/call",
|
|
944
|
+
params=CallToolRequestParams(
|
|
945
|
+
name=tool_call.function.name,
|
|
946
|
+
arguments={}
|
|
947
|
+
if not tool_call.function.arguments
|
|
948
|
+
or tool_call.function.arguments.strip() == ""
|
|
949
|
+
else from_json(tool_call.function.arguments, allow_partial=True),
|
|
950
|
+
),
|
|
951
|
+
)
|
|
952
|
+
requested_tool_calls[tool_call.id] = tool_call_request
|
|
953
|
+
elif choice.finish_reason == "length":
|
|
954
|
+
stop_reason = LlmStopReason.MAX_TOKENS
|
|
955
|
+
# We have reached the max tokens limit
|
|
956
|
+
self.logger.debug(" Stopping because finish_reason is 'length'")
|
|
957
|
+
elif choice.finish_reason == "content_filter":
|
|
958
|
+
stop_reason = LlmStopReason.SAFETY
|
|
959
|
+
self.logger.debug(" Stopping because finish_reason is 'content_filter'")
|
|
960
|
+
|
|
961
|
+
# Update diagnostic snapshot (never read again)
|
|
962
|
+
# This provides a snapshot of what was sent to the provider for debugging
|
|
963
|
+
self.history.set(messages)
|
|
964
|
+
|
|
965
|
+
self._log_chat_finished(model=self.default_request_params.model)
|
|
966
|
+
|
|
967
|
+
reasoning_blocks: list[ContentBlock] | None = None
|
|
968
|
+
if streamed_reasoning:
|
|
969
|
+
reasoning_blocks = [TextContent(type="text", text="".join(streamed_reasoning))]
|
|
970
|
+
|
|
971
|
+
return PromptMessageExtended(
|
|
972
|
+
role="assistant",
|
|
973
|
+
content=response_content_blocks,
|
|
974
|
+
tool_calls=requested_tool_calls,
|
|
975
|
+
channels={REASONING: reasoning_blocks} if reasoning_blocks else None,
|
|
976
|
+
stop_reason=stop_reason,
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
def _stream_failure_response(self, error: APIError, model_name: str) -> PromptMessageExtended:
|
|
980
|
+
"""Convert streaming API errors into a graceful assistant reply."""
|
|
981
|
+
|
|
982
|
+
provider_label = (
|
|
983
|
+
self.provider.value if isinstance(self.provider, Provider) else str(self.provider)
|
|
984
|
+
)
|
|
985
|
+
detail = getattr(error, "message", None) or str(error)
|
|
986
|
+
detail = detail.strip() if isinstance(detail, str) else ""
|
|
987
|
+
|
|
988
|
+
parts: list[str] = [f"{provider_label} request failed"]
|
|
989
|
+
if model_name:
|
|
990
|
+
parts.append(f"for model '{model_name}'")
|
|
991
|
+
code = getattr(error, "code", None)
|
|
992
|
+
if code:
|
|
993
|
+
parts.append(f"(code: {code})")
|
|
994
|
+
status = getattr(error, "status_code", None)
|
|
995
|
+
if status:
|
|
996
|
+
parts.append(f"(status={status})")
|
|
997
|
+
|
|
998
|
+
message = " ".join(parts)
|
|
999
|
+
if detail:
|
|
1000
|
+
message = f"{message}: {detail}"
|
|
1001
|
+
|
|
1002
|
+
user_summary = " ".join(message.split()) if message else ""
|
|
1003
|
+
if user_summary and len(user_summary) > 280:
|
|
1004
|
+
user_summary = user_summary[:277].rstrip() + "..."
|
|
1005
|
+
|
|
1006
|
+
if user_summary:
|
|
1007
|
+
assistant_text = f"I hit an internal error while calling the model: {user_summary}"
|
|
1008
|
+
if not assistant_text.endswith((".", "!", "?")):
|
|
1009
|
+
assistant_text += "."
|
|
1010
|
+
assistant_text += " See fast-agent-error for additional details."
|
|
1011
|
+
else:
|
|
1012
|
+
assistant_text = (
|
|
1013
|
+
"I hit an internal error while calling the model; see fast-agent-error for details."
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
assistant_block = text_content(assistant_text)
|
|
1017
|
+
error_block = text_content(message)
|
|
1018
|
+
|
|
1019
|
+
return PromptMessageExtended(
|
|
1020
|
+
role="assistant",
|
|
1021
|
+
content=[assistant_block],
|
|
1022
|
+
channels={FAST_AGENT_ERROR_CHANNEL: [error_block]},
|
|
1023
|
+
stop_reason=LlmStopReason.ERROR,
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
def _handle_retry_failure(self, error: Exception) -> PromptMessageExtended | None:
|
|
1027
|
+
"""Return the legacy error-channel response when retries are exhausted."""
|
|
1028
|
+
if isinstance(error, APIError):
|
|
1029
|
+
model_name = self.default_request_params.model or DEFAULT_OPENAI_MODEL
|
|
1030
|
+
return self._stream_failure_response(error, model_name)
|
|
1031
|
+
return None
|
|
1032
|
+
|
|
1033
|
+
async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
|
|
1034
|
+
return True
|
|
1035
|
+
|
|
1036
|
+
async def _apply_prompt_provider_specific(
|
|
1037
|
+
self,
|
|
1038
|
+
multipart_messages: list["PromptMessageExtended"],
|
|
1039
|
+
request_params: RequestParams | None = None,
|
|
1040
|
+
tools: list[Tool] | None = None,
|
|
1041
|
+
is_template: bool = False,
|
|
1042
|
+
) -> PromptMessageExtended:
|
|
1043
|
+
"""
|
|
1044
|
+
Provider-specific prompt application.
|
|
1045
|
+
Templates are handled by the agent; messages already include them.
|
|
1046
|
+
"""
|
|
1047
|
+
# Determine effective params
|
|
1048
|
+
req_params = self.get_request_params(request_params)
|
|
1049
|
+
|
|
1050
|
+
last_message = multipart_messages[-1]
|
|
1051
|
+
|
|
1052
|
+
# If the last message is from the assistant, no inference required
|
|
1053
|
+
if last_message.role == "assistant":
|
|
1054
|
+
return last_message
|
|
1055
|
+
|
|
1056
|
+
# Convert the supplied history/messages directly
|
|
1057
|
+
converted_messages = self._convert_to_provider_format(multipart_messages)
|
|
1058
|
+
if not converted_messages:
|
|
1059
|
+
converted_messages = [{"role": "user", "content": ""}]
|
|
1060
|
+
|
|
1061
|
+
return await self._openai_completion(converted_messages, req_params, tools)
|
|
1062
|
+
|
|
1063
|
+
def _prepare_api_request(
|
|
1064
|
+
self, messages, tools: list[ChatCompletionToolParam] | None, request_params: RequestParams
|
|
1065
|
+
) -> dict[str, str]:
|
|
1066
|
+
# Create base arguments dictionary
|
|
1067
|
+
|
|
1068
|
+
# overriding model via request params not supported (intentional)
|
|
1069
|
+
base_args = {
|
|
1070
|
+
"model": self.default_request_params.model,
|
|
1071
|
+
"messages": messages,
|
|
1072
|
+
"tools": tools,
|
|
1073
|
+
"stream": True, # Enable basic streaming
|
|
1074
|
+
"stream_options": {"include_usage": True}, # Required for usage data in streaming
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
if self._reasoning:
|
|
1078
|
+
base_args.update(
|
|
1079
|
+
{
|
|
1080
|
+
"max_completion_tokens": request_params.maxTokens,
|
|
1081
|
+
"reasoning_effort": self._reasoning_effort,
|
|
1082
|
+
}
|
|
1083
|
+
)
|
|
1084
|
+
else:
|
|
1085
|
+
base_args["max_tokens"] = request_params.maxTokens
|
|
1086
|
+
if tools:
|
|
1087
|
+
base_args["parallel_tool_calls"] = request_params.parallel_tool_calls
|
|
1088
|
+
|
|
1089
|
+
arguments: dict[str, str] = self.prepare_provider_arguments(
|
|
1090
|
+
base_args, request_params, self.OPENAI_EXCLUDE_FIELDS.union(self.BASE_EXCLUDE_FIELDS)
|
|
1091
|
+
)
|
|
1092
|
+
return arguments
|
|
1093
|
+
|
|
1094
|
+
@staticmethod
|
|
1095
|
+
def _extract_reasoning_text(reasoning: Any = None, reasoning_content: Any | None = None) -> str:
|
|
1096
|
+
"""Extract text from provider-specific reasoning payloads.
|
|
1097
|
+
|
|
1098
|
+
Priority: explicit `reasoning` field (string/object/list) > `reasoning_content` list.
|
|
1099
|
+
"""
|
|
1100
|
+
|
|
1101
|
+
def _coerce_text(value: Any) -> str:
|
|
1102
|
+
if value is None:
|
|
1103
|
+
return ""
|
|
1104
|
+
if isinstance(value, str):
|
|
1105
|
+
return value
|
|
1106
|
+
if isinstance(value, dict):
|
|
1107
|
+
return str(value.get("text") or value)
|
|
1108
|
+
text_attr = None
|
|
1109
|
+
try:
|
|
1110
|
+
text_attr = getattr(value, "text", None)
|
|
1111
|
+
except Exception:
|
|
1112
|
+
text_attr = None
|
|
1113
|
+
if text_attr:
|
|
1114
|
+
return str(text_attr)
|
|
1115
|
+
return str(value)
|
|
1116
|
+
|
|
1117
|
+
if reasoning is not None:
|
|
1118
|
+
if isinstance(reasoning, (list, tuple)):
|
|
1119
|
+
combined = "".join(_coerce_text(item) for item in reasoning)
|
|
1120
|
+
else:
|
|
1121
|
+
combined = _coerce_text(reasoning)
|
|
1122
|
+
if combined.strip():
|
|
1123
|
+
return combined
|
|
1124
|
+
|
|
1125
|
+
if reasoning_content:
|
|
1126
|
+
parts: list[str] = []
|
|
1127
|
+
for item in reasoning_content:
|
|
1128
|
+
text = _coerce_text(item)
|
|
1129
|
+
if text:
|
|
1130
|
+
parts.append(text)
|
|
1131
|
+
combined = "".join(parts)
|
|
1132
|
+
if combined.strip():
|
|
1133
|
+
return combined
|
|
1134
|
+
|
|
1135
|
+
return ""
|
|
1136
|
+
|
|
1137
|
+
def _convert_extended_messages_to_provider(
|
|
1138
|
+
self, messages: list[PromptMessageExtended]
|
|
1139
|
+
) -> list[ChatCompletionMessageParam]:
|
|
1140
|
+
"""
|
|
1141
|
+
Convert PromptMessageExtended list to OpenAI ChatCompletionMessageParam format.
|
|
1142
|
+
This is called fresh on every API call from _convert_to_provider_format().
|
|
1143
|
+
|
|
1144
|
+
Args:
|
|
1145
|
+
messages: List of PromptMessageExtended objects
|
|
1146
|
+
|
|
1147
|
+
Returns:
|
|
1148
|
+
List of OpenAI ChatCompletionMessageParam objects
|
|
1149
|
+
"""
|
|
1150
|
+
converted: list[ChatCompletionMessageParam] = []
|
|
1151
|
+
reasoning_mode = ModelDatabase.get_reasoning(self.default_request_params.model)
|
|
1152
|
+
|
|
1153
|
+
for msg in messages:
|
|
1154
|
+
# convert_to_openai returns a list of messages
|
|
1155
|
+
openai_msgs = OpenAIConverter.convert_to_openai(msg)
|
|
1156
|
+
|
|
1157
|
+
if reasoning_mode == "reasoning_content" and msg.channels:
|
|
1158
|
+
reasoning_blocks = msg.channels.get(REASONING) if msg.channels else None
|
|
1159
|
+
if reasoning_blocks:
|
|
1160
|
+
reasoning_texts = [get_text(block) for block in reasoning_blocks]
|
|
1161
|
+
reasoning_texts = [txt for txt in reasoning_texts if txt]
|
|
1162
|
+
if reasoning_texts:
|
|
1163
|
+
reasoning_content = "\n\n".join(reasoning_texts)
|
|
1164
|
+
for oai_msg in openai_msgs:
|
|
1165
|
+
oai_msg["reasoning_content"] = reasoning_content
|
|
1166
|
+
|
|
1167
|
+
# gpt-oss: per docs, reasoning should be dropped on subsequent sampling
|
|
1168
|
+
# UNLESS tool calling is involved. For tool calls, prefix the assistant
|
|
1169
|
+
# message content with the reasoning text.
|
|
1170
|
+
if reasoning_mode == "gpt_oss" and msg.channels and msg.tool_calls:
|
|
1171
|
+
reasoning_blocks = msg.channels.get(REASONING) if msg.channels else None
|
|
1172
|
+
if reasoning_blocks:
|
|
1173
|
+
reasoning_texts = [get_text(block) for block in reasoning_blocks]
|
|
1174
|
+
reasoning_texts = [txt for txt in reasoning_texts if txt]
|
|
1175
|
+
if reasoning_texts:
|
|
1176
|
+
reasoning_text = "\n\n".join(reasoning_texts)
|
|
1177
|
+
for oai_msg in openai_msgs:
|
|
1178
|
+
existing_content = oai_msg.get("content", "") or ""
|
|
1179
|
+
oai_msg["content"] = reasoning_text + existing_content
|
|
1180
|
+
|
|
1181
|
+
converted.extend(openai_msgs)
|
|
1182
|
+
|
|
1183
|
+
return converted
|
|
1184
|
+
|
|
1185
|
+
def adjust_schema(self, inputSchema: dict) -> dict:
|
|
1186
|
+
# return inputSchema
|
|
1187
|
+
if self.provider not in [Provider.OPENAI, Provider.AZURE]:
|
|
1188
|
+
return inputSchema
|
|
1189
|
+
|
|
1190
|
+
if "properties" in inputSchema:
|
|
1191
|
+
return inputSchema
|
|
1192
|
+
|
|
1193
|
+
result = inputSchema.copy()
|
|
1194
|
+
result["properties"] = {}
|
|
1195
|
+
return result
|