sqlsaber 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -3,10 +3,10 @@
3
3
  import asyncio
4
4
 
5
5
  import questionary
6
+ from pydantic_ai import Agent
6
7
  from rich.console import Console
7
8
  from rich.panel import Panel
8
9
 
9
- from sqlsaber.agents.base import BaseSQLAgent
10
10
  from sqlsaber.cli.completers import (
11
11
  CompositeCompleter,
12
12
  SlashCommandCompleter,
@@ -14,25 +14,44 @@ from sqlsaber.cli.completers import (
14
14
  )
15
15
  from sqlsaber.cli.display import DisplayManager
16
16
  from sqlsaber.cli.streaming import StreamingQueryHandler
17
+ from sqlsaber.database.schema import SchemaManager
17
18
 
18
19
 
19
20
  class InteractiveSession:
20
21
  """Manages interactive CLI sessions."""
21
22
 
22
- def __init__(self, console: Console, agent: BaseSQLAgent):
23
+ def __init__(self, console: Console, agent: Agent, db_conn, database_name: str):
23
24
  self.console = console
24
25
  self.agent = agent
26
+ self.db_conn = db_conn
27
+ self.database_name = database_name
25
28
  self.display = DisplayManager(console)
26
29
  self.streaming_handler = StreamingQueryHandler(console)
27
30
  self.current_task: asyncio.Task | None = None
28
31
  self.cancellation_token: asyncio.Event | None = None
29
32
  self.table_completer = TableNameCompleter()
33
+ self.message_history: list | None = []
30
34
 
31
35
  def show_welcome_message(self):
32
36
  """Display welcome message for interactive mode."""
33
37
  # Show database information
34
- db_name = getattr(self.agent, "database_name", None) or "Unknown"
35
- db_type = self.agent._get_database_type_name()
38
+ db_name = self.database_name or "Unknown"
39
+ from sqlsaber.database.connection import (
40
+ CSVConnection,
41
+ MySQLConnection,
42
+ PostgreSQLConnection,
43
+ SQLiteConnection,
44
+ )
45
+
46
+ db_type = (
47
+ "PostgreSQL"
48
+ if isinstance(self.db_conn, PostgreSQLConnection)
49
+ else "MySQL"
50
+ if isinstance(self.db_conn, MySQLConnection)
51
+ else "SQLite"
52
+ if isinstance(self.db_conn, (SQLiteConnection, CSVConnection))
53
+ else "database"
54
+ )
36
55
 
37
56
  self.console.print(
38
57
  Panel.fit(
@@ -44,26 +63,27 @@ class InteractiveSession:
44
63
  ███████ ██████ ███████ ███████ ██ ██ ██████ ███████ ██ ██
45
64
  ▀▀
46
65
  """
47
- "\n\n"
48
- "[dim]Use '/clear' to reset conversation, '/exit' or '/quit' to leave.[/dim]\n\n"
49
- "[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]\n\n"
50
- "[dim]Type '@' to get table name completions.[/dim]",
51
- border_style="green",
52
66
  )
53
67
  )
54
68
  self.console.print(
55
- f"[bold blue]Connected to:[/bold blue] {db_name} ({db_type})\n"
69
+ "\n",
70
+ "[dim] ≥ Use '/clear' to reset conversation",
71
+ "[dim] ≥ Use '/exit' or '/quit' to leave[/dim]",
72
+ "[dim] ≥ Use 'Ctrl+C' to interrupt and return to prompt\n\n",
73
+ "[dim] ≥ Start message with '#' to add something to agent's memory for this database",
74
+ "[dim] ≥ Type '@' to get table name completions",
75
+ "[dim] ≥ Press 'Esc-Enter' or 'Meta-Enter' to submit your question",
76
+ sep="\n",
56
77
  )
78
+
57
79
  self.console.print(
58
- "[dim]Press Esc-Enter or Meta-Enter to submit your query.[/dim]\n"
59
- "[dim]Press Ctrl+C during query execution to interrupt and return to prompt.[/dim]\n"
80
+ f"[bold blue]\n\nConnected to:[/bold blue] {db_name} ({db_type})\n"
60
81
  )
61
82
 
62
83
  async def _update_table_cache(self):
63
84
  """Update the table completer cache with fresh data."""
64
85
  try:
65
- # Use the schema manager directly which has built-in caching
66
- tables_data = await self.agent.schema_manager.list_tables()
86
+ tables_data = await SchemaManager(self.db_conn).list_tables()
67
87
 
68
88
  # Parse the table information
69
89
  table_list = []
@@ -100,16 +120,20 @@ class InteractiveSession:
100
120
  # Create the query task
101
121
  query_task = asyncio.create_task(
102
122
  self.streaming_handler.execute_streaming_query(
103
- user_query, self.agent, self.cancellation_token
123
+ user_query, self.agent, self.cancellation_token, self.message_history
104
124
  )
105
125
  )
106
126
  self.current_task = query_task
107
127
 
108
128
  try:
109
- # Simply await the query task
110
- # Ctrl+C will be handled by the KeyboardInterrupt exception in run()
111
- await query_task
112
-
129
+ run_result = await query_task
130
+ # Persist message history from this run using pydantic-ai API
131
+ if run_result is not None:
132
+ try:
133
+ # Use all_messages() so the system prompt and all prior turns are preserved
134
+ self.message_history = run_result.all_messages()
135
+ except Exception:
136
+ pass
113
137
  finally:
114
138
  self.current_task = None
115
139
  self.cancellation_token = None
@@ -144,7 +168,8 @@ class InteractiveSession:
144
168
  break
145
169
 
146
170
  if user_query == "/clear":
147
- await self.agent.clear_history()
171
+ # Reset local history (pydantic-ai call will receive empty history on next run)
172
+ self.message_history = []
148
173
  self.console.print("[green]Conversation history cleared.[/green]\n")
149
174
  continue
150
175
 
@@ -153,18 +178,28 @@ class InteractiveSession:
153
178
  if memory_text.startswith("#"):
154
179
  memory_content = memory_text[1:].strip() # Remove # and trim
155
180
  if memory_content:
156
- # Add memory
157
- memory_id = self.agent.add_memory(memory_content)
158
- if memory_id:
159
- self.console.print(
160
- f"[green]✓ Memory added:[/green] {memory_content}"
161
- )
162
- self.console.print(
163
- f"[dim]Memory ID: {memory_id}[/dim]\n"
181
+ # Add memory via the agent's memory manager
182
+ try:
183
+ mm = getattr(
184
+ self.agent, "_sqlsaber_memory_manager", None
164
185
  )
165
- else:
186
+ if mm and self.database_name:
187
+ memory = mm.add_memory(
188
+ self.database_name, memory_content
189
+ )
190
+ self.console.print(
191
+ f"[green]✓ Memory added:[/green] {memory_content}"
192
+ )
193
+ self.console.print(
194
+ f"[dim]Memory ID: {memory.id}[/dim]\n"
195
+ )
196
+ else:
197
+ self.console.print(
198
+ "[yellow]Could not add memory (no database context)[/yellow]\n"
199
+ )
200
+ except Exception:
166
201
  self.console.print(
167
- "[yellow]Could not add memory (no database context)[/yellow]\n"
202
+ "[yellow]Could not add memory[/yellow]\n"
168
203
  )
169
204
  else:
170
205
  self.console.print(
sqlsaber/cli/models.py CHANGED
@@ -3,12 +3,13 @@
3
3
  import asyncio
4
4
  import sys
5
5
 
6
+ import cyclopts
6
7
  import httpx
7
8
  import questionary
8
- import cyclopts
9
9
  from rich.console import Console
10
10
  from rich.table import Table
11
11
 
12
+ from sqlsaber.config import providers
12
13
  from sqlsaber.config.settings import Config
13
14
 
14
15
  # Global instances for CLI commands
@@ -26,49 +27,75 @@ class ModelManager:
26
27
 
27
28
  DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
28
29
  MODELS_API_URL = "https://models.dev/api.json"
30
+ # Providers come from central registry
31
+ SUPPORTED_PROVIDERS = providers.all_keys()
32
+
33
+ async def fetch_available_models(
34
+ self, providers: list[str] | None = None
35
+ ) -> list[dict]:
36
+ """Fetch available models across providers from models.dev API.
29
37
 
30
- async def fetch_available_models(self) -> list[dict]:
31
- """Fetch available models from models.dev API."""
38
+ Returns list of dicts with keys: id (provider:model_id), provider, name, description, context_length, knowledge.
39
+ """
32
40
  try:
33
41
  async with httpx.AsyncClient(timeout=10.0) as client:
34
42
  response = await client.get(self.MODELS_API_URL)
35
43
  response.raise_for_status()
36
44
  data = response.json()
37
45
 
38
- # Filter for Anthropic models only
39
- anthropic_models = []
40
- anthropic_data = data.get("anthropic", {})
41
-
42
- if "models" in anthropic_data:
43
- for model_id, model_info in anthropic_data["models"].items():
44
- # Convert to our format (anthropic:model-name)
45
- formatted_id = f"anthropic:{model_id}"
46
-
47
- # Extract cost information for display
48
- cost_info = model_info.get("cost", {})
46
+ providers = providers or self.SUPPORTED_PROVIDERS
47
+ results: list[dict] = []
48
+
49
+ for provider in providers:
50
+ prov_data = data.get(provider, {})
51
+ models_obj = (
52
+ prov_data.get("models") or prov_data.get("Models") or {}
53
+ )
54
+ if not isinstance(models_obj, dict):
55
+ continue
56
+ for model_id, model_info in models_obj.items():
57
+ formatted_id = f"{provider}:{model_id}"
58
+ # cost
59
+ cost_info = (
60
+ model_info.get("cost", {})
61
+ if isinstance(model_info, dict)
62
+ else {}
63
+ )
49
64
  cost_display = ""
50
- if cost_info:
65
+ if isinstance(cost_info, dict) and cost_info:
51
66
  input_cost = cost_info.get("input", 0)
52
67
  output_cost = cost_info.get("output", 0)
53
68
  cost_display = f"${input_cost}/{output_cost} per 1M tokens"
69
+ # context
70
+ limit_info = (
71
+ model_info.get("limit", {})
72
+ if isinstance(model_info, dict)
73
+ else {}
74
+ )
75
+ context_length = (
76
+ limit_info.get("context", 0)
77
+ if isinstance(limit_info, dict)
78
+ else 0
79
+ )
54
80
 
55
- # Extract context length
56
- limit_info = model_info.get("limit", {})
57
- context_length = limit_info.get("context", 0)
58
-
59
- anthropic_models.append(
81
+ results.append(
60
82
  {
61
83
  "id": formatted_id,
62
- "name": model_info.get("name", model_id),
84
+ "provider": provider,
85
+ "name": model_info.get("name", model_id)
86
+ if isinstance(model_info, dict)
87
+ else model_id,
63
88
  "description": cost_display,
64
89
  "context_length": context_length,
65
- "knowledge": model_info.get("knowledge", ""),
90
+ "knowledge": model_info.get("knowledge", "")
91
+ if isinstance(model_info, dict)
92
+ else "",
66
93
  }
67
94
  )
68
95
 
69
- # Sort by name for better display
70
- anthropic_models.sort(key=lambda x: x["name"])
71
- return anthropic_models
96
+ # Sort by provider then by name
97
+ results.sort(key=lambda x: (x["provider"], x["name"]))
98
+ return results
72
99
  except Exception as e:
73
100
  console.print(f"[red]Error fetching models: {e}[/red]")
74
101
  return []
@@ -110,7 +137,8 @@ def list():
110
137
  )
111
138
  return
112
139
 
113
- table = Table(title="Available Anthropic Models")
140
+ table = Table(title="Available Models")
141
+ table.add_column("Provider", style="magenta")
114
142
  table.add_column("ID", style="cyan")
115
143
  table.add_column("Name", style="green")
116
144
  table.add_column("Description", style="white")
@@ -133,6 +161,7 @@ def list():
133
161
  )
134
162
 
135
163
  table.add_row(
164
+ model.get("provider", "-"),
136
165
  model["id"],
137
166
  model["name"],
138
167
  description,
@@ -161,8 +190,9 @@ def set():
161
190
  # Create choices for questionary
162
191
  choices = []
163
192
  for model in models:
164
- # Format: "ID - Name (Description)"
165
- choice_text = f"{model['id']} - {model['name']}"
193
+ # Format: "[provider] ID - Name (Description)"
194
+ prov = model.get("provider", "?")
195
+ choice_text = f"[{prov}] {model['id']} - {model['name']}"
166
196
  if model["description"]:
167
197
  choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
168
198
 
@@ -179,7 +209,6 @@ def set():
179
209
  selected_model = await questionary.select(
180
210
  "Select a model:",
181
211
  choices=choices,
182
- use_shortcuts=True,
183
212
  use_search_filter=True,
184
213
  use_jk_keys=False, # Disable j/k keys when using search filter
185
214
  default=choices[default_index] if choices else None,
sqlsaber/cli/streaming.py CHANGED
@@ -1,100 +1,137 @@
1
- """Streaming query handling for the CLI."""
1
+ """Streaming query handling for the CLI (pydantic-ai based)."""
2
2
 
3
3
  import asyncio
4
-
4
+ import json
5
+ from typing import AsyncIterable
6
+
7
+ from pydantic_ai import Agent, RunContext
8
+ from pydantic_ai.messages import (
9
+ AgentStreamEvent,
10
+ FunctionToolCallEvent,
11
+ FunctionToolResultEvent,
12
+ PartDeltaEvent,
13
+ PartStartEvent,
14
+ TextPart,
15
+ TextPartDelta,
16
+ ThinkingPart,
17
+ ThinkingPartDelta,
18
+ )
5
19
  from rich.console import Console
6
20
 
7
- from sqlsaber.agents.base import BaseSQLAgent
8
21
  from sqlsaber.cli.display import DisplayManager
9
22
 
10
23
 
11
24
  class StreamingQueryHandler:
12
- """Handles streaming query execution and display."""
25
+ """Handles streaming query execution and display using pydantic-ai events."""
13
26
 
14
27
  def __init__(self, console: Console):
15
28
  self.console = console
16
29
  self.display = DisplayManager(console)
17
30
 
31
+ self.status = self.console.status(
32
+ "[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
33
+ )
34
+
35
+ async def _event_stream_handler(
36
+ self, ctx: RunContext, event_stream: AsyncIterable[AgentStreamEvent]
37
+ ) -> None:
38
+ async for event in event_stream:
39
+ if isinstance(event, PartStartEvent):
40
+ if isinstance(event.part, (TextPart, ThinkingPart)):
41
+ self.status.stop()
42
+ self.display.show_text_stream(event.part.content)
43
+
44
+ elif isinstance(event, PartDeltaEvent):
45
+ if isinstance(event.delta, (TextPartDelta, ThinkingPartDelta)):
46
+ delta = event.delta.content_delta or ""
47
+ if delta:
48
+ self.status.stop()
49
+ self.display.show_text_stream(delta)
50
+
51
+ elif isinstance(event, FunctionToolCallEvent):
52
+ # Show tool execution start
53
+ self.status.stop()
54
+ args = event.part.args_as_dict()
55
+ self.display.show_newline()
56
+ self.display.show_tool_executing(event.part.tool_name, args)
57
+
58
+ elif isinstance(event, FunctionToolResultEvent):
59
+ self.status.stop()
60
+ # Route tool result to appropriate display
61
+ tool_name = event.result.tool_name
62
+ content = event.result.content
63
+ if tool_name == "list_tables":
64
+ self.display.show_table_list(content)
65
+ elif tool_name == "introspect_schema":
66
+ self.display.show_schema_info(content)
67
+ elif tool_name == "execute_sql":
68
+ try:
69
+ data = json.loads(content)
70
+ if data.get("success") and data.get("results"):
71
+ self.display.show_query_results(data["results"]) # type: ignore[arg-type]
72
+ except json.JSONDecodeError:
73
+ # If not JSON, ignore here
74
+ pass
75
+ elif tool_name == "plot_data":
76
+ self.display.show_plot(
77
+ {"tool_name": tool_name, "result": content, "input": {}}
78
+ )
79
+
18
80
  async def execute_streaming_query(
19
81
  self,
20
82
  user_query: str,
21
- agent: BaseSQLAgent,
83
+ agent: Agent,
22
84
  cancellation_token: asyncio.Event | None = None,
85
+ message_history: list | None = None,
23
86
  ):
24
- """Execute a query with streaming display."""
25
-
26
- status = self.console.status(
27
- "[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
28
- )
29
- status.start()
30
-
87
+ self.status.start()
31
88
  try:
32
- async for event in agent.query_stream(
33
- user_query, cancellation_token=cancellation_token
34
- ):
35
- if cancellation_token is not None and cancellation_token.is_set():
36
- break
37
-
38
- if event.type == "tool_use":
39
- self._stop_status(status)
40
-
41
- if event.data["status"] == "executing":
42
- self.display.show_newline()
43
- self.display.show_tool_executing(
44
- event.data["name"], event.data["input"]
45
- )
46
-
47
- elif event.type == "text":
48
- # Always stop status when text streaming starts
49
- self._stop_status(status)
50
- self.display.show_text_stream(event.data)
51
-
52
- elif event.type == "query_result":
53
- if event.data["results"]:
54
- self.display.show_query_results(event.data["results"])
55
-
56
- elif event.type == "tool_result":
57
- # Handle tool results - particularly list_tables and introspect_schema
58
- if event.data.get("tool_name") == "list_tables":
59
- self.display.show_table_list(event.data["result"])
60
- elif event.data.get("tool_name") == "introspect_schema":
61
- self.display.show_schema_info(event.data["result"])
62
-
63
- elif event.type == "plot_result":
64
- # Handle plot results
65
- self.display.show_plot(event.data)
66
-
67
- elif event.type == "processing":
68
- self.display.show_newline() # Add newline after explanation text
69
- self._stop_status(status)
70
- status = self.display.show_processing(event.data)
71
- status.start()
72
-
73
- elif event.type == "error":
74
- self._stop_status(status)
75
- self.display.show_error(event.data)
76
-
89
+ # If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
90
+ prepared_prompt: str | list[str] = user_query
91
+ is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
92
+ no_history = not message_history
93
+ if is_oauth and no_history:
94
+ ib = getattr(agent, "_sqlsaber_instruction_builder", None)
95
+ mm = getattr(agent, "_sqlsaber_memory_manager", None)
96
+ db_type = getattr(agent, "_sqlsaber_db_type", "database")
97
+ db_name = getattr(agent, "_sqlsaber_database_name", None)
98
+ instructions = (
99
+ ib.build_instructions(db_type=db_type) if ib is not None else ""
100
+ )
101
+ mem = (
102
+ mm.format_memories_for_prompt(db_name)
103
+ if (mm is not None and db_name)
104
+ else ""
105
+ )
106
+ parts = [p for p in (instructions, mem) if p and str(p).strip()]
107
+ if parts:
108
+ injected = "\n\n".join(parts)
109
+ prepared_prompt = [injected, user_query]
110
+
111
+ # Run the agent with our event stream handler
112
+ run = await agent.run(
113
+ prepared_prompt,
114
+ message_history=message_history,
115
+ event_stream_handler=self._event_stream_handler,
116
+ )
117
+ # After the run completes, show the assistant's final text as markdown if available
118
+ try:
119
+ output = run.output
120
+ if isinstance(output, str) and output.strip():
121
+ self.display.show_newline()
122
+ self.display.show_markdown_response(
123
+ [{"type": "text", "text": output}]
124
+ )
125
+ except Exception as e:
126
+ self.display.show_error(str(e))
127
+ self.display.show_newline()
128
+ return run
77
129
  except asyncio.CancelledError:
78
- # Handle cancellation gracefully
79
- self._stop_status(status)
80
130
  self.display.show_newline()
81
131
  self.console.print("[yellow]Query interrupted[/yellow]")
82
- return
132
+ return None
83
133
  finally:
84
- # Make sure status is stopped
85
- self._stop_status(status)
86
-
87
- # Display the last assistant response as markdown
88
- if hasattr(agent, "conversation_history") and agent.conversation_history:
89
- last_message = agent.conversation_history[-1]
90
- if last_message.get("role") == "assistant" and last_message.get(
91
- "content"
92
- ):
93
- self.display.show_markdown_response(last_message["content"])
94
-
95
- def _stop_status(self, status):
96
- """Safely stop a status spinner."""
97
- try:
98
- status.stop()
99
- except Exception:
100
- pass # Status might already be stopped
134
+ try:
135
+ self.status.stop()
136
+ except Exception:
137
+ pass
@@ -6,6 +6,8 @@ import os
6
6
  import keyring
7
7
  from rich.console import Console
8
8
 
9
+ from sqlsaber.config import providers
10
+
9
11
  console = Console()
10
12
 
11
13
 
@@ -30,9 +32,7 @@ class APIKeyManager:
30
32
  try:
31
33
  api_key = keyring.get_password(service_name, provider)
32
34
  if api_key:
33
- console.print(
34
- f"Using stored {provider} API key from keyring", style="dim"
35
- )
35
+ console.print(f"Using stored {provider} API key", style="dim")
36
36
  return api_key
37
37
  except Exception as e:
38
38
  # Keyring access failed, continue to prompt
@@ -43,12 +43,9 @@ class APIKeyManager:
43
43
 
44
44
  def _get_env_var_name(self, provider: str) -> str:
45
45
  """Get the expected environment variable name for a provider."""
46
- if provider == "openai":
47
- return "OPENAI_API_KEY"
48
- elif provider == "anthropic":
49
- return "ANTHROPIC_API_KEY"
50
- else:
51
- return "AI_API_KEY"
46
+ # Normalize aliases to canonical provider keys
47
+ key = providers.canonical(provider) or provider
48
+ return providers.env_var_name(key)
52
49
 
53
50
  def _get_service_name(self, provider: str) -> str:
54
51
  """Get the keyring service name for a provider."""
@@ -60,7 +57,7 @@ class APIKeyManager:
60
57
  """Prompt user for API key and store it in keyring."""
61
58
  try:
62
59
  console.print(
63
- f"\n{provider.title()} API key not found in environment or keyring."
60
+ f"\n{provider.title()} API key not found in environment or your OS's credentials store."
64
61
  )
65
62
  console.print("You can either:")
66
63
  console.print(f" 1. Set the {env_var_name} environment variable")
@@ -85,7 +82,8 @@ class APIKeyManager:
85
82
  console.print("API key stored securely for future use", style="green")
86
83
  except Exception as e:
87
84
  console.print(
88
- f"Warning: Could not store API key in keyring: {e}", style="yellow"
85
+ f"Warning: Could not store API key in your operating system's credentials store: {e}",
86
+ style="yellow",
89
87
  )
90
88
  console.print(
91
89
  "You may need to enter it again next time", style="yellow"