fast-agent-mcp 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_agent/config.py CHANGED
@@ -60,7 +60,7 @@ class MCPServerSettings(BaseModel):
60
60
  description: str | None = None
61
61
  """The description of the server."""
62
62
 
63
- transport: Literal["stdio", "sse"] = "stdio"
63
+ transport: Literal["stdio", "sse", "http"] = "stdio"
64
64
  """The transport mechanism."""
65
65
 
66
66
  command: str | None = None
@@ -198,6 +198,16 @@ class OpenTelemetrySettings(BaseModel):
198
198
  """Sample rate for tracing (1.0 = sample everything)"""
199
199
 
200
200
 
201
+ class TensorZeroSettings(BaseModel):
202
+ """
203
+ Settings for using TensorZero via its OpenAI-compatible API.
204
+ """
205
+
206
+ base_url: Optional[str] = None
207
+ api_key: Optional[str] = None
208
+ model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
209
+
210
+
201
211
  class LoggerSettings(BaseModel):
202
212
  """
203
213
  Logger settings for the fast-agent application.
@@ -239,6 +249,8 @@ class LoggerSettings(BaseModel):
239
249
  """Show MCP Sever tool calls on the console"""
240
250
  truncate_tools: bool = True
241
251
  """Truncate display of long tool calls"""
252
+ enable_markup: bool = True
253
+ """Enable markup in console output. Disable for outputs that may conflict with rich console formatting"""
242
254
 
243
255
 
244
256
  class Settings(BaseSettings):
@@ -287,6 +299,9 @@ class Settings(BaseSettings):
287
299
  generic: GenericSettings | None = None
288
300
  """Settings for using Generic models in the fast-agent application"""
289
301
 
302
+ tensorzero: Optional[TensorZeroSettings] = None
303
+ """Settings for using TensorZero inference gateway"""
304
+
290
305
  logger: LoggerSettings | None = LoggerSettings()
291
306
  """Logger settings for the fast-agent application"""
292
307
 
@@ -131,8 +131,8 @@ class FastAgent:
131
131
  )
132
132
  parser.add_argument(
133
133
  "--transport",
134
- choices=["sse", "stdio"],
135
- default="sse",
134
+ choices=["sse", "http", "stdio"],
135
+ default="http",
136
136
  help="Transport protocol to use when running as a server (sse or stdio)",
137
137
  )
138
138
  parser.add_argument(
@@ -2,7 +2,7 @@
2
2
  Request parameters definitions for LLM interactions.
3
3
  """
4
4
 
5
- from typing import Any, List
5
+ from typing import Any, Dict, List
6
6
 
7
7
  from mcp import SamplingMessage
8
8
  from mcp.types import CreateMessageRequestParams
@@ -25,26 +25,30 @@ class RequestParams(CreateMessageRequestParams):
25
25
 
26
26
  model: str | None = None
27
27
  """
28
- The model to use for the LLM generation.
28
+ The model to use for the LLM generation. This can only be set during Agent creation.
29
29
  If specified, this overrides the 'modelPreferences' selection criteria.
30
30
  """
31
31
 
32
32
  use_history: bool = True
33
33
  """
34
- Include the message history in the generate request.
34
+ Agent/LLM maintains conversation history. Does not include applied Prompts
35
35
  """
36
36
 
37
- max_iterations: int = 10
37
+ max_iterations: int = 20
38
38
  """
39
- The maximum number of iterations to run the LLM for.
39
+ The maximum number of tool calls allowed in a conversation turn
40
40
  """
41
41
 
42
42
  parallel_tool_calls: bool = True
43
43
  """
44
- Whether to allow multiple tool calls per iteration.
45
- Also known as multi-step tool use.
44
+ Whether to allow simultaneous tool calls
46
45
  """
47
46
  response_format: Any | None = None
48
47
  """
49
48
  Override response format for structured calls. Prefer sending pydantic model - only use in exceptional circumstances
50
49
  """
50
+
51
+ template_vars: Dict[str, Any] = Field(default_factory=dict)
52
+ """
53
+ Optional dictionary of template variables for dynamic templates. Currently only works for TensorZero inference backend
54
+ """
@@ -88,7 +88,7 @@ def convert_log_event(event: Event) -> Optional[ProgressEvent]:
88
88
 
