bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1074 @@
|
|
|
1
|
+
"""HuggingFace model adapters for language models and NLI.
|
|
2
|
+
|
|
3
|
+
This module provides adapters for HuggingFace Transformers models:
|
|
4
|
+
- HuggingFaceLanguageModel: Causal LMs (GPT-2, GPT-Neo, Llama, Mistral)
|
|
5
|
+
- HuggingFaceMaskedLanguageModel: Masked LMs (BERT, RoBERTa, ALBERT)
|
|
6
|
+
- HuggingFaceNLI: NLI models (RoBERTa-MNLI, DeBERTa-MNLI, BART-MNLI)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import psutil
|
|
16
|
+
import torch
|
|
17
|
+
from rich.progress import (
|
|
18
|
+
BarColumn,
|
|
19
|
+
Progress,
|
|
20
|
+
SpinnerColumn,
|
|
21
|
+
TaskProgressColumn,
|
|
22
|
+
TextColumn,
|
|
23
|
+
TimeElapsedColumn,
|
|
24
|
+
TimeRemainingColumn,
|
|
25
|
+
)
|
|
26
|
+
from transformers import (
|
|
27
|
+
AutoConfig,
|
|
28
|
+
AutoModelForCausalLM,
|
|
29
|
+
AutoModelForMaskedLM,
|
|
30
|
+
AutoModelForSequenceClassification,
|
|
31
|
+
AutoTokenizer,
|
|
32
|
+
PreTrainedModel,
|
|
33
|
+
PreTrainedTokenizerBase,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from bead.adapters.huggingface import DeviceType, HuggingFaceAdapterMixin
|
|
37
|
+
from bead.items.adapters.base import ModelAdapter
|
|
38
|
+
from bead.items.cache import ModelOutputCache
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
from transformers.models.auto.configuration_auto import AutoConfig
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class HuggingFaceLanguageModel(HuggingFaceAdapterMixin, ModelAdapter):
|
|
47
|
+
"""Adapter for HuggingFace causal language models.
|
|
48
|
+
|
|
49
|
+
Supports models like GPT-2, GPT-Neo, Llama, Mistral, and other
|
|
50
|
+
autoregressive (left-to-right) language models.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
model_name : str
|
|
55
|
+
HuggingFace model identifier (e.g., "gpt2", "gpt2-medium").
|
|
56
|
+
cache : ModelOutputCache
|
|
57
|
+
Cache instance for storing model outputs.
|
|
58
|
+
device : {"cpu", "cuda", "mps"}
|
|
59
|
+
Device to run model on. Falls back to CPU if device unavailable.
|
|
60
|
+
model_version : str
|
|
61
|
+
Version string for cache tracking.
|
|
62
|
+
|
|
63
|
+
Examples
|
|
64
|
+
--------
|
|
65
|
+
>>> from pathlib import Path
|
|
66
|
+
>>> from bead.items.cache import ModelOutputCache
|
|
67
|
+
>>> cache = ModelOutputCache(cache_dir=Path(".cache"))
|
|
68
|
+
>>> model = HuggingFaceLanguageModel("gpt2", cache, device="cpu")
|
|
69
|
+
>>> log_prob = model.compute_log_probability("The cat sat on the mat.")
|
|
70
|
+
>>> perplexity = model.compute_perplexity("The cat sat on the mat.")
|
|
71
|
+
>>> embedding = model.get_embedding("The cat sat on the mat.")
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
model_name: str,
|
|
77
|
+
cache: ModelOutputCache,
|
|
78
|
+
device: DeviceType = "cpu",
|
|
79
|
+
model_version: str = "unknown",
|
|
80
|
+
) -> None:
|
|
81
|
+
super().__init__(model_name, cache, model_version)
|
|
82
|
+
self.device = self._validate_device(device)
|
|
83
|
+
self._model: PreTrainedModel | None = None
|
|
84
|
+
self._tokenizer: PreTrainedTokenizerBase | None = None
|
|
85
|
+
|
|
86
|
+
def _load_model(self) -> None:
|
|
87
|
+
"""Load model and tokenizer lazily on first use."""
|
|
88
|
+
if self._model is None:
|
|
89
|
+
logger.info(f"Loading causal LM: {self.model_name}")
|
|
90
|
+
self._model = AutoModelForCausalLM.from_pretrained(self.model_name)
|
|
91
|
+
self._model.to(self.device)
|
|
92
|
+
self._model.eval()
|
|
93
|
+
|
|
94
|
+
if self._tokenizer is None:
|
|
95
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
96
|
+
# set padding token for models that don't have one
|
|
97
|
+
if self._tokenizer.pad_token is None:
|
|
98
|
+
self._tokenizer.pad_token = self._tokenizer.eos_token
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def model(self) -> PreTrainedModel:
|
|
102
|
+
"""Get the model, loading if necessary."""
|
|
103
|
+
self._load_model()
|
|
104
|
+
assert self._model is not None
|
|
105
|
+
return self._model
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def tokenizer(self) -> PreTrainedTokenizerBase:
|
|
109
|
+
"""Get the tokenizer, loading if necessary."""
|
|
110
|
+
self._load_model()
|
|
111
|
+
assert self._tokenizer is not None
|
|
112
|
+
return self._tokenizer
|
|
113
|
+
|
|
114
|
+
def compute_log_probability(self, text: str) -> float:
|
|
115
|
+
"""Compute log probability of text under language model.
|
|
116
|
+
|
|
117
|
+
Uses the model's loss with labels=input_ids to compute the negative
|
|
118
|
+
log-likelihood of the text.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
text : str
|
|
123
|
+
Text to compute log probability for.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
float
|
|
128
|
+
Log probability of the text.
|
|
129
|
+
"""
|
|
130
|
+
# Check cache
|
|
131
|
+
cached = self.cache.get(self.model_name, "log_probability", text=text)
|
|
132
|
+
if cached is not None:
|
|
133
|
+
return cached
|
|
134
|
+
|
|
135
|
+
# tokenize
|
|
136
|
+
inputs = self.tokenizer(
|
|
137
|
+
text, return_tensors="pt", padding=True, truncation=True
|
|
138
|
+
)
|
|
139
|
+
input_ids = inputs["input_ids"].to(self.device)
|
|
140
|
+
attention_mask = inputs["attention_mask"].to(self.device)
|
|
141
|
+
|
|
142
|
+
# compute loss (negative log-likelihood)
|
|
143
|
+
with torch.no_grad():
|
|
144
|
+
outputs = self.model(
|
|
145
|
+
input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
|
|
146
|
+
)
|
|
147
|
+
loss = outputs.loss.item()
|
|
148
|
+
|
|
149
|
+
# loss is negative log-likelihood per token, convert to total log prob
|
|
150
|
+
log_prob = -loss * input_ids.size(1)
|
|
151
|
+
|
|
152
|
+
# cache result
|
|
153
|
+
self.cache.set(
|
|
154
|
+
self.model_name,
|
|
155
|
+
"log_probability",
|
|
156
|
+
log_prob,
|
|
157
|
+
model_version=self.model_version,
|
|
158
|
+
text=text,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return log_prob
|
|
162
|
+
|
|
163
|
+
def _infer_optimal_batch_size(self) -> int:
|
|
164
|
+
"""Infer optimal batch size based on available resources.
|
|
165
|
+
|
|
166
|
+
Considers:
|
|
167
|
+
- Device type (CPU, CUDA, MPS)
|
|
168
|
+
- Available memory
|
|
169
|
+
- Model size
|
|
170
|
+
- Sequence length estimates
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
int
|
|
175
|
+
Recommended batch size.
|
|
176
|
+
"""
|
|
177
|
+
# estimate model size
|
|
178
|
+
model_params = sum(
|
|
179
|
+
p.numel() * p.element_size() for p in self.model.parameters()
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if self.device == "cuda":
|
|
183
|
+
try:
|
|
184
|
+
# get GPU memory
|
|
185
|
+
free_memory, _ = torch.cuda.mem_get_info(self.device)
|
|
186
|
+
|
|
187
|
+
# conservative estimate: allow model + 4x model size for activations
|
|
188
|
+
# reserve 20% for safety margin
|
|
189
|
+
available_for_batch = (free_memory * 0.8) - model_params
|
|
190
|
+
memory_per_item = model_params * 4 # very rough estimate
|
|
191
|
+
|
|
192
|
+
batch_size = int(available_for_batch / memory_per_item)
|
|
193
|
+
|
|
194
|
+
# clamp between reasonable bounds
|
|
195
|
+
batch_size = max(8, min(batch_size, 256))
|
|
196
|
+
|
|
197
|
+
free_gb = free_memory / 1e9
|
|
198
|
+
model_gb = model_params / 1e9
|
|
199
|
+
logger.info(
|
|
200
|
+
f"Inferred batch size {batch_size} for CUDA "
|
|
201
|
+
f"(free: {free_gb:.1f}GB, model: {model_gb:.2f}GB)"
|
|
202
|
+
)
|
|
203
|
+
return batch_size
|
|
204
|
+
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.warning(
|
|
207
|
+
f"Failed to infer CUDA batch size: {e}, using default 32"
|
|
208
|
+
)
|
|
209
|
+
return 32
|
|
210
|
+
|
|
211
|
+
elif self.device == "mps":
|
|
212
|
+
try:
|
|
213
|
+
# mps (Apple Silicon) - use system RAM as proxy
|
|
214
|
+
# mps shares unified memory with system
|
|
215
|
+
available_memory = psutil.virtual_memory().available
|
|
216
|
+
|
|
217
|
+
# reserve 4GB for system + model
|
|
218
|
+
available_for_batch = max(
|
|
219
|
+
0, available_memory - (4 * 1024**3) - model_params
|
|
220
|
+
)
|
|
221
|
+
memory_per_item = model_params * 3 # mps is more efficient than CUDA
|
|
222
|
+
|
|
223
|
+
batch_size = int(available_for_batch / memory_per_item)
|
|
224
|
+
|
|
225
|
+
# clamp between reasonable bounds
|
|
226
|
+
batch_size = max(8, min(batch_size, 256))
|
|
227
|
+
|
|
228
|
+
avail_gb = available_memory / 1e9
|
|
229
|
+
model_gb = model_params / 1e9
|
|
230
|
+
logger.info(
|
|
231
|
+
f"Inferred batch size {batch_size} for MPS "
|
|
232
|
+
f"(available: {avail_gb:.1f}GB, model: {model_gb:.2f}GB)"
|
|
233
|
+
)
|
|
234
|
+
return batch_size
|
|
235
|
+
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.warning(f"Failed to infer MPS batch size: {e}, using default 64")
|
|
238
|
+
return 64
|
|
239
|
+
|
|
240
|
+
else: # CPU
|
|
241
|
+
try:
|
|
242
|
+
# cpu - check available RAM
|
|
243
|
+
available_memory = psutil.virtual_memory().available
|
|
244
|
+
|
|
245
|
+
# reserve 2GB for system + model
|
|
246
|
+
available_for_batch = max(
|
|
247
|
+
0, available_memory - (2 * 1024**3) - model_params
|
|
248
|
+
)
|
|
249
|
+
memory_per_item = model_params * 2 # cpu has less overhead than GPU
|
|
250
|
+
|
|
251
|
+
batch_size = int(available_for_batch / memory_per_item)
|
|
252
|
+
|
|
253
|
+
# clamp between reasonable bounds
|
|
254
|
+
batch_size = max(4, min(batch_size, 128))
|
|
255
|
+
|
|
256
|
+
avail_gb = available_memory / 1e9
|
|
257
|
+
model_gb = model_params / 1e9
|
|
258
|
+
logger.info(
|
|
259
|
+
f"Inferred batch size {batch_size} for CPU "
|
|
260
|
+
f"(available: {avail_gb:.1f}GB, model: {model_gb:.2f}GB)"
|
|
261
|
+
)
|
|
262
|
+
return batch_size
|
|
263
|
+
|
|
264
|
+
except Exception as e:
|
|
265
|
+
logger.warning(f"Failed to infer CPU batch size: {e}, using default 16")
|
|
266
|
+
return 16
|
|
267
|
+
|
|
268
|
+
def compute_log_probability_batch(
|
|
269
|
+
self, texts: list[str], batch_size: int | None = None
|
|
270
|
+
) -> list[float]:
|
|
271
|
+
"""Compute log probabilities for multiple texts efficiently.
|
|
272
|
+
|
|
273
|
+
Uses batched tokenization and inference for significant speedup.
|
|
274
|
+
Checks cache before computing, only processes uncached texts.
|
|
275
|
+
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
278
|
+
texts : list[str]
|
|
279
|
+
Texts to compute log probabilities for.
|
|
280
|
+
batch_size : int | None, default=None
|
|
281
|
+
Number of texts to process in each batch. If None, automatically
|
|
282
|
+
infers optimal batch size based on available device memory and
|
|
283
|
+
model size.
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
list[float]
|
|
288
|
+
Log probabilities for each text, in the same order as input.
|
|
289
|
+
|
|
290
|
+
Examples
|
|
291
|
+
--------
|
|
292
|
+
>>> texts = ["The cat sat.", "The dog ran.", "The bird flew."]
|
|
293
|
+
>>> log_probs = model.compute_log_probability_batch(texts)
|
|
294
|
+
>>> len(log_probs) == len(texts)
|
|
295
|
+
True
|
|
296
|
+
"""
|
|
297
|
+
# infer batch size if not provided
|
|
298
|
+
if batch_size is None:
|
|
299
|
+
batch_size = self._infer_optimal_batch_size()
|
|
300
|
+
|
|
301
|
+
# check cache for all texts
|
|
302
|
+
results: list[float | None] = []
|
|
303
|
+
uncached_indices: list[int] = []
|
|
304
|
+
uncached_texts: list[str] = []
|
|
305
|
+
|
|
306
|
+
for i, text in enumerate(texts):
|
|
307
|
+
cached = self.cache.get(self.model_name, "log_probability", text=text)
|
|
308
|
+
if cached is not None:
|
|
309
|
+
results.append(cached)
|
|
310
|
+
else:
|
|
311
|
+
results.append(None) # placeholder
|
|
312
|
+
uncached_indices.append(i)
|
|
313
|
+
uncached_texts.append(text)
|
|
314
|
+
|
|
315
|
+
# if everything was cached, return immediately
|
|
316
|
+
if not uncached_texts:
|
|
317
|
+
logger.info(f"All {len(texts)} texts found in cache")
|
|
318
|
+
return [r for r in results if r is not None]
|
|
319
|
+
|
|
320
|
+
# log cache statistics
|
|
321
|
+
n_cached = len(texts) - len(uncached_texts)
|
|
322
|
+
cache_rate = (n_cached / len(texts)) * 100 if texts else 0
|
|
323
|
+
logger.info(
|
|
324
|
+
f"Cache: {n_cached}/{len(texts)} texts ({cache_rate:.1f}%), "
|
|
325
|
+
f"processing {len(uncached_texts)} uncached with batch_size={batch_size}"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# process uncached texts in batches with progress tracking
|
|
329
|
+
uncached_scores: list[float] = []
|
|
330
|
+
|
|
331
|
+
with Progress(
|
|
332
|
+
SpinnerColumn(),
|
|
333
|
+
TextColumn("[progress.description]{task.description}"),
|
|
334
|
+
BarColumn(),
|
|
335
|
+
TaskProgressColumn(),
|
|
336
|
+
TimeElapsedColumn(),
|
|
337
|
+
TimeRemainingColumn(),
|
|
338
|
+
) as progress:
|
|
339
|
+
task = progress.add_task(
|
|
340
|
+
f"[cyan]Scoring with {self.model_name}[/cyan]",
|
|
341
|
+
total=len(uncached_texts),
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
for batch_start in range(0, len(uncached_texts), batch_size):
|
|
345
|
+
batch_texts = uncached_texts[batch_start : batch_start + batch_size]
|
|
346
|
+
batch_scores = self._process_batch(batch_texts)
|
|
347
|
+
uncached_scores.extend(batch_scores)
|
|
348
|
+
progress.update(task, advance=len(batch_texts))
|
|
349
|
+
|
|
350
|
+
# merge cached and newly computed results
|
|
351
|
+
uncached_iter = iter(uncached_scores)
|
|
352
|
+
final_results: list[float] = []
|
|
353
|
+
for result in results:
|
|
354
|
+
if result is None:
|
|
355
|
+
final_results.append(next(uncached_iter))
|
|
356
|
+
else:
|
|
357
|
+
final_results.append(result)
|
|
358
|
+
|
|
359
|
+
return final_results
|
|
360
|
+
|
|
361
|
+
def _process_batch(self, batch_texts: list[str]) -> list[float]:
|
|
362
|
+
"""Process a single batch of texts and return scores.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
batch_texts : list[str]
|
|
367
|
+
Texts to process in this batch.
|
|
368
|
+
|
|
369
|
+
Returns
|
|
370
|
+
-------
|
|
371
|
+
list[float]
|
|
372
|
+
Log probabilities for each text.
|
|
373
|
+
"""
|
|
374
|
+
batch_scores: list[float] = []
|
|
375
|
+
|
|
376
|
+
# tokenize batch
|
|
377
|
+
inputs = self.tokenizer(
|
|
378
|
+
batch_texts,
|
|
379
|
+
return_tensors="pt",
|
|
380
|
+
padding=True,
|
|
381
|
+
truncation=True,
|
|
382
|
+
)
|
|
383
|
+
input_ids = inputs["input_ids"].to(self.device)
|
|
384
|
+
attention_mask = inputs["attention_mask"].to(self.device)
|
|
385
|
+
|
|
386
|
+
# compute losses for batch
|
|
387
|
+
with torch.no_grad():
|
|
388
|
+
outputs = self.model(
|
|
389
|
+
input_ids=input_ids,
|
|
390
|
+
attention_mask=attention_mask,
|
|
391
|
+
labels=input_ids,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# for batched inputs, we need to compute loss per item
|
|
395
|
+
# the model returns average loss across batch, so we need
|
|
396
|
+
# to compute per-item losses manually
|
|
397
|
+
logits = outputs.logits # [batch, seq_len, vocab]
|
|
398
|
+
|
|
399
|
+
# shift for causal LM: predict next token
|
|
400
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
401
|
+
shift_labels = input_ids[..., 1:].contiguous()
|
|
402
|
+
shift_attention = attention_mask[..., 1:].contiguous()
|
|
403
|
+
|
|
404
|
+
# compute log probabilities per token
|
|
405
|
+
log_probs_per_token = torch.nn.functional.log_softmax(shift_logits, dim=-1)
|
|
406
|
+
|
|
407
|
+
# gather log probs for actual tokens
|
|
408
|
+
gathered_log_probs = torch.gather(
|
|
409
|
+
log_probs_per_token,
|
|
410
|
+
dim=-1,
|
|
411
|
+
index=shift_labels.unsqueeze(-1),
|
|
412
|
+
).squeeze(-1)
|
|
413
|
+
|
|
414
|
+
# mask padding tokens and sum per sequence
|
|
415
|
+
masked_log_probs = gathered_log_probs * shift_attention
|
|
416
|
+
sequence_log_probs = masked_log_probs.sum(dim=1)
|
|
417
|
+
|
|
418
|
+
# convert to list and cache
|
|
419
|
+
for text, log_prob_tensor in zip(batch_texts, sequence_log_probs, strict=True):
|
|
420
|
+
log_prob = log_prob_tensor.item()
|
|
421
|
+
batch_scores.append(log_prob)
|
|
422
|
+
|
|
423
|
+
# cache result
|
|
424
|
+
self.cache.set(
|
|
425
|
+
self.model_name,
|
|
426
|
+
"log_probability",
|
|
427
|
+
log_prob,
|
|
428
|
+
model_version=self.model_version,
|
|
429
|
+
text=text,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
return batch_scores
|
|
433
|
+
|
|
434
|
+
def compute_perplexity(self, text: str) -> float:
|
|
435
|
+
"""Compute perplexity of text.
|
|
436
|
+
|
|
437
|
+
Perplexity is exp(average negative log-likelihood per token).
|
|
438
|
+
|
|
439
|
+
Parameters
|
|
440
|
+
----------
|
|
441
|
+
text : str
|
|
442
|
+
Text to compute perplexity for.
|
|
443
|
+
|
|
444
|
+
Returns
|
|
445
|
+
-------
|
|
446
|
+
float
|
|
447
|
+
Perplexity of the text (positive value).
|
|
448
|
+
"""
|
|
449
|
+
# check cache
|
|
450
|
+
cached = self.cache.get(self.model_name, "perplexity", text=text)
|
|
451
|
+
if cached is not None:
|
|
452
|
+
return cached
|
|
453
|
+
|
|
454
|
+
# tokenize
|
|
455
|
+
inputs = self.tokenizer(
|
|
456
|
+
text, return_tensors="pt", padding=True, truncation=True
|
|
457
|
+
)
|
|
458
|
+
input_ids = inputs["input_ids"].to(self.device)
|
|
459
|
+
attention_mask = inputs["attention_mask"].to(self.device)
|
|
460
|
+
|
|
461
|
+
# compute loss
|
|
462
|
+
with torch.no_grad():
|
|
463
|
+
outputs = self.model(
|
|
464
|
+
input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
|
|
465
|
+
)
|
|
466
|
+
loss = outputs.loss.item()
|
|
467
|
+
|
|
468
|
+
# perplexity is exp(loss)
|
|
469
|
+
perplexity = np.exp(loss)
|
|
470
|
+
|
|
471
|
+
# cache result
|
|
472
|
+
self.cache.set(
|
|
473
|
+
self.model_name,
|
|
474
|
+
"perplexity",
|
|
475
|
+
perplexity,
|
|
476
|
+
model_version=self.model_version,
|
|
477
|
+
text=text,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
return float(perplexity)
|
|
481
|
+
|
|
482
|
+
def get_embedding(self, text: str) -> np.ndarray:
|
|
483
|
+
"""Get embedding vector for text.
|
|
484
|
+
|
|
485
|
+
Uses mean pooling of last hidden states as the text embedding.
|
|
486
|
+
|
|
487
|
+
Parameters
|
|
488
|
+
----------
|
|
489
|
+
text : str
|
|
490
|
+
Text to embed.
|
|
491
|
+
|
|
492
|
+
Returns
|
|
493
|
+
-------
|
|
494
|
+
np.ndarray
|
|
495
|
+
Embedding vector for the text.
|
|
496
|
+
"""
|
|
497
|
+
# check cache
|
|
498
|
+
cached = self.cache.get(self.model_name, "embedding", text=text)
|
|
499
|
+
if cached is not None:
|
|
500
|
+
return cached
|
|
501
|
+
|
|
502
|
+
# tokenize
|
|
503
|
+
inputs = self.tokenizer(
|
|
504
|
+
text, return_tensors="pt", padding=True, truncation=True
|
|
505
|
+
)
|
|
506
|
+
input_ids = inputs["input_ids"].to(self.device)
|
|
507
|
+
attention_mask = inputs["attention_mask"].to(self.device)
|
|
508
|
+
|
|
509
|
+
# get hidden states
|
|
510
|
+
with torch.no_grad():
|
|
511
|
+
outputs = self.model(
|
|
512
|
+
input_ids=input_ids,
|
|
513
|
+
attention_mask=attention_mask,
|
|
514
|
+
output_hidden_states=True,
|
|
515
|
+
)
|
|
516
|
+
hidden_states = outputs.hidden_states[-1] # last layer
|
|
517
|
+
|
|
518
|
+
# mean pooling (weighted by attention mask)
|
|
519
|
+
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size())
|
|
520
|
+
sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
|
|
521
|
+
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
|
|
522
|
+
embedding = (sum_hidden / sum_mask).squeeze(0).cpu().numpy()
|
|
523
|
+
|
|
524
|
+
# cache result
|
|
525
|
+
self.cache.set(
|
|
526
|
+
self.model_name,
|
|
527
|
+
"embedding",
|
|
528
|
+
embedding,
|
|
529
|
+
model_version=self.model_version,
|
|
530
|
+
text=text,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
return embedding
|
|
534
|
+
|
|
535
|
+
def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
|
|
536
|
+
"""Compute natural language inference scores.
|
|
537
|
+
|
|
538
|
+
Not supported for causal language models.
|
|
539
|
+
|
|
540
|
+
Raises
|
|
541
|
+
------
|
|
542
|
+
NotImplementedError
|
|
543
|
+
Always raised, as causal LMs don't support NLI directly.
|
|
544
|
+
"""
|
|
545
|
+
raise NotImplementedError(
|
|
546
|
+
f"NLI is not supported for causal language model {self.model_name}. "
|
|
547
|
+
"Use HuggingFaceNLI adapter with an NLI-trained model instead."
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
class HuggingFaceMaskedLanguageModel(HuggingFaceAdapterMixin, ModelAdapter):
|
|
552
|
+
"""Adapter for HuggingFace masked language models.
|
|
553
|
+
|
|
554
|
+
Supports models like BERT, RoBERTa, ALBERT, and other masked language
|
|
555
|
+
models (MLMs).
|
|
556
|
+
|
|
557
|
+
Parameters
|
|
558
|
+
----------
|
|
559
|
+
model_name : str
|
|
560
|
+
HuggingFace model identifier (e.g., "bert-base-uncased").
|
|
561
|
+
cache : ModelOutputCache
|
|
562
|
+
Cache instance for storing model outputs.
|
|
563
|
+
device : {"cpu", "cuda", "mps"}
|
|
564
|
+
Device to run model on. Falls back to CPU if device unavailable.
|
|
565
|
+
model_version : str
|
|
566
|
+
Version string for cache tracking.
|
|
567
|
+
|
|
568
|
+
Examples
|
|
569
|
+
--------
|
|
570
|
+
>>> from pathlib import Path
|
|
571
|
+
>>> from bead.items.cache import ModelOutputCache
|
|
572
|
+
>>> cache = ModelOutputCache(cache_dir=Path(".cache"))
|
|
573
|
+
>>> model = HuggingFaceMaskedLanguageModel("bert-base-uncased", cache)
|
|
574
|
+
>>> log_prob = model.compute_log_probability("The cat sat on the mat.")
|
|
575
|
+
>>> embedding = model.get_embedding("The cat sat on the mat.")
|
|
576
|
+
"""
|
|
577
|
+
|
|
578
|
+
def __init__(
|
|
579
|
+
self,
|
|
580
|
+
model_name: str,
|
|
581
|
+
cache: ModelOutputCache,
|
|
582
|
+
device: DeviceType = "cpu",
|
|
583
|
+
model_version: str = "unknown",
|
|
584
|
+
) -> None:
|
|
585
|
+
super().__init__(model_name, cache, model_version)
|
|
586
|
+
self.device = self._validate_device(device)
|
|
587
|
+
self._model: PreTrainedModel | None = None
|
|
588
|
+
self._tokenizer: PreTrainedTokenizerBase | None = None
|
|
589
|
+
|
|
590
|
+
def _load_model(self) -> None:
|
|
591
|
+
"""Load model and tokenizer lazily on first use."""
|
|
592
|
+
if self._model is None:
|
|
593
|
+
logger.info(f"Loading masked LM: {self.model_name}")
|
|
594
|
+
self._model = AutoModelForMaskedLM.from_pretrained(self.model_name)
|
|
595
|
+
self._model.to(self.device)
|
|
596
|
+
self._model.eval()
|
|
597
|
+
|
|
598
|
+
if self._tokenizer is None:
|
|
599
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
600
|
+
|
|
601
|
+
@property
|
|
602
|
+
def model(self) -> PreTrainedModel:
|
|
603
|
+
"""Get the model, loading if necessary."""
|
|
604
|
+
self._load_model()
|
|
605
|
+
assert self._model is not None
|
|
606
|
+
return self._model
|
|
607
|
+
|
|
608
|
+
@property
|
|
609
|
+
def tokenizer(self) -> PreTrainedTokenizerBase:
|
|
610
|
+
"""Get the tokenizer, loading if necessary."""
|
|
611
|
+
self._load_model()
|
|
612
|
+
assert self._tokenizer is not None
|
|
613
|
+
return self._tokenizer
|
|
614
|
+
|
|
615
|
+
def compute_log_probability(self, text: str) -> float:
|
|
616
|
+
"""Compute log probability of text using pseudo-log-likelihood.
|
|
617
|
+
|
|
618
|
+
For MLMs, we use pseudo-log-likelihood: mask each token one at a time
|
|
619
|
+
and sum the log probabilities of predicting each token.
|
|
620
|
+
|
|
621
|
+
This is computationally expensive - caching is critical.
|
|
622
|
+
|
|
623
|
+
Parameters
|
|
624
|
+
----------
|
|
625
|
+
text : str
|
|
626
|
+
Text to compute log probability for.
|
|
627
|
+
|
|
628
|
+
Returns
|
|
629
|
+
-------
|
|
630
|
+
float
|
|
631
|
+
Pseudo-log-probability of the text.
|
|
632
|
+
"""
|
|
633
|
+
# check cache
|
|
634
|
+
cached = self.cache.get(self.model_name, "log_probability", text=text)
|
|
635
|
+
if cached is not None:
|
|
636
|
+
return cached
|
|
637
|
+
|
|
638
|
+
# tokenize
|
|
639
|
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True)
|
|
640
|
+
input_ids = inputs["input_ids"].to(self.device)
|
|
641
|
+
|
|
642
|
+
# compute pseudo-log-likelihood by masking each token
|
|
643
|
+
total_log_prob = 0.0
|
|
644
|
+
num_tokens = input_ids.size(1)
|
|
645
|
+
|
|
646
|
+
with torch.no_grad():
|
|
647
|
+
for i in range(num_tokens):
|
|
648
|
+
# skip special tokens
|
|
649
|
+
if input_ids[0, i] in [
|
|
650
|
+
self.tokenizer.cls_token_id,
|
|
651
|
+
self.tokenizer.sep_token_id,
|
|
652
|
+
self.tokenizer.pad_token_id,
|
|
653
|
+
]:
|
|
654
|
+
continue
|
|
655
|
+
|
|
656
|
+
# create masked version
|
|
657
|
+
masked_input = input_ids.clone()
|
|
658
|
+
original_token = masked_input[0, i].item()
|
|
659
|
+
masked_input[0, i] = self.tokenizer.mask_token_id
|
|
660
|
+
|
|
661
|
+
# get prediction
|
|
662
|
+
outputs = self.model(masked_input)
|
|
663
|
+
logits = outputs.logits[0, i] # logits for masked position
|
|
664
|
+
|
|
665
|
+
# compute log probability of original token
|
|
666
|
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
|
667
|
+
total_log_prob += log_probs[original_token].item()
|
|
668
|
+
|
|
669
|
+
# cache result
|
|
670
|
+
self.cache.set(
|
|
671
|
+
self.model_name,
|
|
672
|
+
"log_probability",
|
|
673
|
+
total_log_prob,
|
|
674
|
+
model_version=self.model_version,
|
|
675
|
+
text=text,
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
return total_log_prob
|
|
679
|
+
|
|
680
|
+
def compute_perplexity(self, text: str) -> float:
|
|
681
|
+
"""Compute perplexity based on pseudo-log-likelihood.
|
|
682
|
+
|
|
683
|
+
Parameters
|
|
684
|
+
----------
|
|
685
|
+
text : str
|
|
686
|
+
Text to compute perplexity for.
|
|
687
|
+
|
|
688
|
+
Returns
|
|
689
|
+
-------
|
|
690
|
+
float
|
|
691
|
+
Perplexity of the text (positive value).
|
|
692
|
+
"""
|
|
693
|
+
# check cache
|
|
694
|
+
cached = self.cache.get(self.model_name, "perplexity", text=text)
|
|
695
|
+
if cached is not None:
|
|
696
|
+
return cached
|
|
697
|
+
|
|
698
|
+
# get log probability
|
|
699
|
+
log_prob = self.compute_log_probability(text)
|
|
700
|
+
|
|
701
|
+
# count non-special tokens
|
|
702
|
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True)
|
|
703
|
+
input_ids = inputs["input_ids"]
|
|
704
|
+
num_tokens = sum(
|
|
705
|
+
1
|
|
706
|
+
for token_id in input_ids[0].tolist()
|
|
707
|
+
if token_id
|
|
708
|
+
not in [
|
|
709
|
+
self.tokenizer.cls_token_id,
|
|
710
|
+
self.tokenizer.sep_token_id,
|
|
711
|
+
self.tokenizer.pad_token_id,
|
|
712
|
+
]
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
# perplexity is exp(-log_prob / num_tokens)
|
|
716
|
+
perplexity = np.exp(-log_prob / max(num_tokens, 1))
|
|
717
|
+
|
|
718
|
+
# cache result
|
|
719
|
+
self.cache.set(
|
|
720
|
+
self.model_name,
|
|
721
|
+
"perplexity",
|
|
722
|
+
perplexity,
|
|
723
|
+
model_version=self.model_version,
|
|
724
|
+
text=text,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
return float(perplexity)
|
|
728
|
+
|
|
729
|
+
def get_embedding(self, text: str) -> np.ndarray:
|
|
730
|
+
"""Get embedding vector for text.
|
|
731
|
+
|
|
732
|
+
Uses the [CLS] token embedding from the last layer.
|
|
733
|
+
|
|
734
|
+
Parameters
|
|
735
|
+
----------
|
|
736
|
+
text : str
|
|
737
|
+
Text to embed.
|
|
738
|
+
|
|
739
|
+
Returns
|
|
740
|
+
-------
|
|
741
|
+
np.ndarray
|
|
742
|
+
Embedding vector for the text.
|
|
743
|
+
"""
|
|
744
|
+
# check cache
|
|
745
|
+
cached = self.cache.get(self.model_name, "embedding", text=text)
|
|
746
|
+
if cached is not None:
|
|
747
|
+
return cached
|
|
748
|
+
|
|
749
|
+
# tokenize
|
|
750
|
+
inputs = self.tokenizer(
|
|
751
|
+
text, return_tensors="pt", padding=True, truncation=True
|
|
752
|
+
)
|
|
753
|
+
input_ids = inputs["input_ids"].to(self.device)
|
|
754
|
+
attention_mask = inputs["attention_mask"].to(self.device)
|
|
755
|
+
|
|
756
|
+
# get hidden states
|
|
757
|
+
with torch.no_grad():
|
|
758
|
+
outputs = self.model(
|
|
759
|
+
input_ids=input_ids,
|
|
760
|
+
attention_mask=attention_mask,
|
|
761
|
+
output_hidden_states=True,
|
|
762
|
+
)
|
|
763
|
+
# use [CLS] token from last layer
|
|
764
|
+
hidden_states = outputs.hidden_states[-1]
|
|
765
|
+
cls_embedding = hidden_states[0, 0].cpu().numpy()
|
|
766
|
+
|
|
767
|
+
# cache result
|
|
768
|
+
self.cache.set(
|
|
769
|
+
self.model_name,
|
|
770
|
+
"embedding",
|
|
771
|
+
cls_embedding,
|
|
772
|
+
model_version=self.model_version,
|
|
773
|
+
text=text,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
return cls_embedding
|
|
777
|
+
|
|
778
|
+
def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
|
|
779
|
+
"""Compute natural language inference scores.
|
|
780
|
+
|
|
781
|
+
Not supported for masked language models.
|
|
782
|
+
|
|
783
|
+
Raises
|
|
784
|
+
------
|
|
785
|
+
NotImplementedError
|
|
786
|
+
Always raised, as MLMs don't support NLI directly.
|
|
787
|
+
"""
|
|
788
|
+
raise NotImplementedError(
|
|
789
|
+
f"NLI is not supported for masked language model {self.model_name}. "
|
|
790
|
+
"Use HuggingFaceNLI adapter with an NLI-trained model instead."
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
class HuggingFaceNLI(HuggingFaceAdapterMixin, ModelAdapter):
|
|
795
|
+
"""Adapter for HuggingFace NLI models.
|
|
796
|
+
|
|
797
|
+
Supports NLI models trained on MNLI and similar datasets
|
|
798
|
+
(e.g., "roberta-large-mnli", "microsoft/deberta-base-mnli").
|
|
799
|
+
|
|
800
|
+
Parameters
|
|
801
|
+
----------
|
|
802
|
+
model_name : str
|
|
803
|
+
HuggingFace model identifier for NLI model.
|
|
804
|
+
cache : ModelOutputCache
|
|
805
|
+
Cache instance for storing model outputs.
|
|
806
|
+
device : {"cpu", "cuda", "mps"}
|
|
807
|
+
Device to run model on. Falls back to CPU if device unavailable.
|
|
808
|
+
model_version : str
|
|
809
|
+
Version string for cache tracking.
|
|
810
|
+
|
|
811
|
+
Examples
|
|
812
|
+
--------
|
|
813
|
+
>>> from pathlib import Path
|
|
814
|
+
>>> from bead.items.cache import ModelOutputCache
|
|
815
|
+
>>> cache = ModelOutputCache(cache_dir=Path(".cache"))
|
|
816
|
+
>>> nli = HuggingFaceNLI("roberta-large-mnli", cache, device="cpu")
|
|
817
|
+
>>> scores = nli.compute_nli(
|
|
818
|
+
... premise="Mary loves reading books.",
|
|
819
|
+
... hypothesis="Mary enjoys literature."
|
|
820
|
+
... )
|
|
821
|
+
>>> label = nli.get_nli_label(
|
|
822
|
+
... premise="Mary loves reading books.",
|
|
823
|
+
... hypothesis="Mary enjoys literature."
|
|
824
|
+
... )
|
|
825
|
+
"""
|
|
826
|
+
|
|
827
|
+
def __init__(
|
|
828
|
+
self,
|
|
829
|
+
model_name: str,
|
|
830
|
+
cache: ModelOutputCache,
|
|
831
|
+
device: DeviceType = "cpu",
|
|
832
|
+
model_version: str = "unknown",
|
|
833
|
+
) -> None:
|
|
834
|
+
super().__init__(model_name, cache, model_version)
|
|
835
|
+
self.device = self._validate_device(device)
|
|
836
|
+
self._model: PreTrainedModel | None = None
|
|
837
|
+
self._tokenizer: PreTrainedTokenizerBase | None = None
|
|
838
|
+
self._label_mapping: dict[str, str] = {}
|
|
839
|
+
|
|
840
|
+
def _load_model(self) -> None:
|
|
841
|
+
"""Load model and tokenizer lazily on first use."""
|
|
842
|
+
if self._model is None:
|
|
843
|
+
logger.info(f"Loading NLI model: {self.model_name}")
|
|
844
|
+
self._model = AutoModelForSequenceClassification.from_pretrained(
|
|
845
|
+
self.model_name
|
|
846
|
+
)
|
|
847
|
+
self._model.to(self.device)
|
|
848
|
+
self._model.eval()
|
|
849
|
+
|
|
850
|
+
# Get label mapping from config
|
|
851
|
+
config = AutoConfig.from_pretrained(self.model_name)
|
|
852
|
+
if hasattr(config, "id2label"):
|
|
853
|
+
# Build mapping from model labels to standard labels
|
|
854
|
+
self._label_mapping = self._build_label_mapping(config.id2label)
|
|
855
|
+
else:
|
|
856
|
+
# Default mapping (assume standard order)
|
|
857
|
+
self._label_mapping = {
|
|
858
|
+
"0": "entailment",
|
|
859
|
+
"1": "neutral",
|
|
860
|
+
"2": "contradiction",
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
if self._tokenizer is None:
|
|
864
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
865
|
+
|
|
866
|
+
def _build_label_mapping(self, id2label: dict[int, str]) -> dict[str, str]:
|
|
867
|
+
"""Build mapping from model label IDs to standard NLI labels.
|
|
868
|
+
|
|
869
|
+
Parameters
|
|
870
|
+
----------
|
|
871
|
+
id2label
|
|
872
|
+
Mapping from label IDs to label strings from model config.
|
|
873
|
+
|
|
874
|
+
Returns
|
|
875
|
+
-------
|
|
876
|
+
dict[str, str]
|
|
877
|
+
Mapping from label IDs (as strings) to standard labels.
|
|
878
|
+
"""
|
|
879
|
+
mapping: dict[str, str] = {}
|
|
880
|
+
for idx, label in id2label.items():
|
|
881
|
+
# normalize label to lowercase
|
|
882
|
+
normalized = label.lower()
|
|
883
|
+
# map to standard labels
|
|
884
|
+
if "entail" in normalized:
|
|
885
|
+
mapping[str(idx)] = "entailment"
|
|
886
|
+
elif "neutral" in normalized:
|
|
887
|
+
mapping[str(idx)] = "neutral"
|
|
888
|
+
elif "contradict" in normalized:
|
|
889
|
+
mapping[str(idx)] = "contradiction"
|
|
890
|
+
else:
|
|
891
|
+
# keep original if we can't map it
|
|
892
|
+
mapping[str(idx)] = normalized
|
|
893
|
+
return mapping
|
|
894
|
+
|
|
895
|
+
@property
|
|
896
|
+
def model(self) -> PreTrainedModel:
|
|
897
|
+
"""Get the model, loading if necessary."""
|
|
898
|
+
self._load_model()
|
|
899
|
+
assert self._model is not None
|
|
900
|
+
return self._model
|
|
901
|
+
|
|
902
|
+
@property
|
|
903
|
+
def tokenizer(self) -> PreTrainedTokenizerBase:
|
|
904
|
+
"""Get the tokenizer, loading if necessary."""
|
|
905
|
+
self._load_model()
|
|
906
|
+
assert self._tokenizer is not None
|
|
907
|
+
return self._tokenizer
|
|
908
|
+
|
|
909
|
+
def compute_log_probability(self, text: str) -> float:
|
|
910
|
+
"""Compute log probability of text.
|
|
911
|
+
|
|
912
|
+
Not supported for NLI models.
|
|
913
|
+
|
|
914
|
+
Raises
|
|
915
|
+
------
|
|
916
|
+
NotImplementedError
|
|
917
|
+
Always raised, as NLI models don't provide log probabilities.
|
|
918
|
+
"""
|
|
919
|
+
raise NotImplementedError(
|
|
920
|
+
f"Log probability is not supported for NLI model {self.model_name}. "
|
|
921
|
+
"Use HuggingFaceLanguageModel or HuggingFaceMaskedLanguageModel instead."
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
def compute_perplexity(self, text: str) -> float:
|
|
925
|
+
"""Compute perplexity of text.
|
|
926
|
+
|
|
927
|
+
Not supported for NLI models.
|
|
928
|
+
|
|
929
|
+
Raises
|
|
930
|
+
------
|
|
931
|
+
NotImplementedError
|
|
932
|
+
Always raised, as NLI models don't provide perplexity.
|
|
933
|
+
"""
|
|
934
|
+
raise NotImplementedError(
|
|
935
|
+
f"Perplexity is not supported for NLI model {self.model_name}. "
|
|
936
|
+
"Use HuggingFaceLanguageModel or HuggingFaceMaskedLanguageModel instead."
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
def get_embedding(self, text: str) -> np.ndarray:
|
|
940
|
+
"""Get embedding vector for text.
|
|
941
|
+
|
|
942
|
+
Uses the model's encoder to get embeddings. Note that NLI models
|
|
943
|
+
are typically fine-tuned for classification, so embeddings may not
|
|
944
|
+
be optimal for general similarity tasks.
|
|
945
|
+
|
|
946
|
+
Parameters
|
|
947
|
+
----------
|
|
948
|
+
text : str
|
|
949
|
+
Text to embed.
|
|
950
|
+
|
|
951
|
+
Returns
|
|
952
|
+
-------
|
|
953
|
+
np.ndarray
|
|
954
|
+
Embedding vector for the text.
|
|
955
|
+
"""
|
|
956
|
+
# check cache
|
|
957
|
+
cached = self.cache.get(self.model_name, "embedding", text=text)
|
|
958
|
+
if cached is not None:
|
|
959
|
+
return cached
|
|
960
|
+
|
|
961
|
+
# tokenize
|
|
962
|
+
inputs = self.tokenizer(
|
|
963
|
+
text, return_tensors="pt", padding=True, truncation=True
|
|
964
|
+
)
|
|
965
|
+
input_ids = inputs["input_ids"].to(self.device)
|
|
966
|
+
attention_mask = inputs["attention_mask"].to(self.device)
|
|
967
|
+
|
|
968
|
+
# get hidden states (using base model if available)
|
|
969
|
+
with torch.no_grad():
|
|
970
|
+
# try to access base model for embeddings
|
|
971
|
+
if hasattr(self.model, "roberta"):
|
|
972
|
+
base_model = self.model.roberta
|
|
973
|
+
elif hasattr(self.model, "deberta"):
|
|
974
|
+
base_model = self.model.deberta
|
|
975
|
+
elif hasattr(self.model, "bert"):
|
|
976
|
+
base_model = self.model.bert
|
|
977
|
+
else:
|
|
978
|
+
# fallback: use full model with output_hidden_states
|
|
979
|
+
outputs = self.model(
|
|
980
|
+
input_ids=input_ids,
|
|
981
|
+
attention_mask=attention_mask,
|
|
982
|
+
output_hidden_states=True,
|
|
983
|
+
)
|
|
984
|
+
hidden_states = outputs.hidden_states[-1]
|
|
985
|
+
embedding = hidden_states[0, 0].cpu().numpy()
|
|
986
|
+
self.cache.set(
|
|
987
|
+
self.model_name,
|
|
988
|
+
"embedding",
|
|
989
|
+
embedding,
|
|
990
|
+
model_version=self.model_version,
|
|
991
|
+
text=text,
|
|
992
|
+
)
|
|
993
|
+
return embedding
|
|
994
|
+
|
|
995
|
+
# use base model
|
|
996
|
+
outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
|
|
997
|
+
# use [CLS] token
|
|
998
|
+
embedding = outputs.last_hidden_state[0, 0].cpu().numpy()
|
|
999
|
+
|
|
1000
|
+
# cache result
|
|
1001
|
+
self.cache.set(
|
|
1002
|
+
self.model_name,
|
|
1003
|
+
"embedding",
|
|
1004
|
+
embedding,
|
|
1005
|
+
model_version=self.model_version,
|
|
1006
|
+
text=text,
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
return embedding
|
|
1010
|
+
|
|
1011
|
+
def compute_nli(self, premise: str, hypothesis: str) -> dict[str, float]:
|
|
1012
|
+
"""Compute natural language inference scores.
|
|
1013
|
+
|
|
1014
|
+
Parameters
|
|
1015
|
+
----------
|
|
1016
|
+
premise : str
|
|
1017
|
+
Premise text.
|
|
1018
|
+
hypothesis : str
|
|
1019
|
+
Hypothesis text.
|
|
1020
|
+
|
|
1021
|
+
Returns
|
|
1022
|
+
-------
|
|
1023
|
+
dict[str, float]
|
|
1024
|
+
Dictionary with keys "entailment", "neutral", "contradiction"
|
|
1025
|
+
mapping to probability scores that sum to ~1.0.
|
|
1026
|
+
"""
|
|
1027
|
+
# check cache
|
|
1028
|
+
cached = self.cache.get(
|
|
1029
|
+
self.model_name, "nli", premise=premise, hypothesis=hypothesis
|
|
1030
|
+
)
|
|
1031
|
+
if cached is not None:
|
|
1032
|
+
return cached
|
|
1033
|
+
|
|
1034
|
+
# tokenize premise-hypothesis pair
|
|
1035
|
+
inputs = self.tokenizer(
|
|
1036
|
+
premise,
|
|
1037
|
+
hypothesis,
|
|
1038
|
+
return_tensors="pt",
|
|
1039
|
+
padding=True,
|
|
1040
|
+
truncation=True,
|
|
1041
|
+
)
|
|
1042
|
+
input_ids = inputs["input_ids"].to(self.device)
|
|
1043
|
+
attention_mask = inputs["attention_mask"].to(self.device)
|
|
1044
|
+
|
|
1045
|
+
# get logits
|
|
1046
|
+
with torch.no_grad():
|
|
1047
|
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
1048
|
+
logits = outputs.logits[0]
|
|
1049
|
+
|
|
1050
|
+
# convert to probabilities
|
|
1051
|
+
probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
|
|
1052
|
+
|
|
1053
|
+
# map to standard labels
|
|
1054
|
+
scores: dict[str, float] = {}
|
|
1055
|
+
for idx, prob in enumerate(probs):
|
|
1056
|
+
label = self._label_mapping.get(str(idx), str(idx))
|
|
1057
|
+
scores[label] = float(prob)
|
|
1058
|
+
|
|
1059
|
+
# ensure we have all three standard labels
|
|
1060
|
+
for label in ["entailment", "neutral", "contradiction"]:
|
|
1061
|
+
if label not in scores:
|
|
1062
|
+
scores[label] = 0.0
|
|
1063
|
+
|
|
1064
|
+
# cache result
|
|
1065
|
+
self.cache.set(
|
|
1066
|
+
self.model_name,
|
|
1067
|
+
"nli",
|
|
1068
|
+
scores,
|
|
1069
|
+
model_version=self.model_version,
|
|
1070
|
+
premise=premise,
|
|
1071
|
+
hypothesis=hypothesis,
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
return scores
|