causaliq-knowledge 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +5 -2
- causaliq_knowledge/action.py +480 -0
- causaliq_knowledge/cache/encoders/json_encoder.py +15 -3
- causaliq_knowledge/cache/token_cache.py +36 -2
- causaliq_knowledge/cli/__init__.py +15 -0
- causaliq_knowledge/cli/cache.py +478 -0
- causaliq_knowledge/cli/generate.py +410 -0
- causaliq_knowledge/cli/main.py +172 -0
- causaliq_knowledge/cli/models.py +309 -0
- causaliq_knowledge/graph/__init__.py +78 -0
- causaliq_knowledge/graph/generator.py +457 -0
- causaliq_knowledge/graph/loader.py +222 -0
- causaliq_knowledge/graph/models.py +426 -0
- causaliq_knowledge/graph/params.py +175 -0
- causaliq_knowledge/graph/prompts.py +445 -0
- causaliq_knowledge/graph/response.py +392 -0
- causaliq_knowledge/graph/view_filter.py +154 -0
- causaliq_knowledge/llm/base_client.py +6 -0
- causaliq_knowledge/llm/cache.py +124 -61
- causaliq_knowledge/py.typed +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/METADATA +10 -6
- causaliq_knowledge-0.4.0.dist-info/RECORD +42 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/entry_points.txt +3 -0
- causaliq_knowledge/cli.py +0 -757
- causaliq_knowledge-0.3.0.dist-info/RECORD +0 -28
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/WHEEL +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/top_level.txt +0 -0
causaliq_knowledge/__init__.py
CHANGED
|
@@ -2,10 +2,11 @@
|
|
|
2
2
|
causaliq-knowledge: LLM and human knowledge for causal discovery.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from causaliq_knowledge.action import CausalIQAction
|
|
5
6
|
from causaliq_knowledge.base import KnowledgeProvider
|
|
6
7
|
from causaliq_knowledge.models import EdgeDirection, EdgeKnowledge
|
|
7
8
|
|
|
8
|
-
__version__ = "0.
|
|
9
|
+
__version__ = "0.4.0"
|
|
9
10
|
__author__ = "CausalIQ"
|
|
10
11
|
__email__ = "info@causaliq.com"
|
|
11
12
|
|
|
@@ -17,7 +18,7 @@ __url__ = "https://github.com/causaliq/causaliq-knowledge"
|
|
|
17
18
|
__license__ = "MIT"
|
|
18
19
|
|
|
19
20
|
# Version tuple for programmatic access (major, minor, patch)
|
|
20
|
-
VERSION = (0,
|
|
21
|
+
VERSION = (0, 4, 0)
|
|
21
22
|
|
|
22
23
|
__all__ = [
|
|
23
24
|
"__version__",
|
|
@@ -29,5 +30,7 @@ __all__ = [
|
|
|
29
30
|
"EdgeDirection",
|
|
30
31
|
# Abstract interface
|
|
31
32
|
"KnowledgeProvider",
|
|
33
|
+
# Workflow action (auto-discovered by causaliq-workflow)
|
|
34
|
+
"CausalIQAction",
|
|
32
35
|
# Note: Import LLMKnowledge from causaliq_knowledge.llm
|
|
33
36
|
]
|
|
@@ -0,0 +1,480 @@
|
|
|
1
|
+
"""CausalIQ workflow action for graph generation.
|
|
2
|
+
|
|
3
|
+
This module provides the workflow action integration for causaliq-knowledge,
|
|
4
|
+
allowing graph generation to be used as a step in CausalIQ workflows.
|
|
5
|
+
|
|
6
|
+
The action is auto-discovered by causaliq-workflow when this package is
|
|
7
|
+
imported, using the convention of exporting a class named 'CausalIQAction'.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
16
|
+
|
|
17
|
+
from causaliq_workflow.action import (
|
|
18
|
+
ActionExecutionError,
|
|
19
|
+
ActionInput,
|
|
20
|
+
ActionValidationError,
|
|
21
|
+
)
|
|
22
|
+
from causaliq_workflow.action import CausalIQAction as BaseCausalIQAction
|
|
23
|
+
from causaliq_workflow.logger import WorkflowLogger
|
|
24
|
+
from causaliq_workflow.registry import WorkflowContext
|
|
25
|
+
from pydantic import ValidationError
|
|
26
|
+
|
|
27
|
+
from causaliq_knowledge.graph.params import GenerateGraphParams
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
30
|
+
from causaliq_knowledge.graph.response import GeneratedGraph
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
# Re-export for convenience (unused imports are intentional for API surface)
|
|
35
|
+
__all__ = [
|
|
36
|
+
"ActionExecutionError",
|
|
37
|
+
"ActionInput",
|
|
38
|
+
"ActionValidationError",
|
|
39
|
+
"BaseCausalIQAction",
|
|
40
|
+
"WorkflowContext",
|
|
41
|
+
"WorkflowLogger",
|
|
42
|
+
"GenerateGraphAction",
|
|
43
|
+
"CausalIQAction",
|
|
44
|
+
"SUPPORTED_ACTIONS",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# Supported actions within this package
|
|
49
|
+
SUPPORTED_ACTIONS = {"generate_graph"}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _create_action_inputs() -> Dict[str, Any]:
|
|
53
|
+
"""Create action input specifications.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Dictionary of ActionInput specifications.
|
|
57
|
+
"""
|
|
58
|
+
return {
|
|
59
|
+
"action": ActionInput(
|
|
60
|
+
name="action",
|
|
61
|
+
description="Action to perform (e.g., 'generate_graph')",
|
|
62
|
+
required=True,
|
|
63
|
+
type_hint="str",
|
|
64
|
+
),
|
|
65
|
+
"model_spec": ActionInput(
|
|
66
|
+
name="model_spec",
|
|
67
|
+
description="Path to model specification JSON file",
|
|
68
|
+
required=True,
|
|
69
|
+
type_hint="str",
|
|
70
|
+
),
|
|
71
|
+
"prompt_detail": ActionInput(
|
|
72
|
+
name="prompt_detail",
|
|
73
|
+
description="Detail level for prompts: minimal, standard, or rich",
|
|
74
|
+
required=False,
|
|
75
|
+
default="standard",
|
|
76
|
+
type_hint="str",
|
|
77
|
+
),
|
|
78
|
+
"use_benchmark_names": ActionInput(
|
|
79
|
+
name="use_benchmark_names",
|
|
80
|
+
description="Use benchmark names instead of LLM names",
|
|
81
|
+
required=False,
|
|
82
|
+
default=False,
|
|
83
|
+
type_hint="bool",
|
|
84
|
+
),
|
|
85
|
+
"llm_model": ActionInput(
|
|
86
|
+
name="llm_model",
|
|
87
|
+
description="LLM model identifier (e.g., groq/llama-3.1-8b)",
|
|
88
|
+
required=False,
|
|
89
|
+
default="groq/llama-3.1-8b-instant",
|
|
90
|
+
type_hint="str",
|
|
91
|
+
),
|
|
92
|
+
"output": ActionInput(
|
|
93
|
+
name="output",
|
|
94
|
+
description="Output: .json file path or 'none' for stdout",
|
|
95
|
+
required=True,
|
|
96
|
+
type_hint="str",
|
|
97
|
+
),
|
|
98
|
+
"llm_cache": ActionInput(
|
|
99
|
+
name="llm_cache",
|
|
100
|
+
description="Path to cache database (.db) or 'none' to disable",
|
|
101
|
+
required=True,
|
|
102
|
+
type_hint="str",
|
|
103
|
+
),
|
|
104
|
+
"llm_temperature": ActionInput(
|
|
105
|
+
name="llm_temperature",
|
|
106
|
+
description="LLM sampling temperature (0.0-2.0)",
|
|
107
|
+
required=False,
|
|
108
|
+
default=0.1,
|
|
109
|
+
type_hint="float",
|
|
110
|
+
),
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class GenerateGraphAction(BaseCausalIQAction):
|
|
115
|
+
"""Workflow action for generating causal graphs from model specifications.
|
|
116
|
+
|
|
117
|
+
This action integrates causaliq-knowledge graph generation into
|
|
118
|
+
CausalIQ workflows, allowing LLM-based graph generation to be used
|
|
119
|
+
as workflow steps.
|
|
120
|
+
|
|
121
|
+
The action supports the 'generate_graph' operation, which:
|
|
122
|
+
- Loads a model specification from a JSON file
|
|
123
|
+
- Queries an LLM to propose causal relationships
|
|
124
|
+
- Returns the generated graph structure
|
|
125
|
+
|
|
126
|
+
Attributes:
|
|
127
|
+
name: Action identifier for workflow 'uses' field.
|
|
128
|
+
version: Action version.
|
|
129
|
+
description: Human-readable description.
|
|
130
|
+
inputs: Input parameter specifications.
|
|
131
|
+
|
|
132
|
+
Example workflow step:
|
|
133
|
+
```yaml
|
|
134
|
+
steps:
|
|
135
|
+
- name: Generate causal graph
|
|
136
|
+
uses: causaliq-knowledge
|
|
137
|
+
with:
|
|
138
|
+
action: generate_graph
|
|
139
|
+
model_spec: "{{data_dir}}/cancer.json"
|
|
140
|
+
llm_cache: "{{data_dir}}/cancer_llm.db"
|
|
141
|
+
prompt_detail: standard
|
|
142
|
+
llm_model: groq/llama-3.1-8b-instant
|
|
143
|
+
```
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
name: str = "causaliq-knowledge"
|
|
147
|
+
version: str = "0.4.0"
|
|
148
|
+
description: str = "Generate causal graphs using LLM knowledge"
|
|
149
|
+
author: str = "CausalIQ"
|
|
150
|
+
|
|
151
|
+
inputs: Dict[str, Any] = _create_action_inputs()
|
|
152
|
+
outputs: Dict[str, str] = {
|
|
153
|
+
"graph": "Generated graph structure as JSON",
|
|
154
|
+
"edge_count": "Number of edges in generated graph",
|
|
155
|
+
"variable_count": "Number of variables in the model",
|
|
156
|
+
"model_used": "LLM model used for generation",
|
|
157
|
+
"cached": "Whether the result was retrieved from cache",
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
def validate_inputs(self, inputs: Dict[str, Any]) -> bool:
|
|
161
|
+
"""Validate input values against specifications.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
inputs: Dictionary of input values to validate.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
True if all inputs are valid.
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
ActionValidationError: If validation fails.
|
|
171
|
+
"""
|
|
172
|
+
# Check required 'action' parameter
|
|
173
|
+
if "action" not in inputs:
|
|
174
|
+
raise ActionValidationError(
|
|
175
|
+
"Missing required input: 'action'. "
|
|
176
|
+
f"Supported actions: {SUPPORTED_ACTIONS}"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
action = inputs["action"]
|
|
180
|
+
if action not in SUPPORTED_ACTIONS:
|
|
181
|
+
raise ActionValidationError(
|
|
182
|
+
f"Unknown action: '{action}'. "
|
|
183
|
+
f"Supported actions: {SUPPORTED_ACTIONS}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# For generate_graph, validate using GenerateGraphParams
|
|
187
|
+
if action == "generate_graph":
|
|
188
|
+
# Check required model_spec
|
|
189
|
+
if "model_spec" not in inputs:
|
|
190
|
+
raise ActionValidationError(
|
|
191
|
+
"Missing required input: 'model_spec' for generate_graph"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Build params dict (excluding 'action' which isn't a param)
|
|
195
|
+
params_dict = {k: v for k, v in inputs.items() if k != "action"}
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
# Validate using Pydantic model
|
|
199
|
+
GenerateGraphParams.from_dict(params_dict)
|
|
200
|
+
except (ValidationError, ValueError) as e:
|
|
201
|
+
raise ActionValidationError(
|
|
202
|
+
f"Invalid parameters for generate_graph: {e}"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
def run(
|
|
208
|
+
self,
|
|
209
|
+
inputs: Dict[str, Any],
|
|
210
|
+
mode: str = "dry-run",
|
|
211
|
+
context: Optional[Any] = None,
|
|
212
|
+
logger: Optional[Any] = None,
|
|
213
|
+
) -> Dict[str, Any]:
|
|
214
|
+
"""Execute the action with validated inputs.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
inputs: Dictionary of input values keyed by input name.
|
|
218
|
+
mode: Execution mode ('dry-run', 'run', 'compare').
|
|
219
|
+
context: Workflow context for optimisation.
|
|
220
|
+
logger: Optional logger for task execution reporting.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Dictionary containing:
|
|
224
|
+
- status: 'success' or 'skipped' (for dry-run)
|
|
225
|
+
- graph: Generated graph as JSON (if run mode)
|
|
226
|
+
- edge_count: Number of edges
|
|
227
|
+
- variable_count: Number of variables
|
|
228
|
+
- model_used: LLM model identifier
|
|
229
|
+
- cached: Whether result was from cache
|
|
230
|
+
|
|
231
|
+
Raises:
|
|
232
|
+
ActionExecutionError: If action execution fails.
|
|
233
|
+
"""
|
|
234
|
+
# Validate inputs first
|
|
235
|
+
self.validate_inputs(inputs)
|
|
236
|
+
|
|
237
|
+
action = inputs["action"]
|
|
238
|
+
|
|
239
|
+
if action == "generate_graph":
|
|
240
|
+
return self._run_generate_graph(inputs, mode, context, logger)
|
|
241
|
+
else: # pragma: no cover
|
|
242
|
+
# This shouldn't happen after validate_inputs
|
|
243
|
+
raise ActionExecutionError(f"Unknown action: {action}")
|
|
244
|
+
|
|
245
|
+
def _run_generate_graph(
|
|
246
|
+
self,
|
|
247
|
+
inputs: Dict[str, Any],
|
|
248
|
+
mode: str,
|
|
249
|
+
context: Optional[Any],
|
|
250
|
+
logger: Optional[Any],
|
|
251
|
+
) -> Dict[str, Any]:
|
|
252
|
+
"""Execute the generate_graph action.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
inputs: Validated input parameters.
|
|
256
|
+
mode: Execution mode.
|
|
257
|
+
context: Workflow context.
|
|
258
|
+
logger: Optional workflow logger.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
Action result dictionary.
|
|
262
|
+
"""
|
|
263
|
+
# Build params (excluding 'action')
|
|
264
|
+
params_dict = {k: v for k, v in inputs.items() if k != "action"}
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
params = GenerateGraphParams.from_dict(params_dict)
|
|
268
|
+
except (ValidationError, ValueError) as e:
|
|
269
|
+
raise ActionExecutionError(f"Parameter validation failed: {e}")
|
|
270
|
+
|
|
271
|
+
# Check model_spec exists
|
|
272
|
+
if not params.model_spec.exists():
|
|
273
|
+
raise ActionExecutionError(
|
|
274
|
+
f"Model specification not found: {params.model_spec}"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Dry-run mode: validate only, don't execute
|
|
278
|
+
if mode == "dry-run":
|
|
279
|
+
return self._dry_run_result(params)
|
|
280
|
+
|
|
281
|
+
# Run mode: execute graph generation
|
|
282
|
+
return self._execute_generate_graph(params)
|
|
283
|
+
|
|
284
|
+
def _dry_run_result(self, params: GenerateGraphParams) -> Dict[str, Any]:
|
|
285
|
+
"""Return dry-run result without executing.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
params: Validated parameters.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Dry-run result dictionary.
|
|
292
|
+
"""
|
|
293
|
+
return {
|
|
294
|
+
"status": "skipped",
|
|
295
|
+
"message": "Dry-run mode: would generate graph",
|
|
296
|
+
"model_spec": str(params.model_spec),
|
|
297
|
+
"llm_model": params.llm_model,
|
|
298
|
+
"prompt_detail": params.prompt_detail.value,
|
|
299
|
+
"output": params.output,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
def _execute_generate_graph(
|
|
303
|
+
self, params: GenerateGraphParams
|
|
304
|
+
) -> Dict[str, Any]:
|
|
305
|
+
"""Execute graph generation.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
params: Validated parameters.
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Result dictionary with generated graph.
|
|
312
|
+
"""
|
|
313
|
+
# Import here to avoid slow startup and circular imports
|
|
314
|
+
from causaliq_knowledge.cache import TokenCache
|
|
315
|
+
from causaliq_knowledge.graph import ModelLoader
|
|
316
|
+
from causaliq_knowledge.graph.generator import (
|
|
317
|
+
GraphGenerator,
|
|
318
|
+
GraphGeneratorConfig,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
# Load model specification
|
|
323
|
+
spec = ModelLoader.load(params.model_spec)
|
|
324
|
+
logger.info(
|
|
325
|
+
f"Loaded model specification: {spec.dataset_id} "
|
|
326
|
+
f"({len(spec.variables)} variables)"
|
|
327
|
+
)
|
|
328
|
+
except Exception as e:
|
|
329
|
+
raise ActionExecutionError(
|
|
330
|
+
f"Failed to load model specification: {e}"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Track mapping for name conversion
|
|
334
|
+
llm_to_benchmark_mapping: Dict[str, str] = {}
|
|
335
|
+
|
|
336
|
+
# Determine naming mode
|
|
337
|
+
use_llm_names = not params.use_benchmark_names
|
|
338
|
+
if use_llm_names and spec.uses_distinct_llm_names():
|
|
339
|
+
llm_to_benchmark_mapping = spec.get_llm_to_name_mapping()
|
|
340
|
+
|
|
341
|
+
# Set up cache
|
|
342
|
+
cache: Optional[TokenCache] = None
|
|
343
|
+
cache_path = params.get_effective_cache_path()
|
|
344
|
+
if cache_path is not None:
|
|
345
|
+
try:
|
|
346
|
+
cache = TokenCache(str(cache_path))
|
|
347
|
+
cache.open()
|
|
348
|
+
except Exception as e:
|
|
349
|
+
raise ActionExecutionError(f"Failed to open cache: {e}")
|
|
350
|
+
|
|
351
|
+
try:
|
|
352
|
+
# Import OutputFormat for generator config
|
|
353
|
+
from causaliq_knowledge.graph.prompts import OutputFormat
|
|
354
|
+
|
|
355
|
+
# Create generator - always use edge_list format
|
|
356
|
+
# Derive request_id from output filename stem
|
|
357
|
+
if params.output.lower() == "none":
|
|
358
|
+
request_id = "none"
|
|
359
|
+
else:
|
|
360
|
+
request_id = Path(params.output).stem
|
|
361
|
+
|
|
362
|
+
config = GraphGeneratorConfig(
|
|
363
|
+
temperature=params.llm_temperature,
|
|
364
|
+
output_format=OutputFormat.EDGE_LIST,
|
|
365
|
+
prompt_detail=params.prompt_detail,
|
|
366
|
+
use_llm_names=use_llm_names,
|
|
367
|
+
request_id=request_id,
|
|
368
|
+
)
|
|
369
|
+
generator = GraphGenerator(
|
|
370
|
+
model=params.llm_model, config=config, cache=cache
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Generate graph
|
|
374
|
+
graph = generator.generate_from_spec(
|
|
375
|
+
spec, level=params.prompt_detail
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Map LLM names back to benchmark names
|
|
379
|
+
if llm_to_benchmark_mapping:
|
|
380
|
+
graph = self._map_graph_names(graph, llm_to_benchmark_mapping)
|
|
381
|
+
|
|
382
|
+
# Get stats
|
|
383
|
+
stats = generator.get_stats()
|
|
384
|
+
|
|
385
|
+
# Build result
|
|
386
|
+
result = {
|
|
387
|
+
"status": "success",
|
|
388
|
+
"graph": self._graph_to_dict(graph),
|
|
389
|
+
"edge_count": len(graph.edges),
|
|
390
|
+
"variable_count": len(graph.variables),
|
|
391
|
+
"model_used": params.llm_model,
|
|
392
|
+
"cached": stats.get("cache_hits", 0) > 0,
|
|
393
|
+
"outputs": {
|
|
394
|
+
"graph": self._graph_to_dict(graph),
|
|
395
|
+
"edge_count": len(graph.edges),
|
|
396
|
+
"variable_count": len(graph.variables),
|
|
397
|
+
"model_used": params.llm_model,
|
|
398
|
+
"cached": stats.get("cache_hits", 0) > 0,
|
|
399
|
+
},
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
# Write output file if specified
|
|
403
|
+
output_path = params.get_effective_output_path()
|
|
404
|
+
if output_path:
|
|
405
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
406
|
+
output_path.write_text(
|
|
407
|
+
json.dumps(result["graph"], indent=2),
|
|
408
|
+
encoding="utf-8",
|
|
409
|
+
)
|
|
410
|
+
result["output_file"] = str(output_path)
|
|
411
|
+
|
|
412
|
+
return result
|
|
413
|
+
|
|
414
|
+
except Exception as e:
|
|
415
|
+
raise ActionExecutionError(f"Graph generation failed: {e}")
|
|
416
|
+
finally:
|
|
417
|
+
if cache:
|
|
418
|
+
cache.close()
|
|
419
|
+
|
|
420
|
+
def _graph_to_dict(self, graph: "GeneratedGraph") -> Dict[str, Any]:
|
|
421
|
+
"""Convert GeneratedGraph to dictionary.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
graph: Generated graph object.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Dictionary representation of the graph.
|
|
428
|
+
"""
|
|
429
|
+
return {
|
|
430
|
+
"edges": [
|
|
431
|
+
{
|
|
432
|
+
"source": edge.source,
|
|
433
|
+
"target": edge.target,
|
|
434
|
+
"confidence": edge.confidence,
|
|
435
|
+
}
|
|
436
|
+
for edge in graph.edges
|
|
437
|
+
],
|
|
438
|
+
"variables": graph.variables,
|
|
439
|
+
"reasoning": graph.reasoning,
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
def _map_graph_names(
|
|
443
|
+
self, graph: "GeneratedGraph", mapping: Dict[str, str]
|
|
444
|
+
) -> "GeneratedGraph":
|
|
445
|
+
"""Map variable names in a graph using a mapping dictionary.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
graph: The generated graph with edges to map.
|
|
449
|
+
mapping: Dictionary mapping old names to new names.
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
New GeneratedGraph with mapped variable names.
|
|
453
|
+
"""
|
|
454
|
+
from causaliq_knowledge.graph.response import (
|
|
455
|
+
GeneratedGraph,
|
|
456
|
+
ProposedEdge,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
new_edges = []
|
|
460
|
+
for edge in graph.edges:
|
|
461
|
+
new_edge = ProposedEdge(
|
|
462
|
+
source=mapping.get(edge.source, edge.source),
|
|
463
|
+
target=mapping.get(edge.target, edge.target),
|
|
464
|
+
confidence=edge.confidence,
|
|
465
|
+
)
|
|
466
|
+
new_edges.append(new_edge)
|
|
467
|
+
|
|
468
|
+
new_variables = [mapping.get(v, v) for v in graph.variables]
|
|
469
|
+
|
|
470
|
+
return GeneratedGraph(
|
|
471
|
+
edges=new_edges,
|
|
472
|
+
variables=new_variables,
|
|
473
|
+
reasoning=graph.reasoning,
|
|
474
|
+
metadata=graph.metadata,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
# Export as CausalIQAction for auto-discovery by causaliq-workflow
|
|
479
|
+
# This name is required by the auto-discovery convention
|
|
480
|
+
CausalIQAction = GenerateGraphAction
|
|
@@ -173,7 +173,8 @@ class JsonEncoder(EntryEncoder):
|
|
|
173
173
|
"""Encode a string value with tokenisation.
|
|
174
174
|
|
|
175
175
|
Strings are split into tokens (words/punctuation) with special
|
|
176
|
-
markers for string start/end.
|
|
176
|
+
markers for string start/end. Double quotes within the string
|
|
177
|
+
are encoded as '\\"' token to distinguish from string delimiters.
|
|
177
178
|
|
|
178
179
|
Args:
|
|
179
180
|
value: String to encode.
|
|
@@ -184,7 +185,11 @@ class JsonEncoder(EntryEncoder):
|
|
|
184
185
|
# Split on whitespace and punctuation, keeping delimiters
|
|
185
186
|
tokens = self._tokenise_string(value)
|
|
186
187
|
for token in tokens:
|
|
187
|
-
|
|
188
|
+
# Escape embedded quotes to distinguish from string delimiter
|
|
189
|
+
if token == '"':
|
|
190
|
+
self._encode_token('\\"', token_cache, result)
|
|
191
|
+
else:
|
|
192
|
+
self._encode_token(token, token_cache, result)
|
|
188
193
|
self._encode_token('"', token_cache, result)
|
|
189
194
|
|
|
190
195
|
def _encode_list(
|
|
@@ -296,6 +301,9 @@ class JsonEncoder(EntryEncoder):
|
|
|
296
301
|
) -> tuple[str, int]:
|
|
297
302
|
"""Decode a string value (after opening quote consumed).
|
|
298
303
|
|
|
304
|
+
Handles escaped quotes ('\\"' token) which represent literal
|
|
305
|
+
double quotes within the string content.
|
|
306
|
+
|
|
299
307
|
Args:
|
|
300
308
|
blob: Binary data to decode.
|
|
301
309
|
offset: Current position (after opening quote).
|
|
@@ -317,7 +325,11 @@ class JsonEncoder(EntryEncoder):
|
|
|
317
325
|
if token == '"':
|
|
318
326
|
# End of string
|
|
319
327
|
return "".join(parts), offset
|
|
320
|
-
|
|
328
|
+
elif token == '\\"':
|
|
329
|
+
# Escaped quote - append literal quote character
|
|
330
|
+
parts.append('"')
|
|
331
|
+
else:
|
|
332
|
+
parts.append(token)
|
|
321
333
|
raise ValueError("Unterminated string")
|
|
322
334
|
|
|
323
335
|
def _decode_list(
|
|
@@ -51,6 +51,8 @@ class TokenCache:
|
|
|
51
51
|
data BLOB NOT NULL,
|
|
52
52
|
created_at TEXT NOT NULL,
|
|
53
53
|
metadata BLOB,
|
|
54
|
+
hit_count INTEGER DEFAULT 0,
|
|
55
|
+
last_accessed_at TEXT,
|
|
54
56
|
PRIMARY KEY (hash, entry_type)
|
|
55
57
|
);
|
|
56
58
|
|
|
@@ -239,6 +241,28 @@ class TokenCache:
|
|
|
239
241
|
row = cursor.fetchone()
|
|
240
242
|
return int(row[0]) if row else 0
|
|
241
243
|
|
|
244
|
+
def total_hits(self, entry_type: str | None = None) -> int:
|
|
245
|
+
"""Get total cache hits across all entries.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
entry_type: If provided, count only hits for this entry type.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Total hit count.
|
|
252
|
+
"""
|
|
253
|
+
if entry_type is None:
|
|
254
|
+
cursor = self.conn.execute(
|
|
255
|
+
"SELECT COALESCE(SUM(hit_count), 0) FROM cache_entries"
|
|
256
|
+
)
|
|
257
|
+
else:
|
|
258
|
+
cursor = self.conn.execute(
|
|
259
|
+
"SELECT COALESCE(SUM(hit_count), 0) FROM cache_entries "
|
|
260
|
+
"WHERE entry_type = ?",
|
|
261
|
+
(entry_type,),
|
|
262
|
+
)
|
|
263
|
+
row = cursor.fetchone()
|
|
264
|
+
return int(row[0]) if row else 0
|
|
265
|
+
|
|
242
266
|
def get_or_create_token(self, token: str) -> int:
|
|
243
267
|
"""Get token ID, creating a new entry if needed.
|
|
244
268
|
|
|
@@ -319,7 +343,7 @@ class TokenCache:
|
|
|
319
343
|
self.conn.commit()
|
|
320
344
|
|
|
321
345
|
def get(self, hash: str, entry_type: str) -> bytes | None:
|
|
322
|
-
"""Retrieve a cache entry.
|
|
346
|
+
"""Retrieve a cache entry and increment hit count.
|
|
323
347
|
|
|
324
348
|
Args:
|
|
325
349
|
hash: Unique identifier for the entry.
|
|
@@ -334,7 +358,17 @@ class TokenCache:
|
|
|
334
358
|
(hash, entry_type),
|
|
335
359
|
)
|
|
336
360
|
row = cursor.fetchone()
|
|
337
|
-
|
|
361
|
+
if row:
|
|
362
|
+
# Increment hit count and update last accessed time
|
|
363
|
+
self.conn.execute(
|
|
364
|
+
"UPDATE cache_entries SET hit_count = hit_count + 1, "
|
|
365
|
+
"last_accessed_at = ? WHERE hash = ? AND entry_type = ?",
|
|
366
|
+
(self._utcnow_iso(), hash, entry_type),
|
|
367
|
+
)
|
|
368
|
+
self.conn.commit()
|
|
369
|
+
result: bytes = row[0]
|
|
370
|
+
return result
|
|
371
|
+
return None
|
|
338
372
|
|
|
339
373
|
def get_with_metadata(
|
|
340
374
|
self, hash: str, entry_type: str
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Command-line interface for causaliq-knowledge.
|
|
2
|
+
|
|
3
|
+
This package provides the CLI implementation split into logical modules:
|
|
4
|
+
|
|
5
|
+
- main: Core CLI entry point and query command
|
|
6
|
+
- cache: Cache management commands (stats, export, import)
|
|
7
|
+
- generate: Graph generation commands
|
|
8
|
+
- models: Model listing command
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from causaliq_knowledge.cli.main import cli, main
|
|
14
|
+
|
|
15
|
+
__all__ = ["cli", "main"]
|