robo-lib 0.0.11__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
robo_lib/__init__.py CHANGED
@@ -1,8 +1,7 @@
1
1
  from .components import TokenizerConstructor as TokenizerConstructor
2
2
  from .components import create_mask as create_mask
3
- from .components import pad as pad
4
- from .components import process_row as process_row
5
- from .components import scan_max_block_size as scan_max_block_size
3
+ from .components import pre_process_data as pre_process_data
4
+ from .components import safe_stack as safe_stack
6
5
  from .components import DataProcessor as DataProcessor
7
6
  from .components import get_valid_samples as get_valid_samples
8
7
  from .components import get_batch as get_batch
robo_lib/components.py CHANGED
@@ -6,6 +6,8 @@ import numpy as np
6
6
  import random
7
7
  import pickle
8
8
  import itertools
9
+ from pathlib import Path
10
+ import os
9
11
 
10
12
  class TokenizerConstructor:
11
13
  '''
@@ -30,6 +32,7 @@ class TokenizerConstructor:
30
32
  tokenizer_type:str="BPE",
31
33
  pre_tokenizers:list[str]|str=["Whitespace"],
32
34
  normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
35
+ vocab:dict[str,int] = {},
33
36
  special_tokens:list[str]|str=[],
34
37
  unknown_token_string:str="<unk>",
35
38
  start_token_string:str="<sos>",
@@ -42,25 +45,28 @@ class TokenizerConstructor:
42
45
 
43
46
  if isinstance(special_tokens, str):
44
47
  special_tokens = [special_tokens]
45
- self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token != None]
46
- self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string != None else None
47
- self.start_token = self.special_tokens.index(start_token_string) if start_token_string != None else None
48
- self.end_token = self.special_tokens.index(end_token_string) if end_token_string != None else None
49
- self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string != None else None
50
- self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string != None else None
48
+ self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token is not None]
49
+ self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string is not None else None
50
+ self.start_token = self.special_tokens.index(start_token_string) if start_token_string is not None else None
51
+ self.end_token = self.special_tokens.index(end_token_string) if end_token_string is not None else None
52
+ self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string is not None else None
53
+ self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string is not None else None
51
54
 
52
55
  if tokenizer_type == "BPE":
53
56
  self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
54
57
  self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
55
58
  elif tokenizer_type == "WordLevel":
56
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
59
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(vocab = vocab, unk_token=unknown_token_string))
57
60
  self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
58
61
  elif tokenizer_type == "WordPiece":
59
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
62
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(vocab = vocab, unk_token=unknown_token_string))
60
63
  self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
61
64
  elif tokenizer_type == "Unigram":
62
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(unk_token=unknown_token_string))
65
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram())
63
66
  self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
67
+
68
+ if self.pad_token is not None:
69
+ self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=pad_token_string)
64
70
 
65
71
  if isinstance(pre_tokenizers, str):
66
72
  pre_tokenizers = [pre_tokenizers]
@@ -122,6 +128,13 @@ class TokenizerConstructor:
122
128
  def encode(self, inp:str) -> list[int]:
123
129
  return self.tokenizer_type.encode(inp).ids
124
130
 
131
+ def encode_batch(self, inp:list[str], max_length:int=None) -> list[list[int]]:
132
+ if max_length is not None:
133
+ self.tokenizer_type.enable_truncation(max_length=max_length)
134
+ out = [row.ids for row in self.tokenizer_type.encode_batch(inp)]
135
+ self.tokenizer_type.no_truncation()
136
+ return out
137
+
125
138
  def decode(self, inp:list[int]) -> str:
126
139
  return self.tokenizer_type.decode(inp)
127
140
 
@@ -136,38 +149,35 @@ def create_mask(row:list, block_size:int) -> list[bool]:
136
149
  mask = [1]*len(row) + [0]*(block_size - len(row))
137
150
  return mask
138
151
 
139
- def pad(row:list, block_size:int, pad_token:int) -> list[int]:
140
- '''
141
-
142
- returns padded row. Row is padded until length block_size with specified pad_token value
143
-
144
- '''
145
- row.extend([pad_token]*(block_size - len(row)))
146
- return row
147
-
148
- def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
152
+ def pre_process_data(data:str, start_token_string:str, end_token_string:str) -> list[int]:
149
153
  '''
150
154
 
