bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
bead/items/cache.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
1
|
+
"""Content-addressable cache for judgment model outputs.
|
|
2
|
+
|
|
3
|
+
This module provides caching infrastructure for model outputs during item
|
|
4
|
+
construction. It supports multiple backends (filesystem, in-memory) and various
|
|
5
|
+
operation types including log probabilities, NLI scores, embeddings, and
|
|
6
|
+
similarity metrics.
|
|
7
|
+
|
|
8
|
+
Note: This cache is distinct from bead.templates.adapters.cache, which handles
|
|
9
|
+
MLM predictions for template filling. This module caches judgment model outputs
|
|
10
|
+
used in item construction.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import hashlib
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from datetime import UTC, datetime
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any, Literal
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class CacheBackend(ABC):
|
|
29
|
+
"""Abstract base class for cache backends.
|
|
30
|
+
|
|
31
|
+
Defines the interface that all cache backends must implement.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def get(self, key: str) -> dict[str, object] | None:
|
|
36
|
+
"""Retrieve cache entry by key.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
key
|
|
41
|
+
Cache key to retrieve.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
dict[str, object] | None
|
|
46
|
+
Cache entry data if found, None otherwise.
|
|
47
|
+
"""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def set(self, key: str, data: dict[str, object]) -> None:
|
|
52
|
+
"""Store cache entry with key.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
key
|
|
57
|
+
Cache key.
|
|
58
|
+
data
|
|
59
|
+
Cache entry data to store.
|
|
60
|
+
"""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def delete(self, key: str) -> None:
|
|
65
|
+
"""Delete cache entry by key.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
key
|
|
70
|
+
Cache key to delete.
|
|
71
|
+
"""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def clear(self) -> None:
|
|
76
|
+
"""Clear all cache entries."""
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
def keys(self) -> list[str]:
|
|
81
|
+
"""Return all cache keys.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
list[str]
|
|
86
|
+
List of all cache keys in the backend.
|
|
87
|
+
"""
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class FilesystemBackend(CacheBackend):
|
|
92
|
+
"""Filesystem-based cache backend.
|
|
93
|
+
|
|
94
|
+
Stores each cache entry as a separate JSON file with the cache key as
|
|
95
|
+
the filename.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
cache_dir : Path
|
|
100
|
+
Directory for cache storage.
|
|
101
|
+
|
|
102
|
+
Attributes
|
|
103
|
+
----------
|
|
104
|
+
cache_dir : Path
|
|
105
|
+
Directory where cache files are stored.
|
|
106
|
+
|
|
107
|
+
Examples
|
|
108
|
+
--------
|
|
109
|
+
>>> from pathlib import Path
|
|
110
|
+
>>> backend = FilesystemBackend(cache_dir=Path(".cache"))
|
|
111
|
+
>>> backend.set("abc123", {"result": 42})
|
|
112
|
+
>>> backend.get("abc123")
|
|
113
|
+
{'result': 42}
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(self, cache_dir: Path) -> None:
|
|
117
|
+
self.cache_dir = cache_dir
|
|
118
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
119
|
+
|
|
120
|
+
def get(self, key: str) -> dict[str, object] | None:
|
|
121
|
+
"""Retrieve cache entry from filesystem.
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
key
|
|
126
|
+
Cache key.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
dict[str, object] | None
|
|
131
|
+
Cache entry data if found, None otherwise.
|
|
132
|
+
"""
|
|
133
|
+
cache_file = self.cache_dir / f"{key}.json"
|
|
134
|
+
try:
|
|
135
|
+
if cache_file.exists():
|
|
136
|
+
with open(cache_file, encoding="utf-8") as f:
|
|
137
|
+
return json.load(f)
|
|
138
|
+
return None
|
|
139
|
+
except (json.JSONDecodeError, OSError) as e:
|
|
140
|
+
logger.warning(f"Failed to read cache file {cache_file}: {e}")
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
def set(self, key: str, data: dict[str, object]) -> None:
|
|
144
|
+
"""Store cache entry to filesystem.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
key
|
|
149
|
+
Cache key.
|
|
150
|
+
data
|
|
151
|
+
Cache entry data.
|
|
152
|
+
"""
|
|
153
|
+
cache_file = self.cache_dir / f"{key}.json"
|
|
154
|
+
try:
|
|
155
|
+
with open(cache_file, "w", encoding="utf-8") as f:
|
|
156
|
+
json.dump(data, f, indent=2)
|
|
157
|
+
except OSError as e:
|
|
158
|
+
logger.warning(f"Failed to write cache file {cache_file}: {e}")
|
|
159
|
+
|
|
160
|
+
def delete(self, key: str) -> None:
|
|
161
|
+
"""Delete cache entry from filesystem.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
key
|
|
166
|
+
Cache key to delete.
|
|
167
|
+
"""
|
|
168
|
+
cache_file = self.cache_dir / f"{key}.json"
|
|
169
|
+
try:
|
|
170
|
+
if cache_file.exists():
|
|
171
|
+
cache_file.unlink()
|
|
172
|
+
except OSError as e:
|
|
173
|
+
logger.warning(f"Failed to delete cache file {cache_file}: {e}")
|
|
174
|
+
|
|
175
|
+
def clear(self) -> None:
|
|
176
|
+
"""Clear all cache entries from filesystem."""
|
|
177
|
+
try:
|
|
178
|
+
for cache_file in self.cache_dir.glob("*.json"):
|
|
179
|
+
cache_file.unlink()
|
|
180
|
+
except OSError as e:
|
|
181
|
+
logger.warning(f"Failed to clear cache directory {self.cache_dir}: {e}")
|
|
182
|
+
|
|
183
|
+
def keys(self) -> list[str]:
|
|
184
|
+
"""Return all cache keys from filesystem.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
list[str]
|
|
189
|
+
List of cache keys (filenames without .json extension).
|
|
190
|
+
"""
|
|
191
|
+
try:
|
|
192
|
+
return [f.stem for f in self.cache_dir.glob("*.json")]
|
|
193
|
+
except OSError as e:
|
|
194
|
+
logger.warning(f"Failed to list cache keys in {self.cache_dir}: {e}")
|
|
195
|
+
return []
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class InMemoryBackend(CacheBackend):
|
|
199
|
+
"""In-memory cache backend.
|
|
200
|
+
|
|
201
|
+
Stores cache entries in a dictionary. No persistence across program runs.
|
|
202
|
+
Useful for testing and temporary caching scenarios.
|
|
203
|
+
|
|
204
|
+
Examples
|
|
205
|
+
--------
|
|
206
|
+
>>> backend = InMemoryBackend()
|
|
207
|
+
>>> backend.set("xyz789", {"result": 3.14})
|
|
208
|
+
>>> backend.get("xyz789")
|
|
209
|
+
{'result': 3.14}
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(self) -> None:
|
|
213
|
+
self._cache: dict[str, dict[str, object]] = {}
|
|
214
|
+
|
|
215
|
+
def get(self, key: str) -> dict[str, object] | None:
|
|
216
|
+
"""Retrieve cache entry from memory.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
key
|
|
221
|
+
Cache key.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
dict[str, object] | None
|
|
226
|
+
Cache entry data if found, None otherwise.
|
|
227
|
+
"""
|
|
228
|
+
return self._cache.get(key)
|
|
229
|
+
|
|
230
|
+
def set(self, key: str, data: dict[str, object]) -> None:
|
|
231
|
+
"""Store cache entry in memory.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
key
|
|
236
|
+
Cache key.
|
|
237
|
+
data
|
|
238
|
+
Cache entry data.
|
|
239
|
+
"""
|
|
240
|
+
self._cache[key] = data
|
|
241
|
+
|
|
242
|
+
def delete(self, key: str) -> None:
|
|
243
|
+
"""Delete cache entry from memory.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
key
|
|
248
|
+
Cache key to delete.
|
|
249
|
+
"""
|
|
250
|
+
self._cache.pop(key, None)
|
|
251
|
+
|
|
252
|
+
def clear(self) -> None:
|
|
253
|
+
"""Clear all cache entries from memory."""
|
|
254
|
+
self._cache.clear()
|
|
255
|
+
|
|
256
|
+
def keys(self) -> list[str]:
|
|
257
|
+
"""Return all cache keys from memory.
|
|
258
|
+
|
|
259
|
+
Returns
|
|
260
|
+
-------
|
|
261
|
+
list[str]
|
|
262
|
+
List of cache keys.
|
|
263
|
+
"""
|
|
264
|
+
return list(self._cache.keys())
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class ModelOutputCache:
|
|
268
|
+
"""Content-addressable cache for judgment model outputs.
|
|
269
|
+
|
|
270
|
+
Caches results from various model operations to avoid redundant computation.
|
|
271
|
+
Supports multiple operation types including log probabilities, perplexity,
|
|
272
|
+
NLI scores, embeddings, and similarity metrics.
|
|
273
|
+
|
|
274
|
+
Cache keys are automatically generated using SHA-256 hashing of the model
|
|
275
|
+
name, operation type, and all input parameters, ensuring deterministic
|
|
276
|
+
cache hits for identical inputs.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
cache_dir : Path | None
|
|
281
|
+
Directory for cache files (filesystem backend only).
|
|
282
|
+
Defaults to ~/.cache/bead/models if not specified.
|
|
283
|
+
backend : {"filesystem", "memory"}
|
|
284
|
+
Cache backend type. "filesystem" persists across runs,
|
|
285
|
+
"memory" is ephemeral.
|
|
286
|
+
enabled : bool
|
|
287
|
+
Whether caching is enabled.
|
|
288
|
+
|
|
289
|
+
Attributes
|
|
290
|
+
----------
|
|
291
|
+
enabled : bool
|
|
292
|
+
Whether caching is enabled. When False, all operations are no-ops.
|
|
293
|
+
|
|
294
|
+
Examples
|
|
295
|
+
--------
|
|
296
|
+
Basic usage with filesystem backend:
|
|
297
|
+
|
|
298
|
+
>>> from pathlib import Path
|
|
299
|
+
>>> cache = ModelOutputCache(cache_dir=Path(".cache"))
|
|
300
|
+
>>> result = cache.get("gpt2", "log_probability", text="Hello world")
|
|
301
|
+
>>> if result is None:
|
|
302
|
+
... result = -2.5
|
|
303
|
+
... cache.set("gpt2", "log_probability", result, text="Hello world")
|
|
304
|
+
|
|
305
|
+
Caching NLI scores:
|
|
306
|
+
|
|
307
|
+
>>> nli_scores = cache.get("roberta-nli", "nli",
|
|
308
|
+
... premise="Mary loves books",
|
|
309
|
+
... hypothesis="Mary enjoys reading")
|
|
310
|
+
>>> if nli_scores is None:
|
|
311
|
+
... nli_scores = {"entailment": 0.9, "neutral": 0.08, "contradiction": 0.02}
|
|
312
|
+
... cache.set("roberta-nli", "nli", nli_scores,
|
|
313
|
+
... premise="Mary loves books", hypothesis="Mary enjoys reading")
|
|
314
|
+
|
|
315
|
+
Caching embeddings:
|
|
316
|
+
|
|
317
|
+
>>> import numpy as np
|
|
318
|
+
>>> embedding = cache.get("bert-base", "embedding", text="Hello")
|
|
319
|
+
>>> if embedding is None:
|
|
320
|
+
... embedding = np.random.rand(768)
|
|
321
|
+
... cache.set("bert-base", "embedding", embedding, text="Hello")
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
def __init__(
|
|
325
|
+
self,
|
|
326
|
+
cache_dir: Path | None = None,
|
|
327
|
+
backend: Literal["filesystem", "memory"] = "filesystem",
|
|
328
|
+
enabled: bool = True,
|
|
329
|
+
) -> None:
|
|
330
|
+
self.enabled = enabled
|
|
331
|
+
|
|
332
|
+
if backend == "filesystem":
|
|
333
|
+
if cache_dir is None:
|
|
334
|
+
cache_dir = Path.home() / ".cache" / "bead" / "models"
|
|
335
|
+
self._backend: CacheBackend = FilesystemBackend(cache_dir)
|
|
336
|
+
elif backend == "memory":
|
|
337
|
+
self._backend = InMemoryBackend()
|
|
338
|
+
else:
|
|
339
|
+
raise ValueError(f"Unknown backend: {backend}")
|
|
340
|
+
|
|
341
|
+
def generate_cache_key(
|
|
342
|
+
self, model_name: str, operation: str, **inputs: str | int | float | bool | None
|
|
343
|
+
) -> str:
|
|
344
|
+
"""Generate deterministic cache key from inputs.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
model_name
|
|
349
|
+
Model identifier.
|
|
350
|
+
operation
|
|
351
|
+
Operation type (e.g., "log_probability", "embedding").
|
|
352
|
+
**inputs
|
|
353
|
+
Input parameters for the operation (text, premise, hypothesis).
|
|
354
|
+
|
|
355
|
+
Returns
|
|
356
|
+
-------
|
|
357
|
+
str
|
|
358
|
+
SHA-256 hex digest as cache key.
|
|
359
|
+
"""
|
|
360
|
+
# create deterministic dict with sorted keys
|
|
361
|
+
key_data = {
|
|
362
|
+
"model_name": model_name,
|
|
363
|
+
"operation": operation,
|
|
364
|
+
"inputs": self._serialize_for_hash(inputs),
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
# json with sorted keys for determinism
|
|
368
|
+
key_json = json.dumps(key_data, sort_keys=True)
|
|
369
|
+
|
|
370
|
+
# sha-256 hash
|
|
371
|
+
return hashlib.sha256(key_json.encode("utf-8")).hexdigest()
|
|
372
|
+
|
|
373
|
+
def _serialize_for_hash(self, obj: object) -> object:
|
|
374
|
+
"""Serialize object for deterministic hashing.
|
|
375
|
+
|
|
376
|
+
Converts numpy arrays to lists and sorts dict keys.
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
obj
|
|
381
|
+
Object to serialize. Accepts numpy arrays, dicts, lists, tuples,
|
|
382
|
+
and primitive types.
|
|
383
|
+
|
|
384
|
+
Returns
|
|
385
|
+
-------
|
|
386
|
+
object
|
|
387
|
+
JSON-serializable version of the object.
|
|
388
|
+
"""
|
|
389
|
+
if isinstance(obj, np.ndarray):
|
|
390
|
+
return obj.tolist()
|
|
391
|
+
elif isinstance(obj, dict):
|
|
392
|
+
return {k: self._serialize_for_hash(v) for k, v in sorted(obj.items())} # type: ignore[misc]
|
|
393
|
+
elif isinstance(obj, list | tuple):
|
|
394
|
+
return [self._serialize_for_hash(item) for item in obj] # type: ignore[misc]
|
|
395
|
+
else:
|
|
396
|
+
return obj
|
|
397
|
+
|
|
398
|
+
def _serialize_result(self, result: object) -> object:
|
|
399
|
+
"""Serialize result for storage.
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
result
|
|
404
|
+
Result to serialize. Accepts numpy arrays, dicts, lists, tuples,
|
|
405
|
+
and primitive types.
|
|
406
|
+
|
|
407
|
+
Returns
|
|
408
|
+
-------
|
|
409
|
+
object
|
|
410
|
+
JSON-serializable version of result.
|
|
411
|
+
"""
|
|
412
|
+
if isinstance(result, np.ndarray):
|
|
413
|
+
return {
|
|
414
|
+
"__type__": "ndarray",
|
|
415
|
+
"data": result.tolist(),
|
|
416
|
+
"dtype": str(result.dtype), # type: ignore[arg-type]
|
|
417
|
+
}
|
|
418
|
+
elif isinstance(result, dict):
|
|
419
|
+
return {k: self._serialize_result(v) for k, v in result.items()} # type: ignore[misc]
|
|
420
|
+
elif isinstance(result, list | tuple):
|
|
421
|
+
return [self._serialize_result(item) for item in result] # type: ignore[misc]
|
|
422
|
+
else:
|
|
423
|
+
return result
|
|
424
|
+
|
|
425
|
+
def _deserialize_result(self, result: Any) -> Any:
|
|
426
|
+
"""Deserialize result from storage.
|
|
427
|
+
|
|
428
|
+
Parameters
|
|
429
|
+
----------
|
|
430
|
+
result
|
|
431
|
+
Serialized result from cache storage.
|
|
432
|
+
|
|
433
|
+
Returns
|
|
434
|
+
-------
|
|
435
|
+
Any
|
|
436
|
+
Deserialized result with numpy arrays restored.
|
|
437
|
+
"""
|
|
438
|
+
if isinstance(result, dict):
|
|
439
|
+
if result.get("__type__") == "ndarray": # type: ignore[union-attr]
|
|
440
|
+
return np.array(result["data"], dtype=result["dtype"]) # type: ignore[arg-type]
|
|
441
|
+
else:
|
|
442
|
+
return {k: self._deserialize_result(v) for k, v in result.items()} # type: ignore[misc]
|
|
443
|
+
elif isinstance(result, list):
|
|
444
|
+
return [self._deserialize_result(item) for item in result] # type: ignore[misc]
|
|
445
|
+
else:
|
|
446
|
+
return result
|
|
447
|
+
|
|
448
|
+
def get(
|
|
449
|
+
self, model_name: str, operation: str, **inputs: str | int | float | bool | None
|
|
450
|
+
) -> Any:
|
|
451
|
+
"""Retrieve cached result.
|
|
452
|
+
|
|
453
|
+
Parameters
|
|
454
|
+
----------
|
|
455
|
+
model_name
|
|
456
|
+
Model identifier.
|
|
457
|
+
operation
|
|
458
|
+
Operation type (e.g., "log_probability", "nli", "embedding").
|
|
459
|
+
**inputs
|
|
460
|
+
Input parameters for the operation (text, premise, hypothesis).
|
|
461
|
+
|
|
462
|
+
Returns
|
|
463
|
+
-------
|
|
464
|
+
Any
|
|
465
|
+
Cached result if found, None otherwise.
|
|
466
|
+
"""
|
|
467
|
+
if not self.enabled:
|
|
468
|
+
return None
|
|
469
|
+
|
|
470
|
+
cache_key = self.generate_cache_key(model_name, operation, **inputs)
|
|
471
|
+
entry = self._backend.get(cache_key)
|
|
472
|
+
|
|
473
|
+
if entry is None:
|
|
474
|
+
return None
|
|
475
|
+
|
|
476
|
+
# deserialize and return result
|
|
477
|
+
return self._deserialize_result(entry["result"])
|
|
478
|
+
|
|
479
|
+
def set(
|
|
480
|
+
self,
|
|
481
|
+
model_name: str,
|
|
482
|
+
operation: str,
|
|
483
|
+
result: float | dict[str, float] | list[float] | np.ndarray,
|
|
484
|
+
model_version: str | None = None,
|
|
485
|
+
**inputs: str | int | float | bool | None,
|
|
486
|
+
) -> None:
|
|
487
|
+
"""Store result in cache.
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
model_name
|
|
492
|
+
Model identifier.
|
|
493
|
+
operation
|
|
494
|
+
Operation type (e.g., "log_probability", "nli", "embedding").
|
|
495
|
+
result
|
|
496
|
+
Result to cache (log probability, NLI scores, embedding, etc.).
|
|
497
|
+
model_version
|
|
498
|
+
Optional model version string for tracking.
|
|
499
|
+
**inputs
|
|
500
|
+
Input parameters for the operation (text, premise, hypothesis).
|
|
501
|
+
"""
|
|
502
|
+
if not self.enabled:
|
|
503
|
+
return
|
|
504
|
+
|
|
505
|
+
cache_key = self.generate_cache_key(model_name, operation, **inputs)
|
|
506
|
+
|
|
507
|
+
# create cache entry with metadata
|
|
508
|
+
entry = {
|
|
509
|
+
"cache_key": cache_key,
|
|
510
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
511
|
+
"model_name": model_name,
|
|
512
|
+
"model_version": model_version,
|
|
513
|
+
"operation": operation,
|
|
514
|
+
"inputs": self._serialize_for_hash(inputs),
|
|
515
|
+
"result": self._serialize_result(result),
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
self._backend.set(cache_key, entry)
|
|
519
|
+
|
|
520
|
+
def invalidate(
|
|
521
|
+
self, model_name: str, operation: str, **inputs: str | int | float | bool | None
|
|
522
|
+
) -> None:
|
|
523
|
+
"""Invalidate specific cache entry.
|
|
524
|
+
|
|
525
|
+
Parameters
|
|
526
|
+
----------
|
|
527
|
+
model_name
|
|
528
|
+
Model identifier.
|
|
529
|
+
operation
|
|
530
|
+
Operation type.
|
|
531
|
+
**inputs
|
|
532
|
+
Input parameters for the operation.
|
|
533
|
+
"""
|
|
534
|
+
cache_key = self.generate_cache_key(model_name, operation, **inputs)
|
|
535
|
+
self._backend.delete(cache_key)
|
|
536
|
+
|
|
537
|
+
def clear_model(self, model_name: str) -> None:
|
|
538
|
+
"""Clear all cache entries for a specific model.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
model_name : str
|
|
543
|
+
Model identifier.
|
|
544
|
+
"""
|
|
545
|
+
# get all keys and filter by model name
|
|
546
|
+
keys_to_delete: list[str] = []
|
|
547
|
+
for key in self._backend.keys():
|
|
548
|
+
entry = self._backend.get(key)
|
|
549
|
+
if entry and entry.get("model_name") == model_name:
|
|
550
|
+
keys_to_delete.append(key)
|
|
551
|
+
|
|
552
|
+
# delete matching entries
|
|
553
|
+
for key in keys_to_delete:
|
|
554
|
+
self._backend.delete(key)
|
|
555
|
+
|
|
556
|
+
def clear(self) -> None:
|
|
557
|
+
"""Clear all cache entries."""
|
|
558
|
+
self._backend.clear()
|