sigma-terminal 3.2.0__py3-none-any.whl → 3.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sigma/__init__.py +2 -2
- sigma/app.py +347 -103
- sigma/cli.py +37 -6
- sigma/config.py +174 -3
- sigma/llm.py +256 -18
- sigma/setup.py +84 -16
- sigma_terminal-3.3.0.dist-info/METADATA +583 -0
- {sigma_terminal-3.2.0.dist-info → sigma_terminal-3.3.0.dist-info}/RECORD +11 -11
- sigma_terminal-3.2.0.dist-info/METADATA +0 -298
- {sigma_terminal-3.2.0.dist-info → sigma_terminal-3.3.0.dist-info}/WHEEL +0 -0
- {sigma_terminal-3.2.0.dist-info → sigma_terminal-3.3.0.dist-info}/entry_points.txt +0 -0
- {sigma_terminal-3.2.0.dist-info → sigma_terminal-3.3.0.dist-info}/licenses/LICENSE +0 -0
sigma/cli.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""CLI entry point for Sigma v3.
|
|
1
|
+
"""CLI entry point for Sigma v3.3.0."""
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import json
|
|
@@ -10,10 +10,13 @@ from rich.table import Table
|
|
|
10
10
|
from rich.panel import Panel
|
|
11
11
|
|
|
12
12
|
from .app import launch
|
|
13
|
-
from .config import
|
|
13
|
+
from .config import (
|
|
14
|
+
get_settings, save_api_key, save_setting, AVAILABLE_MODELS, LLMProvider,
|
|
15
|
+
is_first_run, mark_first_run_complete, detect_lean_installation, detect_ollama
|
|
16
|
+
)
|
|
14
17
|
|
|
15
18
|
|
|
16
|
-
__version__ = "3.
|
|
19
|
+
__version__ = "3.3.0"
|
|
17
20
|
|
|
18
21
|
console = Console()
|
|
19
22
|
|
|
@@ -28,7 +31,7 @@ def show_banner():
|
|
|
28
31
|
[bold white]███████║██║╚██████╔╝██║ ╚═╝ ██║██║ ██║[/bold white]
|
|
29
32
|
[bold white]╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝[/bold white]
|
|
30
33
|
|
|
31
|
-
[dim]v3.
|
|
34
|
+
[dim]v3.3.0[/dim] [bold cyan]σ[/bold cyan] [bold]Finance Research Agent[/bold]
|
|
32
35
|
"""
|
|
33
36
|
console.print(banner)
|
|
34
37
|
|
|
@@ -37,7 +40,7 @@ def main():
|
|
|
37
40
|
"""Main CLI entry point."""
|
|
38
41
|
parser = argparse.ArgumentParser(
|
|
39
42
|
prog="sigma",
|
|
40
|
-
description="Sigma v3.
|
|
43
|
+
description="Sigma v3.3.0 - Finance Research Agent",
|
|
41
44
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
42
45
|
epilog="""
|
|
43
46
|
Examples:
|
|
@@ -126,7 +129,35 @@ Examples:
|
|
|
126
129
|
|
|
127
130
|
if args.setup:
|
|
128
131
|
from .setup import run_setup
|
|
129
|
-
|
|
132
|
+
result = run_setup()
|
|
133
|
+
if result:
|
|
134
|
+
mark_first_run_complete()
|
|
135
|
+
# After manual setup, ask if user wants to launch
|
|
136
|
+
from rich.prompt import Confirm
|
|
137
|
+
if Confirm.ask("\n[bold cyan]σ[/bold cyan] Launch Sigma now?", default=True):
|
|
138
|
+
launch()
|
|
139
|
+
return 0 if result else 1
|
|
140
|
+
|
|
141
|
+
# Auto-launch setup on first run, then start interactive
|
|
142
|
+
if is_first_run():
|
|
143
|
+
console.print("\n[bold cyan]σ[/bold cyan] [bold]Welcome to Sigma![/bold]")
|
|
144
|
+
console.print("[dim]First time setup detected. Launching setup wizard...[/dim]\n")
|
|
145
|
+
from .setup import run_setup
|
|
146
|
+
result = run_setup()
|
|
147
|
+
mark_first_run_complete() # Always mark complete to not ask again
|
|
148
|
+
|
|
149
|
+
if result:
|
|
150
|
+
console.print("\n[bold green]✓ Setup complete![/bold green]")
|
|
151
|
+
console.print("[dim]Launching Sigma...[/dim]\n")
|
|
152
|
+
import time
|
|
153
|
+
time.sleep(1) # Brief pause for user to see message
|
|
154
|
+
launch()
|
|
155
|
+
return 0
|
|
156
|
+
else:
|
|
157
|
+
console.print("\n[yellow]Setup skipped.[/yellow] You can run [bold]sigma --setup[/bold] later.")
|
|
158
|
+
console.print("[dim]Launching Sigma anyway...[/dim]\n")
|
|
159
|
+
launch()
|
|
160
|
+
return 0
|
|
130
161
|
|
|
131
162
|
if args.list_models:
|
|
132
163
|
console.print("\n[bold]Available Models by Provider:[/bold]\n")
|
sigma/config.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
|
-
"""Configuration management for Sigma v3.
|
|
1
|
+
"""Configuration management for Sigma v3.3.0."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
import shutil
|
|
5
|
+
import subprocess
|
|
4
6
|
from enum import Enum
|
|
5
7
|
from pathlib import Path
|
|
6
|
-
from typing import Optional
|
|
8
|
+
from typing import Optional, Tuple
|
|
7
9
|
|
|
8
10
|
from pydantic import Field
|
|
9
11
|
from pydantic_settings import BaseSettings
|
|
10
12
|
|
|
11
13
|
|
|
12
|
-
__version__ = "3.
|
|
14
|
+
__version__ = "3.3.0"
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class LLMProvider(str, Enum):
|
|
@@ -35,6 +37,167 @@ AVAILABLE_MODELS = {
|
|
|
35
37
|
# Config directory
|
|
36
38
|
CONFIG_DIR = Path.home() / ".sigma"
|
|
37
39
|
CONFIG_FILE = CONFIG_DIR / "config.env"
|
|
40
|
+
FIRST_RUN_MARKER = CONFIG_DIR / ".first_run_complete"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def is_first_run() -> bool:
|
|
44
|
+
"""Check if this is the first run of the application."""
|
|
45
|
+
return not FIRST_RUN_MARKER.exists()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def mark_first_run_complete() -> None:
|
|
49
|
+
"""Mark that the first run setup has been completed."""
|
|
50
|
+
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
51
|
+
FIRST_RUN_MARKER.touch()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def detect_lean_installation() -> Tuple[bool, Optional[str], Optional[str]]:
|
|
55
|
+
"""
|
|
56
|
+
Auto-detect LEAN/QuantConnect installation.
|
|
57
|
+
Returns: (is_installed, cli_path, lean_directory)
|
|
58
|
+
"""
|
|
59
|
+
lean_cli_path = None
|
|
60
|
+
lean_directory = None
|
|
61
|
+
|
|
62
|
+
# Check if lean CLI is available in PATH
|
|
63
|
+
lean_cli = shutil.which("lean")
|
|
64
|
+
if lean_cli:
|
|
65
|
+
lean_cli_path = lean_cli
|
|
66
|
+
|
|
67
|
+
# Check common installation paths for LEAN directory
|
|
68
|
+
common_paths = [
|
|
69
|
+
Path.home() / "Lean",
|
|
70
|
+
Path.home() / ".lean",
|
|
71
|
+
Path.home() / "QuantConnect" / "Lean",
|
|
72
|
+
Path("/opt/lean"),
|
|
73
|
+
Path.home() / "Projects" / "Lean",
|
|
74
|
+
Path.home() / ".local" / "share" / "lean",
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
for path in common_paths:
|
|
78
|
+
if path.exists():
|
|
79
|
+
# Check for LEAN directory structure
|
|
80
|
+
if (path / "Launcher").exists() or (path / "Algorithm.Python").exists() or (path / "lean.json").exists():
|
|
81
|
+
lean_directory = str(path)
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
# Check if lean is installed via pip (check both pip and pip3)
|
|
85
|
+
if not lean_cli_path:
|
|
86
|
+
for pip_cmd in ["pip3", "pip"]:
|
|
87
|
+
try:
|
|
88
|
+
result = subprocess.run(
|
|
89
|
+
[pip_cmd, "show", "lean"],
|
|
90
|
+
capture_output=True,
|
|
91
|
+
text=True,
|
|
92
|
+
timeout=10
|
|
93
|
+
)
|
|
94
|
+
if result.returncode == 0:
|
|
95
|
+
# Parse location from pip show output
|
|
96
|
+
for line in result.stdout.split("\n"):
|
|
97
|
+
if line.startswith("Location:"):
|
|
98
|
+
# lean is installed via pip
|
|
99
|
+
lean_cli_path = "lean"
|
|
100
|
+
break
|
|
101
|
+
if lean_cli_path:
|
|
102
|
+
break
|
|
103
|
+
except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
# Also check if lean command works directly
|
|
107
|
+
if not lean_cli_path:
|
|
108
|
+
try:
|
|
109
|
+
result = subprocess.run(
|
|
110
|
+
["lean", "--version"],
|
|
111
|
+
capture_output=True,
|
|
112
|
+
text=True,
|
|
113
|
+
timeout=5
|
|
114
|
+
)
|
|
115
|
+
if result.returncode == 0:
|
|
116
|
+
lean_cli_path = "lean"
|
|
117
|
+
except (FileNotFoundError, subprocess.TimeoutExpired, Exception):
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
is_installed = lean_cli_path is not None or lean_directory is not None
|
|
121
|
+
return is_installed, lean_cli_path, lean_directory
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
async def install_lean_cli() -> Tuple[bool, str]:
|
|
125
|
+
"""
|
|
126
|
+
Install LEAN CLI via pip.
|
|
127
|
+
Returns: (success, message)
|
|
128
|
+
"""
|
|
129
|
+
import asyncio
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
# Try pip3 first, then pip
|
|
133
|
+
for pip_cmd in ["pip3", "pip"]:
|
|
134
|
+
try:
|
|
135
|
+
process = await asyncio.create_subprocess_exec(
|
|
136
|
+
pip_cmd, "install", "lean",
|
|
137
|
+
stdout=asyncio.subprocess.PIPE,
|
|
138
|
+
stderr=asyncio.subprocess.PIPE
|
|
139
|
+
)
|
|
140
|
+
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=120)
|
|
141
|
+
|
|
142
|
+
if process.returncode == 0:
|
|
143
|
+
return True, "LEAN CLI installed successfully!"
|
|
144
|
+
except (FileNotFoundError, asyncio.TimeoutError):
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
return False, "Failed to install LEAN CLI. Please install manually: pip install lean"
|
|
148
|
+
except Exception as e:
|
|
149
|
+
return False, f"Installation error: {str(e)}"
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def install_lean_cli_sync() -> Tuple[bool, str]:
|
|
153
|
+
"""
|
|
154
|
+
Install LEAN CLI via pip (synchronous version).
|
|
155
|
+
Returns: (success, message)
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
# Try pip3 first, then pip
|
|
159
|
+
for pip_cmd in ["pip3", "pip"]:
|
|
160
|
+
try:
|
|
161
|
+
result = subprocess.run(
|
|
162
|
+
[pip_cmd, "install", "lean"],
|
|
163
|
+
capture_output=True,
|
|
164
|
+
text=True,
|
|
165
|
+
timeout=120
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if result.returncode == 0:
|
|
169
|
+
return True, "LEAN CLI installed successfully!"
|
|
170
|
+
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
return False, "Failed to install LEAN CLI. Please install manually: pip install lean"
|
|
174
|
+
except Exception as e:
|
|
175
|
+
return False, f"Installation error: {str(e)}"
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def detect_ollama() -> Tuple[bool, Optional[str]]:
|
|
179
|
+
"""
|
|
180
|
+
Auto-detect Ollama installation and available models.
|
|
181
|
+
Returns: (is_running, host_url)
|
|
182
|
+
"""
|
|
183
|
+
import urllib.request
|
|
184
|
+
import urllib.error
|
|
185
|
+
|
|
186
|
+
hosts_to_check = [
|
|
187
|
+
"http://localhost:11434",
|
|
188
|
+
"http://127.0.0.1:11434",
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
for host in hosts_to_check:
|
|
192
|
+
try:
|
|
193
|
+
req = urllib.request.Request(f"{host}/api/tags", method="GET")
|
|
194
|
+
with urllib.request.urlopen(req, timeout=2) as resp:
|
|
195
|
+
if resp.status == 200:
|
|
196
|
+
return True, host
|
|
197
|
+
except (urllib.error.URLError, OSError):
|
|
198
|
+
continue
|
|
199
|
+
|
|
200
|
+
return False, None
|
|
38
201
|
|
|
39
202
|
|
|
40
203
|
class Settings(BaseSettings):
|
|
@@ -62,6 +225,11 @@ class Settings(BaseSettings):
|
|
|
62
225
|
# Ollama settings
|
|
63
226
|
ollama_host: str = "http://localhost:11434"
|
|
64
227
|
|
|
228
|
+
# LEAN settings
|
|
229
|
+
lean_cli_path: Optional[str] = Field(default=None, alias="LEAN_CLI_PATH")
|
|
230
|
+
lean_directory: Optional[str] = Field(default=None, alias="LEAN_DIRECTORY")
|
|
231
|
+
lean_enabled: bool = Field(default=False, alias="LEAN_ENABLED")
|
|
232
|
+
|
|
65
233
|
# Data API keys
|
|
66
234
|
alpha_vantage_api_key: str = "6ER128DD3NQUPTVC" # Built-in free key
|
|
67
235
|
exa_api_key: Optional[str] = None
|
|
@@ -183,6 +351,9 @@ def save_setting(key: str, value: str) -> None:
|
|
|
183
351
|
"output_dir": "OUTPUT_DIR",
|
|
184
352
|
"cache_enabled": "CACHE_ENABLED",
|
|
185
353
|
"lean_cli_path": "LEAN_CLI_PATH",
|
|
354
|
+
"lean_directory": "LEAN_DIRECTORY",
|
|
355
|
+
"lean_enabled": "LEAN_ENABLED",
|
|
356
|
+
"ollama_host": "OLLAMA_HOST",
|
|
186
357
|
}
|
|
187
358
|
|
|
188
359
|
config_key = setting_map.get(key, key.upper())
|
sigma/llm.py
CHANGED
|
@@ -1,15 +1,68 @@
|
|
|
1
1
|
"""LLM client implementations for all providers."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import json
|
|
5
|
+
import re
|
|
6
|
+
import time
|
|
4
7
|
from abc import ABC, abstractmethod
|
|
5
8
|
from typing import Any, AsyncIterator, Callable, Optional
|
|
6
9
|
|
|
7
10
|
from sigma.config import LLMProvider, get_settings
|
|
8
11
|
|
|
9
12
|
|
|
13
|
+
# Rate limiting configuration
|
|
14
|
+
class RateLimiter:
|
|
15
|
+
"""Simple rate limiter to prevent API flooding."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, requests_per_minute: int = 10, min_interval: float = 1.0):
|
|
18
|
+
self.requests_per_minute = requests_per_minute
|
|
19
|
+
self.min_interval = min_interval
|
|
20
|
+
self.last_request_time = 0
|
|
21
|
+
self.request_count = 0
|
|
22
|
+
self.window_start = time.time()
|
|
23
|
+
|
|
24
|
+
async def wait(self):
|
|
25
|
+
"""Wait if necessary to respect rate limits."""
|
|
26
|
+
current_time = time.time()
|
|
27
|
+
|
|
28
|
+
# Reset window if a minute has passed
|
|
29
|
+
if current_time - self.window_start >= 60:
|
|
30
|
+
self.window_start = current_time
|
|
31
|
+
self.request_count = 0
|
|
32
|
+
|
|
33
|
+
# Check if we've hit the rate limit
|
|
34
|
+
if self.request_count >= self.requests_per_minute:
|
|
35
|
+
wait_time = 60 - (current_time - self.window_start)
|
|
36
|
+
if wait_time > 0:
|
|
37
|
+
await asyncio.sleep(wait_time)
|
|
38
|
+
self.window_start = time.time()
|
|
39
|
+
self.request_count = 0
|
|
40
|
+
|
|
41
|
+
# Ensure minimum interval between requests
|
|
42
|
+
time_since_last = current_time - self.last_request_time
|
|
43
|
+
if time_since_last < self.min_interval:
|
|
44
|
+
await asyncio.sleep(self.min_interval - time_since_last)
|
|
45
|
+
|
|
46
|
+
self.last_request_time = time.time()
|
|
47
|
+
self.request_count += 1
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# Global rate limiters per provider
|
|
51
|
+
_rate_limiters = {
|
|
52
|
+
"google": RateLimiter(requests_per_minute=15, min_interval=0.5),
|
|
53
|
+
"openai": RateLimiter(requests_per_minute=20, min_interval=0.3),
|
|
54
|
+
"anthropic": RateLimiter(requests_per_minute=15, min_interval=0.5),
|
|
55
|
+
"groq": RateLimiter(requests_per_minute=30, min_interval=0.2),
|
|
56
|
+
"xai": RateLimiter(requests_per_minute=10, min_interval=1.0),
|
|
57
|
+
"ollama": RateLimiter(requests_per_minute=60, min_interval=0.1), # Local, can be faster
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
10
61
|
class BaseLLM(ABC):
|
|
11
62
|
"""Base class for LLM clients."""
|
|
12
63
|
|
|
64
|
+
provider_name: str = "base"
|
|
65
|
+
|
|
13
66
|
@abstractmethod
|
|
14
67
|
async def generate(
|
|
15
68
|
self,
|
|
@@ -19,11 +72,19 @@ class BaseLLM(ABC):
|
|
|
19
72
|
) -> str:
|
|
20
73
|
"""Generate a response."""
|
|
21
74
|
pass
|
|
75
|
+
|
|
76
|
+
async def _rate_limit(self):
|
|
77
|
+
"""Apply rate limiting."""
|
|
78
|
+
limiter = _rate_limiters.get(self.provider_name)
|
|
79
|
+
if limiter:
|
|
80
|
+
await limiter.wait()
|
|
22
81
|
|
|
23
82
|
|
|
24
83
|
class GoogleLLM(BaseLLM):
|
|
25
84
|
"""Google Gemini client."""
|
|
26
85
|
|
|
86
|
+
provider_name = "google"
|
|
87
|
+
|
|
27
88
|
def __init__(self, api_key: str, model: str):
|
|
28
89
|
from google import genai
|
|
29
90
|
self.client = genai.Client(api_key=api_key)
|
|
@@ -35,6 +96,7 @@ class GoogleLLM(BaseLLM):
|
|
|
35
96
|
tools: Optional[list[dict]] = None,
|
|
36
97
|
on_tool_call: Optional[Callable] = None,
|
|
37
98
|
) -> str:
|
|
99
|
+
await self._rate_limit()
|
|
38
100
|
from google.genai import types
|
|
39
101
|
|
|
40
102
|
# Extract system prompt and build contents
|
|
@@ -159,6 +221,8 @@ class GoogleLLM(BaseLLM):
|
|
|
159
221
|
class OpenAILLM(BaseLLM):
|
|
160
222
|
"""OpenAI client."""
|
|
161
223
|
|
|
224
|
+
provider_name = "openai"
|
|
225
|
+
|
|
162
226
|
def __init__(self, api_key: str, model: str):
|
|
163
227
|
from openai import AsyncOpenAI
|
|
164
228
|
self.client = AsyncOpenAI(api_key=api_key)
|
|
@@ -170,6 +234,7 @@ class OpenAILLM(BaseLLM):
|
|
|
170
234
|
tools: Optional[list[dict]] = None,
|
|
171
235
|
on_tool_call: Optional[Callable] = None,
|
|
172
236
|
) -> str:
|
|
237
|
+
await self._rate_limit()
|
|
173
238
|
kwargs = {
|
|
174
239
|
"model": self.model,
|
|
175
240
|
"messages": messages,
|
|
@@ -204,6 +269,8 @@ class OpenAILLM(BaseLLM):
|
|
|
204
269
|
class AnthropicLLM(BaseLLM):
|
|
205
270
|
"""Anthropic Claude client."""
|
|
206
271
|
|
|
272
|
+
provider_name = "anthropic"
|
|
273
|
+
|
|
207
274
|
def __init__(self, api_key: str, model: str):
|
|
208
275
|
from anthropic import AsyncAnthropic
|
|
209
276
|
self.client = AsyncAnthropic(api_key=api_key)
|
|
@@ -215,6 +282,7 @@ class AnthropicLLM(BaseLLM):
|
|
|
215
282
|
tools: Optional[list[dict]] = None,
|
|
216
283
|
on_tool_call: Optional[Callable] = None,
|
|
217
284
|
) -> str:
|
|
285
|
+
await self._rate_limit()
|
|
218
286
|
# Extract system message
|
|
219
287
|
system = ""
|
|
220
288
|
filtered_messages = []
|
|
@@ -277,6 +345,8 @@ class AnthropicLLM(BaseLLM):
|
|
|
277
345
|
class GroqLLM(BaseLLM):
|
|
278
346
|
"""Groq client."""
|
|
279
347
|
|
|
348
|
+
provider_name = "groq"
|
|
349
|
+
|
|
280
350
|
def __init__(self, api_key: str, model: str):
|
|
281
351
|
from groq import AsyncGroq
|
|
282
352
|
self.client = AsyncGroq(api_key=api_key)
|
|
@@ -288,6 +358,7 @@ class GroqLLM(BaseLLM):
|
|
|
288
358
|
tools: Optional[list[dict]] = None,
|
|
289
359
|
on_tool_call: Optional[Callable] = None,
|
|
290
360
|
) -> str:
|
|
361
|
+
await self._rate_limit()
|
|
291
362
|
kwargs = {
|
|
292
363
|
"model": self.model,
|
|
293
364
|
"messages": messages,
|
|
@@ -319,7 +390,9 @@ class GroqLLM(BaseLLM):
|
|
|
319
390
|
|
|
320
391
|
|
|
321
392
|
class OllamaLLM(BaseLLM):
|
|
322
|
-
"""Ollama local client."""
|
|
393
|
+
"""Ollama local client with native tool call support."""
|
|
394
|
+
|
|
395
|
+
provider_name = "ollama"
|
|
323
396
|
|
|
324
397
|
def __init__(self, host: str, model: str):
|
|
325
398
|
self.host = host.rstrip("/")
|
|
@@ -331,35 +404,200 @@ class OllamaLLM(BaseLLM):
|
|
|
331
404
|
tools: Optional[list[dict]] = None,
|
|
332
405
|
on_tool_call: Optional[Callable] = None,
|
|
333
406
|
) -> str:
|
|
407
|
+
await self._rate_limit()
|
|
334
408
|
import aiohttp
|
|
335
409
|
|
|
336
|
-
#
|
|
410
|
+
# Convert tools to Ollama format
|
|
411
|
+
ollama_tools = None
|
|
337
412
|
if tools:
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
413
|
+
ollama_tools = []
|
|
414
|
+
for tool in tools:
|
|
415
|
+
if tool.get("type") == "function":
|
|
416
|
+
f = tool["function"]
|
|
417
|
+
ollama_tools.append({
|
|
418
|
+
"type": "function",
|
|
419
|
+
"function": {
|
|
420
|
+
"name": f["name"],
|
|
421
|
+
"description": f.get("description", ""),
|
|
422
|
+
"parameters": f.get("parameters", {})
|
|
423
|
+
}
|
|
424
|
+
})
|
|
425
|
+
|
|
426
|
+
request_body = {
|
|
427
|
+
"model": self.model,
|
|
428
|
+
"messages": messages,
|
|
429
|
+
"stream": False
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
if ollama_tools:
|
|
433
|
+
request_body["tools"] = ollama_tools
|
|
346
434
|
|
|
347
435
|
async with aiohttp.ClientSession() as session:
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
436
|
+
try:
|
|
437
|
+
async with session.post(
|
|
438
|
+
f"{self.host}/api/chat",
|
|
439
|
+
json=request_body,
|
|
440
|
+
timeout=aiohttp.ClientTimeout(total=120)
|
|
441
|
+
) as resp:
|
|
442
|
+
if resp.status != 200:
|
|
443
|
+
error_text = await resp.text()
|
|
444
|
+
return f"Ollama error: {error_text}"
|
|
445
|
+
|
|
446
|
+
data = await resp.json()
|
|
447
|
+
message = data.get("message", {})
|
|
448
|
+
|
|
449
|
+
# Check for tool calls in response
|
|
450
|
+
tool_calls = message.get("tool_calls", [])
|
|
451
|
+
|
|
452
|
+
if tool_calls and on_tool_call:
|
|
453
|
+
# Process tool calls
|
|
454
|
+
updated_messages = messages.copy()
|
|
455
|
+
updated_messages.append(message)
|
|
456
|
+
|
|
457
|
+
for tc in tool_calls:
|
|
458
|
+
func = tc.get("function", {})
|
|
459
|
+
tool_name = func.get("name", "")
|
|
460
|
+
tool_args = func.get("arguments", {})
|
|
461
|
+
|
|
462
|
+
# If arguments is a string, parse it
|
|
463
|
+
if isinstance(tool_args, str):
|
|
464
|
+
try:
|
|
465
|
+
tool_args = json.loads(tool_args)
|
|
466
|
+
except json.JSONDecodeError:
|
|
467
|
+
tool_args = {}
|
|
468
|
+
|
|
469
|
+
# Execute the tool
|
|
470
|
+
result = await on_tool_call(tool_name, tool_args)
|
|
471
|
+
|
|
472
|
+
# Add tool result to messages
|
|
473
|
+
updated_messages.append({
|
|
474
|
+
"role": "tool",
|
|
475
|
+
"content": json.dumps(result) if not isinstance(result, str) else result
|
|
476
|
+
})
|
|
477
|
+
|
|
478
|
+
# Get final response with tool results
|
|
479
|
+
return await self._continue_with_tools(session, updated_messages, ollama_tools, on_tool_call)
|
|
480
|
+
|
|
481
|
+
# Check for text-based tool calls (fallback for older models)
|
|
482
|
+
content = message.get("content", "")
|
|
483
|
+
if "TOOL_CALL:" in content and on_tool_call:
|
|
484
|
+
result = await self._parse_text_tool_call(content, on_tool_call)
|
|
485
|
+
if result:
|
|
486
|
+
return result
|
|
487
|
+
|
|
488
|
+
return content
|
|
489
|
+
|
|
490
|
+
except aiohttp.ClientError as e:
|
|
491
|
+
return f"Connection error: {e}. Is Ollama running?"
|
|
492
|
+
except asyncio.TimeoutError:
|
|
493
|
+
return "Request timed out. Try a simpler query or check Ollama status."
|
|
494
|
+
|
|
495
|
+
async def _continue_with_tools(
|
|
496
|
+
self,
|
|
497
|
+
session,
|
|
498
|
+
messages: list[dict],
|
|
499
|
+
tools: Optional[list[dict]],
|
|
500
|
+
on_tool_call: Optional[Callable],
|
|
501
|
+
depth: int = 0
|
|
502
|
+
) -> str:
|
|
503
|
+
"""Continue conversation after tool calls."""
|
|
504
|
+
import aiohttp
|
|
505
|
+
|
|
506
|
+
if depth > 5: # Prevent infinite loops
|
|
507
|
+
return "Maximum tool call depth reached."
|
|
508
|
+
|
|
509
|
+
request_body = {
|
|
510
|
+
"model": self.model,
|
|
511
|
+
"messages": messages,
|
|
512
|
+
"stream": False
|
|
513
|
+
}
|
|
514
|
+
if tools:
|
|
515
|
+
request_body["tools"] = tools
|
|
516
|
+
|
|
517
|
+
async with session.post(
|
|
518
|
+
f"{self.host}/api/chat",
|
|
519
|
+
json=request_body,
|
|
520
|
+
timeout=aiohttp.ClientTimeout(total=120)
|
|
521
|
+
) as resp:
|
|
522
|
+
data = await resp.json()
|
|
523
|
+
message = data.get("message", {})
|
|
524
|
+
|
|
525
|
+
# Check for more tool calls
|
|
526
|
+
tool_calls = message.get("tool_calls", [])
|
|
527
|
+
if tool_calls and on_tool_call:
|
|
528
|
+
updated_messages = messages.copy()
|
|
529
|
+
updated_messages.append(message)
|
|
530
|
+
|
|
531
|
+
for tc in tool_calls:
|
|
532
|
+
func = tc.get("function", {})
|
|
533
|
+
tool_name = func.get("name", "")
|
|
534
|
+
tool_args = func.get("arguments", {})
|
|
535
|
+
|
|
536
|
+
if isinstance(tool_args, str):
|
|
537
|
+
try:
|
|
538
|
+
tool_args = json.loads(tool_args)
|
|
539
|
+
except json.JSONDecodeError:
|
|
540
|
+
tool_args = {}
|
|
541
|
+
|
|
542
|
+
result = await on_tool_call(tool_name, tool_args)
|
|
543
|
+
updated_messages.append({
|
|
544
|
+
"role": "tool",
|
|
545
|
+
"content": json.dumps(result) if not isinstance(result, str) else result
|
|
546
|
+
})
|
|
547
|
+
|
|
548
|
+
return await self._continue_with_tools(session, updated_messages, tools, on_tool_call, depth + 1)
|
|
549
|
+
|
|
550
|
+
return message.get("content", "")
|
|
551
|
+
|
|
552
|
+
async def _parse_text_tool_call(self, content: str, on_tool_call: Callable) -> Optional[str]:
|
|
553
|
+
"""Parse text-based tool calls for older models."""
|
|
554
|
+
# Pattern: TOOL_CALL: tool_name({"arg": "value"}) or TOOL_CALL: tool_name(arg=value)
|
|
555
|
+
pattern = r'TOOL_CALL:\s*(\w+)\s*\(([^)]*)\)'
|
|
556
|
+
match = re.search(pattern, content)
|
|
557
|
+
|
|
558
|
+
if not match:
|
|
559
|
+
return None
|
|
560
|
+
|
|
561
|
+
tool_name = match.group(1)
|
|
562
|
+
args_str = match.group(2).strip()
|
|
563
|
+
|
|
564
|
+
# Try to parse arguments
|
|
565
|
+
try:
|
|
566
|
+
if args_str.startswith("{"):
|
|
567
|
+
args = json.loads(args_str)
|
|
568
|
+
else:
|
|
569
|
+
# Parse key=value format
|
|
570
|
+
args = {}
|
|
571
|
+
for part in args_str.split(","):
|
|
572
|
+
if "=" in part:
|
|
573
|
+
k, v = part.split("=", 1)
|
|
574
|
+
args[k.strip()] = v.strip().strip('"\'')
|
|
575
|
+
except:
|
|
576
|
+
args = {"symbol": args_str} if args_str else {}
|
|
577
|
+
|
|
578
|
+
# Execute tool
|
|
579
|
+
result = await on_tool_call(tool_name, args)
|
|
580
|
+
|
|
581
|
+
# Format result for response
|
|
582
|
+
if isinstance(result, dict):
|
|
583
|
+
result_str = json.dumps(result, indent=2)
|
|
584
|
+
else:
|
|
585
|
+
result_str = str(result)
|
|
586
|
+
|
|
587
|
+
# Return combined response
|
|
588
|
+
return f"Tool result:\n```json\n{result_str}\n```"
|
|
354
589
|
|
|
355
590
|
def _format_tools_for_prompt(self, tools: list[dict]) -> str:
|
|
356
|
-
"""Format tools as text for prompt injection."""
|
|
591
|
+
"""Format tools as text for prompt injection (legacy fallback)."""
|
|
357
592
|
lines = ["You have access to these tools:"]
|
|
358
593
|
for tool in tools:
|
|
359
594
|
if tool.get("type") == "function":
|
|
360
595
|
f = tool["function"]
|
|
361
|
-
|
|
596
|
+
params = f.get("parameters", {}).get("properties", {})
|
|
597
|
+
param_str = ", ".join(params.keys()) if params else ""
|
|
598
|
+
lines.append(f"- {f['name']}({param_str}): {f.get('description', '')}")
|
|
362
599
|
lines.append("\nTo use a tool, respond with: TOOL_CALL: tool_name(args)")
|
|
600
|
+
lines.append("Example: TOOL_CALL: get_stock_quote(symbol=\"AAPL\")")
|
|
363
601
|
return "\n".join(lines)
|
|
364
602
|
|
|
365
603
|
|