151
- returns tokenized row using specified tokenizer, and adds the tokenizer's start and end tokens if they exist
155
+ returns string row with the tokenizer's start and end tokens if they exist
152
156
 
153
157
  '''
154
- processed_row = tokenizer.encode(row)
155
- if tokenizer.start_token != None:
156
- processed_row.insert(0, tokenizer.start_token)
157
- if tokenizer.end_token != None:
158
- processed_row.append(tokenizer.end_token)
159
-
160
- return processed_row
158
+ if start_token_string is None and end_token_string is None:
159
+ return data
160
+ else:
161
+ for i in range(len(data)):
162
+ if start_token_string is not None:
163
+ data[i] = start_token_string + data[i]
164
+ if end_token_string is not None:
165
+ data[i] = data[i] + end_token_string
166
+
167
+ return data
161
168
 
162
- def scan_max_block_size(data:list[str], tokenizer:TokenizerConstructor) -> int:
169
+ def safe_stack(tensor_list:list[torch.tensor]) -> torch.tensor:
163
170
  '''
164
171
 
165
- returns max_block_size of given list of strings by taking the length of the longest process_row(row) in data
172
+ torch stack with check to ensure tensors are valid in input list
173
+
174
+ returns torch.stack(out_list) for all valid torch tensors in tensor_list. raises error if no valid tensors
166
175
 
