bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Multi-select simulation strategy."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from bead.simulation.strategies.base import SimulationStrategy
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from bead.items.item import Item
|
|
13
|
+
from bead.items.item_template import ItemTemplate
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MultiSelectStrategy(SimulationStrategy):
|
|
17
|
+
"""Strategy for multi_select tasks.
|
|
18
|
+
|
|
19
|
+
Handles tasks where multiple options can be selected independently.
|
|
20
|
+
Uses model outputs to compute independent selection probabilities
|
|
21
|
+
for each option via sigmoid.
|
|
22
|
+
|
|
23
|
+
For each option i:
|
|
24
|
+
P(select option i) = sigmoid(score_i / temperature)
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
threshold
|
|
29
|
+
Probability threshold for selection. Default: 0.5.
|
|
30
|
+
temperature
|
|
31
|
+
Temperature for scaling decisions. Default: 1.0.
|
|
32
|
+
|
|
33
|
+
Examples
|
|
34
|
+
--------
|
|
35
|
+
>>> strategy = MultiSelectStrategy()
|
|
36
|
+
>>> strategy.supported_task_type
|
|
37
|
+
'multi_select'
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, threshold: float = 0.5, temperature: float = 1.0) -> None:
|
|
41
|
+
self.threshold = threshold
|
|
42
|
+
self.temperature = temperature
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def supported_task_type(self) -> str:
|
|
46
|
+
"""Return 'multi_select'."""
|
|
47
|
+
return "multi_select"
|
|
48
|
+
|
|
49
|
+
def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
|
|
50
|
+
"""Validate item for multi-select.
|
|
51
|
+
|
|
52
|
+
Checks:
|
|
53
|
+
- task_type is 'multi_select'
|
|
54
|
+
- task_spec.options is defined
|
|
55
|
+
- At least 2 options
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
item : Item
|
|
60
|
+
Item to validate.
|
|
61
|
+
item_template : ItemTemplate
|
|
62
|
+
Template defining task.
|
|
63
|
+
|
|
64
|
+
Raises
|
|
65
|
+
------
|
|
66
|
+
ValueError
|
|
67
|
+
If validation fails.
|
|
68
|
+
"""
|
|
69
|
+
if item_template.task_type != "multi_select":
|
|
70
|
+
msg = f"Expected task_type 'multi_select', got '{item_template.task_type}'"
|
|
71
|
+
raise ValueError(msg)
|
|
72
|
+
|
|
73
|
+
if not item_template.task_spec.options:
|
|
74
|
+
raise ValueError("task_spec.options must be defined for multi_select")
|
|
75
|
+
|
|
76
|
+
if len(item_template.task_spec.options) < 2:
|
|
77
|
+
raise ValueError("multi_select requires at least 2 options")
|
|
78
|
+
|
|
79
|
+
def simulate_response(
|
|
80
|
+
self,
|
|
81
|
+
item: Item,
|
|
82
|
+
item_template: ItemTemplate,
|
|
83
|
+
model_output_key: str,
|
|
84
|
+
rng: np.random.RandomState,
|
|
85
|
+
) -> list[str]:
|
|
86
|
+
"""Generate multi-select response.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
item : Item
|
|
91
|
+
Item to respond to.
|
|
92
|
+
item_template : ItemTemplate
|
|
93
|
+
Template defining task.
|
|
94
|
+
model_output_key : str
|
|
95
|
+
Key for model outputs (e.g., "lm_score").
|
|
96
|
+
rng : np.random.RandomState
|
|
97
|
+
Random number generator.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
list[str]
|
|
102
|
+
List of selected option names.
|
|
103
|
+
"""
|
|
104
|
+
options = item_template.task_spec.options
|
|
105
|
+
assert options is not None, "options validated in validate()"
|
|
106
|
+
n_options = len(options)
|
|
107
|
+
|
|
108
|
+
# extract model outputs for each option
|
|
109
|
+
scores = self.extract_model_outputs(item, model_output_key, n_options)
|
|
110
|
+
|
|
111
|
+
if scores is None:
|
|
112
|
+
# fallback to random selection (each option has threshold probability)
|
|
113
|
+
selected = []
|
|
114
|
+
for option in options:
|
|
115
|
+
if rng.random() < self.threshold:
|
|
116
|
+
selected.append(option)
|
|
117
|
+
return selected
|
|
118
|
+
|
|
119
|
+
# compute selection probability for each option using sigmoid
|
|
120
|
+
selected = []
|
|
121
|
+
for option, score in zip(options, scores, strict=True):
|
|
122
|
+
# sigmoid(score / temperature)
|
|
123
|
+
prob = 1.0 / (1.0 + np.exp(-score / self.temperature))
|
|
124
|
+
|
|
125
|
+
# sample selection
|
|
126
|
+
if rng.random() < prob:
|
|
127
|
+
selected.append(option)
|
|
128
|
+
|
|
129
|
+
return selected
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Ordinal scale simulation strategy."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from bead.simulation.strategies.base import SimulationStrategy
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from bead.items.item import Item
|
|
13
|
+
from bead.items.item_template import ItemTemplate
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OrdinalScaleStrategy(SimulationStrategy):
|
|
17
|
+
"""Strategy for ordinal_scale tasks (Likert scales).
|
|
18
|
+
|
|
19
|
+
Handles discrete ordinal scales (e.g., 1-7, 1-5). Maps model outputs
|
|
20
|
+
to scale positions, then samples with noise around that position.
|
|
21
|
+
|
|
22
|
+
For ordinal scales with LM score:
|
|
23
|
+
- Map score to continuous position on scale
|
|
24
|
+
- Add noise
|
|
25
|
+
- Round to nearest integer within bounds
|
|
26
|
+
|
|
27
|
+
Examples
|
|
28
|
+
--------
|
|
29
|
+
>>> strategy = OrdinalScaleStrategy()
|
|
30
|
+
>>> strategy.supported_task_type
|
|
31
|
+
'ordinal_scale'
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def supported_task_type(self) -> str:
|
|
36
|
+
"""Return 'ordinal_scale'.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
str
|
|
41
|
+
Task type identifier.
|
|
42
|
+
"""
|
|
43
|
+
return "ordinal_scale"
|
|
44
|
+
|
|
45
|
+
def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
|
|
46
|
+
"""Validate item for ordinal scale.
|
|
47
|
+
|
|
48
|
+
Checks:
|
|
49
|
+
- task_type is 'ordinal_scale'
|
|
50
|
+
- task_spec.scale_bounds is defined
|
|
51
|
+
- scale_bounds has valid min/max
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
item : Item
|
|
56
|
+
Item to validate.
|
|
57
|
+
item_template : ItemTemplate
|
|
58
|
+
Template defining task.
|
|
59
|
+
|
|
60
|
+
Raises
|
|
61
|
+
------
|
|
62
|
+
ValueError
|
|
63
|
+
If validation fails.
|
|
64
|
+
"""
|
|
65
|
+
if item_template.task_type != "ordinal_scale":
|
|
66
|
+
msg = f"Expected task_type 'ordinal_scale', got '{item_template.task_type}'"
|
|
67
|
+
raise ValueError(msg)
|
|
68
|
+
|
|
69
|
+
if not item_template.task_spec.scale_bounds:
|
|
70
|
+
msg = "task_spec.scale_bounds must be defined for ordinal_scale"
|
|
71
|
+
raise ValueError(msg)
|
|
72
|
+
|
|
73
|
+
min_val, max_val = item_template.task_spec.scale_bounds
|
|
74
|
+
if min_val >= max_val:
|
|
75
|
+
msg = f"scale_bounds min ({min_val}) must be less than max ({max_val})"
|
|
76
|
+
raise ValueError(msg)
|
|
77
|
+
|
|
78
|
+
def simulate_response(
|
|
79
|
+
self,
|
|
80
|
+
item: Item,
|
|
81
|
+
item_template: ItemTemplate,
|
|
82
|
+
model_output_key: str,
|
|
83
|
+
rng: np.random.RandomState,
|
|
84
|
+
) -> int:
|
|
85
|
+
"""Generate ordinal scale response.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
item : Item
|
|
90
|
+
Item to respond to.
|
|
91
|
+
item_template : ItemTemplate
|
|
92
|
+
Template defining task.
|
|
93
|
+
model_output_key : str
|
|
94
|
+
Key for model outputs (e.g., "lm_score").
|
|
95
|
+
rng : np.random.RandomState
|
|
96
|
+
Random number generator.
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
int
|
|
101
|
+
Rating on ordinal scale.
|
|
102
|
+
"""
|
|
103
|
+
scale_bounds = item_template.task_spec.scale_bounds
|
|
104
|
+
if scale_bounds is None:
|
|
105
|
+
msg = "task_spec.scale_bounds must be defined"
|
|
106
|
+
raise ValueError(msg)
|
|
107
|
+
|
|
108
|
+
min_val, max_val = scale_bounds
|
|
109
|
+
scale_range = max_val - min_val
|
|
110
|
+
|
|
111
|
+
# extract model output (expecting single score)
|
|
112
|
+
scores = self.extract_model_outputs(item, model_output_key, required_count=1)
|
|
113
|
+
|
|
114
|
+
if scores is None:
|
|
115
|
+
# fallback to uniform random across scale
|
|
116
|
+
return int(rng.randint(min_val, max_val + 1))
|
|
117
|
+
|
|
118
|
+
# map LM score to scale position; use sigmoid to map unbounded score to [0, 1]
|
|
119
|
+
score = scores[0]
|
|
120
|
+
sigmoid_score = 1.0 / (1.0 + np.exp(-score))
|
|
121
|
+
|
|
122
|
+
# map [0, 1] to scale range
|
|
123
|
+
continuous_rating = min_val + sigmoid_score * scale_range
|
|
124
|
+
|
|
125
|
+
# round to nearest integer
|
|
126
|
+
rating = int(np.round(continuous_rating))
|
|
127
|
+
|
|
128
|
+
# clamp to scale bounds (in case of rounding issues)
|
|
129
|
+
rating = max(min_val, min(max_val, rating))
|
|
130
|
+
|
|
131
|
+
return rating
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Template filling functionality.
|
|
2
|
+
|
|
3
|
+
Provides template filling strategies (exhaustive, random, stratified) and
|
|
4
|
+
constraint resolution for generating experimental stimuli.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from bead.templates.filler import CSPFiller, FilledTemplate, TemplateFiller
|
|
10
|
+
from bead.templates.resolver import ConstraintResolver
|
|
11
|
+
from bead.templates.strategies import (
|
|
12
|
+
ExhaustiveStrategy,
|
|
13
|
+
RandomStrategy,
|
|
14
|
+
StrategyFiller,
|
|
15
|
+
StratifiedStrategy,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"TemplateFiller", # ABC
|
|
20
|
+
"CSPFiller",
|
|
21
|
+
"StrategyFiller",
|
|
22
|
+
"FilledTemplate",
|
|
23
|
+
"ConstraintResolver",
|
|
24
|
+
"ExhaustiveStrategy",
|
|
25
|
+
"RandomStrategy",
|
|
26
|
+
"StratifiedStrategy",
|
|
27
|
+
]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Template filling model adapters.
|
|
2
|
+
|
|
3
|
+
Provides masked language model adapters for template filling (Stage 2).
|
|
4
|
+
Separate from judgment prediction models (Stage 3).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from bead.templates.adapters.base import TemplateFillingModelAdapter
|
|
10
|
+
from bead.templates.adapters.cache import ModelOutputCache
|
|
11
|
+
from bead.templates.adapters.huggingface import HuggingFaceMLMAdapter
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"TemplateFillingModelAdapter",
|
|
15
|
+
"ModelOutputCache",
|
|
16
|
+
"HuggingFaceMLMAdapter",
|
|
17
|
+
]
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Base adapter for template filling models.
|
|
2
|
+
|
|
3
|
+
This module defines the abstract interface for models used in template filling.
|
|
4
|
+
These adapters are SEPARATE from judgment prediction model adapters (Stage 6).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TemplateFillingModelAdapter(ABC):
|
|
14
|
+
"""Base adapter for models used in template filling.
|
|
15
|
+
|
|
16
|
+
This is SEPARATE from judgment prediction model adapters,
|
|
17
|
+
which are used later in the pipeline for predicting human judgments.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
model_name : str
|
|
22
|
+
Model identifier (e.g., "bert-base-uncased")
|
|
23
|
+
device : str
|
|
24
|
+
Computation device ("cpu", "cuda", "mps")
|
|
25
|
+
cache_dir : Path | None
|
|
26
|
+
Directory for caching model files
|
|
27
|
+
|
|
28
|
+
Examples
|
|
29
|
+
--------
|
|
30
|
+
>>> from bead.templates.adapters import TemplateFillingModelAdapter
|
|
31
|
+
>>> # Implemented by HuggingFaceMLMAdapter
|
|
32
|
+
>>> adapter = HuggingFaceMLMAdapter("bert-base-uncased", device="cpu")
|
|
33
|
+
>>> adapter.load_model()
|
|
34
|
+
>>> predictions = adapter.predict_masked_token(
|
|
35
|
+
... text="The cat [MASK] on the mat",
|
|
36
|
+
... mask_position=2,
|
|
37
|
+
... top_k=5
|
|
38
|
+
... )
|
|
39
|
+
>>> adapter.unload_model()
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
model_name: str,
|
|
45
|
+
device: str = "cpu",
|
|
46
|
+
cache_dir: Path | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
self.model_name = model_name
|
|
49
|
+
self.device = device
|
|
50
|
+
self.cache_dir = cache_dir
|
|
51
|
+
self._model_loaded = False
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def load_model(self) -> None:
|
|
55
|
+
"""Load model into memory.
|
|
56
|
+
|
|
57
|
+
Raises
|
|
58
|
+
------
|
|
59
|
+
RuntimeError
|
|
60
|
+
If model loading fails
|
|
61
|
+
"""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def unload_model(self) -> None:
|
|
66
|
+
"""Unload model from memory to free resources."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def predict_masked_token(
|
|
71
|
+
self,
|
|
72
|
+
text: str,
|
|
73
|
+
mask_position: int,
|
|
74
|
+
top_k: int = 10,
|
|
75
|
+
) -> list[tuple[str, float]]:
|
|
76
|
+
"""Predict masked token at specified position.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
text : str
|
|
81
|
+
Text with mask token (e.g., "The cat [MASK] quickly")
|
|
82
|
+
mask_position : int
|
|
83
|
+
Token position of mask (0-indexed)
|
|
84
|
+
top_k : int
|
|
85
|
+
Number of top predictions to return
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
list[tuple[str, float]]
|
|
90
|
+
List of (token, log_probability) tuples, sorted by probability
|
|
91
|
+
|
|
92
|
+
Raises
|
|
93
|
+
------
|
|
94
|
+
RuntimeError
|
|
95
|
+
If model is not loaded
|
|
96
|
+
ValueError
|
|
97
|
+
If mask_position is invalid
|
|
98
|
+
|
|
99
|
+
Examples
|
|
100
|
+
--------
|
|
101
|
+
>>> predictions = adapter.predict_masked_token(
|
|
102
|
+
... text="The cat [MASK] on the mat",
|
|
103
|
+
... mask_position=2,
|
|
104
|
+
... top_k=3
|
|
105
|
+
... )
|
|
106
|
+
>>> predictions
|
|
107
|
+
[("sat", -0.5), ("slept", -1.2), ("jumped", -1.5)]
|
|
108
|
+
"""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
def is_loaded(self) -> bool:
|
|
112
|
+
"""Check if model is loaded.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
bool
|
|
117
|
+
True if model is loaded in memory
|
|
118
|
+
"""
|
|
119
|
+
return self._model_loaded
|
|
120
|
+
|
|
121
|
+
def __enter__(self) -> TemplateFillingModelAdapter:
|
|
122
|
+
"""Context manager entry."""
|
|
123
|
+
self.load_model()
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
def __exit__(self, *args: object) -> None:
|
|
127
|
+
"""Context manager exit."""
|
|
128
|
+
self.unload_model()
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Content-addressable cache for model predictions.
|
|
2
|
+
|
|
3
|
+
This module implements caching for template filling model predictions
|
|
4
|
+
using SHA256-based content addressing.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelOutputCache:
|
|
15
|
+
"""Content-addressable cache for model predictions.
|
|
16
|
+
|
|
17
|
+
Uses SHA256 hashing to create deterministic cache keys based on:
|
|
18
|
+
- Model name
|
|
19
|
+
- Input text
|
|
20
|
+
- Mask position
|
|
21
|
+
- Top-K parameter
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
cache_dir : Path
|
|
26
|
+
Directory for cache storage
|
|
27
|
+
enabled : bool
|
|
28
|
+
Enable/disable caching
|
|
29
|
+
|
|
30
|
+
Examples
|
|
31
|
+
--------
|
|
32
|
+
>>> cache = ModelOutputCache(cache_dir=Path("/tmp/cache"), enabled=True)
|
|
33
|
+
>>> key_args = ("bert-base-uncased", "The cat [MASK]", 2, 10)
|
|
34
|
+
>>> predictions = cache.get(*key_args)
|
|
35
|
+
>>> if predictions is None:
|
|
36
|
+
... predictions = model.predict(...)
|
|
37
|
+
... cache.set(*key_args, predictions)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, cache_dir: Path, enabled: bool = True) -> None:
|
|
41
|
+
self.cache_dir = cache_dir
|
|
42
|
+
self.enabled = enabled
|
|
43
|
+
|
|
44
|
+
if self.enabled:
|
|
45
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
46
|
+
|
|
47
|
+
def _compute_key(
|
|
48
|
+
self,
|
|
49
|
+
model_name: str,
|
|
50
|
+
input_text: str,
|
|
51
|
+
mask_position: int,
|
|
52
|
+
top_k: int,
|
|
53
|
+
) -> str:
|
|
54
|
+
"""Compute cache key from inputs.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
model_name : str
|
|
59
|
+
Model identifier
|
|
60
|
+
input_text : str
|
|
61
|
+
Input text with mask
|
|
62
|
+
mask_position : int
|
|
63
|
+
Position of mask token
|
|
64
|
+
top_k : int
|
|
65
|
+
Number of predictions
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
str
|
|
70
|
+
SHA256 hex digest
|
|
71
|
+
"""
|
|
72
|
+
# create deterministic key
|
|
73
|
+
key_data = {
|
|
74
|
+
"model_name": model_name,
|
|
75
|
+
"input_text": input_text,
|
|
76
|
+
"mask_position": mask_position,
|
|
77
|
+
"top_k": top_k,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# serialize to JSON with sorted keys for determinism
|
|
81
|
+
key_json = json.dumps(key_data, sort_keys=True)
|
|
82
|
+
|
|
83
|
+
# hash with SHA256
|
|
84
|
+
return hashlib.sha256(key_json.encode("utf-8")).hexdigest()
|
|
85
|
+
|
|
86
|
+
def get(
|
|
87
|
+
self,
|
|
88
|
+
model_name: str,
|
|
89
|
+
input_text: str,
|
|
90
|
+
mask_position: int,
|
|
91
|
+
top_k: int,
|
|
92
|
+
) -> list[tuple[str, float]] | None:
|
|
93
|
+
"""Get cached predictions.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
model_name : str
|
|
98
|
+
Model identifier
|
|
99
|
+
input_text : str
|
|
100
|
+
Input text
|
|
101
|
+
mask_position : int
|
|
102
|
+
Mask position
|
|
103
|
+
top_k : int
|
|
104
|
+
Number of predictions
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
list[tuple[str, float]] | None
|
|
109
|
+
Cached predictions or None if not found
|
|
110
|
+
"""
|
|
111
|
+
if not self.enabled:
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
cache_key = self._compute_key(model_name, input_text, mask_position, top_k)
|
|
115
|
+
cache_file = self.cache_dir / f"{cache_key}.json"
|
|
116
|
+
|
|
117
|
+
if not cache_file.exists():
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
with open(cache_file) as f:
|
|
122
|
+
data = json.load(f)
|
|
123
|
+
return [(item["token"], item["log_prob"]) for item in data]
|
|
124
|
+
except (json.JSONDecodeError, KeyError, OSError):
|
|
125
|
+
# cache corruption; return None
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
def set(
|
|
129
|
+
self,
|
|
130
|
+
model_name: str,
|
|
131
|
+
input_text: str,
|
|
132
|
+
mask_position: int,
|
|
133
|
+
top_k: int,
|
|
134
|
+
predictions: list[tuple[str, float]],
|
|
135
|
+
) -> None:
|
|
136
|
+
"""Store predictions in cache.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
model_name : str
|
|
141
|
+
Model identifier
|
|
142
|
+
input_text : str
|
|
143
|
+
Input text
|
|
144
|
+
mask_position : int
|
|
145
|
+
Mask position
|
|
146
|
+
top_k : int
|
|
147
|
+
Number of predictions
|
|
148
|
+
predictions : list[tuple[str, float]]
|
|
149
|
+
Predictions to cache
|
|
150
|
+
"""
|
|
151
|
+
if not self.enabled:
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
cache_key = self._compute_key(model_name, input_text, mask_position, top_k)
|
|
155
|
+
cache_file = self.cache_dir / f"{cache_key}.json"
|
|
156
|
+
|
|
157
|
+
# convert to serializable format
|
|
158
|
+
data = [
|
|
159
|
+
{"token": token, "log_prob": log_prob} for token, log_prob in predictions
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
with open(cache_file, "w") as f:
|
|
164
|
+
json.dump(data, f, indent=2)
|
|
165
|
+
except OSError:
|
|
166
|
+
# silently fail on cache write errors
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
def clear(self) -> None:
|
|
170
|
+
"""Clear all cached predictions."""
|
|
171
|
+
if not self.enabled or not self.cache_dir.exists():
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
for cache_file in self.cache_dir.glob("*.json"):
|
|
175
|
+
try:
|
|
176
|
+
cache_file.unlink()
|
|
177
|
+
except OSError:
|
|
178
|
+
pass
|