sqlsaber 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -0,0 +1,117 @@
1
+ """Exception classes for LLM client errors."""
2
+
3
+ from typing import Any
4
+
5
+
6
+ class LLMClientError(Exception):
7
+ """Base exception for LLM client errors."""
8
+
9
+ def __init__(
10
+ self,
11
+ message: str,
12
+ error_type: str | None = None,
13
+ status_code: int | None = None,
14
+ request_id: str | None = None,
15
+ ):
16
+ super().__init__(message)
17
+ self.error_type = error_type
18
+ self.status_code = status_code
19
+ self.request_id = request_id
20
+
21
+
22
+ class AuthenticationError(LLMClientError):
23
+ """Authentication failed - invalid API key."""
24
+
25
+ def __init__(self, message: str = "Invalid API key", **kwargs):
26
+ super().__init__(message, "authentication_error", **kwargs)
27
+
28
+
29
+ class PermissionError(LLMClientError):
30
+ """Permission denied for the requested resource."""
31
+
32
+ def __init__(self, message: str = "Permission denied", **kwargs):
33
+ super().__init__(message, "permission_error", **kwargs)
34
+
35
+
36
+ class NotFoundError(LLMClientError):
37
+ """Requested resource not found."""
38
+
39
+ def __init__(self, message: str = "Resource not found", **kwargs):
40
+ super().__init__(message, "not_found_error", **kwargs)
41
+
42
+
43
+ class InvalidRequestError(LLMClientError):
44
+ """Invalid request format or content."""
45
+
46
+ def __init__(self, message: str = "Invalid request", **kwargs):
47
+ super().__init__(message, "invalid_request_error", **kwargs)
48
+
49
+
50
+ class RequestTooLargeError(LLMClientError):
51
+ """Request exceeds maximum allowed size."""
52
+
53
+ def __init__(self, message: str = "Request too large", **kwargs):
54
+ super().__init__(message, "request_too_large", **kwargs)
55
+
56
+
57
+ class RateLimitError(LLMClientError):
58
+ """Rate limit exceeded."""
59
+
60
+ def __init__(self, message: str = "Rate limit exceeded", **kwargs):
61
+ super().__init__(message, "rate_limit_error", **kwargs)
62
+
63
+
64
+ class APIError(LLMClientError):
65
+ """Internal API error."""
66
+
67
+ def __init__(self, message: str = "Internal API error", **kwargs):
68
+ super().__init__(message, "api_error", **kwargs)
69
+
70
+
71
+ class OverloadedError(LLMClientError):
72
+ """API is temporarily overloaded."""
73
+
74
+ def __init__(self, message: str = "API temporarily overloaded", **kwargs):
75
+ super().__init__(message, "overloaded_error", **kwargs)
76
+
77
+
78
+ # Mapping of HTTP status codes to exception classes
79
+ STATUS_CODE_TO_EXCEPTION = {
80
+ 400: InvalidRequestError,
81
+ 401: AuthenticationError,
82
+ 403: PermissionError,
83
+ 404: NotFoundError,
84
+ 413: RequestTooLargeError,
85
+ 429: RateLimitError,
86
+ 500: APIError,
87
+ 529: OverloadedError,
88
+ }
89
+
90
+
91
+ def create_exception_from_response(
92
+ status_code: int,
93
+ response_data: dict[str, Any],
94
+ request_id: str | None = None,
95
+ ) -> LLMClientError:
96
+ """Create appropriate exception from HTTP response."""
97
+ error_data = response_data.get("error", {})
98
+ message = error_data.get("message", f"HTTP {status_code} error")
99
+ error_type = error_data.get("type")
100
+
101
+ exception_class = STATUS_CODE_TO_EXCEPTION.get(status_code, LLMClientError)
102
+
103
+ # Handle base vs subclass constructors
104
+ if exception_class == LLMClientError:
105
+ return exception_class(
106
+ message,
107
+ error_type,
108
+ status_code,
109
+ request_id,
110
+ )
111
+ else:
112
+ # Subclasses only take message and **kwargs
113
+ return exception_class(
114
+ message,
115
+ status_code=status_code,
116
+ request_id=request_id,
117
+ )
@@ -0,0 +1,282 @@
1
+ """Data models for LLM client requests and responses."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Any
6
+
7
+
8
+ class MessageRole(str, Enum):
9
+ """Message roles in a conversation."""
10
+
11
+ USER = "user"
12
+ ASSISTANT = "assistant"
13
+ SYSTEM = "system"
14
+
15
+
16
+ class ContentType(str, Enum):
17
+ """Content block types."""
18
+
19
+ TEXT = "text"
20
+ IMAGE = "image"
21
+ TOOL_USE = "tool_use"
22
+ TOOL_RESULT = "tool_result"
23
+
24
+
25
+ class ToolChoiceType(str, Enum):
26
+ """Tool choice types."""
27
+
28
+ AUTO = "auto"
29
+ ANY = "any"
30
+ TOOL = "tool"
31
+ NONE = "none"
32
+
33
+
34
+ class StopReason(str, Enum):
35
+ """Stop reasons for message completion."""
36
+
37
+ END_TURN = "end_turn"
38
+ MAX_TOKENS = "max_tokens"
39
+ STOP_SEQUENCE = "stop_sequence"
40
+ TOOL_USE = "tool_use"
41
+
42
+
43
+ @dataclass
44
+ class ContentBlock:
45
+ """A content block in a message."""
46
+
47
+ type: ContentType
48
+ content: str | dict[str, Any]
49
+
50
+ def to_dict(self) -> dict[str, Any]:
51
+ """Convert to dictionary format."""
52
+ if self.type == ContentType.TEXT:
53
+ return {"type": "text", "text": self.content}
54
+ elif self.type == ContentType.TOOL_USE:
55
+ return {
56
+ "type": "tool_use",
57
+ "id": self.content["id"],
58
+ "name": self.content["name"],
59
+ "input": self.content["input"],
60
+ }
61
+ elif self.type == ContentType.TOOL_RESULT:
62
+ return {
63
+ "type": "tool_result",
64
+ "tool_use_id": self.content["tool_use_id"],
65
+ "content": self.content["content"],
66
+ }
67
+ else:
68
+ return {"type": self.type.value, **self.content}
69
+
70
+
71
+ @dataclass
72
+ class Message:
73
+ """A message in a conversation."""
74
+
75
+ role: MessageRole
76
+ content: str | list[ContentBlock]
77
+
78
+ def to_dict(self) -> dict[str, Any]:
79
+ """Convert to dictionary format for API requests."""
80
+ if isinstance(self.content, str):
81
+ return {"role": self.role.value, "content": self.content}
82
+ else:
83
+ return {
84
+ "role": self.role.value,
85
+ "content": [block.to_dict() for block in self.content],
86
+ }
87
+
88
+
89
+ @dataclass
90
+ class ToolDefinition:
91
+ """Definition of a tool that can be called."""
92
+
93
+ name: str
94
+ description: str
95
+ input_schema: dict[str, Any]
96
+
97
+ def to_dict(self) -> dict[str, Any]:
98
+ """Convert to dictionary format for API requests."""
99
+ return {
100
+ "name": self.name,
101
+ "description": self.description,
102
+ "input_schema": self.input_schema,
103
+ }
104
+
105
+
106
+ @dataclass
107
+ class ToolChoice:
108
+ """Tool choice configuration."""
109
+
110
+ type: ToolChoiceType
111
+ name: str | None = None
112
+ disable_parallel_tool_use: bool = False
113
+
114
+ def to_dict(self) -> dict[str, Any]:
115
+ """Convert to dictionary format for API requests."""
116
+ result = {"type": self.type.value}
117
+ if self.name:
118
+ result["name"] = self.name
119
+ if self.disable_parallel_tool_use:
120
+ result["disable_parallel_tool_use"] = True
121
+ return result
122
+
123
+
124
+ @dataclass
125
+ class CreateMessageRequest:
126
+ """Request to create a message."""
127
+
128
+ model: str
129
+ messages: list[Message]
130
+ max_tokens: int
131
+ system: str | None = None
132
+ tools: list[ToolDefinition] | None = None
133
+ tool_choice: ToolChoice | None = None
134
+ temperature: float | None = None
135
+ stream: bool = False
136
+ stop_sequences: list[str] | None = None
137
+
138
+ def to_dict(self) -> dict[str, Any]:
139
+ """Convert to dictionary format for API requests."""
140
+ data = {
141
+ "model": self.model,
142
+ "messages": [msg.to_dict() for msg in self.messages],
143
+ "max_tokens": self.max_tokens,
144
+ }
145
+
146
+ if self.system:
147
+ data["system"] = self.system
148
+ if self.tools:
149
+ data["tools"] = [tool.to_dict() for tool in self.tools]
150
+ if self.tool_choice:
151
+ data["tool_choice"] = self.tool_choice.to_dict()
152
+ if self.temperature is not None:
153
+ data["temperature"] = self.temperature
154
+ if self.stream:
155
+ data["stream"] = True
156
+ if self.stop_sequences:
157
+ data["stop_sequences"] = self.stop_sequences
158
+
159
+ return data
160
+
161
+
162
+ @dataclass
163
+ class Usage:
164
+ """Token usage information."""
165
+
166
+ input_tokens: int
167
+ output_tokens: int
168
+
169
+
170
+ @dataclass
171
+ class MessageResponse:
172
+ """Response from message creation."""
173
+
174
+ id: str
175
+ model: str
176
+ role: MessageRole
177
+ content: list[ContentBlock]
178
+ stop_reason: StopReason
179
+ stop_sequence: str | None
180
+ usage: Usage
181
+
182
+ @classmethod
183
+ def from_dict(cls, data: dict[str, Any]) -> "MessageResponse":
184
+ """Create from API response dictionary."""
185
+ content_blocks = []
186
+ for block_data in data["content"]:
187
+ if block_data["type"] == "text":
188
+ content_blocks.append(
189
+ ContentBlock(ContentType.TEXT, block_data["text"])
190
+ )
191
+ elif block_data["type"] == "tool_use":
192
+ content_blocks.append(
193
+ ContentBlock(
194
+ ContentType.TOOL_USE,
195
+ {
196
+ "id": block_data["id"],
197
+ "name": block_data["name"],
198
+ "input": block_data["input"],
199
+ },
200
+ )
201
+ )
202
+
203
+ return cls(
204
+ id=data["id"],
205
+ model=data["model"],
206
+ role=MessageRole(data["role"]),
207
+ content=content_blocks,
208
+ stop_reason=StopReason(data["stop_reason"]),
209
+ stop_sequence=data.get("stop_sequence"),
210
+ usage=Usage(
211
+ input_tokens=data["usage"]["input_tokens"],
212
+ output_tokens=data["usage"]["output_tokens"],
213
+ ),
214
+ )
215
+
216
+
217
+ # Stream event types and data models for streaming
218
+ @dataclass
219
+ class StreamEventData:
220
+ """Base class for stream event data."""
221
+
222
+ pass
223
+
224
+
225
+ @dataclass
226
+ class TextDeltaData(StreamEventData):
227
+ """Text delta stream event data."""
228
+
229
+ text: str
230
+
231
+
232
+ @dataclass
233
+ class ToolUseStartData(StreamEventData):
234
+ """Tool use start stream event data."""
235
+
236
+ id: str
237
+ name: str
238
+
239
+
240
+ @dataclass
241
+ class ToolUseInputData(StreamEventData):
242
+ """Tool use input stream event data."""
243
+
244
+ partial_json: str
245
+
246
+
247
+ @dataclass
248
+ class MessageStartData(StreamEventData):
249
+ """Message start stream event data."""
250
+
251
+ message: dict[str, Any]
252
+
253
+
254
+ @dataclass
255
+ class MessageDeltaData(StreamEventData):
256
+ """Message delta stream event data."""
257
+
258
+ delta: dict[str, Any]
259
+ usage: dict[str, Any] | None = None
260
+
261
+
262
+ @dataclass
263
+ class ContentBlockStartData(StreamEventData):
264
+ """Content block start stream event data."""
265
+
266
+ index: int
267
+ content_block: dict[str, Any]
268
+
269
+
270
+ @dataclass
271
+ class ContentBlockDeltaData(StreamEventData):
272
+ """Content block delta stream event data."""
273
+
274
+ index: int
275
+ delta: dict[str, Any]
276
+
277
+
278
+ @dataclass
279
+ class ContentBlockStopData(StreamEventData):
280
+ """Content block stop stream event data."""
281
+
282
+ index: int
@@ -0,0 +1,257 @@
1
+ """Streaming adapters and utilities for LLM clients."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from typing import Any, AsyncIterator
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class AnthropicStreamAdapter:
12
+ """Adapter to convert raw Anthropic stream events to standardized format.
13
+
14
+ This adapter converts the raw SSE events from Anthropic API into objects
15
+ that match the structure expected by the current AnthropicSQLAgent.
16
+ """
17
+
18
+ def __init__(self):
19
+ self.content_blocks: list[dict[str, Any]] = []
20
+ self.tool_use_blocks: list[dict[str, Any]] = []
21
+
22
+ async def process_stream(
23
+ self,
24
+ raw_stream: AsyncIterator[dict[str, Any]],
25
+ cancellation_token: asyncio.Event | None = None,
26
+ ) -> AsyncIterator[Any]:
27
+ """Process raw stream events and yield adapted events.
28
+
29
+ Args:
30
+ raw_stream: Raw stream events from the API
31
+ cancellation_token: Optional cancellation token
32
+
33
+ Yields:
34
+ Adapted stream events that match the SDK format
35
+ """
36
+ async for raw_event in raw_stream:
37
+ # Check for cancellation
38
+ if cancellation_token is not None and cancellation_token.is_set():
39
+ return
40
+
41
+ # Convert raw event to SDK-like event
42
+ adapted_event = self._adapt_event(raw_event)
43
+ if adapted_event:
44
+ yield adapted_event
45
+
46
+ def _adapt_event(self, raw_event: dict[str, Any]) -> Any | None:
47
+ """Adapt a raw stream event to match SDK format.
48
+
49
+ Args:
50
+ raw_event: Raw event from the API
51
+
52
+ Returns:
53
+ Adapted event object or None if event should be filtered
54
+ """
55
+ event_type = raw_event.get("type")
56
+ event_data = raw_event.get("data", {})
57
+
58
+ if event_type == "ping":
59
+ # Create a ping event object
60
+ return PingEvent()
61
+
62
+ elif event_type == "message_start":
63
+ # Create message start event
64
+ return MessageStartEvent(event_data.get("message", {}))
65
+
66
+ elif event_type == "content_block_start":
67
+ # Create content block start event
68
+ index = event_data.get("index", 0)
69
+ content_block = event_data.get("content_block", {})
70
+
71
+ # Initialize content blocks list if needed
72
+ while len(self.content_blocks) <= index:
73
+ self.content_blocks.append({"type": "text", "text": ""})
74
+
75
+ if content_block.get("type") == "tool_use":
76
+ # Add to tool use blocks tracking
77
+ tool_block = {
78
+ "id": content_block.get("id"),
79
+ "name": content_block.get("name"),
80
+ "input": {},
81
+ "_partial": "",
82
+ }
83
+ self.tool_use_blocks.append(tool_block)
84
+
85
+ return ContentBlockStartEvent(index, content_block)
86
+
87
+ elif event_type == "content_block_delta":
88
+ # Create content block delta event
89
+ index = event_data.get("index", 0)
90
+ delta = event_data.get("delta", {})
91
+
92
+ # Update content blocks tracking
93
+ if index < len(self.content_blocks):
94
+ if delta.get("type") == "text_delta":
95
+ self.content_blocks[index]["text"] += delta.get("text", "")
96
+ elif delta.get("type") == "input_json_delta":
97
+ # Update tool use input tracking
98
+ if self.tool_use_blocks:
99
+ current_tool = self.tool_use_blocks[-1]
100
+ current_tool["_partial"] += delta.get("partial_json", "")
101
+ try:
102
+ current_tool["input"] = json.loads(current_tool["_partial"])
103
+ except json.JSONDecodeError:
104
+ pass # Partial JSON, continue accumulating
105
+
106
+ return ContentBlockDeltaEvent(index, delta)
107
+
108
+ elif event_type == "content_block_stop":
109
+ # Create content block stop event
110
+ index = event_data.get("index", 0)
111
+ return ContentBlockStopEvent(index)
112
+
113
+ elif event_type == "message_delta":
114
+ # Create message delta event
115
+ delta = event_data.get("delta", {})
116
+ usage = event_data.get("usage", {})
117
+ return MessageDeltaEvent(delta, usage)
118
+
119
+ elif event_type == "message_stop":
120
+ # Finalize tool blocks
121
+ self._finalize_tool_blocks()
122
+ return MessageStopEvent()
123
+
124
+ elif event_type == "error":
125
+ # Create error event
126
+ return ErrorEvent(event_data)
127
+
128
+ else:
129
+ # Unknown event type, log and ignore
130
+ logger.debug(f"Unknown event type: {event_type}")
131
+ return None
132
+
133
+ def _finalize_tool_blocks(self):
134
+ """Finalize tool use blocks by cleaning up and adding to content blocks."""
135
+ for block in self.tool_use_blocks:
136
+ block["type"] = "tool_use"
137
+ if "_partial" in block:
138
+ del block["_partial"]
139
+ self.content_blocks.append(block)
140
+
141
+ def get_stop_reason(self) -> str:
142
+ """Get the stop reason based on current state."""
143
+ if self.tool_use_blocks:
144
+ return "tool_use"
145
+ return "stop"
146
+
147
+ def get_content_blocks(self) -> list[dict[str, Any]]:
148
+ """Get the current content blocks."""
149
+ return self.content_blocks.copy()
150
+
151
+
152
+ # Event classes that match the SDK structure
153
+ class BaseStreamEvent:
154
+ """Base class for stream events."""
155
+
156
+ def __init__(self, event_type: str):
157
+ self.type = event_type
158
+
159
+
160
+ class PingEvent(BaseStreamEvent):
161
+ """Ping event."""
162
+
163
+ def __init__(self):
164
+ super().__init__("ping")
165
+
166
+
167
+ class MessageStartEvent(BaseStreamEvent):
168
+ """Message start event."""
169
+
170
+ def __init__(self, message: dict[str, Any]):
171
+ super().__init__("message_start")
172
+ self.message = message
173
+
174
+
175
+ class ContentBlockStartEvent(BaseStreamEvent):
176
+ """Content block start event."""
177
+
178
+ def __init__(self, index: int, content_block: dict[str, Any]):
179
+ super().__init__("content_block_start")
180
+ self.index = index
181
+ self.content_block = MockContentBlock(content_block)
182
+
183
+
184
+ class ContentBlockDeltaEvent(BaseStreamEvent):
185
+ """Content block delta event."""
186
+
187
+ def __init__(self, index: int, delta: dict[str, Any]):
188
+ super().__init__("content_block_delta")
189
+ self.index = index
190
+ self.delta = MockDelta(delta)
191
+
192
+
193
+ class ContentBlockStopEvent(BaseStreamEvent):
194
+ """Content block stop event."""
195
+
196
+ def __init__(self, index: int):
197
+ super().__init__("content_block_stop")
198
+ self.index = index
199
+
200
+
201
+ class MessageDeltaEvent(BaseStreamEvent):
202
+ """Message delta event."""
203
+
204
+ def __init__(self, delta: dict[str, Any], usage: dict[str, Any]):
205
+ super().__init__("message_delta")
206
+ self.delta = delta
207
+ self.usage = usage
208
+
209
+
210
+ class MessageStopEvent(BaseStreamEvent):
211
+ """Message stop event."""
212
+
213
+ def __init__(self):
214
+ super().__init__("message_stop")
215
+
216
+
217
+ class ErrorEvent(BaseStreamEvent):
218
+ """Error event."""
219
+
220
+ def __init__(self, error_data: dict[str, Any]):
221
+ super().__init__("error")
222
+ self.error = error_data
223
+
224
+
225
+ # Mock classes to match SDK object structure
226
+ class MockContentBlock:
227
+ """Mock content block object that matches SDK structure."""
228
+
229
+ def __init__(self, data: dict[str, Any]):
230
+ self.type = data.get("type")
231
+ self.id = data.get("id")
232
+ self.name = data.get("name")
233
+ self.input = data.get("input", {})
234
+
235
+
236
+ class MockDelta:
237
+ """Mock delta object that matches SDK structure."""
238
+
239
+ def __init__(self, data: dict[str, Any]):
240
+ self.data = data
241
+ # Set attributes based on delta type
242
+ if data.get("type") == "text_delta":
243
+ self.text = data.get("text", "")
244
+ elif data.get("type") == "input_json_delta":
245
+ self.partial_json = data.get("partial_json", "")
246
+
247
+ def __getattr__(self, name):
248
+ """Allow access to any attribute from the data."""
249
+ return self.data.get(name)
250
+
251
+
252
+ class StreamingResponse:
253
+ """Response object for streaming that matches the current agent's expectations."""
254
+
255
+ def __init__(self, content: list[dict[str, Any]], stop_reason: str):
256
+ self.content = content
257
+ self.stop_reason = stop_reason
@@ -2,7 +2,6 @@
2
2
 
3
3
  import getpass
4
4
  import os
5
- from typing import Optional
6
5
 
7
6
  import keyring
8
7
  from rich.console import Console
@@ -16,7 +15,7 @@ class APIKeyManager:
16
15
  def __init__(self):
17
16
  self.service_prefix = "sqlsaber"
18
17
 
19
- def get_api_key(self, provider: str) -> Optional[str]:
18
+ def get_api_key(self, provider: str) -> str | None:
20
19
  """Get API key for the specified provider using cascading logic."""
21
20
  env_var_name = self._get_env_var_name(provider)
22
21
  service_name = self._get_service_name(provider)
@@ -57,7 +56,7 @@ class APIKeyManager:
57
56
 
58
57
  def _prompt_and_store_key(
59
58
  self, provider: str, env_var_name: str, service_name: str
60
- ) -> Optional[str]:
59
+ ) -> str | None:
61
60
  """Prompt user for API key and store it in keyring."""
62
61
  try:
63
62
  console.print(