robo-lib 0.0.10__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ __pycache__/
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: robo_lib
3
- Version: 0.0.10
3
+ Version: 1.0.0
4
4
  Summary: A package to create, configure, and train transformer models.
5
5
  Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
6
6
  Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "robo_lib"
7
- version = "0.0.10"
7
+ version = "1.0.0"
8
8
  authors = [
9
9
  { name="Erik Papp", email="erik3papp@gmail.com" },
10
10
  ]
@@ -1,8 +1,7 @@
1
1
  from .components import TokenizerConstructor as TokenizerConstructor
2
2
  from .components import create_mask as create_mask
3
- from .components import pad as pad
4
- from .components import process_row as process_row
5
- from .components import scan_max_block_size as scan_max_block_size
3
+ from .components import pre_process_data as pre_process_data
4
+ from .components import safe_stack as safe_stack
6
5
  from .components import DataProcessor as DataProcessor
7
6
  from .components import get_valid_samples as get_valid_samples
8
7
  from .components import get_batch as get_batch
@@ -6,6 +6,8 @@ import numpy as np
6
6
  import random
7
7
  import pickle
8
8
  import itertools
9
+ from pathlib import Path
10
+ import os
9
11
 
10
12
  class TokenizerConstructor:
11
13
  '''
@@ -30,6 +32,7 @@ class TokenizerConstructor:
30
32
  tokenizer_type:str="BPE",
31
33
  pre_tokenizers:list[str]|str=["Whitespace"],
32
34
  normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
35
+ vocab:dict[str,int] = {},
33
36
  special_tokens:list[str]|str=[],
34
37
  unknown_token_string:str="<unk>",
35
38
  start_token_string:str="<sos>",
@@ -42,25 +45,28 @@ class TokenizerConstructor:
42
45
 
43
46
  if isinstance(special_tokens, str):
44
47
  special_tokens = [special_tokens]
45
- self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token != None]
46
- self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string != None else None
47
- self.start_token = self.special_tokens.index(start_token_string) if start_token_string != None else None
48
- self.end_token = self.special_tokens.index(end_token_string) if end_token_string != None else None
49
- self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string != None else None
50
- self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string != None else None
48
+ self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token is not None]
49
+ self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string is not None else None
50
+ self.start_token = self.special_tokens.index(start_token_string) if start_token_string is not None else None
51
+ self.end_token = self.special_tokens.index(end_token_string) if end_token_string is not None else None
52
+ self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string is not None else None
53
+ self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string is not None else None
51
54
 
52
55
  if tokenizer_type == "BPE":
53
56
  self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
54
57
  self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
55
58
  elif tokenizer_type == "WordLevel":
56
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
59
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(vocab = vocab, unk_token=unknown_token_string))
57
60
  self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
58
61
  elif tokenizer_type == "WordPiece":
59
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
62
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(vocab = vocab, unk_token=unknown_token_string))
60
63
  self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
61
64
  elif tokenizer_type == "Unigram":
62
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(unk_token=unknown_token_string))
65
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram())
63
66
  self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
67
+
68
+ if self.pad_token is not None:
69
+ self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=pad_token_string)
64
70
 
65
71
  if isinstance(pre_tokenizers, str):
66
72
  pre_tokenizers = [pre_tokenizers]
@@ -122,6 +128,13 @@ class TokenizerConstructor:
122
128
  def encode(self, inp:str) -> list[int]:
123
129
  return self.tokenizer_type.encode(inp).ids
124
130
 
131
+ def encode_batch(self, inp:list[str], max_length:int=None) -> list[list[int]]:
132
+ if max_length is not None:
133
+ self.tokenizer_type.enable_truncation(max_length=max_length)
134
+ out = [row.ids for row in self.tokenizer_type.encode_batch(inp)]
135
+ self.tokenizer_type.no_truncation()
136
+ return out
137
+
125
138
  def decode(self, inp:list[int]) -> str:
126
139
  return self.tokenizer_type.decode(inp)
127
140
 
@@ -136,38 +149,35 @@ def create_mask(row:list, block_size:int) -> list[bool]:
136
149
  mask = [1]*len(row) + [0]*(block_size - len(row))
137
150
  return mask
138
151
 
139
- def pad(row:list, block_size:int, pad_token:int) -> list[int]:
140
- '''
141
-
142
- returns padded row. Row is padded until length block_size with specified pad_token value
143
-
144
- '''
145
- row.extend([pad_token]*(block_size - len(row)))
146
- return row
147
-
148
- def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
152
+ def pre_process_data(data:str, start_token_string:str, end_token_string:str) -> list[int]:
149
153
  '''
150
154
 
151
- returns tokenized row using specified tokenizer, and adds the tokenizer's start and end tokens if they exist
155
+ returns string row with the tokenizer's start and end tokens if they exist
152
156
 
153
157
  '''
154
- processed_row = tokenizer.encode(row)
155
- if tokenizer.start_token != None:
156
- processed_row.insert(0, tokenizer.start_token)
157
- if tokenizer.end_token != None:
158
- processed_row.append(tokenizer.end_token)
159
-
160
- return processed_row
158
+ if start_token_string is None and end_token_string is None:
159
+ return data
160
+ else:
161
+ for i in range(len(data)):
162
+ if start_token_string is not None:
163
+ data[i] = start_token_string + data[i]
164
+ if end_token_string is not None:
165
+ data[i] = data[i] + end_token_string
166
+
167
+ return data
161
168
 
162
- def scan_max_block_size(data:list[str], tokenizer:TokenizerConstructor) -> int:
169
+ def safe_stack(tensor_list:list[torch.tensor]) -> torch.tensor:
163
170
  '''
164
171
 
165
- returns max_block_size of given list of strings by taking the length of the longest process_row(row) in data
172
+ torch stack with check to ensure tensors are valid in input list
173
+
174
+ returns torch.stack(out_list) for all valid torch tensors in tensor_list. raises error if no valid tensors
166
175
 
