sqlsaber 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +283 -176
- sqlsaber/agents/base.py +11 -11
- sqlsaber/agents/streaming.py +3 -3
- sqlsaber/cli/auth.py +142 -0
- sqlsaber/cli/commands.py +9 -4
- sqlsaber/cli/completers.py +3 -5
- sqlsaber/cli/database.py +9 -10
- sqlsaber/cli/display.py +5 -7
- sqlsaber/cli/interactive.py +2 -3
- sqlsaber/cli/memory.py +7 -9
- sqlsaber/cli/models.py +1 -2
- sqlsaber/cli/streaming.py +5 -31
- sqlsaber/clients/__init__.py +6 -0
- sqlsaber/clients/anthropic.py +285 -0
- sqlsaber/clients/base.py +31 -0
- sqlsaber/clients/exceptions.py +117 -0
- sqlsaber/clients/models.py +282 -0
- sqlsaber/clients/streaming.py +257 -0
- sqlsaber/config/api_keys.py +2 -3
- sqlsaber/config/auth.py +86 -0
- sqlsaber/config/database.py +20 -20
- sqlsaber/config/oauth_flow.py +274 -0
- sqlsaber/config/oauth_tokens.py +175 -0
- sqlsaber/config/settings.py +37 -22
- sqlsaber/database/connection.py +9 -9
- sqlsaber/database/schema.py +25 -25
- sqlsaber/mcp/mcp.py +3 -4
- sqlsaber/memory/manager.py +3 -5
- sqlsaber/memory/storage.py +7 -8
- sqlsaber/models/events.py +4 -4
- sqlsaber/models/types.py +10 -10
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.1.dist-info}/METADATA +1 -1
- sqlsaber-0.8.1.dist-info/RECORD +46 -0
- sqlsaber-0.7.0.dist-info/RECORD +0 -36
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.1.dist-info}/WHEEL +0 -0
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.1.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.7.0.dist-info → sqlsaber-0.8.1.dist-info}/licenses/LICENSE +0 -0
sqlsaber/agents/base.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import json
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
-
from typing import Any, AsyncIterator
|
|
6
|
+
from typing import Any, AsyncIterator
|
|
7
7
|
|
|
8
8
|
from uniplot import histogram, plot
|
|
9
9
|
|
|
@@ -24,7 +24,7 @@ class BaseSQLAgent(ABC):
|
|
|
24
24
|
def __init__(self, db_connection: BaseDatabaseConnection):
|
|
25
25
|
self.db = db_connection
|
|
26
26
|
self.schema_manager = SchemaManager(db_connection)
|
|
27
|
-
self.conversation_history:
|
|
27
|
+
self.conversation_history: list[dict[str, Any]] = []
|
|
28
28
|
|
|
29
29
|
@abstractmethod
|
|
30
30
|
async def query_stream(
|
|
@@ -59,7 +59,7 @@ class BaseSQLAgent(ABC):
|
|
|
59
59
|
else:
|
|
60
60
|
return "database" # Fallback
|
|
61
61
|
|
|
62
|
-
async def introspect_schema(self, table_pattern:
|
|
62
|
+
async def introspect_schema(self, table_pattern: str | None = None) -> str:
|
|
63
63
|
"""Introspect database schema to understand table structures."""
|
|
64
64
|
try:
|
|
65
65
|
# Pass table_pattern to get_schema_info for efficient filtering at DB level
|
|
@@ -96,7 +96,7 @@ class BaseSQLAgent(ABC):
|
|
|
96
96
|
except Exception as e:
|
|
97
97
|
return json.dumps({"error": f"Error listing tables: {str(e)}"})
|
|
98
98
|
|
|
99
|
-
async def execute_sql(self, query: str, limit:
|
|
99
|
+
async def execute_sql(self, query: str, limit: int | None = None) -> str:
|
|
100
100
|
"""Execute a SQL query against the database."""
|
|
101
101
|
try:
|
|
102
102
|
# Security check - only allow SELECT queries unless write is enabled
|
|
@@ -147,7 +147,7 @@ class BaseSQLAgent(ABC):
|
|
|
147
147
|
return json.dumps({"error": error_msg, "suggestions": suggestions})
|
|
148
148
|
|
|
149
149
|
async def process_tool_call(
|
|
150
|
-
self, tool_name: str, tool_input:
|
|
150
|
+
self, tool_name: str, tool_input: dict[str, Any]
|
|
151
151
|
) -> str:
|
|
152
152
|
"""Process a tool call and return the result."""
|
|
153
153
|
if tool_name == "list_tables":
|
|
@@ -170,7 +170,7 @@ class BaseSQLAgent(ABC):
|
|
|
170
170
|
else:
|
|
171
171
|
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
172
172
|
|
|
173
|
-
def _validate_write_operation(self, query: str) ->
|
|
173
|
+
def _validate_write_operation(self, query: str) -> str | None:
|
|
174
174
|
"""Validate if a write operation is allowed.
|
|
175
175
|
|
|
176
176
|
Returns:
|
|
@@ -206,12 +206,12 @@ class BaseSQLAgent(ABC):
|
|
|
206
206
|
|
|
207
207
|
async def plot_data(
|
|
208
208
|
self,
|
|
209
|
-
y_values:
|
|
210
|
-
x_values:
|
|
209
|
+
y_values: list[float],
|
|
210
|
+
x_values: list[float] | None = None,
|
|
211
211
|
plot_type: str = "line",
|
|
212
|
-
title:
|
|
213
|
-
x_label:
|
|
214
|
-
y_label:
|
|
212
|
+
title: str | None = None,
|
|
213
|
+
x_label: str | None = None,
|
|
214
|
+
y_label: str | None = None,
|
|
215
215
|
) -> str:
|
|
216
216
|
"""Create a terminal plot using uniplot.
|
|
217
217
|
|
sqlsaber/agents/streaming.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
"""Streaming utilities for agents."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class StreamingResponse:
|
|
7
7
|
"""Helper class to manage streaming response construction."""
|
|
8
8
|
|
|
9
|
-
def __init__(self, content:
|
|
9
|
+
def __init__(self, content: list[dict[str, Any]], stop_reason: str):
|
|
10
10
|
self.content = content
|
|
11
11
|
self.stop_reason = stop_reason
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def build_tool_result_block(tool_use_id: str, content: str) ->
|
|
14
|
+
def build_tool_result_block(tool_use_id: str, content: str) -> dict[str, Any]:
|
|
15
15
|
"""Build a tool result block for the conversation."""
|
|
16
16
|
return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
|
sqlsaber/cli/auth.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Authentication CLI commands."""
|
|
2
|
+
|
|
3
|
+
import questionary
|
|
4
|
+
import typer
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
|
|
7
|
+
from sqlsaber.config.auth import AuthConfigManager, AuthMethod
|
|
8
|
+
from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
|
|
9
|
+
|
|
10
|
+
# Global instances for CLI commands
|
|
11
|
+
console = Console()
|
|
12
|
+
config_manager = AuthConfigManager()
|
|
13
|
+
|
|
14
|
+
# Create the authentication management CLI app
|
|
15
|
+
auth_app = typer.Typer(
|
|
16
|
+
name="auth",
|
|
17
|
+
help="Manage authentication configuration",
|
|
18
|
+
add_completion=True,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@auth_app.command("setup")
|
|
23
|
+
def setup_auth():
|
|
24
|
+
"""Configure authentication method for SQLSaber."""
|
|
25
|
+
console.print("\n[bold]SQLSaber Authentication Setup[/bold]\n")
|
|
26
|
+
|
|
27
|
+
# Use questionary for selection
|
|
28
|
+
auth_choice = questionary.select(
|
|
29
|
+
"Choose your authentication method:",
|
|
30
|
+
choices=[
|
|
31
|
+
questionary.Choice(
|
|
32
|
+
title="Anthropic API Key",
|
|
33
|
+
value=AuthMethod.API_KEY,
|
|
34
|
+
description="You can create one by visiting https://console.anthropic.com",
|
|
35
|
+
),
|
|
36
|
+
questionary.Choice(
|
|
37
|
+
title="Claude Pro or Max Subscription",
|
|
38
|
+
value=AuthMethod.CLAUDE_PRO,
|
|
39
|
+
description="This does not require creating an API Key, but requires a subscription at https://claude.ai",
|
|
40
|
+
),
|
|
41
|
+
],
|
|
42
|
+
).ask()
|
|
43
|
+
|
|
44
|
+
if auth_choice is None:
|
|
45
|
+
console.print("[yellow]Setup cancelled.[/yellow]")
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
# Handle auth method setup
|
|
49
|
+
if auth_choice == AuthMethod.API_KEY:
|
|
50
|
+
console.print("\nTo configure your API key, you can either:")
|
|
51
|
+
console.print("• Set the ANTHROPIC_API_KEY environment variable")
|
|
52
|
+
console.print(
|
|
53
|
+
"• Let SQLsaber prompt you for the key when needed (stored securely)"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
config_manager.set_auth_method(auth_choice)
|
|
57
|
+
console.print("\n[bold green]Authentication method saved![/bold green]")
|
|
58
|
+
|
|
59
|
+
elif auth_choice == AuthMethod.CLAUDE_PRO:
|
|
60
|
+
oauth_flow = AnthropicOAuthFlow()
|
|
61
|
+
try:
|
|
62
|
+
success = oauth_flow.authenticate()
|
|
63
|
+
if success:
|
|
64
|
+
config_manager.set_auth_method(auth_choice)
|
|
65
|
+
console.print(
|
|
66
|
+
"\n[bold green]Authentication setup complete![/bold green]"
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
console.print(
|
|
70
|
+
"\n[yellow]OAuth authentication failed. Please try again.[/yellow]"
|
|
71
|
+
)
|
|
72
|
+
return
|
|
73
|
+
except Exception as e:
|
|
74
|
+
console.print(f"\n[red]Authentication setup failed: {str(e)}[/red]")
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
console.print(
|
|
78
|
+
"You can change this anytime by running [cyan]saber auth setup[/cyan] again."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@auth_app.command("status")
|
|
83
|
+
def show_auth_status():
|
|
84
|
+
"""Show current authentication configuration."""
|
|
85
|
+
auth_method = config_manager.get_auth_method()
|
|
86
|
+
|
|
87
|
+
console.print("\n[bold blue]Authentication Status[/bold blue]")
|
|
88
|
+
|
|
89
|
+
if auth_method is None:
|
|
90
|
+
console.print("[yellow]No authentication method configured[/yellow]")
|
|
91
|
+
console.print("Run [cyan]saber auth setup[/cyan] to configure authentication.")
|
|
92
|
+
else:
|
|
93
|
+
if auth_method == AuthMethod.API_KEY:
|
|
94
|
+
console.print("[green]✓ API Key authentication configured[/green]")
|
|
95
|
+
console.print("Using Anthropic API key for authentication")
|
|
96
|
+
elif auth_method == AuthMethod.CLAUDE_PRO:
|
|
97
|
+
console.print("[green]✓ Claude Pro/Max subscription configured[/green]")
|
|
98
|
+
|
|
99
|
+
# Check OAuth token status
|
|
100
|
+
oauth_flow = AnthropicOAuthFlow()
|
|
101
|
+
if oauth_flow.has_valid_authentication():
|
|
102
|
+
console.print("OAuth token is valid and ready to use")
|
|
103
|
+
else:
|
|
104
|
+
console.print("[yellow]OAuth token missing or expired[/yellow]")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@auth_app.command("reset")
|
|
108
|
+
def reset_auth():
|
|
109
|
+
"""Reset authentication configuration."""
|
|
110
|
+
if not config_manager.has_auth_configured():
|
|
111
|
+
console.print("[yellow]No authentication configuration to reset.[/yellow]")
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
current_method = config_manager.get_auth_method()
|
|
115
|
+
method_name = (
|
|
116
|
+
"API Key" if current_method == AuthMethod.API_KEY else "Claude Pro/Max"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if questionary.confirm(
|
|
120
|
+
f"Are you sure you want to reset the current authentication method ({method_name})?",
|
|
121
|
+
default=False,
|
|
122
|
+
).ask():
|
|
123
|
+
# If Claude Pro, also remove OAuth tokens
|
|
124
|
+
if current_method == AuthMethod.CLAUDE_PRO:
|
|
125
|
+
oauth_flow = AnthropicOAuthFlow()
|
|
126
|
+
oauth_flow.remove_authentication()
|
|
127
|
+
|
|
128
|
+
# Clear the auth config by setting it to None
|
|
129
|
+
config = config_manager._load_config()
|
|
130
|
+
config["auth_method"] = None
|
|
131
|
+
config_manager._save_config(config)
|
|
132
|
+
console.print("[green]Authentication configuration reset.[/green]")
|
|
133
|
+
console.print(
|
|
134
|
+
"Run [cyan]saber auth setup[/cyan] to configure authentication again."
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
console.print("Reset cancelled.")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def create_auth_app() -> typer.Typer:
|
|
141
|
+
"""Return the authentication management CLI app."""
|
|
142
|
+
return auth_app
|
sqlsaber/cli/commands.py
CHANGED
|
@@ -2,12 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
import typer
|
|
8
7
|
from rich.console import Console
|
|
9
8
|
|
|
10
9
|
from sqlsaber.agents.anthropic import AnthropicSQLAgent
|
|
10
|
+
from sqlsaber.cli.auth import create_auth_app
|
|
11
11
|
from sqlsaber.cli.database import create_db_app
|
|
12
12
|
from sqlsaber.cli.interactive import InteractiveSession
|
|
13
13
|
from sqlsaber.cli.memory import create_memory_app
|
|
@@ -29,7 +29,7 @@ config_manager = DatabaseConfigManager()
|
|
|
29
29
|
|
|
30
30
|
@app.callback()
|
|
31
31
|
def main_callback(
|
|
32
|
-
database:
|
|
32
|
+
database: str | None = typer.Option(
|
|
33
33
|
None,
|
|
34
34
|
"--database",
|
|
35
35
|
"-d",
|
|
@@ -49,11 +49,11 @@ def main_callback(
|
|
|
49
49
|
|
|
50
50
|
@app.command()
|
|
51
51
|
def query(
|
|
52
|
-
query_text:
|
|
52
|
+
query_text: str | None = typer.Argument(
|
|
53
53
|
None,
|
|
54
54
|
help="SQL query in natural language (if not provided, starts interactive mode)",
|
|
55
55
|
),
|
|
56
|
-
database:
|
|
56
|
+
database: str | None = typer.Option(
|
|
57
57
|
None,
|
|
58
58
|
"--database",
|
|
59
59
|
"-d",
|
|
@@ -128,6 +128,7 @@ def query(
|
|
|
128
128
|
|
|
129
129
|
finally:
|
|
130
130
|
# Clean up
|
|
131
|
+
await agent.close() # Close the agent's HTTP client
|
|
131
132
|
await db_conn.close()
|
|
132
133
|
console.print("\n[green]Goodbye![/green]")
|
|
133
134
|
|
|
@@ -135,6 +136,10 @@ def query(
|
|
|
135
136
|
asyncio.run(run_session())
|
|
136
137
|
|
|
137
138
|
|
|
139
|
+
# Add authentication management commands
|
|
140
|
+
auth_app = create_auth_app()
|
|
141
|
+
app.add_typer(auth_app, name="auth")
|
|
142
|
+
|
|
138
143
|
# Add database management commands after main callback is defined
|
|
139
144
|
db_app = create_db_app()
|
|
140
145
|
app.add_typer(db_app, name="db")
|
sqlsaber/cli/completers.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Command line completers for the CLI interface."""
|
|
2
2
|
|
|
3
|
-
from typing import List, Tuple
|
|
4
|
-
|
|
5
3
|
from prompt_toolkit.completion import Completer, Completion
|
|
6
4
|
|
|
7
5
|
|
|
@@ -37,13 +35,13 @@ class TableNameCompleter(Completer):
|
|
|
37
35
|
"""Custom completer for table names."""
|
|
38
36
|
|
|
39
37
|
def __init__(self):
|
|
40
|
-
self._table_cache:
|
|
38
|
+
self._table_cache: list[tuple[str, str]] = []
|
|
41
39
|
|
|
42
|
-
def update_cache(self, tables_data:
|
|
40
|
+
def update_cache(self, tables_data: list[tuple[str, str]]):
|
|
43
41
|
"""Update the cache with fresh table data."""
|
|
44
42
|
self._table_cache = tables_data
|
|
45
43
|
|
|
46
|
-
def _get_table_names(self) ->
|
|
44
|
+
def _get_table_names(self) -> list[tuple[str, str]]:
|
|
47
45
|
"""Get table names from cache."""
|
|
48
46
|
return self._table_cache
|
|
49
47
|
|
sqlsaber/cli/database.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import getpass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Optional
|
|
7
6
|
|
|
8
7
|
import questionary
|
|
9
8
|
import typer
|
|
@@ -34,24 +33,24 @@ def add_database(
|
|
|
34
33
|
"-t",
|
|
35
34
|
help="Database type (postgresql, mysql, sqlite)",
|
|
36
35
|
),
|
|
37
|
-
host:
|
|
38
|
-
port:
|
|
39
|
-
database:
|
|
36
|
+
host: str | None = typer.Option(None, "--host", "-h", help="Database host"),
|
|
37
|
+
port: int | None = typer.Option(None, "--port", "-p", help="Database port"),
|
|
38
|
+
database: str | None = typer.Option(
|
|
40
39
|
None, "--database", "--db", help="Database name"
|
|
41
40
|
),
|
|
42
|
-
username:
|
|
43
|
-
ssl_mode:
|
|
41
|
+
username: str | None = typer.Option(None, "--username", "-u", help="Username"),
|
|
42
|
+
ssl_mode: str | None = typer.Option(
|
|
44
43
|
None,
|
|
45
44
|
"--ssl-mode",
|
|
46
45
|
help="SSL mode (disable, allow, prefer, require, verify-ca, verify-full for PostgreSQL; DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY for MySQL)",
|
|
47
46
|
),
|
|
48
|
-
ssl_ca:
|
|
47
|
+
ssl_ca: str | None = typer.Option(
|
|
49
48
|
None, "--ssl-ca", help="SSL CA certificate file path"
|
|
50
49
|
),
|
|
51
|
-
ssl_cert:
|
|
50
|
+
ssl_cert: str | None = typer.Option(
|
|
52
51
|
None, "--ssl-cert", help="SSL client certificate file path"
|
|
53
52
|
),
|
|
54
|
-
ssl_key:
|
|
53
|
+
ssl_key: str | None = typer.Option(
|
|
55
54
|
None, "--ssl-key", help="SSL client private key file path"
|
|
56
55
|
),
|
|
57
56
|
interactive: bool = typer.Option(
|
|
@@ -310,7 +309,7 @@ def set_default_database(
|
|
|
310
309
|
|
|
311
310
|
@db_app.command("test")
|
|
312
311
|
def test_database(
|
|
313
|
-
name:
|
|
312
|
+
name: str | None = typer.Argument(
|
|
314
313
|
None,
|
|
315
314
|
help="Name of the database connection to test (uses default if not specified)",
|
|
316
315
|
),
|
sqlsaber/cli/display.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Display utilities for the CLI interface."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from typing import Optional
|
|
5
4
|
|
|
6
5
|
from rich.console import Console
|
|
7
6
|
from rich.markdown import Markdown
|
|
@@ -19,7 +18,7 @@ class DisplayManager:
|
|
|
19
18
|
self,
|
|
20
19
|
columns: list,
|
|
21
20
|
header_style: str = "bold blue",
|
|
22
|
-
title:
|
|
21
|
+
title: str | None = None,
|
|
23
22
|
) -> Table:
|
|
24
23
|
"""Create a Rich table with specified columns."""
|
|
25
24
|
table = Table(show_header=True, header_style=header_style, title=title)
|
|
@@ -32,12 +31,9 @@ class DisplayManager:
|
|
|
32
31
|
table.add_column(col)
|
|
33
32
|
return table
|
|
34
33
|
|
|
35
|
-
def show_tool_started(self, tool_name: str):
|
|
36
|
-
"""Display tool started message."""
|
|
37
|
-
self.console.print(f"\n[yellow]🔧 Using tool: {tool_name}[/yellow]")
|
|
38
|
-
|
|
39
34
|
def show_tool_executing(self, tool_name: str, tool_input: dict):
|
|
40
35
|
"""Display tool execution details."""
|
|
36
|
+
self.console.print(f"\n[yellow]🔧 Using tool: {tool_name}[/yellow]")
|
|
41
37
|
if tool_name == "list_tables":
|
|
42
38
|
self.console.print("[dim] → Discovering available tables[/dim]")
|
|
43
39
|
elif tool_name == "introspect_schema":
|
|
@@ -46,12 +42,14 @@ class DisplayManager:
|
|
|
46
42
|
elif tool_name == "execute_sql":
|
|
47
43
|
query = tool_input.get("query", "")
|
|
48
44
|
self.console.print("\n[bold green]Executing SQL:[/bold green]")
|
|
45
|
+
self.show_newline()
|
|
49
46
|
syntax = Syntax(query, "sql")
|
|
50
47
|
self.console.print(syntax)
|
|
51
48
|
|
|
52
49
|
def show_text_stream(self, text: str):
|
|
53
50
|
"""Display streaming text."""
|
|
54
|
-
|
|
51
|
+
if text is not None: # Extra safety check
|
|
52
|
+
self.console.print(text, end="", markup=False)
|
|
55
53
|
|
|
56
54
|
def show_query_results(self, results: list):
|
|
57
55
|
"""Display query results in a formatted table."""
|
sqlsaber/cli/interactive.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Interactive mode handling for the CLI."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from typing import Optional
|
|
5
4
|
|
|
6
5
|
import questionary
|
|
7
6
|
from rich.console import Console
|
|
@@ -25,8 +24,8 @@ class InteractiveSession:
|
|
|
25
24
|
self.agent = agent
|
|
26
25
|
self.display = DisplayManager(console)
|
|
27
26
|
self.streaming_handler = StreamingQueryHandler(console)
|
|
28
|
-
self.current_task:
|
|
29
|
-
self.cancellation_token:
|
|
27
|
+
self.current_task: asyncio.Task | None = None
|
|
28
|
+
self.cancellation_token: asyncio.Event | None = None
|
|
30
29
|
self.table_completer = TableNameCompleter()
|
|
31
30
|
|
|
32
31
|
def show_welcome_message(self):
|
sqlsaber/cli/memory.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Memory management CLI commands."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
3
|
import typer
|
|
6
4
|
from rich.console import Console
|
|
7
5
|
from rich.table import Table
|
|
@@ -22,7 +20,7 @@ memory_app = typer.Typer(
|
|
|
22
20
|
)
|
|
23
21
|
|
|
24
22
|
|
|
25
|
-
def _get_database_name(database:
|
|
23
|
+
def _get_database_name(database: str | None = None) -> str:
|
|
26
24
|
"""Get the database name to use, either specified or default."""
|
|
27
25
|
if database:
|
|
28
26
|
db_config = config_manager.get_database(database)
|
|
@@ -46,7 +44,7 @@ def _get_database_name(database: Optional[str] = None) -> str:
|
|
|
46
44
|
@memory_app.command("add")
|
|
47
45
|
def add_memory(
|
|
48
46
|
content: str = typer.Argument(..., help="Memory content to add"),
|
|
49
|
-
database:
|
|
47
|
+
database: str | None = typer.Option(
|
|
50
48
|
None,
|
|
51
49
|
"--database",
|
|
52
50
|
"-d",
|
|
@@ -68,7 +66,7 @@ def add_memory(
|
|
|
68
66
|
|
|
69
67
|
@memory_app.command("list")
|
|
70
68
|
def list_memories(
|
|
71
|
-
database:
|
|
69
|
+
database: str | None = typer.Option(
|
|
72
70
|
None,
|
|
73
71
|
"--database",
|
|
74
72
|
"-d",
|
|
@@ -107,7 +105,7 @@ def list_memories(
|
|
|
107
105
|
@memory_app.command("show")
|
|
108
106
|
def show_memory(
|
|
109
107
|
memory_id: str = typer.Argument(..., help="Memory ID to show"),
|
|
110
|
-
database:
|
|
108
|
+
database: str | None = typer.Option(
|
|
111
109
|
None,
|
|
112
110
|
"--database",
|
|
113
111
|
"-d",
|
|
@@ -135,7 +133,7 @@ def show_memory(
|
|
|
135
133
|
@memory_app.command("remove")
|
|
136
134
|
def remove_memory(
|
|
137
135
|
memory_id: str = typer.Argument(..., help="Memory ID to remove"),
|
|
138
|
-
database:
|
|
136
|
+
database: str | None = typer.Option(
|
|
139
137
|
None,
|
|
140
138
|
"--database",
|
|
141
139
|
"-d",
|
|
@@ -170,7 +168,7 @@ def remove_memory(
|
|
|
170
168
|
|
|
171
169
|
@memory_app.command("clear")
|
|
172
170
|
def clear_memories(
|
|
173
|
-
database:
|
|
171
|
+
database: str | None = typer.Option(
|
|
174
172
|
None,
|
|
175
173
|
"--database",
|
|
176
174
|
"-d",
|
|
@@ -213,7 +211,7 @@ def clear_memories(
|
|
|
213
211
|
|
|
214
212
|
@memory_app.command("summary")
|
|
215
213
|
def memory_summary(
|
|
216
|
-
database:
|
|
214
|
+
database: str | None = typer.Option(
|
|
217
215
|
None,
|
|
218
216
|
"--database",
|
|
219
217
|
"-d",
|
sqlsaber/cli/models.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Model management CLI commands."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from typing import Dict, List
|
|
5
4
|
|
|
6
5
|
import httpx
|
|
7
6
|
import questionary
|
|
@@ -28,7 +27,7 @@ class ModelManager:
|
|
|
28
27
|
DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
|
|
29
28
|
MODELS_API_URL = "https://models.dev/api.json"
|
|
30
29
|
|
|
31
|
-
async def fetch_available_models(self) ->
|
|
30
|
+
async def fetch_available_models(self) -> list[dict]:
|
|
32
31
|
"""Fetch available models from models.dev API."""
|
|
33
32
|
try:
|
|
34
33
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -23,8 +23,6 @@ class StreamingQueryHandler:
|
|
|
23
23
|
):
|
|
24
24
|
"""Execute a query with streaming display."""
|
|
25
25
|
|
|
26
|
-
has_content = False
|
|
27
|
-
explanation_started = False
|
|
28
26
|
status = self.console.status(
|
|
29
27
|
"[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
|
|
30
28
|
)
|
|
@@ -38,15 +36,10 @@ class StreamingQueryHandler:
|
|
|
38
36
|
break
|
|
39
37
|
|
|
40
38
|
if event.type == "tool_use":
|
|
41
|
-
# Stop any ongoing status, but don't mark has_content yet
|
|
42
39
|
self._stop_status(status)
|
|
43
40
|
|
|
44
|
-
if event.data["status"] == "
|
|
45
|
-
|
|
46
|
-
if explanation_started:
|
|
47
|
-
self.display.show_newline()
|
|
48
|
-
self.display.show_tool_started(event.data["name"])
|
|
49
|
-
elif event.data["status"] == "executing":
|
|
41
|
+
if event.data["status"] == "executing":
|
|
42
|
+
self.display.show_newline()
|
|
50
43
|
self.display.show_tool_executing(
|
|
51
44
|
event.data["name"], event.data["input"]
|
|
52
45
|
)
|
|
@@ -54,12 +47,6 @@ class StreamingQueryHandler:
|
|
|
54
47
|
elif event.type == "text":
|
|
55
48
|
# Always stop status when text streaming starts
|
|
56
49
|
self._stop_status(status)
|
|
57
|
-
|
|
58
|
-
if not explanation_started:
|
|
59
|
-
explanation_started = True
|
|
60
|
-
has_content = True
|
|
61
|
-
|
|
62
|
-
# Print text as it streams
|
|
63
50
|
self.display.show_text_stream(event.data)
|
|
64
51
|
|
|
65
52
|
elif event.type == "query_result":
|
|
@@ -70,46 +57,33 @@ class StreamingQueryHandler:
|
|
|
70
57
|
# Handle tool results - particularly list_tables and introspect_schema
|
|
71
58
|
if event.data.get("tool_name") == "list_tables":
|
|
72
59
|
self.display.show_table_list(event.data["result"])
|
|
73
|
-
has_content = True
|
|
74
60
|
elif event.data.get("tool_name") == "introspect_schema":
|
|
75
61
|
self.display.show_schema_info(event.data["result"])
|
|
76
|
-
has_content = True
|
|
77
62
|
|
|
78
63
|
elif event.type == "plot_result":
|
|
79
64
|
# Handle plot results
|
|
80
65
|
self.display.show_plot(event.data)
|
|
81
|
-
has_content = True
|
|
82
66
|
|
|
83
67
|
elif event.type == "processing":
|
|
84
|
-
#
|
|
85
|
-
if explanation_started:
|
|
86
|
-
self.display.show_newline() # Add newline after explanation text
|
|
68
|
+
self.display.show_newline() # Add newline after explanation text
|
|
87
69
|
self._stop_status(status)
|
|
88
70
|
status = self.display.show_processing(event.data)
|
|
89
71
|
status.start()
|
|
90
|
-
has_content = True
|
|
91
72
|
|
|
92
73
|
elif event.type == "error":
|
|
93
|
-
|
|
94
|
-
self._stop_status(status)
|
|
95
|
-
has_content = True
|
|
74
|
+
self._stop_status(status)
|
|
96
75
|
self.display.show_error(event.data)
|
|
97
76
|
|
|
98
77
|
except asyncio.CancelledError:
|
|
99
78
|
# Handle cancellation gracefully
|
|
100
79
|
self._stop_status(status)
|
|
101
|
-
|
|
102
|
-
self.display.show_newline()
|
|
80
|
+
self.display.show_newline()
|
|
103
81
|
self.console.print("[yellow]Query interrupted[/yellow]")
|
|
104
82
|
return
|
|
105
83
|
finally:
|
|
106
84
|
# Make sure status is stopped
|
|
107
85
|
self._stop_status(status)
|
|
108
86
|
|
|
109
|
-
# Add a newline after streaming completes if explanation was shown
|
|
110
|
-
if explanation_started:
|
|
111
|
-
self.display.show_newline() # Empty line for better readability
|
|
112
|
-
|
|
113
87
|
# Display the last assistant response as markdown
|
|
114
88
|
if hasattr(agent, "conversation_history") and agent.conversation_history:
|
|
115
89
|
last_message = agent.conversation_history[-1]
|