swarmauri_embedding_mlm 0.6.0.dev154__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,229 @@
1
+ from typing import List, Union, Any, Literal
2
+ import logging
3
+ from pydantic import PrivateAttr
4
+ import torch
5
+ from torch.utils.data import TensorDataset, DataLoader
6
+ from torch.optim import AdamW
7
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
8
+
9
+ from swarmauri_base.embeddings.EmbeddingBase import EmbeddingBase
10
+ from swarmauri_standard.vectors.Vector import Vector
11
+ from swarmauri_core.ComponentBase import ComponentBase
12
+
13
+
14
+ @ComponentBase.register_type(EmbeddingBase, "MlmEmbedding")
15
+ class MlmEmbedding(EmbeddingBase):
16
+ """
17
+ EmbeddingBase implementation that fine-tunes a Masked Language Model (MLM).
18
+ """
19
+
20
+ embedding_name: str = "bert-base-uncased"
21
+ batch_size: int = 32
22
+ learning_rate: float = 5e-5
23
+ masking_ratio: float = 0.15
24
+ randomness_ratio: float = 0.10
25
+ epochs: int = 0
26
+ add_new_tokens: bool = False
27
+ _tokenizer = PrivateAttr()
28
+ _model = PrivateAttr()
29
+ _device = PrivateAttr()
30
+ _mask_token_id = PrivateAttr()
31
+ type: Literal["MlmEmbedding"] = "MlmEmbedding"
32
+
33
+ def __init__(self, **kwargs):
34
+ super().__init__(**kwargs)
35
+ self._tokenizer = AutoTokenizer.from_pretrained(self.embedding_name)
36
+ self._model = AutoModelForMaskedLM.from_pretrained(self.embedding_name)
37
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ self._model.to(self._device)
39
+ self._mask_token_id = self._tokenizer.convert_tokens_to_ids(
40
+ [self._tokenizer.mask_token]
41
+ )[0]
42
+
43
+ def extract_features(self) -> List[str]:
44
+ """
45
+ Extracts the tokens from the vocabulary of the fine-tuned MLM.
46
+
47
+ Returns:
48
+ - List[str]: A list of token strings in the model's vocabulary.
49
+ """
50
+ # Get the vocabulary size
51
+ vocab_size = len(self._tokenizer)
52
+
53
+ # Retrieve the token strings for each id in the vocabulary
54
+ token_strings = [
55
+ self._tokenizer.convert_ids_to_tokens(i) for i in range(vocab_size)
56
+ ]
57
+
58
+ return token_strings
59
+
60
+ def _mask_tokens(self, encodings):
61
+ input_ids = encodings.input_ids.to(self._device)
62
+ attention_mask = encodings.attention_mask.to(self._device)
63
+
64
+ labels = input_ids.detach().clone()
65
+
66
+ probability_matrix = torch.full(
67
+ labels.shape, self.masking_ratio, device=self._device
68
+ )
69
+ special_tokens_mask = [
70
+ self._tokenizer.get_special_tokens_mask(
71
+ val, already_has_special_tokens=True
72
+ )
73
+ for val in labels.tolist()
74
+ ]
75
+ probability_matrix.masked_fill_(
76
+ torch.tensor(special_tokens_mask, dtype=torch.bool, device=self._device),
77
+ value=0.0,
78
+ )
79
+ masked_indices = torch.bernoulli(probability_matrix).bool()
80
+
81
+ labels[~masked_indices] = -100
82
+
83
+ indices_replaced = (
84
+ torch.bernoulli(
85
+ torch.full(labels.shape, self.masking_ratio, device=self._device)
86
+ ).bool()
87
+ & masked_indices
88
+ )
89
+ input_ids[indices_replaced] = self._mask_token_id
90
+
91
+ indices_random = (
92
+ torch.bernoulli(
93
+ torch.full(labels.shape, self.randomness_ratio, device=self._device)
94
+ ).bool()
95
+ & masked_indices
96
+ & ~indices_replaced
97
+ )
98
+ random_words = torch.randint(
99
+ len(self._tokenizer), labels.shape, dtype=torch.long, device=self._device
100
+ )
101
+ input_ids[indices_random] = random_words[indices_random]
102
+
103
+ return input_ids, attention_mask, labels
104
+
105
+ def fit(self, documents: List[Union[str, Any]]):
106
+ # Check if we need to add new tokens
107
+ if self.add_new_tokens:
108
+ new_tokens = self.find_new_tokens(documents)
109
+ if new_tokens:
110
+ num_added_toks = self._tokenizer.add_tokens(new_tokens)
111
+ if num_added_toks > 0:
112
+ logging.info(f"Added {num_added_toks} new tokens.")
113
+ self.model.resize_token_embeddings(len(self._tokenizer))
114
+
115
+ encodings = self._tokenizer(
116
+ documents,
117
+ return_tensors="pt",
118
+ padding=True,
119
+ truncation=True,
120
+ max_length=512,
121
+ )
122
+ input_ids, attention_mask, labels = self._mask_tokens(encodings)
123
+ optimizer = AdamW(self._model.parameters(), lr=self.learning_rate)
124
+ dataset = TensorDataset(input_ids, attention_mask, labels)
125
+ data_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
126
+
127
+ self._model.train()
128
+ for batch in data_loader:
129
+ batch = {
130
+ k: v.to(self._device)
131
+ for k, v in zip(["input_ids", "attention_mask", "labels"], batch)
132
+ }
133
+ outputs = self._model(**batch)
134
+ loss = outputs.loss
135
+ optimizer.zero_grad()
136
+ loss.backward()
137
+ optimizer.step()
138
+ self.epochs += 1
139
+ logging.info(f"Epoch {self.epochs} complete. Loss {loss.item()}")
140
+
141
+ def find_new_tokens(self, documents):
142
+ # Identify unique words in documents that are not in the tokenizer's vocabulary
143
+ unique_words = set()
144
+ for doc in documents:
145
+ tokens = set(doc.split()) # Simple whitespace tokenization
146
+ unique_words.update(tokens)
147
+ existing_vocab = set(self._tokenizer.get_vocab().keys())
148
+ new_tokens = list(unique_words - existing_vocab)
149
+ return new_tokens if new_tokens else None
150
+
151
+ def transform(self, documents: List[Union[str, Any]]) -> List[Vector]:
152
+ """
153
+ Generates embeddings for a list of documents using the fine-tuned MLM.
154
+ """
155
+ self._model.eval()
156
+ embedding_list = []
157
+
158
+ for document in documents:
159
+ inputs = self._tokenizer(
160
+ document,
161
+ return_tensors="pt",
162
+ padding=True,
163
+ truncation=True,
164
+ max_length=512,
165
+ )
166
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
167
+ with torch.no_grad():
168
+ outputs = self._model(**inputs)
169
+ # Extract embedding (for simplicity, averaging the last hidden states)
170
+ if hasattr(outputs, "last_hidden_state"):
171
+ embedding = outputs.last_hidden_state.mean(1)
172
+ else:
173
+ # Fallback or corrected attribute access
174
+ embedding = outputs["logits"].mean(1)
175
+ embedding = embedding.cpu().numpy()
176
+ embedding_list.append(Vector(value=embedding.squeeze().tolist()))
177
+
178
+ return embedding_list
179
+
180
+ def fit_transform(self, documents: List[Union[str, Any]], **kwargs) -> List[Vector]:
181
+ """
182
+ Fine-tunes the MLM and generates embeddings for the provided documents.
183
+ """
184
+ self.fit(documents, **kwargs)
185
+ return self.transform(documents)
186
+
187
+ def infer_vector(self, data: Union[str, Any], *args, **kwargs) -> Vector:
188
+ """
189
+ Generates an embedding for the input data.
190
+
191
+ Parameters:
192
+ - data (Union[str, Any]): The input data, expected to be a textual representation.
193
+ Could be a single string or a batch of strings.
194
+ """
195
+ # Tokenize the input data and ensure the tensors are on the correct device.
196
+ self._model.eval()
197
+ inputs = self._tokenizer(
198
+ data, return_tensors="pt", padding=True, truncation=True, max_length=512
199
+ )
200
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
201
+
202
+ # Generate embeddings using the model
203
+ with torch.no_grad():
204
+ outputs = self._model(**inputs)
205
+
206
+ if hasattr(outputs, "last_hidden_state"):
207
+ # Access the last layer and calculate the mean across all tokens (simple pooling)
208
+ embedding = outputs.last_hidden_state.mean(dim=1)
209
+ else:
210
+ embedding = outputs["logits"].mean(1)
211
+ # Move the embeddings back to CPU for compatibility with downstream tasks if necessary
212
+ embedding = embedding.cpu().numpy()
213
+
214
+ return Vector(value=embedding.squeeze().tolist())
215
+
216
+ def save_model(self, path: str) -> None:
217
+ """
218
+ Saves the model and tokenizer to the specified directory.
219
+ """
220
+ self._model.save_pretrained(path)
221
+ self._tokenizer.save_pretrained(path)
222
+
223
+ def load_model(self, path: str) -> None:
224
+ """
225
+ Loads the model and tokenizer from the specified directory.
226
+ """
227
+ self._model = AutoModelForMaskedLM.from_pretrained(path)
228
+ self._tokenizer = AutoTokenizer.from_pretrained(path)
229
+ self._model.to(self._device) # Ensure the model is loaded to the correct device
@@ -0,0 +1,12 @@
1
+ from .MlmEmbedding import MlmEmbedding
2
+
3
+ __version__ = "0.6.0.dev26"
4
+ __long_desc__ = """
5
+
6
+ # Swarmauri Mlm Embedding Plugin
7
+
8
+ Visit us at: https://swarmauri.com
9
+ Follow us at: https://github.com/swarmauri
10
+ Star us at: https://github.com/swarmauri/swarmauri-sdk
11
+
12
+ """
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.3
2
+ Name: swarmauri_embedding_mlm
3
+ Version: 0.6.0.dev154
4
+ Summary: example community package
5
+ License: Apache-2.0
6
+ Author: Jacob Stewart
7
+ Author-email: jacob@swarmauri.com
8
+ Requires-Python: >=3.10,<3.13
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.10
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Requires-Dist: swarmauri_base (>=0.6.0.dev154,<0.7.0)
15
+ Requires-Dist: swarmauri_core (>=0.6.0.dev154,<0.7.0)
16
+ Project-URL: Repository, http://github.com/swarmauri/swarmauri-sdk
17
+ Description-Content-Type: text/markdown
18
+
19
+ # Swarmauri Example Community Package
@@ -0,0 +1,6 @@
1
+ swarmauri_embedding_mlm/__init__.py,sha256=hBjPhsN8xkYyoky86ZepjcoZ0J9gIdECszdXHtynz48,273
2
+ swarmauri_embedding_mlm/MlmEmbedding.py,sha256=ralHd6Q5I4DaAjuatgNLn-FiHIEsz32ggp07SL5FVdI,8910
3
+ swarmauri_embedding_mlm-0.6.0.dev154.dist-info/entry_points.txt,sha256=jOovi3L_GTOA6c_-AgE5RrBYoZh-cfmnu_A02o1yQV4,87
4
+ swarmauri_embedding_mlm-0.6.0.dev154.dist-info/METADATA,sha256=nXTQ9ebnBuuPlcx7KNHkZG9T0-8O_MbueNKVl5xOxDI,733
5
+ swarmauri_embedding_mlm-0.6.0.dev154.dist-info/WHEEL,sha256=IYZQI976HJqqOpQU6PHkJ8fb3tMNBFjg-Cn-pwAbaFM,88
6
+ swarmauri_embedding_mlm-0.6.0.dev154.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 2.0.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,3 @@
1
+ [swarmauri.embeddings]
2
+ MlmEmbedding=swarmauri_embedding_mlm.MlmEmbedding:MlmEmbedding
3
+