167
176
  '''
168
- lengths = [len(process_row(p, tokenizer)) for p in data]
169
- max_block_size_scanner = max(lengths)
170
- return max_block_size_scanner
177
+ out_list = [row for row in tensor_list if isinstance(row, torch.Tensor)]
178
+ if len(out_list) == 0:
179
+ raise ValueError("no valid tensors in list.")
180
+ return torch.stack(out_list)
171
181
 
172
182
 
173
183
  class DataProcessor:
@@ -196,93 +206,55 @@ class DataProcessor:
196
206
  self.enc_tokenizer = enc_tokenizer
197
207
 
198
208
  def process_list(self,
199
- save_path:str,
200
209
  dec_data:list[str]|str,
201
210
  dec_max_block_size:int=None,
202
211
  dec_create_masks:bool=True,
203
- dec_block_size_exceeded_policy:str=None,
204
212
  enc_data:list[str]=None,
205
213
  enc_max_block_size:int=None,
206
214
  enc_create_masks:bool=True,
207
- enc_block_size_exceeded_policy:str=None
215
+ save_path:str = "."
208
216
  ) -> None:
209
217
 
210
218
  if isinstance(dec_data, str):
211
219
  dec_data = [dec_data]
212
220
  dec_data_length = len(dec_data)
213
- save_path = save_path.replace(".pt", "")
214
-
215
- if dec_max_block_size == None:
216
- dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
217
221
 
218
- if enc_data != None:
219
- self.enc_tokenizer = self.dec_tokenizer if self.enc_tokenizer == None else self.enc_tokenizer
222
+ if enc_data is not None:
223
+ if self.enc_tokenizer is None:
224
+ self.enc_tokenizer = self.dec_tokenizer
220
225
 
221
226
  enc_data_length = len(enc_data)
222
227
  if dec_data_length != enc_data_length:
223
- raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
228
+ raise Exception(f"decoder data and encoder data lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
224
229
 
225
- if enc_max_block_size == None:
226
- enc_max_block_size = scan_max_block_size(enc_data, self.enc_tokenizer)
227
-
228
- enc_out_list = [[]]*enc_data_length
229
- enc_mask_list = [[]]*enc_data_length if enc_create_masks else []
230
- else:
231
- enc_out_list = []
232
- enc_mask_list = []
233
-
234
- dec_out_list = [[]]*dec_data_length
235
- dec_mask_list = [[]]*dec_data_length if dec_create_masks else []
236
- for index in range(len(dec_out_list)):
237
- dec_processed_item = process_row(dec_data[index], self.dec_tokenizer)
238
- if dec_max_block_size != None and len(dec_processed_item) > dec_max_block_size:
239
- if dec_block_size_exceeded_policy == "trim":
240
- dec_processed_item = dec_processed_item[:dec_max_block_size]
241
- elif dec_block_size_exceeded_policy == "skip":
242
- continue
243
- elif dec_block_size_exceeded_policy == None:
244
- raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
245
- if dec_create_masks:
246
- dec_mask = create_mask(dec_processed_item, dec_max_block_size)
247
- dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
248
-
249
- if enc_data != None:
250
- enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
251
- if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
252
- if enc_block_size_exceeded_policy == "trim":
253
- enc_processed_item = enc_processed_item[:enc_max_block_size]
254
- elif enc_block_size_exceeded_policy == "skip":
255
- continue
256
- elif enc_block_size_exceeded_policy == None:
257
- raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
258
- if enc_create_masks:
259
- enc_mask = create_mask(enc_processed_item, enc_max_block_size)
260
- enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
261
-
262
- dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
263
- if dec_create_masks:
264
- dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
265
-
266
- if enc_data != None:
267
- enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
268
- if enc_create_masks:
269
- enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
270
-
271
- dec_out_list = torch.stack([row for row in dec_out_list if row != []])
272
- torch.save(dec_out_list, save_path + "_decoder_data.pt")
230
+ print("processing data")
231
+ dec_out_list = self.dec_tokenizer.encode_batch(dec_data, max_length=dec_max_block_size)
232
+ if dec_create_masks:
233
+ mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.dec_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
234
+ dec_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in dec_out_list])
235
+
236
+ if enc_data is not None:
237
+ enc_out_list = self.enc_tokenizer.encode_batch(enc_data, max_length=enc_max_block_size)
238
+ if enc_create_masks:
239
+ mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.enc_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
240
+ enc_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in enc_out_list])
241
+
242
+ dec_out_list = torch.tensor(dec_out_list, dtype=torch.long)
243
+ Path(save_path).mkdir(parents=True, exist_ok=True)
244
+ torch.save(dec_out_list, os.path.join(save_path, "decoder_data.pt"))
273
245
  if dec_create_masks:
274
- dec_mask_list = torch.stack([row for row in dec_mask_list if row != []])
275
- torch.save(dec_mask_list, save_path + "_decoder_mask_data.pt")
276
- if enc_data != None:
277
- enc_out_list = torch.stack([row for row in enc_out_list if row != []])
278
- torch.save(enc_out_list, save_path + "_encoder_data.pt")
246
+ dec_mask_list = torch.tensor(dec_mask_list, dtype=torch.long)
247
+ torch.save(dec_mask_list, os.path.join(save_path, "decoder_mask_data.pt"))
248
+ if enc_data is not None:
249
+ enc_out_list = torch.tensor(enc_out_list, dtype=torch.long)
250
+ torch.save(enc_out_list, os.path.join(save_path, "encoder_data.pt"))
279
251
  if enc_create_masks:
280
- enc_mask_list = torch.stack([row for row in enc_mask_list if row != []])
281
- torch.save(enc_mask_list, save_path + "_encoder_mask_data.pt")
252
+ enc_mask_list = torch.tensor(enc_mask_list, dtype=torch.long)
253
+ torch.save(enc_mask_list, os.path.join(save_path, "encoder_mask_data.pt"))
282
254
 
283
255
 
284
- def get_valid_samples(random_samples:torch.tensor,
285
- masks:torch.tensor,
256
+ def get_valid_samples(random_samples:torch.Tensor,
257
+ masks:torch.Tensor,
286
258
  block_size:int
287
259
  ) -> list[int]:
288
260
  '''
