massgen 0.1.0a3__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of massgen might be problematic. Click here for more details.

Files changed (120) hide show
  1. massgen/__init__.py +1 -1
  2. massgen/agent_config.py +17 -0
  3. massgen/api_params_handler/_api_params_handler_base.py +1 -0
  4. massgen/api_params_handler/_chat_completions_api_params_handler.py +15 -2
  5. massgen/api_params_handler/_claude_api_params_handler.py +8 -1
  6. massgen/api_params_handler/_gemini_api_params_handler.py +73 -0
  7. massgen/api_params_handler/_response_api_params_handler.py +8 -1
  8. massgen/backend/base.py +83 -0
  9. massgen/backend/{base_with_mcp.py → base_with_custom_tool_and_mcp.py} +286 -15
  10. massgen/backend/capabilities.py +6 -6
  11. massgen/backend/chat_completions.py +200 -103
  12. massgen/backend/claude.py +115 -18
  13. massgen/backend/claude_code.py +378 -14
  14. massgen/backend/docs/CLAUDE_API_RESEARCH.md +3 -3
  15. massgen/backend/gemini.py +1333 -1629
  16. massgen/backend/gemini_mcp_manager.py +545 -0
  17. massgen/backend/gemini_trackers.py +344 -0
  18. massgen/backend/gemini_utils.py +43 -0
  19. massgen/backend/grok.py +39 -6
  20. massgen/backend/response.py +147 -81
  21. massgen/cli.py +605 -110
  22. massgen/config_builder.py +376 -27
  23. massgen/configs/README.md +123 -80
  24. massgen/configs/basic/multi/three_agents_default.yaml +3 -3
  25. massgen/configs/basic/single/single_agent.yaml +1 -1
  26. massgen/configs/providers/openai/gpt5_nano.yaml +3 -3
  27. massgen/configs/tools/custom_tools/claude_code_custom_tool_example.yaml +32 -0
  28. massgen/configs/tools/custom_tools/claude_code_custom_tool_example_no_path.yaml +28 -0
  29. massgen/configs/tools/custom_tools/claude_code_custom_tool_with_mcp_example.yaml +40 -0
  30. massgen/configs/tools/custom_tools/claude_code_custom_tool_with_wrong_mcp_example.yaml +38 -0
  31. massgen/configs/tools/custom_tools/claude_code_wrong_custom_tool_with_mcp_example.yaml +38 -0
  32. massgen/configs/tools/custom_tools/claude_custom_tool_example.yaml +24 -0
  33. massgen/configs/tools/custom_tools/claude_custom_tool_example_no_path.yaml +22 -0
  34. massgen/configs/tools/custom_tools/claude_custom_tool_with_mcp_example.yaml +35 -0
  35. massgen/configs/tools/custom_tools/claude_custom_tool_with_wrong_mcp_example.yaml +33 -0
  36. massgen/configs/tools/custom_tools/claude_wrong_custom_tool_with_mcp_example.yaml +33 -0
  37. massgen/configs/tools/custom_tools/gemini_custom_tool_example.yaml +24 -0
  38. massgen/configs/tools/custom_tools/gemini_custom_tool_example_no_path.yaml +22 -0
  39. massgen/configs/tools/custom_tools/gemini_custom_tool_with_mcp_example.yaml +35 -0
  40. massgen/configs/tools/custom_tools/gemini_custom_tool_with_wrong_mcp_example.yaml +33 -0
  41. massgen/configs/tools/custom_tools/gemini_wrong_custom_tool_with_mcp_example.yaml +33 -0
  42. massgen/configs/tools/custom_tools/github_issue_market_analysis.yaml +94 -0
  43. massgen/configs/tools/custom_tools/gpt5_nano_custom_tool_example.yaml +24 -0
  44. massgen/configs/tools/custom_tools/gpt5_nano_custom_tool_example_no_path.yaml +22 -0
  45. massgen/configs/tools/custom_tools/gpt5_nano_custom_tool_with_mcp_example.yaml +35 -0
  46. massgen/configs/tools/custom_tools/gpt5_nano_custom_tool_with_wrong_mcp_example.yaml +33 -0
  47. massgen/configs/tools/custom_tools/gpt5_nano_wrong_custom_tool_with_mcp_example.yaml +33 -0
  48. massgen/configs/tools/custom_tools/gpt_oss_custom_tool_example.yaml +25 -0
  49. massgen/configs/tools/custom_tools/gpt_oss_custom_tool_example_no_path.yaml +23 -0
  50. massgen/configs/tools/custom_tools/gpt_oss_custom_tool_with_mcp_example.yaml +34 -0
  51. massgen/configs/tools/custom_tools/gpt_oss_custom_tool_with_wrong_mcp_example.yaml +34 -0
  52. massgen/configs/tools/custom_tools/gpt_oss_wrong_custom_tool_with_mcp_example.yaml +34 -0
  53. massgen/configs/tools/custom_tools/grok3_mini_custom_tool_example.yaml +24 -0
  54. massgen/configs/tools/custom_tools/grok3_mini_custom_tool_example_no_path.yaml +22 -0
  55. massgen/configs/tools/custom_tools/grok3_mini_custom_tool_with_mcp_example.yaml +35 -0
  56. massgen/configs/tools/custom_tools/grok3_mini_custom_tool_with_wrong_mcp_example.yaml +33 -0
  57. massgen/configs/tools/custom_tools/grok3_mini_wrong_custom_tool_with_mcp_example.yaml +33 -0
  58. massgen/configs/tools/custom_tools/qwen_api_custom_tool_example.yaml +25 -0
  59. massgen/configs/tools/custom_tools/qwen_api_custom_tool_example_no_path.yaml +23 -0
  60. massgen/configs/tools/custom_tools/qwen_api_custom_tool_with_mcp_example.yaml +36 -0
  61. massgen/configs/tools/custom_tools/qwen_api_custom_tool_with_wrong_mcp_example.yaml +34 -0
  62. massgen/configs/tools/custom_tools/qwen_api_wrong_custom_tool_with_mcp_example.yaml +34 -0
  63. massgen/configs/tools/custom_tools/qwen_local_custom_tool_example.yaml +24 -0
  64. massgen/configs/tools/custom_tools/qwen_local_custom_tool_example_no_path.yaml +22 -0
  65. massgen/configs/tools/custom_tools/qwen_local_custom_tool_with_mcp_example.yaml +35 -0
  66. massgen/configs/tools/custom_tools/qwen_local_custom_tool_with_wrong_mcp_example.yaml +33 -0
  67. massgen/configs/tools/custom_tools/qwen_local_wrong_custom_tool_with_mcp_example.yaml +33 -0
  68. massgen/configs/tools/filesystem/claude_code_context_sharing.yaml +1 -1
  69. massgen/configs/tools/planning/five_agents_discord_mcp_planning_mode.yaml +7 -29
  70. massgen/configs/tools/planning/five_agents_filesystem_mcp_planning_mode.yaml +5 -6
  71. massgen/configs/tools/planning/five_agents_notion_mcp_planning_mode.yaml +4 -4
  72. massgen/configs/tools/planning/five_agents_twitter_mcp_planning_mode.yaml +4 -4
  73. massgen/configs/tools/planning/gpt5_mini_case_study_mcp_planning_mode.yaml +2 -2
  74. massgen/configs/voting/gemini_gpt_voting_sensitivity.yaml +67 -0
  75. massgen/formatter/_chat_completions_formatter.py +104 -0
  76. massgen/formatter/_claude_formatter.py +120 -0
  77. massgen/formatter/_gemini_formatter.py +448 -0
  78. massgen/formatter/_response_formatter.py +88 -0
  79. massgen/frontend/coordination_ui.py +4 -2
  80. massgen/logger_config.py +35 -3
  81. massgen/message_templates.py +56 -6
  82. massgen/orchestrator.py +512 -16
  83. massgen/stream_chunk/base.py +3 -0
  84. massgen/tests/custom_tools_example.py +392 -0
  85. massgen/tests/mcp_test_server.py +17 -7
  86. massgen/tests/test_config_builder.py +423 -0
  87. massgen/tests/test_custom_tools.py +401 -0
  88. massgen/tests/test_intelligent_planning_mode.py +643 -0
  89. massgen/tests/test_tools.py +127 -0
  90. massgen/token_manager/token_manager.py +13 -4
  91. massgen/tool/README.md +935 -0
  92. massgen/tool/__init__.py +39 -0
  93. massgen/tool/_async_helpers.py +70 -0
  94. massgen/tool/_basic/__init__.py +8 -0
  95. massgen/tool/_basic/_two_num_tool.py +24 -0
  96. massgen/tool/_code_executors/__init__.py +10 -0
  97. massgen/tool/_code_executors/_python_executor.py +74 -0
  98. massgen/tool/_code_executors/_shell_executor.py +61 -0
  99. massgen/tool/_exceptions.py +39 -0
  100. massgen/tool/_file_handlers/__init__.py +10 -0
  101. massgen/tool/_file_handlers/_file_operations.py +218 -0
  102. massgen/tool/_manager.py +634 -0
  103. massgen/tool/_registered_tool.py +88 -0
  104. massgen/tool/_result.py +66 -0
  105. massgen/tool/_self_evolution/_github_issue_analyzer.py +369 -0
  106. massgen/tool/docs/builtin_tools.md +681 -0
  107. massgen/tool/docs/exceptions.md +794 -0
  108. massgen/tool/docs/execution_results.md +691 -0
  109. massgen/tool/docs/manager.md +887 -0
  110. massgen/tool/docs/workflow_toolkits.md +529 -0
  111. massgen/tool/workflow_toolkits/__init__.py +57 -0
  112. massgen/tool/workflow_toolkits/base.py +55 -0
  113. massgen/tool/workflow_toolkits/new_answer.py +126 -0
  114. massgen/tool/workflow_toolkits/vote.py +167 -0
  115. {massgen-0.1.0a3.dist-info → massgen-0.1.2.dist-info}/METADATA +87 -129
  116. {massgen-0.1.0a3.dist-info → massgen-0.1.2.dist-info}/RECORD +120 -44
  117. {massgen-0.1.0a3.dist-info → massgen-0.1.2.dist-info}/WHEEL +0 -0
  118. {massgen-0.1.0a3.dist-info → massgen-0.1.2.dist-info}/entry_points.txt +0 -0
  119. {massgen-0.1.0a3.dist-info → massgen-0.1.2.dist-info}/licenses/LICENSE +0 -0
  120. {massgen-0.1.0a3.dist-info → massgen-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,344 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ MCP tracking utilities for the Gemini backend, handling deduplication across streaming chunks and extraction from SDK objects.
4
+ """
5
+
6
+ import hashlib
7
+ import json
8
+ import time
9
+ from typing import Any, Dict, Optional
10
+
11
+
12
+ class MCPResponseTracker:
13
+ """
14
+ Tracks MCP tool responses across streaming chunks to handle deduplication.
15
+
16
+ Similar to MCPCallTracker but for tracking tool responses to avoid duplicate output.
17
+ """
18
+
19
+ def __init__(self):
20
+ """Initialize the tracker with empty storage."""
21
+ self.processed_responses = set() # Store hashes of processed responses
22
+ self.response_history = [] # Store all unique responses with timestamps
23
+
24
+ def get_response_hash(self, tool_name: str, tool_response: Any) -> str:
25
+ """
26
+ Generate a unique hash for a tool response based on name and response content.
27
+
28
+ Args:
29
+ tool_name: Name of the tool that responded
30
+ tool_response: Response from the tool
31
+
32
+ Returns:
33
+ MD5 hash string identifying this specific response
34
+ """
35
+ # Create a deterministic string representation
36
+ content = f"{tool_name}:{str(tool_response)}"
37
+ return hashlib.md5(content.encode()).hexdigest()
38
+
39
+ def is_new_response(self, tool_name: str, tool_response: Any) -> bool:
40
+ """
41
+ Check if this is a new tool response we haven't seen before.
42
+
43
+ Args:
44
+ tool_name: Name of the tool that responded
45
+ tool_response: Response from the tool
46
+
47
+ Returns:
48
+ True if this is a new response, False if already processed
49
+ """
50
+ response_hash = self.get_response_hash(tool_name, tool_response)
51
+ return response_hash not in self.processed_responses
52
+
53
+ def add_response(self, tool_name: str, tool_response: Any) -> Dict[str, Any]:
54
+ """
55
+ Add a new response to the tracker.
56
+
57
+ Args:
58
+ tool_name: Name of the tool that responded
59
+ tool_response: Response from the tool
60
+
61
+ Returns:
62
+ Dictionary containing response details and timestamp
63
+ """
64
+ response_hash = self.get_response_hash(tool_name, tool_response)
65
+ self.processed_responses.add(response_hash)
66
+
67
+ record = {
68
+ "tool_name": tool_name,
69
+ "response": tool_response,
70
+ "hash": response_hash,
71
+ "timestamp": time.time(),
72
+ }
73
+ self.response_history.append(record)
74
+ return record
75
+
76
+
77
+ class MCPCallTracker:
78
+ """
79
+ Tracks MCP tool calls across streaming chunks to handle deduplication.
80
+
81
+ Uses hashing to identify unique tool calls and timestamps to track when they occurred.
82
+ This ensures we don't double-count the same tool call appearing in multiple chunks.
83
+ """
84
+
85
+ def __init__(self):
86
+ """Initialize the tracker with empty storage."""
87
+ self.processed_calls = set() # Store hashes of processed calls
88
+ self.call_history = [] # Store all unique calls with timestamps
89
+ self.last_chunk_calls = [] # Track calls from the last chunk for deduplication
90
+ self.dedup_window = 0.5 # Time window in seconds for deduplication
91
+
92
+ def get_call_hash(self, tool_name: str, tool_args: Dict[str, Any]) -> str:
93
+ """
94
+ Generate a unique hash for a tool call based on name and arguments.
95
+
96
+ Args:
97
+ tool_name: Name of the tool being called
98
+ tool_args: Arguments passed to the tool
99
+
100
+ Returns:
101
+ MD5 hash string identifying this specific call
102
+ """
103
+ # Create a deterministic string representation
104
+ content = f"{tool_name}:{json.dumps(tool_args, sort_keys=True)}"
105
+ return hashlib.md5(content.encode()).hexdigest()
106
+
107
+ def is_new_call(self, tool_name: str, tool_args: Dict[str, Any]) -> bool:
108
+ """
109
+ Check if this is a new tool call we haven't seen before.
110
+
111
+ Uses a time-window based approach: identical calls within the dedup_window
112
+ are considered duplicates (likely from streaming chunks), while those outside
113
+ the window are considered new calls (likely intentional repeated calls).
114
+
115
+ Args:
116
+ tool_name: Name of the tool being called
117
+ tool_args: Arguments passed to the tool
118
+
119
+ Returns:
120
+ True if this is a new call, False if we've seen it before
121
+ """
122
+ call_hash = self.get_call_hash(tool_name, tool_args)
123
+ current_time = time.time()
124
+
125
+ # Check if this call exists in recent history within the dedup window
126
+ for call in self.call_history[-10:]: # Check last 10 calls for efficiency
127
+ if call.get("hash") == call_hash:
128
+ time_diff = current_time - call.get("timestamp", 0)
129
+ if time_diff < self.dedup_window:
130
+ # This is likely a duplicate from streaming chunks
131
+ return False
132
+ # If outside the window, treat as a new intentional call
133
+
134
+ # Mark as processed
135
+ self.processed_calls.add(call_hash)
136
+ return True
137
+
138
+ def add_call(self, tool_name: str, tool_args: Dict[str, Any]) -> Dict[str, Any]:
139
+ """
140
+ Add a new tool call to the history.
141
+
142
+ Args:
143
+ tool_name: Name of the tool being called
144
+ tool_args: Arguments passed to the tool
145
+
146
+ Returns:
147
+ Dictionary containing the call details with timestamp and hash
148
+ """
149
+ call_record = {
150
+ "name": tool_name,
151
+ "arguments": tool_args,
152
+ "timestamp": time.time(),
153
+ "hash": self.get_call_hash(tool_name, tool_args),
154
+ "sequence": len(self.call_history), # Add sequence number for ordering
155
+ }
156
+ self.call_history.append(call_record)
157
+
158
+ # Clean up old history to prevent memory growth
159
+ if len(self.call_history) > 100:
160
+ self.call_history = self.call_history[-50:]
161
+
162
+ return call_record
163
+
164
+ def get_summary(self) -> str:
165
+ """
166
+ Get a summary of all tracked tool calls.
167
+
168
+ Returns:
169
+ Human-readable summary of tool usage
170
+ """
171
+ if not self.call_history:
172
+ return "No MCP tools called"
173
+
174
+ tool_names = [call["name"] for call in self.call_history]
175
+ unique_tools = list(dict.fromkeys(tool_names)) # Preserve order
176
+ return f"Used {len(self.call_history)} MCP tool calls: {', '.join(unique_tools)}"
177
+
178
+
179
+ class MCPResponseExtractor:
180
+ """
181
+ Extracts MCP tool calls and responses from Gemini SDK stream chunks.
182
+
183
+ This class parses the internal SDK chunks to capture:
184
+ - function_call parts (tool invocations)
185
+ - function_response parts (tool results)
186
+ - Paired call-response data for tracking complete tool executions
187
+ """
188
+
189
+ def __init__(self):
190
+ """Initialize the extractor with empty storage."""
191
+ self.mcp_calls = [] # All tool calls
192
+ self.mcp_responses = [] # All tool responses
193
+ self.call_response_pairs = [] # Matched call-response pairs
194
+ self._pending_call = None # Track current call awaiting response
195
+
196
+ def extract_function_call(self, function_call) -> Optional[Dict[str, Any]]:
197
+ """
198
+ Extract tool call information from SDK function_call object.
199
+
200
+ Tries multiple methods to extract data from different SDK versions:
201
+ 1. Direct attributes (name, args)
202
+ 2. Dictionary-like interface (get method)
203
+ 3. __dict__ attributes
204
+ 4. Protobuf _pb attributes
205
+ """
206
+ tool_name = None
207
+ tool_args = None
208
+
209
+ # Method 1: Direct attributes
210
+ tool_name = getattr(function_call, "name", None)
211
+ tool_args = getattr(function_call, "args", None)
212
+
213
+ # Method 2: Dictionary-like object
214
+ if tool_name is None:
215
+ try:
216
+ if hasattr(function_call, "get"):
217
+ tool_name = function_call.get("name", None)
218
+ tool_args = function_call.get("args", None)
219
+ except Exception:
220
+ pass
221
+
222
+ # Method 3: __dict__ inspection
223
+ if tool_name is None:
224
+ try:
225
+ if hasattr(function_call, "__dict__"):
226
+ fc_dict = function_call.__dict__
227
+ tool_name = fc_dict.get("name", None)
228
+ tool_args = fc_dict.get("args", None)
229
+ except Exception:
230
+ pass
231
+
232
+ # Method 4: Protobuf _pb attribute
233
+ if tool_name is None:
234
+ try:
235
+ if hasattr(function_call, "_pb"):
236
+ pb = function_call._pb
237
+ if hasattr(pb, "name"):
238
+ tool_name = pb.name
239
+ if hasattr(pb, "args"):
240
+ tool_args = pb.args
241
+ except Exception:
242
+ pass
243
+
244
+ if tool_name:
245
+ call_data = {
246
+ "name": tool_name,
247
+ "arguments": tool_args or {},
248
+ "timestamp": time.time(),
249
+ "raw": str(function_call)[:200], # Truncate for logging
250
+ }
251
+ self.mcp_calls.append(call_data)
252
+ self._pending_call = call_data
253
+ return call_data
254
+
255
+ return None
256
+
257
+ def extract_function_response(self, function_response) -> Optional[Dict[str, Any]]:
258
+ """
259
+ Extract tool response information from SDK function_response object.
260
+
261
+ Uses same extraction methods as function_call for consistency.
262
+ """
263
+ tool_name = None
264
+ tool_response = None
265
+
266
+ # Method 1: Direct attributes
267
+ tool_name = getattr(function_response, "name", None)
268
+ tool_response = getattr(function_response, "response", None)
269
+
270
+ # Method 2: Dictionary-like object
271
+ if tool_name is None:
272
+ try:
273
+ if hasattr(function_response, "get"):
274
+ tool_name = function_response.get("name", None)
275
+ tool_response = function_response.get("response", None)
276
+ except Exception:
277
+ pass
278
+
279
+ # Method 3: __dict__ inspection
280
+ if tool_name is None:
281
+ try:
282
+ if hasattr(function_response, "__dict__"):
283
+ fr_dict = function_response.__dict__
284
+ tool_name = fr_dict.get("name", None)
285
+ tool_response = fr_dict.get("response", None)
286
+ except Exception:
287
+ pass
288
+
289
+ # Method 4: Protobuf _pb attribute
290
+ if tool_name is None:
291
+ try:
292
+ if hasattr(function_response, "_pb"):
293
+ pb = function_response._pb
294
+ if hasattr(pb, "name"):
295
+ tool_name = pb.name
296
+ if hasattr(pb, "response"):
297
+ tool_response = pb.response
298
+ except Exception:
299
+ pass
300
+
301
+ if tool_name:
302
+ response_data = {
303
+ "name": tool_name,
304
+ "response": tool_response or {},
305
+ "timestamp": time.time(),
306
+ "raw": str(function_response)[:500], # Truncate for logging
307
+ }
308
+ self.mcp_responses.append(response_data)
309
+
310
+ # Pair with pending call if names match
311
+ if self._pending_call and self._pending_call["name"] == tool_name:
312
+ self.call_response_pairs.append(
313
+ {
314
+ "call": self._pending_call,
315
+ "response": response_data,
316
+ "duration": response_data["timestamp"] - self._pending_call["timestamp"],
317
+ "paired_at": time.time(),
318
+ },
319
+ )
320
+ self._pending_call = None
321
+
322
+ return response_data
323
+
324
+ return None
325
+
326
+ def get_summary(self) -> Dict[str, Any]:
327
+ """
328
+ Get a summary of all extracted MCP tool interactions.
329
+ """
330
+ return {
331
+ "total_calls": len(self.mcp_calls),
332
+ "total_responses": len(self.mcp_responses),
333
+ "paired_interactions": len(self.call_response_pairs),
334
+ "pending_call": self._pending_call is not None,
335
+ "tool_names": list(set(call["name"] for call in self.mcp_calls)),
336
+ "average_duration": (sum(pair["duration"] for pair in self.call_response_pairs) / len(self.call_response_pairs) if self.call_response_pairs else 0),
337
+ }
338
+
339
+ def clear(self):
340
+ """Clear all stored data."""
341
+ self.mcp_calls.clear()
342
+ self.mcp_responses.clear()
343
+ self.call_response_pairs.clear()
344
+ self._pending_call = None
@@ -0,0 +1,43 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Gemini-specific structured output models for coordination actions (voting and answer submission).
4
+ """
5
+
6
+ import enum
7
+ from typing import Optional
8
+
9
+ try:
10
+ from pydantic import BaseModel, Field
11
+ except ImportError:
12
+ BaseModel = None
13
+ Field = None
14
+
15
+
16
+ class ActionType(enum.Enum):
17
+ """Action types for structured output."""
18
+
19
+ VOTE = "vote"
20
+ NEW_ANSWER = "new_answer"
21
+
22
+
23
+ class VoteAction(BaseModel):
24
+ """Structured output for voting action."""
25
+
26
+ action: ActionType = Field(default=ActionType.VOTE, description="Action type")
27
+ agent_id: str = Field(description="Anonymous agent ID to vote for (e.g., 'agent1', 'agent2')")
28
+ reason: str = Field(description="Brief reason why this agent has the best answer")
29
+
30
+
31
+ class NewAnswerAction(BaseModel):
32
+ """Structured output for new answer action."""
33
+
34
+ action: ActionType = Field(default=ActionType.NEW_ANSWER, description="Action type")
35
+ content: str = Field(description="Your improved answer. If any builtin tools like search or code execution were used, include how they are used here.")
36
+
37
+
38
+ class CoordinationResponse(BaseModel):
39
+ """Structured response for coordination actions."""
40
+
41
+ action_type: ActionType = Field(description="Type of action to take")
42
+ vote_data: Optional[VoteAction] = Field(default=None, description="Vote data if action is vote")
43
+ answer_data: Optional[NewAnswerAction] = Field(default=None, description="Answer data if action is new_answer")
massgen/backend/grok.py CHANGED
@@ -18,6 +18,7 @@ TODO for future releases:
18
18
  # -*- coding: utf-8 -*-
19
19
  from __future__ import annotations
20
20
 
21
+ import logging
21
22
  import os
22
23
  from typing import Any, Dict, List, Optional
23
24
 
@@ -26,6 +27,8 @@ from openai import AsyncOpenAI
26
27
  from ..logger_config import log_stream_chunk
27
28
  from .chat_completions import ChatCompletionsBackend
28
29
 
30
+ logger = logging.getLogger(__name__)
31
+
29
32
 
30
33
  class GrokBackend(ChatCompletionsBackend):
31
34
  """Grok backend using xAI's OpenAI-compatible API."""
@@ -41,12 +44,10 @@ class GrokBackend(ChatCompletionsBackend):
41
44
 
42
45
  return openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
43
46
 
44
- def _build_base_api_params(self, messages: List[Dict[str, Any]], all_params: Dict[str, Any]) -> Dict[str, Any]:
45
- """Build base API params for xAI's Grok API."""
46
- api_params = super()._build_base_api_params(messages, all_params)
47
-
48
- # Add Live Search parameters if enabled (Grok-specific)
47
+ def _add_grok_search_params(self, api_params: Dict[str, Any], all_params: Dict[str, Any]) -> Dict[str, Any]:
48
+ """Add Grok Live Search parameters to API params if web search is enabled."""
49
49
  enable_web_search = all_params.get("enable_web_search", False)
50
+
50
51
  if enable_web_search:
51
52
  # Check for conflict with manually specified search_parameters
52
53
  existing_extra = api_params.get("extra_body", {})
@@ -54,14 +55,46 @@ class GrokBackend(ChatCompletionsBackend):
54
55
  error_message = "Conflict: Cannot use both 'enable_web_search: true' and manual 'extra_body.search_parameters'. Use one or the other."
55
56
  log_stream_chunk("backend.grok", "error", error_message, self.agent_id)
56
57
  raise ValueError(error_message)
58
+
57
59
  # Merge search_parameters into existing extra_body
58
60
  search_params = {"mode": "auto", "return_citations": True}
59
- merged_extra = existing_extra.copy()
61
+ merged_extra = existing_extra.copy() if existing_extra else {}
60
62
  merged_extra["search_parameters"] = search_params
61
63
  api_params["extra_body"] = merged_extra
62
64
 
63
65
  return api_params
64
66
 
67
+ async def _stream_with_custom_and_mcp_tools(
68
+ self,
69
+ current_messages: List[Dict[str, Any]],
70
+ tools: List[Dict[str, Any]],
71
+ client,
72
+ **kwargs,
73
+ ):
74
+ """Override to add Grok-specific search parameters before API call."""
75
+ # Build API params using parent method
76
+ all_params = {**self.config, **kwargs}
77
+ api_params = await self.api_params_handler.build_api_params(current_messages, tools, all_params)
78
+
79
+ # Add provider tools (web search, code interpreter) if enabled
80
+ # Note: For Grok, get_provider_tools() won't add web_search function tool
81
+ provider_tools = self.api_params_handler.get_provider_tools(all_params)
82
+
83
+ if provider_tools:
84
+ if "tools" not in api_params:
85
+ api_params["tools"] = []
86
+ api_params["tools"].extend(provider_tools)
87
+
88
+ # Add Grok-specific web search parameters via extra_body
89
+ api_params = self._add_grok_search_params(api_params, all_params)
90
+
91
+ # Start streaming
92
+ stream = await client.chat.completions.create(**api_params)
93
+
94
+ # Delegate to parent's stream processing
95
+ async for chunk in super()._process_stream(stream, all_params, self.agent_id):
96
+ yield chunk
97
+
65
98
  def get_provider_name(self) -> str:
66
99
  """Get the name of this provider."""
67
100
  return "Grok"