167
176
  '''
168
- lengths = [len(process_row(p, tokenizer)) for p in data]
169
- max_block_size_scanner = max(lengths)
170
- return max_block_size_scanner
177
+ out_list = [row for row in tensor_list if isinstance(row, torch.Tensor)]
178
+ if len(out_list) == 0:
179
+ raise ValueError("no valid tensors in list.")
180
+ return torch.stack(out_list)
171
181
 
172
182
 
173
183
  class DataProcessor:
@@ -196,93 +206,55 @@ class DataProcessor:
196
206
  self.enc_tokenizer = enc_tokenizer
197
207
 
198
208
  def process_list(self,
199
- save_path:str,
200
209
  dec_data:list[str]|str,
201
210
  dec_max_block_size:int=None,
202
211
  dec_create_masks:bool=True,
203
- dec_block_size_exceeded_policy:str=None,
204
212
  enc_data:list[str]=None,
205
213
  enc_max_block_size:int=None,
206
214
  enc_create_masks:bool=True,
207
- enc_block_size_exceeded_policy:str=None
215
+ save_path:str = "."
208
216
  ) -> None:
209
217
 
210
218
  if isinstance(dec_data, str):
211
219
  dec_data = [dec_data]
212
220
  dec_data_length = len(dec_data)
213
- save_path = save_path.replace(".pt", "")
214
-
215
- if dec_max_block_size == None:
216
- dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
217
221
 
218
- if enc_data != None:
219
- self.enc_tokenizer = self.dec_tokenizer if self.enc_tokenizer == None else self.enc_tokenizer
222
+ if enc_data is not None:
223
+ if self.enc_tokenizer is None:
224
+ self.enc_tokenizer = self.dec_tokenizer
220
225
 
221
226
  enc_data_length = len(enc_data)
222
227
  if dec_data_length != enc_data_length:
223
- raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
228
+ raise Exception(f"decoder data and encoder data lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
224
229
 
225
- if enc_max_block_size == None:
226
- enc_max_block_size = scan_max_block_size(enc_data, self.enc_tokenizer)
227
-
228
- enc_out_list = [[]]*enc_data_length
229
- enc_mask_list = [[]]*enc_data_length if enc_create_masks else []
230
- else:
231
- enc_out_list = []
232
- enc_mask_list = []
233
-
234
- dec_out_list = [[]]*dec_data_length
235
- dec_mask_list = [[]]*dec_data_length if dec_create_masks else []
236
- for index in range(len(dec_out_list)):
237
- dec_processed_item = process_row(dec_data[index], self.dec_tokenizer)
238
- if dec_max_block_size != None and len(dec_processed_item) > dec_max_block_size:
239
- if dec_block_size_exceeded_policy == "trim":
240
- dec_processed_item = dec_processed_item[:dec_max_block_size]
241
- elif dec_block_size_exceeded_policy == "skip":
242
- continue
243
- elif dec_block_size_exceeded_policy == None:
244
- raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
245
- if dec_create_masks:
246
- dec_mask = create_mask(dec_processed_item, dec_max_block_size)
247
- dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
248
-
249
- if enc_data != None:
250
- enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
251
- if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
252
- if enc_block_size_exceeded_policy == "trim":
253
- enc_processed_item = enc_processed_item[:enc_max_block_size]
254
- elif enc_block_size_exceeded_policy == "skip":
255
- continue
256
- elif enc_block_size_exceeded_policy == None:
257
- raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
258
- if enc_create_masks:
259
- enc_mask = create_mask(enc_processed_item, enc_max_block_size)
260
- enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
261
-
262
- dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
263
- if dec_create_masks:
264
- dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
265
-
266
- if enc_data != None:
267
- enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
268
- if enc_create_masks:
269
- enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
270
-
271
- dec_out_list = torch.stack([row for row in dec_out_list if row != []])
272
- torch.save(dec_out_list, save_path + "_decoder_data.pt")
230
+ print("processing data")
231
+ dec_out_list = self.dec_tokenizer.encode_batch(dec_data, max_length=dec_max_block_size)
273
232
  if dec_create_masks:
274
- dec_mask_list = torch.stack([row for row in dec_mask_list if row != []])
275
- torch.save(dec_mask_list, save_path + "_decoder_mask_data.pt")
276
- if enc_data != None:
277
- enc_out_list = torch.stack([row for row in enc_out_list if row != []])
278
- torch.save(enc_out_list, save_path + "_encoder_data.pt")
233
+ mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.dec_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
234
+ dec_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in dec_out_list])
235
+
236
+ if enc_data is not None:
237
+ enc_out_list = self.enc_tokenizer.encode_batch(enc_data, max_length=enc_max_block_size)
279
238
  if enc_create_masks:
280
- enc_mask_list = torch.stack([row for row in enc_mask_list if row != []])
281
- torch.save(enc_mask_list, save_path + "_encoder_mask_data.pt")
239
+ mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.enc_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
240
+ enc_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in enc_out_list])
282
241
 
242
+ dec_out_list = torch.tensor(dec_out_list, dtype=torch.long)
243
+ Path(save_path).mkdir(parents=True, exist_ok=True)
244
+ torch.save(dec_out_list, os.path.join(save_path, "decoder_data.pt"))
245
+ if dec_create_masks:
246
+ dec_mask_list = torch.tensor(dec_mask_list, dtype=torch.long)
247
+ torch.save(dec_mask_list, os.path.join(save_path, "decoder_mask_data.pt"))
248
+ if enc_data is not None:
249
+ enc_out_list = torch.tensor(enc_out_list, dtype=torch.long)
250
+ torch.save(enc_out_list, os.path.join(save_path, "encoder_data.pt"))
251
+ if enc_create_masks:
252
+ enc_mask_list = torch.tensor(enc_mask_list, dtype=torch.long)
253
+ torch.save(enc_mask_list, os.path.join(save_path, "encoder_mask_data.pt"))
283
254
 
284
- def get_valid_samples(random_samples:torch.tensor,
285
- masks:torch.tensor,
255
+
256
+ def get_valid_samples(random_samples:torch.Tensor,
257
+ masks:torch.Tensor,
286
258
  block_size:int
287
259
  ) -> list[int]:
288
260
  '''
