causaliq-knowledge 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,392 @@
1
+ """Response models and parsing for LLM graph generation.
2
+
3
+ This module provides Pydantic models for representing LLM-generated
4
+ causal graphs and functions for parsing LLM responses in both edge
5
+ list and adjacency matrix formats.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import logging
12
+ from dataclasses import dataclass, field
13
+ from datetime import datetime, timezone
14
+ from typing import Any, Dict, List, Optional
15
+
16
+ from pydantic import BaseModel, Field, field_validator
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ProposedEdge(BaseModel):
22
+ """A proposed causal edge from LLM graph generation.
23
+
24
+ Represents a single directed edge in the proposed causal graph,
25
+ with confidence score and optional reasoning.
26
+
27
+ Attributes:
28
+ source: The name of the source variable (cause).
29
+ target: The name of the target variable (effect).
30
+ confidence: Confidence score from 0.0 to 1.0.
31
+ reasoning: Optional explanation for this specific edge.
32
+
33
+ Example:
34
+ >>> edge = ProposedEdge(
35
+ ... source="smoking",
36
+ ... target="lung_cancer",
37
+ ... confidence=0.95,
38
+ ... )
39
+ >>> print(f"{edge.source} -> {edge.target}: {edge.confidence}")
40
+ smoking -> lung_cancer: 0.95
41
+ """
42
+
43
+ source: str = Field(..., description="Source variable name (cause)")
44
+ target: str = Field(..., description="Target variable name (effect)")
45
+ confidence: float = Field(
46
+ default=0.5,
47
+ ge=0.0,
48
+ le=1.0,
49
+ description="Confidence score from 0.0 to 1.0",
50
+ )
51
+ reasoning: Optional[str] = Field(
52
+ default=None,
53
+ description="Optional explanation for this edge",
54
+ )
55
+
56
+ @field_validator("confidence", mode="before")
57
+ @classmethod
58
+ def clamp_confidence(cls, v: Any) -> float:
59
+ """Clamp confidence values to [0.0, 1.0] range."""
60
+ if v is None:
61
+ return 0.5
62
+ try:
63
+ val = float(v)
64
+ return max(0.0, min(1.0, val))
65
+ except (TypeError, ValueError):
66
+ return 0.5
67
+
68
+
69
+ @dataclass
70
+ class GenerationMetadata:
71
+ """Metadata about a graph generation request.
72
+
73
+ Attributes:
74
+ model: The LLM model used for generation.
75
+ provider: The LLM provider (e.g., "groq", "gemini").
76
+ timestamp: When the graph was generated.
77
+ latency_ms: Request latency in milliseconds.
78
+ input_tokens: Number of input tokens used.
79
+ output_tokens: Number of output tokens generated.
80
+ cost_usd: Estimated cost in USD (if available).
81
+ from_cache: Whether the response was from cache.
82
+ """
83
+
84
+ model: str
85
+ provider: str = ""
86
+ timestamp: datetime = field(
87
+ default_factory=lambda: datetime.now(timezone.utc)
88
+ )
89
+ latency_ms: int = 0
90
+ input_tokens: int = 0
91
+ output_tokens: int = 0
92
+ cost_usd: float = 0.0
93
+ from_cache: bool = False
94
+
95
+
96
+ @dataclass
97
+ class GeneratedGraph:
98
+ """A complete causal graph generated by an LLM.
99
+
100
+ Represents the full output from an LLM graph generation query,
101
+ including all proposed edges, metadata, and the LLM's reasoning.
102
+
103
+ Attributes:
104
+ edges: List of proposed causal edges.
105
+ variables: List of variable names in the graph.
106
+ reasoning: Overall reasoning provided by the LLM.
107
+ metadata: Generation metadata (model, timing, etc.).
108
+ raw_response: The original LLM response for debugging.
109
+
110
+ Example:
111
+ >>> edge1 = ProposedEdge(
112
+ ... source="age", target="income", confidence=0.7
113
+ ... )
114
+ >>> edge2 = ProposedEdge(
115
+ ... source="education", target="income", confidence=0.9
116
+ ... )
117
+ >>> graph = GeneratedGraph(
118
+ ... edges=[edge1, edge2],
119
+ ... variables=["age", "education", "income"],
120
+ ... reasoning="Age and education both influence income.",
121
+ ... metadata=GenerationMetadata(model="llama-3.1-8b-instant"),
122
+ ... )
123
+ >>> print(f"Generated {len(graph.edges)} edges")
124
+ Generated 2 edges
125
+ """
126
+
127
+ edges: List[ProposedEdge]
128
+ variables: List[str]
129
+ reasoning: str = ""
130
+ metadata: Optional[GenerationMetadata] = None
131
+ raw_response: Optional[Dict[str, Any]] = field(default=None, repr=False)
132
+
133
+ def get_adjacency_matrix(self) -> List[List[float]]:
134
+ """Convert edges to an adjacency matrix.
135
+
136
+ Creates a square matrix where entry (i,j) represents the
137
+ confidence that variable i causes variable j.
138
+
139
+ Returns:
140
+ Square matrix of confidence scores.
141
+ """
142
+ n = len(self.variables)
143
+ var_to_idx = {var: i for i, var in enumerate(self.variables)}
144
+ matrix = [[0.0] * n for _ in range(n)]
145
+
146
+ for edge in self.edges:
147
+ src_idx = var_to_idx.get(edge.source)
148
+ tgt_idx = var_to_idx.get(edge.target)
149
+ if src_idx is not None and tgt_idx is not None:
150
+ matrix[src_idx][tgt_idx] = edge.confidence
151
+
152
+ return matrix
153
+
154
+ def get_edge_list(self) -> List[tuple[str, str, float]]:
155
+ """Get edges as a list of tuples.
156
+
157
+ Returns:
158
+ List of (source, target, confidence) tuples.
159
+ """
160
+ return [
161
+ (edge.source, edge.target, edge.confidence) for edge in self.edges
162
+ ]
163
+
164
+ def filter_by_confidence(self, threshold: float = 0.5) -> "GeneratedGraph":
165
+ """Return a new graph with only edges above the threshold.
166
+
167
+ Args:
168
+ threshold: Minimum confidence score to include.
169
+
170
+ Returns:
171
+ New GeneratedGraph with filtered edges.
172
+ """
173
+ filtered_edges = [e for e in self.edges if e.confidence >= threshold]
174
+ return GeneratedGraph(
175
+ edges=filtered_edges,
176
+ variables=self.variables,
177
+ reasoning=self.reasoning,
178
+ metadata=self.metadata,
179
+ raw_response=self.raw_response,
180
+ )
181
+
182
+
183
+ def parse_edge_list_response(
184
+ response: Dict[str, Any],
185
+ variables: List[str],
186
+ ) -> GeneratedGraph:
187
+ """Parse an edge list format response from the LLM.
188
+
189
+ Expected JSON format:
190
+ {
191
+ "edges": [
192
+ {"source": "var1", "target": "var2", "confidence": 0.8},
193
+ ...
194
+ ],
195
+ "reasoning": "explanation"
196
+ }
197
+
198
+ Args:
199
+ response: Parsed JSON response from the LLM.
200
+ variables: List of valid variable names for validation.
201
+
202
+ Returns:
203
+ GeneratedGraph with parsed edges.
204
+
205
+ Raises:
206
+ ValueError: If the response format is invalid.
207
+ """
208
+ if not isinstance(response, dict):
209
+ raise ValueError(f"Expected dict response, got {type(response)}")
210
+
211
+ edges_data = response.get("edges", [])
212
+ if not isinstance(edges_data, list):
213
+ raise ValueError(
214
+ f"Expected 'edges' to be a list, got {type(edges_data)}"
215
+ )
216
+
217
+ reasoning = response.get("reasoning", "")
218
+ if not isinstance(reasoning, str):
219
+ reasoning = str(reasoning)
220
+
221
+ valid_vars = set(variables)
222
+ edges = []
223
+
224
+ for i, edge_data in enumerate(edges_data):
225
+ if not isinstance(edge_data, dict):
226
+ logger.warning(f"Skipping invalid edge at index {i}: not a dict")
227
+ continue
228
+
229
+ source = edge_data.get("source", "")
230
+ target = edge_data.get("target", "")
231
+
232
+ # Validate variable names
233
+ if source not in valid_vars:
234
+ logger.warning(f"Unknown source variable: {source}")
235
+ continue
236
+ if target not in valid_vars:
237
+ logger.warning(f"Unknown target variable: {target}")
238
+ continue
239
+ if source == target:
240
+ logger.warning(f"Skipping self-loop: {source} -> {target}")
241
+ continue
242
+
243
+ confidence = edge_data.get("confidence", 0.5)
244
+ edge_reasoning = edge_data.get("reasoning")
245
+
246
+ edges.append(
247
+ ProposedEdge(
248
+ source=source,
249
+ target=target,
250
+ confidence=confidence,
251
+ reasoning=edge_reasoning,
252
+ )
253
+ )
254
+
255
+ return GeneratedGraph(
256
+ edges=edges,
257
+ variables=variables,
258
+ reasoning=reasoning,
259
+ raw_response=response,
260
+ )
261
+
262
+
263
+ def parse_adjacency_matrix_response(
264
+ response: Dict[str, Any],
265
+ expected_variables: List[str],
266
+ ) -> GeneratedGraph:
267
+ """Parse an adjacency matrix format response from the LLM.
268
+
269
+ Expected JSON format:
270
+ {
271
+ "variables": ["var1", "var2", "var3"],
272
+ "adjacency_matrix": [
273
+ [0.0, 0.8, 0.0],
274
+ [0.0, 0.0, 0.6],
275
+ [0.0, 0.0, 0.0]
276
+ ],
277
+ "reasoning": "explanation"
278
+ }
279
+
280
+ Args:
281
+ response: Parsed JSON response from the LLM.
282
+ expected_variables: List of expected variable names for validation.
283
+
284
+ Returns:
285
+ GeneratedGraph with edges extracted from the matrix.
286
+
287
+ Raises:
288
+ ValueError: If the response format is invalid.
289
+ """
290
+ if not isinstance(response, dict):
291
+ raise ValueError(f"Expected dict response, got {type(response)}")
292
+
293
+ variables = response.get("variables", [])
294
+ if not isinstance(variables, list):
295
+ raise ValueError(
296
+ f"Expected 'variables' to be a list, got {type(variables)}"
297
+ )
298
+
299
+ matrix = response.get("adjacency_matrix", [])
300
+ if not isinstance(matrix, list):
301
+ raise ValueError(
302
+ f"Expected 'adjacency_matrix' to be a list, got {type(matrix)}"
303
+ )
304
+
305
+ reasoning = response.get("reasoning", "")
306
+ if not isinstance(reasoning, str):
307
+ reasoning = str(reasoning)
308
+
309
+ # Validate matrix dimensions
310
+ n = len(variables)
311
+ if len(matrix) != n:
312
+ raise ValueError(
313
+ f"Matrix has {len(matrix)} rows but {n} variables declared"
314
+ )
315
+
316
+ for i, row in enumerate(matrix):
317
+ if not isinstance(row, list) or len(row) != n:
318
+ raise ValueError(f"Matrix row {i} has incorrect dimensions")
319
+
320
+ # Validate variables against expected
321
+ valid_vars = set(expected_variables)
322
+ for var in variables:
323
+ if var not in valid_vars:
324
+ logger.warning(f"Unexpected variable in response: {var}")
325
+
326
+ # Extract edges from matrix
327
+ edges = []
328
+ for i, source in enumerate(variables):
329
+ for j, target in enumerate(variables):
330
+ confidence = matrix[i][j]
331
+ try:
332
+ confidence = float(confidence)
333
+ except (TypeError, ValueError):
334
+ continue
335
+
336
+ if confidence > 0.0 and i != j: # Skip zeros and self-loops
337
+ edges.append(
338
+ ProposedEdge(
339
+ source=source,
340
+ target=target,
341
+ confidence=max(0.0, min(1.0, confidence)),
342
+ )
343
+ )
344
+
345
+ return GeneratedGraph(
346
+ edges=edges,
347
+ variables=variables,
348
+ reasoning=reasoning,
349
+ raw_response=response,
350
+ )
351
+
352
+
353
+ def parse_graph_response(
354
+ response_text: str,
355
+ variables: List[str],
356
+ output_format: str = "edge_list",
357
+ ) -> GeneratedGraph:
358
+ """Parse an LLM response into a GeneratedGraph.
359
+
360
+ Handles JSON extraction from markdown code blocks and parses
361
+ according to the specified output format.
362
+
363
+ Args:
364
+ response_text: Raw text response from the LLM.
365
+ variables: List of valid variable names.
366
+ output_format: Expected format ("edge_list" or "adjacency_matrix").
367
+
368
+ Returns:
369
+ GeneratedGraph with parsed edges and metadata.
370
+
371
+ Raises:
372
+ ValueError: If JSON parsing fails or format is invalid.
373
+ """
374
+ # Clean up potential markdown code blocks
375
+ text = response_text.strip()
376
+ if text.startswith("```json"):
377
+ text = text[7:]
378
+ elif text.startswith("```"):
379
+ text = text[3:]
380
+ if text.endswith("```"):
381
+ text = text[:-3]
382
+ text = text.strip()
383
+
384
+ try:
385
+ response = json.loads(text)
386
+ except json.JSONDecodeError as e:
387
+ raise ValueError(f"Failed to parse JSON response: {e}")
388
+
389
+ if output_format == "adjacency_matrix":
390
+ return parse_adjacency_matrix_response(response, variables)
391
+ else:
392
+ return parse_edge_list_response(response, variables)
@@ -0,0 +1,154 @@
1
+ """Prompt detail filtering for model specifications.
2
+
3
+ This module provides functionality to extract filtered views
4
+ (minimal, standard, rich) from model specifications for LLM prompts.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from enum import Enum
10
+ from typing import Any
11
+
12
+ from causaliq_knowledge.graph.models import ModelSpec, VariableSpec
13
+
14
+
15
+ class PromptDetail(str, Enum):
16
+ """Level of detail for variable information in prompts."""
17
+
18
+ MINIMAL = "minimal"
19
+ STANDARD = "standard"
20
+ RICH = "rich"
21
+
22
+
23
+ class ViewFilter:
24
+ """Filter model specifications to extract specific detail levels.
25
+
26
+ This class extracts variable information according to the view
27
+ definitions in the model specification (minimal, standard, rich).
28
+
29
+ By default, llm_name is substituted for name in the output to prevent
30
+ LLM memorisation of benchmark networks. Use use_llm_names=False to
31
+ output benchmark names directly (for memorisation testing).
32
+
33
+ Example:
34
+ >>> spec = ModelLoader.load("model.json")
35
+ >>> view_filter = ViewFilter(spec)
36
+ >>> minimal_vars = view_filter.filter_variables(PromptDetail.MINIMAL)
37
+ >>> # Returns list of dicts with llm_name as 'name' field
38
+ """
39
+
40
+ def __init__(self, spec: ModelSpec, *, use_llm_names: bool = True) -> None:
41
+ """Initialise the view filter.
42
+
43
+ Args:
44
+ spec: The model specification to filter.
45
+ use_llm_names: If True (default), output llm_name as 'name'.
46
+ If False, output benchmark name as 'name'.
47
+ """
48
+ self._spec = spec
49
+ self._use_llm_names = use_llm_names
50
+
51
+ @property
52
+ def spec(self) -> ModelSpec:
53
+ """Return the model specification."""
54
+ return self._spec
55
+
56
+ def get_include_fields(self, level: PromptDetail) -> list[str]:
57
+ """Get the fields to include for a given detail level.
58
+
59
+ Args:
60
+ level: The prompt detail level (minimal, standard, rich).
61
+
62
+ Returns:
63
+ List of field names to include.
64
+ """
65
+ if level == PromptDetail.MINIMAL:
66
+ return self._spec.prompt_details.minimal.include_fields
67
+ elif level == PromptDetail.STANDARD:
68
+ return self._spec.prompt_details.standard.include_fields
69
+ elif level == PromptDetail.RICH:
70
+ return self._spec.prompt_details.rich.include_fields
71
+ else: # pragma: no cover
72
+ raise ValueError(f"Unknown prompt detail level: {level}")
73
+
74
+ def filter_variable(
75
+ self,
76
+ variable: VariableSpec,
77
+ level: PromptDetail,
78
+ ) -> dict[str, Any]:
79
+ """Filter a single variable to include only specified fields.
80
+
81
+ Args:
82
+ variable: The variable specification to filter.
83
+ level: The prompt detail level determining which fields to include.
84
+
85
+ Returns:
86
+ Dictionary with only the fields specified by the detail level.
87
+ Enum values are converted to their string representations.
88
+ If use_llm_names is True, the 'name' field contains llm_name.
89
+ """
90
+ include_fields = self.get_include_fields(level)
91
+ # Use mode="json" to convert enums to their string values
92
+ var_dict = variable.model_dump(mode="json")
93
+
94
+ # If using llm_names, substitute llm_name for name in output
95
+ if self._use_llm_names and "name" in include_fields:
96
+ var_dict["name"] = var_dict.get("llm_name", var_dict["name"])
97
+
98
+ # Never include llm_name in output (it's internal)
99
+ return {
100
+ key: value
101
+ for key, value in var_dict.items()
102
+ if key in include_fields
103
+ and key != "llm_name"
104
+ and value is not None
105
+ }
106
+
107
+ def filter_variables(self, level: PromptDetail) -> list[dict[str, Any]]:
108
+ """Filter all variables to the specified detail level.
109
+
110
+ Args:
111
+ level: The prompt detail level (minimal, standard, rich).
112
+
113
+ Returns:
114
+ List of filtered variable dictionaries.
115
+ """
116
+ return [
117
+ self.filter_variable(var, level) for var in self._spec.variables
118
+ ]
119
+
120
+ def get_variable_names(self) -> list[str]:
121
+ """Get all variable names for LLM output.
122
+
123
+ Returns benchmark names if use_llm_names is False,
124
+ otherwise returns llm_names.
125
+
126
+ Returns:
127
+ List of variable names.
128
+ """
129
+ if self._use_llm_names:
130
+ return self._spec.get_llm_names()
131
+ return self._spec.get_variable_names()
132
+
133
+ def get_domain(self) -> str:
134
+ """Get the domain from the specification.
135
+
136
+ Returns:
137
+ The domain string.
138
+ """
139
+ return self._spec.domain
140
+
141
+ def get_context_summary(self, level: PromptDetail) -> dict[str, Any]:
142
+ """Get a complete context summary for LLM prompts.
143
+
144
+ Args:
145
+ level: The prompt detail level for variable filtering.
146
+
147
+ Returns:
148
+ Dictionary with domain and filtered variables.
149
+ """
150
+ return {
151
+ "domain": self._spec.domain,
152
+ "dataset_id": self._spec.dataset_id,
153
+ "variables": self.filter_variables(level),
154
+ }
@@ -297,6 +297,8 @@ class BaseLLMClient(ABC):
297
297
  Args:
298
298
  messages: List of message dicts with "role" and "content" keys.
299
299
  **kwargs: Provider-specific options (temperature, max_tokens, etc.)
300
+ Also accepts request_id (str) for identifying requests in
301
+ exports. Note: request_id is NOT part of the cache key.
300
302
 
301
303
  Returns:
302
304
  LLMResponse with the generated content and metadata.
@@ -306,6 +308,9 @@ class BaseLLMClient(ABC):
306
308
  cache = self.cache
307
309
  use_cache = self.use_cache
308
310
 
311
+ # Extract request_id (not part of cache key)
312
+ request_id = kwargs.pop("request_id", "")
313
+
309
314
  # Build cache key
310
315
  temperature = kwargs.get("temperature")
311
316
  max_tokens = kwargs.get("max_tokens")
@@ -354,6 +359,7 @@ class BaseLLMClient(ABC):
354
359
  input_tokens=response.input_tokens,
355
360
  output_tokens=response.output_tokens,
356
361
  cost_usd=response.cost,
362
+ request_id=request_id,
357
363
  )
358
364
  cache.put_data(cache_key, "llm", entry.to_dict())
359
365