causaliq-knowledge 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. causaliq_knowledge/__init__.py +6 -3
  2. causaliq_knowledge/action.py +480 -0
  3. causaliq_knowledge/cache/__init__.py +18 -0
  4. causaliq_knowledge/cache/encoders/__init__.py +13 -0
  5. causaliq_knowledge/cache/encoders/base.py +90 -0
  6. causaliq_knowledge/cache/encoders/json_encoder.py +430 -0
  7. causaliq_knowledge/cache/token_cache.py +666 -0
  8. causaliq_knowledge/cli/__init__.py +15 -0
  9. causaliq_knowledge/cli/cache.py +478 -0
  10. causaliq_knowledge/cli/generate.py +410 -0
  11. causaliq_knowledge/cli/main.py +172 -0
  12. causaliq_knowledge/cli/models.py +309 -0
  13. causaliq_knowledge/graph/__init__.py +78 -0
  14. causaliq_knowledge/graph/generator.py +457 -0
  15. causaliq_knowledge/graph/loader.py +222 -0
  16. causaliq_knowledge/graph/models.py +426 -0
  17. causaliq_knowledge/graph/params.py +175 -0
  18. causaliq_knowledge/graph/prompts.py +445 -0
  19. causaliq_knowledge/graph/response.py +392 -0
  20. causaliq_knowledge/graph/view_filter.py +154 -0
  21. causaliq_knowledge/llm/base_client.py +147 -1
  22. causaliq_knowledge/llm/cache.py +443 -0
  23. causaliq_knowledge/py.typed +0 -0
  24. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/METADATA +10 -6
  25. causaliq_knowledge-0.4.0.dist-info/RECORD +42 -0
  26. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/WHEEL +1 -1
  27. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/entry_points.txt +3 -0
  28. causaliq_knowledge/cli.py +0 -414
  29. causaliq_knowledge-0.2.0.dist-info/RECORD +0 -22
  30. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/licenses/LICENSE +0 -0
  31. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,392 @@