89
89
  return ProgressEvent(
90
90
  action=ProgressAction(progress_action),
91
- target=target,
91
+ target=target or "unknown",
92
92
  details=details,
93
93
  agent_name=event_data.get("agent_name"),
94
94
  )
@@ -76,20 +76,14 @@ def deep_merge(dict1: Dict[Any, Any], dict2: Dict[Any, Any]) -> Dict[Any, Any]:
76
76
  Dict: The updated `dict1`.
77
77
  """
78
78
  for key in dict2:
79
- if (
80
- key in dict1
81
- and isinstance(dict1[key], dict)
82
- and isinstance(dict2[key], dict)
83
- ):
79
+ if key in dict1 and isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
84
80
  deep_merge(dict1[key], dict2[key])
85
81
  else:
86
82
  dict1[key] = dict2[key]
87
83
  return dict1
88
84
 
89
85
 
90
- class AugmentedLLM(
91
- ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]
92
- ):
86
+ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]):
93
87
  # Common parameter names used across providers
94
88
  PARAM_MESSAGES = "messages"
95
89
  PARAM_MODEL = "model"
@@ -100,7 +94,7 @@ class AugmentedLLM(
100
94
  PARAM_METADATA = "metadata"
101
95
  PARAM_USE_HISTORY = "use_history"
102
96
  PARAM_MAX_ITERATIONS = "max_iterations"
103
-
97
+ PARAM_TEMPLATE_VARS = "template_vars"
104
98
  # Base set of fields that should always be excluded
105
99
  BASE_EXCLUDE_FIELDS = {PARAM_METADATA}
106
100
 
@@ -15,6 +15,7 @@ from mcp_agent.llm.providers.augmented_llm_generic import GenericAugmentedLLM
15
15
  from mcp_agent.llm.providers.augmented_llm_google import GoogleAugmentedLLM
16
16
  from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
17
17
  from mcp_agent.llm.providers.augmented_llm_openrouter import OpenRouterAugmentedLLM
18
+ from mcp_agent.llm.providers.augmented_llm_tensorzero import TensorZeroAugmentedLLM
18
19
  from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
19
20
 
20
21
  # from mcp_agent.workflows.llm.augmented_llm_deepseek import DeekSeekAugmentedLLM
@@ -28,6 +29,7 @@ LLMClass = Union[
28
29
  Type[PlaybackLLM],
29
30
  Type[DeepSeekAugmentedLLM],
30
31
  Type[OpenRouterAugmentedLLM],
32
+ Type[TensorZeroAugmentedLLM],
31
33
  ]
32
34
 
33
35
 
@@ -110,6 +112,7 @@ class ModelFactory:
110
112
  Provider.GENERIC: GenericAugmentedLLM,
111
113
  Provider.GOOGLE: GoogleAugmentedLLM, # type: ignore
112
114
  Provider.OPENROUTER: OpenRouterAugmentedLLM,
115
+ Provider.TENSORZERO: TensorZeroAugmentedLLM,
113
116
  }
114
117
 
115
118
  # Mapping of special model names to their specific LLM classes
@@ -142,6 +145,11 @@ class ModelFactory:
142
145
  provider = Provider(potential_provider)
143
146
  model_parts = model_parts[1:]
144
147
 
148
+ if provider == Provider.TENSORZERO and not model_parts:
149
+ raise ModelConfigError(
150
+ f"TensorZero provider requires a function name after the provider "
151
+ f"(e.g., tensorzero.my-function), got: {model_string}"
152
+ )
145
153
  # Join remaining parts as model name
146
154
  model_name = ".".join(model_parts)
147
155
 
@@ -15,3 +15,4 @@ class Provider(Enum):
15
15
  DEEPSEEK = "deepseek"
16
16
  GENERIC = "generic"
17
17
  OPENROUTER = "openrouter"
18
+ TENSORZERO = "tensorzero" # For TensorZero Gateway
@@ -62,6 +62,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
62
62
  AugmentedLLM.PARAM_USE_HISTORY,
63
63
  AugmentedLLM.PARAM_MAX_ITERATIONS,
64
64
  AugmentedLLM.PARAM_PARALLEL_TOOL_CALLS,
65
+ AugmentedLLM.PARAM_TEMPLATE_VARS,
65
66
  }
66
67
 
67
68
  def __init__(self, *args, **kwargs) -> None:
@@ -56,6 +56,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
56
56
  AugmentedLLM.PARAM_PARALLEL_TOOL_CALLS,
57
57
  AugmentedLLM.PARAM_USE_HISTORY,
58
58
  AugmentedLLM.PARAM_MAX_ITERATIONS,
59
+ AugmentedLLM.PARAM_TEMPLATE_VARS,
59
60
  }
60
61
 
61
62
  def __init__(self, provider: Provider = Provider.OPENAI, *args, **kwargs) -> None:
@@ -143,7 +144,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
143
144
  function={
144
145
  "name": tool.name,
145
146
  "description": tool.description if tool.description else "",
146
- "parameters": tool.inputSchema,
147
+ "parameters": self.adjust_schema(tool.inputSchema),
147
148
  },
148
149
  )
149
150
  for tool in response.tools
@@ -350,3 +351,15 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
350
351
  base_args, request_params, self.OPENAI_EXCLUDE_FIELDS.union(self.BASE_EXCLUDE_FIELDS)
351
352
  )
352
353
  return arguments
354
+
355
+ def adjust_schema(self, inputSchema: Dict) -> Dict:
356
+ # return inputSchema
357
+ if not Provider.OPENAI == self.provider:
358
+ return inputSchema
359
+
360
+ if "properties" in inputSchema:
361
+ return inputSchema
362
+
363
+ result = inputSchema.copy()
364
+ result["properties"] = {}
365
+ return result
@@ -0,0 +1,442 @@
1
+ import json
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ from mcp.types import (
5
+ CallToolRequest,
6
+ CallToolRequestParams,
7
+ CallToolResult,
8
+ EmbeddedResource,
9
+ ImageContent,
10
+ TextContent,
11
+ )
12
+ from tensorzero import AsyncTensorZeroGateway
13
+ from tensorzero.types import (
14
+ ChatInferenceResponse,
15
+ JsonInferenceResponse,
16
+ TensorZeroError,
17
+ )
18
+
19
+ from mcp_agent.agents.agent import Agent
20
+ from mcp_agent.core.exceptions import ModelConfigError
21
+ from mcp_agent.core.request_params import RequestParams
22
+ from mcp_agent.llm.augmented_llm import AugmentedLLM
23
+ from mcp_agent.llm.memory import Memory, SimpleMemory
24
+ from mcp_agent.llm.provider_types import Provider
25
+ from mcp_agent.llm.providers.multipart_converter_tensorzero import TensorZeroConverter
26
+ from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
27
+
28
+
29
+ class TensorZeroAugmentedLLM(AugmentedLLM[Dict[str, Any], Any]):
30
+ """
31
+ AugmentedLLM implementation for TensorZero using its native API.
32
+ Uses the Converter pattern for message formatting.
33
+ Implements multi-turn tool calling logic, storing API dicts in history.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ agent: Agent,
39
+ model: str,
40
+ request_params: Optional[RequestParams] = None,
41
+ **kwargs: Any,
42
+ ):
43
+ self._t0_gateway: Optional[AsyncTensorZeroGateway] = None
44
+ self._t0_function_name: str = model
45
+ self._t0_episode_id: Optional[str] = kwargs.get("episode_id")
46
+
47
+ super().__init__(
48
+ agent=agent,
49
+ model=model,
50
+ provider=Provider.TENSORZERO,
51
+ request_params=request_params,
52
+ **kwargs,
53
+ )
54
+
55
+ self.history: Memory[Dict[str, Any]] = SimpleMemory[Dict[str, Any]]()
56
+
57
+ self.logger.info(
58
+ f"TensorZero LLM provider initialized for function '{self._t0_function_name}'. History type: {type(self.history)}"
59
+ )
60
+
61
+ @staticmethod
62
+ def block_to_dict(block: Any) -> Dict[str, Any]:
63
+ if hasattr(block, "model_dump"):
64
+ try:
65
+ dumped = block.model_dump(mode="json")
66
+ if dumped:
67
+ return dumped
68
+ except Exception:
69
+ pass
70
+ if hasattr(block, "__dict__"):
71
+ try:
72
+ block_vars = vars(block)
73
+ if block_vars:
74
+ return block_vars
75
+ except Exception:
76
+ pass
77
+ if isinstance(block, (str, int, float, bool, list, dict, type(None))):
78
+ return {"type": "raw", "content": block}
79
+
80
+ # Basic attribute extraction as fallback
81
+ d = {"type": getattr(block, "type", "unknown")}
82
+ for attr in ["id", "name", "text", "arguments"]:
83
+ if hasattr(block, attr):
84
+ d[attr] = getattr(block, attr)
85
+ if len(d) == 1 and d.get("type") == "unknown":
86
+ d["content"] = str(block)
87
+ return d
88
+
89
+ def _initialize_default_params(self, kwargs: dict) -> RequestParams:
90
+ func_name = kwargs.get("model", self._t0_function_name or "unknown_t0_function")
91
+ return RequestParams(
92
+ model=func_name,
93
+ systemPrompt=self.instruction,
94
+ maxTokens=4096,
95
+ use_history=True,
96
+ max_iterations=10, # Max iterations for tool use loop
97
+ parallel_tool_calls=True,
98
+ )
99
+
100
+ async def _initialize_gateway(self) -> AsyncTensorZeroGateway:
101
+ if self._t0_gateway is None:
102
+ self.logger.debug("Initializing AsyncTensorZeroGateway client...")
103
+ try:
104
+ base_url: Optional[str] = None
105
+ default_url = "http://localhost:3000"
106
+
107
+ if (
108
+ self.context
109
+ and self.context.config
110
+ and hasattr(self.context.config, "tensorzero")
111
+ and self.context.config.tensorzero
112
+ ):
113
+ base_url = getattr(self.context.config.tensorzero, "base_url", None)
114
+
115
+ if not base_url:
116
+ if not self.context:
117
+ # Handle case where context itself is missing, log and use default
118
+ self.logger.warning(
119
+ f"LLM context not found. Cannot read TensorZero Gateway base URL configuration. "
120
+ f"Using default: {default_url}"
121
+ )
122
+ else:
123
+ self.logger.warning(
124
+ f"TensorZero Gateway base URL not configured in context.config.tensorzero.base_url. "
125
+ f"Using default: {default_url}"
126
+ )
127
+
128
+ base_url = default_url
129
+
130
+ self._t0_gateway = await AsyncTensorZeroGateway.build_http(gateway_url=base_url) # type: ignore
131
+ self.logger.info(f"TensorZero Gateway client initialized for URL: {base_url}")
132
+ except Exception as e:
133
+ self.logger.error(f"Failed to initialize TensorZero Gateway: {e}")
134
+ raise ModelConfigError(f"Failed to initialize TensorZero Gateway lazily: {e}")
135
+
136
+ return self._t0_gateway
137
+
138
+ async def _apply_prompt_provider_specific(
139
+ self,
140
+ multipart_messages: List[PromptMessageMultipart],
141
+ request_params: Optional[RequestParams] = None,
142
+ is_template: bool = False,
143
+ ) -> PromptMessageMultipart:
144
+ gateway = await self._initialize_gateway()
145
+ merged_params = self.get_request_params(request_params)
146
+
147
+ # [1] Retrieve history
148
+ current_api_messages: List[Dict[str, Any]] = []
149
+ if merged_params.use_history:
150
+ try:
151
+ current_api_messages = self.history.get() or []
152
+ self.logger.debug(
153
+ f"Retrieved {len(current_api_messages)} API dict messages from history."
154
+ )
155
+ except Exception as e:
156
+ self.logger.error(f"Error retrieving history: {e}")
157
+
158
+ # [2] Convert *new* incoming PromptMessageMultipart messages to API dicts
159
+ for msg in multipart_messages:
160
+ msg_dict = TensorZeroConverter.convert_mcp_to_t0_message(msg)
161
+ if msg_dict:
162
+ current_api_messages.append(msg_dict)
163
+
164
+ t0_system_vars = self._prepare_t0_system_params(merged_params)
165
+ if t0_system_vars:
166
+ t0_api_input_dict = {"system": t0_system_vars}
167
+ else:
168
+ t0_api_input_dict = {}
169
+ available_tools: Optional[List[Dict[str, Any]]] = await self._prepare_t0_tools()
170
+
171
+ # [3] Initialize storage arrays for the text content of the assistant message reply and, optionally, tool calls and results, and begin inference loop
172
+ final_assistant_message: List[Union[TextContent, ImageContent, EmbeddedResource]] = []
173
+ last_executed_results: Optional[List[CallToolResult]] = None
174
+
175
+ for i in range(merged_params.max_iterations):
176
+ use_parallel_calls = merged_params.parallel_tool_calls if available_tools else False
177
+ current_t0_episode_id = self._t0_episode_id
178
+
179
+ try:
180
+ self.logger.debug(
181
+ f"Calling TensorZero inference (Iteration {i + 1}/{merged_params.max_iterations})..."
182
+ )
183
+ t0_api_input_dict["messages"] = current_api_messages # type: ignore
184
+
185
+ # [4] Call the TensorZero inference API
186
+ response_iter_or_completion = await gateway.inference(
187
+ function_name=self._t0_function_name,
188
+ input=t0_api_input_dict,
189
+ additional_tools=available_tools,
190
+ parallel_tool_calls=use_parallel_calls,
191
+ stream=False,
192
+ episode_id=current_t0_episode_id,
193
+ )
194
+
195
+ if not isinstance(
196
+ response_iter_or_completion, (ChatInferenceResponse, JsonInferenceResponse)
197
+ ):
198
+ self.logger.error(
199
+ f"Unexpected TensorZero response type: {type(response_iter_or_completion)}"
200
+ )
201
+ final_assistant_message = [
202
+ TextContent(type="text", text="Unexpected response type")
203
+ ]
204
+ break # Exit loop
205
+
206
+ # [5] quick check to confirm that episode_id is present and being used correctly by TensorZero
207
+ completion = response_iter_or_completion
208
+ if completion.episode_id: #
209
+ self._t0_episode_id = str(completion.episode_id)
210
+ if (
211
+ self._t0_episode_id != current_t0_episode_id
212
+ and current_t0_episode_id is not None
213
+ ):
214
+ raise Exception(
215
+ f"Episode ID mismatch: {self._t0_episode_id} != {current_t0_episode_id}"
216
+ )
217
+
218
+ # [6] Adapt TensorZero inference response to a format compatible with the broader framework
219
+ (
220
+ content_parts_this_turn, # Text/Image content ONLY
221
+ executed_results_this_iter, # Results from THIS iteration
222
+ raw_tool_call_blocks,
223
+ ) = await self._adapt_t0_native_completion(completion, available_tools)
224
+
225
+ last_executed_results = (
226
+ executed_results_this_iter # Track results from this iteration
227
+ )
228
+
229
+ # [7] If a text message was returned from the assistant, format that message using the multipart_converter_tensorzero.py helper methods and add this to the current list of API messages
230
+ assistant_api_content = []
231
+ for part in content_parts_this_turn:
232
+ api_part = TensorZeroConverter._convert_content_part(part)
233
+ if api_part:
234
+ assistant_api_content.append(api_part)
235
+ if raw_tool_call_blocks:
236
+ assistant_api_content.extend(
237
+ [self.block_to_dict(b) for b in raw_tool_call_blocks]
238
+ )
239
+
240
+ if assistant_api_content:
241
+ assistant_api_message_dict = {
242
+ "role": "assistant",
243
+ "content": assistant_api_content,
244
+ }
245
+ current_api_messages.append(assistant_api_message_dict)
246
+ elif executed_results_this_iter:
247
+ self.logger.debug(
248
+ "Assistant turn contained only tool calls, no API message added."
249
+ )
250
+
251
+ final_assistant_message = content_parts_this_turn
252
+
253
+ # [8] If there were no tool calls we're done. If not, format the tool results and add them to the current list of API messages
254
+ if not executed_results_this_iter:
255
+ self.logger.debug(f"Iteration {i + 1}: No tool calls detected. Finishing loop.")
256
+ break
257
+ else:
258
+ user_message_with_results = (
259
+ TensorZeroConverter.convert_tool_results_to_t0_user_message(
260
+ executed_results_this_iter
261
+ )
262
+ )
263
+ if user_message_with_results:
264
+ current_api_messages.append(user_message_with_results)
265
+ else:
266
+ self.logger.error("Converter failed to format tool results, breaking loop.")
267
+ break
268
+
269
+ # Check max iterations: TODO: implement logic in the future to handle this dynamically, checking for the presence of a tool call in the last iteration
270
+ if i == merged_params.max_iterations - 1:
271
+ self.logger.warning(f"Max iterations ({merged_params.max_iterations}) reached.")
272
+ break
273
+
274
+ # --- Error Handling for Inference Call ---
275
+ except TensorZeroError as e:
276
+ error_details = getattr(e, "detail", str(e.args[0] if e.args else e))
277
+ self.logger.error(f"TensorZero Error (HTTP {e.status_code}): {error_details}")
278
+ error_content = TextContent(type="text", text=f"TensorZero Error: {error_details}")
279
+ return PromptMessageMultipart(role="assistant", content=[error_content])
280
+ except Exception as e:
281
+ import traceback
282
+
283
+ self.logger.error(f"Unexpected Error: {e}\n{traceback.format_exc()}")
284
+ error_content = TextContent(type="text", text=f"Unexpected error: {e}")
285
+ return PromptMessageMultipart(role="assistant", content=[error_content])
286
+
287
+ # [9] Construct the final assistant message and update history
288
+ final_message_to_return = PromptMessageMultipart(
289
+ role="assistant", content=final_assistant_message
290
+ )
291
+
292
+ if merged_params.use_history:
293
+ try:
294
+ # Store the final list of API DICTIONARIES in history
295
+ self.history.set(current_api_messages)
296
+ self.logger.debug(
297
+ f"Updated self.history with {len(current_api_messages)} API message dicts."
298
+ )
299
+ except Exception as e:
300
+ self.logger.error(f"Failed to update self.history after loop: {e}")
301
+
302
+ # [10] Post final assistant message to display
303
+ display_text = final_message_to_return.all_text()
304
+ if display_text and display_text != "<no text>":
305
+ title = f"ASSISTANT/{self._t0_function_name}"
306
+ await self.show_assistant_message(message_text=display_text, title=title)
307
+
308
+ elif not final_assistant_message and last_executed_results:
309
+ self.logger.debug("Final assistant turn involved only tool calls, no text to display.")
310
+
311
+ return final_message_to_return
312
+
313
+ def _prepare_t0_system_params(self, merged_params: RequestParams) -> Dict[str, Any]:
314
+ """Prepares the 'system' dictionary part of the main input."""
315
+ t0_func_input = merged_params.template_vars.copy()
316
+
317
+ metadata_args = None
318
+ if merged_params.metadata and isinstance(merged_params.metadata, dict):
319
+ metadata_args = merged_params.metadata.get("tensorzero_arguments")
320
+ if isinstance(metadata_args, dict):
321
+ t0_func_input.update(metadata_args)
322
+ self.logger.debug(f"Merged tensorzero_arguments from metadata: {metadata_args}")
323
+ return t0_func_input
324
+
325
+ async def _prepare_t0_tools(self) -> Optional[List[Dict[str, Any]]]:
326
+ """Fetches and formats tools for the additional_tools parameter."""
327
+ formatted_tools: List[Dict[str, Any]] = []
328
+ try:
329
+ tools_response = await self.aggregator.list_tools()
330
+ if tools_response and hasattr(tools_response, "tools") and tools_response.tools:
331
+ for mcp_tool in tools_response.tools:
332
+ if (
333
+ not isinstance(mcp_tool.inputSchema, dict)
334
+ or mcp_tool.inputSchema.get("type") != "object"
335
+ ):
336
+ self.logger.warning(
337
+ f"Tool '{mcp_tool.name}' has invalid parameters schema. Skipping."
338
+ )
339
+ continue
340
+ t0_tool_dict = {
341
+ "name": mcp_tool.name,
342
+ "description": mcp_tool.description if mcp_tool.description else "",
343
+ "parameters": mcp_tool.inputSchema,
344
+ }
345
+ formatted_tools.append(t0_tool_dict)
346
+ return formatted_tools if formatted_tools else None
347
+ except Exception as e:
348
+ self.logger.error(f"Failed to fetch or format tools: {e}")
349
+ return None
350
+
351
+ async def _adapt_t0_native_completion(
352
+ self,
353
+ completion: Union[ChatInferenceResponse, JsonInferenceResponse],
354
+ available_tools_for_display: Optional[List[Dict[str, Any]]] = None,
355
+ ) -> Tuple[
356
+ List[Union[TextContent, ImageContent, EmbeddedResource]], # Text/Image content ONLY
357
+ List[CallToolResult], # Executed results
358
+ List[Any], # Raw tool_call blocks
359
+ ]:
360
+ content_parts_this_turn: List[Union[TextContent, ImageContent, EmbeddedResource]] = []
361
+ executed_tool_results: List[CallToolResult] = []
362
+ raw_tool_call_blocks_from_t0: List[Any] = []
363
+
364
+ if isinstance(completion, ChatInferenceResponse) and hasattr(completion, "content"):
365
+ for block in completion.content:
366
+ block_type = getattr(block, "type", "UNKNOWN")
367
+
368
+ if block_type == "text":
369
+ text_val = getattr(block, "text", None)
370
+ if text_val is not None:
371
+ content_parts_this_turn.append(TextContent(type="text", text=text_val))
372
+
373
+ elif block_type == "tool_call":
374
+ raw_tool_call_blocks_from_t0.append(block)
375
+ tool_use_id = getattr(block, "id", None)
376
+ tool_name = getattr(block, "name", None)
377
+ tool_input_raw = getattr(block, "arguments", None)
378
+ tool_input = {}
379
+ if isinstance(tool_input_raw, dict):
380
+ tool_input = tool_input_raw
381
+ elif isinstance(tool_input_raw, str):
382
+ try:
383
+ tool_input = json.loads(tool_input_raw)
384
+ except json.JSONDecodeError:
385
+ tool_input = {}
386
+ elif tool_input_raw is not None:
387
+ tool_input = {}
388
+
389
+ if tool_use_id and tool_name:
390
+ self.show_tool_call(
391
+ available_tools_for_display, tool_name, json.dumps(tool_input)
392
+ )
393
+ mcp_tool_request = CallToolRequest(
394
+ method="tools/call",
395
+ params=CallToolRequestParams(name=tool_name, arguments=tool_input),
396
+ )
397
+ try:
398
+ result: CallToolResult = await self.call_tool(
399
+ mcp_tool_request, tool_use_id
400
+ )
401
+ setattr(result, "_t0_tool_use_id_temp", tool_use_id)
402
+ setattr(result, "_t0_tool_name_temp", tool_name)
403
+ setattr(result, "_t0_is_error_temp", False)
404
+ executed_tool_results.append(result)
405
+ self.show_oai_tool_result(str(result))
406
+ except Exception as e:
407
+ self.logger.error(
408
+ f"Error executing tool {tool_name} (id: {tool_use_id}): {e}"
409
+ )
410
+ error_text = f"Error executing tool {tool_name}: {str(e)}"
411
+ error_result = CallToolResult(
412
+ isError=True, content=[TextContent(type="text", text=error_text)]
413
+ )
414
+ setattr(error_result, "_t0_tool_use_id_temp", tool_use_id)
415
+ setattr(error_result, "_t0_tool_name_temp", tool_name)
416
+ setattr(error_result, "_t0_is_error_temp", True)
417
+ executed_tool_results.append(error_result)
418
+ self.show_oai_tool_result(f"ERROR: {error_text}")
419
+
420
+ elif block_type == "thought":
421
+ thought_text = getattr(block, "text", None)
422
+ self.logger.debug(f"TensorZero thought: {thought_text}")
423
+ else:
424
+ self.logger.warning(
425
+ f"TensorZero Adapt: Skipping unknown block type: {block_type}"
426
+ )
427
+
428
+ elif isinstance(completion, JsonInferenceResponse):
429
+ # `completion.output.raw` should always be present unless the LLM provider returns unexpected data
430
+ if completion.output.raw:
431
+ content_parts_this_turn.append(TextContent(type="text", text=completion.output.raw))
432
+
433
+ return content_parts_this_turn, executed_tool_results, raw_tool_call_blocks_from_t0
434
+
435
+ async def shutdown(self):
436
+ """Close the TensorZero gateway client if initialized."""
437
+ if self._t0_gateway:
438
+ try:
439
+ await self._t0_gateway.close()
440
+ self.logger.debug("TensorZero Gateway client closed.")
441
+ except Exception as e:
442
+ self.logger.error(f"Error closing TensorZero Gateway client: {e}")