@@ -294,9 +266,9 @@ def get_valid_samples(random_samples:torch.tensor,
294
266
  valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
295
267
  return valid_samples
296
268
 
297
- def get_batch(data:torch.tensor,
298
- random_samples:torch.tensor,
299
- masks:torch.tensor=None,
269
+ def get_batch(data:torch.Tensor,
270
+ random_samples:torch.Tensor,
271
+ masks:torch.Tensor=None,
300
272
  block_size:int=None,
301
273
  get_offset:bool=True
302
274
  ) -> tuple[torch.tensor]:
@@ -308,53 +280,78 @@ def get_batch(data:torch.tensor,
308
280
 
309
281
  '''
310
282
  batch_size = len(random_samples)
311
- if block_size != None and block_size != data.shape[1]:
283
+ if block_size is not None and block_size != data.shape[1]:
312
284
  if block_size >= data.shape[1]:
313
285
  raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
314
286
 
315
- if masks != None:
287
+ if masks is not None:
316
288
  random_point = get_valid_samples(random_samples, masks, block_size)
317
289
  else:
318
290
  random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
319
291
  batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
320
- masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks != None else None
292
+ masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks is not None else None
321
293
  batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
322
294
  else:
323
295
  block_size = data.shape[1]
324
296
  batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
325
- masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks != None else None
297
+ masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks is not None else None
326
298
  batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
327
299
 
328
300
  return batch_in, batch_out, masks_in
329
301
 
330
- def top_kp_filter(logits:torch.tensor,
331
- top_k:int,
332
- top_p:float=None
333
- ) -> torch.tensor:
302
+ def top_kp_filter(logits: torch.Tensor,
303
+ top_k: int = None,
304
+ top_p: float = None
305
+ ) -> torch.Tensor:
334
306
  '''
307
+ Returns predicted token by filtering output logits using top_k and/or top_p (nucleus) filtering.
335
308
 
336
- returns predicted token by filtering output logits using top_k and top_p
337
-
338
- '''
339
- if top_p != None:
340
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
341
- cumulative_probs = torch.cumsum(sorted_logits, dim=-1)
309
+ Args:
310
+ logits: (batch_size, vocab_size) tensor of raw logits.
311
+ top_k: keep only top_k tokens with highest logits.
312
+ top_p: keep the smallest set of tokens with cumulative probability >= top_p.
342
313
 
343
- filter = cumulative_probs > top_p
344
- filter[..., 1:] = filter[..., :-1].clone()
345
- filter[..., 0] = 0
346
- indices_to_remove = filter.scatter(1, sorted_indices, filter)
347
- logits[indices_to_remove] = float("-inf")
348
-
349
- if top_k != None:
350
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
314
+ Returns:
315
+ selected: tensor of selected token indices (batch_size,)
316
+ '''
317
+ logits = logits.clone() # avoid modifying input logits in-place
318
+
319
+ # Apply top-p filtering if specified
320
+ if top_p is not None:
321
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
322
+ probs = F.softmax(sorted_logits, dim=-1)
323
+ cumulative_probs = torch.cumsum(probs, dim=-1)
324
+
325
+ # Remove tokens with cumulative probability above threshold (except first token)
326
+ sorted_mask = cumulative_probs > top_p
327
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
328
+ sorted_mask[..., 0] = False
329
+
330
+ # Mask tokens to remove by setting logits to -inf
331
+ indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
332
+ logits[indices_to_remove] = float('-inf')
333
+
334
+ # Apply top-k filtering if specified
335
+ if top_k is not None:
336
+ top_k = min(top_k, logits.size(-1)) # safety check
337
+ topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
338
+ topk_probs = F.softmax(topk_logits, dim=-1).cpu().numpy()
339
+
340
+ # For each batch, sample from top_k candidates
341
+ selected = []
342
+ for i in range(topk_probs.shape[0]):
343
+ candidate = np.random.choice(topk_indices[i].cpu().numpy(), 1, p=topk_probs[i])
344
+ selected.append(candidate[0])
345
+ selected = torch.tensor(selected, dtype=torch.long)
351
346
 
352
- sorted_logits = F.softmax(sorted_logits[:, :top_k], dim=-1)
353
- sorted_indices = sorted_indices[:, :top_k].detach().cpu()
354
- sorted_logits = sorted_logits.detach().cpu().numpy()
355
- sorted_logits[0][0] += 1 - sum(sorted_logits[0])
356
-
357
- selected = torch.tensor(np.random.choice(sorted_indices[0], 1, p=sorted_logits[0]), dtype=torch.long)
347
+ else:
348
+ # If only top_p is specified, sample from entire filtered logits
349
+ probs = F.softmax(logits, dim=-1).cpu().numpy()
350
+ selected = []
351
+ for i in range(probs.shape[0]):
352
+ candidate = np.random.choice(len(probs[i]), 1, p=probs[i])
353
+ selected.append(candidate[0])
354
+ selected = torch.tensor(selected, dtype=torch.long)
358
355
 
359
356
  return selected
360
357
 
@@ -387,10 +384,10 @@ class SelfAttention(nn.Module):
387
384
  self.dropout = nn.Dropout(dropout)
388
385
 
389
386
  def forward(self,
390
- k:torch.tensor,
391
- q:torch.tensor,
392
- v:torch.tensor,
393
- mask:torch.tensor=None
387
+ k:torch.Tensor,
388
+ q:torch.Tensor,
389
+ v:torch.Tensor,
390
+ mask:torch.Tensor=None
394
391
  ) -> torch.tensor:
395
392
  '''
396
393
 
@@ -406,7 +403,7 @@ class SelfAttention(nn.Module):
406
403
  wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
407
404
  if self.triangle_mask and self.block_size >= 0:
408
405
  wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
409
- if mask != None:
406
+ if mask is not None:
410
407
  wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
411
408
  wei = F.softmax(wei, dim=-1)
412
409
  wei = self.dropout(wei)
@@ -438,10 +435,10 @@ class MultiHeadAttention(nn.Module):
438
435
  self.dropout = nn.Dropout(dropout)
439
436
 
440
437
  def forward(self,
441
- k:torch.tensor,
442
- q:torch.tensor,
443
- v:torch.tensor,
444
- mask:torch.tensor=None
438
+ k:torch.Tensor,
439
+ q:torch.Tensor,
440
+ v:torch.Tensor,
441
+ mask:torch.Tensor=None
445
442
  ) -> torch.tensor:
446
443
  '''
447
444
 
@@ -475,7 +472,7 @@ class FeedForward(nn.Module):
475
472
  )
476
473
 
477
474
  def forward(self,
478
- x:torch.tensor
475
+ x:torch.Tensor
479
476
  ) -> torch.tensor:
480
477
  return self.net(x)
481
478
 
@@ -500,8 +497,8 @@ class EncoderBlock(nn.Module):
500
497
  self.ln2 = nn.LayerNorm(n_embed)
501
498
 
502
499
  def forward(self,
503
- x:torch.tensor,
504
- mask:torch.tensor=None
500
+ x:torch.Tensor,
501
+ mask:torch.Tensor=None
505
502
  ) -> tuple[torch.tensor]:
506
503
  att = self.sa(x, x, x, mask=mask)
507
504
  x = self.ln1(att + x)
@@ -541,15 +538,15 @@ class DecoderBlock(nn.Module):
541
538
  self.ca = None
542
539
 
543
540
  def forward(self,
544
- x:torch.tensor,
545
- enc_k:torch.tensor,
546
- enc_v:torch.tensor,
541
+ x:torch.Tensor,
542
+ enc_k:torch.Tensor,
543
+ enc_v:torch.Tensor,
547
544
  mask_out:bool=None,
548
- mask_in:torch.tensor=None
545
+ mask_in:torch.Tensor=None
549
546
  ) -> tuple[torch.tensor]:
550
547
  att = self.sa(x, x, x, mask=mask_out)
551
548
  x = self.ln1(att + x)
552
- if self.ca != None:
549
+ if self.ca is not None:
553
550
  catt = self.ca(enc_k, x, enc_v, mask=mask_in)
554
551
  x = self.ln3(catt + x)
555
552
  ff = self.ffwd(x)
@@ -628,7 +625,7 @@ class RoboConstructor(nn.Module):
628
625
  self.dec_expansion_factor = dec_expansion_factor
629
626
  self.dropout = dropout
630
627
 
631
- if device == None:
628
+ if device is None:
632
629
  self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
633
630
  else:
634
631
  self.device = device
@@ -673,13 +670,13 @@ class RoboConstructor(nn.Module):
673
670
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
674
671
 
675
672
  def forward(self,
676
- dec_in:torch.tensor,
677
- dec_mask:torch.tensor=None,
678
- enc_in:torch.tensor=None,
679
- enc_mask:torch.tensor=None
673
+ dec_in:torch.Tensor,
674
+ dec_mask:torch.Tensor=None,
675
+ enc_in:torch.Tensor=None,
676
+ enc_mask:torch.Tensor=None
680
677
  ) -> torch.tensor:
681
678
  _, dec_T = dec_in.shape
682
- if enc_in != None:
679
+ if enc_in is not None:
683
680
  _, enc_T = enc_in.shape
684
681
 
685
682
  dec_tok_emb = self.dec_token_embedding_table(dec_in)
@@ -718,13 +715,13 @@ class RoboConstructor(nn.Module):
718
715
 
719
716
  dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
720
717
  dec_train_batch_in = dec_train_batch_in.to(self.device)
721
- dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out != None else None
722
- dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in != None else None
718
+ dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out is not None else None
719
+ dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in is not None else None
723
720
 
724
721
  if self.cross_attention:
725
722
  enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
726
723
  enc_train_batch_in = enc_train_batch_in.to(self.device)
727
- enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in != None else None
724
+ enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in is not None else None
728
725
  else:
729
726
  enc_train_batch_in = None
730
727
  enc_train_masks_in = None
@@ -736,14 +733,8 @@ class RoboConstructor(nn.Module):
736
733
  max_iters:int,
737
734
  eval_interval:int,
738
735
  batch_size:int,
739
- dec_training_path:str,
740
- dec_eval_path:str=None,
741
- dec_training_masks_path:str=None,
742
- dec_eval_masks_path:str=None,
743
- enc_training_path:str=None,
744
- enc_eval_path:str=None,
745
- enc_training_masks_path:str=None,
746
- enc_eval_masks_path:str=None,
736
+ training_dir_path:str,
737
+ eval_dir_path:str,
747
738
  eval_iters:int=3,
748
739
  learning_rate:float=1e-4,
749
740
  pad_token:int=None,
@@ -752,21 +743,36 @@ class RoboConstructor(nn.Module):
752
743
  label_smoothing:float=0.1
753
744
  ) -> None:
754
745
 
746
+ dec_training_path = os.path.join(training_dir_path, "decoder_data.pt")
755
747
  dec_training_data = torch.load(dec_training_path, weights_only=True)
756
- dec_eval_data = torch.load(dec_eval_path, weights_only=True) if dec_eval_path != None else None
757
- dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if dec_training_masks_path != None else None
758
- dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if dec_eval_masks_path != None else None
759
- enc_training_data = torch.load(enc_training_path, weights_only=True) if enc_training_path != None else None
760
- enc_eval_data = torch.load(enc_eval_path, weights_only=True) if enc_eval_path != None else None
761
- enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if enc_training_masks_path != None else None
762
- enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if enc_eval_masks_path != None else None
763
-
764
- if pad_token == None and dec_tokenizer != None:
748
+
749
+ dec_eval_path = os.path.join(eval_dir_path, "decoder_data.pt")
750
+ dec_eval_data = torch.load(dec_eval_path, weights_only=True) if os.path.isfile(dec_eval_path) else None
751
+
752
+ dec_training_masks_path = os.path.join(training_dir_path, "decoder_mask_data.pt")
753
+ dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if os.path.isfile(dec_training_masks_path) else None
754
+
755
+ dec_eval_masks_path = os.path.join(eval_dir_path, "decoder_mask_data.pt")
756
+ dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if os.path.isfile(dec_eval_masks_path) else None
757
+
758
+ enc_training_path = os.path.join(training_dir_path, "encoder_data.pt")
759
+ enc_training_data = torch.load(enc_training_path, weights_only=True) if os.path.isfile(enc_training_path) else None
760
+
761
+ enc_eval_path = os.path.join(eval_dir_path, "encoder_data.pt")
762
+ enc_eval_data = torch.load(enc_eval_path, weights_only=True) if os.path.isfile(enc_eval_path) else None
763
+
764
+ enc_training_masks_path = os.path.join(training_dir_path, "encoder_mask_data.pt")
765
+ enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if os.path.isfile(enc_training_masks_path) else None
766
+
767
+ enc_eval_masks_path = os.path.join(eval_dir_path, "encoder_mask_data.pt")
768
+ enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if os.path.isfile(enc_eval_masks_path) else None
769
+
770
+ if pad_token is None and dec_tokenizer is not None:
765
771
  pad_token = dec_tokenizer.pad_token
766
772
 
767
773
  self.to(self.device)
768
774
 
769
- if pad_token != None:
775
+ if pad_token is not None:
770
776
  loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
771
777
  else:
772
778
  loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
@@ -782,7 +788,7 @@ class RoboConstructor(nn.Module):
782
788
  proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
783
789
  losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
784
790
  out["train"] = losses.mean()
785
- if dec_eval_data != None:
791
+ if dec_eval_data is not None:
786
792
  for k in range(eval_iters):
787
793
  dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
788
794
  proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
@@ -798,7 +804,7 @@ class RoboConstructor(nn.Module):
798
804
  if iter % eval_interval == 0 or iter == max_iters-1:
799
805
  losses = estimate_loss()
800
806
  print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
801
- if save_path != None:
807
+ if save_path is not None:
802
808
  save_component(self, save_path=save_path)
803
809
 
804
810
  dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
@@ -825,25 +831,25 @@ class RoboConstructor(nn.Module):
825
831
  top_k:int=None,
826
832
  top_p:float=None
827
833
  ) -> list[int]|str:
