todo-agent 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- todo_agent/_version.py +2 -2
- todo_agent/core/exceptions.py +6 -6
- todo_agent/core/todo_manager.py +13 -8
- todo_agent/infrastructure/inference.py +113 -52
- todo_agent/infrastructure/llm_client.py +56 -22
- todo_agent/infrastructure/ollama_client.py +23 -13
- todo_agent/infrastructure/openrouter_client.py +20 -12
- todo_agent/infrastructure/prompts/system_prompt.txt +88 -438
- todo_agent/infrastructure/todo_shell.py +35 -11
- todo_agent/interface/cli.py +51 -33
- todo_agent/interface/formatters.py +7 -4
- todo_agent/interface/progress.py +30 -19
- todo_agent/interface/tools.py +25 -25
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.3.dist-info}/METADATA +1 -1
- todo_agent-0.3.3.dist-info/RECORD +30 -0
- todo_agent-0.3.2.dist-info/RECORD +0 -30
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.3.dist-info}/WHEEL +0 -0
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.3.dist-info}/entry_points.txt +0 -0
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {todo_agent-0.3.2.dist-info → todo_agent-0.3.3.dist-info}/top_level.txt +0 -0
todo_agent/_version.py
CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
28
28
|
commit_id: COMMIT_ID
|
29
29
|
__commit_id__: COMMIT_ID
|
30
30
|
|
31
|
-
__version__ = version = '0.3.
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
31
|
+
__version__ = version = '0.3.3'
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 3)
|
33
33
|
|
34
34
|
__commit_id__ = commit_id = None
|
todo_agent/core/exceptions.py
CHANGED
@@ -35,7 +35,7 @@ class TodoShellError(TodoError):
|
|
35
35
|
|
36
36
|
class ProviderError(Exception):
|
37
37
|
"""Base exception for LLM provider errors."""
|
38
|
-
|
38
|
+
|
39
39
|
def __init__(self, message: str, error_type: str, provider: str):
|
40
40
|
super().__init__(message)
|
41
41
|
self.message = message
|
@@ -45,34 +45,34 @@ class ProviderError(Exception):
|
|
45
45
|
|
46
46
|
class MalformedResponseError(ProviderError):
|
47
47
|
"""Provider returned malformed or invalid response."""
|
48
|
-
|
48
|
+
|
49
49
|
def __init__(self, message: str, provider: str):
|
50
50
|
super().__init__(message, "malformed_response", provider)
|
51
51
|
|
52
52
|
|
53
53
|
class RateLimitError(ProviderError):
|
54
54
|
"""Provider rate limit exceeded."""
|
55
|
-
|
55
|
+
|
56
56
|
def __init__(self, message: str, provider: str):
|
57
57
|
super().__init__(message, "rate_limit", provider)
|
58
58
|
|
59
59
|
|
60
60
|
class AuthenticationError(ProviderError):
|
61
61
|
"""Provider authentication failed."""
|
62
|
-
|
62
|
+
|
63
63
|
def __init__(self, message: str, provider: str):
|
64
64
|
super().__init__(message, "auth_error", provider)
|
65
65
|
|
66
66
|
|
67
67
|
class TimeoutError(ProviderError):
|
68
68
|
"""Provider request timed out."""
|
69
|
-
|
69
|
+
|
70
70
|
def __init__(self, message: str, provider: str):
|
71
71
|
super().__init__(message, "timeout", provider)
|
72
72
|
|
73
73
|
|
74
74
|
class GeneralProviderError(ProviderError):
|
75
75
|
"""General provider error."""
|
76
|
-
|
76
|
+
|
77
77
|
def __init__(self, message: str, provider: str):
|
78
78
|
super().__init__(message, "general_error", provider)
|
todo_agent/core/todo_manager.py
CHANGED
@@ -102,9 +102,11 @@ class TodoManager:
|
|
102
102
|
self.todo_shell.add(full_description)
|
103
103
|
return f"Added task: {full_description}"
|
104
104
|
|
105
|
-
def list_tasks(
|
105
|
+
def list_tasks(
|
106
|
+
self, filter: Optional[str] = None, suppress_color: bool = True
|
107
|
+
) -> str:
|
106
108
|
"""List tasks with optional filtering."""
|
107
|
-
result = self.todo_shell.list_tasks(filter)
|
109
|
+
result = self.todo_shell.list_tasks(filter, suppress_color=suppress_color)
|
108
110
|
if not result.strip():
|
109
111
|
return "No tasks found."
|
110
112
|
|
@@ -258,16 +260,16 @@ class TodoManager:
|
|
258
260
|
operation_desc = ", ".join(operations)
|
259
261
|
return f"Updated projects for task {task_number} ({operation_desc}): {result}"
|
260
262
|
|
261
|
-
def list_projects(self, **kwargs: Any) -> str:
|
263
|
+
def list_projects(self, suppress_color: bool = True, **kwargs: Any) -> str:
|
262
264
|
"""List all available projects in todo.txt."""
|
263
|
-
result = self.todo_shell.list_projects()
|
265
|
+
result = self.todo_shell.list_projects(suppress_color=suppress_color)
|
264
266
|
if not result.strip():
|
265
267
|
return "No projects found."
|
266
268
|
return result
|
267
269
|
|
268
|
-
def list_contexts(self, **kwargs: Any) -> str:
|
270
|
+
def list_contexts(self, suppress_color: bool = True, **kwargs: Any) -> str:
|
269
271
|
"""List all available contexts in todo.txt."""
|
270
|
-
result = self.todo_shell.list_contexts()
|
272
|
+
result = self.todo_shell.list_contexts(suppress_color=suppress_color)
|
271
273
|
if not result.strip():
|
272
274
|
return "No contexts found."
|
273
275
|
return result
|
@@ -280,6 +282,7 @@ class TodoManager:
|
|
280
282
|
text_search: Optional[str] = None,
|
281
283
|
date_from: Optional[str] = None,
|
282
284
|
date_to: Optional[str] = None,
|
285
|
+
suppress_color: bool = True,
|
283
286
|
**kwargs: Any,
|
284
287
|
) -> str:
|
285
288
|
"""List completed tasks with optional filtering.
|
@@ -329,7 +332,9 @@ class TodoManager:
|
|
329
332
|
# Combine all filters
|
330
333
|
combined_filter = " ".join(filter_parts) if filter_parts else None
|
331
334
|
|
332
|
-
result = self.todo_shell.list_completed(
|
335
|
+
result = self.todo_shell.list_completed(
|
336
|
+
combined_filter, suppress_color=suppress_color
|
337
|
+
)
|
333
338
|
if not result.strip():
|
334
339
|
return "No completed tasks found matching the criteria."
|
335
340
|
return result
|
@@ -464,7 +469,7 @@ class TodoManager:
|
|
464
469
|
|
465
470
|
# Use the move command to restore the task from done.txt to todo.txt
|
466
471
|
result = self.todo_shell.move(task_number, "todo.txt", "done.txt")
|
467
|
-
|
472
|
+
|
468
473
|
# Extract the task description from the result for confirmation
|
469
474
|
# The result format is typically: "TODO: X moved from '.../done.txt' to '.../todo.txt'."
|
470
475
|
if "moved from" in result and "to" in result:
|
@@ -12,8 +12,11 @@ try:
|
|
12
12
|
from todo_agent.infrastructure.config import Config
|
13
13
|
from todo_agent.infrastructure.llm_client_factory import LLMClientFactory
|
14
14
|
from todo_agent.infrastructure.logger import Logger
|
15
|
+
from todo_agent.interface.formatters import (
|
16
|
+
get_provider_error_message as _get_error_msg,
|
17
|
+
)
|
18
|
+
from todo_agent.interface.progress import NoOpProgress, ToolCallProgress
|
15
19
|
from todo_agent.interface.tools import ToolCallHandler
|
16
|
-
from todo_agent.interface.progress import ToolCallProgress, NoOpProgress
|
17
20
|
except ImportError:
|
18
21
|
from core.conversation_manager import ( # type: ignore[no-redef]
|
19
22
|
ConversationManager,
|
@@ -24,8 +27,14 @@ except ImportError:
|
|
24
27
|
LLMClientFactory,
|
25
28
|
)
|
26
29
|
from infrastructure.logger import Logger # type: ignore[no-redef]
|
30
|
+
from interface.formatters import ( # type: ignore[no-redef]
|
31
|
+
get_provider_error_message as _get_error_msg,
|
32
|
+
)
|
33
|
+
from interface.progress import ( # type: ignore[no-redef]
|
34
|
+
NoOpProgress,
|
35
|
+
ToolCallProgress,
|
36
|
+
)
|
27
37
|
from interface.tools import ToolCallHandler # type: ignore[no-redef]
|
28
|
-
from interface.progress import ToolCallProgress, NoOpProgress # type: ignore[no-redef]
|
29
38
|
|
30
39
|
|
31
40
|
class Inference:
|
@@ -68,6 +77,27 @@ class Inference:
|
|
68
77
|
self.conversation_manager.set_system_prompt(system_prompt)
|
69
78
|
self.logger.debug("System prompt loaded and set")
|
70
79
|
|
80
|
+
def current_tasks(self) -> str:
|
81
|
+
"""
|
82
|
+
Get current tasks from the todo manager.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Formatted string of current tasks or error message
|
86
|
+
"""
|
87
|
+
try:
|
88
|
+
# Use the todo manager from the tool handler to get current tasks
|
89
|
+
tasks = self.tool_handler.todo_manager.list_tasks(suppress_color=True)
|
90
|
+
|
91
|
+
# If no tasks found, return a clear message
|
92
|
+
if not tasks.strip() or tasks == "No tasks found.":
|
93
|
+
return "No current tasks found."
|
94
|
+
|
95
|
+
return tasks
|
96
|
+
|
97
|
+
except Exception as e:
|
98
|
+
self.logger.warning(f"Failed to get current tasks: {e!s}")
|
99
|
+
return f"Error retrieving current tasks: {e!s}"
|
100
|
+
|
71
101
|
def _load_system_prompt(self) -> str:
|
72
102
|
"""Load and format the system prompt from file."""
|
73
103
|
# Get current datetime for interpolation
|
@@ -84,6 +114,9 @@ class Inference:
|
|
84
114
|
self.logger.warning(f"Failed to get calendar output: {e!s}")
|
85
115
|
calendar_output = "Calendar unavailable"
|
86
116
|
|
117
|
+
# Get current tasks
|
118
|
+
current_tasks = self.current_tasks()
|
119
|
+
|
87
120
|
# Load system prompt from file
|
88
121
|
prompt_file_path = os.path.join(
|
89
122
|
os.path.dirname(__file__), "prompts", "system_prompt.txt"
|
@@ -93,10 +126,11 @@ class Inference:
|
|
93
126
|
with open(prompt_file_path, encoding="utf-8") as f:
|
94
127
|
system_prompt_template = f.read()
|
95
128
|
|
96
|
-
# Format the template with current datetime and
|
129
|
+
# Format the template with current datetime, calendar, and current tasks
|
97
130
|
return system_prompt_template.format(
|
98
131
|
current_datetime=current_datetime,
|
99
132
|
calendar_output=calendar_output,
|
133
|
+
current_tasks=current_tasks,
|
100
134
|
)
|
101
135
|
|
102
136
|
except FileNotFoundError:
|
@@ -106,9 +140,9 @@ class Inference:
|
|
106
140
|
self.logger.error(f"Error loading system prompt: {e!s}")
|
107
141
|
raise
|
108
142
|
|
109
|
-
|
110
|
-
|
111
|
-
|
143
|
+
def process_request(
|
144
|
+
self, user_input: str, progress_callback: Optional[ToolCallProgress] = None
|
145
|
+
) -> tuple[str, float]:
|
112
146
|
"""
|
113
147
|
Process a user request through the LLM with tool orchestration.
|
114
148
|
|
@@ -121,11 +155,11 @@ class Inference:
|
|
121
155
|
"""
|
122
156
|
# Start timing the request
|
123
157
|
start_time = time.time()
|
124
|
-
|
158
|
+
|
125
159
|
# Initialize progress callback if not provided
|
126
160
|
if progress_callback is None:
|
127
161
|
progress_callback = NoOpProgress()
|
128
|
-
|
162
|
+
|
129
163
|
# Notify progress callback that thinking has started
|
130
164
|
progress_callback.on_thinking_start()
|
131
165
|
|
@@ -155,23 +189,19 @@ class Inference:
|
|
155
189
|
error_type = response.get("error_type", "general_error")
|
156
190
|
provider = response.get("provider", "unknown")
|
157
191
|
self.logger.error(f"Provider error from {provider}: {error_type}")
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
from todo_agent.interface.formatters import get_provider_error_message
|
162
|
-
error_message = get_provider_error_message(error_type)
|
163
|
-
except ImportError:
|
164
|
-
from interface.formatters import get_provider_error_message
|
165
|
-
error_message = get_provider_error_message(error_type)
|
166
|
-
|
192
|
+
|
193
|
+
error_message = _get_error_msg(error_type)
|
194
|
+
|
167
195
|
# Add error message to conversation
|
168
|
-
self.conversation_manager.add_message(
|
169
|
-
|
196
|
+
self.conversation_manager.add_message(
|
197
|
+
MessageRole.ASSISTANT, error_message
|
198
|
+
)
|
199
|
+
|
170
200
|
# Calculate thinking time and return
|
171
201
|
end_time = time.time()
|
172
202
|
thinking_time = end_time - start_time
|
173
203
|
progress_callback.on_thinking_complete(thinking_time)
|
174
|
-
|
204
|
+
|
175
205
|
return error_message, thinking_time
|
176
206
|
|
177
207
|
# Extract actual token usage from API response
|
@@ -188,8 +218,7 @@ class Inference:
|
|
188
218
|
|
189
219
|
# Handle multiple tool calls in sequence
|
190
220
|
tool_call_count = 0
|
191
|
-
|
192
|
-
|
221
|
+
|
193
222
|
while True:
|
194
223
|
tool_calls = self.llm_client.extract_tool_calls(response)
|
195
224
|
|
@@ -200,9 +229,11 @@ class Inference:
|
|
200
229
|
self.logger.debug(
|
201
230
|
f"Executing tool call sequence #{tool_call_count} with {len(tool_calls)} tools"
|
202
231
|
)
|
203
|
-
|
232
|
+
|
204
233
|
# Notify progress callback of sequence start
|
205
|
-
progress_callback.on_sequence_complete(
|
234
|
+
progress_callback.on_sequence_complete(
|
235
|
+
tool_call_count, 0
|
236
|
+
) # We don't know total yet
|
206
237
|
|
207
238
|
# Execute all tool calls and collect results
|
208
239
|
tool_results = []
|
@@ -217,13 +248,18 @@ class Inference:
|
|
217
248
|
self.logger.debug(f"Raw tool call: {tool_call}")
|
218
249
|
|
219
250
|
# Get progress description for the tool
|
220
|
-
progress_description = self._get_tool_progress_description(
|
221
|
-
|
251
|
+
progress_description = self._get_tool_progress_description(
|
252
|
+
tool_name, tool_call
|
253
|
+
)
|
254
|
+
|
222
255
|
# Notify progress callback of tool call start
|
223
256
|
progress_callback.on_tool_call_start(
|
224
|
-
tool_name,
|
257
|
+
tool_name,
|
258
|
+
progress_description,
|
259
|
+
tool_call_count,
|
260
|
+
0, # We don't know total yet
|
225
261
|
)
|
226
|
-
|
262
|
+
|
227
263
|
result = self.tool_handler.execute_tool(tool_call)
|
228
264
|
|
229
265
|
# Log tool execution result (success or error)
|
@@ -253,24 +289,22 @@ class Inference:
|
|
253
289
|
if response.get("error", False):
|
254
290
|
error_type = response.get("error_type", "general_error")
|
255
291
|
provider = response.get("provider", "unknown")
|
256
|
-
self.logger.error(
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
except ImportError:
|
263
|
-
from interface.formatters import get_provider_error_message
|
264
|
-
error_message = get_provider_error_message(error_type)
|
265
|
-
|
292
|
+
self.logger.error(
|
293
|
+
f"Provider error in continuation from {provider}: {error_type}"
|
294
|
+
)
|
295
|
+
|
296
|
+
error_message = _get_error_msg(error_type)
|
297
|
+
|
266
298
|
# Add error message to conversation
|
267
|
-
self.conversation_manager.add_message(
|
268
|
-
|
299
|
+
self.conversation_manager.add_message(
|
300
|
+
MessageRole.ASSISTANT, error_message
|
301
|
+
)
|
302
|
+
|
269
303
|
# Calculate thinking time and return
|
270
304
|
end_time = time.time()
|
271
305
|
thinking_time = end_time - start_time
|
272
306
|
progress_callback.on_thinking_complete(thinking_time)
|
273
|
-
|
307
|
+
|
274
308
|
return error_message, thinking_time
|
275
309
|
|
276
310
|
# Update with actual tokens from subsequent API calls
|
@@ -284,7 +318,7 @@ class Inference:
|
|
284
318
|
# Calculate and log total thinking time
|
285
319
|
end_time = time.time()
|
286
320
|
thinking_time = end_time - start_time
|
287
|
-
|
321
|
+
|
288
322
|
# Notify progress callback that thinking is complete
|
289
323
|
progress_callback.on_thinking_complete(thinking_time)
|
290
324
|
|
@@ -321,22 +355,49 @@ class Inference:
|
|
321
355
|
self.tool_handler.tools
|
322
356
|
)
|
323
357
|
|
324
|
-
def _get_tool_progress_description(self, tool_name: str) -> str:
|
358
|
+
def _get_tool_progress_description(self, tool_name: str, tool_call: Dict[str, Any]) -> str:
|
325
359
|
"""
|
326
|
-
Get user-friendly progress description for a tool.
|
327
|
-
|
360
|
+
Get user-friendly progress description for a tool with parameter interpolation.
|
361
|
+
|
328
362
|
Args:
|
329
363
|
tool_name: Name of the tool
|
330
|
-
|
364
|
+
tool_call: The tool call dictionary containing parameters
|
365
|
+
|
331
366
|
Returns:
|
332
|
-
Progress description string
|
367
|
+
Progress description string with interpolated parameters
|
333
368
|
"""
|
334
|
-
tool_def = next(
|
335
|
-
|
336
|
-
|
369
|
+
tool_def = next(
|
370
|
+
(
|
371
|
+
t
|
372
|
+
for t in self.tool_handler.tools
|
373
|
+
if t.get("function", {}).get("name") == tool_name
|
374
|
+
),
|
375
|
+
None,
|
376
|
+
)
|
377
|
+
|
337
378
|
if tool_def and "progress_description" in tool_def:
|
338
|
-
|
339
|
-
|
379
|
+
template = tool_def["progress_description"]
|
380
|
+
|
381
|
+
# Extract arguments from tool call
|
382
|
+
arguments = tool_call.get("function", {}).get("arguments", {})
|
383
|
+
if isinstance(arguments, str):
|
384
|
+
import json
|
385
|
+
try:
|
386
|
+
arguments = json.loads(arguments)
|
387
|
+
except json.JSONDecodeError:
|
388
|
+
arguments = {}
|
389
|
+
|
390
|
+
# Use .format() like the system prompt does
|
391
|
+
try:
|
392
|
+
return template.format(**arguments)
|
393
|
+
except KeyError as e:
|
394
|
+
# If a required parameter is missing, fall back to template
|
395
|
+
self.logger.warning(f"Missing parameter {e} for progress description of {tool_name}")
|
396
|
+
return template
|
397
|
+
except Exception as e:
|
398
|
+
self.logger.warning(f"Failed to interpolate progress description for {tool_name}: {e}")
|
399
|
+
return template
|
400
|
+
|
340
401
|
# Fallback to generic description
|
341
402
|
return f"🔧 Executing {tool_name}..."
|
342
403
|
|
@@ -2,7 +2,6 @@
|
|
2
2
|
Abstract LLM client interface for todo.sh agent.
|
3
3
|
"""
|
4
4
|
|
5
|
-
import json
|
6
5
|
import time
|
7
6
|
from abc import ABC, abstractmethod
|
8
7
|
from typing import Any, Dict, List
|
@@ -93,7 +92,9 @@ class LLMClient(ABC):
|
|
93
92
|
pass
|
94
93
|
|
95
94
|
@abstractmethod
|
96
|
-
def _get_request_payload(
|
95
|
+
def _get_request_payload(
|
96
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
97
|
+
) -> Dict[str, Any]:
|
97
98
|
"""
|
98
99
|
Get request payload for the API call.
|
99
100
|
|
@@ -117,7 +118,9 @@ class LLMClient(ABC):
|
|
117
118
|
pass
|
118
119
|
|
119
120
|
@abstractmethod
|
120
|
-
def _process_response(
|
121
|
+
def _process_response(
|
122
|
+
self, response_data: Dict[str, Any], start_time: float
|
123
|
+
) -> None:
|
121
124
|
"""
|
122
125
|
Process and log response details.
|
123
126
|
|
@@ -135,7 +138,9 @@ class LLMClient(ABC):
|
|
135
138
|
total_tokens = self.token_counter.count_request_tokens(messages, tools)
|
136
139
|
self.logger.info(f"Request sent - Token count: {total_tokens}")
|
137
140
|
|
138
|
-
def _make_http_request(
|
141
|
+
def _make_http_request(
|
142
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
143
|
+
) -> Dict[str, Any]:
|
139
144
|
"""
|
140
145
|
Make HTTP request to the LLM API with common error handling.
|
141
146
|
|
@@ -155,7 +160,10 @@ class LLMClient(ABC):
|
|
155
160
|
|
156
161
|
try:
|
157
162
|
response = requests.post( # nosec B113
|
158
|
-
endpoint,
|
163
|
+
endpoint,
|
164
|
+
headers=headers,
|
165
|
+
json=payload,
|
166
|
+
timeout=self.get_request_timeout(),
|
159
167
|
)
|
160
168
|
except requests.exceptions.Timeout:
|
161
169
|
self.logger.error(f"{self.get_provider_name()} API request timed out")
|
@@ -169,19 +177,29 @@ class LLMClient(ABC):
|
|
169
177
|
|
170
178
|
if response.status_code != 200:
|
171
179
|
self.logger.error(f"{self.get_provider_name()} API error: {response.text}")
|
172
|
-
error_type = self.classify_error(
|
173
|
-
|
180
|
+
error_type = self.classify_error(
|
181
|
+
Exception(response.text), self.get_provider_name()
|
182
|
+
)
|
183
|
+
return self._create_error_response(
|
184
|
+
error_type, response.text, response.status_code
|
185
|
+
)
|
174
186
|
|
175
187
|
try:
|
176
188
|
response_data: Dict[str, Any] = response.json()
|
177
189
|
except Exception as e:
|
178
|
-
self.logger.error(
|
179
|
-
|
190
|
+
self.logger.error(
|
191
|
+
f"Failed to parse {self.get_provider_name()} response JSON: {e}"
|
192
|
+
)
|
193
|
+
return self._create_error_response(
|
194
|
+
"malformed_response", f"JSON parsing failed: {e}", response.status_code
|
195
|
+
)
|
180
196
|
|
181
197
|
self._process_response(response_data, start_time)
|
182
198
|
return response_data
|
183
199
|
|
184
|
-
def _create_error_response(
|
200
|
+
def _create_error_response(
|
201
|
+
self, error_type: str, raw_error: str, status_code: int = 0
|
202
|
+
) -> Dict[str, Any]:
|
185
203
|
"""
|
186
204
|
Create standardized error response.
|
187
205
|
|
@@ -198,7 +216,7 @@ class LLMClient(ABC):
|
|
198
216
|
"error_type": error_type,
|
199
217
|
"provider": self.get_provider_name(),
|
200
218
|
"status_code": status_code,
|
201
|
-
"raw_error": raw_error
|
219
|
+
"raw_error": raw_error,
|
202
220
|
}
|
203
221
|
|
204
222
|
def _validate_tool_call(self, tool_call: Any, index: int) -> bool:
|
@@ -214,47 +232,63 @@ class LLMClient(ABC):
|
|
214
232
|
"""
|
215
233
|
try:
|
216
234
|
if not isinstance(tool_call, dict):
|
217
|
-
self.logger.warning(
|
235
|
+
self.logger.warning(
|
236
|
+
f"Tool call {index + 1} is not a dictionary: {tool_call}"
|
237
|
+
)
|
218
238
|
return False
|
219
239
|
|
220
240
|
function = tool_call.get("function", {})
|
221
241
|
if not isinstance(function, dict):
|
222
|
-
self.logger.warning(
|
242
|
+
self.logger.warning(
|
243
|
+
f"Tool call {index + 1} function is not a dictionary: {function}"
|
244
|
+
)
|
223
245
|
return False
|
224
246
|
|
225
247
|
tool_name = function.get("name")
|
226
248
|
if not tool_name:
|
227
|
-
self.logger.warning(
|
249
|
+
self.logger.warning(
|
250
|
+
f"Tool call {index + 1} missing function name: {tool_call}"
|
251
|
+
)
|
228
252
|
return False
|
229
253
|
|
230
254
|
arguments = function.get("arguments", "{}")
|
231
255
|
if arguments and not isinstance(arguments, str):
|
232
|
-
self.logger.warning(
|
256
|
+
self.logger.warning(
|
257
|
+
f"Tool call {index + 1} arguments not a string: {arguments}"
|
258
|
+
)
|
233
259
|
return False
|
234
260
|
|
235
261
|
return True
|
236
262
|
except Exception as e:
|
237
|
-
self.logger.warning(f"Error validating tool call {index+1}: {e}")
|
263
|
+
self.logger.warning(f"Error validating tool call {index + 1}: {e}")
|
238
264
|
return False
|
239
265
|
|
240
266
|
def classify_error(self, error: Exception, provider: str) -> str:
|
241
267
|
"""
|
242
268
|
Classify provider errors using simple string matching.
|
243
|
-
|
269
|
+
|
244
270
|
Args:
|
245
271
|
error: The exception that occurred
|
246
272
|
provider: The provider name (e.g., 'openrouter', 'ollama')
|
247
|
-
|
273
|
+
|
248
274
|
Returns:
|
249
275
|
Error type string for message lookup
|
250
276
|
"""
|
251
277
|
error_str = str(error).lower()
|
252
|
-
|
278
|
+
|
253
279
|
if "malformed" in error_str or "invalid" in error_str or "parse" in error_str:
|
254
280
|
return "malformed_response"
|
255
|
-
elif
|
281
|
+
elif (
|
282
|
+
"rate limit" in error_str
|
283
|
+
or "429" in error_str
|
284
|
+
or "too many requests" in error_str
|
285
|
+
):
|
256
286
|
return "rate_limit"
|
257
|
-
elif
|
287
|
+
elif (
|
288
|
+
"unauthorized" in error_str
|
289
|
+
or "401" in error_str
|
290
|
+
or "authentication" in error_str
|
291
|
+
):
|
258
292
|
return "auth_error"
|
259
293
|
elif "timeout" in error_str or "timed out" in error_str:
|
260
294
|
return "timeout"
|
@@ -278,7 +312,7 @@ class LLMClient(ABC):
|
|
278
312
|
def get_request_timeout(self) -> int:
|
279
313
|
"""
|
280
314
|
Get the request timeout in seconds for this provider.
|
281
|
-
|
315
|
+
|
282
316
|
Returns:
|
283
317
|
Timeout value in seconds (default: 30)
|
284
318
|
"""
|
@@ -10,7 +10,7 @@ from todo_agent.infrastructure.llm_client import LLMClient
|
|
10
10
|
class OllamaClient(LLMClient):
|
11
11
|
"""Ollama API client implementation."""
|
12
12
|
|
13
|
-
def __init__(self, config):
|
13
|
+
def __init__(self, config: Any) -> None:
|
14
14
|
"""
|
15
15
|
Initialize Ollama client.
|
16
16
|
|
@@ -26,7 +26,9 @@ class OllamaClient(LLMClient):
|
|
26
26
|
"Content-Type": "application/json",
|
27
27
|
}
|
28
28
|
|
29
|
-
def _get_request_payload(
|
29
|
+
def _get_request_payload(
|
30
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
31
|
+
) -> Dict[str, Any]:
|
30
32
|
"""Get request payload for Ollama API."""
|
31
33
|
return {
|
32
34
|
"model": self.model,
|
@@ -39,10 +41,12 @@ class OllamaClient(LLMClient):
|
|
39
41
|
"""Get Ollama API endpoint."""
|
40
42
|
return f"{self.base_url}/api/chat"
|
41
43
|
|
42
|
-
def _process_response(
|
44
|
+
def _process_response(
|
45
|
+
self, response_data: Dict[str, Any], start_time: float
|
46
|
+
) -> None:
|
43
47
|
"""Process and log Ollama response details."""
|
44
48
|
import time
|
45
|
-
|
49
|
+
|
46
50
|
end_time = time.time()
|
47
51
|
latency_ms = (end_time - start_time) * 1000
|
48
52
|
|
@@ -88,21 +92,25 @@ class OllamaClient(LLMClient):
|
|
88
92
|
"""Extract tool calls from API response."""
|
89
93
|
# Check for provider errors first
|
90
94
|
if response.get("error", False):
|
91
|
-
self.logger.warning(
|
95
|
+
self.logger.warning(
|
96
|
+
f"Cannot extract tool calls from error response: {response.get('error_type')}"
|
97
|
+
)
|
92
98
|
return []
|
93
|
-
|
99
|
+
|
94
100
|
tool_calls = []
|
95
101
|
|
96
102
|
# Ollama response format is different from OpenRouter
|
97
103
|
if "message" in response and "tool_calls" in response["message"]:
|
98
104
|
raw_tool_calls = response["message"]["tool_calls"]
|
99
|
-
|
105
|
+
|
100
106
|
# Validate each tool call using common validation
|
101
107
|
for i, tool_call in enumerate(raw_tool_calls):
|
102
108
|
if self._validate_tool_call(tool_call, i):
|
103
109
|
tool_calls.append(tool_call)
|
104
|
-
|
105
|
-
self.logger.debug(
|
110
|
+
|
111
|
+
self.logger.debug(
|
112
|
+
f"Extracted {len(tool_calls)} valid tool calls from {len(raw_tool_calls)} total"
|
113
|
+
)
|
106
114
|
for i, tool_call in enumerate(tool_calls):
|
107
115
|
tool_name = tool_call.get("function", {}).get("name", "unknown")
|
108
116
|
tool_call_id = tool_call.get("id", "unknown")
|
@@ -118,9 +126,11 @@ class OllamaClient(LLMClient):
|
|
118
126
|
"""Extract content from API response."""
|
119
127
|
# Check for provider errors first
|
120
128
|
if response.get("error", False):
|
121
|
-
self.logger.warning(
|
129
|
+
self.logger.warning(
|
130
|
+
f"Cannot extract content from error response: {response.get('error_type')}"
|
131
|
+
)
|
122
132
|
return ""
|
123
|
-
|
133
|
+
|
124
134
|
if "message" in response and "content" in response["message"]:
|
125
135
|
content = response["message"]["content"]
|
126
136
|
return content if isinstance(content, str) else str(content)
|
@@ -147,9 +157,9 @@ class OllamaClient(LLMClient):
|
|
147
157
|
def get_request_timeout(self) -> int:
|
148
158
|
"""
|
149
159
|
Get the request timeout in seconds for Ollama.
|
150
|
-
|
160
|
+
|
151
161
|
Ollama can be slower than cloud providers, so we use a 2-minute timeout.
|
152
|
-
|
162
|
+
|
153
163
|
Returns:
|
154
164
|
Timeout value in seconds (120)
|
155
165
|
"""
|