causaliq-knowledge 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,10 +2,11 @@
2
2
  causaliq-knowledge: LLM and human knowledge for causal discovery.
3
3
  """
4
4
 
5
+ from causaliq_knowledge.action import CausalIQAction
5
6
  from causaliq_knowledge.base import KnowledgeProvider
6
7
  from causaliq_knowledge.models import EdgeDirection, EdgeKnowledge
7
8
 
8
- __version__ = "0.3.0"
9
+ __version__ = "0.4.0"
9
10
  __author__ = "CausalIQ"
10
11
  __email__ = "info@causaliq.com"
11
12
 
@@ -17,7 +18,7 @@ __url__ = "https://github.com/causaliq/causaliq-knowledge"
17
18
  __license__ = "MIT"
18
19
 
19
20
  # Version tuple for programmatic access (major, minor, patch)
20
- VERSION = (0, 3, 0)
21
+ VERSION = (0, 4, 0)
21
22
 
22
23
  __all__ = [
23
24
  "__version__",
@@ -29,5 +30,7 @@ __all__ = [
29
30
  "EdgeDirection",
30
31
  # Abstract interface
31
32
  "KnowledgeProvider",
33
+ # Workflow action (auto-discovered by causaliq-workflow)
34
+ "CausalIQAction",
32
35
  # Note: Import LLMKnowledge from causaliq_knowledge.llm
33
36
  ]
@@ -0,0 +1,480 @@
1
+ """CausalIQ workflow action for graph generation.
2
+
3
+ This module provides the workflow action integration for causaliq-knowledge,
4
+ allowing graph generation to be used as a step in CausalIQ workflows.
5
+
6
+ The action is auto-discovered by causaliq-workflow when this package is
7
+ imported, using the convention of exporting a class named 'CausalIQAction'.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING, Any, Dict, Optional
16
+
17
+ from causaliq_workflow.action import (
18
+ ActionExecutionError,
19
+ ActionInput,
20
+ ActionValidationError,
21
+ )
22
+ from causaliq_workflow.action import CausalIQAction as BaseCausalIQAction
23
+ from causaliq_workflow.logger import WorkflowLogger
24
+ from causaliq_workflow.registry import WorkflowContext
25
+ from pydantic import ValidationError
26
+
27
+ from causaliq_knowledge.graph.params import GenerateGraphParams
28
+
29
+ if TYPE_CHECKING: # pragma: no cover
30
+ from causaliq_knowledge.graph.response import GeneratedGraph
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Re-export for convenience (unused imports are intentional for API surface)
35
+ __all__ = [
36
+ "ActionExecutionError",
37
+ "ActionInput",
38
+ "ActionValidationError",
39
+ "BaseCausalIQAction",
40
+ "WorkflowContext",
41
+ "WorkflowLogger",
42
+ "GenerateGraphAction",
43
+ "CausalIQAction",
44
+ "SUPPORTED_ACTIONS",
45
+ ]
46
+
47
+
48
+ # Supported actions within this package
49
+ SUPPORTED_ACTIONS = {"generate_graph"}
50
+
51
+
52
+ def _create_action_inputs() -> Dict[str, Any]:
53
+ """Create action input specifications.
54
+
55
+ Returns:
56
+ Dictionary of ActionInput specifications.
57
+ """
58
+ return {
59
+ "action": ActionInput(
60
+ name="action",
61
+ description="Action to perform (e.g., 'generate_graph')",
62
+ required=True,
63
+ type_hint="str",
64
+ ),
65
+ "model_spec": ActionInput(
66
+ name="model_spec",
67
+ description="Path to model specification JSON file",
68
+ required=True,
69
+ type_hint="str",
70
+ ),
71
+ "prompt_detail": ActionInput(
72
+ name="prompt_detail",
73
+ description="Detail level for prompts: minimal, standard, or rich",
74
+ required=False,
75
+ default="standard",
76
+ type_hint="str",
77
+ ),
78
+ "use_benchmark_names": ActionInput(
79
+ name="use_benchmark_names",
80
+ description="Use benchmark names instead of LLM names",
81
+ required=False,
82
+ default=False,
83
+ type_hint="bool",
84
+ ),
85
+ "llm_model": ActionInput(
86
+ name="llm_model",
87
+ description="LLM model identifier (e.g., groq/llama-3.1-8b)",
88
+ required=False,
89
+ default="groq/llama-3.1-8b-instant",
90
+ type_hint="str",
91
+ ),
92
+ "output": ActionInput(
93
+ name="output",
94
+ description="Output: .json file path or 'none' for stdout",
95
+ required=True,
96
+ type_hint="str",
97
+ ),
98
+ "llm_cache": ActionInput(
99
+ name="llm_cache",
100
+ description="Path to cache database (.db) or 'none' to disable",
101
+ required=True,
102
+ type_hint="str",
103
+ ),
104
+ "llm_temperature": ActionInput(
105
+ name="llm_temperature",
106
+ description="LLM sampling temperature (0.0-2.0)",
107
+ required=False,
108
+ default=0.1,
109
+ type_hint="float",
110
+ ),
111
+ }
112
+
113
+
114
+ class GenerateGraphAction(BaseCausalIQAction):
115
+ """Workflow action for generating causal graphs from model specifications.
116
+
117
+ This action integrates causaliq-knowledge graph generation into
118
+ CausalIQ workflows, allowing LLM-based graph generation to be used
119
+ as workflow steps.
120
+
121
+ The action supports the 'generate_graph' operation, which:
122
+ - Loads a model specification from a JSON file
123
+ - Queries an LLM to propose causal relationships
124
+ - Returns the generated graph structure
125
+
126
+ Attributes:
127
+ name: Action identifier for workflow 'uses' field.
128
+ version: Action version.
129
+ description: Human-readable description.
130
+ inputs: Input parameter specifications.
131
+
132
+ Example workflow step:
133
+ ```yaml
134
+ steps:
135
+ - name: Generate causal graph
136
+ uses: causaliq-knowledge
137
+ with:
138
+ action: generate_graph
139
+ model_spec: "{{data_dir}}/cancer.json"
140
+ llm_cache: "{{data_dir}}/cancer_llm.db"
141
+ prompt_detail: standard
142
+ llm_model: groq/llama-3.1-8b-instant
143
+ ```
144
+ """
145
+
146
+ name: str = "causaliq-knowledge"
147
+ version: str = "0.4.0"
148
+ description: str = "Generate causal graphs using LLM knowledge"
149
+ author: str = "CausalIQ"
150
+
151
+ inputs: Dict[str, Any] = _create_action_inputs()
152
+ outputs: Dict[str, str] = {
153
+ "graph": "Generated graph structure as JSON",
154
+ "edge_count": "Number of edges in generated graph",
155
+ "variable_count": "Number of variables in the model",
156
+ "model_used": "LLM model used for generation",
157
+ "cached": "Whether the result was retrieved from cache",
158
+ }
159
+
160
+ def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
161
+ """Validate input values against specifications.
162
+
163
+ Args:
164
+ inputs: Dictionary of input values to validate.
165
+
166
+ Returns:
167
+ True if all inputs are valid.
168
+
169
+ Raises:
170
+ ActionValidationError: If validation fails.
171
+ """
172
+ # Check required 'action' parameter
173
+ if "action" not in inputs:
174
+ raise ActionValidationError(
175
+ "Missing required input: 'action'. "
176
+ f"Supported actions: {SUPPORTED_ACTIONS}"
177
+ )
178
+
179
+ action = inputs["action"]
180
+ if action not in SUPPORTED_ACTIONS:
181
+ raise ActionValidationError(
182
+ f"Unknown action: '{action}'. "
183
+ f"Supported actions: {SUPPORTED_ACTIONS}"
184
+ )
185
+
186
+ # For generate_graph, validate using GenerateGraphParams
187
+ if action == "generate_graph":
188
+ # Check required model_spec
189
+ if "model_spec" not in inputs:
190
+ raise ActionValidationError(
191
+ "Missing required input: 'model_spec' for generate_graph"
192
+ )
193
+
194
+ # Build params dict (excluding 'action' which isn't a param)
195
+ params_dict = {k: v for k, v in inputs.items() if k != "action"}
196
+
197
+ try:
198
+ # Validate using Pydantic model
199
+ GenerateGraphParams.from_dict(params_dict)
200
+ except (ValidationError, ValueError) as e:
201
+ raise ActionValidationError(
202
+ f"Invalid parameters for generate_graph: {e}"
203
+ )
204
+
205
+ return True
206
+
207
+ def run(
208
+ self,
209
+ inputs: Dict[str, Any],
210
+ mode: str = "dry-run",
211
+ context: Optional[Any] = None,
212
+ logger: Optional[Any] = None,
213
+ ) -> Dict[str, Any]:
214
+ """Execute the action with validated inputs.
215
+
216
+ Args:
217
+ inputs: Dictionary of input values keyed by input name.
218
+ mode: Execution mode ('dry-run', 'run', 'compare').
219
+ context: Workflow context for optimisation.
220
+ logger: Optional logger for task execution reporting.
221
+
222
+ Returns:
223
+ Dictionary containing:
224
+ - status: 'success' or 'skipped' (for dry-run)
225
+ - graph: Generated graph as JSON (if run mode)
226
+ - edge_count: Number of edges
227
+ - variable_count: Number of variables
228
+ - model_used: LLM model identifier
229
+ - cached: Whether result was from cache
230
+
231
+ Raises:
232
+ ActionExecutionError: If action execution fails.
233
+ """
234
+ # Validate inputs first
235
+ self.validate_inputs(inputs)
236
+
237
+ action = inputs["action"]
238
+
239
+ if action == "generate_graph":
240
+ return self._run_generate_graph(inputs, mode, context, logger)
241
+ else: # pragma: no cover
242
+ # This shouldn't happen after validate_inputs
243
+ raise ActionExecutionError(f"Unknown action: {action}")
244
+
245
+ def _run_generate_graph(
246
+ self,
247
+ inputs: Dict[str, Any],
248
+ mode: str,
249
+ context: Optional[Any],
250
+ logger: Optional[Any],
251
+ ) -> Dict[str, Any]:
252
+ """Execute the generate_graph action.
253
+
254
+ Args:
255
+ inputs: Validated input parameters.
256
+ mode: Execution mode.
257
+ context: Workflow context.
258
+ logger: Optional workflow logger.
259
+
260
+ Returns:
261
+ Action result dictionary.
262
+ """
263
+ # Build params (excluding 'action')
264
+ params_dict = {k: v for k, v in inputs.items() if k != "action"}
265
+
266
+ try:
267
+ params = GenerateGraphParams.from_dict(params_dict)
268
+ except (ValidationError, ValueError) as e:
269
+ raise ActionExecutionError(f"Parameter validation failed: {e}")
270
+
271
+ # Check model_spec exists
272
+ if not params.model_spec.exists():
273
+ raise ActionExecutionError(
274
+ f"Model specification not found: {params.model_spec}"
275
+ )
276
+
277
+ # Dry-run mode: validate only, don't execute
278
+ if mode == "dry-run":
279
+ return self._dry_run_result(params)
280
+
281
+ # Run mode: execute graph generation
282
+ return self._execute_generate_graph(params)
283
+
284
+ def _dry_run_result(self, params: GenerateGraphParams) -> Dict[str, Any]:
285
+ """Return dry-run result without executing.
286
+
287
+ Args:
288
+ params: Validated parameters.
289
+
290
+ Returns:
291
+ Dry-run result dictionary.
292
+ """
293
+ return {
294
+ "status": "skipped",
295
+ "message": "Dry-run mode: would generate graph",
296
+ "model_spec": str(params.model_spec),
297
+ "llm_model": params.llm_model,
298
+ "prompt_detail": params.prompt_detail.value,
299
+ "output": params.output,
300
+ }
301
+
302
+ def _execute_generate_graph(
303
+ self, params: GenerateGraphParams
304
+ ) -> Dict[str, Any]:
305
+ """Execute graph generation.
306
+
307
+ Args:
308
+ params: Validated parameters.
309
+
310
+ Returns:
311
+ Result dictionary with generated graph.
312
+ """
313
+ # Import here to avoid slow startup and circular imports
314
+ from causaliq_knowledge.cache import TokenCache
315
+ from causaliq_knowledge.graph import ModelLoader
316
+ from causaliq_knowledge.graph.generator import (
317
+ GraphGenerator,
318
+ GraphGeneratorConfig,
319
+ )
320
+
321
+ try:
322
+ # Load model specification
323
+ spec = ModelLoader.load(params.model_spec)
324
+ logger.info(
325
+ f"Loaded model specification: {spec.dataset_id} "
326
+ f"({len(spec.variables)} variables)"
327
+ )
328
+ except Exception as e:
329
+ raise ActionExecutionError(
330
+ f"Failed to load model specification: {e}"
331
+ )
332
+
333
+ # Track mapping for name conversion
334
+ llm_to_benchmark_mapping: Dict[str, str] = {}
335
+
336
+ # Determine naming mode
337
+ use_llm_names = not params.use_benchmark_names
338
+ if use_llm_names and spec.uses_distinct_llm_names():
339
+ llm_to_benchmark_mapping = spec.get_llm_to_name_mapping()
340
+
341
+ # Set up cache
342
+ cache: Optional[TokenCache] = None
343
+ cache_path = params.get_effective_cache_path()
344
+ if cache_path is not None:
345
+ try:
346
+ cache = TokenCache(str(cache_path))
347
+ cache.open()
348
+ except Exception as e:
349
+ raise ActionExecutionError(f"Failed to open cache: {e}")
350
+
351
+ try:
352
+ # Import OutputFormat for generator config
353
+ from causaliq_knowledge.graph.prompts import OutputFormat
354
+
355
+ # Create generator - always use edge_list format
356
+ # Derive request_id from output filename stem
357
+ if params.output.lower() == "none":
358
+ request_id = "none"
359
+ else:
360
+ request_id = Path(params.output).stem
361
+
362
+ config = GraphGeneratorConfig(
363
+ temperature=params.llm_temperature,
364
+ output_format=OutputFormat.EDGE_LIST,
365
+ prompt_detail=params.prompt_detail,
366
+ use_llm_names=use_llm_names,
367
+ request_id=request_id,
368
+ )
369
+ generator = GraphGenerator(
370
+ model=params.llm_model, config=config, cache=cache
371
+ )
372
+
373
+ # Generate graph
374
+ graph = generator.generate_from_spec(
375
+ spec, level=params.prompt_detail
376
+ )
377
+
378
+ # Map LLM names back to benchmark names
379
+ if llm_to_benchmark_mapping:
380
+ graph = self._map_graph_names(graph, llm_to_benchmark_mapping)
381
+
382
+ # Get stats
383
+ stats = generator.get_stats()
384
+
385
+ # Build result
386
+ result = {
387
+ "status": "success",
388
+ "graph": self._graph_to_dict(graph),
389
+ "edge_count": len(graph.edges),
390
+ "variable_count": len(graph.variables),
391
+ "model_used": params.llm_model,
392
+ "cached": stats.get("cache_hits", 0) > 0,
393
+ "outputs": {
394
+ "graph": self._graph_to_dict(graph),
395
+ "edge_count": len(graph.edges),
396
+ "variable_count": len(graph.variables),
397
+ "model_used": params.llm_model,
398
+ "cached": stats.get("cache_hits", 0) > 0,
399
+ },
400
+ }
401
+
402
+ # Write output file if specified
403
+ output_path = params.get_effective_output_path()
404
+ if output_path:
405
+ output_path.parent.mkdir(parents=True, exist_ok=True)
406
+ output_path.write_text(
407
+ json.dumps(result["graph"], indent=2),
408
+ encoding="utf-8",
409
+ )
410
+ result["output_file"] = str(output_path)
411
+
412
+ return result
413
+
414
+ except Exception as e:
415
+ raise ActionExecutionError(f"Graph generation failed: {e}")
416
+ finally:
417
+ if cache:
418
+ cache.close()
419
+
420
+ def _graph_to_dict(self, graph: "GeneratedGraph") -> Dict[str, Any]:
421
+ """Convert GeneratedGraph to dictionary.
422
+
423
+ Args:
424
+ graph: Generated graph object.
425
+
426
+ Returns:
427
+ Dictionary representation of the graph.
428
+ """
429
+ return {
430
+ "edges": [
431
+ {
432
+ "source": edge.source,
433
+ "target": edge.target,
434
+ "confidence": edge.confidence,
435
+ }
436
+ for edge in graph.edges
437
+ ],
438
+ "variables": graph.variables,
439
+ "reasoning": graph.reasoning,
440
+ }
441
+
442
+ def _map_graph_names(
443
+ self, graph: "GeneratedGraph", mapping: Dict[str, str]
444
+ ) -> "GeneratedGraph":
445
+ """Map variable names in a graph using a mapping dictionary.
446
+
447
+ Args:
448
+ graph: The generated graph with edges to map.
449
+ mapping: Dictionary mapping old names to new names.
450
+
451
+ Returns:
452
+ New GeneratedGraph with mapped variable names.
453
+ """
454
+ from causaliq_knowledge.graph.response import (
455
+ GeneratedGraph,
456
+ ProposedEdge,
457
+ )
458
+
459
+ new_edges = []
460
+ for edge in graph.edges:
461
+ new_edge = ProposedEdge(
462
+ source=mapping.get(edge.source, edge.source),
463
+ target=mapping.get(edge.target, edge.target),
464
+ confidence=edge.confidence,
465
+ )
466
+ new_edges.append(new_edge)
467
+
468
+ new_variables = [mapping.get(v, v) for v in graph.variables]
469
+
470
+ return GeneratedGraph(
471
+ edges=new_edges,
472
+ variables=new_variables,
473
+ reasoning=graph.reasoning,
474
+ metadata=graph.metadata,
475
+ )
476
+
477
+
478
+ # Export as CausalIQAction for auto-discovery by causaliq-workflow
479
+ # This name is required by the auto-discovery convention
480
+ CausalIQAction = GenerateGraphAction
@@ -173,7 +173,8 @@ class JsonEncoder(EntryEncoder):
173
173
  """Encode a string value with tokenisation.
174
174
 
175
175
  Strings are split into tokens (words/punctuation) with special
176
- markers for string start/end.
176
+ markers for string start/end. Double quotes within the string
177
+ are encoded as '\\"' token to distinguish from string delimiters.
177
178
 
178
179
  Args:
179
180
  value: String to encode.
@@ -184,7 +185,11 @@ class JsonEncoder(EntryEncoder):
184
185
  # Split on whitespace and punctuation, keeping delimiters
185
186
  tokens = self._tokenise_string(value)
186
187
  for token in tokens:
187
- self._encode_token(token, token_cache, result)
188
+ # Escape embedded quotes to distinguish from string delimiter
189
+ if token == '"':
190
+ self._encode_token('\\"', token_cache, result)
191
+ else:
192
+ self._encode_token(token, token_cache, result)
188
193
  self._encode_token('"', token_cache, result)
189
194
 
190
195
  def _encode_list(
@@ -296,6 +301,9 @@ class JsonEncoder(EntryEncoder):
296
301
  ) -> tuple[str, int]:
297
302
  """Decode a string value (after opening quote consumed).
