sigma-terminal 3.4.1__py3-none-any.whl → 3.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sigma/__init__.py +4 -5
- sigma/analytics/__init__.py +11 -9
- sigma/app.py +384 -1194
- sigma/backtest/__init__.py +2 -0
- sigma/backtest/service.py +116 -0
- sigma/charts.py +2 -2
- sigma/cli.py +14 -12
- sigma/comparison.py +2 -2
- sigma/config.py +14 -4
- sigma/core/command_router.py +93 -0
- sigma/llm/__init__.py +3 -0
- sigma/llm/providers/anthropic_provider.py +196 -0
- sigma/llm/providers/base.py +29 -0
- sigma/llm/providers/google_provider.py +197 -0
- sigma/llm/providers/ollama_provider.py +156 -0
- sigma/llm/providers/openai_provider.py +168 -0
- sigma/llm/providers/sigma_cloud_provider.py +57 -0
- sigma/llm/rate_limit.py +40 -0
- sigma/llm/registry.py +66 -0
- sigma/llm/router.py +122 -0
- sigma/setup_agent.py +188 -0
- sigma/tools/__init__.py +23 -0
- sigma/tools/adapter.py +38 -0
- sigma/{tools.py → tools/library.py} +2 -1
- sigma/tools/registry.py +108 -0
- sigma/utils/extraction.py +83 -0
- sigma_terminal-3.5.0.dist-info/METADATA +184 -0
- sigma_terminal-3.5.0.dist-info/RECORD +46 -0
- sigma/llm.py +0 -786
- sigma/setup.py +0 -440
- sigma_terminal-3.4.1.dist-info/METADATA +0 -272
- sigma_terminal-3.4.1.dist-info/RECORD +0 -30
- /sigma/{backtest.py → backtest/simple_engine.py} +0 -0
- {sigma_terminal-3.4.1.dist-info → sigma_terminal-3.5.0.dist-info}/WHEEL +0 -0
- {sigma_terminal-3.4.1.dist-info → sigma_terminal-3.5.0.dist-info}/entry_points.txt +0 -0
- {sigma_terminal-3.4.1.dist-info → sigma_terminal-3.5.0.dist-info}/licenses/LICENSE +0 -0
sigma/llm.py
DELETED
|
@@ -1,786 +0,0 @@
|
|
|
1
|
-
"""LLM client implementations for all providers."""
|
|
2
|
-
|
|
3
|
-
import asyncio
|
|
4
|
-
import json
|
|
5
|
-
import re
|
|
6
|
-
import time
|
|
7
|
-
from abc import ABC, abstractmethod
|
|
8
|
-
from typing import Any, AsyncIterator, Callable, Optional
|
|
9
|
-
|
|
10
|
-
from sigma.config import LLMProvider, get_settings, ErrorCode, SigmaError
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
# Rate limiting configuration
|
|
14
|
-
class RateLimiter:
|
|
15
|
-
"""Simple rate limiter to prevent API flooding."""
|
|
16
|
-
|
|
17
|
-
def __init__(self, requests_per_minute: int = 10, min_interval: float = 1.0):
|
|
18
|
-
self.requests_per_minute = requests_per_minute
|
|
19
|
-
self.min_interval = min_interval
|
|
20
|
-
self.last_request_time = 0
|
|
21
|
-
self.request_count = 0
|
|
22
|
-
self.window_start = time.time()
|
|
23
|
-
|
|
24
|
-
async def wait(self):
|
|
25
|
-
"""Wait if necessary to respect rate limits."""
|
|
26
|
-
current_time = time.time()
|
|
27
|
-
|
|
28
|
-
# Reset window if a minute has passed
|
|
29
|
-
if current_time - self.window_start >= 60:
|
|
30
|
-
self.window_start = current_time
|
|
31
|
-
self.request_count = 0
|
|
32
|
-
|
|
33
|
-
# Check if we've hit the rate limit
|
|
34
|
-
if self.request_count >= self.requests_per_minute:
|
|
35
|
-
wait_time = 60 - (current_time - self.window_start)
|
|
36
|
-
if wait_time > 0:
|
|
37
|
-
await asyncio.sleep(wait_time)
|
|
38
|
-
self.window_start = time.time()
|
|
39
|
-
self.request_count = 0
|
|
40
|
-
|
|
41
|
-
# Ensure minimum interval between requests
|
|
42
|
-
time_since_last = current_time - self.last_request_time
|
|
43
|
-
if time_since_last < self.min_interval:
|
|
44
|
-
await asyncio.sleep(self.min_interval - time_since_last)
|
|
45
|
-
|
|
46
|
-
self.last_request_time = time.time()
|
|
47
|
-
self.request_count += 1
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
# Global rate limiters per provider (generous limits to avoid rate limiting)
|
|
51
|
-
_rate_limiters = {
|
|
52
|
-
"google": RateLimiter(requests_per_minute=60, min_interval=0.2), # Gemini free tier is generous
|
|
53
|
-
"openai": RateLimiter(requests_per_minute=60, min_interval=0.2), # GPT-5 tier 1+
|
|
54
|
-
"anthropic": RateLimiter(requests_per_minute=40, min_interval=0.3), # Claude standard
|
|
55
|
-
"groq": RateLimiter(requests_per_minute=30, min_interval=0.2), # Groq free tier
|
|
56
|
-
"xai": RateLimiter(requests_per_minute=30, min_interval=0.3), # Grok standard
|
|
57
|
-
"ollama": RateLimiter(requests_per_minute=120, min_interval=0.05), # Local, no limits
|
|
58
|
-
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class BaseLLM(ABC):
|
|
62
|
-
"""Base class for LLM clients."""
|
|
63
|
-
|
|
64
|
-
provider_name: str = "base"
|
|
65
|
-
|
|
66
|
-
@abstractmethod
|
|
67
|
-
async def generate(
|
|
68
|
-
self,
|
|
69
|
-
messages: list[dict],
|
|
70
|
-
tools: Optional[list[dict]] = None,
|
|
71
|
-
on_tool_call: Optional[Callable] = None,
|
|
72
|
-
) -> str:
|
|
73
|
-
"""Generate a response."""
|
|
74
|
-
pass
|
|
75
|
-
|
|
76
|
-
async def _rate_limit(self):
|
|
77
|
-
"""Apply rate limiting."""
|
|
78
|
-
limiter = _rate_limiters.get(self.provider_name)
|
|
79
|
-
if limiter:
|
|
80
|
-
await limiter.wait()
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class GoogleLLM(BaseLLM):
|
|
84
|
-
"""Google Gemini client."""
|
|
85
|
-
|
|
86
|
-
provider_name = "google"
|
|
87
|
-
|
|
88
|
-
def __init__(self, api_key: str, model: str):
|
|
89
|
-
from google import genai
|
|
90
|
-
self.client = genai.Client(api_key=api_key)
|
|
91
|
-
self.model_name = model
|
|
92
|
-
|
|
93
|
-
async def generate(
|
|
94
|
-
self,
|
|
95
|
-
messages: list[dict],
|
|
96
|
-
tools: Optional[list[dict]] = None,
|
|
97
|
-
on_tool_call: Optional[Callable] = None,
|
|
98
|
-
) -> str:
|
|
99
|
-
await self._rate_limit()
|
|
100
|
-
from google.genai import types
|
|
101
|
-
|
|
102
|
-
# Extract system prompt and build contents
|
|
103
|
-
system_prompt = None
|
|
104
|
-
contents = []
|
|
105
|
-
|
|
106
|
-
for msg in messages:
|
|
107
|
-
role = msg["role"]
|
|
108
|
-
content = msg["content"]
|
|
109
|
-
|
|
110
|
-
if role == "system":
|
|
111
|
-
system_prompt = content
|
|
112
|
-
elif role == "user":
|
|
113
|
-
contents.append(types.Content(
|
|
114
|
-
role="user",
|
|
115
|
-
parts=[types.Part(text=content)]
|
|
116
|
-
))
|
|
117
|
-
elif role == "assistant":
|
|
118
|
-
contents.append(types.Content(
|
|
119
|
-
role="model",
|
|
120
|
-
parts=[types.Part(text=content)]
|
|
121
|
-
))
|
|
122
|
-
|
|
123
|
-
# Build config
|
|
124
|
-
config = types.GenerateContentConfig(
|
|
125
|
-
system_instruction=system_prompt,
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
# Add tools if provided
|
|
129
|
-
if tools:
|
|
130
|
-
function_declarations = []
|
|
131
|
-
for tool in tools:
|
|
132
|
-
if tool.get("type") == "function":
|
|
133
|
-
func = tool["function"]
|
|
134
|
-
function_declarations.append(types.FunctionDeclaration(
|
|
135
|
-
name=func["name"],
|
|
136
|
-
description=func.get("description", ""),
|
|
137
|
-
parameters=func.get("parameters", {}),
|
|
138
|
-
))
|
|
139
|
-
if function_declarations:
|
|
140
|
-
config.tools = [types.Tool(function_declarations=function_declarations)]
|
|
141
|
-
|
|
142
|
-
# Generate
|
|
143
|
-
response = self.client.models.generate_content(
|
|
144
|
-
model=self.model_name,
|
|
145
|
-
contents=contents,
|
|
146
|
-
config=config,
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
# Handle function calls
|
|
150
|
-
if response.candidates:
|
|
151
|
-
candidate = response.candidates[0]
|
|
152
|
-
if candidate.content and candidate.content.parts:
|
|
153
|
-
# Collect all function calls first
|
|
154
|
-
function_calls = []
|
|
155
|
-
for part in candidate.content.parts:
|
|
156
|
-
if hasattr(part, 'function_call') and part.function_call:
|
|
157
|
-
function_calls.append(part.function_call)
|
|
158
|
-
|
|
159
|
-
# If there are function calls, process all of them
|
|
160
|
-
if function_calls and on_tool_call:
|
|
161
|
-
# Add the model's response with function calls
|
|
162
|
-
contents.append(candidate.content)
|
|
163
|
-
|
|
164
|
-
# Execute all function calls and build responses
|
|
165
|
-
function_responses = []
|
|
166
|
-
for fc in function_calls:
|
|
167
|
-
args = dict(fc.args) if fc.args else {}
|
|
168
|
-
result = await on_tool_call(fc.name, args)
|
|
169
|
-
function_responses.append(types.Part(
|
|
170
|
-
function_response=types.FunctionResponse(
|
|
171
|
-
name=fc.name,
|
|
172
|
-
response={"result": str(result)}
|
|
173
|
-
)
|
|
174
|
-
))
|
|
175
|
-
|
|
176
|
-
# Add all function responses in a single user message
|
|
177
|
-
contents.append(types.Content(
|
|
178
|
-
role="user",
|
|
179
|
-
parts=function_responses
|
|
180
|
-
))
|
|
181
|
-
|
|
182
|
-
# Get final response
|
|
183
|
-
final_response = self.client.models.generate_content(
|
|
184
|
-
model=self.model_name,
|
|
185
|
-
contents=contents,
|
|
186
|
-
config=config,
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
# Check if there are more function calls in the response
|
|
190
|
-
if final_response.candidates:
|
|
191
|
-
final_candidate = final_response.candidates[0]
|
|
192
|
-
if final_candidate.content and final_candidate.content.parts:
|
|
193
|
-
for part in final_candidate.content.parts:
|
|
194
|
-
if hasattr(part, 'function_call') and part.function_call:
|
|
195
|
-
# Recursive call to handle chained tool calls
|
|
196
|
-
new_contents = contents + [final_candidate.content]
|
|
197
|
-
fc = part.function_call
|
|
198
|
-
args = dict(fc.args) if fc.args else {}
|
|
199
|
-
result = await on_tool_call(fc.name, args)
|
|
200
|
-
new_contents.append(types.Content(
|
|
201
|
-
role="user",
|
|
202
|
-
parts=[types.Part(
|
|
203
|
-
function_response=types.FunctionResponse(
|
|
204
|
-
name=fc.name,
|
|
205
|
-
response={"result": str(result)}
|
|
206
|
-
)
|
|
207
|
-
)]
|
|
208
|
-
))
|
|
209
|
-
final_final = self.client.models.generate_content(
|
|
210
|
-
model=self.model_name,
|
|
211
|
-
contents=new_contents,
|
|
212
|
-
config=config,
|
|
213
|
-
)
|
|
214
|
-
return final_final.text or ""
|
|
215
|
-
|
|
216
|
-
return final_response.text or ""
|
|
217
|
-
|
|
218
|
-
return response.text or ""
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
class OpenAILLM(BaseLLM):
|
|
222
|
-
"""OpenAI client."""
|
|
223
|
-
|
|
224
|
-
provider_name = "openai"
|
|
225
|
-
|
|
226
|
-
def __init__(self, api_key: str, model: str):
|
|
227
|
-
from openai import AsyncOpenAI
|
|
228
|
-
self.client = AsyncOpenAI(api_key=api_key)
|
|
229
|
-
self.model = model
|
|
230
|
-
|
|
231
|
-
async def generate(
|
|
232
|
-
self,
|
|
233
|
-
messages: list[dict],
|
|
234
|
-
tools: Optional[list[dict]] = None,
|
|
235
|
-
on_tool_call: Optional[Callable] = None,
|
|
236
|
-
) -> str:
|
|
237
|
-
await self._rate_limit()
|
|
238
|
-
kwargs = {
|
|
239
|
-
"model": self.model,
|
|
240
|
-
"messages": messages,
|
|
241
|
-
}
|
|
242
|
-
|
|
243
|
-
if tools:
|
|
244
|
-
kwargs["tools"] = tools
|
|
245
|
-
kwargs["tool_choice"] = "auto"
|
|
246
|
-
|
|
247
|
-
response = await self.client.chat.completions.create(**kwargs)
|
|
248
|
-
message = response.choices[0].message
|
|
249
|
-
|
|
250
|
-
# Handle tool calls
|
|
251
|
-
if message.tool_calls and on_tool_call:
|
|
252
|
-
tool_results = []
|
|
253
|
-
for tc in message.tool_calls:
|
|
254
|
-
args = json.loads(tc.function.arguments)
|
|
255
|
-
result = await on_tool_call(tc.function.name, args)
|
|
256
|
-
tool_results.append({
|
|
257
|
-
"tool_call_id": tc.id,
|
|
258
|
-
"role": "tool",
|
|
259
|
-
"content": json.dumps(result)
|
|
260
|
-
})
|
|
261
|
-
|
|
262
|
-
# Continue with tool results
|
|
263
|
-
messages = messages + [message.model_dump()] + tool_results
|
|
264
|
-
return await self.generate(messages, tools, on_tool_call)
|
|
265
|
-
|
|
266
|
-
return message.content or ""
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
class AnthropicLLM(BaseLLM):
|
|
270
|
-
"""Anthropic Claude client."""
|
|
271
|
-
|
|
272
|
-
provider_name = "anthropic"
|
|
273
|
-
|
|
274
|
-
def __init__(self, api_key: str, model: str):
|
|
275
|
-
from anthropic import AsyncAnthropic
|
|
276
|
-
self.client = AsyncAnthropic(api_key=api_key)
|
|
277
|
-
self.model = model
|
|
278
|
-
|
|
279
|
-
async def generate(
|
|
280
|
-
self,
|
|
281
|
-
messages: list[dict],
|
|
282
|
-
tools: Optional[list[dict]] = None,
|
|
283
|
-
on_tool_call: Optional[Callable] = None,
|
|
284
|
-
) -> str:
|
|
285
|
-
await self._rate_limit()
|
|
286
|
-
# Extract system message
|
|
287
|
-
system = ""
|
|
288
|
-
filtered_messages = []
|
|
289
|
-
for msg in messages:
|
|
290
|
-
if msg["role"] == "system":
|
|
291
|
-
system = msg["content"]
|
|
292
|
-
else:
|
|
293
|
-
filtered_messages.append(msg)
|
|
294
|
-
|
|
295
|
-
kwargs = {
|
|
296
|
-
"model": self.model,
|
|
297
|
-
"max_tokens": 4096,
|
|
298
|
-
"messages": filtered_messages,
|
|
299
|
-
}
|
|
300
|
-
|
|
301
|
-
if system:
|
|
302
|
-
kwargs["system"] = system
|
|
303
|
-
|
|
304
|
-
if tools:
|
|
305
|
-
# Convert to Anthropic format
|
|
306
|
-
kwargs["tools"] = [
|
|
307
|
-
{
|
|
308
|
-
"name": t["function"]["name"],
|
|
309
|
-
"description": t["function"].get("description", ""),
|
|
310
|
-
"input_schema": t["function"].get("parameters", {})
|
|
311
|
-
}
|
|
312
|
-
for t in tools if t.get("type") == "function"
|
|
313
|
-
]
|
|
314
|
-
|
|
315
|
-
response = await self.client.messages.create(**kwargs)
|
|
316
|
-
|
|
317
|
-
# Handle tool use
|
|
318
|
-
result_text = ""
|
|
319
|
-
for block in response.content:
|
|
320
|
-
if block.type == "text":
|
|
321
|
-
result_text += block.text
|
|
322
|
-
elif block.type == "tool_use" and on_tool_call:
|
|
323
|
-
result = await on_tool_call(block.name, block.input)
|
|
324
|
-
# Continue conversation
|
|
325
|
-
filtered_messages.append({
|
|
326
|
-
"role": "assistant",
|
|
327
|
-
"content": response.content
|
|
328
|
-
})
|
|
329
|
-
filtered_messages.append({
|
|
330
|
-
"role": "user",
|
|
331
|
-
"content": [{
|
|
332
|
-
"type": "tool_result",
|
|
333
|
-
"tool_use_id": block.id,
|
|
334
|
-
"content": json.dumps(result)
|
|
335
|
-
}]
|
|
336
|
-
})
|
|
337
|
-
return await self.generate(
|
|
338
|
-
[{"role": "system", "content": system}] + filtered_messages,
|
|
339
|
-
tools, on_tool_call
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
return result_text
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
class GroqLLM(BaseLLM):
|
|
346
|
-
"""Groq client."""
|
|
347
|
-
|
|
348
|
-
provider_name = "groq"
|
|
349
|
-
|
|
350
|
-
def __init__(self, api_key: str, model: str):
|
|
351
|
-
from groq import AsyncGroq
|
|
352
|
-
self.client = AsyncGroq(api_key=api_key)
|
|
353
|
-
self.model = model
|
|
354
|
-
|
|
355
|
-
async def generate(
|
|
356
|
-
self,
|
|
357
|
-
messages: list[dict],
|
|
358
|
-
tools: Optional[list[dict]] = None,
|
|
359
|
-
on_tool_call: Optional[Callable] = None,
|
|
360
|
-
) -> str:
|
|
361
|
-
await self._rate_limit()
|
|
362
|
-
kwargs = {
|
|
363
|
-
"model": self.model,
|
|
364
|
-
"messages": messages,
|
|
365
|
-
}
|
|
366
|
-
|
|
367
|
-
if tools:
|
|
368
|
-
kwargs["tools"] = tools
|
|
369
|
-
kwargs["tool_choice"] = "auto"
|
|
370
|
-
|
|
371
|
-
response = await self.client.chat.completions.create(**kwargs)
|
|
372
|
-
message = response.choices[0].message
|
|
373
|
-
|
|
374
|
-
# Handle tool calls (similar to OpenAI)
|
|
375
|
-
if message.tool_calls and on_tool_call:
|
|
376
|
-
tool_results = []
|
|
377
|
-
for tc in message.tool_calls:
|
|
378
|
-
args = json.loads(tc.function.arguments)
|
|
379
|
-
result = await on_tool_call(tc.function.name, args)
|
|
380
|
-
tool_results.append({
|
|
381
|
-
"tool_call_id": tc.id,
|
|
382
|
-
"role": "tool",
|
|
383
|
-
"content": json.dumps(result)
|
|
384
|
-
})
|
|
385
|
-
|
|
386
|
-
messages = messages + [{"role": "assistant", "tool_calls": message.tool_calls}] + tool_results
|
|
387
|
-
return await self.generate(messages, tools, on_tool_call)
|
|
388
|
-
|
|
389
|
-
return message.content or ""
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
class OllamaLLM(BaseLLM):
|
|
393
|
-
"""Ollama local client with native tool call support."""
|
|
394
|
-
|
|
395
|
-
provider_name = "ollama"
|
|
396
|
-
|
|
397
|
-
def __init__(self, host: str, model: str):
|
|
398
|
-
self.host = host.rstrip("/")
|
|
399
|
-
self.model = model
|
|
400
|
-
|
|
401
|
-
async def generate(
|
|
402
|
-
self,
|
|
403
|
-
messages: list[dict],
|
|
404
|
-
tools: Optional[list[dict]] = None,
|
|
405
|
-
on_tool_call: Optional[Callable] = None,
|
|
406
|
-
) -> str:
|
|
407
|
-
await self._rate_limit()
|
|
408
|
-
import aiohttp
|
|
409
|
-
|
|
410
|
-
# Convert tools to Ollama format
|
|
411
|
-
ollama_tools = None
|
|
412
|
-
if tools:
|
|
413
|
-
ollama_tools = []
|
|
414
|
-
for tool in tools:
|
|
415
|
-
if tool.get("type") == "function":
|
|
416
|
-
f = tool["function"]
|
|
417
|
-
ollama_tools.append({
|
|
418
|
-
"type": "function",
|
|
419
|
-
"function": {
|
|
420
|
-
"name": f["name"],
|
|
421
|
-
"description": f.get("description", ""),
|
|
422
|
-
"parameters": f.get("parameters", {})
|
|
423
|
-
}
|
|
424
|
-
})
|
|
425
|
-
|
|
426
|
-
request_body = {
|
|
427
|
-
"model": self.model,
|
|
428
|
-
"messages": messages,
|
|
429
|
-
"stream": False
|
|
430
|
-
}
|
|
431
|
-
|
|
432
|
-
if ollama_tools:
|
|
433
|
-
request_body["tools"] = ollama_tools
|
|
434
|
-
|
|
435
|
-
async with aiohttp.ClientSession() as session:
|
|
436
|
-
try:
|
|
437
|
-
async with session.post(
|
|
438
|
-
f"{self.host}/api/chat",
|
|
439
|
-
json=request_body,
|
|
440
|
-
timeout=aiohttp.ClientTimeout(total=120)
|
|
441
|
-
) as resp:
|
|
442
|
-
if resp.status != 200:
|
|
443
|
-
error_text = await resp.text()
|
|
444
|
-
return f"Ollama error: {error_text}"
|
|
445
|
-
|
|
446
|
-
data = await resp.json()
|
|
447
|
-
message = data.get("message", {})
|
|
448
|
-
|
|
449
|
-
# Check for tool calls in response
|
|
450
|
-
tool_calls = message.get("tool_calls", [])
|
|
451
|
-
|
|
452
|
-
if tool_calls and on_tool_call:
|
|
453
|
-
# Process tool calls
|
|
454
|
-
updated_messages = messages.copy()
|
|
455
|
-
updated_messages.append(message)
|
|
456
|
-
|
|
457
|
-
for tc in tool_calls:
|
|
458
|
-
func = tc.get("function", {})
|
|
459
|
-
tool_name = func.get("name", "")
|
|
460
|
-
tool_args = func.get("arguments", {})
|
|
461
|
-
|
|
462
|
-
# If arguments is a string, parse it
|
|
463
|
-
if isinstance(tool_args, str):
|
|
464
|
-
try:
|
|
465
|
-
tool_args = json.loads(tool_args)
|
|
466
|
-
except json.JSONDecodeError:
|
|
467
|
-
tool_args = {}
|
|
468
|
-
|
|
469
|
-
# Execute the tool
|
|
470
|
-
result = await on_tool_call(tool_name, tool_args)
|
|
471
|
-
|
|
472
|
-
# Add tool result to messages
|
|
473
|
-
updated_messages.append({
|
|
474
|
-
"role": "tool",
|
|
475
|
-
"content": json.dumps(result) if not isinstance(result, str) else result
|
|
476
|
-
})
|
|
477
|
-
|
|
478
|
-
# Get final response with tool results
|
|
479
|
-
return await self._continue_with_tools(session, updated_messages, ollama_tools, on_tool_call)
|
|
480
|
-
|
|
481
|
-
# Check for text-based tool calls (fallback for older models)
|
|
482
|
-
content = message.get("content", "")
|
|
483
|
-
if "TOOL_CALL:" in content and on_tool_call:
|
|
484
|
-
result = await self._parse_text_tool_call(content, on_tool_call)
|
|
485
|
-
if result:
|
|
486
|
-
return result
|
|
487
|
-
|
|
488
|
-
return content
|
|
489
|
-
|
|
490
|
-
except aiohttp.ClientError as e:
|
|
491
|
-
return f"Connection error: {e}. Is Ollama running?"
|
|
492
|
-
except asyncio.TimeoutError:
|
|
493
|
-
return "Request timed out. Try a simpler query or check Ollama status."
|
|
494
|
-
|
|
495
|
-
async def _continue_with_tools(
|
|
496
|
-
self,
|
|
497
|
-
session,
|
|
498
|
-
messages: list[dict],
|
|
499
|
-
tools: Optional[list[dict]],
|
|
500
|
-
on_tool_call: Optional[Callable],
|
|
501
|
-
depth: int = 0
|
|
502
|
-
) -> str:
|
|
503
|
-
"""Continue conversation after tool calls."""
|
|
504
|
-
import aiohttp
|
|
505
|
-
|
|
506
|
-
if depth > 5: # Prevent infinite loops
|
|
507
|
-
return "Maximum tool call depth reached."
|
|
508
|
-
|
|
509
|
-
request_body = {
|
|
510
|
-
"model": self.model,
|
|
511
|
-
"messages": messages,
|
|
512
|
-
"stream": False
|
|
513
|
-
}
|
|
514
|
-
if tools:
|
|
515
|
-
request_body["tools"] = tools
|
|
516
|
-
|
|
517
|
-
async with session.post(
|
|
518
|
-
f"{self.host}/api/chat",
|
|
519
|
-
json=request_body,
|
|
520
|
-
timeout=aiohttp.ClientTimeout(total=120)
|
|
521
|
-
) as resp:
|
|
522
|
-
data = await resp.json()
|
|
523
|
-
message = data.get("message", {})
|
|
524
|
-
|
|
525
|
-
# Check for more tool calls
|
|
526
|
-
tool_calls = message.get("tool_calls", [])
|
|
527
|
-
if tool_calls and on_tool_call:
|
|
528
|
-
updated_messages = messages.copy()
|
|
529
|
-
updated_messages.append(message)
|
|
530
|
-
|
|
531
|
-
for tc in tool_calls:
|
|
532
|
-
func = tc.get("function", {})
|
|
533
|
-
tool_name = func.get("name", "")
|
|
534
|
-
tool_args = func.get("arguments", {})
|
|
535
|
-
|
|
536
|
-
if isinstance(tool_args, str):
|
|
537
|
-
try:
|
|
538
|
-
tool_args = json.loads(tool_args)
|
|
539
|
-
except json.JSONDecodeError:
|
|
540
|
-
tool_args = {}
|
|
541
|
-
|
|
542
|
-
result = await on_tool_call(tool_name, tool_args)
|
|
543
|
-
updated_messages.append({
|
|
544
|
-
"role": "tool",
|
|
545
|
-
"content": json.dumps(result) if not isinstance(result, str) else result
|
|
546
|
-
})
|
|
547
|
-
|
|
548
|
-
return await self._continue_with_tools(session, updated_messages, tools, on_tool_call, depth + 1)
|
|
549
|
-
|
|
550
|
-
return message.get("content", "")
|
|
551
|
-
|
|
552
|
-
async def _parse_text_tool_call(self, content: str, on_tool_call: Callable) -> Optional[str]:
|
|
553
|
-
"""Parse text-based tool calls for older models."""
|
|
554
|
-
# Pattern: TOOL_CALL: tool_name({"arg": "value"}) or TOOL_CALL: tool_name(arg=value)
|
|
555
|
-
pattern = r'TOOL_CALL:\s*(\w+)\s*\(([^)]*)\)'
|
|
556
|
-
match = re.search(pattern, content)
|
|
557
|
-
|
|
558
|
-
if not match:
|
|
559
|
-
return None
|
|
560
|
-
|
|
561
|
-
tool_name = match.group(1)
|
|
562
|
-
args_str = match.group(2).strip()
|
|
563
|
-
|
|
564
|
-
# Try to parse arguments
|
|
565
|
-
try:
|
|
566
|
-
if args_str.startswith("{"):
|
|
567
|
-
args = json.loads(args_str)
|
|
568
|
-
else:
|
|
569
|
-
# Parse key=value format
|
|
570
|
-
args = {}
|
|
571
|
-
for part in args_str.split(","):
|
|
572
|
-
if "=" in part:
|
|
573
|
-
k, v = part.split("=", 1)
|
|
574
|
-
args[k.strip()] = v.strip().strip('"\'')
|
|
575
|
-
except:
|
|
576
|
-
args = {"symbol": args_str} if args_str else {}
|
|
577
|
-
|
|
578
|
-
# Execute tool
|
|
579
|
-
result = await on_tool_call(tool_name, args)
|
|
580
|
-
|
|
581
|
-
# Format result for response
|
|
582
|
-
if isinstance(result, dict):
|
|
583
|
-
result_str = json.dumps(result, indent=2)
|
|
584
|
-
else:
|
|
585
|
-
result_str = str(result)
|
|
586
|
-
|
|
587
|
-
# Return combined response
|
|
588
|
-
return f"Tool result:\n```json\n{result_str}\n```"
|
|
589
|
-
|
|
590
|
-
def _format_tools_for_prompt(self, tools: list[dict]) -> str:
|
|
591
|
-
"""Format tools as text for prompt injection (legacy fallback)."""
|
|
592
|
-
lines = ["You have access to these tools:"]
|
|
593
|
-
for tool in tools:
|
|
594
|
-
if tool.get("type") == "function":
|
|
595
|
-
f = tool["function"]
|
|
596
|
-
params = f.get("parameters", {}).get("properties", {})
|
|
597
|
-
param_str = ", ".join(params.keys()) if params else ""
|
|
598
|
-
lines.append(f"- {f['name']}({param_str}): {f.get('description', '')}")
|
|
599
|
-
lines.append("\nTo use a tool, respond with: TOOL_CALL: tool_name(args)")
|
|
600
|
-
lines.append("Example: TOOL_CALL: get_stock_quote(symbol=\"AAPL\")")
|
|
601
|
-
return "\n".join(lines)
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
class XaiLLM(BaseLLM):
|
|
605
|
-
"""xAI Grok client (uses OpenAI-compatible API)."""
|
|
606
|
-
|
|
607
|
-
provider_name = "xai"
|
|
608
|
-
|
|
609
|
-
def __init__(self, api_key: str, model: str):
|
|
610
|
-
from openai import AsyncOpenAI
|
|
611
|
-
self.client = AsyncOpenAI(
|
|
612
|
-
api_key=api_key,
|
|
613
|
-
base_url="https://api.x.ai/v1"
|
|
614
|
-
)
|
|
615
|
-
self.model = model
|
|
616
|
-
|
|
617
|
-
async def generate(
|
|
618
|
-
self,
|
|
619
|
-
messages: list[dict],
|
|
620
|
-
tools: Optional[list[dict]] = None,
|
|
621
|
-
on_tool_call: Optional[Callable] = None,
|
|
622
|
-
) -> str:
|
|
623
|
-
await self._rate_limit()
|
|
624
|
-
|
|
625
|
-
try:
|
|
626
|
-
kwargs = {
|
|
627
|
-
"model": self.model,
|
|
628
|
-
"messages": messages,
|
|
629
|
-
}
|
|
630
|
-
|
|
631
|
-
if tools:
|
|
632
|
-
kwargs["tools"] = tools
|
|
633
|
-
kwargs["tool_choice"] = "auto"
|
|
634
|
-
|
|
635
|
-
response = await self.client.chat.completions.create(**kwargs)
|
|
636
|
-
message = response.choices[0].message
|
|
637
|
-
|
|
638
|
-
# Handle tool calls
|
|
639
|
-
if message.tool_calls and on_tool_call:
|
|
640
|
-
tool_results = []
|
|
641
|
-
for tc in message.tool_calls:
|
|
642
|
-
args = json.loads(tc.function.arguments)
|
|
643
|
-
result = await on_tool_call(tc.function.name, args)
|
|
644
|
-
tool_results.append({
|
|
645
|
-
"tool_call_id": tc.id,
|
|
646
|
-
"role": "tool",
|
|
647
|
-
"content": json.dumps(result)
|
|
648
|
-
})
|
|
649
|
-
|
|
650
|
-
# Continue with tool results
|
|
651
|
-
messages = messages + [message.model_dump()] + tool_results
|
|
652
|
-
return await self.generate(messages, tools, on_tool_call)
|
|
653
|
-
|
|
654
|
-
return message.content or ""
|
|
655
|
-
|
|
656
|
-
except Exception as e:
|
|
657
|
-
error_str = str(e)
|
|
658
|
-
if "401" in error_str or "invalid_api_key" in error_str:
|
|
659
|
-
raise SigmaError(
|
|
660
|
-
ErrorCode.API_KEY_INVALID,
|
|
661
|
-
"xAI API key is invalid",
|
|
662
|
-
{"provider": "xai"}
|
|
663
|
-
)
|
|
664
|
-
elif "429" in error_str or "rate_limit" in error_str:
|
|
665
|
-
raise SigmaError(
|
|
666
|
-
ErrorCode.API_KEY_RATE_LIMITED,
|
|
667
|
-
"xAI rate limit exceeded. Please wait.",
|
|
668
|
-
{"provider": "xai"}
|
|
669
|
-
)
|
|
670
|
-
raise
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
def _parse_api_error(error: Exception, provider: str) -> SigmaError:
|
|
674
|
-
"""Parse API errors into SigmaError with proper codes."""
|
|
675
|
-
error_str = str(error).lower()
|
|
676
|
-
|
|
677
|
-
if "401" in error_str or "invalid_api_key" in error_str or "api key not valid" in error_str:
|
|
678
|
-
return SigmaError(
|
|
679
|
-
ErrorCode.API_KEY_INVALID,
|
|
680
|
-
f"{provider.title()} API key is invalid. Check your key at /keys",
|
|
681
|
-
{"provider": provider, "original_error": str(error)[:200]}
|
|
682
|
-
)
|
|
683
|
-
elif "403" in error_str or "forbidden" in error_str:
|
|
684
|
-
return SigmaError(
|
|
685
|
-
ErrorCode.API_KEY_INVALID,
|
|
686
|
-
f"{provider.title()} API key doesn't have access to this model",
|
|
687
|
-
{"provider": provider}
|
|
688
|
-
)
|
|
689
|
-
elif "429" in error_str or "rate_limit" in error_str or "quota" in error_str:
|
|
690
|
-
return SigmaError(
|
|
691
|
-
ErrorCode.API_KEY_RATE_LIMITED,
|
|
692
|
-
f"{provider.title()} rate limit exceeded. Wait a moment.",
|
|
693
|
-
{"provider": provider}
|
|
694
|
-
)
|
|
695
|
-
elif "404" in error_str or "not found" in error_str or "does not exist" in error_str:
|
|
696
|
-
return SigmaError(
|
|
697
|
-
ErrorCode.MODEL_NOT_FOUND,
|
|
698
|
-
f"Model not found. Try /models to see available options.",
|
|
699
|
-
{"provider": provider}
|
|
700
|
-
)
|
|
701
|
-
elif "timeout" in error_str:
|
|
702
|
-
return SigmaError(
|
|
703
|
-
ErrorCode.TIMEOUT,
|
|
704
|
-
"Request timed out. Try a simpler query.",
|
|
705
|
-
{"provider": provider}
|
|
706
|
-
)
|
|
707
|
-
elif "connection" in error_str:
|
|
708
|
-
return SigmaError(
|
|
709
|
-
ErrorCode.CONNECTION_ERROR,
|
|
710
|
-
f"Cannot connect to {provider.title()}. Check internet.",
|
|
711
|
-
{"provider": provider}
|
|
712
|
-
)
|
|
713
|
-
else:
|
|
714
|
-
return SigmaError(
|
|
715
|
-
ErrorCode.PROVIDER_ERROR,
|
|
716
|
-
f"{provider.title()} error: {str(error)[:150]}",
|
|
717
|
-
{"provider": provider, "original_error": str(error)[:300]}
|
|
718
|
-
)
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
def get_llm(provider: LLMProvider, model: Optional[str] = None) -> BaseLLM:
|
|
722
|
-
"""Get LLM client for a provider."""
|
|
723
|
-
settings = get_settings()
|
|
724
|
-
|
|
725
|
-
if model is None:
|
|
726
|
-
model = settings.get_model(provider)
|
|
727
|
-
|
|
728
|
-
if provider == LLMProvider.GOOGLE:
|
|
729
|
-
api_key = settings.google_api_key
|
|
730
|
-
if not api_key:
|
|
731
|
-
raise SigmaError(
|
|
732
|
-
ErrorCode.API_KEY_MISSING,
|
|
733
|
-
"Google API key not configured. Use /keys to set up.",
|
|
734
|
-
{"provider": "google", "hint": "Get key at: https://aistudio.google.com/apikey"}
|
|
735
|
-
)
|
|
736
|
-
return GoogleLLM(api_key, model)
|
|
737
|
-
|
|
738
|
-
elif provider == LLMProvider.OPENAI:
|
|
739
|
-
api_key = settings.openai_api_key
|
|
740
|
-
if not api_key:
|
|
741
|
-
raise SigmaError(
|
|
742
|
-
ErrorCode.API_KEY_MISSING,
|
|
743
|
-
"OpenAI API key not configured. Use /keys to set up.",
|
|
744
|
-
{"provider": "openai", "hint": "Get key at: https://platform.openai.com/api-keys"}
|
|
745
|
-
)
|
|
746
|
-
return OpenAILLM(api_key, model)
|
|
747
|
-
|
|
748
|
-
elif provider == LLMProvider.ANTHROPIC:
|
|
749
|
-
api_key = settings.anthropic_api_key
|
|
750
|
-
if not api_key:
|
|
751
|
-
raise SigmaError(
|
|
752
|
-
ErrorCode.API_KEY_MISSING,
|
|
753
|
-
"Anthropic API key not configured. Use /keys to set up.",
|
|
754
|
-
{"provider": "anthropic", "hint": "Get key at: https://console.anthropic.com/settings/keys"}
|
|
755
|
-
)
|
|
756
|
-
return AnthropicLLM(api_key, model)
|
|
757
|
-
|
|
758
|
-
elif provider == LLMProvider.GROQ:
|
|
759
|
-
api_key = settings.groq_api_key
|
|
760
|
-
if not api_key:
|
|
761
|
-
raise SigmaError(
|
|
762
|
-
ErrorCode.API_KEY_MISSING,
|
|
763
|
-
"Groq API key not configured. Use /keys to set up.",
|
|
764
|
-
{"provider": "groq", "hint": "Get key at: https://console.groq.com/keys"}
|
|
765
|
-
)
|
|
766
|
-
return GroqLLM(api_key, model)
|
|
767
|
-
|
|
768
|
-
elif provider == LLMProvider.XAI:
|
|
769
|
-
api_key = settings.xai_api_key
|
|
770
|
-
if not api_key:
|
|
771
|
-
raise SigmaError(
|
|
772
|
-
ErrorCode.API_KEY_MISSING,
|
|
773
|
-
"xAI API key not configured. Use /keys to set up.",
|
|
774
|
-
{"provider": "xai", "hint": "Get key at: https://console.x.ai"}
|
|
775
|
-
)
|
|
776
|
-
return XaiLLM(api_key, model)
|
|
777
|
-
|
|
778
|
-
elif provider == LLMProvider.OLLAMA:
|
|
779
|
-
return OllamaLLM(settings.ollama_host, model)
|
|
780
|
-
|
|
781
|
-
else:
|
|
782
|
-
raise SigmaError(
|
|
783
|
-
ErrorCode.PROVIDER_UNAVAILABLE,
|
|
784
|
-
f"Unsupported provider: {provider}",
|
|
785
|
-
{"provider": str(provider)}
|
|
786
|
-
)
|