sqlsaber 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlsaber might be problematic. Click here for more details.

Files changed (38) hide show
  1. sqlsaber/agents/__init__.py +2 -4
  2. sqlsaber/agents/base.py +18 -221
  3. sqlsaber/agents/mcp.py +2 -2
  4. sqlsaber/agents/pydantic_ai_agent.py +170 -0
  5. sqlsaber/cli/auth.py +146 -79
  6. sqlsaber/cli/commands.py +22 -7
  7. sqlsaber/cli/database.py +1 -1
  8. sqlsaber/cli/interactive.py +65 -30
  9. sqlsaber/cli/models.py +58 -29
  10. sqlsaber/cli/streaming.py +114 -77
  11. sqlsaber/config/api_keys.py +9 -11
  12. sqlsaber/config/providers.py +116 -0
  13. sqlsaber/config/settings.py +50 -30
  14. sqlsaber/database/connection.py +3 -3
  15. sqlsaber/mcp/mcp.py +43 -51
  16. sqlsaber/models/__init__.py +0 -3
  17. sqlsaber/tools/__init__.py +25 -0
  18. sqlsaber/tools/base.py +85 -0
  19. sqlsaber/tools/enums.py +21 -0
  20. sqlsaber/tools/instructions.py +251 -0
  21. sqlsaber/tools/registry.py +130 -0
  22. sqlsaber/tools/sql_tools.py +275 -0
  23. sqlsaber/tools/visualization_tools.py +144 -0
  24. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/METADATA +20 -39
  25. sqlsaber-0.16.0.dist-info/RECORD +51 -0
  26. sqlsaber/agents/anthropic.py +0 -579
  27. sqlsaber/agents/streaming.py +0 -16
  28. sqlsaber/clients/__init__.py +0 -6
  29. sqlsaber/clients/anthropic.py +0 -285
  30. sqlsaber/clients/base.py +0 -31
  31. sqlsaber/clients/exceptions.py +0 -117
  32. sqlsaber/clients/models.py +0 -282
  33. sqlsaber/clients/streaming.py +0 -257
  34. sqlsaber/models/events.py +0 -28
  35. sqlsaber-0.14.0.dist-info/RECORD +0 -51
  36. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/WHEEL +0 -0
  37. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/entry_points.txt +0 -0
  38. {sqlsaber-0.14.0.dist-info → sqlsaber-0.16.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/models.py CHANGED
@@ -3,12 +3,13 @@
3
3
  import asyncio
4
4
  import sys
5
5
 
6
+ import cyclopts
6
7
  import httpx
7
8
  import questionary
8
- import cyclopts
9
9
  from rich.console import Console
10
10
  from rich.table import Table
11
11
 
12
+ from sqlsaber.config import providers
12
13
  from sqlsaber.config.settings import Config
13
14
 
14
15
  # Global instances for CLI commands
@@ -26,49 +27,75 @@ class ModelManager:
26
27
 
27
28
  DEFAULT_MODEL = "anthropic:claude-sonnet-4-20250514"
28
29
  MODELS_API_URL = "https://models.dev/api.json"
30
+ # Providers come from central registry
31
+ SUPPORTED_PROVIDERS = providers.all_keys()
32
+
33
+ async def fetch_available_models(
34
+ self, providers: list[str] | None = None
35
+ ) -> list[dict]:
36
+ """Fetch available models across providers from models.dev API.
29
37
 
30
- async def fetch_available_models(self) -> list[dict]:
31
- """Fetch available models from models.dev API."""
38
+ Returns list of dicts with keys: id (provider:model_id), provider, name, description, context_length, knowledge.
39
+ """
32
40
  try:
33
41
  async with httpx.AsyncClient(timeout=10.0) as client:
34
42
  response = await client.get(self.MODELS_API_URL)
35
43
  response.raise_for_status()
36
44
  data = response.json()
37
45
 
38
- # Filter for Anthropic models only
39
- anthropic_models = []
40
- anthropic_data = data.get("anthropic", {})
41
-
42
- if "models" in anthropic_data:
43
- for model_id, model_info in anthropic_data["models"].items():
44
- # Convert to our format (anthropic:model-name)
45
- formatted_id = f"anthropic:{model_id}"
46
-
47
- # Extract cost information for display
48
- cost_info = model_info.get("cost", {})
46
+ providers = providers or self.SUPPORTED_PROVIDERS
47
+ results: list[dict] = []
48
+
49
+ for provider in providers:
50
+ prov_data = data.get(provider, {})
51
+ models_obj = (
52
+ prov_data.get("models") or prov_data.get("Models") or {}
53
+ )
54
+ if not isinstance(models_obj, dict):
55
+ continue
56
+ for model_id, model_info in models_obj.items():
57
+ formatted_id = f"{provider}:{model_id}"
58
+ # cost
59
+ cost_info = (
60
+ model_info.get("cost", {})
61
+ if isinstance(model_info, dict)
62
+ else {}
63
+ )
49
64
  cost_display = ""
50
- if cost_info:
65
+ if isinstance(cost_info, dict) and cost_info:
51
66
  input_cost = cost_info.get("input", 0)
52
67
  output_cost = cost_info.get("output", 0)
53
68
  cost_display = f"${input_cost}/{output_cost} per 1M tokens"
69
+ # context
70
+ limit_info = (
71
+ model_info.get("limit", {})
72
+ if isinstance(model_info, dict)
73
+ else {}
74
+ )
75
+ context_length = (
76
+ limit_info.get("context", 0)
77
+ if isinstance(limit_info, dict)
78
+ else 0
79
+ )
54
80
 
55
- # Extract context length
56
- limit_info = model_info.get("limit", {})
57
- context_length = limit_info.get("context", 0)
58
-
59
- anthropic_models.append(
81
+ results.append(
60
82
  {
61
83
  "id": formatted_id,
62
- "name": model_info.get("name", model_id),
84
+ "provider": provider,
85
+ "name": model_info.get("name", model_id)
86
+ if isinstance(model_info, dict)
87
+ else model_id,
63
88
  "description": cost_display,
64
89
  "context_length": context_length,
65
- "knowledge": model_info.get("knowledge", ""),
90
+ "knowledge": model_info.get("knowledge", "")
91
+ if isinstance(model_info, dict)
92
+ else "",
66
93
  }
67
94
  )
68
95
 
69
- # Sort by name for better display
70
- anthropic_models.sort(key=lambda x: x["name"])
71
- return anthropic_models
96
+ # Sort by provider then by name
97
+ results.sort(key=lambda x: (x["provider"], x["name"]))
98
+ return results
72
99
  except Exception as e:
73
100
  console.print(f"[red]Error fetching models: {e}[/red]")
74
101
  return []
@@ -110,7 +137,8 @@ def list():
110
137
  )
111
138
  return
112
139
 
113
- table = Table(title="Available Anthropic Models")
140
+ table = Table(title="Available Models")
141
+ table.add_column("Provider", style="magenta")
114
142
  table.add_column("ID", style="cyan")
115
143
  table.add_column("Name", style="green")
116
144
  table.add_column("Description", style="white")
@@ -133,6 +161,7 @@ def list():
133
161
  )
