causaliq-knowledge 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causaliq_knowledge/__init__.py +6 -3
- causaliq_knowledge/action.py +480 -0
- causaliq_knowledge/cache/__init__.py +18 -0
- causaliq_knowledge/cache/encoders/__init__.py +13 -0
- causaliq_knowledge/cache/encoders/base.py +90 -0
- causaliq_knowledge/cache/encoders/json_encoder.py +430 -0
- causaliq_knowledge/cache/token_cache.py +666 -0
- causaliq_knowledge/cli/__init__.py +15 -0
- causaliq_knowledge/cli/cache.py +478 -0
- causaliq_knowledge/cli/generate.py +410 -0
- causaliq_knowledge/cli/main.py +172 -0
- causaliq_knowledge/cli/models.py +309 -0
- causaliq_knowledge/graph/__init__.py +78 -0
- causaliq_knowledge/graph/generator.py +457 -0
- causaliq_knowledge/graph/loader.py +222 -0
- causaliq_knowledge/graph/models.py +426 -0
- causaliq_knowledge/graph/params.py +175 -0
- causaliq_knowledge/graph/prompts.py +445 -0
- causaliq_knowledge/graph/response.py +392 -0
- causaliq_knowledge/graph/view_filter.py +154 -0
- causaliq_knowledge/llm/base_client.py +147 -1
- causaliq_knowledge/llm/cache.py +443 -0
- causaliq_knowledge/py.typed +0 -0
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/METADATA +10 -6
- causaliq_knowledge-0.4.0.dist-info/RECORD +42 -0
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/WHEEL +1 -1
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/entry_points.txt +3 -0
- causaliq_knowledge/cli.py +0 -414
- causaliq_knowledge-0.2.0.dist-info/RECORD +0 -22
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {causaliq_knowledge-0.2.0.dist-info → causaliq_knowledge-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
"""Pydantic models for model specification schemas.
|
|
2
|
+
|
|
3
|
+
This module defines the data models for loading and validating
|
|
4
|
+
causal model specifications from JSON files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any, Optional
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VariableType(str, Enum):
|
|
16
|
+
"""Type of variable in the model."""
|
|
17
|
+
|
|
18
|
+
BINARY = "binary"
|
|
19
|
+
CATEGORICAL = "categorical"
|
|
20
|
+
ORDINAL = "ordinal"
|
|
21
|
+
CONTINUOUS = "continuous"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VariableRole(str, Enum):
|
|
25
|
+
"""Role of variable in the causal structure."""
|
|
26
|
+
|
|
27
|
+
EXOGENOUS = "exogenous" # No parents (root cause)
|
|
28
|
+
ENDOGENOUS = "endogenous" # Has parents (caused by other variables)
|
|
29
|
+
LATENT = "latent" # Unobserved variable
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class VariableSpec(BaseModel):
|
|
33
|
+
"""Specification for a single variable in the causal model.
|
|
34
|
+
|
|
35
|
+
This model captures all metadata about a variable that can be used
|
|
36
|
+
to provide context to LLMs for graph generation.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
name: Benchmark/literature name used for ground truth and reporting.
|
|
40
|
+
llm_name: Name used when querying LLMs (prevents memorisation).
|
|
41
|
+
Defaults to name if not specified.
|
|
42
|
+
display_name: Human-readable name for display.
|
|
43
|
+
aliases: Alternative names for the variable.
|
|
44
|
+
type: Variable type (binary, categorical, ordinal, continuous).
|
|
45
|
+
states: Possible values/states for discrete variables.
|
|
46
|
+
role: Causal role (exogenous, endogenous, latent).
|
|
47
|
+
category: Domain-specific category (e.g., "environmental_exposure").
|
|
48
|
+
short_description: Brief description of the variable.
|
|
49
|
+
extended_description: Detailed description with domain context.
|
|
50
|
+
base_rate: Prior probabilities for each state.
|
|
51
|
+
conditional_rates: Conditional probabilities given parent states.
|
|
52
|
+
sensitivity_hints: Hints about causal relationships.
|
|
53
|
+
related_domain_knowledge: Domain knowledge statements.
|
|
54
|
+
references: Literature references.
|
|
55
|
+
|
|
56
|
+
Example:
|
|
57
|
+
>>> var = VariableSpec(
|
|
58
|
+
... name="smoke",
|
|
59
|
+
... llm_name="tobacco_history",
|
|
60
|
+
... type="binary",
|
|
61
|
+
... states=["never", "ever"],
|
|
62
|
+
... role="exogenous",
|
|
63
|
+
... short_description="Patient has history of tobacco smoking."
|
|
64
|
+
... )
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
name: str = Field(
|
|
68
|
+
..., description="Benchmark/literature name for ground truth"
|
|
69
|
+
)
|
|
70
|
+
llm_name: str = Field(
|
|
71
|
+
default="",
|
|
72
|
+
description="Name used for LLM queries (defaults to name)",
|
|
73
|
+
)
|
|
74
|
+
display_name: Optional[str] = Field(
|
|
75
|
+
default=None, description="Human-readable display name"
|
|
76
|
+
)
|
|
77
|
+
aliases: list[str] = Field(
|
|
78
|
+
default_factory=list, description="Alternative names"
|
|
79
|
+
)
|
|
80
|
+
type: VariableType = Field(..., description="Variable type")
|
|
81
|
+
states: list[str] = Field(
|
|
82
|
+
default_factory=list,
|
|
83
|
+
description="Possible states for discrete variables",
|
|
84
|
+
)
|
|
85
|
+
role: Optional[VariableRole] = Field(
|
|
86
|
+
default=None, description="Causal role in the structure"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@model_validator(mode="after")
|
|
90
|
+
def set_llm_name_default(self) -> "VariableSpec":
|
|
91
|
+
"""Set llm_name to name if not specified or empty."""
|
|
92
|
+
if not self.llm_name:
|
|
93
|
+
# Use object.__setattr__ since Pydantic models may be frozen
|
|
94
|
+
object.__setattr__(self, "llm_name", self.name)
|
|
95
|
+
return self
|
|
96
|
+
|
|
97
|
+
category: Optional[str] = Field(
|
|
98
|
+
default=None, description="Domain-specific category"
|
|
99
|
+
)
|
|
100
|
+
short_description: Optional[str] = Field(
|
|
101
|
+
default=None, description="Brief description"
|
|
102
|
+
)
|
|
103
|
+
extended_description: Optional[str] = Field(
|
|
104
|
+
default=None, description="Detailed description with domain context"
|
|
105
|
+
)
|
|
106
|
+
base_rate: Optional[dict[str, float]] = Field(
|
|
107
|
+
default=None, description="Prior probabilities for each state"
|
|
108
|
+
)
|
|
109
|
+
conditional_rates: Optional[dict[str, Any]] = Field(
|
|
110
|
+
default=None, description="Conditional probabilities"
|
|
111
|
+
)
|
|
112
|
+
sensitivity_hints: Optional[str] = Field(
|
|
113
|
+
default=None, description="Hints about causal relationships"
|
|
114
|
+
)
|
|
115
|
+
related_domain_knowledge: list[str] = Field(
|
|
116
|
+
default_factory=list, description="Domain knowledge statements"
|
|
117
|
+
)
|
|
118
|
+
references: list[str] = Field(
|
|
119
|
+
default_factory=list, description="Literature references"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
@field_validator("type", mode="before")
|
|
123
|
+
@classmethod
|
|
124
|
+
def validate_type(cls, v: str | VariableType) -> VariableType:
|
|
125
|
+
"""Convert string type to VariableType enum."""
|
|
126
|
+
if isinstance(v, VariableType):
|
|
127
|
+
return v
|
|
128
|
+
return VariableType(v.lower())
|
|
129
|
+
|
|
130
|
+
@field_validator("role", mode="before")
|
|
131
|
+
@classmethod
|
|
132
|
+
def validate_role(
|
|
133
|
+
cls, v: str | VariableRole | None
|
|
134
|
+
) -> VariableRole | None:
|
|
135
|
+
"""Convert string role to VariableRole enum."""
|
|
136
|
+
if v is None:
|
|
137
|
+
return None
|
|
138
|
+
if isinstance(v, VariableRole):
|
|
139
|
+
return v
|
|
140
|
+
return VariableRole(v.lower())
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class Provenance(BaseModel):
|
|
144
|
+
"""Provenance information for the model specification.
|
|
145
|
+
|
|
146
|
+
Attributes:
|
|
147
|
+
source_network: Name of the source benchmark network.
|
|
148
|
+
source_reference: Citation for the original source.
|
|
149
|
+
source_url: URL to the source data.
|
|
150
|
+
disguise_strategy: Strategy used for variable name disguising.
|
|
151
|
+
memorization_risk: Risk level for LLM memorization.
|
|
152
|
+
notes: Additional notes about the source.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
source_network: Optional[str] = Field(
|
|
156
|
+
default=None, description="Source benchmark network name"
|
|
157
|
+
)
|
|
158
|
+
source_reference: Optional[str] = Field(
|
|
159
|
+
default=None, description="Citation for original source"
|
|
160
|
+
)
|
|
161
|
+
source_url: Optional[str] = Field(
|
|
162
|
+
default=None, description="URL to source data"
|
|
163
|
+
)
|
|
164
|
+
disguise_strategy: Optional[str] = Field(
|
|
165
|
+
default=None, description="Variable name disguising strategy"
|
|
166
|
+
)
|
|
167
|
+
memorization_risk: Optional[str] = Field(
|
|
168
|
+
default=None, description="LLM memorization risk level"
|
|
169
|
+
)
|
|
170
|
+
notes: Optional[str] = Field(default=None, description="Additional notes")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class LLMGuidance(BaseModel):
|
|
174
|
+
"""Guidance for LLMs when processing the model.
|
|
175
|
+
|
|
176
|
+
Attributes:
|
|
177
|
+
usage_notes: Notes about how to use the model.
|
|
178
|
+
do_not_provide: Information that should not be given to LLMs.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
usage_notes: list[str] = Field(
|
|
182
|
+
default_factory=list, description="Usage guidance for LLMs"
|
|
183
|
+
)
|
|
184
|
+
do_not_provide: list[str] = Field(
|
|
185
|
+
default_factory=list, description="Information to withhold from LLMs"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class ViewDefinition(BaseModel):
|
|
190
|
+
"""Definition of a view (minimal, standard, rich).
|
|
191
|
+
|
|
192
|
+
Attributes:
|
|
193
|
+
description: Description of what this view includes.
|
|
194
|
+
include_fields: List of VariableSpec fields to include in this view.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
description: Optional[str] = Field(
|
|
198
|
+
default=None, description="Description of this view"
|
|
199
|
+
)
|
|
200
|
+
include_fields: list[str] = Field(
|
|
201
|
+
default_factory=list, description="Fields to include in this view"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class PromptDetails(BaseModel):
|
|
206
|
+
"""Collection of prompt detail definitions.
|
|
207
|
+
|
|
208
|
+
Attributes:
|
|
209
|
+
minimal: Minimal view (typically just variable names).
|
|
210
|
+
standard: Standard view (names, types, descriptions, states).
|
|
211
|
+
rich: Rich view (all available metadata).
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
minimal: ViewDefinition = Field(
|
|
215
|
+
default_factory=lambda: ViewDefinition(include_fields=["name"]),
|
|
216
|
+
description="Minimal context view",
|
|
217
|
+
)
|
|
218
|
+
standard: ViewDefinition = Field(
|
|
219
|
+
default_factory=lambda: ViewDefinition(
|
|
220
|
+
include_fields=["name", "type", "short_description", "states"]
|
|
221
|
+
),
|
|
222
|
+
description="Standard context view",
|
|
223
|
+
)
|
|
224
|
+
rich: ViewDefinition = Field(
|
|
225
|
+
default_factory=lambda: ViewDefinition(
|
|
226
|
+
include_fields=[
|
|
227
|
+
"name",
|
|
228
|
+
"display_name",
|
|
229
|
+
"type",
|
|
230
|
+
"role",
|
|
231
|
+
"category",
|
|
232
|
+
"short_description",
|
|
233
|
+
"extended_description",
|
|
234
|
+
"states",
|
|
235
|
+
"base_rate",
|
|
236
|
+
"conditional_rates",
|
|
237
|
+
"sensitivity_hints",
|
|
238
|
+
"related_domain_knowledge",
|
|
239
|
+
"references",
|
|
240
|
+
]
|
|
241
|
+
),
|
|
242
|
+
description="Rich context view",
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class Constraints(BaseModel):
|
|
247
|
+
"""Structural constraints for the causal model.
|
|
248
|
+
|
|
249
|
+
Attributes:
|
|
250
|
+
forbidden_edges: Pairs of variables that cannot have direct edges.
|
|
251
|
+
partial_order: Pairs indicating causal ordering (a must precede b).
|
|
252
|
+
tiers: Grouping of variables into causal tiers.
|
|
253
|
+
notes: Additional notes about constraints.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
forbidden_edges: list[list[str]] = Field(
|
|
257
|
+
default_factory=list,
|
|
258
|
+
description="Variable pairs that cannot have edges",
|
|
259
|
+
)
|
|
260
|
+
partial_order: list[list[str]] = Field(
|
|
261
|
+
default_factory=list, description="Causal ordering constraints"
|
|
262
|
+
)
|
|
263
|
+
tiers: dict[str, list[str]] = Field(
|
|
264
|
+
default_factory=dict, description="Variable tier groupings"
|
|
265
|
+
)
|
|
266
|
+
notes: Optional[str] = Field(
|
|
267
|
+
default=None, description="Notes about constraints"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class CausalPrinciple(BaseModel):
|
|
272
|
+
"""A causal principle that applies to the domain.
|
|
273
|
+
|
|
274
|
+
Attributes:
|
|
275
|
+
id: Unique identifier for the principle.
|
|
276
|
+
statement: The causal principle statement.
|
|
277
|
+
references: Supporting literature references.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
id: str = Field(..., description="Principle identifier")
|
|
281
|
+
statement: str = Field(..., description="The causal principle")
|
|
282
|
+
references: list[str] = Field(
|
|
283
|
+
default_factory=list, description="Literature references"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class GroundTruth(BaseModel):
|
|
288
|
+
"""Ground truth structure for evaluation.
|
|
289
|
+
|
|
290
|
+
Note: This should NOT be provided to LLMs during generation.
|
|
291
|
+
|
|
292
|
+
Attributes:
|
|
293
|
+
edges: Ground truth edges using benchmark variable names.
|
|
294
|
+
v_structures: V-structure definitions.
|
|
295
|
+
adjacency_matrix: Adjacency matrix representation.
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
edges: list[list[str]] = Field(
|
|
299
|
+
default_factory=list, description="Edges with benchmark variable names"
|
|
300
|
+
)
|
|
301
|
+
v_structures: list[dict[str, Any]] = Field(
|
|
302
|
+
default_factory=list, description="V-structure definitions"
|
|
303
|
+
)
|
|
304
|
+
adjacency_matrix: Optional[dict[str, Any]] = Field(
|
|
305
|
+
default=None, description="Adjacency matrix representation"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class ModelSpec(BaseModel):
|
|
310
|
+
"""Complete specification for a causal model.
|
|
311
|
+
|
|
312
|
+
This is the top-level model that represents an entire model
|
|
313
|
+
specification JSON file.
|
|
314
|
+
|
|
315
|
+
Attributes:
|
|
316
|
+
schema_version: Version of the specification schema.
|
|
317
|
+
dataset_id: Unique identifier for the dataset.
|
|
318
|
+
domain: Domain of the causal model (e.g., "pulmonary_oncology").
|
|
319
|
+
purpose: Purpose of the model specification.
|
|
320
|
+
provenance: Source and provenance information.
|
|
321
|
+
llm_guidance: Guidance for LLM usage.
|
|
322
|
+
views: View definitions (minimal, standard, rich).
|
|
323
|
+
variables: List of variable specifications.
|
|
324
|
+
constraints: Structural constraints.
|
|
325
|
+
causal_principles: Domain causal principles.
|
|
326
|
+
ground_truth: Ground truth for evaluation (not for LLMs).
|
|
327
|
+
|
|
328
|
+
Example:
|
|
329
|
+
>>> spec = ModelSpec(
|
|
330
|
+
... schema_version="2.0",
|
|
331
|
+
... dataset_id="cancer",
|
|
332
|
+
... domain="pulmonary_oncology",
|
|
333
|
+
... variables=[
|
|
334
|
+
... VariableSpec(
|
|
335
|
+
... name="smoking", llm_name="tobacco_use", type="binary"
|
|
336
|
+
... ),
|
|
337
|
+
... VariableSpec(
|
|
338
|
+
... name="cancer", llm_name="malignancy", type="binary"
|
|
339
|
+
... ),
|
|
340
|
+
... ]
|
|
341
|
+
... )
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
schema_version: str = Field(default="2.0", description="Schema version")
|
|
345
|
+
dataset_id: str = Field(..., description="Dataset identifier")
|
|
346
|
+
domain: str = Field(..., description="Domain of the causal model")
|
|
347
|
+
purpose: Optional[str] = Field(
|
|
348
|
+
default=None, description="Purpose of this specification"
|
|
349
|
+
)
|
|
350
|
+
provenance: Optional[Provenance] = Field(
|
|
351
|
+
default=None, description="Source and provenance information"
|
|
352
|
+
)
|
|
353
|
+
llm_guidance: Optional[LLMGuidance] = Field(
|
|
354
|
+
default=None, description="Guidance for LLM usage"
|
|
355
|
+
)
|
|
356
|
+
prompt_details: PromptDetails = Field(
|
|
357
|
+
default_factory=PromptDetails,
|
|
358
|
+
description="Prompt detail definitions",
|
|
359
|
+
alias="prompt_details",
|
|
360
|
+
)
|
|
361
|
+
variables: list[VariableSpec] = Field(
|
|
362
|
+
default_factory=list, description="Variable specifications"
|
|
363
|
+
)
|
|
364
|
+
constraints: Optional[Constraints] = Field(
|
|
365
|
+
default=None, description="Structural constraints"
|
|
366
|
+
)
|
|
367
|
+
causal_principles: list[CausalPrinciple] = Field(
|
|
368
|
+
default_factory=list, description="Domain causal principles"
|
|
369
|
+
)
|
|
370
|
+
ground_truth: Optional[GroundTruth] = Field(
|
|
371
|
+
default=None, description="Ground truth for evaluation"
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
def get_variable(self, name: str) -> VariableSpec | None:
|
|
375
|
+
"""Get a variable by name.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
name: Variable name to look up.
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
VariableSpec if found, None otherwise.
|
|
382
|
+
"""
|
|
383
|
+
for var in self.variables:
|
|
384
|
+
if var.name == name:
|
|
385
|
+
return var
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
def get_variable_names(self) -> list[str]:
|
|
389
|
+
"""Get list of all benchmark variable names.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
List of variable names.
|
|
393
|
+
"""
|
|
394
|
+
return [var.name for var in self.variables]
|
|
395
|
+
|
|
396
|
+
def get_llm_names(self) -> list[str]:
|
|
397
|
+
"""Get list of all LLM variable names.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
List of llm_name values.
|
|
401
|
+
"""
|
|
402
|
+
return [var.llm_name for var in self.variables]
|
|
403
|
+
|
|
404
|
+
def get_llm_to_name_mapping(self) -> dict[str, str]:
|
|
405
|
+
"""Get mapping from LLM names to benchmark names.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
Dict mapping llm_name -> name.
|
|
409
|
+
"""
|
|
410
|
+
return {var.llm_name: var.name for var in self.variables}
|
|
411
|
+
|
|
412
|
+
def get_name_to_llm_mapping(self) -> dict[str, str]:
|
|
413
|
+
"""Get mapping from benchmark names to LLM names.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
Dict mapping name -> llm_name.
|
|
417
|
+
"""
|
|
418
|
+
return {var.name: var.llm_name for var in self.variables}
|
|
419
|
+
|
|
420
|
+
def uses_distinct_llm_names(self) -> bool:
|
|
421
|
+
"""Check if any variable has a different llm_name from name.
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
True if at least one variable has llm_name != name.
|
|
425
|
+
"""
|
|
426
|
+
return any(var.llm_name != var.name for var in self.variables)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Shared parameter models for graph generation.
|
|
2
|
+
|
|
3
|
+
This module provides Pydantic models for validating graph generation
|
|
4
|
+
parameters, shared between CLI commands and workflow actions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Optional
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel, Field, field_validator
|
|
13
|
+
|
|
14
|
+
from causaliq_knowledge.graph.view_filter import PromptDetail
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GenerateGraphParams(BaseModel):
|
|
18
|
+
"""Parameters for graph generation - shared by CLI and Action.
|
|
19
|
+
|
|
20
|
+
This model provides validation for all graph generation parameters,
|
|
21
|
+
ensuring consistent behaviour between CLI invocation and workflow
|
|
22
|
+
action execution.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
model_spec: Path to model specification JSON file.
|
|
26
|
+
prompt_detail: Detail level for variable information in prompts.
|
|
27
|
+
use_benchmark_names: Use benchmark names instead of LLM names.
|
|
28
|
+
llm_model: LLM model identifier with provider prefix.
|
|
29
|
+
output: Output destination - .json file path or "none" for stdout.
|
|
30
|
+
llm_cache: Path to cache database file (.db) or "none" to disable.
|
|
31
|
+
llm_temperature: LLM sampling temperature.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> params = GenerateGraphParams(
|
|
35
|
+
... model_spec=Path("model.json"),
|
|
36
|
+
... prompt_detail=PromptDetail.STANDARD,
|
|
37
|
+
... llm_model="groq/llama-3.1-8b-instant",
|
|
38
|
+
... output="none",
|
|
39
|
+
... llm_cache="cache.db",
|
|
40
|
+
... )
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
model_spec: Path = Field(
|
|
44
|
+
...,
|
|
45
|
+
description="Path to model specification JSON file",
|
|
46
|
+
)
|
|
47
|
+
prompt_detail: PromptDetail = Field(
|
|
48
|
+
default=PromptDetail.STANDARD,
|
|
49
|
+
description="Detail level for variable information in prompts",
|
|
50
|
+
)
|
|
51
|
+
use_benchmark_names: bool = Field(
|
|
52
|
+
default=False,
|
|
53
|
+
description="Use benchmark names instead of LLM names",
|
|
54
|
+
)
|
|
55
|
+
llm_model: str = Field(
|
|
56
|
+
default="groq/llama-3.1-8b-instant",
|
|
57
|
+
description="LLM model identifier with provider prefix",
|
|
58
|
+
)
|
|
59
|
+
output: str = Field(
|
|
60
|
+
...,
|
|
61
|
+
description="Output destination: .json file path or 'none' for stdout",
|
|
62
|
+
)
|
|
63
|
+
llm_cache: str = Field(
|
|
64
|
+
...,
|
|
65
|
+
description="Path to cache database file (.db) or 'none' to disable",
|
|
66
|
+
)
|
|
67
|
+
llm_temperature: float = Field(
|
|
68
|
+
default=0.1,
|
|
69
|
+
ge=0.0,
|
|
70
|
+
le=2.0,
|
|
71
|
+
description="LLM sampling temperature (0.0-2.0)",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
75
|
+
|
|
76
|
+
@field_validator("llm_model")
|
|
77
|
+
@classmethod
|
|
78
|
+
def validate_llm_model_format(cls, v: str) -> str:
|
|
79
|
+
"""Validate LLM model identifier has provider prefix."""
|
|
80
|
+
valid_prefixes = (
|
|
81
|
+
"anthropic/",
|
|
82
|
+
"deepseek/",
|
|
83
|
+
"gemini/",
|
|
84
|
+
"groq/",
|
|
85
|
+
"mistral/",
|
|
86
|
+
"ollama/",
|
|
87
|
+
"openai/",
|
|
88
|
+
)
|
|
89
|
+
if not v.startswith(valid_prefixes):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"LLM model must start with provider prefix. "
|
|
92
|
+
f"Valid prefixes: {', '.join(valid_prefixes)}. Got: {v}"
|
|
93
|
+
)
|
|
94
|
+
return v
|
|
95
|
+
|
|
96
|
+
@field_validator("llm_cache")
|
|
97
|
+
@classmethod
|
|
98
|
+
def validate_llm_cache_format(cls, v: str) -> str:
|
|
99
|
+
"""Validate llm_cache is 'none' or a path ending with .db."""
|
|
100
|
+
if v.lower() == "none":
|
|
101
|
+
return "none"
|
|
102
|
+
if not v.endswith(".db"):
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"llm_cache must be 'none' or a path ending with .db. "
|
|
105
|
+
f"Got: {v}"
|
|
106
|
+
)
|
|
107
|
+
return v
|
|
108
|
+
|
|
109
|
+
@field_validator("output")
|
|
110
|
+
@classmethod
|
|
111
|
+
def validate_output_format(cls, v: str) -> str:
|
|
112
|
+
"""Validate output is 'none' or a path ending with .json."""
|
|
113
|
+
if v.lower() == "none":
|
|
114
|
+
return "none"
|
|
115
|
+
if not v.endswith(".json"):
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"output must be 'none' or a path ending with .json. "
|
|
118
|
+
f"Got: {v}"
|
|
119
|
+
)
|
|
120
|
+
return v
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_dict(cls, data: dict[str, Any]) -> "GenerateGraphParams":
|
|
124
|
+
"""Create params from dictionary with string-to-enum conversion.
|
|
125
|
+
|
|
126
|
+
This method handles conversion of string values to enum types,
|
|
127
|
+
useful when receiving parameters from workflow inputs.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
data: Dictionary of parameter values.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Validated GenerateGraphParams instance.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ValueError: If validation fails.
|
|
137
|
+
"""
|
|
138
|
+
# Convert string values to enums where needed
|
|
139
|
+
processed = dict(data)
|
|
140
|
+
|
|
141
|
+
# Convert prompt_detail string to PromptDetail enum
|
|
142
|
+
if "prompt_detail" in processed and isinstance(
|
|
143
|
+
processed["prompt_detail"], str
|
|
144
|
+
):
|
|
145
|
+
processed["prompt_detail"] = PromptDetail(
|
|
146
|
+
processed["prompt_detail"].lower()
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Convert model_spec string to Path
|
|
150
|
+
if "model_spec" in processed and isinstance(
|
|
151
|
+
processed["model_spec"], str
|
|
152
|
+
):
|
|
153
|
+
processed["model_spec"] = Path(processed["model_spec"])
|
|
154
|
+
|
|
155
|
+
return cls(**processed)
|
|
156
|
+
|
|
157
|
+
def get_effective_cache_path(self) -> Optional[Path]:
|
|
158
|
+
"""Get the effective cache path.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Path to cache database, or None if caching is disabled.
|
|
162
|
+
"""
|
|
163
|
+
if self.llm_cache.lower() == "none":
|
|
164
|
+
return None
|
|
165
|
+
return Path(self.llm_cache)
|
|
166
|
+
|
|
167
|
+
def get_effective_output_path(self) -> Optional[Path]:
|
|
168
|
+
"""Get the effective output path.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Path to output JSON file, or None for stdout.
|
|
172
|
+
"""
|
|
173
|
+
if self.output.lower() == "none":
|
|
174
|
+
return None
|
|
175
|
+
return Path(self.output)
|