todo-agent 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- todo_agent/_version.py +2 -2
- todo_agent/core/conversation_manager.py +1 -1
- todo_agent/core/exceptions.py +54 -3
- todo_agent/core/todo_manager.py +127 -56
- todo_agent/infrastructure/calendar_utils.py +2 -4
- todo_agent/infrastructure/inference.py +158 -52
- todo_agent/infrastructure/llm_client.py +258 -1
- todo_agent/infrastructure/ollama_client.py +77 -76
- todo_agent/infrastructure/openrouter_client.py +77 -72
- todo_agent/infrastructure/prompts/system_prompt.txt +88 -396
- todo_agent/infrastructure/todo_shell.py +37 -27
- todo_agent/interface/cli.py +129 -19
- todo_agent/interface/formatters.py +25 -0
- todo_agent/interface/progress.py +69 -0
- todo_agent/interface/tools.py +142 -23
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/METADATA +3 -3
- todo_agent-0.3.3.dist-info/RECORD +30 -0
- todo_agent-0.3.1.dist-info/RECORD +0 -29
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/WHEEL +0 -0
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/entry_points.txt +0 -0
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {todo_agent-0.3.1.dist-info → todo_agent-0.3.3.dist-info}/top_level.txt +0 -0
todo_agent/_version.py
CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
28
28
|
commit_id: COMMIT_ID
|
29
29
|
__commit_id__: COMMIT_ID
|
30
30
|
|
31
|
-
__version__ = version = '0.3.
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
31
|
+
__version__ = version = '0.3.3'
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 3)
|
33
33
|
|
34
34
|
__commit_id__ = commit_id = None
|
@@ -35,7 +35,7 @@ class ConversationManager:
|
|
35
35
|
"""Manages conversation state and memory for LLM interactions."""
|
36
36
|
|
37
37
|
def __init__(
|
38
|
-
self, max_tokens: int =
|
38
|
+
self, max_tokens: int = 64000, max_messages: int = 100, model: str = "gpt-4"
|
39
39
|
):
|
40
40
|
self.history: List[ConversationMessage] = []
|
41
41
|
self.max_tokens = max_tokens
|
todo_agent/core/exceptions.py
CHANGED
@@ -12,16 +12,67 @@ class TodoError(Exception):
|
|
12
12
|
class TaskNotFoundError(TodoError):
|
13
13
|
"""Task not found in todo file."""
|
14
14
|
|
15
|
-
|
15
|
+
def __init__(self, message: str = "Task not found"):
|
16
|
+
super().__init__(message)
|
17
|
+
self.message = message
|
16
18
|
|
17
19
|
|
18
20
|
class InvalidTaskFormatError(TodoError):
|
19
21
|
"""Invalid task format."""
|
20
22
|
|
21
|
-
|
23
|
+
def __init__(self, message: str = "Invalid task format"):
|
24
|
+
super().__init__(message)
|
25
|
+
self.message = message
|
22
26
|
|
23
27
|
|
24
28
|
class TodoShellError(TodoError):
|
25
29
|
"""Subprocess execution error."""
|
26
30
|
|
27
|
-
|
31
|
+
def __init__(self, message: str = "Todo.sh command failed"):
|
32
|
+
super().__init__(message)
|
33
|
+
self.message = message
|
34
|
+
|
35
|
+
|
36
|
+
class ProviderError(Exception):
|
37
|
+
"""Base exception for LLM provider errors."""
|
38
|
+
|
39
|
+
def __init__(self, message: str, error_type: str, provider: str):
|
40
|
+
super().__init__(message)
|
41
|
+
self.message = message
|
42
|
+
self.error_type = error_type
|
43
|
+
self.provider = provider
|
44
|
+
|
45
|
+
|
46
|
+
class MalformedResponseError(ProviderError):
|
47
|
+
"""Provider returned malformed or invalid response."""
|
48
|
+
|
49
|
+
def __init__(self, message: str, provider: str):
|
50
|
+
super().__init__(message, "malformed_response", provider)
|
51
|
+
|
52
|
+
|
53
|
+
class RateLimitError(ProviderError):
|
54
|
+
"""Provider rate limit exceeded."""
|
55
|
+
|
56
|
+
def __init__(self, message: str, provider: str):
|
57
|
+
super().__init__(message, "rate_limit", provider)
|
58
|
+
|
59
|
+
|
60
|
+
class AuthenticationError(ProviderError):
|
61
|
+
"""Provider authentication failed."""
|
62
|
+
|
63
|
+
def __init__(self, message: str, provider: str):
|
64
|
+
super().__init__(message, "auth_error", provider)
|
65
|
+
|
66
|
+
|
67
|
+
class TimeoutError(ProviderError):
|
68
|
+
"""Provider request timed out."""
|
69
|
+
|
70
|
+
def __init__(self, message: str, provider: str):
|
71
|
+
super().__init__(message, "timeout", provider)
|
72
|
+
|
73
|
+
|
74
|
+
class GeneralProviderError(ProviderError):
|
75
|
+
"""General provider error."""
|
76
|
+
|
77
|
+
def __init__(self, message: str, provider: str):
|
78
|
+
super().__init__(message, "general_error", provider)
|
todo_agent/core/todo_manager.py
CHANGED
@@ -19,7 +19,6 @@ class TodoManager:
|
|
19
19
|
project: Optional[str] = None,
|
20
20
|
context: Optional[str] = None,
|
21
21
|
due: Optional[str] = None,
|
22
|
-
recurring: Optional[str] = None,
|
23
22
|
duration: Optional[str] = None,
|
24
23
|
) -> str:
|
25
24
|
"""Add new task with explicit project/context parameters."""
|
@@ -56,33 +55,6 @@ class TodoManager:
|
|
56
55
|
f"Invalid due date format '{due}'. Must be YYYY-MM-DD."
|
57
56
|
)
|
58
57
|
|
59
|
-
if recurring:
|
60
|
-
# Validate recurring format
|
61
|
-
if not recurring.startswith("rec:"):
|
62
|
-
raise ValueError(
|
63
|
-
f"Invalid recurring format '{recurring}'. Must start with 'rec:'."
|
64
|
-
)
|
65
|
-
# Basic validation of recurring syntax
|
66
|
-
parts = recurring.split(":")
|
67
|
-
if len(parts) < 2 or len(parts) > 3:
|
68
|
-
raise ValueError(
|
69
|
-
f"Invalid recurring format '{recurring}'. Expected 'rec:frequency' or 'rec:frequency:interval'."
|
70
|
-
)
|
71
|
-
frequency = parts[1]
|
72
|
-
if frequency not in ["daily", "weekly", "monthly", "yearly"]:
|
73
|
-
raise ValueError(
|
74
|
-
f"Invalid frequency '{frequency}'. Must be one of: daily, weekly, monthly, yearly."
|
75
|
-
)
|
76
|
-
if len(parts) == 3:
|
77
|
-
try:
|
78
|
-
interval = int(parts[2])
|
79
|
-
if interval < 1:
|
80
|
-
raise ValueError("Interval must be at least 1.")
|
81
|
-
except ValueError:
|
82
|
-
raise ValueError(
|
83
|
-
f"Invalid interval '{parts[2]}'. Must be a positive integer."
|
84
|
-
)
|
85
|
-
|
86
58
|
if duration is not None:
|
87
59
|
# Validate duration format (e.g., "30m", "2h", "1d")
|
88
60
|
if not duration or not isinstance(duration, str):
|
@@ -124,18 +96,17 @@ class TodoManager:
|
|
124
96
|
if due:
|
125
97
|
full_description = f"{full_description} due:{due}"
|
126
98
|
|
127
|
-
if recurring:
|
128
|
-
full_description = f"{full_description} {recurring}"
|
129
|
-
|
130
99
|
if duration:
|
131
100
|
full_description = f"{full_description} duration:{duration}"
|
132
101
|
|
133
102
|
self.todo_shell.add(full_description)
|
134
103
|
return f"Added task: {full_description}"
|
135
104
|
|
136
|
-
def list_tasks(
|
105
|
+
def list_tasks(
|
106
|
+
self, filter: Optional[str] = None, suppress_color: bool = True
|
107
|
+
) -> str:
|
137
108
|
"""List tasks with optional filtering."""
|
138
|
-
result = self.todo_shell.list_tasks(filter)
|
109
|
+
result = self.todo_shell.list_tasks(filter, suppress_color=suppress_color)
|
139
110
|
if not result.strip():
|
140
111
|
return "No tasks found."
|
141
112
|
|
@@ -148,24 +119,6 @@ class TodoManager:
|
|
148
119
|
result = self.todo_shell.complete(task_number)
|
149
120
|
return f"Completed task {task_number}: {result}"
|
150
121
|
|
151
|
-
def get_overview(self, **kwargs: Any) -> str:
|
152
|
-
"""Show current task statistics."""
|
153
|
-
tasks = self.todo_shell.list_tasks()
|
154
|
-
completed = self.todo_shell.list_completed()
|
155
|
-
|
156
|
-
task_count = (
|
157
|
-
len([line for line in tasks.split("\n") if line.strip()])
|
158
|
-
if tasks.strip()
|
159
|
-
else 0
|
160
|
-
)
|
161
|
-
completed_count = (
|
162
|
-
len([line for line in completed.split("\n") if line.strip()])
|
163
|
-
if completed.strip()
|
164
|
-
else 0
|
165
|
-
)
|
166
|
-
|
167
|
-
return f"Task Overview:\n- Active tasks: {task_count}\n- Completed tasks: {completed_count}"
|
168
|
-
|
169
122
|
def replace_task(self, task_number: int, new_description: str) -> str:
|
170
123
|
"""Replace entire task content."""
|
171
124
|
result = self.todo_shell.replace(task_number, new_description)
|
@@ -307,16 +260,16 @@ class TodoManager:
|
|
307
260
|
operation_desc = ", ".join(operations)
|
308
261
|
return f"Updated projects for task {task_number} ({operation_desc}): {result}"
|
309
262
|
|
310
|
-
def list_projects(self, **kwargs: Any) -> str:
|
263
|
+
def list_projects(self, suppress_color: bool = True, **kwargs: Any) -> str:
|
311
264
|
"""List all available projects in todo.txt."""
|
312
|
-
result = self.todo_shell.list_projects()
|
265
|
+
result = self.todo_shell.list_projects(suppress_color=suppress_color)
|
313
266
|
if not result.strip():
|
314
267
|
return "No projects found."
|
315
268
|
return result
|
316
269
|
|
317
|
-
def list_contexts(self, **kwargs: Any) -> str:
|
270
|
+
def list_contexts(self, suppress_color: bool = True, **kwargs: Any) -> str:
|
318
271
|
"""List all available contexts in todo.txt."""
|
319
|
-
result = self.todo_shell.list_contexts()
|
272
|
+
result = self.todo_shell.list_contexts(suppress_color=suppress_color)
|
320
273
|
if not result.strip():
|
321
274
|
return "No contexts found."
|
322
275
|
return result
|
@@ -329,6 +282,7 @@ class TodoManager:
|
|
329
282
|
text_search: Optional[str] = None,
|
330
283
|
date_from: Optional[str] = None,
|
331
284
|
date_to: Optional[str] = None,
|
285
|
+
suppress_color: bool = True,
|
332
286
|
**kwargs: Any,
|
333
287
|
) -> str:
|
334
288
|
"""List completed tasks with optional filtering.
|
@@ -378,7 +332,9 @@ class TodoManager:
|
|
378
332
|
# Combine all filters
|
379
333
|
combined_filter = " ".join(filter_parts) if filter_parts else None
|
380
334
|
|
381
|
-
result = self.todo_shell.list_completed(
|
335
|
+
result = self.todo_shell.list_completed(
|
336
|
+
combined_filter, suppress_color=suppress_color
|
337
|
+
)
|
382
338
|
if not result.strip():
|
383
339
|
return "No completed tasks found matching the criteria."
|
384
340
|
return result
|
@@ -406,3 +362,118 @@ class TodoManager:
|
|
406
362
|
week_number = now.isocalendar()[1]
|
407
363
|
timezone = now.astimezone().tzinfo
|
408
364
|
return f"Current date and time: {now.strftime('%Y-%m-%d %H:%M:%S')} {timezone} ({now.strftime('%A, %B %d, %Y at %I:%M %p')}) - Week {week_number}"
|
365
|
+
|
366
|
+
def created_completed_task(
|
367
|
+
self,
|
368
|
+
description: str,
|
369
|
+
completion_date: Optional[str] = None,
|
370
|
+
project: Optional[str] = None,
|
371
|
+
context: Optional[str] = None,
|
372
|
+
) -> str:
|
373
|
+
"""
|
374
|
+
Create a task and immediately mark it as completed.
|
375
|
+
|
376
|
+
This is a convenience method for handling "I did X on [date]" statements.
|
377
|
+
The task is created with the specified completion date and immediately marked complete.
|
378
|
+
|
379
|
+
Args:
|
380
|
+
description: The task description of what was completed
|
381
|
+
completion_date: Completion date in YYYY-MM-DD format (defaults to today)
|
382
|
+
project: Optional project name (without the + symbol)
|
383
|
+
context: Optional context name (without the @ symbol)
|
384
|
+
|
385
|
+
Returns:
|
386
|
+
Confirmation message with the completed task details
|
387
|
+
"""
|
388
|
+
# Set default completion date to today if not provided
|
389
|
+
if not completion_date:
|
390
|
+
completion_date = datetime.now().strftime("%Y-%m-%d")
|
391
|
+
|
392
|
+
# Validate completion date format
|
393
|
+
try:
|
394
|
+
datetime.strptime(completion_date, "%Y-%m-%d")
|
395
|
+
except ValueError:
|
396
|
+
raise ValueError(
|
397
|
+
f"Invalid completion date format '{completion_date}'. Must be YYYY-MM-DD."
|
398
|
+
)
|
399
|
+
|
400
|
+
# Build the task description with project and context
|
401
|
+
full_description = description
|
402
|
+
|
403
|
+
if project:
|
404
|
+
# Remove any existing + symbols to prevent duplication
|
405
|
+
clean_project = project.strip().lstrip("+")
|
406
|
+
if not clean_project:
|
407
|
+
raise ValueError(
|
408
|
+
"Project name cannot be empty after removing + symbol."
|
409
|
+
)
|
410
|
+
full_description = f"{full_description} +{clean_project}"
|
411
|
+
|
412
|
+
if context:
|
413
|
+
# Remove any existing @ symbols to prevent duplication
|
414
|
+
clean_context = context.strip().lstrip("@")
|
415
|
+
if not clean_context:
|
416
|
+
raise ValueError(
|
417
|
+
"Context name cannot be empty after removing @ symbol."
|
418
|
+
)
|
419
|
+
full_description = f"{full_description} @{clean_context}"
|
420
|
+
|
421
|
+
# Add the task first
|
422
|
+
self.todo_shell.add(full_description)
|
423
|
+
|
424
|
+
# Get the task number by finding the newly added task
|
425
|
+
tasks = self.todo_shell.list_tasks()
|
426
|
+
task_lines = [line.strip() for line in tasks.split("\n") if line.strip()]
|
427
|
+
if not task_lines:
|
428
|
+
raise RuntimeError("Failed to add task - no tasks found after addition")
|
429
|
+
|
430
|
+
# Find the task that matches our description (it should be the last one added)
|
431
|
+
# Look for the task that contains our description
|
432
|
+
task_number = None
|
433
|
+
for i, line in enumerate(task_lines, 1): # Start from 1 for todo.sh numbering
|
434
|
+
if description in line:
|
435
|
+
task_number = i
|
436
|
+
break
|
437
|
+
|
438
|
+
if task_number is None:
|
439
|
+
# Fallback: use the last task number if we can't find a match
|
440
|
+
task_number = len(task_lines)
|
441
|
+
# Log a warning that we're using fallback logic
|
442
|
+
import logging
|
443
|
+
|
444
|
+
logging.warning(
|
445
|
+
f"Could not find exact match for '{description}', using fallback task number {task_number}"
|
446
|
+
)
|
447
|
+
|
448
|
+
# Mark it as complete
|
449
|
+
self.todo_shell.complete(task_number)
|
450
|
+
|
451
|
+
return f"Created and completed task: {full_description} (completed on {completion_date})"
|
452
|
+
|
453
|
+
def restore_completed_task(self, task_number: int) -> str:
|
454
|
+
"""
|
455
|
+
Restore a completed task from done.txt back to todo.txt.
|
456
|
+
|
457
|
+
This method moves a completed task from done.txt back to todo.txt,
|
458
|
+
effectively restoring it to active status.
|
459
|
+
|
460
|
+
Args:
|
461
|
+
task_number: The line number of the completed task in done.txt to restore
|
462
|
+
|
463
|
+
Returns:
|
464
|
+
Confirmation message with the restored task details
|
465
|
+
"""
|
466
|
+
# Validate task number
|
467
|
+
if task_number <= 0:
|
468
|
+
raise ValueError("Task number must be a positive integer")
|
469
|
+
|
470
|
+
# Use the move command to restore the task from done.txt to todo.txt
|
471
|
+
result = self.todo_shell.move(task_number, "todo.txt", "done.txt")
|
472
|
+
|
473
|
+
# Extract the task description from the result for confirmation
|
474
|
+
# The result format is typically: "TODO: X moved from '.../done.txt' to '.../todo.txt'."
|
475
|
+
if "moved from" in result and "to" in result:
|
476
|
+
# Try to extract the task description if possible
|
477
|
+
return f"Restored completed task {task_number} to active status: {result}"
|
478
|
+
else:
|
479
|
+
return f"Restored completed task {task_number} to active status"
|
@@ -15,10 +15,8 @@ def get_calendar_output() -> str:
|
|
15
15
|
Formatted calendar string showing three months side by side
|
16
16
|
"""
|
17
17
|
try:
|
18
|
-
# Use cal
|
19
|
-
result = subprocess.run(
|
20
|
-
["cal", "-3"], capture_output=True, text=True, check=True
|
21
|
-
)
|
18
|
+
# Use cal to get current month calendar
|
19
|
+
result = subprocess.run(["cal"], capture_output=True, text=True, check=True)
|
22
20
|
return result.stdout.strip()
|
23
21
|
except (subprocess.SubprocessError, FileNotFoundError):
|
24
22
|
# Fallback to Python calendar module
|
@@ -12,6 +12,10 @@ try:
|
|
12
12
|
from todo_agent.infrastructure.config import Config
|
13
13
|
from todo_agent.infrastructure.llm_client_factory import LLMClientFactory
|
14
14
|
from todo_agent.infrastructure.logger import Logger
|
15
|
+
from todo_agent.interface.formatters import (
|
16
|
+
get_provider_error_message as _get_error_msg,
|
17
|
+
)
|
18
|
+
from todo_agent.interface.progress import NoOpProgress, ToolCallProgress
|
15
19
|
from todo_agent.interface.tools import ToolCallHandler
|
16
20
|
except ImportError:
|
17
21
|
from core.conversation_manager import ( # type: ignore[no-redef]
|
@@ -23,6 +27,13 @@ except ImportError:
|
|
23
27
|
LLMClientFactory,
|
24
28
|
)
|
25
29
|
from infrastructure.logger import Logger # type: ignore[no-redef]
|
30
|
+
from interface.formatters import ( # type: ignore[no-redef]
|
31
|
+
get_provider_error_message as _get_error_msg,
|
32
|
+
)
|
33
|
+
from interface.progress import ( # type: ignore[no-redef]
|
34
|
+
NoOpProgress,
|
35
|
+
ToolCallProgress,
|
36
|
+
)
|
26
37
|
from interface.tools import ToolCallHandler # type: ignore[no-redef]
|
27
38
|
|
28
39
|
|
@@ -66,11 +77,29 @@ class Inference:
|
|
66
77
|
self.conversation_manager.set_system_prompt(system_prompt)
|
67
78
|
self.logger.debug("System prompt loaded and set")
|
68
79
|
|
80
|
+
def current_tasks(self) -> str:
|
81
|
+
"""
|
82
|
+
Get current tasks from the todo manager.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Formatted string of current tasks or error message
|
86
|
+
"""
|
87
|
+
try:
|
88
|
+
# Use the todo manager from the tool handler to get current tasks
|
89
|
+
tasks = self.tool_handler.todo_manager.list_tasks(suppress_color=True)
|
90
|
+
|
91
|
+
# If no tasks found, return a clear message
|
92
|
+
if not tasks.strip() or tasks == "No tasks found.":
|
93
|
+
return "No current tasks found."
|
94
|
+
|
95
|
+
return tasks
|
96
|
+
|
97
|
+
except Exception as e:
|
98
|
+
self.logger.warning(f"Failed to get current tasks: {e!s}")
|
99
|
+
return f"Error retrieving current tasks: {e!s}"
|
100
|
+
|
69
101
|
def _load_system_prompt(self) -> str:
|
70
102
|
"""Load and format the system prompt from file."""
|
71
|
-
# Generate tools section programmatically
|
72
|
-
tools_section = self._generate_tools_section()
|
73
|
-
|
74
103
|
# Get current datetime for interpolation
|
75
104
|
now = datetime.now()
|
76
105
|
timezone_info = time.tzname[time.daylight]
|
@@ -85,6 +114,9 @@ class Inference:
|
|
85
114
|
self.logger.warning(f"Failed to get calendar output: {e!s}")
|
86
115
|
calendar_output = "Calendar unavailable"
|
87
116
|
|
117
|
+
# Get current tasks
|
118
|
+
current_tasks = self.current_tasks()
|
119
|
+
|
88
120
|
# Load system prompt from file
|
89
121
|
prompt_file_path = os.path.join(
|
90
122
|
os.path.dirname(__file__), "prompts", "system_prompt.txt"
|
@@ -94,11 +126,11 @@ class Inference:
|
|
94
126
|
with open(prompt_file_path, encoding="utf-8") as f:
|
95
127
|
system_prompt_template = f.read()
|
96
128
|
|
97
|
-
# Format the template with
|
129
|
+
# Format the template with current datetime, calendar, and current tasks
|
98
130
|
return system_prompt_template.format(
|
99
|
-
tools_section=tools_section,
|
100
131
|
current_datetime=current_datetime,
|
101
132
|
calendar_output=calendar_output,
|
133
|
+
current_tasks=current_tasks,
|
102
134
|
)
|
103
135
|
|
104
136
|
except FileNotFoundError:
|
@@ -108,58 +140,15 @@ class Inference:
|
|
108
140
|
self.logger.error(f"Error loading system prompt: {e!s}")
|
109
141
|
raise
|
110
142
|
|
111
|
-
def
|
112
|
-
|
113
|
-
|
114
|
-
"Discovery Tools": [
|
115
|
-
"list_projects",
|
116
|
-
"list_contexts",
|
117
|
-
"list_tasks",
|
118
|
-
"list_completed_tasks",
|
119
|
-
],
|
120
|
-
"Modification Tools": [
|
121
|
-
"add_task",
|
122
|
-
"complete_task",
|
123
|
-
"replace_task",
|
124
|
-
"append_to_task",
|
125
|
-
"prepend_to_task",
|
126
|
-
],
|
127
|
-
"Management Tools": [
|
128
|
-
"delete_task",
|
129
|
-
"set_priority",
|
130
|
-
"remove_priority",
|
131
|
-
"move_task",
|
132
|
-
],
|
133
|
-
"Maintenance Tools": ["archive_tasks", "deduplicate_tasks", "get_overview"],
|
134
|
-
}
|
135
|
-
|
136
|
-
tools_section = []
|
137
|
-
for category, tool_names in tool_categories.items():
|
138
|
-
tools_section.append(f"\n**{category}:**")
|
139
|
-
for tool_name in tool_names:
|
140
|
-
tool_info = next(
|
141
|
-
(
|
142
|
-
t
|
143
|
-
for t in self.tool_handler.tools
|
144
|
-
if t["function"]["name"] == tool_name
|
145
|
-
),
|
146
|
-
None,
|
147
|
-
)
|
148
|
-
if tool_info:
|
149
|
-
# Get first sentence of description for concise overview
|
150
|
-
first_sentence = (
|
151
|
-
tool_info["function"]["description"].split(".")[0] + "."
|
152
|
-
)
|
153
|
-
tools_section.append(f"- {tool_name}(): {first_sentence}")
|
154
|
-
|
155
|
-
return "\n".join(tools_section)
|
156
|
-
|
157
|
-
def process_request(self, user_input: str) -> tuple[str, float]:
|
143
|
+
def process_request(
|
144
|
+
self, user_input: str, progress_callback: Optional[ToolCallProgress] = None
|
145
|
+
) -> tuple[str, float]:
|
158
146
|
"""
|
159
147
|
Process a user request through the LLM with tool orchestration.
|
160
148
|
|
161
149
|
Args:
|
162
150
|
user_input: Natural language user request
|
151
|
+
progress_callback: Optional progress callback for tool call tracking
|
163
152
|
|
164
153
|
Returns:
|
165
154
|
Tuple of (formatted response for user, thinking time in seconds)
|
@@ -167,6 +156,13 @@ class Inference:
|
|
167
156
|
# Start timing the request
|
168
157
|
start_time = time.time()
|
169
158
|
|
159
|
+
# Initialize progress callback if not provided
|
160
|
+
if progress_callback is None:
|
161
|
+
progress_callback = NoOpProgress()
|
162
|
+
|
163
|
+
# Notify progress callback that thinking has started
|
164
|
+
progress_callback.on_thinking_start()
|
165
|
+
|
170
166
|
try:
|
171
167
|
self.logger.debug(
|
172
168
|
f"Starting request processing for: {user_input[:30]}{'...' if len(user_input) > 30 else ''}"
|
@@ -188,6 +184,26 @@ class Inference:
|
|
188
184
|
messages=messages, tools=self.tool_handler.tools
|
189
185
|
)
|
190
186
|
|
187
|
+
# Check for provider errors
|
188
|
+
if response.get("error", False):
|
189
|
+
error_type = response.get("error_type", "general_error")
|
190
|
+
provider = response.get("provider", "unknown")
|
191
|
+
self.logger.error(f"Provider error from {provider}: {error_type}")
|
192
|
+
|
193
|
+
error_message = _get_error_msg(error_type)
|
194
|
+
|
195
|
+
# Add error message to conversation
|
196
|
+
self.conversation_manager.add_message(
|
197
|
+
MessageRole.ASSISTANT, error_message
|
198
|
+
)
|
199
|
+
|
200
|
+
# Calculate thinking time and return
|
201
|
+
end_time = time.time()
|
202
|
+
thinking_time = end_time - start_time
|
203
|
+
progress_callback.on_thinking_complete(thinking_time)
|
204
|
+
|
205
|
+
return error_message, thinking_time
|
206
|
+
|
191
207
|
# Extract actual token usage from API response
|
192
208
|
usage = response.get("usage", {})
|
193
209
|
actual_prompt_tokens = usage.get("prompt_tokens", 0)
|
@@ -202,6 +218,7 @@ class Inference:
|
|
202
218
|
|
203
219
|
# Handle multiple tool calls in sequence
|
204
220
|
tool_call_count = 0
|
221
|
+
|
205
222
|
while True:
|
206
223
|
tool_calls = self.llm_client.extract_tool_calls(response)
|
207
224
|
|
@@ -213,6 +230,11 @@ class Inference:
|
|
213
230
|
f"Executing tool call sequence #{tool_call_count} with {len(tool_calls)} tools"
|
214
231
|
)
|
215
232
|
|
233
|
+
# Notify progress callback of sequence start
|
234
|
+
progress_callback.on_sequence_complete(
|
235
|
+
tool_call_count, 0
|
236
|
+
) # We don't know total yet
|
237
|
+
|
216
238
|
# Execute all tool calls and collect results
|
217
239
|
tool_results = []
|
218
240
|
for i, tool_call in enumerate(tool_calls):
|
@@ -225,6 +247,19 @@ class Inference:
|
|
225
247
|
self.logger.debug(f"Tool Call ID: {tool_call_id}")
|
226
248
|
self.logger.debug(f"Raw tool call: {tool_call}")
|
227
249
|
|
250
|
+
# Get progress description for the tool
|
251
|
+
progress_description = self._get_tool_progress_description(
|
252
|
+
tool_name, tool_call
|
253
|
+
)
|
254
|
+
|
255
|
+
# Notify progress callback of tool call start
|
256
|
+
progress_callback.on_tool_call_start(
|
257
|
+
tool_name,
|
258
|
+
progress_description,
|
259
|
+
tool_call_count,
|
260
|
+
0, # We don't know total yet
|
261
|
+
)
|
262
|
+
|
228
263
|
result = self.tool_handler.execute_tool(tool_call)
|
229
264
|
|
230
265
|
# Log tool execution result (success or error)
|
@@ -250,6 +285,28 @@ class Inference:
|
|
250
285
|
messages=messages, tools=self.tool_handler.tools
|
251
286
|
)
|
252
287
|
|
288
|
+
# Check for provider errors in continuation
|
289
|
+
if response.get("error", False):
|
290
|
+
error_type = response.get("error_type", "general_error")
|
291
|
+
provider = response.get("provider", "unknown")
|
292
|
+
self.logger.error(
|
293
|
+
f"Provider error in continuation from {provider}: {error_type}"
|
294
|
+
)
|
295
|
+
|
296
|
+
error_message = _get_error_msg(error_type)
|
297
|
+
|
298
|
+
# Add error message to conversation
|
299
|
+
self.conversation_manager.add_message(
|
300
|
+
MessageRole.ASSISTANT, error_message
|
301
|
+
)
|
302
|
+
|
303
|
+
# Calculate thinking time and return
|
304
|
+
end_time = time.time()
|
305
|
+
thinking_time = end_time - start_time
|
306
|
+
progress_callback.on_thinking_complete(thinking_time)
|
307
|
+
|
308
|
+
return error_message, thinking_time
|
309
|
+
|
253
310
|
# Update with actual tokens from subsequent API calls
|
254
311
|
usage = response.get("usage", {})
|
255
312
|
actual_prompt_tokens = usage.get("prompt_tokens", 0)
|
@@ -262,6 +319,9 @@ class Inference:
|
|
262
319
|
end_time = time.time()
|
263
320
|
thinking_time = end_time - start_time
|
264
321
|
|
322
|
+
# Notify progress callback that thinking is complete
|
323
|
+
progress_callback.on_thinking_complete(thinking_time)
|
324
|
+
|
265
325
|
# Add final assistant response to conversation with thinking time
|
266
326
|
final_content = self.llm_client.extract_content(response)
|
267
327
|
self.conversation_manager.add_message(
|
@@ -295,6 +355,52 @@ class Inference:
|
|
295
355
|
self.tool_handler.tools
|
296
356
|
)
|
297
357
|
|
358
|
+
def _get_tool_progress_description(self, tool_name: str, tool_call: Dict[str, Any]) -> str:
|
359
|
+
"""
|
360
|
+
Get user-friendly progress description for a tool with parameter interpolation.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
tool_name: Name of the tool
|
364
|
+
tool_call: The tool call dictionary containing parameters
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
Progress description string with interpolated parameters
|
368
|
+
"""
|
369
|
+
tool_def = next(
|
370
|
+
(
|
371
|
+
t
|
372
|
+
for t in self.tool_handler.tools
|
373
|
+
if t.get("function", {}).get("name") == tool_name
|
374
|
+
),
|
375
|
+
None,
|
376
|
+
)
|
377
|
+
|
378
|
+
if tool_def and "progress_description" in tool_def:
|
379
|
+
template = tool_def["progress_description"]
|
380
|
+
|
381
|
+
# Extract arguments from tool call
|
382
|
+
arguments = tool_call.get("function", {}).get("arguments", {})
|
383
|
+
if isinstance(arguments, str):
|
384
|
+
import json
|
385
|
+
try:
|
386
|
+
arguments = json.loads(arguments)
|
387
|
+
except json.JSONDecodeError:
|
388
|
+
arguments = {}
|
389
|
+
|
390
|
+
# Use .format() like the system prompt does
|
391
|
+
try:
|
392
|
+
return template.format(**arguments)
|
393
|
+
except KeyError as e:
|
394
|
+
# If a required parameter is missing, fall back to template
|
395
|
+
self.logger.warning(f"Missing parameter {e} for progress description of {tool_name}")
|
396
|
+
return template
|
397
|
+
except Exception as e:
|
398
|
+
self.logger.warning(f"Failed to interpolate progress description for {tool_name}: {e}")
|
399
|
+
return template
|
400
|
+
|
401
|
+
# Fallback to generic description
|
402
|
+
return f"🔧 Executing {tool_name}..."
|
403
|
+
|
298
404
|
def clear_conversation(self) -> None:
|
299
405
|
"""Clear conversation history."""
|
300
406
|
self.conversation_manager.clear_conversation()
|