134
162
 
135
163
  table.add_row(
164
+ model.get("provider", "-"),
136
165
  model["id"],
137
166
  model["name"],
138
167
  description,
@@ -161,8 +190,9 @@ def set():
161
190
  # Create choices for questionary
162
191
  choices = []
163
192
  for model in models:
164
- # Format: "ID - Name (Description)"
165
- choice_text = f"{model['id']} - {model['name']}"
193
+ # Format: "[provider] ID - Name (Description)"
194
+ prov = model.get("provider", "?")
195
+ choice_text = f"[{prov}] {model['id']} - {model['name']}"
166
196
  if model["description"]:
167
197
  choice_text += f" ({model['description'][:50]}{'...' if len(model['description']) > 50 else ''})"
168
198
 
@@ -179,7 +209,6 @@ def set():
179
209
  selected_model = await questionary.select(
180
210
  "Select a model:",
181
211
  choices=choices,
182
- use_shortcuts=True,
183
212
  use_search_filter=True,
184
213
  use_jk_keys=False, # Disable j/k keys when using search filter
185
214
  default=choices[default_index] if choices else None,
sqlsaber/cli/streaming.py CHANGED
@@ -1,100 +1,137 @@
1
- """Streaming query handling for the CLI."""
1
+ """Streaming query handling for the CLI (pydantic-ai based)."""
2
2
 
3
3
  import asyncio
4
-
4
+ import json
5
+ from typing import AsyncIterable
6
+
7
+ from pydantic_ai import Agent, RunContext
8
+ from pydantic_ai.messages import (
9
+ AgentStreamEvent,
10
+ FunctionToolCallEvent,
11
+ FunctionToolResultEvent,
12
+ PartDeltaEvent,
13
+ PartStartEvent,
14
+ TextPart,
15
+ TextPartDelta,
16
+ ThinkingPart,
17
+ ThinkingPartDelta,
18
+ )
5
19
  from rich.console import Console
6
20
 
7
- from sqlsaber.agents.base import BaseSQLAgent
8
21
  from sqlsaber.cli.display import DisplayManager
9
22
 
10
23
 
11
24
  class StreamingQueryHandler:
12
- """Handles streaming query execution and display."""
25
+ """Handles streaming query execution and display using pydantic-ai events."""
13
26
 
14
27
  def __init__(self, console: Console):
15
28
  self.console = console
16
29
  self.display = DisplayManager(console)
17
30
 
31
+ self.status = self.console.status(
32
+ "[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
33
+ )
34
+
35
+ async def _event_stream_handler(
36
+ self, ctx: RunContext, event_stream: AsyncIterable[AgentStreamEvent]
37
+ ) -> None:
38
+ async for event in event_stream:
39
+ if isinstance(event, PartStartEvent):
40
+ if isinstance(event.part, (TextPart, ThinkingPart)):
41
+ self.status.stop()
42
+ self.display.show_text_stream(event.part.content)
43
+
44
+ elif isinstance(event, PartDeltaEvent):
45
+ if isinstance(event.delta, (TextPartDelta, ThinkingPartDelta)):
46
+ delta = event.delta.content_delta or ""
47
+ if delta:
48
+ self.status.stop()
49
+ self.display.show_text_stream(delta)
50
+
51
+ elif isinstance(event, FunctionToolCallEvent):
52
+ # Show tool execution start
53
+ self.status.stop()
54
+ args = event.part.args_as_dict()
55
+ self.display.show_newline()
56
+ self.display.show_tool_executing(event.part.tool_name, args)
57
+
58
+ elif isinstance(event, FunctionToolResultEvent):
59
+ self.status.stop()
60
+ # Route tool result to appropriate display
61
+ tool_name = event.result.tool_name
62
+ content = event.result.content
63
+ if tool_name == "list_tables":
64
+ self.display.show_table_list(content)
65
+ elif tool_name == "introspect_schema":
66
+ self.display.show_schema_info(content)
67
+ elif tool_name == "execute_sql":
68
+ try:
69
+ data = json.loads(content)
70
+ if data.get("success") and data.get("results"):
71
+ self.display.show_query_results(data["results"]) # type: ignore[arg-type]
72
+ except json.JSONDecodeError:
73
+ # If not JSON, ignore here
74
+ pass
75
+ elif tool_name == "plot_data":
76
+ self.display.show_plot(
77
+ {"tool_name": tool_name, "result": content, "input": {}}
78
+ )
79
+
18
80
  async def execute_streaming_query(
19
81
  self,
20
82
  user_query: str,
21
- agent: BaseSQLAgent,
83
+ agent: Agent,
22
84
  cancellation_token: asyncio.Event | None = None,
85
+ message_history: list | None = None,
23
86
  ):
24
- """Execute a query with streaming display."""
25
-
26
- status = self.console.status(
27
- "[yellow]Crunching data...[/yellow]", spinner="bouncingBall"
28
- )
29
- status.start()
30
-
87
+ self.status.start()
31
88
  try:
32
- async for event in agent.query_stream(
33
- user_query, cancellation_token=cancellation_token
34
- ):
35
- if cancellation_token is not None and cancellation_token.is_set():
36
- break
37
-
38
- if event.type == "tool_use":
39
- self._stop_status(status)
40
-
41
- if event.data["status"] == "executing":
42
- self.display.show_newline()
43
- self.display.show_tool_executing(
44
- event.data["name"], event.data["input"]
45
- )
46
-
47
- elif event.type == "text":
48
- # Always stop status when text streaming starts
49
- self._stop_status(status)
50
- self.display.show_text_stream(event.data)
51
-
52
- elif event.type == "query_result":
53
- if event.data["results"]:
54
- self.display.show_query_results(event.data["results"])
55
-
56
- elif event.type == "tool_result":
57
- # Handle tool results - particularly list_tables and introspect_schema
58
- if event.data.get("tool_name") == "list_tables":
59
- self.display.show_table_list(event.data["result"])
60
- elif event.data.get("tool_name") == "introspect_schema":
61
- self.display.show_schema_info(event.data["result"])
62
-
63
- elif event.type == "plot_result":
64
- # Handle plot results
65
- self.display.show_plot(event.data)
66
-
67
- elif event.type == "processing":
68
- self.display.show_newline() # Add newline after explanation text
69
- self._stop_status(status)
70
- status = self.display.show_processing(event.data)
71
- status.start()
72
-
73
- elif event.type == "error":
74
- self._stop_status(status)
75
- self.display.show_error(event.data)
76
-
89
+ # If Anthropic OAuth, inject SQLsaber instructions before the first user prompt
90
+ prepared_prompt: str | list[str] = user_query
91
+ is_oauth = bool(getattr(agent, "_sqlsaber_is_oauth", False))
92
+ no_history = not message_history
93
+ if is_oauth and no_history:
94
+ ib = getattr(agent, "_sqlsaber_instruction_builder", None)
95
+ mm = getattr(agent, "_sqlsaber_memory_manager", None)
96
+ db_type = getattr(agent, "_sqlsaber_db_type", "database")
97
+ db_name = getattr(agent, "_sqlsaber_database_name", None)
98
+ instructions = (
99
+ ib.build_instructions(db_type=db_type) if ib is not None else ""
100
+ )
101
+ mem = (
102
+ mm.format_memories_for_prompt(db_name)
103
+ if (mm is not None and db_name)
104
+ else ""
105
+ )
106
+ parts = [p for p in (instructions, mem) if p and str(p).strip()]
107
+ if parts:
108
+ injected = "\n\n".join(parts)
109
+ prepared_prompt = [injected, user_query]
110
+
111
+ # Run the agent with our event stream handler
112
+ run = await agent.run(
113
+ prepared_prompt,
114
+ message_history=message_history,
115
+ event_stream_handler=self._event_stream_handler,
116
+ )
117
+ # After the run completes, show the assistant's final text as markdown if available
118
+ try:
119
+ output = run.output
120
+ if isinstance(output, str) and output.strip():
121
+ self.display.show_newline()
122
+ self.display.show_markdown_response(
123
+ [{"type": "text", "text": output}]
124
+ )
125
+ except Exception as e:
126
+ self.display.show_error(str(e))
127
+ self.display.show_newline()
128
+ return run
77
129
  except asyncio.CancelledError:
78
- # Handle cancellation gracefully
79
- self._stop_status(status)
80
130
  self.display.show_newline()
81
131
  self.console.print("[yellow]Query interrupted[/yellow]")
82
- return
132
+ return None
83
133
  finally:
84
- # Make sure status is stopped
85
- self._stop_status(status)
86
-
87
- # Display the last assistant response as markdown
88
- if hasattr(agent, "conversation_history") and agent.conversation_history:
89
- last_message = agent.conversation_history[-1]
90
- if last_message.get("role") == "assistant" and last_message.get(
91
- "content"
92
- ):
93
- self.display.show_markdown_response(last_message["content"])
94
-
95
- def _stop_status(self, status):
96
- """Safely stop a status spinner."""
97
- try:
98
- status.stop()
99
- except Exception:
100
- pass # Status might already be stopped
134
+ try:
135
+ self.status.stop()
136
+ except Exception:
137
+ pass
@@ -6,6 +6,8 @@ import os
6
6
  import keyring
7
7
  from rich.console import Console
8
8
 
9
+ from sqlsaber.config import providers
10
+
9
11
  console = Console()
10
12
 
11
13
 
@@ -30,9 +32,7 @@ class APIKeyManager:
30
32
  try:
31
33
  api_key = keyring.get_password(service_name, provider)
32
34
  if api_key:
33
- console.print(
34
- f"Using stored {provider} API key from keyring", style="dim"
35
- )
35
+ console.print(f"Using stored {provider} API key", style="dim")
36
36
  return api_key
37
37
  except Exception as e:
38
38
  # Keyring access failed, continue to prompt
@@ -43,12 +43,9 @@ class APIKeyManager:
43
43
 
44
44
  def _get_env_var_name(self, provider: str) -> str:
45
45
  """Get the expected environment variable name for a provider."""
46
- if provider == "openai":
47
- return "OPENAI_API_KEY"
48
- elif provider == "anthropic":
49
- return "ANTHROPIC_API_KEY"
50
- else:
51
- return "AI_API_KEY"
46
+ # Normalize aliases to canonical provider keys
47
+ key = providers.canonical(provider) or provider
48
+ return providers.env_var_name(key)
52
49
 
53
50
  def _get_service_name(self, provider: str) -> str:
54
51
  """Get the keyring service name for a provider."""
@@ -60,7 +57,7 @@ class APIKeyManager:
60
57
  """Prompt user for API key and store it in keyring."""
61
58
  try:
62
59
  console.print(
63
- f"\n{provider.title()} API key not found in environment or keyring."
60
+ f"\n{provider.title()} API key not found in environment or your OS's credentials store."
64
61
  )
65
62
  console.print("You can either:")
66
63
  console.print(f" 1. Set the {env_var_name} environment variable")
@@ -85,7 +82,8 @@ class APIKeyManager:
85
82
  console.print("API key stored securely for future use", style="green")
86
83
  except Exception as e:
87
84
  console.print(
88
- f"Warning: Could not store API key in keyring: {e}", style="yellow"
85
+ f"Warning: Could not store API key in your operating system's credentials store: {e}",
86
+ style="yellow",
89
87
  )
90
88
  console.print(
91
89
  "You may need to enter it again next time", style="yellow"
@@ -0,0 +1,116 @@
1
+ """Central registry for supported AI providers.
2
+
3
+ This module defines a single source of truth for providers used across the
4
+ codebase (CLI, config, agents). Update this file to add or modify providers.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Dict, Iterable, List, Optional
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class ProviderSpec:
15
+ """Specification for a provider."""
16
+
17
+ key: str
18
+ env_var: str
19
+ supports_oauth: bool = False
20
+ aliases: tuple[str, ...] = ()
21
+
22
+
23
+ # Ordered definition -> used for CLI display order
24
+ _PROVIDERS: List[ProviderSpec] = [
25
+ ProviderSpec(
26
+ key="anthropic",
27
+ env_var="ANTHROPIC_API_KEY",
28
+ supports_oauth=True,
29
+ aliases=(),
30
+ ),
31
+ ProviderSpec(
32
+ key="openai",
33
+ env_var="OPENAI_API_KEY",
34
+ aliases=(),
35
+ ),
36
+ ProviderSpec(
37
+ key="google",
38
+ env_var="GOOGLE_API_KEY",
39
+ # Historically some model IDs start with "google-gla"; treat as alias
40
+ aliases=("google-gla",),
41
+ ),
42
+ ProviderSpec(
43
+ key="groq",
44
+ env_var="GROQ_API_KEY",
45
+ aliases=(),
46
+ ),
47
+ ProviderSpec(
48
+ key="mistral",
49
+ env_var="MISTRAL_API_KEY",
50
+ aliases=(),
51
+ ),
52
+ ProviderSpec(
53
+ key="cohere",
54
+ env_var="COHERE_API_KEY",
55
+ aliases=(),
56
+ ),
57
+ ProviderSpec(
58
+ key="huggingface",
59
+ env_var="HUGGINGFACE_API_KEY",
60
+ aliases=(),
61
+ ),
62
+ ]
63
+
64
+
65
+ # Fast lookup maps
66
+ _BY_KEY: Dict[str, ProviderSpec] = {p.key: p for p in _PROVIDERS}
67
+ _ALIAS_TO_KEY: Dict[str, str] = {
68
+ alias: p.key for p in _PROVIDERS for alias in p.aliases
69
+ }
70
+
71
+
72
+ def all_keys() -> List[str]:
73
+ """Return provider keys in display order."""
74
+ return [p.key for p in _PROVIDERS]
75
+
76
+
77
+ def env_var_name(key: str) -> str:
78
+ """Return the expected environment variable for a provider.
79
+
80
+ Falls back to a generic name if the provider is unknown.
81
+ """
82
+ spec = _BY_KEY.get(key)
83
+ return spec.env_var if spec else "AI_API_KEY"
84
+
85
+
86
+ def supports_oauth(key: str) -> bool:
87
+ """Return True if the provider supports OAuth in SQLsaber."""
88
+ spec = _BY_KEY.get(key)
89
+ return bool(spec and spec.supports_oauth)
90
+
91
+
92
+ def canonical(key_or_alias: str) -> Optional[str]:
93
+ """Return the canonical provider key for a provider or alias.
94
+
95
+ Returns None if not recognized.
96
+ """
97
+ if key_or_alias in _BY_KEY:
98
+ return key_or_alias
99
+ return _ALIAS_TO_KEY.get(key_or_alias)
100
+
101
+
102
+ def provider_from_model(model_name: str) -> Optional[str]:
103
+ """Infer the canonical provider key from a model identifier.
104
+
105
+ Accepts either "provider:model_id" or a bare provider string. Aliases are
106
+ normalized to their canonical provider key.
107
+ """
108
+ if not model_name:
109
+ return None
110
+ provider_raw = model_name.split(":", 1)[0]
111
+ return canonical(provider_raw)
112
+
113
+
114
+ def specs() -> Iterable[ProviderSpec]:
115
+ """Iterate provider specifications (in display order)."""
116
+ return tuple(_PROVIDERS)
@@ -9,6 +9,7 @@ from typing import Any
9
9
 
10
10
  import platformdirs
11
11
 
12
+ from sqlsaber.config import providers
12
13
  from sqlsaber.config.api_keys import APIKeyManager
13
14
  from sqlsaber.config.auth import AuthConfigManager, AuthMethod
14
15
  from sqlsaber.config.oauth_flow import AnthropicOAuthFlow
@@ -84,47 +85,66 @@ class Config:
84
85
  self.model_name = self.model_config_manager.get_model()
85
86
  self.api_key_manager = APIKeyManager()
86
87
  self.auth_config_manager = AuthConfigManager()
87
- self.oauth_flow = AnthropicOAuthFlow()
88
88
 
89
- # Get authentication credentials based on configured method
89
+ # Authentication method (API key or Anthropic OAuth)
90
90
  self.auth_method = self.auth_config_manager.get_auth_method()
91
- self.api_key = None
92
- self.oauth_token = None
93
-
94
- if self.auth_method == AuthMethod.CLAUDE_PRO:
95
- # Try to get OAuth token and refresh if needed
96
- try:
97
- token = self.oauth_flow.refresh_token_if_needed()
98
- if token:
99
- self.oauth_token = token.access_token
100
- except Exception:
101
- # OAuth token unavailable, will need to re-authenticate
102
- pass
91
+
92
+ # Optional Anthropic OAuth access token (only relevant for provider=='anthropic')
93
+ if self.auth_method == AuthMethod.CLAUDE_PRO and self.model_name.startswith(
94
+ "anthropic"
95
+ ):
96
+ self.oauth_token = self.get_oauth_access_token()
103
97
  else:
104
- # Use API key authentication (default or explicitly configured)
105
98
  self.api_key = self._get_api_key()
99
+ # self.oauth_token = None
106
100
 
107
101
  def _get_api_key(self) -> str | None:
108
102
  """Get API key for the model provider using cascading logic."""
109
- model = self.model_name
110
- if model.startswith("anthropic:"):
111
- return self.api_key_manager.get_api_key("anthropic")
103
+ model = self.model_name or ""
104
+ prov = providers.provider_from_model(model)
105
+ if prov in set(providers.all_keys()):
106
+ return self.api_key_manager.get_api_key(prov) # type: ignore[arg-type]
107
+ return None
112
108
 
113
109
  def set_model(self, model: str) -> None:
114
110
  """Set the model and update configuration."""
115
111
  self.model_config_manager.set_model(model)
116
112
  self.model_name = model
117
113
 
114
+ def get_oauth_access_token(self) -> str | None:
115
+ """Return a valid Anthropic OAuth access token if configured, else None.
116
+
117
+ Uses the stored refresh token (if present) to refresh as needed.
118
+ Only relevant when provider is 'anthropic'.
119
+ """
120
+ if not self.model_name.startswith("anthropic"):
121
+ return None
122
+ try:
123
+ flow = AnthropicOAuthFlow()
124
+ token = flow.refresh_token_if_needed()
125
+ return token.access_token if token else None
126
+ except Exception:
127
+ return None
128
+
118
129
  def validate(self):
119
- """Validate that necessary configuration is present."""
120
- # 1. Claude-Pro flow → require OAuth token only
121
- if self.auth_method == AuthMethod.CLAUDE_PRO:
122
- if not self.oauth_token:
123
- raise ValueError(
124
- "OAuth token not available. Run 'saber auth setup' to authenticate with Claude Pro."
125
- )
126
- return # OAuth path satisfied – nothing more to check
127
-
128
- # 2. Default / API-key flow → require API key
129
- if not self.api_key:
130
- raise ValueError("Anthropic API key not found.")
130
+ """Validate that necessary configuration is present.
131
+
132
+ Also ensure provider env var is set from keyring if needed for API-key flows.
133
+ """
134
+ model = self.model_name or ""
135
+ provider_key = providers.provider_from_model(model)
136
+ env_var = providers.env_var_name(provider_key or "") if provider_key else None
137
+ if env_var:
138
+ # Anthropic special-case: allow OAuth in lieu of API key only when explicitly configured
139
+ if (
140
+ provider_key == "anthropic"
141
+ and self.auth_method == AuthMethod.CLAUDE_PRO
142
+ and self.oauth_token
143
+ ):
144
+ return
145
+ # If we don't have a key resolved from env/keyring, raise
146
+ if not self.api_key:
147
+ raise ValueError(f"{provider_key.capitalize()} API key not found.")
148
+ # Hydrate env var for downstream SDKs if missing
149
+ if not os.getenv(env_var):
150
+ os.environ[env_var] = self.api_key
@@ -1,10 +1,10 @@
1
1
  """Database connection management."""
2
2
 
3
- from abc import ABC, abstractmethod
4
- from typing import Any
5
- from urllib.parse import urlparse, parse_qs
6
3
  import ssl
4
+ from abc import ABC, abstractmethod
7
5
  from pathlib import Path
6
+ from typing import Any
7
+ from urllib.parse import parse_qs, urlparse
8
8
 
9
9
  import aiomysql
10
10
  import aiosqlite