828
- max_new_tokens = self.dec_block_size if max_new_tokens == None else max_new_tokens
834
+ max_new_tokens = self.dec_block_size if max_new_tokens is None else max_new_tokens
829
835
 
830
836
  if self.cross_attention:
831
- if enc_tokenizer != None:
832
- if enc_start_token == None:
837
+ if enc_tokenizer is not None:
838
+ if enc_start_token is None:
833
839
  enc_start_token = enc_tokenizer.start_token
834
- if enc_end_token == None:
840
+ if enc_end_token is None:
835
841
  enc_end_token = enc_tokenizer.end_token
836
842
  if isinstance(inputs, str):
837
843
  inputs = enc_tokenizer.encode(inputs)
838
844
 
839
- if dec_tokenizer != None:
840
- if dec_start_token == None:
845
+ if dec_tokenizer is not None:
846
+ if dec_start_token is None:
841
847
  dec_start_token = dec_tokenizer.start_token
842
- if dec_end_token == None:
848
+ if dec_end_token is None:
843
849
  dec_end_token = dec_tokenizer.end_token
844
- if new_line_token == None:
850
+ if new_line_token is None:
845
851
  new_line_token = dec_tokenizer.new_line_token
846
- if self.cross_attention == False and isinstance(inputs, str):
852
+ if not self.cross_attention and isinstance(inputs, str):
847
853
  inputs = dec_tokenizer.encode(inputs)
