sqlsaber 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

@@ -1,25 +1,34 @@
1
- """Anthropic-specific SQL agent implementation."""
1
+ """Anthropic-specific SQL agent implementation using the custom client."""
2
2
 
3
3
  import asyncio
4
4
  import json
5
- from typing import Any, AsyncIterator, Dict, List
6
-
7
- from anthropic import AsyncAnthropic
5
+ from typing import Any, AsyncIterator
8
6
 
9
7
  from sqlsaber.agents.base import BaseSQLAgent
10
8
  from sqlsaber.agents.streaming import (
11
- StreamingResponse,
12
9
  build_tool_result_block,
13
10
  )
11
+ from sqlsaber.clients import AnthropicClient
12
+ from sqlsaber.clients.models import (
13
+ ContentBlock,
14
+ ContentType,
15
+ CreateMessageRequest,
16
+ Message,
17
+ MessageRole,
18
+ ToolDefinition,
19
+ )
14
20
  from sqlsaber.config.settings import Config
15
21
  from sqlsaber.database.connection import BaseDatabaseConnection
16
22
  from sqlsaber.memory.manager import MemoryManager
17
23
  from sqlsaber.models.events import StreamEvent
18
- from sqlsaber.models.types import ToolDefinition
19
24
 
20
25
 
21
26
  class AnthropicSQLAgent(BaseSQLAgent):
22
- """SQL Agent using Anthropic SDK directly."""
27
+ """SQL Agent using the custom Anthropic client."""
28
+
29
+ # Constants
30
+ MAX_TOKENS = 4096
31
+ DEFAULT_SQL_LIMIT = 100
23
32
 