@@ -294,9 +266,9 @@ def get_valid_samples(random_samples:torch.tensor,
294
266
  valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
295
267
  return valid_samples
296
268
 
297
- def get_batch(data:torch.tensor,
298
- random_samples:torch.tensor,
299
- masks:torch.tensor=None,
269
+ def get_batch(data:torch.Tensor,
270
+ random_samples:torch.Tensor,
271
+ masks:torch.Tensor=None,
300
272
  block_size:int=None,
301
273
  get_offset:bool=True
302
274
  ) -> tuple[torch.tensor]:
@@ -308,53 +280,78 @@ def get_batch(data:torch.tensor,
308
280
 
309
281
  '''
310
282
  batch_size = len(random_samples)
311
- if block_size != None and block_size != data.shape[1]:
283
+ if block_size is not None and block_size != data.shape[1]:
312
284
  if block_size >= data.shape[1]:
313
285
  raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
314
286
 
315
- if masks != None:
287
+ if masks is not None:
316
288
  random_point = get_valid_samples(random_samples, masks, block_size)
317
289
  else:
318
290
  random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
319
291
  batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
320
- masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks != None else None
292
+ masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks is not None else None
321
293
  batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
322
294
  else:
323
295
  block_size = data.shape[1]
324
296
  batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
325
- masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks != None else None
297
+ masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks is not None else None
326
298
  batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
327
299
 
328
300
  return batch_in, batch_out, masks_in
329
301
 
330
- def top_kp_filter(logits:torch.tensor,
331
- top_k:int,
332
- top_p:float=None
333
- ) -> torch.tensor:
302
+ def top_kp_filter(logits: torch.Tensor,
303
+ top_k: int = None,
304
+ top_p: float = None
305
+ ) -> torch.Tensor:
334
306
  '''
307
+ Returns predicted token by filtering output logits using top_k and/or top_p (nucleus) filtering.
335
308
 
336
- returns predicted token by filtering output logits using top_k and top_p
337
-
338
- '''
339
- if top_p != None:
340
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
341
- cumulative_probs = torch.cumsum(sorted_logits, dim=-1)
309
+ Args:
310
+ logits: (batch_size, vocab_size) tensor of raw logits.
311
+ top_k: keep only top_k tokens with highest logits.
312
+ top_p: keep the smallest set of tokens with cumulative probability >= top_p.
342
313
 
343
- filter = cumulative_probs > top_p
344
- filter[..., 1:] = filter[..., :-1].clone()
345
- filter[..., 0] = 0
346
- indices_to_remove = filter.scatter(1, sorted_indices, filter)
347
- logits[indices_to_remove] = float("-inf")
314
+ Returns:
315
+ selected: tensor of selected token indices (batch_size,)
316
+ '''
317
+ logits = logits.clone() # avoid modifying input logits in-place
318
+
319
+ # Apply top-p filtering if specified
320
+ if top_p is not None:
321
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
322
+ probs = F.softmax(sorted_logits, dim=-1)
323
+ cumulative_probs = torch.cumsum(probs, dim=-1)
324
+
325
+ # Remove tokens with cumulative probability above threshold (except first token)
326
+ sorted_mask = cumulative_probs > top_p
327
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
328
+ sorted_mask[..., 0] = False
329
+
330
+ # Mask tokens to remove by setting logits to -inf
331
+ indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
332
+ logits[indices_to_remove] = float('-inf')
333
+
334
+ # Apply top-k filtering if specified
335
+ if top_k is not None:
336
+ top_k = min(top_k, logits.size(-1)) # safety check
337
+ topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
338
+ topk_probs = F.softmax(topk_logits, dim=-1).cpu().numpy()
339
+
340
+ # For each batch, sample from top_k candidates
341
+ selected = []
342
+ for i in range(topk_probs.shape[0]):
343
+ candidate = np.random.choice(topk_indices[i].cpu().numpy(), 1, p=topk_probs[i])
344
+ selected.append(candidate[0])
345
+ selected = torch.tensor(selected, dtype=torch.long)
348
346
 
349
- if top_k != None:
350
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
351
-
352
- sorted_logits = F.softmax(sorted_logits[:, :top_k], dim=-1)
353
- sorted_indices = sorted_indices[:, :top_k].detach().cpu()
354
- sorted_logits = sorted_logits.detach().cpu().numpy()
355
- sorted_logits[0][0] += 1 - sum(sorted_logits[0])
356
-
357
- selected = torch.tensor(np.random.choice(sorted_indices[0], 1, p=sorted_logits[0]), dtype=torch.long)
347
+ else:
348
+ # If only top_p is specified, sample from entire filtered logits
349
+ probs = F.softmax(logits, dim=-1).cpu().numpy()
350
+ selected = []
351
+ for i in range(probs.shape[0]):
352
+ candidate = np.random.choice(len(probs[i]), 1, p=probs[i])
353
+ selected.append(candidate[0])
354
+ selected = torch.tensor(selected, dtype=torch.long)
358
355
 
359
356
  return selected
360
357
 
@@ -387,10 +384,10 @@ class SelfAttention(nn.Module):
387
384
  self.dropout = nn.Dropout(dropout)
388
385
 
389
386
  def forward(self,
390
- k:torch.tensor,
391
- q:torch.tensor,
392
- v:torch.tensor,
393
- mask:torch.tensor=None
387
+ k:torch.Tensor,
388
+ q:torch.Tensor,
389
+ v:torch.Tensor,
390
+ mask:torch.Tensor=None
394
391
  ) -> torch.tensor:
395
392
  '''
396
393
 
@@ -406,7 +403,7 @@ class SelfAttention(nn.Module):
406
403
  wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
407
404
  if self.triangle_mask and self.block_size >= 0:
408
405
  wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
409
- if mask != None:
406
+ if mask is not None:
410
407
  wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
411
408
  wei = F.softmax(wei, dim=-1)
412
409
  wei = self.dropout(wei)
@@ -438,10 +435,10 @@ class MultiHeadAttention(nn.Module):
438
435
  self.dropout = nn.Dropout(dropout)
439
436
 
440
437
  def forward(self,
441
- k:torch.tensor,
442
- q:torch.tensor,
443
- v:torch.tensor,
444
- mask:torch.tensor=None
438
+ k:torch.Tensor,
439
+ q:torch.Tensor,
440
+ v:torch.Tensor,
441
+ mask:torch.Tensor=None
445
442
  ) -> torch.tensor:
446
443
  '''
447
444
 
@@ -475,7 +472,7 @@ class FeedForward(nn.Module):
475
472
  )
476
473
 
477
474
  def forward(self,
478
- x:torch.tensor
475
+ x:torch.Tensor
479
476
  ) -> torch.tensor:
480
477
  return self.net(x)
481
478
 
@@ -500,8 +497,8 @@ class EncoderBlock(nn.Module):
500
497
  self.ln2 = nn.LayerNorm(n_embed)
501
498
 
502
499
  def forward(self,
503
- x:torch.tensor,
504
- mask:torch.tensor=None
500
+ x:torch.Tensor,
501
+ mask:torch.Tensor=None
505
502
  ) -> tuple[torch.tensor]:
506
503
  att = self.sa(x, x, x, mask=mask)
507
504
  x = self.ln1(att + x)
@@ -541,15 +538,15 @@ class DecoderBlock(nn.Module):
541
538
  self.ca = None
542
539
 
543
540
  def forward(self,
544
- x:torch.tensor,
545
- enc_k:torch.tensor,
546
- enc_v:torch.tensor,
541
+ x:torch.Tensor,
542
+ enc_k:torch.Tensor,
543
+ enc_v:torch.Tensor,
547
544
  mask_out:bool=None,
548
- mask_in:torch.tensor=None
545
+ mask_in:torch.Tensor=None
549
546
  ) -> tuple[torch.tensor]:
550
547
  att = self.sa(x, x, x, mask=mask_out)
551
548
  x = self.ln1(att + x)
552
- if self.ca != None:
549
+ if self.ca is not None:
553
550
  catt = self.ca(enc_k, x, enc_v, mask=mask_in)
554
551
  x = self.ln3(catt + x)
555
552
  ff = self.ffwd(x)
@@ -615,6 +612,7 @@ class RoboConstructor(nn.Module):
615
612
  enc_vocab_size:int=None,
616
613
  enc_block_size:int=None,
617
614
  enc_expansion_factor:int=4,
615
+ enc_positional_encoding:bool=True,
618
616
  dropout:float=0.1,
619
617
  device:str=None
620
618
  ) -> None:
@@ -627,7 +625,7 @@ class RoboConstructor(nn.Module):
627
625
  self.dec_expansion_factor = dec_expansion_factor
628
626
  self.dropout = dropout
629
627
 
630
- if device == None:
628
+ if device is None:
631
629
  self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
632
630
  else:
633
631
  self.device = device
@@ -635,6 +633,7 @@ class RoboConstructor(nn.Module):
635
633
  self.dec_positional_embedding_table = nn.Embedding(dec_block_size, n_embed)
636
634
 
637
635
  if enc_n_blocks != 0:
636
+ self.enc_positional_encoding = enc_positional_encoding
638
637
  self.enc_n_blocks = enc_n_blocks
639
638
  self.enc_n_head = enc_n_head
640
639
  self.enc_expansion_factor = enc_expansion_factor
@@ -642,7 +641,8 @@ class RoboConstructor(nn.Module):
642
641
  self.enc_block_size = enc_block_size
643
642
  self.cross_attention = True
644
643
  self.enc_token_embedding_table = nn.Embedding(enc_vocab_size, n_embed)
645
- self.enc_positional_embedding_table = nn.Embedding(enc_block_size, n_embed)
644
+ if enc_positional_encoding:
645
+ self.enc_positional_embedding_table = nn.Embedding(enc_block_size, n_embed)
646
646
  self.encoder_blocks = MySequential(*[EncoderBlock(n_embed, enc_n_head, enc_expansion_factor, dropout=dropout) for _ in range(enc_n_blocks)])
647
647
  else:
648
648
  self.cross_attention = False
@@ -670,13 +670,13 @@ class RoboConstructor(nn.Module):
670
670
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
671
671
 
672
672
  def forward(self,
673
- dec_in:torch.tensor,
674
- dec_mask:torch.tensor=None,
675
- enc_in:torch.tensor=None,
676
- enc_mask:torch.tensor=None
673
+ dec_in:torch.Tensor,
674
+ dec_mask:torch.Tensor=None,
675
+ enc_in:torch.Tensor=None,
676
+ enc_mask:torch.Tensor=None
677
677
  ) -> torch.tensor:
678
678
  _, dec_T = dec_in.shape
679
- if enc_in != None:
679
+ if enc_in is not None:
680
680
  _, enc_T = enc_in.shape
681
681
 
682
682
  dec_tok_emb = self.dec_token_embedding_table(dec_in)
@@ -685,8 +685,11 @@ class RoboConstructor(nn.Module):
685
685
 
686
686
  if self.cross_attention:
687
687
  enc_tok_emb = self.enc_token_embedding_table(enc_in)
688
- enc_pos_emb = self.enc_positional_embedding_table(torch.arange(enc_T, device=self.device))
689
- enc_x = enc_tok_emb + enc_pos_emb
688
+ if self.enc_positional_encoding:
689
+ enc_pos_emb = self.enc_positional_embedding_table(torch.arange(enc_T, device=self.device))
690
+ enc_x = enc_tok_emb + enc_pos_emb
691
+ else:
692
+ enc_x = enc_tok_emb
690
693
 
691
694
  enc_out, enc_mask = self.encoder_blocks(enc_x, enc_mask)
692
695
  else:
@@ -712,13 +715,13 @@ class RoboConstructor(nn.Module):
712
715
 
713
716
  dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
714
717
  dec_train_batch_in = dec_train_batch_in.to(self.device)
715
- dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out != None else None
716
- dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in != None else None
718
+ dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out is not None else None
719
+ dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in is not None else None
717
720
 
718
721
  if self.cross_attention:
719
722
  enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
720
723
  enc_train_batch_in = enc_train_batch_in.to(self.device)
721
- enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in != None else None
724
+ enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in is not None else None
722
725
  else:
723
726
  enc_train_batch_in = None
724
727
  enc_train_masks_in = None
@@ -730,14 +733,8 @@ class RoboConstructor(nn.Module):
730
733
  max_iters:int,
731
734
  eval_interval:int,
732
735
  batch_size:int,
733
- dec_training_path:str,
734
- dec_eval_path:str=None,
735
- dec_training_masks_path:str=None,
736
- dec_eval_masks_path:str=None,
737
- enc_training_path:str=None,
738
- enc_eval_path:str=None,
739
- enc_training_masks_path:str=None,
740
- enc_eval_masks_path:str=None,
736
+ training_dir_path:str,
737
+ eval_dir_path:str,
741
738
  eval_iters:int=3,
742
739
  learning_rate:float=1e-4,
743
740
  pad_token:int=None,
@@ -746,21 +743,36 @@ class RoboConstructor(nn.Module):
746
743
  label_smoothing:float=0.1
747
744
  ) -> None:
748
745
 
746
+ dec_training_path = os.path.join(training_dir_path, "decoder_data.pt")
749
747
  dec_training_data = torch.load(dec_training_path, weights_only=True)
750
- dec_eval_data = torch.load(dec_eval_path, weights_only=True) if dec_eval_path != None else None
751
- dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if dec_training_masks_path != None else None
752
- dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if dec_eval_masks_path != None else None
753
- enc_training_data = torch.load(enc_training_path, weights_only=True) if enc_training_path != None else None
754
- enc_eval_data = torch.load(enc_eval_path, weights_only=True) if enc_eval_path != None else None
755
- enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if enc_training_masks_path != None else None
756
- enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if enc_eval_masks_path != None else None
757
-
758
- if pad_token == None and dec_tokenizer != None:
748
+
749
+ dec_eval_path = os.path.join(eval_dir_path, "decoder_data.pt")
750
+ dec_eval_data = torch.load(dec_eval_path, weights_only=True) if os.path.isfile(dec_eval_path) else None
751
+
752
+ dec_training_masks_path = os.path.join(training_dir_path, "decoder_mask_data.pt")
753
+ dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if os.path.isfile(dec_training_masks_path) else None
754
+
755
+ dec_eval_masks_path = os.path.join(eval_dir_path, "decoder_mask_data.pt")
756
+ dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if os.path.isfile(dec_eval_masks_path) else None
757
+
758
+ enc_training_path = os.path.join(training_dir_path, "encoder_data.pt")
759
+ enc_training_data = torch.load(enc_training_path, weights_only=True) if os.path.isfile(enc_training_path) else None
760
+
761
+ enc_eval_path = os.path.join(eval_dir_path, "encoder_data.pt")
762
+ enc_eval_data = torch.load(enc_eval_path, weights_only=True) if os.path.isfile(enc_eval_path) else None
763
+
764
+ enc_training_masks_path = os.path.join(training_dir_path, "encoder_mask_data.pt")
765
+ enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if os.path.isfile(enc_training_masks_path) else None
766
+
767
+ enc_eval_masks_path = os.path.join(eval_dir_path, "encoder_mask_data.pt")
768
+ enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if os.path.isfile(enc_eval_masks_path) else None
769
+
770
+ if pad_token is None and dec_tokenizer is not None:
759
771
  pad_token = dec_tokenizer.pad_token
