causaliq-knowledge 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +5 -2
- causaliq_knowledge/action.py +480 -0
- causaliq_knowledge/cache/encoders/json_encoder.py +15 -3
- causaliq_knowledge/cache/token_cache.py +36 -2
- causaliq_knowledge/cli/__init__.py +15 -0
- causaliq_knowledge/cli/cache.py +478 -0
- causaliq_knowledge/cli/generate.py +410 -0
- causaliq_knowledge/cli/main.py +172 -0
- causaliq_knowledge/cli/models.py +309 -0
- causaliq_knowledge/graph/__init__.py +78 -0
- causaliq_knowledge/graph/generator.py +457 -0
- causaliq_knowledge/graph/loader.py +222 -0
- causaliq_knowledge/graph/models.py +426 -0
- causaliq_knowledge/graph/params.py +175 -0
- causaliq_knowledge/graph/prompts.py +445 -0
- causaliq_knowledge/graph/response.py +392 -0
- causaliq_knowledge/graph/view_filter.py +154 -0
- causaliq_knowledge/llm/base_client.py +6 -0
- causaliq_knowledge/llm/cache.py +124 -61
- causaliq_knowledge/py.typed +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/METADATA +10 -6
- causaliq_knowledge-0.4.0.dist-info/RECORD +42 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/entry_points.txt +3 -0
- causaliq_knowledge/cli.py +0 -757
- causaliq_knowledge-0.3.0.dist-info/RECORD +0 -28
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/WHEEL +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
"""Graph generation CLI commands.
|
|
2
|
+
|
|
3
|
+
This module provides commands for generating causal graphs from
|
|
4
|
+
model specifications using LLMs.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import sys
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING, Optional
|
|
13
|
+
|
|
14
|
+
import click
|
|
15
|
+
from pydantic import ValidationError
|
|
16
|
+
|
|
17
|
+
from causaliq_knowledge.graph.params import GenerateGraphParams
|
|
18
|
+
from causaliq_knowledge.graph.view_filter import PromptDetail
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
21
|
+
from causaliq_knowledge.graph.models import ModelSpec
|
|
22
|
+
from causaliq_knowledge.graph.response import GeneratedGraph
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _map_graph_names(
|
|
26
|
+
graph: "GeneratedGraph", mapping: dict[str, str]
|
|
27
|
+
) -> "GeneratedGraph":
|
|
28
|
+
"""Map variable names in a graph using a mapping dictionary.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
graph: The generated graph with edges to map.
|
|
32
|
+
mapping: Dictionary mapping old names to new names.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
New GeneratedGraph with mapped variable names.
|
|
36
|
+
"""
|
|
37
|
+
from causaliq_knowledge.graph.response import GeneratedGraph, ProposedEdge
|
|
38
|
+
|
|
39
|
+
new_edges = []
|
|
40
|
+
for edge in graph.edges:
|
|
41
|
+
new_edge = ProposedEdge(
|
|
42
|
+
source=mapping.get(edge.source, edge.source),
|
|
43
|
+
target=mapping.get(edge.target, edge.target),
|
|
44
|
+
confidence=edge.confidence,
|
|
45
|
+
)
|
|
46
|
+
new_edges.append(new_edge)
|
|
47
|
+
|
|
48
|
+
# Map variable names too
|
|
49
|
+
new_variables = [mapping.get(v, v) for v in graph.variables]
|
|
50
|
+
|
|
51
|
+
return GeneratedGraph(
|
|
52
|
+
edges=new_edges,
|
|
53
|
+
variables=new_variables,
|
|
54
|
+
reasoning=graph.reasoning,
|
|
55
|
+
metadata=graph.metadata,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@click.command("generate_graph")
|
|
60
|
+
@click.option(
|
|
61
|
+
"--model-spec",
|
|
62
|
+
"-s",
|
|
63
|
+
required=True,
|
|
64
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
65
|
+
help="Path to model specification JSON file.",
|
|
66
|
+
)
|
|
67
|
+
@click.option(
|
|
68
|
+
"--prompt-detail",
|
|
69
|
+
"-p",
|
|
70
|
+
"prompt_detail",
|
|
71
|
+
default="standard",
|
|
72
|
+
type=click.Choice(["minimal", "standard", "rich"], case_sensitive=False),
|
|
73
|
+
help="Detail level for variable information in prompts.",
|
|
74
|
+
)
|
|
75
|
+
@click.option(
|
|
76
|
+
"--use-benchmark-names/--use-llm-names",
|
|
77
|
+
"use_benchmark_names",
|
|
78
|
+
default=False,
|
|
79
|
+
help="Use benchmark names instead of LLM names (test memorisation).",
|
|
80
|
+
)
|
|
81
|
+
@click.option(
|
|
82
|
+
"--llm-model",
|
|
83
|
+
"-m",
|
|
84
|
+
"llm_model",
|
|
85
|
+
default="groq/llama-3.1-8b-instant",
|
|
86
|
+
help="LLM model to use (e.g., groq/llama-3.1-8b-instant).",
|
|
87
|
+
)
|
|
88
|
+
@click.option(
|
|
89
|
+
"--output",
|
|
90
|
+
"-o",
|
|
91
|
+
required=True,
|
|
92
|
+
help="Output: .json file path or 'none' for adjacency matrix to stdout.",
|
|
93
|
+
)
|
|
94
|
+
@click.option(
|
|
95
|
+
"--llm-cache",
|
|
96
|
+
"-c",
|
|
97
|
+
"llm_cache",
|
|
98
|
+
required=True,
|
|
99
|
+
help="Path to cache database (.db) or 'none' to disable caching.",
|
|
100
|
+
)
|
|
101
|
+
@click.option(
|
|
102
|
+
"--llm-temperature",
|
|
103
|
+
"-t",
|
|
104
|
+
type=float,
|
|
105
|
+
default=0.1,
|
|
106
|
+
help="LLM temperature (0.0-1.0). Lower = more deterministic.",
|
|
107
|
+
)
|
|
108
|
+
def generate_graph(
|
|
109
|
+
model_spec: Path,
|
|
110
|
+
prompt_detail: str,
|
|
111
|
+
use_benchmark_names: bool,
|
|
112
|
+
llm_model: str,
|
|
113
|
+
output: str,
|
|
114
|
+
llm_cache: str,
|
|
115
|
+
llm_temperature: float,
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Generate a causal graph from a model specification.
|
|
118
|
+
|
|
119
|
+
Reads variable definitions from a JSON model specification file and
|
|
120
|
+
uses an LLM to propose causal relationships between variables.
|
|
121
|
+
|
|
122
|
+
By default, LLM names are used in prompts to prevent memorisation.
|
|
123
|
+
Use --use-benchmark-names to test with original benchmark names.
|
|
124
|
+
|
|
125
|
+
Output behaviour:
|
|
126
|
+
- If output is a .json file: writes JSON to file, prints edges to stdout
|
|
127
|
+
- If output is 'none': prints adjacency matrix to stdout, edges to stderr
|
|
128
|
+
|
|
129
|
+
Examples:
|
|
130
|
+
|
|
131
|
+
cqknow generate_graph -s model.json -c cache.db -o graph.json
|
|
132
|
+
|
|
133
|
+
cqknow generate_graph -s model.json -c cache.db -o none
|
|
134
|
+
|
|
135
|
+
cqknow generate_graph -s model.json -c none -o none --use-benchmark
|
|
136
|
+
"""
|
|
137
|
+
# Import here to avoid slow startup for --help
|
|
138
|
+
from causaliq_knowledge.cache import TokenCache
|
|
139
|
+
from causaliq_knowledge.graph import ModelLoader
|
|
140
|
+
from causaliq_knowledge.graph.generator import (
|
|
141
|
+
GraphGenerator,
|
|
142
|
+
GraphGeneratorConfig,
|
|
143
|
+
)
|
|
144
|
+
from causaliq_knowledge.graph.prompts import OutputFormat
|
|
145
|
+
|
|
146
|
+
# Validate all parameters using shared model
|
|
147
|
+
try:
|
|
148
|
+
params = GenerateGraphParams(
|
|
149
|
+
model_spec=model_spec,
|
|
150
|
+
prompt_detail=PromptDetail(prompt_detail.lower()),
|
|
151
|
+
use_benchmark_names=use_benchmark_names,
|
|
152
|
+
llm_model=llm_model,
|
|
153
|
+
output=output,
|
|
154
|
+
llm_cache=llm_cache,
|
|
155
|
+
llm_temperature=llm_temperature,
|
|
156
|
+
)
|
|
157
|
+
except ValidationError as e:
|
|
158
|
+
# Format Pydantic errors for CLI
|
|
159
|
+
for error in e.errors():
|
|
160
|
+
field = error.get("loc", ["unknown"])[0]
|
|
161
|
+
msg = error.get("msg", "validation error")
|
|
162
|
+
click.echo(f"Error: --{field}: {msg}", err=True)
|
|
163
|
+
sys.exit(1)
|
|
164
|
+
|
|
165
|
+
# Get effective paths from validated params
|
|
166
|
+
output_path = params.get_effective_output_path()
|
|
167
|
+
|
|
168
|
+
# Load model specification
|
|
169
|
+
try:
|
|
170
|
+
spec = ModelLoader.load(params.model_spec)
|
|
171
|
+
click.echo(
|
|
172
|
+
f"Loaded model specification: {spec.dataset_id} "
|
|
173
|
+
f"({len(spec.variables)} variables)",
|
|
174
|
+
err=True,
|
|
175
|
+
)
|
|
176
|
+
except Exception as e:
|
|
177
|
+
click.echo(f"Error loading model specification: {e}", err=True)
|
|
178
|
+
sys.exit(1)
|
|
179
|
+
|
|
180
|
+
# Track mapping for converting LLM output back to benchmark names
|
|
181
|
+
llm_to_benchmark_mapping: dict[str, str] = {}
|
|
182
|
+
|
|
183
|
+
# Determine naming mode
|
|
184
|
+
use_llm_names = not params.use_benchmark_names
|
|
185
|
+
if use_llm_names and spec.uses_distinct_llm_names():
|
|
186
|
+
llm_to_benchmark_mapping = spec.get_llm_to_name_mapping()
|
|
187
|
+
click.echo("Using LLM names (prevents memorisation)", err=True)
|
|
188
|
+
elif params.use_benchmark_names:
|
|
189
|
+
click.echo("Using benchmark names (memorisation test)", err=True)
|
|
190
|
+
|
|
191
|
+
# Set up cache
|
|
192
|
+
cache: Optional[TokenCache] = None
|
|
193
|
+
cache_path = params.get_effective_cache_path()
|
|
194
|
+
if cache_path is not None:
|
|
195
|
+
try:
|
|
196
|
+
cache = TokenCache(str(cache_path))
|
|
197
|
+
cache.open()
|
|
198
|
+
click.echo(f"Using cache: {cache_path}", err=True)
|
|
199
|
+
except Exception as e:
|
|
200
|
+
click.echo(f"Error opening cache: {e}", err=True)
|
|
201
|
+
sys.exit(1)
|
|
202
|
+
else:
|
|
203
|
+
click.echo("Cache disabled", err=True)
|
|
204
|
+
|
|
205
|
+
# Create generator - use edge_list format for structured output
|
|
206
|
+
try:
|
|
207
|
+
# Derive request_id from output filename stem
|
|
208
|
+
if params.output.lower() == "none":
|
|
209
|
+
request_id = "none"
|
|
210
|
+
else:
|
|
211
|
+
request_id = Path(params.output).stem
|
|
212
|
+
|
|
213
|
+
config = GraphGeneratorConfig(
|
|
214
|
+
temperature=params.llm_temperature,
|
|
215
|
+
output_format=OutputFormat.EDGE_LIST,
|
|
216
|
+
prompt_detail=params.prompt_detail,
|
|
217
|
+
use_llm_names=use_llm_names,
|
|
218
|
+
request_id=request_id,
|
|
219
|
+
)
|
|
220
|
+
generator = GraphGenerator(
|
|
221
|
+
model=params.llm_model, config=config, cache=cache
|
|
222
|
+
)
|
|
223
|
+
except ValueError as e:
|
|
224
|
+
click.echo(f"Error creating generator: {e}", err=True)
|
|
225
|
+
sys.exit(1)
|
|
226
|
+
|
|
227
|
+
# Generate graph
|
|
228
|
+
click.echo(f"Generating graph using {params.llm_model}...", err=True)
|
|
229
|
+
click.echo(f"View level: {params.prompt_detail.value}", err=True)
|
|
230
|
+
|
|
231
|
+
try:
|
|
232
|
+
graph = generator.generate_from_spec(spec, level=params.prompt_detail)
|
|
233
|
+
except Exception as e:
|
|
234
|
+
click.echo(f"Error generating graph: {e}", err=True)
|
|
235
|
+
sys.exit(1)
|
|
236
|
+
|
|
237
|
+
# Map LLM names back to benchmark names
|
|
238
|
+
if llm_to_benchmark_mapping:
|
|
239
|
+
graph = _map_graph_names(graph, llm_to_benchmark_mapping)
|
|
240
|
+
click.echo("Mapped LLM names back to benchmark names", err=True)
|
|
241
|
+
|
|
242
|
+
# Build JSON output
|
|
243
|
+
result = _build_output(graph, spec, params.llm_model, params.prompt_detail)
|
|
244
|
+
|
|
245
|
+
# Output results - always print edges summary to stdout
|
|
246
|
+
_print_edges(graph)
|
|
247
|
+
_print_summary(graph, err=False)
|
|
248
|
+
|
|
249
|
+
if output_path:
|
|
250
|
+
# Write JSON to file
|
|
251
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
252
|
+
output_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
|
253
|
+
click.echo(f"\nOutput written to: {output_path}", err=True)
|
|
254
|
+
else:
|
|
255
|
+
# Print adjacency matrix to stdout
|
|
256
|
+
click.echo()
|
|
257
|
+
_print_adjacency_matrix(graph, spec)
|
|
258
|
+
|
|
259
|
+
# Show stats
|
|
260
|
+
stats = generator.get_stats()
|
|
261
|
+
if stats.get("client_call_count", 0) > 0:
|
|
262
|
+
click.echo(
|
|
263
|
+
f"\nLLM calls: {stats['client_call_count']}, "
|
|
264
|
+
f"Generator calls: {stats['call_count']}",
|
|
265
|
+
err=True,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Close cache if opened
|
|
269
|
+
if cache:
|
|
270
|
+
cache.close()
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _build_output(
|
|
274
|
+
graph: GeneratedGraph,
|
|
275
|
+
spec: ModelSpec,
|
|
276
|
+
llm_model: str,
|
|
277
|
+
level: PromptDetail,
|
|
278
|
+
) -> dict:
|
|
279
|
+
"""Build output dictionary for the generated graph.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
graph: The GeneratedGraph result.
|
|
283
|
+
spec: The ModelSpec used.
|
|
284
|
+
llm_model: LLM model identifier.
|
|
285
|
+
level: View level used.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Dictionary suitable for JSON output.
|
|
289
|
+
"""
|
|
290
|
+
edges = []
|
|
291
|
+
for edge in graph.edges:
|
|
292
|
+
edge_dict = {
|
|
293
|
+
"source": edge.source,
|
|
294
|
+
"target": edge.target,
|
|
295
|
+
"confidence": edge.confidence,
|
|
296
|
+
}
|
|
297
|
+
if edge.reasoning:
|
|
298
|
+
edge_dict["reasoning"] = edge.reasoning
|
|
299
|
+
edges.append(edge_dict)
|
|
300
|
+
|
|
301
|
+
result = {
|
|
302
|
+
"dataset_id": spec.dataset_id,
|
|
303
|
+
"domain": spec.domain,
|
|
304
|
+
"variable_count": len(spec.variables),
|
|
305
|
+
"edge_count": len(edges),
|
|
306
|
+
"edges": edges,
|
|
307
|
+
"generation": {
|
|
308
|
+
"model": llm_model,
|
|
309
|
+
"prompt_detail": level.value,
|
|
310
|
+
},
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
# Add metadata if available
|
|
314
|
+
if graph.metadata:
|
|
315
|
+
result["metadata"] = {
|
|
316
|
+
"model": graph.metadata.model,
|
|
317
|
+
"provider": graph.metadata.provider,
|
|
318
|
+
"input_tokens": graph.metadata.input_tokens,
|
|
319
|
+
"output_tokens": graph.metadata.output_tokens,
|
|
320
|
+
"from_cache": graph.metadata.from_cache,
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
return result
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _print_edges(graph: GeneratedGraph) -> None:
|
|
327
|
+
"""Print proposed edges with confidence bars.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
graph: The GeneratedGraph result.
|
|
331
|
+
"""
|
|
332
|
+
if not graph.edges:
|
|
333
|
+
click.echo("\nNo edges proposed by the LLM.")
|
|
334
|
+
return
|
|
335
|
+
|
|
336
|
+
click.echo(f"\nProposed Edges ({len(graph.edges)}):\n")
|
|
337
|
+
|
|
338
|
+
# Sort by confidence descending
|
|
339
|
+
sorted_edges = sorted(
|
|
340
|
+
graph.edges, key=lambda e: e.confidence, reverse=True
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
for i, edge in enumerate(sorted_edges, 1):
|
|
344
|
+
conf_pct = edge.confidence * 100
|
|
345
|
+
conf_bar = "█" * int(edge.confidence * 10) + "░" * (
|
|
346
|
+
10 - int(edge.confidence * 10)
|
|
347
|
+
)
|
|
348
|
+
click.echo(
|
|
349
|
+
f" {i:2d}. {edge.source} → {edge.target} "
|
|
350
|
+
f"[{conf_bar}] {conf_pct:5.1f}%"
|
|
351
|
+
)
|
|
352
|
+
if edge.reasoning:
|
|
353
|
+
# Wrap reasoning text
|
|
354
|
+
reasoning = edge.reasoning[:100]
|
|
355
|
+
if len(edge.reasoning) > 100:
|
|
356
|
+
reasoning += "..."
|
|
357
|
+
click.echo(f" {reasoning}")
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _print_summary(graph: GeneratedGraph, err: bool = False) -> None:
|
|
361
|
+
"""Print a brief summary of the generated graph.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
graph: The GeneratedGraph result.
|
|
365
|
+
err: Whether to print to stderr.
|
|
366
|
+
"""
|
|
367
|
+
edge_count = len(graph.edges)
|
|
368
|
+
high_conf = sum(1 for e in graph.edges if e.confidence >= 0.7)
|
|
369
|
+
med_conf = sum(1 for e in graph.edges if 0.4 <= e.confidence < 0.7)
|
|
370
|
+
low_conf = sum(1 for e in graph.edges if e.confidence < 0.4)
|
|
371
|
+
|
|
372
|
+
click.echo(f"\nEdge Confidence Summary ({edge_count} edges):", err=err)
|
|
373
|
+
click.echo(f" High confidence (>=0.7): {high_conf}", err=err)
|
|
374
|
+
click.echo(f" Medium confidence (0.4-0.7): {med_conf}", err=err)
|
|
375
|
+
click.echo(f" Low confidence (<0.4): {low_conf}", err=err)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _print_adjacency_matrix(graph: GeneratedGraph, spec: ModelSpec) -> None:
|
|
379
|
+
"""Print adjacency matrix representation of the graph.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
graph: The GeneratedGraph result.
|
|
383
|
+
spec: The ModelSpec used for variable names.
|
|
384
|
+
"""
|
|
385
|
+
# Get variable names in order
|
|
386
|
+
var_names = [v.name for v in spec.variables]
|
|
387
|
+
|
|
388
|
+
# Build edge lookup (source, target) -> confidence
|
|
389
|
+
edge_lookup = {(e.source, e.target): e.confidence for e in graph.edges}
|
|
390
|
+
|
|
391
|
+
click.echo("Adjacency Matrix:")
|
|
392
|
+
click.echo()
|
|
393
|
+
|
|
394
|
+
# Header row
|
|
395
|
+
max_name_len = max(len(name) for name in var_names)
|
|
396
|
+
header = " " * (max_name_len + 2)
|
|
397
|
+
for name in var_names:
|
|
398
|
+
header += f"{name[:3]:>4}"
|
|
399
|
+
click.echo(header)
|
|
400
|
+
|
|
401
|
+
# Data rows
|
|
402
|
+
for i, row_name in enumerate(var_names):
|
|
403
|
+
row = f"{row_name:<{max_name_len}} "
|
|
404
|
+
for j, col_name in enumerate(var_names):
|
|
405
|
+
if (row_name, col_name) in edge_lookup:
|
|
406
|
+
conf = edge_lookup[(row_name, col_name)]
|
|
407
|
+
row += f"{conf:4.1f}"
|
|
408
|
+
else:
|
|
409
|
+
row += " ."
|
|
410
|
+
click.echo(row)
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Main CLI entry point and core commands.
|
|
2
|
+
|
|
3
|
+
This module provides the main CLI group and the query command for
|
|
4
|
+
querying LLMs about causal relationships between variables.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import sys
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
|
|
15
|
+
from causaliq_knowledge import __version__
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@click.group()
|
|
19
|
+
@click.version_option(version=__version__)
|
|
20
|
+
def cli() -> None:
|
|
21
|
+
"""CausalIQ Knowledge - LLM knowledge for causal discovery.
|
|
22
|
+
|
|
23
|
+
Query LLMs about causal relationships between variables.
|
|
24
|
+
"""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@cli.command("query")
|
|
29
|
+
@click.argument("node_a")
|
|
30
|
+
@click.argument("node_b")
|
|
31
|
+
@click.option(
|
|
32
|
+
"--model",
|
|
33
|
+
"-m",
|
|
34
|
+
multiple=True,
|
|
35
|
+
default=["groq/llama-3.1-8b-instant"],
|
|
36
|
+
help="LLM model(s) to query. Can be specified multiple times.",
|
|
37
|
+
)
|
|
38
|
+
@click.option(
|
|
39
|
+
"--domain",
|
|
40
|
+
"-d",
|
|
41
|
+
default=None,
|
|
42
|
+
help="Domain context (e.g., 'medicine', 'economics').",
|
|
43
|
+
)
|
|
44
|
+
@click.option(
|
|
45
|
+
"--strategy",
|
|
46
|
+
"-s",
|
|
47
|
+
type=click.Choice(["weighted_vote", "highest_confidence"]),
|
|
48
|
+
default="weighted_vote",
|
|
49
|
+
help="Consensus strategy for multi-model queries.",
|
|
50
|
+
)
|
|
51
|
+
@click.option(
|
|
52
|
+
"--json",
|
|
53
|
+
"output_json",
|
|
54
|
+
is_flag=True,
|
|
55
|
+
help="Output result as JSON.",
|
|
56
|
+
)
|
|
57
|
+
@click.option(
|
|
58
|
+
"--llm-temperature",
|
|
59
|
+
"-t",
|
|
60
|
+
type=float,
|
|
61
|
+
default=0.1,
|
|
62
|
+
help="LLM temperature (0.0-1.0).",
|
|
63
|
+
)
|
|
64
|
+
def query_edge(
|
|
65
|
+
node_a: str,
|
|
66
|
+
node_b: str,
|
|
67
|
+
model: tuple[str, ...],
|
|
68
|
+
domain: Optional[str],
|
|
69
|
+
strategy: str,
|
|
70
|
+
output_json: bool,
|
|
71
|
+
llm_temperature: float,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Query LLMs about a causal relationship between two variables.
|
|
74
|
+
|
|
75
|
+
NODE_A and NODE_B are the variable names to query about.
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
|
|
79
|
+
cqknow query smoking lung_cancer
|
|
80
|
+
|
|
81
|
+
cqknow query smoking lung_cancer --domain medicine
|
|
82
|
+
|
|
83
|
+
cqknow query X Y --model groq/llama-3.1-8b-instant \
|
|
84
|
+
--model gemini/gemini-2.5-flash
|
|
85
|
+
"""
|
|
86
|
+
# Import here to avoid slow startup for --help
|
|
87
|
+
from causaliq_knowledge.llm import LLMKnowledge
|
|
88
|
+
|
|
89
|
+
# Build context
|
|
90
|
+
context = None
|
|
91
|
+
if domain:
|
|
92
|
+
context = {"domain": domain}
|
|
93
|
+
|
|
94
|
+
# Create provider
|
|
95
|
+
try:
|
|
96
|
+
provider = LLMKnowledge(
|
|
97
|
+
models=list(model),
|
|
98
|
+
consensus_strategy=strategy,
|
|
99
|
+
temperature=llm_temperature,
|
|
100
|
+
)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
click.echo(f"Error creating provider: {e}", err=True)
|
|
103
|
+
sys.exit(1)
|
|
104
|
+
|
|
105
|
+
# Query
|
|
106
|
+
click.echo(
|
|
107
|
+
f"Querying {len(model)} model(s) about: {node_a} -> {node_b}",
|
|
108
|
+
err=True,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
result = provider.query_edge(node_a, node_b, context=context)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
click.echo(f"Error querying LLM: {e}", err=True)
|
|
115
|
+
sys.exit(1)
|
|
116
|
+
|
|
117
|
+
# Output
|
|
118
|
+
if output_json:
|
|
119
|
+
output = {
|
|
120
|
+
"node_a": node_a,
|
|
121
|
+
"node_b": node_b,
|
|
122
|
+
"exists": result.exists,
|
|
123
|
+
"direction": result.direction.value if result.direction else None,
|
|
124
|
+
"confidence": result.confidence,
|
|
125
|
+
"reasoning": result.reasoning,
|
|
126
|
+
"model": result.model,
|
|
127
|
+
}
|
|
128
|
+
click.echo(json.dumps(output, indent=2))
|
|
129
|
+
else:
|
|
130
|
+
# Human-readable output
|
|
131
|
+
exists_map = {True: "Yes", False: "No", None: "Uncertain"}
|
|
132
|
+
exists_str = exists_map[result.exists]
|
|
133
|
+
direction_str = result.direction.value if result.direction else "N/A"
|
|
134
|
+
|
|
135
|
+
click.echo(f"\n{'='*60}")
|
|
136
|
+
click.echo(f"Query: Does '{node_a}' cause '{node_b}'?")
|
|
137
|
+
click.echo("=" * 60)
|
|
138
|
+
click.echo(f"Exists: {exists_str}")
|
|
139
|
+
click.echo(f"Direction: {direction_str}")
|
|
140
|
+
click.echo(f"Confidence: {result.confidence:.2f}")
|
|
141
|
+
click.echo(f"Model(s): {result.model or 'unknown'}")
|
|
142
|
+
click.echo(f"{'='*60}")
|
|
143
|
+
click.echo(f"Reasoning: {result.reasoning}")
|
|
144
|
+
click.echo()
|
|
145
|
+
|
|
146
|
+
# Show stats
|
|
147
|
+
stats = provider.get_stats()
|
|
148
|
+
if stats["total_cost"] > 0:
|
|
149
|
+
click.echo(
|
|
150
|
+
f"Cost: ${stats['total_cost']:.6f} "
|
|
151
|
+
f"({stats['total_calls']} call(s))",
|
|
152
|
+
err=True,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# Import and register command groups
|
|
157
|
+
from causaliq_knowledge.cli.cache import cache_group # noqa: E402
|
|
158
|
+
from causaliq_knowledge.cli.generate import generate_graph # noqa: E402
|
|
159
|
+
from causaliq_knowledge.cli.models import list_models # noqa: E402
|
|
160
|
+
|
|
161
|
+
cli.add_command(cache_group)
|
|
162
|
+
cli.add_command(generate_graph)
|
|
163
|
+
cli.add_command(list_models)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def main() -> None:
|
|
167
|
+
"""Entry point for the CLI."""
|
|
168
|
+
cli()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
if __name__ == "__main__": # pragma: no cover
|
|
172
|
+
main()
|