robo-lib 0.0.11__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ __pycache__/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robo_lib
3
- Version: 0.0.11
3
+ Version: 1.0.0
4
4
  Summary: A package to create, configure, and train transformer models.
5
5
  Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
6
6
  Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "robo_lib"
7
- version = "0.0.11"
7
+ version = "1.0.0"
8
8
  authors = [
9
9
  { name="Erik Papp", email="erik3papp@gmail.com" },
10
10
  ]
@@ -1,8 +1,7 @@
1
1
  from .components import TokenizerConstructor as TokenizerConstructor
2
2
  from .components import create_mask as create_mask
3
- from .components import pad as pad
4
- from .components import process_row as process_row
5
- from .components import scan_max_block_size as scan_max_block_size
3
+ from .components import pre_process_data as pre_process_data
4
+ from .components import safe_stack as safe_stack
6
5
  from .components import DataProcessor as DataProcessor
7
6
  from .components import get_valid_samples as get_valid_samples
8
7
  from .components import get_batch as get_batch
@@ -6,6 +6,8 @@ import numpy as np
6
6
  import random
7
7
  import pickle
8
8
  import itertools
9
+ from pathlib import Path
10
+ import os
9
11
 
10
12
  class TokenizerConstructor:
11
13
  '''
@@ -30,6 +32,7 @@ class TokenizerConstructor:
30
32
  tokenizer_type:str="BPE",
31
33
  pre_tokenizers:list[str]|str=["Whitespace"],
32
34
  normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
35
+ vocab:dict[str,int] = {},
33
36
  special_tokens:list[str]|str=[],
34
37
  unknown_token_string:str="<unk>",
35
38
  start_token_string:str="<sos>",
@@ -42,25 +45,28 @@ class TokenizerConstructor:
42
45
 
43
46
  if isinstance(special_tokens, str):
44
47
  special_tokens = [special_tokens]
45
- self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token != None]
46
- self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string != None else None
47
- self.start_token = self.special_tokens.index(start_token_string) if start_token_string != None else None
48
- self.end_token = self.special_tokens.index(end_token_string) if end_token_string != None else None
49
- self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string != None else None
50
- self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string != None else None
48
+ self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token is not None]
49
+ self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string is not None else None
50
+ self.start_token = self.special_tokens.index(start_token_string) if start_token_string is not None else None
51
+ self.end_token = self.special_tokens.index(end_token_string) if end_token_string is not None else None
52
+ self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string is not None else None
53
+ self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string is not None else None
51
54
 
52
55
  if tokenizer_type == "BPE":
53
56
  self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
54
57
  self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
55
58
  elif tokenizer_type == "WordLevel":
56
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
59
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(vocab = vocab, unk_token=unknown_token_string))
57
60
  self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
58
61
  elif tokenizer_type == "WordPiece":
59
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
62
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(vocab = vocab, unk_token=unknown_token_string))
60
63
  self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
61
64
  elif tokenizer_type == "Unigram":
62
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(unk_token=unknown_token_string))
65
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram())
63
66
  self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
67
+
68
+ if self.pad_token is not None:
69
+ self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=pad_token_string)
64
70
 
65
71
  if isinstance(pre_tokenizers, str):
66
72
  pre_tokenizers = [pre_tokenizers]
@@ -122,6 +128,13 @@ class TokenizerConstructor:
122
128
  def encode(self, inp:str) -> list[int]:
123
129
  return self.tokenizer_type.encode(inp).ids
124
130
 
131
+ def encode_batch(self, inp:list[str], max_length:int=None) -> list[list[int]]:
132
+ if max_length is not None:
133
+ self.tokenizer_type.enable_truncation(max_length=max_length)
134
+ out = [row.ids for row in self.tokenizer_type.encode_batch(inp)]
135
+ self.tokenizer_type.no_truncation()
136
+ return out
137
+
125
138
  def decode(self, inp:list[int]) -> str:
126
139
  return self.tokenizer_type.decode(inp)
127
140
 
@@ -136,38 +149,35 @@ def create_mask(row:list, block_size:int) -> list[bool]:
136
149
  mask = [1]*len(row) + [0]*(block_size - len(row))
137
150
  return mask
138
151
 
139
- def pad(row:list, block_size:int, pad_token:int) -> list[int]:
140
- '''
141
-
142
- returns padded row. Row is padded until length block_size with specified pad_token value
143
-
144
- '''
145
- row.extend([pad_token]*(block_size - len(row)))
146
- return row
147
-
148
- def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
152
+ def pre_process_data(data:str, start_token_string:str, end_token_string:str) -> list[int]:
149
153
  '''
150
154
 
151
- returns tokenized row using specified tokenizer, and adds the tokenizer's start and end tokens if they exist
155
+ returns string row with the tokenizer's start and end tokens if they exist
152
156
 
153
157
  '''
154
- processed_row = tokenizer.encode(row)
155
- if tokenizer.start_token != None:
156
- processed_row.insert(0, tokenizer.start_token)
157
- if tokenizer.end_token != None:
158
- processed_row.append(tokenizer.end_token)
159
-
160
- return processed_row
158
+ if start_token_string is None and end_token_string is None:
159
+ return data
160
+ else:
161
+ for i in range(len(data)):
162
+ if start_token_string is not None:
163
+ data[i] = start_token_string + data[i]
164
+ if end_token_string is not None:
165
+ data[i] = data[i] + end_token_string
166
+
167
+ return data
161
168
 
162
- def scan_max_block_size(data:list[str], tokenizer:TokenizerConstructor) -> int:
169
+ def safe_stack(tensor_list:list[torch.tensor]) -> torch.tensor:
163
170
  '''
164
171
 
165
- returns max_block_size of given list of strings by taking the length of the longest process_row(row) in data
172
+ torch stack with check to ensure tensors are valid in input list
173
+
174
+ returns torch.stack(out_list) for all valid torch tensors in tensor_list. raises error if no valid tensors
166
175
 
