bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,784 @@
|
|
|
1
|
+
"""Item constructor for building experimental items from templates.
|
|
2
|
+
|
|
3
|
+
This module provides the ItemConstructor class which transforms filled templates
|
|
4
|
+
into experimental items by applying model-based constraints and collecting
|
|
5
|
+
model outputs for analysis.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Iterator
|
|
11
|
+
from datetime import UTC, datetime
|
|
12
|
+
from uuid import UUID
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from bead.dsl.ast import (
|
|
17
|
+
ASTNode,
|
|
18
|
+
AttributeAccess,
|
|
19
|
+
BinaryOp,
|
|
20
|
+
FunctionCall,
|
|
21
|
+
Literal,
|
|
22
|
+
UnaryOp,
|
|
23
|
+
Variable,
|
|
24
|
+
)
|
|
25
|
+
from bead.dsl.context import EvaluationContext
|
|
26
|
+
from bead.dsl.evaluator import Evaluator
|
|
27
|
+
from bead.dsl.parser import parse
|
|
28
|
+
from bead.dsl.stdlib import register_stdlib
|
|
29
|
+
from bead.items.adapters.registry import ModelAdapterRegistry
|
|
30
|
+
from bead.items.cache import ModelOutputCache
|
|
31
|
+
from bead.items.item import Item, MetadataValue, ModelOutput
|
|
32
|
+
from bead.items.item_template import ItemTemplate
|
|
33
|
+
from bead.resources.constraints import Constraint
|
|
34
|
+
from bead.templates.filler import FilledTemplate
|
|
35
|
+
from bead.templates.resolver import ConstraintResolver
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ItemConstructor:
|
|
39
|
+
"""Construct experimental items from filled templates.
|
|
40
|
+
|
|
41
|
+
Transforms filled templates into items by:
|
|
42
|
+
1. Resolving element references to text
|
|
43
|
+
2. Computing required model outputs (from constraints)
|
|
44
|
+
3. Evaluating constraints with model outputs
|
|
45
|
+
4. Creating Item instances with metadata
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
model_registry : ModelAdapterRegistry
|
|
50
|
+
Registry of model adapters for constraint evaluation.
|
|
51
|
+
cache : ModelOutputCache
|
|
52
|
+
Cache for model outputs to avoid redundant computation.
|
|
53
|
+
constraint_resolver : ConstraintResolver | None, optional
|
|
54
|
+
Resolver for evaluating non-model constraints. If None, only
|
|
55
|
+
model-based constraints can be evaluated.
|
|
56
|
+
|
|
57
|
+
Attributes
|
|
58
|
+
----------
|
|
59
|
+
model_registry : ModelAdapterRegistry
|
|
60
|
+
Registry of model adapters for constraint evaluation.
|
|
61
|
+
cache : ModelOutputCache
|
|
62
|
+
Cache for model outputs to avoid redundant computation.
|
|
63
|
+
constraint_resolver : ConstraintResolver | None
|
|
64
|
+
Resolver for evaluating constraints (not used for model constraints).
|
|
65
|
+
|
|
66
|
+
Examples
|
|
67
|
+
--------
|
|
68
|
+
>>> from bead.items.adapters.registry import default_registry
|
|
69
|
+
>>> from bead.items.cache import ModelOutputCache
|
|
70
|
+
>>> cache = ModelOutputCache(backend="memory")
|
|
71
|
+
>>> constructor = ItemConstructor(default_registry, cache)
|
|
72
|
+
>>> constraints = {constraint_id: constraint_obj}
|
|
73
|
+
>>> items = list(constructor.construct_items(
|
|
74
|
+
... template, filled_templates, constraints
|
|
75
|
+
... ))
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
model_registry: ModelAdapterRegistry,
|
|
81
|
+
cache: ModelOutputCache,
|
|
82
|
+
constraint_resolver: ConstraintResolver | None = None,
|
|
83
|
+
) -> None:
|
|
84
|
+
self.model_registry = model_registry
|
|
85
|
+
self.cache = cache
|
|
86
|
+
self.constraint_resolver = constraint_resolver
|
|
87
|
+
self._dsl_evaluator = Evaluator(use_cache=True)
|
|
88
|
+
|
|
89
|
+
def construct_items(
|
|
90
|
+
self,
|
|
91
|
+
item_template: ItemTemplate,
|
|
92
|
+
filled_templates: dict[UUID, FilledTemplate],
|
|
93
|
+
constraints: dict[UUID, Constraint],
|
|
94
|
+
) -> Iterator[Item]:
|
|
95
|
+
"""Construct items from template and filled templates.
|
|
96
|
+
|
|
97
|
+
For each combination of filled templates:
|
|
98
|
+
1. Render elements (resolve filled_template_ref → text)
|
|
99
|
+
2. Compute required model outputs (from constraints)
|
|
100
|
+
3. Check constraints using model outputs
|
|
101
|
+
4. Yield item if all constraints satisfied
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
item_template : ItemTemplate
|
|
106
|
+
Template defining item structure and constraints.
|
|
107
|
+
filled_templates : dict[UUID, FilledTemplate]
|
|
108
|
+
Map of filled template UUIDs to FilledTemplate instances.
|
|
109
|
+
constraints : dict[UUID, Constraint]
|
|
110
|
+
Map of constraint UUIDs to Constraint objects.
|
|
111
|
+
|
|
112
|
+
Yields
|
|
113
|
+
------
|
|
114
|
+
Item
|
|
115
|
+
Constructed items that satisfy all constraints.
|
|
116
|
+
|
|
117
|
+
Raises
|
|
118
|
+
------
|
|
119
|
+
ValueError
|
|
120
|
+
If template references missing filled templates or constraints.
|
|
121
|
+
RuntimeError
|
|
122
|
+
If constraint evaluation or model computation fails.
|
|
123
|
+
|
|
124
|
+
Examples
|
|
125
|
+
--------
|
|
126
|
+
>>> template = ItemTemplate(...)
|
|
127
|
+
>>> filled = {uuid1: filled1, uuid2: filled2}
|
|
128
|
+
>>> constraints = {c_id: constraint_obj}
|
|
129
|
+
>>> items = list(constructor.construct_items(
|
|
130
|
+
... template, filled, constraints
|
|
131
|
+
... ))
|
|
132
|
+
>>> len(items)
|
|
133
|
+
2
|
|
134
|
+
"""
|
|
135
|
+
# Render elements to text
|
|
136
|
+
rendered_elements = self._render_elements(item_template, filled_templates)
|
|
137
|
+
|
|
138
|
+
# Compute model outputs required by constraints
|
|
139
|
+
model_outputs = self._compute_model_outputs(
|
|
140
|
+
item_template, rendered_elements, constraints
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Check constraints
|
|
144
|
+
constraint_satisfaction = self._check_constraints(
|
|
145
|
+
item_template, rendered_elements, model_outputs, constraints
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Only yield item if all constraints satisfied
|
|
149
|
+
if all(constraint_satisfaction.values()):
|
|
150
|
+
# Create item
|
|
151
|
+
item = Item(
|
|
152
|
+
item_template_id=item_template.id,
|
|
153
|
+
filled_template_refs=list(filled_templates.keys()),
|
|
154
|
+
rendered_elements=rendered_elements,
|
|
155
|
+
model_outputs=model_outputs,
|
|
156
|
+
constraint_satisfaction=constraint_satisfaction,
|
|
157
|
+
)
|
|
158
|
+
yield item
|
|
159
|
+
|
|
160
|
+
def _render_elements(
|
|
161
|
+
self,
|
|
162
|
+
item_template: ItemTemplate,
|
|
163
|
+
filled_templates: dict[UUID, FilledTemplate],
|
|
164
|
+
) -> dict[str, str]:
|
|
165
|
+
"""Render ItemElements to text.
|
|
166
|
+
|
|
167
|
+
Resolve element references: text elements use content directly,
|
|
168
|
+
filled_template_ref elements use the rendered text from FilledTemplate.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
item_template : ItemTemplate
|
|
173
|
+
Template with elements to render.
|
|
174
|
+
filled_templates : dict[UUID, FilledTemplate]
|
|
175
|
+
Map of filled template UUIDs to instances.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
dict[str, str]
|
|
180
|
+
Map of element names to rendered text.
|
|
181
|
+
|
|
182
|
+
Raises
|
|
183
|
+
------
|
|
184
|
+
ValueError
|
|
185
|
+
If element references missing filled template.
|
|
186
|
+
"""
|
|
187
|
+
rendered: dict[str, str] = {}
|
|
188
|
+
|
|
189
|
+
for element in item_template.elements:
|
|
190
|
+
if element.is_text:
|
|
191
|
+
# Static text element
|
|
192
|
+
rendered[element.element_name] = element.content or ""
|
|
193
|
+
elif element.is_template_ref:
|
|
194
|
+
# Reference to filled template
|
|
195
|
+
ref_id = element.filled_template_ref_id
|
|
196
|
+
if ref_id is None:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Element {element.element_name} has no filled_template_ref_id"
|
|
199
|
+
)
|
|
200
|
+
if ref_id not in filled_templates:
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"Element {element.element_name} references missing "
|
|
203
|
+
f"filled template {ref_id}"
|
|
204
|
+
)
|
|
205
|
+
filled_template = filled_templates[ref_id]
|
|
206
|
+
rendered[element.element_name] = filled_template.rendered_text
|
|
207
|
+
|
|
208
|
+
return rendered
|
|
209
|
+
|
|
210
|
+
def _compute_model_outputs(
|
|
211
|
+
self,
|
|
212
|
+
item_template: ItemTemplate,
|
|
213
|
+
rendered_elements: dict[str, str],
|
|
214
|
+
constraints: dict[UUID, Constraint],
|
|
215
|
+
) -> list[ModelOutput]:
|
|
216
|
+
"""Execute model calls required by constraints.
|
|
217
|
+
|
|
218
|
+
Parse DSL constraints to find model function calls, then execute
|
|
219
|
+
them via adapters with caching.
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
item_template : ItemTemplate
|
|
224
|
+
Template with constraints.
|
|
225
|
+
rendered_elements : dict[str, str]
|
|
226
|
+
Rendered element text.
|
|
227
|
+
constraints : dict[UUID, Constraint]
|
|
228
|
+
Map of constraint UUIDs to Constraint objects.
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
list[ModelOutput]
|
|
233
|
+
All model outputs computed for this item.
|
|
234
|
+
|
|
235
|
+
Raises
|
|
236
|
+
------
|
|
237
|
+
RuntimeError
|
|
238
|
+
If model computation fails.
|
|
239
|
+
ValueError
|
|
240
|
+
If constraint UUID not found in constraints dict.
|
|
241
|
+
"""
|
|
242
|
+
model_outputs: list[ModelOutput] = []
|
|
243
|
+
|
|
244
|
+
# Extract model calls from all DSL constraints
|
|
245
|
+
for constraint_id in item_template.constraints:
|
|
246
|
+
if constraint_id not in constraints:
|
|
247
|
+
raise ValueError(f"Constraint {constraint_id} not found")
|
|
248
|
+
|
|
249
|
+
constraint = constraints[constraint_id]
|
|
250
|
+
|
|
251
|
+
# Parse constraint expression to AST
|
|
252
|
+
try:
|
|
253
|
+
ast_node = parse(constraint.expression)
|
|
254
|
+
except Exception as e:
|
|
255
|
+
raise RuntimeError(
|
|
256
|
+
f"Failed to parse constraint '{constraint.expression}': {e}"
|
|
257
|
+
) from e
|
|
258
|
+
|
|
259
|
+
# Extract all model function calls from AST
|
|
260
|
+
model_calls = self._extract_model_calls(ast_node, rendered_elements)
|
|
261
|
+
|
|
262
|
+
# Execute each model call
|
|
263
|
+
for call in model_calls:
|
|
264
|
+
try:
|
|
265
|
+
output = self._execute_model_call(call)
|
|
266
|
+
if output:
|
|
267
|
+
model_outputs.append(output)
|
|
268
|
+
except Exception as e:
|
|
269
|
+
raise RuntimeError(
|
|
270
|
+
f"Failed to execute model call {call}: {e}"
|
|
271
|
+
) from e
|
|
272
|
+
|
|
273
|
+
return model_outputs
|
|
274
|
+
|
|
275
|
+
def _extract_model_calls(
|
|
276
|
+
self, ast_node: ASTNode, rendered_elements: dict[str, str]
|
|
277
|
+
) -> list[dict[str, str | int | float | bool | None]]:
|
|
278
|
+
"""Extract model function calls from AST.
|
|
279
|
+
|
|
280
|
+
Recursively traverse AST to find calls to model functions
|
|
281
|
+
(lm_prob, nli, similarity, etc.) and extract their arguments.
|
|
282
|
+
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
ast_node : ASTNode
|
|
286
|
+
AST node to traverse.
|
|
287
|
+
rendered_elements : dict[str, str]
|
|
288
|
+
Rendered elements for variable resolution.
|
|
289
|
+
|
|
290
|
+
Returns
|
|
291
|
+
-------
|
|
292
|
+
list[dict[str, str | int | float | bool | None]]
|
|
293
|
+
List of model call specifications with function name and arguments.
|
|
294
|
+
"""
|
|
295
|
+
calls: list[dict[str, str | int | float | bool | None]] = []
|
|
296
|
+
|
|
297
|
+
if isinstance(ast_node, FunctionCall):
|
|
298
|
+
# Check if this is a model function call
|
|
299
|
+
# Function can be Variable (for functions) or AttributeAccess (for methods)
|
|
300
|
+
if isinstance(ast_node.function, Variable):
|
|
301
|
+
func_name: str = ast_node.function.name
|
|
302
|
+
elif isinstance(ast_node.function, AttributeAccess):
|
|
303
|
+
func_name = ast_node.function.attribute
|
|
304
|
+
else:
|
|
305
|
+
# Skip other function call types
|
|
306
|
+
return calls
|
|
307
|
+
|
|
308
|
+
model_functions = {
|
|
309
|
+
"lm_prob",
|
|
310
|
+
"lm_perplexity",
|
|
311
|
+
"nli",
|
|
312
|
+
"similarity",
|
|
313
|
+
"embedding",
|
|
314
|
+
}
|
|
315
|
+
if func_name in model_functions:
|
|
316
|
+
# Extract arguments
|
|
317
|
+
call_spec = self._extract_call_args(
|
|
318
|
+
func_name, ast_node.arguments, rendered_elements
|
|
319
|
+
)
|
|
320
|
+
if call_spec:
|
|
321
|
+
calls.append(call_spec)
|
|
322
|
+
|
|
323
|
+
# Also check arguments for nested calls
|
|
324
|
+
for arg in ast_node.arguments:
|
|
325
|
+
calls.extend(self._extract_model_calls(arg, rendered_elements))
|
|
326
|
+
|
|
327
|
+
# Recursively check other node types
|
|
328
|
+
elif isinstance(ast_node, BinaryOp):
|
|
329
|
+
calls.extend(self._extract_model_calls(ast_node.left, rendered_elements))
|
|
330
|
+
calls.extend(self._extract_model_calls(ast_node.right, rendered_elements))
|
|
331
|
+
elif isinstance(ast_node, UnaryOp):
|
|
332
|
+
calls.extend(self._extract_model_calls(ast_node.operand, rendered_elements))
|
|
333
|
+
elif isinstance(ast_node, AttributeAccess):
|
|
334
|
+
calls.extend(self._extract_model_calls(ast_node.object, rendered_elements))
|
|
335
|
+
|
|
336
|
+
return calls
|
|
337
|
+
|
|
338
|
+
def _extract_call_args(
|
|
339
|
+
self,
|
|
340
|
+
func_name: str,
|
|
341
|
+
args: list[ASTNode],
|
|
342
|
+
rendered_elements: dict[str, str],
|
|
343
|
+
) -> dict[str, str | int | float | bool | None] | None:
|
|
344
|
+
"""Extract arguments from a model function call.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
func_name : str
|
|
349
|
+
Name of the function.
|
|
350
|
+
args : list[ASTNode]
|
|
351
|
+
AST nodes representing function arguments.
|
|
352
|
+
rendered_elements : dict[str, str]
|
|
353
|
+
Rendered elements for variable resolution.
|
|
354
|
+
|
|
355
|
+
Returns
|
|
356
|
+
-------
|
|
357
|
+
dict[str, Any] | None
|
|
358
|
+
Call specification with function, args, and model name.
|
|
359
|
+
"""
|
|
360
|
+
# Resolve literal values and variables
|
|
361
|
+
resolved_args: list[str | int | float | bool | None] = []
|
|
362
|
+
for arg in args:
|
|
363
|
+
if isinstance(arg, Literal):
|
|
364
|
+
resolved_args.append(arg.value)
|
|
365
|
+
elif isinstance(arg, Variable):
|
|
366
|
+
# Try to resolve from rendered elements
|
|
367
|
+
if arg.name in rendered_elements:
|
|
368
|
+
resolved_args.append(rendered_elements[arg.name])
|
|
369
|
+
else:
|
|
370
|
+
# Can't resolve, skip this call
|
|
371
|
+
return None
|
|
372
|
+
else:
|
|
373
|
+
# Complex expression, can't extract statically
|
|
374
|
+
return None
|
|
375
|
+
|
|
376
|
+
# Build call specification based on function type
|
|
377
|
+
if func_name in {"lm_prob", "lm_perplexity"}:
|
|
378
|
+
# lm_prob(text, model='gpt2')
|
|
379
|
+
if len(resolved_args) == 0:
|
|
380
|
+
return None
|
|
381
|
+
text = str(resolved_args[0])
|
|
382
|
+
model = str(resolved_args[1]) if len(resolved_args) > 1 else "gpt2"
|
|
383
|
+
operation = "log_probability" if func_name == "lm_prob" else "perplexity"
|
|
384
|
+
return {
|
|
385
|
+
"function": func_name,
|
|
386
|
+
"text": text,
|
|
387
|
+
"model": model,
|
|
388
|
+
"operation": operation,
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
elif func_name == "nli":
|
|
392
|
+
# nli(premise, hypothesis, model='roberta-large-mnli')
|
|
393
|
+
if len(resolved_args) < 2:
|
|
394
|
+
return None
|
|
395
|
+
premise = str(resolved_args[0])
|
|
396
|
+
hypothesis = str(resolved_args[1])
|
|
397
|
+
default_nli_model = "roberta-large-mnli"
|
|
398
|
+
model = (
|
|
399
|
+
str(resolved_args[2]) if len(resolved_args) > 2 else default_nli_model
|
|
400
|
+
)
|
|
401
|
+
return {
|
|
402
|
+
"function": func_name,
|
|
403
|
+
"premise": premise,
|
|
404
|
+
"hypothesis": hypothesis,
|
|
405
|
+
"model": model,
|
|
406
|
+
"operation": "nli",
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
elif func_name == "similarity":
|
|
410
|
+
# similarity(text1, text2, model='all-MiniLM-L6-v2')
|
|
411
|
+
if len(resolved_args) < 2:
|
|
412
|
+
return None
|
|
413
|
+
text1 = str(resolved_args[0])
|
|
414
|
+
text2 = str(resolved_args[1])
|
|
415
|
+
model = (
|
|
416
|
+
str(resolved_args[2]) if len(resolved_args) > 2 else "all-MiniLM-L6-v2"
|
|
417
|
+
)
|
|
418
|
+
return {
|
|
419
|
+
"function": func_name,
|
|
420
|
+
"text1": text1,
|
|
421
|
+
"text2": text2,
|
|
422
|
+
"model": model,
|
|
423
|
+
"operation": "similarity",
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
elif func_name == "embedding":
|
|
427
|
+
# embedding(text, model='all-MiniLM-L6-v2')
|
|
428
|
+
if len(resolved_args) == 0:
|
|
429
|
+
return None
|
|
430
|
+
text = str(resolved_args[0])
|
|
431
|
+
model = (
|
|
432
|
+
str(resolved_args[1]) if len(resolved_args) > 1 else "all-MiniLM-L6-v2"
|
|
433
|
+
)
|
|
434
|
+
return {
|
|
435
|
+
"function": func_name,
|
|
436
|
+
"text": text,
|
|
437
|
+
"model": model,
|
|
438
|
+
"operation": "embedding",
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
def _execute_model_call(
|
|
444
|
+
self, call_spec: dict[str, str | int | float | bool | None]
|
|
445
|
+
) -> ModelOutput | None:
|
|
446
|
+
"""Execute a single model call and return ModelOutput.
|
|
447
|
+
|
|
448
|
+
Parameters
|
|
449
|
+
----------
|
|
450
|
+
call_spec : dict[str, str | int | float | bool | None]
|
|
451
|
+
Call specification with function, args, and model.
|
|
452
|
+
|
|
453
|
+
Returns
|
|
454
|
+
-------
|
|
455
|
+
ModelOutput | None
|
|
456
|
+
Model output if successful, None if already cached or failed.
|
|
457
|
+
|
|
458
|
+
Raises
|
|
459
|
+
------
|
|
460
|
+
RuntimeError
|
|
461
|
+
If model execution fails.
|
|
462
|
+
"""
|
|
463
|
+
operation = str(call_spec["operation"])
|
|
464
|
+
model_name = str(call_spec["model"])
|
|
465
|
+
|
|
466
|
+
# Determine adapter type based on operation
|
|
467
|
+
if operation in {"log_probability", "perplexity"}:
|
|
468
|
+
adapter_type = "huggingface_lm"
|
|
469
|
+
elif operation == "nli":
|
|
470
|
+
adapter_type = "huggingface_nli"
|
|
471
|
+
elif operation in {"similarity", "embedding"}:
|
|
472
|
+
adapter_type = "sentence_transformer"
|
|
473
|
+
else:
|
|
474
|
+
raise ValueError(f"Unknown operation: {operation}")
|
|
475
|
+
|
|
476
|
+
# Check cache first
|
|
477
|
+
cache_key_args: dict[str, str | int | float | bool | None] = {}
|
|
478
|
+
if operation in {"log_probability", "perplexity"}:
|
|
479
|
+
cache_key_args = {"text": call_spec["text"]}
|
|
480
|
+
elif operation == "nli":
|
|
481
|
+
cache_key_args = {
|
|
482
|
+
"premise": call_spec["premise"],
|
|
483
|
+
"hypothesis": call_spec["hypothesis"],
|
|
484
|
+
}
|
|
485
|
+
elif operation == "similarity":
|
|
486
|
+
cache_key_args = {
|
|
487
|
+
"text1": call_spec["text1"],
|
|
488
|
+
"text2": call_spec["text2"],
|
|
489
|
+
}
|
|
490
|
+
elif operation == "embedding":
|
|
491
|
+
cache_key_args = {"text": call_spec["text"]}
|
|
492
|
+
|
|
493
|
+
cached_result = self.cache.get(model_name, operation, **cache_key_args)
|
|
494
|
+
if cached_result is not None:
|
|
495
|
+
# Already cached, create ModelOutput from cache
|
|
496
|
+
cache_key = self.cache.generate_cache_key(
|
|
497
|
+
model_name, operation, **cache_key_args
|
|
498
|
+
)
|
|
499
|
+
# Convert inputs to MetadataValue compatible dict
|
|
500
|
+
metadata_inputs: dict[str, MetadataValue] = {
|
|
501
|
+
k: str(v) for k, v in cache_key_args.items()
|
|
502
|
+
}
|
|
503
|
+
return ModelOutput(
|
|
504
|
+
model_name=model_name,
|
|
505
|
+
model_version="unknown", # Could fetch from cache
|
|
506
|
+
operation=operation,
|
|
507
|
+
inputs=metadata_inputs,
|
|
508
|
+
output=cached_result,
|
|
509
|
+
cache_key=cache_key,
|
|
510
|
+
computation_metadata={
|
|
511
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
512
|
+
"from_cache": True,
|
|
513
|
+
},
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Get adapter and execute
|
|
517
|
+
adapter = self.model_registry.get_adapter(
|
|
518
|
+
adapter_type=adapter_type,
|
|
519
|
+
model_name=model_name,
|
|
520
|
+
cache=self.cache,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
# Execute the operation
|
|
524
|
+
if operation == "log_probability":
|
|
525
|
+
result = adapter.compute_log_probability(str(call_spec["text"]))
|
|
526
|
+
elif operation == "perplexity":
|
|
527
|
+
result = adapter.compute_perplexity(str(call_spec["text"]))
|
|
528
|
+
elif operation == "nli":
|
|
529
|
+
result = adapter.compute_nli(
|
|
530
|
+
str(call_spec["premise"]), str(call_spec["hypothesis"])
|
|
531
|
+
)
|
|
532
|
+
elif operation == "similarity":
|
|
533
|
+
result = adapter.compute_similarity(
|
|
534
|
+
str(call_spec["text1"]), str(call_spec["text2"])
|
|
535
|
+
)
|
|
536
|
+
elif operation == "embedding":
|
|
537
|
+
result = adapter.get_embedding(str(call_spec["text"]))
|
|
538
|
+
else:
|
|
539
|
+
raise ValueError(f"Unknown operation: {operation}")
|
|
540
|
+
|
|
541
|
+
# Generate cache key
|
|
542
|
+
cache_key = self.cache.generate_cache_key(
|
|
543
|
+
model_name, operation, **cache_key_args
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Convert inputs to MetadataValue compatible dict
|
|
547
|
+
metadata_inputs: dict[str, MetadataValue] = {
|
|
548
|
+
k: str(v) for k, v in cache_key_args.items()
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
# Create ModelOutput
|
|
552
|
+
model_version = (
|
|
553
|
+
adapter.model_version if hasattr(adapter, "model_version") else "unknown"
|
|
554
|
+
)
|
|
555
|
+
return ModelOutput(
|
|
556
|
+
model_name=model_name,
|
|
557
|
+
model_version=model_version,
|
|
558
|
+
operation=operation,
|
|
559
|
+
inputs=metadata_inputs,
|
|
560
|
+
output=result, # type: ignore[arg-type] # Output can be various types
|
|
561
|
+
cache_key=cache_key,
|
|
562
|
+
computation_metadata={
|
|
563
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
564
|
+
"from_cache": False,
|
|
565
|
+
},
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
def _check_constraints(
|
|
569
|
+
self,
|
|
570
|
+
item_template: ItemTemplate,
|
|
571
|
+
rendered_elements: dict[str, str],
|
|
572
|
+
model_outputs: list[ModelOutput],
|
|
573
|
+
constraints: dict[UUID, Constraint],
|
|
574
|
+
) -> dict[UUID, bool]:
|
|
575
|
+
"""Evaluate constraints using model outputs.
|
|
576
|
+
|
|
577
|
+
Check each constraint against rendered elements and model outputs.
|
|
578
|
+
|
|
579
|
+
Parameters
|
|
580
|
+
----------
|
|
581
|
+
item_template : ItemTemplate
|
|
582
|
+
Template with constraints.
|
|
583
|
+
rendered_elements : dict[str, str]
|
|
584
|
+
Rendered element text.
|
|
585
|
+
model_outputs : list[ModelOutput]
|
|
586
|
+
Model outputs to use in constraint evaluation.
|
|
587
|
+
constraints : dict[UUID, Constraint]
|
|
588
|
+
Map of constraint UUIDs to Constraint objects.
|
|
589
|
+
|
|
590
|
+
Returns
|
|
591
|
+
-------
|
|
592
|
+
dict[UUID, bool]
|
|
593
|
+
Map of constraint UUIDs to satisfaction status.
|
|
594
|
+
|
|
595
|
+
Raises
|
|
596
|
+
------
|
|
597
|
+
RuntimeError
|
|
598
|
+
If constraint evaluation fails.
|
|
599
|
+
ValueError
|
|
600
|
+
If constraint UUID not found.
|
|
601
|
+
"""
|
|
602
|
+
constraint_satisfaction: dict[UUID, bool] = {}
|
|
603
|
+
|
|
604
|
+
# Evaluate each constraint
|
|
605
|
+
for constraint_id in item_template.constraints:
|
|
606
|
+
if constraint_id not in constraints:
|
|
607
|
+
raise ValueError(f"Constraint {constraint_id} not found")
|
|
608
|
+
|
|
609
|
+
constraint = constraints[constraint_id]
|
|
610
|
+
|
|
611
|
+
# Evaluate constraint
|
|
612
|
+
satisfied = self._evaluate_dsl_constraint(
|
|
613
|
+
constraint, rendered_elements, model_outputs
|
|
614
|
+
)
|
|
615
|
+
constraint_satisfaction[constraint_id] = satisfied
|
|
616
|
+
|
|
617
|
+
return constraint_satisfaction
|
|
618
|
+
|
|
619
|
+
def _evaluate_dsl_constraint(
|
|
620
|
+
self,
|
|
621
|
+
constraint: Constraint,
|
|
622
|
+
rendered_elements: dict[str, str],
|
|
623
|
+
model_outputs: list[ModelOutput],
|
|
624
|
+
) -> bool:
|
|
625
|
+
"""Evaluate a DSL constraint with model outputs.
|
|
626
|
+
|
|
627
|
+
Parse and evaluate DSL expression with element variables and
|
|
628
|
+
model output values in context.
|
|
629
|
+
|
|
630
|
+
Parameters
|
|
631
|
+
----------
|
|
632
|
+
constraint : Constraint
|
|
633
|
+
Constraint to evaluate.
|
|
634
|
+
rendered_elements : dict[str, str]
|
|
635
|
+
Rendered element text for variable substitution.
|
|
636
|
+
model_outputs : list[ModelOutput]
|
|
637
|
+
Model outputs to include in context.
|
|
638
|
+
|
|
639
|
+
Returns
|
|
640
|
+
-------
|
|
641
|
+
bool
|
|
642
|
+
True if constraint is satisfied.
|
|
643
|
+
|
|
644
|
+
Raises
|
|
645
|
+
------
|
|
646
|
+
RuntimeError
|
|
647
|
+
If DSL evaluation fails.
|
|
648
|
+
"""
|
|
649
|
+
# Create evaluation context
|
|
650
|
+
context = EvaluationContext()
|
|
651
|
+
|
|
652
|
+
# Register standard library
|
|
653
|
+
register_stdlib(context)
|
|
654
|
+
|
|
655
|
+
# Register model functions that will use cached outputs
|
|
656
|
+
self._register_model_functions(context, model_outputs)
|
|
657
|
+
|
|
658
|
+
# Set element variables
|
|
659
|
+
for name, text in rendered_elements.items():
|
|
660
|
+
context.set_variable(name, text)
|
|
661
|
+
|
|
662
|
+
# Parse and evaluate
|
|
663
|
+
try:
|
|
664
|
+
ast_node = parse(constraint.expression)
|
|
665
|
+
result = self._dsl_evaluator.evaluate(ast_node, context)
|
|
666
|
+
return bool(result)
|
|
667
|
+
except Exception as e:
|
|
668
|
+
raise RuntimeError(
|
|
669
|
+
f"Failed to evaluate constraint '{constraint.expression}': {e}"
|
|
670
|
+
) from e
|
|
671
|
+
|
|
672
|
+
def _register_model_functions(
|
|
673
|
+
self,
|
|
674
|
+
context: EvaluationContext,
|
|
675
|
+
model_outputs: list[ModelOutput],
|
|
676
|
+
) -> None:
|
|
677
|
+
"""Register model functions in DSL context.
|
|
678
|
+
|
|
679
|
+
Add functions like lm_prob(), nli(), similarity() that can access
|
|
680
|
+
precomputed model outputs from cache.
|
|
681
|
+
|
|
682
|
+
Parameters
|
|
683
|
+
----------
|
|
684
|
+
context : EvaluationContext
|
|
685
|
+
DSL evaluation context.
|
|
686
|
+
model_outputs : list[ModelOutput]
|
|
687
|
+
Precomputed model outputs.
|
|
688
|
+
"""
|
|
689
|
+
# Create lookup for model outputs
|
|
690
|
+
output_map: dict[tuple[str, str, str], ModelOutput] = {}
|
|
691
|
+
for output in model_outputs:
|
|
692
|
+
# Key includes model, operation, and stringified inputs
|
|
693
|
+
inputs_str = str(sorted(output.inputs.items()))
|
|
694
|
+
key = (output.model_name, output.operation, inputs_str)
|
|
695
|
+
output_map[key] = output
|
|
696
|
+
|
|
697
|
+
# Define model functions that use cached outputs
|
|
698
|
+
def lm_prob(text: str, model: str = "gpt2") -> float:
|
|
699
|
+
"""Get log probability from cache or compute."""
|
|
700
|
+
# Check cache first
|
|
701
|
+
cached = self.cache.get(model, "log_probability", text=text)
|
|
702
|
+
if cached is not None:
|
|
703
|
+
return float(cached)
|
|
704
|
+
|
|
705
|
+
# Compute if not cached
|
|
706
|
+
adapter = self.model_registry.get_adapter(
|
|
707
|
+
adapter_type="huggingface_lm",
|
|
708
|
+
model_name=model,
|
|
709
|
+
cache=self.cache,
|
|
710
|
+
)
|
|
711
|
+
result = adapter.compute_log_probability(text)
|
|
712
|
+
return result
|
|
713
|
+
|
|
714
|
+
def lm_perplexity(text: str, model: str = "gpt2") -> float:
|
|
715
|
+
"""Get perplexity from cache or compute."""
|
|
716
|
+
cached = self.cache.get(model, "perplexity", text=text)
|
|
717
|
+
if cached is not None:
|
|
718
|
+
return float(cached)
|
|
719
|
+
|
|
720
|
+
adapter = self.model_registry.get_adapter(
|
|
721
|
+
adapter_type="huggingface_lm",
|
|
722
|
+
model_name=model,
|
|
723
|
+
cache=self.cache,
|
|
724
|
+
)
|
|
725
|
+
result = adapter.compute_perplexity(text)
|
|
726
|
+
return result
|
|
727
|
+
|
|
728
|
+
def nli(
|
|
729
|
+
premise: str, hypothesis: str, model: str = "roberta-large-mnli"
|
|
730
|
+
) -> dict[str, float]:
|
|
731
|
+
"""Get NLI scores from cache or compute."""
|
|
732
|
+
cached = self.cache.get(
|
|
733
|
+
model, "nli", premise=premise, hypothesis=hypothesis
|
|
734
|
+
)
|
|
735
|
+
if cached is not None:
|
|
736
|
+
return dict(cached) # type: ignore[arg-type]
|
|
737
|
+
|
|
738
|
+
adapter = self.model_registry.get_adapter(
|
|
739
|
+
adapter_type="huggingface_nli",
|
|
740
|
+
model_name=model,
|
|
741
|
+
cache=self.cache,
|
|
742
|
+
)
|
|
743
|
+
result = adapter.compute_nli(premise, hypothesis)
|
|
744
|
+
return result
|
|
745
|
+
|
|
746
|
+
def similarity(
|
|
747
|
+
text1: str, text2: str, model: str = "all-MiniLM-L6-v2"
|
|
748
|
+
) -> float:
|
|
749
|
+
"""Get similarity from cache or compute."""
|
|
750
|
+
cached = self.cache.get(model, "similarity", text1=text1, text2=text2)
|
|
751
|
+
if cached is not None:
|
|
752
|
+
return float(cached)
|
|
753
|
+
|
|
754
|
+
adapter = self.model_registry.get_adapter(
|
|
755
|
+
adapter_type="sentence_transformer",
|
|
756
|
+
model_name=model,
|
|
757
|
+
cache=self.cache,
|
|
758
|
+
)
|
|
759
|
+
result = adapter.compute_similarity(text1, text2)
|
|
760
|
+
return result
|
|
761
|
+
|
|
762
|
+
def embedding(text: str, model: str = "all-MiniLM-L6-v2") -> list[float]:
|
|
763
|
+
"""Get embedding from cache or compute."""
|
|
764
|
+
cached = self.cache.get(model, "embedding", text=text)
|
|
765
|
+
if cached is not None:
|
|
766
|
+
# Convert numpy array back to list
|
|
767
|
+
if isinstance(cached, np.ndarray):
|
|
768
|
+
return cached.tolist() # type: ignore[return-value]
|
|
769
|
+
return list(cached) # type: ignore[arg-type]
|
|
770
|
+
|
|
771
|
+
adapter = self.model_registry.get_adapter(
|
|
772
|
+
adapter_type="sentence_transformer",
|
|
773
|
+
model_name=model,
|
|
774
|
+
cache=self.cache,
|
|
775
|
+
)
|
|
776
|
+
result = adapter.get_embedding(text)
|
|
777
|
+
return result.tolist() # type: ignore[return-value]
|
|
778
|
+
|
|
779
|
+
# Register functions in context
|
|
780
|
+
context.set_function("lm_prob", lm_prob)
|
|
781
|
+
context.set_function("lm_perplexity", lm_perplexity)
|
|
782
|
+
context.set_function("nli", nli)
|
|
783
|
+
context.set_function("similarity", similarity)
|
|
784
|
+
context.set_function("embedding", embedding)
|