todo-agent 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,12 +2,32 @@
2
2
  Abstract LLM client interface for todo.sh agent.
3
3
  """
4
4
 
5
+ import time
5
6
  from abc import ABC, abstractmethod
6
7
  from typing import Any, Dict, List
7
8
 
9
+ import requests
10
+
11
+ from todo_agent.infrastructure.logger import Logger
12
+ from todo_agent.infrastructure.token_counter import get_token_counter
13
+
8
14
 
9
15
  class LLMClient(ABC):
10
- """Abstract interface for LLM clients."""
16
+ """Abstract interface for LLM clients with common functionality."""
17
+
18
+ def __init__(self, config: Any, model: str, logger_name: str = "llm_client"):
19
+ """
20
+ Initialize common LLM client functionality.
21
+
22
+ Args:
23
+ config: Configuration object
24
+ model: Model name for token counting
25
+ logger_name: Logger name for this client
26
+ """
27
+ self.config = config
28
+ self.model = model
29
+ self.logger = Logger(logger_name)
30
+ self.token_counter = get_token_counter(model)
11
31
 
12
32
  @abstractmethod
13
33
  def chat_with_tools(
@@ -60,3 +80,240 @@ class LLMClient(ABC):
60
80
  Model name string
61
81
  """
62
82
  pass
