causaliq-knowledge 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. causaliq_knowledge/__init__.py +6 -3
  2. causaliq_knowledge/action.py +480 -0
  3. causaliq_knowledge/cache/__init__.py +18 -0
  4. causaliq_knowledge/cache/encoders/__init__.py +13 -0
  5. causaliq_knowledge/cache/encoders/base.py +90 -0
  6. causaliq_knowledge/cache/encoders/json_encoder.py +430 -0
  7. causaliq_knowledge/cache/token_cache.py +666 -0
  8. causaliq_knowledge/cli/__init__.py +15 -0
  9. causaliq_knowledge/cli/cache.py +478 -0
  10. causaliq_knowledge/cli/generate.py +410 -0
  11. causaliq_knowledge/cli/main.py +172 -0
  12. causaliq_knowledge/cli/models.py +309 -0
  13. causaliq_knowledge/graph/__init__.py +78 -0
  14. causaliq_knowledge/graph/generator.py +457 -0
  15. causaliq_knowledge/graph/loader.py +222 -0
  16. causaliq_knowledge/graph/models.py +426 -0
  17. causaliq_knowledge/graph/params.py +175 -0
  18. causaliq_knowledge/graph/prompts.py +445 -0
  19. causaliq_knowledge/graph/response.py +392 -0
  20. causaliq_knowledge/graph/view_filter.py +154 -0
  21. causaliq_knowledge/llm/base_client.py +147 -1
  22. causaliq_knowledge/llm/cache.py +443 -0
  23. causaliq_knowledge/py.typed +0 -0
  24. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/METADATA +10 -6
  25. causaliq_knowledge-0.4.0.dist-info/RECORD +42 -0
  26. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/WHEEL +1 -1
  27. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/entry_points.txt +3 -0
  28. causaliq_knowledge/cli.py +0 -414
  29. causaliq_knowledge-0.2.0.dist-info/RECORD +0 -22
  30. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/licenses/LICENSE +0 -0
  31. {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,410 @@
1
+ """Graph generation CLI commands.
2
+
3
+ This module provides commands for generating causal graphs from
4
+ model specifications using LLMs.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import sys
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Optional
13
+
14
+ import click
15
+ from pydantic import ValidationError
16
+
17
+ from causaliq_knowledge.graph.params import GenerateGraphParams
18
+ from causaliq_knowledge.graph.view_filter import PromptDetail
19
+
20
+ if TYPE_CHECKING: # pragma: no cover
21
+ from causaliq_knowledge.graph.models import ModelSpec
22
+ from causaliq_knowledge.graph.response import GeneratedGraph
23
+
24
+
25
+ def _map_graph_names(
26
+ graph: "GeneratedGraph", mapping: dict[str, str]
27
+ ) -> "GeneratedGraph":
28
+ """Map variable names in a graph using a mapping dictionary.
29
+
30
+ Args:
31
+ graph: The generated graph with edges to map.
32
+ mapping: Dictionary mapping old names to new names.
33
+
34
+ Returns:
35
+ New GeneratedGraph with mapped variable names.
36
+ """
37
+ from causaliq_knowledge.graph.response import GeneratedGraph, ProposedEdge
38
+
39
+ new_edges = []
40
+ for edge in graph.edges:
41
+ new_edge = ProposedEdge(
42
+ source=mapping.get(edge.source, edge.source),
43
+ target=mapping.get(edge.target, edge.target),
44
+ confidence=edge.confidence,
45
+ )
46
+ new_edges.append(new_edge)
47
+
48
+ # Map variable names too
49
+ new_variables = [mapping.get(v, v) for v in graph.variables]
50
+
51
+ return GeneratedGraph(
52
+ edges=new_edges,
53
+ variables=new_variables,
54
+ reasoning=graph.reasoning,
55
+ metadata=graph.metadata,
56
+ )
57
+
58
+
59
+ @click.command("generate_graph")
60
+ @click.option(
61
+ "--model-spec",
62
+ "-s",
63
+ required=True,
64
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
65
+ help="Path to model specification JSON file.",
66
+ )
67
+ @click.option(
68
+ "--prompt-detail",
69
+ "-p",
70
+ "prompt_detail",
71
+ default="standard",
72
+ type=click.Choice(["minimal", "standard", "rich"], case_sensitive=False),
73
+ help="Detail level for variable information in prompts.",
74
+ )
75
+ @click.option(
76
+ "--use-benchmark-names/--use-llm-names",
77
+ "use_benchmark_names",
78
+ default=False,
79
+ help="Use benchmark names instead of LLM names (test memorisation).",
80
+ )
81
+ @click.option(
82
+ "--llm-model",
83
+ "-m",
84
+ "llm_model",
85
+ default="groq/llama-3.1-8b-instant",
86
+ help="LLM model to use (e.g., groq/llama-3.1-8b-instant).",
87
+ )
88
+ @click.option(
89
+ "--output",
90
+ "-o",
91
+ required=True,
92
+ help="Output: .json file path or 'none' for adjacency matrix to stdout.",
93
+ )
94
+ @click.option(
95
+ "--llm-cache",
96
+ "-c",
97
+ "llm_cache",
98
+ required=True,
99
+ help="Path to cache database (.db) or 'none' to disable caching.",
100
+ )
101
+ @click.option(
102
+ "--llm-temperature",
103
+ "-t",
104
+ type=float,
105
+ default=0.1,
106
+ help="LLM temperature (0.0-1.0). Lower = more deterministic.",
107
+ )
108
+ def generate_graph(
109
+ model_spec: Path,
110
+ prompt_detail: str,
111
+ use_benchmark_names: bool,
112
+ llm_model: str,
113
+ output: str,
114
+ llm_cache: str,
115
+ llm_temperature: float,
116
+ ) -> None:
117
+ """Generate a causal graph from a model specification.
118
+
119
+ Reads variable definitions from a JSON model specification file and
120
+ uses an LLM to propose causal relationships between variables.
121
+
122
+ By default, LLM names are used in prompts to prevent memorisation.
123
+ Use --use-benchmark-names to test with original benchmark names.
124
+
125
+ Output behaviour:
126
+ - If output is a .json file: writes JSON to file, prints edges to stdout
127
+ - If output is 'none': prints adjacency matrix to stdout, edges to stderr
128
+
129
+ Examples:
130
+
131
+ cqknow generate_graph -s model.json -c cache.db -o graph.json
132
+
133
+ cqknow generate_graph -s model.json -c cache.db -o none
134
+
135
+ cqknow generate_graph -s model.json -c none -o none --use-benchmark
136
+ """
137
+ # Import here to avoid slow startup for --help
138
+ from causaliq_knowledge.cache import TokenCache
139
+ from causaliq_knowledge.graph import ModelLoader
140
+ from causaliq_knowledge.graph.generator import (
141
+ GraphGenerator,
142
+ GraphGeneratorConfig,
143
+ )
144
+ from causaliq_knowledge.graph.prompts import OutputFormat
145
+
146
+ # Validate all parameters using shared model
147
+ try:
148
+ params = GenerateGraphParams(
149
+ model_spec=model_spec,
150
+ prompt_detail=PromptDetail(prompt_detail.lower()),
151
+ use_benchmark_names=use_benchmark_names,
152
+ llm_model=llm_model,
153
+ output=output,
154
+ llm_cache=llm_cache,
155
+ llm_temperature=llm_temperature,
156
+ )
157
+ except ValidationError as e:
158
+ # Format Pydantic errors for CLI
159
+ for error in e.errors():
160
+ field = error.get("loc", ["unknown"])[0]
161
+ msg = error.get("msg", "validation error")
162
+ click.echo(f"Error: --{field}: {msg}", err=True)
163
+ sys.exit(1)
164
+
165
+ # Get effective paths from validated params
166
+ output_path = params.get_effective_output_path()
167
+
168
+ # Load model specification
169
+ try:
170
+ spec = ModelLoader.load(params.model_spec)
171
+ click.echo(
172
+ f"Loaded model specification: {spec.dataset_id} "
173
+ f"({len(spec.variables)} variables)",
174
+ err=True,
175
+ )
176
+ except Exception as e:
177
+ click.echo(f"Error loading model specification: {e}", err=True)
178
+ sys.exit(1)
179
+
180
+ # Track mapping for converting LLM output back to benchmark names
181
+ llm_to_benchmark_mapping: dict[str, str] = {}
182
+
183
+ # Determine naming mode
184
+ use_llm_names = not params.use_benchmark_names
185
+ if use_llm_names and spec.uses_distinct_llm_names():
186
+ llm_to_benchmark_mapping = spec.get_llm_to_name_mapping()
187
+ click.echo("Using LLM names (prevents memorisation)", err=True)
188
+ elif params.use_benchmark_names:
189
+ click.echo("Using benchmark names (memorisation test)", err=True)
190
+
191
+ # Set up cache
192
+ cache: Optional[TokenCache] = None
193
+ cache_path = params.get_effective_cache_path()
194
+ if cache_path is not None:
195
+ try:
196
+ cache = TokenCache(str(cache_path))
197
+ cache.open()
198
+ click.echo(f"Using cache: {cache_path}", err=True)
199
+ except Exception as e:
200
+ click.echo(f"Error opening cache: {e}", err=True)
201
+ sys.exit(1)
202
+ else:
203
+ click.echo("Cache disabled", err=True)
204
+
205
+ # Create generator - use edge_list format for structured output
206
+ try:
207
+ # Derive request_id from output filename stem
208
+ if params.output.lower() == "none":
209
+ request_id = "none"
210
+ else:
211
+ request_id = Path(params.output).stem
212
+
213
+ config = GraphGeneratorConfig(
214
+ temperature=params.llm_temperature,
215
+ output_format=OutputFormat.EDGE_LIST,
216
+ prompt_detail=params.prompt_detail,
217
+ use_llm_names=use_llm_names,
218
+ request_id=request_id,
219
+ )
220
+ generator = GraphGenerator(
221
+ model=params.llm_model, config=config, cache=cache
222
+ )
223
+ except ValueError as e:
224
+ click.echo(f"Error creating generator: {e}", err=True)
225
+ sys.exit(1)
226
+
227
+ # Generate graph
228
+ click.echo(f"Generating graph using {params.llm_model}...", err=True)
229
+ click.echo(f"View level: {params.prompt_detail.value}", err=True)
230
+
231
+ try:
232
+ graph = generator.generate_from_spec(spec, level=params.prompt_detail)
233
+ except Exception as e:
234
+ click.echo(f"Error generating graph: {e}", err=True)
235
+ sys.exit(1)
236
+
237
+ # Map LLM names back to benchmark names
238
+ if llm_to_benchmark_mapping:
239
+ graph = _map_graph_names(graph, llm_to_benchmark_mapping)
240
+ click.echo("Mapped LLM names back to benchmark names", err=True)
241
+
242
+ # Build JSON output
243
+ result = _build_output(graph, spec, params.llm_model, params.prompt_detail)
244
+
245
+ # Output results - always print edges summary to stdout
246
+ _print_edges(graph)
247
+ _print_summary(graph, err=False)
248
+
249
+ if output_path:
250
+ # Write JSON to file
251
+ output_path.parent.mkdir(parents=True, exist_ok=True)
252
+ output_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
253
+ click.echo(f"\nOutput written to: {output_path}", err=True)
254
+ else:
255
+ # Print adjacency matrix to stdout
256
+ click.echo()
257
+ _print_adjacency_matrix(graph, spec)
258
+
259
+ # Show stats
260
+ stats = generator.get_stats()
261
+ if stats.get("client_call_count", 0) > 0:
262
+ click.echo(
263
+ f"\nLLM calls: {stats['client_call_count']}, "
264
+ f"Generator calls: {stats['call_count']}",
265
+ err=True,
266
+ )
267
+
268
+ # Close cache if opened
269
+ if cache:
270
+ cache.close()
271
+
272
+
273
+ def _build_output(
274
+ graph: GeneratedGraph,
275
+ spec: ModelSpec,
276
+ llm_model: str,
277
+ level: PromptDetail,
278
+ ) -> dict:
279
+ """Build output dictionary for the generated graph.
280
+
281
+ Args:
282
+ graph: The GeneratedGraph result.
283
+ spec: The ModelSpec used.
284
+ llm_model: LLM model identifier.
285
+ level: View level used.
286
+
287
+ Returns:
288
+ Dictionary suitable for JSON output.
289
+ """
290
+ edges = []
291
+ for edge in graph.edges:
292
+ edge_dict = {
293
+ "source": edge.source,
294
+ "target": edge.target,
295
+ "confidence": edge.confidence,
296
+ }
297
+ if edge.reasoning:
298
+ edge_dict["reasoning"] = edge.reasoning
299
+ edges.append(edge_dict)
300
+
301
+ result = {
302
+ "dataset_id": spec.dataset_id,
303
+ "domain": spec.domain,
304
+ "variable_count": len(spec.variables),
305
+ "edge_count": len(edges),
306
+ "edges": edges,
307
+ "generation": {
308
+ "model": llm_model,
309
+ "prompt_detail": level.value,
310
+ },
311
+ }
312
+
313
+ # Add metadata if available
314
+ if graph.metadata:
315
+ result["metadata"] = {
316
+ "model": graph.metadata.model,
317
+ "provider": graph.metadata.provider,
318
+ "input_tokens": graph.metadata.input_tokens,
319
+ "output_tokens": graph.metadata.output_tokens,
320
+ "from_cache": graph.metadata.from_cache,
321
+ }
322
+
323
+ return result
324
+
325
+
326
+ def _print_edges(graph: GeneratedGraph) -> None:
327
+ """Print proposed edges with confidence bars.
328
+
329
+ Args:
330
+ graph: The GeneratedGraph result.
331
+ """
332
+ if not graph.edges:
333
+ click.echo("\nNo edges proposed by the LLM.")
334
+ return
335
+
336
+ click.echo(f"\nProposed Edges ({len(graph.edges)}):\n")
337
+
338
+ # Sort by confidence descending
339
+ sorted_edges = sorted(
340
+ graph.edges, key=lambda e: e.confidence, reverse=True
341
+ )
342
+
343
+ for i, edge in enumerate(sorted_edges, 1):
344
+ conf_pct = edge.confidence * 100
345
+ conf_bar = "█" * int(edge.confidence * 10) + "░" * (
346
+ 10 - int(edge.confidence * 10)
347
+ )
348
+ click.echo(
349
+ f" {i:2d}. {edge.source} → {edge.target} "
350
+ f"[{conf_bar}] {conf_pct:5.1f}%"
351
+ )
352
+ if edge.reasoning:
353
+ # Wrap reasoning text
354
+ reasoning = edge.reasoning[:100]
355
+ if len(edge.reasoning) > 100:
356
+ reasoning += "..."
357
+ click.echo(f" {reasoning}")
358
+
359
+
360
+ def _print_summary(graph: GeneratedGraph, err: bool = False) -> None:
361
+ """Print a brief summary of the generated graph.
362
+
363
+ Args:
364
+ graph: The GeneratedGraph result.
365
+ err: Whether to print to stderr.
366
+ """
367
+ edge_count = len(graph.edges)
368
+ high_conf = sum(1 for e in graph.edges if e.confidence >= 0.7)
369
+ med_conf = sum(1 for e in graph.edges if 0.4 <= e.confidence < 0.7)
370
+ low_conf = sum(1 for e in graph.edges if e.confidence < 0.4)
371
+
372
+ click.echo(f"\nEdge Confidence Summary ({edge_count} edges):", err=err)
373
+ click.echo(f" High confidence (>=0.7): {high_conf}", err=err)
374
+ click.echo(f" Medium confidence (0.4-0.7): {med_conf}", err=err)
375
+ click.echo(f" Low confidence (<0.4): {low_conf}", err=err)
376
+
377
+
378
+ def _print_adjacency_matrix(graph: GeneratedGraph, spec: ModelSpec) -> None:
379
+ """Print adjacency matrix representation of the graph.
380
+
381
+ Args:
382
+ graph: The GeneratedGraph result.
383
+ spec: The ModelSpec used for variable names.
384
+ """
385
+ # Get variable names in order
386
+ var_names = [v.name for v in spec.variables]
387
+
388
+ # Build edge lookup (source, target) -> confidence
389
+ edge_lookup = {(e.source, e.target): e.confidence for e in graph.edges}
390
+
391
+ click.echo("Adjacency Matrix:")
392
+ click.echo()
393
+
394
+ # Header row
395
+ max_name_len = max(len(name) for name in var_names)
396
+ header = " " * (max_name_len + 2)
397
+ for name in var_names:
398
+ header += f"{name[:3]:>4}"
399
+ click.echo(header)
400
+
401
+ # Data rows
402
+ for i, row_name in enumerate(var_names):
403
+ row = f"{row_name:<{max_name_len}} "
404
+ for j, col_name in enumerate(var_names):
405
+ if (row_name, col_name) in edge_lookup:
406
+ conf = edge_lookup[(row_name, col_name)]
407
+ row += f"{conf:4.1f}"
408
+ else:
409
+ row += " ."
410
+ click.echo(row)
@@ -0,0 +1,172 @@
1
+ """Main CLI entry point and core commands.
2
+
3
+ This module provides the main CLI group and the query command for
4
+ querying LLMs about causal relationships between variables.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import sys
11
+ from typing import Optional
12
+
13
+ import click
14
+
15
+ from causaliq_knowledge import __version__
16
+
17
+
18
+ @click.group()
19
+ @click.version_option(version=__version__)
20
+ def cli() -> None:
21
+ """CausalIQ Knowledge - LLM knowledge for causal discovery.
22
+
23
+ Query LLMs about causal relationships between variables.
24
+ """
25
+ pass
26
+
27
+
28
+ @cli.command("query")
29
+ @click.argument("node_a")
30
+ @click.argument("node_b")
31
+ @click.option(
32
+ "--model",
33
+ "-m",
34
+ multiple=True,
35
+ default=["groq/llama-3.1-8b-instant"],
36
+ help="LLM model(s) to query. Can be specified multiple times.",
37
+ )
38
+ @click.option(
39
+ "--domain",
40
+ "-d",
41
+ default=None,
42
+ help="Domain context (e.g., 'medicine', 'economics').",
43
+ )
44
+ @click.option(
45
+ "--strategy",
46
+ "-s",
47
+ type=click.Choice(["weighted_vote", "highest_confidence"]),
48
+ default="weighted_vote",
49
+ help="Consensus strategy for multi-model queries.",
50
+ )
51
+ @click.option(
52
+ "--json",
53
+ "output_json",
54
+ is_flag=True,
55
+ help="Output result as JSON.",
56
+ )
57
+ @click.option(
58
+ "--llm-temperature",
59
+ "-t",
60
+ type=float,
61
+ default=0.1,
62
+ help="LLM temperature (0.0-1.0).",
63
+ )
64
+ def query_edge(
65
+ node_a: str,
66
+ node_b: str,
67
+ model: tuple[str, ...],
68
+ domain: Optional[str],
69
+ strategy: str,
70
+ output_json: bool,
71
+ llm_temperature: float,
72
+ ) -> None:
73
+ """Query LLMs about a causal relationship between two variables.
74
+
75
+ NODE_A and NODE_B are the variable names to query about.
76
+
77
+ Examples:
78
+
79
+ cqknow query smoking lung_cancer
80
+
81
+ cqknow query smoking lung_cancer --domain medicine
82
+
83
+ cqknow query X Y --model groq/llama-3.1-8b-instant \
84
+ --model gemini/gemini-2.5-flash
85
+ """
86
+ # Import here to avoid slow startup for --help
87
+ from causaliq_knowledge.llm import LLMKnowledge
88
+
89
+ # Build context
90
+ context = None
91
+ if domain:
92
+ context = {"domain": domain}
93
+
94
+ # Create provider
95
+ try:
96
+ provider = LLMKnowledge(
97
+ models=list(model),
98
+ consensus_strategy=strategy,
99
+ temperature=llm_temperature,
100
+ )
101
+ except Exception as e:
102
+ click.echo(f"Error creating provider: {e}", err=True)
103
+ sys.exit(1)
104
+
105
+ # Query
106
+ click.echo(
107
+ f"Querying {len(model)} model(s) about: {node_a} -> {node_b}",
108
+ err=True,
109
+ )
110
+
111
+ try:
112
+ result = provider.query_edge(node_a, node_b, context=context)
113
+ except Exception as e:
114
+ click.echo(f"Error querying LLM: {e}", err=True)
115
+ sys.exit(1)
116
+
117
+ # Output
118
+ if output_json:
119
+ output = {
120
+ "node_a": node_a,
121
+ "node_b": node_b,
122
+ "exists": result.exists,
123
+ "direction": result.direction.value if result.direction else None,
124
+ "confidence": result.confidence,
125
+ "reasoning": result.reasoning,
126
+ "model": result.model,
127
+ }
128
+ click.echo(json.dumps(output, indent=2))
129
+ else:
130
+ # Human-readable output
131
+ exists_map = {True: "Yes", False: "No", None: "Uncertain"}
132
+ exists_str = exists_map[result.exists]
133
+ direction_str = result.direction.value if result.direction else "N/A"
134
+
135
+ click.echo(f"\n{'='*60}")
136
+ click.echo(f"Query: Does '{node_a}' cause '{node_b}'?")
137
+ click.echo("=" * 60)
138
+ click.echo(f"Exists: {exists_str}")
139
+ click.echo(f"Direction: {direction_str}")
140
+ click.echo(f"Confidence: {result.confidence:.2f}")
141
+ click.echo(f"Model(s): {result.model or 'unknown'}")
142
+ click.echo(f"{'='*60}")
143
+ click.echo(f"Reasoning: {result.reasoning}")
144
+ click.echo()
145
+
146
+ # Show stats
147
+ stats = provider.get_stats()
148
+ if stats["total_cost"] > 0:
149
+ click.echo(
150
+ f"Cost: ${stats['total_cost']:.6f} "
151
+ f"({stats['total_calls']} call(s))",
152
+ err=True,
153
+ )
154
+
155
+
156
+ # Import and register command groups
157
+ from causaliq_knowledge.cli.cache import cache_group # noqa: E402
158
+ from causaliq_knowledge.cli.generate import generate_graph # noqa: E402
159
+ from causaliq_knowledge.cli.models import list_models # noqa: E402
160
+
161
+ cli.add_command(cache_group)
162
+ cli.add_command(generate_graph)
163
+ cli.add_command(list_models)
164
+
165
+
166
+ def main() -> None:
167
+ """Entry point for the CLI."""
168
+ cli()
169
+
170
+
171
+ if __name__ == "__main__": # pragma: no cover
172
+ main()