sqlsaber 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -1,46 +1,21 @@
1
1
  """Interactive mode handling for the CLI."""
2
2
 
3
3
  import asyncio
4
- from typing import Optional
5
4
 
6
5
  import questionary
7
- from prompt_toolkit.completion import Completer, Completion
8
6
  from rich.console import Console
9
7
  from rich.panel import Panel
10
8
 
11
9
  from sqlsaber.agents.base import BaseSQLAgent
10
+ from sqlsaber.cli.completers import (
11
+ CompositeCompleter,
12
+ SlashCommandCompleter,
13
+ TableNameCompleter,
14
+ )
12
15
  from sqlsaber.cli.display import DisplayManager
13
16
  from sqlsaber.cli.streaming import StreamingQueryHandler
14
17
 
15
18
 
16
- class SlashCommandCompleter(Completer):
17
- """Custom completer for slash commands."""
18
-
19
- def get_completions(self, document, complete_event):
20
- """Get completions for slash commands."""
21
- # Only provide completions if the line starts with "/"
22
- text = document.text
23
- if text.startswith("/"):
24
- # Get the partial command after the slash
25
- partial_cmd = text[1:]
26
-
27
- # Define available commands with descriptions
28
- commands = [
29
- ("clear", "Clear conversation history"),
30
- ("exit", "Exit the interactive session"),
31
- ("quit", "Exit the interactive session"),
32
- ]
33
-
34
- # Yield completions that match the partial command
35
- for cmd, description in commands:
36
- if cmd.startswith(partial_cmd):
37
- yield Completion(
38
- cmd,
39
- start_position=-len(partial_cmd),
40
- display_meta=description,
41
- )
42
-
43
-
44
19
  class InteractiveSession:
45
20
  """Manages interactive CLI sessions."""
46
21
 
@@ -49,8 +24,9 @@ class InteractiveSession:
49
24
  self.agent = agent
50
25
  self.display = DisplayManager(console)
51
26
  self.streaming_handler = StreamingQueryHandler(console)
52
- self.current_task: Optional[asyncio.Task] = None
53
- self.cancellation_token: Optional[asyncio.Event] = None
27
+ self.current_task: asyncio.Task | None = None
28
+ self.cancellation_token: asyncio.Event | None = None
29
+ self.table_completer = TableNameCompleter()
54
30
 
55
31
  def show_welcome_message(self):
56
32
  """Display welcome message for interactive mode."""
@@ -63,7 +39,8 @@ class InteractiveSession:
63
39
  "[bold green]SQLSaber - Use the agent Luke![/bold green]\n\n"
64
40
  "[bold]Your agentic SQL assistant.[/bold]\n\n\n"
65
41
  "[dim]Use '/clear' to reset conversation, '/exit' or '/quit' to leave.[/dim]\n\n"
66
- "[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]",
42
+ "[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]\n\n"
43
+ "[dim]Type '@' to get table name completions.[/dim]",
67
44
  border_style="green",
68
45
  )
69
46
  )
@@ -75,6 +52,39 @@ class InteractiveSession:
75
52
  "[dim]Press Ctrl+C during query execution to interrupt and return to prompt.[/dim]\n"
76
53
  )
77
54
 
55
+ async def _update_table_cache(self):
56
+ """Update the table completer cache with fresh data."""
57
+ try:
58
+ # Use the schema manager directly which has built-in caching
59
+ tables_data = await self.agent.schema_manager.list_tables()
60
+
61
+ # Parse the table information
62
+ table_list = []
63
+ if isinstance(tables_data, dict) and "tables" in tables_data:
64
+ for table in tables_data["tables"]:
65
+ if isinstance(table, dict):
66
+ name = table.get("name", "")
67
+ schema = table.get("schema", "")
68
+ full_name = table.get("full_name", "")
69
+
70
+ # Use full_name if available, otherwise construct it
71
+ if full_name:
72
+ table_name = full_name
73
+ elif schema and schema != "main":
74
+ table_name = f"{schema}.{name}"
75
+ else:
76
+ table_name = name
77
+
78
+ # No description needed - cleaner completions
79
+ table_list.append((table_name, ""))
80
+
81
+ # Update the completer cache
82
+ self.table_completer.update_cache(table_list)
83
+
84
+ except Exception:
85
+ # If there's an error, just use empty cache
86
+ self.table_completer.update_cache([])
87
+
78
88
  async def _execute_query_with_cancellation(self, user_query: str):
