fast-agent-mcp 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,459 @@
1
+ from typing import List
2
+
3
+ # Import necessary types and client from google.genai
4
+ from google import genai
5
+ from google.genai import (
6
+ errors, # For error handling
7
+ types,
8
+ )
9
+ from mcp.types import (
10
+ CallToolRequest,
11
+ CallToolRequestParams,
12
+ CallToolResult,
13
+ EmbeddedResource,
14
+ ImageContent,
15
+ TextContent,
16
+ )
17
+ from rich.text import Text
18
+
19
+ from mcp_agent.core.exceptions import ProviderKeyError
20
+ from mcp_agent.core.prompt import Prompt
21
+ from mcp_agent.core.request_params import RequestParams
22
+ from mcp_agent.llm.augmented_llm import AugmentedLLM
23
+ from mcp_agent.llm.provider_types import Provider
24
+
25
+ # Import the new converter class
26
+ from mcp_agent.llm.providers.google_converter import GoogleConverter
27
+ from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
28
+
29
+ # Define default model and potentially other Google-specific defaults
30
+ DEFAULT_GOOGLE_MODEL = "gemini-2.0-flash"
31
+
32
+ # Suppress this warning for now
33
+ # TODO: Find out where we're passing null
34
+ # warnings.filterwarnings(
35
+ # "ignore",
36
+ # message="null is not a valid Type",
37
+ # category=UserWarning,
38
+ # module="google.genai._common",
39
+ # )
40
+
41
+
42
+ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
43
+ """
44
+ Google LLM provider using the native google.genai library.
45
+ """
46
+
47
+ async def _apply_prompt_provider_specific_structured(
48
+ self,
49
+ multipart_messages,
50
+ model,
51
+ request_params=None,
52
+ ):
53
+ """
54
+ Handles structured output for Gemini models using response_schema and response_mime_type.
55
+ """
56
+ import json
57
+
58
+ # Prepare request params
59
+ request_params = self.get_request_params(request_params)
60
+ # Convert Pydantic model to schema dict for Gemini
61
+ schema = None
62
+ try:
63
+ schema = model.model_json_schema()
64
+ except Exception:
65
+ pass
66
+
67
+ # Set up Gemini config for structured output
68
+ def _get_schema_type(model):
69
+ # Try to get the type annotation for the model (for list[...] etc)
70
+ # Fallback to dict schema if not available
71
+ try:
72
+ return model
73
+ except Exception:
74
+ return None
75
+
76
+ # Use the schema as a dict or as a type, as Gemini supports both
77
+ response_schema = _get_schema_type(model)
78
+ if schema is not None:
79
+ response_schema = schema
80
+
81
+ # Set config for structured output
82
+ generate_content_config = self._converter.convert_request_params_to_google_config(
83
+ request_params
84
+ )
85
+ generate_content_config.response_mime_type = "application/json"
86
+ generate_content_config.response_schema = response_schema
87
+
88
+ # Convert messages to google.genai format
89
+ conversation_history = self._converter.convert_to_google_content(multipart_messages)
90
+
91
+ # Call Gemini API
92
+ try:
93
+ api_response = await self._google_client.aio.models.generate_content(
94
+ model=request_params.model,
95
+ contents=conversation_history,
96
+ config=generate_content_config,
97
+ )
98
+ except Exception as e:
99
+ self.logger.error(f"Error during Gemini structured call: {e}")
100
+ # Return None and a dummy assistant message
101
+ return None, Prompt.assistant(f"Error: {e}")
102
+
103
+ # Parse the response as JSON and validate against the model
104
+ if not api_response.candidates or not api_response.candidates[0].content.parts:
105
+ return None, Prompt.assistant("No structured response returned.")
106
+
107
+ # Try to extract the JSON from the first part
108
+ text = None
109
+ for part in api_response.candidates[0].content.parts:
110
+ if part.text:
111
+ text = part.text
112
+ break
113
+ if text is None:
114
+ return None, Prompt.assistant("No structured text returned.")
115
+
116
+ try:
117
+ json_data = json.loads(text)
118
+ validated_model = model.model_validate(json_data)
119
+ # Update LLM history with user and assistant messages for correct history tracking
120
+ # Add user message(s)
121
+ for msg in multipart_messages:
122
+ self.history.append(msg)
123
+ # Add assistant message
124
+ assistant_msg = Prompt.assistant(text)
125
+ self.history.append(assistant_msg)
126
+ return validated_model, assistant_msg
127
+ except Exception as e:
128
+ self.logger.warning(f"Failed to parse structured response: {e}")
129
+ # Still update history for consistency
130
+ for msg in multipart_messages:
131
+ self.history.append(msg)
132
+ assistant_msg = Prompt.assistant(text)
133
+ self.history.append(assistant_msg)
134
+ return None, assistant_msg
135
+
136
+ # Define Google-specific parameter exclusions if necessary
137
+ GOOGLE_EXCLUDE_FIELDS = {
138
+ # Add fields that should not be passed directly from RequestParams to google.genai config
139
+ AugmentedLLM.PARAM_MESSAGES, # Handled by contents
140
+ AugmentedLLM.PARAM_MODEL, # Handled during client/call setup
141
+ AugmentedLLM.PARAM_SYSTEM_PROMPT, # Handled by system_instruction in config
142
+ # AugmentedLLM.PARAM_PARALLEL_TOOL_CALLS, # Handled by tool_config in config
143
+ AugmentedLLM.PARAM_USE_HISTORY, # Handled by AugmentedLLM base / this class's logic
144
+ AugmentedLLM.PARAM_MAX_ITERATIONS, # Handled by this class's loop
145
+ # Add any other OpenAI-specific params not applicable to google.genai
146
+ }.union(AugmentedLLM.BASE_EXCLUDE_FIELDS)
147
+
148
+ def __init__(self, *args, **kwargs) -> None:
149
+ super().__init__(*args, provider=Provider.GOOGLE, **kwargs)
150
+ # Initialize the google.genai client
151
+ self._google_client = self._initialize_google_client()
152
+ # Initialize the converter
153
+ self._converter = GoogleConverter()
154
+
155
+ def _initialize_google_client(self) -> genai.Client:
156
+ """
157
+ Initializes the google.genai client.
158
+
159
+ Reads Google API key or Vertex AI configuration from context config.
160
+ """
161
+ try:
162
+ # Example: Authenticate using API key from config
163
+ api_key = self._api_key() # Assuming _api_key() exists in base class
164
+ if not api_key:
165
+ # Handle case where API key is missing
166
+ raise ProviderKeyError(
167
+ "Google API key not found.", "Please configure your Google API key."
168
+ )
169
+
170
+ # Check for Vertex AI configuration
171
+ if (
172
+ self.context
173
+ and self.context.config
174
+ and hasattr(self.context.config, "google")
175
+ and hasattr(self.context.config.google, "vertex_ai")
176
+ and self.context.config.google.vertex_ai.enabled
177
+ ):
178
+ vertex_config = self.context.config.google.vertex_ai
179
+ return genai.Client(
180
+ vertexai=True,
181
+ project=vertex_config.project_id,
182
+ location=vertex_config.location,
183
+ # Add other Vertex AI specific options if needed
184
+ # http_options=types.HttpOptions(api_version='v1') # Example for v1 API
185
+ )
186
+ else:
187
+ # Default to Gemini Developer API
188
+ return genai.Client(
189
+ api_key=api_key,
190
+ # http_options=types.HttpOptions(api_version='v1') # Example for v1 API
191
+ )
192
+ except Exception as e:
193
+ # Catch potential initialization errors and raise ProviderKeyError
194
+ raise ProviderKeyError("Failed to initialize Google GenAI client.", str(e)) from e
195
+
196
+ def _initialize_default_params(self, kwargs: dict) -> RequestParams:
197
+ """Initialize Google-specific default parameters."""
198
+ chosen_model = kwargs.get("model", DEFAULT_GOOGLE_MODEL)
199
+
200
+ return RequestParams(
201
+ model=chosen_model,
202
+ systemPrompt=self.instruction, # System instruction will be mapped in _google_completion
203
+ parallel_tool_calls=True, # Assume parallel tool calls are supported by default with native API
204
+ max_iterations=20,
205
+ use_history=True,
206
+ # Include other relevant default parameters
207
+ )
208
+
209
+ async def _google_completion(
210
+ self,
211
+ request_params: RequestParams | None = None,
212
+ ) -> List[TextContent | ImageContent | EmbeddedResource]:
213
+ """
214
+ Process a query using Google's generate_content API and available tools.
215
+ """
216
+ request_params = self.get_request_params(request_params=request_params)
217
+ responses: List[TextContent | ImageContent | EmbeddedResource] = []
218
+
219
+ # Load full conversation history if use_history is true
220
+ if request_params.use_history:
221
+ # Get history from self.history and convert to google.genai format
222
+ conversation_history = self._converter.convert_to_google_content(
223
+ self.history.get(include_completion_history=True)
224
+ )
225
+ else:
226
+ # If not using history, start with an empty list
227
+ conversation_history = []
228
+
229
+ self.logger.debug(f"Google completion requested with messages: {conversation_history}")
230
+ self._log_chat_progress(
231
+ self.chat_turn(), model=request_params.model
232
+ ) # Log chat progress at the start of completion
233
+
234
+ # Keep track of the number of messages in history before this turn
235
+ initial_history_length = len(conversation_history)
236
+
237
+ for i in range(request_params.max_iterations):
238
+ # 1. Get available tools
239
+ aggregator_response = await self.aggregator.list_tools()
240
+ available_tools = self._converter.convert_to_google_tools(
241
+ aggregator_response.tools
242
+ ) # Convert fast-agent tools to google.genai tools
243
+
244
+ # 2. Prepare generate_content arguments
245
+ generate_content_config = self._converter.convert_request_params_to_google_config(
246
+ request_params
247
+ )
248
+
249
+ # Add tools and tool_config to generate_content_config if tools are available
250
+ if available_tools:
251
+ generate_content_config.tools = available_tools
252
+ # Set tool_config mode to AUTO to allow the model to decide when to call tools
253
+ generate_content_config.tool_config = types.ToolConfig(
254
+ function_calling_config=types.FunctionCallingConfig(mode="AUTO")
255
+ )
256
+
257
+ # 3. Call the google.genai API
258
+ try:
259
+ # Use the async client
260
+ api_response = await self._google_client.aio.models.generate_content(
261
+ model=request_params.model,
262
+ contents=conversation_history, # Pass the current turn's conversation history
263
+ config=generate_content_config,
264
+ )
265
+ self.logger.debug("Google generate_content response:", data=api_response)
266
+
267
+ except errors.APIError as e:
268
+ # Handle specific Google API errors
269
+ self.logger.error(f"Google API Error: {e.code} - {e.message}")
270
+ raise ProviderKeyError(f"Google API Error: {e.code}", e.message) from e
271
+ except Exception as e:
272
+ self.logger.error(f"Error during Google generate_content call: {e}")
273
+ # Decide how to handle other exceptions - potentially re-raise or return an error message
274
+ raise e
275
+
276
+ # 4. Process the API response
277
+ if not api_response.candidates:
278
+ # No response from the model, we're done
279
+ self.logger.debug(f"Iteration {i}: No candidates returned.")
280
+ break
281
+
282
+ candidate = api_response.candidates[0] # Process the first candidate
283
+
284
+ # Convert the model's response content to fast-agent types
285
+ model_response_content_parts = self._converter.convert_from_google_content(
286
+ candidate.content
287
+ )
288
+
289
+ # Add model's response to conversation history for potential next turn
290
+ # This is for the *internal* conversation history of this completion call
291
+ # to handle multi-turn tool use within one _google_completion call.
292
+ conversation_history.append(candidate.content)
293
+
294
+ # Extract and process text content and tool calls
295
+ assistant_message_parts = []
296
+ tool_calls_to_execute = []
297
+
298
+ for part in model_response_content_parts:
299
+ if isinstance(part, TextContent):
300
+ responses.append(part) # Add text content to the final responses to be returned
301
+ assistant_message_parts.append(
302
+ part
303
+ ) # Collect text for potential assistant message display
304
+ elif isinstance(part, CallToolRequestParams):
305
+ # This is a function call requested by the model
306
+ tool_calls_to_execute.append(part) # Collect tool calls to execute
307
+
308
+ # Display assistant message if there is text content
309
+ if assistant_message_parts:
310
+ # Combine text parts for display
311
+ assistant_text = "".join(
312
+ [p.text for p in assistant_message_parts if isinstance(p, TextContent)]
313
+ )
314
+ # Display the assistant message. If there are tool calls, indicate that.
315
+ if tool_calls_to_execute:
316
+ tool_names = ", ".join([tc.name for tc in tool_calls_to_execute])
317
+ display_text = Text(
318
+ f"{assistant_text}\nAssistant requested tool calls: {tool_names}",
319
+ style="dim green italic",
320
+ )
321
+ await self.show_assistant_message(display_text, tool_names)
322
+ else:
323
+ await self.show_assistant_message(Text(assistant_text))
324
+
325
+ # 5. Handle tool calls if any
326
+ if tool_calls_to_execute:
327
+ tool_results = []
328
+ for tool_call_params in tool_calls_to_execute:
329
+ # Convert to CallToolRequest and execute
330
+ tool_call_request = CallToolRequest(
331
+ method="tools/call", params=tool_call_params
332
+ )
333
+ self.show_tool_call(
334
+ aggregator_response.tools, # Pass fast-agent tool definitions for display
335
+ tool_call_request.params.name,
336
+ str(
337
+ tool_call_request.params.arguments
338
+ ), # Convert dict to string for display
339
+ )
340
+
341
+ # Execute the tool call. google.genai does not provide a tool_call_id, pass None.
342
+ result = await self.call_tool(tool_call_request, None)
343
+ self.show_oai_tool_result(
344
+ str(result.content)
345
+ ) # Use show_oai_tool_result for consistency
346
+
347
+ tool_results.append((tool_call_params.name, result)) # Store name and result
348
+
349
+ # Add tool result content to the overall responses to be returned
350
+ responses.extend(result.content)
351
+
352
+ # Convert tool results back to google.genai format and add to conversation_history for the next turn
353
+ tool_response_google_contents = self._converter.convert_function_results_to_google(
354
+ tool_results
355
+ )
356
+ conversation_history.extend(tool_response_google_contents)
357
+
358
+ self.logger.debug(f"Iteration {i}: Tool call results processed.")
359
+ else:
360
+ # If no tool calls, check finish reason to stop or continue
361
+ # google.genai finish reasons: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
362
+ if candidate.finish_reason in ["STOP", "MAX_TOKENS", "SAFETY"]:
363
+ self.logger.debug(
364
+ f"Iteration {i}: Stopping because finish_reason is '{candidate.finish_reason}'"
365
+ )
366
+ # Display message if stopping due to max tokens
367
+ if (
368
+ candidate.finish_reason == "MAX_TOKENS"
369
+ and request_params
370
+ and request_params.maxTokens is not None
371
+ ):
372
+ message_text = Text(
373
+ f"the assistant has reached the maximum token limit ({request_params.maxTokens})",
374
+ style="dim green italic",
375
+ )
376
+ await self.show_assistant_message(message_text)
377
+ break # Exit the loop if a stopping condition is met
378
+ # If no tool calls and no explicit stopping reason, the model might be done.
379
+ # Break to avoid infinite loops if the model doesn't explicitly stop or call tools.
380
+ self.logger.debug(
381
+ f"Iteration {i}: No tool calls and no explicit stop reason, breaking."
382
+ )
383
+ break
384
+
385
+ # 6. Update history after all iterations are done (or max_iterations reached)
386
+ # Only add the new messages generated during this completion turn to history
387
+ if request_params.use_history:
388
+ new_google_messages = conversation_history[initial_history_length:]
389
+ new_multipart_messages = self._converter.convert_from_google_content_list(
390
+ new_google_messages
391
+ )
392
+ self.history.extend(new_multipart_messages)
393
+
394
+ self._log_chat_finished(model=request_params.model) # Use model from request_params
395
+ return responses # Return the accumulated responses (fast-agent content types)
396
+
397
+ async def _apply_prompt_provider_specific(
398
+ self,
399
+ multipart_messages: List[PromptMessageMultipart],
400
+ request_params: RequestParams | None = None,
401
+ is_template: bool = False,
402
+ ) -> PromptMessageMultipart:
403
+ """
404
+ Applies the prompt messages and potentially calls the LLM for completion.
405
+ """
406
+ request_params = self.get_request_params(
407
+ request_params=request_params
408
+ ) # Get request params
409
+
410
+ # Add incoming messages to history before calling completion
411
+ # This ensures the current user message is part of the history for the API call
412
+ self.history.extend(multipart_messages, is_prompt=is_template)
413
+
414
+ last_message_role = multipart_messages[-1].role if multipart_messages else None
415
+
416
+ if last_message_role == "user":
417
+ # If the last message is from the user, call the LLM for a response
418
+ # _google_completion will now load history internally
419
+ responses = await self._google_completion(request_params=request_params)
420
+
421
+ # History update is now handled within _google_completion
422
+ pass
423
+
424
+ return Prompt.assistant(*responses) # Return combined responses as an assistant message
425
+ else:
426
+ # If the last message is not from the user (e.g., assistant), no completion is needed for this step
427
+ # The messages have already been added to history by the calling code/framework
428
+ return multipart_messages[-1] # Return the last message as is
429
+
430
+ async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest):
431
+ """
432
+ Hook called before a tool call.
433
+
434
+ Args:
435
+ tool_call_id: The ID of the tool call.
436
+ request: The CallToolRequest object.
437
+
438
+ Returns:
439
+ The modified CallToolRequest object.
440
+ """
441
+ # Currently a pass-through, can add Google-specific logic if needed
442
+ return request
443
+
444
+ async def post_tool_call(
445
+ self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
446
+ ):
447
+ """
448
+ Hook called after a tool call.
449
+
450
+ Args:
451
+ tool_call_id: The ID of the tool call.
452
+ request: The original CallToolRequest object.
453
+ result: The CallToolResult object.
454
+
455
+ Returns:
456
+ The modified CallToolResult object.
457
+ """
458
+ # Currently a pass-through, can add Google-specific logic if needed
459
+ return result
@@ -6,9 +6,9 @@ GOOGLE_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
6
6
  DEFAULT_GOOGLE_MODEL = "gemini-2.0-flash"
7
7
 
8
8
 
9
- class GoogleAugmentedLLM(OpenAIAugmentedLLM):
9
+ class GoogleOaiAugmentedLLM(OpenAIAugmentedLLM):
10
10
  def __init__(self, *args, **kwargs) -> None:
11
- super().__init__(*args, provider=Provider.GOOGLE, **kwargs)
11
+ super().__init__(*args, provider=Provider.GOOGLE_OAI, **kwargs)
12
12
 
13
13
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
14
14
  """Initialize Google OpenAI Compatibility default parameters"""