robo-lib 0.0.11__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- robo_lib-1.0.0/.gitignore +1 -0
- {robo_lib-0.0.11 → robo_lib-1.0.0}/PKG-INFO +1 -1
- {robo_lib-0.0.11 → robo_lib-1.0.0}/pyproject.toml +1 -1
- {robo_lib-0.0.11 → robo_lib-1.0.0}/robo_lib/__init__.py +2 -3
- {robo_lib-0.0.11 → robo_lib-1.0.0}/robo_lib/components.py +195 -189
- robo_lib-1.0.0/tests/test_data_processor.py +82 -0
- robo_lib-1.0.0/tests/test_functions.py +176 -0
- robo_lib-1.0.0/tests/test_robo_constructor.py +130 -0
- robo_lib-1.0.0/tests/test_tokenizer_constructor.py +89 -0
- {robo_lib-0.0.11 → robo_lib-1.0.0}/LICENSE +0 -0
- {robo_lib-0.0.11 → robo_lib-1.0.0}/README.md +0 -0
- {robo_lib-0.0.11 → robo_lib-1.0.0}/tests/__init__.py +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
__pycache__/
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: robo_lib
|
3
|
-
Version: 0.0
|
3
|
+
Version: 1.0.0
|
4
4
|
Summary: A package to create, configure, and train transformer models.
|
5
5
|
Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
|
6
6
|
Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
|
@@ -1,8 +1,7 @@
|
|
1
1
|
from .components import TokenizerConstructor as TokenizerConstructor
|
2
2
|
from .components import create_mask as create_mask
|
3
|
-
from .components import
|
4
|
-
from .components import
|
5
|
-
from .components import scan_max_block_size as scan_max_block_size
|
3
|
+
from .components import pre_process_data as pre_process_data
|
4
|
+
from .components import safe_stack as safe_stack
|
6
5
|
from .components import DataProcessor as DataProcessor
|
7
6
|
from .components import get_valid_samples as get_valid_samples
|
8
7
|
from .components import get_batch as get_batch
|
@@ -6,6 +6,8 @@ import numpy as np
|
|
6
6
|
import random
|
7
7
|
import pickle
|
8
8
|
import itertools
|
9
|
+
from pathlib import Path
|
10
|
+
import os
|
9
11
|
|
10
12
|
class TokenizerConstructor:
|
11
13
|
'''
|
@@ -30,6 +32,7 @@ class TokenizerConstructor:
|
|
30
32
|
tokenizer_type:str="BPE",
|
31
33
|
pre_tokenizers:list[str]|str=["Whitespace"],
|
32
34
|
normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
|
35
|
+
vocab:dict[str,int] = {},
|
33
36
|
special_tokens:list[str]|str=[],
|
34
37
|
unknown_token_string:str="<unk>",
|
35
38
|
start_token_string:str="<sos>",
|
@@ -42,25 +45,28 @@ class TokenizerConstructor:
|
|
42
45
|
|
43
46
|
if isinstance(special_tokens, str):
|
44
47
|
special_tokens = [special_tokens]
|
45
|
-
self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token
|
46
|
-
self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string
|
47
|
-
self.start_token = self.special_tokens.index(start_token_string) if start_token_string
|
48
|
-
self.end_token = self.special_tokens.index(end_token_string) if end_token_string
|
49
|
-
self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string
|
50
|
-
self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string
|
48
|
+
self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token is not None]
|
49
|
+
self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string is not None else None
|
50
|
+
self.start_token = self.special_tokens.index(start_token_string) if start_token_string is not None else None
|
51
|
+
self.end_token = self.special_tokens.index(end_token_string) if end_token_string is not None else None
|
52
|
+
self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string is not None else None
|
53
|
+
self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string is not None else None
|
51
54
|
|
52
55
|
if tokenizer_type == "BPE":
|
53
56
|
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
|
54
57
|
self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
55
58
|
elif tokenizer_type == "WordLevel":
|
56
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
|
59
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(vocab = vocab, unk_token=unknown_token_string))
|
57
60
|
self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
58
61
|
elif tokenizer_type == "WordPiece":
|
59
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
|
62
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(vocab = vocab, unk_token=unknown_token_string))
|
60
63
|
self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
61
64
|
elif tokenizer_type == "Unigram":
|
62
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(
|
65
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram())
|
63
66
|
self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
67
|
+
|
68
|
+
if self.pad_token is not None:
|
69
|
+
self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=pad_token_string)
|
64
70
|
|
65
71
|
if isinstance(pre_tokenizers, str):
|
66
72
|
pre_tokenizers = [pre_tokenizers]
|
@@ -122,6 +128,13 @@ class TokenizerConstructor:
|
|
122
128
|
def encode(self, inp:str) -> list[int]:
|
123
129
|
return self.tokenizer_type.encode(inp).ids
|
124
130
|
|
131
|
+
def encode_batch(self, inp:list[str], max_length:int=None) -> list[list[int]]:
|
132
|
+
if max_length is not None:
|
133
|
+
self.tokenizer_type.enable_truncation(max_length=max_length)
|
134
|
+
out = [row.ids for row in self.tokenizer_type.encode_batch(inp)]
|
135
|
+
self.tokenizer_type.no_truncation()
|
136
|
+
return out
|
137
|
+
|
125
138
|
def decode(self, inp:list[int]) -> str:
|
126
139
|
return self.tokenizer_type.decode(inp)
|
127
140
|
|
@@ -136,38 +149,35 @@ def create_mask(row:list, block_size:int) -> list[bool]:
|
|
136
149
|
mask = [1]*len(row) + [0]*(block_size - len(row))
|
137
150
|
return mask
|
138
151
|
|
139
|
-
def
|
140
|
-
'''
|
141
|
-
|
142
|
-
returns padded row. Row is padded until length block_size with specified pad_token value
|
143
|
-
|
144
|
-
'''
|
145
|
-
row.extend([pad_token]*(block_size - len(row)))
|
146
|
-
return row
|
147
|
-
|
148
|
-
def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
|
152
|
+
def pre_process_data(data:str, start_token_string:str, end_token_string:str) -> list[int]:
|
149
153
|
'''
|
150
154
|
|
151
|
-
returns
|
155
|
+
returns string row with the tokenizer's start and end tokens if they exist
|
152
156
|
|
153
157
|
'''
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
158
|
+
if start_token_string is None and end_token_string is None:
|
159
|
+
return data
|
160
|
+
else:
|
161
|
+
for i in range(len(data)):
|
162
|
+
if start_token_string is not None:
|
163
|
+
data[i] = start_token_string + data[i]
|
164
|
+
if end_token_string is not None:
|
165
|
+
data[i] = data[i] + end_token_string
|
166
|
+
|
167
|
+
return data
|
161
168
|
|
162
|
-
def
|
169
|
+
def safe_stack(tensor_list:list[torch.tensor]) -> torch.tensor:
|
163
170
|
'''
|
164
171
|
|
165
|
-
|
172
|
+
torch stack with check to ensure tensors are valid in input list
|
173
|
+
|
174
|
+
returns torch.stack(out_list) for all valid torch tensors in tensor_list. raises error if no valid tensors
|
166
175
|
|
167
176
|
'''
|
168
|
-
|
169
|
-
|
170
|
-
|
177
|
+
out_list = [row for row in tensor_list if isinstance(row, torch.Tensor)]
|
178
|
+
if len(out_list) == 0:
|
179
|
+
raise ValueError("no valid tensors in list.")
|
180
|
+
return torch.stack(out_list)
|
171
181
|
|
172
182
|
|
173
183
|
class DataProcessor:
|
@@ -196,93 +206,55 @@ class DataProcessor:
|
|
196
206
|
self.enc_tokenizer = enc_tokenizer
|
197
207
|
|
198
208
|
def process_list(self,
|
199
|
-
save_path:str,
|
200
209
|
dec_data:list[str]|str,
|
201
210
|
dec_max_block_size:int=None,
|
202
211
|
dec_create_masks:bool=True,
|
203
|
-
dec_block_size_exceeded_policy:str=None,
|
204
212
|
enc_data:list[str]=None,
|
205
213
|
enc_max_block_size:int=None,
|
206
214
|
enc_create_masks:bool=True,
|
207
|
-
|
215
|
+
save_path:str = "."
|
208
216
|
) -> None:
|
209
217
|
|
210
218
|
if isinstance(dec_data, str):
|
211
219
|
dec_data = [dec_data]
|
212
220
|
dec_data_length = len(dec_data)
|
213
|
-
save_path = save_path.replace(".pt", "")
|
214
|
-
|
215
|
-
if dec_max_block_size == None:
|
216
|
-
dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
|
217
221
|
|
218
|
-
if enc_data
|
219
|
-
|
222
|
+
if enc_data is not None:
|
223
|
+
if self.enc_tokenizer is None:
|
224
|
+
self.enc_tokenizer = self.dec_tokenizer
|
220
225
|
|
221
226
|
enc_data_length = len(enc_data)
|
222
227
|
if dec_data_length != enc_data_length:
|
223
|
-
raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
|
228
|
+
raise Exception(f"decoder data and encoder data lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
|
224
229
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
enc_out_list = [[]]*enc_data_length
|
229
|
-
enc_mask_list = [[]]*enc_data_length if enc_create_masks else []
|
230
|
-
else:
|
231
|
-
enc_out_list = []
|
232
|
-
enc_mask_list = []
|
233
|
-
|
234
|
-
dec_out_list = [[]]*dec_data_length
|
235
|
-
dec_mask_list = [[]]*dec_data_length if dec_create_masks else []
|
236
|
-
for index in range(len(dec_out_list)):
|
237
|
-
dec_processed_item = process_row(dec_data[index], self.dec_tokenizer)
|
238
|
-
if dec_max_block_size != None and len(dec_processed_item) > dec_max_block_size:
|
239
|
-
if dec_block_size_exceeded_policy == "trim":
|
240
|
-
dec_processed_item = dec_processed_item[:dec_max_block_size]
|
241
|
-
elif dec_block_size_exceeded_policy == "skip":
|
242
|
-
continue
|
243
|
-
elif dec_block_size_exceeded_policy == None:
|
244
|
-
raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
|
245
|
-
if dec_create_masks:
|
246
|
-
dec_mask = create_mask(dec_processed_item, dec_max_block_size)
|
247
|
-
dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
|
248
|
-
|
249
|
-
if enc_data != None:
|
250
|
-
enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
|
251
|
-
if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
|
252
|
-
if enc_block_size_exceeded_policy == "trim":
|
253
|
-
enc_processed_item = enc_processed_item[:enc_max_block_size]
|
254
|
-
elif enc_block_size_exceeded_policy == "skip":
|
255
|
-
continue
|
256
|
-
elif enc_block_size_exceeded_policy == None:
|
257
|
-
raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
|
258
|
-
if enc_create_masks:
|
259
|
-
enc_mask = create_mask(enc_processed_item, enc_max_block_size)
|
260
|
-
enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
|
261
|
-
|
262
|
-
dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
|
263
|
-
if dec_create_masks:
|
264
|
-
dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
|
265
|
-
|
266
|
-
if enc_data != None:
|
267
|
-
enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
|
268
|
-
if enc_create_masks:
|
269
|
-
enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
|
270
|
-
|
271
|
-
dec_out_list = torch.stack([row for row in dec_out_list if row != []])
|
272
|
-
torch.save(dec_out_list, save_path + "_decoder_data.pt")
|
230
|
+
print("processing data")
|
231
|
+
dec_out_list = self.dec_tokenizer.encode_batch(dec_data, max_length=dec_max_block_size)
|
273
232
|
if dec_create_masks:
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
233
|
+
mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.dec_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
|
234
|
+
dec_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in dec_out_list])
|
235
|
+
|
236
|
+
if enc_data is not None:
|
237
|
+
enc_out_list = self.enc_tokenizer.encode_batch(enc_data, max_length=enc_max_block_size)
|
279
238
|
if enc_create_masks:
|
280
|
-
|
281
|
-
|
239
|
+
mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.enc_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
|
240
|
+
enc_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in enc_out_list])
|
282
241
|
|
242
|
+
dec_out_list = torch.tensor(dec_out_list, dtype=torch.long)
|
243
|
+
Path(save_path).mkdir(parents=True, exist_ok=True)
|
244
|
+
torch.save(dec_out_list, os.path.join(save_path, "decoder_data.pt"))
|
245
|
+
if dec_create_masks:
|
246
|
+
dec_mask_list = torch.tensor(dec_mask_list, dtype=torch.long)
|
247
|
+
torch.save(dec_mask_list, os.path.join(save_path, "decoder_mask_data.pt"))
|
248
|
+
if enc_data is not None:
|
249
|
+
enc_out_list = torch.tensor(enc_out_list, dtype=torch.long)
|
250
|
+
torch.save(enc_out_list, os.path.join(save_path, "encoder_data.pt"))
|
251
|
+
if enc_create_masks:
|
252
|
+
enc_mask_list = torch.tensor(enc_mask_list, dtype=torch.long)
|
253
|
+
torch.save(enc_mask_list, os.path.join(save_path, "encoder_mask_data.pt"))
|
283
254
|
|
284
|
-
|
285
|
-
|
255
|
+
|
256
|
+
def get_valid_samples(random_samples:torch.Tensor,
|
257
|
+
masks:torch.Tensor,
|
286
258
|
block_size:int
|
287
259
|
) -> list[int]:
|
288
260
|
'''
|
@@ -294,9 +266,9 @@ def get_valid_samples(random_samples:torch.tensor,
|
|
294
266
|
valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
|
295
267
|
return valid_samples
|
296
268
|
|
297
|
-
def get_batch(data:torch.
|
298
|
-
random_samples:torch.
|
299
|
-
masks:torch.
|
269
|
+
def get_batch(data:torch.Tensor,
|
270
|
+
random_samples:torch.Tensor,
|
271
|
+
masks:torch.Tensor=None,
|
300
272
|
block_size:int=None,
|
301
273
|
get_offset:bool=True
|
302
274
|
) -> tuple[torch.tensor]:
|
@@ -308,53 +280,78 @@ def get_batch(data:torch.tensor,
|
|
308
280
|
|
309
281
|
'''
|
310
282
|
batch_size = len(random_samples)
|
311
|
-
if block_size
|
283
|
+
if block_size is not None and block_size != data.shape[1]:
|
312
284
|
if block_size >= data.shape[1]:
|
313
285
|
raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
|
314
286
|
|
315
|
-
if masks
|
287
|
+
if masks is not None:
|
316
288
|
random_point = get_valid_samples(random_samples, masks, block_size)
|
317
289
|
else:
|
318
290
|
random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
|
319
291
|
batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
|
320
|
-
masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks
|
292
|
+
masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks is not None else None
|
321
293
|
batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
|
322
294
|
else:
|
323
295
|
block_size = data.shape[1]
|
324
296
|
batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
|
325
|
-
masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks
|
297
|
+
masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks is not None else None
|
326
298
|
batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
|
327
299
|
|
328
300
|
return batch_in, batch_out, masks_in
|
329
301
|
|
330
|
-
def top_kp_filter(logits:torch.
|
331
|
-
top_k:int,
|
332
|
-
top_p:float=None
|
333
|
-
) -> torch.
|
302
|
+
def top_kp_filter(logits: torch.Tensor,
|
303
|
+
top_k: int = None,
|
304
|
+
top_p: float = None
|
305
|
+
) -> torch.Tensor:
|
334
306
|
'''
|
307
|
+
Returns predicted token by filtering output logits using top_k and/or top_p (nucleus) filtering.
|
335
308
|
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
341
|
-
cumulative_probs = torch.cumsum(sorted_logits, dim=-1)
|
309
|
+
Args:
|
310
|
+
logits: (batch_size, vocab_size) tensor of raw logits.
|
311
|
+
top_k: keep only top_k tokens with highest logits.
|
312
|
+
top_p: keep the smallest set of tokens with cumulative probability >= top_p.
|
342
313
|
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
if
|
350
|
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
314
|
+
Returns:
|
315
|
+
selected: tensor of selected token indices (batch_size,)
|
316
|
+
'''
|
317
|
+
logits = logits.clone() # avoid modifying input logits in-place
|
318
|
+
|
319
|
+
# Apply top-p filtering if specified
|
320
|
+
if top_p is not None:
|
321
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
322
|
+
probs = F.softmax(sorted_logits, dim=-1)
|
323
|
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
324
|
+
|
325
|
+
# Remove tokens with cumulative probability above threshold (except first token)
|
326
|
+
sorted_mask = cumulative_probs > top_p
|
327
|
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
328
|
+
sorted_mask[..., 0] = False
|
329
|
+
|
330
|
+
# Mask tokens to remove by setting logits to -inf
|
331
|
+
indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
|
332
|
+
logits[indices_to_remove] = float('-inf')
|
333
|
+
|
334
|
+
# Apply top-k filtering if specified
|
335
|
+
if top_k is not None:
|
336
|
+
top_k = min(top_k, logits.size(-1)) # safety check
|
337
|
+
topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
|
338
|
+
topk_probs = F.softmax(topk_logits, dim=-1).cpu().numpy()
|
339
|
+
|
340
|
+
# For each batch, sample from top_k candidates
|
341
|
+
selected = []
|
342
|
+
for i in range(topk_probs.shape[0]):
|
343
|
+
candidate = np.random.choice(topk_indices[i].cpu().numpy(), 1, p=topk_probs[i])
|
344
|
+
selected.append(candidate[0])
|
345
|
+
selected = torch.tensor(selected, dtype=torch.long)
|
351
346
|
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
347
|
+
else:
|
348
|
+
# If only top_p is specified, sample from entire filtered logits
|
349
|
+
probs = F.softmax(logits, dim=-1).cpu().numpy()
|
350
|
+
selected = []
|
351
|
+
for i in range(probs.shape[0]):
|
352
|
+
candidate = np.random.choice(len(probs[i]), 1, p=probs[i])
|
353
|
+
selected.append(candidate[0])
|
354
|
+
selected = torch.tensor(selected, dtype=torch.long)
|
358
355
|
|
359
356
|
return selected
|
360
357
|
|
@@ -387,10 +384,10 @@ class SelfAttention(nn.Module):
|
|
387
384
|
self.dropout = nn.Dropout(dropout)
|
388
385
|
|
389
386
|
def forward(self,
|
390
|
-
k:torch.
|
391
|
-
q:torch.
|
392
|
-
v:torch.
|
393
|
-
mask:torch.
|
387
|
+
k:torch.Tensor,
|
388
|
+
q:torch.Tensor,
|
389
|
+
v:torch.Tensor,
|
390
|
+
mask:torch.Tensor=None
|
394
391
|
) -> torch.tensor:
|
395
392
|
'''
|
396
393
|
|
@@ -406,7 +403,7 @@ class SelfAttention(nn.Module):
|
|
406
403
|
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
|
407
404
|
if self.triangle_mask and self.block_size >= 0:
|
408
405
|
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
409
|
-
if mask
|
406
|
+
if mask is not None:
|
410
407
|
wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
|
411
408
|
wei = F.softmax(wei, dim=-1)
|
412
409
|
wei = self.dropout(wei)
|
@@ -438,10 +435,10 @@ class MultiHeadAttention(nn.Module):
|
|
438
435
|
self.dropout = nn.Dropout(dropout)
|
439
436
|
|
440
437
|
def forward(self,
|
441
|
-
k:torch.
|
442
|
-
q:torch.
|
443
|
-
v:torch.
|
444
|
-
mask:torch.
|
438
|
+
k:torch.Tensor,
|
439
|
+
q:torch.Tensor,
|
440
|
+
v:torch.Tensor,
|
441
|
+
mask:torch.Tensor=None
|
445
442
|
) -> torch.tensor:
|
446
443
|
'''
|
447
444
|
|
@@ -475,7 +472,7 @@ class FeedForward(nn.Module):
|
|
475
472
|
)
|
476
473
|
|
477
474
|
def forward(self,
|
478
|
-
x:torch.
|
475
|
+
x:torch.Tensor
|
479
476
|
) -> torch.tensor:
|
480
477
|
return self.net(x)
|
481
478
|
|
@@ -500,8 +497,8 @@ class EncoderBlock(nn.Module):
|
|
500
497
|
self.ln2 = nn.LayerNorm(n_embed)
|
501
498
|
|
502
499
|
def forward(self,
|
503
|
-
x:torch.
|
504
|
-
mask:torch.
|
500
|
+
x:torch.Tensor,
|
501
|
+
mask:torch.Tensor=None
|
505
502
|
) -> tuple[torch.tensor]:
|
506
503
|
att = self.sa(x, x, x, mask=mask)
|
507
504
|
x = self.ln1(att + x)
|
@@ -541,15 +538,15 @@ class DecoderBlock(nn.Module):
|
|
541
538
|
self.ca = None
|
542
539
|
|
543
540
|
def forward(self,
|
544
|
-
x:torch.
|
545
|
-
enc_k:torch.
|
546
|
-
enc_v:torch.
|
541
|
+
x:torch.Tensor,
|
542
|
+
enc_k:torch.Tensor,
|
543
|
+
enc_v:torch.Tensor,
|
547
544
|
mask_out:bool=None,
|
548
|
-
mask_in:torch.
|
545
|
+
mask_in:torch.Tensor=None
|
549
546
|
) -> tuple[torch.tensor]:
|
550
547
|
att = self.sa(x, x, x, mask=mask_out)
|
551
548
|
x = self.ln1(att + x)
|
552
|
-
if self.ca
|
549
|
+
if self.ca is not None:
|
553
550
|
catt = self.ca(enc_k, x, enc_v, mask=mask_in)
|
554
551
|
x = self.ln3(catt + x)
|
555
552
|
ff = self.ffwd(x)
|
@@ -628,7 +625,7 @@ class RoboConstructor(nn.Module):
|
|
628
625
|
self.dec_expansion_factor = dec_expansion_factor
|
629
626
|
self.dropout = dropout
|
630
627
|
|
631
|
-
if device
|
628
|
+
if device is None:
|
632
629
|
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
633
630
|
else:
|
634
631
|
self.device = device
|
@@ -673,13 +670,13 @@ class RoboConstructor(nn.Module):
|
|
673
670
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
674
671
|
|
675
672
|
def forward(self,
|
676
|
-
dec_in:torch.
|
677
|
-
dec_mask:torch.
|
678
|
-
enc_in:torch.
|
679
|
-
enc_mask:torch.
|
673
|
+
dec_in:torch.Tensor,
|
674
|
+
dec_mask:torch.Tensor=None,
|
675
|
+
enc_in:torch.Tensor=None,
|
676
|
+
enc_mask:torch.Tensor=None
|
680
677
|
) -> torch.tensor:
|
681
678
|
_, dec_T = dec_in.shape
|
682
|
-
if enc_in
|
679
|
+
if enc_in is not None:
|
683
680
|
_, enc_T = enc_in.shape
|
684
681
|
|
685
682
|
dec_tok_emb = self.dec_token_embedding_table(dec_in)
|
@@ -718,13 +715,13 @@ class RoboConstructor(nn.Module):
|
|
718
715
|
|
719
716
|
dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
|
720
717
|
dec_train_batch_in = dec_train_batch_in.to(self.device)
|
721
|
-
dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out
|
722
|
-
dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in
|
718
|
+
dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out is not None else None
|
719
|
+
dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in is not None else None
|
723
720
|
|
724
721
|
if self.cross_attention:
|
725
722
|
enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
|
726
723
|
enc_train_batch_in = enc_train_batch_in.to(self.device)
|
727
|
-
enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in
|
724
|
+
enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in is not None else None
|
728
725
|
else:
|
729
726
|
enc_train_batch_in = None
|
730
727
|
enc_train_masks_in = None
|
@@ -736,14 +733,8 @@ class RoboConstructor(nn.Module):
|
|
736
733
|
max_iters:int,
|
737
734
|
eval_interval:int,
|
738
735
|
batch_size:int,
|
739
|
-
|
740
|
-
|
741
|
-
dec_training_masks_path:str=None,
|
742
|
-
dec_eval_masks_path:str=None,
|
743
|
-
enc_training_path:str=None,
|
744
|
-
enc_eval_path:str=None,
|
745
|
-
enc_training_masks_path:str=None,
|
746
|
-
enc_eval_masks_path:str=None,
|
736
|
+
training_dir_path:str,
|
737
|
+
eval_dir_path:str,
|
747
738
|
eval_iters:int=3,
|
748
739
|
learning_rate:float=1e-4,
|
749
740
|
pad_token:int=None,
|
@@ -752,21 +743,36 @@ class RoboConstructor(nn.Module):
|
|
752
743
|
label_smoothing:float=0.1
|
753
744
|
) -> None:
|
754
745
|
|
746
|
+
dec_training_path = os.path.join(training_dir_path, "decoder_data.pt")
|
755
747
|
dec_training_data = torch.load(dec_training_path, weights_only=True)
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
748
|
+
|
749
|
+
dec_eval_path = os.path.join(eval_dir_path, "decoder_data.pt")
|
750
|
+
dec_eval_data = torch.load(dec_eval_path, weights_only=True) if os.path.isfile(dec_eval_path) else None
|
751
|
+
|
752
|
+
dec_training_masks_path = os.path.join(training_dir_path, "decoder_mask_data.pt")
|
753
|
+
dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if os.path.isfile(dec_training_masks_path) else None
|
754
|
+
|
755
|
+
dec_eval_masks_path = os.path.join(eval_dir_path, "decoder_mask_data.pt")
|
756
|
+
dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if os.path.isfile(dec_eval_masks_path) else None
|
757
|
+
|
758
|
+
enc_training_path = os.path.join(training_dir_path, "encoder_data.pt")
|
759
|
+
enc_training_data = torch.load(enc_training_path, weights_only=True) if os.path.isfile(enc_training_path) else None
|
760
|
+
|
761
|
+
enc_eval_path = os.path.join(eval_dir_path, "encoder_data.pt")
|
762
|
+
enc_eval_data = torch.load(enc_eval_path, weights_only=True) if os.path.isfile(enc_eval_path) else None
|
763
|
+
|
764
|
+
enc_training_masks_path = os.path.join(training_dir_path, "encoder_mask_data.pt")
|
765
|
+
enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if os.path.isfile(enc_training_masks_path) else None
|
766
|
+
|
767
|
+
enc_eval_masks_path = os.path.join(eval_dir_path, "encoder_mask_data.pt")
|
768
|
+
enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if os.path.isfile(enc_eval_masks_path) else None
|
769
|
+
|
770
|
+
if pad_token is None and dec_tokenizer is not None:
|
765
771
|
pad_token = dec_tokenizer.pad_token
|
766
772
|
|
767
773
|
self.to(self.device)
|
768
774
|
|
769
|
-
if pad_token
|
775
|
+
if pad_token is not None:
|
770
776
|
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
|
771
777
|
else:
|
772
778
|
loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
|
@@ -782,7 +788,7 @@ class RoboConstructor(nn.Module):
|
|
782
788
|
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
783
789
|
losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
|
784
790
|
out["train"] = losses.mean()
|
785
|
-
if dec_eval_data
|
791
|
+
if dec_eval_data is not None:
|
786
792
|
for k in range(eval_iters):
|
787
793
|
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
|
788
794
|
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
@@ -798,7 +804,7 @@ class RoboConstructor(nn.Module):
|
|
798
804
|
if iter % eval_interval == 0 or iter == max_iters-1:
|
799
805
|
losses = estimate_loss()
|
800
806
|
print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
|
801
|
-
if save_path
|
807
|
+
if save_path is not None:
|
802
808
|
save_component(self, save_path=save_path)
|
803
809
|
|
804
810
|
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
|
@@ -825,25 +831,25 @@ class RoboConstructor(nn.Module):
|
|
825
831
|
top_k:int=None,
|
826
832
|
top_p:float=None
|
827
833
|
) -> list[int]|str:
|
828
|
-
max_new_tokens = self.dec_block_size if max_new_tokens
|
834
|
+
max_new_tokens = self.dec_block_size if max_new_tokens is None else max_new_tokens
|
829
835
|
|
830
836
|
if self.cross_attention:
|
831
|
-
if enc_tokenizer
|
832
|
-
if enc_start_token
|
837
|
+
if enc_tokenizer is not None:
|
838
|
+
if enc_start_token is None:
|
833
839
|
enc_start_token = enc_tokenizer.start_token
|
834
|
-
if enc_end_token
|
840
|
+
if enc_end_token is None:
|
835
841
|
enc_end_token = enc_tokenizer.end_token
|
836
842
|
if isinstance(inputs, str):
|
837
843
|
inputs = enc_tokenizer.encode(inputs)
|
838
844
|
|
839
|
-
if dec_tokenizer
|
840
|
-
if dec_start_token
|
845
|
+
if dec_tokenizer is not None:
|
846
|
+
if dec_start_token is None:
|
841
847
|
dec_start_token = dec_tokenizer.start_token
|
842
|
-
if dec_end_token
|
848
|
+
if dec_end_token is None:
|
843
849
|
dec_end_token = dec_tokenizer.end_token
|
844
|
-
if new_line_token
|
850
|
+
if new_line_token is None:
|
845
851
|
new_line_token = dec_tokenizer.new_line_token
|
846
|
-
if self.cross_attention
|
852
|
+
if not self.cross_attention and isinstance(inputs, str):
|
847
853
|
inputs = dec_tokenizer.encode(inputs)
|
848
854
|
|
849
855
|
|
@@ -852,7 +858,7 @@ class RoboConstructor(nn.Module):
|
|
852
858
|
idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
|
853
859
|
else:
|
854
860
|
enc_input = None
|
855
|
-
if separator_token
|
861
|
+
if separator_token is not None:
|
856
862
|
idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
|
857
863
|
else:
|
858
864
|
idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
|
@@ -866,7 +872,7 @@ class RoboConstructor(nn.Module):
|
|
866
872
|
logits = proj_output[:, -1, :]
|
867
873
|
probabilities = F.log_softmax(logits/temperature, dim=-1)
|
868
874
|
|
869
|
-
if top_k
|
875
|
+
if top_k is None and top_p is None:
|
870
876
|
idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
|
871
877
|
else:
|
872
878
|
idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
|
@@ -874,10 +880,10 @@ class RoboConstructor(nn.Module):
|
|
874
880
|
if idx_next[0] == dec_end_token:
|
875
881
|
break
|
876
882
|
|
877
|
-
if dec_tokenizer
|
883
|
+
if dec_tokenizer is None:
|
878
884
|
return idx[0].tolist()
|
879
885
|
else:
|
880
|
-
if new_line_token
|
886
|
+
if new_line_token is not None:
|
881
887
|
return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
|
882
888
|
else:
|
883
889
|
return dec_tokenizer.decode(idx[0].tolist())
|
@@ -0,0 +1,82 @@
|
|
1
|
+
import os
|
2
|
+
import shutil
|
3
|
+
import torch
|
4
|
+
import pytest
|
5
|
+
from robo_lib.components import DataProcessor, TokenizerConstructor
|
6
|
+
|
7
|
+
|
8
|
+
@pytest.fixture
|
9
|
+
def temp_save_path():
|
10
|
+
path = "temp_test_dir"
|
11
|
+
yield path
|
12
|
+
if os.path.exists(path):
|
13
|
+
shutil.rmtree(path)
|
14
|
+
|
15
|
+
|
16
|
+
@pytest.fixture
|
17
|
+
def dummy_tokenizer():
|
18
|
+
tokenizer = TokenizerConstructor(
|
19
|
+
tokenizer_type="WordLevel",
|
20
|
+
pre_tokenizers="Whitespace",
|
21
|
+
special_tokens=["<pad>", "<unk>"],
|
22
|
+
vocab={"hello": 0, "world": 1, "<pad>": 2, "<unk>": 3},
|
23
|
+
unknown_token_string="<unk>",
|
24
|
+
pad_token_string="<pad>",
|
25
|
+
start_token_string=None,
|
26
|
+
end_token_string=None,
|
27
|
+
new_line_token_string=None
|
28
|
+
)
|
29
|
+
# Faking "training" so encode works with the vocab
|
30
|
+
tokenizer.tokenizer_type.add_tokens(["hello", "world"])
|
31
|
+
return tokenizer
|
32
|
+
|
33
|
+
|
34
|
+
def test_data_processor_initialization(dummy_tokenizer):
|
35
|
+
processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
|
36
|
+
assert processor.dec_tokenizer is dummy_tokenizer
|
37
|
+
assert processor.enc_tokenizer is None
|
38
|
+
|
39
|
+
|
40
|
+
def test_process_list_decoder_only(dummy_tokenizer, temp_save_path):
|
41
|
+
processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
|
42
|
+
data = ["hello world", "world hello"]
|
43
|
+
|
44
|
+
processor.process_list(
|
45
|
+
dec_data=data,
|
46
|
+
dec_max_block_size=10,
|
47
|
+
save_path=temp_save_path
|
48
|
+
)
|
49
|
+
|
50
|
+
assert os.path.exists(os.path.join(temp_save_path, "decoder_data.pt"))
|
51
|
+
assert os.path.exists(os.path.join(temp_save_path, "decoder_mask_data.pt"))
|
52
|
+
|
53
|
+
tensor = torch.load(os.path.join(temp_save_path, "decoder_data.pt"))
|
54
|
+
assert isinstance(tensor, torch.Tensor)
|
55
|
+
assert tensor.shape[0] == len(data)
|
56
|
+
|
57
|
+
|
58
|
+
def test_process_list_encoder_decoder(dummy_tokenizer, temp_save_path):
|
59
|
+
processor = DataProcessor(dec_tokenizer=dummy_tokenizer, enc_tokenizer=dummy_tokenizer)
|
60
|
+
data = ["hello world", "world hello"]
|
61
|
+
|
62
|
+
processor.process_list(
|
63
|
+
dec_data=data,
|
64
|
+
enc_data=data,
|
65
|
+
dec_max_block_size=10,
|
66
|
+
enc_max_block_size=10,
|
67
|
+
save_path=temp_save_path
|
68
|
+
)
|
69
|
+
|
70
|
+
assert os.path.exists(os.path.join(temp_save_path, "decoder_data.pt"))
|
71
|
+
assert os.path.exists(os.path.join(temp_save_path, "encoder_data.pt"))
|
72
|
+
assert os.path.exists(os.path.join(temp_save_path, "decoder_mask_data.pt"))
|
73
|
+
assert os.path.exists(os.path.join(temp_save_path, "encoder_mask_data.pt"))
|
74
|
+
|
75
|
+
|
76
|
+
def test_process_list_mismatched_lengths_raises(dummy_tokenizer):
|
77
|
+
processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
|
78
|
+
dec_data = ["hello world"]
|
79
|
+
enc_data = ["world hello", "extra row"]
|
80
|
+
|
81
|
+
with pytest.raises(Exception, match="decoder data and encoder data lengths do not match"):
|
82
|
+
processor.process_list(dec_data=dec_data, enc_data=enc_data)
|
@@ -0,0 +1,176 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
import numpy as np
|
4
|
+
import random
|
5
|
+
from robo_lib import create_mask, pre_process_data, safe_stack, get_valid_samples, get_batch, top_kp_filter
|
6
|
+
|
7
|
+
def test_create_mask_basic():
|
8
|
+
row = [1, 2, 3]
|
9
|
+
block_size = 5
|
10
|
+
expected = [1, 1, 1, 0, 0]
|
11
|
+
assert create_mask(row, block_size) == expected
|
12
|
+
|
13
|
+
def test_create_mask_equal_length():
|
14
|
+
row = [1, 2, 3, 4]
|
15
|
+
block_size = 4
|
16
|
+
expected = [1, 1, 1, 1]
|
17
|
+
assert create_mask(row, block_size) == expected
|
18
|
+
|
19
|
+
def test_create_mask_empty_row():
|
20
|
+
row = []
|
21
|
+
block_size = 3
|
22
|
+
expected = [0, 0, 0]
|
23
|
+
assert create_mask(row, block_size) == expected
|
24
|
+
|
25
|
+
def test_pre_process_data_none_tokens():
|
26
|
+
data = ["hello", "world"]
|
27
|
+
start_token = None
|
28
|
+
end_token = None
|
29
|
+
# Should return the input unchanged
|
30
|
+
assert pre_process_data(data.copy(), start_token, end_token) == data
|
31
|
+
|
32
|
+
def test_pre_process_data_start_token_only():
|
33
|
+
data = ["hello", "world"]
|
34
|
+
start_token = "<s>"
|
35
|
+
end_token = None
|
36
|
+
expected = ["<s>hello", "<s>world"]
|
37
|
+
assert pre_process_data(data.copy(), start_token, end_token) == expected
|
38
|
+
|
39
|
+
def test_pre_process_data_end_token_only():
|
40
|
+
data = ["hello", "world"]
|
41
|
+
start_token = None
|
42
|
+
end_token = "</s>"
|
43
|
+
expected = ["hello</s>", "world</s>"]
|
44
|
+
assert pre_process_data(data.copy(), start_token, end_token) == expected
|
45
|
+
|
46
|
+
def test_pre_process_data_both_tokens():
|
47
|
+
data = ["hello", "world"]
|
48
|
+
start_token = "<s>"
|
49
|
+
end_token = "</s>"
|
50
|
+
expected = ["<s>hello</s>", "<s>world</s>"]
|
51
|
+
assert pre_process_data(data.copy(), start_token, end_token) == expected
|
52
|
+
|
53
|
+
def test_safe_stack_valid_tensors():
|
54
|
+
t1 = torch.tensor([1, 2])
|
55
|
+
t2 = torch.tensor([3, 4])
|
56
|
+
tensor_list = [t1, t2]
|
57
|
+
stacked = safe_stack(tensor_list)
|
58
|
+
assert isinstance(stacked, torch.Tensor)
|
59
|
+
assert stacked.shape == (2, 2)
|
60
|
+
|
61
|
+
def test_safe_stack_ignore_non_tensors():
|
62
|
+
t1 = torch.tensor([1, 2])
|
63
|
+
not_tensor = [1, 2, 3]
|
64
|
+
tensor_list = [t1, not_tensor]
|
65
|
+
stacked = safe_stack(tensor_list)
|
66
|
+
assert stacked.shape == (1, 2)
|
67
|
+
|
68
|
+
def test_safe_stack_raises_for_empty():
|
69
|
+
with pytest.raises(ValueError):
|
70
|
+
safe_stack(["not a tensor", 123, None])
|
71
|
+
|
72
|
+
|
73
|
+
# For reproducibility
|
74
|
+
random.seed(0)
|
75
|
+
torch.manual_seed(0)
|
76
|
+
np.random.seed(0)
|
77
|
+
|
78
|
+
def test_get_valid_samples_all_masked_less_than_block():
|
79
|
+
masks = torch.tensor([[1, 0, 0], [1, 1, 0]])
|
80
|
+
random_samples = torch.tensor([0, 1])
|
81
|
+
block_size = 2
|
82
|
+
result = get_valid_samples(random_samples, masks, block_size)
|
83
|
+
# For first row sum(masks) = 1 <= block_size => should return 0
|
84
|
+
# For second row sum(masks) = 2 <= block_size => 0
|
85
|
+
assert result == [0, 0]
|
86
|
+
|
87
|
+
def test_get_valid_samples_some_greater_than_block():
|
88
|
+
masks = torch.tensor([[1, 1, 1], [1, 1, 0]])
|
89
|
+
random_samples = torch.tensor([0, 1])
|
90
|
+
block_size = 2
|
91
|
+
result = get_valid_samples(random_samples, masks, block_size)
|
92
|
+
# first sum = 3 > 2, so random index in [0, 1]
|
93
|
+
# second sum = 2 <= 2, so 0
|
94
|
+
assert result[1] == 0
|
95
|
+
assert 0 <= result[0] <= 1
|
96
|
+
|
97
|
+
def test_get_batch_no_masks_get_offset_true():
|
98
|
+
data = torch.arange(30).view(5, 6) # 5 rows, 6 cols
|
99
|
+
random_samples = torch.tensor([0, 1, 2])
|
100
|
+
block_size = 4
|
101
|
+
batch_in, batch_out, masks_in = get_batch(data, random_samples, masks=None, block_size=block_size, get_offset=True)
|
102
|
+
assert batch_in.shape == (3, block_size-1)
|
103
|
+
assert batch_out.shape == (3, block_size-1)
|
104
|
+
assert masks_in is None
|
105
|
+
|
106
|
+
def test_get_batch_with_masks_get_offset_false():
|
107
|
+
data = torch.arange(30).view(5, 6)
|
108
|
+
masks = torch.ones_like(data)
|
109
|
+
random_samples = torch.tensor([0, 1])
|
110
|
+
block_size = 5
|
111
|
+
batch_in, batch_out, masks_in = get_batch(data, random_samples, masks=masks, block_size=block_size, get_offset=False)
|
112
|
+
assert batch_in.shape == (2, block_size)
|
113
|
+
assert batch_out is None
|
114
|
+
assert masks_in.shape == (2, block_size)
|
115
|
+
|
116
|
+
def test_get_batch_block_size_larger_than_data_length_raises():
|
117
|
+
data = torch.arange(20).view(4, 5)
|
118
|
+
random_samples = torch.tensor([0])
|
119
|
+
block_size = 6
|
120
|
+
with pytest.raises(Exception):
|
121
|
+
get_batch(data, random_samples, block_size=block_size)
|
122
|
+
|
123
|
+
|
124
|
+
def test_top_kp_filter_top_k_only():
|
125
|
+
# Create dummy logits batch (2 samples, vocab size 5)
|
126
|
+
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0],
|
127
|
+
[5.0, 4.0, 3.0, 2.0, 1.0]])
|
128
|
+
top_k = 3
|
129
|
+
selected = top_kp_filter(logits, top_k=top_k, top_p=None)
|
130
|
+
|
131
|
+
assert selected.shape == (2,)
|
132
|
+
# Selected indices must be in top_k tokens
|
133
|
+
for i, sel in enumerate(selected):
|
134
|
+
topk_indices = torch.topk(logits[i], top_k).indices.tolist()
|
135
|
+
assert sel.item() in topk_indices
|
136
|
+
|
137
|
+
def test_top_kp_filter_top_p_only():
|
138
|
+
# Dummy logits with clear probabilities
|
139
|
+
logits = torch.tensor([[0.1, 0.2, 0.3, 0.4],
|
140
|
+
[0.4, 0.3, 0.2, 0.1]])
|
141
|
+
top_p = 0.7
|
142
|
+
selected = top_kp_filter(logits, top_k=None, top_p=top_p)
|
143
|
+
|
144
|
+
assert selected.shape == (2,)
|
145
|
+
# Selected indices must be in vocab range
|
146
|
+
for sel in selected:
|
147
|
+
assert 0 <= sel.item() < logits.shape[1]
|
148
|
+
|
149
|
+
def test_top_kp_filter_top_k_and_top_p():
|
150
|
+
logits = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5],
|
151
|
+
[0.5, 0.4, 0.3, 0.2, 0.1]])
|
152
|
+
top_k = 2
|
153
|
+
top_p = 0.6
|
154
|
+
selected = top_kp_filter(logits, top_k=top_k, top_p=top_p)
|
155
|
+
|
156
|
+
assert selected.shape == (2,)
|
157
|
+
for i, sel in enumerate(selected):
|
158
|
+
# With both filters, selected index should be in top_k indices
|
159
|
+
topk_indices = torch.topk(logits[i], top_k).indices.tolist()
|
160
|
+
assert sel.item() in topk_indices
|
161
|
+
|
162
|
+
def test_top_kp_filter_no_filter():
|
163
|
+
logits = torch.tensor([[0.1, 0.2, 0.3],
|
164
|
+
[0.3, 0.2, 0.1]])
|
165
|
+
selected = top_kp_filter(logits, top_k=None, top_p=None)
|
166
|
+
|
167
|
+
assert selected.shape == (2,)
|
168
|
+
for sel in selected:
|
169
|
+
assert 0 <= sel.item() < logits.shape[1]
|
170
|
+
|
171
|
+
def test_top_kp_filter_empty_logits():
|
172
|
+
# Edge case: logits empty or zero size
|
173
|
+
logits = torch.empty((0, 0))
|
174
|
+
with pytest.raises(IndexError):
|
175
|
+
_ = top_kp_filter(logits, top_k=1, top_p=0.5)
|
176
|
+
|
@@ -0,0 +1,130 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
import tempfile
|
4
|
+
import os
|
5
|
+
from types import SimpleNamespace
|
6
|
+
from unittest.mock import patch, MagicMock
|
7
|
+
from robo_lib import RoboConstructor, save_component, load_component
|
8
|
+
|
9
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10
|
+
|
11
|
+
|
12
|
+
# ---------- FIXTURES AND MOCKS ----------
|
13
|
+
|
14
|
+
@pytest.fixture
|
15
|
+
def mock_encoder_block():
|
16
|
+
return MagicMock()
|
17
|
+
|
18
|
+
@pytest.fixture
|
19
|
+
def mock_decoder_block():
|
20
|
+
return MagicMock()
|
21
|
+
|
22
|
+
@pytest.fixture
|
23
|
+
def mock_my_sequential():
|
24
|
+
class DummySequential(torch.nn.Module):
|
25
|
+
def __init__(self, *args, **kwargs):
|
26
|
+
super().__init__()
|
27
|
+
def forward(self, *args, **kwargs):
|
28
|
+
return args[0], None, None, None, None
|
29
|
+
return DummySequential
|
30
|
+
|
31
|
+
@pytest.fixture
|
32
|
+
def dummy_tokenizer():
|
33
|
+
return SimpleNamespace(
|
34
|
+
start_token=1,
|
35
|
+
end_token=2,
|
36
|
+
pad_token=0,
|
37
|
+
new_line_token=3,
|
38
|
+
encode=lambda s: [4, 5, 6],
|
39
|
+
decode=lambda tokens: "decoded"
|
40
|
+
)
|
41
|
+
|
42
|
+
@pytest.fixture
|
43
|
+
def dummy_data():
|
44
|
+
return torch.randint(0, 10, (8, 32)).to(device) # 8 samples of 32 tokens
|
45
|
+
|
46
|
+
@pytest.fixture
|
47
|
+
def robo_decoder_only(mock_my_sequential):
|
48
|
+
with patch("robo_lib.DecoderBlock", return_value=MagicMock()), \
|
49
|
+
patch("robo_lib.MySequential", mock_my_sequential):
|
50
|
+
return RoboConstructor(
|
51
|
+
n_embed=16,
|
52
|
+
dec_n_blocks=2,
|
53
|
+
dec_n_head=2,
|
54
|
+
dec_vocab_size=50,
|
55
|
+
dec_block_size=32
|
56
|
+
).to(device)
|
57
|
+
|
58
|
+
@pytest.fixture
|
59
|
+
def robo_enc_dec(mock_my_sequential):
|
60
|
+
with patch("robo_lib.DecoderBlock", return_value=MagicMock()), \
|
61
|
+
patch("robo_lib.EncoderBlock", return_value=MagicMock()), \
|
62
|
+
patch("robo_lib.MySequential", mock_my_sequential):
|
63
|
+
return RoboConstructor(
|
64
|
+
n_embed=16,
|
65
|
+
dec_n_blocks=2,
|
66
|
+
dec_n_head=2,
|
67
|
+
dec_vocab_size=50,
|
68
|
+
dec_block_size=32,
|
69
|
+
enc_n_blocks=2,
|
70
|
+
enc_n_head=2,
|
71
|
+
enc_vocab_size=50,
|
72
|
+
enc_block_size=32
|
73
|
+
).to(device)
|
74
|
+
|
75
|
+
# ---------- TESTS ----------
|
76
|
+
|
77
|
+
def test_decoder_only_init(robo_decoder_only):
|
78
|
+
assert not robo_decoder_only.cross_attention
|
79
|
+
assert robo_decoder_only.decoder_blocks is not None
|
80
|
+
assert robo_decoder_only.encoder_blocks is None
|
81
|
+
|
82
|
+
def test_encoder_decoder_init(robo_enc_dec):
|
83
|
+
assert robo_enc_dec.cross_attention
|
84
|
+
assert robo_enc_dec.encoder_blocks is not None
|
85
|
+
|
86
|
+
def test_forward_decoder_only(robo_decoder_only):
|
87
|
+
input_tensor = torch.randint(0, 50, (2, 32)).to(device)
|
88
|
+
output = robo_decoder_only(dec_in=input_tensor)
|
89
|
+
assert output.shape[:2] == (2, 32)
|
90
|
+
|
91
|
+
def test_forward_encoder_decoder(robo_enc_dec):
|
92
|
+
dec_input = torch.randint(0, 50, (2, 32)).to(device)
|
93
|
+
enc_input = torch.randint(0, 50, (2, 32)).to(device)
|
94
|
+
output = robo_enc_dec(dec_in=dec_input, enc_in=enc_input)
|
95
|
+
assert output.shape[:2] == (2, 32)
|
96
|
+
|
97
|
+
@patch("robo_lib.get_batch")
|
98
|
+
def test_prep_data_decoder_only(mock_get_batch, robo_decoder_only, dummy_data):
|
99
|
+
mock_get_batch.return_value = (dummy_data[:2], dummy_data[:2], dummy_data[:2])
|
100
|
+
out = robo_decoder_only.prep_data(batch_size=2, dec_data=dummy_data, dec_block_size=32)
|
101
|
+
assert len(out) == 5
|
102
|
+
assert out[0].shape[0] == 2
|
103
|
+
|
104
|
+
@patch("robo_lib.get_batch")
|
105
|
+
def test_prep_data_encoder_decoder(mock_get_batch, robo_enc_dec, dummy_data):
|
106
|
+
mock_get_batch.side_effect = [
|
107
|
+
(dummy_data[:2], dummy_data[:2], dummy_data[:2]), # decoder
|
108
|
+
(dummy_data[:2], None, dummy_data[:2]) # encoder
|
109
|
+
]
|
110
|
+
out = robo_enc_dec.prep_data(batch_size=2, dec_data=dummy_data, dec_block_size=32, enc_data=dummy_data, enc_block_size=32)
|
111
|
+
assert len(out) == 5
|
112
|
+
assert out[3].shape[0] == 2 # encoder input
|
113
|
+
|
114
|
+
@patch("robo_lib.top_kp_filter", return_value=torch.tensor([2]))
|
115
|
+
def test_generate_decoder_only(mock_top_kp, robo_decoder_only, dummy_tokenizer):
|
116
|
+
out = robo_decoder_only.generate(inputs="hello", dec_tokenizer=dummy_tokenizer, max_new_tokens=3, dec_start_token=1, dec_end_token=2)
|
117
|
+
assert isinstance(out, str)
|
118
|
+
|
119
|
+
@patch("robo_lib.top_kp_filter", return_value=torch.tensor([2]))
|
120
|
+
def test_generate_encoder_decoder(mock_top_kp, robo_enc_dec, dummy_tokenizer):
|
121
|
+
out = robo_enc_dec.generate(inputs="hello", enc_tokenizer=dummy_tokenizer, dec_tokenizer=dummy_tokenizer,
|
122
|
+
max_new_tokens=3, enc_start_token=1, enc_end_token=2, dec_start_token=1, dec_end_token=2)
|
123
|
+
assert isinstance(out, str)
|
124
|
+
|
125
|
+
def test_save_and_load_component(robo_decoder_only):
|
126
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
127
|
+
path = os.path.join(tmpdir, "test_model")
|
128
|
+
save_component(robo_decoder_only, path)
|
129
|
+
loaded = load_component(path)
|
130
|
+
assert isinstance(loaded, RoboConstructor)
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import pytest
|
2
|
+
import os
|
3
|
+
from tempfile import NamedTemporaryFile
|
4
|
+
from robo_lib import TokenizerConstructor
|
5
|
+
|
6
|
+
|
7
|
+
@pytest.fixture
|
8
|
+
def training_file():
|
9
|
+
with NamedTemporaryFile(mode="w+", delete=False) as f:
|
10
|
+
f.write("Hello world\nThis is a test\nTokenizer test\n")
|
11
|
+
f.flush()
|
12
|
+
yield f.name
|
13
|
+
os.remove(f.name)
|
14
|
+
|
15
|
+
|
16
|
+
def test_tokenizer_creation():
|
17
|
+
tokenizer = TokenizerConstructor(
|
18
|
+
tokenizer_type="BPE",
|
19
|
+
pre_tokenizers="Whitespace",
|
20
|
+
normalizers=["Lowercase"],
|
21
|
+
special_tokens=["<unk>", "<pad>"],
|
22
|
+
vocab_size=100
|
23
|
+
)
|
24
|
+
assert tokenizer is not None
|
25
|
+
assert "<unk>" in tokenizer.special_tokens
|
26
|
+
assert tokenizer.vocab_size is None # Untrained tokenizer should have vocab_size None
|
27
|
+
|
28
|
+
|
29
|
+
def test_tokenizer_train(training_file):
|
30
|
+
tokenizer = TokenizerConstructor(
|
31
|
+
tokenizer_type="WordLevel",
|
32
|
+
pre_tokenizers="Whitespace",
|
33
|
+
normalizers=["Lowercase"],
|
34
|
+
special_tokens=["<unk>", "<pad>"],
|
35
|
+
vocab_size=50
|
36
|
+
)
|
37
|
+
tokenizer.train(training_file)
|
38
|
+
assert tokenizer.vocab_size is not None
|
39
|
+
assert tokenizer.vocab_size > 0
|
40
|
+
|
41
|
+
|
42
|
+
def test_tokenizer_encode_decode(training_file):
|
43
|
+
tokenizer = TokenizerConstructor(
|
44
|
+
tokenizer_type="BPE",
|
45
|
+
pre_tokenizers="Whitespace",
|
46
|
+
normalizers=["Lowercase"],
|
47
|
+
special_tokens=["<unk>", "<pad>"],
|
48
|
+
vocab_size=50
|
49
|
+
)
|
50
|
+
tokenizer.train(training_file)
|
51
|
+
encoded = tokenizer.encode("This is a test")
|
52
|
+
assert isinstance(encoded, list)
|
53
|
+
assert all(isinstance(i, int) for i in encoded)
|
54
|
+
|
55
|
+
decoded = tokenizer.decode(encoded)
|
56
|
+
assert isinstance(decoded, str)
|
57
|
+
assert len(decoded) > 0
|
58
|
+
|
59
|
+
|
60
|
+
def test_tokenizer_encode_batch(training_file):
|
61
|
+
tokenizer = TokenizerConstructor(
|
62
|
+
tokenizer_type="BPE",
|
63
|
+
pre_tokenizers="Whitespace",
|
64
|
+
normalizers=["Lowercase"],
|
65
|
+
special_tokens=["<unk>", "<pad>"],
|
66
|
+
vocab_size=50
|
67
|
+
)
|
68
|
+
tokenizer.train(training_file)
|
69
|
+
batch = ["This is a test", "Hello world"]
|
70
|
+
encoded_batch = tokenizer.encode_batch(batch)
|
71
|
+
assert isinstance(encoded_batch, list)
|
72
|
+
assert len(encoded_batch) == len(batch)
|
73
|
+
assert all(isinstance(seq, list) for seq in encoded_batch)
|
74
|
+
|
75
|
+
encoded_truncated = tokenizer.encode_batch(batch, max_length=3)
|
76
|
+
assert all(len(seq) <= 3 for seq in encoded_truncated)
|
77
|
+
|
78
|
+
|
79
|
+
def test_special_token_indexes():
|
80
|
+
tokenizer = TokenizerConstructor(
|
81
|
+
tokenizer_type="BPE",
|
82
|
+
pre_tokenizers="Whitespace",
|
83
|
+
special_tokens=["<unk>", "<sos>", "<eos>", "<pad>", "\n"]
|
84
|
+
)
|
85
|
+
assert tokenizer.unknown_token == tokenizer.special_tokens.index("<unk>")
|
86
|
+
assert tokenizer.start_token == tokenizer.special_tokens.index("<sos>")
|
87
|
+
assert tokenizer.end_token == tokenizer.special_tokens.index("<eos>")
|
88
|
+
assert tokenizer.pad_token == tokenizer.special_tokens.index("<pad>")
|
89
|
+
assert tokenizer.new_line_token == tokenizer.special_tokens.index("\n")
|
File without changes
|
File without changes
|
File without changes
|