79
89
  """Execute a query with cancellation support."""
80
90
  # Create cancellation token
@@ -101,6 +111,9 @@ class InteractiveSession:
101
111
  """Run the interactive session loop."""
102
112
  self.show_welcome_message()
103
113
 
114
+ # Initialize table cache
115
+ await self._update_table_cache()
116
+
104
117
  while True:
105
118
  try:
106
119
  user_query = await questionary.text(
@@ -108,7 +121,9 @@ class InteractiveSession:
108
121
  qmark="",
109
122
  multiline=True,
110
123
  instruction="",
111
- completer=SlashCommandCompleter(),
124
+ completer=CompositeCompleter(
125
+ SlashCommandCompleter(), self.table_completer
126
+ ),
112
127
  ).ask_async()
113
128
 
114
129
  if not user_query:
sqlsaber/cli/memory.py CHANGED
@@ -1,7 +1,5 @@
1
1
  """Memory management CLI commands."""
2
2
 
3
- from typing import Optional
4
-
5
3
  import typer
6
4
  from rich.console import Console
7
5
  from rich.table import Table
@@ -22,7 +20,7 @@ memory_app = typer.Typer(
22
20
  )
23
21
 
24
22
 
25
- def _get_database_name(database: Optional[str] = None) -> str:
23
+ def _get_database_name(database: str | None = None) -> str:
26
24
  """Get the database name to use, either specified or default."""
27
25
  if database:
28
26
  db_config = config_manager.get_database(database)
@@ -46,7 +44,7 @@ def _get_database_name(database: Optional[str] = None) -> str:
46
44
  @memory_app.command("add")
