bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,129 @@
1
+ """Multi-select simulation strategy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ from bead.simulation.strategies.base import SimulationStrategy
10
+
11
+ if TYPE_CHECKING:
12
+ from bead.items.item import Item
13
+ from bead.items.item_template import ItemTemplate
14
+
15
+
16
+ class MultiSelectStrategy(SimulationStrategy):
17
+ """Strategy for multi_select tasks.
18
+
19
+ Handles tasks where multiple options can be selected independently.
20
+ Uses model outputs to compute independent selection probabilities
21
+ for each option via sigmoid.
22
+
23
+ For each option i:
24
+ P(select option i) = sigmoid(score_i / temperature)
25
+
26
+ Parameters
27
+ ----------
28
+ threshold
29
+ Probability threshold for selection. Default: 0.5.
30
+ temperature
31
+ Temperature for scaling decisions. Default: 1.0.
32
+
33
+ Examples
34
+ --------
35
+ >>> strategy = MultiSelectStrategy()
36
+ >>> strategy.supported_task_type
37
+ 'multi_select'
38
+ """
39
+
40
+ def __init__(self, threshold: float = 0.5, temperature: float = 1.0) -> None:
41
+ self.threshold = threshold
42
+ self.temperature = temperature
43
+
44
+ @property
45
+ def supported_task_type(self) -> str:
46
+ """Return 'multi_select'."""
47
+ return "multi_select"
48
+
49
+ def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
50
+ """Validate item for multi-select.
51
+
52
+ Checks:
53
+ - task_type is 'multi_select'
54
+ - task_spec.options is defined
55
+ - At least 2 options
56
+
57
+ Parameters
58
+ ----------
59
+ item : Item
60
+ Item to validate.
61
+ item_template : ItemTemplate
62
+ Template defining task.
63
+
64
+ Raises
65
+ ------
66
+ ValueError
67
+ If validation fails.
68
+ """
69
+ if item_template.task_type != "multi_select":
70
+ msg = f"Expected task_type 'multi_select', got '{item_template.task_type}'"
71
+ raise ValueError(msg)
72
+
73
+ if not item_template.task_spec.options:
74
+ raise ValueError("task_spec.options must be defined for multi_select")
75
+
76
+ if len(item_template.task_spec.options) < 2:
77
+ raise ValueError("multi_select requires at least 2 options")
78
+
79
+ def simulate_response(
80
+ self,
81
+ item: Item,
82
+ item_template: ItemTemplate,
83
+ model_output_key: str,
84
+ rng: np.random.RandomState,
85
+ ) -> list[str]:
86
+ """Generate multi-select response.
87
+
88
+ Parameters
89
+ ----------
90
+ item : Item
91
+ Item to respond to.
92
+ item_template : ItemTemplate
93
+ Template defining task.
94
+ model_output_key : str
95
+ Key for model outputs (e.g., "lm_score").
96
+ rng : np.random.RandomState
97
+ Random number generator.
98
+
99
+ Returns
100
+ -------
101
+ list[str]
102
+ List of selected option names.
103
+ """
104
+ options = item_template.task_spec.options
105
+ assert options is not None, "options validated in validate()"
106
+ n_options = len(options)
107
+
108
+ # extract model outputs for each option
109
+ scores = self.extract_model_outputs(item, model_output_key, n_options)
110
+
111
+ if scores is None:
112
+ # fallback to random selection (each option has threshold probability)
113
+ selected = []
114
+ for option in options:
115
+ if rng.random() < self.threshold:
116
+ selected.append(option)
117
+ return selected
118
+
119
+ # compute selection probability for each option using sigmoid
120
+ selected = []
121
+ for option, score in zip(options, scores, strict=True):
122
+ # sigmoid(score / temperature)
123
+ prob = 1.0 / (1.0 + np.exp(-score / self.temperature))
124
+
125
+ # sample selection
126
+ if rng.random() < prob:
127
+ selected.append(option)
128
+
129
+ return selected
@@ -0,0 +1,131 @@
1
+ """Ordinal scale simulation strategy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ from bead.simulation.strategies.base import SimulationStrategy
10
+
11
+ if TYPE_CHECKING:
12
+ from bead.items.item import Item
13
+ from bead.items.item_template import ItemTemplate
14
+
15
+
16
+ class OrdinalScaleStrategy(SimulationStrategy):
17
+ """Strategy for ordinal_scale tasks (Likert scales).
18
+
19
+ Handles discrete ordinal scales (e.g., 1-7, 1-5). Maps model outputs
20
+ to scale positions, then samples with noise around that position.
21
+
22
+ For ordinal scales with LM score:
23
+ - Map score to continuous position on scale
24
+ - Add noise
25
+ - Round to nearest integer within bounds
26
+
27
+ Examples
28
+ --------
29
+ >>> strategy = OrdinalScaleStrategy()
30
+ >>> strategy.supported_task_type
31
+ 'ordinal_scale'
32
+ """
33
+
34
+ @property
35
+ def supported_task_type(self) -> str:
36
+ """Return 'ordinal_scale'.
37
+
38
+ Returns
39
+ -------
40
+ str
41
+ Task type identifier.
42
+ """
43
+ return "ordinal_scale"
44
+
45
+ def validate_item(self, item: Item, item_template: ItemTemplate) -> None:
46
+ """Validate item for ordinal scale.
47
+
48
+ Checks:
49
+ - task_type is 'ordinal_scale'
50
+ - task_spec.scale_bounds is defined
51
+ - scale_bounds has valid min/max
52
+
53
+ Parameters
54
+ ----------
55
+ item : Item
56
+ Item to validate.
57
+ item_template : ItemTemplate
58
+ Template defining task.
59
+
60
+ Raises
61
+ ------
62
+ ValueError
63
+ If validation fails.
64
+ """
65
+ if item_template.task_type != "ordinal_scale":
66
+ msg = f"Expected task_type 'ordinal_scale', got '{item_template.task_type}'"
67
+ raise ValueError(msg)
68
+
69
+ if not item_template.task_spec.scale_bounds:
70
+ msg = "task_spec.scale_bounds must be defined for ordinal_scale"
71
+ raise ValueError(msg)
72
+
73
+ min_val, max_val = item_template.task_spec.scale_bounds
74
+ if min_val >= max_val:
75
+ msg = f"scale_bounds min ({min_val}) must be less than max ({max_val})"
76
+ raise ValueError(msg)
77
+
78
+ def simulate_response(
79
+ self,
80
+ item: Item,
81
+ item_template: ItemTemplate,
82
+ model_output_key: str,
83
+ rng: np.random.RandomState,
84
+ ) -> int:
85
+ """Generate ordinal scale response.
86
+
87
+ Parameters
88
+ ----------
89
+ item : Item
90
+ Item to respond to.
91
+ item_template : ItemTemplate
92
+ Template defining task.
93
+ model_output_key : str
94
+ Key for model outputs (e.g., "lm_score").
95
+ rng : np.random.RandomState
96
+ Random number generator.
97
+
98
+ Returns
99
+ -------
100
+ int
101
+ Rating on ordinal scale.
102
+ """
103
+ scale_bounds = item_template.task_spec.scale_bounds
104
+ if scale_bounds is None:
105
+ msg = "task_spec.scale_bounds must be defined"
106
+ raise ValueError(msg)
107
+
108
+ min_val, max_val = scale_bounds
109
+ scale_range = max_val - min_val
110
+
111
+ # extract model output (expecting single score)
112
+ scores = self.extract_model_outputs(item, model_output_key, required_count=1)
113
+
114
+ if scores is None:
115
+ # fallback to uniform random across scale
116
+ return int(rng.randint(min_val, max_val + 1))
117
+
118
+ # map LM score to scale position; use sigmoid to map unbounded score to [0, 1]
119
+ score = scores[0]
120
+ sigmoid_score = 1.0 / (1.0 + np.exp(-score))
121
+
122
+ # map [0, 1] to scale range
123
+ continuous_rating = min_val + sigmoid_score * scale_range
124
+
125
+ # round to nearest integer
126
+ rating = int(np.round(continuous_rating))
127
+
128
+ # clamp to scale bounds (in case of rounding issues)
129
+ rating = max(min_val, min(max_val, rating))
130
+
131
+ return rating
@@ -0,0 +1,27 @@
1
+ """Template filling functionality.
2
+
3
+ Provides template filling strategies (exhaustive, random, stratified) and
4
+ constraint resolution for generating experimental stimuli.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from bead.templates.filler import CSPFiller, FilledTemplate, TemplateFiller
10
+ from bead.templates.resolver import ConstraintResolver
11
+ from bead.templates.strategies import (
12
+ ExhaustiveStrategy,
13
+ RandomStrategy,
14
+ StrategyFiller,
15
+ StratifiedStrategy,
16
+ )
17
+
18
+ __all__ = [
19
+ "TemplateFiller", # ABC
20
+ "CSPFiller",
21
+ "StrategyFiller",
22
+ "FilledTemplate",
23
+ "ConstraintResolver",
24
+ "ExhaustiveStrategy",
25
+ "RandomStrategy",
26
+ "StratifiedStrategy",
27
+ ]
@@ -0,0 +1,17 @@
1
+ """Template filling model adapters.
2
+
3
+ Provides masked language model adapters for template filling (Stage 2).
4
+ Separate from judgment prediction models (Stage 3).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from bead.templates.adapters.base import TemplateFillingModelAdapter
10
+ from bead.templates.adapters.cache import ModelOutputCache
11
+ from bead.templates.adapters.huggingface import HuggingFaceMLMAdapter
12
+
13
+ __all__ = [
14
+ "TemplateFillingModelAdapter",
15
+ "ModelOutputCache",
16
+ "HuggingFaceMLMAdapter",
17
+ ]
@@ -0,0 +1,128 @@
1
+ """Base adapter for template filling models.
2
+
3
+ This module defines the abstract interface for models used in template filling.
4
+ These adapters are SEPARATE from judgment prediction model adapters (Stage 6).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from abc import ABC, abstractmethod
10
+ from pathlib import Path
11
+
12
+
13
+ class TemplateFillingModelAdapter(ABC):
14
+ """Base adapter for models used in template filling.
15
+
16
+ This is SEPARATE from judgment prediction model adapters,
17
+ which are used later in the pipeline for predicting human judgments.
18
+
19
+ Parameters
20
+ ----------
21
+ model_name : str
22
+ Model identifier (e.g., "bert-base-uncased")
23
+ device : str
24
+ Computation device ("cpu", "cuda", "mps")
25
+ cache_dir : Path | None
26
+ Directory for caching model files
27
+
28
+ Examples
29
+ --------
30
+ >>> from bead.templates.adapters import TemplateFillingModelAdapter
31
+ >>> # Implemented by HuggingFaceMLMAdapter
32
+ >>> adapter = HuggingFaceMLMAdapter("bert-base-uncased", device="cpu")
33
+ >>> adapter.load_model()
34
+ >>> predictions = adapter.predict_masked_token(
35
+ ... text="The cat [MASK] on the mat",
36
+ ... mask_position=2,
37
+ ... top_k=5
38
+ ... )
39
+ >>> adapter.unload_model()
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ model_name: str,
45
+ device: str = "cpu",
46
+ cache_dir: Path | None = None,
47
+ ) -> None:
48
+ self.model_name = model_name
49
+ self.device = device
50
+ self.cache_dir = cache_dir
51
+ self._model_loaded = False
52
+
53
+ @abstractmethod
54
+ def load_model(self) -> None:
55
+ """Load model into memory.
56
+
57
+ Raises
58
+ ------
59
+ RuntimeError
60
+ If model loading fails
61
+ """
62
+ pass
63
+
64
+ @abstractmethod
65
+ def unload_model(self) -> None:
66
+ """Unload model from memory to free resources."""
67
+ pass
68
+
69
+ @abstractmethod
70
+ def predict_masked_token(
71
+ self,
72
+ text: str,
73
+ mask_position: int,
74
+ top_k: int = 10,
75
+ ) -> list[tuple[str, float]]:
76
+ """Predict masked token at specified position.
77
+
78
+ Parameters
79
+ ----------
80
+ text : str
81
+ Text with mask token (e.g., "The cat [MASK] quickly")
82
+ mask_position : int
83
+ Token position of mask (0-indexed)
84
+ top_k : int
85
+ Number of top predictions to return
86
+
87
+ Returns
88
+ -------
89
+ list[tuple[str, float]]
90
+ List of (token, log_probability) tuples, sorted by probability
91
+
92
+ Raises
93
+ ------
94
+ RuntimeError
95
+ If model is not loaded
96
+ ValueError
97
+ If mask_position is invalid
98
+
99
+ Examples
100
+ --------
101
+ >>> predictions = adapter.predict_masked_token(
102
+ ... text="The cat [MASK] on the mat",
103
+ ... mask_position=2,
104
+ ... top_k=3
105
+ ... )
106
+ >>> predictions
107
+ [("sat", -0.5), ("slept", -1.2), ("jumped", -1.5)]
108
+ """
109
+ pass
110
+
111
+ def is_loaded(self) -> bool:
112
+ """Check if model is loaded.
113
+
114
+ Returns
115
+ -------
116
+ bool
117
+ True if model is loaded in memory
118
+ """
119
+ return self._model_loaded
120
+
121
+ def __enter__(self) -> TemplateFillingModelAdapter:
122
+ """Context manager entry."""
123
+ self.load_model()
124
+ return self
125
+
126
+ def __exit__(self, *args: object) -> None:
127
+ """Context manager exit."""
128
+ self.unload_model()
@@ -0,0 +1,178 @@
1
+ """Content-addressable cache for model predictions.
2
+
3
+ This module implements caching for template filling model predictions
4
+ using SHA256-based content addressing.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import hashlib
10
+ import json
11
+ from pathlib import Path
12
+
13
+
14
+ class ModelOutputCache:
15
+ """Content-addressable cache for model predictions.
16
+
17
+ Uses SHA256 hashing to create deterministic cache keys based on:
18
+ - Model name
19
+ - Input text
20
+ - Mask position
21
+ - Top-K parameter
22
+
23
+ Parameters
24
+ ----------
25
+ cache_dir : Path
26
+ Directory for cache storage
27
+ enabled : bool
28
+ Enable/disable caching
29
+
30
+ Examples
31
+ --------
32
+ >>> cache = ModelOutputCache(cache_dir=Path("/tmp/cache"), enabled=True)
33
+ >>> key_args = ("bert-base-uncased", "The cat [MASK]", 2, 10)
34
+ >>> predictions = cache.get(*key_args)
35
+ >>> if predictions is None:
36
+ ... predictions = model.predict(...)
37
+ ... cache.set(*key_args, predictions)
38
+ """
39
+
40
+ def __init__(self, cache_dir: Path, enabled: bool = True) -> None:
41
+ self.cache_dir = cache_dir
42
+ self.enabled = enabled
43
+
44
+ if self.enabled:
45
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ def _compute_key(
48
+ self,
49
+ model_name: str,
50
+ input_text: str,
51
+ mask_position: int,
52
+ top_k: int,
53
+ ) -> str:
54
+ """Compute cache key from inputs.
55
+
56
+ Parameters
57
+ ----------
58
+ model_name : str
59
+ Model identifier
60
+ input_text : str
61
+ Input text with mask
62
+ mask_position : int
63
+ Position of mask token
64
+ top_k : int
65
+ Number of predictions
66
+
67
+ Returns
68
+ -------
69
+ str
70
+ SHA256 hex digest
71
+ """
72
+ # create deterministic key
73
+ key_data = {
74
+ "model_name": model_name,
75
+ "input_text": input_text,
76
+ "mask_position": mask_position,
77
+ "top_k": top_k,
78
+ }
79
+
80
+ # serialize to JSON with sorted keys for determinism
81
+ key_json = json.dumps(key_data, sort_keys=True)
82
+
83
+ # hash with SHA256
84
+ return hashlib.sha256(key_json.encode("utf-8")).hexdigest()
85
+
86
+ def get(
87
+ self,
88
+ model_name: str,
89
+ input_text: str,
90
+ mask_position: int,
91
+ top_k: int,
92
+ ) -> list[tuple[str, float]] | None:
93
+ """Get cached predictions.
94
+
95
+ Parameters
96
+ ----------
97
+ model_name : str
98
+ Model identifier
99
+ input_text : str
100
+ Input text
101
+ mask_position : int
102
+ Mask position
103
+ top_k : int
104
+ Number of predictions
105
+
106
+ Returns
107
+ -------
108
+ list[tuple[str, float]] | None
109
+ Cached predictions or None if not found
110
+ """
111
+ if not self.enabled:
112
+ return None
113
+
114
+ cache_key = self._compute_key(model_name, input_text, mask_position, top_k)
115
+ cache_file = self.cache_dir / f"{cache_key}.json"
116
+
117
+ if not cache_file.exists():
118
+ return None
119
+
120
+ try:
121
+ with open(cache_file) as f:
122
+ data = json.load(f)
123
+ return [(item["token"], item["log_prob"]) for item in data]
124
+ except (json.JSONDecodeError, KeyError, OSError):
125
+ # cache corruption; return None
126
+ return None
127
+
128
+ def set(
129
+ self,
130
+ model_name: str,
131
+ input_text: str,
132
+ mask_position: int,
133
+ top_k: int,
134
+ predictions: list[tuple[str, float]],
135
+ ) -> None:
136
+ """Store predictions in cache.
137
+
138
+ Parameters
139
+ ----------
140
+ model_name : str
141
+ Model identifier
142
+ input_text : str
143
+ Input text
144
+ mask_position : int
145
+ Mask position
146
+ top_k : int
147
+ Number of predictions
148
+ predictions : list[tuple[str, float]]
149
+ Predictions to cache
150
+ """
151
+ if not self.enabled:
152
+ return
153
+
154
+ cache_key = self._compute_key(model_name, input_text, mask_position, top_k)
155
+ cache_file = self.cache_dir / f"{cache_key}.json"
156
+
157
+ # convert to serializable format
158
+ data = [
159
+ {"token": token, "log_prob": log_prob} for token, log_prob in predictions
160
+ ]
161
+
162
+ try:
163
+ with open(cache_file, "w") as f:
164
+ json.dump(data, f, indent=2)
165
+ except OSError:
166
+ # silently fail on cache write errors
167
+ pass
168
+
169
+ def clear(self) -> None:
170
+ """Clear all cached predictions."""
171
+ if not self.enabled or not self.cache_dir.exists():
172
+ return
173
+
174
+ for cache_file in self.cache_dir.glob("*.json"):
175
+ try:
176
+ cache_file.unlink()
177
+ except OSError:
178
+ pass