causaliq-knowledge 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,457 @@
1
+ """Graph generator using LLM providers.
2
+
3
+ This module provides the GraphGenerator class for generating complete
4
+ causal graphs from variable specifications using LLM providers.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import hashlib
10
+ import json
11
+ import logging
12
+ import time
13
+ from dataclasses import dataclass
14
+ from datetime import datetime, timezone
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
16
+
17
+ from causaliq_knowledge.graph.prompts import GraphQueryPrompt, OutputFormat
18
+ from causaliq_knowledge.graph.response import (
19
+ GeneratedGraph,
20
+ GenerationMetadata,
21
+ parse_graph_response,
22
+ )
23
+ from causaliq_knowledge.graph.view_filter import PromptDetail
24
+ from causaliq_knowledge.llm.anthropic_client import (
25
+ AnthropicClient,
26
+ AnthropicConfig,
27
+ )
28
+ from causaliq_knowledge.llm.deepseek_client import (
29
+ DeepSeekClient,
30
+ DeepSeekConfig,
31
+ )
32
+ from causaliq_knowledge.llm.gemini_client import GeminiClient, GeminiConfig
33
+ from causaliq_knowledge.llm.groq_client import GroqClient, GroqConfig
34
+ from causaliq_knowledge.llm.mistral_client import MistralClient, MistralConfig
35
+ from causaliq_knowledge.llm.ollama_client import OllamaClient, OllamaConfig
36
+ from causaliq_knowledge.llm.openai_client import OpenAIClient, OpenAIConfig
37
+
38
+ if TYPE_CHECKING: # pragma: no cover
39
+ from causaliq_knowledge.cache import TokenCache
40
+ from causaliq_knowledge.graph.models import ModelSpec
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ # Type alias for supported clients
46
+ LLMClientType = Union[
47
+ AnthropicClient,
48
+ DeepSeekClient,
49
+ GeminiClient,
50
+ GroqClient,
51
+ MistralClient,
52
+ OllamaClient,
53
+ OpenAIClient,
54
+ ]
55
+
56
+
57
+ @dataclass
58
+ class GraphGeneratorConfig:
59
+ """Configuration for the GraphGenerator.
60
+
61
+ Attributes:
62
+ temperature: LLM sampling temperature (lower = more deterministic).
63
+ max_tokens: Maximum tokens in LLM response.
64
+ timeout: Request timeout in seconds.
65
+ output_format: Desired output format (edge_list or adjacency_matrix).
66
+ prompt_detail: Detail level for variable information in prompts.
67
+ use_llm_names: Use llm_name instead of benchmark name in prompts.
68
+ request_id: Optional identifier for requests (stored in metadata).
69
+ """
70
+
71
+ temperature: float = 0.1
72
+ max_tokens: int = 2000
73
+ timeout: float = 60.0
74
+ output_format: OutputFormat = OutputFormat.EDGE_LIST
75
+ prompt_detail: PromptDetail = PromptDetail.STANDARD
76
+ use_llm_names: bool = True
77
+ request_id: str = ""
78
+
79
+
80
+ class GraphGenerator:
81
+ """Generate causal graphs from variable specifications using LLMs.
82
+
83
+ This class provides methods for generating complete causal graphs
84
+ from ModelSpec objects or variable dictionaries. It supports all
85
+ LLM providers available in causaliq-knowledge and integrates with
86
+ the TokenCache for efficient caching of requests.
87
+
88
+ Attributes:
89
+ model: The LLM model identifier (e.g., "groq/llama-3.1-8b-instant").
90
+ config: Configuration for generation parameters.
91
+
92
+ Example:
93
+ >>> from causaliq_knowledge.graph import ModelLoader
94
+ >>> from causaliq_knowledge.graph.generator import GraphGenerator
95
+ >>>
96
+ >>> # Load model specification
97
+ >>> spec = ModelLoader.load("model.json")
98
+ >>>
99
+ >>> # Create generator
100
+ >>> generator = GraphGenerator(model="groq/llama-3.1-8b-instant")
101
+ >>>
102
+ >>> # Generate graph
103
+ >>> graph = generator.generate_from_spec(spec)
104
+ >>> print(f"Generated {len(graph.edges)} edges")
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ model: str = "groq/llama-3.1-8b-instant",
110
+ config: Optional[GraphGeneratorConfig] = None,
111
+ cache: Optional["TokenCache"] = None,
112
+ ) -> None:
113
+ """Initialise the GraphGenerator.
114
+
115
+ Args:
116
+ model: LLM model identifier with provider prefix. Supported:
117
+ - "groq/llama-3.1-8b-instant" (Groq API)
118
+ - "gemini/gemini-2.5-flash" (Google Gemini)
119
+ - "openai/gpt-4o" (OpenAI)
120
+ - "anthropic/claude-3-5-sonnet-20241022" (Anthropic)
121
+ - "deepseek/deepseek-chat" (DeepSeek)
122
+ - "mistral/mistral-small-latest" (Mistral)
123
+ - "ollama/llama3.2:1b" (Local Ollama)
124
+ config: Generation configuration. Uses defaults if None.
125
+ cache: TokenCache instance for caching. Disabled if None.
126
+
127
+ Raises:
128
+ ValueError: If the model prefix is not supported.
129
+ """
130
+ self._model = model
131
+ self._config = config or GraphGeneratorConfig()
132
+ self._cache = cache
133
+ self._client = self._create_client(model)
134
+ self._call_count = 0
135
+
136
+ # Configure cache on client if provided
137
+ if cache is not None:
138
+ self._client.set_cache(cache, use_cache=True)
139
+
140
+ def _create_client(self, model: str) -> LLMClientType:
141
+ """Create the appropriate LLM client for the model.
142
+
143
+ Args:
144
+ model: Model identifier with provider prefix.
145
+
146
+ Returns:
147
+ Configured LLM client instance.
148
+
149
+ Raises:
150
+ ValueError: If the model prefix is not supported.
151
+ """
152
+ config = self._config
153
+
154
+ if model.startswith("anthropic/"):
155
+ model_name = model.split("/", 1)[1]
156
+ return AnthropicClient(
157
+ config=AnthropicConfig(
158
+ model=model_name,
159
+ temperature=config.temperature,
160
+ max_tokens=config.max_tokens,
161
+ timeout=config.timeout,
162
+ )
163
+ )
164
+ elif model.startswith("deepseek/"):
165
+ model_name = model.split("/", 1)[1]
166
+ return DeepSeekClient(
167
+ config=DeepSeekConfig(
168
+ model=model_name,
169
+ temperature=config.temperature,
170
+ max_tokens=config.max_tokens,
171
+ timeout=config.timeout,
172
+ )
173
+ )
174
+ elif model.startswith("gemini/"):
175
+ model_name = model.split("/", 1)[1]
176
+ return GeminiClient(
177
+ config=GeminiConfig(
178
+ model=model_name,
179
+ temperature=config.temperature,
180
+ max_tokens=config.max_tokens,
181
+ timeout=config.timeout,
182
+ )
183
+ )
184
+ elif model.startswith("groq/"):
185
+ model_name = model.split("/", 1)[1]
186
+ return GroqClient(
187
+ config=GroqConfig(
188
+ model=model_name,
189
+ temperature=config.temperature,
190
+ max_tokens=config.max_tokens,
191
+ timeout=config.timeout,
192
+ )
193
+ )
194
+ elif model.startswith("mistral/"):
195
+ model_name = model.split("/", 1)[1]
196
+ return MistralClient(
197
+ config=MistralConfig(
198
+ model=model_name,
199
+ temperature=config.temperature,
200
+ max_tokens=config.max_tokens,
201
+ timeout=config.timeout,
202
+ )
203
+ )
204
+ elif model.startswith("ollama/"):
205
+ model_name = model.split("/", 1)[1]
206
+ return OllamaClient(
207
+ config=OllamaConfig(
208
+ model=model_name,
209
+ temperature=config.temperature,
210
+ max_tokens=config.max_tokens,
211
+ timeout=config.timeout,
212
+ )
213
+ )
214
+ elif model.startswith("openai/"):
215
+ model_name = model.split("/", 1)[1]
216
+ return OpenAIClient(
217
+ config=OpenAIConfig(
218
+ model=model_name,
219
+ temperature=config.temperature,
220
+ max_tokens=config.max_tokens,
221
+ timeout=config.timeout,
222
+ )
223
+ )
224
+ else:
225
+ supported = [
226
+ "anthropic/",
227
+ "deepseek/",
228
+ "gemini/",
229
+ "groq/",
230
+ "mistral/",
231
+ "ollama/",
232
+ "openai/",
233
+ ]
234
+ raise ValueError(
235
+ f"Model '{model}' not supported. "
236
+ f"Supported prefixes: {supported}."
237
+ )
238
+
239
+ @property
240
+ def model(self) -> str:
241
+ """Return the model identifier."""
242
+ return self._model
243
+
244
+ @property
245
+ def config(self) -> GraphGeneratorConfig:
246
+ """Return the generator configuration."""
247
+ return self._config
248
+
249
+ @property
250
+ def call_count(self) -> int:
251
+ """Return the number of generation calls made."""
252
+ return self._call_count
253
+
254
+ def set_cache(
255
+ self,
256
+ cache: Optional["TokenCache"],
257
+ use_cache: bool = True,
258
+ ) -> None:
259
+ """Configure caching for this generator.
260
+
261
+ Args:
262
+ cache: TokenCache instance for caching, or None to disable.
263
+ use_cache: Whether to use the cache (default True).
264
+ """
265
+ self._cache = cache
266
+ self._client.set_cache(cache, use_cache)
267
+
268
+ def _build_cache_key(
269
+ self,
270
+ prompt: GraphQueryPrompt,
271
+ system_prompt: str,
272
+ user_prompt: str,
273
+ ) -> str:
274
+ """Build a deterministic cache key for graph generation.
275
+
276
+ Uses a prefix to distinguish graph queries from edge queries.
277
+
278
+ Args:
279
+ prompt: The GraphQueryPrompt used.
280
+ system_prompt: The system prompt string.
281
+ user_prompt: The user prompt string.
282
+
283
+ Returns:
284
+ 16-character hex string cache key with graph prefix.
285
+ """
286
+ key_data = {
287
+ "type": "graph_generation",
288
+ "model": self._model,
289
+ "output_format": self._config.output_format.value,
290
+ "prompt_detail": prompt.level.value,
291
+ "system_prompt": system_prompt,
292
+ "user_prompt": user_prompt,
293
+ "temperature": self._config.temperature,
294
+ }
295
+ key_json = json.dumps(key_data, sort_keys=True, separators=(",", ":"))
296
+ return "graph_" + hashlib.sha256(key_json.encode()).hexdigest()[:12]
297
+
298
+ def generate_graph(
299
+ self,
300
+ variables: List[Dict[str, Any]],
301
+ level: Optional[PromptDetail] = None,
302
+ domain: Optional[str] = None,
303
+ output_format: Optional[OutputFormat] = None,
304
+ system_prompt: Optional[str] = None,
305
+ ) -> GeneratedGraph:
306
+ """Generate a causal graph from variable dictionaries.
307
+
308
+ Args:
309
+ variables: List of variable dictionaries with at least "name".
310
+ level: View level for context. Uses config default if None.
311
+ domain: Optional domain context for the query.
312
+ output_format: Output format. Uses config default if None.
313
+ system_prompt: Custom system prompt (optional).
314
+
315
+ Returns:
316
+ GeneratedGraph with proposed edges and metadata.
317
+
318
+ Raises:
319
+ ValueError: If LLM response cannot be parsed.
320
+ """
321
+ level = level or self._config.prompt_detail
322
+ output_format = output_format or self._config.output_format
323
+
324
+ # Build the prompt
325
+ prompt = GraphQueryPrompt(
326
+ variables=variables,
327
+ level=level,
328
+ domain=domain,
329
+ output_format=output_format,
330
+ system_prompt=system_prompt,
331
+ )
332
+
333
+ return self._execute_query(prompt)
334
+
335
+ def generate_from_spec(
336
+ self,
337
+ spec: "ModelSpec",
338
+ level: Optional[PromptDetail] = None,
339
+ output_format: Optional[OutputFormat] = None,
340
+ system_prompt: Optional[str] = None,
341
+ use_llm_names: Optional[bool] = None,
342
+ ) -> GeneratedGraph:
343
+ """Generate a causal graph from a ModelSpec.
344
+
345
+ Convenience method that extracts variables and domain from the
346
+ specification automatically.
347
+
348
+ Args:
349
+ spec: The model specification.
350
+ level: View level for context. Uses config default if None.
351
+ output_format: Output format. Uses config default if None.
352
+ system_prompt: Custom system prompt (optional).
353
+ use_llm_names: Use llm_name instead of benchmark name.
354
+ Uses config default if None.
355
+
356
+ Returns:
357
+ GeneratedGraph with proposed edges and metadata.
358
+
359
+ Raises:
360
+ ValueError: If LLM response cannot be parsed.
361
+ """
362
+ level = level or self._config.prompt_detail
363
+ output_format = output_format or self._config.output_format
364
+ use_llm = (
365
+ use_llm_names
366
+ if use_llm_names is not None
367
+ else self._config.use_llm_names
368
+ )
369
+
370
+ # Use the class method to create prompt from spec
371
+ prompt = GraphQueryPrompt.from_model_spec(
372
+ spec=spec,
373
+ level=level,
374
+ output_format=output_format,
375
+ system_prompt=system_prompt,
376
+ use_llm_names=use_llm,
377
+ )
378
+
379
+ return self._execute_query(prompt)
380
+
381
+ def _execute_query(self, prompt: GraphQueryPrompt) -> GeneratedGraph:
382
+ """Execute the LLM query and parse the response.
383
+
384
+ Args:
385
+ prompt: The configured GraphQueryPrompt.
386
+
387
+ Returns:
388
+ GeneratedGraph with parsed edges and metadata.
389
+
390
+ Raises:
391
+ ValueError: If LLM response cannot be parsed.
392
+ """
393
+ system_prompt, user_prompt = prompt.build()
394
+ variable_names = prompt.get_variable_names()
395
+
396
+ # Build messages for the LLM
397
+ messages: List[Dict[str, str]] = []
398
+ if system_prompt:
399
+ messages.append({"role": "system", "content": system_prompt})
400
+ messages.append({"role": "user", "content": user_prompt})
401
+
402
+ # Make the request (using cached_completion if cache is set)
403
+ start_time = time.perf_counter()
404
+ from_cache = False
405
+
406
+ if self._cache is not None and self._client.use_cache:
407
+ response = self._client.cached_completion(
408
+ messages, request_id=self._config.request_id
409
+ )
410
+ # Check if response was from cache by comparing timing
411
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
412
+ # If latency is very low, likely from cache
413
+ from_cache = latency_ms < 50
414
+ else:
415
+ response = self._client.completion(messages)
416
+
417
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
418
+ self._call_count += 1
419
+
420
+ # Parse the response
421
+ output_format_str = prompt.output_format.value
422
+ graph = parse_graph_response(
423
+ response.content,
424
+ variable_names,
425
+ output_format_str,
426
+ )
427
+
428
+ # Add metadata
429
+ provider = self._model.split("/")[0] if "/" in self._model else ""
430
+ model_name = (
431
+ self._model.split("/", 1)[1] if "/" in self._model else self._model
432
+ )
433
+
434
+ graph.metadata = GenerationMetadata(
435
+ model=model_name,
436
+ provider=provider,
437
+ timestamp=datetime.now(timezone.utc),
438
+ latency_ms=latency_ms,
439
+ input_tokens=response.input_tokens,
440
+ output_tokens=response.output_tokens,
441
+ cost_usd=response.cost,
442
+ from_cache=from_cache,
443
+ )
444
+
445
+ return graph
446
+
447
+ def get_stats(self) -> Dict[str, Any]:
448
+ """Get statistics about generation calls.
449
+
450
+ Returns:
451
+ Dict with call_count, model, and client stats.
452
+ """
453
+ return {
454
+ "model": self._model,
455
+ "call_count": self._call_count,
456
+ "client_call_count": self._client.call_count,
457
+ }
@@ -0,0 +1,222 @@
1
+ """Model specification loader with validation.
2
+
3
+ This module provides functionality to load and validate model
4
+ specification JSON files.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Union
12
+
13
+ from causaliq_knowledge.graph.models import ModelSpec
14
+
15
+
16
+ class ModelLoadError(Exception):
17
+ """Error raised when model loading fails.
18
+
19
+ Attributes:
20
+ message: Error description.
21
+ path: Path to the file that failed to load.
22
+ details: Additional error details.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ message: str,
28
+ path: Path | str | None = None,
29
+ details: str | None = None,
30
+ ) -> None:
31
+ self.message = message
32
+ self.path = path
33
+ self.details = details
34
+ full_message = message
35
+ if path:
36
+ full_message = f"{message}: {path}"
37
+ if details:
38
+ full_message = f"{full_message}\n Details: {details}"
39
+ super().__init__(full_message)
40
+
41
+
42
+ class ModelLoader:
43
+ """Loader for model specification JSON files.
44
+
45
+ This class provides methods to load, validate, and access
46
+ model specifications from JSON files.
47
+
48
+ Example:
49
+ >>> loader = ModelLoader()
50
+ >>> spec = loader.load("path/to/model.json")
51
+ >>> print(spec.dataset_id)
52
+ 'cancer'
53
+
54
+ >>> # Or load from dict
55
+ >>> spec = loader.from_dict(
56
+ ... {"dataset_id": "test", "domain": "test", ...}
57
+ ... )
58
+ """
59
+
60
+ @staticmethod
61
+ def load(path: Union[str, Path]) -> ModelSpec:
62
+ """Load a model specification from a JSON file.
63
+
64
+ Args:
65
+ path: Path to the JSON file.
66
+
67
+ Returns:
68
+ Validated ModelSpec instance.
69
+
70
+ Raises:
71
+ ModelLoadError: If the file cannot be loaded or validated.
72
+ """
73
+ path = Path(path)
74
+
75
+ # Check file exists
76
+ if not path.exists():
77
+ raise ModelLoadError("Model specification file not found", path)
78
+
79
+ # Check file extension
80
+ if path.suffix.lower() != ".json":
81
+ raise ModelLoadError(
82
+ "Model specification must be a JSON file",
83
+ path,
84
+ f"Got extension: {path.suffix}",
85
+ )
86
+
87
+ # Load JSON
88
+ try:
89
+ with open(path, "r", encoding="utf-8") as f:
90
+ data = json.load(f)
91
+ except json.JSONDecodeError as e:
92
+ raise ModelLoadError(
93
+ "Invalid JSON in model specification",
94
+ path,
95
+ str(e),
96
+ ) from e
97
+ except OSError as e:
98
+ raise ModelLoadError(
99
+ "Failed to read model specification file",
100
+ path,
101
+ str(e),
102
+ ) from e
103
+
104
+ # Validate and create ModelSpec
105
+ return ModelLoader.from_dict(data, source_path=path)
106
+
107
+ @staticmethod
108
+ def from_dict(
109
+ data: dict,
110
+ source_path: Path | str | None = None,
111
+ ) -> ModelSpec:
112
+ """Create a ModelSpec from a dictionary.
113
+
114
+ Args:
115
+ data: Dictionary containing model specification.
116
+ source_path: Optional source path for error messages.
117
+
118
+ Returns:
119
+ Validated ModelSpec instance.
120
+
121
+ Raises:
122
+ ModelLoadError: If validation fails.
123
+ """
124
+ # Check required fields
125
+ required_fields = ["dataset_id", "domain"]
126
+ missing = [f for f in required_fields if f not in data]
127
+ if missing:
128
+ raise ModelLoadError(
129
+ f"Missing required fields: {', '.join(missing)}",
130
+ source_path,
131
+ )
132
+
133
+ # Validate with Pydantic
134
+ try:
135
+ return ModelSpec.model_validate(data)
136
+ except Exception as e:
137
+ raise ModelLoadError(
138
+ "Model specification validation failed",
139
+ source_path,
140
+ str(e),
141
+ ) from e
142
+
143
+ @staticmethod
144
+ def validate_variables(spec: ModelSpec) -> list[str]:
145
+ """Validate variable specifications and return warnings.
146
+
147
+ Performs additional validation beyond Pydantic schema:
148
+ - Checks for duplicate variable names
149
+ - Checks that states are defined for discrete variables
150
+ - Checks for empty variable list
151
+
152
+ Args:
153
+ spec: ModelSpec to validate.
154
+
155
+ Returns:
156
+ List of warning messages (empty if no issues).
157
+
158
+ Raises:
159
+ ModelLoadError: If critical validation errors found.
160
+ """
161
+ warnings: list[str] = []
162
+
163
+ # Check for empty variables
164
+ if not spec.variables:
165
+ raise ModelLoadError(
166
+ "Model specification has no variables defined"
167
+ )
168
+
169
+ # Check for duplicate names
170
+ names = [v.name for v in spec.variables]
171
+ duplicates = [n for n in names if names.count(n) > 1]
172
+ if duplicates:
173
+ raise ModelLoadError(
174
+ f"Duplicate variable names found: {', '.join(set(duplicates))}"
175
+ )
176
+
177
+ # Check states for discrete variables
178
+ for var in spec.variables:
179
+ if (
180
+ var.type
181
+ in (
182
+ "binary",
183
+ "categorical",
184
+ "ordinal",
185
+ )
186
+ and not var.states
187
+ ):
188
+ warnings.append(
189
+ f"Variable '{var.name}' is {var.type} "
190
+ "but has no states defined"
191
+ )
192
+
193
+ # Check binary variables have exactly 2 states
194
+ for var in spec.variables:
195
+ if var.type == "binary" and var.states and len(var.states) != 2:
196
+ warnings.append(
197
+ f"Variable '{var.name}' is binary "
198
+ f"but has {len(var.states)} states"
199
+ )
200
+
201
+ return warnings
202
+
203
+ @staticmethod
204
+ def load_and_validate(
205
+ path: Union[str, Path],
206
+ ) -> tuple[ModelSpec, list[str]]:
207
+ """Load and fully validate a model specification.
208
+
209
+ Combines loading with additional validation checks.
210
+
211
+ Args:
212
+ path: Path to the JSON file.
213
+
214
+ Returns:
215
+ Tuple of (ModelSpec, list of warnings).
216
+
217
+ Raises:
218
+ ModelLoadError: If loading or validation fails.
219
+ """
220
+ spec = ModelLoader.load(path)
221
+ warnings = ModelLoader.validate_variables(spec)
222
+ return spec, warnings