mbxai 0.6.10__tar.gz → 0.6.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mbxai-0.6.10 → mbxai-0.6.11}/PKG-INFO +1 -1
- {mbxai-0.6.10 → mbxai-0.6.11}/pyproject.toml +1 -1
- {mbxai-0.6.10 → mbxai-0.6.11}/setup.py +1 -1
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/__init__.py +1 -1
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/mcp/server.py +1 -1
- mbxai-0.6.11/src/mbxai/tools/client.py +425 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/uv.lock +7 -7
- mbxai-0.6.10/src/mbxai/tools/client.py +0 -305
- {mbxai-0.6.10 → mbxai-0.6.11}/.gitignore +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/LICENSE +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/README.md +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/core.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/mcp/__init__.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/mcp/client.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/mcp/example.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/openrouter/__init__.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/openrouter/client.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/openrouter/config.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/openrouter/models.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/tools/__init__.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/tools/example.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/src/mbxai/tools/types.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/tests/test_core.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/tests/test_mcp.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/tests/test_openrouter.py +0 -0
- {mbxai-0.6.10 → mbxai-0.6.11}/tests/test_tools.py +0 -0
@@ -0,0 +1,425 @@
|
|
1
|
+
"""
|
2
|
+
ToolClient implementation for MBX AI.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Any, Callable, TypeVar, cast
|
6
|
+
import logging
|
7
|
+
import inspect
|
8
|
+
import json
|
9
|
+
from pydantic import BaseModel
|
10
|
+
from ..openrouter import OpenRouterClient
|
11
|
+
from .types import Tool, ToolCall
|
12
|
+
import asyncio
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
T = TypeVar("T", bound=BaseModel)
|
17
|
+
|
18
|
+
class ToolClient:
|
19
|
+
"""Client for handling tool calls with OpenRouter."""
|
20
|
+
|
21
|
+
def __init__(self, openrouter_client: OpenRouterClient) -> None:
|
22
|
+
"""Initialize the ToolClient.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
openrouter_client: The OpenRouter client to use
|
26
|
+
"""
|
27
|
+
self._client = openrouter_client
|
28
|
+
self._tools: dict[str, Tool] = {}
|
29
|
+
|
30
|
+
def register_tool(
|
31
|
+
self,
|
32
|
+
name: str,
|
33
|
+
description: str,
|
34
|
+
function: Callable[..., Any],
|
35
|
+
schema: dict[str, Any],
|
36
|
+
) -> None:
|
37
|
+
"""Register a new tool.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
name: The name of the tool
|
41
|
+
description: A description of what the tool does
|
42
|
+
function: The function to call when the tool is used
|
43
|
+
schema: The JSON schema for the tool's parameters
|
44
|
+
"""
|
45
|
+
tool = Tool(
|
46
|
+
name=name,
|
47
|
+
description=description,
|
48
|
+
function=function,
|
49
|
+
schema=schema,
|
50
|
+
)
|
51
|
+
self._tools[name] = tool
|
52
|
+
logger.info(f"Registered tool: {name}")
|
53
|
+
|
54
|
+
def _truncate_content(self, content: str | None, max_length: int = 100) -> str:
|
55
|
+
"""Truncate content for logging."""
|
56
|
+
if not content:
|
57
|
+
return "None"
|
58
|
+
if len(content) <= max_length:
|
59
|
+
return content
|
60
|
+
return content[:max_length] + "..."
|
61
|
+
|
62
|
+
def _truncate_dict(self, data: dict[str, Any], max_length: int = 50) -> str:
|
63
|
+
"""Truncate dictionary values for logging."""
|
64
|
+
if not data:
|
65
|
+
return "{}"
|
66
|
+
truncated = {}
|
67
|
+
for k, v in data.items():
|
68
|
+
if isinstance(v, str):
|
69
|
+
truncated[k] = self._truncate_content(v, max_length)
|
70
|
+
elif isinstance(v, dict):
|
71
|
+
truncated[k] = self._truncate_dict(v, max_length)
|
72
|
+
else:
|
73
|
+
truncated[k] = str(v)[:max_length] + "..." if len(str(v)) > max_length else v
|
74
|
+
return str(truncated)
|
75
|
+
|
76
|
+
def _validate_message_sequence(self, messages: list[dict[str, Any]], validate_responses: bool = True) -> None:
|
77
|
+
"""Validate the message sequence for tool calls and responses.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
messages: The message sequence to validate
|
81
|
+
validate_responses: Whether to validate that all tool calls have responses
|
82
|
+
"""
|
83
|
+
tool_call_ids = set()
|
84
|
+
tool_response_ids = set()
|
85
|
+
|
86
|
+
for i, msg in enumerate(messages):
|
87
|
+
role = msg.get("role")
|
88
|
+
if role == "assistant" and "tool_calls" in msg:
|
89
|
+
# Track tool calls
|
90
|
+
for tc in msg["tool_calls"]:
|
91
|
+
tool_call_ids.add(tc["id"])
|
92
|
+
logger.info(f"Found tool call {tc['id']} for {tc['function']['name']} in message {i}")
|
93
|
+
elif role == "tool":
|
94
|
+
# Track tool responses
|
95
|
+
tool_response_ids.add(msg["tool_call_id"])
|
96
|
+
logger.info(f"Found tool response for call ID {msg['tool_call_id']} in message {i}")
|
97
|
+
|
98
|
+
# Only validate responses if requested
|
99
|
+
if validate_responses:
|
100
|
+
# Check for missing responses
|
101
|
+
missing_responses = tool_call_ids - tool_response_ids
|
102
|
+
if missing_responses:
|
103
|
+
logger.error(f"Missing tool responses for call IDs: {missing_responses}")
|
104
|
+
logger.error("Message sequence:")
|
105
|
+
for i, msg in enumerate(messages):
|
106
|
+
role = msg.get("role", "unknown")
|
107
|
+
if role == "assistant" and "tool_calls" in msg:
|
108
|
+
logger.error(f" Message {i} - Assistant with tool calls: {[tc['id'] for tc in msg['tool_calls']]}")
|
109
|
+
elif role == "tool":
|
110
|
+
logger.error(f" Message {i} - Tool response for call ID: {msg['tool_call_id']}")
|
111
|
+
else:
|
112
|
+
logger.error(f" Message {i} - {role}: {self._truncate_content(msg.get('content'))}")
|
113
|
+
raise ValueError(f"Invalid message sequence: missing responses for tool calls {missing_responses}")
|
114
|
+
|
115
|
+
def _log_messages(self, messages: list[dict[str, Any]], validate_responses: bool = True) -> None:
|
116
|
+
"""Log the messages being sent to OpenRouter.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
messages: The messages to log
|
120
|
+
validate_responses: Whether to validate that all tool calls have responses
|
121
|
+
"""
|
122
|
+
logger.info("Sending messages to OpenRouter:")
|
123
|
+
for i, msg in enumerate(messages):
|
124
|
+
role = msg.get("role", "unknown")
|
125
|
+
content = self._truncate_content(msg.get("content"))
|
126
|
+
tool_calls = msg.get("tool_calls", [])
|
127
|
+
tool_call_id = msg.get("tool_call_id")
|
128
|
+
|
129
|
+
if tool_calls:
|
130
|
+
tool_call_info = [
|
131
|
+
f"{tc['function']['name']}(id={tc['id']})"
|
132
|
+
for tc in tool_calls
|
133
|
+
]
|
134
|
+
logger.info(f" Message {i} - {role}: content='{content}', tool_calls={tool_call_info}")
|
135
|
+
elif tool_call_id:
|
136
|
+
logger.info(f" Message {i} - {role}: content='{content}', tool_call_id={tool_call_id}")
|
137
|
+
else:
|
138
|
+
logger.info(f" Message {i} - {role}: content='{content}'")
|
139
|
+
|
140
|
+
# Validate message sequence
|
141
|
+
self._validate_message_sequence(messages, validate_responses)
|
142
|
+
|
143
|
+
async def _process_tool_calls(self, message: Any, messages: list[dict[str, Any]]) -> None:
|
144
|
+
"""Process all tool calls in a message.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
message: The message containing tool calls
|
148
|
+
messages: The list of messages to add responses to
|
149
|
+
"""
|
150
|
+
if not message.tool_calls:
|
151
|
+
return
|
152
|
+
|
153
|
+
# Process all tool calls first
|
154
|
+
tool_responses = []
|
155
|
+
for tool_call in message.tool_calls:
|
156
|
+
tool = self._tools.get(tool_call.function.name)
|
157
|
+
if not tool:
|
158
|
+
raise ValueError(f"Unknown tool: {tool_call.function.name}")
|
159
|
+
|
160
|
+
# Parse arguments if they're a string
|
161
|
+
arguments = tool_call.function.arguments
|
162
|
+
if isinstance(arguments, str):
|
163
|
+
try:
|
164
|
+
arguments = json.loads(arguments)
|
165
|
+
except json.JSONDecodeError as e:
|
166
|
+
logger.error(f"Failed to parse tool arguments: {e}")
|
167
|
+
raise ValueError(f"Invalid tool arguments format: {arguments}")
|
168
|
+
|
169
|
+
# Call the tool
|
170
|
+
logger.info(f"Calling tool: {tool.name} with args: {self._truncate_dict(arguments)}")
|
171
|
+
if inspect.iscoroutinefunction(tool.function):
|
172
|
+
result = await tool.function(**arguments)
|
173
|
+
else:
|
174
|
+
result = tool.function(**arguments)
|
175
|
+
|
176
|
+
# Convert result to JSON string if it's not already
|
177
|
+
if not isinstance(result, str):
|
178
|
+
result = json.dumps(result)
|
179
|
+
|
180
|
+
# Create the tool response
|
181
|
+
tool_response = {
|
182
|
+
"role": "tool",
|
183
|
+
"tool_call_id": tool_call.id,
|
184
|
+
"content": result,
|
185
|
+
}
|
186
|
+
tool_responses.append(tool_response)
|
187
|
+
logger.info(f"Created tool response for call ID {tool_call.id}")
|
188
|
+
|
189
|
+
# Add all tool responses to the messages
|
190
|
+
messages.extend(tool_responses)
|
191
|
+
logger.info(f"Message count: {len(messages)}, Added {len(tool_responses)} tool responses to messages")
|
192
|
+
|
193
|
+
# Validate the message sequence
|
194
|
+
self._validate_message_sequence(messages, validate_responses=True)
|
195
|
+
|
196
|
+
# Log the messages we're about to send
|
197
|
+
self._log_messages(messages, validate_responses=False)
|
198
|
+
|
199
|
+
async def chat(
|
200
|
+
self,
|
201
|
+
messages: list[dict[str, Any]],
|
202
|
+
*,
|
203
|
+
model: str | None = None,
|
204
|
+
stream: bool = False,
|
205
|
+
**kwargs: Any,
|
206
|
+
) -> Any:
|
207
|
+
"""Chat with the model, handling tool calls."""
|
208
|
+
tools = [tool.to_openai_function() for tool in self._tools.values()]
|
209
|
+
|
210
|
+
if tools:
|
211
|
+
logger.info(f"Available tools: {[tool['function']['name'] for tool in tools]}")
|
212
|
+
kwargs["tools"] = tools
|
213
|
+
kwargs["tool_choice"] = "auto"
|
214
|
+
|
215
|
+
while True:
|
216
|
+
# Get the model's response
|
217
|
+
response = self._client.chat_completion(
|
218
|
+
messages=messages,
|
219
|
+
model=model,
|
220
|
+
stream=stream,
|
221
|
+
**kwargs,
|
222
|
+
)
|
223
|
+
|
224
|
+
if stream:
|
225
|
+
return response
|
226
|
+
|
227
|
+
message = response.choices[0].message
|
228
|
+
# Add the assistant's message with tool calls
|
229
|
+
assistant_message = {
|
230
|
+
"role": "assistant",
|
231
|
+
"content": message.content or None, # Ensure content is None if empty
|
232
|
+
}
|
233
|
+
if message.tool_calls:
|
234
|
+
assistant_message["tool_calls"] = [
|
235
|
+
{
|
236
|
+
"id": tool_call.id,
|
237
|
+
"type": "function",
|
238
|
+
"function": {
|
239
|
+
"name": tool_call.function.name,
|
240
|
+
"arguments": tool_call.function.arguments,
|
241
|
+
},
|
242
|
+
}
|
243
|
+
for tool_call in message.tool_calls
|
244
|
+
]
|
245
|
+
messages.append(assistant_message)
|
246
|
+
logger.info(f"Message count: {len(messages)}, Added assistant message with tool calls: {[tc.function.name for tc in message.tool_calls] if message.tool_calls else None}")
|
247
|
+
|
248
|
+
# If there are no tool calls, we're done
|
249
|
+
if not message.tool_calls:
|
250
|
+
return response
|
251
|
+
|
252
|
+
# Process all tool calls
|
253
|
+
tool_responses = []
|
254
|
+
for tool_call in message.tool_calls:
|
255
|
+
tool = self._tools.get(tool_call.function.name)
|
256
|
+
if not tool:
|
257
|
+
raise ValueError(f"Unknown tool: {tool_call.function.name}")
|
258
|
+
|
259
|
+
# Parse arguments if they're a string
|
260
|
+
arguments = tool_call.function.arguments
|
261
|
+
if isinstance(arguments, str):
|
262
|
+
try:
|
263
|
+
arguments = json.loads(arguments)
|
264
|
+
except json.JSONDecodeError as e:
|
265
|
+
logger.error(f"Failed to parse tool arguments: {e}")
|
266
|
+
raise ValueError(f"Invalid tool arguments format: {arguments}")
|
267
|
+
|
268
|
+
# Call the tool
|
269
|
+
logger.info(f"Calling tool: {tool.name} with args: {self._truncate_dict(arguments)}")
|
270
|
+
try:
|
271
|
+
if inspect.iscoroutinefunction(tool.function):
|
272
|
+
result = await asyncio.wait_for(tool.function(**arguments), timeout=300.0) # 5 minutes timeout
|
273
|
+
else:
|
274
|
+
result = tool.function(**arguments)
|
275
|
+
logger.info(f"Tool {tool.name} completed successfully")
|
276
|
+
except asyncio.TimeoutError:
|
277
|
+
logger.error(f"Tool {tool.name} timed out after 5 minutes")
|
278
|
+
result = {"error": "Tool execution timed out after 5 minutes"}
|
279
|
+
except Exception as e:
|
280
|
+
logger.error(f"Error calling tool {tool.name}: {str(e)}")
|
281
|
+
result = {"error": f"Tool execution failed: {str(e)}"}
|
282
|
+
|
283
|
+
# Convert result to JSON string if it's not already
|
284
|
+
if not isinstance(result, str):
|
285
|
+
result = json.dumps(result)
|
286
|
+
|
287
|
+
# Create the tool response
|
288
|
+
tool_response = {
|
289
|
+
"role": "tool",
|
290
|
+
"tool_call_id": tool_call.id,
|
291
|
+
"content": result,
|
292
|
+
}
|
293
|
+
tool_responses.append(tool_response)
|
294
|
+
logger.info(f"Created tool response for call ID {tool_call.id}")
|
295
|
+
|
296
|
+
# Add all tool responses to the messages
|
297
|
+
messages.extend(tool_responses)
|
298
|
+
logger.info(f"Message count: {len(messages)}, Added {len(tool_responses)} tool responses to messages")
|
299
|
+
|
300
|
+
# Validate the message sequence
|
301
|
+
self._validate_message_sequence(messages, validate_responses=True)
|
302
|
+
|
303
|
+
# Log the messages we're about to send
|
304
|
+
self._log_messages(messages, validate_responses=False)
|
305
|
+
|
306
|
+
# Continue the loop to get the next response
|
307
|
+
continue
|
308
|
+
|
309
|
+
async def parse(
|
310
|
+
self,
|
311
|
+
messages: list[dict[str, Any]],
|
312
|
+
response_format: type[T],
|
313
|
+
*,
|
314
|
+
model: str | None = None,
|
315
|
+
stream: bool = False,
|
316
|
+
**kwargs: Any,
|
317
|
+
) -> Any:
|
318
|
+
"""Chat with the model and parse the response into a Pydantic model.
|
319
|
+
|
320
|
+
Args:
|
321
|
+
messages: The conversation messages
|
322
|
+
response_format: The Pydantic model to parse the response into
|
323
|
+
model: Optional model override
|
324
|
+
stream: Whether to stream the response
|
325
|
+
**kwargs: Additional parameters for the chat completion
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
The parsed response from the model
|
329
|
+
"""
|
330
|
+
tools = [tool.to_openai_function() for tool in self._tools.values()]
|
331
|
+
|
332
|
+
if tools:
|
333
|
+
logger.info(f"Available tools: {[tool['function']['name'] for tool in tools]}")
|
334
|
+
kwargs["tools"] = tools
|
335
|
+
kwargs["tool_choice"] = "auto"
|
336
|
+
|
337
|
+
while True:
|
338
|
+
# Log messages before sending to OpenRouter
|
339
|
+
self._log_messages(messages)
|
340
|
+
|
341
|
+
# Get the model's response
|
342
|
+
response = self._client.chat_completion_parse(
|
343
|
+
messages=messages,
|
344
|
+
response_format=response_format,
|
345
|
+
model=model,
|
346
|
+
stream=stream,
|
347
|
+
**kwargs,
|
348
|
+
)
|
349
|
+
|
350
|
+
if stream:
|
351
|
+
return response
|
352
|
+
|
353
|
+
message = response.choices[0].message
|
354
|
+
# Add the assistant's message with tool calls
|
355
|
+
assistant_message = {
|
356
|
+
"role": "assistant",
|
357
|
+
"content": message.content or None, # Ensure content is None if empty
|
358
|
+
}
|
359
|
+
if message.tool_calls:
|
360
|
+
assistant_message["tool_calls"] = [
|
361
|
+
{
|
362
|
+
"id": tool_call.id,
|
363
|
+
"type": "function",
|
364
|
+
"function": {
|
365
|
+
"name": tool_call.function.name,
|
366
|
+
"arguments": tool_call.function.arguments,
|
367
|
+
},
|
368
|
+
}
|
369
|
+
for tool_call in message.tool_calls
|
370
|
+
]
|
371
|
+
messages.append(assistant_message)
|
372
|
+
logger.info(f"Message count: {len(messages)}, Added assistant message with tool calls: {[tc.function.name for tc in message.tool_calls] if message.tool_calls else None}")
|
373
|
+
|
374
|
+
# If there are no tool calls, we're done
|
375
|
+
if not message.tool_calls:
|
376
|
+
return response
|
377
|
+
|
378
|
+
# Process all tool calls
|
379
|
+
tool_responses = []
|
380
|
+
for tool_call in message.tool_calls:
|
381
|
+
tool = self._tools.get(tool_call.function.name)
|
382
|
+
if not tool:
|
383
|
+
raise ValueError(f"Unknown tool: {tool_call.function.name}")
|
384
|
+
|
385
|
+
# Parse arguments if they're a string
|
386
|
+
arguments = tool_call.function.arguments
|
387
|
+
if isinstance(arguments, str):
|
388
|
+
try:
|
389
|
+
arguments = json.loads(arguments)
|
390
|
+
except json.JSONDecodeError as e:
|
391
|
+
logger.error(f"Failed to parse tool arguments: {e}")
|
392
|
+
raise ValueError(f"Invalid tool arguments format: {arguments}")
|
393
|
+
|
394
|
+
# Call the tool
|
395
|
+
logger.info(f"Calling tool: {tool.name} with args: {self._truncate_dict(arguments)}")
|
396
|
+
if inspect.iscoroutinefunction(tool.function):
|
397
|
+
result = await tool.function(**arguments)
|
398
|
+
else:
|
399
|
+
result = tool.function(**arguments)
|
400
|
+
|
401
|
+
# Convert result to JSON string if it's not already
|
402
|
+
if not isinstance(result, str):
|
403
|
+
result = json.dumps(result)
|
404
|
+
|
405
|
+
# Create the tool response
|
406
|
+
tool_response = {
|
407
|
+
"role": "tool",
|
408
|
+
"tool_call_id": tool_call.id,
|
409
|
+
"content": result,
|
410
|
+
}
|
411
|
+
tool_responses.append(tool_response)
|
412
|
+
logger.info(f"Created tool response for call ID {tool_call.id}")
|
413
|
+
|
414
|
+
# Add all tool responses to the messages
|
415
|
+
messages.extend(tool_responses)
|
416
|
+
logger.info(f"Message count: {len(messages)}, Added {len(tool_responses)} tool responses to messages")
|
417
|
+
|
418
|
+
# Validate the message sequence
|
419
|
+
self._validate_message_sequence(messages, validate_responses=True)
|
420
|
+
|
421
|
+
# Log the messages we're about to send
|
422
|
+
self._log_messages(messages, validate_responses=False)
|
423
|
+
|
424
|
+
# Continue the loop to get the next response
|
425
|
+
continue
|
@@ -292,11 +292,11 @@ wheels = [
|
|
292
292
|
|
293
293
|
[[package]]
|
294
294
|
name = "httpx-sse"
|
295
|
-
version = "0.6.
|
295
|
+
version = "0.6.11"
|
296
296
|
source = { registry = "https://pypi.org/simple" }
|
297
|
-
sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.6.
|
297
|
+
sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.6.11.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624 }
|
298
298
|
wheels = [
|
299
|
-
{ url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.6.
|
299
|
+
{ url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.6.11-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819 },
|
300
300
|
]
|
301
301
|
|
302
302
|
[[package]]
|
@@ -446,7 +446,7 @@ wheels = [
|
|
446
446
|
|
447
447
|
[[package]]
|
448
448
|
name = "mbxai"
|
449
|
-
version = "0.6.
|
449
|
+
version = "0.6.11"
|
450
450
|
source = { editable = "." }
|
451
451
|
dependencies = [
|
452
452
|
{ name = "fastapi" },
|
@@ -980,14 +980,14 @@ wheels = [
|
|
980
980
|
|
981
981
|
[[package]]
|
982
982
|
name = "typing-inspection"
|
983
|
-
version = "0.6.
|
983
|
+
version = "0.6.11"
|
984
984
|
source = { registry = "https://pypi.org/simple" }
|
985
985
|
dependencies = [
|
986
986
|
{ name = "typing-extensions" },
|
987
987
|
]
|
988
|
-
sdist = { url = "https://files.pythonhosted.org/packages/82/5c/e6082df02e215b846b4b8c0b887a64d7d08ffaba30605502639d44c06b82/typing_inspection-0.6.
|
988
|
+
sdist = { url = "https://files.pythonhosted.org/packages/82/5c/e6082df02e215b846b4b8c0b887a64d7d08ffaba30605502639d44c06b82/typing_inspection-0.6.11.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122", size = 76222 }
|
989
989
|
wheels = [
|
990
|
-
{ url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.6.
|
990
|
+
{ url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.6.11-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 },
|
991
991
|
]
|
992
992
|
|
993
993
|
[[package]]
|
@@ -1,305 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
ToolClient implementation for MBX AI.
|
3
|
-
"""
|
4
|
-
|
5
|
-
from typing import Any, Callable, TypeVar, cast
|
6
|
-
import logging
|
7
|
-
import inspect
|
8
|
-
import json
|
9
|
-
from pydantic import BaseModel
|
10
|
-
from ..openrouter import OpenRouterClient
|
11
|
-
from .types import Tool, ToolCall
|
12
|
-
import asyncio
|
13
|
-
|
14
|
-
logger = logging.getLogger(__name__)
|
15
|
-
|
16
|
-
T = TypeVar("T", bound=BaseModel)
|
17
|
-
|
18
|
-
class ToolClient:
|
19
|
-
"""Base class for tool clients."""
|
20
|
-
|
21
|
-
def __init__(self, openrouter_client: OpenRouterClient):
|
22
|
-
"""Initialize the tool client."""
|
23
|
-
self._openrouter_client = openrouter_client
|
24
|
-
self._tools: dict[str, Tool] = {}
|
25
|
-
|
26
|
-
def register_tool(self, tool: Tool) -> None:
|
27
|
-
"""Register a tool."""
|
28
|
-
self._tools[tool.name] = tool
|
29
|
-
|
30
|
-
async def invoke_tool(self, tool_name: str, **kwargs: Any) -> Any:
|
31
|
-
"""Invoke a tool by name."""
|
32
|
-
tool = self._tools.get(tool_name)
|
33
|
-
if not tool:
|
34
|
-
raise ValueError(f"Tool {tool_name} not found")
|
35
|
-
|
36
|
-
if not tool.function:
|
37
|
-
raise ValueError(f"Tool {tool_name} has no function implementation")
|
38
|
-
|
39
|
-
return await tool.function(**kwargs)
|
40
|
-
|
41
|
-
async def chat(self, messages: list[dict[str, str]], model: str | None = None) -> Any:
|
42
|
-
"""Process a chat request with tools."""
|
43
|
-
# Convert tools to OpenAI function format
|
44
|
-
functions = [tool.to_openai_function() for tool in self._tools.values()]
|
45
|
-
|
46
|
-
# Make the chat request
|
47
|
-
response = await self._openrouter_client.chat_completion(
|
48
|
-
messages=messages,
|
49
|
-
model=model,
|
50
|
-
functions=functions,
|
51
|
-
)
|
52
|
-
|
53
|
-
# Validate response
|
54
|
-
if not response:
|
55
|
-
raise ValueError("No response received from OpenRouter")
|
56
|
-
|
57
|
-
if not response.choices:
|
58
|
-
raise ValueError("Response missing choices")
|
59
|
-
|
60
|
-
choice = response.choices[0]
|
61
|
-
if not choice:
|
62
|
-
raise ValueError("Empty choice in response")
|
63
|
-
|
64
|
-
message = choice.message
|
65
|
-
if not message:
|
66
|
-
raise ValueError("Choice missing message")
|
67
|
-
|
68
|
-
# If message has function call, execute it
|
69
|
-
if message.function_call:
|
70
|
-
tool_name = message.function_call.name
|
71
|
-
tool_args = json.loads(message.function_call.arguments)
|
72
|
-
|
73
|
-
# Invoke the tool
|
74
|
-
tool_response = await self.invoke_tool(tool_name, **tool_args)
|
75
|
-
|
76
|
-
# Add tool response to messages
|
77
|
-
messages.append({
|
78
|
-
"role": "assistant",
|
79
|
-
"content": None,
|
80
|
-
"function_call": {
|
81
|
-
"name": tool_name,
|
82
|
-
"arguments": message.function_call.arguments,
|
83
|
-
},
|
84
|
-
})
|
85
|
-
messages.append({
|
86
|
-
"role": "function",
|
87
|
-
"name": tool_name,
|
88
|
-
"content": json.dumps(tool_response),
|
89
|
-
})
|
90
|
-
|
91
|
-
# Get final response
|
92
|
-
final_response = await self._openrouter_client.chat_completion(
|
93
|
-
messages=messages,
|
94
|
-
model=model,
|
95
|
-
)
|
96
|
-
|
97
|
-
if not final_response or not final_response.choices:
|
98
|
-
raise ValueError("No response received after tool execution")
|
99
|
-
|
100
|
-
return final_response
|
101
|
-
|
102
|
-
return response
|
103
|
-
|
104
|
-
def _truncate_content(self, content: str | None, max_length: int = 100) -> str:
|
105
|
-
"""Truncate content for logging."""
|
106
|
-
if not content:
|
107
|
-
return "None"
|
108
|
-
if len(content) <= max_length:
|
109
|
-
return content
|
110
|
-
return content[:max_length] + "..."
|
111
|
-
|
112
|
-
def _truncate_dict(self, data: dict[str, Any], max_length: int = 50) -> str:
|
113
|
-
"""Truncate dictionary values for logging."""
|
114
|
-
if not data:
|
115
|
-
return "{}"
|
116
|
-
truncated = {}
|
117
|
-
for k, v in data.items():
|
118
|
-
if isinstance(v, str):
|
119
|
-
truncated[k] = self._truncate_content(v, max_length)
|
120
|
-
elif isinstance(v, dict):
|
121
|
-
truncated[k] = self._truncate_dict(v, max_length)
|
122
|
-
else:
|
123
|
-
truncated[k] = str(v)[:max_length] + "..." if len(str(v)) > max_length else v
|
124
|
-
return str(truncated)
|
125
|
-
|
126
|
-
def _validate_message_sequence(self, messages: list[dict[str, Any]], validate_responses: bool = True) -> None:
|
127
|
-
"""Validate the message sequence for tool calls and responses.
|
128
|
-
|
129
|
-
Args:
|
130
|
-
messages: The message sequence to validate
|
131
|
-
validate_responses: Whether to validate that all tool calls have responses
|
132
|
-
"""
|
133
|
-
tool_call_ids = set()
|
134
|
-
tool_response_ids = set()
|
135
|
-
|
136
|
-
for i, msg in enumerate(messages):
|
137
|
-
role = msg.get("role")
|
138
|
-
if role == "assistant" and "tool_calls" in msg:
|
139
|
-
# Track tool calls
|
140
|
-
for tc in msg["tool_calls"]:
|
141
|
-
tool_call_ids.add(tc["id"])
|
142
|
-
logger.info(f"Found tool call {tc['id']} for {tc['function']['name']} in message {i}")
|
143
|
-
elif role == "tool":
|
144
|
-
# Track tool responses
|
145
|
-
tool_response_ids.add(msg["tool_call_id"])
|
146
|
-
logger.info(f"Found tool response for call ID {msg['tool_call_id']} in message {i}")
|
147
|
-
|
148
|
-
# Only validate responses if requested
|
149
|
-
if validate_responses:
|
150
|
-
# Check for missing responses
|
151
|
-
missing_responses = tool_call_ids - tool_response_ids
|
152
|
-
if missing_responses:
|
153
|
-
logger.error(f"Missing tool responses for call IDs: {missing_responses}")
|
154
|
-
logger.error("Message sequence:")
|
155
|
-
for i, msg in enumerate(messages):
|
156
|
-
role = msg.get("role", "unknown")
|
157
|
-
if role == "assistant" and "tool_calls" in msg:
|
158
|
-
logger.error(f" Message {i} - Assistant with tool calls: {[tc['id'] for tc in msg['tool_calls']]}")
|
159
|
-
elif role == "tool":
|
160
|
-
logger.error(f" Message {i} - Tool response for call ID: {msg['tool_call_id']}")
|
161
|
-
else:
|
162
|
-
logger.error(f" Message {i} - {role}: {self._truncate_content(msg.get('content'))}")
|
163
|
-
raise ValueError(f"Invalid message sequence: missing responses for tool calls {missing_responses}")
|
164
|
-
|
165
|
-
def _log_messages(self, messages: list[dict[str, Any]], validate_responses: bool = True) -> None:
|
166
|
-
"""Log the messages being sent to OpenRouter.
|
167
|
-
|
168
|
-
Args:
|
169
|
-
messages: The messages to log
|
170
|
-
validate_responses: Whether to validate that all tool calls have responses
|
171
|
-
"""
|
172
|
-
logger.info("Sending messages to OpenRouter:")
|
173
|
-
for i, msg in enumerate(messages):
|
174
|
-
role = msg.get("role", "unknown")
|
175
|
-
content = self._truncate_content(msg.get("content"))
|
176
|
-
tool_calls = msg.get("tool_calls", [])
|
177
|
-
tool_call_id = msg.get("tool_call_id")
|
178
|
-
|
179
|
-
if tool_calls:
|
180
|
-
tool_call_info = [
|
181
|
-
f"{tc['function']['name']}(id={tc['id']})"
|
182
|
-
for tc in tool_calls
|
183
|
-
]
|
184
|
-
logger.info(f" Message {i} - {role}: content='{content}', tool_calls={tool_call_info}")
|
185
|
-
elif tool_call_id:
|
186
|
-
logger.info(f" Message {i} - {role}: content='{content}', tool_call_id={tool_call_id}")
|
187
|
-
else:
|
188
|
-
logger.info(f" Message {i} - {role}: content='{content}'")
|
189
|
-
|
190
|
-
# Validate message sequence
|
191
|
-
self._validate_message_sequence(messages, validate_responses)
|
192
|
-
|
193
|
-
async def _process_tool_calls(self, message: Any, messages: list[dict[str, Any]]) -> None:
|
194
|
-
"""Process all tool calls in a message.
|
195
|
-
|
196
|
-
Args:
|
197
|
-
message: The message containing tool calls
|
198
|
-
messages: The list of messages to add responses to
|
199
|
-
"""
|
200
|
-
if not message.tool_calls:
|
201
|
-
return
|
202
|
-
|
203
|
-
# Process all tool calls first
|
204
|
-
tool_responses = []
|
205
|
-
for tool_call in message.tool_calls:
|
206
|
-
tool = self._tools.get(tool_call.function.name)
|
207
|
-
if not tool:
|
208
|
-
raise ValueError(f"Unknown tool: {tool_call.function.name}")
|
209
|
-
|
210
|
-
# Parse arguments if they're a string
|
211
|
-
arguments = tool_call.function.arguments
|
212
|
-
if isinstance(arguments, str):
|
213
|
-
try:
|
214
|
-
arguments = json.loads(arguments)
|
215
|
-
except json.JSONDecodeError as e:
|
216
|
-
logger.error(f"Failed to parse tool arguments: {e}")
|
217
|
-
raise ValueError(f"Invalid tool arguments format: {arguments}")
|
218
|
-
|
219
|
-
# Call the tool
|
220
|
-
logger.info(f"Calling tool: {tool.name} with args: {self._truncate_dict(arguments)}")
|
221
|
-
if inspect.iscoroutinefunction(tool.function):
|
222
|
-
result = await tool.function(**arguments)
|
223
|
-
else:
|
224
|
-
result = tool.function(**arguments)
|
225
|
-
|
226
|
-
# Convert result to JSON string if it's not already
|
227
|
-
if not isinstance(result, str):
|
228
|
-
result = json.dumps(result)
|
229
|
-
|
230
|
-
# Create the tool response
|
231
|
-
tool_response = {
|
232
|
-
"role": "tool",
|
233
|
-
"tool_call_id": tool_call.id,
|
234
|
-
"content": result,
|
235
|
-
}
|
236
|
-
tool_responses.append(tool_response)
|
237
|
-
logger.info(f"Created tool response for call ID {tool_call.id}")
|
238
|
-
|
239
|
-
# Add all tool responses to the messages
|
240
|
-
messages.extend(tool_responses)
|
241
|
-
logger.info(f"Message count: {len(messages)}, Added {len(tool_responses)} tool responses to messages")
|
242
|
-
|
243
|
-
# Validate the message sequence
|
244
|
-
self._validate_message_sequence(messages, validate_responses=True)
|
245
|
-
|
246
|
-
# Log the messages we're about to send
|
247
|
-
self._log_messages(messages, validate_responses=False)
|
248
|
-
|
249
|
-
async def parse(
|
250
|
-
self,
|
251
|
-
messages: list[dict[str, str]],
|
252
|
-
response_format: type[T],
|
253
|
-
*,
|
254
|
-
model: str | None = None,
|
255
|
-
**kwargs: Any,
|
256
|
-
) -> Any:
|
257
|
-
"""Chat with the model and parse the response into a Pydantic model.
|
258
|
-
|
259
|
-
Args:
|
260
|
-
messages: The conversation messages
|
261
|
-
response_format: The Pydantic model to parse the response into
|
262
|
-
model: Optional model override
|
263
|
-
**kwargs: Additional parameters for the chat completion
|
264
|
-
|
265
|
-
Returns:
|
266
|
-
The parsed response from the model
|
267
|
-
"""
|
268
|
-
# First use our own chat function to handle any tool calls
|
269
|
-
response = await self.chat(
|
270
|
-
messages=messages,
|
271
|
-
model=model,
|
272
|
-
**kwargs,
|
273
|
-
)
|
274
|
-
|
275
|
-
if not response or not response.choices:
|
276
|
-
raise ValueError("No response received from OpenRouter")
|
277
|
-
|
278
|
-
choice = response.choices[0]
|
279
|
-
if not choice:
|
280
|
-
raise ValueError("Empty choice in response")
|
281
|
-
|
282
|
-
message = choice.message
|
283
|
-
if not message:
|
284
|
-
raise ValueError("Choice missing message")
|
285
|
-
|
286
|
-
# If we still have tool calls, something went wrong
|
287
|
-
if message.tool_calls:
|
288
|
-
raise ValueError("Unexpected tool calls in final response")
|
289
|
-
|
290
|
-
# Ensure we have content to parse
|
291
|
-
if not message.content:
|
292
|
-
raise ValueError("No content in final response to parse")
|
293
|
-
|
294
|
-
# Now that we have the final response, parse it into the desired format
|
295
|
-
try:
|
296
|
-
final_response = await self._openrouter_client.chat_completion_parse(
|
297
|
-
messages=messages,
|
298
|
-
response_format=response_format,
|
299
|
-
model=model,
|
300
|
-
**kwargs,
|
301
|
-
)
|
302
|
-
return final_response
|
303
|
-
except Exception as e:
|
304
|
-
logger.error(f"Failed to parse response: {e}")
|
305
|
-
raise ValueError(f"Failed to parse response as {response_format.__name__}: {str(e)}")
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|