1
+ """Response models and parsing for LLM graph generation.
2
+
3
+ This module provides Pydantic models for representing LLM-generated
4
+ causal graphs and functions for parsing LLM responses in both edge
5
+ list and adjacency matrix formats.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import logging
12
+ from dataclasses import dataclass, field
13
+ from datetime import datetime, timezone
14
+ from typing import Any, Dict, List, Optional
15
+
16
+ from pydantic import BaseModel, Field, field_validator
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ProposedEdge(BaseModel):
22
+ """A proposed causal edge from LLM graph generation.
23
+
24
+ Represents a single directed edge in the proposed causal graph,
25
+ with confidence score and optional reasoning.
26
+
27
+ Attributes:
28
+ source: The name of the source variable (cause).
29
+ target: The name of the target variable (effect).
30
+ confidence: Confidence score from 0.0 to 1.0.
31
+ reasoning: Optional explanation for this specific edge.
32
+
33
+ Example:
34
+ >>> edge = ProposedEdge(
35
+ ... source="smoking",
36
+ ... target="lung_cancer",
37
+ ... confidence=0.95,
38
+ ... )
39
+ >>> print(f"{edge.source} -> {edge.target}: {edge.confidence}")
40
+ smoking -> lung_cancer: 0.95
41
+ """
42
+
43
+ source: str = Field(..., description="Source variable name (cause)")
44
+ target: str = Field(..., description="Target variable name (effect)")
45
+ confidence: float = Field(
46
+ default=0.5,
47
+ ge=0.0,
48
+ le=1.0,
49
+ description="Confidence score from 0.0 to 1.0",
50
+ )
51
+ reasoning: Optional[str] = Field(
52
+ default=None,
53
+ description="Optional explanation for this edge",
54
+ )
55
+
56
+ @field_validator("confidence", mode="before")
57
+ @classmethod
58
+ def clamp_confidence(cls, v: Any) -> float:
59
+ """Clamp confidence values to [0.0, 1.0] range."""
60
+ if v is None:
61
+ return 0.5
62
+ try:
63
+ val = float(v)
64
+ return max(0.0, min(1.0, val))
65
+ except (TypeError, ValueError):
66
+ return 0.5
67
+
68
+
69
+ @dataclass
70
+ class GenerationMetadata:
71
+ """Metadata about a graph generation request.
72
+
73
+ Attributes:
74
+ model: The LLM model used for generation.
75
+ provider: The LLM provider (e.g., "groq", "gemini").
76
+ timestamp: When the graph was generated.
77
+ latency_ms: Request latency in milliseconds.
78
+ input_tokens: Number of input tokens used.
79
+ output_tokens: Number of output tokens generated.
80
+ cost_usd: Estimated cost in USD (if available).
81
+ from_cache: Whether the response was from cache.
82
+ """
83
+
84
+ model: str
85
+ provider: str = ""
86
+ timestamp: datetime = field(
87
+ default_factory=lambda: datetime.now(timezone.utc)
88
+ )
89
+ latency_ms: int = 0
90
+ input_tokens: int = 0
91
+ output_tokens: int = 0
92
+ cost_usd: float = 0.0
93
+ from_cache: bool = False
94
+
95
+
96
+ @dataclass
97
+ class GeneratedGraph:
98
+ """A complete causal graph generated by an LLM.
99
+
100
+ Represents the full output from an LLM graph generation query,
101
+ including all proposed edges, metadata, and the LLM's reasoning.
102
+
103
+ Attributes:
104
+ edges: List of proposed causal edges.
105
+ variables: List of variable names in the graph.
106
+ reasoning: Overall reasoning provided by the LLM.
107
+ metadata: Generation metadata (model, timing, etc.).
108
+ raw_response: The original LLM response for debugging.
109
+
110
+ Example:
111
+ >>> edge1 = ProposedEdge(
112
+ ... source="age", target="income", confidence=0.7
113
+ ... )
114
+ >>> edge2 = ProposedEdge(
115
+ ... source="education", target="income", confidence=0.9
116
+ ... )
117
+ >>> graph = GeneratedGraph(
118
+ ... edges=[edge1, edge2],
119
+ ... variables=["age", "education", "income"],
120
+ ... reasoning="Age and education both influence income.",
121
+ ... metadata=GenerationMetadata(model="llama-3.1-8b-instant"),
122
+ ... )
123
+ >>> print(f"Generated {len(graph.edges)} edges")
124
+ Generated 2 edges
125
+ """
126
+
127
+ edges: List[ProposedEdge]
128
+ variables: List[str]
129
+ reasoning: str = ""
130
+ metadata: Optional[GenerationMetadata] = None
131
+ raw_response: Optional[Dict[str, Any]] = field(default=None, repr=False)
132
+
133
+ def get_adjacency_matrix(self) -> List[List[float]]:
134
+ """Convert edges to an adjacency matrix.
135
+
136
+ Creates a square matrix where entry (i,j) represents the
137
+ confidence that variable i causes variable j.
138
+
139
+ Returns:
140
+ Square matrix of confidence scores.
141
+ """
142
+ n = len(self.variables)
143
+ var_to_idx = {var: i for i, var in enumerate(self.variables)}
144
+ matrix = [[0.0] * n for _ in range(n)]
145
+
146
+ for edge in self.edges:
147
+ src_idx = var_to_idx.get(edge.source)
148
+ tgt_idx = var_to_idx.get(edge.target)
149
+ if src_idx is not None and tgt_idx is not None:
150
+ matrix[src_idx][tgt_idx] = edge.confidence
151
+
152
+ return matrix
153
+
154
+ def get_edge_list(self) -> List[tuple[str, str, float]]:
155
+ """Get edges as a list of tuples.
156
+
157
+ Returns:
158
+ List of (source, target, confidence) tuples.
159
+ """
160
+ return [
161
+ (edge.source, edge.target, edge.confidence) for edge in self.edges
162
+ ]
163
+
164
+ def filter_by_confidence(self, threshold: float = 0.5) -> "GeneratedGraph":
165
+ """Return a new graph with only edges above the threshold.
166
+
167
+ Args:
168
+ threshold: Minimum confidence score to include.
169
+
170
+ Returns:
171
+ New GeneratedGraph with filtered edges.
172
+ """
173
+ filtered_edges = [e for e in self.edges if e.confidence >= threshold]
174
+ return GeneratedGraph(
175
+ edges=filtered_edges,
176
+ variables=self.variables,
177
+ reasoning=self.reasoning,
178
+ metadata=self.metadata,
179
+ raw_response=self.raw_response,
180
+ )
181
+
182
+
183
+ def parse_edge_list_response(
184
+ response: Dict[str, Any],
185
+ variables: List[str],
186
+ ) -> GeneratedGraph:
187
+ """Parse an edge list format response from the LLM.
188
+
189
+ Expected JSON format:
190
+ {
191
+ "edges": [
192
+ {"source": "var1", "target": "var2", "confidence": 0.8},
193
+ ...
194
+ ],
195
+ "reasoning": "explanation"
196
+ }
197
+
198
+ Args:
199
+ response: Parsed JSON response from the LLM.
200
+ variables: List of valid variable names for validation.
201
+
202
+ Returns:
203
+ GeneratedGraph with parsed edges.
204
+
205
+ Raises:
206
+ ValueError: If the response format is invalid.
207
+ """
208
+ if not isinstance(response, dict):
209
+ raise ValueError(f"Expected dict response, got {type(response)}")
210
+
211
+ edges_data = response.get("edges", [])
212
+ if not isinstance(edges_data, list):
213
+ raise ValueError(
214
+ f"Expected 'edges' to be a list, got {type(edges_data)}"
215
+ )
216
+
217
+ reasoning = response.get("reasoning", "")
218
+ if not isinstance(reasoning, str):
219
+ reasoning = str(reasoning)
220
+
221
+ valid_vars = set(variables)
222
+ edges = []
223
+
224
+ for i, edge_data in enumerate(edges_data):
225
+ if not isinstance(edge_data, dict):
226
+ logger.warning(f"Skipping invalid edge at index {i}: not a dict")
227
+ continue
228
+
229
+ source = edge_data.get("source", "")
230
+ target = edge_data.get("target", "")
231
+
232
+ # Validate variable names
233
+ if source not in valid_vars:
234
+ logger.warning(f"Unknown source variable: {source}")
235
+ continue
236
+ if target not in valid_vars:
237
+ logger.warning(f"Unknown target variable: {target}")
238
+ continue
239
+ if source == target:
240
+ logger.warning(f"Skipping self-loop: {source} -> {target}")
241
+ continue
242
+
243
+ confidence = edge_data.get("confidence", 0.5)
244
+ edge_reasoning = edge_data.get("reasoning")
245
+
246
+ edges.append(
247
+ ProposedEdge(
248
+ source=source,
249
+ target=target,
250
+ confidence=confidence,
251
+ reasoning=edge_reasoning,
252
+ )
253
+ )
254
+
255
+ return GeneratedGraph(
256
+ edges=edges,
257
+ variables=variables,
258
+ reasoning=reasoning,
259
+ raw_response=response,
260
+ )
261
+
262
+
263
+ def parse_adjacency_matrix_response(
264
+ response: Dict[str, Any],
265
+ expected_variables: List[str],
266
+ ) -> GeneratedGraph:
267
+ """Parse an adjacency matrix format response from the LLM.
268
+
269
+ Expected JSON format:
270
+ {
271
+ "variables": ["var1", "var2", "var3"],
272
+ "adjacency_matrix": [
273
+ [0.0, 0.8, 0.0],
274
+ [0.0, 0.0, 0.6],
275
+ [0.0, 0.0, 0.0]
276
+ ],
277
+ "reasoning": "explanation"
278
+ }
279
+
280
+ Args:
281
+ response: Parsed JSON response from the LLM.
282
+ expected_variables: List of expected variable names for validation.
283
+
284
+ Returns:
285
+ GeneratedGraph with edges extracted from the matrix.
286
+
287
+ Raises:
288
+ ValueError: If the response format is invalid.
289
+ """
290
+ if not isinstance(response, dict):
291
+ raise ValueError(f"Expected dict response, got {type(response)}")
292
+
293
+ variables = response.get("variables", [])
294
+ if not isinstance(variables, list):
295
+ raise ValueError(
296
+ f"Expected 'variables' to be a list, got {type(variables)}"
297
+ )
298
+
299
+ matrix = response.get("adjacency_matrix", [])
300
+ if not isinstance(matrix, list):
301
+ raise ValueError(
302
+ f"Expected 'adjacency_matrix' to be a list, got {type(matrix)}"
303
+ )
304
+
305
+ reasoning = response.get("reasoning", "")
306
+ if not isinstance(reasoning, str):
307
+ reasoning = str(reasoning)
308
+
309
+ # Validate matrix dimensions
310
+ n = len(variables)
311
+ if len(matrix) != n:
312
+ raise ValueError(
313
+ f"Matrix has {len(matrix)} rows but {n} variables declared"
314
+ )
315
+
316
+ for i, row in enumerate(matrix):
317
+ if not isinstance(row, list) or len(row) != n:
318
+ raise ValueError(f"Matrix row {i} has incorrect dimensions")
319
+
320
+ # Validate variables against expected
321
+ valid_vars = set(expected_variables)
322
+ for var in variables:
323
+ if var not in valid_vars:
324
+ logger.warning(f"Unexpected variable in response: {var}")
325
+
326
+ # Extract edges from matrix
327
+ edges = []
328
+ for i, source in enumerate(variables):
329
+ for j, target in enumerate(variables):
330
+ confidence = matrix[i][j]
331
+ try:
332
+ confidence = float(confidence)
333
+ except (TypeError, ValueError):
334
+ continue
335
+
336
+ if confidence > 0.0 and i != j: # Skip zeros and self-loops
337
+ edges.append(
338
+ ProposedEdge(
339
+ source=source,
340
+ target=target,
341
+ confidence=max(0.0, min(1.0, confidence)),
342
+ )
343
+ )
344
+
345
+ return GeneratedGraph(
346
+ edges=edges,
347
+ variables=variables,
348
+ reasoning=reasoning,
349
+ raw_response=response,
350
+ )
351
+
352
+
353
+ def parse_graph_response(
354
+ response_text: str,
355
+ variables: List[str],
356
+ output_format: str = "edge_list",
357
+ ) -> GeneratedGraph:
358
+ """Parse an LLM response into a GeneratedGraph.
359
+
360
+ Handles JSON extraction from markdown code blocks and parses
361
+ according to the specified output format.
362
+
363
+ Args:
364
+ response_text: Raw text response from the LLM.
365
+ variables: List of valid variable names.
366
+ output_format: Expected format ("edge_list" or "adjacency_matrix").
367
+
368
+ Returns:
369
+ GeneratedGraph with parsed edges and metadata.
370
+
371
+ Raises:
372
+ ValueError: If JSON parsing fails or format is invalid.
373
+ """
374
+ # Clean up potential markdown code blocks
375
+ text = response_text.strip()
376
+ if text.startswith("```json"):
377
+ text = text[7:]
378
+ elif text.startswith("```"):
379
+ text = text[3:]
380
+ if text.endswith("```"):
381
+ text = text[:-3]
382
+ text = text.strip()
383
+
384
+ try:
385
+ response = json.loads(text)
386
+ except json.JSONDecodeError as e:
387
+ raise ValueError(f"Failed to parse JSON response: {e}")
388
+
389
+ if output_format == "adjacency_matrix":
390
+ return parse_adjacency_matrix_response(response, variables)
391
+ else:
392
+ return parse_edge_list_response(response, variables)
@@ -0,0 +1,154 @@
1
+ """Prompt detail filtering for model specifications.
2
+
3
+ This module provides functionality to extract filtered views
4
+ (minimal, standard, rich) from model specifications for LLM prompts.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from enum import Enum
10
+ from typing import Any
11
+
12
+ from causaliq_knowledge.graph.models import ModelSpec, VariableSpec
13
+
14
+
15
+ class PromptDetail(str, Enum):
16
+ """Level of detail for variable information in prompts."""
17
+
18
+ MINIMAL = "minimal"
19
+ STANDARD = "standard"
20
+ RICH = "rich"
21
+
22
+
23
+ class ViewFilter:
24
+ """Filter model specifications to extract specific detail levels.
25
+
26
+ This class extracts variable information according to the view
27
+ definitions in the model specification (minimal, standard, rich).
28
+
29
+ By default, llm_name is substituted for name in the output to prevent
30
+ LLM memorisation of benchmark networks. Use use_llm_names=False to
31
+ output benchmark names directly (for memorisation testing).
32
+
33
+ Example:
34
+ >>> spec = ModelLoader.load("model.json")
35
+ >>> view_filter = ViewFilter(spec)
36
+ >>> minimal_vars = view_filter.filter_variables(PromptDetail.MINIMAL)
37
+ >>> # Returns list of dicts with llm_name as 'name' field
38
+ """
39
+
40
+ def __init__(self, spec: ModelSpec, *, use_llm_names: bool = True) -> None:
41
+ """Initialise the view filter.
42
+
43
+ Args:
44
+ spec: The model specification to filter.
45
+ use_llm_names: If True (default), output llm_name as 'name'.
46
+ If False, output benchmark name as 'name'.
47
+ """
48
+ self._spec = spec
49
+ self._use_llm_names = use_llm_names
50
+
51
+ @property
52
+ def spec(self) -> ModelSpec:
53
+ """Return the model specification."""
54
+ return self._spec
55
+
56
+ def get_include_fields(self, level: PromptDetail) -> list[str]:
57
+ """Get the fields to include for a given detail level.
58
+
59
+ Args:
60
+ level: The prompt detail level (minimal, standard, rich).
61
+
62
+ Returns:
63
+ List of field names to include.
64
+ """
65
+ if level == PromptDetail.MINIMAL:
66
+ return self._spec.prompt_details.minimal.include_fields
67
+ elif level == PromptDetail.STANDARD:
68
+ return self._spec.prompt_details.standard.include_fields
69
+ elif level == PromptDetail.RICH:
70
+ return self._spec.prompt_details.rich.include_fields
71
+ else: # pragma: no cover
72
+ raise ValueError(f"Unknown prompt detail level: {level}")
73
+
74
+ def filter_variable(
75
+ self,
76
+ variable: VariableSpec,
77
+ level: PromptDetail,
78
+ ) -> dict[str, Any]:
79
+ """Filter a single variable to include only specified fields.
80
+
81
+ Args:
82
+ variable: The variable specification to filter.
83
+ level: The prompt detail level determining which fields to include.
84
+
85
+ Returns:
86
+ Dictionary with only the fields specified by the detail level.
87
+ Enum values are converted to their string representations.
88
+ If use_llm_names is True, the 'name' field contains llm_name.
89
+ """
90
+ include_fields = self.get_include_fields(level)
91
+ # Use mode="json" to convert enums to their string values
92
+ var_dict = variable.model_dump(mode="json")
93
+
94
+ # If using llm_names, substitute llm_name for name in output
95
+ if self._use_llm_names and "name" in include_fields:
96
+ var_dict["name"] = var_dict.get("llm_name", var_dict["name"])
97
+
98
+ # Never include llm_name in output (it's internal)
99
+ return {
100
+ key: value
101
+ for key, value in var_dict.items()
102
+ if key in include_fields
103
+ and key != "llm_name"
104
+ and value is not None
105
+ }
106
+
107
+ def filter_variables(self, level: PromptDetail) -> list[dict[str, Any]]:
108
+ """Filter all variables to the specified detail level.
109
+
110
+ Args:
111
+ level: The prompt detail level (minimal, standard, rich).
112
+
113
+ Returns:
114
+ List of filtered variable dictionaries.
115
+ """
116
+ return [
117
+ self.filter_variable(var, level) for var in self._spec.variables
118
+ ]
119
+
120
+ def get_variable_names(self) -> list[str]:
121
+ """Get all variable names for LLM output.
122
+
123
+ Returns benchmark names if use_llm_names is False,
124
+ otherwise returns llm_names.
125
+
126
+ Returns:
127
+ List of variable names.
128
+ """
129
+ if self._use_llm_names:
130
+ return self._spec.get_llm_names()
131
+ return self._spec.get_variable_names()
132
+
133
+ def get_domain(self) -> str:
134
+ """Get the domain from the specification.
135
+
136
+ Returns:
137
+ The domain string.
138
+ """
139
+ return self._spec.domain
140
+
141
+ def get_context_summary(self, level: PromptDetail) -> dict[str, Any]:
142
+ """Get a complete context summary for LLM prompts.
143
+
144
+ Args:
145
+ level: The prompt detail level for variable filtering.
146
+
147
+ Returns:
148
+ Dictionary with domain and filtered variables.
149
+ """
150
+ return {
151
+ "domain": self._spec.domain,
152
+ "dataset_id": self._spec.dataset_id,
153
+ "variables": self.filter_variables(level),
154
+ }
@@ -5,11 +5,18 @@ must implement. This provides a consistent API regardless of the
5
5
  underlying LLM provider.
