sigma-terminal 2.0.1__py3-none-any.whl → 3.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sigma/__init__.py +182 -6
- sigma/__main__.py +2 -2
- sigma/analytics/__init__.py +636 -0
- sigma/app.py +563 -898
- sigma/backtest.py +372 -0
- sigma/charts.py +407 -0
- sigma/cli.py +434 -0
- sigma/comparison.py +611 -0
- sigma/config.py +195 -0
- sigma/core/__init__.py +4 -17
- sigma/core/engine.py +493 -0
- sigma/core/intent.py +595 -0
- sigma/core/models.py +516 -125
- sigma/data/__init__.py +681 -0
- sigma/data/models.py +130 -0
- sigma/llm.py +401 -0
- sigma/monitoring.py +666 -0
- sigma/portfolio.py +697 -0
- sigma/reporting.py +658 -0
- sigma/robustness.py +675 -0
- sigma/setup.py +305 -402
- sigma/strategy.py +753 -0
- sigma/tools/backtest.py +23 -5
- sigma/tools.py +617 -0
- sigma/visualization.py +766 -0
- sigma_terminal-3.2.0.dist-info/METADATA +298 -0
- sigma_terminal-3.2.0.dist-info/RECORD +30 -0
- sigma_terminal-3.2.0.dist-info/entry_points.txt +6 -0
- sigma_terminal-3.2.0.dist-info/licenses/LICENSE +25 -0
- sigma/core/agent.py +0 -205
- sigma/core/config.py +0 -119
- sigma/core/llm.py +0 -794
- sigma/tools/__init__.py +0 -5
- sigma/tools/charts.py +0 -400
- sigma/tools/financial.py +0 -1457
- sigma/ui/__init__.py +0 -1
- sigma_terminal-2.0.1.dist-info/METADATA +0 -222
- sigma_terminal-2.0.1.dist-info/RECORD +0 -19
- sigma_terminal-2.0.1.dist-info/entry_points.txt +0 -2
- sigma_terminal-2.0.1.dist-info/licenses/LICENSE +0 -42
- {sigma_terminal-2.0.1.dist-info → sigma_terminal-3.2.0.dist-info}/WHEEL +0 -0
sigma/data/models.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Data models for the data module."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime, date
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AssetClass(Enum):
|
|
10
|
+
"""Asset class types."""
|
|
11
|
+
EQUITY = "equity"
|
|
12
|
+
ETF = "etf"
|
|
13
|
+
CRYPTO = "crypto"
|
|
14
|
+
FOREX = "forex"
|
|
15
|
+
INDEX = "index"
|
|
16
|
+
COMMODITY = "commodity"
|
|
17
|
+
BOND = "bond"
|
|
18
|
+
OPTION = "option"
|
|
19
|
+
FUTURE = "future"
|
|
20
|
+
UNKNOWN = "unknown"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DataSource(Enum):
|
|
24
|
+
"""Data source providers."""
|
|
25
|
+
YFINANCE = "yfinance"
|
|
26
|
+
ALPHA_VANTAGE = "alpha_vantage"
|
|
27
|
+
POLYGON = "polygon"
|
|
28
|
+
QUANDL = "quandl"
|
|
29
|
+
FRED = "fred"
|
|
30
|
+
CACHE = "cache"
|
|
31
|
+
COMPUTED = "computed"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class DataLineage:
|
|
36
|
+
"""Track data provenance."""
|
|
37
|
+
source: DataSource
|
|
38
|
+
symbol: str
|
|
39
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
40
|
+
version: str = "1.0"
|
|
41
|
+
transformations: List[str] = field(default_factory=list)
|
|
42
|
+
quality_score: float = 1.0
|
|
43
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class DataQualityReport:
|
|
48
|
+
"""Quality assessment for data."""
|
|
49
|
+
completeness: float = 1.0 # % of non-null values
|
|
50
|
+
accuracy: float = 1.0 # estimated accuracy
|
|
51
|
+
timeliness: float = 1.0 # freshness score
|
|
52
|
+
consistency: float = 1.0 # internal consistency
|
|
53
|
+
issues: List[str] = field(default_factory=list)
|
|
54
|
+
warnings: List[str] = field(default_factory=list)
|
|
55
|
+
total_records: int = 0
|
|
56
|
+
missing_count: int = 0
|
|
57
|
+
missing_pct: float = 0.0
|
|
58
|
+
stale_ticks: int = 0
|
|
59
|
+
outliers_detected: int = 0
|
|
60
|
+
timezone_issues: int = 0
|
|
61
|
+
date_range: Optional[Tuple[Any, Any]] = None
|
|
62
|
+
gaps: List[Any] = field(default_factory=list)
|
|
63
|
+
passed: bool = True
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def overall_score(self) -> float:
|
|
67
|
+
"""Calculate overall quality score."""
|
|
68
|
+
return (self.completeness + self.accuracy + self.timeliness + self.consistency) / 4
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class CorporateAction:
|
|
73
|
+
"""Corporate action event."""
|
|
74
|
+
action_type: str # split, dividend, merger, spinoff
|
|
75
|
+
date: date
|
|
76
|
+
symbol: str
|
|
77
|
+
ratio: Optional[float] = None # for splits
|
|
78
|
+
amount: Optional[float] = None # for dividends
|
|
79
|
+
adjustment_factor: Optional[float] = None # adjustment multiplier
|
|
80
|
+
details: Dict[str, Any] = field(default_factory=dict)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class PriceBar:
|
|
85
|
+
"""Single price bar."""
|
|
86
|
+
timestamp: datetime
|
|
87
|
+
open: float
|
|
88
|
+
high: float
|
|
89
|
+
low: float
|
|
90
|
+
close: float
|
|
91
|
+
volume: int
|
|
92
|
+
adjusted_close: Optional[float] = None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class Fundamental:
|
|
97
|
+
"""Fundamental data point."""
|
|
98
|
+
symbol: str
|
|
99
|
+
period: str # quarterly, annual
|
|
100
|
+
date: date
|
|
101
|
+
metrics: Dict[str, Any] = field(default_factory=dict)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def detect_asset_class(symbol: str) -> AssetClass:
|
|
105
|
+
"""Detect asset class from symbol."""
|
|
106
|
+
symbol = symbol.upper()
|
|
107
|
+
|
|
108
|
+
# Crypto patterns
|
|
109
|
+
if symbol.endswith("-USD") or symbol.endswith("USD"):
|
|
110
|
+
return AssetClass.CRYPTO
|
|
111
|
+
if symbol in ["BTC", "ETH", "DOGE", "SOL", "ADA"]:
|
|
112
|
+
return AssetClass.CRYPTO
|
|
113
|
+
|
|
114
|
+
# Forex patterns
|
|
115
|
+
if len(symbol) == 6 and symbol.isalpha():
|
|
116
|
+
major_currencies = ["USD", "EUR", "GBP", "JPY", "CHF", "CAD", "AUD", "NZD"]
|
|
117
|
+
if symbol[:3] in major_currencies and symbol[3:] in major_currencies:
|
|
118
|
+
return AssetClass.FOREX
|
|
119
|
+
|
|
120
|
+
# Index patterns
|
|
121
|
+
if symbol.startswith("^") or symbol in ["SPY", "QQQ", "DIA", "IWM", "VTI"]:
|
|
122
|
+
return AssetClass.INDEX if symbol.startswith("^") else AssetClass.ETF
|
|
123
|
+
|
|
124
|
+
# Common ETFs
|
|
125
|
+
etfs = ["SPY", "QQQ", "IWM", "VTI", "VOO", "VEA", "VWO", "BND", "GLD", "SLV", "USO"]
|
|
126
|
+
if symbol in etfs:
|
|
127
|
+
return AssetClass.ETF
|
|
128
|
+
|
|
129
|
+
# Default to equity
|
|
130
|
+
return AssetClass.EQUITY
|
sigma/llm.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
"""LLM client implementations for all providers."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Any, AsyncIterator, Callable, Optional
|
|
6
|
+
|
|
7
|
+
from sigma.config import LLMProvider, get_settings
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseLLM(ABC):
|
|
11
|
+
"""Base class for LLM clients."""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
async def generate(
|
|
15
|
+
self,
|
|
16
|
+
messages: list[dict],
|
|
17
|
+
tools: Optional[list[dict]] = None,
|
|
18
|
+
on_tool_call: Optional[Callable] = None,
|
|
19
|
+
) -> str:
|
|
20
|
+
"""Generate a response."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GoogleLLM(BaseLLM):
|
|
25
|
+
"""Google Gemini client."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, api_key: str, model: str):
|
|
28
|
+
from google import genai
|
|
29
|
+
self.client = genai.Client(api_key=api_key)
|
|
30
|
+
self.model_name = model
|
|
31
|
+
|
|
32
|
+
async def generate(
|
|
33
|
+
self,
|
|
34
|
+
messages: list[dict],
|
|
35
|
+
tools: Optional[list[dict]] = None,
|
|
36
|
+
on_tool_call: Optional[Callable] = None,
|
|
37
|
+
) -> str:
|
|
38
|
+
from google.genai import types
|
|
39
|
+
|
|
40
|
+
# Extract system prompt and build contents
|
|
41
|
+
system_prompt = None
|
|
42
|
+
contents = []
|
|
43
|
+
|
|
44
|
+
for msg in messages:
|
|
45
|
+
role = msg["role"]
|
|
46
|
+
content = msg["content"]
|
|
47
|
+
|
|
48
|
+
if role == "system":
|
|
49
|
+
system_prompt = content
|
|
50
|
+
elif role == "user":
|
|
51
|
+
contents.append(types.Content(
|
|
52
|
+
role="user",
|
|
53
|
+
parts=[types.Part(text=content)]
|
|
54
|
+
))
|
|
55
|
+
elif role == "assistant":
|
|
56
|
+
contents.append(types.Content(
|
|
57
|
+
role="model",
|
|
58
|
+
parts=[types.Part(text=content)]
|
|
59
|
+
))
|
|
60
|
+
|
|
61
|
+
# Build config
|
|
62
|
+
config = types.GenerateContentConfig(
|
|
63
|
+
system_instruction=system_prompt,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Add tools if provided
|
|
67
|
+
if tools:
|
|
68
|
+
function_declarations = []
|
|
69
|
+
for tool in tools:
|
|
70
|
+
if tool.get("type") == "function":
|
|
71
|
+
func = tool["function"]
|
|
72
|
+
function_declarations.append(types.FunctionDeclaration(
|
|
73
|
+
name=func["name"],
|
|
74
|
+
description=func.get("description", ""),
|
|
75
|
+
parameters=func.get("parameters", {}),
|
|
76
|
+
))
|
|
77
|
+
if function_declarations:
|
|
78
|
+
config.tools = [types.Tool(function_declarations=function_declarations)]
|
|
79
|
+
|
|
80
|
+
# Generate
|
|
81
|
+
response = self.client.models.generate_content(
|
|
82
|
+
model=self.model_name,
|
|
83
|
+
contents=contents,
|
|
84
|
+
config=config,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Handle function calls
|
|
88
|
+
if response.candidates:
|
|
89
|
+
candidate = response.candidates[0]
|
|
90
|
+
if candidate.content and candidate.content.parts:
|
|
91
|
+
# Collect all function calls first
|
|
92
|
+
function_calls = []
|
|
93
|
+
for part in candidate.content.parts:
|
|
94
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
95
|
+
function_calls.append(part.function_call)
|
|
96
|
+
|
|
97
|
+
# If there are function calls, process all of them
|
|
98
|
+
if function_calls and on_tool_call:
|
|
99
|
+
# Add the model's response with function calls
|
|
100
|
+
contents.append(candidate.content)
|
|
101
|
+
|
|
102
|
+
# Execute all function calls and build responses
|
|
103
|
+
function_responses = []
|
|
104
|
+
for fc in function_calls:
|
|
105
|
+
args = dict(fc.args) if fc.args else {}
|
|
106
|
+
result = await on_tool_call(fc.name, args)
|
|
107
|
+
function_responses.append(types.Part(
|
|
108
|
+
function_response=types.FunctionResponse(
|
|
109
|
+
name=fc.name,
|
|
110
|
+
response={"result": str(result)}
|
|
111
|
+
)
|
|
112
|
+
))
|
|
113
|
+
|
|
114
|
+
# Add all function responses in a single user message
|
|
115
|
+
contents.append(types.Content(
|
|
116
|
+
role="user",
|
|
117
|
+
parts=function_responses
|
|
118
|
+
))
|
|
119
|
+
|
|
120
|
+
# Get final response
|
|
121
|
+
final_response = self.client.models.generate_content(
|
|
122
|
+
model=self.model_name,
|
|
123
|
+
contents=contents,
|
|
124
|
+
config=config,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Check if there are more function calls in the response
|
|
128
|
+
if final_response.candidates:
|
|
129
|
+
final_candidate = final_response.candidates[0]
|
|
130
|
+
if final_candidate.content and final_candidate.content.parts:
|
|
131
|
+
for part in final_candidate.content.parts:
|
|
132
|
+
if hasattr(part, 'function_call') and part.function_call:
|
|
133
|
+
# Recursive call to handle chained tool calls
|
|
134
|
+
new_contents = contents + [final_candidate.content]
|
|
135
|
+
fc = part.function_call
|
|
136
|
+
args = dict(fc.args) if fc.args else {}
|
|
137
|
+
result = await on_tool_call(fc.name, args)
|
|
138
|
+
new_contents.append(types.Content(
|
|
139
|
+
role="user",
|
|
140
|
+
parts=[types.Part(
|
|
141
|
+
function_response=types.FunctionResponse(
|
|
142
|
+
name=fc.name,
|
|
143
|
+
response={"result": str(result)}
|
|
144
|
+
)
|
|
145
|
+
)]
|
|
146
|
+
))
|
|
147
|
+
final_final = self.client.models.generate_content(
|
|
148
|
+
model=self.model_name,
|
|
149
|
+
contents=new_contents,
|
|
150
|
+
config=config,
|
|
151
|
+
)
|
|
152
|
+
return final_final.text or ""
|
|
153
|
+
|
|
154
|
+
return final_response.text or ""
|
|
155
|
+
|
|
156
|
+
return response.text or ""
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class OpenAILLM(BaseLLM):
|
|
160
|
+
"""OpenAI client."""
|
|
161
|
+
|
|
162
|
+
def __init__(self, api_key: str, model: str):
|
|
163
|
+
from openai import AsyncOpenAI
|
|
164
|
+
self.client = AsyncOpenAI(api_key=api_key)
|
|
165
|
+
self.model = model
|
|
166
|
+
|
|
167
|
+
async def generate(
|
|
168
|
+
self,
|
|
169
|
+
messages: list[dict],
|
|
170
|
+
tools: Optional[list[dict]] = None,
|
|
171
|
+
on_tool_call: Optional[Callable] = None,
|
|
172
|
+
) -> str:
|
|
173
|
+
kwargs = {
|
|
174
|
+
"model": self.model,
|
|
175
|
+
"messages": messages,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
if tools:
|
|
179
|
+
kwargs["tools"] = tools
|
|
180
|
+
kwargs["tool_choice"] = "auto"
|
|
181
|
+
|
|
182
|
+
response = await self.client.chat.completions.create(**kwargs)
|
|
183
|
+
message = response.choices[0].message
|
|
184
|
+
|
|
185
|
+
# Handle tool calls
|
|
186
|
+
if message.tool_calls and on_tool_call:
|
|
187
|
+
tool_results = []
|
|
188
|
+
for tc in message.tool_calls:
|
|
189
|
+
args = json.loads(tc.function.arguments)
|
|
190
|
+
result = await on_tool_call(tc.function.name, args)
|
|
191
|
+
tool_results.append({
|
|
192
|
+
"tool_call_id": tc.id,
|
|
193
|
+
"role": "tool",
|
|
194
|
+
"content": json.dumps(result)
|
|
195
|
+
})
|
|
196
|
+
|
|
197
|
+
# Continue with tool results
|
|
198
|
+
messages = messages + [message.model_dump()] + tool_results
|
|
199
|
+
return await self.generate(messages, tools, on_tool_call)
|
|
200
|
+
|
|
201
|
+
return message.content or ""
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class AnthropicLLM(BaseLLM):
|
|
205
|
+
"""Anthropic Claude client."""
|
|
206
|
+
|
|
207
|
+
def __init__(self, api_key: str, model: str):
|
|
208
|
+
from anthropic import AsyncAnthropic
|
|
209
|
+
self.client = AsyncAnthropic(api_key=api_key)
|
|
210
|
+
self.model = model
|
|
211
|
+
|
|
212
|
+
async def generate(
|
|
213
|
+
self,
|
|
214
|
+
messages: list[dict],
|
|
215
|
+
tools: Optional[list[dict]] = None,
|
|
216
|
+
on_tool_call: Optional[Callable] = None,
|
|
217
|
+
) -> str:
|
|
218
|
+
# Extract system message
|
|
219
|
+
system = ""
|
|
220
|
+
filtered_messages = []
|
|
221
|
+
for msg in messages:
|
|
222
|
+
if msg["role"] == "system":
|
|
223
|
+
system = msg["content"]
|
|
224
|
+
else:
|
|
225
|
+
filtered_messages.append(msg)
|
|
226
|
+
|
|
227
|
+
kwargs = {
|
|
228
|
+
"model": self.model,
|
|
229
|
+
"max_tokens": 4096,
|
|
230
|
+
"messages": filtered_messages,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
if system:
|
|
234
|
+
kwargs["system"] = system
|
|
235
|
+
|
|
236
|
+
if tools:
|
|
237
|
+
# Convert to Anthropic format
|
|
238
|
+
kwargs["tools"] = [
|
|
239
|
+
{
|
|
240
|
+
"name": t["function"]["name"],
|
|
241
|
+
"description": t["function"].get("description", ""),
|
|
242
|
+
"input_schema": t["function"].get("parameters", {})
|
|
243
|
+
}
|
|
244
|
+
for t in tools if t.get("type") == "function"
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
response = await self.client.messages.create(**kwargs)
|
|
248
|
+
|
|
249
|
+
# Handle tool use
|
|
250
|
+
result_text = ""
|
|
251
|
+
for block in response.content:
|
|
252
|
+
if block.type == "text":
|
|
253
|
+
result_text += block.text
|
|
254
|
+
elif block.type == "tool_use" and on_tool_call:
|
|
255
|
+
result = await on_tool_call(block.name, block.input)
|
|
256
|
+
# Continue conversation
|
|
257
|
+
filtered_messages.append({
|
|
258
|
+
"role": "assistant",
|
|
259
|
+
"content": response.content
|
|
260
|
+
})
|
|
261
|
+
filtered_messages.append({
|
|
262
|
+
"role": "user",
|
|
263
|
+
"content": [{
|
|
264
|
+
"type": "tool_result",
|
|
265
|
+
"tool_use_id": block.id,
|
|
266
|
+
"content": json.dumps(result)
|
|
267
|
+
}]
|
|
268
|
+
})
|
|
269
|
+
return await self.generate(
|
|
270
|
+
[{"role": "system", "content": system}] + filtered_messages,
|
|
271
|
+
tools, on_tool_call
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return result_text
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class GroqLLM(BaseLLM):
|
|
278
|
+
"""Groq client."""
|
|
279
|
+
|
|
280
|
+
def __init__(self, api_key: str, model: str):
|
|
281
|
+
from groq import AsyncGroq
|
|
282
|
+
self.client = AsyncGroq(api_key=api_key)
|
|
283
|
+
self.model = model
|
|
284
|
+
|
|
285
|
+
async def generate(
|
|
286
|
+
self,
|
|
287
|
+
messages: list[dict],
|
|
288
|
+
tools: Optional[list[dict]] = None,
|
|
289
|
+
on_tool_call: Optional[Callable] = None,
|
|
290
|
+
) -> str:
|
|
291
|
+
kwargs = {
|
|
292
|
+
"model": self.model,
|
|
293
|
+
"messages": messages,
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
if tools:
|
|
297
|
+
kwargs["tools"] = tools
|
|
298
|
+
kwargs["tool_choice"] = "auto"
|
|
299
|
+
|
|
300
|
+
response = await self.client.chat.completions.create(**kwargs)
|
|
301
|
+
message = response.choices[0].message
|
|
302
|
+
|
|
303
|
+
# Handle tool calls (similar to OpenAI)
|
|
304
|
+
if message.tool_calls and on_tool_call:
|
|
305
|
+
tool_results = []
|
|
306
|
+
for tc in message.tool_calls:
|
|
307
|
+
args = json.loads(tc.function.arguments)
|
|
308
|
+
result = await on_tool_call(tc.function.name, args)
|
|
309
|
+
tool_results.append({
|
|
310
|
+
"tool_call_id": tc.id,
|
|
311
|
+
"role": "tool",
|
|
312
|
+
"content": json.dumps(result)
|
|
313
|
+
})
|
|
314
|
+
|
|
315
|
+
messages = messages + [{"role": "assistant", "tool_calls": message.tool_calls}] + tool_results
|
|
316
|
+
return await self.generate(messages, tools, on_tool_call)
|
|
317
|
+
|
|
318
|
+
return message.content or ""
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class OllamaLLM(BaseLLM):
|
|
322
|
+
"""Ollama local client."""
|
|
323
|
+
|
|
324
|
+
def __init__(self, host: str, model: str):
|
|
325
|
+
self.host = host.rstrip("/")
|
|
326
|
+
self.model = model
|
|
327
|
+
|
|
328
|
+
async def generate(
|
|
329
|
+
self,
|
|
330
|
+
messages: list[dict],
|
|
331
|
+
tools: Optional[list[dict]] = None,
|
|
332
|
+
on_tool_call: Optional[Callable] = None,
|
|
333
|
+
) -> str:
|
|
334
|
+
import aiohttp
|
|
335
|
+
|
|
336
|
+
# Ollama doesn't support tools natively, so we embed tool info in prompt
|
|
337
|
+
if tools:
|
|
338
|
+
tool_desc = self._format_tools_for_prompt(tools)
|
|
339
|
+
# Prepend to system message
|
|
340
|
+
for i, msg in enumerate(messages):
|
|
341
|
+
if msg["role"] == "system":
|
|
342
|
+
messages[i]["content"] = f"{msg['content']}\n\n{tool_desc}"
|
|
343
|
+
break
|
|
344
|
+
else:
|
|
345
|
+
messages.insert(0, {"role": "system", "content": tool_desc})
|
|
346
|
+
|
|
347
|
+
async with aiohttp.ClientSession() as session:
|
|
348
|
+
async with session.post(
|
|
349
|
+
f"{self.host}/api/chat",
|
|
350
|
+
json={"model": self.model, "messages": messages, "stream": False}
|
|
351
|
+
) as resp:
|
|
352
|
+
data = await resp.json()
|
|
353
|
+
return data.get("message", {}).get("content", "")
|
|
354
|
+
|
|
355
|
+
def _format_tools_for_prompt(self, tools: list[dict]) -> str:
|
|
356
|
+
"""Format tools as text for prompt injection."""
|
|
357
|
+
lines = ["You have access to these tools:"]
|
|
358
|
+
for tool in tools:
|
|
359
|
+
if tool.get("type") == "function":
|
|
360
|
+
f = tool["function"]
|
|
361
|
+
lines.append(f"- {f['name']}: {f.get('description', '')}")
|
|
362
|
+
lines.append("\nTo use a tool, respond with: TOOL_CALL: tool_name(args)")
|
|
363
|
+
return "\n".join(lines)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def get_llm(provider: LLMProvider, model: Optional[str] = None) -> BaseLLM:
|
|
367
|
+
"""Get LLM client for a provider."""
|
|
368
|
+
settings = get_settings()
|
|
369
|
+
|
|
370
|
+
if model is None:
|
|
371
|
+
model = settings.get_model(provider)
|
|
372
|
+
|
|
373
|
+
if provider == LLMProvider.GOOGLE:
|
|
374
|
+
api_key = settings.google_api_key
|
|
375
|
+
if not api_key:
|
|
376
|
+
raise ValueError("Google API key not configured")
|
|
377
|
+
return GoogleLLM(api_key, model)
|
|
378
|
+
|
|
379
|
+
elif provider == LLMProvider.OPENAI:
|
|
380
|
+
api_key = settings.openai_api_key
|
|
381
|
+
if not api_key:
|
|
382
|
+
raise ValueError("OpenAI API key not configured")
|
|
383
|
+
return OpenAILLM(api_key, model)
|
|
384
|
+
|
|
385
|
+
elif provider == LLMProvider.ANTHROPIC:
|
|
386
|
+
api_key = settings.anthropic_api_key
|
|
387
|
+
if not api_key:
|
|
388
|
+
raise ValueError("Anthropic API key not configured")
|
|
389
|
+
return AnthropicLLM(api_key, model)
|
|
390
|
+
|
|
391
|
+
elif provider == LLMProvider.GROQ:
|
|
392
|
+
api_key = settings.groq_api_key
|
|
393
|
+
if not api_key:
|
|
394
|
+
raise ValueError("Groq API key not configured")
|
|
395
|
+
return GroqLLM(api_key, model)
|
|
396
|
+
|
|
397
|
+
elif provider == LLMProvider.OLLAMA:
|
|
398
|
+
return OllamaLLM(settings.ollama_host, model)
|
|
399
|
+
|
|
400
|
+
else:
|
|
401
|
+
raise ValueError(f"Unsupported provider: {provider}")
|