167
176
  '''
168
- lengths = [len(process_row(p, tokenizer)) for p in data]
169
- max_block_size_scanner = max(lengths)
170
- return max_block_size_scanner
177
+ out_list = [row for row in tensor_list if isinstance(row, torch.Tensor)]
178
+ if len(out_list) == 0:
179
+ raise ValueError("no valid tensors in list.")
180
+ return torch.stack(out_list)
171
181
 
172
182
 
173
183
  class DataProcessor:
@@ -196,93 +206,55 @@ class DataProcessor:
196
206
  self.enc_tokenizer = enc_tokenizer
197
207
 
198
208
  def process_list(self,
199
- save_path:str,
200
209
  dec_data:list[str]|str,
201
210
  dec_max_block_size:int=None,
202
211
  dec_create_masks:bool=True,
203
- dec_block_size_exceeded_policy:str=None,
204
212
  enc_data:list[str]=None,
205
213
  enc_max_block_size:int=None,
206
214
  enc_create_masks:bool=True,
207
- enc_block_size_exceeded_policy:str=None
215
+ save_path:str = "."
208
216
  ) -> None:
209
217
 
210
218
  if isinstance(dec_data, str):
211
219
  dec_data = [dec_data]
212
220
  dec_data_length = len(dec_data)
213
- save_path = save_path.replace(".pt", "")
214
-
215
- if dec_max_block_size == None:
216
- dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
217
221
 
218
- if enc_data != None:
219
- self.enc_tokenizer = self.dec_tokenizer if self.enc_tokenizer == None else self.enc_tokenizer
222
+ if enc_data is not None:
223
+ if self.enc_tokenizer is None:
224
+ self.enc_tokenizer = self.dec_tokenizer
220
225
 
221
226
  enc_data_length = len(enc_data)
222
227
  if dec_data_length != enc_data_length:
223
- raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
228
+ raise Exception(f"decoder data and encoder data lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
224
229
 
225
- if enc_max_block_size == None:
226
- enc_max_block_size = scan_max_block_size(enc_data, self.enc_tokenizer)
227
-
228
- enc_out_list = [[]]*enc_data_length
229
- enc_mask_list = [[]]*enc_data_length if enc_create_masks else []
230
- else:
231
- enc_out_list = []
232
- enc_mask_list = []
233
-
234
- dec_out_list = [[]]*dec_data_length
235
- dec_mask_list = [[]]*dec_data_length if dec_create_masks else []
236
- for index in range(len(dec_out_list)):
237
- dec_processed_item = process_row(dec_data[index], self.dec_tokenizer)
238
- if dec_max_block_size != None and len(dec_processed_item) > dec_max_block_size:
239
- if dec_block_size_exceeded_policy == "trim":
240
- dec_processed_item = dec_processed_item[:dec_max_block_size]
241
- elif dec_block_size_exceeded_policy == "skip":
242
- continue
243
- elif dec_block_size_exceeded_policy == None:
244
- raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
245
- if dec_create_masks:
246
- dec_mask = create_mask(dec_processed_item, dec_max_block_size)
247
- dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
248
-
249
- if enc_data != None:
250
- enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
251
- if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
252
- if enc_block_size_exceeded_policy == "trim":
253
- enc_processed_item = enc_processed_item[:enc_max_block_size]
254
- elif enc_block_size_exceeded_policy == "skip":
255
- continue
256
- elif enc_block_size_exceeded_policy == None:
257
- raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
258
- if enc_create_masks:
259
- enc_mask = create_mask(enc_processed_item, enc_max_block_size)
260
- enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
261
-
262
- dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
263
- if dec_create_masks:
264
- dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
265
-
266
- if enc_data != None:
267
- enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
268
- if enc_create_masks:
269
- enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
270
-
271
- dec_out_list = torch.stack([row for row in dec_out_list if row != []])
272
- torch.save(dec_out_list, save_path + "_decoder_data.pt")
230
+ print("processing data")
231
+ dec_out_list = self.dec_tokenizer.encode_batch(dec_data, max_length=dec_max_block_size)
273
232
  if dec_create_masks:
274
- dec_mask_list = torch.stack([row for row in dec_mask_list if row != []])
275
- torch.save(dec_mask_list, save_path + "_decoder_mask_data.pt")
276
- if enc_data != None:
277
- enc_out_list = torch.stack([row for row in enc_out_list if row != []])
278
- torch.save(enc_out_list, save_path + "_encoder_data.pt")
233
+ mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.dec_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
234
+ dec_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in dec_out_list])
235
+
236
+ if enc_data is not None:
237
+ enc_out_list = self.enc_tokenizer.encode_batch(enc_data, max_length=enc_max_block_size)
279
238
  if enc_create_masks:
280
- enc_mask_list = torch.stack([row for row in enc_mask_list if row != []])
281
- torch.save(enc_mask_list, save_path + "_encoder_mask_data.pt")
239
+ mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.enc_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
240
+ enc_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in enc_out_list])
282
241
 
242
+ dec_out_list = torch.tensor(dec_out_list, dtype=torch.long)
243
+ Path(save_path).mkdir(parents=True, exist_ok=True)
244
+ torch.save(dec_out_list, os.path.join(save_path, "decoder_data.pt"))
245
+ if dec_create_masks:
246
+ dec_mask_list = torch.tensor(dec_mask_list, dtype=torch.long)
247
+ torch.save(dec_mask_list, os.path.join(save_path, "decoder_mask_data.pt"))
248
+ if enc_data is not None:
249
+ enc_out_list = torch.tensor(enc_out_list, dtype=torch.long)
250
+ torch.save(enc_out_list, os.path.join(save_path, "encoder_data.pt"))
251
+ if enc_create_masks:
252
+ enc_mask_list = torch.tensor(enc_mask_list, dtype=torch.long)
253
+ torch.save(enc_mask_list, os.path.join(save_path, "encoder_mask_data.pt"))
283
254
 
284
- def get_valid_samples(random_samples:torch.tensor,
285
- masks:torch.tensor,
255
+
256
+ def get_valid_samples(random_samples:torch.Tensor,
257
+ masks:torch.Tensor,
286
258
  block_size:int
287
259
  ) -> list[int]:
288
260
  '''