760
772
 
761
773
  self.to(self.device)
762
774
 
763
- if pad_token != None:
775
+ if pad_token is not None:
764
776
  loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
765
777
  else:
766
778
  loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
@@ -776,7 +788,7 @@ class RoboConstructor(nn.Module):
776
788
  proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
777
789
  losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
778
790
  out["train"] = losses.mean()
779
- if dec_eval_data != None:
791
+ if dec_eval_data is not None:
780
792
  for k in range(eval_iters):
781
793
  dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
782
794
  proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
@@ -792,7 +804,7 @@ class RoboConstructor(nn.Module):
792
804
  if iter % eval_interval == 0 or iter == max_iters-1:
793
805
  losses = estimate_loss()
794
806
  print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
795
- if save_path != None:
807
+ if save_path is not None:
796
808
  save_component(self, save_path=save_path)
797
809
 
798
810
  dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
@@ -819,25 +831,25 @@ class RoboConstructor(nn.Module):
819
831
  top_k:int=None,
820
832
  top_p:float=None
821
833
  ) -> list[int]|str:
822
- max_new_tokens = self.dec_block_size if max_new_tokens == None else max_new_tokens
834
+ max_new_tokens = self.dec_block_size if max_new_tokens is None else max_new_tokens
823
835
 
824
836
  if self.cross_attention:
825
- if enc_tokenizer != None:
826
- if enc_start_token == None:
837
+ if enc_tokenizer is not None:
838
+ if enc_start_token is None:
827
839
  enc_start_token = enc_tokenizer.start_token
828
- if enc_end_token == None:
840
+ if enc_end_token is None:
829
841
  enc_end_token = enc_tokenizer.end_token
830
842
  if isinstance(inputs, str):
831
843
  inputs = enc_tokenizer.encode(inputs)
832
844
 
833
- if dec_tokenizer != None:
834
- if dec_start_token == None:
845
+ if dec_tokenizer is not None:
846
+ if dec_start_token is None:
835
847
  dec_start_token = dec_tokenizer.start_token
836
- if dec_end_token == None:
848
+ if dec_end_token is None:
837
849
  dec_end_token = dec_tokenizer.end_token
838
- if new_line_token == None:
850
+ if new_line_token is None:
839
851
  new_line_token = dec_tokenizer.new_line_token
840
- if self.cross_attention == False and isinstance(inputs, str):
852
+ if not self.cross_attention and isinstance(inputs, str):
841
853
  inputs = dec_tokenizer.encode(inputs)
842
854
 
843
855
 
@@ -846,7 +858,7 @@ class RoboConstructor(nn.Module):
846
858
  idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
847
859
  else:
848
860
  enc_input = None
849
- if separator_token != None:
861
+ if separator_token is not None:
850
862
  idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
851
863
  else:
852
864
  idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
@@ -860,7 +872,7 @@ class RoboConstructor(nn.Module):
860
872
  logits = proj_output[:, -1, :]
861
873
  probabilities = F.log_softmax(logits/temperature, dim=-1)
862
874
 
863
- if top_k == None and top_p == None:
875
+ if top_k is None and top_p is None:
864
876
  idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
865
877
  else:
866
878
  idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
@@ -868,10 +880,10 @@ class RoboConstructor(nn.Module):
868
880
  if idx_next[0] == dec_end_token:
869
881
  break
870
882
 
871
- if dec_tokenizer == None:
883
+ if dec_tokenizer is None:
872
884
  return idx[0].tolist()
873
885
  else:
874
- if new_line_token != None:
886
+ if new_line_token is not None:
875
887
  return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
876
888
  else:
877
889
  return dec_tokenizer.decode(idx[0].tolist())