24
33
  def __init__(
25
34
  self, db_connection: BaseDatabaseConnection, database_name: str | None = None
@@ -27,9 +36,12 @@ class AnthropicSQLAgent(BaseSQLAgent):
27
36
  super().__init__(db_connection)
28
37
 
29
38
  config = Config()
30
- config.validate() # This will raise ValueError if API key is missing
39
+ config.validate() # This will raise ValueError if credentials are missing
31
40
 
32
- self.client = AsyncAnthropic(api_key=config.api_key)
41
+ if config.oauth_token:
42
+ self.client = AnthropicClient(oauth_token=config.oauth_token)
43
+ else:
44
+ self.client = AnthropicClient(api_key=config.api_key)
33
45
  self.model = config.model_name.replace("anthropic:", "")
34
46
 
35
47
  self.database_name = database_name
@@ -39,21 +51,21 @@ class AnthropicSQLAgent(BaseSQLAgent):
39
51
  self._last_results = None
40
52
  self._last_query = None
41
53
 
42
- # Define tools in Anthropic format
43
- self.tools: List[ToolDefinition] = [
44
- {
45
- "name": "list_tables",
46
- "description": "Get a list of all tables in the database with row counts. Use this first to discover available tables.",
47
- "input_schema": {
54
+ # Define tools in the new format
55
+ self.tools: list[ToolDefinition] = [
56
+ ToolDefinition(
57
+ name="list_tables",
58
+ description="Get a list of all tables in the database with row counts. Use this first to discover available tables.",
59
+ input_schema={
48
60
  "type": "object",
49
61
  "properties": {},
50
62
  "required": [],
51
63
  },
52
- },
53
- {
54
- "name": "introspect_schema",
55
- "description": "Introspect database schema to understand table structures.",
56
- "input_schema": {
64
+ ),
65
+ ToolDefinition(
66
+ name="introspect_schema",
67
+ description="Introspect database schema to understand table structures.",
68
+ input_schema={
57
69
  "type": "object",
58
70
  "properties": {
59
71
  "table_pattern": {
@@ -63,11 +75,11 @@ class AnthropicSQLAgent(BaseSQLAgent):
63
75
  },
64
76
  "required": [],
65
77
  },
66
- },
67
- {
68
- "name": "execute_sql",
69
- "description": "Execute a SQL query against the database.",
70
- "input_schema": {
78
+ ),
79
+ ToolDefinition(
80
+ name="execute_sql",
81
+ description="Execute a SQL query against the database.",
82
+ input_schema={
71
83
  "type": "object",
72
84
  "properties": {
73
85
  "query": {
@@ -76,17 +88,17 @@ class AnthropicSQLAgent(BaseSQLAgent):
76
88
  },
77
89
  "limit": {
78
90
  "type": "integer",
79
- "description": "Maximum number of rows to return (default: 100)",
80
- "default": 100,
91
+ "description": f"Maximum number of rows to return (default: {AnthropicSQLAgent.DEFAULT_SQL_LIMIT})",
92
+ "default": AnthropicSQLAgent.DEFAULT_SQL_LIMIT,
81
93
  },
82
94
  },
83
95
  "required": ["query"],
84
96
  },
85
- },
86
- {
87
- "name": "plot_data",
88
- "description": "Create a plot of query results.",
89
- "input_schema": {
97
+ ),
98
+ ToolDefinition(
99
+ name="plot_data",
100
+ description="Create a plot of query results.",
101
+ input_schema={
90
102
  "type": "object",
91
103
  "properties": {
92
104
  "y_values": {
@@ -120,7 +132,7 @@ class AnthropicSQLAgent(BaseSQLAgent):
120
132
  },
121
133
  "required": ["y_values"],
122
134
  },
123
- },
135
+ ),
124
136
  ]
125
137
 
126
138
  # Build system prompt with memories if available
@@ -128,8 +140,24 @@ class AnthropicSQLAgent(BaseSQLAgent):
128
140
 
129
141
  def _build_system_prompt(self) -> str:
130
142
  """Build system prompt with optional memory context."""
143
+ # For OAuth authentication, start with Claude Code identity
144
+ # Check if we're using OAuth by looking at the client
145
+ is_oauth = (
146
+ hasattr(self, "client")
147
+ and hasattr(self.client, "use_oauth")
148
+ and self.client.use_oauth
149
+ )
150
+
151
+ if is_oauth:
152
+ # For OAuth, keep system prompt minimal - just Claude Code identity
153
+ return "You are Claude Code, Anthropic's official CLI for Claude."
154
+ else:
155
+ return self._get_sql_assistant_instructions()
156
+
157
+ def _get_sql_assistant_instructions(self) -> str:
158
+ """Get the detailed SQL assistant instructions."""
131
159
  db_type = self._get_database_type_name()
132
- base_prompt = f"""You are a helpful SQL assistant that helps users query their {db_type} database.
160
+ instructions = f"""You are also a helpful SQL assistant that helps users query their {db_type} database.
133
161
 
134
162
  Your responsibilities:
135
163
  1. Understand user's natural language requests, think and convert them to SQL
@@ -161,9 +189,9 @@ Guidelines:
161
189
  self.database_name
162
190
  )
163
191
  if memory_context.strip():
164
- base_prompt += memory_context
192
+ instructions += memory_context
165
193
 
166
- return base_prompt
194
+ return instructions
167
195
 
168
196
  def add_memory(self, content: str) -> str | None:
169
197
  """Add a memory for the current database."""
@@ -197,83 +225,129 @@ Guidelines:
197
225
  return result
198
226
 
199
227
  async def process_tool_call(
200
- self, tool_name: str, tool_input: Dict[str, Any]
228
+ self, tool_name: str, tool_input: dict[str, Any]
201
229
  ) -> str:
202
230
  """Process a tool call and return the result."""
203
231
  # Use parent implementation for core tools
204
232
  return await super().process_tool_call(tool_name, tool_input)
205
233
 
206
- async def _process_stream_events(
207
- self,
208
- stream,
209
- content_blocks: List[Dict],
210
- tool_use_blocks: List[Dict],
211
- cancellation_token: asyncio.Event | None = None,
212
- ) -> AsyncIterator[StreamEvent]:
213
- """Process stream events and yield appropriate StreamEvents."""
214
- async for event in stream:
215
- # Only check cancellation if token is provided
216
- if cancellation_token is not None and cancellation_token.is_set():
217
- return
234
+ def _convert_user_message_to_message(
235
+ self, msg_content: str | list[dict[str, Any]]
236
+ ) -> Message:
237
+ """Convert user message content to Message object."""
238
+ if isinstance(msg_content, str):
239
+ return Message(MessageRole.USER, msg_content)
240
+
241
+ # Handle tool results format
242
+ tool_result_blocks = []
243
+ if isinstance(msg_content, list):
244
+ for item in msg_content:
245
+ if isinstance(item, dict) and item.get("type") == "tool_result":
246
+ tool_result_blocks.append(
247
+ ContentBlock(ContentType.TOOL_RESULT, item)
248
+ )
218
249
 
219
- if event.type == "content_block_start":
220
- if hasattr(event.content_block, "type"):
221
- if event.content_block.type == "tool_use":
222
- yield StreamEvent(
223
- "tool_use",
224
- {"name": event.content_block.name, "status": "started"},
250
+ if tool_result_blocks:
251
+ return Message(MessageRole.USER, tool_result_blocks)
252
+
253
+ # Fallback to string representation
254
+ return Message(MessageRole.USER, str(msg_content))
255
+
256
+ def _convert_assistant_message_to_message(
257
+ self, msg_content: str | list[dict[str, Any]]
258
+ ) -> Message:
259
+ """Convert assistant message content to Message object."""
260
+ if isinstance(msg_content, str):
261
+ return Message(MessageRole.ASSISTANT, msg_content)
262
+
263
+ if isinstance(msg_content, list):
264
+ content_blocks = []
265
+ for block in msg_content:
266
+ if isinstance(block, dict):
267
+ if block.get("type") == "text":
268
+ text_content = block.get("text", "")
269
+ if text_content: # Only add non-empty text blocks
270
+ content_blocks.append(
271
+ ContentBlock(ContentType.TEXT, text_content)
272
+ )
273
+ elif block.get("type") == "tool_use":
274
+ content_blocks.append(
275
+ ContentBlock(
276
+ ContentType.TOOL_USE,
277
+ {
278
+ "id": block["id"],
279
+ "name": block["name"],
280
+ "input": block["input"],
281
+ },
282
+ )
225
283
  )
226
- tool_use_blocks.append(
227
- {
228
- "id": event.content_block.id,
229
- "name": event.content_block.name,
230
- "input": {},
231
- }
232
- )
233
- elif event.content_block.type == "text":
234
- content_blocks.append({"type": "text", "text": ""})
235
-
236
- elif event.type == "content_block_delta":
237
- if hasattr(event.delta, "text"):
238
- yield StreamEvent("text", event.delta.text)
239
- if content_blocks and content_blocks[-1]["type"] == "text":
240
- content_blocks[-1]["text"] += event.delta.text
241
- elif hasattr(event.delta, "partial_json"):
242
- if tool_use_blocks:
243
- try:
244
- current_json = tool_use_blocks[-1].get("_partial", "")
245
- current_json += event.delta.partial_json
246
- tool_use_blocks[-1]["_partial"] = current_json
247
- tool_use_blocks[-1]["input"] = json.loads(current_json)
248
- except json.JSONDecodeError:
249
- pass
250
-
251
- elif event.type == "message_stop":
252
- break
253
-
254
- def _finalize_tool_blocks(self, tool_use_blocks: List[Dict]) -> str:
255
- """Finalize tool use blocks and return stop reason."""
256
- if tool_use_blocks:
257
- for block in tool_use_blocks:
258
- block["type"] = "tool_use"
259
- if "_partial" in block:
260
- del block["_partial"]
261
- return "tool_use"
262
- return "stop"
263
-
264
- async def _process_tool_results(
284
+ if content_blocks:
285
+ return Message(MessageRole.ASSISTANT, content_blocks)
286
+
287
+ # Fallback to string representation
288
+ return Message(MessageRole.ASSISTANT, str(msg_content))
289
+
290
+ def _convert_history_to_messages(self) -> list[Message]:
291
+ """Convert conversation history to Message objects."""
292
+ messages = []
293
+ for msg in self.conversation_history:
294
+ if msg["role"] == "user":
295
+ messages.append(self._convert_user_message_to_message(msg["content"]))
296
+ elif msg["role"] == "assistant":
297
+ messages.append(
298
+ self._convert_assistant_message_to_message(msg["content"])
299
+ )
300
+ return messages
301
+
302
+ def _convert_tool_results_to_message(
303
+ self, tool_results: list[dict[str, Any]]
304
+ ) -> Message:
305
+ """Convert tool results to a user Message object."""
306
+ tool_result_blocks = []
307
+ for tool_result in tool_results:
308
+ tool_result_blocks.append(
309
+ ContentBlock(ContentType.TOOL_RESULT, tool_result)
310
+ )
311
+ return Message(MessageRole.USER, tool_result_blocks)
312
+
313
+ def _convert_response_content_to_message(
314
+ self, content: list[dict[str, Any]]
315
+ ) -> Message:
316
+ """Convert response content to assistant Message object."""
317
+ content_blocks = []
318
+ for block in content:
319
+ if block.get("type") == "text":
320
+ text_content = block["text"]
321
+ if text_content: # Only add non-empty text blocks
322
+ content_blocks.append(ContentBlock(ContentType.TEXT, text_content))
323
+ elif block.get("type") == "tool_use":
324
+ content_blocks.append(
325
+ ContentBlock(
326
+ ContentType.TOOL_USE,
327
+ {
328
+ "id": block["id"],
329
+ "name": block["name"],
330
+ "input": block["input"],
331
+ },
332
+ )
333
+ )
334
+ return Message(MessageRole.ASSISTANT, content_blocks)
335
+
336
+ async def _execute_and_yield_tool_results(
265
337
  self,
266
- response: StreamingResponse,
338
+ response_content: list[dict[str, Any]],
267
339
  cancellation_token: asyncio.Event | None = None,
268
- ) -> AsyncIterator[StreamEvent]:
269
- """Process tool results and yield appropriate events."""
340
+ ) -> AsyncIterator[StreamEvent | list[dict[str, Any]]]:
341
+ """Execute tool calls and yield appropriate stream events."""
270
342
  tool_results = []
271
- for block in response.content:
272
- # Only check cancellation if token is provided
273
- if cancellation_token is not None and cancellation_token.is_set():
274
- return
275
343
 
344
+ for block in response_content:
276
345
  if block.get("type") == "tool_use":
346
+ # Check for cancellation before tool execution
347
+ if cancellation_token is not None and cancellation_token.is_set():
348
+ yield tool_results
349
+ return
350
+
277
351
  yield StreamEvent(
278
352
  "tool_use",
279
353
  {
@@ -316,7 +390,53 @@ Guidelines:
316
390
 
317
391
  tool_results.append(build_tool_result_block(block["id"], tool_result))
318
392
 
319
- yield StreamEvent("tool_result_data", tool_results)
393
+ yield tool_results
394
+
395
+ async def _handle_stream_events(
396
+ self,
397
+ stream_iterator: AsyncIterator[Any],
398
+ cancellation_token: asyncio.Event | None = None,
399
+ ) -> AsyncIterator[StreamEvent | Any]:
400
+ """Handle streaming events and yield stream events, return final response."""
401
+ response = None
402
+
403
+ async for event in stream_iterator:
404
+ if cancellation_token is not None and cancellation_token.is_set():
405
+ yield None
406
+ return
407
+
408
+ # Handle different event types
409
+ if hasattr(event, "type"):
410
+ if event.type == "content_block_start":
411
+ if hasattr(event.content_block, "type"):
412
+ if event.content_block.type == "tool_use":
413
+ yield StreamEvent(
414
+ "tool_use",
415
+ {
416
+ "name": event.content_block.name,
417
+ "status": "started",
418
+ },
419
+ )
420
+ elif event.type == "content_block_delta":
421
+ if hasattr(event.delta, "text"):
422
+ text = event.delta.text
423
+ if text is not None and text: # Only yield non-empty text
424
+ yield StreamEvent("text", text)
425
+ elif isinstance(event, dict) and event.get("type") == "response_ready":
426
+ response = event["data"]
427
+
428
+ yield response
429
+
430
+ def _create_message_request(self, messages: list[Message]) -> CreateMessageRequest:
431
+ """Create a CreateMessageRequest with standard parameters."""
432
+ return CreateMessageRequest(
433
+ model=self.model,
434
+ messages=messages,
435
+ max_tokens=self.MAX_TOKENS,
436
+ system=self.system_prompt,
437
+ tools=self.tools,
438
+ stream=True,
439
+ )
320
440
 
321
441
  async def query_stream(
322
442
  self,
@@ -329,32 +449,37 @@ Guidelines:
329
449
  self._last_results = None
330
450
  self._last_query = None
331
451
 
332
- # Build messages with history if requested
333
- if use_history:
334
- messages = self.conversation_history + [
335
- {"role": "user", "content": user_query}
336
- ]
337
- else:
338
- messages = [{"role": "user", "content": user_query}]
339
-
340
452
  try:
341
- # Create initial stream and get response
453
+ # Build messages with history if requested
454
+ messages = []
455
+ if use_history:
456
+ messages = self._convert_history_to_messages()
457
+
458
+ # For OAuth with no history, inject SQL assistant instructions as first user message
459
+ is_oauth = hasattr(self.client, "use_oauth") and self.client.use_oauth
460
+ if is_oauth and not messages:
461
+ instructions = self._get_sql_assistant_instructions()
462
+ messages.append(Message(MessageRole.USER, instructions))
463
+
464
+ # Add current user message
465
+ messages.append(Message(MessageRole.USER, user_query))
466
+
467
+ # Create initial request and get response
468
+ request = self._create_message_request(messages)
342
469
  response = None
343
- async for event in self._create_and_process_stream(
344
- messages, cancellation_token
470
+
471
+ async for event in self._handle_stream_events(
472
+ self.client.create_message_with_tools(request, cancellation_token),
473
+ cancellation_token,
345
474
  ):
346
- if cancellation_token is not None and cancellation_token.is_set():
347
- return
348
- if event.type == "response_ready":
349
- response = event.data
350
- else:
475
+ if isinstance(event, StreamEvent):
351
476
  yield event
477
+ else:
478
+ response = event
352
479
 
480
+ # Handle tool use cycles
353
481
  collected_content = []
354
-
355
- # Process tool calls if needed
356
482
  while response is not None and response.stop_reason == "tool_use":
357
- # Check for cancellation at the start of tool cycle
358
483
  if cancellation_token is not None and cancellation_token.is_set():
359
484
  return
360
485
 
@@ -363,82 +488,64 @@ Guidelines:
363
488
  {"role": "assistant", "content": response.content}
364
489
  )
365
490
 
366
- # Process tool results - DO NOT check cancellation during tool execution
367
- # as this would break the tool_use -> tool_result API contract
491
+ # Execute tools and get results
368
492
  tool_results = []
369
- async for event in self._process_tool_results(
370
- response, None
371
- ): # Pass None to disable cancellation checks
372
- if event.type == "tool_result_data":
373
- tool_results = event.data
374
- else:
493
+ async for event in self._execute_and_yield_tool_results(
494
+ response.content, cancellation_token
495
+ ):
496
+ if isinstance(event, StreamEvent):
375
497
  yield event
498
+ elif isinstance(event, list):
499
+ tool_results = event
376
500
 
377
501
  # Continue conversation with tool results
378
502
  collected_content.append({"role": "user", "content": tool_results})
379
503
  if use_history:
380
504
  self.conversation_history.extend(collected_content)
381
505
 
382
- # Check for cancellation AFTER tool results are complete
383
506
  if cancellation_token is not None and cancellation_token.is_set():
384
507
  return
385
508
 
386
- # Signal that we're processing the tool results
387
509
  yield StreamEvent("processing", "Analyzing results...")
388
510
 
511
+ # Build new messages with collected content
512
+ new_messages = messages.copy()
513
+ for content in collected_content:
514
+ if content["role"] == "user":
515
+ new_messages.append(
516
+ self._convert_tool_results_to_message(content["content"])
517
+ )
518
+ elif content["role"] == "assistant":
519
+ new_messages.append(
520
+ self._convert_response_content_to_message(
521
+ content["content"]
522
+ )
523
+ )
524
+
389
525
  # Get next response
526
+ request = self._create_message_request(new_messages)
390
527
  response = None
391
- async for event in self._create_and_process_stream(
392
- messages + collected_content, cancellation_token
528
+
529
+ async for event in self._handle_stream_events(
530
+ self.client.create_message_with_tools(request, cancellation_token),
531
+ cancellation_token,
393
532
  ):
394
- if cancellation_token is not None and cancellation_token.is_set():
395
- return
396
- if event.type == "response_ready":
397
- response = event.data
398
- else:
533
+ if isinstance(event, StreamEvent):
399
534
  yield event
535
+ else:
536
+ response = event
400
537
 
401
- # Update conversation history if using history
402
- if use_history:
403
- # Add final assistant response
404
- if response is not None:
405
- self.conversation_history.append(
406
- {"role": "assistant", "content": response.content}
407
- )
538
+ # Update conversation history with final response
539
+ if use_history and response is not None:
540
+ self.conversation_history.append(
541
+ {"role": "assistant", "content": response.content}
542
+ )
408
543
 
409
544
  except asyncio.CancelledError:
410
545
  return
411
546
  except Exception as e:
412
547
  yield StreamEvent("error", str(e))
413
548
 
414
- async def _create_and_process_stream(
415
- self, messages: List[Dict], cancellation_token: asyncio.Event | None = None
416
- ) -> AsyncIterator[StreamEvent]:
417
- """Create a stream and yield events while building response."""
418
- stream = await self.client.messages.create(
419
- model=self.model,
420
- max_tokens=4096,
421
- system=self.system_prompt,
422
- messages=messages,
423
- tools=self.tools,
424
- stream=True,
425
- )
426
-
427
- content_blocks = []
428
- tool_use_blocks = []
429
-
430
- async for event in self._process_stream_events(
431
- stream, content_blocks, tool_use_blocks, cancellation_token
432
- ):
433
- # Only check cancellation if token is provided
434
- if cancellation_token is not None and cancellation_token.is_set():
435
- return
436
- yield event
437
-
438
- # Finalize tool blocks and create response
439
- stop_reason = self._finalize_tool_blocks(tool_use_blocks)
440
- content_blocks.extend(tool_use_blocks)
441
-
442
- yield StreamEvent(
443
- "response_ready", StreamingResponse(content_blocks, stop_reason)
444
- )
549
+ async def close(self):
550
+ """Close the client."""
551
+ await self.client.close()