@@ -294,9 +266,9 @@ def get_valid_samples(random_samples:torch.tensor,
294
266
  valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
295
267
  return valid_samples
296
268
 
297
- def get_batch(data:torch.tensor,
298
- random_samples:torch.tensor,
299
- masks:torch.tensor=None,
269
+ def get_batch(data:torch.Tensor,
270
+ random_samples:torch.Tensor,
271
+ masks:torch.Tensor=None,
300
272
  block_size:int=None,
301
273
  get_offset:bool=True
302
274
  ) -> tuple[torch.tensor]:
@@ -308,53 +280,78 @@ def get_batch(data:torch.tensor,
308
280
 
309
281
  '''
310
282
  batch_size = len(random_samples)
311
- if block_size != None and block_size != data.shape[1]:
283
+ if block_size is not None and block_size != data.shape[1]:
312
284
  if block_size >= data.shape[1]:
313
285
  raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
314
286
 
315
- if masks != None:
287
+ if masks is not None:
316
288
  random_point = get_valid_samples(random_samples, masks, block_size)
317
289
  else:
318
290
  random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
319
291
  batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
320
- masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks != None else None
292
+ masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks is not None else None
321
293
  batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
322
294
  else:
323
295
  block_size = data.shape[1]
324
296
  batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
325
- masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks != None else None
297
+ masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks is not None else None
326
298
  batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
327
299
 
328
300
  return batch_in, batch_out, masks_in
329
301
 
330
- def top_kp_filter(logits:torch.tensor,
331
- top_k:int,
332
- top_p:float=None
333
- ) -> torch.tensor:
302
+ def top_kp_filter(logits: torch.Tensor,
303
+ top_k: int = None,
304
+ top_p: float = None
305
+ ) -> torch.Tensor:
334
306
  '''
307
+ Returns predicted token by filtering output logits using top_k and/or top_p (nucleus) filtering.
335
308
 
336
- returns predicted token by filtering output logits using top_k and top_p
337
-
338
- '''
339
- if top_p != None:
340
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
341
- cumulative_probs = torch.cumsum(sorted_logits, dim=-1)
309
+ Args:
310
+ logits: (batch_size, vocab_size) tensor of raw logits.
311
+ top_k: keep only top_k tokens with highest logits.
312
+ top_p: keep the smallest set of tokens with cumulative probability >= top_p.
342
313
 
343
- filter = cumulative_probs > top_p
344
- filter[..., 1:] = filter[..., :-1].clone()
345
- filter[..., 0] = 0
346
- indices_to_remove = filter.scatter(1, sorted_indices, filter)
347
- logits[indices_to_remove] = float("-inf")
348
-
349
- if top_k != None:
350
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
314
+ Returns:
315
+ selected: tensor of selected token indices (batch_size,)
316
+ '''
317
+ logits = logits.clone() # avoid modifying input logits in-place
318
+
319
+ # Apply top-p filtering if specified
320
+ if top_p is not None:
321
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
322
+ probs = F.softmax(sorted_logits, dim=-1)
323
+ cumulative_probs = torch.cumsum(probs, dim=-1)
324
+
325
+ # Remove tokens with cumulative probability above threshold (except first token)
326
+ sorted_mask = cumulative_probs > top_p
327
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
328
+ sorted_mask[..., 0] = False
329
+
330
+ # Mask tokens to remove by setting logits to -inf
331
+ indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
332
+ logits[indices_to_remove] = float('-inf')
333
+
334
+ # Apply top-k filtering if specified
335
+ if top_k is not None:
336
+ top_k = min(top_k, logits.size(-1)) # safety check
337
+ topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
338
+ topk_probs = F.softmax(topk_logits, dim=-1).cpu().numpy()
339
+
340
+ # For each batch, sample from top_k candidates
341
+ selected = []
342
+ for i in range(topk_probs.shape[0]):
343
+ candidate = np.random.choice(topk_indices[i].cpu().numpy(), 1, p=topk_probs[i])
344
+ selected.append(candidate[0])
345
+ selected = torch.tensor(selected, dtype=torch.long)
351
346
 
352
- sorted_logits = F.softmax(sorted_logits[:, :top_k], dim=-1)
353
- sorted_indices = sorted_indices[:, :top_k].detach().cpu()
354
- sorted_logits = sorted_logits.detach().cpu().numpy()
355
- sorted_logits[0][0] += 1 - sum(sorted_logits[0])
356
-
357
- selected = torch.tensor(np.random.choice(sorted_indices[0], 1, p=sorted_logits[0]), dtype=torch.long)
347
+ else:
348
+ # If only top_p is specified, sample from entire filtered logits
349
+ probs = F.softmax(logits, dim=-1).cpu().numpy()
350
+ selected = []
351
+ for i in range(probs.shape[0]):
352
+ candidate = np.random.choice(len(probs[i]), 1, p=probs[i])
353
+ selected.append(candidate[0])
354
+ selected = torch.tensor(selected, dtype=torch.long)
358
355
 
359
356
  return selected
360
357
 
@@ -387,10 +384,10 @@ class SelfAttention(nn.Module):
387
384
  self.dropout = nn.Dropout(dropout)
388
385
 
389
386
  def forward(self,
390
- k:torch.tensor,
391
- q:torch.tensor,
392
- v:torch.tensor,
393
- mask:torch.tensor=None
387
+ k:torch.Tensor,
388
+ q:torch.Tensor,
389
+ v:torch.Tensor,
390
+ mask:torch.Tensor=None
394
391
  ) -> torch.tensor:
395
392
  '''
396
393
 
@@ -406,7 +403,7 @@ class SelfAttention(nn.Module):
406
403
  wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
407
404
  if self.triangle_mask and self.block_size >= 0:
408
405
  wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
409
- if mask != None:
406
+ if mask is not None:
410
407
  wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
411
408
  wei = F.softmax(wei, dim=-1)
412
409
  wei = self.dropout(wei)
@@ -438,10 +435,10 @@ class MultiHeadAttention(nn.Module):
438
435
  self.dropout = nn.Dropout(dropout)
439
436
 
440
437
  def forward(self,
441
- k:torch.tensor,
442
- q:torch.tensor,
443
- v:torch.tensor,
444
- mask:torch.tensor=None
438
+ k:torch.Tensor,
439
+ q:torch.Tensor,
440
+ v:torch.Tensor,
441
+ mask:torch.Tensor=None
445
442
  ) -> torch.tensor:
446
443
  '''
447
444
 
@@ -475,7 +472,7 @@ class FeedForward(nn.Module):
475
472
  )
476
473
 
477
474
  def forward(self,
478
- x:torch.tensor
475
+ x:torch.Tensor
479
476
  ) -> torch.tensor:
480
477
  return self.net(x)
481
478
 
@@ -500,8 +497,8 @@ class EncoderBlock(nn.Module):
500
497
  self.ln2 = nn.LayerNorm(n_embed)
501
498
 
502
499
  def forward(self,
503
- x:torch.tensor,
504
- mask:torch.tensor=None
500
+ x:torch.Tensor,
501
+ mask:torch.Tensor=None
505
502
  ) -> tuple[torch.tensor]:
506
503
  att = self.sa(x, x, x, mask=mask)
507
504
  x = self.ln1(att + x)
@@ -541,15 +538,15 @@ class DecoderBlock(nn.Module):
541
538
  self.ca = None
542
539
 
543
540
  def forward(self,
544
- x:torch.tensor,
545
- enc_k:torch.tensor,
546
- enc_v:torch.tensor,
541
+ x:torch.Tensor,
542
+ enc_k:torch.Tensor,
543
+ enc_v:torch.Tensor,
547
544
  mask_out:bool=None,
548
- mask_in:torch.tensor=None
545
+ mask_in:torch.Tensor=None
549
546
  ) -> tuple[torch.tensor]:
550
547
  att = self.sa(x, x, x, mask=mask_out)
551
548
  x = self.ln1(att + x)
552
- if self.ca != None:
549
+ if self.ca is not None:
553
550
  catt = self.ca(enc_k, x, enc_v, mask=mask_in)
554
551
  x = self.ln3(catt + x)
555
552
  ff = self.ffwd(x)
@@ -628,7 +625,7 @@ class RoboConstructor(nn.Module):
628
625
  self.dec_expansion_factor = dec_expansion_factor
629
626
  self.dropout = dropout
630
627
 
631
- if device == None:
628
+ if device is None:
632
629
  self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
633
630
  else:
634
631
  self.device = device
@@ -673,13 +670,13 @@ class RoboConstructor(nn.Module):
673
670
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
674
671
 
675
672
  def forward(self,
676
- dec_in:torch.tensor,
677
- dec_mask:torch.tensor=None,
678
- enc_in:torch.tensor=None,
679
- enc_mask:torch.tensor=None
673
+ dec_in:torch.Tensor,
674
+ dec_mask:torch.Tensor=None,
675
+ enc_in:torch.Tensor=None,
676
+ enc_mask:torch.Tensor=None
680
677
  ) -> torch.tensor:
681
678
  _, dec_T = dec_in.shape
682
- if enc_in != None:
679
+ if enc_in is not None:
683
680
  _, enc_T = enc_in.shape
684
681
 
685
682
  dec_tok_emb = self.dec_token_embedding_table(dec_in)
@@ -718,13 +715,13 @@ class RoboConstructor(nn.Module):
718
715
 
719
716
  dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
720
717
  dec_train_batch_in = dec_train_batch_in.to(self.device)
721
- dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out != None else None
722
- dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in != None else None
718
+ dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out is not None else None
719
+ dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in is not None else None
723
720
 
724
721
  if self.cross_attention:
725
722
  enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
726
723
  enc_train_batch_in = enc_train_batch_in.to(self.device)
727
- enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in != None else None
724
+ enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in is not None else None
728
725
  else:
729
726
  enc_train_batch_in = None
730
727
  enc_train_masks_in = None
@@ -736,14 +733,8 @@ class RoboConstructor(nn.Module):
736
733
  max_iters:int,
737
734
  eval_interval:int,
738
735
  batch_size:int,
739
- dec_training_path:str,
740
- dec_eval_path:str=None,
741
- dec_training_masks_path:str=None,
742
- dec_eval_masks_path:str=None,
743
- enc_training_path:str=None,
744
- enc_eval_path:str=None,
745
- enc_training_masks_path:str=None,
746
- enc_eval_masks_path:str=None,
736
+ training_dir_path:str,
737
+ eval_dir_path:str,
747
738
  eval_iters:int=3,
748
739
  learning_rate:float=1e-4,
749
740
  pad_token:int=None,
@@ -752,21 +743,36 @@ class RoboConstructor(nn.Module):
752
743
  label_smoothing:float=0.1
753
744
  ) -> None:
754
745
 
746
+ dec_training_path = os.path.join(training_dir_path, "decoder_data.pt")
755
747
  dec_training_data = torch.load(dec_training_path, weights_only=True)
756
- dec_eval_data = torch.load(dec_eval_path, weights_only=True) if dec_eval_path != None else None
757
- dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if dec_training_masks_path != None else None
758
- dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if dec_eval_masks_path != None else None
759
- enc_training_data = torch.load(enc_training_path, weights_only=True) if enc_training_path != None else None
760
- enc_eval_data = torch.load(enc_eval_path, weights_only=True) if enc_eval_path != None else None
761
- enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if enc_training_masks_path != None else None
762
- enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if enc_eval_masks_path != None else None
763
-
764
- if pad_token == None and dec_tokenizer != None:
748
+
749
+ dec_eval_path = os.path.join(eval_dir_path, "decoder_data.pt")
750
+ dec_eval_data = torch.load(dec_eval_path, weights_only=True) if os.path.isfile(dec_eval_path) else None
751
+
752
+ dec_training_masks_path = os.path.join(training_dir_path, "decoder_mask_data.pt")
753
+ dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if os.path.isfile(dec_training_masks_path) else None
754
+
755
+ dec_eval_masks_path = os.path.join(eval_dir_path, "decoder_mask_data.pt")
756
+ dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if os.path.isfile(dec_eval_masks_path) else None
757
+
758
+ enc_training_path = os.path.join(training_dir_path, "encoder_data.pt")
759
+ enc_training_data = torch.load(enc_training_path, weights_only=True) if os.path.isfile(enc_training_path) else None
760
+
761
+ enc_eval_path = os.path.join(eval_dir_path, "encoder_data.pt")
762
+ enc_eval_data = torch.load(enc_eval_path, weights_only=True) if os.path.isfile(enc_eval_path) else None
763
+
764
+ enc_training_masks_path = os.path.join(training_dir_path, "encoder_mask_data.pt")
765
+ enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if os.path.isfile(enc_training_masks_path) else None
766
+
767
+ enc_eval_masks_path = os.path.join(eval_dir_path, "encoder_mask_data.pt")
768
+ enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if os.path.isfile(enc_eval_masks_path) else None
769
+
770
+ if pad_token is None and dec_tokenizer is not None:
765
771
  pad_token = dec_tokenizer.pad_token
766
772
 
767
773
  self.to(self.device)
768
774
 
769
- if pad_token != None:
775
+ if pad_token is not None:
770
776
  loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
771
777
  else:
772
778
  loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
@@ -782,7 +788,7 @@ class RoboConstructor(nn.Module):
782
788
  proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
783
789
  losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
784
790
  out["train"] = losses.mean()
785
- if dec_eval_data != None:
791
+ if dec_eval_data is not None:
786
792
  for k in range(eval_iters):
787
793
  dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
788
794
  proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
@@ -798,7 +804,7 @@ class RoboConstructor(nn.Module):
798
804
  if iter % eval_interval == 0 or iter == max_iters-1:
799
805
  losses = estimate_loss()
800
806
  print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
801
- if save_path != None:
807
+ if save_path is not None:
802
808
  save_component(self, save_path=save_path)
803
809
 
804
810
  dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
@@ -825,25 +831,25 @@ class RoboConstructor(nn.Module):
825
831
  top_k:int=None,
826
832
  top_p:float=None
827
833
  ) -> list[int]|str:
828
- max_new_tokens = self.dec_block_size if max_new_tokens == None else max_new_tokens
834
+ max_new_tokens = self.dec_block_size if max_new_tokens is None else max_new_tokens
829
835
 
830
836
  if self.cross_attention:
831
- if enc_tokenizer != None:
832
- if enc_start_token == None:
837
+ if enc_tokenizer is not None:
838
+ if enc_start_token is None:
833
839
  enc_start_token = enc_tokenizer.start_token
834
- if enc_end_token == None:
840
+ if enc_end_token is None:
835
841
  enc_end_token = enc_tokenizer.end_token
836
842
  if isinstance(inputs, str):
837
843
  inputs = enc_tokenizer.encode(inputs)
838
844
 
839
- if dec_tokenizer != None:
840
- if dec_start_token == None:
845
+ if dec_tokenizer is not None:
846
+ if dec_start_token is None:
841
847
  dec_start_token = dec_tokenizer.start_token
842
- if dec_end_token == None:
848
+ if dec_end_token is None:
843
849
  dec_end_token = dec_tokenizer.end_token
844
- if new_line_token == None:
850
+ if new_line_token is None:
845
851
  new_line_token = dec_tokenizer.new_line_token
846
- if self.cross_attention == False and isinstance(inputs, str):
852
+ if not self.cross_attention and isinstance(inputs, str):
847
853
  inputs = dec_tokenizer.encode(inputs)
