swarmauri_embedding_mlm 0.6.0.dev154__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swarmauri_embedding_mlm/MlmEmbedding.py +229 -0
- swarmauri_embedding_mlm/__init__.py +12 -0
- swarmauri_embedding_mlm-0.6.0.dev154.dist-info/METADATA +19 -0
- swarmauri_embedding_mlm-0.6.0.dev154.dist-info/RECORD +6 -0
- swarmauri_embedding_mlm-0.6.0.dev154.dist-info/WHEEL +4 -0
- swarmauri_embedding_mlm-0.6.0.dev154.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
from typing import List, Union, Any, Literal
|
|
2
|
+
import logging
|
|
3
|
+
from pydantic import PrivateAttr
|
|
4
|
+
import torch
|
|
5
|
+
from torch.utils.data import TensorDataset, DataLoader
|
|
6
|
+
from torch.optim import AdamW
|
|
7
|
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
8
|
+
|
|
9
|
+
from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase
|
|
10
|
+
from swarmauri_standard.vectors.Vector import Vector
|
|
11
|
+
from swarmauri_core.ComponentBase import ComponentBase
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@ComponentBase.register_type(EmbeddingBase, "MlmEmbedding")
|
|
15
|
+
class MlmEmbedding(EmbeddingBase):
|
|
16
|
+
"""
|
|
17
|
+
EmbeddingBase implementation that fine-tunes a Masked Language Model (MLM).
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
embedding_name: str = "bert-base-uncased"
|
|
21
|
+
batch_size: int = 32
|
|
22
|
+
learning_rate: float = 5e-5
|
|
23
|
+
masking_ratio: float = 0.15
|
|
24
|
+
randomness_ratio: float = 0.10
|
|
25
|
+
epochs: int = 0
|
|
26
|
+
add_new_tokens: bool = False
|
|
27
|
+
_tokenizer = PrivateAttr()
|
|
28
|
+
_model = PrivateAttr()
|
|
29
|
+
_device = PrivateAttr()
|
|
30
|
+
_mask_token_id = PrivateAttr()
|
|
31
|
+
type: Literal["MlmEmbedding"] = "MlmEmbedding"
|
|
32
|
+
|
|
33
|
+
def __init__(self, **kwargs):
|
|
34
|
+
super().__init__(**kwargs)
|
|
35
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self.embedding_name)
|
|
36
|
+
self._model = AutoModelForMaskedLM.from_pretrained(self.embedding_name)
|
|
37
|
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
38
|
+
self._model.to(self._device)
|
|
39
|
+
self._mask_token_id = self._tokenizer.convert_tokens_to_ids(
|
|
40
|
+
[self._tokenizer.mask_token]
|
|
41
|
+
)[0]
|
|
42
|
+
|
|
43
|
+
def extract_features(self) -> List[str]:
|
|
44
|
+
"""
|
|
45
|
+
Extracts the tokens from the vocabulary of the fine-tuned MLM.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
- List[str]: A list of token strings in the model's vocabulary.
|
|
49
|
+
"""
|
|
50
|
+
# Get the vocabulary size
|
|
51
|
+
vocab_size = len(self._tokenizer)
|
|
52
|
+
|
|
53
|
+
# Retrieve the token strings for each id in the vocabulary
|
|
54
|
+
token_strings = [
|
|
55
|
+
self._tokenizer.convert_ids_to_tokens(i) for i in range(vocab_size)
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
return token_strings
|
|
59
|
+
|
|
60
|
+
def _mask_tokens(self, encodings):
|
|
61
|
+
input_ids = encodings.input_ids.to(self._device)
|
|
62
|
+
attention_mask = encodings.attention_mask.to(self._device)
|
|
63
|
+
|
|
64
|
+
labels = input_ids.detach().clone()
|
|
65
|
+
|
|
66
|
+
probability_matrix = torch.full(
|
|
67
|
+
labels.shape, self.masking_ratio, device=self._device
|
|
68
|
+
)
|
|
69
|
+
special_tokens_mask = [
|
|
70
|
+
self._tokenizer.get_special_tokens_mask(
|
|
71
|
+
val, already_has_special_tokens=True
|
|
72
|
+
)
|
|
73
|
+
for val in labels.tolist()
|
|
74
|
+
]
|
|
75
|
+
probability_matrix.masked_fill_(
|
|
76
|
+
torch.tensor(special_tokens_mask, dtype=torch.bool, device=self._device),
|
|
77
|
+
value=0.0,
|
|
78
|
+
)
|
|
79
|
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
|
80
|
+
|
|
81
|
+
labels[~masked_indices] = -100
|
|
82
|
+
|
|
83
|
+
indices_replaced = (
|
|
84
|
+
torch.bernoulli(
|
|
85
|
+
torch.full(labels.shape, self.masking_ratio, device=self._device)
|
|
86
|
+
).bool()
|
|
87
|
+
& masked_indices
|
|
88
|
+
)
|
|
89
|
+
input_ids[indices_replaced] = self._mask_token_id
|
|
90
|
+
|
|
91
|
+
indices_random = (
|
|
92
|
+
torch.bernoulli(
|
|
93
|
+
torch.full(labels.shape, self.randomness_ratio, device=self._device)
|
|
94
|
+
).bool()
|
|
95
|
+
& masked_indices
|
|
96
|
+
& ~indices_replaced
|
|
97
|
+
)
|
|
98
|
+
random_words = torch.randint(
|
|
99
|
+
len(self._tokenizer), labels.shape, dtype=torch.long, device=self._device
|
|
100
|
+
)
|
|
101
|
+
input_ids[indices_random] = random_words[indices_random]
|
|
102
|
+
|
|
103
|
+
return input_ids, attention_mask, labels
|
|
104
|
+
|
|
105
|
+
def fit(self, documents: List[Union[str, Any]]):
|
|
106
|
+
# Check if we need to add new tokens
|
|
107
|
+
if self.add_new_tokens:
|
|
108
|
+
new_tokens = self.find_new_tokens(documents)
|
|
109
|
+
if new_tokens:
|
|
110
|
+
num_added_toks = self._tokenizer.add_tokens(new_tokens)
|
|
111
|
+
if num_added_toks > 0:
|
|
112
|
+
logging.info(f"Added {num_added_toks} new tokens.")
|
|
113
|
+
self.model.resize_token_embeddings(len(self._tokenizer))
|
|
114
|
+
|
|
115
|
+
encodings = self._tokenizer(
|
|
116
|
+
documents,
|
|
117
|
+
return_tensors="pt",
|
|
118
|
+
padding=True,
|
|
119
|
+
truncation=True,
|
|
120
|
+
max_length=512,
|
|
121
|
+
)
|
|
122
|
+
input_ids, attention_mask, labels = self._mask_tokens(encodings)
|
|
123
|
+
optimizer = AdamW(self._model.parameters(), lr=self.learning_rate)
|
|
124
|
+
dataset = TensorDataset(input_ids, attention_mask, labels)
|
|
125
|
+
data_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
|
|
126
|
+
|
|
127
|
+
self._model.train()
|
|
128
|
+
for batch in data_loader:
|
|
129
|
+
batch = {
|
|
130
|
+
k: v.to(self._device)
|
|
131
|
+
for k, v in zip(["input_ids", "attention_mask", "labels"], batch)
|
|
132
|
+
}
|
|
133
|
+
outputs = self._model(**batch)
|
|
134
|
+
loss = outputs.loss
|
|
135
|
+
optimizer.zero_grad()
|
|
136
|
+
loss.backward()
|
|
137
|
+
optimizer.step()
|
|
138
|
+
self.epochs += 1
|
|
139
|
+
logging.info(f"Epoch {self.epochs} complete. Loss {loss.item()}")
|
|
140
|
+
|
|
141
|
+
def find_new_tokens(self, documents):
|
|
142
|
+
# Identify unique words in documents that are not in the tokenizer's vocabulary
|
|
143
|
+
unique_words = set()
|
|
144
|
+
for doc in documents:
|
|
145
|
+
tokens = set(doc.split()) # Simple whitespace tokenization
|
|
146
|
+
unique_words.update(tokens)
|
|
147
|
+
existing_vocab = set(self._tokenizer.get_vocab().keys())
|
|
148
|
+
new_tokens = list(unique_words - existing_vocab)
|
|
149
|
+
return new_tokens if new_tokens else None
|
|
150
|
+
|
|
151
|
+
def transform(self, documents: List[Union[str, Any]]) -> List[Vector]:
|
|
152
|
+
"""
|
|
153
|
+
Generates embeddings for a list of documents using the fine-tuned MLM.
|
|
154
|
+
"""
|
|
155
|
+
self._model.eval()
|
|
156
|
+
embedding_list = []
|
|
157
|
+
|
|
158
|
+
for document in documents:
|
|
159
|
+
inputs = self._tokenizer(
|
|
160
|
+
document,
|
|
161
|
+
return_tensors="pt",
|
|
162
|
+
padding=True,
|
|
163
|
+
truncation=True,
|
|
164
|
+
max_length=512,
|
|
165
|
+
)
|
|
166
|
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
167
|
+
with torch.no_grad():
|
|
168
|
+
outputs = self._model(**inputs)
|
|
169
|
+
# Extract embedding (for simplicity, averaging the last hidden states)
|
|
170
|
+
if hasattr(outputs, "last_hidden_state"):
|
|
171
|
+
embedding = outputs.last_hidden_state.mean(1)
|
|
172
|
+
else:
|
|
173
|
+
# Fallback or corrected attribute access
|
|
174
|
+
embedding = outputs["logits"].mean(1)
|
|
175
|
+
embedding = embedding.cpu().numpy()
|
|
176
|
+
embedding_list.append(Vector(value=embedding.squeeze().tolist()))
|
|
177
|
+
|
|
178
|
+
return embedding_list
|
|
179
|
+
|
|
180
|
+
def fit_transform(self, documents: List[Union[str, Any]], **kwargs) -> List[Vector]:
|
|
181
|
+
"""
|
|
182
|
+
Fine-tunes the MLM and generates embeddings for the provided documents.
|
|
183
|
+
"""
|
|
184
|
+
self.fit(documents, **kwargs)
|
|
185
|
+
return self.transform(documents)
|
|
186
|
+
|
|
187
|
+
def infer_vector(self, data: Union[str, Any], *args, **kwargs) -> Vector:
|
|
188
|
+
"""
|
|
189
|
+
Generates an embedding for the input data.
|
|
190
|
+
|
|
191
|
+
Parameters:
|
|
192
|
+
- data (Union[str, Any]): The input data, expected to be a textual representation.
|
|
193
|
+
Could be a single string or a batch of strings.
|
|
194
|
+
"""
|
|
195
|
+
# Tokenize the input data and ensure the tensors are on the correct device.
|
|
196
|
+
self._model.eval()
|
|
197
|
+
inputs = self._tokenizer(
|
|
198
|
+
data, return_tensors="pt", padding=True, truncation=True, max_length=512
|
|
199
|
+
)
|
|
200
|
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
201
|
+
|
|
202
|
+
# Generate embeddings using the model
|
|
203
|
+
with torch.no_grad():
|
|
204
|
+
outputs = self._model(**inputs)
|
|
205
|
+
|
|
206
|
+
if hasattr(outputs, "last_hidden_state"):
|
|
207
|
+
# Access the last layer and calculate the mean across all tokens (simple pooling)
|
|
208
|
+
embedding = outputs.last_hidden_state.mean(dim=1)
|
|
209
|
+
else:
|
|
210
|
+
embedding = outputs["logits"].mean(1)
|
|
211
|
+
# Move the embeddings back to CPU for compatibility with downstream tasks if necessary
|
|
212
|
+
embedding = embedding.cpu().numpy()
|
|
213
|
+
|
|
214
|
+
return Vector(value=embedding.squeeze().tolist())
|
|
215
|
+
|
|
216
|
+
def save_model(self, path: str) -> None:
|
|
217
|
+
"""
|
|
218
|
+
Saves the model and tokenizer to the specified directory.
|
|
219
|
+
"""
|
|
220
|
+
self._model.save_pretrained(path)
|
|
221
|
+
self._tokenizer.save_pretrained(path)
|
|
222
|
+
|
|
223
|
+
def load_model(self, path: str) -> None:
|
|
224
|
+
"""
|
|
225
|
+
Loads the model and tokenizer from the specified directory.
|
|
226
|
+
"""
|
|
227
|
+
self._model = AutoModelForMaskedLM.from_pretrained(path)
|
|
228
|
+
self._tokenizer = AutoTokenizer.from_pretrained(path)
|
|
229
|
+
self._model.to(self._device) # Ensure the model is loaded to the correct device
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .MlmEmbedding import MlmEmbedding
|
|
2
|
+
|
|
3
|
+
__version__ = "0.6.0.dev26"
|
|
4
|
+
__long_desc__ = """
|
|
5
|
+
|
|
6
|
+
# Swarmauri Mlm Embedding Plugin
|
|
7
|
+
|
|
8
|
+
Visit us at: https://swarmauri.com
|
|
9
|
+
Follow us at: https://github.com/swarmauri
|
|
10
|
+
Star us at: https://github.com/swarmauri/swarmauri-sdk
|
|
11
|
+
|
|
12
|
+
"""
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: swarmauri_embedding_mlm
|
|
3
|
+
Version: 0.6.0.dev154
|
|
4
|
+
Summary: example community package
|
|
5
|
+
License: Apache-2.0
|
|
6
|
+
Author: Jacob Stewart
|
|
7
|
+
Author-email: jacob@swarmauri.com
|
|
8
|
+
Requires-Python: >=3.10,<3.13
|
|
9
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Requires-Dist: swarmauri_base (>=0.6.0.dev154,<0.7.0)
|
|
15
|
+
Requires-Dist: swarmauri_core (>=0.6.0.dev154,<0.7.0)
|
|
16
|
+
Project-URL: Repository, http://github.com/swarmauri/swarmauri-sdk
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
|
|
19
|
+
# Swarmauri Example Community Package
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
swarmauri_embedding_mlm/__init__.py,sha256=hBjPhsN8xkYyoky86ZepjcoZ0J9gIdECszdXHtynz48,273
|
|
2
|
+
swarmauri_embedding_mlm/MlmEmbedding.py,sha256=ralHd6Q5I4DaAjuatgNLn-FiHIEsz32ggp07SL5FVdI,8910
|
|
3
|
+
swarmauri_embedding_mlm-0.6.0.dev154.dist-info/entry_points.txt,sha256=jOovi3L_GTOA6c_-AgE5RrBYoZh-cfmnu_A02o1yQV4,87
|
|
4
|
+
swarmauri_embedding_mlm-0.6.0.dev154.dist-info/METADATA,sha256=nXTQ9ebnBuuPlcx7KNHkZG9T0-8O_MbueNKVl5xOxDI,733
|
|
5
|
+
swarmauri_embedding_mlm-0.6.0.dev154.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
|
|
6
|
+
swarmauri_embedding_mlm-0.6.0.dev154.dist-info/RECORD,,
|