causaliq-knowledge 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +5 -2
- causaliq_knowledge/action.py +480 -0
- causaliq_knowledge/cache/encoders/json_encoder.py +15 -3
- causaliq_knowledge/cache/token_cache.py +36 -2
- causaliq_knowledge/cli/__init__.py +15 -0
- causaliq_knowledge/cli/cache.py +478 -0
- causaliq_knowledge/cli/generate.py +410 -0
- causaliq_knowledge/cli/main.py +172 -0
- causaliq_knowledge/cli/models.py +309 -0
- causaliq_knowledge/graph/__init__.py +78 -0
- causaliq_knowledge/graph/generator.py +457 -0
- causaliq_knowledge/graph/loader.py +222 -0
- causaliq_knowledge/graph/models.py +426 -0
- causaliq_knowledge/graph/params.py +175 -0
- causaliq_knowledge/graph/prompts.py +445 -0
- causaliq_knowledge/graph/response.py +392 -0
- causaliq_knowledge/graph/view_filter.py +154 -0
- causaliq_knowledge/llm/base_client.py +6 -0
- causaliq_knowledge/llm/cache.py +124 -61
- causaliq_knowledge/py.typed +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/METADATA +10 -6
- causaliq_knowledge-0.4.0.dist-info/RECORD +42 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/entry_points.txt +3 -0
- causaliq_knowledge/cli.py +0 -757
- causaliq_knowledge-0.3.0.dist-info/RECORD +0 -28
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/WHEEL +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.3.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
"""Graph generator using LLM providers.
|
|
2
|
+
|
|
3
|
+
This module provides the GraphGenerator class for generating complete
|
|
4
|
+
causal graphs from variable specifications using LLM providers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from datetime import datetime, timezone
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
from causaliq_knowledge.graph.prompts import GraphQueryPrompt, OutputFormat
|
|
18
|
+
from causaliq_knowledge.graph.response import (
|
|
19
|
+
GeneratedGraph,
|
|
20
|
+
GenerationMetadata,
|
|
21
|
+
parse_graph_response,
|
|
22
|
+
)
|
|
23
|
+
from causaliq_knowledge.graph.view_filter import PromptDetail
|
|
24
|
+
from causaliq_knowledge.llm.anthropic_client import (
|
|
25
|
+
AnthropicClient,
|
|
26
|
+
AnthropicConfig,
|
|
27
|
+
)
|
|
28
|
+
from causaliq_knowledge.llm.deepseek_client import (
|
|
29
|
+
DeepSeekClient,
|
|
30
|
+
DeepSeekConfig,
|
|
31
|
+
)
|
|
32
|
+
from causaliq_knowledge.llm.gemini_client import GeminiClient, GeminiConfig
|
|
33
|
+
from causaliq_knowledge.llm.groq_client import GroqClient, GroqConfig
|
|
34
|
+
from causaliq_knowledge.llm.mistral_client import MistralClient, MistralConfig
|
|
35
|
+
from causaliq_knowledge.llm.ollama_client import OllamaClient, OllamaConfig
|
|
36
|
+
from causaliq_knowledge.llm.openai_client import OpenAIClient, OpenAIConfig
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
39
|
+
from causaliq_knowledge.cache import TokenCache
|
|
40
|
+
from causaliq_knowledge.graph.models import ModelSpec
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Type alias for supported clients
|
|
46
|
+
LLMClientType = Union[
|
|
47
|
+
AnthropicClient,
|
|
48
|
+
DeepSeekClient,
|
|
49
|
+
GeminiClient,
|
|
50
|
+
GroqClient,
|
|
51
|
+
MistralClient,
|
|
52
|
+
OllamaClient,
|
|
53
|
+
OpenAIClient,
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class GraphGeneratorConfig:
|
|
59
|
+
"""Configuration for the GraphGenerator.
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
temperature: LLM sampling temperature (lower = more deterministic).
|
|
63
|
+
max_tokens: Maximum tokens in LLM response.
|
|
64
|
+
timeout: Request timeout in seconds.
|
|
65
|
+
output_format: Desired output format (edge_list or adjacency_matrix).
|
|
66
|
+
prompt_detail: Detail level for variable information in prompts.
|
|
67
|
+
use_llm_names: Use llm_name instead of benchmark name in prompts.
|
|
68
|
+
request_id: Optional identifier for requests (stored in metadata).
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
temperature: float = 0.1
|
|
72
|
+
max_tokens: int = 2000
|
|
73
|
+
timeout: float = 60.0
|
|
74
|
+
output_format: OutputFormat = OutputFormat.EDGE_LIST
|
|
75
|
+
prompt_detail: PromptDetail = PromptDetail.STANDARD
|
|
76
|
+
use_llm_names: bool = True
|
|
77
|
+
request_id: str = ""
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class GraphGenerator:
|
|
81
|
+
"""Generate causal graphs from variable specifications using LLMs.
|
|
82
|
+
|
|
83
|
+
This class provides methods for generating complete causal graphs
|
|
84
|
+
from ModelSpec objects or variable dictionaries. It supports all
|
|
85
|
+
LLM providers available in causaliq-knowledge and integrates with
|
|
86
|
+
the TokenCache for efficient caching of requests.
|
|
87
|
+
|
|
88
|
+
Attributes:
|
|
89
|
+
model: The LLM model identifier (e.g., "groq/llama-3.1-8b-instant").
|
|
90
|
+
config: Configuration for generation parameters.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
>>> from causaliq_knowledge.graph import ModelLoader
|
|
94
|
+
>>> from causaliq_knowledge.graph.generator import GraphGenerator
|
|
95
|
+
>>>
|
|
96
|
+
>>> # Load model specification
|
|
97
|
+
>>> spec = ModelLoader.load("model.json")
|
|
98
|
+
>>>
|
|
99
|
+
>>> # Create generator
|
|
100
|
+
>>> generator = GraphGenerator(model="groq/llama-3.1-8b-instant")
|
|
101
|
+
>>>
|
|
102
|
+
>>> # Generate graph
|
|
103
|
+
>>> graph = generator.generate_from_spec(spec)
|
|
104
|
+
>>> print(f"Generated {len(graph.edges)} edges")
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
model: str = "groq/llama-3.1-8b-instant",
|
|
110
|
+
config: Optional[GraphGeneratorConfig] = None,
|
|
111
|
+
cache: Optional["TokenCache"] = None,
|
|
112
|
+
) -> None:
|
|
113
|
+
"""Initialise the GraphGenerator.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
model: LLM model identifier with provider prefix. Supported:
|
|
117
|
+
- "groq/llama-3.1-8b-instant" (Groq API)
|
|
118
|
+
- "gemini/gemini-2.5-flash" (Google Gemini)
|
|
119
|
+
- "openai/gpt-4o" (OpenAI)
|
|
120
|
+
- "anthropic/claude-3-5-sonnet-20241022" (Anthropic)
|
|
121
|
+
- "deepseek/deepseek-chat" (DeepSeek)
|
|
122
|
+
- "mistral/mistral-small-latest" (Mistral)
|
|
123
|
+
- "ollama/llama3.2:1b" (Local Ollama)
|
|
124
|
+
config: Generation configuration. Uses defaults if None.
|
|
125
|
+
cache: TokenCache instance for caching. Disabled if None.
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
ValueError: If the model prefix is not supported.
|
|
129
|
+
"""
|
|
130
|
+
self._model = model
|
|
131
|
+
self._config = config or GraphGeneratorConfig()
|
|
132
|
+
self._cache = cache
|
|
133
|
+
self._client = self._create_client(model)
|
|
134
|
+
self._call_count = 0
|
|
135
|
+
|
|
136
|
+
# Configure cache on client if provided
|
|
137
|
+
if cache is not None:
|
|
138
|
+
self._client.set_cache(cache, use_cache=True)
|
|
139
|
+
|
|
140
|
+
def _create_client(self, model: str) -> LLMClientType:
|
|
141
|
+
"""Create the appropriate LLM client for the model.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
model: Model identifier with provider prefix.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Configured LLM client instance.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
ValueError: If the model prefix is not supported.
|
|
151
|
+
"""
|
|
152
|
+
config = self._config
|
|
153
|
+
|
|
154
|
+
if model.startswith("anthropic/"):
|
|
155
|
+
model_name = model.split("/", 1)[1]
|
|
156
|
+
return AnthropicClient(
|
|
157
|
+
config=AnthropicConfig(
|
|
158
|
+
model=model_name,
|
|
159
|
+
temperature=config.temperature,
|
|
160
|
+
max_tokens=config.max_tokens,
|
|
161
|
+
timeout=config.timeout,
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
elif model.startswith("deepseek/"):
|
|
165
|
+
model_name = model.split("/", 1)[1]
|
|
166
|
+
return DeepSeekClient(
|
|
167
|
+
config=DeepSeekConfig(
|
|
168
|
+
model=model_name,
|
|
169
|
+
temperature=config.temperature,
|
|
170
|
+
max_tokens=config.max_tokens,
|
|
171
|
+
timeout=config.timeout,
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
elif model.startswith("gemini/"):
|
|
175
|
+
model_name = model.split("/", 1)[1]
|
|
176
|
+
return GeminiClient(
|
|
177
|
+
config=GeminiConfig(
|
|
178
|
+
model=model_name,
|
|
179
|
+
temperature=config.temperature,
|
|
180
|
+
max_tokens=config.max_tokens,
|
|
181
|
+
timeout=config.timeout,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
elif model.startswith("groq/"):
|
|
185
|
+
model_name = model.split("/", 1)[1]
|
|
186
|
+
return GroqClient(
|
|
187
|
+
config=GroqConfig(
|
|
188
|
+
model=model_name,
|
|
189
|
+
temperature=config.temperature,
|
|
190
|
+
max_tokens=config.max_tokens,
|
|
191
|
+
timeout=config.timeout,
|
|
192
|
+
)
|
|
193
|
+
)
|
|
194
|
+
elif model.startswith("mistral/"):
|
|
195
|
+
model_name = model.split("/", 1)[1]
|
|
196
|
+
return MistralClient(
|
|
197
|
+
config=MistralConfig(
|
|
198
|
+
model=model_name,
|
|
199
|
+
temperature=config.temperature,
|
|
200
|
+
max_tokens=config.max_tokens,
|
|
201
|
+
timeout=config.timeout,
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
elif model.startswith("ollama/"):
|
|
205
|
+
model_name = model.split("/", 1)[1]
|
|
206
|
+
return OllamaClient(
|
|
207
|
+
config=OllamaConfig(
|
|
208
|
+
model=model_name,
|
|
209
|
+
temperature=config.temperature,
|
|
210
|
+
max_tokens=config.max_tokens,
|
|
211
|
+
timeout=config.timeout,
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
elif model.startswith("openai/"):
|
|
215
|
+
model_name = model.split("/", 1)[1]
|
|
216
|
+
return OpenAIClient(
|
|
217
|
+
config=OpenAIConfig(
|
|
218
|
+
model=model_name,
|
|
219
|
+
temperature=config.temperature,
|
|
220
|
+
max_tokens=config.max_tokens,
|
|
221
|
+
timeout=config.timeout,
|
|
222
|
+
)
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
supported = [
|
|
226
|
+
"anthropic/",
|
|
227
|
+
"deepseek/",
|
|
228
|
+
"gemini/",
|
|
229
|
+
"groq/",
|
|
230
|
+
"mistral/",
|
|
231
|
+
"ollama/",
|
|
232
|
+
"openai/",
|
|
233
|
+
]
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"Model '{model}' not supported. "
|
|
236
|
+
f"Supported prefixes: {supported}."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
@property
|
|
240
|
+
def model(self) -> str:
|
|
241
|
+
"""Return the model identifier."""
|
|
242
|
+
return self._model
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def config(self) -> GraphGeneratorConfig:
|
|
246
|
+
"""Return the generator configuration."""
|
|
247
|
+
return self._config
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def call_count(self) -> int:
|
|
251
|
+
"""Return the number of generation calls made."""
|
|
252
|
+
return self._call_count
|
|
253
|
+
|
|
254
|
+
def set_cache(
|
|
255
|
+
self,
|
|
256
|
+
cache: Optional["TokenCache"],
|
|
257
|
+
use_cache: bool = True,
|
|
258
|
+
) -> None:
|
|
259
|
+
"""Configure caching for this generator.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
cache: TokenCache instance for caching, or None to disable.
|
|
263
|
+
use_cache: Whether to use the cache (default True).
|
|
264
|
+
"""
|
|
265
|
+
self._cache = cache
|
|
266
|
+
self._client.set_cache(cache, use_cache)
|
|
267
|
+
|
|
268
|
+
def _build_cache_key(
|
|
269
|
+
self,
|
|
270
|
+
prompt: GraphQueryPrompt,
|
|
271
|
+
system_prompt: str,
|
|
272
|
+
user_prompt: str,
|
|
273
|
+
) -> str:
|
|
274
|
+
"""Build a deterministic cache key for graph generation.
|
|
275
|
+
|
|
276
|
+
Uses a prefix to distinguish graph queries from edge queries.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
prompt: The GraphQueryPrompt used.
|
|
280
|
+
system_prompt: The system prompt string.
|
|
281
|
+
user_prompt: The user prompt string.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
16-character hex string cache key with graph prefix.
|
|
285
|
+
"""
|
|
286
|
+
key_data = {
|
|
287
|
+
"type": "graph_generation",
|
|
288
|
+
"model": self._model,
|
|
289
|
+
"output_format": self._config.output_format.value,
|
|
290
|
+
"prompt_detail": prompt.level.value,
|
|
291
|
+
"system_prompt": system_prompt,
|
|
292
|
+
"user_prompt": user_prompt,
|
|
293
|
+
"temperature": self._config.temperature,
|
|
294
|
+
}
|
|
295
|
+
key_json = json.dumps(key_data, sort_keys=True, separators=(",", ":"))
|
|
296
|
+
return "graph_" + hashlib.sha256(key_json.encode()).hexdigest()[:12]
|
|
297
|
+
|
|
298
|
+
def generate_graph(
|
|
299
|
+
self,
|
|
300
|
+
variables: List[Dict[str, Any]],
|
|
301
|
+
level: Optional[PromptDetail] = None,
|
|
302
|
+
domain: Optional[str] = None,
|
|
303
|
+
output_format: Optional[OutputFormat] = None,
|
|
304
|
+
system_prompt: Optional[str] = None,
|
|
305
|
+
) -> GeneratedGraph:
|
|
306
|
+
"""Generate a causal graph from variable dictionaries.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
variables: List of variable dictionaries with at least "name".
|
|
310
|
+
level: View level for context. Uses config default if None.
|
|
311
|
+
domain: Optional domain context for the query.
|
|
312
|
+
output_format: Output format. Uses config default if None.
|
|
313
|
+
system_prompt: Custom system prompt (optional).
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
GeneratedGraph with proposed edges and metadata.
|
|
317
|
+
|
|
318
|
+
Raises:
|
|
319
|
+
ValueError: If LLM response cannot be parsed.
|
|
320
|
+
"""
|
|
321
|
+
level = level or self._config.prompt_detail
|
|
322
|
+
output_format = output_format or self._config.output_format
|
|
323
|
+
|
|
324
|
+
# Build the prompt
|
|
325
|
+
prompt = GraphQueryPrompt(
|
|
326
|
+
variables=variables,
|
|
327
|
+
level=level,
|
|
328
|
+
domain=domain,
|
|
329
|
+
output_format=output_format,
|
|
330
|
+
system_prompt=system_prompt,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
return self._execute_query(prompt)
|
|
334
|
+
|
|
335
|
+
def generate_from_spec(
|
|
336
|
+
self,
|
|
337
|
+
spec: "ModelSpec",
|
|
338
|
+
level: Optional[PromptDetail] = None,
|
|
339
|
+
output_format: Optional[OutputFormat] = None,
|
|
340
|
+
system_prompt: Optional[str] = None,
|
|
341
|
+
use_llm_names: Optional[bool] = None,
|
|
342
|
+
) -> GeneratedGraph:
|
|
343
|
+
"""Generate a causal graph from a ModelSpec.
|
|
344
|
+
|
|
345
|
+
Convenience method that extracts variables and domain from the
|
|
346
|
+
specification automatically.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
spec: The model specification.
|
|
350
|
+
level: View level for context. Uses config default if None.
|
|
351
|
+
output_format: Output format. Uses config default if None.
|
|
352
|
+
system_prompt: Custom system prompt (optional).
|
|
353
|
+
use_llm_names: Use llm_name instead of benchmark name.
|
|
354
|
+
Uses config default if None.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
GeneratedGraph with proposed edges and metadata.
|
|
358
|
+
|
|
359
|
+
Raises:
|
|
360
|
+
ValueError: If LLM response cannot be parsed.
|
|
361
|
+
"""
|
|
362
|
+
level = level or self._config.prompt_detail
|
|
363
|
+
output_format = output_format or self._config.output_format
|
|
364
|
+
use_llm = (
|
|
365
|
+
use_llm_names
|
|
366
|
+
if use_llm_names is not None
|
|
367
|
+
else self._config.use_llm_names
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Use the class method to create prompt from spec
|
|
371
|
+
prompt = GraphQueryPrompt.from_model_spec(
|
|
372
|
+
spec=spec,
|
|
373
|
+
level=level,
|
|
374
|
+
output_format=output_format,
|
|
375
|
+
system_prompt=system_prompt,
|
|
376
|
+
use_llm_names=use_llm,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
return self._execute_query(prompt)
|
|
380
|
+
|
|
381
|
+
def _execute_query(self, prompt: GraphQueryPrompt) -> GeneratedGraph:
|
|
382
|
+
"""Execute the LLM query and parse the response.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
prompt: The configured GraphQueryPrompt.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
GeneratedGraph with parsed edges and metadata.
|
|
389
|
+
|
|
390
|
+
Raises:
|
|
391
|
+
ValueError: If LLM response cannot be parsed.
|
|
392
|
+
"""
|
|
393
|
+
system_prompt, user_prompt = prompt.build()
|
|
394
|
+
variable_names = prompt.get_variable_names()
|
|
395
|
+
|
|
396
|
+
# Build messages for the LLM
|
|
397
|
+
messages: List[Dict[str, str]] = []
|
|
398
|
+
if system_prompt:
|
|
399
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
400
|
+
messages.append({"role": "user", "content": user_prompt})
|
|
401
|
+
|
|
402
|
+
# Make the request (using cached_completion if cache is set)
|
|
403
|
+
start_time = time.perf_counter()
|
|
404
|
+
from_cache = False
|
|
405
|
+
|
|
406
|
+
if self._cache is not None and self._client.use_cache:
|
|
407
|
+
response = self._client.cached_completion(
|
|
408
|
+
messages, request_id=self._config.request_id
|
|
409
|
+
)
|
|
410
|
+
# Check if response was from cache by comparing timing
|
|
411
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
412
|
+
# If latency is very low, likely from cache
|
|
413
|
+
from_cache = latency_ms < 50
|
|
414
|
+
else:
|
|
415
|
+
response = self._client.completion(messages)
|
|
416
|
+
|
|
417
|
+
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
418
|
+
self._call_count += 1
|
|
419
|
+
|
|
420
|
+
# Parse the response
|
|
421
|
+
output_format_str = prompt.output_format.value
|
|
422
|
+
graph = parse_graph_response(
|
|
423
|
+
response.content,
|
|
424
|
+
variable_names,
|
|
425
|
+
output_format_str,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Add metadata
|
|
429
|
+
provider = self._model.split("/")[0] if "/" in self._model else ""
|
|
430
|
+
model_name = (
|
|
431
|
+
self._model.split("/", 1)[1] if "/" in self._model else self._model
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
graph.metadata = GenerationMetadata(
|
|
435
|
+
model=model_name,
|
|
436
|
+
provider=provider,
|
|
437
|
+
timestamp=datetime.now(timezone.utc),
|
|
438
|
+
latency_ms=latency_ms,
|
|
439
|
+
input_tokens=response.input_tokens,
|
|
440
|
+
output_tokens=response.output_tokens,
|
|
441
|
+
cost_usd=response.cost,
|
|
442
|
+
from_cache=from_cache,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
return graph
|
|
446
|
+
|
|
447
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
448
|
+
"""Get statistics about generation calls.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
Dict with call_count, model, and client stats.
|
|
452
|
+
"""
|
|
453
|
+
return {
|
|
454
|
+
"model": self._model,
|
|
455
|
+
"call_count": self._call_count,
|
|
456
|
+
"client_call_count": self._client.call_count,
|
|
457
|
+
}
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Model specification loader with validation.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to load and validate model
|
|
4
|
+
specification JSON files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Union
|
|
12
|
+
|
|
13
|
+
from causaliq_knowledge.graph.models import ModelSpec
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ModelLoadError(Exception):
|
|
17
|
+
"""Error raised when model loading fails.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
message: Error description.
|
|
21
|
+
path: Path to the file that failed to load.
|
|
22
|
+
details: Additional error details.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
message: str,
|
|
28
|
+
path: Path | str | None = None,
|
|
29
|
+
details: str | None = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
self.message = message
|
|
32
|
+
self.path = path
|
|
33
|
+
self.details = details
|
|
34
|
+
full_message = message
|
|
35
|
+
if path:
|
|
36
|
+
full_message = f"{message}: {path}"
|
|
37
|
+
if details:
|
|
38
|
+
full_message = f"{full_message}\n Details: {details}"
|
|
39
|
+
super().__init__(full_message)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ModelLoader:
|
|
43
|
+
"""Loader for model specification JSON files.
|
|
44
|
+
|
|
45
|
+
This class provides methods to load, validate, and access
|
|
46
|
+
model specifications from JSON files.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> loader = ModelLoader()
|
|
50
|
+
>>> spec = loader.load("path/to/model.json")
|
|
51
|
+
>>> print(spec.dataset_id)
|
|
52
|
+
'cancer'
|
|
53
|
+
|
|
54
|
+
>>> # Or load from dict
|
|
55
|
+
>>> spec = loader.from_dict(
|
|
56
|
+
... {"dataset_id": "test", "domain": "test", ...}
|
|
57
|
+
... )
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def load(path: Union[str, Path]) -> ModelSpec:
|
|
62
|
+
"""Load a model specification from a JSON file.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
path: Path to the JSON file.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Validated ModelSpec instance.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ModelLoadError: If the file cannot be loaded or validated.
|
|
72
|
+
"""
|
|
73
|
+
path = Path(path)
|
|
74
|
+
|
|
75
|
+
# Check file exists
|
|
76
|
+
if not path.exists():
|
|
77
|
+
raise ModelLoadError("Model specification file not found", path)
|
|
78
|
+
|
|
79
|
+
# Check file extension
|
|
80
|
+
if path.suffix.lower() != ".json":
|
|
81
|
+
raise ModelLoadError(
|
|
82
|
+
"Model specification must be a JSON file",
|
|
83
|
+
path,
|
|
84
|
+
f"Got extension: {path.suffix}",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Load JSON
|
|
88
|
+
try:
|
|
89
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
90
|
+
data = json.load(f)
|
|
91
|
+
except json.JSONDecodeError as e:
|
|
92
|
+
raise ModelLoadError(
|
|
93
|
+
"Invalid JSON in model specification",
|
|
94
|
+
path,
|
|
95
|
+
str(e),
|
|
96
|
+
) from e
|
|
97
|
+
except OSError as e:
|
|
98
|
+
raise ModelLoadError(
|
|
99
|
+
"Failed to read model specification file",
|
|
100
|
+
path,
|
|
101
|
+
str(e),
|
|
102
|
+
) from e
|
|
103
|
+
|
|
104
|
+
# Validate and create ModelSpec
|
|
105
|
+
return ModelLoader.from_dict(data, source_path=path)
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def from_dict(
|
|
109
|
+
data: dict,
|
|
110
|
+
source_path: Path | str | None = None,
|
|
111
|
+
) -> ModelSpec:
|
|
112
|
+
"""Create a ModelSpec from a dictionary.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
data: Dictionary containing model specification.
|
|
116
|
+
source_path: Optional source path for error messages.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Validated ModelSpec instance.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ModelLoadError: If validation fails.
|
|
123
|
+
"""
|
|
124
|
+
# Check required fields
|
|
125
|
+
required_fields = ["dataset_id", "domain"]
|
|
126
|
+
missing = [f for f in required_fields if f not in data]
|
|
127
|
+
if missing:
|
|
128
|
+
raise ModelLoadError(
|
|
129
|
+
f"Missing required fields: {', '.join(missing)}",
|
|
130
|
+
source_path,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Validate with Pydantic
|
|
134
|
+
try:
|
|
135
|
+
return ModelSpec.model_validate(data)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
raise ModelLoadError(
|
|
138
|
+
"Model specification validation failed",
|
|
139
|
+
source_path,
|
|
140
|
+
str(e),
|
|
141
|
+
) from e
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def validate_variables(spec: ModelSpec) -> list[str]:
|
|
145
|
+
"""Validate variable specifications and return warnings.
|
|
146
|
+
|
|
147
|
+
Performs additional validation beyond Pydantic schema:
|
|
148
|
+
- Checks for duplicate variable names
|
|
149
|
+
- Checks that states are defined for discrete variables
|
|
150
|
+
- Checks for empty variable list
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
spec: ModelSpec to validate.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
List of warning messages (empty if no issues).
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
ModelLoadError: If critical validation errors found.
|
|
160
|
+
"""
|
|
161
|
+
warnings: list[str] = []
|
|
162
|
+
|
|
163
|
+
# Check for empty variables
|
|
164
|
+
if not spec.variables:
|
|
165
|
+
raise ModelLoadError(
|
|
166
|
+
"Model specification has no variables defined"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Check for duplicate names
|
|
170
|
+
names = [v.name for v in spec.variables]
|
|
171
|
+
duplicates = [n for n in names if names.count(n) > 1]
|
|
172
|
+
if duplicates:
|
|
173
|
+
raise ModelLoadError(
|
|
174
|
+
f"Duplicate variable names found: {', '.join(set(duplicates))}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Check states for discrete variables
|
|
178
|
+
for var in spec.variables:
|
|
179
|
+
if (
|
|
180
|
+
var.type
|
|
181
|
+
in (
|
|
182
|
+
"binary",
|
|
183
|
+
"categorical",
|
|
184
|
+
"ordinal",
|
|
185
|
+
)
|
|
186
|
+
and not var.states
|
|
187
|
+
):
|
|
188
|
+
warnings.append(
|
|
189
|
+
f"Variable '{var.name}' is {var.type} "
|
|
190
|
+
"but has no states defined"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Check binary variables have exactly 2 states
|
|
194
|
+
for var in spec.variables:
|
|
195
|
+
if var.type == "binary" and var.states and len(var.states) != 2:
|
|
196
|
+
warnings.append(
|
|
197
|
+
f"Variable '{var.name}' is binary "
|
|
198
|
+
f"but has {len(var.states)} states"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return warnings
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def load_and_validate(
|
|
205
|
+
path: Union[str, Path],
|
|
206
|
+
) -> tuple[ModelSpec, list[str]]:
|
|
207
|
+
"""Load and fully validate a model specification.
|
|
208
|
+
|
|
209
|
+
Combines loading with additional validation checks.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
path: Path to the JSON file.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Tuple of (ModelSpec, list of warnings).
|
|
216
|
+
|
|
217
|
+
Raises:
|
|
218
|
+
ModelLoadError: If loading or validation fails.
|
|
219
|
+
"""
|
|
220
|
+
spec = ModelLoader.load(path)
|
|
221
|
+
warnings = ModelLoader.validate_variables(spec)
|
|
222
|
+
return spec, warnings
|