sqlsaber 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +283 -176
- sqlsaber/agents/base.py +11 -11
- sqlsaber/agents/streaming.py +3 -3
- sqlsaber/cli/auth.py +142 -0
- sqlsaber/cli/commands.py +9 -4
- sqlsaber/cli/completers.py +170 -0
- sqlsaber/cli/database.py +9 -10
- sqlsaber/cli/display.py +27 -7
- sqlsaber/cli/interactive.py +49 -34
- sqlsaber/cli/memory.py +7 -9
- sqlsaber/cli/models.py +1 -2
- sqlsaber/cli/streaming.py +12 -30
- sqlsaber/clients/__init__.py +6 -0
- sqlsaber/clients/anthropic.py +285 -0
- sqlsaber/clients/base.py +31 -0
- sqlsaber/clients/exceptions.py +117 -0
- sqlsaber/clients/models.py +282 -0
- sqlsaber/clients/streaming.py +257 -0
- sqlsaber/config/api_keys.py +2 -3
- sqlsaber/config/auth.py +86 -0
- sqlsaber/config/database.py +20 -20
- sqlsaber/config/oauth_flow.py +274 -0
- sqlsaber/config/oauth_tokens.py +175 -0
- sqlsaber/config/settings.py +34 -23
- sqlsaber/database/connection.py +9 -9
- sqlsaber/database/schema.py +41 -24
- sqlsaber/mcp/mcp.py +3 -4
- sqlsaber/memory/manager.py +3 -5
- sqlsaber/memory/storage.py +7 -8
- sqlsaber/models/events.py +4 -4
- sqlsaber/models/types.py +10 -10
- {sqlsaber-0.6.0.dist-info → sqlsaber-0.8.0.dist-info}/METADATA +9 -8
- sqlsaber-0.8.0.dist-info/RECORD +46 -0
- sqlsaber-0.6.0.dist-info/RECORD +0 -35
- {sqlsaber-0.6.0.dist-info → sqlsaber-0.8.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.6.0.dist-info → sqlsaber-0.8.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.6.0.dist-info → sqlsaber-0.8.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/interactive.py
CHANGED
|
@@ -1,46 +1,21 @@
|
|
|
1
1
|
"""Interactive mode handling for the CLI."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from typing import Optional
|
|
5
4
|
|
|
6
5
|
import questionary
|
|
7
|
-
from prompt_toolkit.completion import Completer, Completion
|
|
8
6
|
from rich.console import Console
|
|
9
7
|
from rich.panel import Panel
|
|
10
8
|
|
|
11
9
|
from sqlsaber.agents.base import BaseSQLAgent
|
|
10
|
+
from sqlsaber.cli.completers import (
|
|
11
|
+
CompositeCompleter,
|
|
12
|
+
SlashCommandCompleter,
|
|
13
|
+
TableNameCompleter,
|
|
14
|
+
)
|
|
12
15
|
from sqlsaber.cli.display import DisplayManager
|
|
13
16
|
from sqlsaber.cli.streaming import StreamingQueryHandler
|
|
14
17
|
|
|
15
18
|
|
|
16
|
-
class SlashCommandCompleter(Completer):
|
|
17
|
-
"""Custom completer for slash commands."""
|
|
18
|
-
|
|
19
|
-
def get_completions(self, document, complete_event):
|
|
20
|
-
"""Get completions for slash commands."""
|
|
21
|
-
# Only provide completions if the line starts with "/"
|
|
22
|
-
text = document.text
|
|
23
|
-
if text.startswith("/"):
|
|
24
|
-
# Get the partial command after the slash
|
|
25
|
-
partial_cmd = text[1:]
|
|
26
|
-
|
|
27
|
-
# Define available commands with descriptions
|
|
28
|
-
commands = [
|
|
29
|
-
("clear", "Clear conversation history"),
|
|
30
|
-
("exit", "Exit the interactive session"),
|
|
31
|
-
("quit", "Exit the interactive session"),
|
|
32
|
-
]
|
|
33
|
-
|
|
34
|
-
# Yield completions that match the partial command
|
|
35
|
-
for cmd, description in commands:
|
|
36
|
-
if cmd.startswith(partial_cmd):
|
|
37
|
-
yield Completion(
|
|
38
|
-
cmd,
|
|
39
|
-
start_position=-len(partial_cmd),
|
|
40
|
-
display_meta=description,
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
|
|
44
19
|
class InteractiveSession:
|
|
45
20
|
"""Manages interactive CLI sessions."""
|
|
46
21
|
|
|
@@ -49,8 +24,9 @@ class InteractiveSession:
|
|
|
49
24
|
self.agent = agent
|
|
50
25
|
self.display = DisplayManager(console)
|
|
51
26
|
self.streaming_handler = StreamingQueryHandler(console)
|
|
52
|
-
self.current_task:
|
|
53
|
-
self.cancellation_token:
|
|
27
|
+
self.current_task: asyncio.Task | None = None
|
|
28
|
+
self.cancellation_token: asyncio.Event | None = None
|
|
29
|
+
self.table_completer = TableNameCompleter()
|
|
54
30
|
|
|
55
31
|
def show_welcome_message(self):
|
|
56
32
|
"""Display welcome message for interactive mode."""
|
|
@@ -63,7 +39,8 @@ class InteractiveSession:
|
|
|
63
39
|
"[bold green]SQLSaber - Use the agent Luke![/bold green]\n\n"
|
|
64
40
|
"[bold]Your agentic SQL assistant.[/bold]\n\n\n"
|
|
65
41
|
"[dim]Use '/clear' to reset conversation, '/exit' or '/quit' to leave.[/dim]\n\n"
|
|
66
|
-
"[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]"
|
|
42
|
+
"[dim]Start a message with '#' to add something to agent's memory for this database.[/dim]\n\n"
|
|
43
|
+
"[dim]Type '@' to get table name completions.[/dim]",
|
|
67
44
|
border_style="green",
|
|
68
45
|
)
|
|
69
46
|
)
|
|
@@ -75,6 +52,39 @@ class InteractiveSession:
|
|
|
75
52
|
"[dim]Press Ctrl+C during query execution to interrupt and return to prompt.[/dim]\n"
|
|
76
53
|
)
|
|
77
54
|
|
|
55
|
+
async def _update_table_cache(self):
|
|
56
|
+
"""Update the table completer cache with fresh data."""
|
|
57
|
+
try:
|
|
58
|
+
# Use the schema manager directly which has built-in caching
|
|
59
|
+
tables_data = await self.agent.schema_manager.list_tables()
|
|
60
|
+
|
|
61
|
+
# Parse the table information
|
|
62
|
+
table_list = []
|
|
63
|
+
if isinstance(tables_data, dict) and "tables" in tables_data:
|
|
64
|
+
for table in tables_data["tables"]:
|
|
65
|
+
if isinstance(table, dict):
|
|
66
|
+
name = table.get("name", "")
|
|
67
|
+
schema = table.get("schema", "")
|
|
68
|
+
full_name = table.get("full_name", "")
|
|
69
|
+
|
|
70
|
+
# Use full_name if available, otherwise construct it
|
|
71
|
+
if full_name:
|
|
72
|
+
table_name = full_name
|
|
73
|
+
elif schema and schema != "main":
|
|
74
|
+
table_name = f"{schema}.{name}"
|
|
75
|
+
else:
|
|
76
|
+
table_name = name
|
|
77
|
+
|
|
78
|
+
# No description needed - cleaner completions
|
|
79
|
+
table_list.append((table_name, ""))
|
|
80
|
+
|
|
81
|
+
# Update the completer cache
|
|
82
|
+
self.table_completer.update_cache(table_list)
|
|
83
|
+
|
|
84
|
+
except Exception:
|
|
85
|
+
# If there's an error, just use empty cache
|
|
86
|
+
self.table_completer.update_cache([])
|
|
87
|
+
|
|
78
88
|
async def _execute_query_with_cancellation(self, user_query: str):
|
|
79
89
|
"""Execute a query with cancellation support."""
|
|
80
90
|
# Create cancellation token
|
|
@@ -101,6 +111,9 @@ class InteractiveSession:
|
|
|
101
111
|
"""Run the interactive session loop."""
|
|
102
112
|
self.show_welcome_message()
|
|
103
113
|
|
|
114
|
+
# Initialize table cache
|
|
115
|
+
await self._update_table_cache()
|
|
116
|
+
|
|
104
117
|
while True:
|
|
105
118
|
try:
|
|
106
119
|
user_query = await questionary.text(
|
|
@@ -108,7 +121,9 @@ class InteractiveSession:
|
|
|
108
121
|
qmark="",
|
|
109
122
|
multiline=True,
|
|
110
123
|
instruction="",
|
|
111
|
-
completer=
|
|
124
|
+
completer=CompositeCompleter(
|
|
125
|
+
SlashCommandCompleter(), self.table_completer
|
|
126
|
+
),
|
|
112
127
|
).ask_async()
|
|
113
128
|
|
|
114
129
|
if not user_query:
|
sqlsaber/cli/memory.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Memory management CLI commands."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
3
|
import typer
|
|
6
4
|
from rich.console import Console
|
|
7
5
|
from rich.table import Table
|
|
@@ -22,7 +20,7 @@ memory_app = typer.Typer(
|
|
|
22
20
|
)
|
|
23
21
|
|
|
24
22
|
|
|
25
|
-
def _get_database_name(database:
|
|
23
|
+
def _get_database_name(database: str | None = None) -> str:
|
|
26
24
|
"""Get the database name to use, either specified or default."""
|
|
27
25
|
if database:
|
|
28
26
|
db_config = config_manager.get_database(database)
|
|
@@ -46,7 +44,7 @@ def _get_database_name(database: Optional[str] = None) -> str:
|
|
|
46
44
|
@memory_app.command("add")
|
|
47
45
|
def add_memory(
|
|
48
46
|
content: str = typer.Argument(..., help="Memory content to add"),
|
|
49
|
-
database:
|
|
47
|
+
database: str | None = typer.Option(
|
|
50
48
|
None,
|
|
51
49
|
"--database",
|
|
52
50
|
"-d",
|
|
@@ -68,7 +66,7 @@ def add_memory(
|
|
|
68
66
|
|
|
69
67
|
@memory_app.command("list")
|
|
70
68
|
def list_memories(
|
|
71
|
-
database:
|
|
69
|
+
database: str | None = typer.Option(
|
|
72
70
|
None,
|
|
73
71
|
"--database",
|
|
74
72
|
"-d",
|
|
@@ -107,7 +105,7 @@ def list_memories(
|
|
|
107
105
|
@memory_app.command("show")
|
|
108
106
|
def show_memory(
|
|
109
107
|
memory_id: str = typer.Argument(..., help="Memory ID to show"),
|
|
110
|
-
database:
|
|
108
|
+
database: str | None = typer.Option(
|
|
111
109
|
None,
|
|
112
110
|
"--database",
|
|
113
111
|
"-d",
|
|
@@ -135,7 +133,7 @@ def show_memory(
|
|
|
135
133
|
@memory_app.command("remove")
|
|
136
134
|
def remove_memory(
|
|
137
135
|
memory_id: str = typer.Argument(..., help="Memory ID to remove"),
|
|
138
|
-
database:
|
|
136
|
+
database: str | None = typer.Option(
|
|
139
137
|
None,
|
|
140
138
|
"--database",
|
|
141
139
|
"-d",
|
|
@@ -170,7 +168,7 @@ def remove_memory(
|
|
|
170
168
|
|
|
171
169
|
@memory_app.command("clear")
|
|
172
170
|
def clear_memories(
|
|
173
|
-
database:
|
|
171
|
+
database: str | None = typer.Option(
|
|
174
172
|
None,
|
|
175
173
|
"--database",
|
|
176
174
|
"-d",
|
|
@@ -213,7 +211,7 @@ def clear_memories(
|
|
|
213
211
|
|
|
214
212
|
@memory_app.command("summary")
|
|
215
213
|
def memory_summary(
|
|
216
|
-
database:
|
|
214
|
+
database: str | None = typer.Option(
|
|
217
215
|
None,
|
|
218
216
|
"--database",
|
|
219
217
|
"-d",
|
sqlsaber/cli/models.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Model management CLI commands."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from typing import Dict, List
|
|
5
4
|
|
|
6
5
|
import httpx
|
|
7
6
|
import questionary
|
|
@@ -28,7 +27,7 @@ class ModelManager:
|
|
|
28
27
|
DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
|
|
29
28
|
MODELS_API_URL = "https://models.dev/api.json"
|
|
30
29
|
|
|
31
|
-
async def fetch_available_models(self) ->
|
|
30
|
+
async def fetch_available_models(self) -> list[dict]:
|
|
32
31
|
"""Fetch available models from models.dev API."""
|
|
33
32
|
try:
|
|
34
33
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
sqlsaber/cli/streaming.py
CHANGED
|
@@ -23,8 +23,6 @@ class StreamingQueryHandler:
|
|
|
23
23
|
):
|
|
24
24
|
"""Execute a query with streaming display."""
|
|
25
25
|
|
|
26
|
-
has_content = False
|
|
27
|
-
explanation_started = False
|
|
28
26
|
status = self.console.status(
|
|
29
27
|
"[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
|
|
30
28
|
)
|
|
@@ -38,15 +36,10 @@ class StreamingQueryHandler:
|
|
|
38
36
|
break
|
|
39
37
|
|
|
40
38
|
if event.type == "tool_use":
|
|
41
|
-
# Stop any ongoing status, but don't mark has_content yet
|
|
42
39
|
self._stop_status(status)
|
|
43
40
|
|
|
44
|
-
if event.data["status"] == "
|
|
45
|
-
|
|
46
|
-
if explanation_started:
|
|
47
|
-
self.display.show_newline()
|
|
48
|
-
self.display.show_tool_started(event.data["name"])
|
|
49
|
-
elif event.data["status"] == "executing":
|
|
41
|
+
if event.data["status"] == "executing":
|
|
42
|
+
self.display.show_newline()
|
|
50
43
|
self.display.show_tool_executing(
|
|
51
44
|
event.data["name"], event.data["input"]
|
|
52
45
|
)
|
|
@@ -54,12 +47,6 @@ class StreamingQueryHandler:
|
|
|
54
47
|
elif event.type == "text":
|
|
55
48
|
# Always stop status when text streaming starts
|
|
56
49
|
self._stop_status(status)
|
|
57
|
-
|
|
58
|
-
if not explanation_started:
|
|
59
|
-
explanation_started = True
|
|
60
|
-
has_content = True
|
|
61
|
-
|
|
62
|
-
# Print text as it streams
|
|
63
50
|
self.display.show_text_stream(event.data)
|
|
64
51
|
|
|
65
52
|
elif event.type == "query_result":
|
|
@@ -70,45 +57,40 @@ class StreamingQueryHandler:
|
|
|
70
57
|
# Handle tool results - particularly list_tables and introspect_schema
|
|
71
58
|
if event.data.get("tool_name") == "list_tables":
|
|
72
59
|
self.display.show_table_list(event.data["result"])
|
|
73
|
-
has_content = True
|
|
74
60
|
elif event.data.get("tool_name") == "introspect_schema":
|
|
75
61
|
self.display.show_schema_info(event.data["result"])
|
|
76
|
-
has_content = True
|
|
77
62
|
|
|
78
63
|
elif event.type == "plot_result":
|
|
79
64
|
# Handle plot results
|
|
80
65
|
self.display.show_plot(event.data)
|
|
81
|
-
has_content = True
|
|
82
66
|
|
|
83
67
|
elif event.type == "processing":
|
|
84
|
-
#
|
|
85
|
-
if explanation_started:
|
|
86
|
-
self.display.show_newline() # Add newline after explanation text
|
|
68
|
+
self.display.show_newline() # Add newline after explanation text
|
|
87
69
|
self._stop_status(status)
|
|
88
70
|
status = self.display.show_processing(event.data)
|
|
89
71
|
status.start()
|
|
90
|
-
has_content = True
|
|
91
72
|
|
|
92
73
|
elif event.type == "error":
|
|
93
|
-
|
|
94
|
-
self._stop_status(status)
|
|
95
|
-
has_content = True
|
|
74
|
+
self._stop_status(status)
|
|
96
75
|
self.display.show_error(event.data)
|
|
97
76
|
|
|
98
77
|
except asyncio.CancelledError:
|
|
99
78
|
# Handle cancellation gracefully
|
|
100
79
|
self._stop_status(status)
|
|
101
|
-
|
|
102
|
-
self.display.show_newline()
|
|
80
|
+
self.display.show_newline()
|
|
103
81
|
self.console.print("[yellow]Query interrupted[/yellow]")
|
|
104
82
|
return
|
|
105
83
|
finally:
|
|
106
84
|
# Make sure status is stopped
|
|
107
85
|
self._stop_status(status)
|
|
108
86
|
|
|
109
|
-
#
|
|
110
|
-
if
|
|
111
|
-
|
|
87
|
+
# Display the last assistant response as markdown
|
|
88
|
+
if hasattr(agent, "conversation_history") and agent.conversation_history:
|
|
89
|
+
last_message = agent.conversation_history[-1]
|
|
90
|
+
if last_message.get("role") == "assistant" and last_message.get(
|
|
91
|
+
"content"
|
|
92
|
+
):
|
|
93
|
+
self.display.show_markdown_response(last_message["content"])
|
|
112
94
|
|
|
113
95
|
def _stop_status(self, status):
|
|
114
96
|
"""Safely stop a status spinner."""
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""Anthropic API client implementation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, AsyncIterator
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from .base import BaseLLMClient
|
|
11
|
+
from .exceptions import LLMClientError, create_exception_from_response
|
|
12
|
+
from .models import CreateMessageRequest
|
|
13
|
+
from .streaming import AnthropicStreamAdapter, StreamingResponse
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AnthropicClient(BaseLLMClient):
|
|
19
|
+
"""Client for Anthropic's Claude API."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
api_key: str | None = None,
|
|
24
|
+
oauth_token: str | None = None,
|
|
25
|
+
base_url: str | None = None,
|
|
26
|
+
):
|
|
27
|
+
"""Initialize the Anthropic client.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
api_key: Anthropic API key
|
|
31
|
+
base_url: Base URL for the API (defaults to Anthropic's API)
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(api_key or "", base_url)
|
|
34
|
+
|
|
35
|
+
if not api_key and not oauth_token:
|
|
36
|
+
raise ValueError("Either api_key or oauth_token must be provided")
|
|
37
|
+
|
|
38
|
+
self.oauth_token = oauth_token
|
|
39
|
+
self.use_oauth = oauth_token is not None
|
|
40
|
+
self.base_url = base_url or "https://api.anthropic.com"
|
|
41
|
+
self.client: httpx.AsyncClient | None = None
|
|
42
|
+
|
|
43
|
+
def _get_client(self) -> httpx.AsyncClient:
|
|
44
|
+
"""Get or create the HTTP client."""
|
|
45
|
+
if self.client is None or self.client.is_closed:
|
|
46
|
+
# Configure timeouts and connection limits for reliability
|
|
47
|
+
timeout = httpx.Timeout(
|
|
48
|
+
connect=10.0, # Connection timeout
|
|
49
|
+
read=60.0, # Read timeout for streaming
|
|
50
|
+
write=10.0, # Write timeout
|
|
51
|
+
pool=10.0, # Pool timeout
|
|
52
|
+
)
|
|
53
|
+
limits = httpx.Limits(
|
|
54
|
+
max_keepalive_connections=20, max_connections=100, keepalive_expiry=30.0
|
|
55
|
+
)
|
|
56
|
+
self.client = httpx.AsyncClient(
|
|
57
|
+
timeout=timeout, limits=limits, follow_redirects=True
|
|
58
|
+
)
|
|
59
|
+
return self.client
|
|
60
|
+
|
|
61
|
+
def _get_headers(self) -> dict[str, str]:
|
|
62
|
+
"""Get the standard headers for API requests."""
|
|
63
|
+
if self.use_oauth:
|
|
64
|
+
# OAuth headers for Claude Pro authentication (matching Claude Code CLI)
|
|
65
|
+
return {
|
|
66
|
+
"Authorization": f"Bearer {self.oauth_token}",
|
|
67
|
+
"Content-Type": "application/json",
|
|
68
|
+
"anthropic-version": "2023-06-01",
|
|
69
|
+
"anthropic-beta": "oauth-2025-04-20",
|
|
70
|
+
"User-Agent": "ClaudeCode/1.0 (Anthropic Claude Code CLI)",
|
|
71
|
+
"Accept": "application/json",
|
|
72
|
+
"X-Client-Name": "claude-code",
|
|
73
|
+
"X-Client-Version": "1.0.0",
|
|
74
|
+
}
|
|
75
|
+
else:
|
|
76
|
+
# API key headers for standard authentication
|
|
77
|
+
return {
|
|
78
|
+
"x-api-key": self.api_key,
|
|
79
|
+
"anthropic-version": "2023-06-01",
|
|
80
|
+
"content-type": "application/json",
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
async def create_message_with_tools(
|
|
84
|
+
self,
|
|
85
|
+
request: CreateMessageRequest,
|
|
86
|
+
cancellation_token: asyncio.Event | None = None,
|
|
87
|
+
) -> AsyncIterator[Any]:
|
|
88
|
+
"""Create a message with tool support and stream the response.
|
|
89
|
+
|
|
90
|
+
This method handles the full message creation flow including tool use,
|
|
91
|
+
similar to what the current AnthropicSQLAgent expects.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
request: The message creation request
|
|
95
|
+
cancellation_token: Optional event to signal cancellation
|
|
96
|
+
|
|
97
|
+
Yields:
|
|
98
|
+
Stream events and final StreamingResponse
|
|
99
|
+
"""
|
|
100
|
+
request.stream = True
|
|
101
|
+
|
|
102
|
+
client = self._get_client()
|
|
103
|
+
url = f"{self.base_url}/v1/messages"
|
|
104
|
+
headers = self._get_headers()
|
|
105
|
+
data = request.to_dict()
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
async with client.stream(
|
|
109
|
+
"POST", url, headers=headers, json=data
|
|
110
|
+
) as response:
|
|
111
|
+
request_id = response.headers.get("request-id")
|
|
112
|
+
|
|
113
|
+
if response.status_code != 200:
|
|
114
|
+
response_content = await response.aread()
|
|
115
|
+
response_data = json.loads(response_content.decode())
|
|
116
|
+
raise create_exception_from_response(
|
|
117
|
+
response.status_code, response_data, request_id
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Use stream adapter to convert raw events and track state
|
|
121
|
+
adapter = AnthropicStreamAdapter()
|
|
122
|
+
raw_stream = self._process_sse_stream(response, cancellation_token)
|
|
123
|
+
|
|
124
|
+
async for event in adapter.process_stream(
|
|
125
|
+
raw_stream, cancellation_token
|
|
126
|
+
):
|
|
127
|
+
yield event
|
|
128
|
+
|
|
129
|
+
# Create final response object with proper state
|
|
130
|
+
response_obj = StreamingResponse(
|
|
131
|
+
content=adapter.get_content_blocks(),
|
|
132
|
+
stop_reason=adapter.get_stop_reason(),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Yield special event with response
|
|
136
|
+
yield {"type": "response_ready", "data": response_obj}
|
|
137
|
+
|
|
138
|
+
except asyncio.CancelledError:
|
|
139
|
+
# Handle cancellation gracefully
|
|
140
|
+
logger.debug("Stream cancelled")
|
|
141
|
+
return
|
|
142
|
+
except Exception as e:
|
|
143
|
+
if not isinstance(e, LLMClientError):
|
|
144
|
+
raise LLMClientError(f"Stream processing error: {str(e)}")
|
|
145
|
+
raise
|
|
146
|
+
|
|
147
|
+
def _handle_ping_event(self, event_data: str) -> dict[str, Any]:
|
|
148
|
+
"""Handle ping event data.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
event_data: Raw event data string
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Parsed ping event
|
|
155
|
+
"""
|
|
156
|
+
try:
|
|
157
|
+
return {"type": "ping", "data": json.loads(event_data)}
|
|
158
|
+
except json.JSONDecodeError:
|
|
159
|
+
return {"type": "ping", "data": {}}
|
|
160
|
+
|
|
161
|
+
def _handle_error_event(self, event_data: str) -> None:
|
|
162
|
+
"""Handle error event data.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
event_data: Raw event data string
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
LLMClientError: Always raises with error details
|
|
169
|
+
"""
|
|
170
|
+
try:
|
|
171
|
+
error_data = json.loads(event_data)
|
|
172
|
+
raise LLMClientError(
|
|
173
|
+
error_data.get("message", "Stream error"),
|
|
174
|
+
error_data.get("type", "stream_error"),
|
|
175
|
+
)
|
|
176
|
+
except json.JSONDecodeError:
|
|
177
|
+
raise LLMClientError("Stream error with invalid JSON")
|
|
178
|
+
|
|
179
|
+
def _parse_event_data(
|
|
180
|
+
self, event_type: str | None, event_data: str
|
|
181
|
+
) -> dict[str, Any] | None:
|
|
182
|
+
"""Parse event data based on event type.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
event_type: Type of the event
|
|
186
|
+
event_data: Raw event data string
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Parsed event or None if parsing failed
|
|
190
|
+
"""
|
|
191
|
+
try:
|
|
192
|
+
parsed_data = json.loads(event_data)
|
|
193
|
+
return {"type": event_type, "data": parsed_data}
|
|
194
|
+
except json.JSONDecodeError as e:
|
|
195
|
+
logger.warning(f"Failed to parse stream data for event {event_type}: {e}")
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
def _process_sse_line(
|
|
199
|
+
self, line: str, event_type: str | None
|
|
200
|
+
) -> tuple[str | None, dict[str, Any] | None]:
|
|
201
|
+
"""Process a single SSE line.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
line: Line to process
|
|
205
|
+
event_type: Current event type
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Tuple of (new_event_type, event_to_yield)
|
|
209
|
+
"""
|
|
210
|
+
if line.startswith("event: "):
|
|
211
|
+
return line[7:], None
|
|
212
|
+
elif line.startswith("data: "):
|
|
213
|
+
event_data = line[6:]
|
|
214
|
+
|
|
215
|
+
if event_type == "ping":
|
|
216
|
+
return event_type, self._handle_ping_event(event_data)
|
|
217
|
+
elif event_type == "error":
|
|
218
|
+
self._handle_error_event(event_data)
|
|
219
|
+
return event_type, None # Never reached due to exception
|
|
220
|
+
else:
|
|
221
|
+
parsed_event = self._parse_event_data(event_type, event_data)
|
|
222
|
+
return event_type, parsed_event
|
|
223
|
+
|
|
224
|
+
return event_type, None
|
|
225
|
+
|
|
226
|
+
async def _process_sse_stream(
|
|
227
|
+
self,
|
|
228
|
+
response: httpx.Response,
|
|
229
|
+
cancellation_token: asyncio.Event | None = None,
|
|
230
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
231
|
+
"""Process server-sent events from the response stream.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
response: The HTTP response object
|
|
235
|
+
cancellation_token: Optional event to signal cancellation
|
|
236
|
+
|
|
237
|
+
Yields:
|
|
238
|
+
Parsed stream events
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
LLMClientError: If stream processing fails
|
|
242
|
+
"""
|
|
243
|
+
buffer = ""
|
|
244
|
+
event_type = None
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
async for chunk in response.aiter_bytes():
|
|
248
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
buffer += chunk.decode("utf-8")
|
|
253
|
+
except UnicodeDecodeError as e:
|
|
254
|
+
logger.warning(f"Failed to decode chunk: {e}")
|
|
255
|
+
continue
|
|
256
|
+
|
|
257
|
+
while "\n" in buffer:
|
|
258
|
+
line, buffer = buffer.split("\n", 1)
|
|
259
|
+
line = line.strip()
|
|
260
|
+
|
|
261
|
+
if not line:
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
event_type, event_to_yield = self._process_sse_line(
|
|
265
|
+
line, event_type
|
|
266
|
+
)
|
|
267
|
+
if event_to_yield is not None:
|
|
268
|
+
yield event_to_yield
|
|
269
|
+
|
|
270
|
+
except httpx.TimeoutException as e:
|
|
271
|
+
raise LLMClientError(f"Stream timeout error: {str(e)}")
|
|
272
|
+
except httpx.NetworkError as e:
|
|
273
|
+
raise LLMClientError(f"Network error during streaming: {str(e)}")
|
|
274
|
+
except httpx.HTTPError as e:
|
|
275
|
+
raise LLMClientError(f"HTTP error during streaming: {str(e)}")
|
|
276
|
+
except asyncio.TimeoutError:
|
|
277
|
+
raise LLMClientError("Stream timeout")
|
|
278
|
+
except Exception as e:
|
|
279
|
+
raise LLMClientError(f"Unexpected error during streaming: {str(e)}")
|
|
280
|
+
|
|
281
|
+
async def close(self):
|
|
282
|
+
"""Close the HTTP client."""
|
|
283
|
+
if self.client and not self.client.is_closed:
|
|
284
|
+
await self.client.aclose()
|
|
285
|
+
self.client = None
|
sqlsaber/clients/base.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Abstract base class for LLM clients."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseLLMClient(ABC):
|
|
7
|
+
"""Abstract base class for LLM API clients."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, api_key: str, base_url: str | None = None):
|
|
10
|
+
"""Initialize the client with API key and optional base URL.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
api_key: API key for authentication
|
|
14
|
+
base_url: Base URL for the API (optional, uses default if not provided)
|
|
15
|
+
"""
|
|
16
|
+
self.api_key = api_key
|
|
17
|
+
self.base_url = base_url
|
|
18
|
+
|
|
19
|
+
async def close(self):
|
|
20
|
+
"""Close the client and clean up resources."""
|
|
21
|
+
# Default implementation does nothing
|
|
22
|
+
# Subclasses can override to clean up HTTP sessions, etc.
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
async def __aenter__(self):
|
|
26
|
+
"""Async context manager entry."""
|
|
27
|
+
return self
|
|
28
|
+
|
|
29
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
30
|
+
"""Async context manager exit."""
|
|
31
|
+
await self.close()
|