causaliq-knowledge 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,220 @@
1
+ """Abstract base class for LLM clients.
2
+
3
+ This module defines the common interface that all LLM vendor clients
4
+ must implement. This provides a consistent API regardless of the
5
+ underlying LLM provider.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class LLMConfig:
19
+ """Base configuration for all LLM clients.
20
+
21
+ This dataclass defines common configuration options shared by all
22
+ LLM provider clients. Vendor-specific clients may extend this with
23
+ additional options.
24
+
25
+ Attributes:
26
+ model: Model identifier (provider-specific format).
27
+ temperature: Sampling temperature (0.0=deterministic, 1.0=creative).
28
+ max_tokens: Maximum tokens in the response.
29
+ timeout: Request timeout in seconds.
30
+ api_key: API key for authentication (optional, can use env var).
31
+ """
32
+
33
+ model: str
34
+ temperature: float = 0.1
35
+ max_tokens: int = 500
36
+ timeout: float = 30.0
37
+ api_key: Optional[str] = None
38
+
39
+
40
+ @dataclass
41
+ class LLMResponse:
42
+ """Standard response from any LLM client.
43
+
44
+ This dataclass provides a unified response format across all LLM providers,
45
+ abstracting away provider-specific response structures.
46
+
47
+ Attributes:
48
+ content: The text content of the response.
49
+ model: The model that generated the response.
50
+ input_tokens: Number of input/prompt tokens used.
51
+ output_tokens: Number of output/completion tokens generated.
52
+ cost: Estimated cost of the request (if available).
53
+ raw_response: The original provider-specific response (for debugging).
54
+ """
55
+
56
+ content: str
57
+ model: str
58
+ input_tokens: int = 0
59
+ output_tokens: int = 0
60
+ cost: float = 0.0
61
+ raw_response: Optional[Dict[str, Any]] = field(default=None, repr=False)
62
+
63
+ def parse_json(self) -> Optional[Dict[str, Any]]:
64
+ """Parse content as JSON, handling common formatting issues.
65
+
66
+ LLMs sometimes wrap JSON in markdown code blocks. This method
67
+ handles those cases and attempts to extract valid JSON.
68
+
69
+ Returns:
70
+ Parsed JSON as dict, or None if parsing fails.
71
+ """
72
+ try:
73
+ # Clean up potential markdown code blocks
74
+ text = self.content.strip()
75
+ if text.startswith("```json"):
76
+ text = text[7:]
77
+ elif text.startswith("```"):
78
+ text = text[3:]
79
+ if text.endswith("```"):
80
+ text = text[:-3]
81
+
82
+ return json.loads(text.strip()) # type: ignore[no-any-return]
83
+ except json.JSONDecodeError as e:
84
+ logger.warning(f"Failed to parse JSON response: {e}")
85
+ return None
86
+
87
+
88
+ class BaseLLMClient(ABC):
89
+ """Abstract base class for LLM clients.
90
+
91
+ All LLM vendor clients (OpenAI, Anthropic, Groq, Gemini, Llama, etc.)
92
+ must implement this interface to ensure consistent behavior across
93
+ the codebase.
94
+
95
+ This abstraction allows:
96
+ - Easy addition of new LLM providers
97
+ - Consistent API for all providers
98
+ - Provider-agnostic code in higher-level modules
99
+ - Simplified testing with mock implementations
100
+
101
+ Example:
102
+ >>> class MyClient(BaseLLMClient):
103
+ ... def completion(self, messages, **kwargs):
104
+ ... # Implementation here
105
+ ... pass
106
+ ...
107
+ >>> client = MyClient(config)
108
+ >>> msgs = [{"role": "user", "content": "Hello"}]
109
+ >>> response = client.completion(msgs)
110
+ >>> print(response.content)
111
+ """
112
+
113
+ @abstractmethod
114
+ def __init__(self, config: LLMConfig) -> None:
115
+ """Initialize the client with configuration.
116
+
117
+ Args:
118
+ config: Configuration for the LLM client.
119
+ """
120
+ pass
121
+
122
+ @property
123
+ @abstractmethod
124
+ def provider_name(self) -> str:
125
+ """Return the name of the LLM provider.
126
+
127
+ Returns:
128
+ Provider name (e.g., "openai", "anthropic", "groq").
129
+ """
130
+ pass
131
+
132
+ @abstractmethod
133
+ def completion(
134
+ self, messages: List[Dict[str, str]], **kwargs: Any
135
+ ) -> LLMResponse:
136
+ """Make a chat completion request.
137
+
138
+ This is the core method that sends a request to the LLM provider
139
+ and returns a standardized response.
140
+
141
+ Args:
142
+ messages: List of message dicts with "role" and "content" keys.
143
+ Roles can be: "system", "user", "assistant".
144
+ **kwargs: Provider-specific options (temperature, max_tokens, etc.)
145
+ that override the config defaults.
146
+
147
+ Returns:
148
+ LLMResponse with the generated content and metadata.
149
+
150
+ Raises:
151
+ ValueError: If the API request fails or returns an error.
152
+ """
153
+ pass
154
+
155
+ def complete_json(
156
+ self, messages: List[Dict[str, str]], **kwargs: Any
157
+ ) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
158
+ """Make a completion request and parse response as JSON.
159
+
160
+ Convenience method that calls completion() and attempts to parse
161
+ the response content as JSON.
162
+
163
+ Args:
164
+ messages: List of message dicts with "role" and "content" keys.
165
+ **kwargs: Provider-specific options passed to completion().
166
+
167
+ Returns:
168
+ Tuple of (parsed JSON dict or None, raw LLMResponse).
169
+ """
170
+ response = self.completion(messages, **kwargs)
171
+ parsed = response.parse_json()
172
+ return parsed, response
173
+
174
+ @property
175
+ @abstractmethod
176
+ def call_count(self) -> int:
177
+ """Return the number of API calls made by this client.
178
+
179
+ Returns:
180
+ Total number of completion calls made.
181
+ """
182
+ pass
183
+
184
+ @abstractmethod
185
+ def is_available(self) -> bool:
186
+ """Check if the LLM provider is available and configured.
187
+
188
+ This method checks whether the client can make API calls:
189
+ - For cloud providers: checks if API key is set
190
+ - For local providers: checks if server is running
191
+
192
+ Returns:
193
+ True if the provider is available and ready for requests.
194
+ """
195
+ pass
196
+
197
+ @abstractmethod
198
+ def list_models(self) -> List[str]:
199
+ """List available models from the provider.
200
+
201
+ Queries the provider's API to get the list of models accessible
202
+ with the current API key or configuration. Results are filtered
203
+ by the user's subscription/access level.
204
+
205
+ Returns:
206
+ List of model identifiers available for use.
207
+
208
+ Raises:
209
+ ValueError: If the API request fails.
210
+ """
211
+ pass
212
+
213
+ @property
214
+ def model_name(self) -> str:
215
+ """Return the model name being used.
216
+
217
+ Returns:
218
+ Model identifier string.
219
+ """
220
+ return getattr(self, "config", LLMConfig(model="unknown")).model
@@ -0,0 +1,108 @@
1
+ """Direct DeepSeek API client - OpenAI-compatible API."""
2
+
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional
7
+
8
+ from causaliq_knowledge.llm.openai_compat_client import (
9
+ OpenAICompatClient,
10
+ OpenAICompatConfig,
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class DeepSeekConfig(OpenAICompatConfig):
18
+ """Configuration for DeepSeek API client.
19
+
20
+ Extends OpenAICompatConfig with DeepSeek-specific defaults.
21
+
22
+ Attributes:
23
+ model: DeepSeek model identifier (default: deepseek-chat).
24
+ temperature: Sampling temperature (default: 0.1).
25
+ max_tokens: Maximum response tokens (default: 500).
26
+ timeout: Request timeout in seconds (default: 30.0).
27
+ api_key: DeepSeek API key (falls back to DEEPSEEK_API_KEY env var).
28
+ """
29
+
30
+ model: str = "deepseek-chat"
31
+ temperature: float = 0.1
32
+ max_tokens: int = 500
33
+ timeout: float = 30.0
34
+ api_key: Optional[str] = None
35
+
36
+ def __post_init__(self) -> None:
37
+ """Set API key from environment if not provided."""
38
+ if self.api_key is None:
39
+ self.api_key = os.getenv("DEEPSEEK_API_KEY")
40
+ if not self.api_key:
41
+ raise ValueError(
42
+ "DEEPSEEK_API_KEY environment variable is required"
43
+ )
44
+
45
+
46
+ class DeepSeekClient(OpenAICompatClient):
47
+ """Direct DeepSeek API client.
48
+
49
+ DeepSeek uses an OpenAI-compatible API, making integration straightforward.
50
+ Known for excellent reasoning capabilities (R1) at low cost.
51
+
52
+ Available models:
53
+ - deepseek-chat: General purpose (DeepSeek-V3)
54
+ - deepseek-reasoner: Advanced reasoning (DeepSeek-R1)
55
+
56
+ Example:
57
+ >>> config = DeepSeekConfig(model="deepseek-chat")
58
+ >>> client = DeepSeekClient(config)
59
+ >>> msgs = [{"role": "user", "content": "Hello"}]
60
+ >>> response = client.completion(msgs)
61
+ >>> print(response.content)
62
+ """
63
+
64
+ BASE_URL = "https://api.deepseek.com"
65
+ PROVIDER_NAME = "deepseek"
66
+ ENV_VAR = "DEEPSEEK_API_KEY"
67
+
68
+ def __init__(self, config: Optional[DeepSeekConfig] = None) -> None:
69
+ """Initialize DeepSeek client.
70
+
71
+ Args:
72
+ config: DeepSeek configuration. If None, uses defaults with
73
+ API key from DEEPSEEK_API_KEY environment variable.
74
+ """
75
+ super().__init__(config)
76
+
77
+ def _default_config(self) -> DeepSeekConfig:
78
+ """Return default DeepSeek configuration."""
79
+ return DeepSeekConfig()
80
+
81
+ def _get_pricing(self) -> Dict[str, Dict[str, float]]:
82
+ """Return DeepSeek pricing per 1M tokens.
83
+
84
+ Returns:
85
+ Dict mapping model prefixes to input/output costs.
86
+ """
87
+ # DeepSeek pricing as of Jan 2025
88
+ # Note: Cache hits are much cheaper but we use regular pricing
89
+ return {
90
+ "deepseek-reasoner": {"input": 0.55, "output": 2.19},
91
+ "deepseek-chat": {"input": 0.14, "output": 0.28},
92
+ }
93
+
94
+ def _filter_models(self, models: List[str]) -> List[str]:
95
+ """Filter to DeepSeek chat models only.
96
+
97
+ Args:
98
+ models: List of all model IDs from API.
99
+
100
+ Returns:
101
+ Filtered list of DeepSeek models.
102
+ """
103
+ filtered = []
104
+ for model_id in models:
105
+ # Include deepseek chat and reasoner models
106
+ if model_id.startswith("deepseek-"):
107
+ filtered.append(model_id)
108
+ return filtered
@@ -1,6 +1,5 @@
1
1
  """Direct Google Gemini API client - clean and reliable."""
2
2
 
3
- import json
4
3
  import logging
5
4
  import os
6
5
  from dataclasses import dataclass
@@ -8,12 +7,28 @@ from typing import Any, Dict, List, Optional
8
7
 
9
8
  import httpx
10
9
 
10
+ from causaliq_knowledge.llm.base_client import (
11
+ BaseLLMClient,
12
+ LLMConfig,
13
+ LLMResponse,
14
+ )
15
+
11
16
  logger = logging.getLogger(__name__)
12
17
 
13
18
 
14
19
  @dataclass
15
- class GeminiConfig:
16
- """Configuration for Gemini API client."""
20
+ class GeminiConfig(LLMConfig):
21
+ """Configuration for Gemini API client.
22
+
23
+ Extends LLMConfig with Gemini-specific defaults.
24
+
25
+ Attributes:
26
+ model: Gemini model identifier (default: gemini-2.5-flash).
27
+ temperature: Sampling temperature (default: 0.1).
28
+ max_tokens: Maximum response tokens (default: 500).
29
+ timeout: Request timeout in seconds (default: 30.0).
30
+ api_key: Gemini API key (falls back to GEMINI_API_KEY env var).
31
+ """
17
32
 
18
33
  model: str = "gemini-2.5-flash"
19
34
  temperature: float = 0.1
@@ -29,48 +44,52 @@ class GeminiConfig:
29
44
  raise ValueError("GEMINI_API_KEY environment variable is required")
30
45
 
31
46
 
32
- @dataclass
33
- class GeminiResponse:
34
- """Response from Gemini API."""
35
-
36
- content: str
37
- model: str
38
- input_tokens: int = 0
39
- output_tokens: int = 0
40
- cost: float = 0.0 # Gemini free tier
41
- raw_response: Optional[Dict] = None
42
-
43
- def parse_json(self) -> Optional[Dict[str, Any]]:
44
- """Parse content as JSON, handling common formatting issues."""
45
- try:
46
- # Clean up potential markdown code blocks
47
- text = self.content.strip()
48
- if text.startswith("```json"):
49
- text = text[7:]
50
- elif text.startswith("```"):
51
- text = text[3:]
52
- if text.endswith("```"):
53
- text = text[:-3]
54
-
55
- return json.loads(text.strip()) # type: ignore[no-any-return]
56
- except json.JSONDecodeError:
57
- return None
47
+ class GeminiClient(BaseLLMClient):
48
+ """Direct Gemini API client.
58
49
 
