causaliq-knowledge 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +3 -3
- causaliq_knowledge/cache/__init__.py +18 -0
- causaliq_knowledge/cache/encoders/__init__.py +13 -0
- causaliq_knowledge/cache/encoders/base.py +90 -0
- causaliq_knowledge/cache/encoders/json_encoder.py +418 -0
- causaliq_knowledge/cache/token_cache.py +632 -0
- causaliq_knowledge/cli.py +588 -38
- causaliq_knowledge/llm/__init__.py +39 -10
- causaliq_knowledge/llm/anthropic_client.py +256 -0
- causaliq_knowledge/llm/base_client.py +360 -0
- causaliq_knowledge/llm/cache.py +380 -0
- causaliq_knowledge/llm/deepseek_client.py +108 -0
- causaliq_knowledge/llm/gemini_client.py +117 -39
- causaliq_knowledge/llm/groq_client.py +115 -40
- causaliq_knowledge/llm/mistral_client.py +122 -0
- causaliq_knowledge/llm/ollama_client.py +240 -0
- causaliq_knowledge/llm/openai_client.py +115 -0
- causaliq_knowledge/llm/openai_compat_client.py +287 -0
- causaliq_knowledge/llm/provider.py +99 -46
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/METADATA +9 -10
- causaliq_knowledge-0.3.0.dist-info/RECORD +28 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/WHEEL +1 -1
- causaliq_knowledge-0.1.0.dist-info/RECORD +0 -15
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/entry_points.txt +0 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM-specific cache encoder and data structures.
|
|
3
|
+
|
|
4
|
+
This module provides the LLMEntryEncoder for caching LLM requests and
|
|
5
|
+
responses with rich metadata for analysis.
|
|
6
|
+
|
|
7
|
+
Note: This module stays in causaliq-knowledge (LLM-specific).
|
|
8
|
+
The base cache infrastructure will migrate to causaliq-core.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import asdict, dataclass, field
|
|
14
|
+
from datetime import datetime, timezone
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, Any
|
|
17
|
+
|
|
18
|
+
from causaliq_knowledge.cache.encoders import JsonEncoder
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
21
|
+
from causaliq_knowledge.cache.token_cache import TokenCache
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class LLMTokenUsage:
|
|
26
|
+
"""Token usage statistics for an LLM request.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
input: Number of tokens in the prompt.
|
|
30
|
+
output: Number of tokens in the completion.
|
|
31
|
+
total: Total tokens (input + output).
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
input: int = 0
|
|
35
|
+
output: int = 0
|
|
36
|
+
total: int = 0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class LLMMetadata:
|
|
41
|
+
"""Metadata for a cached LLM response.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
provider: LLM provider name (openai, anthropic, etc.).
|
|
45
|
+
timestamp: When the original request was made (ISO format).
|
|
46
|
+
latency_ms: Response time in milliseconds.
|
|
47
|
+
tokens: Token usage statistics.
|
|
48
|
+
cost_usd: Estimated cost of the request in USD.
|
|
49
|
+
cache_hit: Whether this was served from cache.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
provider: str = ""
|
|
53
|
+
timestamp: str = ""
|
|
54
|
+
latency_ms: int = 0
|
|
55
|
+
tokens: LLMTokenUsage = field(default_factory=LLMTokenUsage)
|
|
56
|
+
cost_usd: float = 0.0
|
|
57
|
+
cache_hit: bool = False
|
|
58
|
+
|
|
59
|
+
def to_dict(self) -> dict[str, Any]:
|
|
60
|
+
"""Convert to dictionary for JSON serialisation."""
|
|
61
|
+
return {
|
|
62
|
+
"provider": self.provider,
|
|
63
|
+
"timestamp": self.timestamp,
|
|
64
|
+
"latency_ms": self.latency_ms,
|
|
65
|
+
"tokens": asdict(self.tokens),
|
|
66
|
+
"cost_usd": self.cost_usd,
|
|
67
|
+
"cache_hit": self.cache_hit,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_dict(cls, data: dict[str, Any]) -> LLMMetadata:
|
|
72
|
+
"""Create from dictionary."""
|
|
73
|
+
tokens_data = data.get("tokens", {})
|
|
74
|
+
return cls(
|
|
75
|
+
provider=data.get("provider", ""),
|
|
76
|
+
timestamp=data.get("timestamp", ""),
|
|
77
|
+
latency_ms=data.get("latency_ms", 0),
|
|
78
|
+
tokens=LLMTokenUsage(
|
|
79
|
+
input=tokens_data.get("input", 0),
|
|
80
|
+
output=tokens_data.get("output", 0),
|
|
81
|
+
total=tokens_data.get("total", 0),
|
|
82
|
+
),
|
|
83
|
+
cost_usd=data.get("cost_usd", 0.0),
|
|
84
|
+
cache_hit=data.get("cache_hit", False),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class LLMResponse:
|
|
90
|
+
"""LLM response data for caching.
|
|
91
|
+
|
|
92
|
+
Attributes:
|
|
93
|
+
content: The full text response from the LLM.
|
|
94
|
+
finish_reason: Why generation stopped (stop, length, etc.).
|
|
95
|
+
model_version: Actual model version used.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
content: str = ""
|
|
99
|
+
finish_reason: str = "stop"
|
|
100
|
+
model_version: str = ""
|
|
101
|
+
|
|
102
|
+
def to_dict(self) -> dict[str, Any]:
|
|
103
|
+
"""Convert to dictionary for JSON serialisation."""
|
|
104
|
+
return {
|
|
105
|
+
"content": self.content,
|
|
106
|
+
"finish_reason": self.finish_reason,
|
|
107
|
+
"model_version": self.model_version,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def from_dict(cls, data: dict[str, Any]) -> LLMResponse:
|
|
112
|
+
"""Create from dictionary."""
|
|
113
|
+
return cls(
|
|
114
|
+
content=data.get("content", ""),
|
|
115
|
+
finish_reason=data.get("finish_reason", "stop"),
|
|
116
|
+
model_version=data.get("model_version", ""),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class LLMCacheEntry:
|
|
122
|
+
"""Complete LLM cache entry with request, response, and metadata.
|
|
123
|
+
|
|
124
|
+
Attributes:
|
|
125
|
+
model: The model name requested.
|
|
126
|
+
messages: The conversation messages.
|
|
127
|
+
temperature: Sampling temperature.
|
|
128
|
+
max_tokens: Maximum tokens in response.
|
|
129
|
+
response: The LLM response data.
|
|
130
|
+
metadata: Rich metadata for analysis.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
model: str = ""
|
|
134
|
+
messages: list[dict[str, Any]] = field(default_factory=list)
|
|
135
|
+
temperature: float = 0.0
|
|
136
|
+
max_tokens: int | None = None
|
|
137
|
+
response: LLMResponse = field(default_factory=LLMResponse)
|
|
138
|
+
metadata: LLMMetadata = field(default_factory=LLMMetadata)
|
|
139
|
+
|
|
140
|
+
def to_dict(self) -> dict[str, Any]:
|
|
141
|
+
"""Convert to dictionary for JSON serialisation."""
|
|
142
|
+
return {
|
|
143
|
+
"cache_key": {
|
|
144
|
+
"model": self.model,
|
|
145
|
+
"messages": self.messages,
|
|
146
|
+
"temperature": self.temperature,
|
|
147
|
+
"max_tokens": self.max_tokens,
|
|
148
|
+
},
|
|
149
|
+
"response": self.response.to_dict(),
|
|
150
|
+
"metadata": self.metadata.to_dict(),
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def from_dict(cls, data: dict[str, Any]) -> LLMCacheEntry:
|
|
155
|
+
"""Create from dictionary."""
|
|
156
|
+
cache_key = data.get("cache_key", {})
|
|
157
|
+
return cls(
|
|
158
|
+
model=cache_key.get("model", ""),
|
|
159
|
+
messages=cache_key.get("messages", []),
|
|
160
|
+
temperature=cache_key.get("temperature", 0.0),
|
|
161
|
+
max_tokens=cache_key.get("max_tokens"),
|
|
162
|
+
response=LLMResponse.from_dict(data.get("response", {})),
|
|
163
|
+
metadata=LLMMetadata.from_dict(data.get("metadata", {})),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@classmethod
|
|
167
|
+
def create(
|
|
168
|
+
cls,
|
|
169
|
+
model: str,
|
|
170
|
+
messages: list[dict[str, Any]],
|
|
171
|
+
content: str,
|
|
172
|
+
*,
|
|
173
|
+
temperature: float = 0.0,
|
|
174
|
+
max_tokens: int | None = None,
|
|
175
|
+
finish_reason: str = "stop",
|
|
176
|
+
model_version: str = "",
|
|
177
|
+
provider: str = "",
|
|
178
|
+
latency_ms: int = 0,
|
|
179
|
+
input_tokens: int = 0,
|
|
180
|
+
output_tokens: int = 0,
|
|
181
|
+
cost_usd: float = 0.0,
|
|
182
|
+
) -> LLMCacheEntry:
|
|
183
|
+
"""Create a cache entry with common parameters.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
model: The model name requested.
|
|
187
|
+
messages: The conversation messages.
|
|
188
|
+
content: The response content.
|
|
189
|
+
temperature: Sampling temperature.
|
|
190
|
+
max_tokens: Maximum tokens in response.
|
|
191
|
+
finish_reason: Why generation stopped.
|
|
192
|
+
model_version: Actual model version.
|
|
193
|
+
provider: LLM provider name.
|
|
194
|
+
latency_ms: Response time in milliseconds.
|
|
195
|
+
input_tokens: Number of input tokens.
|
|
196
|
+
output_tokens: Number of output tokens.
|
|
197
|
+
cost_usd: Estimated cost in USD.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Configured LLMCacheEntry.
|
|
201
|
+
"""
|
|
202
|
+
return cls(
|
|
203
|
+
model=model,
|
|
204
|
+
messages=messages,
|
|
205
|
+
temperature=temperature,
|
|
206
|
+
max_tokens=max_tokens,
|
|
207
|
+
response=LLMResponse(
|
|
208
|
+
content=content,
|
|
209
|
+
finish_reason=finish_reason,
|
|
210
|
+
model_version=model_version or model,
|
|
211
|
+
),
|
|
212
|
+
metadata=LLMMetadata(
|
|
213
|
+
provider=provider,
|
|
214
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
215
|
+
latency_ms=latency_ms,
|
|
216
|
+
tokens=LLMTokenUsage(
|
|
217
|
+
input=input_tokens,
|
|
218
|
+
output=output_tokens,
|
|
219
|
+
total=input_tokens + output_tokens,
|
|
220
|
+
),
|
|
221
|
+
cost_usd=cost_usd,
|
|
222
|
+
cache_hit=False,
|
|
223
|
+
),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class LLMEntryEncoder(JsonEncoder):
|
|
228
|
+
"""Encoder for LLM cache entries.
|
|
229
|
+
|
|
230
|
+
Extends JsonEncoder with LLM-specific convenience methods for
|
|
231
|
+
encoding/decoding LLMCacheEntry objects.
|
|
232
|
+
|
|
233
|
+
The encoder stores data in the standard JSON tokenised format,
|
|
234
|
+
achieving 50-70% compression through the shared token dictionary.
|
|
235
|
+
|
|
236
|
+
Example:
|
|
237
|
+
>>> from causaliq_knowledge.cache import TokenCache
|
|
238
|
+
>>> from causaliq_knowledge.llm.cache import (
|
|
239
|
+
... LLMEntryEncoder, LLMCacheEntry,
|
|
240
|
+
... )
|
|
241
|
+
>>> with TokenCache(":memory:") as cache:
|
|
242
|
+
... encoder = LLMEntryEncoder()
|
|
243
|
+
... entry = LLMCacheEntry.create(
|
|
244
|
+
... model="gpt-4",
|
|
245
|
+
... messages=[{"role": "user", "content": "Hello"}],
|
|
246
|
+
... content="Hi there!",
|
|
247
|
+
... provider="openai",
|
|
248
|
+
... )
|
|
249
|
+
... blob = encoder.encode(entry.to_dict(), cache)
|
|
250
|
+
... data = encoder.decode(blob, cache)
|
|
251
|
+
... restored = LLMCacheEntry.from_dict(data)
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
def encode_entry(self, entry: LLMCacheEntry, cache: TokenCache) -> bytes:
|
|
255
|
+
"""Encode an LLMCacheEntry to bytes.
|
|
256
|
+
|
|
257
|
+
Convenience method that handles to_dict conversion.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
entry: The cache entry to encode.
|
|
261
|
+
cache: TokenCache for token dictionary.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Encoded bytes.
|
|
265
|
+
"""
|
|
266
|
+
return self.encode(entry.to_dict(), cache)
|
|
267
|
+
|
|
268
|
+
def decode_entry(self, blob: bytes, cache: TokenCache) -> LLMCacheEntry:
|
|
269
|
+
"""Decode bytes to an LLMCacheEntry.
|
|
270
|
+
|
|
271
|
+
Convenience method that handles from_dict conversion.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
blob: Encoded bytes.
|
|
275
|
+
cache: TokenCache for token dictionary.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Decoded LLMCacheEntry.
|
|
279
|
+
"""
|
|
280
|
+
data = self.decode(blob, cache)
|
|
281
|
+
return LLMCacheEntry.from_dict(data)
|
|
282
|
+
|
|
283
|
+
def generate_export_filename(
|
|
284
|
+
self, entry: LLMCacheEntry, cache_key: str
|
|
285
|
+
) -> str:
|
|
286
|
+
"""Generate a human-readable filename for export.
|
|
287
|
+
|
|
288
|
+
Creates a filename from model name and query details, with a
|
|
289
|
+
short hash suffix for uniqueness.
|
|
290
|
+
|
|
291
|
+
For edge queries, extracts node names for format:
|
|
292
|
+
{model}_{node_a}_{node_b}_edge_{hash}.json
|
|
293
|
+
|
|
294
|
+
For other queries, uses prompt excerpt:
|
|
295
|
+
{model}_{prompt_excerpt}_{hash}.json
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
entry: The cache entry to generate filename for.
|
|
299
|
+
cache_key: The cache key (hash) for uniqueness suffix.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Human-readable filename with .json extension.
|
|
303
|
+
|
|
304
|
+
Example:
|
|
305
|
+
>>> encoder = LLMEntryEncoder()
|
|
306
|
+
>>> entry = LLMCacheEntry.create(
|
|
307
|
+
... model="gpt-4",
|
|
308
|
+
... messages=[{"role": "user", "content": "smoking and lung"}],
|
|
309
|
+
... content="Yes...",
|
|
310
|
+
... )
|
|
311
|
+
>>> encoder.generate_export_filename(entry, "a1b2c3d4e5f6")
|
|
312
|
+
'gpt4_smoking_lung_edge_a1b2.json'
|
|
313
|
+
"""
|
|
314
|
+
import re
|
|
315
|
+
|
|
316
|
+
# Sanitize model name (alphanumeric only, lowercase)
|
|
317
|
+
model = re.sub(r"[^a-z0-9]", "", entry.model.lower())
|
|
318
|
+
if len(model) > 15:
|
|
319
|
+
model = model[:15]
|
|
320
|
+
|
|
321
|
+
# Extract user message content
|
|
322
|
+
prompt = ""
|
|
323
|
+
for msg in entry.messages:
|
|
324
|
+
if msg.get("role") == "user":
|
|
325
|
+
prompt = msg.get("content", "")
|
|
326
|
+
break
|
|
327
|
+
|
|
328
|
+
# Try to extract node names for edge queries
|
|
329
|
+
# Look for patterns like "X and Y", "X cause Y", "between X and Y"
|
|
330
|
+
prompt_lower = prompt.lower()
|
|
331
|
+
slug = ""
|
|
332
|
+
|
|
333
|
+
# Pattern: "between X and Y" or "X and Y"
|
|
334
|
+
match = re.search(r"(?:between\s+)?(\w+)\s+and\s+(\w+)", prompt_lower)
|
|
335
|
+
if match:
|
|
336
|
+
node_a = match.group(1)[:15]
|
|
337
|
+
node_b = match.group(2)[:15]
|
|
338
|
+
slug = f"{node_a}_{node_b}_edge"
|
|
339
|
+
|
|
340
|
+
# Fallback: extract first significant words from prompt
|
|
341
|
+
if not slug:
|
|
342
|
+
# Remove common words, keep alphanumeric
|
|
343
|
+
cleaned = re.sub(r"[^a-z0-9\s]", "", prompt_lower)
|
|
344
|
+
words = [
|
|
345
|
+
w
|
|
346
|
+
for w in cleaned.split()
|
|
347
|
+
if w
|
|
348
|
+
not in ("the", "a", "an", "is", "are", "does", "do", "can")
|
|
349
|
+
]
|
|
350
|
+
slug = "_".join(words[:4])
|
|
351
|
+
if len(slug) > 30:
|
|
352
|
+
slug = slug[:30].rstrip("_")
|
|
353
|
+
|
|
354
|
+
# Short hash suffix for uniqueness (4 chars)
|
|
355
|
+
hash_suffix = cache_key[:4] if cache_key else "0000"
|
|
356
|
+
|
|
357
|
+
# Build filename
|
|
358
|
+
parts = [p for p in [model, slug, hash_suffix] if p]
|
|
359
|
+
return "_".join(parts) + ".json"
|
|
360
|
+
|
|
361
|
+
def export_entry(self, entry: LLMCacheEntry, path: Path) -> None:
|
|
362
|
+
"""Export an LLMCacheEntry to a JSON file.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
entry: The cache entry to export.
|
|
366
|
+
path: Destination file path.
|
|
367
|
+
"""
|
|
368
|
+
self.export(entry.to_dict(), path)
|
|
369
|
+
|
|
370
|
+
def import_entry(self, path: Path) -> LLMCacheEntry:
|
|
371
|
+
"""Import an LLMCacheEntry from a JSON file.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
path: Source file path.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
Imported LLMCacheEntry.
|
|
378
|
+
"""
|
|
379
|
+
data = self.import_(path)
|
|
380
|
+
return LLMCacheEntry.from_dict(data)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Direct DeepSeek API client - OpenAI-compatible API."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from causaliq_knowledge.llm.openai_compat_client import (
|
|
9
|
+
OpenAICompatClient,
|
|
10
|
+
OpenAICompatConfig,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class DeepSeekConfig(OpenAICompatConfig):
|
|
18
|
+
"""Configuration for DeepSeek API client.
|
|
19
|
+
|
|
20
|
+
Extends OpenAICompatConfig with DeepSeek-specific defaults.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
model: DeepSeek model identifier (default: deepseek-chat).
|
|
24
|
+
temperature: Sampling temperature (default: 0.1).
|
|
25
|
+
max_tokens: Maximum response tokens (default: 500).
|
|
26
|
+
timeout: Request timeout in seconds (default: 30.0).
|
|
27
|
+
api_key: DeepSeek API key (falls back to DEEPSEEK_API_KEY env var).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
model: str = "deepseek-chat"
|
|
31
|
+
temperature: float = 0.1
|
|
32
|
+
max_tokens: int = 500
|
|
33
|
+
timeout: float = 30.0
|
|
34
|
+
api_key: Optional[str] = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self) -> None:
|
|
37
|
+
"""Set API key from environment if not provided."""
|
|
38
|
+
if self.api_key is None:
|
|
39
|
+
self.api_key = os.getenv("DEEPSEEK_API_KEY")
|
|
40
|
+
if not self.api_key:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"DEEPSEEK_API_KEY environment variable is required"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class DeepSeekClient(OpenAICompatClient):
|
|
47
|
+
"""Direct DeepSeek API client.
|
|
48
|
+
|
|
49
|
+
DeepSeek uses an OpenAI-compatible API, making integration straightforward.
|
|
50
|
+
Known for excellent reasoning capabilities (R1) at low cost.
|
|
51
|
+
|
|
52
|
+
Available models:
|
|
53
|
+
- deepseek-chat: General purpose (DeepSeek-V3)
|
|
54
|
+
- deepseek-reasoner: Advanced reasoning (DeepSeek-R1)
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
>>> config = DeepSeekConfig(model="deepseek-chat")
|
|
58
|
+
>>> client = DeepSeekClient(config)
|
|
59
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
60
|
+
>>> response = client.completion(msgs)
|
|
61
|
+
>>> print(response.content)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
BASE_URL = "https://api.deepseek.com"
|
|
65
|
+
PROVIDER_NAME = "deepseek"
|
|
66
|
+
ENV_VAR = "DEEPSEEK_API_KEY"
|
|
67
|
+
|
|
68
|
+
def __init__(self, config: Optional[DeepSeekConfig] = None) -> None:
|
|
69
|
+
"""Initialize DeepSeek client.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
config: DeepSeek configuration. If None, uses defaults with
|
|
73
|
+
API key from DEEPSEEK_API_KEY environment variable.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__(config)
|
|
76
|
+
|
|
77
|
+
def _default_config(self) -> DeepSeekConfig:
|
|
78
|
+
"""Return default DeepSeek configuration."""
|
|
79
|
+
return DeepSeekConfig()
|
|
80
|
+
|
|
81
|
+
def _get_pricing(self) -> Dict[str, Dict[str, float]]:
|
|
82
|
+
"""Return DeepSeek pricing per 1M tokens.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Dict mapping model prefixes to input/output costs.
|
|
86
|
+
"""
|
|
87
|
+
# DeepSeek pricing as of Jan 2025
|
|
88
|
+
# Note: Cache hits are much cheaper but we use regular pricing
|
|
89
|
+
return {
|
|
90
|
+
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
|
|
91
|
+
"deepseek-chat": {"input": 0.14, "output": 0.28},
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
def _filter_models(self, models: List[str]) -> List[str]:
|
|
95
|
+
"""Filter to DeepSeek chat models only.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
models: List of all model IDs from API.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Filtered list of DeepSeek models.
|
|
102
|
+
"""
|
|
103
|
+
filtered = []
|
|
104
|
+
for model_id in models:
|
|
105
|
+
# Include deepseek chat and reasoner models
|
|
106
|
+
if model_id.startswith("deepseek-"):
|
|
107
|
+
filtered.append(model_id)
|
|
108
|
+
return filtered
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Direct Google Gemini API client - clean and reliable."""
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
import logging
|
|
5
4
|
import os
|
|
6
5
|
from dataclasses import dataclass
|
|
@@ -8,12 +7,28 @@ from typing import Any, Dict, List, Optional
|
|
|
8
7
|
|
|
9
8
|
import httpx
|
|
10
9
|
|
|
10
|
+
from causaliq_knowledge.llm.base_client import (
|
|
11
|
+
BaseLLMClient,
|
|
12
|
+
LLMConfig,
|
|
13
|
+
LLMResponse,
|
|
14
|
+
)
|
|
15
|
+
|
|
11
16
|
logger = logging.getLogger(__name__)
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
@dataclass
|
|
15
|
-
class GeminiConfig:
|
|
16
|
-
"""Configuration for Gemini API client.
|
|
20
|
+
class GeminiConfig(LLMConfig):
|
|
21
|
+
"""Configuration for Gemini API client.
|
|
22
|
+
|
|
23
|
+
Extends LLMConfig with Gemini-specific defaults.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
model: Gemini model identifier (default: gemini-2.5-flash).
|
|
27
|
+
temperature: Sampling temperature (default: 0.1).
|
|
28
|
+
max_tokens: Maximum response tokens (default: 500).
|
|
29
|
+
timeout: Request timeout in seconds (default: 30.0).
|
|
30
|
+
api_key: Gemini API key (falls back to GEMINI_API_KEY env var).
|
|
31
|
+
"""
|
|
17
32
|
|
|
18
33
|
model: str = "gemini-2.5-flash"
|
|
19
34
|
temperature: float = 0.1
|
|
@@ -29,48 +44,52 @@ class GeminiConfig:
|
|
|
29
44
|
raise ValueError("GEMINI_API_KEY environment variable is required")
|
|
30
45
|
|
|
31
46
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
"""Response from Gemini API."""
|
|
35
|
-
|
|
36
|
-
content: str
|
|
37
|
-
model: str
|
|
38
|
-
input_tokens: int = 0
|
|
39
|
-
output_tokens: int = 0
|
|
40
|
-
cost: float = 0.0 # Gemini free tier
|
|
41
|
-
raw_response: Optional[Dict] = None
|
|
42
|
-
|
|
43
|
-
def parse_json(self) -> Optional[Dict[str, Any]]:
|
|
44
|
-
"""Parse content as JSON, handling common formatting issues."""
|
|
45
|
-
try:
|
|
46
|
-
# Clean up potential markdown code blocks
|
|
47
|
-
text = self.content.strip()
|
|
48
|
-
if text.startswith("```json"):
|
|
49
|
-
text = text[7:]
|
|
50
|
-
elif text.startswith("```"):
|
|
51
|
-
text = text[3:]
|
|
52
|
-
if text.endswith("```"):
|
|
53
|
-
text = text[:-3]
|
|
54
|
-
|
|
55
|
-
return json.loads(text.strip()) # type: ignore[no-any-return]
|
|
56
|
-
except json.JSONDecodeError:
|
|
57
|
-
return None
|
|
47
|
+
class GeminiClient(BaseLLMClient):
|
|
48
|
+
"""Direct Gemini API client.
|
|
58
49
|
|
|
50
|
+
Implements the BaseLLMClient interface for Google's Gemini API.
|
|
51
|
+
Uses httpx for HTTP requests.
|
|
59
52
|
|
|
60
|
-
|
|
61
|
-
|
|
53
|
+
Example:
|
|
54
|
+
>>> config = GeminiConfig(model="gemini-2.5-flash")
|
|
55
|
+
>>> client = GeminiClient(config)
|
|
56
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
57
|
+
>>> response = client.completion(msgs)
|
|
58
|
+
>>> print(response.content)
|
|
59
|
+
"""
|
|
62
60
|
|
|
63
61
|
BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
|
64
62
|
|
|
65
|
-
def __init__(self, config: Optional[GeminiConfig] = None):
|
|
66
|
-
"""Initialize Gemini client.
|
|
63
|
+
def __init__(self, config: Optional[GeminiConfig] = None) -> None:
|
|
64
|
+
"""Initialize Gemini client.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
config: Gemini configuration. If None, uses defaults with
|
|
68
|
+
API key from GEMINI_API_KEY environment variable.
|
|
69
|
+
"""
|
|
67
70
|
self.config = config or GeminiConfig()
|
|
68
71
|
self._total_calls = 0
|
|
69
72
|
|
|
73
|
+
@property
|
|
74
|
+
def provider_name(self) -> str:
|
|
75
|
+
"""Return the provider name."""
|
|
76
|
+
return "gemini"
|
|
77
|
+
|
|
70
78
|
def completion(
|
|
71
79
|
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
72
|
-
) ->
|
|
73
|
-
"""Make a chat completion request to Gemini.
|
|
80
|
+
) -> LLMResponse:
|
|
81
|
+
"""Make a chat completion request to Gemini.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
85
|
+
**kwargs: Override config options (temperature, max_tokens).
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
LLMResponse with the generated content and metadata.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If the API request fails.
|
|
92
|
+
"""
|
|
74
93
|
|
|
75
94
|
# Convert OpenAI-style messages to Gemini format
|
|
76
95
|
contents = []
|
|
@@ -158,7 +177,7 @@ class GeminiClient:
|
|
|
158
177
|
f"Gemini response: {input_tokens} in, {output_tokens} out"
|
|
159
178
|
)
|
|
160
179
|
|
|
161
|
-
return
|
|
180
|
+
return LLMResponse(
|
|
162
181
|
content=content,
|
|
163
182
|
model=self.config.model,
|
|
164
183
|
input_tokens=input_tokens,
|
|
@@ -191,13 +210,72 @@ class GeminiClient:
|
|
|
191
210
|
|
|
192
211
|
def complete_json(
|
|
193
212
|
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
194
|
-
) -> tuple[Optional[Dict[str, Any]],
|
|
195
|
-
"""Make a completion request and parse response as JSON.
|
|
213
|
+
) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
|
|
214
|
+
"""Make a completion request and parse response as JSON.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
218
|
+
**kwargs: Override config options passed to completion().
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Tuple of (parsed JSON dict or None, raw LLMResponse).
|
|
222
|
+
"""
|
|
196
223
|
response = self.completion(messages, **kwargs)
|
|
197
224
|
parsed = response.parse_json()
|
|
198
225
|
return parsed, response
|
|
199
226
|
|
|
200
227
|
@property
|
|
201
228
|
def call_count(self) -> int:
|
|
202
|
-
"""
|
|
229
|
+
"""Return the number of API calls made."""
|
|
203
230
|
return self._total_calls
|
|
231
|
+
|
|
232
|
+
def is_available(self) -> bool:
|
|
233
|
+
"""Check if Gemini API is available.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
True if GEMINI_API_KEY is configured.
|
|
237
|
+
"""
|
|
238
|
+
return bool(self.config.api_key)
|
|
239
|
+
|
|
240
|
+
def list_models(self) -> List[str]:
|
|
241
|
+
"""List available models from Gemini API.
|
|
242
|
+
|
|
243
|
+
Queries the Gemini API to get models accessible with the current
|
|
244
|
+
API key. Filters to only include models that support generateContent.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
List of model identifiers (e.g., ['gemini-2.5-flash', ...]).
|
|
248
|
+
|
|
249
|
+
Raises:
|
|
250
|
+
ValueError: If the API request fails.
|
|
251
|
+
"""
|
|
252
|
+
try:
|
|
253
|
+
with httpx.Client(timeout=self.config.timeout) as client:
|
|
254
|
+
response = client.get(
|
|
255
|
+
f"{self.BASE_URL}?key={self.config.api_key}",
|
|
256
|
+
)
|
|
257
|
+
response.raise_for_status()
|
|
258
|
+
data = response.json()
|
|
259
|
+
|
|
260
|
+
# Filter to models that support text generation
|
|
261
|
+
models = []
|
|
262
|
+
for model in data.get("models", []):
|
|
263
|
+
methods = model.get("supportedGenerationMethods", [])
|
|
264
|
+
if "generateContent" not in methods:
|
|
265
|
+
continue
|
|
266
|
+
# Extract model name (remove 'models/' prefix)
|
|
267
|
+
name = model.get("name", "").replace("models/", "")
|
|
268
|
+
# Skip embedding and TTS models
|
|
269
|
+
if any(x in name.lower() for x in ["embed", "tts", "aqa"]):
|
|
270
|
+
continue
|
|
271
|
+
models.append(name)
|
|
272
|
+
|
|
273
|
+
return sorted(models)
|
|
274
|
+
|
|
275
|
+
except httpx.HTTPStatusError as e:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Gemini API error: {e.response.status_code} - "
|
|
278
|
+
f"{e.response.text}"
|
|
279
|
+
)
|
|
280
|
+
except Exception as e:
|
|
281
|
+
raise ValueError(f"Failed to list Gemini models: {e}")
|