causaliq-knowledge 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +3 -3
- causaliq_knowledge/cache/__init__.py +18 -0
- causaliq_knowledge/cache/encoders/__init__.py +13 -0
- causaliq_knowledge/cache/encoders/base.py +90 -0
- causaliq_knowledge/cache/encoders/json_encoder.py +418 -0
- causaliq_knowledge/cache/token_cache.py +632 -0
- causaliq_knowledge/cli.py +588 -38
- causaliq_knowledge/llm/__init__.py +39 -10
- causaliq_knowledge/llm/anthropic_client.py +256 -0
- causaliq_knowledge/llm/base_client.py +360 -0
- causaliq_knowledge/llm/cache.py +380 -0
- causaliq_knowledge/llm/deepseek_client.py +108 -0
- causaliq_knowledge/llm/gemini_client.py +117 -39
- causaliq_knowledge/llm/groq_client.py +115 -40
- causaliq_knowledge/llm/mistral_client.py +122 -0
- causaliq_knowledge/llm/ollama_client.py +240 -0
- causaliq_knowledge/llm/openai_client.py +115 -0
- causaliq_knowledge/llm/openai_compat_client.py +287 -0
- causaliq_knowledge/llm/provider.py +99 -46
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/METADATA +9 -10
- causaliq_knowledge-0.3.0.dist-info/RECORD +28 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/WHEEL +1 -1
- causaliq_knowledge-0.1.0.dist-info/RECORD +0 -15
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/entry_points.txt +0 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.1.0.dist-info → causaliq_knowledge-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Direct Groq API client - clean and reliable."""
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
import logging
|
|
5
4
|
import os
|
|
6
5
|
from dataclasses import dataclass
|
|
@@ -8,12 +7,28 @@ from typing import Any, Dict, List, Optional
|
|
|
8
7
|
|
|
9
8
|
import httpx
|
|
10
9
|
|
|
10
|
+
from causaliq_knowledge.llm.base_client import (
|
|
11
|
+
BaseLLMClient,
|
|
12
|
+
LLMConfig,
|
|
13
|
+
LLMResponse,
|
|
14
|
+
)
|
|
15
|
+
|
|
11
16
|
logger = logging.getLogger(__name__)
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
@dataclass
|
|
15
|
-
class GroqConfig:
|
|
16
|
-
"""Configuration for Groq API client.
|
|
20
|
+
class GroqConfig(LLMConfig):
|
|
21
|
+
"""Configuration for Groq API client.
|
|
22
|
+
|
|
23
|
+
Extends LLMConfig with Groq-specific defaults.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
model: Groq model identifier (default: llama-3.1-8b-instant).
|
|
27
|
+
temperature: Sampling temperature (default: 0.1).
|
|
28
|
+
max_tokens: Maximum response tokens (default: 500).
|
|
29
|
+
timeout: Request timeout in seconds (default: 30.0).
|
|
30
|
+
api_key: Groq API key (falls back to GROQ_API_KEY env var).
|
|
31
|
+
"""
|
|
17
32
|
|
|
18
33
|
model: str = "llama-3.1-8b-instant"
|
|
19
34
|
temperature: float = 0.1
|
|
@@ -29,50 +44,52 @@ class GroqConfig:
|
|
|
29
44
|
raise ValueError("GROQ_API_KEY environment variable is required")
|
|
30
45
|
|
|
31
46
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
"""Response from Groq API."""
|
|
35
|
-
|
|
36
|
-
content: str
|
|
37
|
-
model: str
|
|
38
|
-
input_tokens: int = 0
|
|
39
|
-
output_tokens: int = 0
|
|
40
|
-
cost: float = 0.0 # Groq free tier
|
|
41
|
-
raw_response: Optional[Dict] = None
|
|
42
|
-
|
|
43
|
-
def parse_json(self) -> Optional[Dict[str, Any]]:
|
|
44
|
-
"""Parse content as JSON, handling common formatting issues."""
|
|
45
|
-
try:
|
|
46
|
-
# Clean up potential markdown code blocks
|
|
47
|
-
text = self.content.strip()
|
|
48
|
-
if text.startswith("```json"):
|
|
49
|
-
text = text[7:]
|
|
50
|
-
elif text.startswith("```"):
|
|
51
|
-
text = text[3:]
|
|
52
|
-
if text.endswith("```"):
|
|
53
|
-
text = text[:-3]
|
|
54
|
-
|
|
55
|
-
return json.loads(text.strip()) # type: ignore[no-any-return]
|
|
56
|
-
except json.JSONDecodeError as e:
|
|
57
|
-
logger.warning(f"Failed to parse JSON response: {e}")
|
|
58
|
-
return None
|
|
47
|
+
class GroqClient(BaseLLMClient):
|
|
48
|
+
"""Direct Groq API client.
|
|
59
49
|
|
|
50
|
+
Implements the BaseLLMClient interface for Groq's API.
|
|
51
|
+
Uses httpx for HTTP requests.
|
|
60
52
|
|
|
61
|
-
|
|
62
|
-
|
|
53
|
+
Example:
|
|
54
|
+
>>> config = GroqConfig(model="llama-3.1-8b-instant")
|
|
55
|
+
>>> client = GroqClient(config)
|
|
56
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
57
|
+
>>> response = client.completion(msgs)
|
|
58
|
+
>>> print(response.content)
|
|
59
|
+
"""
|
|
63
60
|
|
|
64
61
|
BASE_URL = "https://api.groq.com/openai/v1"
|
|
65
62
|
|
|
66
|
-
def __init__(self, config: Optional[GroqConfig] = None):
|
|
67
|
-
"""Initialize Groq client.
|
|
63
|
+
def __init__(self, config: Optional[GroqConfig] = None) -> None:
|
|
64
|
+
"""Initialize Groq client.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
config: Groq configuration. If None, uses defaults with
|
|
68
|
+
API key from GROQ_API_KEY environment variable.
|
|
69
|
+
"""
|
|
68
70
|
self.config = config or GroqConfig()
|
|
69
71
|
self._total_calls = 0
|
|
70
72
|
|
|
73
|
+
@property
|
|
74
|
+
def provider_name(self) -> str:
|
|
75
|
+
"""Return the provider name."""
|
|
76
|
+
return "groq"
|
|
77
|
+
|
|
71
78
|
def completion(
|
|
72
79
|
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
73
|
-
) ->
|
|
74
|
-
"""Make a chat completion request to Groq.
|
|
80
|
+
) -> LLMResponse:
|
|
81
|
+
"""Make a chat completion request to Groq.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
85
|
+
**kwargs: Override config options (temperature, max_tokens).
|
|
75
86
|
|
|
87
|
+
Returns:
|
|
88
|
+
LLMResponse with the generated content and metadata.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
ValueError: If the API request fails.
|
|
92
|
+
"""
|
|
76
93
|
# Build request payload
|
|
77
94
|
payload = {
|
|
78
95
|
"model": self.config.model,
|
|
@@ -112,7 +129,7 @@ class GroqClient:
|
|
|
112
129
|
f"Groq response: {input_tokens} in, {output_tokens} out"
|
|
113
130
|
)
|
|
114
131
|
|
|
115
|
-
return
|
|
132
|
+
return LLMResponse(
|
|
116
133
|
content=content,
|
|
117
134
|
model=data.get("model", self.config.model),
|
|
118
135
|
input_tokens=input_tokens,
|
|
@@ -136,13 +153,71 @@ class GroqClient:
|
|
|
136
153
|
|
|
137
154
|
def complete_json(
|
|
138
155
|
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
139
|
-
) -> tuple[Optional[Dict[str, Any]],
|
|
140
|
-
"""Make a completion request and parse response as JSON.
|
|
156
|
+
) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
|
|
157
|
+
"""Make a completion request and parse response as JSON.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
161
|
+
**kwargs: Override config options passed to completion().
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Tuple of (parsed JSON dict or None, raw LLMResponse).
|
|
165
|
+
"""
|
|
141
166
|
response = self.completion(messages, **kwargs)
|
|
142
167
|
parsed = response.parse_json()
|
|
143
168
|
return parsed, response
|
|
144
169
|
|
|
145
170
|
@property
|
|
146
171
|
def call_count(self) -> int:
|
|
147
|
-
"""
|
|
172
|
+
"""Return the number of API calls made."""
|
|
148
173
|
return self._total_calls
|
|
174
|
+
|
|
175
|
+
def is_available(self) -> bool:
|
|
176
|
+
"""Check if Groq API is available.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
True if GROQ_API_KEY is configured.
|
|
180
|
+
"""
|
|
181
|
+
return bool(self.config.api_key)
|
|
182
|
+
|
|
183
|
+
def list_models(self) -> List[str]:
|
|
184
|
+
"""List available models from Groq API.
|
|
185
|
+
|
|
186
|
+
Queries the Groq API to get models accessible with the current
|
|
187
|
+
API key. Filters to only include text generation models.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
List of model identifiers (e.g., ['llama-3.1-8b-instant', ...]).
|
|
191
|
+
|
|
192
|
+
Raises:
|
|
193
|
+
ValueError: If the API request fails.
|
|
194
|
+
"""
|
|
195
|
+
try:
|
|
196
|
+
with httpx.Client(timeout=self.config.timeout) as client:
|
|
197
|
+
response = client.get(
|
|
198
|
+
f"{self.BASE_URL}/models",
|
|
199
|
+
headers={"Authorization": f"Bearer {self.config.api_key}"},
|
|
200
|
+
)
|
|
201
|
+
response.raise_for_status()
|
|
202
|
+
data = response.json()
|
|
203
|
+
|
|
204
|
+
# Filter and sort models
|
|
205
|
+
models = []
|
|
206
|
+
for model in data.get("data", []):
|
|
207
|
+
model_id = model.get("id", "")
|
|
208
|
+
# Skip whisper (audio), guard, and safeguard models
|
|
209
|
+
if any(
|
|
210
|
+
x in model_id.lower()
|
|
211
|
+
for x in ["whisper", "guard", "embed"]
|
|
212
|
+
):
|
|
213
|
+
continue
|
|
214
|
+
models.append(model_id)
|
|
215
|
+
|
|
216
|
+
return sorted(models)
|
|
217
|
+
|
|
218
|
+
except httpx.HTTPStatusError as e:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"Groq API error: {e.response.status_code} - {e.response.text}"
|
|
221
|
+
)
|
|
222
|
+
except Exception as e:
|
|
223
|
+
raise ValueError(f"Failed to list Groq models: {e}")
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Direct Mistral AI API client - OpenAI-compatible API."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from causaliq_knowledge.llm.openai_compat_client import (
|
|
9
|
+
OpenAICompatClient,
|
|
10
|
+
OpenAICompatConfig,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class MistralConfig(OpenAICompatConfig):
|
|
18
|
+
"""Configuration for Mistral AI API client.
|
|
19
|
+
|
|
20
|
+
Extends OpenAICompatConfig with Mistral-specific defaults.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
model: Mistral model identifier (default: mistral-small-latest).
|
|
24
|
+
temperature: Sampling temperature (default: 0.1).
|
|
25
|
+
max_tokens: Maximum response tokens (default: 500).
|
|
26
|
+
timeout: Request timeout in seconds (default: 30.0).
|
|
27
|
+
api_key: Mistral API key (falls back to MISTRAL_API_KEY env var).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
model: str = "mistral-small-latest"
|
|
31
|
+
temperature: float = 0.1
|
|
32
|
+
max_tokens: int = 500
|
|
33
|
+
timeout: float = 30.0
|
|
34
|
+
api_key: Optional[str] = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self) -> None:
|
|
37
|
+
"""Set API key from environment if not provided."""
|
|
38
|
+
if self.api_key is None:
|
|
39
|
+
self.api_key = os.getenv("MISTRAL_API_KEY")
|
|
40
|
+
if not self.api_key:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"MISTRAL_API_KEY environment variable is required"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MistralClient(OpenAICompatClient):
|
|
47
|
+
"""Direct Mistral AI API client.
|
|
48
|
+
|
|
49
|
+
Mistral AI is a French company providing high-quality LLMs with an
|
|
50
|
+
OpenAI-compatible API.
|
|
51
|
+
|
|
52
|
+
Available models:
|
|
53
|
+
- mistral-small-latest: Fast, cost-effective
|
|
54
|
+
- mistral-medium-latest: Balanced performance
|
|
55
|
+
- mistral-large-latest: Most capable
|
|
56
|
+
- codestral-latest: Optimized for code
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
>>> config = MistralConfig(model="mistral-small-latest")
|
|
60
|
+
>>> client = MistralClient(config)
|
|
61
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
62
|
+
>>> response = client.completion(msgs)
|
|
63
|
+
>>> print(response.content)
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
BASE_URL = "https://api.mistral.ai/v1"
|
|
67
|
+
PROVIDER_NAME = "mistral"
|
|
68
|
+
ENV_VAR = "MISTRAL_API_KEY"
|
|
69
|
+
|
|
70
|
+
def __init__(self, config: Optional[MistralConfig] = None) -> None:
|
|
71
|
+
"""Initialize Mistral client.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
config: Mistral configuration. If None, uses defaults with
|
|
75
|
+
API key from MISTRAL_API_KEY environment variable.
|
|
76
|
+
"""
|
|
77
|
+
super().__init__(config)
|
|
78
|
+
|
|
79
|
+
def _default_config(self) -> MistralConfig:
|
|
80
|
+
"""Return default Mistral configuration."""
|
|
81
|
+
return MistralConfig()
|
|
82
|
+
|
|
83
|
+
def _get_pricing(self) -> Dict[str, Dict[str, float]]:
|
|
84
|
+
"""Return Mistral pricing per 1M tokens.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Dict mapping model prefixes to input/output costs.
|
|
88
|
+
"""
|
|
89
|
+
# Mistral pricing as of Jan 2025
|
|
90
|
+
return {
|
|
91
|
+
"mistral-large": {"input": 2.00, "output": 6.00},
|
|
92
|
+
"mistral-medium": {"input": 2.70, "output": 8.10},
|
|
93
|
+
"mistral-small": {"input": 0.20, "output": 0.60},
|
|
94
|
+
"codestral": {"input": 0.20, "output": 0.60},
|
|
95
|
+
"open-mistral-nemo": {"input": 0.15, "output": 0.15},
|
|
96
|
+
"open-mixtral-8x22b": {"input": 2.00, "output": 6.00},
|
|
97
|
+
"open-mixtral-8x7b": {"input": 0.70, "output": 0.70},
|
|
98
|
+
"ministral-3b": {"input": 0.04, "output": 0.04},
|
|
99
|
+
"ministral-8b": {"input": 0.10, "output": 0.10},
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
def _filter_models(self, models: List[str]) -> List[str]:
|
|
103
|
+
"""Filter to Mistral chat models only.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
models: List of all model IDs from API.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Filtered list of Mistral models.
|
|
110
|
+
"""
|
|
111
|
+
filtered = []
|
|
112
|
+
for model_id in models:
|
|
113
|
+
# Include mistral and codestral models
|
|
114
|
+
if any(
|
|
115
|
+
prefix in model_id.lower()
|
|
116
|
+
for prefix in ["mistral", "codestral", "ministral", "mixtral"]
|
|
117
|
+
):
|
|
118
|
+
# Exclude embedding models
|
|
119
|
+
if "embed" in model_id.lower():
|
|
120
|
+
continue
|
|
121
|
+
filtered.append(model_id)
|
|
122
|
+
return filtered
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
"""Local Ollama API client for running Llama models locally."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from causaliq_knowledge.llm.base_client import (
|
|
10
|
+
BaseLLMClient,
|
|
11
|
+
LLMConfig,
|
|
12
|
+
LLMResponse,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class OllamaConfig(LLMConfig):
|
|
20
|
+
"""Configuration for Ollama API client.
|
|
21
|
+
|
|
22
|
+
Extends LLMConfig with Ollama-specific defaults.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
model: Ollama model identifier (default: llama3.2:1b).
|
|
26
|
+
temperature: Sampling temperature (default: 0.1).
|
|
27
|
+
max_tokens: Maximum response tokens (default: 500).
|
|
28
|
+
timeout: Request timeout in seconds (default: 120.0, local).
|
|
29
|
+
api_key: Not used for Ollama (local server).
|
|
30
|
+
base_url: Ollama server URL (default: http://localhost:11434).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
model: str = "llama3.2:1b"
|
|
34
|
+
temperature: float = 0.1
|
|
35
|
+
max_tokens: int = 500
|
|
36
|
+
timeout: float = 120.0 # Local inference can be slow
|
|
37
|
+
api_key: Optional[str] = None # Not needed for local Ollama
|
|
38
|
+
base_url: str = "http://localhost:11434"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class OllamaClient(BaseLLMClient):
|
|
42
|
+
"""Local Ollama API client.
|
|
43
|
+
|
|
44
|
+
Implements the BaseLLMClient interface for locally running Ollama server.
|
|
45
|
+
Uses httpx for HTTP requests to the local Ollama API.
|
|
46
|
+
|
|
47
|
+
Ollama provides an OpenAI-compatible API for running open-source models
|
|
48
|
+
like Llama locally without requiring API keys or internet access.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> config = OllamaConfig(model="llama3.2:1b")
|
|
52
|
+
>>> client = OllamaClient(config)
|
|
53
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
54
|
+
>>> response = client.completion(msgs)
|
|
55
|
+
>>> print(response.content)
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, config: Optional[OllamaConfig] = None) -> None:
|
|
59
|
+
"""Initialize Ollama client.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
config: Ollama configuration. If None, uses defaults connecting
|
|
63
|
+
to localhost:11434 with llama3.2:1b model.
|
|
64
|
+
"""
|
|
65
|
+
self.config = config or OllamaConfig()
|
|
66
|
+
self._total_calls = 0
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def provider_name(self) -> str:
|
|
70
|
+
"""Return the provider name."""
|
|
71
|
+
return "ollama"
|
|
72
|
+
|
|
73
|
+
def completion(
|
|
74
|
+
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
75
|
+
) -> LLMResponse:
|
|
76
|
+
"""Make a chat completion request to Ollama.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
80
|
+
**kwargs: Override config options (temperature, max_tokens).
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
LLMResponse with the generated content and metadata.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: If the API request fails or Ollama is not running.
|
|
87
|
+
"""
|
|
88
|
+
# Build request payload (Ollama uses similar format to OpenAI)
|
|
89
|
+
payload: Dict[str, Any] = {
|
|
90
|
+
"model": self.config.model,
|
|
91
|
+
"messages": messages,
|
|
92
|
+
"stream": False,
|
|
93
|
+
"options": {
|
|
94
|
+
"temperature": kwargs.get(
|
|
95
|
+
"temperature", self.config.temperature
|
|
96
|
+
),
|
|
97
|
+
"num_predict": kwargs.get(
|
|
98
|
+
"max_tokens", self.config.max_tokens
|
|
99
|
+
),
|
|
100
|
+
},
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
url = f"{self.config.base_url}/api/chat"
|
|
104
|
+
headers = {"Content-Type": "application/json"}
|
|
105
|
+
|
|
106
|
+
logger.debug(f"Calling Ollama API with model: {self.config.model}")
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
with httpx.Client(timeout=self.config.timeout) as client:
|
|
110
|
+
response = client.post(url, json=payload, headers=headers)
|
|
111
|
+
response.raise_for_status()
|
|
112
|
+
|
|
113
|
+
data = response.json()
|
|
114
|
+
|
|
115
|
+
# Extract response content
|
|
116
|
+
content = data.get("message", {}).get("content", "")
|
|
117
|
+
|
|
118
|
+
# Extract token counts (Ollama provides these)
|
|
119
|
+
input_tokens = data.get("prompt_eval_count", 0)
|
|
120
|
+
output_tokens = data.get("eval_count", 0)
|
|
121
|
+
|
|
122
|
+
self._total_calls += 1
|
|
123
|
+
|
|
124
|
+
logger.debug(
|
|
125
|
+
f"Ollama response: {input_tokens} in, {output_tokens} out"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return LLMResponse(
|
|
129
|
+
content=content,
|
|
130
|
+
model=data.get("model", self.config.model),
|
|
131
|
+
input_tokens=input_tokens,
|
|
132
|
+
output_tokens=output_tokens,
|
|
133
|
+
cost=0.0, # Local inference is free
|
|
134
|
+
raw_response=data,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
except httpx.ConnectError:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"Could not connect to Ollama. "
|
|
140
|
+
"Make sure Ollama is running (run 'ollama serve' or start "
|
|
141
|
+
"the Ollama app). "
|
|
142
|
+
f"Tried to connect to: {self.config.base_url}"
|
|
143
|
+
)
|
|
144
|
+
except httpx.HTTPStatusError as e:
|
|
145
|
+
if e.response.status_code == 404:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Model '{self.config.model}' not found. "
|
|
148
|
+
f"Run 'ollama pull {self.config.model}' to download it."
|
|
149
|
+
)
|
|
150
|
+
logger.error(
|
|
151
|
+
f"Ollama API error: {e.response.status_code} - "
|
|
152
|
+
f"{e.response.text}"
|
|
153
|
+
)
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Ollama API error: {e.response.status_code} - "
|
|
156
|
+
f"{e.response.text}"
|
|
157
|
+
)
|
|
158
|
+
except httpx.TimeoutException:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
"Ollama API request timed out. Local inference can be slow - "
|
|
161
|
+
"try increasing the timeout in OllamaConfig."
|
|
162
|
+
)
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.error(f"Ollama API unexpected error: {e}")
|
|
165
|
+
raise ValueError(f"Ollama API error: {str(e)}")
|
|
166
|
+
|
|
167
|
+
def complete_json(
|
|
168
|
+
self, messages: List[Dict[str, str]], **kwargs: Any
|
|
169
|
+
) -> tuple[Optional[Dict[str, Any]], LLMResponse]:
|
|
170
|
+
"""Make a completion request and parse response as JSON.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
174
|
+
**kwargs: Override config options passed to completion().
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Tuple of (parsed JSON dict or None, raw LLMResponse).
|
|
178
|
+
"""
|
|
179
|
+
response = self.completion(messages, **kwargs)
|
|
180
|
+
parsed = response.parse_json()
|
|
181
|
+
return parsed, response
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def call_count(self) -> int:
|
|
185
|
+
"""Return the number of API calls made."""
|
|
186
|
+
return self._total_calls
|
|
187
|
+
|
|
188
|
+
def is_available(self) -> bool:
|
|
189
|
+
"""Check if Ollama server is running and model is available.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
True if Ollama is running and the configured model exists.
|
|
193
|
+
"""
|
|
194
|
+
try:
|
|
195
|
+
with httpx.Client(timeout=5.0) as client:
|
|
196
|
+
# Check if server is running
|
|
197
|
+
response = client.get(f"{self.config.base_url}/api/tags")
|
|
198
|
+
if response.status_code != 200:
|
|
199
|
+
return False
|
|
200
|
+
|
|
201
|
+
# Check if model is available
|
|
202
|
+
data = response.json()
|
|
203
|
+
models = [m.get("name", "") for m in data.get("models", [])]
|
|
204
|
+
# Ollama model names can have :latest suffix
|
|
205
|
+
model_name = self.config.model
|
|
206
|
+
return any(
|
|
207
|
+
m == model_name or m.startswith(f"{model_name}:")
|
|
208
|
+
for m in models
|
|
209
|
+
)
|
|
210
|
+
except Exception:
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
def list_models(self) -> List[str]:
|
|
214
|
+
"""List installed models from Ollama.
|
|
215
|
+
|
|
216
|
+
Queries the local Ollama server to get installed models.
|
|
217
|
+
Unlike cloud providers, this returns only models the user
|
|
218
|
+
has explicitly pulled/installed.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
List of model identifiers (e.g., ['llama3.2:1b', ...]).
|
|
222
|
+
|
|
223
|
+
Raises:
|
|
224
|
+
ValueError: If Ollama server is not running.
|
|
225
|
+
"""
|
|
226
|
+
try:
|
|
227
|
+
with httpx.Client(timeout=5.0) as client:
|
|
228
|
+
response = client.get(f"{self.config.base_url}/api/tags")
|
|
229
|
+
response.raise_for_status()
|
|
230
|
+
data = response.json()
|
|
231
|
+
|
|
232
|
+
models = [m.get("name", "") for m in data.get("models", [])]
|
|
233
|
+
return sorted(models)
|
|
234
|
+
|
|
235
|
+
except httpx.ConnectError:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"Ollama server not running. Start with: ollama serve"
|
|
238
|
+
)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
raise ValueError(f"Failed to list Ollama models: {e}")
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Direct OpenAI API client - clean and reliable."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from causaliq_knowledge.llm.openai_compat_client import (
|
|
9
|
+
OpenAICompatClient,
|
|
10
|
+
OpenAICompatConfig,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class OpenAIConfig(OpenAICompatConfig):
|
|
18
|
+
"""Configuration for OpenAI API client.
|
|
19
|
+
|
|
20
|
+
Extends OpenAICompatConfig with OpenAI-specific defaults.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
model: OpenAI model identifier (default: gpt-4o-mini).
|
|
24
|
+
temperature: Sampling temperature (default: 0.1).
|
|
25
|
+
max_tokens: Maximum response tokens (default: 500).
|
|
26
|
+
timeout: Request timeout in seconds (default: 30.0).
|
|
27
|
+
api_key: OpenAI API key (falls back to OPENAI_API_KEY env var).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
model: str = "gpt-4o-mini"
|
|
31
|
+
temperature: float = 0.1
|
|
32
|
+
max_tokens: int = 500
|
|
33
|
+
timeout: float = 30.0
|
|
34
|
+
api_key: Optional[str] = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self) -> None:
|
|
37
|
+
"""Set API key from environment if not provided."""
|
|
38
|
+
if self.api_key is None:
|
|
39
|
+
self.api_key = os.getenv("OPENAI_API_KEY")
|
|
40
|
+
if not self.api_key:
|
|
41
|
+
raise ValueError("OPENAI_API_KEY environment variable is required")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class OpenAIClient(OpenAICompatClient):
|
|
45
|
+
"""Direct OpenAI API client.
|
|
46
|
+
|
|
47
|
+
Implements the BaseLLMClient interface for OpenAI's API.
|
|
48
|
+
Uses httpx for HTTP requests.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> config = OpenAIConfig(model="gpt-4o-mini")
|
|
52
|
+
>>> client = OpenAIClient(config)
|
|
53
|
+
>>> msgs = [{"role": "user", "content": "Hello"}]
|
|
54
|
+
>>> response = client.completion(msgs)
|
|
55
|
+
>>> print(response.content)
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
BASE_URL = "https://api.openai.com/v1"
|
|
59
|
+
PROVIDER_NAME = "openai"
|
|
60
|
+
ENV_VAR = "OPENAI_API_KEY"
|
|
61
|
+
|
|
62
|
+
def __init__(self, config: Optional[OpenAIConfig] = None) -> None:
|
|
63
|
+
"""Initialize OpenAI client.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
config: OpenAI configuration. If None, uses defaults with
|
|
67
|
+
API key from OPENAI_API_KEY environment variable.
|
|
68
|
+
"""
|
|
69
|
+
super().__init__(config)
|
|
70
|
+
|
|
71
|
+
def _default_config(self) -> OpenAIConfig:
|
|
72
|
+
"""Return default OpenAI configuration."""
|
|
73
|
+
return OpenAIConfig()
|
|
74
|
+
|
|
75
|
+
def _get_pricing(self) -> Dict[str, Dict[str, float]]:
|
|
76
|
+
"""Return OpenAI pricing per 1M tokens.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Dict mapping model prefixes to input/output costs.
|
|
80
|
+
"""
|
|
81
|
+
# Order matters - more specific prefixes must come first
|
|
82
|
+
return {
|
|
83
|
+
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
|
84
|
+
"gpt-4o": {"input": 2.50, "output": 10.00},
|
|
85
|
+
"gpt-4-turbo": {"input": 10.00, "output": 30.00},
|
|
86
|
+
"gpt-4": {"input": 30.00, "output": 60.00},
|
|
87
|
+
"gpt-3.5-turbo": {"input": 0.50, "output": 1.50},
|
|
88
|
+
"o1-mini": {"input": 3.00, "output": 12.00},
|
|
89
|
+
"o1": {"input": 15.00, "output": 60.00},
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
def _filter_models(self, models: List[str]) -> List[str]:
|
|
93
|
+
"""Filter to OpenAI chat models only.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
models: List of all model IDs from API.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Filtered list of GPT and o1/o3 models.
|
|
100
|
+
"""
|
|
101
|
+
filtered = []
|
|
102
|
+
for model_id in models:
|
|
103
|
+
# Include GPT and o1/o3 models
|
|
104
|
+
if any(
|
|
105
|
+
prefix in model_id
|
|
106
|
+
for prefix in ["gpt-4", "gpt-3.5", "o1", "o3"]
|
|
107
|
+
):
|
|
108
|
+
# Exclude instruct variants and specific exclusions
|
|
109
|
+
if any(
|
|
110
|
+
x in model_id.lower()
|
|
111
|
+
for x in ["instruct", "vision", "audio", "realtime"]
|
|
112
|
+
):
|
|
113
|
+
continue
|
|
114
|
+
filtered.append(model_id)
|
|
115
|
+
return filtered
|