bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,312 @@
1
+ """HuggingFace masked language model adapter for template filling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForMaskedLM,
10
+ AutoTokenizer,
11
+ PreTrainedModel,
12
+ PreTrainedTokenizer,
13
+ )
14
+
15
+ from bead.adapters.huggingface import DeviceType, HuggingFaceAdapterMixin
16
+ from bead.templates.adapters.base import TemplateFillingModelAdapter
17
+
18
+
19
+ class HuggingFaceMLMAdapter(HuggingFaceAdapterMixin, TemplateFillingModelAdapter):
20
+ """Adapter for HuggingFace masked language models.
21
+
22
+ Supports BERT, RoBERTa, ALBERT, and other MLM architectures.
23
+
24
+ Parameters
25
+ ----------
26
+ model_name : str
27
+ HuggingFace model identifier (e.g., "bert-base-uncased")
28
+ device : DeviceType
29
+ Computation device ("cpu", "cuda", "mps")
30
+ cache_dir : Path | None
31
+ Directory for caching model files
32
+
33
+ Examples
34
+ --------
35
+ >>> adapter = HuggingFaceMLMAdapter("bert-base-uncased", device="cpu")
36
+ >>> adapter.load_model()
37
+ >>> predictions = adapter.predict_masked_token(
38
+ ... text="The cat sat on the mat",
39
+ ... mask_position=2,
40
+ ... top_k=5
41
+ ... )
42
+ >>> for token, log_prob in predictions:
43
+ ... print(f"{token}: {log_prob:.2f}")
44
+ >>> adapter.unload_model()
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ model_name: str,
50
+ device: DeviceType = "cpu",
51
+ cache_dir: Path | None = None,
52
+ ) -> None:
53
+ # validate device before passing to parent
54
+ validated_device = self._validate_device(device)
55
+ super().__init__(model_name, validated_device, cache_dir)
56
+ self.model: PreTrainedModel | None = None
57
+ self.tokenizer: PreTrainedTokenizer | None = None
58
+
59
+ def load_model(self) -> None:
60
+ """Load model and tokenizer from HuggingFace.
61
+
62
+ Raises
63
+ ------
64
+ RuntimeError
65
+ If model loading fails
66
+ """
67
+ if self._model_loaded:
68
+ return
69
+
70
+ try:
71
+ # load tokenizer
72
+ self.tokenizer = AutoTokenizer.from_pretrained(
73
+ self.model_name,
74
+ cache_dir=self.cache_dir,
75
+ )
76
+
77
+ # load model
78
+ self.model = AutoModelForMaskedLM.from_pretrained(
79
+ self.model_name,
80
+ cache_dir=self.cache_dir,
81
+ )
82
+
83
+ # move to device
84
+ self.model.to(self.device)
85
+
86
+ # set to evaluation mode
87
+ self.model.eval()
88
+
89
+ self._model_loaded = True
90
+
91
+ except Exception as e:
92
+ raise RuntimeError(f"Failed to load model {self.model_name}: {e}") from e
93
+
94
+ def unload_model(self) -> None:
95
+ """Unload model from memory."""
96
+ if not self._model_loaded:
97
+ return
98
+
99
+ # move model to CPU and delete
100
+ if self.model is not None:
101
+ self.model.to("cpu")
102
+ del self.model
103
+ self.model = None
104
+
105
+ del self.tokenizer
106
+ self.tokenizer = None
107
+
108
+ self._model_loaded = False
109
+
110
+ # clear CUDA cache if using GPU
111
+ if self.device == "cuda":
112
+ torch.cuda.empty_cache()
113
+
114
+ def predict_masked_token(
115
+ self,
116
+ text: str,
117
+ mask_position: int,
118
+ top_k: int = 10,
119
+ ) -> list[tuple[str, float]]:
120
+ """Predict masked token at specified position.
121
+
122
+ Parameters
123
+ ----------
124
+ text : str
125
+ Text with mask token (e.g., "The cat [MASK] quickly")
126
+ mask_position : int
127
+ Token position of mask (0-indexed)
128
+ top_k : int
129
+ Number of top predictions to return
130
+
131
+ Returns
132
+ -------
133
+ list[tuple[str, float]]
134
+ List of (token, log_probability) tuples, sorted by probability
135
+
136
+ Raises
137
+ ------
138
+ RuntimeError
139
+ If model is not loaded
140
+ ValueError
141
+ If mask_position is invalid or text has no mask token
142
+ """
143
+ if not self._model_loaded:
144
+ raise RuntimeError("Model not loaded. Call load_model() first.")
145
+
146
+ # tokenize input
147
+ inputs = self.tokenizer(text, return_tensors="pt")
148
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
149
+
150
+ # find mask token ID
151
+ mask_token_id = self.tokenizer.mask_token_id
152
+ if mask_token_id is None:
153
+ raise ValueError(f"Model {self.model_name} does not have a mask token")
154
+
155
+ # find mask position in tokenized input
156
+ input_ids = inputs["input_ids"][0]
157
+ mask_positions = (input_ids == mask_token_id).nonzero(as_tuple=True)[0]
158
+
159
+ if len(mask_positions) == 0:
160
+ raise ValueError(f"No mask token found in text: {text}")
161
+
162
+ if mask_position >= len(mask_positions):
163
+ raise ValueError(
164
+ f"mask_position {mask_position} out of range. "
165
+ f"Found {len(mask_positions)} mask tokens in text."
166
+ )
167
+
168
+ # get actual token index
169
+ mask_idx = mask_positions[mask_position].item()
170
+
171
+ # forward pass
172
+ with torch.no_grad():
173
+ outputs = self.model(**inputs)
174
+ logits = outputs.logits
175
+
176
+ # get predictions for mask position
177
+ mask_logits = logits[0, mask_idx]
178
+
179
+ # convert to log probabilities
180
+ log_probs = torch.log_softmax(mask_logits, dim=0)
181
+
182
+ # get top-k predictions
183
+ top_log_probs, top_indices = torch.topk(log_probs, k=min(top_k, len(log_probs)))
184
+
185
+ # convert to tokens
186
+ predictions: list[tuple[str, float]] = []
187
+ for log_prob, idx in zip(top_log_probs.cpu(), top_indices.cpu(), strict=True):
188
+ token = self.tokenizer.decode([idx], skip_special_tokens=True).strip()
189
+ predictions.append((token, float(log_prob)))
190
+
191
+ return predictions
192
+
193
+ def predict_masked_token_batch(
194
+ self,
195
+ texts: list[str],
196
+ mask_position: int = 0,
197
+ top_k: int = 10,
198
+ ) -> list[list[tuple[str, float]]]:
199
+ """Predict masked tokens for multiple texts in a single batch.
200
+
201
+ Parameters
202
+ ----------
203
+ texts : list[str]
204
+ List of texts with mask tokens
205
+ mask_position : int
206
+ Token position of mask (0-indexed, relative to mask tokens found)
207
+ top_k : int
208
+ Number of top predictions to return per text
209
+
210
+ Returns
211
+ -------
212
+ list[list[tuple[str, float]]]
213
+ List of predictions for each text. Each element is a list of
214
+ (token, log_probability) tuples.
215
+
216
+ Raises
217
+ ------
218
+ RuntimeError
219
+ If model is not loaded
220
+ ValueError
221
+ If any text has no mask token
222
+ """
223
+ if not self._model_loaded:
224
+ raise RuntimeError("Model not loaded. Call load_model() first.")
225
+
226
+ if not texts:
227
+ return []
228
+
229
+ # tokenize all texts with padding
230
+ inputs = self.tokenizer(
231
+ texts,
232
+ return_tensors="pt",
233
+ padding=True,
234
+ truncation=True,
235
+ )
236
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
237
+
238
+ # find mask token ID
239
+ mask_token_id = self.tokenizer.mask_token_id
240
+ if mask_token_id is None:
241
+ raise ValueError(f"Model {self.model_name} does not have a mask token")
242
+
243
+ # forward pass for entire batch
244
+ with torch.no_grad():
245
+ outputs = self.model(**inputs)
246
+ logits = outputs.logits # shape: (batch_size, seq_len, vocab_size)
247
+
248
+ # process each text in batch
249
+ results: list[list[tuple[str, float]]] = []
250
+ for i, text in enumerate(texts):
251
+ # find mask position in this text
252
+ input_ids = inputs["input_ids"][i]
253
+ mask_positions = (input_ids == mask_token_id).nonzero(as_tuple=True)[0]
254
+
255
+ if len(mask_positions) == 0:
256
+ raise ValueError(f"No mask token found in text: {text}")
257
+
258
+ if mask_position >= len(mask_positions):
259
+ raise ValueError(
260
+ f"mask_position {mask_position} out of range. "
261
+ f"Found {len(mask_positions)} mask tokens in text."
262
+ )
263
+
264
+ # get actual token index
265
+ mask_idx = mask_positions[mask_position].item()
266
+
267
+ # get predictions for this mask position
268
+ mask_logits = logits[i, mask_idx]
269
+
270
+ # convert to log probabilities
271
+ log_probs = torch.log_softmax(mask_logits, dim=0)
272
+
273
+ # get top-k predictions
274
+ top_log_probs, top_indices = torch.topk(
275
+ log_probs, k=min(top_k, len(log_probs))
276
+ )
277
+
278
+ # convert to tokens
279
+ predictions: list[tuple[str, float]] = []
280
+ for log_prob, idx in zip(
281
+ top_log_probs.cpu(), top_indices.cpu(), strict=True
282
+ ):
283
+ token = self.tokenizer.decode([idx], skip_special_tokens=True).strip()
284
+ predictions.append((token, float(log_prob)))
285
+
286
+ results.append(predictions)
287
+
288
+ return results
289
+
290
+ def get_mask_token(self) -> str:
291
+ """Get the mask token for this model.
292
+
293
+ Returns
294
+ -------
295
+ str
296
+ Mask token string (e.g., "[MASK]" for BERT)
297
+
298
+ Raises
299
+ ------
300
+ RuntimeError
301
+ If model is not loaded
302
+ ValueError
303
+ If model has no mask token
304
+ """
305
+ if not self._model_loaded:
306
+ raise RuntimeError("Model not loaded. Call load_model() first.")
307
+
308
+ mask_token = self.tokenizer.mask_token
309
+ if mask_token is None:
310
+ raise ValueError(f"Model {self.model_name} does not have a mask token")
311
+
312
+ return mask_token
@@ -0,0 +1,103 @@
1
+ """Combinatorial utilities for template filling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import itertools
6
+ import random
7
+ from collections.abc import Iterator
8
+ from typing import TypeVar
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ def cartesian_product[T](*iterables: list[T]) -> Iterator[tuple[T, ...]]:
14
+ """Generate Cartesian product of iterables.
15
+
16
+ Equivalent to itertools.product but with explicit type hints
17
+ and documentation for template filling use case.
18
+
19
+ Parameters
20
+ ----------
21
+ *iterables : list[T]
22
+ Variable number of iterables to combine.
23
+
24
+ Yields
25
+ ------
26
+ tuple[T, ...]
27
+ Each combination from the Cartesian product.
28
+
29
+ Examples
30
+ --------
31
+ >>> list(cartesian_product([1, 2], ['a', 'b']))
32
+ [(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b')]
33
+ """
34
+ return itertools.product(*iterables)
35
+
36
+
37
+ def count_combinations[T](*iterables: list[T]) -> int:
38
+ """Count total combinations without generating them.
39
+
40
+ Calculate the size of the Cartesian product space efficiently
41
+ without actually generating combinations.
42
+
43
+ Parameters
44
+ ----------
45
+ *iterables : list[Any]
46
+ Variable number of iterables.
47
+
48
+ Returns
49
+ -------
50
+ int
51
+ Total number of combinations.
52
+
53
+ Examples
54
+ --------
55
+ >>> count_combinations([1, 2], ['a', 'b'], [True, False])
56
+ 8
57
+ """
58
+ count = 1
59
+ for iterable in iterables:
60
+ count *= len(iterable)
61
+ return count
62
+
63
+
64
+ def stratified_sample[T](
65
+ groups: dict[str, list[T]],
66
+ n_per_group: int,
67
+ seed: int | None = None,
68
+ ) -> list[T]:
69
+ """Sample items from groups with balanced representation.
70
+
71
+ Ensure each group is represented approximately equally in the sample.
72
+
73
+ Parameters
74
+ ----------
75
+ groups : dict[str, list[T]]
76
+ Dictionary mapping group names to lists of items.
77
+ n_per_group : int
78
+ Number of items to sample from each group.
79
+ seed : int | None
80
+ Random seed for reproducibility.
81
+
82
+ Returns
83
+ -------
84
+ list[T]
85
+ Sampled items, balanced across groups.
86
+
87
+ Examples
88
+ --------
89
+ >>> groups = {"verbs": [v1, v2, v3], "nouns": [n1, n2, n3]}
90
+ >>> sample = stratified_sample(groups, n_per_group=2, seed=42)
91
+ >>> len(sample)
92
+ 4
93
+ """
94
+ if seed is not None:
95
+ random.seed(seed)
96
+
97
+ sampled: list[T] = []
98
+ for group_items in groups.values():
99
+ # Sample with replacement if n_per_group > len(group_items)
100
+ k = min(n_per_group, len(group_items))
101
+ sampled.extend(random.sample(group_items, k))
102
+
103
+ return sampled