causaliq-knowledge 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +6 -3
- causaliq_knowledge/action.py +480 -0
- causaliq_knowledge/cache/__init__.py +18 -0
- causaliq_knowledge/cache/encoders/__init__.py +13 -0
- causaliq_knowledge/cache/encoders/base.py +90 -0
- causaliq_knowledge/cache/encoders/json_encoder.py +430 -0
- causaliq_knowledge/cache/token_cache.py +666 -0
- causaliq_knowledge/cli/__init__.py +15 -0
- causaliq_knowledge/cli/cache.py +478 -0
- causaliq_knowledge/cli/generate.py +410 -0
- causaliq_knowledge/cli/main.py +172 -0
- causaliq_knowledge/cli/models.py +309 -0
- causaliq_knowledge/graph/__init__.py +78 -0
- causaliq_knowledge/graph/generator.py +457 -0
- causaliq_knowledge/graph/loader.py +222 -0
- causaliq_knowledge/graph/models.py +426 -0
- causaliq_knowledge/graph/params.py +175 -0
- causaliq_knowledge/graph/prompts.py +445 -0
- causaliq_knowledge/graph/response.py +392 -0
- causaliq_knowledge/graph/view_filter.py +154 -0
- causaliq_knowledge/llm/base_client.py +147 -1
- causaliq_knowledge/llm/cache.py +443 -0
- causaliq_knowledge/py.typed +0 -0
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/METADATA +10 -6
- causaliq_knowledge-0.4.0.dist-info/RECORD +42 -0
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/WHEEL +1 -1
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/entry_points.txt +3 -0
- causaliq_knowledge/cli.py +0 -414
- causaliq_knowledge-0.2.0.dist-info/RECORD +0 -22
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
"""Response models and parsing for LLM graph generation.
|
|
2
|
+
|
|
3
|
+
This module provides Pydantic models for representing LLM-generated
|
|
4
|
+
causal graphs and functions for parsing LLM responses in both edge
|
|
5
|
+
list and adjacency matrix formats.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime, timezone
|
|
14
|
+
from typing import Any, Dict, List, Optional
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, Field, field_validator
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProposedEdge(BaseModel):
|
|
22
|
+
"""A proposed causal edge from LLM graph generation.
|
|
23
|
+
|
|
24
|
+
Represents a single directed edge in the proposed causal graph,
|
|
25
|
+
with confidence score and optional reasoning.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
source: The name of the source variable (cause).
|
|
29
|
+
target: The name of the target variable (effect).
|
|
30
|
+
confidence: Confidence score from 0.0 to 1.0.
|
|
31
|
+
reasoning: Optional explanation for this specific edge.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> edge = ProposedEdge(
|
|
35
|
+
... source="smoking",
|
|
36
|
+
... target="lung_cancer",
|
|
37
|
+
... confidence=0.95,
|
|
38
|
+
... )
|
|
39
|
+
>>> print(f"{edge.source} -> {edge.target}: {edge.confidence}")
|
|
40
|
+
smoking -> lung_cancer: 0.95
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
source: str = Field(..., description="Source variable name (cause)")
|
|
44
|
+
target: str = Field(..., description="Target variable name (effect)")
|
|
45
|
+
confidence: float = Field(
|
|
46
|
+
default=0.5,
|
|
47
|
+
ge=0.0,
|
|
48
|
+
le=1.0,
|
|
49
|
+
description="Confidence score from 0.0 to 1.0",
|
|
50
|
+
)
|
|
51
|
+
reasoning: Optional[str] = Field(
|
|
52
|
+
default=None,
|
|
53
|
+
description="Optional explanation for this edge",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@field_validator("confidence", mode="before")
|
|
57
|
+
@classmethod
|
|
58
|
+
def clamp_confidence(cls, v: Any) -> float:
|
|
59
|
+
"""Clamp confidence values to [0.0, 1.0] range."""
|
|
60
|
+
if v is None:
|
|
61
|
+
return 0.5
|
|
62
|
+
try:
|
|
63
|
+
val = float(v)
|
|
64
|
+
return max(0.0, min(1.0, val))
|
|
65
|
+
except (TypeError, ValueError):
|
|
66
|
+
return 0.5
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class GenerationMetadata:
|
|
71
|
+
"""Metadata about a graph generation request.
|
|
72
|
+
|
|
73
|
+
Attributes:
|
|
74
|
+
model: The LLM model used for generation.
|
|
75
|
+
provider: The LLM provider (e.g., "groq", "gemini").
|
|
76
|
+
timestamp: When the graph was generated.
|
|
77
|
+
latency_ms: Request latency in milliseconds.
|
|
78
|
+
input_tokens: Number of input tokens used.
|
|
79
|
+
output_tokens: Number of output tokens generated.
|
|
80
|
+
cost_usd: Estimated cost in USD (if available).
|
|
81
|
+
from_cache: Whether the response was from cache.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
model: str
|
|
85
|
+
provider: str = ""
|
|
86
|
+
timestamp: datetime = field(
|
|
87
|
+
default_factory=lambda: datetime.now(timezone.utc)
|
|
88
|
+
)
|
|
89
|
+
latency_ms: int = 0
|
|
90
|
+
input_tokens: int = 0
|
|
91
|
+
output_tokens: int = 0
|
|
92
|
+
cost_usd: float = 0.0
|
|
93
|
+
from_cache: bool = False
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class GeneratedGraph:
|
|
98
|
+
"""A complete causal graph generated by an LLM.
|
|
99
|
+
|
|
100
|
+
Represents the full output from an LLM graph generation query,
|
|
101
|
+
including all proposed edges, metadata, and the LLM's reasoning.
|
|
102
|
+
|
|
103
|
+
Attributes:
|
|
104
|
+
edges: List of proposed causal edges.
|
|
105
|
+
variables: List of variable names in the graph.
|
|
106
|
+
reasoning: Overall reasoning provided by the LLM.
|
|
107
|
+
metadata: Generation metadata (model, timing, etc.).
|
|
108
|
+
raw_response: The original LLM response for debugging.
|
|
109
|
+
|
|
110
|
+
Example:
|
|
111
|
+
>>> edge1 = ProposedEdge(
|
|
112
|
+
... source="age", target="income", confidence=0.7
|
|
113
|
+
... )
|
|
114
|
+
>>> edge2 = ProposedEdge(
|
|
115
|
+
... source="education", target="income", confidence=0.9
|
|
116
|
+
... )
|
|
117
|
+
>>> graph = GeneratedGraph(
|
|
118
|
+
... edges=[edge1, edge2],
|
|
119
|
+
... variables=["age", "education", "income"],
|
|
120
|
+
... reasoning="Age and education both influence income.",
|
|
121
|
+
... metadata=GenerationMetadata(model="llama-3.1-8b-instant"),
|
|
122
|
+
... )
|
|
123
|
+
>>> print(f"Generated {len(graph.edges)} edges")
|
|
124
|
+
Generated 2 edges
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
edges: List[ProposedEdge]
|
|
128
|
+
variables: List[str]
|
|
129
|
+
reasoning: str = ""
|
|
130
|
+
metadata: Optional[GenerationMetadata] = None
|
|
131
|
+
raw_response: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
|
132
|
+
|
|
133
|
+
def get_adjacency_matrix(self) -> List[List[float]]:
|
|
134
|
+
"""Convert edges to an adjacency matrix.
|
|
135
|
+
|
|
136
|
+
Creates a square matrix where entry (i,j) represents the
|
|
137
|
+
confidence that variable i causes variable j.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Square matrix of confidence scores.
|
|
141
|
+
"""
|
|
142
|
+
n = len(self.variables)
|
|
143
|
+
var_to_idx = {var: i for i, var in enumerate(self.variables)}
|
|
144
|
+
matrix = [[0.0] * n for _ in range(n)]
|
|
145
|
+
|
|
146
|
+
for edge in self.edges:
|
|
147
|
+
src_idx = var_to_idx.get(edge.source)
|
|
148
|
+
tgt_idx = var_to_idx.get(edge.target)
|
|
149
|
+
if src_idx is not None and tgt_idx is not None:
|
|
150
|
+
matrix[src_idx][tgt_idx] = edge.confidence
|
|
151
|
+
|
|
152
|
+
return matrix
|
|
153
|
+
|
|
154
|
+
def get_edge_list(self) -> List[tuple[str, str, float]]:
|
|
155
|
+
"""Get edges as a list of tuples.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
List of (source, target, confidence) tuples.
|
|
159
|
+
"""
|
|
160
|
+
return [
|
|
161
|
+
(edge.source, edge.target, edge.confidence) for edge in self.edges
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
def filter_by_confidence(self, threshold: float = 0.5) -> "GeneratedGraph":
|
|
165
|
+
"""Return a new graph with only edges above the threshold.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
threshold: Minimum confidence score to include.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
New GeneratedGraph with filtered edges.
|
|
172
|
+
"""
|
|
173
|
+
filtered_edges = [e for e in self.edges if e.confidence >= threshold]
|
|
174
|
+
return GeneratedGraph(
|
|
175
|
+
edges=filtered_edges,
|
|
176
|
+
variables=self.variables,
|
|
177
|
+
reasoning=self.reasoning,
|
|
178
|
+
metadata=self.metadata,
|
|
179
|
+
raw_response=self.raw_response,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def parse_edge_list_response(
|
|
184
|
+
response: Dict[str, Any],
|
|
185
|
+
variables: List[str],
|
|
186
|
+
) -> GeneratedGraph:
|
|
187
|
+
"""Parse an edge list format response from the LLM.
|
|
188
|
+
|
|
189
|
+
Expected JSON format:
|
|
190
|
+
{
|
|
191
|
+
"edges": [
|
|
192
|
+
{"source": "var1", "target": "var2", "confidence": 0.8},
|
|
193
|
+
...
|
|
194
|
+
],
|
|
195
|
+
"reasoning": "explanation"
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
response: Parsed JSON response from the LLM.
|
|
200
|
+
variables: List of valid variable names for validation.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
GeneratedGraph with parsed edges.
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
ValueError: If the response format is invalid.
|
|
207
|
+
"""
|
|
208
|
+
if not isinstance(response, dict):
|
|
209
|
+
raise ValueError(f"Expected dict response, got {type(response)}")
|
|
210
|
+
|
|
211
|
+
edges_data = response.get("edges", [])
|
|
212
|
+
if not isinstance(edges_data, list):
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"Expected 'edges' to be a list, got {type(edges_data)}"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
reasoning = response.get("reasoning", "")
|
|
218
|
+
if not isinstance(reasoning, str):
|
|
219
|
+
reasoning = str(reasoning)
|
|
220
|
+
|
|
221
|
+
valid_vars = set(variables)
|
|
222
|
+
edges = []
|
|
223
|
+
|
|
224
|
+
for i, edge_data in enumerate(edges_data):
|
|
225
|
+
if not isinstance(edge_data, dict):
|
|
226
|
+
logger.warning(f"Skipping invalid edge at index {i}: not a dict")
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
source = edge_data.get("source", "")
|
|
230
|
+
target = edge_data.get("target", "")
|
|
231
|
+
|
|
232
|
+
# Validate variable names
|
|
233
|
+
if source not in valid_vars:
|
|
234
|
+
logger.warning(f"Unknown source variable: {source}")
|
|
235
|
+
continue
|
|
236
|
+
if target not in valid_vars:
|
|
237
|
+
logger.warning(f"Unknown target variable: {target}")
|
|
238
|
+
continue
|
|
239
|
+
if source == target:
|
|
240
|
+
logger.warning(f"Skipping self-loop: {source} -> {target}")
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
confidence = edge_data.get("confidence", 0.5)
|
|
244
|
+
edge_reasoning = edge_data.get("reasoning")
|
|
245
|
+
|
|
246
|
+
edges.append(
|
|
247
|
+
ProposedEdge(
|
|
248
|
+
source=source,
|
|
249
|
+
target=target,
|
|
250
|
+
confidence=confidence,
|
|
251
|
+
reasoning=edge_reasoning,
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
return GeneratedGraph(
|
|
256
|
+
edges=edges,
|
|
257
|
+
variables=variables,
|
|
258
|
+
reasoning=reasoning,
|
|
259
|
+
raw_response=response,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def parse_adjacency_matrix_response(
|
|
264
|
+
response: Dict[str, Any],
|
|
265
|
+
expected_variables: List[str],
|
|
266
|
+
) -> GeneratedGraph:
|
|
267
|
+
"""Parse an adjacency matrix format response from the LLM.
|
|
268
|
+
|
|
269
|
+
Expected JSON format:
|
|
270
|
+
{
|
|
271
|
+
"variables": ["var1", "var2", "var3"],
|
|
272
|
+
"adjacency_matrix": [
|
|
273
|
+
[0.0, 0.8, 0.0],
|
|
274
|
+
[0.0, 0.0, 0.6],
|
|
275
|
+
[0.0, 0.0, 0.0]
|
|
276
|
+
],
|
|
277
|
+
"reasoning": "explanation"
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
response: Parsed JSON response from the LLM.
|
|
282
|
+
expected_variables: List of expected variable names for validation.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
GeneratedGraph with edges extracted from the matrix.
|
|
286
|
+
|
|
287
|
+
Raises:
|
|
288
|
+
ValueError: If the response format is invalid.
|
|
289
|
+
"""
|
|
290
|
+
if not isinstance(response, dict):
|
|
291
|
+
raise ValueError(f"Expected dict response, got {type(response)}")
|
|
292
|
+
|
|
293
|
+
variables = response.get("variables", [])
|
|
294
|
+
if not isinstance(variables, list):
|
|
295
|
+
raise ValueError(
|
|
296
|
+
f"Expected 'variables' to be a list, got {type(variables)}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
matrix = response.get("adjacency_matrix", [])
|
|
300
|
+
if not isinstance(matrix, list):
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"Expected 'adjacency_matrix' to be a list, got {type(matrix)}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
reasoning = response.get("reasoning", "")
|
|
306
|
+
if not isinstance(reasoning, str):
|
|
307
|
+
reasoning = str(reasoning)
|
|
308
|
+
|
|
309
|
+
# Validate matrix dimensions
|
|
310
|
+
n = len(variables)
|
|
311
|
+
if len(matrix) != n:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"Matrix has {len(matrix)} rows but {n} variables declared"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
for i, row in enumerate(matrix):
|
|
317
|
+
if not isinstance(row, list) or len(row) != n:
|
|
318
|
+
raise ValueError(f"Matrix row {i} has incorrect dimensions")
|
|
319
|
+
|
|
320
|
+
# Validate variables against expected
|
|
321
|
+
valid_vars = set(expected_variables)
|
|
322
|
+
for var in variables:
|
|
323
|
+
if var not in valid_vars:
|
|
324
|
+
logger.warning(f"Unexpected variable in response: {var}")
|
|
325
|
+
|
|
326
|
+
# Extract edges from matrix
|
|
327
|
+
edges = []
|
|
328
|
+
for i, source in enumerate(variables):
|
|
329
|
+
for j, target in enumerate(variables):
|
|
330
|
+
confidence = matrix[i][j]
|
|
331
|
+
try:
|
|
332
|
+
confidence = float(confidence)
|
|
333
|
+
except (TypeError, ValueError):
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
if confidence > 0.0 and i != j: # Skip zeros and self-loops
|
|
337
|
+
edges.append(
|
|
338
|
+
ProposedEdge(
|
|
339
|
+
source=source,
|
|
340
|
+
target=target,
|
|
341
|
+
confidence=max(0.0, min(1.0, confidence)),
|
|
342
|
+
)
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return GeneratedGraph(
|
|
346
|
+
edges=edges,
|
|
347
|
+
variables=variables,
|
|
348
|
+
reasoning=reasoning,
|
|
349
|
+
raw_response=response,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def parse_graph_response(
|
|
354
|
+
response_text: str,
|
|
355
|
+
variables: List[str],
|
|
356
|
+
output_format: str = "edge_list",
|
|
357
|
+
) -> GeneratedGraph:
|
|
358
|
+
"""Parse an LLM response into a GeneratedGraph.
|
|
359
|
+
|
|
360
|
+
Handles JSON extraction from markdown code blocks and parses
|
|
361
|
+
according to the specified output format.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
response_text: Raw text response from the LLM.
|
|
365
|
+
variables: List of valid variable names.
|
|
366
|
+
output_format: Expected format ("edge_list" or "adjacency_matrix").
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
GeneratedGraph with parsed edges and metadata.
|
|
370
|
+
|
|
371
|
+
Raises:
|
|
372
|
+
ValueError: If JSON parsing fails or format is invalid.
|
|
373
|
+
"""
|
|
374
|
+
# Clean up potential markdown code blocks
|
|
375
|
+
text = response_text.strip()
|
|
376
|
+
if text.startswith("```json"):
|
|
377
|
+
text = text[7:]
|
|
378
|
+
elif text.startswith("```"):
|
|
379
|
+
text = text[3:]
|
|
380
|
+
if text.endswith("```"):
|
|
381
|
+
text = text[:-3]
|
|
382
|
+
text = text.strip()
|
|
383
|
+
|
|
384
|
+
try:
|
|
385
|
+
response = json.loads(text)
|
|
386
|
+
except json.JSONDecodeError as e:
|
|
387
|
+
raise ValueError(f"Failed to parse JSON response: {e}")
|
|
388
|
+
|
|
389
|
+
if output_format == "adjacency_matrix":
|
|
390
|
+
return parse_adjacency_matrix_response(response, variables)
|
|
391
|
+
else:
|
|
392
|
+
return parse_edge_list_response(response, variables)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Prompt detail filtering for model specifications.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to extract filtered views
|
|
4
|
+
(minimal, standard, rich) from model specifications for LLM prompts.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from causaliq_knowledge.graph.models import ModelSpec, VariableSpec
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PromptDetail(str, Enum):
|
|
16
|
+
"""Level of detail for variable information in prompts."""
|
|
17
|
+
|
|
18
|
+
MINIMAL = "minimal"
|
|
19
|
+
STANDARD = "standard"
|
|
20
|
+
RICH = "rich"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ViewFilter:
|
|
24
|
+
"""Filter model specifications to extract specific detail levels.
|
|
25
|
+
|
|
26
|
+
This class extracts variable information according to the view
|
|
27
|
+
definitions in the model specification (minimal, standard, rich).
|
|
28
|
+
|
|
29
|
+
By default, llm_name is substituted for name in the output to prevent
|
|
30
|
+
LLM memorisation of benchmark networks. Use use_llm_names=False to
|
|
31
|
+
output benchmark names directly (for memorisation testing).
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> spec = ModelLoader.load("model.json")
|
|
35
|
+
>>> view_filter = ViewFilter(spec)
|
|
36
|
+
>>> minimal_vars = view_filter.filter_variables(PromptDetail.MINIMAL)
|
|
37
|
+
>>> # Returns list of dicts with llm_name as 'name' field
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, spec: ModelSpec, *, use_llm_names: bool = True) -> None:
|
|
41
|
+
"""Initialise the view filter.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
spec: The model specification to filter.
|
|
45
|
+
use_llm_names: If True (default), output llm_name as 'name'.
|
|
46
|
+
If False, output benchmark name as 'name'.
|
|
47
|
+
"""
|
|
48
|
+
self._spec = spec
|
|
49
|
+
self._use_llm_names = use_llm_names
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def spec(self) -> ModelSpec:
|
|
53
|
+
"""Return the model specification."""
|
|
54
|
+
return self._spec
|
|
55
|
+
|
|
56
|
+
def get_include_fields(self, level: PromptDetail) -> list[str]:
|
|
57
|
+
"""Get the fields to include for a given detail level.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
level: The prompt detail level (minimal, standard, rich).
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List of field names to include.
|
|
64
|
+
"""
|
|
65
|
+
if level == PromptDetail.MINIMAL:
|
|
66
|
+
return self._spec.prompt_details.minimal.include_fields
|
|
67
|
+
elif level == PromptDetail.STANDARD:
|
|
68
|
+
return self._spec.prompt_details.standard.include_fields
|
|
69
|
+
elif level == PromptDetail.RICH:
|
|
70
|
+
return self._spec.prompt_details.rich.include_fields
|
|
71
|
+
else: # pragma: no cover
|
|
72
|
+
raise ValueError(f"Unknown prompt detail level: {level}")
|
|
73
|
+
|
|
74
|
+
def filter_variable(
|
|
75
|
+
self,
|
|
76
|
+
variable: VariableSpec,
|
|
77
|
+
level: PromptDetail,
|
|
78
|
+
) -> dict[str, Any]:
|
|
79
|
+
"""Filter a single variable to include only specified fields.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
variable: The variable specification to filter.
|
|
83
|
+
level: The prompt detail level determining which fields to include.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Dictionary with only the fields specified by the detail level.
|
|
87
|
+
Enum values are converted to their string representations.
|
|
88
|
+
If use_llm_names is True, the 'name' field contains llm_name.
|
|
89
|
+
"""
|
|
90
|
+
include_fields = self.get_include_fields(level)
|
|
91
|
+
# Use mode="json" to convert enums to their string values
|
|
92
|
+
var_dict = variable.model_dump(mode="json")
|
|
93
|
+
|
|
94
|
+
# If using llm_names, substitute llm_name for name in output
|
|
95
|
+
if self._use_llm_names and "name" in include_fields:
|
|
96
|
+
var_dict["name"] = var_dict.get("llm_name", var_dict["name"])
|
|
97
|
+
|
|
98
|
+
# Never include llm_name in output (it's internal)
|
|
99
|
+
return {
|
|
100
|
+
key: value
|
|
101
|
+
for key, value in var_dict.items()
|
|
102
|
+
if key in include_fields
|
|
103
|
+
and key != "llm_name"
|
|
104
|
+
and value is not None
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
def filter_variables(self, level: PromptDetail) -> list[dict[str, Any]]:
|
|
108
|
+
"""Filter all variables to the specified detail level.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
level: The prompt detail level (minimal, standard, rich).
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
List of filtered variable dictionaries.
|
|
115
|
+
"""
|
|
116
|
+
return [
|
|
117
|
+
self.filter_variable(var, level) for var in self._spec.variables
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
def get_variable_names(self) -> list[str]:
|
|
121
|
+
"""Get all variable names for LLM output.
|
|
122
|
+
|
|
123
|
+
Returns benchmark names if use_llm_names is False,
|
|
124
|
+
otherwise returns llm_names.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
List of variable names.
|
|
128
|
+
"""
|
|
129
|
+
if self._use_llm_names:
|
|
130
|
+
return self._spec.get_llm_names()
|
|
131
|
+
return self._spec.get_variable_names()
|
|
132
|
+
|
|
133
|
+
def get_domain(self) -> str:
|
|
134
|
+
"""Get the domain from the specification.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
The domain string.
|
|
138
|
+
"""
|
|
139
|
+
return self._spec.domain
|
|
140
|
+
|
|
141
|
+
def get_context_summary(self, level: PromptDetail) -> dict[str, Any]:
|
|
142
|
+
"""Get a complete context summary for LLM prompts.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
level: The prompt detail level for variable filtering.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Dictionary with domain and filtered variables.
|
|
149
|
+
"""
|
|
150
|
+
return {
|
|
151
|
+
"domain": self._spec.domain,
|
|
152
|
+
"dataset_id": self._spec.dataset_id,
|
|
153
|
+
"variables": self.filter_variables(level),
|
|
154
|
+
}
|
|
@@ -5,11 +5,18 @@ must implement. This provides a consistent API regardless of the
|
|
|
5
5
|
underlying LLM provider.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
8
11
|
import json
|
|
9
12
|
import logging
|
|
13
|
+
import time
|
|
10
14
|
from abc import ABC, abstractmethod
|
|
11
15
|
from dataclasses import dataclass, field
|
|
12
|
-
from typing import Any, Dict, List, Optional
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
19
|
+
from causaliq_knowledge.cache import TokenCache
|
|
13
20
|
|
|
14
21
|
logger = logging.getLogger(__name__)
|
|
15
22
|
|
|
@@ -218,3 +225,142 @@ class BaseLLMClient(ABC):
|
|
|
218
225
|
Model identifier string.
|
|
219
226
|
"""
|
|
220
227
|
return getattr(self, "config", LLMConfig(model="unknown")).model
|
|
228
|
+
|
|
229
|
+
def _build_cache_key(
|
|
230
|
+
self,
|
|
231
|
+
messages: List[Dict[str, str]],
|
|
232
|
+
temperature: Optional[float] = None,
|
|
233
|
+
max_tokens: Optional[int] = None,
|
|
234
|
+
) -> str:
|
|
235
|
+
"""Build a deterministic cache key for the request.
|
|
236
|
+
|
|
237
|
+
Creates a SHA-256 hash from the model, messages, temperature, and
|
|
238
|
+
max_tokens. The hash is truncated to 16 hex characters (64 bits).
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
242
|
+
temperature: Sampling temperature (defaults to config value).
|
|
243
|
+
max_tokens: Maximum tokens (defaults to config value).
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
16-character hex string cache key.
|
|
247
|
+
"""
|
|
248
|
+
config = getattr(self, "config", LLMConfig(model="unknown"))
|
|
249
|
+
key_data = {
|
|
250
|
+
"model": config.model,
|
|
251
|
+
"messages": messages,
|
|
252
|
+
"temperature": (
|
|
253
|
+
temperature if temperature is not None else config.temperature
|
|
254
|
+
),
|
|
255
|
+
"max_tokens": (
|
|
256
|
+
max_tokens if max_tokens is not None else config.max_tokens
|
|
257
|
+
),
|
|
258
|
+
}
|
|
259
|
+
key_json = json.dumps(key_data, sort_keys=True, separators=(",", ":"))
|
|
260
|
+
return hashlib.sha256(key_json.encode()).hexdigest()[:16]
|
|
261
|
+
|
|
262
|
+
def set_cache(
|
|
263
|
+
self,
|
|
264
|
+
cache: Optional["TokenCache"],
|
|
265
|
+
use_cache: bool = True,
|
|
266
|
+
) -> None:
|
|
267
|
+
"""Configure caching for this client.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
cache: TokenCache instance for caching, or None to disable.
|
|
271
|
+
use_cache: Whether to use the cache (default True).
|
|
272
|
+
"""
|
|
273
|
+
self._cache = cache
|
|
274
|
+
self._use_cache = use_cache
|
|
275
|
+
|
|
276
|
+
@property
|
|
277
|
+
def cache(self) -> Optional["TokenCache"]:
|
|
278
|
+
"""Return the configured cache, if any."""
|
|
279
|
+
return getattr(self, "_cache", None)
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def use_cache(self) -> bool:
|
|
283
|
+
"""Return whether caching is enabled."""
|
|
284
|
+
return getattr(self, "_use_cache", True)
|
|
285
|
+
|
|
286
|
+
def cached_completion(
|
|
287
|
+
self,
|
|
288
|
+
messages: List[Dict[str, str]],
|
|
289
|
+
**kwargs: Any,
|
|
290
|
+
) -> LLMResponse:
|
|
291
|
+
"""Make a completion request with caching.
|
|
292
|
+
|
|
293
|
+
If caching is enabled and a cached response exists, returns
|
|
294
|
+
the cached response without making an API call. Otherwise,
|
|
295
|
+
makes the API call and caches the result.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
messages: List of message dicts with "role" and "content" keys.
|
|
299
|
+
**kwargs: Provider-specific options (temperature, max_tokens, etc.)
|
|
300
|
+
Also accepts request_id (str) for identifying requests in
|
|
301
|
+
exports. Note: request_id is NOT part of the cache key.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
LLMResponse with the generated content and metadata.
|
|
305
|
+
"""
|
|
306
|
+
from causaliq_knowledge.llm.cache import LLMCacheEntry, LLMEntryEncoder
|
|
307
|
+
|
|
308
|
+
cache = self.cache
|
|
309
|
+
use_cache = self.use_cache
|
|
310
|
+
|
|
311
|
+
# Extract request_id (not part of cache key)
|
|
312
|
+
request_id = kwargs.pop("request_id", "")
|
|
313
|
+
|
|
314
|
+
# Build cache key
|
|
315
|
+
temperature = kwargs.get("temperature")
|
|
316
|
+
max_tokens = kwargs.get("max_tokens")
|
|
317
|
+
cache_key = self._build_cache_key(messages, temperature, max_tokens)
|
|
318
|
+
|
|
319
|
+
# Check cache
|
|
320
|
+
if use_cache and cache is not None:
|
|
321
|
+
# Ensure encoder is registered
|
|
322
|
+
if not cache.has_encoder("llm"):
|
|
323
|
+
cache.register_encoder("llm", LLMEntryEncoder())
|
|
324
|
+
|
|
325
|
+
if cache.exists(cache_key, "llm"):
|
|
326
|
+
cached_data = cache.get_data(cache_key, "llm")
|
|
327
|
+
if cached_data is not None:
|
|
328
|
+
entry = LLMCacheEntry.from_dict(cached_data)
|
|
329
|
+
return LLMResponse(
|
|
330
|
+
content=entry.response.content,
|
|
331
|
+
model=entry.model,
|
|
332
|
+
input_tokens=entry.metadata.tokens.input,
|
|
333
|
+
output_tokens=entry.metadata.tokens.output,
|
|
334
|
+
cost=entry.metadata.cost_usd or 0.0,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Make API call with timing
|
|
338
|
+
start_time = time.perf_counter()
|
|
339
|
+
response = self.completion(messages, **kwargs)
|
|
340
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
341
|
+
|
|
342
|
+
# Store in cache
|
|
343
|
+
if use_cache and cache is not None:
|
|
344
|
+
config = getattr(self, "config", LLMConfig(model="unknown"))
|
|
345
|
+
entry = LLMCacheEntry.create(
|
|
346
|
+
model=config.model,
|
|
347
|
+
messages=messages,
|
|
348
|
+
content=response.content,
|
|
349
|
+
temperature=(
|
|
350
|
+
temperature
|
|
351
|
+
if temperature is not None
|
|
352
|
+
else config.temperature
|
|
353
|
+
),
|
|
354
|
+
max_tokens=(
|
|
355
|
+
max_tokens if max_tokens is not None else config.max_tokens
|
|
356
|
+
),
|
|
357
|
+
provider=self.provider_name,
|
|
358
|
+
latency_ms=latency_ms,
|
|
359
|
+
input_tokens=response.input_tokens,
|
|
360
|
+
output_tokens=response.output_tokens,
|
|
361
|
+
cost_usd=response.cost,
|
|
362
|
+
request_id=request_id,
|
|
363
|
+
)
|
|
364
|
+
cache.put_data(cache_key, "llm", entry.to_dict())
|
|
365
|
+
|
|
366
|
+
return response
|