causaliq-knowledge 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +1 -1
- causaliq_knowledge/cli.py +244 -37
- causaliq_knowledge/llm/__init__.py +39 -10
- causaliq_knowledge/llm/anthropic_client.py +256 -0
- causaliq_knowledge/llm/base_client.py +220 -0
- causaliq_knowledge/llm/deepseek_client.py +108 -0
- causaliq_knowledge/llm/gemini_client.py +117 -39
- causaliq_knowledge/llm/groq_client.py +115 -40
- causaliq_knowledge/llm/mistral_client.py +122 -0
- causaliq_knowledge/llm/ollama_client.py +240 -0
- causaliq_knowledge/llm/openai_client.py +115 -0
- causaliq_knowledge/llm/openai_compat_client.py +287 -0
- causaliq_knowledge/llm/provider.py +99 -46
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.2.0.dist-info}/METADATA +8 -9
- causaliq_knowledge-0.2.0.dist-info/RECORD +22 -0
- causaliq_knowledge-0.1.0.dist-info/RECORD +0 -15
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.2.0.dist-info}/WHEEL +0 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.2.0.dist-info}/entry_points.txt +0 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""Abstract base class for LLM clients.
|
|
2
|
+
|
|
3
|
+
This module defines the common interface that all LLM vendor clients
|
|
4
|
+
must implement. This provides a consistent API regardless of the
|
|
5
|
+
underlying LLM provider.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class LLMConfig:
|
|
19
|
+
"""Base configuration for all LLM clients.
|
|
20
|
+
|
|
21
|
+
This dataclass defines common configuration options shared by all
|
|
22
|
+
LLM provider clients. Vendor-specific clients may extend this with
|
|
23
|
+
additional options.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
model: Model identifier (provider-specific format).
|
|
27
|
+
temperature: Sampling temperature (0.0=deterministic, 1.0=creative).
|
|
28
|
+
max_tokens: Maximum tokens in the response.
|
|
29
|
+
timeout: Request timeout in seconds.
|
|
30
|
+
api_key: API key for authentication (optional, can use env var).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
model: str
|
|
34
|
+
temperature: float = 0.1
|
|
35
|
+
max_tokens: int = 500
|
|
36
|
+
timeout: float = 30.0
|
|
37
|
+
api_key: Optional[str] = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class LLMResponse:
|
|
42
|
+
"""Standard response from any LLM client.
|
|
43
|
+
|
|
44
|
+
This dataclass provides a unified response format across all LLM providers,
|
|
45
|
+
abstracting away provider-specific response structures.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
content: The text content of the response.
|
|
49
|
+
model: The model that generated the response.
|
|
50
|
+
input_tokens: Number of input/prompt tokens used.
|
|
51
|
+
output_tokens: Number of output/completion tokens generated.
|
|
52
|
+
cost: Estimated cost of the request (if available).
|
|
53
|
+
raw_response: The original provider-specific response (for debugging).
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
content: str
|
|
57
|
+
model: str
|
|
58
|
+
input_tokens: int = 0
|
|
59
|
+
output_tokens: int = 0
|
|
60
|
+
cost: float = 0.0
|
|
61
|
+
raw_response: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
|
62
|
+
|
|
63
|
+
def parse_json(self) -> Optional[Dict[str, Any]]:
|
|
64
|
+
"""Parse content as JSON, handling common formatting issues.
|
|
65
|
+
|
|
66
|
+
LLMs sometimes wrap JSON in markdown code blocks. This method
|
|
67
|
+
handles those cases and attempts to extract valid JSON.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Parsed JSON as dict, or None if parsing fails.
|
|
71
|
+
"""
|
|
72
|
+
try:
|
|
73
|
+
# Clean up potential markdown code blocks
|
|
74
|
+
text = self.content.strip()
|
|
75
|
+
if text.startswith("```json"):
|
|
76
|
+
text = text[7:]
|
|
77
|
+
elif text.startswith("```"):
|
|
78
|
+
text = text[3:]
|
|
79
|
+
if text.endswith("```"):
|
|
80
|
+
text = text[:-3]
|
|
81
|
+
|
|
82
|
+
return json.loads(text.strip()) # type: ignore[no-any-return]
|
|
83
|
+
except json.JSONDecodeError as e:
|
|
84
|
+
logger.warning(f"Failed to parse JSON response: {e}")
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class BaseLLMClient(ABC):
|
|
89
|
+
"""Abstract base class for LLM clients.
|
|
90
|
+
|
|
91
|
+
All LLM vendor clients (OpenAI, Anthropic, Groq, Gemini, Llama, etc.)
|
|
92
|
+
must implement this interface to ensure consistent behavior across
|
|
93
|
+
the codebase.
|
|
94
|
+
|
|
95
|
+
This abstraction allows:
|
|
96
|
+
- Easy addition of new LLM providers
|
|
97
|
+
- Consistent API for all providers
|
|
98
|
+
- Provider-agnostic code in higher-level modules
|
|
99
|
+
- Simplified testing with mock implementations
|
|
100
|
+
|
|
101
|
+
Example:
|
|
102
|
+
>>> class MyClient(BaseLLMClient):
|
|
103
|
+
... def completion(self, messages, **kwargs):
|
|
104
|
+
... # Implementation here
|
|
105
|
+
... pass
|
|
106
|
+
...
|
|
107
|
+
>>> client = MyClient(config)
|
|
108
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
109
|
+
>>> response = client.completion(msgs)
|
|
110
|
+
>>> print(response.content)
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def __init__(self, config: LLMConfig) -> None:
|
|
115
|
+
"""Initialize the client with configuration.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
config: Configuration for the LLM client.
|
|
119
|
+
"""
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def provider_name(self) -> str:
|
|
125
|
+
"""Return the name of the LLM provider.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Provider name (e.g., "openai", "anthropic", "groq").
|
|
129
|
+
"""
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def completion(
|
|
134
|
+
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
135
|
+
) -> LLMResponse:
|
|
136
|
+
"""Make a chat completion request.
|
|
137
|
+
|
|
138
|
+
This is the core method that sends a request to the LLM provider
|
|
139
|
+
and returns a standardized response.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
143
|
+
Roles can be: "system", "user", "assistant".
|
|
144
|
+
**kwargs: Provider-specific options (temperature, max_tokens, etc.)
|
|
145
|
+
that override the config defaults.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
LLMResponse with the generated content and metadata.
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If the API request fails or returns an error.
|
|
152
|
+
"""
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
def complete_json(
|
|
156
|
+
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
157
|
+
) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
|
|
158
|
+
"""Make a completion request and parse response as JSON.
|
|
159
|
+
|
|
160
|
+
Convenience method that calls completion() and attempts to parse
|
|
161
|
+
the response content as JSON.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
165
|
+
**kwargs: Provider-specific options passed to completion().
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Tuple of (parsed JSON dict or None, raw LLMResponse).
|
|
169
|
+
"""
|
|
170
|
+
response = self.completion(messages, **kwargs)
|
|
171
|
+
parsed = response.parse_json()
|
|
172
|
+
return parsed, response
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
@abstractmethod
|
|
176
|
+
def call_count(self) -> int:
|
|
177
|
+
"""Return the number of API calls made by this client.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Total number of completion calls made.
|
|
181
|
+
"""
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
@abstractmethod
|
|
185
|
+
def is_available(self) -> bool:
|
|
186
|
+
"""Check if the LLM provider is available and configured.
|
|
187
|
+
|
|
188
|
+
This method checks whether the client can make API calls:
|
|
189
|
+
- For cloud providers: checks if API key is set
|
|
190
|
+
- For local providers: checks if server is running
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
True if the provider is available and ready for requests.
|
|
194
|
+
"""
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
@abstractmethod
|
|
198
|
+
def list_models(self) -> List[str]:
|
|
199
|
+
"""List available models from the provider.
|
|
200
|
+
|
|
201
|
+
Queries the provider's API to get the list of models accessible
|
|
202
|
+
with the current API key or configuration. Results are filtered
|
|
203
|
+
by the user's subscription/access level.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
List of model identifiers available for use.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
ValueError: If the API request fails.
|
|
210
|
+
"""
|
|
211
|
+
pass
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def model_name(self) -> str:
|
|
215
|
+
"""Return the model name being used.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Model identifier string.
|
|
219
|
+
"""
|
|
220
|
+
return getattr(self, "config", LLMConfig(model="unknown")).model
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Direct DeepSeek API client - OpenAI-compatible API."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from causaliq_knowledge.llm.openai_compat_client import (
|
|
9
|
+
OpenAICompatClient,
|
|
10
|
+
OpenAICompatConfig,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class DeepSeekConfig(OpenAICompatConfig):
|
|
18
|
+
"""Configuration for DeepSeek API client.
|
|
19
|
+
|
|
20
|
+
Extends OpenAICompatConfig with DeepSeek-specific defaults.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
model: DeepSeek model identifier (default: deepseek-chat).
|
|
24
|
+
temperature: Sampling temperature (default: 0.1).
|
|
25
|
+
max_tokens: Maximum response tokens (default: 500).
|
|
26
|
+
timeout: Request timeout in seconds (default: 30.0).
|
|
27
|
+
api_key: DeepSeek API key (falls back to DEEPSEEK_API_KEY env var).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
model: str = "deepseek-chat"
|
|
31
|
+
temperature: float = 0.1
|
|
32
|
+
max_tokens: int = 500
|
|
33
|
+
timeout: float = 30.0
|
|
34
|
+
api_key: Optional[str] = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self) -> None:
|
|
37
|
+
"""Set API key from environment if not provided."""
|
|
38
|
+
if self.api_key is None:
|
|
39
|
+
self.api_key = os.getenv("DEEPSEEK_API_KEY")
|
|
40
|
+
if not self.api_key:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"DEEPSEEK_API_KEY environment variable is required"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class DeepSeekClient(OpenAICompatClient):
|
|
47
|
+
"""Direct DeepSeek API client.
|
|
48
|
+
|
|
49
|
+
DeepSeek uses an OpenAI-compatible API, making integration straightforward.
|
|
50
|
+
Known for excellent reasoning capabilities (R1) at low cost.
|
|
51
|
+
|
|
52
|
+
Available models:
|
|
53
|
+
- deepseek-chat: General purpose (DeepSeek-V3)
|
|
54
|
+
- deepseek-reasoner: Advanced reasoning (DeepSeek-R1)
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
>>> config = DeepSeekConfig(model="deepseek-chat")
|
|
58
|
+
>>> client = DeepSeekClient(config)
|
|
59
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
60
|
+
>>> response = client.completion(msgs)
|
|
61
|
+
>>> print(response.content)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
BASE_URL = "https://api.deepseek.com"
|
|
65
|
+
PROVIDER_NAME = "deepseek"
|
|
66
|
+
ENV_VAR = "DEEPSEEK_API_KEY"
|
|
67
|
+
|
|
68
|
+
def __init__(self, config: Optional[DeepSeekConfig] = None) -> None:
|
|
69
|
+
"""Initialize DeepSeek client.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
config: DeepSeek configuration. If None, uses defaults with
|
|
73
|
+
API key from DEEPSEEK_API_KEY environment variable.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__(config)
|
|
76
|
+
|
|
77
|
+
def _default_config(self) -> DeepSeekConfig:
|
|
78
|
+
"""Return default DeepSeek configuration."""
|
|
79
|
+
return DeepSeekConfig()
|
|
80
|
+
|
|
81
|
+
def _get_pricing(self) -> Dict[str, Dict[str, float]]:
|
|
82
|
+
"""Return DeepSeek pricing per 1M tokens.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Dict mapping model prefixes to input/output costs.
|
|
86
|
+
"""
|
|
87
|
+
# DeepSeek pricing as of Jan 2025
|
|
88
|
+
# Note: Cache hits are much cheaper but we use regular pricing
|
|
89
|
+
return {
|
|
90
|
+
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
|
|
91
|
+
"deepseek-chat": {"input": 0.14, "output": 0.28},
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
def _filter_models(self, models: List[str]) -> List[str]:
|
|
95
|
+
"""Filter to DeepSeek chat models only.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
models: List of all model IDs from API.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Filtered list of DeepSeek models.
|
|
102
|
+
"""
|
|
103
|
+
filtered = []
|
|
104
|
+
for model_id in models:
|
|
105
|
+
# Include deepseek chat and reasoner models
|
|
106
|
+
if model_id.startswith("deepseek-"):
|
|
107
|
+
filtered.append(model_id)
|
|
108
|
+
return filtered
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Direct Google Gemini API client - clean and reliable."""
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
import logging
|
|
5
4
|
import os
|
|
6
5
|
from dataclasses import dataclass
|
|
@@ -8,12 +7,28 @@ from typing import Any, Dict, List, Optional
|
|
|
8
7
|
|
|
9
8
|
import httpx
|
|
10
9
|
|
|
10
|
+
from causaliq_knowledge.llm.base_client import (
|
|
11
|
+
BaseLLMClient,
|
|
12
|
+
LLMConfig,
|
|
13
|
+
LLMResponse,
|
|
14
|
+
)
|
|
15
|
+
|
|
11
16
|
logger = logging.getLogger(__name__)
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
@dataclass
|
|
15
|
-
class GeminiConfig:
|
|
16
|
-
"""Configuration for Gemini API client.
|
|
20
|
+
class GeminiConfig(LLMConfig):
|
|
21
|
+
"""Configuration for Gemini API client.
|
|
22
|
+
|
|
23
|
+
Extends LLMConfig with Gemini-specific defaults.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
model: Gemini model identifier (default: gemini-2.5-flash).
|
|
27
|
+
temperature: Sampling temperature (default: 0.1).
|
|
28
|
+
max_tokens: Maximum response tokens (default: 500).
|
|
29
|
+
timeout: Request timeout in seconds (default: 30.0).
|
|
30
|
+
api_key: Gemini API key (falls back to GEMINI_API_KEY env var).
|
|
31
|
+
"""
|
|
17
32
|
|
|
18
33
|
model: str = "gemini-2.5-flash"
|
|
19
34
|
temperature: float = 0.1
|
|
@@ -29,48 +44,52 @@ class GeminiConfig:
|
|
|
29
44
|
raise ValueError("GEMINI_API_KEY environment variable is required")
|
|
30
45
|
|
|
31
46
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
"""Response from Gemini API."""
|
|
35
|
-
|
|
36
|
-
content: str
|
|
37
|
-
model: str
|
|
38
|
-
input_tokens: int = 0
|
|
39
|
-
output_tokens: int = 0
|
|
40
|
-
cost: float = 0.0 # Gemini free tier
|
|
41
|
-
raw_response: Optional[Dict] = None
|
|
42
|
-
|
|
43
|
-
def parse_json(self) -> Optional[Dict[str, Any]]:
|
|
44
|
-
"""Parse content as JSON, handling common formatting issues."""
|
|
45
|
-
try:
|
|
46
|
-
# Clean up potential markdown code blocks
|
|
47
|
-
text = self.content.strip()
|
|
48
|
-
if text.startswith("```json"):
|
|
49
|
-
text = text[7:]
|
|
50
|
-
elif text.startswith("```"):
|
|
51
|
-
text = text[3:]
|
|
52
|
-
if text.endswith("```"):
|
|
53
|
-
text = text[:-3]
|
|
54
|
-
|
|
55
|
-
return json.loads(text.strip()) # type: ignore[no-any-return]
|
|
56
|
-
except json.JSONDecodeError:
|
|
57
|
-
return None
|
|
47
|
+
class GeminiClient(BaseLLMClient):
|
|
48
|
+
"""Direct Gemini API client.
|
|
58
49
|
|
|
50
|
+
Implements the BaseLLMClient interface for Google's Gemini API.
|
|
51
|
+
Uses httpx for HTTP requests.
|
|
59
52
|
|
|
60
|
-
|
|
61
|
-
|
|
53
|
+
Example:
|
|
54
|
+
>>> config = GeminiConfig(model="gemini-2.5-flash")
|
|
55
|
+
>>> client = GeminiClient(config)
|
|
56
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
57
|
+
>>> response = client.completion(msgs)
|
|
58
|
+
>>> print(response.content)
|
|
59
|
+
"""
|
|
62
60
|
|
|
63
61
|
BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
|
64
62
|
|
|
65
|
-
def __init__(self, config: Optional[GeminiConfig] = None):
|
|
66
|
-
"""Initialize Gemini client.
|
|
63
|
+
def __init__(self, config: Optional[GeminiConfig] = None) -> None:
|
|
64
|
+
"""Initialize Gemini client.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
config: Gemini configuration. If None, uses defaults with
|
|
68
|
+
API key from GEMINI_API_KEY environment variable.
|
|
69
|
+
"""
|
|
67
70
|
self.config = config or GeminiConfig()
|
|
68
71
|
self._total_calls = 0
|
|
69
72
|
|
|
73
|
+
@property
|
|
74
|
+
def provider_name(self) -> str:
|
|
75
|
+
"""Return the provider name."""
|
|
76
|
+
return "gemini"
|
|
77
|
+
|
|
70
78
|
def completion(
|
|
71
79
|
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
72
|
-
) ->
|
|
73
|
-
"""Make a chat completion request to Gemini.
|
|
80
|
+
) -> LLMResponse:
|
|
81
|
+
"""Make a chat completion request to Gemini.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
85
|
+
**kwargs: Override config options (temperature, max_tokens).
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
LLMResponse with the generated content and metadata.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If the API request fails.
|
|
92
|
+
"""
|
|
74
93
|
|
|
75
94
|
# Convert OpenAI-style messages to Gemini format
|
|
76
95
|
contents = []
|
|
@@ -158,7 +177,7 @@ class GeminiClient:
|
|
|
158
177
|
f"Gemini response: {input_tokens} in, {output_tokens} out"
|
|
159
178
|
)
|
|
160
179
|
|
|
161
|
-
return
|
|
180
|
+
return LLMResponse(
|
|
162
181
|
content=content,
|
|
163
182
|
model=self.config.model,
|
|
164
183
|
input_tokens=input_tokens,
|
|
@@ -191,13 +210,72 @@ class GeminiClient:
|
|
|
191
210
|
|
|
192
211
|
def complete_json(
|
|
193
212
|
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
194
|
-
) -> tuple[Optional[Dict[str, Any]],
|
|
195
|
-
"""Make a completion request and parse response as JSON.
|
|
213
|
+
) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
|
|
214
|
+
"""Make a completion request and parse response as JSON.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
218
|
+
**kwargs: Override config options passed to completion().
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Tuple of (parsed JSON dict or None, raw LLMResponse).
|
|
222
|
+
"""
|
|
196
223
|
response = self.completion(messages, **kwargs)
|
|
197
224
|
parsed = response.parse_json()
|
|
198
225
|
return parsed, response
|
|
199
226
|
|
|
200
227
|
@property
|
|
201
228
|
def call_count(self) -> int:
|
|
202
|
-
"""
|
|
229
|
+
"""Return the number of API calls made."""
|
|
203
230
|
return self._total_calls
|
|
231
|
+
|
|
232
|
+
def is_available(self) -> bool:
|
|
233
|
+
"""Check if Gemini API is available.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
True if GEMINI_API_KEY is configured.
|
|
237
|
+
"""
|
|
238
|
+
return bool(self.config.api_key)
|
|
239
|
+
|
|
240
|
+
def list_models(self) -> List[str]:
|
|
241
|
+
"""List available models from Gemini API.
|
|
242
|
+
|
|
243
|
+
Queries the Gemini API to get models accessible with the current
|
|
244
|
+
API key. Filters to only include models that support generateContent.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
List of model identifiers (e.g., ['gemini-2.5-flash', ...]).
|
|
248
|
+
|
|
249
|
+
Raises:
|
|
250
|
+
ValueError: If the API request fails.
|
|
251
|
+
"""
|
|
252
|
+
try:
|
|
253
|
+
with httpx.Client(timeout=self.config.timeout) as client:
|
|
254
|
+
response = client.get(
|
|
255
|
+
f"{self.BASE_URL}?key={self.config.api_key}",
|
|
256
|
+
)
|
|
257
|
+
response.raise_for_status()
|
|
258
|
+
data = response.json()
|
|
259
|
+
|
|
260
|
+
# Filter to models that support text generation
|
|
261
|
+
models = []
|
|
262
|
+
for model in data.get("models", []):
|
|
263
|
+
methods = model.get("supportedGenerationMethods", [])
|
|
264
|
+
if "generateContent" not in methods:
|
|
265
|
+
continue
|
|
266
|
+
# Extract model name (remove 'models/' prefix)
|
|
267
|
+
name = model.get("name", "").replace("models/", "")
|
|
268
|
+
# Skip embedding and TTS models
|
|
269
|
+
if any(x in name.lower() for x in ["embed", "tts", "aqa"]):
|
|
270
|
+
continue
|
|
271
|
+
models.append(name)
|
|
272
|
+
|
|
273
|
+
return sorted(models)
|
|
274
|
+
|
|
275
|
+
except httpx.HTTPStatusError as e:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Gemini API error: {e.response.status_code} - "
|
|
278
|
+
f"{e.response.text}"
|
|
279
|
+
)
|
|
280
|
+
except Exception as e:
|
|
281
|
+
raise ValueError(f"Failed to list Gemini models: {e}")
|