@@ -0,0 +1,82 @@
1
+ import os
2
+ import shutil
3
+ import torch
4
+ import pytest
5
+ from robo_lib.components import DataProcessor, TokenizerConstructor
6
+
7
+
8
+ @pytest.fixture
9
+ def temp_save_path():
10
+ path = "temp_test_dir"
11
+ yield path
12
+ if os.path.exists(path):
13
+ shutil.rmtree(path)
14
+
15
+
16
+ @pytest.fixture
17
+ def dummy_tokenizer():
18
+ tokenizer = TokenizerConstructor(
19
+ tokenizer_type="WordLevel",
20
+ pre_tokenizers="Whitespace",
21
+ special_tokens=["<pad>", "<unk>"],
22
+ vocab={"hello": 0, "world": 1, "<pad>": 2, "<unk>": 3},
23
+ unknown_token_string="<unk>",
24
+ pad_token_string="<pad>",
25
+ start_token_string=None,
26
+ end_token_string=None,
27
+ new_line_token_string=None
28
+ )
29
+ # Faking "training" so encode works with the vocab
30
+ tokenizer.tokenizer_type.add_tokens(["hello", "world"])
31
+ return tokenizer
32
+
33
+
34
+ def test_data_processor_initialization(dummy_tokenizer):
35
+ processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
36
+ assert processor.dec_tokenizer is dummy_tokenizer
37
+ assert processor.enc_tokenizer is None
38
+
39
+
40
+ def test_process_list_decoder_only(dummy_tokenizer, temp_save_path):
41
+ processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
42
+ data = ["hello world", "world hello"]
43
+
44
+ processor.process_list(
45
+ dec_data=data,
46
+ dec_max_block_size=10,
47
+ save_path=temp_save_path
48
+ )
49
+
50
+ assert os.path.exists(os.path.join(temp_save_path, "decoder_data.pt"))
51
+ assert os.path.exists(os.path.join(temp_save_path, "decoder_mask_data.pt"))
52
+
53
+ tensor = torch.load(os.path.join(temp_save_path, "decoder_data.pt"))
54
+ assert isinstance(tensor, torch.Tensor)
55
+ assert tensor.shape[0] == len(data)
56
+
57
+
58
+ def test_process_list_encoder_decoder(dummy_tokenizer, temp_save_path):
59
+ processor = DataProcessor(dec_tokenizer=dummy_tokenizer, enc_tokenizer=dummy_tokenizer)
60
+ data = ["hello world", "world hello"]
61
+
62
+ processor.process_list(
63
+ dec_data=data,
64
+ enc_data=data,
65
+ dec_max_block_size=10,
66
+ enc_max_block_size=10,
67
+ save_path=temp_save_path
68
+ )
69
+
70
+ assert os.path.exists(os.path.join(temp_save_path, "decoder_data.pt"))
71
+ assert os.path.exists(os.path.join(temp_save_path, "encoder_data.pt"))
72
+ assert os.path.exists(os.path.join(temp_save_path, "decoder_mask_data.pt"))
73
+ assert os.path.exists(os.path.join(temp_save_path, "encoder_mask_data.pt"))
74
+
75
+
76
+ def test_process_list_mismatched_lengths_raises(dummy_tokenizer):
77
+ processor = DataProcessor(dec_tokenizer=dummy_tokenizer)
78
+ dec_data = ["hello world"]
79
+ enc_data = ["world hello", "extra row"]
80
+
81
+ with pytest.raises(Exception, match="decoder data and encoder data lengths do not match"):
82
+ processor.process_list(dec_data=dec_data, enc_data=enc_data)
@@ -0,0 +1,176 @@
1
+ import pytest
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ from robo_lib import create_mask, pre_process_data, safe_stack, get_valid_samples, get_batch, top_kp_filter
6
+
7
+ def test_create_mask_basic():
8
+ row = [1, 2, 3]
9
+ block_size = 5
10
+ expected = [1, 1, 1, 0, 0]
11
+ assert create_mask(row, block_size) == expected
12
+
13
+ def test_create_mask_equal_length():
14
+ row = [1, 2, 3, 4]
15
+ block_size = 4
16
+ expected = [1, 1, 1, 1]
17
+ assert create_mask(row, block_size) == expected
18
+
19
+ def test_create_mask_empty_row():
20
+ row = []
21
+ block_size = 3
22
+ expected = [0, 0, 0]
23
+ assert create_mask(row, block_size) == expected
24
+
25
+ def test_pre_process_data_none_tokens():
26
+ data = ["hello", "world"]
27
+ start_token = None
28
+ end_token = None
29
+ # Should return the input unchanged
30
+ assert pre_process_data(data.copy(), start_token, end_token) == data
31
+
32
+ def test_pre_process_data_start_token_only():
33
+ data = ["hello", "world"]
34
+ start_token = "<s>"
35
+ end_token = None
36
+ expected = ["<s>hello", "<s>world"]
37
+ assert pre_process_data(data.copy(), start_token, end_token) == expected
38
+
39
+ def test_pre_process_data_end_token_only():
40
+ data = ["hello", "world"]
41
+ start_token = None
42
+ end_token = "</s>"
43
+ expected = ["hello</s>", "world</s>"]
44
+ assert pre_process_data(data.copy(), start_token, end_token) == expected
45
+
46
+ def test_pre_process_data_both_tokens():
47
+ data = ["hello", "world"]
48
+ start_token = "<s>"
49
+ end_token = "</s>"
50
+ expected = ["<s>hello</s>", "<s>world</s>"]
51
+ assert pre_process_data(data.copy(), start_token, end_token) == expected
52
+
53
+ def test_safe_stack_valid_tensors():
54
+ t1 = torch.tensor([1, 2])
55
+ t2 = torch.tensor([3, 4])
56
+ tensor_list = [t1, t2]
57
+ stacked = safe_stack(tensor_list)
58
+ assert isinstance(stacked, torch.Tensor)
59
+ assert stacked.shape == (2, 2)
60
+
61
+ def test_safe_stack_ignore_non_tensors():
62
+ t1 = torch.tensor([1, 2])
63
+ not_tensor = [1, 2, 3]
64
+ tensor_list = [t1, not_tensor]
65
+ stacked = safe_stack(tensor_list)
66
+ assert stacked.shape == (1, 2)
67
+
68
+ def test_safe_stack_raises_for_empty():
69
+ with pytest.raises(ValueError):
70
+ safe_stack(["not a tensor", 123, None])
71
+
72
+
73
+ # For reproducibility
74
+ random.seed(0)
75
+ torch.manual_seed(0)
76
+ np.random.seed(0)
77
+
78
+ def test_get_valid_samples_all_masked_less_than_block():
79
+ masks = torch.tensor([[1, 0, 0], [1, 1, 0]])
80
+ random_samples = torch.tensor([0, 1])
81
+ block_size = 2
82
+ result = get_valid_samples(random_samples, masks, block_size)
83
+ # For first row sum(masks) = 1 <= block_size => should return 0
84
+ # For second row sum(masks) = 2 <= block_size => 0
85
+ assert result == [0, 0]
86
+
87
+ def test_get_valid_samples_some_greater_than_block():
88
+ masks = torch.tensor([[1, 1, 1], [1, 1, 0]])
89
+ random_samples = torch.tensor([0, 1])
90
+ block_size = 2
91
+ result = get_valid_samples(random_samples, masks, block_size)
92
+ # first sum = 3 > 2, so random index in [0, 1]
93
+ # second sum = 2 <= 2, so 0
94
+ assert result[1] == 0
95
+ assert 0 <= result[0] <= 1
96
+
97
+ def test_get_batch_no_masks_get_offset_true():
98
+ data = torch.arange(30).view(5, 6) # 5 rows, 6 cols
99
+ random_samples = torch.tensor([0, 1, 2])
100
+ block_size = 4
101
+ batch_in, batch_out, masks_in = get_batch(data, random_samples, masks=None, block_size=block_size, get_offset=True)
102
+ assert batch_in.shape == (3, block_size-1)
103
+ assert batch_out.shape == (3, block_size-1)
104
+ assert masks_in is None
105
+
106
+ def test_get_batch_with_masks_get_offset_false():
107
+ data = torch.arange(30).view(5, 6)
108
+ masks = torch.ones_like(data)
109
+ random_samples = torch.tensor([0, 1])
110
+ block_size = 5
111
+ batch_in, batch_out, masks_in = get_batch(data, random_samples, masks=masks, block_size=block_size, get_offset=False)
112
+ assert batch_in.shape == (2, block_size)
113
+ assert batch_out is None
114
+ assert masks_in.shape == (2, block_size)
115
+
116
+ def test_get_batch_block_size_larger_than_data_length_raises():
117
+ data = torch.arange(20).view(4, 5)
118
+ random_samples = torch.tensor([0])
119
+ block_size = 6
120
+ with pytest.raises(Exception):
121
+ get_batch(data, random_samples, block_size=block_size)
122
+
123
+
124
+ def test_top_kp_filter_top_k_only():
125
+ # Create dummy logits batch (2 samples, vocab size 5)
126
+ logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0],
127
+ [5.0, 4.0, 3.0, 2.0, 1.0]])
128
+ top_k = 3
129
+ selected = top_kp_filter(logits, top_k=top_k, top_p=None)
130
+
131
+ assert selected.shape == (2,)
132
+ # Selected indices must be in top_k tokens
133
+ for i, sel in enumerate(selected):
134
+ topk_indices = torch.topk(logits[i], top_k).indices.tolist()
135
+ assert sel.item() in topk_indices
136
+
137
+ def test_top_kp_filter_top_p_only():
138
+ # Dummy logits with clear probabilities
139
+ logits = torch.tensor([[0.1, 0.2, 0.3, 0.4],
140
+ [0.4, 0.3, 0.2, 0.1]])
141
+ top_p = 0.7
142
+ selected = top_kp_filter(logits, top_k=None, top_p=top_p)
143
+
144
+ assert selected.shape == (2,)
145
+ # Selected indices must be in vocab range
146
+ for sel in selected:
147
+ assert 0 <= sel.item() < logits.shape[1]
148
+
149
+ def test_top_kp_filter_top_k_and_top_p():
150
+ logits = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5],
151
+ [0.5, 0.4, 0.3, 0.2, 0.1]])
152
+ top_k = 2
153
+ top_p = 0.6
154
+ selected = top_kp_filter(logits, top_k=top_k, top_p=top_p)
155
+
156
+ assert selected.shape == (2,)
157
+ for i, sel in enumerate(selected):
158
+ # With both filters, selected index should be in top_k indices
159
+ topk_indices = torch.topk(logits[i], top_k).indices.tolist()
160
+ assert sel.item() in topk_indices
161
+
162
+ def test_top_kp_filter_no_filter():
163
+ logits = torch.tensor([[0.1, 0.2, 0.3],
164
+ [0.3, 0.2, 0.1]])
165
+ selected = top_kp_filter(logits, top_k=None, top_p=None)
166
+
167
+ assert selected.shape == (2,)
168
+ for sel in selected:
169
+ assert 0 <= sel.item() < logits.shape[1]
170
+
171
+ def test_top_kp_filter_empty_logits():
172
+ # Edge case: logits empty or zero size
173
+ logits = torch.empty((0, 0))
174
+ with pytest.raises(IndexError):
175
+ _ = top_kp_filter(logits, top_k=1, top_p=0.5)
176
+
@@ -0,0 +1,130 @@
1
+ import pytest
2
+ import torch
3
+ import tempfile
4
+ import os
5
+ from types import SimpleNamespace
6
+ from unittest.mock import patch, MagicMock
7
+ from robo_lib import RoboConstructor, save_component, load_component
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+
12
+ # ---------- FIXTURES AND MOCKS ----------
13
+
14
+ @pytest.fixture
15
+ def mock_encoder_block():
16
+ return MagicMock()
17
+
18
+ @pytest.fixture
19
+ def mock_decoder_block():
20
+ return MagicMock()
21
+
22
+ @pytest.fixture
23
+ def mock_my_sequential():
24
+ class DummySequential(torch.nn.Module):
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__()
27
+ def forward(self, *args, **kwargs):
28
+ return args[0], None, None, None, None
29
+ return DummySequential
30
+
31
+ @pytest.fixture
32
+ def dummy_tokenizer():
33
+ return SimpleNamespace(
34
+ start_token=1,
35
+ end_token=2,
36
+ pad_token=0,
37
+ new_line_token=3,
38
+ encode=lambda s: [4, 5, 6],
39
+ decode=lambda tokens: "decoded"
40
+ )
41
+
42
+ @pytest.fixture
43
+ def dummy_data():
44
+ return torch.randint(0, 10, (8, 32)).to(device) # 8 samples of 32 tokens
45
+
46
+ @pytest.fixture
47
+ def robo_decoder_only(mock_my_sequential):
48
+ with patch("robo_lib.DecoderBlock", return_value=MagicMock()), \
49
+ patch("robo_lib.MySequential", mock_my_sequential):
50
+ return RoboConstructor(
51
+ n_embed=16,
52
+ dec_n_blocks=2,
53
+ dec_n_head=2,
54
+ dec_vocab_size=50,
55
+ dec_block_size=32
56
+ ).to(device)
57
+
58
+ @pytest.fixture
59
+ def robo_enc_dec(mock_my_sequential):
60
+ with patch("robo_lib.DecoderBlock", return_value=MagicMock()), \
61
+ patch("robo_lib.EncoderBlock", return_value=MagicMock()), \
62
+ patch("robo_lib.MySequential", mock_my_sequential):
63
+ return RoboConstructor(
64
+ n_embed=16,
65
+ dec_n_blocks=2,
66
+ dec_n_head=2,
67
+ dec_vocab_size=50,
68
+ dec_block_size=32,
69
+ enc_n_blocks=2,
70
+ enc_n_head=2,
71
+ enc_vocab_size=50,
72
+ enc_block_size=32
73
+ ).to(device)
74
+
75
+ # ---------- TESTS ----------
76
+
77
+ def test_decoder_only_init(robo_decoder_only):
78
+ assert not robo_decoder_only.cross_attention
79
+ assert robo_decoder_only.decoder_blocks is not None
80
+ assert robo_decoder_only.encoder_blocks is None
81
+
82
+ def test_encoder_decoder_init(robo_enc_dec):
83
+ assert robo_enc_dec.cross_attention
84
+ assert robo_enc_dec.encoder_blocks is not None
85
+
86
+ def test_forward_decoder_only(robo_decoder_only):
87
+ input_tensor = torch.randint(0, 50, (2, 32)).to(device)
88
+ output = robo_decoder_only(dec_in=input_tensor)
89
+ assert output.shape[:2] == (2, 32)
90
+
91
+ def test_forward_encoder_decoder(robo_enc_dec):
92
+ dec_input = torch.randint(0, 50, (2, 32)).to(device)
93
+ enc_input = torch.randint(0, 50, (2, 32)).to(device)
94
+ output = robo_enc_dec(dec_in=dec_input, enc_in=enc_input)
95
+ assert output.shape[:2] == (2, 32)
96
+
97
+ @patch("robo_lib.get_batch")
98
+ def test_prep_data_decoder_only(mock_get_batch, robo_decoder_only, dummy_data):
99
+ mock_get_batch.return_value = (dummy_data[:2], dummy_data[:2], dummy_data[:2])
100
+ out = robo_decoder_only.prep_data(batch_size=2, dec_data=dummy_data, dec_block_size=32)
101
+ assert len(out) == 5
102
+ assert out[0].shape[0] == 2
103
+
104
+ @patch("robo_lib.get_batch")
105
+ def test_prep_data_encoder_decoder(mock_get_batch, robo_enc_dec, dummy_data):
106
+ mock_get_batch.side_effect = [
107
+ (dummy_data[:2], dummy_data[:2], dummy_data[:2]), # decoder
108
+ (dummy_data[:2], None, dummy_data[:2]) # encoder
109
+ ]
110
+ out = robo_enc_dec.prep_data(batch_size=2, dec_data=dummy_data, dec_block_size=32, enc_data=dummy_data, enc_block_size=32)
111
+ assert len(out) == 5
112
+ assert out[3].shape[0] == 2 # encoder input
113
+
114
+ @patch("robo_lib.top_kp_filter", return_value=torch.tensor([2]))
115
+ def test_generate_decoder_only(mock_top_kp, robo_decoder_only, dummy_tokenizer):
116
+ out = robo_decoder_only.generate(inputs="hello", dec_tokenizer=dummy_tokenizer, max_new_tokens=3, dec_start_token=1, dec_end_token=2)
117
+ assert isinstance(out, str)
118
+
119
+ @patch("robo_lib.top_kp_filter", return_value=torch.tensor([2]))
120
+ def test_generate_encoder_decoder(mock_top_kp, robo_enc_dec, dummy_tokenizer):
121
+ out = robo_enc_dec.generate(inputs="hello", enc_tokenizer=dummy_tokenizer, dec_tokenizer=dummy_tokenizer,
122
+ max_new_tokens=3, enc_start_token=1, enc_end_token=2, dec_start_token=1, dec_end_token=2)
123
+ assert isinstance(out, str)
124
+
125
+ def test_save_and_load_component(robo_decoder_only):
126
+ with tempfile.TemporaryDirectory() as tmpdir:
127
+ path = os.path.join(tmpdir, "test_model")
128
+ save_component(robo_decoder_only, path)
129
+ loaded = load_component(path)
130
+ assert isinstance(loaded, RoboConstructor)
@@ -0,0 +1,89 @@
1
+ import pytest
2
+ import os
3
+ from tempfile import NamedTemporaryFile
4
+ from robo_lib import TokenizerConstructor
5
+
6
+
7
+ @pytest.fixture
8
+ def training_file():
9
+ with NamedTemporaryFile(mode="w+", delete=False) as f:
10
+ f.write("Hello world\nThis is a test\nTokenizer test\n")
11
+ f.flush()
12
+ yield f.name
13
+ os.remove(f.name)
14
+
15
+
16
+ def test_tokenizer_creation():
17
+ tokenizer = TokenizerConstructor(
18
+ tokenizer_type="BPE",
19
+ pre_tokenizers="Whitespace",
20
+ normalizers=["Lowercase"],
21
+ special_tokens=["<unk>", "<pad>"],
22
+ vocab_size=100
23
+ )
24
+ assert tokenizer is not None
25
+ assert "<unk>" in tokenizer.special_tokens
26
+ assert tokenizer.vocab_size is None # Untrained tokenizer should have vocab_size None
27
+
28
+
29
+ def test_tokenizer_train(training_file):
30
+ tokenizer = TokenizerConstructor(
31
+ tokenizer_type="WordLevel",
32
+ pre_tokenizers="Whitespace",
33
+ normalizers=["Lowercase"],
34
+ special_tokens=["<unk>", "<pad>"],
35
+ vocab_size=50
36
+ )
37
+ tokenizer.train(training_file)
38
+ assert tokenizer.vocab_size is not None
39
+ assert tokenizer.vocab_size > 0
40
+
41
+
42
+ def test_tokenizer_encode_decode(training_file):
43
+ tokenizer = TokenizerConstructor(
44
+ tokenizer_type="BPE",
45
+ pre_tokenizers="Whitespace",
46
+ normalizers=["Lowercase"],
47
+ special_tokens=["<unk>", "<pad>"],
48
+ vocab_size=50
49
+ )
50
+ tokenizer.train(training_file)
51
+ encoded = tokenizer.encode("This is a test")
52
+ assert isinstance(encoded, list)
53
+ assert all(isinstance(i, int) for i in encoded)
54
+
55
+ decoded = tokenizer.decode(encoded)
56
+ assert isinstance(decoded, str)
57
+ assert len(decoded) > 0
58
+
59
+
60
+ def test_tokenizer_encode_batch(training_file):
61
+ tokenizer = TokenizerConstructor(
62
+ tokenizer_type="BPE",
63
+ pre_tokenizers="Whitespace",
64
+ normalizers=["Lowercase"],
65
+ special_tokens=["<unk>", "<pad>"],
66
+ vocab_size=50
67
+ )
68
+ tokenizer.train(training_file)
69
+ batch = ["This is a test", "Hello world"]
70
+ encoded_batch = tokenizer.encode_batch(batch)
71
+ assert isinstance(encoded_batch, list)
72
+ assert len(encoded_batch) == len(batch)
73
+ assert all(isinstance(seq, list) for seq in encoded_batch)
74
+
75
+ encoded_truncated = tokenizer.encode_batch(batch, max_length=3)
76
+ assert all(len(seq) <= 3 for seq in encoded_truncated)
77
+
78
+
79
+ def test_special_token_indexes():
80
+ tokenizer = TokenizerConstructor(
81
+ tokenizer_type="BPE",
82
+ pre_tokenizers="Whitespace",
83
+ special_tokens=["<unk>", "<sos>", "<eos>", "<pad>", "\n"]
84
+ )
85
+ assert tokenizer.unknown_token == tokenizer.special_tokens.index("<unk>")
86
+ assert tokenizer.start_token == tokenizer.special_tokens.index("<sos>")
87
+ assert tokenizer.end_token == tokenizer.special_tokens.index("<eos>")
88
+ assert tokenizer.pad_token == tokenizer.special_tokens.index("<pad>")
89
+ assert tokenizer.new_line_token == tokenizer.special_tokens.index("\n")
File without changes
File without changes
File without changes