83
+
84
+ @abstractmethod
85
+ def _get_request_headers(self) -> Dict[str, str]:
86
+ """
87
+ Get request headers for the API call.
88
+
89
+ Returns:
90
+ Dictionary of headers
91
+ """
92
+ pass
93
+
94
+ @abstractmethod
95
+ def _get_request_payload(
96
+ self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
97
+ ) -> Dict[str, Any]:
98
+ """
99
+ Get request payload for the API call.
100
+
101
+ Args:
102
+ messages: List of message dictionaries
103
+ tools: List of tool definitions
104
+
105
+ Returns:
106
+ Request payload dictionary
107
+ """
108
+ pass
109
+
110
+ @abstractmethod
111
+ def _get_api_endpoint(self) -> str:
112
+ """
113
+ Get the API endpoint for requests.
114
+
115
+ Returns:
116
+ API endpoint URL
117
+ """
118
+ pass
119
+
120
+ @abstractmethod
121
+ def _process_response(
122
+ self, response_data: Dict[str, Any], start_time: float
123
+ ) -> None:
124
+ """
125
+ Process and log response details.
126
+
127
+ Args:
128
+ response_data: Response data from API
129
+ start_time: Request start time for latency calculation
130
+ """
131
+ pass
132
+
133
+ def _log_request_details(self, payload: Dict[str, Any], start_time: float) -> None:
134
+ """Log request details including accurate token count."""
135
+ messages = payload.get("messages", [])
136
+ tools = payload.get("tools", [])
137
+
138
+ total_tokens = self.token_counter.count_request_tokens(messages, tools)
139
+ self.logger.info(f"Request sent - Token count: {total_tokens}")
140
+
141
+ def _make_http_request(
142
+ self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
143
+ ) -> Dict[str, Any]:
144
+ """
145
+ Make HTTP request to the LLM API with common error handling.
146
+
147
+ Args:
148
+ messages: List of message dictionaries
149
+ tools: List of tool definitions
150
+
151
+ Returns:
152
+ API response dictionary
153
+ """
154
+ headers = self._get_request_headers()
155
+ payload = self._get_request_payload(messages, tools)
156
+ endpoint = self._get_api_endpoint()
157
+
158
+ start_time = time.time()
159
+ self._log_request_details(payload, start_time)
160
+
161
+ try:
162
+ response = requests.post( # nosec B113
163
+ endpoint,
164
+ headers=headers,
165
+ json=payload,
166
+ timeout=self.get_request_timeout(),
167
+ )
168
+ except requests.exceptions.Timeout:
169
+ self.logger.error(f"{self.get_provider_name()} API request timed out")
170
+ return self._create_error_response("timeout", "Request timed out")
171
+ except requests.exceptions.ConnectionError as e:
172
+ self.logger.error(f"{self.get_provider_name()} API connection error: {e}")
173
+ return self._create_error_response("timeout", f"Connection error: {e}")
174
+ except requests.exceptions.RequestException as e:
175
+ self.logger.error(f"{self.get_provider_name()} API request error: {e}")
176
+ return self._create_error_response("general_error", f"Request error: {e}")
177
+
178
+ if response.status_code != 200:
179
+ self.logger.error(f"{self.get_provider_name()} API error: {response.text}")
180
+ error_type = self.classify_error(
181
+ Exception(response.text), self.get_provider_name()
182
+ )
183
+ return self._create_error_response(
184
+ error_type, response.text, response.status_code
185
+ )
186
+
187
+ try:
188
+ response_data: Dict[str, Any] = response.json()
189
+ except Exception as e:
190
+ self.logger.error(
191
+ f"Failed to parse {self.get_provider_name()} response JSON: {e}"
192
+ )
193
+ return self._create_error_response(
194
+ "malformed_response", f"JSON parsing failed: {e}", response.status_code
195
+ )
196
+
197
+ self._process_response(response_data, start_time)
198
+ return response_data
199
+
200
+ def _create_error_response(
201
+ self, error_type: str, raw_error: str, status_code: int = 0
202
+ ) -> Dict[str, Any]:
203
+ """
204
+ Create standardized error response.
205
+
206
+ Args:
207
+ error_type: Type of error
208
+ raw_error: Raw error message
209
+ status_code: HTTP status code if available
210
+
211
+ Returns:
212
+ Standardized error response dictionary
213
+ """
214
+ return {
215
+ "error": True,
216
+ "error_type": error_type,
217
+ "provider": self.get_provider_name(),
218
+ "status_code": status_code,
219
+ "raw_error": raw_error,
220
+ }
221
+
222
+ def _validate_tool_call(self, tool_call: Any, index: int) -> bool:
223
+ """
224
+ Validate a tool call structure.
225
+
226
+ Args:
227
+ tool_call: Tool call to validate
228
+ index: Index of the tool call for logging
229
+
230
+ Returns:
231
+ True if valid, False otherwise
232
+ """
233
+ try:
234
+ if not isinstance(tool_call, dict):
235
+ self.logger.warning(
236
+ f"Tool call {index + 1} is not a dictionary: {tool_call}"
237
+ )
238
+ return False
239
+
240
+ function = tool_call.get("function", {})
241
+ if not isinstance(function, dict):
242
+ self.logger.warning(
243
+ f"Tool call {index + 1} function is not a dictionary: {function}"
244
+ )
245
+ return False
246
+
247
+ tool_name = function.get("name")
248
+ if not tool_name:
249
+ self.logger.warning(
250
+ f"Tool call {index + 1} missing function name: {tool_call}"
251
+ )
252
+ return False
253
+
254
+ arguments = function.get("arguments", "{}")
255
+ if arguments and not isinstance(arguments, str):
256
+ self.logger.warning(
257
+ f"Tool call {index + 1} arguments not a string: {arguments}"
258
+ )
259
+ return False
260
+
261
+ return True
262
+ except Exception as e:
263
+ self.logger.warning(f"Error validating tool call {index + 1}: {e}")
264
+ return False
265
+
266
+ def classify_error(self, error: Exception, provider: str) -> str:
267
+ """
268
+ Classify provider errors using simple string matching.
269
+
270
+ Args:
271
+ error: The exception that occurred
272
+ provider: The provider name (e.g., 'openrouter', 'ollama')
273
+
274
+ Returns:
275
+ Error type string for message lookup
276
+ """
277
+ error_str = str(error).lower()
278
+
279
+ if "malformed" in error_str or "invalid" in error_str or "parse" in error_str:
280
+ return "malformed_response"
281
+ elif (
282
+ "rate limit" in error_str
283
+ or "429" in error_str
284
+ or "too many requests" in error_str
285
+ ):
286
+ return "rate_limit"
287
+ elif (
288
+ "unauthorized" in error_str
289
+ or "401" in error_str
290
+ or "authentication" in error_str
291
+ ):
292
+ return "auth_error"
293
+ elif "timeout" in error_str or "timed out" in error_str:
294
+ return "timeout"
295
+ elif "connection" in error_str or "network" in error_str or "dns" in error_str:
296
+ return "timeout" # Treat connection issues as timeouts for user messaging
297
+ elif "refused" in error_str or "unreachable" in error_str:
298
+ return "timeout" # Connection refused is similar to timeout for users
299
+ else:
300
+ return "general_error"
301
+
302
+ @abstractmethod
303
+ def get_provider_name(self) -> str:
304
+ """
305
+ Get the provider name for this client.
306
+
307
+ Returns:
308
+ Provider name string
309
+ """
310
+ pass
311
+
312
+ def get_request_timeout(self) -> int:
313
+ """
314
+ Get the request timeout in seconds for this provider.
315
+
316
+ Returns:
317
+ Timeout value in seconds (default: 30)
318
+ """
319
+ return 30
@@ -2,92 +2,76 @@
2
2
  LLM client for Ollama API communication.
