todo-agent 0.3.2__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- todo_agent/_version.py +2 -2
- todo_agent/core/exceptions.py +6 -6
- todo_agent/core/todo_manager.py +315 -182
- todo_agent/infrastructure/inference.py +120 -52
- todo_agent/infrastructure/llm_client.py +56 -22
- todo_agent/infrastructure/ollama_client.py +23 -13
- todo_agent/infrastructure/openrouter_client.py +20 -12
- todo_agent/infrastructure/prompts/system_prompt.txt +190 -438
- todo_agent/infrastructure/todo_shell.py +94 -11
- todo_agent/interface/cli.py +51 -33
- todo_agent/interface/formatters.py +7 -4
- todo_agent/interface/progress.py +30 -19
- todo_agent/interface/tools.py +73 -30
- todo_agent/main.py +17 -1
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.5.dist-info}/METADATA +1 -1
- todo_agent-0.3.5.dist-info/RECORD +30 -0
- todo_agent-0.3.2.dist-info/RECORD +0 -30
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.5.dist-info}/WHEEL +0 -0
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.5.dist-info}/entry_points.txt +0 -0
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.5.dist-info}/licenses/LICENSE +0 -0
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.5.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,11 @@ try:
|
|
12
12
|
from todo_agent.infrastructure.config import Config
|
13
13
|
from todo_agent.infrastructure.llm_client_factory import LLMClientFactory
|
14
14
|
from todo_agent.infrastructure.logger import Logger
|
15
|
+
from todo_agent.interface.formatters import (
|
16
|
+
get_provider_error_message as _get_error_msg,
|
17
|
+
)
|
18
|
+
from todo_agent.interface.progress import NoOpProgress, ToolCallProgress
|
15
19
|
from todo_agent.interface.tools import ToolCallHandler
|
16
|
-
from todo_agent.interface.progress import ToolCallProgress, NoOpProgress
|
17
20
|
except ImportError:
|
18
21
|
from core.conversation_manager import ( # type: ignore[no-redef]
|
19
22
|
ConversationManager,
|
@@ -24,8 +27,14 @@ except ImportError:
|
|
24
27
|
LLMClientFactory,
|
25
28
|
)
|
26
29
|
from infrastructure.logger import Logger # type: ignore[no-redef]
|
30
|
+
from interface.formatters import ( # type: ignore[no-redef]
|
31
|
+
get_provider_error_message as _get_error_msg,
|
32
|
+
)
|
33
|
+
from interface.progress import ( # type: ignore[no-redef]
|
34
|
+
NoOpProgress,
|
35
|
+
ToolCallProgress,
|
36
|
+
)
|
27
37
|
from interface.tools import ToolCallHandler # type: ignore[no-redef]
|
28
|
-
from interface.progress import ToolCallProgress, NoOpProgress # type: ignore[no-redef]
|
29
38
|
|
30
39
|
|
31
40
|
class Inference:
|
@@ -68,6 +77,27 @@ class Inference:
|
|
68
77
|
self.conversation_manager.set_system_prompt(system_prompt)
|
69
78
|
self.logger.debug("System prompt loaded and set")
|
70
79
|
|
80
|
+
def current_tasks(self) -> str:
|
81
|
+
"""
|
82
|
+
Get current tasks from the todo manager.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Formatted string of current tasks or error message
|
86
|
+
"""
|
87
|
+
try:
|
88
|
+
# Use the todo manager from the tool handler to get current tasks
|
89
|
+
tasks = self.tool_handler.todo_manager.list_tasks(suppress_color=True)
|
90
|
+
|
91
|
+
# If no tasks found, return a clear message
|
92
|
+
if not tasks.strip() or tasks == "No tasks found.":
|
93
|
+
return "No current tasks found."
|
94
|
+
|
95
|
+
return tasks
|
96
|
+
|
97
|
+
except Exception as e:
|
98
|
+
self.logger.warning(f"Failed to get current tasks: {e!s}")
|
99
|
+
return f"Error retrieving current tasks: {e!s}"
|
100
|
+
|
71
101
|
def _load_system_prompt(self) -> str:
|
72
102
|
"""Load and format the system prompt from file."""
|
73
103
|
# Get current datetime for interpolation
|
@@ -84,6 +114,9 @@ class Inference:
|
|
84
114
|
self.logger.warning(f"Failed to get calendar output: {e!s}")
|
85
115
|
calendar_output = "Calendar unavailable"
|
86
116
|
|
117
|
+
# Get current tasks
|
118
|
+
current_tasks = self.current_tasks()
|
119
|
+
|
87
120
|
# Load system prompt from file
|
88
121
|
prompt_file_path = os.path.join(
|
89
122
|
os.path.dirname(__file__), "prompts", "system_prompt.txt"
|
@@ -93,10 +126,11 @@ class Inference:
|
|
93
126
|
with open(prompt_file_path, encoding="utf-8") as f:
|
94
127
|
system_prompt_template = f.read()
|
95
128
|
|
96
|
-
# Format the template with current datetime and
|
129
|
+
# Format the template with current datetime, calendar, and current tasks
|
97
130
|
return system_prompt_template.format(
|
98
131
|
current_datetime=current_datetime,
|
99
132
|
calendar_output=calendar_output,
|
133
|
+
current_tasks=current_tasks,
|
100
134
|
)
|
101
135
|
|
102
136
|
except FileNotFoundError:
|
@@ -106,9 +140,9 @@ class Inference:
|
|
106
140
|
self.logger.error(f"Error loading system prompt: {e!s}")
|
107
141
|
raise
|
108
142
|
|
109
|
-
|
110
|
-
|
111
|
-
|
143
|
+
def process_request(
|
144
|
+
self, user_input: str, progress_callback: Optional[ToolCallProgress] = None
|
145
|
+
) -> tuple[str, float]:
|
112
146
|
"""
|
113
147
|
Process a user request through the LLM with tool orchestration.
|
114
148
|
|
@@ -121,11 +155,11 @@ class Inference:
|
|
121
155
|
"""
|
122
156
|
# Start timing the request
|
123
157
|
start_time = time.time()
|
124
|
-
|
158
|
+
|
125
159
|
# Initialize progress callback if not provided
|
126
160
|
if progress_callback is None:
|
127
161
|
progress_callback = NoOpProgress()
|
128
|
-
|
162
|
+
|
129
163
|
# Notify progress callback that thinking has started
|
130
164
|
progress_callback.on_thinking_start()
|
131
165
|
|
@@ -155,23 +189,19 @@ class Inference:
|
|
155
189
|
error_type = response.get("error_type", "general_error")
|
156
190
|
provider = response.get("provider", "unknown")
|
157
191
|
self.logger.error(f"Provider error from {provider}: {error_type}")
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
from todo_agent.interface.formatters import get_provider_error_message
|
162
|
-
error_message = get_provider_error_message(error_type)
|
163
|
-
except ImportError:
|
164
|
-
from interface.formatters import get_provider_error_message
|
165
|
-
error_message = get_provider_error_message(error_type)
|
166
|
-
|
192
|
+
|
193
|
+
error_message = _get_error_msg(error_type)
|
194
|
+
|
167
195
|
# Add error message to conversation
|
168
|
-
self.conversation_manager.add_message(
|
169
|
-
|
196
|
+
self.conversation_manager.add_message(
|
197
|
+
MessageRole.ASSISTANT, error_message
|
198
|
+
)
|
199
|
+
|
170
200
|
# Calculate thinking time and return
|
171
201
|
end_time = time.time()
|
172
202
|
thinking_time = end_time - start_time
|
173
203
|
progress_callback.on_thinking_complete(thinking_time)
|
174
|
-
|
204
|
+
|
175
205
|
return error_message, thinking_time
|
176
206
|
|
177
207
|
# Extract actual token usage from API response
|
@@ -188,8 +218,7 @@ class Inference:
|
|
188
218
|
|
189
219
|
# Handle multiple tool calls in sequence
|
190
220
|
tool_call_count = 0
|
191
|
-
|
192
|
-
|
221
|
+
|
193
222
|
while True:
|
194
223
|
tool_calls = self.llm_client.extract_tool_calls(response)
|
195
224
|
|
@@ -200,9 +229,11 @@ class Inference:
|
|
200
229
|
self.logger.debug(
|
201
230
|
f"Executing tool call sequence #{tool_call_count} with {len(tool_calls)} tools"
|
202
231
|
)
|
203
|
-
|
232
|
+
|
204
233
|
# Notify progress callback of sequence start
|
205
|
-
progress_callback.on_sequence_complete(
|
234
|
+
progress_callback.on_sequence_complete(
|
235
|
+
tool_call_count, 0
|
236
|
+
) # We don't know total yet
|
206
237
|
|
207
238
|
# Execute all tool calls and collect results
|
208
239
|
tool_results = []
|
@@ -217,13 +248,18 @@ class Inference:
|
|
217
248
|
self.logger.debug(f"Raw tool call: {tool_call}")
|
218
249
|
|
219
250
|
# Get progress description for the tool
|
220
|
-
progress_description = self._get_tool_progress_description(
|
221
|
-
|
251
|
+
progress_description = self._get_tool_progress_description(
|
252
|
+
tool_name, tool_call
|
253
|
+
)
|
254
|
+
|
222
255
|
# Notify progress callback of tool call start
|
223
256
|
progress_callback.on_tool_call_start(
|
224
|
-
tool_name,
|
257
|
+
tool_name,
|
258
|
+
progress_description,
|
259
|
+
tool_call_count,
|
260
|
+
0, # We don't know total yet
|
225
261
|
)
|
226
|
-
|
262
|
+
|
227
263
|
result = self.tool_handler.execute_tool(tool_call)
|
228
264
|
|
229
265
|
# Log tool execution result (success or error)
|
@@ -253,24 +289,22 @@ class Inference:
|
|
253
289
|
if response.get("error", False):
|
254
290
|
error_type = response.get("error_type", "general_error")
|
255
291
|
provider = response.get("provider", "unknown")
|
256
|
-
self.logger.error(
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
except ImportError:
|
263
|
-
from interface.formatters import get_provider_error_message
|
264
|
-
error_message = get_provider_error_message(error_type)
|
265
|
-
|
292
|
+
self.logger.error(
|
293
|
+
f"Provider error in continuation from {provider}: {error_type}"
|
294
|
+
)
|
295
|
+
|
296
|
+
error_message = _get_error_msg(error_type)
|
297
|
+
|
266
298
|
# Add error message to conversation
|
267
|
-
self.conversation_manager.add_message(
|
268
|
-
|
299
|
+
self.conversation_manager.add_message(
|
300
|
+
MessageRole.ASSISTANT, error_message
|
301
|
+
)
|
302
|
+
|
269
303
|
# Calculate thinking time and return
|
270
304
|
end_time = time.time()
|
271
305
|
thinking_time = end_time - start_time
|
272
306
|
progress_callback.on_thinking_complete(thinking_time)
|
273
|
-
|
307
|
+
|
274
308
|
return error_message, thinking_time
|
275
309
|
|
276
310
|
# Update with actual tokens from subsequent API calls
|
@@ -284,7 +318,7 @@ class Inference:
|
|
284
318
|
# Calculate and log total thinking time
|
285
319
|
end_time = time.time()
|
286
320
|
thinking_time = end_time - start_time
|
287
|
-
|
321
|
+
|
288
322
|
# Notify progress callback that thinking is complete
|
289
323
|
progress_callback.on_thinking_complete(thinking_time)
|
290
324
|
|
@@ -321,22 +355,56 @@ class Inference:
|
|
321
355
|
self.tool_handler.tools
|
322
356
|
)
|
323
357
|
|
324
|
-
def _get_tool_progress_description(
|
358
|
+
def _get_tool_progress_description(
|
359
|
+
self, tool_name: str, tool_call: Dict[str, Any]
|
360
|
+
) -> str:
|
325
361
|
"""
|
326
|
-
Get user-friendly progress description for a tool.
|
327
|
-
|
362
|
+
Get user-friendly progress description for a tool with parameter interpolation.
|
363
|
+
|
328
364
|
Args:
|
329
365
|
tool_name: Name of the tool
|
330
|
-
|
366
|
+
tool_call: The tool call dictionary containing parameters
|
367
|
+
|
331
368
|
Returns:
|
332
|
-
Progress description string
|
369
|
+
Progress description string with interpolated parameters
|
333
370
|
"""
|
334
|
-
tool_def = next(
|
335
|
-
|
336
|
-
|
371
|
+
tool_def = next(
|
372
|
+
(
|
373
|
+
t
|
374
|
+
for t in self.tool_handler.tools
|
375
|
+
if t.get("function", {}).get("name") == tool_name
|
376
|
+
),
|
377
|
+
None,
|
378
|
+
)
|
379
|
+
|
337
380
|
if tool_def and "progress_description" in tool_def:
|
338
|
-
|
339
|
-
|
381
|
+
template = tool_def["progress_description"]
|
382
|
+
|
383
|
+
# Extract arguments from tool call
|
384
|
+
arguments = tool_call.get("function", {}).get("arguments", {})
|
385
|
+
if isinstance(arguments, str):
|
386
|
+
import json
|
387
|
+
|
388
|
+
try:
|
389
|
+
arguments = json.loads(arguments)
|
390
|
+
except json.JSONDecodeError:
|
391
|
+
arguments = {}
|
392
|
+
|
393
|
+
# Use .format() like the system prompt does
|
394
|
+
try:
|
395
|
+
return template.format(**arguments)
|
396
|
+
except KeyError as e:
|
397
|
+
# If a required parameter is missing, fall back to template
|
398
|
+
self.logger.warning(
|
399
|
+
f"Missing parameter {e} for progress description of {tool_name}"
|
400
|
+
)
|
401
|
+
return template
|
402
|
+
except Exception as e:
|
403
|
+
self.logger.warning(
|
404
|
+
f"Failed to interpolate progress description for {tool_name}: {e}"
|
405
|
+
)
|
406
|
+
return template
|
407
|
+
|
340
408
|
# Fallback to generic description
|
341
409
|
return f"🔧 Executing {tool_name}..."
|
342
410
|
|
@@ -2,7 +2,6 @@
|
|
2
2
|
Abstract LLM client interface for todo.sh agent.
|
3
3
|
"""
|
4
4
|
|
5
|
-
import json
|
6
5
|
import time
|
7
6
|
from abc import ABC, abstractmethod
|
8
7
|
from typing import Any, Dict, List
|
@@ -93,7 +92,9 @@ class LLMClient(ABC):
|
|
93
92
|
pass
|
94
93
|
|
95
94
|
@abstractmethod
|
96
|
-
def _get_request_payload(
|
95
|
+
def _get_request_payload(
|
96
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
97
|
+
) -> Dict[str, Any]:
|
97
98
|
"""
|
98
99
|
Get request payload for the API call.
|
99
100
|
|
@@ -117,7 +118,9 @@ class LLMClient(ABC):
|
|
117
118
|
pass
|
118
119
|
|
119
120
|
@abstractmethod
|
120
|
-
def _process_response(
|
121
|
+
def _process_response(
|
122
|
+
self, response_data: Dict[str, Any], start_time: float
|
123
|
+
) -> None:
|
121
124
|
"""
|
122
125
|
Process and log response details.
|
123
126
|
|
@@ -135,7 +138,9 @@ class LLMClient(ABC):
|
|
135
138
|
total_tokens = self.token_counter.count_request_tokens(messages, tools)
|
136
139
|
self.logger.info(f"Request sent - Token count: {total_tokens}")
|
137
140
|
|
138
|
-
def _make_http_request(
|
141
|
+
def _make_http_request(
|
142
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
143
|
+
) -> Dict[str, Any]:
|
139
144
|
"""
|
140
145
|
Make HTTP request to the LLM API with common error handling.
|
141
146
|
|
@@ -155,7 +160,10 @@ class LLMClient(ABC):
|
|
155
160
|
|
156
161
|
try:
|
157
162
|
response = requests.post( # nosec B113
|
158
|
-
endpoint,
|
163
|
+
endpoint,
|
164
|
+
headers=headers,
|
165
|
+
json=payload,
|
166
|
+
timeout=self.get_request_timeout(),
|
159
167
|
)
|
160
168
|
except requests.exceptions.Timeout:
|
161
169
|
self.logger.error(f"{self.get_provider_name()} API request timed out")
|
@@ -169,19 +177,29 @@ class LLMClient(ABC):
|
|
169
177
|
|
170
178
|
if response.status_code != 200:
|
171
179
|
self.logger.error(f"{self.get_provider_name()} API error: {response.text}")
|
172
|
-
error_type = self.classify_error(
|
173
|
-
|
180
|
+
error_type = self.classify_error(
|
181
|
+
Exception(response.text), self.get_provider_name()
|
182
|
+
)
|
183
|
+
return self._create_error_response(
|
184
|
+
error_type, response.text, response.status_code
|
185
|
+
)
|
174
186
|
|
175
187
|
try:
|
176
188
|
response_data: Dict[str, Any] = response.json()
|
177
189
|
except Exception as e:
|
178
|
-
self.logger.error(
|
179
|
-
|
190
|
+
self.logger.error(
|
191
|
+
f"Failed to parse {self.get_provider_name()} response JSON: {e}"
|
192
|
+
)
|
193
|
+
return self._create_error_response(
|
194
|
+
"malformed_response", f"JSON parsing failed: {e}", response.status_code
|
195
|
+
)
|
180
196
|
|
181
197
|
self._process_response(response_data, start_time)
|
182
198
|
return response_data
|
183
199
|
|
184
|
-
def _create_error_response(
|
200
|
+
def _create_error_response(
|
201
|
+
self, error_type: str, raw_error: str, status_code: int = 0
|
202
|
+
) -> Dict[str, Any]:
|
185
203
|
"""
|
186
204
|
Create standardized error response.
|
187
205
|
|
@@ -198,7 +216,7 @@ class LLMClient(ABC):
|
|
198
216
|
"error_type": error_type,
|
199
217
|
"provider": self.get_provider_name(),
|
200
218
|
"status_code": status_code,
|
201
|
-
"raw_error": raw_error
|
219
|
+
"raw_error": raw_error,
|
202
220
|
}
|
203
221
|
|
204
222
|
def _validate_tool_call(self, tool_call: Any, index: int) -> bool:
|
@@ -214,47 +232,63 @@ class LLMClient(ABC):
|
|
214
232
|
"""
|
215
233
|
try:
|
216
234
|
if not isinstance(tool_call, dict):
|
217
|
-
self.logger.warning(
|
235
|
+
self.logger.warning(
|
236
|
+
f"Tool call {index + 1} is not a dictionary: {tool_call}"
|
237
|
+
)
|
218
238
|
return False
|
219
239
|
|
220
240
|
function = tool_call.get("function", {})
|
221
241
|
if not isinstance(function, dict):
|
222
|
-
self.logger.warning(
|
242
|
+
self.logger.warning(
|
243
|
+
f"Tool call {index + 1} function is not a dictionary: {function}"
|
244
|
+
)
|
223
245
|
return False
|
224
246
|
|
225
247
|
tool_name = function.get("name")
|
226
248
|
if not tool_name:
|
227
|
-
self.logger.warning(
|
249
|
+
self.logger.warning(
|
250
|
+
f"Tool call {index + 1} missing function name: {tool_call}"
|
251
|
+
)
|
228
252
|
return False
|
229
253
|
|
230
254
|
arguments = function.get("arguments", "{}")
|
231
255
|
if arguments and not isinstance(arguments, str):
|
232
|
-
self.logger.warning(
|
256
|
+
self.logger.warning(
|
257
|
+
f"Tool call {index + 1} arguments not a string: {arguments}"
|
258
|
+
)
|
233
259
|
return False
|
234
260
|
|
235
261
|
return True
|
236
262
|
except Exception as e:
|
237
|
-
self.logger.warning(f"Error validating tool call {index+1}: {e}")
|
263
|
+
self.logger.warning(f"Error validating tool call {index + 1}: {e}")
|
238
264
|
return False
|
239
265
|
|
240
266
|
def classify_error(self, error: Exception, provider: str) -> str:
|
241
267
|
"""
|
242
268
|
Classify provider errors using simple string matching.
|
243
|
-
|
269
|
+
|
244
270
|
Args:
|
245
271
|
error: The exception that occurred
|
246
272
|
provider: The provider name (e.g., 'openrouter', 'ollama')
|
247
|
-
|
273
|
+
|
248
274
|
Returns:
|
249
275
|
Error type string for message lookup
|
250
276
|
"""
|
251
277
|
error_str = str(error).lower()
|
252
|
-
|
278
|
+
|
253
279
|
if "malformed" in error_str or "invalid" in error_str or "parse" in error_str:
|
254
280
|
return "malformed_response"
|
255
|
-
elif
|
281
|
+
elif (
|
282
|
+
"rate limit" in error_str
|
283
|
+
or "429" in error_str
|
284
|
+
or "too many requests" in error_str
|
285
|
+
):
|
256
286
|
return "rate_limit"
|
257
|
-
elif
|
287
|
+
elif (
|
288
|
+
"unauthorized" in error_str
|
289
|
+
or "401" in error_str
|
290
|
+
or "authentication" in error_str
|
291
|
+
):
|
258
292
|
return "auth_error"
|
259
293
|
elif "timeout" in error_str or "timed out" in error_str:
|
260
294
|
return "timeout"
|
@@ -278,7 +312,7 @@ class LLMClient(ABC):
|
|
278
312
|
def get_request_timeout(self) -> int:
|
279
313
|
"""
|
280
314
|
Get the request timeout in seconds for this provider.
|
281
|
-
|
315
|
+
|
282
316
|
Returns:
|
283
317
|
Timeout value in seconds (default: 30)
|
284
318
|
"""
|
@@ -10,7 +10,7 @@ from todo_agent.infrastructure.llm_client import LLMClient
|
|
10
10
|
class OllamaClient(LLMClient):
|
11
11
|
"""Ollama API client implementation."""
|
12
12
|
|
13
|
-
def __init__(self, config):
|
13
|
+
def __init__(self, config: Any) -> None:
|
14
14
|
"""
|
15
15
|
Initialize Ollama client.
|
16
16
|
|
@@ -26,7 +26,9 @@ class OllamaClient(LLMClient):
|
|
26
26
|
"Content-Type": "application/json",
|
27
27
|
}
|
28
28
|
|
29
|
-
def _get_request_payload(
|
29
|
+
def _get_request_payload(
|
30
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
31
|
+
) -> Dict[str, Any]:
|
30
32
|
"""Get request payload for Ollama API."""
|
31
33
|
return {
|
32
34
|
"model": self.model,
|
@@ -39,10 +41,12 @@ class OllamaClient(LLMClient):
|
|
39
41
|
"""Get Ollama API endpoint."""
|
40
42
|
return f"{self.base_url}/api/chat"
|
41
43
|
|
42
|
-
def _process_response(
|
44
|
+
def _process_response(
|
45
|
+
self, response_data: Dict[str, Any], start_time: float
|
46
|
+
) -> None:
|
43
47
|
"""Process and log Ollama response details."""
|
44
48
|
import time
|
45
|
-
|
49
|
+
|
46
50
|
end_time = time.time()
|
47
51
|
latency_ms = (end_time - start_time) * 1000
|
48
52
|
|
@@ -88,21 +92,25 @@ class OllamaClient(LLMClient):
|
|
88
92
|
"""Extract tool calls from API response."""
|
89
93
|
# Check for provider errors first
|
90
94
|
if response.get("error", False):
|
91
|
-
self.logger.warning(
|
95
|
+
self.logger.warning(
|
96
|
+
f"Cannot extract tool calls from error response: {response.get('error_type')}"
|
97
|
+
)
|
92
98
|
return []
|
93
|
-
|
99
|
+
|
94
100
|
tool_calls = []
|
95
101
|
|
96
102
|
# Ollama response format is different from OpenRouter
|
97
103
|
if "message" in response and "tool_calls" in response["message"]:
|
98
104
|
raw_tool_calls = response["message"]["tool_calls"]
|
99
|
-
|
105
|
+
|
100
106
|
# Validate each tool call using common validation
|
101
107
|
for i, tool_call in enumerate(raw_tool_calls):
|
102
108
|
if self._validate_tool_call(tool_call, i):
|
103
109
|
tool_calls.append(tool_call)
|
104
|
-
|
105
|
-
self.logger.debug(
|
110
|
+
|
111
|
+
self.logger.debug(
|
112
|
+
f"Extracted {len(tool_calls)} valid tool calls from {len(raw_tool_calls)} total"
|
113
|
+
)
|
106
114
|
for i, tool_call in enumerate(tool_calls):
|
107
115
|
tool_name = tool_call.get("function", {}).get("name", "unknown")
|
108
116
|
tool_call_id = tool_call.get("id", "unknown")
|
@@ -118,9 +126,11 @@ class OllamaClient(LLMClient):
|
|
118
126
|
"""Extract content from API response."""
|
119
127
|
# Check for provider errors first
|
120
128
|
if response.get("error", False):
|
121
|
-
self.logger.warning(
|
129
|
+
self.logger.warning(
|
130
|
+
f"Cannot extract content from error response: {response.get('error_type')}"
|
131
|
+
)
|
122
132
|
return ""
|
123
|
-
|
133
|
+
|
124
134
|
if "message" in response and "content" in response["message"]:
|
125
135
|
content = response["message"]["content"]
|
126
136
|
return content if isinstance(content, str) else str(content)
|
@@ -147,9 +157,9 @@ class OllamaClient(LLMClient):
|
|
147
157
|
def get_request_timeout(self) -> int:
|
148
158
|
"""
|
149
159
|
Get the request timeout in seconds for Ollama.
|
150
|
-
|
160
|
+
|
151
161
|
Ollama can be slower than cloud providers, so we use a 2-minute timeout.
|
152
|
-
|
162
|
+
|
153
163
|
Returns:
|
154
164
|
Timeout value in seconds (120)
|
155
165
|
"""
|
@@ -10,7 +10,7 @@ from todo_agent.infrastructure.llm_client import LLMClient
|
|
10
10
|
class OpenRouterClient(LLMClient):
|
11
11
|
"""LLM API communication and response handling."""
|
12
12
|
|
13
|
-
def __init__(self, config):
|
13
|
+
def __init__(self, config: Any) -> None:
|
14
14
|
"""
|
15
15
|
Initialize OpenRouter client.
|
16
16
|
|
@@ -28,7 +28,9 @@ class OpenRouterClient(LLMClient):
|
|
28
28
|
"Content-Type": "application/json",
|
29
29
|
}
|
30
30
|
|
31
|
-
def _get_request_payload(
|
31
|
+
def _get_request_payload(
|
32
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
33
|
+
) -> Dict[str, Any]:
|
32
34
|
"""Get request payload for OpenRouter API."""
|
33
35
|
return {
|
34
36
|
"model": self.model,
|
@@ -41,10 +43,12 @@ class OpenRouterClient(LLMClient):
|
|
41
43
|
"""Get OpenRouter API endpoint."""
|
42
44
|
return f"{self.base_url}/chat/completions"
|
43
45
|
|
44
|
-
def _process_response(
|
46
|
+
def _process_response(
|
47
|
+
self, response_data: Dict[str, Any], start_time: float
|
48
|
+
) -> None:
|
45
49
|
"""Process and log OpenRouter response details."""
|
46
50
|
import time
|
47
|
-
|
51
|
+
|
48
52
|
end_time = time.time()
|
49
53
|
latency_ms = (end_time - start_time) * 1000
|
50
54
|
|
@@ -120,20 +124,22 @@ class OpenRouterClient(LLMClient):
|
|
120
124
|
"""Extract tool calls from API response."""
|
121
125
|
# Check for provider errors first
|
122
126
|
if response.get("error", False):
|
123
|
-
self.logger.warning(
|
127
|
+
self.logger.warning(
|
128
|
+
f"Cannot extract tool calls from error response: {response.get('error_type')}"
|
129
|
+
)
|
124
130
|
return []
|
125
|
-
|
131
|
+
|
126
132
|
tool_calls = []
|
127
133
|
if response.get("choices"):
|
128
134
|
choice = response["choices"][0]
|
129
135
|
if "message" in choice and "tool_calls" in choice["message"]:
|
130
136
|
raw_tool_calls = choice["message"]["tool_calls"]
|
131
|
-
|
137
|
+
|
132
138
|
# Validate each tool call using common validation
|
133
139
|
for i, tool_call in enumerate(raw_tool_calls):
|
134
140
|
if self._validate_tool_call(tool_call, i):
|
135
141
|
tool_calls.append(tool_call)
|
136
|
-
|
142
|
+
|
137
143
|
self.logger.debug(
|
138
144
|
f"Extracted {len(tool_calls)} valid tool calls from {len(raw_tool_calls)} total"
|
139
145
|
)
|
@@ -153,9 +159,11 @@ class OpenRouterClient(LLMClient):
|
|
153
159
|
"""Extract content from API response."""
|
154
160
|
# Check for provider errors first
|
155
161
|
if response.get("error", False):
|
156
|
-
self.logger.warning(
|
162
|
+
self.logger.warning(
|
163
|
+
f"Cannot extract content from error response: {response.get('error_type')}"
|
164
|
+
)
|
157
165
|
return ""
|
158
|
-
|
166
|
+
|
159
167
|
if response.get("choices"):
|
160
168
|
choice = response["choices"][0]
|
161
169
|
if "message" in choice and "content" in choice["message"]:
|
@@ -184,9 +192,9 @@ class OpenRouterClient(LLMClient):
|
|
184
192
|
def get_request_timeout(self) -> int:
|
185
193
|
"""
|
186
194
|
Get the request timeout in seconds for OpenRouter.
|
187
|
-
|
195
|
+
|
188
196
|
Cloud APIs typically respond quickly, so we use a 30-second timeout.
|
189
|
-
|
197
|
+
|
190
198
|
Returns:
|
191
199
|
Timeout value in seconds (30)
|
192
200
|
"""
|