50
+ Implements the BaseLLMClient interface for Google's Gemini API.
51
+ Uses httpx for HTTP requests.
59
52
 
60
- class GeminiClient:
61
- """Direct Gemini API client."""
53
+ Example:
54
+ >>> config = GeminiConfig(model="gemini-2.5-flash")
55
+ >>> client = GeminiClient(config)
56
+ >>> msgs = [{"role": "user", "content": "Hello"}]
57
+ >>> response = client.completion(msgs)
58
+ >>> print(response.content)
59
+ """
62
60
 
63
61
  BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
64
62
 
65
- def __init__(self, config: Optional[GeminiConfig] = None):
66
- """Initialize Gemini client."""
63
+ def __init__(self, config: Optional[GeminiConfig] = None) -> None:
64
+ """Initialize Gemini client.
65
+
66
+ Args:
67
+ config: Gemini configuration. If None, uses defaults with
68
+ API key from GEMINI_API_KEY environment variable.
69
+ """
67
70
  self.config = config or GeminiConfig()
68
71
  self._total_calls = 0
69
72
 
73
+ @property
74
+ def provider_name(self) -> str:
75
+ """Return the provider name."""
76
+ return "gemini"
77
+
70
78
  def completion(
71
79
  self, messages: List[Dict[str, str]], **kwargs: Any
72
- ) -> GeminiResponse:
73
- """Make a chat completion request to Gemini."""
80
+ ) -> LLMResponse:
81
+ """Make a chat completion request to Gemini.
82
+
83
+ Args:
84
+ messages: List of message dicts with "role" and "content" keys.
85
+ **kwargs: Override config options (temperature, max_tokens).
86
+
87
+ Returns:
88
+ LLMResponse with the generated content and metadata.
89
+
90
+ Raises:
91
+ ValueError: If the API request fails.
92
+ """
74
93
 
75
94
  # Convert OpenAI-style messages to Gemini format
76
95
  contents = []
@@ -158,7 +177,7 @@ class GeminiClient:
158
177
  f"Gemini response: {input_tokens} in, {output_tokens} out"
159
178
  )
160
179
 
161
- return GeminiResponse(
180
+ return LLMResponse(
162
181
  content=content,
163
182
  model=self.config.model,
164
183
  input_tokens=input_tokens,
@@ -191,13 +210,72 @@ class GeminiClient:
191
210
 
192
211
  def complete_json(
193
212
  self, messages: List[Dict[str, str]], **kwargs: Any
194
- ) -> tuple[Optional[Dict[str, Any]], GeminiResponse]:
195
- """Make a completion request and parse response as JSON."""
213
+ ) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
214
+ """Make a completion request and parse response as JSON.
215
+
216
+ Args:
217
+ messages: List of message dicts with "role" and "content" keys.
218
+ **kwargs: Override config options passed to completion().
219
+
220
+ Returns:
221
+ Tuple of (parsed JSON dict or None, raw LLMResponse).
222
+ """
196
223
  response = self.completion(messages, **kwargs)
197
224
  parsed = response.parse_json()
198
225
  return parsed, response
199
226
 
200
227
  @property
201
228
  def call_count(self) -> int:
202
- """Number of API calls made."""
229
+ """Return the number of API calls made."""
203
230
  return self._total_calls
231
+
232
+ def is_available(self) -> bool:
233
+ """Check if Gemini API is available.
234
+
235
+ Returns:
236
+ True if GEMINI_API_KEY is configured.
237
+ """
238
+ return bool(self.config.api_key)
239
+
240
+ def list_models(self) -> List[str]:
241
+ """List available models from Gemini API.
242
+
243
+ Queries the Gemini API to get models accessible with the current
244
+ API key. Filters to only include models that support generateContent.
245
+
246
+ Returns:
247
+ List of model identifiers (e.g., ['gemini-2.5-flash', ...]).
248
+
249
+ Raises:
250
+ ValueError: If the API request fails.
251
+ """
252
+ try:
253
+ with httpx.Client(timeout=self.config.timeout) as client:
254
+ response = client.get(
255
+ f"{self.BASE_URL}?key={self.config.api_key}",
256
+ )
257
+ response.raise_for_status()
258
+ data = response.json()
259
+
260
+ # Filter to models that support text generation
261
+ models = []
262
+ for model in data.get("models", []):
263
+ methods = model.get("supportedGenerationMethods", [])
264
+ if "generateContent" not in methods:
265
+ continue
266
+ # Extract model name (remove 'models/' prefix)
267
+ name = model.get("name", "").replace("models/", "")
268
+ # Skip embedding and TTS models
269
+ if any(x in name.lower() for x in ["embed", "tts", "aqa"]):
270
+ continue
271
+ models.append(name)
272
+
273
+ return sorted(models)
274
+
275
+ except httpx.HTTPStatusError as e:
276
+ raise ValueError(
277
+ f"Gemini API error: {e.response.status_code} - "
278
+ f"{e.response.text}"
279
+ )
280
+ except Exception as e:
281
+ raise ValueError(f"Failed to list Gemini models: {e}")