3
3
  """
4
4
 
5
- import json
6
- import time
7
5
  from typing import Any, Dict, List
8
6
 
9
- import requests
10
-
11
- try:
12
- from todo_agent.infrastructure.config import Config
13
- from todo_agent.infrastructure.llm_client import LLMClient
14
- from todo_agent.infrastructure.logger import Logger
15
- from todo_agent.infrastructure.token_counter import get_token_counter
16
- except ImportError:
17
- from infrastructure.config import Config # type: ignore[no-redef]
18
- from infrastructure.llm_client import LLMClient # type: ignore[no-redef]
19
- from infrastructure.logger import Logger # type: ignore[no-redef]
20
- from infrastructure.token_counter import get_token_counter # type: ignore[no-redef]
7
+ from todo_agent.infrastructure.llm_client import LLMClient
21
8
 
22
9
 
23
10
  class OllamaClient(LLMClient):
24
11
  """Ollama API client implementation."""
25
12
 
26
- def __init__(self, config: Config):
13
+ def __init__(self, config: Any) -> None:
27
14
  """
28
15
  Initialize Ollama client.
29
16
 
30
17
  Args:
31
18
  config: Configuration object
32
19
  """
33
- self.config = config
20
+ super().__init__(config, config.ollama_model, "ollama_client")
34
21
  self.base_url = config.ollama_base_url
35
- self.model = config.ollama_model
36
- self.logger = Logger("ollama_client")
37
- self.token_counter = get_token_counter(self.model)
38
-
39
- def _estimate_tokens(self, text: str) -> int:
40
- """
41
- Estimate token count for text using accurate tokenization.
42
22
 
43
- Args:
44
- text: Text to count tokens for
45
-
46
- Returns:
47
- Number of tokens
48
- """
49
- return self.token_counter.count_tokens(text)
50
-
51
- def _log_request_details(self, payload: Dict[str, Any], start_time: float) -> None:
52
- """Log request details including accurate token count."""
53
- # Count tokens for messages
54
- messages = payload.get("messages", [])
55
- tools = payload.get("tools", [])
23
+ def _get_request_headers(self) -> Dict[str, str]:
24
+ """Get request headers for Ollama API."""
25
+ return {
26
+ "Content-Type": "application/json",
27
+ }
56
28
 
57
- total_tokens = self.token_counter.count_request_tokens(messages, tools)
29
+ def _get_request_payload(
30
+ self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
31
+ ) -> Dict[str, Any]:
32
+ """Get request payload for Ollama API."""
33
+ return {
34
+ "model": self.model,
35
+ "messages": messages,
36
+ "tools": tools,
37
+ "stream": False,
38
+ }
58
39
 
59
- self.logger.info(f"Request sent - Token count: {total_tokens}")
60
- # self.logger.debug(f"Raw request payload: {json.dumps(payload, indent=2)}")
40
+ def _get_api_endpoint(self) -> str:
41
+ """Get Ollama API endpoint."""
42
+ return f"{self.base_url}/api/chat"
61
43
 
62
- def _log_response_details(
63
- self, response: Dict[str, Any], start_time: float
44
+ def _process_response(
45
+ self, response_data: Dict[str, Any], start_time: float
64
46
  ) -> None:
65
- """Log response details including latency."""
47
+ """Process and log Ollama response details."""
48
+ import time
49
+
66
50
  end_time = time.time()
67
51
  latency_ms = (end_time - start_time) * 1000
68
52
 
69
53
  self.logger.info(f"Response received - Latency: {latency_ms:.2f}ms")
70
54
 
71
55
  # Log tool call details if present
72
- if "message" in response and "tool_calls" in response["message"]:
73
- tool_calls = response["message"]["tool_calls"]
56
+ if "message" in response_data and "tool_calls" in response_data["message"]:
57
+ tool_calls = response_data["message"]["tool_calls"]
74
58
  self.logger.info(f"Response contains {len(tool_calls)} tool calls")
75
59
 
76
60
  # Log thinking content (response body) if present
77
- content = response["message"].get("content", "")
61
+ content = response_data["message"].get("content", "")
78
62
  if content and content.strip():
79
63
  self.logger.info(f"LLM thinking before tool calls: {content}")
80
64
 
81
65
  for i, tool_call in enumerate(tool_calls):
82
66
  tool_name = tool_call.get("function", {}).get("name", "unknown")
83
67
  self.logger.info(f" Tool call {i + 1}: {tool_name}")
84
- elif "message" in response and "content" in response["message"]:
85
- content = response["message"]["content"]
68
+ elif "message" in response_data and "content" in response_data["message"]:
69
+ content = response_data["message"]["content"]
86
70
  self.logger.debug(
87
71
  f"Response contains content: {content[:100]}{'...' if len(content) > 100 else ''}"
88
72
  )
89
73
 
90
- self.logger.debug(f"Raw response: {json.dumps(response, indent=2)}")
74
+ self.logger.debug(f"Raw response: {response_data}")
91
75
 
92
76
  def chat_with_tools(
93
77
  self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
@@ -102,41 +86,31 @@ class OllamaClient(LLMClient):
102
86
  Returns:
103
87
  API response dictionary
104
88
  """
