todo-agent 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- todo_agent/_version.py +2 -2
- todo_agent/core/conversation_manager.py +1 -1
- todo_agent/core/exceptions.py +54 -3
- todo_agent/core/todo_manager.py +127 -56
- todo_agent/infrastructure/calendar_utils.py +2 -4
- todo_agent/infrastructure/inference.py +158 -52
- todo_agent/infrastructure/llm_client.py +258 -1
- todo_agent/infrastructure/ollama_client.py +77 -76
- todo_agent/infrastructure/openrouter_client.py +77 -72
- todo_agent/infrastructure/prompts/system_prompt.txt +88 -396
- todo_agent/infrastructure/todo_shell.py +37 -27
- todo_agent/interface/cli.py +129 -19
- todo_agent/interface/formatters.py +25 -0
- todo_agent/interface/progress.py +69 -0
- todo_agent/interface/tools.py +142 -23
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/METADATA +3 -3
- todo_agent-0.3.3.dist-info/RECORD +30 -0
- todo_agent-0.3.1.dist-info/RECORD +0 -29
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/WHEEL +0 -0
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/entry_points.txt +0 -0
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/top_level.txt +0 -0
@@ -2,12 +2,32 @@
|
|
2
2
|
Abstract LLM client interface for todo.sh agent.
|
3
3
|
"""
|
4
4
|
|
5
|
+
import time
|
5
6
|
from abc import ABC, abstractmethod
|
6
7
|
from typing import Any, Dict, List
|
7
8
|
|
9
|
+
import requests
|
10
|
+
|
11
|
+
from todo_agent.infrastructure.logger import Logger
|
12
|
+
from todo_agent.infrastructure.token_counter import get_token_counter
|
13
|
+
|
8
14
|
|
9
15
|
class LLMClient(ABC):
|
10
|
-
"""Abstract interface for LLM clients."""
|
16
|
+
"""Abstract interface for LLM clients with common functionality."""
|
17
|
+
|
18
|
+
def __init__(self, config: Any, model: str, logger_name: str = "llm_client"):
|
19
|
+
"""
|
20
|
+
Initialize common LLM client functionality.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
config: Configuration object
|
24
|
+
model: Model name for token counting
|
25
|
+
logger_name: Logger name for this client
|
26
|
+
"""
|
27
|
+
self.config = config
|
28
|
+
self.model = model
|
29
|
+
self.logger = Logger(logger_name)
|
30
|
+
self.token_counter = get_token_counter(model)
|
11
31
|
|
12
32
|
@abstractmethod
|
13
33
|
def chat_with_tools(
|
@@ -60,3 +80,240 @@ class LLMClient(ABC):
|
|
60
80
|
Model name string
|
61
81
|
"""
|
62
82
|
pass
|
83
|
+
|
84
|
+
@abstractmethod
|
85
|
+
def _get_request_headers(self) -> Dict[str, str]:
|
86
|
+
"""
|
87
|
+
Get request headers for the API call.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
Dictionary of headers
|
91
|
+
"""
|
92
|
+
pass
|
93
|
+
|
94
|
+
@abstractmethod
|
95
|
+
def _get_request_payload(
|
96
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
97
|
+
) -> Dict[str, Any]:
|
98
|
+
"""
|
99
|
+
Get request payload for the API call.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
messages: List of message dictionaries
|
103
|
+
tools: List of tool definitions
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
Request payload dictionary
|
107
|
+
"""
|
108
|
+
pass
|
109
|
+
|
110
|
+
@abstractmethod
|
111
|
+
def _get_api_endpoint(self) -> str:
|
112
|
+
"""
|
113
|
+
Get the API endpoint for requests.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
API endpoint URL
|
117
|
+
"""
|
118
|
+
pass
|
119
|
+
|
120
|
+
@abstractmethod
|
121
|
+
def _process_response(
|
122
|
+
self, response_data: Dict[str, Any], start_time: float
|
123
|
+
) -> None:
|
124
|
+
"""
|
125
|
+
Process and log response details.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
response_data: Response data from API
|
129
|
+
start_time: Request start time for latency calculation
|
130
|
+
"""
|
131
|
+
pass
|
132
|
+
|
133
|
+
def _log_request_details(self, payload: Dict[str, Any], start_time: float) -> None:
|
134
|
+
"""Log request details including accurate token count."""
|
135
|
+
messages = payload.get("messages", [])
|
136
|
+
tools = payload.get("tools", [])
|
137
|
+
|
138
|
+
total_tokens = self.token_counter.count_request_tokens(messages, tools)
|
139
|
+
self.logger.info(f"Request sent - Token count: {total_tokens}")
|
140
|
+
|
141
|
+
def _make_http_request(
|
142
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
143
|
+
) -> Dict[str, Any]:
|
144
|
+
"""
|
145
|
+
Make HTTP request to the LLM API with common error handling.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
messages: List of message dictionaries
|
149
|
+
tools: List of tool definitions
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
API response dictionary
|
153
|
+
"""
|
154
|
+
headers = self._get_request_headers()
|
155
|
+
payload = self._get_request_payload(messages, tools)
|
156
|
+
endpoint = self._get_api_endpoint()
|
157
|
+
|
158
|
+
start_time = time.time()
|
159
|
+
self._log_request_details(payload, start_time)
|
160
|
+
|
161
|
+
try:
|
162
|
+
response = requests.post( # nosec B113
|
163
|
+
endpoint,
|
164
|
+
headers=headers,
|
165
|
+
json=payload,
|
166
|
+
timeout=self.get_request_timeout(),
|
167
|
+
)
|
168
|
+
except requests.exceptions.Timeout:
|
169
|
+
self.logger.error(f"{self.get_provider_name()} API request timed out")
|
170
|
+
return self._create_error_response("timeout", "Request timed out")
|
171
|
+
except requests.exceptions.ConnectionError as e:
|
172
|
+
self.logger.error(f"{self.get_provider_name()} API connection error: {e}")
|
173
|
+
return self._create_error_response("timeout", f"Connection error: {e}")
|
174
|
+
except requests.exceptions.RequestException as e:
|
175
|
+
self.logger.error(f"{self.get_provider_name()} API request error: {e}")
|
176
|
+
return self._create_error_response("general_error", f"Request error: {e}")
|
177
|
+
|
178
|
+
if response.status_code != 200:
|
179
|
+
self.logger.error(f"{self.get_provider_name()} API error: {response.text}")
|
180
|
+
error_type = self.classify_error(
|
181
|
+
Exception(response.text), self.get_provider_name()
|
182
|
+
)
|
183
|
+
return self._create_error_response(
|
184
|
+
error_type, response.text, response.status_code
|
185
|
+
)
|
186
|
+
|
187
|
+
try:
|
188
|
+
response_data: Dict[str, Any] = response.json()
|
189
|
+
except Exception as e:
|
190
|
+
self.logger.error(
|
191
|
+
f"Failed to parse {self.get_provider_name()} response JSON: {e}"
|
192
|
+
)
|
193
|
+
return self._create_error_response(
|
194
|
+
"malformed_response", f"JSON parsing failed: {e}", response.status_code
|
195
|
+
)
|
196
|
+
|
197
|
+
self._process_response(response_data, start_time)
|
198
|
+
return response_data
|
199
|
+
|
200
|
+
def _create_error_response(
|
201
|
+
self, error_type: str, raw_error: str, status_code: int = 0
|
202
|
+
) -> Dict[str, Any]:
|
203
|
+
"""
|
204
|
+
Create standardized error response.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
error_type: Type of error
|
208
|
+
raw_error: Raw error message
|
209
|
+
status_code: HTTP status code if available
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Standardized error response dictionary
|
213
|
+
"""
|
214
|
+
return {
|
215
|
+
"error": True,
|
216
|
+
"error_type": error_type,
|
217
|
+
"provider": self.get_provider_name(),
|
218
|
+
"status_code": status_code,
|
219
|
+
"raw_error": raw_error,
|
220
|
+
}
|
221
|
+
|
222
|
+
def _validate_tool_call(self, tool_call: Any, index: int) -> bool:
|
223
|
+
"""
|
224
|
+
Validate a tool call structure.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
tool_call: Tool call to validate
|
228
|
+
index: Index of the tool call for logging
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
True if valid, False otherwise
|
232
|
+
"""
|
233
|
+
try:
|
234
|
+
if not isinstance(tool_call, dict):
|
235
|
+
self.logger.warning(
|
236
|
+
f"Tool call {index + 1} is not a dictionary: {tool_call}"
|
237
|
+
)
|
238
|
+
return False
|
239
|
+
|
240
|
+
function = tool_call.get("function", {})
|
241
|
+
if not isinstance(function, dict):
|
242
|
+
self.logger.warning(
|
243
|
+
f"Tool call {index + 1} function is not a dictionary: {function}"
|
244
|
+
)
|
245
|
+
return False
|
246
|
+
|
247
|
+
tool_name = function.get("name")
|
248
|
+
if not tool_name:
|
249
|
+
self.logger.warning(
|
250
|
+
f"Tool call {index + 1} missing function name: {tool_call}"
|
251
|
+
)
|
252
|
+
return False
|
253
|
+
|
254
|
+
arguments = function.get("arguments", "{}")
|
255
|
+
if arguments and not isinstance(arguments, str):
|
256
|
+
self.logger.warning(
|
257
|
+
f"Tool call {index + 1} arguments not a string: {arguments}"
|
258
|
+
)
|
259
|
+
return False
|
260
|
+
|
261
|
+
return True
|
262
|
+
except Exception as e:
|
263
|
+
self.logger.warning(f"Error validating tool call {index + 1}: {e}")
|
264
|
+
return False
|
265
|
+
|
266
|
+
def classify_error(self, error: Exception, provider: str) -> str:
|
267
|
+
"""
|
268
|
+
Classify provider errors using simple string matching.
|
269
|
+
|
270
|
+
Args:
|
271
|
+
error: The exception that occurred
|
272
|
+
provider: The provider name (e.g., 'openrouter', 'ollama')
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
Error type string for message lookup
|
276
|
+
"""
|
277
|
+
error_str = str(error).lower()
|
278
|
+
|
279
|
+
if "malformed" in error_str or "invalid" in error_str or "parse" in error_str:
|
280
|
+
return "malformed_response"
|
281
|
+
elif (
|
282
|
+
"rate limit" in error_str
|
283
|
+
or "429" in error_str
|
284
|
+
or "too many requests" in error_str
|
285
|
+
):
|
286
|
+
return "rate_limit"
|
287
|
+
elif (
|
288
|
+
"unauthorized" in error_str
|
289
|
+
or "401" in error_str
|
290
|
+
or "authentication" in error_str
|
291
|
+
):
|
292
|
+
return "auth_error"
|
293
|
+
elif "timeout" in error_str or "timed out" in error_str:
|
294
|
+
return "timeout"
|
295
|
+
elif "connection" in error_str or "network" in error_str or "dns" in error_str:
|
296
|
+
return "timeout" # Treat connection issues as timeouts for user messaging
|
297
|
+
elif "refused" in error_str or "unreachable" in error_str:
|
298
|
+
return "timeout" # Connection refused is similar to timeout for users
|
299
|
+
else:
|
300
|
+
return "general_error"
|
301
|
+
|
302
|
+
@abstractmethod
|
303
|
+
def get_provider_name(self) -> str:
|
304
|
+
"""
|
305
|
+
Get the provider name for this client.
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
Provider name string
|
309
|
+
"""
|
310
|
+
pass
|
311
|
+
|
312
|
+
def get_request_timeout(self) -> int:
|
313
|
+
"""
|
314
|
+
Get the request timeout in seconds for this provider.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
Timeout value in seconds (default: 30)
|
318
|
+
"""
|
319
|
+
return 30
|
@@ -2,92 +2,76 @@
|
|
2
2
|
LLM client for Ollama API communication.
|
3
3
|
"""
|
4
4
|
|
5
|
-
import json
|
6
|
-
import time
|
7
5
|
from typing import Any, Dict, List
|
8
6
|
|
9
|
-
import
|
10
|
-
|
11
|
-
try:
|
12
|
-
from todo_agent.infrastructure.config import Config
|
13
|
-
from todo_agent.infrastructure.llm_client import LLMClient
|
14
|
-
from todo_agent.infrastructure.logger import Logger
|
15
|
-
from todo_agent.infrastructure.token_counter import get_token_counter
|
16
|
-
except ImportError:
|
17
|
-
from infrastructure.config import Config # type: ignore[no-redef]
|
18
|
-
from infrastructure.llm_client import LLMClient # type: ignore[no-redef]
|
19
|
-
from infrastructure.logger import Logger # type: ignore[no-redef]
|
20
|
-
from infrastructure.token_counter import get_token_counter # type: ignore[no-redef]
|
7
|
+
from todo_agent.infrastructure.llm_client import LLMClient
|
21
8
|
|
22
9
|
|
23
10
|
class OllamaClient(LLMClient):
|
24
11
|
"""Ollama API client implementation."""
|
25
12
|
|
26
|
-
def __init__(self, config:
|
13
|
+
def __init__(self, config: Any) -> None:
|
27
14
|
"""
|
28
15
|
Initialize Ollama client.
|
29
16
|
|
30
17
|
Args:
|
31
18
|
config: Configuration object
|
32
19
|
"""
|
33
|
-
|
20
|
+
super().__init__(config, config.ollama_model, "ollama_client")
|
34
21
|
self.base_url = config.ollama_base_url
|
35
|
-
self.model = config.ollama_model
|
36
|
-
self.logger = Logger("ollama_client")
|
37
|
-
self.token_counter = get_token_counter(self.model)
|
38
|
-
|
39
|
-
def _estimate_tokens(self, text: str) -> int:
|
40
|
-
"""
|
41
|
-
Estimate token count for text using accurate tokenization.
|
42
22
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
"""
|
49
|
-
return self.token_counter.count_tokens(text)
|
50
|
-
|
51
|
-
def _log_request_details(self, payload: Dict[str, Any], start_time: float) -> None:
|
52
|
-
"""Log request details including accurate token count."""
|
53
|
-
# Count tokens for messages
|
54
|
-
messages = payload.get("messages", [])
|
55
|
-
tools = payload.get("tools", [])
|
23
|
+
def _get_request_headers(self) -> Dict[str, str]:
|
24
|
+
"""Get request headers for Ollama API."""
|
25
|
+
return {
|
26
|
+
"Content-Type": "application/json",
|
27
|
+
}
|
56
28
|
|
57
|
-
|
29
|
+
def _get_request_payload(
|
30
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
31
|
+
) -> Dict[str, Any]:
|
32
|
+
"""Get request payload for Ollama API."""
|
33
|
+
return {
|
34
|
+
"model": self.model,
|
35
|
+
"messages": messages,
|
36
|
+
"tools": tools,
|
37
|
+
"stream": False,
|
38
|
+
}
|
58
39
|
|
59
|
-
|
60
|
-
|
40
|
+
def _get_api_endpoint(self) -> str:
|
41
|
+
"""Get Ollama API endpoint."""
|
42
|
+
return f"{self.base_url}/api/chat"
|
61
43
|
|
62
|
-
def
|
63
|
-
self,
|
44
|
+
def _process_response(
|
45
|
+
self, response_data: Dict[str, Any], start_time: float
|
64
46
|
) -> None:
|
65
|
-
"""
|
47
|
+
"""Process and log Ollama response details."""
|
48
|
+
import time
|
49
|
+
|
66
50
|
end_time = time.time()
|
67
51
|
latency_ms = (end_time - start_time) * 1000
|
68
52
|
|
69
53
|
self.logger.info(f"Response received - Latency: {latency_ms:.2f}ms")
|
70
54
|
|
71
55
|
# Log tool call details if present
|
72
|
-
if "message" in
|
73
|
-
tool_calls =
|
56
|
+
if "message" in response_data and "tool_calls" in response_data["message"]:
|
57
|
+
tool_calls = response_data["message"]["tool_calls"]
|
74
58
|
self.logger.info(f"Response contains {len(tool_calls)} tool calls")
|
75
59
|
|
76
60
|
# Log thinking content (response body) if present
|
77
|
-
content =
|
61
|
+
content = response_data["message"].get("content", "")
|
78
62
|
if content and content.strip():
|
79
63
|
self.logger.info(f"LLM thinking before tool calls: {content}")
|
80
64
|
|
81
65
|
for i, tool_call in enumerate(tool_calls):
|
82
66
|
tool_name = tool_call.get("function", {}).get("name", "unknown")
|
83
67
|
self.logger.info(f" Tool call {i + 1}: {tool_name}")
|
84
|
-
elif "message" in
|
85
|
-
content =
|
68
|
+
elif "message" in response_data and "content" in response_data["message"]:
|
69
|
+
content = response_data["message"]["content"]
|
86
70
|
self.logger.debug(
|
87
71
|
f"Response contains content: {content[:100]}{'...' if len(content) > 100 else ''}"
|
88
72
|
)
|
89
73
|
|
90
|
-
self.logger.debug(f"Raw response: {
|
74
|
+
self.logger.debug(f"Raw response: {response_data}")
|
91
75
|
|
92
76
|
def chat_with_tools(
|
93
77
|
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
@@ -102,41 +86,31 @@ class OllamaClient(LLMClient):
|
|
102
86
|
Returns:
|
103
87
|
API response dictionary
|
104
88
|
"""
|
105
|
-
|
106
|
-
"Content-Type": "application/json",
|
107
|
-
}
|
108
|
-
|
109
|
-
payload = {
|
110
|
-
"model": self.model,
|
111
|
-
"messages": messages,
|
112
|
-
"tools": tools,
|
113
|
-
"stream": False,
|
114
|
-
}
|
115
|
-
|
116
|
-
start_time = time.time()
|
117
|
-
self._log_request_details(payload, start_time)
|
118
|
-
|
119
|
-
response = requests.post( # nosec B113
|
120
|
-
f"{self.base_url}/api/chat", headers=headers, json=payload
|
121
|
-
)
|
122
|
-
|
123
|
-
if response.status_code != 200:
|
124
|
-
self.logger.error(f"Ollama API error: {response.text}")
|
125
|
-
raise Exception(f"Ollama API error: {response.text}")
|
126
|
-
|
127
|
-
response_data: Dict[str, Any] = response.json()
|
128
|
-
self._log_response_details(response_data, start_time)
|
129
|
-
|
130
|
-
return response_data
|
89
|
+
return self._make_http_request(messages, tools)
|
131
90
|
|
132
91
|
def extract_tool_calls(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
133
92
|
"""Extract tool calls from API response."""
|
93
|
+
# Check for provider errors first
|
94
|
+
if response.get("error", False):
|
95
|
+
self.logger.warning(
|
96
|
+
f"Cannot extract tool calls from error response: {response.get('error_type')}"
|
97
|
+
)
|
98
|
+
return []
|
99
|
+
|
134
100
|
tool_calls = []
|
135
101
|
|
136
102
|
# Ollama response format is different from OpenRouter
|
137
103
|
if "message" in response and "tool_calls" in response["message"]:
|
138
|
-
|
139
|
-
|
104
|
+
raw_tool_calls = response["message"]["tool_calls"]
|
105
|
+
|
106
|
+
# Validate each tool call using common validation
|
107
|
+
for i, tool_call in enumerate(raw_tool_calls):
|
108
|
+
if self._validate_tool_call(tool_call, i):
|
109
|
+
tool_calls.append(tool_call)
|
110
|
+
|
111
|
+
self.logger.debug(
|
112
|
+
f"Extracted {len(tool_calls)} valid tool calls from {len(raw_tool_calls)} total"
|
113
|
+
)
|
140
114
|
for i, tool_call in enumerate(tool_calls):
|
141
115
|
tool_name = tool_call.get("function", {}).get("name", "unknown")
|
142
116
|
tool_call_id = tool_call.get("id", "unknown")
|
@@ -150,6 +124,13 @@ class OllamaClient(LLMClient):
|
|
150
124
|
|
151
125
|
def extract_content(self, response: Dict[str, Any]) -> str:
|
152
126
|
"""Extract content from API response."""
|
127
|
+
# Check for provider errors first
|
128
|
+
if response.get("error", False):
|
129
|
+
self.logger.warning(
|
130
|
+
f"Cannot extract content from error response: {response.get('error_type')}"
|
131
|
+
)
|
132
|
+
return ""
|
133
|
+
|
153
134
|
if "message" in response and "content" in response["message"]:
|
154
135
|
content = response["message"]["content"]
|
155
136
|
return content if isinstance(content, str) else str(content)
|
@@ -163,3 +144,23 @@ class OllamaClient(LLMClient):
|
|
163
144
|
Model name string
|
164
145
|
"""
|
165
146
|
return self.model
|
147
|
+
|
148
|
+
def get_provider_name(self) -> str:
|
149
|
+
"""
|
150
|
+
Get the provider name for this client.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Provider name string
|
154
|
+
"""
|
155
|
+
return "ollama"
|
156
|
+
|
157
|
+
def get_request_timeout(self) -> int:
|
158
|
+
"""
|
159
|
+
Get the request timeout in seconds for Ollama.
|
160
|
+
|
161
|
+
Ollama can be slower than cloud providers, so we use a 2-minute timeout.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
Timeout value in seconds (120)
|
165
|
+
"""
|
166
|
+
return 120
|