848
854
 
849
855
 
@@ -852,7 +858,7 @@ class RoboConstructor(nn.Module):
852
858
  idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
853
859
  else:
854
860
  enc_input = None
855
- if separator_token != None:
861
+ if separator_token is not None:
856
862
  idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
857
863
  else:
858
864
  idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
@@ -866,7 +872,7 @@ class RoboConstructor(nn.Module):
866
872
  logits = proj_output[:, -1, :]
867
873
  probabilities = F.log_softmax(logits/temperature, dim=-1)
868
874
 
869
- if top_k == None and top_p == None:
875
+ if top_k is None and top_p is None:
870
876
  idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
871
877
  else:
872
878
  idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
@@ -874,10 +880,10 @@ class RoboConstructor(nn.Module):
874
880
  if idx_next[0] == dec_end_token:
875
881
  break
876
882
 
877
- if dec_tokenizer == None:
883
+ if dec_tokenizer is None:
878
884
  return idx[0].tolist()
879
885
  else:
880
- if new_line_token != None:
886
+ if new_line_token is not None:
881
887
  return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
882
888
  else:
883
889
  return dec_tokenizer.decode(idx[0].tolist())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robo_lib
3
- Version: 0.0.11
3
+ Version: 1.0.0
4
4
  Summary: A package to create, configure, and train transformer models.
