sqlsaber 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/anthropic.py +283 -176
- sqlsaber/agents/base.py +11 -11
- sqlsaber/agents/streaming.py +3 -3
- sqlsaber/cli/auth.py +142 -0
- sqlsaber/cli/commands.py +9 -4
- sqlsaber/cli/completers.py +170 -0
- sqlsaber/cli/database.py +9 -10
- sqlsaber/cli/display.py +27 -7
- sqlsaber/cli/interactive.py +49 -34
- sqlsaber/cli/memory.py +7 -9
- sqlsaber/cli/models.py +1 -2
- sqlsaber/cli/streaming.py +12 -30
- sqlsaber/clients/__init__.py +6 -0
- sqlsaber/clients/anthropic.py +285 -0
- sqlsaber/clients/base.py +31 -0
- sqlsaber/clients/exceptions.py +117 -0
- sqlsaber/clients/models.py +282 -0
- sqlsaber/clients/streaming.py +257 -0
- sqlsaber/config/api_keys.py +2 -3
- sqlsaber/config/auth.py +86 -0
- sqlsaber/config/database.py +20 -20
- sqlsaber/config/oauth_flow.py +274 -0
- sqlsaber/config/oauth_tokens.py +175 -0
- sqlsaber/config/settings.py +34 -23
- sqlsaber/database/connection.py +9 -9
- sqlsaber/database/schema.py +41 -24
- sqlsaber/mcp/mcp.py +3 -4
- sqlsaber/memory/manager.py +3 -5
- sqlsaber/memory/storage.py +7 -8
- sqlsaber/models/events.py +4 -4
- sqlsaber/models/types.py +10 -10
- {sqlsaber-0.6.0.dist-info → sqlsaber-0.8.0.dist-info}/METADATA +9 -8
- sqlsaber-0.8.0.dist-info/RECORD +46 -0
- sqlsaber-0.6.0.dist-info/RECORD +0 -35
- {sqlsaber-0.6.0.dist-info → sqlsaber-0.8.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.6.0.dist-info → sqlsaber-0.8.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.6.0.dist-info → sqlsaber-0.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Exception classes for LLM client errors."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LLMClientError(Exception):
|
|
7
|
+
"""Base exception for LLM client errors."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
message: str,
|
|
12
|
+
error_type: str | None = None,
|
|
13
|
+
status_code: int | None = None,
|
|
14
|
+
request_id: str | None = None,
|
|
15
|
+
):
|
|
16
|
+
super().__init__(message)
|
|
17
|
+
self.error_type = error_type
|
|
18
|
+
self.status_code = status_code
|
|
19
|
+
self.request_id = request_id
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AuthenticationError(LLMClientError):
|
|
23
|
+
"""Authentication failed - invalid API key."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, message: str = "Invalid API key", **kwargs):
|
|
26
|
+
super().__init__(message, "authentication_error", **kwargs)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PermissionError(LLMClientError):
|
|
30
|
+
"""Permission denied for the requested resource."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, message: str = "Permission denied", **kwargs):
|
|
33
|
+
super().__init__(message, "permission_error", **kwargs)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class NotFoundError(LLMClientError):
|
|
37
|
+
"""Requested resource not found."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, message: str = "Resource not found", **kwargs):
|
|
40
|
+
super().__init__(message, "not_found_error", **kwargs)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class InvalidRequestError(LLMClientError):
|
|
44
|
+
"""Invalid request format or content."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, message: str = "Invalid request", **kwargs):
|
|
47
|
+
super().__init__(message, "invalid_request_error", **kwargs)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RequestTooLargeError(LLMClientError):
|
|
51
|
+
"""Request exceeds maximum allowed size."""
|
|
52
|
+
|
|
53
|
+
def __init__(self, message: str = "Request too large", **kwargs):
|
|
54
|
+
super().__init__(message, "request_too_large", **kwargs)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class RateLimitError(LLMClientError):
|
|
58
|
+
"""Rate limit exceeded."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, message: str = "Rate limit exceeded", **kwargs):
|
|
61
|
+
super().__init__(message, "rate_limit_error", **kwargs)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class APIError(LLMClientError):
|
|
65
|
+
"""Internal API error."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, message: str = "Internal API error", **kwargs):
|
|
68
|
+
super().__init__(message, "api_error", **kwargs)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class OverloadedError(LLMClientError):
|
|
72
|
+
"""API is temporarily overloaded."""
|
|
73
|
+
|
|
74
|
+
def __init__(self, message: str = "API temporarily overloaded", **kwargs):
|
|
75
|
+
super().__init__(message, "overloaded_error", **kwargs)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Mapping of HTTP status codes to exception classes
|
|
79
|
+
STATUS_CODE_TO_EXCEPTION = {
|
|
80
|
+
400: InvalidRequestError,
|
|
81
|
+
401: AuthenticationError,
|
|
82
|
+
403: PermissionError,
|
|
83
|
+
404: NotFoundError,
|
|
84
|
+
413: RequestTooLargeError,
|
|
85
|
+
429: RateLimitError,
|
|
86
|
+
500: APIError,
|
|
87
|
+
529: OverloadedError,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def create_exception_from_response(
|
|
92
|
+
status_code: int,
|
|
93
|
+
response_data: dict[str, Any],
|
|
94
|
+
request_id: str | None = None,
|
|
95
|
+
) -> LLMClientError:
|
|
96
|
+
"""Create appropriate exception from HTTP response."""
|
|
97
|
+
error_data = response_data.get("error", {})
|
|
98
|
+
message = error_data.get("message", f"HTTP {status_code} error")
|
|
99
|
+
error_type = error_data.get("type")
|
|
100
|
+
|
|
101
|
+
exception_class = STATUS_CODE_TO_EXCEPTION.get(status_code, LLMClientError)
|
|
102
|
+
|
|
103
|
+
# Handle base vs subclass constructors
|
|
104
|
+
if exception_class == LLMClientError:
|
|
105
|
+
return exception_class(
|
|
106
|
+
message,
|
|
107
|
+
error_type,
|
|
108
|
+
status_code,
|
|
109
|
+
request_id,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
# Subclasses only take message and **kwargs
|
|
113
|
+
return exception_class(
|
|
114
|
+
message,
|
|
115
|
+
status_code=status_code,
|
|
116
|
+
request_id=request_id,
|
|
117
|
+
)
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""Data models for LLM client requests and responses."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MessageRole(str, Enum):
|
|
9
|
+
"""Message roles in a conversation."""
|
|
10
|
+
|
|
11
|
+
USER = "user"
|
|
12
|
+
ASSISTANT = "assistant"
|
|
13
|
+
SYSTEM = "system"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ContentType(str, Enum):
|
|
17
|
+
"""Content block types."""
|
|
18
|
+
|
|
19
|
+
TEXT = "text"
|
|
20
|
+
IMAGE = "image"
|
|
21
|
+
TOOL_USE = "tool_use"
|
|
22
|
+
TOOL_RESULT = "tool_result"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ToolChoiceType(str, Enum):
|
|
26
|
+
"""Tool choice types."""
|
|
27
|
+
|
|
28
|
+
AUTO = "auto"
|
|
29
|
+
ANY = "any"
|
|
30
|
+
TOOL = "tool"
|
|
31
|
+
NONE = "none"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class StopReason(str, Enum):
|
|
35
|
+
"""Stop reasons for message completion."""
|
|
36
|
+
|
|
37
|
+
END_TURN = "end_turn"
|
|
38
|
+
MAX_TOKENS = "max_tokens"
|
|
39
|
+
STOP_SEQUENCE = "stop_sequence"
|
|
40
|
+
TOOL_USE = "tool_use"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class ContentBlock:
|
|
45
|
+
"""A content block in a message."""
|
|
46
|
+
|
|
47
|
+
type: ContentType
|
|
48
|
+
content: str | dict[str, Any]
|
|
49
|
+
|
|
50
|
+
def to_dict(self) -> dict[str, Any]:
|
|
51
|
+
"""Convert to dictionary format."""
|
|
52
|
+
if self.type == ContentType.TEXT:
|
|
53
|
+
return {"type": "text", "text": self.content}
|
|
54
|
+
elif self.type == ContentType.TOOL_USE:
|
|
55
|
+
return {
|
|
56
|
+
"type": "tool_use",
|
|
57
|
+
"id": self.content["id"],
|
|
58
|
+
"name": self.content["name"],
|
|
59
|
+
"input": self.content["input"],
|
|
60
|
+
}
|
|
61
|
+
elif self.type == ContentType.TOOL_RESULT:
|
|
62
|
+
return {
|
|
63
|
+
"type": "tool_result",
|
|
64
|
+
"tool_use_id": self.content["tool_use_id"],
|
|
65
|
+
"content": self.content["content"],
|
|
66
|
+
}
|
|
67
|
+
else:
|
|
68
|
+
return {"type": self.type.value, **self.content}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class Message:
|
|
73
|
+
"""A message in a conversation."""
|
|
74
|
+
|
|
75
|
+
role: MessageRole
|
|
76
|
+
content: str | list[ContentBlock]
|
|
77
|
+
|
|
78
|
+
def to_dict(self) -> dict[str, Any]:
|
|
79
|
+
"""Convert to dictionary format for API requests."""
|
|
80
|
+
if isinstance(self.content, str):
|
|
81
|
+
return {"role": self.role.value, "content": self.content}
|
|
82
|
+
else:
|
|
83
|
+
return {
|
|
84
|
+
"role": self.role.value,
|
|
85
|
+
"content": [block.to_dict() for block in self.content],
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class ToolDefinition:
|
|
91
|
+
"""Definition of a tool that can be called."""
|
|
92
|
+
|
|
93
|
+
name: str
|
|
94
|
+
description: str
|
|
95
|
+
input_schema: dict[str, Any]
|
|
96
|
+
|
|
97
|
+
def to_dict(self) -> dict[str, Any]:
|
|
98
|
+
"""Convert to dictionary format for API requests."""
|
|
99
|
+
return {
|
|
100
|
+
"name": self.name,
|
|
101
|
+
"description": self.description,
|
|
102
|
+
"input_schema": self.input_schema,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class ToolChoice:
|
|
108
|
+
"""Tool choice configuration."""
|
|
109
|
+
|
|
110
|
+
type: ToolChoiceType
|
|
111
|
+
name: str | None = None
|
|
112
|
+
disable_parallel_tool_use: bool = False
|
|
113
|
+
|
|
114
|
+
def to_dict(self) -> dict[str, Any]:
|
|
115
|
+
"""Convert to dictionary format for API requests."""
|
|
116
|
+
result = {"type": self.type.value}
|
|
117
|
+
if self.name:
|
|
118
|
+
result["name"] = self.name
|
|
119
|
+
if self.disable_parallel_tool_use:
|
|
120
|
+
result["disable_parallel_tool_use"] = True
|
|
121
|
+
return result
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class CreateMessageRequest:
|
|
126
|
+
"""Request to create a message."""
|
|
127
|
+
|
|
128
|
+
model: str
|
|
129
|
+
messages: list[Message]
|
|
130
|
+
max_tokens: int
|
|
131
|
+
system: str | None = None
|
|
132
|
+
tools: list[ToolDefinition] | None = None
|
|
133
|
+
tool_choice: ToolChoice | None = None
|
|
134
|
+
temperature: float | None = None
|
|
135
|
+
stream: bool = False
|
|
136
|
+
stop_sequences: list[str] | None = None
|
|
137
|
+
|
|
138
|
+
def to_dict(self) -> dict[str, Any]:
|
|
139
|
+
"""Convert to dictionary format for API requests."""
|
|
140
|
+
data = {
|
|
141
|
+
"model": self.model,
|
|
142
|
+
"messages": [msg.to_dict() for msg in self.messages],
|
|
143
|
+
"max_tokens": self.max_tokens,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
if self.system:
|
|
147
|
+
data["system"] = self.system
|
|
148
|
+
if self.tools:
|
|
149
|
+
data["tools"] = [tool.to_dict() for tool in self.tools]
|
|
150
|
+
if self.tool_choice:
|
|
151
|
+
data["tool_choice"] = self.tool_choice.to_dict()
|
|
152
|
+
if self.temperature is not None:
|
|
153
|
+
data["temperature"] = self.temperature
|
|
154
|
+
if self.stream:
|
|
155
|
+
data["stream"] = True
|
|
156
|
+
if self.stop_sequences:
|
|
157
|
+
data["stop_sequences"] = self.stop_sequences
|
|
158
|
+
|
|
159
|
+
return data
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@dataclass
|
|
163
|
+
class Usage:
|
|
164
|
+
"""Token usage information."""
|
|
165
|
+
|
|
166
|
+
input_tokens: int
|
|
167
|
+
output_tokens: int
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclass
|
|
171
|
+
class MessageResponse:
|
|
172
|
+
"""Response from message creation."""
|
|
173
|
+
|
|
174
|
+
id: str
|
|
175
|
+
model: str
|
|
176
|
+
role: MessageRole
|
|
177
|
+
content: list[ContentBlock]
|
|
178
|
+
stop_reason: StopReason
|
|
179
|
+
stop_sequence: str | None
|
|
180
|
+
usage: Usage
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
def from_dict(cls, data: dict[str, Any]) -> "MessageResponse":
|
|
184
|
+
"""Create from API response dictionary."""
|
|
185
|
+
content_blocks = []
|
|
186
|
+
for block_data in data["content"]:
|
|
187
|
+
if block_data["type"] == "text":
|
|
188
|
+
content_blocks.append(
|
|
189
|
+
ContentBlock(ContentType.TEXT, block_data["text"])
|
|
190
|
+
)
|
|
191
|
+
elif block_data["type"] == "tool_use":
|
|
192
|
+
content_blocks.append(
|
|
193
|
+
ContentBlock(
|
|
194
|
+
ContentType.TOOL_USE,
|
|
195
|
+
{
|
|
196
|
+
"id": block_data["id"],
|
|
197
|
+
"name": block_data["name"],
|
|
198
|
+
"input": block_data["input"],
|
|
199
|
+
},
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return cls(
|
|
204
|
+
id=data["id"],
|
|
205
|
+
model=data["model"],
|
|
206
|
+
role=MessageRole(data["role"]),
|
|
207
|
+
content=content_blocks,
|
|
208
|
+
stop_reason=StopReason(data["stop_reason"]),
|
|
209
|
+
stop_sequence=data.get("stop_sequence"),
|
|
210
|
+
usage=Usage(
|
|
211
|
+
input_tokens=data["usage"]["input_tokens"],
|
|
212
|
+
output_tokens=data["usage"]["output_tokens"],
|
|
213
|
+
),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# Stream event types and data models for streaming
|
|
218
|
+
@dataclass
|
|
219
|
+
class StreamEventData:
|
|
220
|
+
"""Base class for stream event data."""
|
|
221
|
+
|
|
222
|
+
pass
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@dataclass
|
|
226
|
+
class TextDeltaData(StreamEventData):
|
|
227
|
+
"""Text delta stream event data."""
|
|
228
|
+
|
|
229
|
+
text: str
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@dataclass
|
|
233
|
+
class ToolUseStartData(StreamEventData):
|
|
234
|
+
"""Tool use start stream event data."""
|
|
235
|
+
|
|
236
|
+
id: str
|
|
237
|
+
name: str
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@dataclass
|
|
241
|
+
class ToolUseInputData(StreamEventData):
|
|
242
|
+
"""Tool use input stream event data."""
|
|
243
|
+
|
|
244
|
+
partial_json: str
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@dataclass
|
|
248
|
+
class MessageStartData(StreamEventData):
|
|
249
|
+
"""Message start stream event data."""
|
|
250
|
+
|
|
251
|
+
message: dict[str, Any]
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@dataclass
|
|
255
|
+
class MessageDeltaData(StreamEventData):
|
|
256
|
+
"""Message delta stream event data."""
|
|
257
|
+
|
|
258
|
+
delta: dict[str, Any]
|
|
259
|
+
usage: dict[str, Any] | None = None
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@dataclass
|
|
263
|
+
class ContentBlockStartData(StreamEventData):
|
|
264
|
+
"""Content block start stream event data."""
|
|
265
|
+
|
|
266
|
+
index: int
|
|
267
|
+
content_block: dict[str, Any]
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dataclass
|
|
271
|
+
class ContentBlockDeltaData(StreamEventData):
|
|
272
|
+
"""Content block delta stream event data."""
|
|
273
|
+
|
|
274
|
+
index: int
|
|
275
|
+
delta: dict[str, Any]
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@dataclass
|
|
279
|
+
class ContentBlockStopData(StreamEventData):
|
|
280
|
+
"""Content block stop stream event data."""
|
|
281
|
+
|
|
282
|
+
index: int
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""Streaming adapters and utilities for LLM clients."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, AsyncIterator
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AnthropicStreamAdapter:
|
|
12
|
+
"""Adapter to convert raw Anthropic stream events to standardized format.
|
|
13
|
+
|
|
14
|
+
This adapter converts the raw SSE events from Anthropic API into objects
|
|
15
|
+
that match the structure expected by the current AnthropicSQLAgent.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
self.content_blocks: list[dict[str, Any]] = []
|
|
20
|
+
self.tool_use_blocks: list[dict[str, Any]] = []
|
|
21
|
+
|
|
22
|
+
async def process_stream(
|
|
23
|
+
self,
|
|
24
|
+
raw_stream: AsyncIterator[dict[str, Any]],
|
|
25
|
+
cancellation_token: asyncio.Event | None = None,
|
|
26
|
+
) -> AsyncIterator[Any]:
|
|
27
|
+
"""Process raw stream events and yield adapted events.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
raw_stream: Raw stream events from the API
|
|
31
|
+
cancellation_token: Optional cancellation token
|
|
32
|
+
|
|
33
|
+
Yields:
|
|
34
|
+
Adapted stream events that match the SDK format
|
|
35
|
+
"""
|
|
36
|
+
async for raw_event in raw_stream:
|
|
37
|
+
# Check for cancellation
|
|
38
|
+
if cancellation_token is not None and cancellation_token.is_set():
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
# Convert raw event to SDK-like event
|
|
42
|
+
adapted_event = self._adapt_event(raw_event)
|
|
43
|
+
if adapted_event:
|
|
44
|
+
yield adapted_event
|
|
45
|
+
|
|
46
|
+
def _adapt_event(self, raw_event: dict[str, Any]) -> Any | None:
|
|
47
|
+
"""Adapt a raw stream event to match SDK format.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
raw_event: Raw event from the API
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Adapted event object or None if event should be filtered
|
|
54
|
+
"""
|
|
55
|
+
event_type = raw_event.get("type")
|
|
56
|
+
event_data = raw_event.get("data", {})
|
|
57
|
+
|
|
58
|
+
if event_type == "ping":
|
|
59
|
+
# Create a ping event object
|
|
60
|
+
return PingEvent()
|
|
61
|
+
|
|
62
|
+
elif event_type == "message_start":
|
|
63
|
+
# Create message start event
|
|
64
|
+
return MessageStartEvent(event_data.get("message", {}))
|
|
65
|
+
|
|
66
|
+
elif event_type == "content_block_start":
|
|
67
|
+
# Create content block start event
|
|
68
|
+
index = event_data.get("index", 0)
|
|
69
|
+
content_block = event_data.get("content_block", {})
|
|
70
|
+
|
|
71
|
+
# Initialize content blocks list if needed
|
|
72
|
+
while len(self.content_blocks) <= index:
|
|
73
|
+
self.content_blocks.append({"type": "text", "text": ""})
|
|
74
|
+
|
|
75
|
+
if content_block.get("type") == "tool_use":
|
|
76
|
+
# Add to tool use blocks tracking
|
|
77
|
+
tool_block = {
|
|
78
|
+
"id": content_block.get("id"),
|
|
79
|
+
"name": content_block.get("name"),
|
|
80
|
+
"input": {},
|
|
81
|
+
"_partial": "",
|
|
82
|
+
}
|
|
83
|
+
self.tool_use_blocks.append(tool_block)
|
|
84
|
+
|
|
85
|
+
return ContentBlockStartEvent(index, content_block)
|
|
86
|
+
|
|
87
|
+
elif event_type == "content_block_delta":
|
|
88
|
+
# Create content block delta event
|
|
89
|
+
index = event_data.get("index", 0)
|
|
90
|
+
delta = event_data.get("delta", {})
|
|
91
|
+
|
|
92
|
+
# Update content blocks tracking
|
|
93
|
+
if index < len(self.content_blocks):
|
|
94
|
+
if delta.get("type") == "text_delta":
|
|
95
|
+
self.content_blocks[index]["text"] += delta.get("text", "")
|
|
96
|
+
elif delta.get("type") == "input_json_delta":
|
|
97
|
+
# Update tool use input tracking
|
|
98
|
+
if self.tool_use_blocks:
|
|
99
|
+
current_tool = self.tool_use_blocks[-1]
|
|
100
|
+
current_tool["_partial"] += delta.get("partial_json", "")
|
|
101
|
+
try:
|
|
102
|
+
current_tool["input"] = json.loads(current_tool["_partial"])
|
|
103
|
+
except json.JSONDecodeError:
|
|
104
|
+
pass # Partial JSON, continue accumulating
|
|
105
|
+
|
|
106
|
+
return ContentBlockDeltaEvent(index, delta)
|
|
107
|
+
|
|
108
|
+
elif event_type == "content_block_stop":
|
|
109
|
+
# Create content block stop event
|
|
110
|
+
index = event_data.get("index", 0)
|
|
111
|
+
return ContentBlockStopEvent(index)
|
|
112
|
+
|
|
113
|
+
elif event_type == "message_delta":
|
|
114
|
+
# Create message delta event
|
|
115
|
+
delta = event_data.get("delta", {})
|
|
116
|
+
usage = event_data.get("usage", {})
|
|
117
|
+
return MessageDeltaEvent(delta, usage)
|
|
118
|
+
|
|
119
|
+
elif event_type == "message_stop":
|
|
120
|
+
# Finalize tool blocks
|
|
121
|
+
self._finalize_tool_blocks()
|
|
122
|
+
return MessageStopEvent()
|
|
123
|
+
|
|
124
|
+
elif event_type == "error":
|
|
125
|
+
# Create error event
|
|
126
|
+
return ErrorEvent(event_data)
|
|
127
|
+
|
|
128
|
+
else:
|
|
129
|
+
# Unknown event type, log and ignore
|
|
130
|
+
logger.debug(f"Unknown event type: {event_type}")
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
def _finalize_tool_blocks(self):
|
|
134
|
+
"""Finalize tool use blocks by cleaning up and adding to content blocks."""
|
|
135
|
+
for block in self.tool_use_blocks:
|
|
136
|
+
block["type"] = "tool_use"
|
|
137
|
+
if "_partial" in block:
|
|
138
|
+
del block["_partial"]
|
|
139
|
+
self.content_blocks.append(block)
|
|
140
|
+
|
|
141
|
+
def get_stop_reason(self) -> str:
|
|
142
|
+
"""Get the stop reason based on current state."""
|
|
143
|
+
if self.tool_use_blocks:
|
|
144
|
+
return "tool_use"
|
|
145
|
+
return "stop"
|
|
146
|
+
|
|
147
|
+
def get_content_blocks(self) -> list[dict[str, Any]]:
|
|
148
|
+
"""Get the current content blocks."""
|
|
149
|
+
return self.content_blocks.copy()
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# Event classes that match the SDK structure
|
|
153
|
+
class BaseStreamEvent:
|
|
154
|
+
"""Base class for stream events."""
|
|
155
|
+
|
|
156
|
+
def __init__(self, event_type: str):
|
|
157
|
+
self.type = event_type
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class PingEvent(BaseStreamEvent):
|
|
161
|
+
"""Ping event."""
|
|
162
|
+
|
|
163
|
+
def __init__(self):
|
|
164
|
+
super().__init__("ping")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class MessageStartEvent(BaseStreamEvent):
|
|
168
|
+
"""Message start event."""
|
|
169
|
+
|
|
170
|
+
def __init__(self, message: dict[str, Any]):
|
|
171
|
+
super().__init__("message_start")
|
|
172
|
+
self.message = message
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class ContentBlockStartEvent(BaseStreamEvent):
|
|
176
|
+
"""Content block start event."""
|
|
177
|
+
|
|
178
|
+
def __init__(self, index: int, content_block: dict[str, Any]):
|
|
179
|
+
super().__init__("content_block_start")
|
|
180
|
+
self.index = index
|
|
181
|
+
self.content_block = MockContentBlock(content_block)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class ContentBlockDeltaEvent(BaseStreamEvent):
|
|
185
|
+
"""Content block delta event."""
|
|
186
|
+
|
|
187
|
+
def __init__(self, index: int, delta: dict[str, Any]):
|
|
188
|
+
super().__init__("content_block_delta")
|
|
189
|
+
self.index = index
|
|
190
|
+
self.delta = MockDelta(delta)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class ContentBlockStopEvent(BaseStreamEvent):
|
|
194
|
+
"""Content block stop event."""
|
|
195
|
+
|
|
196
|
+
def __init__(self, index: int):
|
|
197
|
+
super().__init__("content_block_stop")
|
|
198
|
+
self.index = index
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class MessageDeltaEvent(BaseStreamEvent):
|
|
202
|
+
"""Message delta event."""
|
|
203
|
+
|
|
204
|
+
def __init__(self, delta: dict[str, Any], usage: dict[str, Any]):
|
|
205
|
+
super().__init__("message_delta")
|
|
206
|
+
self.delta = delta
|
|
207
|
+
self.usage = usage
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class MessageStopEvent(BaseStreamEvent):
|
|
211
|
+
"""Message stop event."""
|
|
212
|
+
|
|
213
|
+
def __init__(self):
|
|
214
|
+
super().__init__("message_stop")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class ErrorEvent(BaseStreamEvent):
|
|
218
|
+
"""Error event."""
|
|
219
|
+
|
|
220
|
+
def __init__(self, error_data: dict[str, Any]):
|
|
221
|
+
super().__init__("error")
|
|
222
|
+
self.error = error_data
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# Mock classes to match SDK object structure
|
|
226
|
+
class MockContentBlock:
|
|
227
|
+
"""Mock content block object that matches SDK structure."""
|
|
228
|
+
|
|
229
|
+
def __init__(self, data: dict[str, Any]):
|
|
230
|
+
self.type = data.get("type")
|
|
231
|
+
self.id = data.get("id")
|
|
232
|
+
self.name = data.get("name")
|
|
233
|
+
self.input = data.get("input", {})
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class MockDelta:
|
|
237
|
+
"""Mock delta object that matches SDK structure."""
|
|
238
|
+
|
|
239
|
+
def __init__(self, data: dict[str, Any]):
|
|
240
|
+
self.data = data
|
|
241
|
+
# Set attributes based on delta type
|
|
242
|
+
if data.get("type") == "text_delta":
|
|
243
|
+
self.text = data.get("text", "")
|
|
244
|
+
elif data.get("type") == "input_json_delta":
|
|
245
|
+
self.partial_json = data.get("partial_json", "")
|
|
246
|
+
|
|
247
|
+
def __getattr__(self, name):
|
|
248
|
+
"""Allow access to any attribute from the data."""
|
|
249
|
+
return self.data.get(name)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class StreamingResponse:
|
|
253
|
+
"""Response object for streaming that matches the current agent's expectations."""
|
|
254
|
+
|
|
255
|
+
def __init__(self, content: list[dict[str, Any]], stop_reason: str):
|
|
256
|
+
self.content = content
|
|
257
|
+
self.stop_reason = stop_reason
|
sqlsaber/config/api_keys.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import getpass
|
|
4
4
|
import os
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
import keyring
|
|
8
7
|
from rich.console import Console
|
|
@@ -16,7 +15,7 @@ class APIKeyManager:
|
|
|
16
15
|
def __init__(self):
|
|
17
16
|
self.service_prefix = "sqlsaber"
|
|
18
17
|
|
|
19
|
-
def get_api_key(self, provider: str) ->
|
|
18
|
+
def get_api_key(self, provider: str) -> str | None:
|
|
20
19
|
"""Get API key for the specified provider using cascading logic."""
|
|
21
20
|
env_var_name = self._get_env_var_name(provider)
|
|
22
21
|
service_name = self._get_service_name(provider)
|
|
@@ -57,7 +56,7 @@ class APIKeyManager:
|
|
|
57
56
|
|
|
58
57
|
def _prompt_and_store_key(
|
|
59
58
|
self, provider: str, env_var_name: str, service_name: str
|
|
60
|
-
) ->
|
|
59
|
+
) -> str | None:
|
|
61
60
|
"""Prompt user for API key and store it in keyring."""
|
|
62
61
|
try:
|
|
63
62
|
console.print(
|