bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
bead/items/cache.py ADDED
@@ -0,0 +1,558 @@
1
+ """Content-addressable cache for judgment model outputs.
2
+
3
+ This module provides caching infrastructure for model outputs during item
4
+ construction. It supports multiple backends (filesystem, in-memory) and various
5
+ operation types including log probabilities, NLI scores, embeddings, and
6
+ similarity metrics.
7
+
8
+ Note: This cache is distinct from bead.templates.adapters.cache, which handles
9
+ MLM predictions for template filling. This module caches judgment model outputs
10
+ used in item construction.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import hashlib
16
+ import json
17
+ import logging
18
+ from abc import ABC, abstractmethod
19
+ from datetime import UTC, datetime
20
+ from pathlib import Path
21
+ from typing import Any, Literal
22
+
23
+ import numpy as np
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class CacheBackend(ABC):
29
+ """Abstract base class for cache backends.
30
+
31
+ Defines the interface that all cache backends must implement.
32
+ """
33
+
34
+ @abstractmethod
35
+ def get(self, key: str) -> dict[str, object] | None:
36
+ """Retrieve cache entry by key.
37
+
38
+ Parameters
39
+ ----------
40
+ key
41
+ Cache key to retrieve.
42
+
43
+ Returns
44
+ -------
45
+ dict[str, object] | None
46
+ Cache entry data if found, None otherwise.
47
+ """
48
+ pass
49
+
50
+ @abstractmethod
51
+ def set(self, key: str, data: dict[str, object]) -> None:
52
+ """Store cache entry with key.
53
+
54
+ Parameters
55
+ ----------
56
+ key
57
+ Cache key.
58
+ data
59
+ Cache entry data to store.
60
+ """
61
+ pass
62
+
63
+ @abstractmethod
64
+ def delete(self, key: str) -> None:
65
+ """Delete cache entry by key.
66
+
67
+ Parameters
68
+ ----------
69
+ key
70
+ Cache key to delete.
71
+ """
72
+ pass
73
+
74
+ @abstractmethod
75
+ def clear(self) -> None:
76
+ """Clear all cache entries."""
77
+ pass
78
+
79
+ @abstractmethod
80
+ def keys(self) -> list[str]:
81
+ """Return all cache keys.
82
+
83
+ Returns
84
+ -------
85
+ list[str]
86
+ List of all cache keys in the backend.
87
+ """
88
+ pass
89
+
90
+
91
+ class FilesystemBackend(CacheBackend):
92
+ """Filesystem-based cache backend.
93
+
94
+ Stores each cache entry as a separate JSON file with the cache key as
95
+ the filename.
96
+
97
+ Parameters
98
+ ----------
99
+ cache_dir : Path
100
+ Directory for cache storage.
101
+
102
+ Attributes
103
+ ----------
104
+ cache_dir : Path
105
+ Directory where cache files are stored.
106
+
107
+ Examples
108
+ --------
109
+ >>> from pathlib import Path
110
+ >>> backend = FilesystemBackend(cache_dir=Path(".cache"))
111
+ >>> backend.set("abc123", {"result": 42})
112
+ >>> backend.get("abc123")
113
+ {'result': 42}
114
+ """
115
+
116
+ def __init__(self, cache_dir: Path) -> None:
117
+ self.cache_dir = cache_dir
118
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
119
+
120
+ def get(self, key: str) -> dict[str, object] | None:
121
+ """Retrieve cache entry from filesystem.
122
+
123
+ Parameters
124
+ ----------
125
+ key
126
+ Cache key.
127
+
128
+ Returns
129
+ -------
130
+ dict[str, object] | None
131
+ Cache entry data if found, None otherwise.
132
+ """
133
+ cache_file = self.cache_dir / f"{key}.json"
134
+ try:
135
+ if cache_file.exists():
136
+ with open(cache_file, encoding="utf-8") as f:
137
+ return json.load(f)
138
+ return None
139
+ except (json.JSONDecodeError, OSError) as e:
140
+ logger.warning(f"Failed to read cache file {cache_file}: {e}")
141
+ return None
142
+
143
+ def set(self, key: str, data: dict[str, object]) -> None:
144
+ """Store cache entry to filesystem.
145
+
146
+ Parameters
147
+ ----------
148
+ key
149
+ Cache key.
150
+ data
151
+ Cache entry data.
152
+ """
153
+ cache_file = self.cache_dir / f"{key}.json"
154
+ try:
155
+ with open(cache_file, "w", encoding="utf-8") as f:
156
+ json.dump(data, f, indent=2)
157
+ except OSError as e:
158
+ logger.warning(f"Failed to write cache file {cache_file}: {e}")
159
+
160
+ def delete(self, key: str) -> None:
161
+ """Delete cache entry from filesystem.
162
+
163
+ Parameters
164
+ ----------
165
+ key
166
+ Cache key to delete.
167
+ """
168
+ cache_file = self.cache_dir / f"{key}.json"
169
+ try:
170
+ if cache_file.exists():
171
+ cache_file.unlink()
172
+ except OSError as e:
173
+ logger.warning(f"Failed to delete cache file {cache_file}: {e}")
174
+
175
+ def clear(self) -> None:
176
+ """Clear all cache entries from filesystem."""
177
+ try:
178
+ for cache_file in self.cache_dir.glob("*.json"):
179
+ cache_file.unlink()
180
+ except OSError as e:
181
+ logger.warning(f"Failed to clear cache directory {self.cache_dir}: {e}")
182
+
183
+ def keys(self) -> list[str]:
184
+ """Return all cache keys from filesystem.
185
+
186
+ Returns
187
+ -------
188
+ list[str]
189
+ List of cache keys (filenames without .json extension).
190
+ """
191
+ try:
192
+ return [f.stem for f in self.cache_dir.glob("*.json")]
193
+ except OSError as e:
194
+ logger.warning(f"Failed to list cache keys in {self.cache_dir}: {e}")
195
+ return []
196
+
197
+
198
+ class InMemoryBackend(CacheBackend):
199
+ """In-memory cache backend.
200
+
201
+ Stores cache entries in a dictionary. No persistence across program runs.
202
+ Useful for testing and temporary caching scenarios.
203
+
204
+ Examples
205
+ --------
206
+ >>> backend = InMemoryBackend()
207
+ >>> backend.set("xyz789", {"result": 3.14})
208
+ >>> backend.get("xyz789")
209
+ {'result': 3.14}
210
+ """
211
+
212
+ def __init__(self) -> None:
213
+ self._cache: dict[str, dict[str, object]] = {}
214
+
215
+ def get(self, key: str) -> dict[str, object] | None:
216
+ """Retrieve cache entry from memory.
217
+
218
+ Parameters
219
+ ----------
220
+ key
221
+ Cache key.
222
+
223
+ Returns
224
+ -------
225
+ dict[str, object] | None
226
+ Cache entry data if found, None otherwise.
227
+ """
228
+ return self._cache.get(key)
229
+
230
+ def set(self, key: str, data: dict[str, object]) -> None:
231
+ """Store cache entry in memory.
232
+
233
+ Parameters
234
+ ----------
235
+ key
236
+ Cache key.
237
+ data
238
+ Cache entry data.
239
+ """
240
+ self._cache[key] = data
241
+
242
+ def delete(self, key: str) -> None:
243
+ """Delete cache entry from memory.
244
+
245
+ Parameters
246
+ ----------
247
+ key
248
+ Cache key to delete.
249
+ """
250
+ self._cache.pop(key, None)
251
+
252
+ def clear(self) -> None:
253
+ """Clear all cache entries from memory."""
254
+ self._cache.clear()
255
+
256
+ def keys(self) -> list[str]:
257
+ """Return all cache keys from memory.
258
+
259
+ Returns
260
+ -------
261
+ list[str]
262
+ List of cache keys.
263
+ """
264
+ return list(self._cache.keys())
265
+
266
+
267
+ class ModelOutputCache:
268
+ """Content-addressable cache for judgment model outputs.
269
+
270
+ Caches results from various model operations to avoid redundant computation.
271
+ Supports multiple operation types including log probabilities, perplexity,
272
+ NLI scores, embeddings, and similarity metrics.
273
+
274
+ Cache keys are automatically generated using SHA-256 hashing of the model
275
+ name, operation type, and all input parameters, ensuring deterministic
276
+ cache hits for identical inputs.
277
+
278
+ Parameters
279
+ ----------
280
+ cache_dir : Path | None
281
+ Directory for cache files (filesystem backend only).
282
+ Defaults to ~/.cache/bead/models if not specified.
283
+ backend : {"filesystem", "memory"}
284
+ Cache backend type. "filesystem" persists across runs,
285
+ "memory" is ephemeral.
286
+ enabled : bool
287
+ Whether caching is enabled.
288
+
289
+ Attributes
290
+ ----------
291
+ enabled : bool
292
+ Whether caching is enabled. When False, all operations are no-ops.
293
+
294
+ Examples
295
+ --------
296
+ Basic usage with filesystem backend:
297
+
298
+ >>> from pathlib import Path
299
+ >>> cache = ModelOutputCache(cache_dir=Path(".cache"))
300
+ >>> result = cache.get("gpt2", "log_probability", text="Hello world")
301
+ >>> if result is None:
302
+ ... result = -2.5
303
+ ... cache.set("gpt2", "log_probability", result, text="Hello world")
304
+
305
+ Caching NLI scores:
306
+
307
+ >>> nli_scores = cache.get("roberta-nli", "nli",
308
+ ... premise="Mary loves books",
309
+ ... hypothesis="Mary enjoys reading")
310
+ >>> if nli_scores is None:
311
+ ... nli_scores = {"entailment": 0.9, "neutral": 0.08, "contradiction": 0.02}
312
+ ... cache.set("roberta-nli", "nli", nli_scores,
313
+ ... premise="Mary loves books", hypothesis="Mary enjoys reading")
314
+
315
+ Caching embeddings:
316
+
317
+ >>> import numpy as np
318
+ >>> embedding = cache.get("bert-base", "embedding", text="Hello")
319
+ >>> if embedding is None:
320
+ ... embedding = np.random.rand(768)
321
+ ... cache.set("bert-base", "embedding", embedding, text="Hello")
322
+ """
323
+
324
+ def __init__(
325
+ self,
326
+ cache_dir: Path | None = None,
327
+ backend: Literal["filesystem", "memory"] = "filesystem",
328
+ enabled: bool = True,
329
+ ) -> None:
330
+ self.enabled = enabled
331
+
332
+ if backend == "filesystem":
333
+ if cache_dir is None:
334
+ cache_dir = Path.home() / ".cache" / "bead" / "models"
335
+ self._backend: CacheBackend = FilesystemBackend(cache_dir)
336
+ elif backend == "memory":
337
+ self._backend = InMemoryBackend()
338
+ else:
339
+ raise ValueError(f"Unknown backend: {backend}")
340
+
341
+ def generate_cache_key(
342
+ self, model_name: str, operation: str, **inputs: str | int | float | bool | None
343
+ ) -> str:
344
+ """Generate deterministic cache key from inputs.
345
+
346
+ Parameters
347
+ ----------
348
+ model_name
349
+ Model identifier.
350
+ operation
351
+ Operation type (e.g., "log_probability", "embedding").
352
+ **inputs
353
+ Input parameters for the operation (text, premise, hypothesis).
354
+
355
+ Returns
356
+ -------
357
+ str
358
+ SHA-256 hex digest as cache key.
359
+ """
360
+ # create deterministic dict with sorted keys
361
+ key_data = {
362
+ "model_name": model_name,
363
+ "operation": operation,
364
+ "inputs": self._serialize_for_hash(inputs),
365
+ }
366
+
367
+ # json with sorted keys for determinism
368
+ key_json = json.dumps(key_data, sort_keys=True)
369
+
370
+ # sha-256 hash
371
+ return hashlib.sha256(key_json.encode("utf-8")).hexdigest()
372
+
373
+ def _serialize_for_hash(self, obj: object) -> object:
374
+ """Serialize object for deterministic hashing.
375
+
376
+ Converts numpy arrays to lists and sorts dict keys.
377
+
378
+ Parameters
379
+ ----------
380
+ obj
381
+ Object to serialize. Accepts numpy arrays, dicts, lists, tuples,
382
+ and primitive types.
383
+
384
+ Returns
385
+ -------
386
+ object
387
+ JSON-serializable version of the object.
388
+ """
389
+ if isinstance(obj, np.ndarray):
390
+ return obj.tolist()
391
+ elif isinstance(obj, dict):
392
+ return {k: self._serialize_for_hash(v) for k, v in sorted(obj.items())} # type: ignore[misc]
393
+ elif isinstance(obj, list | tuple):
394
+ return [self._serialize_for_hash(item) for item in obj] # type: ignore[misc]
395
+ else:
396
+ return obj
397
+
398
+ def _serialize_result(self, result: object) -> object:
399
+ """Serialize result for storage.
400
+
401
+ Parameters
402
+ ----------
403
+ result
404
+ Result to serialize. Accepts numpy arrays, dicts, lists, tuples,
405
+ and primitive types.
406
+
407
+ Returns
408
+ -------
409
+ object
410
+ JSON-serializable version of result.
411
+ """
412
+ if isinstance(result, np.ndarray):
413
+ return {
414
+ "__type__": "ndarray",
415
+ "data": result.tolist(),
416
+ "dtype": str(result.dtype), # type: ignore[arg-type]
417
+ }
418
+ elif isinstance(result, dict):
419
+ return {k: self._serialize_result(v) for k, v in result.items()} # type: ignore[misc]
420
+ elif isinstance(result, list | tuple):
421
+ return [self._serialize_result(item) for item in result] # type: ignore[misc]
422
+ else:
423
+ return result
424
+
425
+ def _deserialize_result(self, result: Any) -> Any:
426
+ """Deserialize result from storage.
427
+
428
+ Parameters
429
+ ----------
430
+ result
431
+ Serialized result from cache storage.
432
+
433
+ Returns
434
+ -------
435
+ Any
436
+ Deserialized result with numpy arrays restored.
437
+ """
438
+ if isinstance(result, dict):
439
+ if result.get("__type__") == "ndarray": # type: ignore[union-attr]
440
+ return np.array(result["data"], dtype=result["dtype"]) # type: ignore[arg-type]
441
+ else:
442
+ return {k: self._deserialize_result(v) for k, v in result.items()} # type: ignore[misc]
443
+ elif isinstance(result, list):
444
+ return [self._deserialize_result(item) for item in result] # type: ignore[misc]
445
+ else:
446
+ return result
447
+
448
+ def get(
449
+ self, model_name: str, operation: str, **inputs: str | int | float | bool | None
450
+ ) -> Any:
451
+ """Retrieve cached result.
452
+
453
+ Parameters
454
+ ----------
455
+ model_name
456
+ Model identifier.
457
+ operation
458
+ Operation type (e.g., "log_probability", "nli", "embedding").
459
+ **inputs
460
+ Input parameters for the operation (text, premise, hypothesis).
461
+
462
+ Returns
463
+ -------
464
+ Any
465
+ Cached result if found, None otherwise.
466
+ """
467
+ if not self.enabled:
468
+ return None
469
+
470
+ cache_key = self.generate_cache_key(model_name, operation, **inputs)
471
+ entry = self._backend.get(cache_key)
472
+
473
+ if entry is None:
474
+ return None
475
+
476
+ # deserialize and return result
477
+ return self._deserialize_result(entry["result"])
478
+
479
+ def set(
480
+ self,
481
+ model_name: str,
482
+ operation: str,
483
+ result: float | dict[str, float] | list[float] | np.ndarray,
484
+ model_version: str | None = None,
485
+ **inputs: str | int | float | bool | None,
486
+ ) -> None:
487
+ """Store result in cache.
488
+
489
+ Parameters
490
+ ----------
491
+ model_name
492
+ Model identifier.
493
+ operation
494
+ Operation type (e.g., "log_probability", "nli", "embedding").
495
+ result
496
+ Result to cache (log probability, NLI scores, embedding, etc.).
497
+ model_version
498
+ Optional model version string for tracking.
499
+ **inputs
500
+ Input parameters for the operation (text, premise, hypothesis).
501
+ """
502
+ if not self.enabled:
503
+ return
504
+
505
+ cache_key = self.generate_cache_key(model_name, operation, **inputs)
506
+
507
+ # create cache entry with metadata
508
+ entry = {
509
+ "cache_key": cache_key,
510
+ "timestamp": datetime.now(UTC).isoformat(),
511
+ "model_name": model_name,
512
+ "model_version": model_version,
513
+ "operation": operation,
514
+ "inputs": self._serialize_for_hash(inputs),
515
+ "result": self._serialize_result(result),
516
+ }
517
+
518
+ self._backend.set(cache_key, entry)
519
+
520
+ def invalidate(
521
+ self, model_name: str, operation: str, **inputs: str | int | float | bool | None
522
+ ) -> None:
523
+ """Invalidate specific cache entry.
524
+
525
+ Parameters
526
+ ----------
527
+ model_name
528
+ Model identifier.
529
+ operation
530
+ Operation type.
531
+ **inputs
532
+ Input parameters for the operation.
533
+ """
534
+ cache_key = self.generate_cache_key(model_name, operation, **inputs)
535
+ self._backend.delete(cache_key)
536
+
537
+ def clear_model(self, model_name: str) -> None:
538
+ """Clear all cache entries for a specific model.
539
+
540
+ Parameters
541
+ ----------
542
+ model_name : str
543
+ Model identifier.
544
+ """
545
+ # get all keys and filter by model name
546
+ keys_to_delete: list[str] = []
547
+ for key in self._backend.keys():
548
+ entry = self._backend.get(key)
549
+ if entry and entry.get("model_name") == model_name:
550
+ keys_to_delete.append(key)
551
+
552
+ # delete matching entries
553
+ for key in keys_to_delete:
554
+ self._backend.delete(key)
555
+
556
+ def clear(self) -> None:
557
+ """Clear all cache entries."""
558
+ self._backend.clear()