robo-lib 0.0.10__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- robo_lib/__init__.py +2 -3
- robo_lib/components.py +204 -192
- {robo_lib-0.0.10.dist-info → robo_lib-1.0.0.dist-info}/METADATA +2 -2
- robo_lib-1.0.0.dist-info/RECORD +6 -0
- {robo_lib-0.0.10.dist-info → robo_lib-1.0.0.dist-info}/WHEEL +1 -1
- robo_lib-0.0.10.dist-info/RECORD +0 -6
- {robo_lib-0.0.10.dist-info → robo_lib-1.0.0.dist-info}/licenses/LICENSE +0 -0
robo_lib/__init__.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
1
|
from .components import TokenizerConstructor as TokenizerConstructor
|
2
2
|
from .components import create_mask as create_mask
|
3
|
-
from .components import
|
4
|
-
from .components import
|
5
|
-
from .components import scan_max_block_size as scan_max_block_size
|
3
|
+
from .components import pre_process_data as pre_process_data
|
4
|
+
from .components import safe_stack as safe_stack
|
6
5
|
from .components import DataProcessor as DataProcessor
|
7
6
|
from .components import get_valid_samples as get_valid_samples
|
8
7
|
from .components import get_batch as get_batch
|
robo_lib/components.py
CHANGED
@@ -6,6 +6,8 @@ import numpy as np
|
|
6
6
|
import random
|
7
7
|
import pickle
|
8
8
|
import itertools
|
9
|
+
from pathlib import Path
|
10
|
+
import os
|
9
11
|
|
10
12
|
class TokenizerConstructor:
|
11
13
|
'''
|
@@ -30,6 +32,7 @@ class TokenizerConstructor:
|
|
30
32
|
tokenizer_type:str="BPE",
|
31
33
|
pre_tokenizers:list[str]|str=["Whitespace"],
|
32
34
|
normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
|
35
|
+
vocab:dict[str,int] = {},
|
33
36
|
special_tokens:list[str]|str=[],
|
34
37
|
unknown_token_string:str="<unk>",
|
35
38
|
start_token_string:str="<sos>",
|
@@ -42,25 +45,28 @@ class TokenizerConstructor:
|
|
42
45
|
|
43
46
|
if isinstance(special_tokens, str):
|
44
47
|
special_tokens = [special_tokens]
|
45
|
-
self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token
|
46
|
-
self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string
|
47
|
-
self.start_token = self.special_tokens.index(start_token_string) if start_token_string
|
48
|
-
self.end_token = self.special_tokens.index(end_token_string) if end_token_string
|
49
|
-
self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string
|
50
|
-
self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string
|
48
|
+
self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token is not None]
|
49
|
+
self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string is not None else None
|
50
|
+
self.start_token = self.special_tokens.index(start_token_string) if start_token_string is not None else None
|
51
|
+
self.end_token = self.special_tokens.index(end_token_string) if end_token_string is not None else None
|
52
|
+
self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string is not None else None
|
53
|
+
self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string is not None else None
|
51
54
|
|
52
55
|
if tokenizer_type == "BPE":
|
53
56
|
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
|
54
57
|
self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
55
58
|
elif tokenizer_type == "WordLevel":
|
56
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
|
59
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(vocab = vocab, unk_token=unknown_token_string))
|
57
60
|
self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
58
61
|
elif tokenizer_type == "WordPiece":
|
59
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
|
62
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(vocab = vocab, unk_token=unknown_token_string))
|
60
63
|
self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
61
64
|
elif tokenizer_type == "Unigram":
|
62
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(
|
65
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram())
|
63
66
|
self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
67
|
+
|
68
|
+
if self.pad_token is not None:
|
69
|
+
self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=pad_token_string)
|
64
70
|
|
65
71
|
if isinstance(pre_tokenizers, str):
|
66
72
|
pre_tokenizers = [pre_tokenizers]
|
@@ -122,6 +128,13 @@ class TokenizerConstructor:
|
|
122
128
|
def encode(self, inp:str) -> list[int]:
|
123
129
|
return self.tokenizer_type.encode(inp).ids
|
124
130
|
|
131
|
+
def encode_batch(self, inp:list[str], max_length:int=None) -> list[list[int]]:
|
132
|
+
if max_length is not None:
|
133
|
+
self.tokenizer_type.enable_truncation(max_length=max_length)
|
134
|
+
out = [row.ids for row in self.tokenizer_type.encode_batch(inp)]
|
135
|
+
self.tokenizer_type.no_truncation()
|
136
|
+
return out
|
137
|
+
|
125
138
|
def decode(self, inp:list[int]) -> str:
|
126
139
|
return self.tokenizer_type.decode(inp)
|
127
140
|
|
@@ -136,38 +149,35 @@ def create_mask(row:list, block_size:int) -> list[bool]:
|
|
136
149
|
mask = [1]*len(row) + [0]*(block_size - len(row))
|
137
150
|
return mask
|
138
151
|
|
139
|
-
def
|
140
|
-
'''
|
141
|
-
|
142
|
-
returns padded row. Row is padded until length block_size with specified pad_token value
|
143
|
-
|
144
|
-
'''
|
145
|
-
row.extend([pad_token]*(block_size - len(row)))
|
146
|
-
return row
|
147
|
-
|
148
|
-
def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
|
152
|
+
def pre_process_data(data:str, start_token_string:str, end_token_string:str) -> list[int]:
|
149
153
|
'''
|
150
154
|
|
151
|
-
returns
|
155
|
+
returns string row with the tokenizer's start and end tokens if they exist
|
152
156
|
|
153
157
|
'''
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
158
|
+
if start_token_string is None and end_token_string is None:
|
159
|
+
return data
|
160
|
+
else:
|
161
|
+
for i in range(len(data)):
|
162
|
+
if start_token_string is not None:
|
163
|
+
data[i] = start_token_string + data[i]
|
164
|
+
if end_token_string is not None:
|
165
|
+
data[i] = data[i] + end_token_string
|
166
|
+
|
167
|
+
return data
|
161
168
|
|
162
|
-
def
|
169
|
+
def safe_stack(tensor_list:list[torch.tensor]) -> torch.tensor:
|
163
170
|
'''
|
164
171
|
|
165
|
-
|
172
|
+
torch stack with check to ensure tensors are valid in input list
|
173
|
+
|
174
|
+
returns torch.stack(out_list) for all valid torch tensors in tensor_list. raises error if no valid tensors
|
166
175
|
|
167
176
|
'''
|
168
|
-
|
169
|
-
|
170
|
-
|
177
|
+
out_list = [row for row in tensor_list if isinstance(row, torch.Tensor)]
|
178
|
+
if len(out_list) == 0:
|
179
|
+
raise ValueError("no valid tensors in list.")
|
180
|
+
return torch.stack(out_list)
|
171
181
|
|
172
182
|
|
173
183
|
class DataProcessor:
|
@@ -196,93 +206,55 @@ class DataProcessor:
|
|
196
206
|
self.enc_tokenizer = enc_tokenizer
|
197
207
|
|
198
208
|
def process_list(self,
|
199
|
-
save_path:str,
|
200
209
|
dec_data:list[str]|str,
|
201
210
|
dec_max_block_size:int=None,
|
202
211
|
dec_create_masks:bool=True,
|
203
|
-
dec_block_size_exceeded_policy:str=None,
|
204
212
|
enc_data:list[str]=None,
|
205
213
|
enc_max_block_size:int=None,
|
206
214
|
enc_create_masks:bool=True,
|
207
|
-
|
215
|
+
save_path:str = "."
|
208
216
|
) -> None:
|
209
217
|
|
210
218
|
if isinstance(dec_data, str):
|
211
219
|
dec_data = [dec_data]
|
212
220
|
dec_data_length = len(dec_data)
|
213
|
-
save_path = save_path.replace(".pt", "")
|
214
|
-
|
215
|
-
if dec_max_block_size == None:
|
216
|
-
dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
|
217
221
|
|
218
|
-
if enc_data
|
219
|
-
|
222
|
+
if enc_data is not None:
|
223
|
+
if self.enc_tokenizer is None:
|
224
|
+
self.enc_tokenizer = self.dec_tokenizer
|
220
225
|
|
221
226
|
enc_data_length = len(enc_data)
|
222
227
|
if dec_data_length != enc_data_length:
|
223
|
-
raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
|
228
|
+
raise Exception(f"decoder data and encoder data lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
|
224
229
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
dec_processed_item = dec_processed_item[:dec_max_block_size]
|
241
|
-
elif dec_block_size_exceeded_policy == "skip":
|
242
|
-
continue
|
243
|
-
elif dec_block_size_exceeded_policy == None:
|
244
|
-
raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
|
245
|
-
if dec_create_masks:
|
246
|
-
dec_mask = create_mask(dec_processed_item, dec_max_block_size)
|
247
|
-
dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
|
248
|
-
|
249
|
-
if enc_data != None:
|
250
|
-
enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
|
251
|
-
if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
|
252
|
-
if enc_block_size_exceeded_policy == "trim":
|
253
|
-
enc_processed_item = enc_processed_item[:enc_max_block_size]
|
254
|
-
elif enc_block_size_exceeded_policy == "skip":
|
255
|
-
continue
|
256
|
-
elif enc_block_size_exceeded_policy == None:
|
257
|
-
raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
|
258
|
-
if enc_create_masks:
|
259
|
-
enc_mask = create_mask(enc_processed_item, enc_max_block_size)
|
260
|
-
enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
|
261
|
-
|
262
|
-
dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
|
263
|
-
if dec_create_masks:
|
264
|
-
dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
|
265
|
-
|
266
|
-
if enc_data != None:
|
267
|
-
enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
|
268
|
-
if enc_create_masks:
|
269
|
-
enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
|
270
|
-
|
271
|
-
dec_out_list = torch.stack([row for row in dec_out_list if row != []])
|
272
|
-
torch.save(dec_out_list, save_path + "_decoder_data.pt")
|
230
|
+
print("processing data")
|
231
|
+
dec_out_list = self.dec_tokenizer.encode_batch(dec_data, max_length=dec_max_block_size)
|
232
|
+
if dec_create_masks:
|
233
|
+
mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.dec_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
|
234
|
+
dec_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in dec_out_list])
|
235
|
+
|
236
|
+
if enc_data is not None:
|
237
|
+
enc_out_list = self.enc_tokenizer.encode_batch(enc_data, max_length=enc_max_block_size)
|
238
|
+
if enc_create_masks:
|
239
|
+
mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.enc_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
|
240
|
+
enc_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in enc_out_list])
|
241
|
+
|
242
|
+
dec_out_list = torch.tensor(dec_out_list, dtype=torch.long)
|
243
|
+
Path(save_path).mkdir(parents=True, exist_ok=True)
|
244
|
+
torch.save(dec_out_list, os.path.join(save_path, "decoder_data.pt"))
|
273
245
|
if dec_create_masks:
|
274
|
-
dec_mask_list = torch.
|
275
|
-
torch.save(dec_mask_list, save_path
|
276
|
-
if enc_data
|
277
|
-
enc_out_list = torch.
|
278
|
-
torch.save(enc_out_list, save_path
|
246
|
+
dec_mask_list = torch.tensor(dec_mask_list, dtype=torch.long)
|
247
|
+
torch.save(dec_mask_list, os.path.join(save_path, "decoder_mask_data.pt"))
|
248
|
+
if enc_data is not None:
|
249
|
+
enc_out_list = torch.tensor(enc_out_list, dtype=torch.long)
|
250
|
+
torch.save(enc_out_list, os.path.join(save_path, "encoder_data.pt"))
|
279
251
|
if enc_create_masks:
|
280
|
-
enc_mask_list = torch.
|
281
|
-
torch.save(enc_mask_list, save_path
|
252
|
+
enc_mask_list = torch.tensor(enc_mask_list, dtype=torch.long)
|
253
|
+
torch.save(enc_mask_list, os.path.join(save_path, "encoder_mask_data.pt"))
|
282
254
|
|
283
255
|
|
284
|
-
def get_valid_samples(random_samples:torch.
|
285
|
-
masks:torch.
|
256
|
+
def get_valid_samples(random_samples:torch.Tensor,
|
257
|
+
masks:torch.Tensor,
|
286
258
|
block_size:int
|
287
259
|
) -> list[int]:
|
288
260
|
'''
|
@@ -294,9 +266,9 @@ def get_valid_samples(random_samples:torch.tensor,
|
|
294
266
|
valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
|
295
267
|
return valid_samples
|
296
268
|
|
297
|
-
def get_batch(data:torch.
|
298
|
-
random_samples:torch.
|
299
|
-
masks:torch.
|
269
|
+
def get_batch(data:torch.Tensor,
|
270
|
+
random_samples:torch.Tensor,
|
271
|
+
masks:torch.Tensor=None,
|
300
272
|
block_size:int=None,
|
301
273
|
get_offset:bool=True
|
302
274
|
) -> tuple[torch.tensor]:
|
@@ -308,53 +280,78 @@ def get_batch(data:torch.tensor,
|
|
308
280
|
|
309
281
|
'''
|
310
282
|
batch_size = len(random_samples)
|
311
|
-
if block_size
|
283
|
+
if block_size is not None and block_size != data.shape[1]:
|
312
284
|
if block_size >= data.shape[1]:
|
313
285
|
raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
|
314
286
|
|
315
|
-
if masks
|
287
|
+
if masks is not None:
|
316
288
|
random_point = get_valid_samples(random_samples, masks, block_size)
|
317
289
|
else:
|
318
290
|
random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
|
319
291
|
batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
|
320
|
-
masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks
|
292
|
+
masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks is not None else None
|
321
293
|
batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
|
322
294
|
else:
|
323
295
|
block_size = data.shape[1]
|
324
296
|
batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
|
325
|
-
masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks
|
297
|
+
masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks is not None else None
|
326
298
|
batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
|
327
299
|
|
328
300
|
return batch_in, batch_out, masks_in
|
329
301
|
|
330
|
-
def top_kp_filter(logits:torch.
|
331
|
-
top_k:int,
|
332
|
-
top_p:float=None
|
333
|
-
) -> torch.
|
302
|
+
def top_kp_filter(logits: torch.Tensor,
|
303
|
+
top_k: int = None,
|
304
|
+
top_p: float = None
|
305
|
+
) -> torch.Tensor:
|
334
306
|
'''
|
307
|
+
Returns predicted token by filtering output logits using top_k and/or top_p (nucleus) filtering.
|
335
308
|
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
341
|
-
cumulative_probs = torch.cumsum(sorted_logits, dim=-1)
|
309
|
+
Args:
|
310
|
+
logits: (batch_size, vocab_size) tensor of raw logits.
|
311
|
+
top_k: keep only top_k tokens with highest logits.
|
312
|
+
top_p: keep the smallest set of tokens with cumulative probability >= top_p.
|
342
313
|
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
314
|
+
Returns:
|
315
|
+
selected: tensor of selected token indices (batch_size,)
|
316
|
+
'''
|
317
|
+
logits = logits.clone() # avoid modifying input logits in-place
|
318
|
+
|
319
|
+
# Apply top-p filtering if specified
|
320
|
+
if top_p is not None:
|
321
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
322
|
+
probs = F.softmax(sorted_logits, dim=-1)
|
323
|
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
324
|
+
|
325
|
+
# Remove tokens with cumulative probability above threshold (except first token)
|
326
|
+
sorted_mask = cumulative_probs > top_p
|
327
|
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
328
|
+
sorted_mask[..., 0] = False
|
329
|
+
|
330
|
+
# Mask tokens to remove by setting logits to -inf
|
331
|
+
indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
|
332
|
+
logits[indices_to_remove] = float('-inf')
|
333
|
+
|
334
|
+
# Apply top-k filtering if specified
|
335
|
+
if top_k is not None:
|
336
|
+
top_k = min(top_k, logits.size(-1)) # safety check
|
337
|
+
topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
|
338
|
+
topk_probs = F.softmax(topk_logits, dim=-1).cpu().numpy()
|
339
|
+
|
340
|
+
# For each batch, sample from top_k candidates
|
341
|
+
selected = []
|
342
|
+
for i in range(topk_probs.shape[0]):
|
343
|
+
candidate = np.random.choice(topk_indices[i].cpu().numpy(), 1, p=topk_probs[i])
|
344
|
+
selected.append(candidate[0])
|
345
|
+
selected = torch.tensor(selected, dtype=torch.long)
|
348
346
|
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
selected = torch.tensor(np.random.choice(sorted_indices[0], 1, p=sorted_logits[0]), dtype=torch.long)
|
347
|
+
else:
|
348
|
+
# If only top_p is specified, sample from entire filtered logits
|
349
|
+
probs = F.softmax(logits, dim=-1).cpu().numpy()
|
350
|
+
selected = []
|
351
|
+
for i in range(probs.shape[0]):
|
352
|
+
candidate = np.random.choice(len(probs[i]), 1, p=probs[i])
|
353
|
+
selected.append(candidate[0])
|
354
|
+
selected = torch.tensor(selected, dtype=torch.long)
|
358
355
|
|
359
356
|
return selected
|
360
357
|
|
@@ -387,10 +384,10 @@ class SelfAttention(nn.Module):
|
|
387
384
|
self.dropout = nn.Dropout(dropout)
|
388
385
|
|
389
386
|
def forward(self,
|
390
|
-
k:torch.
|
391
|
-
q:torch.
|
392
|
-
v:torch.
|
393
|
-
mask:torch.
|
387
|
+
k:torch.Tensor,
|
388
|
+
q:torch.Tensor,
|
389
|
+
v:torch.Tensor,
|
390
|
+
mask:torch.Tensor=None
|
394
391
|
) -> torch.tensor:
|
395
392
|
'''
|
396
393
|
|
@@ -406,7 +403,7 @@ class SelfAttention(nn.Module):
|
|
406
403
|
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
|
407
404
|
if self.triangle_mask and self.block_size >= 0:
|
408
405
|
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
409
|
-
if mask
|
406
|
+
if mask is not None:
|
410
407
|
wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
|
411
408
|
wei = F.softmax(wei, dim=-1)
|
412
409
|
wei = self.dropout(wei)
|
@@ -438,10 +435,10 @@ class MultiHeadAttention(nn.Module):
|
|
438
435
|
self.dropout = nn.Dropout(dropout)
|
439
436
|
|
440
437
|
def forward(self,
|
441
|
-
k:torch.
|
442
|
-
q:torch.
|
443
|
-
v:torch.
|
444
|
-
mask:torch.
|
438
|
+
k:torch.Tensor,
|
439
|
+
q:torch.Tensor,
|
440
|
+
v:torch.Tensor,
|
441
|
+
mask:torch.Tensor=None
|
445
442
|
) -> torch.tensor:
|
446
443
|
'''
|
447
444
|
|
@@ -475,7 +472,7 @@ class FeedForward(nn.Module):
|
|
475
472
|
)
|
476
473
|
|
477
474
|
def forward(self,
|
478
|
-
x:torch.
|
475
|
+
x:torch.Tensor
|
479
476
|
) -> torch.tensor:
|
480
477
|
return self.net(x)
|
481
478
|
|
@@ -500,8 +497,8 @@ class EncoderBlock(nn.Module):
|
|
500
497
|
self.ln2 = nn.LayerNorm(n_embed)
|
501
498
|
|
502
499
|
def forward(self,
|
503
|
-
x:torch.
|
504
|
-
mask:torch.
|
500
|
+
x:torch.Tensor,
|
501
|
+
mask:torch.Tensor=None
|
505
502
|
) -> tuple[torch.tensor]:
|
506
503
|
att = self.sa(x, x, x, mask=mask)
|
507
504
|
x = self.ln1(att + x)
|
@@ -541,15 +538,15 @@ class DecoderBlock(nn.Module):
|
|
541
538
|
self.ca = None
|
542
539
|
|
543
540
|
def forward(self,
|
544
|
-
x:torch.
|
545
|
-
enc_k:torch.
|
546
|
-
enc_v:torch.
|
541
|
+
x:torch.Tensor,
|
542
|
+
enc_k:torch.Tensor,
|
543
|
+
enc_v:torch.Tensor,
|
547
544
|
mask_out:bool=None,
|
548
|
-
mask_in:torch.
|
545
|
+
mask_in:torch.Tensor=None
|
549
546
|
) -> tuple[torch.tensor]:
|
550
547
|
att = self.sa(x, x, x, mask=mask_out)
|
551
548
|
x = self.ln1(att + x)
|
552
|
-
if self.ca
|
549
|
+
if self.ca is not None:
|
553
550
|
catt = self.ca(enc_k, x, enc_v, mask=mask_in)
|
554
551
|
x = self.ln3(catt + x)
|
555
552
|
ff = self.ffwd(x)
|
@@ -615,6 +612,7 @@ class RoboConstructor(nn.Module):
|
|
615
612
|
enc_vocab_size:int=None,
|
616
613
|
enc_block_size:int=None,
|
617
614
|
enc_expansion_factor:int=4,
|
615
|
+
enc_positional_encoding:bool=True,
|
618
616
|
dropout:float=0.1,
|
619
617
|
device:str=None
|
620
618
|
) -> None:
|
@@ -627,7 +625,7 @@ class RoboConstructor(nn.Module):
|
|
627
625
|
self.dec_expansion_factor = dec_expansion_factor
|
628
626
|
self.dropout = dropout
|
629
627
|
|
630
|
-
if device
|
628
|
+
if device is None:
|
631
629
|
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
632
630
|
else:
|
633
631
|
self.device = device
|
@@ -635,6 +633,7 @@ class RoboConstructor(nn.Module):
|
|
635
633
|
self.dec_positional_embedding_table = nn.Embedding(dec_block_size, n_embed)
|
636
634
|
|
637
635
|
if enc_n_blocks != 0:
|
636
|
+
self.enc_positional_encoding = enc_positional_encoding
|
638
637
|
self.enc_n_blocks = enc_n_blocks
|
639
638
|
self.enc_n_head = enc_n_head
|
640
639
|
self.enc_expansion_factor = enc_expansion_factor
|
@@ -642,7 +641,8 @@ class RoboConstructor(nn.Module):
|
|
642
641
|
self.enc_block_size = enc_block_size
|
643
642
|
self.cross_attention = True
|
644
643
|
self.enc_token_embedding_table = nn.Embedding(enc_vocab_size, n_embed)
|
645
|
-
|
644
|
+
if enc_positional_encoding:
|
645
|
+
self.enc_positional_embedding_table = nn.Embedding(enc_block_size, n_embed)
|
646
646
|
self.encoder_blocks = MySequential(*[EncoderBlock(n_embed, enc_n_head, enc_expansion_factor, dropout=dropout) for _ in range(enc_n_blocks)])
|
647
647
|
else:
|
648
648
|
self.cross_attention = False
|
@@ -670,13 +670,13 @@ class RoboConstructor(nn.Module):
|
|
670
670
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
671
671
|
|
672
672
|
def forward(self,
|
673
|
-
dec_in:torch.
|
674
|
-
dec_mask:torch.
|
675
|
-
enc_in:torch.
|
676
|
-
enc_mask:torch.
|
673
|
+
dec_in:torch.Tensor,
|
674
|
+
dec_mask:torch.Tensor=None,
|
675
|
+
enc_in:torch.Tensor=None,
|
676
|
+
enc_mask:torch.Tensor=None
|
677
677
|
) -> torch.tensor:
|
678
678
|
_, dec_T = dec_in.shape
|
679
|
-
if enc_in
|
679
|
+
if enc_in is not None:
|
680
680
|
_, enc_T = enc_in.shape
|
681
681
|
|
682
682
|
dec_tok_emb = self.dec_token_embedding_table(dec_in)
|
@@ -685,8 +685,11 @@ class RoboConstructor(nn.Module):
|
|
685
685
|
|
686
686
|
if self.cross_attention:
|
687
687
|
enc_tok_emb = self.enc_token_embedding_table(enc_in)
|
688
|
-
|
689
|
-
|
688
|
+
if self.enc_positional_encoding:
|
689
|
+
enc_pos_emb = self.enc_positional_embedding_table(torch.arange(enc_T, device=self.device))
|
690
|
+
enc_x = enc_tok_emb + enc_pos_emb
|
691
|
+
else:
|
692
|
+
enc_x = enc_tok_emb
|
690
693
|
|
691
694
|
enc_out, enc_mask = self.encoder_blocks(enc_x, enc_mask)
|
692
695
|
else:
|
@@ -712,13 +715,13 @@ class RoboConstructor(nn.Module):
|
|
712
715
|
|
713
716
|
dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
|
714
717
|
dec_train_batch_in = dec_train_batch_in.to(self.device)
|
715
|
-
dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out
|
716
|
-
dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in
|
718
|
+
dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out is not None else None
|
719
|
+
dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in is not None else None
|
717
720
|
|
718
721
|
if self.cross_attention:
|
719
722
|
enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
|
720
723
|
enc_train_batch_in = enc_train_batch_in.to(self.device)
|
721
|
-
enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in
|
724
|
+
enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in is not None else None
|
722
725
|
else:
|
723
726
|
enc_train_batch_in = None
|
724
727
|
enc_train_masks_in = None
|
@@ -730,14 +733,8 @@ class RoboConstructor(nn.Module):
|
|
730
733
|
max_iters:int,
|
731
734
|
eval_interval:int,
|
732
735
|
batch_size:int,
|
733
|
-
|
734
|
-
|
735
|
-
dec_training_masks_path:str=None,
|
736
|
-
dec_eval_masks_path:str=None,
|
737
|
-
enc_training_path:str=None,
|
738
|
-
enc_eval_path:str=None,
|
739
|
-
enc_training_masks_path:str=None,
|
740
|
-
enc_eval_masks_path:str=None,
|
736
|
+
training_dir_path:str,
|
737
|
+
eval_dir_path:str,
|
741
738
|
eval_iters:int=3,
|
742
739
|
learning_rate:float=1e-4,
|
743
740
|
pad_token:int=None,
|
@@ -746,21 +743,36 @@ class RoboConstructor(nn.Module):
|
|
746
743
|
label_smoothing:float=0.1
|
747
744
|
) -> None:
|
748
745
|
|
746
|
+
dec_training_path = os.path.join(training_dir_path, "decoder_data.pt")
|
749
747
|
dec_training_data = torch.load(dec_training_path, weights_only=True)
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
748
|
+
|
749
|
+
dec_eval_path = os.path.join(eval_dir_path, "decoder_data.pt")
|
750
|
+
dec_eval_data = torch.load(dec_eval_path, weights_only=True) if os.path.isfile(dec_eval_path) else None
|
751
|
+
|
752
|
+
dec_training_masks_path = os.path.join(training_dir_path, "decoder_mask_data.pt")
|
753
|
+
dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if os.path.isfile(dec_training_masks_path) else None
|
754
|
+
|
755
|
+
dec_eval_masks_path = os.path.join(eval_dir_path, "decoder_mask_data.pt")
|
756
|
+
dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if os.path.isfile(dec_eval_masks_path) else None
|
757
|
+
|
758
|
+
enc_training_path = os.path.join(training_dir_path, "encoder_data.pt")
|
759
|
+
enc_training_data = torch.load(enc_training_path, weights_only=True) if os.path.isfile(enc_training_path) else None
|
760
|
+
|
761
|
+
enc_eval_path = os.path.join(eval_dir_path, "encoder_data.pt")
|
762
|
+
enc_eval_data = torch.load(enc_eval_path, weights_only=True) if os.path.isfile(enc_eval_path) else None
|
763
|
+
|
764
|
+
enc_training_masks_path = os.path.join(training_dir_path, "encoder_mask_data.pt")
|
765
|
+
enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if os.path.isfile(enc_training_masks_path) else None
|
766
|
+
|
767
|
+
enc_eval_masks_path = os.path.join(eval_dir_path, "encoder_mask_data.pt")
|
768
|
+
enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if os.path.isfile(enc_eval_masks_path) else None
|
769
|
+
|
770
|
+
if pad_token is None and dec_tokenizer is not None:
|
759
771
|
pad_token = dec_tokenizer.pad_token
|
760
772
|
|
761
773
|
self.to(self.device)
|
762
774
|
|
763
|
-
if pad_token
|
775
|
+
if pad_token is not None:
|
764
776
|
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
|
765
777
|
else:
|
766
778
|
loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
|
@@ -776,7 +788,7 @@ class RoboConstructor(nn.Module):
|
|
776
788
|
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
777
789
|
losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
|
778
790
|
out["train"] = losses.mean()
|
779
|
-
if dec_eval_data
|
791
|
+
if dec_eval_data is not None:
|
780
792
|
for k in range(eval_iters):
|
781
793
|
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
|
782
794
|
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
@@ -792,7 +804,7 @@ class RoboConstructor(nn.Module):
|
|
792
804
|
if iter % eval_interval == 0 or iter == max_iters-1:
|
793
805
|
losses = estimate_loss()
|
794
806
|
print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
|
795
|
-
if save_path
|
807
|
+
if save_path is not None:
|
796
808
|
save_component(self, save_path=save_path)
|
797
809
|
|
798
810
|
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
|
@@ -819,25 +831,25 @@ class RoboConstructor(nn.Module):
|
|
819
831
|
top_k:int=None,
|
820
832
|
top_p:float=None
|
821
833
|
) -> list[int]|str:
|
822
|
-
max_new_tokens = self.dec_block_size if max_new_tokens
|
834
|
+
max_new_tokens = self.dec_block_size if max_new_tokens is None else max_new_tokens
|
823
835
|
|
824
836
|
if self.cross_attention:
|
825
|
-
if enc_tokenizer
|
826
|
-
if enc_start_token
|
837
|
+
if enc_tokenizer is not None:
|
838
|
+
if enc_start_token is None:
|
827
839
|
enc_start_token = enc_tokenizer.start_token
|
828
|
-
if enc_end_token
|
840
|
+
if enc_end_token is None:
|
829
841
|
enc_end_token = enc_tokenizer.end_token
|
830
842
|
if isinstance(inputs, str):
|
831
843
|
inputs = enc_tokenizer.encode(inputs)
|
832
844
|
|
833
|
-
if dec_tokenizer
|
834
|
-
if dec_start_token
|
845
|
+
if dec_tokenizer is not None:
|
846
|
+
if dec_start_token is None:
|
835
847
|
dec_start_token = dec_tokenizer.start_token
|
836
|
-
if dec_end_token
|
848
|
+
if dec_end_token is None:
|
837
849
|
dec_end_token = dec_tokenizer.end_token
|
838
|
-
if new_line_token
|
850
|
+
if new_line_token is None:
|
839
851
|
new_line_token = dec_tokenizer.new_line_token
|
840
|
-
if self.cross_attention
|
852
|
+
if not self.cross_attention and isinstance(inputs, str):
|
841
853
|
inputs = dec_tokenizer.encode(inputs)
|
842
854
|
|
843
855
|
|
@@ -846,7 +858,7 @@ class RoboConstructor(nn.Module):
|
|
846
858
|
idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
|
847
859
|
else:
|
848
860
|
enc_input = None
|
849
|
-
if separator_token
|
861
|
+
if separator_token is not None:
|
850
862
|
idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
|
851
863
|
else:
|
852
864
|
idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
|
@@ -860,7 +872,7 @@ class RoboConstructor(nn.Module):
|
|
860
872
|
logits = proj_output[:, -1, :]
|
861
873
|
probabilities = F.log_softmax(logits/temperature, dim=-1)
|
862
874
|
|
863
|
-
if top_k
|
875
|
+
if top_k is None and top_p is None:
|
864
876
|
idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
|
865
877
|
else:
|
866
878
|
idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
|
@@ -868,10 +880,10 @@ class RoboConstructor(nn.Module):
|
|
868
880
|
if idx_next[0] == dec_end_token:
|
869
881
|
break
|
870
882
|
|
871
|
-
if dec_tokenizer
|
883
|
+
if dec_tokenizer is None:
|
872
884
|
return idx[0].tolist()
|
873
885
|
else:
|
874
|
-
if new_line_token
|
886
|
+
if new_line_token is not None:
|
875
887
|
return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
|
876
888
|
else:
|
877
889
|
return dec_tokenizer.decode(idx[0].tolist())
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: robo_lib
|
3
|
-
Version: 0.0
|
3
|
+
Version: 1.0.0
|
4
4
|
Summary: A package to create, configure, and train transformer models.
|
5
5
|
Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
|
6
6
|
Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
|
@@ -0,0 +1,6 @@
|
|
1
|
+
robo_lib/__init__.py,sha256=NnzWHWwpFcSJD_XRMWKKPQFAIrRBFYiCFN0pgUGPygc,968
|
2
|
+
robo_lib/components.py,sha256=M_1M1Y56_W0bSElZlg3M6gRoJJPAnUchTO3N8AdsEV8,43091
|
3
|
+
robo_lib-1.0.0.dist-info/METADATA,sha256=GAnmrynDr3-hv9KyCjXlpx5I8v2BLQJCIDXURoGFw2w,9633
|
4
|
+
robo_lib-1.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
robo_lib-1.0.0.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
|
6
|
+
robo_lib-1.0.0.dist-info/RECORD,,
|
robo_lib-0.0.10.dist-info/RECORD
DELETED
@@ -1,6 +0,0 @@
|
|
1
|
-
robo_lib/__init__.py,sha256=iVOAsANj0lScVW9KKMxCULYmpp0cv4sv1k3sHjBSlE0,1012
|
2
|
-
robo_lib/components.py,sha256=OjusjkSlMlAsTEq1kSqixKXG9sBw8Re8hsXTEy_bJ48,42315
|
3
|
-
robo_lib-0.0.10.dist-info/METADATA,sha256=a30lSFG-Eo9UGFQErA64MTbeVqCeD8BwViXMmB2OPX4,9634
|
4
|
-
robo_lib-0.0.10.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
5
|
-
robo_lib-0.0.10.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
|
6
|
-
robo_lib-0.0.10.dist-info/RECORD,,
|
File without changes
|