105
- headers = {
106
- "Content-Type": "application/json",
107
- }
108
-
109
- payload = {
110
- "model": self.model,
111
- "messages": messages,
112
- "tools": tools,
113
- "stream": False,
114
- }
115
-
116
- start_time = time.time()
117
- self._log_request_details(payload, start_time)
118
-
119
- response = requests.post( # nosec B113
120
- f"{self.base_url}/api/chat", headers=headers, json=payload
121
- )
122
-
123
- if response.status_code != 200:
124
- self.logger.error(f"Ollama API error: {response.text}")
125
- raise Exception(f"Ollama API error: {response.text}")
126
-
127
- response_data: Dict[str, Any] = response.json()
128
- self._log_response_details(response_data, start_time)
129
-
130
- return response_data
89
+ return self._make_http_request(messages, tools)
131
90
 
132
91
  def extract_tool_calls(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
133
92
  """Extract tool calls from API response."""
93
+ # Check for provider errors first
94
+ if response.get("error", False):
95
+ self.logger.warning(
96
+ f"Cannot extract tool calls from error response: {response.get('error_type')}"
97
+ )
98
+ return []
99
+
134
100
  tool_calls = []
135
101
 
136
102
  # Ollama response format is different from OpenRouter
137
103
  if "message" in response and "tool_calls" in response["message"]:
138
- tool_calls = response["message"]["tool_calls"]
139
- self.logger.debug(f"Extracted {len(tool_calls)} tool calls from response")
104
+ raw_tool_calls = response["message"]["tool_calls"]
105
+
106
+ # Validate each tool call using common validation
107
+ for i, tool_call in enumerate(raw_tool_calls):
108
+ if self._validate_tool_call(tool_call, i):
109
+ tool_calls.append(tool_call)
110
+
111
+ self.logger.debug(
112
+ f"Extracted {len(tool_calls)} valid tool calls from {len(raw_tool_calls)} total"
113
+ )
140
114
  for i, tool_call in enumerate(tool_calls):
141
115
  tool_name = tool_call.get("function", {}).get("name", "unknown")
142
116
  tool_call_id = tool_call.get("id", "unknown")
@@ -150,6 +124,13 @@ class OllamaClient(LLMClient):
150
124
 
151
125
  def extract_content(self, response: Dict[str, Any]) -> str:
152
126
  """Extract content from API response."""
127
+ # Check for provider errors first
128
+ if response.get("error", False):
129
+ self.logger.warning(
130
+ f"Cannot extract content from error response: {response.get('error_type')}"
131
+ )
132
+ return ""
133
+
153
134
  if "message" in response and "content" in response["message"]:
154
135
  content = response["message"]["content"]
155
136
  return content if isinstance(content, str) else str(content)
@@ -163,3 +144,23 @@ class OllamaClient(LLMClient):
163
144
  Model name string
164
145
  """
165
146
  return self.model
147
+
148
+ def get_provider_name(self) -> str:
149
+ """
150
+ Get the provider name for this client.
151
+
152
+ Returns:
153
+ Provider name string
154
+ """
155
+ return "ollama"
156
+
157
+ def get_request_timeout(self) -> int:
158
+ """
159
+ Get the request timeout in seconds for Ollama.
160
+
161
+ Ollama can be slower than cloud providers, so we use a 2-minute timeout.
162
+
163
+ Returns:
164
+ Timeout value in seconds (120)
165
+ """
166
+ return 120