848
854
 
849
855
 
@@ -852,7 +858,7 @@ class RoboConstructor(nn.Module):
852
858
  idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
853
859
  else:
854
860
  enc_input = None
855
- if separator_token != None:
861
+ if separator_token is not None:
856
862
  idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
857
863
  else:
858
864
  idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
@@ -866,7 +872,7 @@ class RoboConstructor(nn.Module):
866
872
  logits = proj_output[:, -1, :]
867
873
  probabilities = F.log_softmax(logits/temperature, dim=-1)
868
874
 
869
- if top_k == None and top_p == None:
875
+ if top_k is None and top_p is None:
870
876
  idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
871
877
  else:
872
878
  idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
@@ -874,10 +880,10 @@ class RoboConstructor(nn.Module):
874
880
  if idx_next[0] == dec_end_token:
875
881
  break
876
882
 
877
- if dec_tokenizer == None:
883
+ if dec_tokenizer is None:
878
884
  return idx[0].tolist()
879
885
  else:
880
- if new_line_token != None:
886
+ if new_line_token is not None:
881
887
  return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
882
888
  else:
883
889
  return dec_tokenizer.decode(idx[0].tolist())
@@ -0,0 +1,82 @@
1
+ import os
2
+ import shutil
3
+ import torch
4
+ import pytest
5
+ from robo_lib.components import DataProcessor, TokenizerConstructor
6
+
7
+
8
+ @pytest.fixture
9
+ def temp_save_path():
10
+ path = "temp_test_dir"
11
+ yield path
12
+ if os.path.exists(path):
13
+ shutil.rmtree(path)
14
+
15
+
16
+ @pytest.fixture
17
+ def dummy_tokenizer():
18
+ tokenizer = TokenizerConstructor(
19
+ tokenizer_type="WordLevel",
20
+ pre_tokenizers="Whitespace",
21
+ special_tokens=["<pad>", "<unk>"],
22
+ vocab={"hello": 0, "world": 1, "<pad>": 2, "<unk>": 3},
23
+ unknown_token_string="<unk>",
24
+ pad_token_string="<pad>",
25
+ start_token_string=None,
26
+ end_token_string=None,
27
+ new_line_token_string=None
28
+ )
29
+ # Faking "training" so encode works with the vocab
30
+ tokenizer.tokenizer_type.add_tokens(["hello", "world"])
31
+ return tokenizer
32
+
33
+
34
+ def test_data_processor_initialization(dummy_tokenizer):
35
+ processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
36
+ assert processor.dec_tokenizer is dummy_tokenizer
37
+ assert processor.enc_tokenizer is None
38
+
39
+
40
+ def test_process_list_decoder_only(dummy_tokenizer, temp_save_path):
41
+ processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
42
+ data = ["hello world", "world hello"]
43
+
44
+ processor.process_list(
45
+ dec_data=data,
46
+ dec_max_block_size=10,
47
+ save_path=temp_save_path
48
+ )
49
+
50
+ assert os.path.exists(os.path.join(temp_save_path, "decoder_data.pt"))
51
+ assert os.path.exists(os.path.join(temp_save_path, "decoder_mask_data.pt"))
52
+
53
+ tensor = torch.load(os.path.join(temp_save_path, "decoder_data.pt"))
54
+ assert isinstance(tensor, torch.Tensor)
55
+ assert tensor.shape[0] == len(data)
56
+
57
+
58
+ def test_process_list_encoder_decoder(dummy_tokenizer, temp_save_path):
59
+ processor = DataProcessor(dec_tokenizer=dummy_tokenizer, enc_tokenizer=dummy_tokenizer)
60
+ data = ["hello world", "world hello"]
61
+
62
+ processor.process_list(
63
+ dec_data=data,
64
+ enc_data=data,
65
+ dec_max_block_size=10,
66
+ enc_max_block_size=10,
67
+ save_path=temp_save_path
68
+ )
69
+
70
+ assert os.path.exists(os.path.join(temp_save_path, "decoder_data.pt"))
71
+ assert os.path.exists(os.path.join(temp_save_path, "encoder_data.pt"))
72
+ assert os.path.exists(os.path.join(temp_save_path, "decoder_mask_data.pt"))
73
+ assert os.path.exists(os.path.join(temp_save_path, "encoder_mask_data.pt"))
74
+
75
+
76
+ def test_process_list_mismatched_lengths_raises(dummy_tokenizer):
77
+ processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
78
+ dec_data = ["hello world"]
79
+ enc_data = ["world hello", "extra row"]
80
+
81
+ with pytest.raises(Exception, match="decoder data and encoder data lengths do not match"):
82
+ processor.process_list(dec_data=dec_data, enc_data=enc_data)
@@ -0,0 +1,176 @@
1
+ import pytest
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ from robo_lib import create_mask, pre_process_data, safe_stack, get_valid_samples, get_batch, top_kp_filter
6
+
7
+ def test_create_mask_basic():
8
+ row = [1, 2, 3]
9
+ block_size = 5
10
+ expected = [1, 1, 1, 0, 0]
11
+ assert create_mask(row, block_size) == expected
12
+
13
+ def test_create_mask_equal_length():
14
+ row = [1, 2, 3, 4]
15
+ block_size = 4
16
+ expected = [1, 1, 1, 1]
17
+ assert create_mask(row, block_size) == expected
18
+
19
+ def test_create_mask_empty_row():
20
+ row = []
21
+ block_size = 3
22
+ expected = [0, 0, 0]
23
+ assert create_mask(row, block_size) == expected
24
+
25
+ def test_pre_process_data_none_tokens():
26
+ data = ["hello", "world"]
27
+ start_token = None
28
+ end_token = None
29
+ # Should return the input unchanged
30
+ assert pre_process_data(data.copy(), start_token, end_token) == data
31
+
32
+ def test_pre_process_data_start_token_only():
33
+ data = ["hello", "world"]
34
+ start_token = "<s>"
35
+ end_token = None
36
+ expected = ["<s>hello", "<s>world"]
37
+ assert pre_process_data(data.copy(), start_token, end_token) == expected
38
+
39
+ def test_pre_process_data_end_token_only():
40
+ data = ["hello", "world"]
41
+ start_token = None
42
+ end_token = "</s>"
43
+ expected = ["hello</s>", "world</s>"]
44
+ assert pre_process_data(data.copy(), start_token, end_token) == expected
45
+
46
+ def test_pre_process_data_both_tokens():
47
+ data = ["hello", "world"]
48
+ start_token = "<s>"
49
+ end_token = "</s>"
50
+ expected = ["<s>hello</s>", "<s>world</s>"]
51
+ assert pre_process_data(data.copy(), start_token, end_token) == expected
52
+
53
+ def test_safe_stack_valid_tensors():
54
+ t1 = torch.tensor([1, 2])
55
+ t2 = torch.tensor([3, 4])
56
+ tensor_list = [t1, t2]
57
+ stacked = safe_stack(tensor_list)
58
+ assert isinstance(stacked, torch.Tensor)
59
+ assert stacked.shape == (2, 2)
60
+
61
+ def test_safe_stack_ignore_non_tensors():
62
+ t1 = torch.tensor([1, 2])
63
+ not_tensor = [1, 2, 3]
64
+ tensor_list = [t1, not_tensor]
65
+ stacked = safe_stack(tensor_list)
66
+ assert stacked.shape == (1, 2)
67
+
68
+ def test_safe_stack_raises_for_empty():
69
+ with pytest.raises(ValueError):
70
+ safe_stack(["not a tensor", 123, None])
71
+
72
+
73
+ # For reproducibility
74
+ random.seed(0)
75
+ torch.manual_seed(0)
76
+ np.random.seed(0)
77
+
78
+ def test_get_valid_samples_all_masked_less_than_block():
79
+ masks = torch.tensor([[1, 0, 0], [1, 1, 0]])
80
+ random_samples = torch.tensor([0, 1])
81
+ block_size = 2
82
+ result = get_valid_samples(random_samples, masks, block_size)
83
+ # For first row sum(masks) = 1 <= block_size => should return 0
84
+ # For second row sum(masks) = 2 <= block_size => 0
85
+ assert result == [0, 0]
86
+
87
+ def test_get_valid_samples_some_greater_than_block():
88
+ masks = torch.tensor([[1, 1, 1], [1, 1, 0]])
89
+ random_samples = torch.tensor([0, 1])
90
+ block_size = 2
91
+ result = get_valid_samples(random_samples, masks, block_size)
92
+ # first sum = 3 > 2, so random index in [0, 1]
93
+ # second sum = 2 <= 2, so 0
94
+ assert result[1] == 0
95
+ assert 0 <= result[0] <= 1
96
+
97
+ def test_get_batch_no_masks_get_offset_true():
98
+ data = torch.arange(30).view(5, 6) # 5 rows, 6 cols
99
+ random_samples = torch.tensor([0, 1, 2])
100
+ block_size = 4
101
+ batch_in, batch_out, masks_in = get_batch(data, random_samples, masks=None, block_size=block_size, get_offset=True)
102
+ assert batch_in.shape == (3, block_size-1)
103
+ assert batch_out.shape == (3, block_size-1)
104
+ assert masks_in is None
105
+
106
+ def test_get_batch_with_masks_get_offset_false():
107
+ data = torch.arange(30).view(5, 6)
108
+ masks = torch.ones_like(data)
109
+ random_samples = torch.tensor([0, 1])
110
+ block_size = 5
111
+ batch_in, batch_out, masks_in = get_batch(data, random_samples, masks=masks, block_size=block_size, get_offset=False)
112
+ assert batch_in.shape == (2, block_size)
113
+ assert batch_out is None
114
+ assert masks_in.shape == (2, block_size)
115
+
116
+ def test_get_batch_block_size_larger_than_data_length_raises():
117
+ data = torch.arange(20).view(4, 5)
118
+ random_samples = torch.tensor([0])
119
+ block_size = 6
120
+ with pytest.raises(Exception):
121
+ get_batch(data, random_samples, block_size=block_size)
122
+
123
+
124
+ def test_top_kp_filter_top_k_only():
125
+ # Create dummy logits batch (2 samples, vocab size 5)
126
+ logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0],
127
+ [5.0, 4.0, 3.0, 2.0, 1.0]])
128
+ top_k = 3
129
+ selected = top_kp_filter(logits, top_k=top_k, top_p=None)
130
+
131
+ assert selected.shape == (2,)
132
+ # Selected indices must be in top_k tokens
133
+ for i, sel in enumerate(selected):
134
+ topk_indices = torch.topk(logits[i], top_k).indices.tolist()
135
+ assert sel.item() in topk_indices
136
+
137
+ def test_top_kp_filter_top_p_only():
138
+ # Dummy logits with clear probabilities
139
+ logits = torch.tensor([[0.1, 0.2, 0.3, 0.4],
140
+ [0.4, 0.3, 0.2, 0.1]])
141
+ top_p = 0.7
142
+ selected = top_kp_filter(logits, top_k=None, top_p=top_p)
143
+
144
+ assert selected.shape == (2,)
145
+ # Selected indices must be in vocab range
146
+ for sel in selected:
147
+ assert 0 <= sel.item() < logits.shape[1]
148
+
149
+ def test_top_kp_filter_top_k_and_top_p():
150
+ logits = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5],
151
+ [0.5, 0.4, 0.3, 0.2, 0.1]])
152
+ top_k = 2
153
+ top_p = 0.6
154
+ selected = top_kp_filter(logits, top_k=top_k, top_p=top_p)
155
+
156
+ assert selected.shape == (2,)
157
+ for i, sel in enumerate(selected):
158
+ # With both filters, selected index should be in top_k indices
159
+ topk_indices = torch.topk(logits[i], top_k).indices.tolist()
160
+ assert sel.item() in topk_indices
161
+
162
+ def test_top_kp_filter_no_filter():
163
+ logits = torch.tensor([[0.1, 0.2, 0.3],
164
+ [0.3, 0.2, 0.1]])
165
+ selected = top_kp_filter(logits, top_k=None, top_p=None)
166
+
167
+ assert selected.shape == (2,)
168
+ for sel in selected:
169
+ assert 0 <= sel.item() < logits.shape[1]
170
+
171
+ def test_top_kp_filter_empty_logits():
172
+ # Edge case: logits empty or zero size
173
+ logits = torch.empty((0, 0))
174
+ with pytest.raises(IndexError):
175
+ _ = top_kp_filter(logits, top_k=1, top_p=0.5)
176
+
@@ -0,0 +1,130 @@
1
+ import pytest
2
+ import torch
3
+ import tempfile
4
+ import os
5
+ from types import SimpleNamespace
6
+ from unittest.mock import patch, MagicMock
7
+ from robo_lib import RoboConstructor, save_component, load_component
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+
12
+ # ---------- FIXTURES AND MOCKS ----------
13
+
14
+ @pytest.fixture
15
+ def mock_encoder_block():
16
+ return MagicMock()
17
+
18
+ @pytest.fixture
19
+ def mock_decoder_block():
20
+ return MagicMock()
21
+
22
+ @pytest.fixture
23
+ def mock_my_sequential():
24
+ class DummySequential(torch.nn.Module):
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__()
27
+ def forward(self, *args, **kwargs):
28
+ return args[0], None, None, None, None
29
+ return DummySequential
30
+
31
+ @pytest.fixture
32
+ def dummy_tokenizer():
33
+ return SimpleNamespace(
34
+ start_token=1,
35
+ end_token=2,
36
+ pad_token=0,
37
+ new_line_token=3,
38
+ encode=lambda s: [4, 5, 6],
39
+ decode=lambda tokens: "decoded"
40
+ )
41
+
42
+ @pytest.fixture
43
+ def dummy_data():
44
+ return torch.randint(0, 10, (8, 32)).to(device) # 8 samples of 32 tokens
45
+
46
+ @pytest.fixture
47
+ def robo_decoder_only(mock_my_sequential):
48
+ with patch("robo_lib.DecoderBlock", return_value=MagicMock()), \
49
+ patch("robo_lib.MySequential", mock_my_sequential):
50
+ return RoboConstructor(
51
+ n_embed=16,
52
+ dec_n_blocks=2,
53
+ dec_n_head=2,
54
+ dec_vocab_size=50,
55
+ dec_block_size=32
56
+ ).to(device)
57
+
58
+ @pytest.fixture
59
+ def robo_enc_dec(mock_my_sequential):
60
+ with patch("robo_lib.DecoderBlock", return_value=MagicMock()), \
61
+ patch("robo_lib.EncoderBlock", return_value=MagicMock()), \
62
+ patch("robo_lib.MySequential", mock_my_sequential):
63
+ return RoboConstructor(
64
+ n_embed=16,
65
+ dec_n_blocks=2,
66
+ dec_n_head=2,
67
+ dec_vocab_size=50,
68
+ dec_block_size=32,
69
+ enc_n_blocks=2,
70
+ enc_n_head=2,
71
+ enc_vocab_size=50,
72
+ enc_block_size=32
73
+ ).to(device)
74
+
75
+ # ---------- TESTS ----------
76
+
77
+ def test_decoder_only_init(robo_decoder_only):
78
+ assert not robo_decoder_only.cross_attention
79
+ assert robo_decoder_only.decoder_blocks is not None
80
+ assert robo_decoder_only.encoder_blocks is None
81
+
82
+ def test_encoder_decoder_init(robo_enc_dec):
83
+ assert robo_enc_dec.cross_attention
84
+ assert robo_enc_dec.encoder_blocks is not None
85
+
86
+ def test_forward_decoder_only(robo_decoder_only):
87
+ input_tensor = torch.randint(0, 50, (2, 32)).to(device)
88
+ output = robo_decoder_only(dec_in=input_tensor)
89
+ assert output.shape[:2] == (2, 32)
90
+
91
+ def test_forward_encoder_decoder(robo_enc_dec):
92
+ dec_input = torch.randint(0, 50, (2, 32)).to(device)
93
+ enc_input = torch.randint(0, 50, (2, 32)).to(device)
94
+ output = robo_enc_dec(dec_in=dec_input, enc_in=enc_input)
95
+ assert output.shape[:2] == (2, 32)
96
+
97
+ @patch("robo_lib.get_batch")
98
+ def test_prep_data_decoder_only(mock_get_batch, robo_decoder_only, dummy_data):
99
+ mock_get_batch.return_value = (dummy_data[:2], dummy_data[:2], dummy_data[:2])
100
+ out = robo_decoder_only.prep_data(batch_size=2, dec_data=dummy_data, dec_block_size=32)
101
+ assert len(out) == 5
102
+ assert out[0].shape[0] == 2
103
+
104
+ @patch("robo_lib.get_batch")
105
+ def test_prep_data_encoder_decoder(mock_get_batch, robo_enc_dec, dummy_data):
106
+ mock_get_batch.side_effect = [
107
+ (dummy_data[:2], dummy_data[:2], dummy_data[:2]), # decoder
108
+ (dummy_data[:2], None, dummy_data[:2]) # encoder
109
+ ]
110
+ out = robo_enc_dec.prep_data(batch_size=2, dec_data=dummy_data, dec_block_size=32, enc_data=dummy_data, enc_block_size=32)
111
+ assert len(out) == 5
112
+ assert out[3].shape[0] == 2 # encoder input
113
+
114
+ @patch("robo_lib.top_kp_filter", return_value=torch.tensor([2]))
115
+ def test_generate_decoder_only(mock_top_kp, robo_decoder_only, dummy_tokenizer):
116
+ out = robo_decoder_only.generate(inputs="hello", dec_tokenizer=dummy_tokenizer, max_new_tokens=3, dec_start_token=1, dec_end_token=2)
117
+ assert isinstance(out, str)
118
+
119
+ @patch("robo_lib.top_kp_filter", return_value=torch.tensor([2]))
120
+ def test_generate_encoder_decoder(mock_top_kp, robo_enc_dec, dummy_tokenizer):
121
+ out = robo_enc_dec.generate(inputs="hello", enc_tokenizer=dummy_tokenizer, dec_tokenizer=dummy_tokenizer,
122
+ max_new_tokens=3, enc_start_token=1, enc_end_token=2, dec_start_token=1, dec_end_token=2)
123
+ assert isinstance(out, str)
124
+
125
+ def test_save_and_load_component(robo_decoder_only):
126
+ with tempfile.TemporaryDirectory() as tmpdir:
127
+ path = os.path.join(tmpdir, "test_model")
128
+ save_component(robo_decoder_only, path)
129
+ loaded = load_component(path)
130
+ assert isinstance(loaded, RoboConstructor)
@@ -0,0 +1,89 @@
1
+ import pytest
2
+ import os
3
+ from tempfile import NamedTemporaryFile
4
+ from robo_lib import TokenizerConstructor
5
+
6
+
7
+ @pytest.fixture
8
+ def training_file():
9
+ with NamedTemporaryFile(mode="w+", delete=False) as f:
10
+ f.write("Hello world\nThis is a test\nTokenizer test\n")
11
+ f.flush()
12
+ yield f.name
13
+ os.remove(f.name)
14
+
15
+
16
+ def test_tokenizer_creation():
17
+ tokenizer = TokenizerConstructor(
18
+ tokenizer_type="BPE",
19
+ pre_tokenizers="Whitespace",
20
+ normalizers=["Lowercase"],
21
+ special_tokens=["<unk>", "<pad>"],
22
+ vocab_size=100
23
+ )
24
+ assert tokenizer is not None
25
+ assert "<unk>" in tokenizer.special_tokens
26
+ assert tokenizer.vocab_size is None # Untrained tokenizer should have vocab_size None
27
+
28
+
29
+ def test_tokenizer_train(training_file):
30
+ tokenizer = TokenizerConstructor(
31
+ tokenizer_type="WordLevel",
32
+ pre_tokenizers="Whitespace",
33
+ normalizers=["Lowercase"],
34
+ special_tokens=["<unk>", "<pad>"],
35
+ vocab_size=50
36
+ )
37
+ tokenizer.train(training_file)
38
+ assert tokenizer.vocab_size is not None
39
+ assert tokenizer.vocab_size > 0
40
+
41
+
42
+ def test_tokenizer_encode_decode(training_file):
43
+ tokenizer = TokenizerConstructor(
44
+ tokenizer_type="BPE",
45
+ pre_tokenizers="Whitespace",
46
+ normalizers=["Lowercase"],
47
+ special_tokens=["<unk>", "<pad>"],
48
+ vocab_size=50
49
+ )
50
+ tokenizer.train(training_file)
51
+ encoded = tokenizer.encode("This is a test")
52
+ assert isinstance(encoded, list)
53
+ assert all(isinstance(i, int) for i in encoded)
54
+
55
+ decoded = tokenizer.decode(encoded)
56
+ assert isinstance(decoded, str)
57
+ assert len(decoded) > 0
58
+
59
+
60
+ def test_tokenizer_encode_batch(training_file):
61
+ tokenizer = TokenizerConstructor(
62
+ tokenizer_type="BPE",
63
+ pre_tokenizers="Whitespace",
64
+ normalizers=["Lowercase"],
65
+ special_tokens=["<unk>", "<pad>"],
66
+ vocab_size=50
67
+ )
68
+ tokenizer.train(training_file)
69
+ batch = ["This is a test", "Hello world"]
70
+ encoded_batch = tokenizer.encode_batch(batch)
71
+ assert isinstance(encoded_batch, list)
72
+ assert len(encoded_batch) == len(batch)
73
+ assert all(isinstance(seq, list) for seq in encoded_batch)
74
+
75
+ encoded_truncated = tokenizer.encode_batch(batch, max_length=3)
76
+ assert all(len(seq) <= 3 for seq in encoded_truncated)
77
+
78
+
79
+ def test_special_token_indexes():
80
+ tokenizer = TokenizerConstructor(
81
+ tokenizer_type="BPE",
82
+ pre_tokenizers="Whitespace",
83
+ special_tokens=["<unk>", "<sos>", "<eos>", "<pad>", "\n"]
84
+ )
85
+ assert tokenizer.unknown_token == tokenizer.special_tokens.index("<unk>")
86
+ assert tokenizer.start_token == tokenizer.special_tokens.index("<sos>")
87
+ assert tokenizer.end_token == tokenizer.special_tokens.index("<eos>")
88
+ assert tokenizer.pad_token == tokenizer.special_tokens.index("<pad>")
89
+ assert tokenizer.new_line_token == tokenizer.special_tokens.index("\n")
File without changes
File without changes
File without changes