sigma-terminal 2.0.1__py3-none-any.whl → 3.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sigma/__init__.py +182 -6
- sigma/__main__.py +2 -2
- sigma/analytics/__init__.py +636 -0
- sigma/app.py +563 -898
- sigma/backtest.py +372 -0
- sigma/charts.py +407 -0
- sigma/cli.py +434 -0
- sigma/comparison.py +611 -0
- sigma/config.py +195 -0
- sigma/core/__init__.py +4 -17
- sigma/core/engine.py +493 -0
- sigma/core/intent.py +595 -0
- sigma/core/models.py +516 -125
- sigma/data/__init__.py +681 -0
- sigma/data/models.py +130 -0
- sigma/llm.py +401 -0
- sigma/monitoring.py +666 -0
- sigma/portfolio.py +697 -0
- sigma/reporting.py +658 -0
- sigma/robustness.py +675 -0
- sigma/setup.py +305 -402
- sigma/strategy.py +753 -0
- sigma/tools/backtest.py +23 -5
- sigma/tools.py +617 -0
- sigma/visualization.py +766 -0
- sigma_terminal-3.2.0.dist-info/METADATA +298 -0
- sigma_terminal-3.2.0.dist-info/RECORD +30 -0
- sigma_terminal-3.2.0.dist-info/entry_points.txt +6 -0
- sigma_terminal-3.2.0.dist-info/licenses/LICENSE +25 -0
- sigma/core/agent.py +0 -205
- sigma/core/config.py +0 -119
- sigma/core/llm.py +0 -794
- sigma/tools/__init__.py +0 -5
- sigma/tools/charts.py +0 -400
- sigma/tools/financial.py +0 -1457
- sigma/ui/__init__.py +0 -1
- sigma_terminal-2.0.1.dist-info/METADATA +0 -222
- sigma_terminal-2.0.1.dist-info/RECORD +0 -19
- sigma_terminal-2.0.1.dist-info/entry_points.txt +0 -2
- sigma_terminal-2.0.1.dist-info/licenses/LICENSE +0 -42
- {sigma_terminal-2.0.1.dist-info → sigma_terminal-3.2.0.dist-info}/WHEEL +0 -0
sigma/core/llm.py
DELETED
|
@@ -1,794 +0,0 @@
|
|
|
1
|
-
"""LLM provider implementations."""
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
from abc import ABC, abstractmethod
|
|
5
|
-
from typing import Any, AsyncIterator, Optional
|
|
6
|
-
|
|
7
|
-
import httpx
|
|
8
|
-
|
|
9
|
-
from sigma.core.config import LLMProvider, get_settings
|
|
10
|
-
from sigma.core.models import Message, MessageRole, ToolCall
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class BaseLLM(ABC):
|
|
14
|
-
"""Base LLM provider."""
|
|
15
|
-
|
|
16
|
-
def __init__(self, model: Optional[str] = None):
|
|
17
|
-
self.settings = get_settings()
|
|
18
|
-
self.model = model or self.settings.get_model(self.provider)
|
|
19
|
-
|
|
20
|
-
@property
|
|
21
|
-
@abstractmethod
|
|
22
|
-
def provider(self) -> LLMProvider:
|
|
23
|
-
"""Provider type."""
|
|
24
|
-
pass
|
|
25
|
-
|
|
26
|
-
@abstractmethod
|
|
27
|
-
async def generate(
|
|
28
|
-
self,
|
|
29
|
-
messages: list[Message],
|
|
30
|
-
tools: Optional[list[dict]] = None,
|
|
31
|
-
temperature: Optional[float] = None,
|
|
32
|
-
) -> tuple[str, list[ToolCall]]:
|
|
33
|
-
"""Generate response."""
|
|
34
|
-
pass
|
|
35
|
-
|
|
36
|
-
@abstractmethod
|
|
37
|
-
async def stream(
|
|
38
|
-
self,
|
|
39
|
-
messages: list[Message],
|
|
40
|
-
tools: Optional[list[dict]] = None,
|
|
41
|
-
temperature: Optional[float] = None,
|
|
42
|
-
) -> AsyncIterator[str]:
|
|
43
|
-
"""Stream response."""
|
|
44
|
-
pass
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class OpenAILLM(BaseLLM):
|
|
48
|
-
"""OpenAI provider."""
|
|
49
|
-
|
|
50
|
-
@property
|
|
51
|
-
def provider(self) -> LLMProvider:
|
|
52
|
-
return LLMProvider.OPENAI
|
|
53
|
-
|
|
54
|
-
def _convert_messages(self, messages: list[Message]) -> list[dict]:
|
|
55
|
-
"""Convert to OpenAI format."""
|
|
56
|
-
result = []
|
|
57
|
-
for msg in messages:
|
|
58
|
-
m: dict[str, Any] = {"role": msg.role.value, "content": msg.content}
|
|
59
|
-
if msg.tool_calls:
|
|
60
|
-
m["tool_calls"] = [
|
|
61
|
-
{
|
|
62
|
-
"id": tc.id,
|
|
63
|
-
"type": "function",
|
|
64
|
-
"function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}
|
|
65
|
-
}
|
|
66
|
-
for tc in msg.tool_calls
|
|
67
|
-
]
|
|
68
|
-
if msg.tool_call_id:
|
|
69
|
-
m["tool_call_id"] = msg.tool_call_id
|
|
70
|
-
if msg.name:
|
|
71
|
-
m["name"] = msg.name
|
|
72
|
-
result.append(m)
|
|
73
|
-
return result
|
|
74
|
-
|
|
75
|
-
async def generate(
|
|
76
|
-
self,
|
|
77
|
-
messages: list[Message],
|
|
78
|
-
tools: Optional[list[dict]] = None,
|
|
79
|
-
temperature: Optional[float] = None,
|
|
80
|
-
) -> tuple[str, list[ToolCall]]:
|
|
81
|
-
api_key = self.settings.get_api_key(LLMProvider.OPENAI)
|
|
82
|
-
if not api_key:
|
|
83
|
-
raise ValueError("OpenAI API key not set")
|
|
84
|
-
|
|
85
|
-
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
86
|
-
data: dict[str, Any] = {
|
|
87
|
-
"model": self.model,
|
|
88
|
-
"messages": self._convert_messages(messages),
|
|
89
|
-
"temperature": temperature or self.settings.temperature,
|
|
90
|
-
"max_tokens": self.settings.max_tokens,
|
|
91
|
-
}
|
|
92
|
-
|
|
93
|
-
if tools:
|
|
94
|
-
data["tools"] = [{"type": "function", "function": t} for t in tools]
|
|
95
|
-
data["tool_choice"] = "auto"
|
|
96
|
-
|
|
97
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
98
|
-
resp = await client.post(
|
|
99
|
-
"https://api.openai.com/v1/chat/completions",
|
|
100
|
-
headers=headers,
|
|
101
|
-
json=data,
|
|
102
|
-
)
|
|
103
|
-
resp.raise_for_status()
|
|
104
|
-
result = resp.json()
|
|
105
|
-
|
|
106
|
-
choice = result["choices"][0]
|
|
107
|
-
msg = choice["message"]
|
|
108
|
-
content = msg.get("content", "") or ""
|
|
109
|
-
|
|
110
|
-
tool_calls = []
|
|
111
|
-
if "tool_calls" in msg and msg["tool_calls"]:
|
|
112
|
-
for tc in msg["tool_calls"]:
|
|
113
|
-
tool_calls.append(ToolCall(
|
|
114
|
-
id=tc["id"],
|
|
115
|
-
name=tc["function"]["name"],
|
|
116
|
-
arguments=json.loads(tc["function"]["arguments"]),
|
|
117
|
-
))
|
|
118
|
-
|
|
119
|
-
return content, tool_calls
|
|
120
|
-
|
|
121
|
-
async def stream(
|
|
122
|
-
self,
|
|
123
|
-
messages: list[Message],
|
|
124
|
-
tools: Optional[list[dict]] = None,
|
|
125
|
-
temperature: Optional[float] = None,
|
|
126
|
-
) -> AsyncIterator[str]:
|
|
127
|
-
api_key = self.settings.get_api_key(LLMProvider.OPENAI)
|
|
128
|
-
if not api_key:
|
|
129
|
-
raise ValueError("OpenAI API key not set")
|
|
130
|
-
|
|
131
|
-
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
132
|
-
data: dict[str, Any] = {
|
|
133
|
-
"model": self.model,
|
|
134
|
-
"messages": self._convert_messages(messages),
|
|
135
|
-
"temperature": temperature or self.settings.temperature,
|
|
136
|
-
"max_tokens": self.settings.max_tokens,
|
|
137
|
-
"stream": True,
|
|
138
|
-
}
|
|
139
|
-
|
|
140
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
141
|
-
async with client.stream(
|
|
142
|
-
"POST",
|
|
143
|
-
"https://api.openai.com/v1/chat/completions",
|
|
144
|
-
headers=headers,
|
|
145
|
-
json=data,
|
|
146
|
-
) as resp:
|
|
147
|
-
async for line in resp.aiter_lines():
|
|
148
|
-
if line.startswith("data: "):
|
|
149
|
-
payload = line[6:]
|
|
150
|
-
if payload == "[DONE]":
|
|
151
|
-
break
|
|
152
|
-
chunk = json.loads(payload)
|
|
153
|
-
delta = chunk["choices"][0].get("delta", {})
|
|
154
|
-
if "content" in delta and delta["content"]:
|
|
155
|
-
yield delta["content"]
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
class AnthropicLLM(BaseLLM):
|
|
159
|
-
"""Anthropic provider."""
|
|
160
|
-
|
|
161
|
-
@property
|
|
162
|
-
def provider(self) -> LLMProvider:
|
|
163
|
-
return LLMProvider.ANTHROPIC
|
|
164
|
-
|
|
165
|
-
def _convert_messages(self, messages: list[Message]) -> tuple[Optional[str], list[dict]]:
|
|
166
|
-
"""Convert to Anthropic format."""
|
|
167
|
-
system = None
|
|
168
|
-
result = []
|
|
169
|
-
|
|
170
|
-
for msg in messages:
|
|
171
|
-
if msg.role == MessageRole.SYSTEM:
|
|
172
|
-
system = msg.content
|
|
173
|
-
continue
|
|
174
|
-
|
|
175
|
-
if msg.role == MessageRole.TOOL:
|
|
176
|
-
result.append({
|
|
177
|
-
"role": "user",
|
|
178
|
-
"content": [{
|
|
179
|
-
"type": "tool_result",
|
|
180
|
-
"tool_use_id": msg.tool_call_id,
|
|
181
|
-
"content": msg.content,
|
|
182
|
-
}]
|
|
183
|
-
})
|
|
184
|
-
elif msg.tool_calls:
|
|
185
|
-
content: list[dict[str, Any]] = []
|
|
186
|
-
if msg.content:
|
|
187
|
-
content.append({"type": "text", "text": msg.content})
|
|
188
|
-
for tc in msg.tool_calls:
|
|
189
|
-
content.append({
|
|
190
|
-
"type": "tool_use",
|
|
191
|
-
"id": tc.id,
|
|
192
|
-
"name": tc.name,
|
|
193
|
-
"input": tc.arguments,
|
|
194
|
-
})
|
|
195
|
-
result.append({"role": "assistant", "content": content})
|
|
196
|
-
else:
|
|
197
|
-
result.append({"role": msg.role.value, "content": msg.content})
|
|
198
|
-
|
|
199
|
-
return system, result
|
|
200
|
-
|
|
201
|
-
async def generate(
|
|
202
|
-
self,
|
|
203
|
-
messages: list[Message],
|
|
204
|
-
tools: Optional[list[dict]] = None,
|
|
205
|
-
temperature: Optional[float] = None,
|
|
206
|
-
) -> tuple[str, list[ToolCall]]:
|
|
207
|
-
api_key = self.settings.get_api_key(LLMProvider.ANTHROPIC)
|
|
208
|
-
if not api_key:
|
|
209
|
-
raise ValueError("Anthropic API key not set")
|
|
210
|
-
|
|
211
|
-
headers = {
|
|
212
|
-
"x-api-key": api_key,
|
|
213
|
-
"Content-Type": "application/json",
|
|
214
|
-
"anthropic-version": "2023-06-01",
|
|
215
|
-
}
|
|
216
|
-
|
|
217
|
-
system, msgs = self._convert_messages(messages)
|
|
218
|
-
data: dict[str, Any] = {
|
|
219
|
-
"model": self.model,
|
|
220
|
-
"messages": msgs,
|
|
221
|
-
"max_tokens": self.settings.max_tokens,
|
|
222
|
-
"temperature": temperature or self.settings.temperature,
|
|
223
|
-
}
|
|
224
|
-
|
|
225
|
-
if system:
|
|
226
|
-
data["system"] = system
|
|
227
|
-
|
|
228
|
-
if tools:
|
|
229
|
-
data["tools"] = [
|
|
230
|
-
{
|
|
231
|
-
"name": t["name"],
|
|
232
|
-
"description": t.get("description", ""),
|
|
233
|
-
"input_schema": t.get("parameters", {}),
|
|
234
|
-
}
|
|
235
|
-
for t in tools
|
|
236
|
-
]
|
|
237
|
-
|
|
238
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
239
|
-
resp = await client.post(
|
|
240
|
-
"https://api.anthropic.com/v1/messages",
|
|
241
|
-
headers=headers,
|
|
242
|
-
json=data,
|
|
243
|
-
)
|
|
244
|
-
resp.raise_for_status()
|
|
245
|
-
result = resp.json()
|
|
246
|
-
|
|
247
|
-
content = ""
|
|
248
|
-
tool_calls = []
|
|
249
|
-
|
|
250
|
-
for block in result.get("content", []):
|
|
251
|
-
if block["type"] == "text":
|
|
252
|
-
content += block["text"]
|
|
253
|
-
elif block["type"] == "tool_use":
|
|
254
|
-
tool_calls.append(ToolCall(
|
|
255
|
-
id=block["id"],
|
|
256
|
-
name=block["name"],
|
|
257
|
-
arguments=block["input"],
|
|
258
|
-
))
|
|
259
|
-
|
|
260
|
-
return content, tool_calls
|
|
261
|
-
|
|
262
|
-
async def stream(
|
|
263
|
-
self,
|
|
264
|
-
messages: list[Message],
|
|
265
|
-
tools: Optional[list[dict]] = None,
|
|
266
|
-
temperature: Optional[float] = None,
|
|
267
|
-
) -> AsyncIterator[str]:
|
|
268
|
-
api_key = self.settings.get_api_key(LLMProvider.ANTHROPIC)
|
|
269
|
-
if not api_key:
|
|
270
|
-
raise ValueError("Anthropic API key not set")
|
|
271
|
-
|
|
272
|
-
headers = {
|
|
273
|
-
"x-api-key": api_key,
|
|
274
|
-
"Content-Type": "application/json",
|
|
275
|
-
"anthropic-version": "2023-06-01",
|
|
276
|
-
}
|
|
277
|
-
|
|
278
|
-
system, msgs = self._convert_messages(messages)
|
|
279
|
-
data: dict[str, Any] = {
|
|
280
|
-
"model": self.model,
|
|
281
|
-
"messages": msgs,
|
|
282
|
-
"max_tokens": self.settings.max_tokens,
|
|
283
|
-
"temperature": temperature or self.settings.temperature,
|
|
284
|
-
"stream": True,
|
|
285
|
-
}
|
|
286
|
-
|
|
287
|
-
if system:
|
|
288
|
-
data["system"] = system
|
|
289
|
-
|
|
290
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
291
|
-
async with client.stream(
|
|
292
|
-
"POST",
|
|
293
|
-
"https://api.anthropic.com/v1/messages",
|
|
294
|
-
headers=headers,
|
|
295
|
-
json=data,
|
|
296
|
-
) as resp:
|
|
297
|
-
async for line in resp.aiter_lines():
|
|
298
|
-
if line.startswith("data: "):
|
|
299
|
-
event = json.loads(line[6:])
|
|
300
|
-
if event["type"] == "content_block_delta":
|
|
301
|
-
delta = event.get("delta", {})
|
|
302
|
-
if delta.get("type") == "text_delta":
|
|
303
|
-
yield delta.get("text", "")
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
class GoogleLLM(BaseLLM):
|
|
307
|
-
"""Google Gemini provider using REST API."""
|
|
308
|
-
|
|
309
|
-
@property
|
|
310
|
-
def provider(self) -> LLMProvider:
|
|
311
|
-
return LLMProvider.GOOGLE
|
|
312
|
-
|
|
313
|
-
def _convert_messages(self, messages: list[Message]) -> tuple[Optional[str], list[dict]]:
|
|
314
|
-
"""Convert to Gemini format."""
|
|
315
|
-
system = None
|
|
316
|
-
contents = []
|
|
317
|
-
|
|
318
|
-
for msg in messages:
|
|
319
|
-
if msg.role == MessageRole.SYSTEM:
|
|
320
|
-
system = msg.content
|
|
321
|
-
continue
|
|
322
|
-
|
|
323
|
-
role = "model" if msg.role == MessageRole.ASSISTANT else "user"
|
|
324
|
-
|
|
325
|
-
if msg.role == MessageRole.TOOL:
|
|
326
|
-
contents.append({
|
|
327
|
-
"role": "user",
|
|
328
|
-
"parts": [{
|
|
329
|
-
"functionResponse": {
|
|
330
|
-
"name": msg.name or "tool",
|
|
331
|
-
"response": {"result": msg.content}
|
|
332
|
-
}
|
|
333
|
-
}]
|
|
334
|
-
})
|
|
335
|
-
elif msg.tool_calls:
|
|
336
|
-
parts: list[dict[str, Any]] = []
|
|
337
|
-
if msg.content:
|
|
338
|
-
parts.append({"text": msg.content})
|
|
339
|
-
for tc in msg.tool_calls:
|
|
340
|
-
parts.append({
|
|
341
|
-
"functionCall": {
|
|
342
|
-
"name": tc.name,
|
|
343
|
-
"args": tc.arguments
|
|
344
|
-
}
|
|
345
|
-
})
|
|
346
|
-
contents.append({"role": role, "parts": parts})
|
|
347
|
-
else:
|
|
348
|
-
contents.append({
|
|
349
|
-
"role": role,
|
|
350
|
-
"parts": [{"text": msg.content}]
|
|
351
|
-
})
|
|
352
|
-
|
|
353
|
-
return system, contents
|
|
354
|
-
|
|
355
|
-
def _convert_tools(self, tools: list[dict]) -> list[dict]:
|
|
356
|
-
"""Convert tools to Gemini format."""
|
|
357
|
-
declarations = []
|
|
358
|
-
for t in tools:
|
|
359
|
-
decl: dict[str, Any] = {
|
|
360
|
-
"name": t["name"],
|
|
361
|
-
"description": t.get("description", ""),
|
|
362
|
-
}
|
|
363
|
-
if "parameters" in t and t["parameters"]:
|
|
364
|
-
params = t["parameters"].copy()
|
|
365
|
-
# Gemini doesn't want 'additionalProperties'
|
|
366
|
-
params.pop("additionalProperties", None)
|
|
367
|
-
decl["parameters"] = params
|
|
368
|
-
declarations.append(decl)
|
|
369
|
-
return [{"functionDeclarations": declarations}]
|
|
370
|
-
|
|
371
|
-
async def generate(
|
|
372
|
-
self,
|
|
373
|
-
messages: list[Message],
|
|
374
|
-
tools: Optional[list[dict]] = None,
|
|
375
|
-
temperature: Optional[float] = None,
|
|
376
|
-
) -> tuple[str, list[ToolCall]]:
|
|
377
|
-
api_key = self.settings.get_api_key(LLMProvider.GOOGLE)
|
|
378
|
-
if not api_key:
|
|
379
|
-
raise ValueError("Google API key not set")
|
|
380
|
-
|
|
381
|
-
system, contents = self._convert_messages(messages)
|
|
382
|
-
|
|
383
|
-
data: dict[str, Any] = {
|
|
384
|
-
"contents": contents,
|
|
385
|
-
"generationConfig": {
|
|
386
|
-
"temperature": temperature or self.settings.temperature,
|
|
387
|
-
"maxOutputTokens": self.settings.max_tokens,
|
|
388
|
-
}
|
|
389
|
-
}
|
|
390
|
-
|
|
391
|
-
if system:
|
|
392
|
-
data["systemInstruction"] = {"parts": [{"text": system}]}
|
|
393
|
-
|
|
394
|
-
if tools:
|
|
395
|
-
data["tools"] = self._convert_tools(tools)
|
|
396
|
-
|
|
397
|
-
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent?key={api_key}"
|
|
398
|
-
|
|
399
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
400
|
-
resp = await client.post(url, json=data)
|
|
401
|
-
resp.raise_for_status()
|
|
402
|
-
result = resp.json()
|
|
403
|
-
|
|
404
|
-
content = ""
|
|
405
|
-
tool_calls = []
|
|
406
|
-
|
|
407
|
-
candidates = result.get("candidates", [])
|
|
408
|
-
if candidates:
|
|
409
|
-
parts = candidates[0].get("content", {}).get("parts", [])
|
|
410
|
-
for i, part in enumerate(parts):
|
|
411
|
-
if "text" in part:
|
|
412
|
-
content += part["text"]
|
|
413
|
-
elif "functionCall" in part:
|
|
414
|
-
fc = part["functionCall"]
|
|
415
|
-
tool_calls.append(ToolCall(
|
|
416
|
-
id=f"call_{i}",
|
|
417
|
-
name=fc["name"],
|
|
418
|
-
arguments=fc.get("args", {}),
|
|
419
|
-
))
|
|
420
|
-
|
|
421
|
-
return content, tool_calls
|
|
422
|
-
|
|
423
|
-
async def stream(
|
|
424
|
-
self,
|
|
425
|
-
messages: list[Message],
|
|
426
|
-
tools: Optional[list[dict]] = None,
|
|
427
|
-
temperature: Optional[float] = None,
|
|
428
|
-
) -> AsyncIterator[str]:
|
|
429
|
-
api_key = self.settings.get_api_key(LLMProvider.GOOGLE)
|
|
430
|
-
if not api_key:
|
|
431
|
-
raise ValueError("Google API key not set")
|
|
432
|
-
|
|
433
|
-
system, contents = self._convert_messages(messages)
|
|
434
|
-
|
|
435
|
-
data: dict[str, Any] = {
|
|
436
|
-
"contents": contents,
|
|
437
|
-
"generationConfig": {
|
|
438
|
-
"temperature": temperature or self.settings.temperature,
|
|
439
|
-
"maxOutputTokens": self.settings.max_tokens,
|
|
440
|
-
}
|
|
441
|
-
}
|
|
442
|
-
|
|
443
|
-
if system:
|
|
444
|
-
data["systemInstruction"] = {"parts": [{"text": system}]}
|
|
445
|
-
|
|
446
|
-
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:streamGenerateContent?key={api_key}&alt=sse"
|
|
447
|
-
|
|
448
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
449
|
-
async with client.stream("POST", url, json=data) as resp:
|
|
450
|
-
async for line in resp.aiter_lines():
|
|
451
|
-
if line.startswith("data: "):
|
|
452
|
-
chunk = json.loads(line[6:])
|
|
453
|
-
candidates = chunk.get("candidates", [])
|
|
454
|
-
if candidates:
|
|
455
|
-
parts = candidates[0].get("content", {}).get("parts", [])
|
|
456
|
-
for part in parts:
|
|
457
|
-
if "text" in part:
|
|
458
|
-
yield part["text"]
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
class OllamaLLM(BaseLLM):
|
|
462
|
-
"""Ollama provider."""
|
|
463
|
-
|
|
464
|
-
@property
|
|
465
|
-
def provider(self) -> LLMProvider:
|
|
466
|
-
return LLMProvider.OLLAMA
|
|
467
|
-
|
|
468
|
-
def _convert_messages(self, messages: list[Message]) -> list[dict]:
|
|
469
|
-
"""Convert to Ollama format."""
|
|
470
|
-
result = []
|
|
471
|
-
for msg in messages:
|
|
472
|
-
m: dict[str, Any] = {"role": msg.role.value, "content": msg.content}
|
|
473
|
-
if msg.tool_calls:
|
|
474
|
-
m["tool_calls"] = [
|
|
475
|
-
{
|
|
476
|
-
"id": tc.id,
|
|
477
|
-
"type": "function",
|
|
478
|
-
"function": {"name": tc.name, "arguments": tc.arguments}
|
|
479
|
-
}
|
|
480
|
-
for tc in msg.tool_calls
|
|
481
|
-
]
|
|
482
|
-
result.append(m)
|
|
483
|
-
return result
|
|
484
|
-
|
|
485
|
-
async def generate(
|
|
486
|
-
self,
|
|
487
|
-
messages: list[Message],
|
|
488
|
-
tools: Optional[list[dict]] = None,
|
|
489
|
-
temperature: Optional[float] = None,
|
|
490
|
-
) -> tuple[str, list[ToolCall]]:
|
|
491
|
-
data: dict[str, Any] = {
|
|
492
|
-
"model": self.model,
|
|
493
|
-
"messages": self._convert_messages(messages),
|
|
494
|
-
"stream": False,
|
|
495
|
-
"options": {
|
|
496
|
-
"temperature": temperature or self.settings.temperature,
|
|
497
|
-
"num_predict": self.settings.max_tokens,
|
|
498
|
-
}
|
|
499
|
-
}
|
|
500
|
-
|
|
501
|
-
if tools:
|
|
502
|
-
data["tools"] = [{"type": "function", "function": t} for t in tools]
|
|
503
|
-
|
|
504
|
-
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
505
|
-
resp = await client.post(
|
|
506
|
-
f"{self.settings.ollama_base_url}/api/chat",
|
|
507
|
-
json=data,
|
|
508
|
-
)
|
|
509
|
-
resp.raise_for_status()
|
|
510
|
-
result = resp.json()
|
|
511
|
-
|
|
512
|
-
msg = result.get("message", {})
|
|
513
|
-
content = msg.get("content", "") or ""
|
|
514
|
-
|
|
515
|
-
tool_calls = []
|
|
516
|
-
if "tool_calls" in msg and msg["tool_calls"]:
|
|
517
|
-
for i, tc in enumerate(msg["tool_calls"]):
|
|
518
|
-
fn = tc.get("function", {})
|
|
519
|
-
tool_calls.append(ToolCall(
|
|
520
|
-
id=f"call_{i}",
|
|
521
|
-
name=fn.get("name", ""),
|
|
522
|
-
arguments=fn.get("arguments", {}),
|
|
523
|
-
))
|
|
524
|
-
|
|
525
|
-
return content, tool_calls
|
|
526
|
-
|
|
527
|
-
async def stream(
|
|
528
|
-
self,
|
|
529
|
-
messages: list[Message],
|
|
530
|
-
tools: Optional[list[dict]] = None,
|
|
531
|
-
temperature: Optional[float] = None,
|
|
532
|
-
) -> AsyncIterator[str]:
|
|
533
|
-
data: dict[str, Any] = {
|
|
534
|
-
"model": self.model,
|
|
535
|
-
"messages": self._convert_messages(messages),
|
|
536
|
-
"stream": True,
|
|
537
|
-
"options": {
|
|
538
|
-
"temperature": temperature or self.settings.temperature,
|
|
539
|
-
"num_predict": self.settings.max_tokens,
|
|
540
|
-
}
|
|
541
|
-
}
|
|
542
|
-
|
|
543
|
-
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
544
|
-
async with client.stream(
|
|
545
|
-
"POST",
|
|
546
|
-
f"{self.settings.ollama_base_url}/api/chat",
|
|
547
|
-
json=data,
|
|
548
|
-
) as resp:
|
|
549
|
-
async for line in resp.aiter_lines():
|
|
550
|
-
if line:
|
|
551
|
-
chunk = json.loads(line)
|
|
552
|
-
msg = chunk.get("message", {})
|
|
553
|
-
if "content" in msg:
|
|
554
|
-
yield msg["content"]
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
class GroqLLM(BaseLLM):
|
|
558
|
-
"""Groq provider."""
|
|
559
|
-
|
|
560
|
-
@property
|
|
561
|
-
def provider(self) -> LLMProvider:
|
|
562
|
-
return LLMProvider.GROQ
|
|
563
|
-
|
|
564
|
-
def _convert_messages(self, messages: list[Message]) -> list[dict]:
|
|
565
|
-
"""Convert to Groq format (OpenAI compatible)."""
|
|
566
|
-
result = []
|
|
567
|
-
for msg in messages:
|
|
568
|
-
m: dict[str, Any] = {"role": msg.role.value, "content": msg.content}
|
|
569
|
-
if msg.tool_calls:
|
|
570
|
-
m["tool_calls"] = [
|
|
571
|
-
{
|
|
572
|
-
"id": tc.id,
|
|
573
|
-
"type": "function",
|
|
574
|
-
"function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}
|
|
575
|
-
}
|
|
576
|
-
for tc in msg.tool_calls
|
|
577
|
-
]
|
|
578
|
-
if msg.tool_call_id:
|
|
579
|
-
m["tool_call_id"] = msg.tool_call_id
|
|
580
|
-
if msg.name:
|
|
581
|
-
m["name"] = msg.name
|
|
582
|
-
result.append(m)
|
|
583
|
-
return result
|
|
584
|
-
|
|
585
|
-
async def generate(
|
|
586
|
-
self,
|
|
587
|
-
messages: list[Message],
|
|
588
|
-
tools: Optional[list[dict]] = None,
|
|
589
|
-
temperature: Optional[float] = None,
|
|
590
|
-
) -> tuple[str, list[ToolCall]]:
|
|
591
|
-
api_key = self.settings.get_api_key(LLMProvider.GROQ)
|
|
592
|
-
if not api_key:
|
|
593
|
-
raise ValueError("Groq API key not set")
|
|
594
|
-
|
|
595
|
-
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
596
|
-
data: dict[str, Any] = {
|
|
597
|
-
"model": self.model,
|
|
598
|
-
"messages": self._convert_messages(messages),
|
|
599
|
-
"temperature": temperature or self.settings.temperature,
|
|
600
|
-
"max_tokens": self.settings.max_tokens,
|
|
601
|
-
}
|
|
602
|
-
|
|
603
|
-
if tools:
|
|
604
|
-
data["tools"] = [{"type": "function", "function": t} for t in tools]
|
|
605
|
-
data["tool_choice"] = "auto"
|
|
606
|
-
|
|
607
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
608
|
-
resp = await client.post(
|
|
609
|
-
"https://api.groq.com/openai/v1/chat/completions",
|
|
610
|
-
headers=headers,
|
|
611
|
-
json=data,
|
|
612
|
-
)
|
|
613
|
-
resp.raise_for_status()
|
|
614
|
-
result = resp.json()
|
|
615
|
-
|
|
616
|
-
choice = result["choices"][0]
|
|
617
|
-
msg = choice["message"]
|
|
618
|
-
content = msg.get("content", "") or ""
|
|
619
|
-
|
|
620
|
-
tool_calls = []
|
|
621
|
-
if "tool_calls" in msg and msg["tool_calls"]:
|
|
622
|
-
for tc in msg["tool_calls"]:
|
|
623
|
-
tool_calls.append(ToolCall(
|
|
624
|
-
id=tc["id"],
|
|
625
|
-
name=tc["function"]["name"],
|
|
626
|
-
arguments=json.loads(tc["function"]["arguments"]),
|
|
627
|
-
))
|
|
628
|
-
|
|
629
|
-
return content, tool_calls
|
|
630
|
-
|
|
631
|
-
async def stream(
|
|
632
|
-
self,
|
|
633
|
-
messages: list[Message],
|
|
634
|
-
tools: Optional[list[dict]] = None,
|
|
635
|
-
temperature: Optional[float] = None,
|
|
636
|
-
) -> AsyncIterator[str]:
|
|
637
|
-
api_key = self.settings.get_api_key(LLMProvider.GROQ)
|
|
638
|
-
if not api_key:
|
|
639
|
-
raise ValueError("Groq API key not set")
|
|
640
|
-
|
|
641
|
-
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
642
|
-
data: dict[str, Any] = {
|
|
643
|
-
"model": self.model,
|
|
644
|
-
"messages": self._convert_messages(messages),
|
|
645
|
-
"temperature": temperature or self.settings.temperature,
|
|
646
|
-
"max_tokens": self.settings.max_tokens,
|
|
647
|
-
"stream": True,
|
|
648
|
-
}
|
|
649
|
-
|
|
650
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
651
|
-
async with client.stream(
|
|
652
|
-
"POST",
|
|
653
|
-
"https://api.groq.com/openai/v1/chat/completions",
|
|
654
|
-
headers=headers,
|
|
655
|
-
json=data,
|
|
656
|
-
) as resp:
|
|
657
|
-
async for line in resp.aiter_lines():
|
|
658
|
-
if line.startswith("data: "):
|
|
659
|
-
payload = line[6:]
|
|
660
|
-
if payload == "[DONE]":
|
|
661
|
-
break
|
|
662
|
-
chunk = json.loads(payload)
|
|
663
|
-
delta = chunk["choices"][0].get("delta", {})
|
|
664
|
-
if "content" in delta and delta["content"]:
|
|
665
|
-
yield delta["content"]
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
class XaiLLM(BaseLLM):
|
|
669
|
-
"""xAI Grok provider."""
|
|
670
|
-
|
|
671
|
-
@property
|
|
672
|
-
def provider(self) -> LLMProvider:
|
|
673
|
-
return LLMProvider.XAI
|
|
674
|
-
|
|
675
|
-
def _convert_messages(self, messages: list[Message]) -> list[dict]:
|
|
676
|
-
"""Convert to xAI format (OpenAI compatible)."""
|
|
677
|
-
result = []
|
|
678
|
-
for msg in messages:
|
|
679
|
-
m: dict[str, Any] = {"role": msg.role.value, "content": msg.content}
|
|
680
|
-
if msg.tool_calls:
|
|
681
|
-
m["tool_calls"] = [
|
|
682
|
-
{
|
|
683
|
-
"id": tc.id,
|
|
684
|
-
"type": "function",
|
|
685
|
-
"function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}
|
|
686
|
-
}
|
|
687
|
-
for tc in msg.tool_calls
|
|
688
|
-
]
|
|
689
|
-
if msg.tool_call_id:
|
|
690
|
-
m["tool_call_id"] = msg.tool_call_id
|
|
691
|
-
result.append(m)
|
|
692
|
-
return result
|
|
693
|
-
|
|
694
|
-
async def generate(
|
|
695
|
-
self,
|
|
696
|
-
messages: list[Message],
|
|
697
|
-
tools: Optional[list[dict]] = None,
|
|
698
|
-
temperature: Optional[float] = None,
|
|
699
|
-
) -> tuple[str, list[ToolCall]]:
|
|
700
|
-
api_key = self.settings.get_api_key(LLMProvider.XAI)
|
|
701
|
-
if not api_key:
|
|
702
|
-
raise ValueError("xAI API key not set")
|
|
703
|
-
|
|
704
|
-
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
705
|
-
data: dict[str, Any] = {
|
|
706
|
-
"model": self.model,
|
|
707
|
-
"messages": self._convert_messages(messages),
|
|
708
|
-
"temperature": temperature or self.settings.temperature,
|
|
709
|
-
"max_tokens": self.settings.max_tokens,
|
|
710
|
-
}
|
|
711
|
-
|
|
712
|
-
if tools:
|
|
713
|
-
data["tools"] = [{"type": "function", "function": t} for t in tools]
|
|
714
|
-
|
|
715
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
716
|
-
resp = await client.post(
|
|
717
|
-
"https://api.x.ai/v1/chat/completions",
|
|
718
|
-
headers=headers,
|
|
719
|
-
json=data,
|
|
720
|
-
)
|
|
721
|
-
resp.raise_for_status()
|
|
722
|
-
result = resp.json()
|
|
723
|
-
|
|
724
|
-
choice = result["choices"][0]
|
|
725
|
-
msg = choice["message"]
|
|
726
|
-
content = msg.get("content", "") or ""
|
|
727
|
-
|
|
728
|
-
tool_calls = []
|
|
729
|
-
if "tool_calls" in msg and msg["tool_calls"]:
|
|
730
|
-
for tc in msg["tool_calls"]:
|
|
731
|
-
tool_calls.append(ToolCall(
|
|
732
|
-
id=tc["id"],
|
|
733
|
-
name=tc["function"]["name"],
|
|
734
|
-
arguments=json.loads(tc["function"]["arguments"]),
|
|
735
|
-
))
|
|
736
|
-
|
|
737
|
-
return content, tool_calls
|
|
738
|
-
|
|
739
|
-
async def stream(
|
|
740
|
-
self,
|
|
741
|
-
messages: list[Message],
|
|
742
|
-
tools: Optional[list[dict]] = None,
|
|
743
|
-
temperature: Optional[float] = None,
|
|
744
|
-
) -> AsyncIterator[str]:
|
|
745
|
-
api_key = self.settings.get_api_key(LLMProvider.XAI)
|
|
746
|
-
if not api_key:
|
|
747
|
-
raise ValueError("xAI API key not set")
|
|
748
|
-
|
|
749
|
-
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
750
|
-
data: dict[str, Any] = {
|
|
751
|
-
"model": self.model,
|
|
752
|
-
"messages": self._convert_messages(messages),
|
|
753
|
-
"temperature": temperature or self.settings.temperature,
|
|
754
|
-
"max_tokens": self.settings.max_tokens,
|
|
755
|
-
"stream": True,
|
|
756
|
-
}
|
|
757
|
-
|
|
758
|
-
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
759
|
-
async with client.stream(
|
|
760
|
-
"POST",
|
|
761
|
-
"https://api.x.ai/v1/chat/completions",
|
|
762
|
-
headers=headers,
|
|
763
|
-
json=data,
|
|
764
|
-
) as resp:
|
|
765
|
-
async for line in resp.aiter_lines():
|
|
766
|
-
if line.startswith("data: "):
|
|
767
|
-
payload = line[6:]
|
|
768
|
-
if payload == "[DONE]":
|
|
769
|
-
break
|
|
770
|
-
chunk = json.loads(payload)
|
|
771
|
-
delta = chunk["choices"][0].get("delta", {})
|
|
772
|
-
if "content" in delta and delta["content"]:
|
|
773
|
-
yield delta["content"]
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
def get_llm(provider: Optional[LLMProvider] = None, model: Optional[str] = None) -> BaseLLM:
|
|
777
|
-
"""Get LLM instance."""
|
|
778
|
-
settings = get_settings()
|
|
779
|
-
provider = provider or settings.default_provider
|
|
780
|
-
|
|
781
|
-
providers = {
|
|
782
|
-
LLMProvider.OPENAI: OpenAILLM,
|
|
783
|
-
LLMProvider.ANTHROPIC: AnthropicLLM,
|
|
784
|
-
LLMProvider.GOOGLE: GoogleLLM,
|
|
785
|
-
LLMProvider.OLLAMA: OllamaLLM,
|
|
786
|
-
LLMProvider.GROQ: GroqLLM,
|
|
787
|
-
LLMProvider.XAI: XaiLLM,
|
|
788
|
-
}
|
|
789
|
-
|
|
790
|
-
cls = providers.get(provider)
|
|
791
|
-
if not cls:
|
|
792
|
-
raise ValueError(f"Unknown provider: {provider}")
|
|
793
|
-
|
|
794
|
-
return cls(model=model)
|