47
45
  def add_memory(
48
46
  content: str = typer.Argument(..., help="Memory content to add"),
49
- database: Optional[str] = typer.Option(
47
+ database: str | None = typer.Option(
50
48
  None,
51
49
  "--database",
52
50
  "-d",
@@ -68,7 +66,7 @@ def add_memory(
68
66
 
69
67
  @memory_app.command("list")
70
68
  def list_memories(
71
- database: Optional[str] = typer.Option(
69
+ database: str | None = typer.Option(
72
70
  None,
73
71
  "--database",
74
72
  "-d",
@@ -107,7 +105,7 @@ def list_memories(
107
105
  @memory_app.command("show")
108
106
  def show_memory(
109
107
  memory_id: str = typer.Argument(..., help="Memory ID to show"),
110
- database: Optional[str] = typer.Option(
108
+ database: str | None = typer.Option(
111
109
  None,
112
110
  "--database",
113
111
  "-d",
@@ -135,7 +133,7 @@ def show_memory(
135
133
  @memory_app.command("remove")
136
134
  def remove_memory(
137
135
  memory_id: str = typer.Argument(..., help="Memory ID to remove"),
138
- database: Optional[str] = typer.Option(
136
+ database: str | None = typer.Option(
139
137
  None,
140
138
  "--database",
141
139
  "-d",
@@ -170,7 +168,7 @@ def remove_memory(
170
168
 
171
169
  @memory_app.command("clear")
172
170
  def clear_memories(
173
- database: Optional[str] = typer.Option(
171
+ database: str | None = typer.Option(
174
172
  None,
175
173
  "--database",
176
174
  "-d",
@@ -213,7 +211,7 @@ def clear_memories(
213
211
 
214
212
  @memory_app.command("summary")
215
213
  def memory_summary(
216
- database: Optional[str] = typer.Option(
214
+ database: str | None = typer.Option(
217
215
  None,
218
216
  "--database",
219
217
  "-d",
sqlsaber/cli/models.py CHANGED
@@ -1,7 +1,6 @@
1
1
  """Model management CLI commands."""
2
2
 
3
3
  import asyncio
4
- from typing import Dict, List
5
4
 
6
5
  import httpx
7
6
  import questionary
@@ -28,7 +27,7 @@ class ModelManager:
28
27
  DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
29
28
  MODELS_API_URL = "https://models.dev/api.json"
30
29
 
31
- async def fetch_available_models(self) -> List[Dict]:
30
+ async def fetch_available_models(self) -> list[dict]:
32
31
  """Fetch available models from models.dev API."""
33
32
  try:
34
33
  async with httpx.AsyncClient(timeout=10.0) as client:
sqlsaber/cli/streaming.py CHANGED
@@ -23,8 +23,6 @@ class StreamingQueryHandler:
23
23
  ):
24
24
  """Execute a query with streaming display."""
25
25
 
26
- has_content = False
27
- explanation_started = False
28
26
  status = self.console.status(
29
27
  "[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
30
28
  )
@@ -38,15 +36,10 @@ class StreamingQueryHandler:
38
36
  break
39
37
 
40
38
  if event.type == "tool_use":
41
- # Stop any ongoing status, but don't mark has_content yet
42
39
  self._stop_status(status)
43
40
 
44
- if event.data["status"] == "started":
45
- # If explanation was streaming, add newline before tool use
46
- if explanation_started:
47
- self.display.show_newline()
48
- self.display.show_tool_started(event.data["name"])
49
- elif event.data["status"] == "executing":
41
+ if event.data["status"] == "executing":
42
+ self.display.show_newline()
50
43
  self.display.show_tool_executing(
51
44
  event.data["name"], event.data["input"]
52
45
  )
@@ -54,12 +47,6 @@ class StreamingQueryHandler:
54
47
  elif event.type == "text":
55
48
  # Always stop status when text streaming starts
56
49
  self._stop_status(status)
57
-
58
- if not explanation_started:
59
- explanation_started = True
60
- has_content = True
61
-
62
- # Print text as it streams
63
50
  self.display.show_text_stream(event.data)
64
51
 
65
52
  elif event.type == "query_result":
@@ -70,45 +57,40 @@ class StreamingQueryHandler:
70
57
  # Handle tool results - particularly list_tables and introspect_schema
71
58
  if event.data.get("tool_name") == "list_tables":
72
59
  self.display.show_table_list(event.data["result"])
73
- has_content = True
74
60
  elif event.data.get("tool_name") == "introspect_schema":
75
61
  self.display.show_schema_info(event.data["result"])
76
- has_content = True
77
62
 
78
63
  elif event.type == "plot_result":
79
64
  # Handle plot results
80
65
  self.display.show_plot(event.data)
81
- has_content = True
82
66
 
83
67
  elif event.type == "processing":
84
- # Show status when processing tool results
85
- if explanation_started:
86
- self.display.show_newline() # Add newline after explanation text
68
+ self.display.show_newline() # Add newline after explanation text
87
69
  self._stop_status(status)
88
70
  status = self.display.show_processing(event.data)
89
71
  status.start()
90
- has_content = True
91
72
 
92
73
  elif event.type == "error":
93
- if not has_content:
94
- self._stop_status(status)
95
- has_content = True
74
+ self._stop_status(status)
96
75
  self.display.show_error(event.data)
97
76
 
98
77
  except asyncio.CancelledError:
99
78
  # Handle cancellation gracefully
100
79
  self._stop_status(status)
101
- if explanation_started:
102
- self.display.show_newline()
80
+ self.display.show_newline()
103
81
  self.console.print("[yellow]Query interrupted[/yellow]")
104
82
  return
105
83
  finally:
106
84
  # Make sure status is stopped
107
85
  self._stop_status(status)
108
86
 
109
- # Add a newline after streaming completes if explanation was shown
110
- if explanation_started:
111
- self.display.show_newline() # Empty line for better readability
87
+ # Display the last assistant response as markdown
88
+ if hasattr(agent, "conversation_history") and agent.conversation_history:
89
+ last_message = agent.conversation_history[-1]
90
+ if last_message.get("role") == "assistant" and last_message.get(
91
+ "content"
92
+ ):
93
+ self.display.show_markdown_response(last_message["content"])
112
94
 
113
95
  def _stop_status(self, status):
114
96
  """Safely stop a status spinner."""
@@ -0,0 +1,6 @@
1
+ """Client implementations for various LLM APIs."""
2
+
3
+ from .base import BaseLLMClient
4
+ from .anthropic import AnthropicClient
5
+
6
+ __all__ = ["BaseLLMClient", "AnthropicClient"]
@@ -0,0 +1,285 @@
1
+ """Anthropic API client implementation."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from typing import Any, AsyncIterator
7
+
8
+ import httpx
9
+
10
+ from .base import BaseLLMClient
11
+ from .exceptions import LLMClientError, create_exception_from_response
12
+ from .models import CreateMessageRequest
13
+ from .streaming import AnthropicStreamAdapter, StreamingResponse
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class AnthropicClient(BaseLLMClient):
19
+ """Client for Anthropic's Claude API."""
20
+
21
+ def __init__(
22
+ self,
23
+ api_key: str | None = None,
24
+ oauth_token: str | None = None,
25
+ base_url: str | None = None,
26
+ ):
27
+ """Initialize the Anthropic client.
28
+
29
+ Args:
30
+ api_key: Anthropic API key
31
+ base_url: Base URL for the API (defaults to Anthropic's API)
32
+ """
33
+ super().__init__(api_key or "", base_url)
34
+
35
+ if not api_key and not oauth_token:
36
+ raise ValueError("Either api_key or oauth_token must be provided")
37
+
38
+ self.oauth_token = oauth_token
39
+ self.use_oauth = oauth_token is not None
40
+ self.base_url = base_url or "https://api.anthropic.com"
41
+ self.client: httpx.AsyncClient | None = None
42
+
43
+ def _get_client(self) -> httpx.AsyncClient:
44
+ """Get or create the HTTP client."""
45
+ if self.client is None or self.client.is_closed:
46
+ # Configure timeouts and connection limits for reliability
47
+ timeout = httpx.Timeout(
48
+ connect=10.0, # Connection timeout
49
+ read=60.0, # Read timeout for streaming
50
+ write=10.0, # Write timeout
51
+ pool=10.0, # Pool timeout
52
+ )
53
+ limits = httpx.Limits(
54
+ max_keepalive_connections=20, max_connections=100, keepalive_expiry=30.0
55
+ )
56
+ self.client = httpx.AsyncClient(
57
+ timeout=timeout, limits=limits, follow_redirects=True
58
+ )
59
+ return self.client
60
+
61
+ def _get_headers(self) -> dict[str, str]:
62
+ """Get the standard headers for API requests."""
63
+ if self.use_oauth:
64
+ # OAuth headers for Claude Pro authentication (matching Claude Code CLI)
65
+ return {
66
+ "Authorization": f"Bearer {self.oauth_token}",
67
+ "Content-Type": "application/json",
68
+ "anthropic-version": "2023-06-01",
69
+ "anthropic-beta": "oauth-2025-04-20",
70
+ "User-Agent": "ClaudeCode/1.0 (Anthropic Claude Code CLI)",
71
+ "Accept": "application/json",
72
+ "X-Client-Name": "claude-code",
73
+ "X-Client-Version": "1.0.0",
74
+ }
75
+ else:
76
+ # API key headers for standard authentication
77
+ return {
78
+ "x-api-key": self.api_key,
79
+ "anthropic-version": "2023-06-01",
80
+ "content-type": "application/json",
81
+ }
82
+
83
+ async def create_message_with_tools(
84
+ self,
85
+ request: CreateMessageRequest,
86
+ cancellation_token: asyncio.Event | None = None,
87
+ ) -> AsyncIterator[Any]:
88
+ """Create a message with tool support and stream the response.
89
+
90
+ This method handles the full message creation flow including tool use,
91
+ similar to what the current AnthropicSQLAgent expects.
92
+
93
+ Args:
94
+ request: The message creation request
95
+ cancellation_token: Optional event to signal cancellation
96
+
97
+ Yields:
98
+ Stream events and final StreamingResponse
99
+ """
100
+ request.stream = True
101
+
102
+ client = self._get_client()
103
+ url = f"{self.base_url}/v1/messages"
104
+ headers = self._get_headers()
105
+ data = request.to_dict()
106
+
107
+ try:
108
+ async with client.stream(
109
+ "POST", url, headers=headers, json=data
110
+ ) as response:
111
+ request_id = response.headers.get("request-id")
112
+
113
+ if response.status_code != 200:
114
+ response_content = await response.aread()
115
+ response_data = json.loads(response_content.decode())
116
+ raise create_exception_from_response(
117
+ response.status_code, response_data, request_id
118
+ )
119
+
120
+ # Use stream adapter to convert raw events and track state
121
+ adapter = AnthropicStreamAdapter()
122
+ raw_stream = self._process_sse_stream(response, cancellation_token)
123
+
124
+ async for event in adapter.process_stream(
125
+ raw_stream, cancellation_token
126
+ ):
127
+ yield event
128
+
129
+ # Create final response object with proper state
130
+ response_obj = StreamingResponse(
131
+ content=adapter.get_content_blocks(),
132
+ stop_reason=adapter.get_stop_reason(),
133
+ )
134
+
135
+ # Yield special event with response
136
+ yield {"type": "response_ready", "data": response_obj}
137
+
138
+ except asyncio.CancelledError:
139
+ # Handle cancellation gracefully
140
+ logger.debug("Stream cancelled")
141
+ return
142
+ except Exception as e:
143
+ if not isinstance(e, LLMClientError):
144
+ raise LLMClientError(f"Stream processing error: {str(e)}")
145
+ raise
146
+
147
+ def _handle_ping_event(self, event_data: str) -> dict[str, Any]:
148
+ """Handle ping event data.
149
+
150
+ Args:
151
+ event_data: Raw event data string
152
+
153
+ Returns:
154
+ Parsed ping event
155
+ """
156
+ try:
157
+ return {"type": "ping", "data": json.loads(event_data)}
158
+ except json.JSONDecodeError:
159
+ return {"type": "ping", "data": {}}
160
+
161
+ def _handle_error_event(self, event_data: str) -> None:
162
+ """Handle error event data.
163
+
164
+ Args:
165
+ event_data: Raw event data string
166
+
167
+ Raises:
168
+ LLMClientError: Always raises with error details
169
+ """
170
+ try:
171
+ error_data = json.loads(event_data)
172
+ raise LLMClientError(
173
+ error_data.get("message", "Stream error"),
174
+ error_data.get("type", "stream_error"),
175
+ )
176
+ except json.JSONDecodeError:
177
+ raise LLMClientError("Stream error with invalid JSON")
178
+
179
+ def _parse_event_data(
180
+ self, event_type: str | None, event_data: str
181
+ ) -> dict[str, Any] | None:
182
+ """Parse event data based on event type.
183
+
184
+ Args:
185
+ event_type: Type of the event
186
+ event_data: Raw event data string
187
+
188
+ Returns:
189
+ Parsed event or None if parsing failed
190
+ """
191
+ try:
192
+ parsed_data = json.loads(event_data)
193
+ return {"type": event_type, "data": parsed_data}
194
+ except json.JSONDecodeError as e:
195
+ logger.warning(f"Failed to parse stream data for event {event_type}: {e}")
196
+ return None
197
+
198
+ def _process_sse_line(
199
+ self, line: str, event_type: str | None
200
+ ) -> tuple[str | None, dict[str, Any] | None]:
201
+ """Process a single SSE line.
202
+
203
+ Args:
204
+ line: Line to process
205
+ event_type: Current event type
206
+
207
+ Returns:
208
+ Tuple of (new_event_type, event_to_yield)
209
+ """
210
+ if line.startswith("event: "):
211
+ return line[7:], None
212
+ elif line.startswith("data: "):
213
+ event_data = line[6:]
214
+
215
+ if event_type == "ping":
216
+ return event_type, self._handle_ping_event(event_data)
217
+ elif event_type == "error":
218
+ self._handle_error_event(event_data)
219
+ return event_type, None # Never reached due to exception
220
+ else:
221
+ parsed_event = self._parse_event_data(event_type, event_data)
222
+ return event_type, parsed_event
223
+
224
+ return event_type, None
225
+
226
+ async def _process_sse_stream(
227
+ self,
228
+ response: httpx.Response,
229
+ cancellation_token: asyncio.Event | None = None,
230
+ ) -> AsyncIterator[dict[str, Any]]:
231
+ """Process server-sent events from the response stream.
232
+
233
+ Args:
234
+ response: The HTTP response object
235
+ cancellation_token: Optional event to signal cancellation
236
+
237
+ Yields:
238
+ Parsed stream events
239
+
240
+ Raises:
241
+ LLMClientError: If stream processing fails
242
+ """
243
+ buffer = ""
244
+ event_type = None
245
+
246
+ try:
247
+ async for chunk in response.aiter_bytes():
248
+ if cancellation_token is not None and cancellation_token.is_set():
249
+ return
250
+
251
+ try:
252
+ buffer += chunk.decode("utf-8")
253
+ except UnicodeDecodeError as e:
254
+ logger.warning(f"Failed to decode chunk: {e}")
255
+ continue
256
+
257
+ while "\n" in buffer:
258
+ line, buffer = buffer.split("\n", 1)
259
+ line = line.strip()
260
+
261
+ if not line:
262
+ continue
263
+
264
+ event_type, event_to_yield = self._process_sse_line(
265
+ line, event_type
266
+ )
267
+ if event_to_yield is not None:
268
+ yield event_to_yield
269
+
270
+ except httpx.TimeoutException as e:
271
+ raise LLMClientError(f"Stream timeout error: {str(e)}")
272
+ except httpx.NetworkError as e:
273
+ raise LLMClientError(f"Network error during streaming: {str(e)}")
274
+ except httpx.HTTPError as e:
275
+ raise LLMClientError(f"HTTP error during streaming: {str(e)}")
276
+ except asyncio.TimeoutError:
277
+ raise LLMClientError("Stream timeout")
278
+ except Exception as e:
279
+ raise LLMClientError(f"Unexpected error during streaming: {str(e)}")
280
+
281
+ async def close(self):
282
+ """Close the HTTP client."""
283
+ if self.client and not self.client.is_closed:
284
+ await self.client.aclose()
285
+ self.client = None
@@ -0,0 +1,31 @@
1
+ """Abstract base class for LLM clients."""
2
+
3
+ from abc import ABC
4
+
5
+
6
+ class BaseLLMClient(ABC):
7
+ """Abstract base class for LLM API clients."""
8
+
9
+ def __init__(self, api_key: str, base_url: str | None = None):
10
+ """Initialize the client with API key and optional base URL.
11
+
12
+ Args:
13
+ api_key: API key for authentication
14
+ base_url: Base URL for the API (optional, uses default if not provided)
15
+ """
16
+ self.api_key = api_key
17
+ self.base_url = base_url
18
+
19
+ async def close(self):
20
+ """Close the client and clean up resources."""
21
+ # Default implementation does nothing
22
+ # Subclasses can override to clean up HTTP sessions, etc.
23
+ pass
24
+
25
+ async def __aenter__(self):
26
+ """Async context manager entry."""
27
+ return self
28
+
29
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
30
+ """Async context manager exit."""
31
+ await self.close()