todo-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- todo_agent/__init__.py +14 -0
- todo_agent/_version.py +34 -0
- todo_agent/core/__init__.py +16 -0
- todo_agent/core/conversation_manager.py +310 -0
- todo_agent/core/exceptions.py +27 -0
- todo_agent/core/todo_manager.py +194 -0
- todo_agent/infrastructure/__init__.py +11 -0
- todo_agent/infrastructure/config.py +59 -0
- todo_agent/infrastructure/inference.py +221 -0
- todo_agent/infrastructure/llm_client.py +62 -0
- todo_agent/infrastructure/llm_client_factory.py +48 -0
- todo_agent/infrastructure/logger.py +128 -0
- todo_agent/infrastructure/ollama_client.py +152 -0
- todo_agent/infrastructure/openrouter_client.py +173 -0
- todo_agent/infrastructure/prompts/system_prompt.txt +51 -0
- todo_agent/infrastructure/todo_shell.py +151 -0
- todo_agent/infrastructure/token_counter.py +184 -0
- todo_agent/interface/__init__.py +10 -0
- todo_agent/interface/cli.py +210 -0
- todo_agent/interface/tools.py +578 -0
- todo_agent/main.py +54 -0
- todo_agent-0.1.0.dist-info/METADATA +282 -0
- todo_agent-0.1.0.dist-info/RECORD +27 -0
- todo_agent-0.1.0.dist-info/WHEEL +5 -0
- todo_agent-0.1.0.dist-info/entry_points.txt +2 -0
- todo_agent-0.1.0.dist-info/licenses/LICENSE +674 -0
- todo_agent-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,221 @@
|
|
1
|
+
"""
|
2
|
+
LLM inference engine for todo.sh agent.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import time
|
7
|
+
from typing import Any, Dict, List, Optional
|
8
|
+
|
9
|
+
try:
|
10
|
+
from todo_agent.infrastructure.config import Config
|
11
|
+
from todo_agent.infrastructure.llm_client_factory import LLMClientFactory
|
12
|
+
from todo_agent.infrastructure.logger import Logger
|
13
|
+
from todo_agent.core.conversation_manager import ConversationManager, MessageRole
|
14
|
+
from todo_agent.interface.tools import ToolCallHandler
|
15
|
+
except ImportError:
|
16
|
+
from infrastructure.config import Config
|
17
|
+
from infrastructure.llm_client_factory import LLMClientFactory
|
18
|
+
from infrastructure.logger import Logger
|
19
|
+
from core.conversation_manager import ConversationManager, MessageRole
|
20
|
+
from interface.tools import ToolCallHandler
|
21
|
+
|
22
|
+
|
23
|
+
class Inference:
|
24
|
+
"""LLM inference engine that orchestrates tool calling and conversation management."""
|
25
|
+
|
26
|
+
def __init__(self, config: Config, tool_handler: ToolCallHandler, logger: Optional[Logger] = None):
|
27
|
+
"""
|
28
|
+
Initialize the inference engine.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
config: Configuration object
|
32
|
+
tool_handler: Tool call handler for executing tools
|
33
|
+
logger: Optional logger instance
|
34
|
+
"""
|
35
|
+
self.config = config
|
36
|
+
self.tool_handler = tool_handler
|
37
|
+
self.logger = logger or Logger("inference")
|
38
|
+
|
39
|
+
# Initialize LLM client using factory
|
40
|
+
self.llm_client = LLMClientFactory.create_client(config, self.logger)
|
41
|
+
|
42
|
+
# Initialize conversation manager
|
43
|
+
self.conversation_manager = ConversationManager()
|
44
|
+
|
45
|
+
# Set up system prompt
|
46
|
+
self._setup_system_prompt()
|
47
|
+
|
48
|
+
self.logger.info(f"Inference engine initialized with {config.provider} provider using model: {self.llm_client.get_model_name()}")
|
49
|
+
|
50
|
+
def _setup_system_prompt(self) -> None:
|
51
|
+
"""Set up the system prompt for the LLM."""
|
52
|
+
system_prompt = self._load_system_prompt()
|
53
|
+
self.conversation_manager.set_system_prompt(system_prompt)
|
54
|
+
self.logger.debug("System prompt loaded and set")
|
55
|
+
|
56
|
+
def _load_system_prompt(self) -> str:
|
57
|
+
"""Load and format the system prompt from file."""
|
58
|
+
# Generate tools section programmatically
|
59
|
+
tools_section = self._generate_tools_section()
|
60
|
+
|
61
|
+
# Get current datetime for interpolation
|
62
|
+
from datetime import datetime
|
63
|
+
current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
64
|
+
|
65
|
+
# Load system prompt from file
|
66
|
+
prompt_file_path = os.path.join(os.path.dirname(__file__), "prompts", "system_prompt.txt")
|
67
|
+
|
68
|
+
try:
|
69
|
+
with open(prompt_file_path, 'r', encoding='utf-8') as f:
|
70
|
+
system_prompt_template = f.read()
|
71
|
+
|
72
|
+
# Format the template with the tools section and current datetime
|
73
|
+
return system_prompt_template.format(
|
74
|
+
tools_section=tools_section,
|
75
|
+
current_datetime=current_datetime
|
76
|
+
)
|
77
|
+
|
78
|
+
except FileNotFoundError:
|
79
|
+
self.logger.error(f"System prompt file not found: {prompt_file_path}")
|
80
|
+
raise
|
81
|
+
except Exception as e:
|
82
|
+
self.logger.error(f"Error loading system prompt: {str(e)}")
|
83
|
+
raise
|
84
|
+
|
85
|
+
def _generate_tools_section(self) -> str:
|
86
|
+
"""Generate the AVAILABLE TOOLS section with strategic categorization."""
|
87
|
+
tool_categories = {
|
88
|
+
"Discovery Tools": ["list_projects", "list_contexts", "list_tasks", "list_completed_tasks"],
|
89
|
+
"Modification Tools": ["add_task", "complete_task", "replace_task", "append_to_task", "prepend_to_task"],
|
90
|
+
"Management Tools": ["delete_task", "set_priority", "remove_priority", "move_task"],
|
91
|
+
"Maintenance Tools": ["archive_tasks", "deduplicate_tasks", "get_overview"]
|
92
|
+
}
|
93
|
+
|
94
|
+
tools_section = []
|
95
|
+
for category, tool_names in tool_categories.items():
|
96
|
+
tools_section.append(f"\n**{category}:**")
|
97
|
+
for tool_name in tool_names:
|
98
|
+
tool_info = next((t for t in self.tool_handler.tools if t["function"]["name"] == tool_name), None)
|
99
|
+
if tool_info:
|
100
|
+
# Get first sentence of description for concise overview
|
101
|
+
first_sentence = tool_info["function"]["description"].split('.')[0] + '.'
|
102
|
+
tools_section.append(f"- {tool_name}(): {first_sentence}")
|
103
|
+
|
104
|
+
return '\n'.join(tools_section)
|
105
|
+
|
106
|
+
def process_request(self, user_input: str) -> tuple[str, float]:
|
107
|
+
"""
|
108
|
+
Process a user request through the LLM with tool orchestration.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
user_input: Natural language user request
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Tuple of (formatted response for user, thinking time in seconds)
|
115
|
+
"""
|
116
|
+
# Start timing the request
|
117
|
+
start_time = time.time()
|
118
|
+
|
119
|
+
try:
|
120
|
+
self.logger.debug(f"Starting request processing for: {user_input[:30]}{'...' if len(user_input) > 30 else ''}")
|
121
|
+
|
122
|
+
# Add user message to conversation
|
123
|
+
self.conversation_manager.add_message(MessageRole.USER, user_input)
|
124
|
+
self.logger.debug("Added user message to conversation")
|
125
|
+
|
126
|
+
# Get conversation history for LLM
|
127
|
+
messages = self.conversation_manager.get_messages()
|
128
|
+
self.logger.debug(f"Retrieved {len(messages)} messages from conversation history")
|
129
|
+
|
130
|
+
# Send to LLM with function calling enabled
|
131
|
+
self.logger.debug("Sending request to LLM with tools")
|
132
|
+
response = self.llm_client.chat_with_tools(
|
133
|
+
messages=messages, tools=self.tool_handler.tools
|
134
|
+
)
|
135
|
+
|
136
|
+
# Handle multiple tool calls in sequence
|
137
|
+
tool_call_count = 0
|
138
|
+
while True:
|
139
|
+
tool_calls = self.llm_client.extract_tool_calls(response)
|
140
|
+
|
141
|
+
if not tool_calls:
|
142
|
+
break
|
143
|
+
|
144
|
+
tool_call_count += 1
|
145
|
+
self.logger.debug(f"Executing tool call sequence #{tool_call_count} with {len(tool_calls)} tools")
|
146
|
+
|
147
|
+
# Execute all tool calls and collect results
|
148
|
+
tool_results = []
|
149
|
+
for i, tool_call in enumerate(tool_calls):
|
150
|
+
tool_name = tool_call.get("function", {}).get("name", "unknown")
|
151
|
+
tool_call_id = tool_call.get("id", "unknown")
|
152
|
+
self.logger.debug(f"=== TOOL EXECUTION #{i+1}/{len(tool_calls)} ===")
|
153
|
+
self.logger.debug(f"Tool: {tool_name}")
|
154
|
+
self.logger.debug(f"Tool Call ID: {tool_call_id}")
|
155
|
+
self.logger.debug(f"Raw tool call: {tool_call}")
|
156
|
+
|
157
|
+
result = self.tool_handler.execute_tool(tool_call)
|
158
|
+
|
159
|
+
# Log tool execution result (success or error)
|
160
|
+
if result.get("error", False):
|
161
|
+
self.logger.warning(f"Tool {tool_name} failed: {result.get('user_message', result.get('output', 'Unknown error'))}")
|
162
|
+
else:
|
163
|
+
self.logger.debug(f"Tool {tool_name} succeeded")
|
164
|
+
|
165
|
+
self.logger.debug(f"Tool result: {result}")
|
166
|
+
tool_results.append(result)
|
167
|
+
|
168
|
+
# Add tool call sequence to conversation
|
169
|
+
self.conversation_manager.add_tool_call_sequence(
|
170
|
+
tool_calls, tool_results
|
171
|
+
)
|
172
|
+
self.logger.debug("Added tool call sequence to conversation")
|
173
|
+
|
174
|
+
# Continue conversation with tool results
|
175
|
+
messages = self.conversation_manager.get_messages()
|
176
|
+
response = self.llm_client.chat_with_tools(
|
177
|
+
messages=messages, tools=self.tool_handler.tools
|
178
|
+
)
|
179
|
+
|
180
|
+
# Calculate and log total thinking time
|
181
|
+
end_time = time.time()
|
182
|
+
thinking_time = end_time - start_time
|
183
|
+
|
184
|
+
# Add final assistant response to conversation with thinking time
|
185
|
+
final_content = self.llm_client.extract_content(response)
|
186
|
+
self.conversation_manager.add_message(MessageRole.ASSISTANT, final_content, thinking_time=thinking_time)
|
187
|
+
|
188
|
+
self.logger.info(f"Request completed successfully with {tool_call_count} tool call sequences in {thinking_time:.2f}s")
|
189
|
+
|
190
|
+
# Return final user-facing response and thinking time
|
191
|
+
return final_content, thinking_time
|
192
|
+
|
193
|
+
except Exception as e:
|
194
|
+
# Calculate and log thinking time even for failed requests
|
195
|
+
end_time = time.time()
|
196
|
+
thinking_time = end_time - start_time
|
197
|
+
self.logger.error(f"Error processing request after {thinking_time:.2f}s: {str(e)}")
|
198
|
+
return f"Error: {str(e)}", thinking_time
|
199
|
+
|
200
|
+
def get_conversation_summary(self) -> Dict[str, any]:
|
201
|
+
"""
|
202
|
+
Get conversation statistics and summary.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
Dictionary with conversation metrics
|
206
|
+
"""
|
207
|
+
return self.conversation_manager.get_conversation_summary()
|
208
|
+
|
209
|
+
def clear_conversation(self) -> None:
|
210
|
+
"""Clear conversation history."""
|
211
|
+
self.conversation_manager.clear_conversation()
|
212
|
+
self.logger.info("Conversation history cleared")
|
213
|
+
|
214
|
+
def get_conversation_manager(self) -> ConversationManager:
|
215
|
+
"""
|
216
|
+
Get the conversation manager instance.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
Conversation manager instance
|
220
|
+
"""
|
221
|
+
return self.conversation_manager
|
@@ -0,0 +1,62 @@
|
|
1
|
+
"""
|
2
|
+
Abstract LLM client interface for todo.sh agent.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import Any, Dict, List
|
7
|
+
|
8
|
+
|
9
|
+
class LLMClient(ABC):
|
10
|
+
"""Abstract interface for LLM clients."""
|
11
|
+
|
12
|
+
@abstractmethod
|
13
|
+
def chat_with_tools(
|
14
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
15
|
+
) -> Dict[str, Any]:
|
16
|
+
"""
|
17
|
+
Send chat message with function calling enabled.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
messages: List of message dictionaries
|
21
|
+
tools: List of tool definitions
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
API response dictionary
|
25
|
+
"""
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def extract_tool_calls(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
30
|
+
"""
|
31
|
+
Extract tool calls from API response.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
response: API response dictionary
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
List of tool call dictionaries
|
38
|
+
"""
|
39
|
+
pass
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def extract_content(self, response: Dict[str, Any]) -> str:
|
43
|
+
"""
|
44
|
+
Extract content from API response.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
response: API response dictionary
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
Extracted content string
|
51
|
+
"""
|
52
|
+
pass
|
53
|
+
|
54
|
+
@abstractmethod
|
55
|
+
def get_model_name(self) -> str:
|
56
|
+
"""
|
57
|
+
Get the model name being used by this client.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
Model name string
|
61
|
+
"""
|
62
|
+
pass
|
@@ -0,0 +1,48 @@
|
|
1
|
+
"""
|
2
|
+
Factory for creating LLM clients based on configuration.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
try:
|
8
|
+
from todo_agent.infrastructure.config import Config
|
9
|
+
from todo_agent.infrastructure.llm_client import LLMClient
|
10
|
+
from todo_agent.infrastructure.openrouter_client import OpenRouterClient
|
11
|
+
from todo_agent.infrastructure.ollama_client import OllamaClient
|
12
|
+
from todo_agent.infrastructure.logger import Logger
|
13
|
+
except ImportError:
|
14
|
+
from infrastructure.config import Config
|
15
|
+
from infrastructure.llm_client import LLMClient
|
16
|
+
from infrastructure.openrouter_client import OpenRouterClient
|
17
|
+
from infrastructure.ollama_client import OllamaClient
|
18
|
+
from infrastructure.logger import Logger
|
19
|
+
|
20
|
+
|
21
|
+
class LLMClientFactory:
|
22
|
+
"""Factory for creating LLM clients based on configuration."""
|
23
|
+
|
24
|
+
@staticmethod
|
25
|
+
def create_client(config: Config, logger: Optional[Logger] = None) -> LLMClient:
|
26
|
+
"""
|
27
|
+
Create appropriate LLM client based on configuration.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
config: Configuration object
|
31
|
+
logger: Optional logger instance
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
LLM client instance
|
35
|
+
|
36
|
+
Raises:
|
37
|
+
ValueError: If provider is not supported
|
38
|
+
"""
|
39
|
+
logger = logger or Logger("llm_client_factory")
|
40
|
+
|
41
|
+
if config.provider == "openrouter":
|
42
|
+
logger.info(f"Creating OpenRouter client with model: {config.openrouter_model}")
|
43
|
+
return OpenRouterClient(config)
|
44
|
+
elif config.provider == "ollama":
|
45
|
+
logger.info(f"Creating Ollama client with model: {config.ollama_model}")
|
46
|
+
return OllamaClient(config)
|
47
|
+
else:
|
48
|
+
raise ValueError(f"Unsupported LLM provider: {config.provider}")
|
@@ -0,0 +1,128 @@
|
|
1
|
+
"""
|
2
|
+
Logging infrastructure for todo.sh LLM agent.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from datetime import datetime
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Optional
|
10
|
+
|
11
|
+
|
12
|
+
class Logger:
|
13
|
+
"""Custom logger that respects LOG_LEVEL environment variable and logs to screen and file."""
|
14
|
+
|
15
|
+
def __init__(self, name: str = "todo_agent"):
|
16
|
+
"""
|
17
|
+
Initialize the logger.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
name: Logger name, defaults to "todo_agent"
|
21
|
+
"""
|
22
|
+
self.name = name
|
23
|
+
self.logger = logging.getLogger(name)
|
24
|
+
self.logger.setLevel(logging.DEBUG)
|
25
|
+
|
26
|
+
# Clear any existing handlers
|
27
|
+
self.logger.handlers.clear()
|
28
|
+
|
29
|
+
# Create logs directory if it doesn't exist
|
30
|
+
self._ensure_logs_directory()
|
31
|
+
|
32
|
+
# Set up file handler (always active)
|
33
|
+
self._setup_file_handler()
|
34
|
+
|
35
|
+
# Set up console handler with appropriate log level
|
36
|
+
self._setup_console_handler()
|
37
|
+
|
38
|
+
def _ensure_logs_directory(self):
|
39
|
+
"""Ensure the logs directory exists in TODO_DIR."""
|
40
|
+
logs_dir = self._get_logs_directory()
|
41
|
+
logs_dir.mkdir(exist_ok=True)
|
42
|
+
|
43
|
+
def _get_logs_directory(self) -> Path:
|
44
|
+
"""Get the logs directory path from TODO_DIR environment variable."""
|
45
|
+
todo_dir = os.getenv("TODO_DIR")
|
46
|
+
if todo_dir:
|
47
|
+
return Path(todo_dir) / "logs"
|
48
|
+
else:
|
49
|
+
# Fallback to local logs directory if TODO_DIR is not set
|
50
|
+
return Path("logs")
|
51
|
+
|
52
|
+
def _get_log_level(self) -> int:
|
53
|
+
"""Get log level from LOG_LEVEL environment variable."""
|
54
|
+
log_level_str = os.getenv("LOG_LEVEL", "INFO").upper()
|
55
|
+
|
56
|
+
# Map string values to logging constants
|
57
|
+
level_map = {
|
58
|
+
"DEBUG": logging.DEBUG,
|
59
|
+
"INFO": logging.INFO,
|
60
|
+
"WARNING": logging.WARNING,
|
61
|
+
"ERROR": logging.ERROR,
|
62
|
+
"CRITICAL": logging.CRITICAL
|
63
|
+
}
|
64
|
+
|
65
|
+
return level_map.get(log_level_str, logging.INFO)
|
66
|
+
|
67
|
+
def _should_log_to_console(self) -> bool:
|
68
|
+
"""Check if we should log to console based on DEBUG environment variable."""
|
69
|
+
return os.getenv("DEBUG") is not None
|
70
|
+
|
71
|
+
def _setup_file_handler(self):
|
72
|
+
"""Set up file handler for logging to file."""
|
73
|
+
# Create log filename with timestamp
|
74
|
+
timestamp = datetime.now().strftime("%Y%m%d")
|
75
|
+
logs_dir = self._get_logs_directory()
|
76
|
+
log_file = logs_dir / f"todo_agent_{timestamp}.log"
|
77
|
+
|
78
|
+
file_handler = logging.FileHandler(log_file)
|
79
|
+
file_handler.setLevel(logging.DEBUG)
|
80
|
+
|
81
|
+
# Create formatter for file logging
|
82
|
+
file_formatter = logging.Formatter(
|
83
|
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
84
|
+
)
|
85
|
+
file_handler.setFormatter(file_formatter)
|
86
|
+
|
87
|
+
self.logger.addHandler(file_handler)
|
88
|
+
|
89
|
+
def _setup_console_handler(self):
|
90
|
+
"""Set up console handler for logging to screen with appropriate log level."""
|
91
|
+
# Only add console handler if DEBUG environment variable is set
|
92
|
+
if not self._should_log_to_console():
|
93
|
+
return
|
94
|
+
|
95
|
+
console_handler = logging.StreamHandler()
|
96
|
+
console_handler.setLevel(self._get_log_level())
|
97
|
+
|
98
|
+
# Create formatter for console logging (more concise)
|
99
|
+
console_formatter = logging.Formatter(
|
100
|
+
'%(levelname)s - %(message)s'
|
101
|
+
)
|
102
|
+
console_handler.setFormatter(console_formatter)
|
103
|
+
|
104
|
+
self.logger.addHandler(console_handler)
|
105
|
+
|
106
|
+
def debug(self, message: str):
|
107
|
+
"""Log a debug message."""
|
108
|
+
self.logger.debug(message)
|
109
|
+
|
110
|
+
def info(self, message: str):
|
111
|
+
"""Log an info message."""
|
112
|
+
self.logger.info(message)
|
113
|
+
|
114
|
+
def warning(self, message: str):
|
115
|
+
"""Log a warning message."""
|
116
|
+
self.logger.warning(message)
|
117
|
+
|
118
|
+
def error(self, message: str):
|
119
|
+
"""Log an error message."""
|
120
|
+
self.logger.error(message)
|
121
|
+
|
122
|
+
def critical(self, message: str):
|
123
|
+
"""Log a critical message."""
|
124
|
+
self.logger.critical(message)
|
125
|
+
|
126
|
+
def exception(self, message: str):
|
127
|
+
"""Log an exception message with traceback."""
|
128
|
+
self.logger.exception(message)
|
@@ -0,0 +1,152 @@
|
|
1
|
+
"""
|
2
|
+
LLM client for Ollama API communication.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
import time
|
7
|
+
from typing import Any, Dict, List
|
8
|
+
|
9
|
+
import requests
|
10
|
+
|
11
|
+
try:
|
12
|
+
from todo_agent.infrastructure.config import Config
|
13
|
+
from todo_agent.infrastructure.logger import Logger
|
14
|
+
from todo_agent.infrastructure.token_counter import get_token_counter
|
15
|
+
from todo_agent.infrastructure.llm_client import LLMClient
|
16
|
+
except ImportError:
|
17
|
+
from infrastructure.config import Config
|
18
|
+
from infrastructure.logger import Logger
|
19
|
+
from infrastructure.token_counter import get_token_counter
|
20
|
+
from infrastructure.llm_client import LLMClient
|
21
|
+
|
22
|
+
|
23
|
+
class OllamaClient(LLMClient):
|
24
|
+
"""Ollama API client implementation."""
|
25
|
+
|
26
|
+
def __init__(self, config: Config):
|
27
|
+
"""
|
28
|
+
Initialize Ollama client.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
config: Configuration object
|
32
|
+
"""
|
33
|
+
self.config = config
|
34
|
+
self.base_url = config.ollama_base_url
|
35
|
+
self.model = config.ollama_model
|
36
|
+
self.logger = Logger("ollama_client")
|
37
|
+
self.token_counter = get_token_counter(self.model)
|
38
|
+
|
39
|
+
def _estimate_tokens(self, text: str) -> int:
|
40
|
+
"""
|
41
|
+
Estimate token count for text using accurate tokenization.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
text: Text to count tokens for
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
Number of tokens
|
48
|
+
"""
|
49
|
+
return self.token_counter.count_tokens(text)
|
50
|
+
|
51
|
+
def _log_request_details(self, payload: Dict[str, Any], start_time: float):
|
52
|
+
"""Log request details including accurate token count."""
|
53
|
+
# Count tokens for messages
|
54
|
+
messages = payload.get("messages", [])
|
55
|
+
tools = payload.get("tools", [])
|
56
|
+
|
57
|
+
total_tokens = self.token_counter.count_request_tokens(messages, tools)
|
58
|
+
|
59
|
+
self.logger.info(f"Request sent - Token count: {total_tokens}")
|
60
|
+
# self.logger.debug(f"Raw request payload: {json.dumps(payload, indent=2)}")
|
61
|
+
|
62
|
+
def _log_response_details(self, response: Dict[str, Any], start_time: float):
|
63
|
+
"""Log response details including latency."""
|
64
|
+
end_time = time.time()
|
65
|
+
latency_ms = (end_time - start_time) * 1000
|
66
|
+
|
67
|
+
self.logger.info(f"Response received - Latency: {latency_ms:.2f}ms")
|
68
|
+
|
69
|
+
# Log tool call details if present
|
70
|
+
if "message" in response and "tool_calls" in response["message"]:
|
71
|
+
tool_calls = response["message"]["tool_calls"]
|
72
|
+
self.logger.info(f"Response contains {len(tool_calls)} tool calls")
|
73
|
+
for i, tool_call in enumerate(tool_calls):
|
74
|
+
tool_name = tool_call.get("function", {}).get("name", "unknown")
|
75
|
+
self.logger.info(f" Tool call {i+1}: {tool_name}")
|
76
|
+
elif "message" in response and "content" in response["message"]:
|
77
|
+
content = response["message"]["content"]
|
78
|
+
self.logger.debug(f"Response contains content: {content[:100]}{'...' if len(content) > 100 else ''}")
|
79
|
+
|
80
|
+
self.logger.debug(f"Raw response: {json.dumps(response, indent=2)}")
|
81
|
+
|
82
|
+
def chat_with_tools(
|
83
|
+
self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]]
|
84
|
+
) -> Dict[str, Any]:
|
85
|
+
"""
|
86
|
+
Send chat message with function calling enabled.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
messages: List of message dictionaries
|
90
|
+
tools: List of tool definitions
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
API response dictionary
|
94
|
+
"""
|
95
|
+
headers = {
|
96
|
+
"Content-Type": "application/json",
|
97
|
+
}
|
98
|
+
|
99
|
+
payload = {
|
100
|
+
"model": self.model,
|
101
|
+
"messages": messages,
|
102
|
+
"tools": tools,
|
103
|
+
"stream": False,
|
104
|
+
}
|
105
|
+
|
106
|
+
start_time = time.time()
|
107
|
+
self._log_request_details(payload, start_time)
|
108
|
+
|
109
|
+
response = requests.post(
|
110
|
+
f"{self.base_url}/api/chat", headers=headers, json=payload
|
111
|
+
)
|
112
|
+
|
113
|
+
if response.status_code != 200:
|
114
|
+
self.logger.error(f"Ollama API error: {response.text}")
|
115
|
+
raise Exception(f"Ollama API error: {response.text}")
|
116
|
+
|
117
|
+
response_data = response.json()
|
118
|
+
self._log_response_details(response_data, start_time)
|
119
|
+
|
120
|
+
return response_data
|
121
|
+
|
122
|
+
def extract_tool_calls(self, response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
123
|
+
"""Extract tool calls from API response."""
|
124
|
+
tool_calls = []
|
125
|
+
|
126
|
+
# Ollama response format is different from OpenRouter
|
127
|
+
if "message" in response and "tool_calls" in response["message"]:
|
128
|
+
tool_calls = response["message"]["tool_calls"]
|
129
|
+
self.logger.debug(f"Extracted {len(tool_calls)} tool calls from response")
|
130
|
+
for i, tool_call in enumerate(tool_calls):
|
131
|
+
tool_name = tool_call.get("function", {}).get("name", "unknown")
|
132
|
+
tool_call_id = tool_call.get("id", "unknown")
|
133
|
+
self.logger.debug(f"Tool call {i+1}: {tool_name} (ID: {tool_call_id})")
|
134
|
+
else:
|
135
|
+
self.logger.debug("No tool calls found in response")
|
136
|
+
|
137
|
+
return tool_calls
|
138
|
+
|
139
|
+
def extract_content(self, response: Dict[str, Any]) -> str:
|
140
|
+
"""Extract content from API response."""
|
141
|
+
if "message" in response and "content" in response["message"]:
|
142
|
+
return response["message"]["content"]
|
143
|
+
return ""
|
144
|
+
|
145
|
+
def get_model_name(self) -> str:
|
146
|
+
"""
|
147
|
+
Get the model name being used by this client.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
Model name string
|
151
|
+
"""
|
152
|
+
return self.model
|