sqlsaber 0.15.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -1,491 +0,0 @@
1
- """Anthropic-specific SQL agent implementation using the custom client."""
2
-
3
- import asyncio
4
- import json
5
- from typing import Any, AsyncIterator
6
-
7
- from sqlsaber.agents.base import BaseSQLAgent
8
- from sqlsaber.agents.streaming import (
9
- build_tool_result_block,
10
- )
11
- from sqlsaber.clients import AnthropicClient
12
- from sqlsaber.clients.models import (
13
- ContentBlock,
14
- ContentType,
15
- CreateMessageRequest,
16
- Message,
17
- MessageRole,
18
- ToolDefinition,
19
- )
20
- from sqlsaber.config.settings import Config
21
- from sqlsaber.database.connection import BaseDatabaseConnection
22
- from sqlsaber.memory.manager import MemoryManager
23
- from sqlsaber.models.events import StreamEvent
24
- from sqlsaber.tools import tool_registry
25
- from sqlsaber.tools.instructions import InstructionBuilder
26
-
27
-
28
- class AnthropicSQLAgent(BaseSQLAgent):
29
- """SQL Agent using the custom Anthropic client."""
30
-
31
- # Constants
32
- MAX_TOKENS = 4096
33
- DEFAULT_SQL_LIMIT = 100
34
-
35
- def __init__(
36
- self, db_connection: BaseDatabaseConnection, database_name: str | None = None
37
- ):
38
- super().__init__(db_connection)
39
-
40
- config = Config()
41
- config.validate() # This will raise ValueError if credentials are missing
42
-
43
- if config.oauth_token:
44
- self.client = AnthropicClient(oauth_token=config.oauth_token)
45
- else:
46
- self.client = AnthropicClient(api_key=config.api_key)
47
- self.model = config.model_name.replace("anthropic:", "")
48
-
49
- self.database_name = database_name
50
- self.memory_manager = MemoryManager()
51
-
52
- # Track last query results for streaming
53
- self._last_results = None
54
- self._last_query = None
55
-
56
- # Get tool definitions from registry
57
- self.tools: list[ToolDefinition] = tool_registry.get_tool_definitions()
58
-
59
- # Initialize instruction builder
60
- self.instruction_builder = InstructionBuilder(tool_registry)
61
-
62
- # Build system prompt with memories if available
63
- self.system_prompt = self._build_system_prompt()
64
-
65
- def _build_system_prompt(self) -> str:
66
- """Build system prompt with optional memory context."""
67
- # For OAuth authentication, start with Claude Code identity
68
- # Check if we're using OAuth by looking at the client
69
- is_oauth = (
70
- hasattr(self, "client")
71
- and hasattr(self.client, "use_oauth")
72
- and self.client.use_oauth
73
- )
74
-
75
- if is_oauth:
76
- # For OAuth, keep system prompt minimal - just Claude Code identity
77
- return "You are Claude Code, Anthropic's official CLI for Claude."
78
- else:
79
- return self._get_sql_assistant_instructions()
80
-
81
- def _get_sql_assistant_instructions(self) -> str:
82
- """Get the detailed SQL assistant instructions."""
83
- db_type = self._get_database_type_name()
84
-
85
- # Build dynamic instructions from available tools
86
- instructions = self.instruction_builder.build_instructions(db_type=db_type)
87
-
88
- # Add memory context if database name is available
89
- if self.database_name:
90
- memory_context = self.memory_manager.format_memories_for_prompt(
91
- self.database_name
92
- )
93
- if memory_context.strip():
94
- instructions += "\n\n" + memory_context
95
-
96
- return instructions
97
-
98
- def add_memory(self, content: str) -> str | None:
99
- """Add a memory for the current database."""
100
- if not self.database_name:
101
- return None
102
-
103
- memory = self.memory_manager.add_memory(self.database_name, content)
104
- # Rebuild system prompt with new memory (includes dynamic instructions)
105
- self.system_prompt = self._build_system_prompt()
106
- return memory.id
107
-
108
- async def _execute_sql_with_tracking(
109
- self, query: str, limit: int | None = None
110
- ) -> str:
111
- """Execute SQL and track results for streaming."""
112
- # Get the execute_sql tool and run it
113
- tool = tool_registry.get_tool("execute_sql")
114
- result = await tool.execute(query=query, limit=limit)
115
-
116
- # Parse result to extract data for streaming
117
- try:
118
- result_data = json.loads(result)
119
- if result_data.get("success") and "results" in result_data:
120
- # Store results for streaming
121
- actual_limit = (
122
- limit if limit is not None else len(result_data["results"])
123
- )
124
- self._last_results = result_data["results"][:actual_limit]
125
- self._last_query = query
126
- except (json.JSONDecodeError, KeyError):
127
- # If we can't parse the result, just continue without storing
128
- pass
129
-
130
- return result
131
-
132
- async def process_tool_call(
133
- self, tool_name: str, tool_input: dict[str, Any]
134
- ) -> str:
135
- """Process a tool call and return the result."""
136
- # Special handling for execute_sql to track results
137
- if tool_name == "execute_sql":
138
- return await self._execute_sql_with_tracking(
139
- tool_input.get("query", ""),
140
- tool_input.get("limit", self.DEFAULT_SQL_LIMIT),
141
- )
142
-
143
- # Use parent implementation for all other tools
144
- return await super().process_tool_call(tool_name, tool_input)
145
-
146
- def _convert_user_message_to_message(
147
- self, msg_content: str | list[dict[str, Any]]
148
- ) -> Message:
149
- """Convert user message content to Message object."""
150
- if isinstance(msg_content, str):
151
- return Message(MessageRole.USER, msg_content)
152
-
153
- # Handle tool results format
154
- tool_result_blocks = []
155
- if isinstance(msg_content, list):
156
- for item in msg_content:
157
- if isinstance(item, dict) and item.get("type") == "tool_result":
158
- tool_result_blocks.append(
159
- ContentBlock(ContentType.TOOL_RESULT, item)
160
- )
161
-
162
- if tool_result_blocks:
163
- return Message(MessageRole.USER, tool_result_blocks)
164
-
165
- # Fallback to string representation
166
- return Message(MessageRole.USER, str(msg_content))
167
-
168
- def _convert_assistant_message_to_message(
169
- self, msg_content: str | list[dict[str, Any]]
170
- ) -> Message:
171
- """Convert assistant message content to Message object."""
172
- if isinstance(msg_content, str):
173
- return Message(MessageRole.ASSISTANT, msg_content)
174
-
175
- if isinstance(msg_content, list):
176
- content_blocks = []
177
- for block in msg_content:
178
- if isinstance(block, dict):
179
- if block.get("type") == "text":
180
- text_content = block.get("text", "")
181
- if text_content: # Only add non-empty text blocks
182
- content_blocks.append(
183
- ContentBlock(ContentType.TEXT, text_content)
184
- )
185
- elif block.get("type") == "tool_use":
186
- content_blocks.append(
187
- ContentBlock(
188
- ContentType.TOOL_USE,
189
- {
190
- "id": block["id"],
191
- "name": block["name"],
192
- "input": block["input"],
193
- },
194
- )
195
- )
196
- if content_blocks:
197
- return Message(MessageRole.ASSISTANT, content_blocks)
198
-
199
- # Fallback to string representation
200
- return Message(MessageRole.ASSISTANT, str(msg_content))
201
-
202
- def _convert_history_to_messages(self) -> list[Message]:
203
- """Convert conversation history to Message objects."""
204
- messages = []
205
- for msg in self.conversation_history:
206
- if msg["role"] == "user":
207
- messages.append(self._convert_user_message_to_message(msg["content"]))
208
- elif msg["role"] == "assistant":
209
- messages.append(
210
- self._convert_assistant_message_to_message(msg["content"])
211
- )
212
- return messages
213
-
214
- def _convert_tool_results_to_message(
215
- self, tool_results: list[dict[str, Any]]
216
- ) -> Message:
217
- """Convert tool results to a user Message object."""
218
- tool_result_blocks = []
219
- for tool_result in tool_results:
220
- tool_result_blocks.append(
221
- ContentBlock(ContentType.TOOL_RESULT, tool_result)
222
- )
223
- return Message(MessageRole.USER, tool_result_blocks)
224
-
225
- def _convert_response_content_to_message(
226
- self, content: list[dict[str, Any]]
227
- ) -> Message:
228
- """Convert response content to assistant Message object."""
229
- content_blocks = []
230
- for block in content:
231
- if block.get("type") == "text":
232
- text_content = block["text"]
233
- if text_content: # Only add non-empty text blocks
234
- content_blocks.append(ContentBlock(ContentType.TEXT, text_content))
235
- elif block.get("type") == "tool_use":
236
- content_blocks.append(
237
- ContentBlock(
238
- ContentType.TOOL_USE,
239
- {
240
- "id": block["id"],
241
- "name": block["name"],
242
- "input": block["input"],
243
- },
244
- )
245
- )
246
- return Message(MessageRole.ASSISTANT, content_blocks)
247
-
248
- async def _execute_and_yield_tool_results(
249
- self,
250
- response_content: list[dict[str, Any]],
251
- cancellation_token: asyncio.Event | None = None,
252
- ) -> AsyncIterator[StreamEvent | list[dict[str, Any]]]:
253
- """Execute tool calls and yield appropriate stream events."""
254
- tool_results = []
255
-
256
- for block in response_content:
257
- if block.get("type") == "tool_use":
258
- # Check for cancellation before tool execution
259
- if cancellation_token is not None and cancellation_token.is_set():
260
- yield tool_results
261
- return
262
-
263
- yield StreamEvent(
264
- "tool_use",
265
- {
266
- "name": block["name"],
267
- "input": block["input"],
268
- "status": "executing",
269
- },
270
- )
271
-
272
- tool_result = await self.process_tool_call(
273
- block["name"], block["input"]
274
- )
275
-
276
- # Yield specific events based on tool type
277
- if block["name"] == "execute_sql" and self._last_results:
278
- yield StreamEvent(
279
- "query_result",
280
- {
281
- "query": self._last_query,
282
- "results": self._last_results,
283
- },
284
- )
285
- elif block["name"] in ["list_tables", "introspect_schema"]:
286
- yield StreamEvent(
287
- "tool_result",
288
- {
289
- "tool_name": block["name"],
290
- "result": tool_result,
291
- },
292
- )
293
- elif block["name"] == "plot_data":
294
- yield StreamEvent(
295
- "plot_result",
296
- {
297
- "tool_name": block["name"],
298
- "input": block["input"],
299
- "result": tool_result,
300
- },
301
- )
302
-
303
- tool_results.append(build_tool_result_block(block["id"], tool_result))
304
-
305
- yield tool_results
306
-
307
- async def _handle_stream_events(
308
- self,
309
- stream_iterator: AsyncIterator[Any],
310
- cancellation_token: asyncio.Event | None = None,
311
- ) -> AsyncIterator[StreamEvent | Any]:
312
- """Handle streaming events and yield stream events, return final response."""
313
- response = None
314
-
315
- async for event in stream_iterator:
316
- if cancellation_token is not None and cancellation_token.is_set():
317
- yield None
318
- return
319
-
320
- # Handle different event types
321
- if hasattr(event, "type"):
322
- if event.type == "content_block_start":
323
- if hasattr(event.content_block, "type"):
324
- if event.content_block.type == "tool_use":
325
- yield StreamEvent(
326
- "tool_use",
327
- {
328
- "name": event.content_block.name,
329
- "status": "started",
330
- },
331
- )
332
- elif event.type == "content_block_delta":
333
- if hasattr(event.delta, "text"):
334
- text = event.delta.text
335
- if text is not None and text: # Only yield non-empty text
336
- yield StreamEvent("text", text)
337
- elif isinstance(event, dict) and event.get("type") == "response_ready":
338
- response = event["data"]
339
-
340
- yield response
341
-
342
- def _create_message_request(self, messages: list[Message]) -> CreateMessageRequest:
343
- """Create a CreateMessageRequest with standard parameters."""
344
- return CreateMessageRequest(
345
- model=self.model,
346
- messages=messages,
347
- max_tokens=self.MAX_TOKENS,
348
- system=self.system_prompt,
349
- tools=self.tools,
350
- stream=True,
351
- )
352
-
353
- async def query_stream(
354
- self,
355
- user_query: str,
356
- use_history: bool = True,
357
- cancellation_token: asyncio.Event | None = None,
358
- ) -> AsyncIterator[StreamEvent]:
359
- """Process a user query and stream responses."""
360
- # Initialize for tracking state
361
- self._last_results = None
362
- self._last_query = None
363
-
364
- try:
365
- # Ensure conversation is active for persistence
366
- await self._ensure_conversation()
367
-
368
- # Store user message in conversation history and persistence
369
- if use_history:
370
- self.conversation_history.append(
371
- {"role": "user", "content": user_query}
372
- )
373
- await self._store_user_message(user_query)
374
-
375
- # Build messages with history if requested
376
- messages = []
377
- if use_history:
378
- messages = self._convert_history_to_messages()
379
-
380
- # For OAuth with no history, inject SQL assistant instructions as first user message
381
- is_oauth = hasattr(self.client, "use_oauth") and self.client.use_oauth
382
- if is_oauth and not messages:
383
- instructions = self._get_sql_assistant_instructions()
384
- messages.append(Message(MessageRole.USER, instructions))
385
-
386
- # Add current user message if not already in messages from history
387
- if not use_history:
388
- messages.append(Message(MessageRole.USER, user_query))
389
-
390
- # Create initial request and get response
391
- request = self._create_message_request(messages)
392
- response = None
393
-
394
- async for event in self._handle_stream_events(
395
- self.client.create_message_with_tools(request, cancellation_token),
396
- cancellation_token,
397
- ):
398
- if isinstance(event, StreamEvent):
399
- yield event
400
- else:
401
- response = event
402
-
403
- # Handle tool use cycles
404
- collected_content = []
405
- while response is not None and response.stop_reason == "tool_use":
406
- if cancellation_token is not None and cancellation_token.is_set():
407
- return
408
-
409
- # Add assistant's response to conversation
410
- assistant_content = {"role": "assistant", "content": response.content}
411
- collected_content.append(assistant_content)
412
-
413
- # Store the assistant message immediately (not from collected_content)
414
- if use_history:
415
- await self._store_assistant_message(response.content)
416
-
417
- # Execute tools and get results
418
- tool_results = []
419
- async for event in self._execute_and_yield_tool_results(
420
- response.content, cancellation_token
421
- ):
422
- if isinstance(event, StreamEvent):
423
- yield event
424
- elif isinstance(event, list):
425
- tool_results = event
426
-
427
- # Continue conversation with tool results
428
- tool_content = {"role": "user", "content": tool_results}
429
- collected_content.append(tool_content)
430
-
431
- # Store the tool message immediately and update history
432
- if use_history:
433
- # Only add the NEW messages to history (not the accumulated ones)
434
- # collected_content has [assistant1, tool1, assistant2, tool2, ...]
435
- # We only want to add the last 2 items that were just added
436
- new_messages_for_history = collected_content[
437
- -2:
438
- ] # Last assistant + tool pair
439
- self.conversation_history.extend(new_messages_for_history)
440
- await self._store_tool_message(tool_results)
441
-
442
- if cancellation_token is not None and cancellation_token.is_set():
443
- return
444
-
445
- yield StreamEvent("processing", "Analyzing results...")
446
-
447
- # Build new messages with collected content
448
- new_messages = messages.copy()
449
- for content in collected_content:
450
- if content["role"] == "user":
451
- new_messages.append(
452
- self._convert_tool_results_to_message(content["content"])
453
- )
454
- elif content["role"] == "assistant":
455
- new_messages.append(
456
- self._convert_response_content_to_message(
457
- content["content"]
458
- )
459
- )
460
-
461
- # Get next response
462
- request = self._create_message_request(new_messages)
463
- response = None
464
-
465
- async for event in self._handle_stream_events(
466
- self.client.create_message_with_tools(request, cancellation_token),
467
- cancellation_token,
468
- ):
469
- if isinstance(event, StreamEvent):
470
- yield event
471
- else:
472
- response = event
473
-
474
- # Update conversation history with final response
475
- if use_history and response is not None:
476
- self.conversation_history.append(
477
- {"role": "assistant", "content": response.content}
478
- )
479
-
480
- # Store final assistant message in persistence (only if not tool_use)
481
- if response.stop_reason != "tool_use":
482
- await self._store_assistant_message(response.content)
483
-
484
- except asyncio.CancelledError:
485
- return
486
- except Exception as e:
487
- yield StreamEvent("error", str(e))
488
-
489
- async def close(self):
490
- """Close the client."""
491
- await self.client.close()
@@ -1,16 +0,0 @@
1
- """Streaming utilities for agents."""
2
-
3
- from typing import Any
4
-
5
-
6
- class StreamingResponse:
7
- """Helper class to manage streaming response construction."""
8
-
9
- def __init__(self, content: list[dict[str, Any]], stop_reason: str):
10
- self.content = content
11
- self.stop_reason = stop_reason
12
-
13
-
14
- def build_tool_result_block(tool_use_id: str, content: str) -> dict[str, Any]:
15
- """Build a tool result block for the conversation."""
16
- return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
@@ -1,6 +0,0 @@
1
- """Client implementations for various LLM APIs."""
2
-
3
- from .base import BaseLLMClient
4
- from .anthropic import AnthropicClient
5
-
6
- __all__ = ["BaseLLMClient", "AnthropicClient"]