sqlsaber 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/__init__.py +3 -0
- sqlsaber/__main__.py +4 -0
- sqlsaber/agents/__init__.py +9 -0
- sqlsaber/agents/anthropic.py +451 -0
- sqlsaber/agents/base.py +67 -0
- sqlsaber/agents/streaming.py +26 -0
- sqlsaber/cli/__init__.py +7 -0
- sqlsaber/cli/commands.py +132 -0
- sqlsaber/cli/database.py +275 -0
- sqlsaber/cli/display.py +207 -0
- sqlsaber/cli/interactive.py +93 -0
- sqlsaber/cli/memory.py +239 -0
- sqlsaber/cli/models.py +231 -0
- sqlsaber/cli/streaming.py +94 -0
- sqlsaber/config/__init__.py +7 -0
- sqlsaber/config/api_keys.py +102 -0
- sqlsaber/config/database.py +252 -0
- sqlsaber/config/settings.py +115 -0
- sqlsaber/database/__init__.py +9 -0
- sqlsaber/database/connection.py +187 -0
- sqlsaber/database/schema.py +678 -0
- sqlsaber/memory/__init__.py +1 -0
- sqlsaber/memory/manager.py +77 -0
- sqlsaber/memory/storage.py +176 -0
- sqlsaber/models/__init__.py +13 -0
- sqlsaber/models/events.py +28 -0
- sqlsaber/models/types.py +40 -0
- sqlsaber-0.1.0.dist-info/METADATA +168 -0
- sqlsaber-0.1.0.dist-info/RECORD +32 -0
- sqlsaber-0.1.0.dist-info/WHEEL +4 -0
- sqlsaber-0.1.0.dist-info/entry_points.txt +4 -0
- sqlsaber-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Interactive mode handling for the CLI."""
|
|
2
|
+
|
|
3
|
+
import questionary
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
from rich.panel import Panel
|
|
6
|
+
|
|
7
|
+
from sqlsaber.agents.base import BaseSQLAgent
|
|
8
|
+
from sqlsaber.cli.display import DisplayManager
|
|
9
|
+
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class InteractiveSession:
|
|
13
|
+
"""Manages interactive CLI sessions."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, console: Console, agent: BaseSQLAgent):
|
|
16
|
+
self.console = console
|
|
17
|
+
self.agent = agent
|
|
18
|
+
self.display = DisplayManager(console)
|
|
19
|
+
self.streaming_handler = StreamingQueryHandler(console)
|
|
20
|
+
|
|
21
|
+
def show_welcome_message(self):
|
|
22
|
+
"""Display welcome message for interactive mode."""
|
|
23
|
+
self.console.print(
|
|
24
|
+
Panel.fit(
|
|
25
|
+
"[bold green]SQLSaber - Use the agent Luke![/bold green]\n\n"
|
|
26
|
+
"Type your queries in natural language.\n\n"
|
|
27
|
+
"Press Esc-Enter or Meta-Enter to submit your query.\n\n"
|
|
28
|
+
"Type 'exit' or 'quit' to leave.",
|
|
29
|
+
border_style="green",
|
|
30
|
+
)
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.console.print(
|
|
34
|
+
"[dim]Commands: 'clear' to reset conversation, 'exit' or 'quit' to leave[/dim]"
|
|
35
|
+
)
|
|
36
|
+
self.console.print(
|
|
37
|
+
"[dim]Memory: Start a message with '#' to add it as a memory for this database[/dim]\n"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
async def run(self):
|
|
41
|
+
"""Run the interactive session loop."""
|
|
42
|
+
self.show_welcome_message()
|
|
43
|
+
|
|
44
|
+
while True:
|
|
45
|
+
try:
|
|
46
|
+
user_query = await questionary.text(
|
|
47
|
+
">",
|
|
48
|
+
qmark="",
|
|
49
|
+
multiline=True,
|
|
50
|
+
instruction="",
|
|
51
|
+
).ask_async()
|
|
52
|
+
|
|
53
|
+
if user_query.lower() in ["exit", "quit", "q"]:
|
|
54
|
+
break
|
|
55
|
+
|
|
56
|
+
if user_query.lower() == "clear":
|
|
57
|
+
self.agent.clear_history()
|
|
58
|
+
self.console.print("[green]Conversation history cleared.[/green]\n")
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
if memory_text := user_query.strip():
|
|
62
|
+
# Check if query starts with # for memory addition
|
|
63
|
+
if memory_text.startswith("#"):
|
|
64
|
+
memory_content = memory_text[1:].strip() # Remove # and trim
|
|
65
|
+
if memory_content:
|
|
66
|
+
# Add memory
|
|
67
|
+
memory_id = self.agent.add_memory(memory_content)
|
|
68
|
+
if memory_id:
|
|
69
|
+
self.console.print(
|
|
70
|
+
f"[green]✓ Memory added:[/green] {memory_content}"
|
|
71
|
+
)
|
|
72
|
+
self.console.print(
|
|
73
|
+
f"[dim]Memory ID: {memory_id}[/dim]\n"
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
self.console.print(
|
|
77
|
+
"[yellow]Could not add memory (no database context)[/yellow]\n"
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
self.console.print(
|
|
81
|
+
"[yellow]Empty memory content after '#'[/yellow]\n"
|
|
82
|
+
)
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
await self.streaming_handler.execute_streaming_query(
|
|
86
|
+
user_query, self.agent
|
|
87
|
+
)
|
|
88
|
+
self.display.show_newline() # Empty line for readability
|
|
89
|
+
|
|
90
|
+
except KeyboardInterrupt:
|
|
91
|
+
self.console.print("\n[yellow]Use 'exit' or 'quit' to leave.[/yellow]")
|
|
92
|
+
except Exception as e:
|
|
93
|
+
self.console.print(f"[bold red]Error:[/bold red] {str(e)}")
|
sqlsaber/cli/memory.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Memory management CLI commands."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import typer
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
|
|
9
|
+
from sqlsaber.config.database import DatabaseConfigManager
|
|
10
|
+
from sqlsaber.memory.manager import MemoryManager
|
|
11
|
+
|
|
12
|
+
# Global instances for CLI commands
|
|
13
|
+
console = Console()
|
|
14
|
+
config_manager = DatabaseConfigManager()
|
|
15
|
+
memory_manager = MemoryManager()
|
|
16
|
+
|
|
17
|
+
# Create the memory management CLI app
|
|
18
|
+
memory_app = typer.Typer(
|
|
19
|
+
name="memory",
|
|
20
|
+
help="Manage database-specific memories",
|
|
21
|
+
add_completion=True,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_database_name(database: Optional[str] = None) -> str:
|
|
26
|
+
"""Get the database name to use, either specified or default."""
|
|
27
|
+
if database:
|
|
28
|
+
db_config = config_manager.get_database(database)
|
|
29
|
+
if not db_config:
|
|
30
|
+
console.print(
|
|
31
|
+
f"[bold red]Error:[/bold red] Database connection '{database}' not found."
|
|
32
|
+
)
|
|
33
|
+
raise typer.Exit(1)
|
|
34
|
+
return database
|
|
35
|
+
else:
|
|
36
|
+
db_config = config_manager.get_default_database()
|
|
37
|
+
if not db_config:
|
|
38
|
+
console.print(
|
|
39
|
+
"[bold red]Error:[/bold red] No database connections configured."
|
|
40
|
+
)
|
|
41
|
+
console.print("Use 'sqlsaber db add <name>' to add a database connection.")
|
|
42
|
+
raise typer.Exit(1)
|
|
43
|
+
return db_config.name
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@memory_app.command("add")
|
|
47
|
+
def add_memory(
|
|
48
|
+
content: str = typer.Argument(..., help="Memory content to add"),
|
|
49
|
+
database: Optional[str] = typer.Option(
|
|
50
|
+
None,
|
|
51
|
+
"--database",
|
|
52
|
+
"-d",
|
|
53
|
+
help="Database connection name (uses default if not specified)",
|
|
54
|
+
),
|
|
55
|
+
):
|
|
56
|
+
"""Add a new memory for the specified database."""
|
|
57
|
+
database_name = _get_database_name(database)
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
memory = memory_manager.add_memory(database_name, content)
|
|
61
|
+
console.print(f"[green]✓ Memory added for database '{database_name}'[/green]")
|
|
62
|
+
console.print(f"[dim]Memory ID:[/dim] {memory.id}")
|
|
63
|
+
console.print(f"[dim]Content:[/dim] {memory.content}")
|
|
64
|
+
except Exception as e:
|
|
65
|
+
console.print(f"[bold red]Error adding memory:[/bold red] {e}")
|
|
66
|
+
raise typer.Exit(1)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@memory_app.command("list")
|
|
70
|
+
def list_memories(
|
|
71
|
+
database: Optional[str] = typer.Option(
|
|
72
|
+
None,
|
|
73
|
+
"--database",
|
|
74
|
+
"-d",
|
|
75
|
+
help="Database connection name (uses default if not specified)",
|
|
76
|
+
),
|
|
77
|
+
):
|
|
78
|
+
"""List all memories for the specified database."""
|
|
79
|
+
database_name = _get_database_name(database)
|
|
80
|
+
|
|
81
|
+
memories = memory_manager.get_memories(database_name)
|
|
82
|
+
|
|
83
|
+
if not memories:
|
|
84
|
+
console.print(
|
|
85
|
+
f"[yellow]No memories found for database '{database_name}'[/yellow]"
|
|
86
|
+
)
|
|
87
|
+
console.print("Use 'sqlsaber memory add \"<content>\"' to add memories")
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
table = Table(title=f"Memories for Database: {database_name}")
|
|
91
|
+
table.add_column("ID", style="cyan", width=36)
|
|
92
|
+
table.add_column("Content", style="white")
|
|
93
|
+
table.add_column("Created", style="dim")
|
|
94
|
+
|
|
95
|
+
for memory in memories:
|
|
96
|
+
# Truncate content if it's too long for display
|
|
97
|
+
display_content = memory.content
|
|
98
|
+
if len(display_content) > 80:
|
|
99
|
+
display_content = display_content[:77] + "..."
|
|
100
|
+
|
|
101
|
+
table.add_row(memory.id, display_content, memory.formatted_timestamp())
|
|
102
|
+
|
|
103
|
+
console.print(table)
|
|
104
|
+
console.print(f"\n[dim]Total memories: {len(memories)}[/dim]")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@memory_app.command("show")
|
|
108
|
+
def show_memory(
|
|
109
|
+
memory_id: str = typer.Argument(..., help="Memory ID to show"),
|
|
110
|
+
database: Optional[str] = typer.Option(
|
|
111
|
+
None,
|
|
112
|
+
"--database",
|
|
113
|
+
"-d",
|
|
114
|
+
help="Database connection name (uses default if not specified)",
|
|
115
|
+
),
|
|
116
|
+
):
|
|
117
|
+
"""Show the full content of a specific memory."""
|
|
118
|
+
database_name = _get_database_name(database)
|
|
119
|
+
|
|
120
|
+
memory = memory_manager.get_memory_by_id(database_name, memory_id)
|
|
121
|
+
|
|
122
|
+
if not memory:
|
|
123
|
+
console.print(
|
|
124
|
+
f"[bold red]Error:[/bold red] Memory with ID '{memory_id}' not found for database '{database_name}'"
|
|
125
|
+
)
|
|
126
|
+
raise typer.Exit(1)
|
|
127
|
+
|
|
128
|
+
console.print(f"[bold]Memory ID:[/bold] {memory.id}")
|
|
129
|
+
console.print(f"[bold]Database:[/bold] {memory.database}")
|
|
130
|
+
console.print(f"[bold]Created:[/bold] {memory.formatted_timestamp()}")
|
|
131
|
+
console.print("[bold]Content:[/bold]")
|
|
132
|
+
console.print(f"{memory.content}")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@memory_app.command("remove")
|
|
136
|
+
def remove_memory(
|
|
137
|
+
memory_id: str = typer.Argument(..., help="Memory ID to remove"),
|
|
138
|
+
database: Optional[str] = typer.Option(
|
|
139
|
+
None,
|
|
140
|
+
"--database",
|
|
141
|
+
"-d",
|
|
142
|
+
help="Database connection name (uses default if not specified)",
|
|
143
|
+
),
|
|
144
|
+
):
|
|
145
|
+
"""Remove a specific memory by ID."""
|
|
146
|
+
database_name = _get_database_name(database)
|
|
147
|
+
|
|
148
|
+
# First check if memory exists
|
|
149
|
+
memory = memory_manager.get_memory_by_id(database_name, memory_id)
|
|
150
|
+
if not memory:
|
|
151
|
+
console.print(
|
|
152
|
+
f"[bold red]Error:[/bold red] Memory with ID '{memory_id}' not found for database '{database_name}'"
|
|
153
|
+
)
|
|
154
|
+
raise typer.Exit(1)
|
|
155
|
+
|
|
156
|
+
# Show memory content before removal
|
|
157
|
+
console.print("[yellow]Removing memory:[/yellow]")
|
|
158
|
+
console.print(f"[dim]Content:[/dim] {memory.content}")
|
|
159
|
+
|
|
160
|
+
if memory_manager.remove_memory(database_name, memory_id):
|
|
161
|
+
console.print(
|
|
162
|
+
f"[green]✓ Memory removed from database '{database_name}'[/green]"
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
console.print(
|
|
166
|
+
f"[bold red]Error:[/bold red] Failed to remove memory '{memory_id}'"
|
|
167
|
+
)
|
|
168
|
+
raise typer.Exit(1)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@memory_app.command("clear")
|
|
172
|
+
def clear_memories(
|
|
173
|
+
database: Optional[str] = typer.Option(
|
|
174
|
+
None,
|
|
175
|
+
"--database",
|
|
176
|
+
"-d",
|
|
177
|
+
help="Database connection name (uses default if not specified)",
|
|
178
|
+
),
|
|
179
|
+
force: bool = typer.Option(
|
|
180
|
+
False,
|
|
181
|
+
"--force",
|
|
182
|
+
"-f",
|
|
183
|
+
help="Skip confirmation prompt",
|
|
184
|
+
),
|
|
185
|
+
):
|
|
186
|
+
"""Clear all memories for the specified database."""
|
|
187
|
+
database_name = _get_database_name(database)
|
|
188
|
+
|
|
189
|
+
# Count memories first
|
|
190
|
+
memories_count = len(memory_manager.get_memories(database_name))
|
|
191
|
+
|
|
192
|
+
if memories_count == 0:
|
|
193
|
+
console.print(
|
|
194
|
+
f"[yellow]No memories to clear for database '{database_name}'[/yellow]"
|
|
195
|
+
)
|
|
196
|
+
return
|
|
197
|
+
|
|
198
|
+
if not force:
|
|
199
|
+
# Show confirmation
|
|
200
|
+
console.print(
|
|
201
|
+
f"[yellow]About to clear {memories_count} memories for database '{database_name}'[/yellow]"
|
|
202
|
+
)
|
|
203
|
+
confirm = typer.confirm("Are you sure you want to proceed?")
|
|
204
|
+
if not confirm:
|
|
205
|
+
console.print("Operation cancelled")
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
cleared_count = memory_manager.clear_memories(database_name)
|
|
209
|
+
console.print(
|
|
210
|
+
f"[green]✓ Cleared {cleared_count} memories for database '{database_name}'[/green]"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@memory_app.command("summary")
|
|
215
|
+
def memory_summary(
|
|
216
|
+
database: Optional[str] = typer.Option(
|
|
217
|
+
None,
|
|
218
|
+
"--database",
|
|
219
|
+
"-d",
|
|
220
|
+
help="Database connection name (uses default if not specified)",
|
|
221
|
+
),
|
|
222
|
+
):
|
|
223
|
+
"""Show memory summary for the specified database."""
|
|
224
|
+
database_name = _get_database_name(database)
|
|
225
|
+
|
|
226
|
+
summary = memory_manager.get_memories_summary(database_name)
|
|
227
|
+
|
|
228
|
+
console.print(f"[bold]Memory Summary for Database: {summary['database']}[/bold]")
|
|
229
|
+
console.print(f"[dim]Total memories:[/dim] {summary['total_memories']}")
|
|
230
|
+
|
|
231
|
+
if summary["total_memories"] > 0:
|
|
232
|
+
console.print("\n[bold]Recent memories:[/bold]")
|
|
233
|
+
for memory in summary["memories"][-5:]: # Show last 5 memories
|
|
234
|
+
console.print(f"[dim]{memory['timestamp']}[/dim] - {memory['content']}")
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def create_memory_app() -> typer.Typer:
|
|
238
|
+
"""Return the memory management CLI app."""
|
|
239
|
+
return memory_app
|
sqlsaber/cli/models.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Model management CLI commands."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
import questionary
|
|
8
|
+
import typer
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
from rich.table import Table
|
|
11
|
+
|
|
12
|
+
from sqlsaber.config.settings import Config
|
|
13
|
+
|
|
14
|
+
# Global instances for CLI commands
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
# Create the model management CLI app
|
|
18
|
+
models_app = typer.Typer(
|
|
19
|
+
name="models",
|
|
20
|
+
help="Select and manage models",
|
|
21
|
+
add_completion=True,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModelManager:
|
|
26
|
+
"""Manages AI model configuration and fetching."""
|
|
27
|
+
|
|
28
|
+
DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
|
|
29
|
+
MODELS_API_URL = "https://models.dev/api.json"
|
|
30
|
+
|
|
31
|
+
async def fetch_available_models(self) -> List[Dict]:
|
|
32
|
+
"""Fetch available models from models.dev API."""
|
|
33
|
+
try:
|
|
34
|
+
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
35
|
+
response = await client.get(self.MODELS_API_URL)
|
|
36
|
+
response.raise_for_status()
|
|
37
|
+
data = response.json()
|
|
38
|
+
|
|
39
|
+
# Filter for Anthropic models only
|
|
40
|
+
anthropic_models = []
|
|
41
|
+
anthropic_data = data.get("anthropic", {})
|
|
42
|
+
|
|
43
|
+
if "models" in anthropic_data:
|
|
44
|
+
for model_id, model_info in anthropic_data["models"].items():
|
|
45
|
+
# Convert to our format (anthropic:model-name)
|
|
46
|
+
formatted_id = f"anthropic:{model_id}"
|
|
47
|
+
|
|
48
|
+
# Extract cost information for display
|
|
49
|
+
cost_info = model_info.get("cost", {})
|
|
50
|
+
cost_display = ""
|
|
51
|
+
if cost_info:
|
|
52
|
+
input_cost = cost_info.get("input", 0)
|
|
53
|
+
output_cost = cost_info.get("output", 0)
|
|
54
|
+
cost_display = f"${input_cost}/{output_cost} per 1M tokens"
|
|
55
|
+
|
|
56
|
+
# Extract context length
|
|
57
|
+
limit_info = model_info.get("limit", {})
|
|
58
|
+
context_length = limit_info.get("context", 0)
|
|
59
|
+
|
|
60
|
+
anthropic_models.append(
|
|
61
|
+
{
|
|
62
|
+
"id": formatted_id,
|
|
63
|
+
"name": model_info.get("name", model_id),
|
|
64
|
+
"description": cost_display,
|
|
65
|
+
"context_length": context_length,
|
|
66
|
+
"knowledge": model_info.get("knowledge", ""),
|
|
67
|
+
}
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Sort by name for better display
|
|
71
|
+
anthropic_models.sort(key=lambda x: x["name"])
|
|
72
|
+
return anthropic_models
|
|
73
|
+
except Exception as e:
|
|
74
|
+
console.print(f"[red]Error fetching models: {e}[/red]")
|
|
75
|
+
return []
|
|
76
|
+
|
|
77
|
+
def get_current_model(self) -> str:
|
|
78
|
+
"""Get the currently configured model."""
|
|
79
|
+
config = Config()
|
|
80
|
+
return config.model_name
|
|
81
|
+
|
|
82
|
+
def set_model(self, model_id: str) -> bool:
|
|
83
|
+
"""Set the current model."""
|
|
84
|
+
try:
|
|
85
|
+
config = Config()
|
|
86
|
+
config.set_model(model_id)
|
|
87
|
+
return True
|
|
88
|
+
except Exception as e:
|
|
89
|
+
console.print(f"[red]Error setting model: {e}[/red]")
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
def reset_model(self) -> bool:
|
|
93
|
+
"""Reset to default model."""
|
|
94
|
+
return self.set_model(self.DEFAULT_MODEL)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
model_manager = ModelManager()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@models_app.command("list")
|
|
101
|
+
def list_models():
|
|
102
|
+
"""List available AI models."""
|
|
103
|
+
|
|
104
|
+
async def fetch_and_display():
|
|
105
|
+
console.print("[blue]Fetching available models...[/blue]")
|
|
106
|
+
models = await model_manager.fetch_available_models()
|
|
107
|
+
|
|
108
|
+
if not models:
|
|
109
|
+
console.print(
|
|
110
|
+
"[yellow]No models available or failed to fetch models[/yellow]"
|
|
111
|
+
)
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
table = Table(title="Available Anthropic Models")
|
|
115
|
+
table.add_column("ID", style="cyan")
|
|
116
|
+
table.add_column("Name", style="green")
|
|
117
|
+
table.add_column("Description", style="white")
|
|
118
|
+
table.add_column("Context", style="yellow", justify="right")
|
|
119
|
+
table.add_column("Current", style="bold red", justify="center")
|
|
120
|
+
|
|
121
|
+
current_model = model_manager.get_current_model()
|
|
122
|
+
|
|
123
|
+
for model in models:
|
|
124
|
+
is_current = "✓" if model["id"] == current_model else ""
|
|
125
|
+
context_str = (
|
|
126
|
+
f"{model['context_length']:,}" if model["context_length"] else "N/A"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Truncate description if too long
|
|
130
|
+
description = (
|
|
131
|
+
model["description"][:50] + "..."
|
|
132
|
+
if len(model["description"]) > 50
|
|
133
|
+
else model["description"]
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
table.add_row(
|
|
137
|
+
model["id"],
|
|
138
|
+
model["name"],
|
|
139
|
+
description,
|
|
140
|
+
context_str,
|
|
141
|
+
is_current,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
console.print(table)
|
|
145
|
+
console.print(f"\n[dim]Current model: {current_model}[/dim]")
|
|
146
|
+
|
|
147
|
+
asyncio.run(fetch_and_display())
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@models_app.command("set")
|
|
151
|
+
def set_model():
|
|
152
|
+
"""Set the AI model to use."""
|
|
153
|
+
|
|
154
|
+
async def interactive_set():
|
|
155
|
+
console.print("[blue]Fetching available models...[/blue]")
|
|
156
|
+
models = await model_manager.fetch_available_models()
|
|
157
|
+
|
|
158
|
+
if not models:
|
|
159
|
+
console.print("[red]Failed to fetch models. Cannot set model.[/red]")
|
|
160
|
+
raise typer.Exit(1)
|
|
161
|
+
|
|
162
|
+
# Create choices for questionary
|
|
163
|
+
choices = []
|
|
164
|
+
for model in models:
|
|
165
|
+
# Format: "ID - Name (Description)"
|
|
166
|
+
choice_text = f"{model['id']} - {model['name']}"
|
|
167
|
+
if model["description"]:
|
|
168
|
+
choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
|
|
169
|
+
|
|
170
|
+
choices.append({"name": choice_text, "value": model["id"]})
|
|
171
|
+
|
|
172
|
+
# Get current model to set as default
|
|
173
|
+
current_model = model_manager.get_current_model()
|
|
174
|
+
default_index = 0
|
|
175
|
+
for i, choice in enumerate(choices):
|
|
176
|
+
if choice["value"] == current_model:
|
|
177
|
+
default_index = i
|
|
178
|
+
break
|
|
179
|
+
|
|
180
|
+
selected_model = await questionary.select(
|
|
181
|
+
"Select a model:",
|
|
182
|
+
choices=choices,
|
|
183
|
+
use_shortcuts=True,
|
|
184
|
+
use_search_filter=True,
|
|
185
|
+
use_jk_keys=False, # Disable j/k keys when using search filter
|
|
186
|
+
default=choices[default_index] if choices else None,
|
|
187
|
+
).ask_async()
|
|
188
|
+
|
|
189
|
+
if selected_model:
|
|
190
|
+
if model_manager.set_model(selected_model):
|
|
191
|
+
console.print(f"[green]✓ Model set to: {selected_model}[/green]")
|
|
192
|
+
else:
|
|
193
|
+
console.print("[red]✗ Failed to set model[/red]")
|
|
194
|
+
raise typer.Exit(1)
|
|
195
|
+
else:
|
|
196
|
+
console.print("[yellow]Operation cancelled[/yellow]")
|
|
197
|
+
|
|
198
|
+
asyncio.run(interactive_set())
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@models_app.command("current")
|
|
202
|
+
def current_model():
|
|
203
|
+
"""Show the currently configured model."""
|
|
204
|
+
current = model_manager.get_current_model()
|
|
205
|
+
console.print(f"Current model: [cyan]{current}[/cyan]")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@models_app.command("reset")
|
|
209
|
+
def reset_model():
|
|
210
|
+
"""Reset to the default model."""
|
|
211
|
+
|
|
212
|
+
async def interactive_reset():
|
|
213
|
+
if await questionary.confirm(
|
|
214
|
+
f"Reset to default model ({ModelManager.DEFAULT_MODEL})?"
|
|
215
|
+
).ask_async():
|
|
216
|
+
if model_manager.reset_model():
|
|
217
|
+
console.print(
|
|
218
|
+
f"[green]✓ Model reset to default: {ModelManager.DEFAULT_MODEL}[/green]"
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
console.print("[red]✗ Failed to reset model[/red]")
|
|
222
|
+
raise typer.Exit(1)
|
|
223
|
+
else:
|
|
224
|
+
console.print("[yellow]Operation cancelled[/yellow]")
|
|
225
|
+
|
|
226
|
+
asyncio.run(interactive_reset())
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def create_models_app() -> typer.Typer:
|
|
230
|
+
"""Return the model management CLI app."""
|
|
231
|
+
return models_app
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Streaming query handling for the CLI."""
|
|
2
|
+
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
|
|
5
|
+
from sqlsaber.agents.base import BaseSQLAgent
|
|
6
|
+
from sqlsaber.cli.display import DisplayManager
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class StreamingQueryHandler:
|
|
10
|
+
"""Handles streaming query execution and display."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, console: Console):
|
|
13
|
+
self.console = console
|
|
14
|
+
self.display = DisplayManager(console)
|
|
15
|
+
|
|
16
|
+
async def execute_streaming_query(self, user_query: str, agent: BaseSQLAgent):
|
|
17
|
+
"""Execute a query with streaming display."""
|
|
18
|
+
|
|
19
|
+
has_content = False
|
|
20
|
+
explanation_started = False
|
|
21
|
+
status = self.console.status(
|
|
22
|
+
"[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
|
|
23
|
+
)
|
|
24
|
+
status.start()
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
async for event in agent.query_stream(user_query):
|
|
28
|
+
if event.type == "tool_use":
|
|
29
|
+
# Stop any ongoing status, but don't mark has_content yet
|
|
30
|
+
self._stop_status(status)
|
|
31
|
+
|
|
32
|
+
if event.data["status"] == "started":
|
|
33
|
+
# If explanation was streaming, add newline before tool use
|
|
34
|
+
if explanation_started:
|
|
35
|
+
self.display.show_newline()
|
|
36
|
+
self.display.show_tool_started(event.data["name"])
|
|
37
|
+
elif event.data["status"] == "executing":
|
|
38
|
+
self.display.show_tool_executing(
|
|
39
|
+
event.data["name"], event.data["input"]
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
elif event.type == "text":
|
|
43
|
+
# Always stop status when text streaming starts
|
|
44
|
+
self._stop_status(status)
|
|
45
|
+
|
|
46
|
+
if not explanation_started:
|
|
47
|
+
explanation_started = True
|
|
48
|
+
has_content = True
|
|
49
|
+
|
|
50
|
+
# Print text as it streams
|
|
51
|
+
self.display.show_text_stream(event.data)
|
|
52
|
+
|
|
53
|
+
elif event.type == "query_result":
|
|
54
|
+
if event.data["results"]:
|
|
55
|
+
self.display.show_query_results(event.data["results"])
|
|
56
|
+
|
|
57
|
+
elif event.type == "tool_result":
|
|
58
|
+
# Handle tool results - particularly list_tables and introspect_schema
|
|
59
|
+
if event.data.get("tool_name") == "list_tables":
|
|
60
|
+
self.display.show_table_list(event.data["result"])
|
|
61
|
+
has_content = True
|
|
62
|
+
elif event.data.get("tool_name") == "introspect_schema":
|
|
63
|
+
self.display.show_schema_info(event.data["result"])
|
|
64
|
+
has_content = True
|
|
65
|
+
|
|
66
|
+
elif event.type == "processing":
|
|
67
|
+
# Show status when processing tool results
|
|
68
|
+
if explanation_started:
|
|
69
|
+
self.display.show_newline() # Add newline after explanation text
|
|
70
|
+
self._stop_status(status)
|
|
71
|
+
status = self.display.show_processing(event.data)
|
|
72
|
+
status.start()
|
|
73
|
+
has_content = True
|
|
74
|
+
|
|
75
|
+
elif event.type == "error":
|
|
76
|
+
if not has_content:
|
|
77
|
+
self._stop_status(status)
|
|
78
|
+
has_content = True
|
|
79
|
+
self.display.show_error(event.data)
|
|
80
|
+
|
|
81
|
+
finally:
|
|
82
|
+
# Make sure status is stopped
|
|
83
|
+
self._stop_status(status)
|
|
84
|
+
|
|
85
|
+
# Add a newline after streaming completes if explanation was shown
|
|
86
|
+
if explanation_started:
|
|
87
|
+
self.display.show_newline() # Empty line for better readability
|
|
88
|
+
|
|
89
|
+
def _stop_status(self, status):
|
|
90
|
+
"""Safely stop a status spinner."""
|
|
91
|
+
try:
|
|
92
|
+
status.stop()
|
|
93
|
+
except Exception:
|
|
94
|
+
pass # Status might already be stopped
|