6
6
  """
7
7
 
8
+ from __future__ import annotations
9
+
10
+ import hashlib
8
11
  import json
9
12
  import logging
13
+ import time
10
14
  from abc import ABC, abstractmethod
11
15
  from dataclasses import dataclass, field
12
- from typing import Any, Dict, List, Optional
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
17
+
18
+ if TYPE_CHECKING: # pragma: no cover
19
+ from causaliq_knowledge.cache import TokenCache
13
20
 
14
21
  logger = logging.getLogger(__name__)
15
22
 
@@ -218,3 +225,142 @@ class BaseLLMClient(ABC):
218
225
  Model identifier string.
219
226
  """
220
227
  return getattr(self, "config", LLMConfig(model="unknown")).model
228
+
229
+ def _build_cache_key(
230
+ self,
231
+ messages: List[Dict[str, str]],
232
+ temperature: Optional[float] = None,
233
+ max_tokens: Optional[int] = None,
234
+ ) -> str:
235
+ """Build a deterministic cache key for the request.
236
+
237
+ Creates a SHA-256 hash from the model, messages, temperature, and
238
+ max_tokens. The hash is truncated to 16 hex characters (64 bits).
239
+
240
+ Args:
241
+ messages: List of message dicts with "role" and "content" keys.
242
+ temperature: Sampling temperature (defaults to config value).
243
+ max_tokens: Maximum tokens (defaults to config value).
244
+
245
+ Returns:
246
+ 16-character hex string cache key.
247
+ """
248
+ config = getattr(self, "config", LLMConfig(model="unknown"))
249
+ key_data = {
250
+ "model": config.model,
251
+ "messages": messages,
252
+ "temperature": (
253
+ temperature if temperature is not None else config.temperature
254
+ ),
255
+ "max_tokens": (
256
+ max_tokens if max_tokens is not None else config.max_tokens
257
+ ),
258
+ }
259
+ key_json = json.dumps(key_data, sort_keys=True, separators=(",", ":"))
260
+ return hashlib.sha256(key_json.encode()).hexdigest()[:16]
261
+
262
+ def set_cache(
263
+ self,
264
+ cache: Optional["TokenCache"],
265
+ use_cache: bool = True,
266
+ ) -> None:
267
+ """Configure caching for this client.
268
+
269
+ Args:
270
+ cache: TokenCache instance for caching, or None to disable.
271
+ use_cache: Whether to use the cache (default True).
272
+ """
273
+ self._cache = cache
274
+ self._use_cache = use_cache
275
+
276
+ @property
277
+ def cache(self) -> Optional["TokenCache"]:
278
+ """Return the configured cache, if any."""
279
+ return getattr(self, "_cache", None)
280
+
281
+ @property
282
+ def use_cache(self) -> bool:
283
+ """Return whether caching is enabled."""
284
+ return getattr(self, "_use_cache", True)
285
+
286
+ def cached_completion(
287
+ self,
288
+ messages: List[Dict[str, str]],
289
+ **kwargs: Any,
290
+ ) -> LLMResponse:
291
+ """Make a completion request with caching.
292
+
293
+ If caching is enabled and a cached response exists, returns
294
+ the cached response without making an API call. Otherwise,
295
+ makes the API call and caches the result.
296
+
297
+ Args:
298
+ messages: List of message dicts with "role" and "content" keys.
299
+ **kwargs: Provider-specific options (temperature, max_tokens, etc.)
300
+ Also accepts request_id (str) for identifying requests in
301
+ exports. Note: request_id is NOT part of the cache key.
302
+
303
+ Returns:
304
+ LLMResponse with the generated content and metadata.
305
+ """
306
+ from causaliq_knowledge.llm.cache import LLMCacheEntry, LLMEntryEncoder
307
+
308
+ cache = self.cache
309
+ use_cache = self.use_cache
310
+
311
+ # Extract request_id (not part of cache key)
312
+ request_id = kwargs.pop("request_id", "")
313
+
314
+ # Build cache key
315
+ temperature = kwargs.get("temperature")
316
+ max_tokens = kwargs.get("max_tokens")
317
+ cache_key = self._build_cache_key(messages, temperature, max_tokens)
318
+
319
+ # Check cache
320
+ if use_cache and cache is not None:
321
+ # Ensure encoder is registered
322
+ if not cache.has_encoder("llm"):
323
+ cache.register_encoder("llm", LLMEntryEncoder())
324
+
325
+ if cache.exists(cache_key, "llm"):
326
+ cached_data = cache.get_data(cache_key, "llm")
327
+ if cached_data is not None:
328
+ entry = LLMCacheEntry.from_dict(cached_data)
329
+ return LLMResponse(
330
+ content=entry.response.content,
331
+ model=entry.model,
332
+ input_tokens=entry.metadata.tokens.input,
333
+ output_tokens=entry.metadata.tokens.output,
334
+ cost=entry.metadata.cost_usd or 0.0,
335
+ )
336
+
337
+ # Make API call with timing
338
+ start_time = time.perf_counter()
339
+ response = self.completion(messages, **kwargs)
340
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
341
+
342
+ # Store in cache
343
+ if use_cache and cache is not None:
344
+ config = getattr(self, "config", LLMConfig(model="unknown"))
345
+ entry = LLMCacheEntry.create(
346
+ model=config.model,
347
+ messages=messages,
348
+ content=response.content,
349
+ temperature=(
350
+ temperature
351
+ if temperature is not None
352
+ else config.temperature
353
+ ),
354
+ max_tokens=(
355
+ max_tokens if max_tokens is not None else config.max_tokens
356
+ ),
357
+ provider=self.provider_name,
358
+ latency_ms=latency_ms,
359
+ input_tokens=response.input_tokens,
360
+ output_tokens=response.output_tokens,
361
+ cost_usd=response.cost,
362
+ request_id=request_id,
363
+ )
364
+ cache.put_data(cache_key, "llm", entry.to_dict())
365
+
366
+ return response