bead 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. bead/__init__.py +11 -0
  2. bead/__main__.py +11 -0
  3. bead/active_learning/__init__.py +15 -0
  4. bead/active_learning/config.py +231 -0
  5. bead/active_learning/loop.py +566 -0
  6. bead/active_learning/models/__init__.py +24 -0
  7. bead/active_learning/models/base.py +852 -0
  8. bead/active_learning/models/binary.py +910 -0
  9. bead/active_learning/models/categorical.py +943 -0
  10. bead/active_learning/models/cloze.py +862 -0
  11. bead/active_learning/models/forced_choice.py +956 -0
  12. bead/active_learning/models/free_text.py +773 -0
  13. bead/active_learning/models/lora.py +365 -0
  14. bead/active_learning/models/magnitude.py +835 -0
  15. bead/active_learning/models/multi_select.py +795 -0
  16. bead/active_learning/models/ordinal_scale.py +811 -0
  17. bead/active_learning/models/peft_adapter.py +155 -0
  18. bead/active_learning/models/random_effects.py +639 -0
  19. bead/active_learning/selection.py +354 -0
  20. bead/active_learning/strategies.py +391 -0
  21. bead/active_learning/trainers/__init__.py +26 -0
  22. bead/active_learning/trainers/base.py +210 -0
  23. bead/active_learning/trainers/data_collator.py +172 -0
  24. bead/active_learning/trainers/dataset_utils.py +261 -0
  25. bead/active_learning/trainers/huggingface.py +304 -0
  26. bead/active_learning/trainers/lightning.py +324 -0
  27. bead/active_learning/trainers/metrics.py +424 -0
  28. bead/active_learning/trainers/mixed_effects.py +551 -0
  29. bead/active_learning/trainers/model_wrapper.py +509 -0
  30. bead/active_learning/trainers/registry.py +104 -0
  31. bead/adapters/__init__.py +11 -0
  32. bead/adapters/huggingface.py +61 -0
  33. bead/behavioral/__init__.py +116 -0
  34. bead/behavioral/analytics.py +646 -0
  35. bead/behavioral/extraction.py +343 -0
  36. bead/behavioral/merging.py +343 -0
  37. bead/cli/__init__.py +11 -0
  38. bead/cli/active_learning.py +513 -0
  39. bead/cli/active_learning_commands.py +779 -0
  40. bead/cli/completion.py +359 -0
  41. bead/cli/config.py +624 -0
  42. bead/cli/constraint_builders.py +286 -0
  43. bead/cli/deployment.py +859 -0
  44. bead/cli/deployment_trials.py +493 -0
  45. bead/cli/deployment_ui.py +332 -0
  46. bead/cli/display.py +378 -0
  47. bead/cli/items.py +960 -0
  48. bead/cli/items_factories.py +776 -0
  49. bead/cli/list_constraints.py +714 -0
  50. bead/cli/lists.py +490 -0
  51. bead/cli/main.py +430 -0
  52. bead/cli/models.py +877 -0
  53. bead/cli/resource_loaders.py +621 -0
  54. bead/cli/resources.py +1036 -0
  55. bead/cli/shell.py +356 -0
  56. bead/cli/simulate.py +840 -0
  57. bead/cli/templates.py +1158 -0
  58. bead/cli/training.py +1080 -0
  59. bead/cli/utils.py +614 -0
  60. bead/cli/workflow.py +1273 -0
  61. bead/config/__init__.py +68 -0
  62. bead/config/active_learning.py +1009 -0
  63. bead/config/config.py +192 -0
  64. bead/config/defaults.py +118 -0
  65. bead/config/deployment.py +217 -0
  66. bead/config/env.py +147 -0
  67. bead/config/item.py +45 -0
  68. bead/config/list.py +193 -0
  69. bead/config/loader.py +149 -0
  70. bead/config/logging.py +42 -0
  71. bead/config/model.py +49 -0
  72. bead/config/paths.py +46 -0
  73. bead/config/profiles.py +320 -0
  74. bead/config/resources.py +47 -0
  75. bead/config/serialization.py +210 -0
  76. bead/config/simulation.py +206 -0
  77. bead/config/template.py +238 -0
  78. bead/config/validation.py +267 -0
  79. bead/data/__init__.py +65 -0
  80. bead/data/base.py +87 -0
  81. bead/data/identifiers.py +97 -0
  82. bead/data/language_codes.py +61 -0
  83. bead/data/metadata.py +270 -0
  84. bead/data/range.py +123 -0
  85. bead/data/repository.py +358 -0
  86. bead/data/serialization.py +249 -0
  87. bead/data/timestamps.py +89 -0
  88. bead/data/validation.py +349 -0
  89. bead/data_collection/__init__.py +11 -0
  90. bead/data_collection/jatos.py +223 -0
  91. bead/data_collection/merger.py +154 -0
  92. bead/data_collection/prolific.py +198 -0
  93. bead/deployment/__init__.py +5 -0
  94. bead/deployment/distribution.py +402 -0
  95. bead/deployment/jatos/__init__.py +1 -0
  96. bead/deployment/jatos/api.py +200 -0
  97. bead/deployment/jatos/exporter.py +210 -0
  98. bead/deployment/jspsych/__init__.py +9 -0
  99. bead/deployment/jspsych/biome.json +44 -0
  100. bead/deployment/jspsych/config.py +411 -0
  101. bead/deployment/jspsych/generator.py +598 -0
  102. bead/deployment/jspsych/package.json +51 -0
  103. bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
  104. bead/deployment/jspsych/randomizer.py +299 -0
  105. bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
  106. bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
  107. bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
  108. bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
  109. bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
  110. bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
  111. bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
  112. bead/deployment/jspsych/src/plugins/rating.ts +248 -0
  113. bead/deployment/jspsych/src/slopit/index.ts +9 -0
  114. bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
  115. bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
  116. bead/deployment/jspsych/templates/experiment.css +1 -0
  117. bead/deployment/jspsych/templates/experiment.js.template +289 -0
  118. bead/deployment/jspsych/templates/index.html +51 -0
  119. bead/deployment/jspsych/templates/randomizer.js +241 -0
  120. bead/deployment/jspsych/templates/randomizer.js.template +313 -0
  121. bead/deployment/jspsych/trials.py +723 -0
  122. bead/deployment/jspsych/tsconfig.json +23 -0
  123. bead/deployment/jspsych/tsup.config.ts +30 -0
  124. bead/deployment/jspsych/ui/__init__.py +1 -0
  125. bead/deployment/jspsych/ui/components.py +383 -0
  126. bead/deployment/jspsych/ui/styles.py +411 -0
  127. bead/dsl/__init__.py +80 -0
  128. bead/dsl/ast.py +168 -0
  129. bead/dsl/context.py +178 -0
  130. bead/dsl/errors.py +71 -0
  131. bead/dsl/evaluator.py +570 -0
  132. bead/dsl/grammar.lark +81 -0
  133. bead/dsl/parser.py +231 -0
  134. bead/dsl/stdlib.py +929 -0
  135. bead/evaluation/__init__.py +13 -0
  136. bead/evaluation/convergence.py +485 -0
  137. bead/evaluation/interannotator.py +398 -0
  138. bead/items/__init__.py +40 -0
  139. bead/items/adapters/__init__.py +70 -0
  140. bead/items/adapters/anthropic.py +224 -0
  141. bead/items/adapters/api_utils.py +167 -0
  142. bead/items/adapters/base.py +216 -0
  143. bead/items/adapters/google.py +259 -0
  144. bead/items/adapters/huggingface.py +1074 -0
  145. bead/items/adapters/openai.py +323 -0
  146. bead/items/adapters/registry.py +202 -0
  147. bead/items/adapters/sentence_transformers.py +224 -0
  148. bead/items/adapters/togetherai.py +309 -0
  149. bead/items/binary.py +515 -0
  150. bead/items/cache.py +558 -0
  151. bead/items/categorical.py +593 -0
  152. bead/items/cloze.py +757 -0
  153. bead/items/constructor.py +784 -0
  154. bead/items/forced_choice.py +413 -0
  155. bead/items/free_text.py +681 -0
  156. bead/items/generation.py +432 -0
  157. bead/items/item.py +396 -0
  158. bead/items/item_template.py +787 -0
  159. bead/items/magnitude.py +573 -0
  160. bead/items/multi_select.py +621 -0
  161. bead/items/ordinal_scale.py +569 -0
  162. bead/items/scoring.py +448 -0
  163. bead/items/validation.py +723 -0
  164. bead/lists/__init__.py +30 -0
  165. bead/lists/balancer.py +263 -0
  166. bead/lists/constraints.py +1067 -0
  167. bead/lists/experiment_list.py +286 -0
  168. bead/lists/list_collection.py +378 -0
  169. bead/lists/partitioner.py +1141 -0
  170. bead/lists/stratification.py +254 -0
  171. bead/participants/__init__.py +73 -0
  172. bead/participants/collection.py +699 -0
  173. bead/participants/merging.py +312 -0
  174. bead/participants/metadata_spec.py +491 -0
  175. bead/participants/models.py +276 -0
  176. bead/resources/__init__.py +29 -0
  177. bead/resources/adapters/__init__.py +19 -0
  178. bead/resources/adapters/base.py +104 -0
  179. bead/resources/adapters/cache.py +128 -0
  180. bead/resources/adapters/glazing.py +508 -0
  181. bead/resources/adapters/registry.py +117 -0
  182. bead/resources/adapters/unimorph.py +796 -0
  183. bead/resources/classification.py +856 -0
  184. bead/resources/constraint_builders.py +329 -0
  185. bead/resources/constraints.py +165 -0
  186. bead/resources/lexical_item.py +223 -0
  187. bead/resources/lexicon.py +744 -0
  188. bead/resources/loaders.py +209 -0
  189. bead/resources/template.py +441 -0
  190. bead/resources/template_collection.py +707 -0
  191. bead/resources/template_generation.py +349 -0
  192. bead/simulation/__init__.py +29 -0
  193. bead/simulation/annotators/__init__.py +15 -0
  194. bead/simulation/annotators/base.py +175 -0
  195. bead/simulation/annotators/distance_based.py +135 -0
  196. bead/simulation/annotators/lm_based.py +114 -0
  197. bead/simulation/annotators/oracle.py +182 -0
  198. bead/simulation/annotators/random.py +181 -0
  199. bead/simulation/dsl_extension/__init__.py +3 -0
  200. bead/simulation/noise_models/__init__.py +13 -0
  201. bead/simulation/noise_models/base.py +42 -0
  202. bead/simulation/noise_models/random_noise.py +82 -0
  203. bead/simulation/noise_models/systematic.py +132 -0
  204. bead/simulation/noise_models/temperature.py +86 -0
  205. bead/simulation/runner.py +144 -0
  206. bead/simulation/strategies/__init__.py +23 -0
  207. bead/simulation/strategies/base.py +123 -0
  208. bead/simulation/strategies/binary.py +103 -0
  209. bead/simulation/strategies/categorical.py +123 -0
  210. bead/simulation/strategies/cloze.py +224 -0
  211. bead/simulation/strategies/forced_choice.py +127 -0
  212. bead/simulation/strategies/free_text.py +105 -0
  213. bead/simulation/strategies/magnitude.py +116 -0
  214. bead/simulation/strategies/multi_select.py +129 -0
  215. bead/simulation/strategies/ordinal_scale.py +131 -0
  216. bead/templates/__init__.py +27 -0
  217. bead/templates/adapters/__init__.py +17 -0
  218. bead/templates/adapters/base.py +128 -0
  219. bead/templates/adapters/cache.py +178 -0
  220. bead/templates/adapters/huggingface.py +312 -0
  221. bead/templates/combinatorics.py +103 -0
  222. bead/templates/filler.py +605 -0
  223. bead/templates/renderers.py +177 -0
  224. bead/templates/resolver.py +178 -0
  225. bead/templates/strategies.py +1806 -0
  226. bead/templates/streaming.py +195 -0
  227. bead-0.1.0.dist-info/METADATA +212 -0
  228. bead-0.1.0.dist-info/RECORD +231 -0
  229. bead-0.1.0.dist-info/WHEEL +4 -0
  230. bead-0.1.0.dist-info/entry_points.txt +2 -0
  231. bead-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1074 @@
