bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,123 @@
1
+ """Categorical choice simulation strategy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ from bead.simulation.strategies.base import SimulationStrategy
10
+
11
+ if TYPE_CHECKING:
12
+ from bead.items.item import Item
13
+ from bead.items.item_template import ItemTemplate
14
+
15
+
16
+ class CategoricalStrategy(SimulationStrategy):
17
+ """Strategy for categorical tasks (unordered multi-class).
18
+
19
+ Similar to forced_choice but for unordered categories (e.g., NLI labels,
20
+ sentiment classes). Uses softmax over model outputs.
21
+
22
+ For categorical with LM scores:
23
+ P(category_i) = softmax([score_1, ..., score_n] / temperature)[i]
24
+
25
+ Examples
26
+ --------
27
+ >>> strategy = CategoricalStrategy()
28
+ >>> strategy.supported_task_type
29
+ 'categorical'
30
+ """
31
+
32
+ @property
33
+ def supported_task_type(self) -> str:
34
+ """Return 'categorical'.
35
+
36
+ Returns
37
+ -------
38
+ str
39
+ Task type identifier.
40
+ """
41
+ return "categorical"
42
+
43
+ def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
44
+ """Validate item for categorical choice.
45
+
46
+ Checks:
47
+ - task_type is 'categorical'
48
+ - task_spec.options is defined
49
+ - At least 2 options available
50
+
51
+ Parameters
52
+ ----------
53
+ item : Item
54
+ Item to validate.
55
+ item_template : ItemTemplate
56
+ Template defining task.
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If validation fails.
62
+ """
63
+ if item_template.task_type != "categorical":
64
+ msg = f"Expected task_type 'categorical', got '{item_template.task_type}'"
65
+ raise ValueError(msg)
66
+
67
+ if not item_template.task_spec.options:
68
+ msg = "task_spec.options must be defined for categorical"
69
+ raise ValueError(msg)
70
+
71
+ if len(item_template.task_spec.options) < 2:
72
+ msg = "categorical requires at least 2 options"
73
+ raise ValueError(msg)
74
+
75
+ def simulate_response(
76
+ self,
77
+ item: Item,
78
+ item_template: ItemTemplate,
79
+ model_output_key: str,
80
+ rng: np.random.RandomState,
81
+ ) -> str:
82
+ """Generate categorical response.
83
+
84
+ Parameters
85
+ ----------
86
+ item : Item
87
+ Item to respond to.
88
+ item_template : ItemTemplate
89
+ Template defining task.
90
+ model_output_key : str
91
+ Key for model outputs (e.g., "lm_score").
92
+ rng : np.random.RandomState
93
+ Random number generator.
94
+
95
+ Returns
96
+ -------
97
+ str
98
+ Chosen category name.
99
+ """
100
+ options = item_template.task_spec.options
101
+ if options is None:
102
+ msg = "task_spec.options must be defined"
103
+ raise ValueError(msg)
104
+
105
+ n_options = len(options)
106
+
107
+ # extract model outputs for each category
108
+ scores = self.extract_model_outputs(item, model_output_key, n_options)
109
+
110
+ if scores is None:
111
+ # fallback to uniform random
112
+ choice_idx = rng.randint(0, n_options)
113
+ return options[choice_idx]
114
+
115
+ # convert scores to probabilities using softmax
116
+ scores_array = np.array(scores)
117
+ exp_scores = np.exp(scores_array - np.max(scores_array)) # numerical stability
118
+ probs = exp_scores / np.sum(exp_scores)
119
+
120
+ # sample from distribution
121
+ choice_idx = rng.choice(n_options, p=probs)
122
+
123
+ return options[choice_idx]
@@ -0,0 +1,224 @@
1
+ """Cloze (fill-in-the-blank) simulation strategy using MLM scores."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ from bead.simulation.strategies.base import SimulationStrategy
10
+
11
+ if TYPE_CHECKING:
12
+ from bead.items.item import Item
13
+ from bead.items.item_template import ItemTemplate
14
+
15
+ __all__ = ["ClozeStrategy"]
16
+
17
+
18
+ class ClozeStrategy(SimulationStrategy):
19
+ """MLM-based strategy for cloze (fill-in-the-blank) tasks.
20
+
21
+ Uses masked language model scores to select fillers for unfilled slots.
22
+ For constrained slots (with specific options), selects highest-scoring option.
23
+ For unconstrained slots, uses rendered_elements or metadata as fallback.
24
+
25
+ The strategy expects model outputs to contain MLM scores for each slot,
26
+ stored as separate ModelOutput instances with operation="mlm_score" and
27
+ inputs containing {"slot_name": slot_name, "candidate": candidate_value}.
28
+
29
+ Examples
30
+ --------
31
+ >>> from bead.simulation.strategies.cloze import ClozeStrategy
32
+ >>> strategy = ClozeStrategy()
33
+ >>> # item with unfilled_slots and model_outputs with MLM scores
34
+ >>> # response = strategy.simulate_response(item, template, "mlm_score", rng)
35
+ >>> # Returns: {"determiner": "the", "verb": "chased", "object": "mouse"}
36
+ """
37
+
38
+ @property
39
+ def supported_task_type(self) -> str:
40
+ """Get supported task type.
41
+
42
+ Returns
43
+ -------
44
+ str
45
+ Always returns "cloze".
46
+ """
47
+ return "cloze"
48
+
49
+ def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
50
+ """Validate item is compatible with cloze strategy.
51
+
52
+ Parameters
53
+ ----------
54
+ item : Item
55
+ Item to validate.
56
+ item_template : ItemTemplate
57
+ Template defining task.
58
+
59
+ Raises
60
+ ------
61
+ ValueError
62
+ If validation fails.
63
+ """
64
+ if item_template.task_type != "cloze":
65
+ msg = f"Expected task_type 'cloze', got '{item_template.task_type}'"
66
+ raise ValueError(msg)
67
+
68
+ if not item.unfilled_slots:
69
+ raise ValueError("cloze task requires at least one unfilled slot")
70
+
71
+ def simulate_response(
72
+ self,
73
+ item: Item,
74
+ item_template: ItemTemplate,
75
+ model_output_key: str,
76
+ rng: np.random.RandomState,
77
+ ) -> dict[str, str]:
78
+ """Simulate cloze response using MLM scores.
79
+
80
+ For each unfilled slot, selects the filler with highest MLM score.
81
+ Falls back to random selection or metadata if MLM scores unavailable.
82
+
83
+ Parameters
84
+ ----------
85
+ item : Item
86
+ Item to annotate.
87
+ item_template : ItemTemplate
88
+ Template defining task constraints.
89
+ model_output_key : str
90
+ Key identifying which model outputs to use (e.g., "mlm_score").
91
+ rng : np.random.RandomState
92
+ Random number generator for stochasticity.
93
+
94
+ Returns
95
+ -------
96
+ dict[str, str]
97
+ Dictionary mapping slot names to selected fillers.
98
+
99
+ Examples
100
+ --------
101
+ >>> response = {"determiner": "the", "verb": "chased", "object": "mouse"}
102
+ """
103
+ response = {}
104
+
105
+ for slot in item.unfilled_slots:
106
+ slot_name = slot.slot_name
107
+
108
+ # try to get MLM scores for this slot
109
+ slot_scores = self._get_slot_scores(item, slot_name, model_output_key)
110
+
111
+ if slot_scores:
112
+ # select filler with highest score (with softmax sampling)
113
+ fillers = list(slot_scores.keys())
114
+ scores = np.array(list(slot_scores.values()))
115
+
116
+ # apply softmax to convert scores to probabilities
117
+ exp_scores = np.exp(scores - np.max(scores)) # numerical stability
118
+ probs = exp_scores / np.sum(exp_scores)
119
+
120
+ # sample from distribution
121
+ selected_idx = rng.choice(len(fillers), p=probs)
122
+ response[slot_name] = fillers[selected_idx]
123
+ else:
124
+ # fallback: use ground truth if available, else placeholder
125
+ response[slot_name] = self._get_fallback_filler(item, slot_name, rng)
126
+
127
+ return response
128
+
129
+ def _get_slot_scores(
130
+ self, item: Item, slot_name: str, model_output_key: str
131
+ ) -> dict[str, float]:
132
+ """Extract MLM scores for a specific slot.
133
+
134
+ Looks for ModelOutput instances where:
135
+ - operation matches model_output_key (e.g., "mlm_score")
136
+ - inputs contains {"slot_name": slot_name}
137
+ - inputs contains "candidate" (the filler being scored)
138
+ - output is the MLM score
139
+
140
+ Parameters
141
+ ----------
142
+ item : Item
143
+ Item containing model outputs.
144
+ slot_name : str
145
+ Name of the slot to get scores for.
146
+ model_output_key : str
147
+ Operation type to filter by.
148
+
149
+ Returns
150
+ -------
151
+ dict[str, float]
152
+ Mapping from candidate fillers to MLM scores.
153
+ """
154
+ scores = {}
155
+
156
+ for model_output in item.model_outputs:
157
+ if model_output.operation != model_output_key:
158
+ continue
159
+
160
+ inputs = model_output.inputs
161
+
162
+ # check if this output is for our slot
163
+ if inputs.get("slot_name") != slot_name:
164
+ continue
165
+
166
+ candidate = inputs.get("candidate")
167
+ if candidate is None:
168
+ continue
169
+
170
+ # extract score
171
+ score = model_output.output
172
+ if isinstance(score, int | float):
173
+ scores[str(candidate)] = float(score)
174
+
175
+ return scores
176
+
177
+ def _get_fallback_filler(
178
+ self, item: Item, slot_name: str, rng: np.random.RandomState
179
+ ) -> str:
180
+ """Get fallback filler when MLM scores unavailable.
181
+
182
+ Priority:
183
+ 1. Ground truth from item_metadata["ground_truth"][slot_name]
184
+ 2. Random common filler based on slot name pattern
185
+ 3. Generic placeholder
186
+
187
+ Parameters
188
+ ----------
189
+ item : Item
190
+ Item to get fallback from.
191
+ slot_name : str
192
+ Slot name.
193
+ rng : np.random.RandomState
194
+ Random number generator.
195
+
196
+ Returns
197
+ -------
198
+ str
199
+ Fallback filler.
200
+ """
201
+ # try ground truth
202
+ if hasattr(item, "item_metadata") and item.item_metadata:
203
+ ground_truth = item.item_metadata.get("ground_truth")
204
+ if isinstance(ground_truth, dict) and slot_name in ground_truth:
205
+ return str(ground_truth[slot_name])
206
+
207
+ # common fallbacks by slot name patterns
208
+ fallback_options = {
209
+ "determiner": ["the", "a", "an", "this", "that"],
210
+ "verb": ["is", "was", "has", "can", "will"],
211
+ "noun": ["thing", "person", "place", "time", "way"],
212
+ "adjective": ["good", "new", "old", "big", "small"],
213
+ "adverb": ["very", "well", "just", "now", "here"],
214
+ "preposition": ["in", "on", "at", "to", "for"],
215
+ }
216
+
217
+ # match slot name to category
218
+ slot_lower = slot_name.lower()
219
+ for category, options in fallback_options.items():
220
+ if category in slot_lower:
221
+ return str(rng.choice(options))
222
+
223
+ # generic fallback
224
+ return f"[{slot_name}]"
@@ -0,0 +1,127 @@
1
+ """Forced choice simulation strategy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ from bead.simulation.strategies.base import SimulationStrategy
10
+
11
+ if TYPE_CHECKING:
12
+ from bead.items.item import Item
13
+ from bead.items.item_template import ItemTemplate
14
+
15
+
16
+ class ForcedChoiceStrategy(SimulationStrategy):
17
+ """Strategy for forced_choice tasks (n-AFC).
18
+
19
+ Handles 2AFC, 3AFC, 4AFC, etc. Uses model outputs to compute
20
+ preference probabilities, then samples categorically.
21
+
22
+ For 2AFC with LM scores:
23
+ P(choose A) = sigmoid((score_A - score_B) / temperature)
24
+
25
+ For n-AFC with LM scores:
26
+ P(choose i) = softmax([score_1, ..., score_n] / temperature)[i]
27
+
28
+ Examples
29
+ --------
30
+ >>> strategy = ForcedChoiceStrategy()
31
+ >>> strategy.supported_task_type
32
+ 'forced_choice'
33
+ """
34
+
35
+ @property
36
+ def supported_task_type(self) -> str:
37
+ """Return 'forced_choice'.
38
+
39
+ Returns
40
+ -------
41
+ str
42
+ Task type identifier.
43
+ """
44
+ return "forced_choice"
45
+
46
+ def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
47
+ """Validate item for forced choice.
48
+
49
+ Checks:
50
+ - task_type is 'forced_choice'
51
+ - task_spec.options is defined
52
+ - Item has appropriate model outputs OR can fall back
53
+
54
+ Parameters
55
+ ----------
56
+ item : Item
57
+ Item to validate.
58
+ item_template : ItemTemplate
59
+ Template defining task.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If validation fails.
65
+ """
66
+ if item_template.task_type != "forced_choice":
67
+ msg = f"Expected task_type 'forced_choice', got '{item_template.task_type}'"
68
+ raise ValueError(msg)
69
+
70
+ if not item_template.task_spec.options:
71
+ msg = "task_spec.options must be defined for forced_choice"
72
+ raise ValueError(msg)
73
+
74
+ if len(item_template.task_spec.options) < 2:
75
+ msg = "forced_choice requires at least 2 options"
76
+ raise ValueError(msg)
77
+
78
+ def simulate_response(
79
+ self,
80
+ item: Item,
81
+ item_template: ItemTemplate,
82
+ model_output_key: str,
83
+ rng: np.random.RandomState,
84
+ ) -> str:
85
+ """Generate forced choice response.
86
+
87
+ Parameters
88
+ ----------
89
+ item : Item
90
+ Item to respond to.
91
+ item_template : ItemTemplate
92
+ Template defining task.
93
+ model_output_key : str
94
+ Key for model outputs (e.g., "lm_score").
95
+ rng : np.random.RandomState
96
+ Random number generator.
97
+
98
+ Returns
99
+ -------
100
+ str
101
+ Chosen option name.
102
+ """
103
+ options = item_template.task_spec.options
104
+ if options is None:
105
+ msg = "task_spec.options must be defined"
106
+ raise ValueError(msg)
107
+
108
+ n_options = len(options)
109
+
110
+ # extract model outputs for each option
111
+ scores = self.extract_model_outputs(item, model_output_key, n_options)
112
+
113
+ if scores is None:
114
+ # fallback to uniform random
115
+ choice_idx = rng.randint(0, n_options)
116
+ return options[choice_idx]
117
+
118
+ # convert scores to probabilities using softmax
119
+ # (will be scaled by noise model later)
120
+ scores_array = np.array(scores)
121
+ exp_scores = np.exp(scores_array - np.max(scores_array)) # numerical stability
122
+ probs = exp_scores / np.sum(exp_scores)
123
+
124
+ # sample from distribution
125
+ choice_idx = rng.choice(n_options, p=probs)
126
+
127
+ return options[choice_idx]
@@ -0,0 +1,105 @@
1
+ """Free text simulation strategy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ from bead.simulation.strategies.base import SimulationStrategy
10
+
11
+ if TYPE_CHECKING:
12
+ from bead.items.item import Item
13
+ from bead.items.item_template import ItemTemplate
14
+
15
+
16
+ class FreeTextStrategy(SimulationStrategy):
17
+ """Strategy for free_text tasks.
18
+
19
+ Handles free text generation using rule-based approaches.
20
+ For simulations, this typically:
21
+ - Extracts text from rendered_elements
22
+ - Uses templates if provided
23
+ - Falls back to simple defaults
24
+
25
+ Note: This is a simplified implementation for simulation purposes.
26
+ For realistic free text generation, consider using LLMs.
27
+
28
+ Examples
29
+ --------
30
+ >>> strategy = FreeTextStrategy()
31
+ >>> strategy.supported_task_type
32
+ 'free_text'
33
+ """
34
+
35
+ @property
36
+ def supported_task_type(self) -> str:
37
+ """Return 'free_text'."""
38
+ return "free_text"
39
+
40
+ def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
41
+ """Validate item for free text.
42
+
43
+ Checks:
44
+ - task_type is 'free_text'
45
+
46
+ Parameters
47
+ ----------
48
+ item : Item
49
+ Item to validate.
50
+ item_template : ItemTemplate
51
+ Template defining task.
52
+
53
+ Raises
54
+ ------
55
+ ValueError
56
+ If validation fails.
57
+ """
58
+ if item_template.task_type != "free_text":
59
+ raise ValueError(
60
+ f"Expected task_type 'free_text', got '{item_template.task_type}'"
61
+ )
62
+
63
+ def simulate_response(
64
+ self,
65
+ item: Item,
66
+ item_template: ItemTemplate,
67
+ model_output_key: str,
68
+ rng: np.random.RandomState,
69
+ ) -> str:
70
+ """Generate free text response.
71
+
72
+ Parameters
73
+ ----------
74
+ item : Item
75
+ Item to respond to.
76
+ item_template : ItemTemplate
77
+ Template defining task.
78
+ model_output_key : str
79
+ Key for model outputs (unused for free text).
80
+ rng : np.random.RandomState
81
+ Random number generator.
82
+
83
+ Returns
84
+ -------
85
+ str
86
+ Generated text response.
87
+ """
88
+ # check if there's a ground truth response we can use
89
+ if hasattr(item, "item_metadata") and "response" in item.item_metadata:
90
+ return str(item.item_metadata["response"])
91
+
92
+ # check for text template
93
+ if hasattr(item, "item_metadata") and "response_template" in item.item_metadata:
94
+ return str(item.item_metadata["response_template"])
95
+
96
+ # try to extract from rendered elements
97
+ if hasattr(item, "rendered_elements") and item.rendered_elements:
98
+ # get first text element as fallback
99
+ for value in item.rendered_elements.values():
100
+ if len(value) > 0:
101
+ # return first non-empty string
102
+ return value
103
+
104
+ # final fallback: generic response
105
+ return "No response"
@@ -0,0 +1,116 @@
1
+ """Magnitude estimation simulation strategy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ from bead.simulation.strategies.base import SimulationStrategy
10
+
11
+ if TYPE_CHECKING:
12
+ from bead.items.item import Item
13
+ from bead.items.item_template import ItemTemplate
14
+
15
+
16
+ class MagnitudeStrategy(SimulationStrategy):
17
+ """Strategy for magnitude estimation tasks.
18
+
19
+ Handles unbounded numeric magnitude estimation. Converts model outputs
20
+ (typically LM scores) to positive magnitude values.
21
+
22
+ For LM scores (typically negative log probabilities):
23
+ magnitude = exp(-score / scale_factor)
24
+
25
+ This maps:
26
+ - Better scores (less negative) -> larger magnitudes
27
+ - Worse scores (more negative) -> smaller magnitudes
28
+
29
+ Parameters
30
+ ----------
31
+ scale_factor
32
+ Scaling factor for converting scores to magnitudes.
33
+ Higher values produce more variation. Default: 10.0.
34
+
35
+ Examples
36
+ --------
37
+ >>> strategy = MagnitudeStrategy()
38
+ >>> strategy.supported_task_type
39
+ 'magnitude'
40
+ """
41
+
42
+ def __init__(self, scale_factor: float = 10.0) -> None:
43
+ self.scale_factor = scale_factor
44
+
45
+ @property
46
+ def supported_task_type(self) -> str:
47
+ """Return 'magnitude'."""
48
+ return "magnitude"
49
+
50
+ def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
51
+ """Validate item for magnitude estimation.
52
+
53
+ Checks:
54
+ - task_type is 'magnitude'
55
+ - Item has model outputs OR can fall back
56
+
57
+ Parameters
58
+ ----------
59
+ item : Item
60
+ Item to validate.
61
+ item_template : ItemTemplate
62
+ Template defining task.
63
+
64
+ Raises
65
+ ------
66
+ ValueError
67
+ If validation fails.
68
+ """
69
+ if item_template.task_type != "magnitude":
70
+ raise ValueError(
71
+ f"Expected task_type 'magnitude', got '{item_template.task_type}'"
72
+ )
73
+
74
+ def simulate_response(
75
+ self,
76
+ item: Item,
77
+ item_template: ItemTemplate,
78
+ model_output_key: str,
79
+ rng: np.random.RandomState,
80
+ ) -> float:
81
+ """Generate magnitude estimation response.
82
+
83
+ Parameters
84
+ ----------
85
+ item : Item
86
+ Item to respond to.
87
+ item_template : ItemTemplate
88
+ Template defining task.
89
+ model_output_key : str
90
+ Key for model outputs (e.g., "lm_score").
91
+ rng : np.random.RandomState
92
+ Random number generator.
93
+
94
+ Returns
95
+ -------
96
+ float
97
+ Estimated magnitude (positive value).
98
+ """
99
+ # extract model output (expect single value)
100
+ scores = self.extract_model_outputs(item, model_output_key, required_count=1)
101
+
102
+ if scores is None:
103
+ # fallback to random positive value (log-normal)
104
+ return float(rng.lognormal(mean=0, sigma=1))
105
+
106
+ score = scores[0]
107
+
108
+ # convert score to magnitude:
109
+ # for LM scores (negative), exp(-score/scale) gives positive magnitude
110
+ # for positive scores, use exp(score/scale)
111
+ if score < 0:
112
+ magnitude = np.exp(-score / self.scale_factor)
113
+ else:
114
+ magnitude = np.exp(score / self.scale_factor)
115
+
116
+ return float(magnitude)