298
303
 
304
+ Handles escaped quotes ('\\"' token) which represent literal
305
+ double quotes within the string content.
306
+
299
307
  Args:
300
308
  blob: Binary data to decode.
301
309
  offset: Current position (after opening quote).
@@ -317,7 +325,11 @@ class JsonEncoder(EntryEncoder):
317
325
  if token == '"':
318
326
  # End of string
319
327
  return "".join(parts), offset
320
- parts.append(token)
328
+ elif token == '\\"':
329
+ # Escaped quote - append literal quote character
330
+ parts.append('"')
331
+ else:
332
+ parts.append(token)
321
333
  raise ValueError("Unterminated string")
322
334
 
323
335
  def _decode_list(
@@ -51,6 +51,8 @@ class TokenCache:
51
51
  data BLOB NOT NULL,
52
52
  created_at TEXT NOT NULL,
53
53
  metadata BLOB,
54
+ hit_count INTEGER DEFAULT 0,
55
+ last_accessed_at TEXT,
54
56
  PRIMARY KEY (hash, entry_type)
55
57
  );
56
58
 
@@ -239,6 +241,28 @@ class TokenCache:
239
241
  row = cursor.fetchone()
240
242
  return int(row[0]) if row else 0
241
243
 
244
+ def total_hits(self, entry_type: str | None = None) -> int:
245
+ """Get total cache hits across all entries.
246
+
247
+ Args:
248
+ entry_type: If provided, count only hits for this entry type.
249
+
250
+ Returns:
251
+ Total hit count.
252
+ """
253
+ if entry_type is None:
254
+ cursor = self.conn.execute(
255
+ "SELECT COALESCE(SUM(hit_count), 0) FROM cache_entries"
256
+ )
257
+ else:
258
+ cursor = self.conn.execute(
259
+ "SELECT COALESCE(SUM(hit_count), 0) FROM cache_entries "
260
+ "WHERE entry_type = ?",
261
+ (entry_type,),
262
+ )
263
+ row = cursor.fetchone()
264
+ return int(row[0]) if row else 0
265
+
242
266
  def get_or_create_token(self, token: str) -> int:
243
267
  """Get token ID, creating a new entry if needed.
244
268
 
@@ -319,7 +343,7 @@ class TokenCache:
319
343
  self.conn.commit()
320
344
 
321
345
  def get(self, hash: str, entry_type: str) -> bytes | None:
322
- """Retrieve a cache entry.
346
+ """Retrieve a cache entry and increment hit count.
323
347
 
324
348
  Args:
325
349
  hash: Unique identifier for the entry.
@@ -334,7 +358,17 @@ class TokenCache:
334
358
  (hash, entry_type),
335
359
  )
336
360
  row = cursor.fetchone()
337
- return row[0] if row else None
361
+ if row:
362
+ # Increment hit count and update last accessed time
363
+ self.conn.execute(
364
+ "UPDATE cache_entries SET hit_count = hit_count + 1, "
365
+ "last_accessed_at = ? WHERE hash = ? AND entry_type = ?",
366
+ (self._utcnow_iso(), hash, entry_type),
367
+ )
368
+ self.conn.commit()
369
+ result: bytes = row[0]
370
+ return result
371
+ return None
338
372
 
339
373
  def get_with_metadata(
340
374
  self, hash: str, entry_type: str
@@ -0,0 +1,15 @@
1
+ """Command-line interface for causaliq-knowledge.
2
+
3
+ This package provides the CLI implementation split into logical modules:
4
+
5
+ - main: Core CLI entry point and query command
6
+ - cache: Cache management commands (stats, export, import)
7
+ - generate: Graph generation commands
8
+ - models: Model listing command
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from causaliq_knowledge.cli.main import cli, main
14
+
15
+ __all__ = ["cli", "main"]