causaliq-knowledge 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,23 @@
1
1
  """LLM integration module for causaliq-knowledge."""
2
2
 
3
- from causaliq_knowledge.llm.gemini_client import (
4
- GeminiClient,
5
- GeminiConfig,
6
- GeminiResponse,
3
+ from causaliq_knowledge.llm.anthropic_client import (
4
+ AnthropicClient,
5
+ AnthropicConfig,
7
6
  )
8
- from causaliq_knowledge.llm.groq_client import (
9
- GroqClient,
10
- GroqConfig,
11
- GroqResponse,
7
+ from causaliq_knowledge.llm.base_client import (
8
+ BaseLLMClient,
9
+ LLMConfig,
10
+ LLMResponse,
12
11
  )
12
+ from causaliq_knowledge.llm.deepseek_client import (
13
+ DeepSeekClient,
14
+ DeepSeekConfig,
15
+ )
16
+ from causaliq_knowledge.llm.gemini_client import GeminiClient, GeminiConfig
17
+ from causaliq_knowledge.llm.groq_client import GroqClient, GroqConfig
18
+ from causaliq_knowledge.llm.mistral_client import MistralClient, MistralConfig
19
+ from causaliq_knowledge.llm.ollama_client import OllamaClient, OllamaConfig
20
+ from causaliq_knowledge.llm.openai_client import OpenAIClient, OpenAIConfig
13
21
  from causaliq_knowledge.llm.prompts import EdgeQueryPrompt, parse_edge_response
14
22
  from causaliq_knowledge.llm.provider import (
15
23
  CONSENSUS_STRATEGIES,
@@ -19,14 +27,35 @@ from causaliq_knowledge.llm.provider import (
19
27
  )
20
28
 
21
29
  __all__ = [
30
+ # Abstract base
31
+ "BaseLLMClient",
32
+ "LLMConfig",
33
+ "LLMResponse",
34
+ # Anthropic
35
+ "AnthropicClient",
36
+ "AnthropicConfig",
37
+ # Consensus
22
38
  "CONSENSUS_STRATEGIES",
39
+ # DeepSeek
40
+ "DeepSeekClient",
41
+ "DeepSeekConfig",
23
42
  "EdgeQueryPrompt",
43
+ # Gemini
24
44
  "GeminiClient",
25
45
  "GeminiConfig",
26
- "GeminiResponse",
46
+ # Groq
27
47
  "GroqClient",
28
48
  "GroqConfig",
29
- "GroqResponse",
49
+ # Mistral
50
+ "MistralClient",
51
+ "MistralConfig",
52
+ # Ollama (local)
53
+ "OllamaClient",
54
+ "OllamaConfig",
55
+ # OpenAI
56
+ "OpenAIClient",
57
+ "OpenAIConfig",
58
+ # Provider
30
59
  "LLMKnowledge",
31
60
  "highest_confidence",
32
61
  "parse_edge_response",
@@ -0,0 +1,256 @@
1
+ """Direct Anthropic API client - clean and reliable."""
2
+
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import httpx
9
+
10
+ from causaliq_knowledge.llm.base_client import (
11
+ BaseLLMClient,
12
+ LLMConfig,
13
+ LLMResponse,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class AnthropicConfig(LLMConfig):
21
+ """Configuration for Anthropic API client.
22
+
23
+ Extends LLMConfig with Anthropic-specific defaults.
24
+
25
+ Attributes:
26
+ model: Anthropic model identifier (default: claude-sonnet-4-20250514).
27
+ temperature: Sampling temperature (default: 0.1).
28
+ max_tokens: Maximum response tokens (default: 500).
29
+ timeout: Request timeout in seconds (default: 30.0).
30
+ api_key: Anthropic API key (falls back to ANTHROPIC_API_KEY env var).
31
+ """
32
+
33
+ model: str = "claude-sonnet-4-20250514"
34
+ temperature: float = 0.1
35
+ max_tokens: int = 500
36
+ timeout: float = 30.0
37
+ api_key: Optional[str] = None
38
+
39
+ def __post_init__(self) -> None:
40
+ """Set API key from environment if not provided."""
41
+ if self.api_key is None:
42
+ self.api_key = os.getenv("ANTHROPIC_API_KEY")
43
+ if not self.api_key:
44
+ raise ValueError(
45
+ "ANTHROPIC_API_KEY environment variable is required"
46
+ )
47
+
48
+
49
+ class AnthropicClient(BaseLLMClient):
50
+ """Direct Anthropic API client.
51
+
52
+ Implements the BaseLLMClient interface for Anthropic's Claude API.
53
+ Uses httpx for HTTP requests.
54
+
55
+ Example:
56
+ >>> config = AnthropicConfig(model="claude-sonnet-4-20250514")
57
+ >>> client = AnthropicClient(config)
58
+ >>> msgs = [{"role": "user", "content": "Hello"}]
59
+ >>> response = client.completion(msgs)
60
+ >>> print(response.content)
61
+ """
62
+
63
+ BASE_URL = "https://api.anthropic.com/v1"
64
+ API_VERSION = "2023-06-01"
65
+
66
+ def __init__(self, config: Optional[AnthropicConfig] = None) -> None:
67
+ """Initialize Anthropic client.
68
+
69
+ Args:
70
+ config: Anthropic configuration. If None, uses defaults with
71
+ API key from ANTHROPIC_API_KEY environment variable.
72
+ """
73
+ self.config = config or AnthropicConfig()
74
+ self._total_calls = 0
75
+
76
+ @property
77
+ def provider_name(self) -> str:
78
+ """Return the provider name."""
79
+ return "anthropic"
80
+
81
+ def completion(
82
+ self, messages: List[Dict[str, str]], **kwargs: Any
83
+ ) -> LLMResponse:
84
+ """Make a chat completion request to Anthropic.
85
+
86
+ Args:
87
+ messages: List of message dicts with "role" and "content" keys.
88
+ **kwargs: Override config options (temperature, max_tokens).
89
+
90
+ Returns:
91
+ LLMResponse with the generated content and metadata.
92
+
93
+ Raises:
94
+ ValueError: If the API request fails.
95
+ """
96
+ # Anthropic uses separate system parameter, not in messages
97
+ system_content = None
98
+ filtered_messages = []
99
+
100
+ for msg in messages:
101
+ if msg["role"] == "system":
102
+ system_content = msg["content"]
103
+ else:
104
+ filtered_messages.append(msg)
105
+
106
+ # Build request payload in Anthropic's format
107
+ payload: Dict[str, Any] = {
108
+ "model": self.config.model,
109
+ "messages": filtered_messages,
110
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
111
+ "temperature": kwargs.get("temperature", self.config.temperature),
112
+ }
113
+
114
+ # Add system prompt if present
115
+ if system_content:
116
+ payload["system"] = system_content
117
+
118
+ # api_key is guaranteed non-None after __post_init__ validation
119
+ headers: dict[str, str] = {
120
+ "x-api-key": self.config.api_key, # type: ignore[dict-item]
121
+ "anthropic-version": self.API_VERSION,
122
+ "Content-Type": "application/json",
123
+ }
124
+
125
+ logger.debug(f"Calling Anthropic API with model: {self.config.model}")
126
+
127
+ try:
128
+ with httpx.Client(timeout=self.config.timeout) as client:
129
+ response = client.post(
130
+ f"{self.BASE_URL}/messages",
131
+ json=payload,
132
+ headers=headers,
133
+ )
134
+ response.raise_for_status()
135
+
136
+ data = response.json()
137
+
138
+ # Extract response content from Anthropic format
139
+ content_blocks = data.get("content", [])
140
+ content = ""
141
+ for block in content_blocks:
142
+ if block.get("type") == "text":
143
+ content += block.get("text", "")
144
+
145
+ # Extract usage info
146
+ usage = data.get("usage", {})
147
+ input_tokens = usage.get("input_tokens", 0)
148
+ output_tokens = usage.get("output_tokens", 0)
149
+
150
+ self._total_calls += 1
151
+
152
+ logger.debug(
153
+ f"Anthropic response: {input_tokens} in, "
154
+ f"{output_tokens} out"
155
+ )
156
+
157
+ return LLMResponse(
158
+ content=content,
159
+ model=data.get("model", self.config.model),
160
+ input_tokens=input_tokens,
161
+ output_tokens=output_tokens,
162
+ cost=0.0, # Cost calculation not implemented
163
+ raw_response=data,
164
+ )
165
+
166
+ except httpx.HTTPStatusError as e:
167
+ try:
168
+ error_data = e.response.json()
169
+ error_msg = error_data.get("error", {}).get(
170
+ "message", e.response.text
171
+ )
172
+ except Exception:
173
+ error_msg = e.response.text
174
+
175
+ logger.error(
176
+ f"Anthropic API HTTP error: {e.response.status_code} - "
177
+ f"{error_msg}"
178
+ )
179
+ raise ValueError(
180
+ f"Anthropic API error: {e.response.status_code} - {error_msg}"
181
+ )
182
+ except httpx.TimeoutException:
183
+ raise ValueError("Anthropic API request timed out")
184
+ except Exception as e:
185
+ logger.error(f"Anthropic API unexpected error: {e}")
186
+ raise ValueError(f"Anthropic API error: {str(e)}")
187
+
188
+ def complete_json(
189
+ self, messages: List[Dict[str, str]], **kwargs: Any
190
+ ) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
191
+ """Make a completion request and parse response as JSON.
192
+
193
+ Args:
194
+ messages: List of message dicts with "role" and "content" keys.
195
+ **kwargs: Override config options passed to completion().
196
+
197
+ Returns:
198
+ Tuple of (parsed JSON dict or None, raw LLMResponse).
199
+ """
200
+ response = self.completion(messages, **kwargs)
201
+ parsed = response.parse_json()
202
+ return parsed, response
203
+
204
+ @property
205
+ def call_count(self) -> int:
206
+ """Return the number of API calls made."""
207
+ return self._total_calls
208
+
209
+ def is_available(self) -> bool:
210
+ """Check if Anthropic API is available.
211
+
212
+ Returns:
213
+ True if ANTHROPIC_API_KEY is configured.
214
+ """
215
+ return bool(self.config.api_key)
216
+
217
+ def list_models(self) -> List[str]:
218
+ """List available Claude models from Anthropic API.
219
+
220
+ Queries the Anthropic /v1/models endpoint to get available models.
221
+
222
+ Returns:
223
+ List of model identifiers
224
+ (e.g., ['claude-sonnet-4-20250514', ...]).
225
+ """
226
+ if not self.config.api_key:
227
+ return []
228
+
229
+ headers: dict[str, str] = {
230
+ "x-api-key": self.config.api_key,
231
+ "anthropic-version": self.API_VERSION,
232
+ }
233
+
234
+ try:
235
+ with httpx.Client(timeout=self.config.timeout) as client:
236
+ response = client.get(
237
+ f"{self.BASE_URL}/models",
238
+ headers=headers,
239
+ )
240
+ response.raise_for_status()
241
+
242
+ data = response.json()
243
+ models = []
244
+ for model_info in data.get("data", []):
245
+ model_id = model_info.get("id")
246
+ if model_id:
247
+ models.append(model_id)
248
+
249
+ return sorted(models)
250
+
251
+ except httpx.HTTPStatusError as e:
252
+ logger.warning(f"Anthropic API error listing models: {e}")
253
+ return []
254
+ except Exception as e:
255
+ logger.warning(f"Error listing Anthropic models: {e}")
256
+ return []
@@ -0,0 +1,360 @@
1
+ """Abstract base class for LLM clients.
2
+
3
+ This module defines the common interface that all LLM vendor clients
4
+ must implement. This provides a consistent API regardless of the
5
+ underlying LLM provider.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import hashlib
11
+ import json
12
+ import logging
13
+ import time
14
+ from abc import ABC, abstractmethod
15
+ from dataclasses import dataclass, field
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
17
+
18
+ if TYPE_CHECKING: # pragma: no cover
19
+ from causaliq_knowledge.cache import TokenCache
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class LLMConfig:
26
+ """Base configuration for all LLM clients.
27
+
28
+ This dataclass defines common configuration options shared by all
29
+ LLM provider clients. Vendor-specific clients may extend this with
30
+ additional options.
31
+
32
+ Attributes:
33
+ model: Model identifier (provider-specific format).
34
+ temperature: Sampling temperature (0.0=deterministic, 1.0=creative).
35
+ max_tokens: Maximum tokens in the response.
36
+ timeout: Request timeout in seconds.
37
+ api_key: API key for authentication (optional, can use env var).
38
+ """
39
+
40
+ model: str
41
+ temperature: float = 0.1
42
+ max_tokens: int = 500
43
+ timeout: float = 30.0
44
+ api_key: Optional[str] = None
45
+
46
+
47
+ @dataclass
48
+ class LLMResponse:
49
+ """Standard response from any LLM client.
50
+
51
+ This dataclass provides a unified response format across all LLM providers,
52
+ abstracting away provider-specific response structures.
53
+
54
+ Attributes:
55
+ content: The text content of the response.
56
+ model: The model that generated the response.
57
+ input_tokens: Number of input/prompt tokens used.
58
+ output_tokens: Number of output/completion tokens generated.
59
+ cost: Estimated cost of the request (if available).
60
+ raw_response: The original provider-specific response (for debugging).
61
+ """
62
+
63
+ content: str
64
+ model: str
65
+ input_tokens: int = 0
66
+ output_tokens: int = 0
67
+ cost: float = 0.0
68
+ raw_response: Optional[Dict[str, Any]] = field(default=None, repr=False)
69
+
70
+ def parse_json(self) -> Optional[Dict[str, Any]]:
71
+ """Parse content as JSON, handling common formatting issues.
72
+
73
+ LLMs sometimes wrap JSON in markdown code blocks. This method
74
+ handles those cases and attempts to extract valid JSON.
75
+
76
+ Returns:
77
+ Parsed JSON as dict, or None if parsing fails.
78
+ """
79
+ try:
80
+ # Clean up potential markdown code blocks
81
+ text = self.content.strip()
82
+ if text.startswith("```json"):
83
+ text = text[7:]
84
+ elif text.startswith("```"):
85
+ text = text[3:]
86
+ if text.endswith("```"):
87
+ text = text[:-3]
88
+
89
+ return json.loads(text.strip()) # type: ignore[no-any-return]
90
+ except json.JSONDecodeError as e:
91
+ logger.warning(f"Failed to parse JSON response: {e}")
92
+ return None
93
+
94
+
95
+ class BaseLLMClient(ABC):
96
+ """Abstract base class for LLM clients.
97
+
98
+ All LLM vendor clients (OpenAI, Anthropic, Groq, Gemini, Llama, etc.)
99
+ must implement this interface to ensure consistent behavior across
100
+ the codebase.
101
+
102
+ This abstraction allows:
103
+ - Easy addition of new LLM providers
104
+ - Consistent API for all providers
105
+ - Provider-agnostic code in higher-level modules
106
+ - Simplified testing with mock implementations
107
+
108
+ Example:
109
+ >>> class MyClient(BaseLLMClient):
110
+ ... def completion(self, messages, **kwargs):
111
+ ... # Implementation here
112
+ ... pass
113
+ ...
114
+ >>> client = MyClient(config)
115
+ >>> msgs = [{"role": "user", "content": "Hello"}]
116
+ >>> response = client.completion(msgs)
117
+ >>> print(response.content)
118
+ """
119
+
120
+ @abstractmethod
121
+ def __init__(self, config: LLMConfig) -> None:
122
+ """Initialize the client with configuration.
123
+
124
+ Args:
125
+ config: Configuration for the LLM client.
126
+ """
127
+ pass
128
+
129
+ @property
130
+ @abstractmethod
131
+ def provider_name(self) -> str:
132
+ """Return the name of the LLM provider.
133
+
134
+ Returns:
135
+ Provider name (e.g., "openai", "anthropic", "groq").
136
+ """
137
+ pass
138
+
139
+ @abstractmethod
140
+ def completion(
141
+ self, messages: List[Dict[str, str]], **kwargs: Any
142
+ ) -> LLMResponse:
143
+ """Make a chat completion request.
144
+
145
+ This is the core method that sends a request to the LLM provider
146
+ and returns a standardized response.
147
+
148
+ Args:
149
+ messages: List of message dicts with "role" and "content" keys.
150
+ Roles can be: "system", "user", "assistant".
151
+ **kwargs: Provider-specific options (temperature, max_tokens, etc.)
152
+ that override the config defaults.
153
+
154
+ Returns:
155
+ LLMResponse with the generated content and metadata.
156
+
157
+ Raises:
158
+ ValueError: If the API request fails or returns an error.
159
+ """
160
+ pass
161
+
162
+ def complete_json(
163
+ self, messages: List[Dict[str, str]], **kwargs: Any
164
+ ) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
165
+ """Make a completion request and parse response as JSON.
166
+
167
+ Convenience method that calls completion() and attempts to parse
168
+ the response content as JSON.
169
+
170
+ Args:
171
+ messages: List of message dicts with "role" and "content" keys.
172
+ **kwargs: Provider-specific options passed to completion().
173
+
174
+ Returns:
175
+ Tuple of (parsed JSON dict or None, raw LLMResponse).
176
+ """
177
+ response = self.completion(messages, **kwargs)
178
+ parsed = response.parse_json()
179
+ return parsed, response
180
+
181
+ @property
182
+ @abstractmethod
183
+ def call_count(self) -> int:
184
+ """Return the number of API calls made by this client.
185
+
186
+ Returns:
187
+ Total number of completion calls made.
188
+ """
189
+ pass
190
+
191
+ @abstractmethod
192
+ def is_available(self) -> bool:
193
+ """Check if the LLM provider is available and configured.
194
+
195
+ This method checks whether the client can make API calls:
196
+ - For cloud providers: checks if API key is set
197
+ - For local providers: checks if server is running
198
+
199
+ Returns:
200
+ True if the provider is available and ready for requests.
201
+ """
202
+ pass
203
+
204
+ @abstractmethod
205
+ def list_models(self) -> List[str]:
206
+ """List available models from the provider.
207
+
208
+ Queries the provider's API to get the list of models accessible
209
+ with the current API key or configuration. Results are filtered
210
+ by the user's subscription/access level.
211
+
212
+ Returns:
213
+ List of model identifiers available for use.
214
+
215
+ Raises:
216
+ ValueError: If the API request fails.
217
+ """
218
+ pass
219
+
220
+ @property
221
+ def model_name(self) -> str:
222
+ """Return the model name being used.
223
+
224
+ Returns:
225
+ Model identifier string.
226
+ """
227
+ return getattr(self, "config", LLMConfig(model="unknown")).model
228
+
229
+ def _build_cache_key(
230
+ self,
231
+ messages: List[Dict[str, str]],
232
+ temperature: Optional[float] = None,
233
+ max_tokens: Optional[int] = None,
234
+ ) -> str:
235
+ """Build a deterministic cache key for the request.
236
+
237
+ Creates a SHA-256 hash from the model, messages, temperature, and
238
+ max_tokens. The hash is truncated to 16 hex characters (64 bits).
239
+
240
+ Args:
241
+ messages: List of message dicts with "role" and "content" keys.
242
+ temperature: Sampling temperature (defaults to config value).
243
+ max_tokens: Maximum tokens (defaults to config value).
244
+
245
+ Returns:
246
+ 16-character hex string cache key.
247
+ """
248
+ config = getattr(self, "config", LLMConfig(model="unknown"))
249
+ key_data = {
250
+ "model": config.model,
251
+ "messages": messages,
252
+ "temperature": (
253
+ temperature if temperature is not None else config.temperature
254
+ ),
255
+ "max_tokens": (
256
+ max_tokens if max_tokens is not None else config.max_tokens
257
+ ),
258
+ }
259
+ key_json = json.dumps(key_data, sort_keys=True, separators=(",", ":"))
260
+ return hashlib.sha256(key_json.encode()).hexdigest()[:16]
261
+
262
+ def set_cache(
263
+ self,
264
+ cache: Optional["TokenCache"],
265
+ use_cache: bool = True,
266
+ ) -> None:
267
+ """Configure caching for this client.
268
+
269
+ Args:
270
+ cache: TokenCache instance for caching, or None to disable.
271
+ use_cache: Whether to use the cache (default True).
272
+ """
273
+ self._cache = cache
274
+ self._use_cache = use_cache
275
+
276
+ @property
277
+ def cache(self) -> Optional["TokenCache"]:
278
+ """Return the configured cache, if any."""
279
+ return getattr(self, "_cache", None)
280
+
281
+ @property
282
+ def use_cache(self) -> bool:
283
+ """Return whether caching is enabled."""
284
+ return getattr(self, "_use_cache", True)
285
+
286
+ def cached_completion(
287
+ self,
288
+ messages: List[Dict[str, str]],
289
+ **kwargs: Any,
290
+ ) -> LLMResponse:
291
+ """Make a completion request with caching.
292
+
293
+ If caching is enabled and a cached response exists, returns
294
+ the cached response without making an API call. Otherwise,
295
+ makes the API call and caches the result.
296
+
297
+ Args:
298
+ messages: List of message dicts with "role" and "content" keys.
299
+ **kwargs: Provider-specific options (temperature, max_tokens, etc.)
300
+
301
+ Returns:
302
+ LLMResponse with the generated content and metadata.
303
+ """
304
+ from causaliq_knowledge.llm.cache import LLMCacheEntry, LLMEntryEncoder
305
+
306
+ cache = self.cache
307
+ use_cache = self.use_cache
308
+
309
+ # Build cache key
310
+ temperature = kwargs.get("temperature")
311
+ max_tokens = kwargs.get("max_tokens")
312
+ cache_key = self._build_cache_key(messages, temperature, max_tokens)
313
+
314
+ # Check cache
315
+ if use_cache and cache is not None:
316
+ # Ensure encoder is registered
317
+ if not cache.has_encoder("llm"):
318
+ cache.register_encoder("llm", LLMEntryEncoder())
319
+
320
+ if cache.exists(cache_key, "llm"):
321
+ cached_data = cache.get_data(cache_key, "llm")
322
+ if cached_data is not None:
323
+ entry = LLMCacheEntry.from_dict(cached_data)
324
+ return LLMResponse(
325
+ content=entry.response.content,
326
+ model=entry.model,
327
+ input_tokens=entry.metadata.tokens.input,
328
+ output_tokens=entry.metadata.tokens.output,
329
+ cost=entry.metadata.cost_usd or 0.0,
330
+ )
331
+
332
+ # Make API call with timing
333
+ start_time = time.perf_counter()
334
+ response = self.completion(messages, **kwargs)
335
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
336
+
337
+ # Store in cache
338
+ if use_cache and cache is not None:
339
+ config = getattr(self, "config", LLMConfig(model="unknown"))
340
+ entry = LLMCacheEntry.create(
341
+ model=config.model,
342
+ messages=messages,
343
+ content=response.content,
344
+ temperature=(
345
+ temperature
346
+ if temperature is not None
347
+ else config.temperature
348
+ ),
349
+ max_tokens=(
350
+ max_tokens if max_tokens is not None else config.max_tokens
351
+ ),
352
+ provider=self.provider_name,
353
+ latency_ms=latency_ms,
354
+ input_tokens=response.input_tokens,
355
+ output_tokens=response.output_tokens,
356
+ cost_usd=response.cost,
357
+ )
358
+ cache.put_data(cache_key, "llm", entry.to_dict())
359
+
360
+ return response