5
5
  Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
6
6
  Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
@@ -0,0 +1,6 @@
1
+ robo_lib/__init__.py,sha256=NnzWHWwpFcSJD_XRMWKKPQFAIrRBFYiCFN0pgUGPygc,968
2
+ robo_lib/components.py,sha256=M_1M1Y56_W0bSElZlg3M6gRoJJPAnUchTO3N8AdsEV8,43091
3
+ robo_lib-1.0.0.dist-info/METADATA,sha256=GAnmrynDr3-hv9KyCjXlpx5I8v2BLQJCIDXURoGFw2w,9633
4
+ robo_lib-1.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ robo_lib-1.0.0.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
6
+ robo_lib-1.0.0.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- robo_lib/__init__.py,sha256=iVOAsANj0lScVW9KKMxCULYmpp0cv4sv1k3sHjBSlE0,1012
2
- robo_lib/components.py,sha256=L_GUEHdKC_-Xn56ObQ9-DH8T1ywaz0M8jlWv227gZBs,42591
3
- robo_lib-0.0.11.dist-info/METADATA,sha256=ePF06l2FXzo0qjK8v9Vob4WnOQ61KVd0mUqd7JVG7j4,9634
4
- robo_lib-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- robo_lib-0.0.11.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
6
- robo_lib-0.0.11.dist-info/RECORD,,