1
+ """HuggingFace model adapters for language models and NLI.
2
+
3
+ This module provides adapters for HuggingFace Transformers models:
4
+ - HuggingFaceLanguageModel: Causal LMs (GPT-2, GPT-Neo, Llama, Mistral)
5
+ - HuggingFaceMaskedLanguageModel: Masked LMs (BERT, RoBERTa, ALBERT)
6
+ - HuggingFaceNLI: NLI models (RoBERTa-MNLI, DeBERTa-MNLI, BART-MNLI)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ from typing import TYPE_CHECKING
13
+
14
+ import numpy as np
15
+ import psutil
16
+ import torch
17
+ from rich.progress import (
18
+ BarColumn,
19
+ Progress,
20
+ SpinnerColumn,
21
+ TaskProgressColumn,
22
+ TextColumn,
23
+ TimeElapsedColumn,
24
+ TimeRemainingColumn,
25
+ )
26
+ from transformers import (
27
+ AutoConfig,
28
+ AutoModelForCausalLM,
29
+ AutoModelForMaskedLM,
30
+ AutoModelForSequenceClassification,
31
+ AutoTokenizer,
32
+ PreTrainedModel,
33
+ PreTrainedTokenizerBase,
34
+ )
35
+
36
+ from bead.adapters.huggingface import DeviceType, HuggingFaceAdapterMixin
37
+ from bead.items.adapters.base import ModelAdapter
38
+ from bead.items.cache import ModelOutputCache
39
+
40
+ if TYPE_CHECKING:
41
+ from transformers.models.auto.configuration_auto import AutoConfig
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ class HuggingFaceLanguageModel(HuggingFaceAdapterMixin, ModelAdapter):
47
+ """Adapter for HuggingFace causal language models.
48
+
49
+ Supports models like GPT-2, GPT-Neo, Llama, Mistral, and other
50
+ autoregressive (left-to-right) language models.
51
+
52
+ Parameters
53
+ ----------
54
+ model_name : str
55
+ HuggingFace model identifier (e.g., "gpt2", "gpt2-medium").
56
+ cache : ModelOutputCache
57
+ Cache instance for storing model outputs.
58
+ device : {"cpu", "cuda", "mps"}
59
+ Device to run model on. Falls back to CPU if device unavailable.
60
+ model_version : str
61
+ Version string for cache tracking.
62
+
63
+ Examples
64
+ --------
65
+ >>> from pathlib import Path
66
+ >>> from bead.items.cache import ModelOutputCache
67
+ >>> cache = ModelOutputCache(cache_dir=Path(".cache"))
68
+ >>> model = HuggingFaceLanguageModel("gpt2", cache, device="cpu")
69
+ >>> log_prob = model.compute_log_probability("The cat sat on the mat.")
70
+ >>> perplexity = model.compute_perplexity("The cat sat on the mat.")
71
+ >>> embedding = model.get_embedding("The cat sat on the mat.")
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ model_name: str,
77
+ cache: ModelOutputCache,
78
+ device: DeviceType = "cpu",
79
+ model_version: str = "unknown",
80
+ ) -> None:
81
+ super().__init__(model_name, cache, model_version)
82
+ self.device = self._validate_device(device)
83
+ self._model: PreTrainedModel | None = None
84
+ self._tokenizer: PreTrainedTokenizerBase | None = None
85
+
86
+ def _load_model(self) -> None:
87
+ """Load model and tokenizer lazily on first use."""
88
+ if self._model is None:
89
+ logger.info(f"Loading causal LM: {self.model_name}")
90
+ self._model = AutoModelForCausalLM.from_pretrained(self.model_name)
91
+ self._model.to(self.device)
92
+ self._model.eval()
93
+
94
+ if self._tokenizer is None:
95
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
96
+ # set padding token for models that don't have one
97
+ if self._tokenizer.pad_token is None:
98
+ self._tokenizer.pad_token = self._tokenizer.eos_token
99
+
100
+ @property
101
+ def model(self) -> PreTrainedModel:
102
+ """Get the model, loading if necessary."""
103
+ self._load_model()
104
+ assert self._model is not None
105
+ return self._model
106
+
107
+ @property
108
+ def tokenizer(self) -> PreTrainedTokenizerBase:
109
+ """Get the tokenizer, loading if necessary."""
110
+ self._load_model()
111
+ assert self._tokenizer is not None
112
+ return self._tokenizer
113
+
114
+ def compute_log_probability(self, text: str) -> float:
115
+ """Compute log probability of text under language model.
116
+
117
+ Uses the model's loss with labels=input_ids to compute the negative
118
+ log-likelihood of the text.
119
+
120
+ Parameters
121
+ ----------
122
+ text : str
123
+ Text to compute log probability for.
124
+
125
+ Returns
126
+ -------
127
+ float
128
+ Log probability of the text.
129
+ """
130
+ # Check cache
131
+ cached = self.cache.get(self.model_name, "log_probability", text=text)
132
+ if cached is not None:
133
+ return cached
134
+
135
+ # tokenize
136
+ inputs = self.tokenizer(
137
+ text, return_tensors="pt", padding=True, truncation=True
138
+ )
139
+ input_ids = inputs["input_ids"].to(self.device)
140
+ attention_mask = inputs["attention_mask"].to(self.device)
141
+
142
+ # compute loss (negative log-likelihood)
143
+ with torch.no_grad():
144
+ outputs = self.model(
145
+ input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
146
+ )
147
+ loss = outputs.loss.item()
148
+
149
+ # loss is negative log-likelihood per token, convert to total log prob
150
+ log_prob = -loss * input_ids.size(1)
151
+
152
+ # cache result
153
+ self.cache.set(
154
+ self.model_name,
155
+ "log_probability",
156
+ log_prob,
157
+ model_version=self.model_version,
158
+ text=text,
159
+ )
160
+
161
+ return log_prob
162
+
163
+ def _infer_optimal_batch_size(self) -> int:
164
+ """Infer optimal batch size based on available resources.
165
+
166
+ Considers:
167
+ - Device type (CPU, CUDA, MPS)
168
+ - Available memory
169
+ - Model size
170
+ - Sequence length estimates
171
+
172
+ Returns
173
+ -------
174
+ int
175
+ Recommended batch size.
176
+ """
177
+ # estimate model size
178
+ model_params = sum(
179
+ p.numel() * p.element_size() for p in self.model.parameters()
180
+ )
181
+
182
+ if self.device == "cuda":
183
+ try:
184
+ # get GPU memory
185
+ free_memory, _ = torch.cuda.mem_get_info(self.device)
186
+
187
+ # conservative estimate: allow model + 4x model size for activations
188
+ # reserve 20% for safety margin
189
+ available_for_batch = (free_memory * 0.8) - model_params
190
+ memory_per_item = model_params * 4 # very rough estimate
191
+
192
+ batch_size = int(available_for_batch / memory_per_item)
193
+
194
+ # clamp between reasonable bounds
195
+ batch_size = max(8, min(batch_size, 256))
196
+
197
+ free_gb = free_memory / 1e9
198
+ model_gb = model_params / 1e9
199
+ logger.info(
200
+ f"Inferred batch size {batch_size} for CUDA "
201
+ f"(free: {free_gb:.1f}GB, model: {model_gb:.2f}GB)"
202
+ )
203
+ return batch_size
204
+
205
+ except Exception as e:
206
+ logger.warning(
207
+ f"Failed to infer CUDA batch size: {e}, using default 32"
208
+ )
209
+ return 32
210
+
211
+ elif self.device == "mps":
212
+ try:
213
+ # mps (Apple Silicon) - use system RAM as proxy
214
+ # mps shares unified memory with system
215
+ available_memory = psutil.virtual_memory().available
216
+
217
+ # reserve 4GB for system + model
218
+ available_for_batch = max(
219
+ 0, available_memory - (4 * 1024**3) - model_params
220
+ )
221
+ memory_per_item = model_params * 3 # mps is more efficient than CUDA
222
+
223
+ batch_size = int(available_for_batch / memory_per_item)
224
+
225
+ # clamp between reasonable bounds
226
+ batch_size = max(8, min(batch_size, 256))
227
+
228
+ avail_gb = available_memory / 1e9
229
+ model_gb = model_params / 1e9
230
+ logger.info(
231
+ f"Inferred batch size {batch_size} for MPS "
232
+ f"(available: {avail_gb:.1f}GB, model: {model_gb:.2f}GB)"
233
+ )
234
+ return batch_size
235
+
236
+ except Exception as e:
237
+ logger.warning(f"Failed to infer MPS batch size: {e}, using default 64")
238
+ return 64
239
+
240
+ else: # CPU
241
+ try:
242
+ # cpu - check available RAM
243
+ available_memory = psutil.virtual_memory().available
244
+
245
+ # reserve 2GB for system + model
246
+ available_for_batch = max(
247
+ 0, available_memory - (2 * 1024**3) - model_params
248
+ )
249
+ memory_per_item = model_params * 2 # cpu has less overhead than GPU
250
+
251
+ batch_size = int(available_for_batch / memory_per_item)
252
+
253
+ # clamp between reasonable bounds
254
+ batch_size = max(4, min(batch_size, 128))
255
+
256
+ avail_gb = available_memory / 1e9
257
+ model_gb = model_params / 1e9
258
+ logger.info(
259
+ f"Inferred batch size {batch_size} for CPU "
260
+ f"(available: {avail_gb:.1f}GB, model: {model_gb:.2f}GB)"
261
+ )
262
+ return batch_size
263
+
264
+ except Exception as e:
265
+ logger.warning(f"Failed to infer CPU batch size: {e}, using default 16")
266
+ return 16
267
+
268
+ def compute_log_probability_batch(
269
+ self, texts: list[str], batch_size: int | None = None
270
+ ) -> list[float]:
271
+ """Compute log probabilities for multiple texts efficiently.
272
+
273
+ Uses batched tokenization and inference for significant speedup.
274
+ Checks cache before computing, only processes uncached texts.
275
+
276
+ Parameters
277
+ ----------
278
+ texts : list[str]
279
+ Texts to compute log probabilities for.
280
+ batch_size : int | None, default=None
281
+ Number of texts to process in each batch. If None, automatically
282
+ infers optimal batch size based on available device memory and
283
+ model size.
284
+
285
+ Returns
286
+ -------
287
+ list[float]
288
+ Log probabilities for each text, in the same order as input.
289
+
290
+ Examples
291
+ --------
292
+ >>> texts = ["The cat sat.", "The dog ran.", "The bird flew."]
293
+ >>> log_probs = model.compute_log_probability_batch(texts)
294
+ >>> len(log_probs) == len(texts)
295
+ True
296
+ """
297
+ # infer batch size if not provided
298
+ if batch_size is None:
299
+ batch_size = self._infer_optimal_batch_size()
300
+
301
+ # check cache for all texts
302
+ results: list[float | None] = []
303
+ uncached_indices: list[int] = []
304
+ uncached_texts: list[str] = []
305
+
306
+ for i, text in enumerate(texts):
307
+ cached = self.cache.get(self.model_name, "log_probability", text=text)
308
+ if cached is not None:
309
+ results.append(cached)
310
+ else:
311
+ results.append(None) # placeholder
312
+ uncached_indices.append(i)
313
+ uncached_texts.append(text)
314
+
315
+ # if everything was cached, return immediately
316
+ if not uncached_texts:
317
+ logger.info(f"All {len(texts)} texts found in cache")
318
+ return [r for r in results if r is not None]
319
+
320
+ # log cache statistics
321
+ n_cached = len(texts) - len(uncached_texts)
322
+ cache_rate = (n_cached / len(texts)) * 100 if texts else 0
323
+ logger.info(
324
+ f"Cache: {n_cached}/{len(texts)} texts ({cache_rate:.1f}%), "
325
+ f"processing {len(uncached_texts)} uncached with batch_size={batch_size}"
326
+ )
327
+
328
+ # process uncached texts in batches with progress tracking
329
+ uncached_scores: list[float] = []
330
+
331
+ with Progress(
332
+ SpinnerColumn(),
333
+ TextColumn("[progress.description]{task.description}"),
334
+ BarColumn(),
335
+ TaskProgressColumn(),
336
+ TimeElapsedColumn(),
337
+ TimeRemainingColumn(),
338
+ ) as progress:
339
+ task = progress.add_task(
340
+ f"[cyan]Scoring with {self.model_name}[/cyan]",
341
+ total=len(uncached_texts),
342
+ )
343
+
344
+ for batch_start in range(0, len(uncached_texts), batch_size):
345
+ batch_texts = uncached_texts[batch_start : batch_start + batch_size]
346
+ batch_scores = self._process_batch(batch_texts)
347
+ uncached_scores.extend(batch_scores)
348
+ progress.update(task, advance=len(batch_texts))
349
+
350
+ # merge cached and newly computed results
351
+ uncached_iter = iter(uncached_scores)
352
+ final_results: list[float] = []
353
+ for result in results:
354
+ if result is None:
355
+ final_results.append(next(uncached_iter))
356
+ else:
357
+ final_results.append(result)
358
+
359
+ return final_results
360
+
361
+ def _process_batch(self, batch_texts: list[str]) -> list[float]:
362
+ """Process a single batch of texts and return scores.
363
+
364
+ Parameters
365
+ ----------
366
+ batch_texts : list[str]
367
+ Texts to process in this batch.
368
+
369
+ Returns
370
+ -------
371
+ list[float]
372
+ Log probabilities for each text.
373
+ """
374
+ batch_scores: list[float] = []
375
+
376
+ # tokenize batch
377
+ inputs = self.tokenizer(
378
+ batch_texts,
379
+ return_tensors="pt",
380
+ padding=True,
381
+ truncation=True,
382
+ )
383
+ input_ids = inputs["input_ids"].to(self.device)
384
+ attention_mask = inputs["attention_mask"].to(self.device)
385
+
386
+ # compute losses for batch
387
+ with torch.no_grad():
388
+ outputs = self.model(
389
+ input_ids=input_ids,
390
+ attention_mask=attention_mask,
391
+ labels=input_ids,
392
+ )
393
+
394
+ # for batched inputs, we need to compute loss per item
395
+ # the model returns average loss across batch, so we need
396
+ # to compute per-item losses manually
397
+ logits = outputs.logits # [batch, seq_len, vocab]
398
+
399
+ # shift for causal LM: predict next token
400
+ shift_logits = logits[..., :-1, :].contiguous()
401
+ shift_labels = input_ids[..., 1:].contiguous()
402
+ shift_attention = attention_mask[..., 1:].contiguous()
403
+
404
+ # compute log probabilities per token
405
+ log_probs_per_token = torch.nn.functional.log_softmax(shift_logits, dim=-1)
406
+
407
+ # gather log probs for actual tokens
408
+ gathered_log_probs = torch.gather(
409
+ log_probs_per_token,
410
+ dim=-1,
411
+ index=shift_labels.unsqueeze(-1),
412
+ ).squeeze(-1)
413
+
414
+ # mask padding tokens and sum per sequence
415
+ masked_log_probs = gathered_log_probs * shift_attention
416
+ sequence_log_probs = masked_log_probs.sum(dim=1)
417
+
418
+ # convert to list and cache
419
+ for text, log_prob_tensor in zip(batch_texts, sequence_log_probs, strict=True):
420
+ log_prob = log_prob_tensor.item()
421
+ batch_scores.append(log_prob)
422
+
423
+ # cache result
424
+ self.cache.set(
425
+ self.model_name,
426
+ "log_probability",
427
+ log_prob,
428
+ model_version=self.model_version,
429
+ text=text,
430
+ )
431
+
432
+ return batch_scores
433
+
434
+ def compute_perplexity(self, text: str) -> float:
435
+ """Compute perplexity of text.
436
+
437
+ Perplexity is exp(average negative log-likelihood per token).
438
+
439
+ Parameters
440
+ ----------
441
+ text : str
442
+ Text to compute perplexity for.
443
+
444
+ Returns
445
+ -------
446
+ float
447
+ Perplexity of the text (positive value).
448
+ """
449
+ # check cache
450
+ cached = self.cache.get(self.model_name, "perplexity", text=text)
451
+ if cached is not None:
452
+ return cached
453
+
454
+ # tokenize
455
+ inputs = self.tokenizer(
456
+ text, return_tensors="pt", padding=True, truncation=True
457
+ )
458
+ input_ids = inputs["input_ids"].to(self.device)
459
+ attention_mask = inputs["attention_mask"].to(self.device)
460
+
461
+ # compute loss
462
+ with torch.no_grad():
463
+ outputs = self.model(
464
+ input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
465
+ )
466
+ loss = outputs.loss.item()
467
+
468
+ # perplexity is exp(loss)
469
+ perplexity = np.exp(loss)
470
+
471
+ # cache result
472
+ self.cache.set(
473
+ self.model_name,
474
+ "perplexity",
475
+ perplexity,
476
+ model_version=self.model_version,
477
+ text=text,
478
+ )
479
+
480
+ return float(perplexity)
481
+
482
+ def get_embedding(self, text: str) -> np.ndarray:
483
+ """Get embedding vector for text.
484
+
485
+ Uses mean pooling of last hidden states as the text embedding.
486
+
487
+ Parameters
488
+ ----------
489
+ text : str
490
+ Text to embed.
491
+
492
+ Returns
493
+ -------
494
+ np.ndarray
495
+ Embedding vector for the text.
496
+ """
497
+ # check cache
498
+ cached = self.cache.get(self.model_name, "embedding", text=text)
499
+ if cached is not None:
500
+ return cached
501
+
502
+ # tokenize
503
+ inputs = self.tokenizer(
504
+ text, return_tensors="pt", padding=True, truncation=True
505
+ )
506
+ input_ids = inputs["input_ids"].to(self.device)
507
+ attention_mask = inputs["attention_mask"].to(self.device)
508
+
509
+ # get hidden states
510
+ with torch.no_grad():
511
+ outputs = self.model(
512
+ input_ids=input_ids,
513
+ attention_mask=attention_mask,
514
+ output_hidden_states=True,
515
+ )
516
+ hidden_states = outputs.hidden_states[-1] # last layer
517
+
518
+ # mean pooling (weighted by attention mask)
519
+ mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size())
520
+ sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
521
+ sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
522
+ embedding = (sum_hidden / sum_mask).squeeze(0).cpu().numpy()
523
+
524
+ # cache result
525
+ self.cache.set(
526
+ self.model_name,
527
+ "embedding",
528
+ embedding,
529
+ model_version=self.model_version,
530
+ text=text,
531
+ )
532
+
533
+ return embedding
534
+
535
+ def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
536
+ """Compute natural language inference scores.
537
+
538
+ Not supported for causal language models.
539
+
540
+ Raises
541
+ ------
542
+ NotImplementedError
543
+ Always raised, as causal LMs don't support NLI directly.
544
+ """
545
+ raise NotImplementedError(
546
+ f"NLI is not supported for causal language model {self.model_name}. "
547
+ "Use HuggingFaceNLI adapter with an NLI-trained model instead."
548
+ )
549
+
550
+
551
+ class HuggingFaceMaskedLanguageModel(HuggingFaceAdapterMixin, ModelAdapter):
552
+ """Adapter for HuggingFace masked language models.
553
+
554
+ Supports models like BERT, RoBERTa, ALBERT, and other masked language
555
+ models (MLMs).
556
+
557
+ Parameters
558
+ ----------
559
+ model_name : str
560
+ HuggingFace model identifier (e.g., "bert-base-uncased").
561
+ cache : ModelOutputCache
562
+ Cache instance for storing model outputs.
563
+ device : {"cpu", "cuda", "mps"}
564
+ Device to run model on. Falls back to CPU if device unavailable.
565
+ model_version : str
566
+ Version string for cache tracking.
567
+
568
+ Examples
569
+ --------
570
+ >>> from pathlib import Path
571
+ >>> from bead.items.cache import ModelOutputCache
572
+ >>> cache = ModelOutputCache(cache_dir=Path(".cache"))
573
+ >>> model = HuggingFaceMaskedLanguageModel("bert-base-uncased", cache)
574
+ >>> log_prob = model.compute_log_probability("The cat sat on the mat.")
575
+ >>> embedding = model.get_embedding("The cat sat on the mat.")
576
+ """
577
+
578
+ def __init__(
579
+ self,
580
+ model_name: str,
581
+ cache: ModelOutputCache,
582
+ device: DeviceType = "cpu",
583
+ model_version: str = "unknown",
584
+ ) -> None:
585
+ super().__init__(model_name, cache, model_version)
586
+ self.device = self._validate_device(device)
587
+ self._model: PreTrainedModel | None = None
588
+ self._tokenizer: PreTrainedTokenizerBase | None = None
589
+
590
+ def _load_model(self) -> None:
591
+ """Load model and tokenizer lazily on first use."""
592
+ if self._model is None:
593
+ logger.info(f"Loading masked LM: {self.model_name}")
594
+ self._model = AutoModelForMaskedLM.from_pretrained(self.model_name)
595
+ self._model.to(self.device)
596
+ self._model.eval()
597
+
598
+ if self._tokenizer is None:
599
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
600
+
601
+ @property
602
+ def model(self) -> PreTrainedModel:
603
+ """Get the model, loading if necessary."""
604
+ self._load_model()
605
+ assert self._model is not None
606
+ return self._model
607
+
608
+ @property
609
+ def tokenizer(self) -> PreTrainedTokenizerBase:
610
+ """Get the tokenizer, loading if necessary."""
611
+ self._load_model()
612
+ assert self._tokenizer is not None
613
+ return self._tokenizer
614
+
615
+ def compute_log_probability(self, text: str) -> float:
616
+ """Compute log probability of text using pseudo-log-likelihood.
617
+
618
+ For MLMs, we use pseudo-log-likelihood: mask each token one at a time
619
+ and sum the log probabilities of predicting each token.
620
+
621
+ This is computationally expensive - caching is critical.
622
+
623
+ Parameters
624
+ ----------
625
+ text : str
626
+ Text to compute log probability for.
627
+
628
+ Returns
629
+ -------
630
+ float
631
+ Pseudo-log-probability of the text.
632
+ """
633
+ # check cache
634
+ cached = self.cache.get(self.model_name, "log_probability", text=text)
635
+ if cached is not None:
636
+ return cached
637
+
638
+ # tokenize
639
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True)
640
+ input_ids = inputs["input_ids"].to(self.device)
641
+
642
+ # compute pseudo-log-likelihood by masking each token
643
+ total_log_prob = 0.0
644
+ num_tokens = input_ids.size(1)
645
+
646
+ with torch.no_grad():
647
+ for i in range(num_tokens):
648
+ # skip special tokens
649
+ if input_ids[0, i] in [
650
+ self.tokenizer.cls_token_id,
651
+ self.tokenizer.sep_token_id,
652
+ self.tokenizer.pad_token_id,
653
+ ]:
654
+ continue
655
+
656
+ # create masked version
657
+ masked_input = input_ids.clone()
658
+ original_token = masked_input[0, i].item()
659
+ masked_input[0, i] = self.tokenizer.mask_token_id
660
+
661
+ # get prediction
662
+ outputs = self.model(masked_input)
663
+ logits = outputs.logits[0, i] # logits for masked position
664
+
665
+ # compute log probability of original token
666
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
667
+ total_log_prob += log_probs[original_token].item()
668
+
669
+ # cache result
670
+ self.cache.set(
671
+ self.model_name,
672
+ "log_probability",
673
+ total_log_prob,
674
+ model_version=self.model_version,
675
+ text=text,
676
+ )
677
+
678
+ return total_log_prob
679
+
680
+ def compute_perplexity(self, text: str) -> float:
681
+ """Compute perplexity based on pseudo-log-likelihood.
682
+
683
+ Parameters
684
+ ----------
685
+ text : str
686
+ Text to compute perplexity for.
687
+
688
+ Returns
689
+ -------
690
+ float
691
+ Perplexity of the text (positive value).
692
+ """
693
+ # check cache
694
+ cached = self.cache.get(self.model_name, "perplexity", text=text)
695
+ if cached is not None:
696
+ return cached
697
+
698
+ # get log probability
699
+ log_prob = self.compute_log_probability(text)
700
+
701
+ # count non-special tokens
702
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True)
703
+ input_ids = inputs["input_ids"]
704
+ num_tokens = sum(
705
+ 1
706
+ for token_id in input_ids[0].tolist()
707
+ if token_id
708
+ not in [
709
+ self.tokenizer.cls_token_id,
710
+ self.tokenizer.sep_token_id,
711
+ self.tokenizer.pad_token_id,
712
+ ]
713
+ )
714
+
715
+ # perplexity is exp(-log_prob / num_tokens)
716
+ perplexity = np.exp(-log_prob / max(num_tokens, 1))
717
+
718
+ # cache result
719
+ self.cache.set(
720
+ self.model_name,
721
+ "perplexity",
722
+ perplexity,
723
+ model_version=self.model_version,
724
+ text=text,
725
+ )
726
+
727
+ return float(perplexity)
728
+
729
+ def get_embedding(self, text: str) -> np.ndarray:
730
+ """Get embedding vector for text.
731
+
732
+ Uses the [CLS] token embedding from the last layer.
733
+
734
+ Parameters
735
+ ----------
736
+ text : str
737
+ Text to embed.
738
+
739
+ Returns
740
+ -------
741
+ np.ndarray
742
+ Embedding vector for the text.
743
+ """
744
+ # check cache
745
+ cached = self.cache.get(self.model_name, "embedding", text=text)
746
+ if cached is not None:
747
+ return cached
748
+
749
+ # tokenize
750
+ inputs = self.tokenizer(
751
+ text, return_tensors="pt", padding=True, truncation=True
752
+ )
753
+ input_ids = inputs["input_ids"].to(self.device)
754
+ attention_mask = inputs["attention_mask"].to(self.device)
755
+
756
+ # get hidden states
757
+ with torch.no_grad():
758
+ outputs = self.model(
759
+ input_ids=input_ids,
760
+ attention_mask=attention_mask,
761
+ output_hidden_states=True,
762
+ )
763
+ # use [CLS] token from last layer
764
+ hidden_states = outputs.hidden_states[-1]
765
+ cls_embedding = hidden_states[0, 0].cpu().numpy()
766
+
767
+ # cache result
768
+ self.cache.set(
769
+ self.model_name,
770
+ "embedding",
771
+ cls_embedding,
772
+ model_version=self.model_version,
773
+ text=text,
774
+ )
775
+
776
+ return cls_embedding
777
+
778
+ def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
779
+ """Compute natural language inference scores.
780
+
781
+ Not supported for masked language models.
782
+
783
+ Raises
784
+ ------
785
+ NotImplementedError
786
+ Always raised, as MLMs don't support NLI directly.
787
+ """
788
+ raise NotImplementedError(
789
+ f"NLI is not supported for masked language model {self.model_name}. "
790
+ "Use HuggingFaceNLI adapter with an NLI-trained model instead."
791
+ )
792
+
793
+
794
+ class HuggingFaceNLI(HuggingFaceAdapterMixin, ModelAdapter):
795
+ """Adapter for HuggingFace NLI models.
796
+
797
+ Supports NLI models trained on MNLI and similar datasets
798
+ (e.g., "roberta-large-mnli", "microsoft/deberta-base-mnli").
799
+
800
+ Parameters
801
+ ----------
802
+ model_name : str
803
+ HuggingFace model identifier for NLI model.
804
+ cache : ModelOutputCache
805
+ Cache instance for storing model outputs.
806
+ device : {"cpu", "cuda", "mps"}
807
+ Device to run model on. Falls back to CPU if device unavailable.
808
+ model_version : str
809
+ Version string for cache tracking.
810
+
811
+ Examples
812
+ --------
813
+ >>> from pathlib import Path
814
+ >>> from bead.items.cache import ModelOutputCache
815
+ >>> cache = ModelOutputCache(cache_dir=Path(".cache"))
816
+ >>> nli = HuggingFaceNLI("roberta-large-mnli", cache, device="cpu")
817
+ >>> scores = nli.compute_nli(
818
+ ... premise="Mary loves reading books.",
819
+ ... hypothesis="Mary enjoys literature."
820
+ ... )
821
+ >>> label = nli.get_nli_label(
822
+ ... premise="Mary loves reading books.",
823
+ ... hypothesis="Mary enjoys literature."
824
+ ... )
825
+ """
826
+
827
+ def __init__(
828
+ self,
829
+ model_name: str,
830
+ cache: ModelOutputCache,
831
+ device: DeviceType = "cpu",
832
+ model_version: str = "unknown",
833
+ ) -> None:
834
+ super().__init__(model_name, cache, model_version)
835
+ self.device = self._validate_device(device)
836
+ self._model: PreTrainedModel | None = None
837
+ self._tokenizer: PreTrainedTokenizerBase | None = None
838
+ self._label_mapping: dict[str, str] = {}
839
+
840
+ def _load_model(self) -> None:
841
+ """Load model and tokenizer lazily on first use."""
842
+ if self._model is None:
843
+ logger.info(f"Loading NLI model: {self.model_name}")
844
+ self._model = AutoModelForSequenceClassification.from_pretrained(
845
+ self.model_name
846
+ )
847
+ self._model.to(self.device)
848
+ self._model.eval()
849
+
850
+ # Get label mapping from config
851
+ config = AutoConfig.from_pretrained(self.model_name)
852
+ if hasattr(config, "id2label"):
853
+ # Build mapping from model labels to standard labels
854
+ self._label_mapping = self._build_label_mapping(config.id2label)
855
+ else:
856
+ # Default mapping (assume standard order)
857
+ self._label_mapping = {
858
+ "0": "entailment",
859
+ "1": "neutral",
860
+ "2": "contradiction",
861
+ }
862
+
863
+ if self._tokenizer is None:
864
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
865
+
866
+ def _build_label_mapping(self, id2label: dict[int, str]) -> dict[str, str]:
867
+ """Build mapping from model label IDs to standard NLI labels.
868
+
869
+ Parameters
870
+ ----------
871
+ id2label
872
+ Mapping from label IDs to label strings from model config.
873
+
874
+ Returns
875
+ -------
876
+ dict[str, str]
877
+ Mapping from label IDs (as strings) to standard labels.
878
+ """
879
+ mapping: dict[str, str] = {}
880
+ for idx, label in id2label.items():
881
+ # normalize label to lowercase
882
+ normalized = label.lower()
883
+ # map to standard labels
884
+ if "entail" in normalized:
885
+ mapping[str(idx)] = "entailment"
886
+ elif "neutral" in normalized:
887
+ mapping[str(idx)] = "neutral"
888
+ elif "contradict" in normalized:
889
+ mapping[str(idx)] = "contradiction"
890
+ else:
891
+ # keep original if we can't map it
892
+ mapping[str(idx)] = normalized
893
+ return mapping
894
+
895
+ @property
896
+ def model(self) -> PreTrainedModel:
897
+ """Get the model, loading if necessary."""
898
+ self._load_model()
899
+ assert self._model is not None
900
+ return self._model
901
+
902
+ @property
903
+ def tokenizer(self) -> PreTrainedTokenizerBase:
904
+ """Get the tokenizer, loading if necessary."""
905
+ self._load_model()
906
+ assert self._tokenizer is not None
907
+ return self._tokenizer
908
+
909
+ def compute_log_probability(self, text: str) -> float:
910
+ """Compute log probability of text.
911
+
912
+ Not supported for NLI models.
913
+
914
+ Raises
915
+ ------
916
+ NotImplementedError
917
+ Always raised, as NLI models don't provide log probabilities.
918
+ """
919
+ raise NotImplementedError(
920
+ f"Log probability is not supported for NLI model {self.model_name}. "
921
+ "Use HuggingFaceLanguageModel or HuggingFaceMaskedLanguageModel instead."
922
+ )
923
+
924
+ def compute_perplexity(self, text: str) -> float:
925
+ """Compute perplexity of text.
926
+
927
+ Not supported for NLI models.
928
+
929
+ Raises
930
+ ------
931
+ NotImplementedError
932
+ Always raised, as NLI models don't provide perplexity.
933
+ """
934
+ raise NotImplementedError(
935
+ f"Perplexity is not supported for NLI model {self.model_name}. "
936
+ "Use HuggingFaceLanguageModel or HuggingFaceMaskedLanguageModel instead."
937
+ )
938
+
939
+ def get_embedding(self, text: str) -> np.ndarray:
940
+ """Get embedding vector for text.
941
+
942
+ Uses the model's encoder to get embeddings. Note that NLI models
943
+ are typically fine-tuned for classification, so embeddings may not
944
+ be optimal for general similarity tasks.
945
+
946
+ Parameters
947
+ ----------
948
+ text : str
949
+ Text to embed.
950
+
951
+ Returns
952
+ -------
953
+ np.ndarray
954
+ Embedding vector for the text.
955
+ """
956
+ # check cache
957
+ cached = self.cache.get(self.model_name, "embedding", text=text)
958
+ if cached is not None:
959
+ return cached
960
+
961
+ # tokenize
962
+ inputs = self.tokenizer(
963
+ text, return_tensors="pt", padding=True, truncation=True
964
+ )
965
+ input_ids = inputs["input_ids"].to(self.device)
966
+ attention_mask = inputs["attention_mask"].to(self.device)
967
+
968
+ # get hidden states (using base model if available)
969
+ with torch.no_grad():
970
+ # try to access base model for embeddings
971
+ if hasattr(self.model, "roberta"):
972
+ base_model = self.model.roberta
973
+ elif hasattr(self.model, "deberta"):
974
+ base_model = self.model.deberta
975
+ elif hasattr(self.model, "bert"):
976
+ base_model = self.model.bert
977
+ else:
978
+ # fallback: use full model with output_hidden_states
979
+ outputs = self.model(
980
+ input_ids=input_ids,
981
+ attention_mask=attention_mask,
982
+ output_hidden_states=True,
983
+ )
984
+ hidden_states = outputs.hidden_states[-1]
985
+ embedding = hidden_states[0, 0].cpu().numpy()
986
+ self.cache.set(
987
+ self.model_name,
988
+ "embedding",
989
+ embedding,
990
+ model_version=self.model_version,
991
+ text=text,
992
+ )
993
+ return embedding
994
+
995
+ # use base model
996
+ outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
997
+ # use [CLS] token
998
+ embedding = outputs.last_hidden_state[0, 0].cpu().numpy()
999
+
1000
+ # cache result
1001
+ self.cache.set(
1002
+ self.model_name,
1003
+ "embedding",
1004
+ embedding,
1005
+ model_version=self.model_version,
1006
+ text=text,
1007
+ )
1008
+
1009
+ return embedding
1010
+
1011
+ def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
1012
+ """Compute natural language inference scores.
1013
+
1014
+ Parameters
1015
+ ----------
1016
+ premise : str
1017
+ Premise text.
1018
+ hypothesis : str
1019
+ Hypothesis text.
1020
+
1021
+ Returns
1022
+ -------
1023
+ dict[str, float]
1024
+ Dictionary with keys "entailment", "neutral", "contradiction"
1025
+ mapping to probability scores that sum to ~1.0.
1026
+ """
1027
+ # check cache
1028
+ cached = self.cache.get(
1029
+ self.model_name, "nli", premise=premise, hypothesis=hypothesis
1030
+ )
1031
+ if cached is not None:
1032
+ return cached
1033
+
1034
+ # tokenize premise-hypothesis pair
1035
+ inputs = self.tokenizer(
1036
+ premise,
1037
+ hypothesis,
1038
+ return_tensors="pt",
1039
+ padding=True,
1040
+ truncation=True,
1041
+ )
1042
+ input_ids = inputs["input_ids"].to(self.device)
1043
+ attention_mask = inputs["attention_mask"].to(self.device)
1044
+
1045
+ # get logits
1046
+ with torch.no_grad():
1047
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
1048
+ logits = outputs.logits[0]
1049
+
1050
+ # convert to probabilities
1051
+ probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
1052
+
1053
+ # map to standard labels
1054
+ scores: dict[str, float] = {}
1055
+ for idx, prob in enumerate(probs):
1056
+ label = self._label_mapping.get(str(idx), str(idx))
1057
+ scores[label] = float(prob)
1058
+
1059
+ # ensure we have all three standard labels
1060
+ for label in ["entailment", "neutral", "contradiction"]:
1061
+ if label not in scores:
1062
+ scores[label] = 0.0
1063
+
1064
+ # cache result
1065
+ self.cache.set(
1066
+ self.model_name,
1067
+ "nli",
1068
+ scores,
1069
+ model_version=self.model_version,
1070
+ premise=premise,
1071
+ hypothesis=hypothesis,
1072
+ )
1073
+
1074
+ return scores