causaliq-knowledge 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +5 -2
- causaliq_knowledge/action.py +480 -0
- causaliq_knowledge/cache/encoders/json_encoder.py +15 -3
- causaliq_knowledge/cache/token_cache.py +36 -2
- causaliq_knowledge/cli/__init__.py +15 -0
- causaliq_knowledge/cli/cache.py +478 -0
- causaliq_knowledge/cli/generate.py +410 -0
- causaliq_knowledge/cli/main.py +172 -0
- causaliq_knowledge/cli/models.py +309 -0
- causaliq_knowledge/graph/__init__.py +78 -0
- causaliq_knowledge/graph/generator.py +457 -0
- causaliq_knowledge/graph/loader.py +222 -0
- causaliq_knowledge/graph/models.py +426 -0
- causaliq_knowledge/graph/params.py +175 -0
- causaliq_knowledge/graph/prompts.py +445 -0
- causaliq_knowledge/graph/response.py +392 -0
- causaliq_knowledge/graph/view_filter.py +154 -0
- causaliq_knowledge/llm/base_client.py +6 -0
- causaliq_knowledge/llm/cache.py +124 -61
- causaliq_knowledge/py.typed +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/METADATA +10 -6
- causaliq_knowledge-0.4.0.dist-info/RECORD +42 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/entry_points.txt +3 -0
- causaliq_knowledge/cli.py +0 -757
- causaliq_knowledge-0.3.0.dist-info/RECORD +0 -28
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/WHEEL +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
"""Response models and parsing for LLM graph generation.
|
|
2
|
+
|
|
3
|
+
This module provides Pydantic models for representing LLM-generated
|
|
4
|
+
causal graphs and functions for parsing LLM responses in both edge
|
|
5
|
+
list and adjacency matrix formats.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime, timezone
|
|
14
|
+
from typing import Any, Dict, List, Optional
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, Field, field_validator
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProposedEdge(BaseModel):
|
|
22
|
+
"""A proposed causal edge from LLM graph generation.
|
|
23
|
+
|
|
24
|
+
Represents a single directed edge in the proposed causal graph,
|
|
25
|
+
with confidence score and optional reasoning.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
source: The name of the source variable (cause).
|
|
29
|
+
target: The name of the target variable (effect).
|
|
30
|
+
confidence: Confidence score from 0.0 to 1.0.
|
|
31
|
+
reasoning: Optional explanation for this specific edge.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> edge = ProposedEdge(
|
|
35
|
+
... source="smoking",
|
|
36
|
+
... target="lung_cancer",
|
|
37
|
+
... confidence=0.95,
|
|
38
|
+
... )
|
|
39
|
+
>>> print(f"{edge.source} -> {edge.target}: {edge.confidence}")
|
|
40
|
+
smoking -> lung_cancer: 0.95
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
source: str = Field(..., description="Source variable name (cause)")
|
|
44
|
+
target: str = Field(..., description="Target variable name (effect)")
|
|
45
|
+
confidence: float = Field(
|
|
46
|
+
default=0.5,
|
|
47
|
+
ge=0.0,
|
|
48
|
+
le=1.0,
|
|
49
|
+
description="Confidence score from 0.0 to 1.0",
|
|
50
|
+
)
|
|
51
|
+
reasoning: Optional[str] = Field(
|
|
52
|
+
default=None,
|
|
53
|
+
description="Optional explanation for this edge",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@field_validator("confidence", mode="before")
|
|
57
|
+
@classmethod
|
|
58
|
+
def clamp_confidence(cls, v: Any) -> float:
|
|
59
|
+
"""Clamp confidence values to [0.0, 1.0] range."""
|
|
60
|
+
if v is None:
|
|
61
|
+
return 0.5
|
|
62
|
+
try:
|
|
63
|
+
val = float(v)
|
|
64
|
+
return max(0.0, min(1.0, val))
|
|
65
|
+
except (TypeError, ValueError):
|
|
66
|
+
return 0.5
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class GenerationMetadata:
|
|
71
|
+
"""Metadata about a graph generation request.
|
|
72
|
+
|
|
73
|
+
Attributes:
|
|
74
|
+
model: The LLM model used for generation.
|
|
75
|
+
provider: The LLM provider (e.g., "groq", "gemini").
|
|
76
|
+
timestamp: When the graph was generated.
|
|
77
|
+
latency_ms: Request latency in milliseconds.
|
|
78
|
+
input_tokens: Number of input tokens used.
|
|
79
|
+
output_tokens: Number of output tokens generated.
|
|
80
|
+
cost_usd: Estimated cost in USD (if available).
|
|
81
|
+
from_cache: Whether the response was from cache.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
model: str
|
|
85
|
+
provider: str = ""
|
|
86
|
+
timestamp: datetime = field(
|
|
87
|
+
default_factory=lambda: datetime.now(timezone.utc)
|
|
88
|
+
)
|
|
89
|
+
latency_ms: int = 0
|
|
90
|
+
input_tokens: int = 0
|
|
91
|
+
output_tokens: int = 0
|
|
92
|
+
cost_usd: float = 0.0
|
|
93
|
+
from_cache: bool = False
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class GeneratedGraph:
|
|
98
|
+
"""A complete causal graph generated by an LLM.
|
|
99
|
+
|
|
100
|
+
Represents the full output from an LLM graph generation query,
|
|
101
|
+
including all proposed edges, metadata, and the LLM's reasoning.
|
|
102
|
+
|
|
103
|
+
Attributes:
|
|
104
|
+
edges: List of proposed causal edges.
|
|
105
|
+
variables: List of variable names in the graph.
|
|
106
|
+
reasoning: Overall reasoning provided by the LLM.
|
|
107
|
+
metadata: Generation metadata (model, timing, etc.).
|
|
108
|
+
raw_response: The original LLM response for debugging.
|
|
109
|
+
|
|
110
|
+
Example:
|
|
111
|
+
>>> edge1 = ProposedEdge(
|
|
112
|
+
... source="age", target="income", confidence=0.7
|
|
113
|
+
... )
|
|
114
|
+
>>> edge2 = ProposedEdge(
|
|
115
|
+
... source="education", target="income", confidence=0.9
|
|
116
|
+
... )
|
|
117
|
+
>>> graph = GeneratedGraph(
|
|
118
|
+
... edges=[edge1, edge2],
|
|
119
|
+
... variables=["age", "education", "income"],
|
|
120
|
+
... reasoning="Age and education both influence income.",
|
|
121
|
+
... metadata=GenerationMetadata(model="llama-3.1-8b-instant"),
|
|
122
|
+
... )
|
|
123
|
+
>>> print(f"Generated {len(graph.edges)} edges")
|
|
124
|
+
Generated 2 edges
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
edges: List[ProposedEdge]
|
|
128
|
+
variables: List[str]
|
|
129
|
+
reasoning: str = ""
|
|
130
|
+
metadata: Optional[GenerationMetadata] = None
|
|
131
|
+
raw_response: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
|
132
|
+
|
|
133
|
+
def get_adjacency_matrix(self) -> List[List[float]]:
|
|
134
|
+
"""Convert edges to an adjacency matrix.
|
|
135
|
+
|
|
136
|
+
Creates a square matrix where entry (i,j) represents the
|
|
137
|
+
confidence that variable i causes variable j.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Square matrix of confidence scores.
|
|
141
|
+
"""
|
|
142
|
+
n = len(self.variables)
|
|
143
|
+
var_to_idx = {var: i for i, var in enumerate(self.variables)}
|
|
144
|
+
matrix = [[0.0] * n for _ in range(n)]
|
|
145
|
+
|
|
146
|
+
for edge in self.edges:
|
|
147
|
+
src_idx = var_to_idx.get(edge.source)
|
|
148
|
+
tgt_idx = var_to_idx.get(edge.target)
|
|
149
|
+
if src_idx is not None and tgt_idx is not None:
|
|
150
|
+
matrix[src_idx][tgt_idx] = edge.confidence
|
|
151
|
+
|
|
152
|
+
return matrix
|
|
153
|
+
|
|
154
|
+
def get_edge_list(self) -> List[tuple[str, str, float]]:
|
|
155
|
+
"""Get edges as a list of tuples.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
List of (source, target, confidence) tuples.
|
|
159
|
+
"""
|
|
160
|
+
return [
|
|
161
|
+
(edge.source, edge.target, edge.confidence) for edge in self.edges
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
def filter_by_confidence(self, threshold: float = 0.5) -> "GeneratedGraph":
|
|
165
|
+
"""Return a new graph with only edges above the threshold.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
threshold: Minimum confidence score to include.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
New GeneratedGraph with filtered edges.
|
|
172
|
+
"""
|
|
173
|
+
filtered_edges = [e for e in self.edges if e.confidence >= threshold]
|
|
174
|
+
return GeneratedGraph(
|
|
175
|
+
edges=filtered_edges,
|
|
176
|
+
variables=self.variables,
|
|
177
|
+
reasoning=self.reasoning,
|
|
178
|
+
metadata=self.metadata,
|
|
179
|
+
raw_response=self.raw_response,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def parse_edge_list_response(
|
|
184
|
+
response: Dict[str, Any],
|
|
185
|
+
variables: List[str],
|
|
186
|
+
) -> GeneratedGraph:
|
|
187
|
+
"""Parse an edge list format response from the LLM.
|
|
188
|
+
|
|
189
|
+
Expected JSON format:
|
|
190
|
+
{
|
|
191
|
+
"edges": [
|
|
192
|
+
{"source": "var1", "target": "var2", "confidence": 0.8},
|
|
193
|
+
...
|
|
194
|
+
],
|
|
195
|
+
"reasoning": "explanation"
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
response: Parsed JSON response from the LLM.
|
|
200
|
+
variables: List of valid variable names for validation.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
GeneratedGraph with parsed edges.
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
ValueError: If the response format is invalid.
|
|
207
|
+
"""
|
|
208
|
+
if not isinstance(response, dict):
|
|
209
|
+
raise ValueError(f"Expected dict response, got {type(response)}")
|
|
210
|
+
|
|
211
|
+
edges_data = response.get("edges", [])
|
|
212
|
+
if not isinstance(edges_data, list):
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"Expected 'edges' to be a list, got {type(edges_data)}"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
reasoning = response.get("reasoning", "")
|
|
218
|
+
if not isinstance(reasoning, str):
|
|
219
|
+
reasoning = str(reasoning)
|
|
220
|
+
|
|
221
|
+
valid_vars = set(variables)
|
|
222
|
+
edges = []
|
|
223
|
+
|
|
224
|
+
for i, edge_data in enumerate(edges_data):
|
|
225
|
+
if not isinstance(edge_data, dict):
|
|
226
|
+
logger.warning(f"Skipping invalid edge at index {i}: not a dict")
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
source = edge_data.get("source", "")
|
|
230
|
+
target = edge_data.get("target", "")
|
|
231
|
+
|
|
232
|
+
# Validate variable names
|
|
233
|
+
if source not in valid_vars:
|
|
234
|
+
logger.warning(f"Unknown source variable: {source}")
|
|
235
|
+
continue
|
|
236
|
+
if target not in valid_vars:
|
|
237
|
+
logger.warning(f"Unknown target variable: {target}")
|
|
238
|
+
continue
|
|
239
|
+
if source == target:
|
|
240
|
+
logger.warning(f"Skipping self-loop: {source} -> {target}")
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
confidence = edge_data.get("confidence", 0.5)
|
|
244
|
+
edge_reasoning = edge_data.get("reasoning")
|
|
245
|
+
|
|
246
|
+
edges.append(
|
|
247
|
+
ProposedEdge(
|
|
248
|
+
source=source,
|
|
249
|
+
target=target,
|
|
250
|
+
confidence=confidence,
|
|
251
|
+
reasoning=edge_reasoning,
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
return GeneratedGraph(
|
|
256
|
+
edges=edges,
|
|
257
|
+
variables=variables,
|
|
258
|
+
reasoning=reasoning,
|
|
259
|
+
raw_response=response,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def parse_adjacency_matrix_response(
|
|
264
|
+
response: Dict[str, Any],
|
|
265
|
+
expected_variables: List[str],
|
|
266
|
+
) -> GeneratedGraph:
|
|
267
|
+
"""Parse an adjacency matrix format response from the LLM.
|
|
268
|
+
|
|
269
|
+
Expected JSON format:
|
|
270
|
+
{
|
|
271
|
+
"variables": ["var1", "var2", "var3"],
|
|
272
|
+
"adjacency_matrix": [
|
|
273
|
+
[0.0, 0.8, 0.0],
|
|
274
|
+
[0.0, 0.0, 0.6],
|
|
275
|
+
[0.0, 0.0, 0.0]
|
|
276
|
+
],
|
|
277
|
+
"reasoning": "explanation"
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
response: Parsed JSON response from the LLM.
|
|
282
|
+
expected_variables: List of expected variable names for validation.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
GeneratedGraph with edges extracted from the matrix.
|
|
286
|
+
|
|
287
|
+
Raises:
|
|
288
|
+
ValueError: If the response format is invalid.
|
|
289
|
+
"""
|
|
290
|
+
if not isinstance(response, dict):
|
|
291
|
+
raise ValueError(f"Expected dict response, got {type(response)}")
|
|
292
|
+
|
|
293
|
+
variables = response.get("variables", [])
|
|
294
|
+
if not isinstance(variables, list):
|
|
295
|
+
raise ValueError(
|
|
296
|
+
f"Expected 'variables' to be a list, got {type(variables)}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
matrix = response.get("adjacency_matrix", [])
|
|
300
|
+
if not isinstance(matrix, list):
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"Expected 'adjacency_matrix' to be a list, got {type(matrix)}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
reasoning = response.get("reasoning", "")
|
|
306
|
+
if not isinstance(reasoning, str):
|
|
307
|
+
reasoning = str(reasoning)
|
|
308
|
+
|
|
309
|
+
# Validate matrix dimensions
|
|
310
|
+
n = len(variables)
|
|
311
|
+
if len(matrix) != n:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"Matrix has {len(matrix)} rows but {n} variables declared"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
for i, row in enumerate(matrix):
|
|
317
|
+
if not isinstance(row, list) or len(row) != n:
|
|
318
|
+
raise ValueError(f"Matrix row {i} has incorrect dimensions")
|
|
319
|
+
|
|
320
|
+
# Validate variables against expected
|
|
321
|
+
valid_vars = set(expected_variables)
|
|
322
|
+
for var in variables:
|
|
323
|
+
if var not in valid_vars:
|
|
324
|
+
logger.warning(f"Unexpected variable in response: {var}")
|
|
325
|
+
|
|
326
|
+
# Extract edges from matrix
|
|
327
|
+
edges = []
|
|
328
|
+
for i, source in enumerate(variables):
|
|
329
|
+
for j, target in enumerate(variables):
|
|
330
|
+
confidence = matrix[i][j]
|
|
331
|
+
try:
|
|
332
|
+
confidence = float(confidence)
|
|
333
|
+
except (TypeError, ValueError):
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
if confidence > 0.0 and i != j: # Skip zeros and self-loops
|
|
337
|
+
edges.append(
|
|
338
|
+
ProposedEdge(
|
|
339
|
+
source=source,
|
|
340
|
+
target=target,
|
|
341
|
+
confidence=max(0.0, min(1.0, confidence)),
|
|
342
|
+
)
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return GeneratedGraph(
|
|
346
|
+
edges=edges,
|
|
347
|
+
variables=variables,
|
|
348
|
+
reasoning=reasoning,
|
|
349
|
+
raw_response=response,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def parse_graph_response(
|
|
354
|
+
response_text: str,
|
|
355
|
+
variables: List[str],
|
|
356
|
+
output_format: str = "edge_list",
|
|
357
|
+
) -> GeneratedGraph:
|
|
358
|
+
"""Parse an LLM response into a GeneratedGraph.
|
|
359
|
+
|
|
360
|
+
Handles JSON extraction from markdown code blocks and parses
|
|
361
|
+
according to the specified output format.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
response_text: Raw text response from the LLM.
|
|
365
|
+
variables: List of valid variable names.
|
|
366
|
+
output_format: Expected format ("edge_list" or "adjacency_matrix").
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
GeneratedGraph with parsed edges and metadata.
|
|
370
|
+
|
|
371
|
+
Raises:
|
|
372
|
+
ValueError: If JSON parsing fails or format is invalid.
|
|
373
|
+
"""
|
|
374
|
+
# Clean up potential markdown code blocks
|
|
375
|
+
text = response_text.strip()
|
|
376
|
+
if text.startswith("```json"):
|
|
377
|
+
text = text[7:]
|
|
378
|
+
elif text.startswith("```"):
|
|
379
|
+
text = text[3:]
|
|
380
|
+
if text.endswith("```"):
|
|
381
|
+
text = text[:-3]
|
|
382
|
+
text = text.strip()
|
|
383
|
+
|
|
384
|
+
try:
|
|
385
|
+
response = json.loads(text)
|
|
386
|
+
except json.JSONDecodeError as e:
|
|
387
|
+
raise ValueError(f"Failed to parse JSON response: {e}")
|
|
388
|
+
|
|
389
|
+
if output_format == "adjacency_matrix":
|
|
390
|
+
return parse_adjacency_matrix_response(response, variables)
|
|
391
|
+
else:
|
|
392
|
+
return parse_edge_list_response(response, variables)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Prompt detail filtering for model specifications.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to extract filtered views
|
|
4
|
+
(minimal, standard, rich) from model specifications for LLM prompts.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from causaliq_knowledge.graph.models import ModelSpec, VariableSpec
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PromptDetail(str, Enum):
|
|
16
|
+
"""Level of detail for variable information in prompts."""
|
|
17
|
+
|
|
18
|
+
MINIMAL = "minimal"
|
|
19
|
+
STANDARD = "standard"
|
|
20
|
+
RICH = "rich"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ViewFilter:
|
|
24
|
+
"""Filter model specifications to extract specific detail levels.
|
|
25
|
+
|
|
26
|
+
This class extracts variable information according to the view
|
|
27
|
+
definitions in the model specification (minimal, standard, rich).
|
|
28
|
+
|
|
29
|
+
By default, llm_name is substituted for name in the output to prevent
|
|
30
|
+
LLM memorisation of benchmark networks. Use use_llm_names=False to
|
|
31
|
+
output benchmark names directly (for memorisation testing).
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> spec = ModelLoader.load("model.json")
|
|
35
|
+
>>> view_filter = ViewFilter(spec)
|
|
36
|
+
>>> minimal_vars = view_filter.filter_variables(PromptDetail.MINIMAL)
|
|
37
|
+
>>> # Returns list of dicts with llm_name as 'name' field
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, spec: ModelSpec, *, use_llm_names: bool = True) -> None:
|
|
41
|
+
"""Initialise the view filter.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
spec: The model specification to filter.
|
|
45
|
+
use_llm_names: If True (default), output llm_name as 'name'.
|
|
46
|
+
If False, output benchmark name as 'name'.
|
|
47
|
+
"""
|
|
48
|
+
self._spec = spec
|
|
49
|
+
self._use_llm_names = use_llm_names
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def spec(self) -> ModelSpec:
|
|
53
|
+
"""Return the model specification."""
|
|
54
|
+
return self._spec
|
|
55
|
+
|
|
56
|
+
def get_include_fields(self, level: PromptDetail) -> list[str]:
|
|
57
|
+
"""Get the fields to include for a given detail level.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
level: The prompt detail level (minimal, standard, rich).
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List of field names to include.
|
|
64
|
+
"""
|
|
65
|
+
if level == PromptDetail.MINIMAL:
|
|
66
|
+
return self._spec.prompt_details.minimal.include_fields
|
|
67
|
+
elif level == PromptDetail.STANDARD:
|
|
68
|
+
return self._spec.prompt_details.standard.include_fields
|
|
69
|
+
elif level == PromptDetail.RICH:
|
|
70
|
+
return self._spec.prompt_details.rich.include_fields
|
|
71
|
+
else: # pragma: no cover
|
|
72
|
+
raise ValueError(f"Unknown prompt detail level: {level}")
|
|
73
|
+
|
|
74
|
+
def filter_variable(
|
|
75
|
+
self,
|
|
76
|
+
variable: VariableSpec,
|
|
77
|
+
level: PromptDetail,
|
|
78
|
+
) -> dict[str, Any]:
|
|
79
|
+
"""Filter a single variable to include only specified fields.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
variable: The variable specification to filter.
|
|
83
|
+
level: The prompt detail level determining which fields to include.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Dictionary with only the fields specified by the detail level.
|
|
87
|
+
Enum values are converted to their string representations.
|
|
88
|
+
If use_llm_names is True, the 'name' field contains llm_name.
|
|
89
|
+
"""
|
|
90
|
+
include_fields = self.get_include_fields(level)
|
|
91
|
+
# Use mode="json" to convert enums to their string values
|
|
92
|
+
var_dict = variable.model_dump(mode="json")
|
|
93
|
+
|
|
94
|
+
# If using llm_names, substitute llm_name for name in output
|
|
95
|
+
if self._use_llm_names and "name" in include_fields:
|
|
96
|
+
var_dict["name"] = var_dict.get("llm_name", var_dict["name"])
|
|
97
|
+
|
|
98
|
+
# Never include llm_name in output (it's internal)
|
|
99
|
+
return {
|
|
100
|
+
key: value
|
|
101
|
+
for key, value in var_dict.items()
|
|
102
|
+
if key in include_fields
|
|
103
|
+
and key != "llm_name"
|
|
104
|
+
and value is not None
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
def filter_variables(self, level: PromptDetail) -> list[dict[str, Any]]:
|
|
108
|
+
"""Filter all variables to the specified detail level.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
level: The prompt detail level (minimal, standard, rich).
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
List of filtered variable dictionaries.
|
|
115
|
+
"""
|
|
116
|
+
return [
|
|
117
|
+
self.filter_variable(var, level) for var in self._spec.variables
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
def get_variable_names(self) -> list[str]:
|
|
121
|
+
"""Get all variable names for LLM output.
|
|
122
|
+
|
|
123
|
+
Returns benchmark names if use_llm_names is False,
|
|
124
|
+
otherwise returns llm_names.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
List of variable names.
|
|
128
|
+
"""
|
|
129
|
+
if self._use_llm_names:
|
|
130
|
+
return self._spec.get_llm_names()
|
|
131
|
+
return self._spec.get_variable_names()
|
|
132
|
+
|
|
133
|
+
def get_domain(self) -> str:
|
|
134
|
+
"""Get the domain from the specification.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
The domain string.
|
|
138
|
+
"""
|
|
139
|
+
return self._spec.domain
|
|
140
|
+
|
|
141
|
+
def get_context_summary(self, level: PromptDetail) -> dict[str, Any]:
|
|
142
|
+
"""Get a complete context summary for LLM prompts.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
level: The prompt detail level for variable filtering.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Dictionary with domain and filtered variables.
|
|
149
|
+
"""
|
|
150
|
+
return {
|
|
151
|
+
"domain": self._spec.domain,
|
|
152
|
+
"dataset_id": self._spec.dataset_id,
|
|
153
|
+
"variables": self.filter_variables(level),
|
|
154
|
+
}
|
|
@@ -297,6 +297,8 @@ class BaseLLMClient(ABC):
|
|
|
297
297
|
Args:
|
|
298
298
|
messages: List of message dicts with "role" and "content" keys.
|
|
299
299
|
**kwargs: Provider-specific options (temperature, max_tokens, etc.)
|
|
300
|
+
Also accepts request_id (str) for identifying requests in
|
|
301
|
+
exports. Note: request_id is NOT part of the cache key.
|
|
300
302
|
|
|
301
303
|
Returns:
|
|
302
304
|
LLMResponse with the generated content and metadata.
|
|
@@ -306,6 +308,9 @@ class BaseLLMClient(ABC):
|
|
|
306
308
|
cache = self.cache
|
|
307
309
|
use_cache = self.use_cache
|
|
308
310
|
|
|
311
|
+
# Extract request_id (not part of cache key)
|
|
312
|
+
request_id = kwargs.pop("request_id", "")
|
|
313
|
+
|
|
309
314
|
# Build cache key
|
|
310
315
|
temperature = kwargs.get("temperature")
|
|
311
316
|
max_tokens = kwargs.get("max_tokens")
|
|
@@ -354,6 +359,7 @@ class BaseLLMClient(ABC):
|
|
|
354
359
|
input_tokens=response.input_tokens,
|
|
355
360
|
output_tokens=response.output_tokens,
|
|
356
361
|
cost_usd=response.cost,
|
|
362
|
+
request_id=request_id,
|
|
357
363
|
)
|
|
358
364
|
cache.put_data(cache_key, "llm", entry.to_dict())
|
|
359
365
|
|