robo-lib 0.0.11__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
robo_lib/__init__.py CHANGED
@@ -1,8 +1,7 @@
1
1
  from .components import TokenizerConstructor as TokenizerConstructor
2
2
  from .components import create_mask as create_mask
3
- from .components import pad as pad
4
- from .components import process_row as process_row
5
- from .components import scan_max_block_size as scan_max_block_size
3
+ from .components import pre_process_data as pre_process_data
4
+ from .components import safe_stack as safe_stack
6
5
  from .components import DataProcessor as DataProcessor
7
6
  from .components import get_valid_samples as get_valid_samples
8
7
  from .components import get_batch as get_batch
robo_lib/components.py CHANGED
@@ -6,30 +6,28 @@ import numpy as np
6
6
  import random
7
7
  import pickle
8
8
  import itertools
9
+ from pathlib import Path
10
+ import os
11
+ from typing import List, Literal
12
+
13
+ pre_tokenizers = Literal["Whitespace", "IndividualDigit", "Digits", "BertPreTokenizer", "ByteLevel", "Metaspace", "Punctuation", "UnicodeScripts", "WhitespaceSplit"]
9
14
 
10
15
  class TokenizerConstructor:
11
16
  '''
12
-
13
17
  simple assembler for tokenizer using the tokenizers library
14
- tokenizer parameters can be set using strings and list[string]s
18
+ tokenizer parameters can be set using strings and list[str]s
15
19
  strings used for tokenizer_type, pre_tokenizers, normalizers arguments are the names of those present in the
16
20
  tokenizers library. Additionally "IndividualDigits" can be used in normalizers for tokenizers.pre_tokenizers.Digits(individual_digits=True)
17
21
 
18
- train([paths]) function points to text files to be used for training the tokenizer instance
19
-
20
- encode(string) function encodes string using trained tokenizer instance
21
-
22
- decode(list[int]) function decodes list of tokenz using trained tokenizer instance
23
-
24
22
  vocab_size attribute returns the tokenizer instance's vocab_size (untrained tokenizer will have vocab_size=None)
25
23
 
26
-
27
24
  '''
28
25
  def __init__(self,
29
26
  min_frequency:int=2,
30
- tokenizer_type:str="BPE",
31
- pre_tokenizers:list[str]|str=["Whitespace"],
27
+ tokenizer_type:Literal["BPE", "WordLevel", "WordPiece", "Unigram"] = "BPE",
28
+ pre_tokenizers: pre_tokenizers|List[pre_tokenizers]=["Whitespace"],
32
29
  normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
30
+ vocab:dict[str,int] = {},
33
31
  special_tokens:list[str]|str=[],
34
32
  unknown_token_string:str="<unk>",
35
33
  start_token_string:str="<sos>",
@@ -42,25 +40,29 @@ class TokenizerConstructor:
42
40
 
43
41
  if isinstance(special_tokens, str):
44
42
  special_tokens = [special_tokens]
45
- self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token != None]
46
- self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string != None else None
47
- self.start_token = self.special_tokens.index(start_token_string) if start_token_string != None else None
48
- self.end_token = self.special_tokens.index(end_token_string) if end_token_string != None else None
49
- self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string != None else None
50
- self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string != None else None
43
+ self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token is not None]
44
+ self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string is not None else None
45
+ self.start_token = self.special_tokens.index(start_token_string) if start_token_string is not None else None
46
+ self.end_token = self.special_tokens.index(end_token_string) if end_token_string is not None else None
47
+ self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string is not None else None
48
+ self.pad_token_string = pad_token_string
49
+ self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string is not None else None
51
50
 
52
51
  if tokenizer_type == "BPE":
53
52
  self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
54
53
  self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
55
54
  elif tokenizer_type == "WordLevel":
56
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
55
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(vocab = vocab, unk_token=unknown_token_string))
57
56
  self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
58
57
  elif tokenizer_type == "WordPiece":
59
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
58
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(vocab = vocab, unk_token=unknown_token_string))
60
59
  self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
61
60
  elif tokenizer_type == "Unigram":
62
- self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(unk_token=unknown_token_string))
61
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram())
63
62
  self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
63
+
64
+ if self.pad_token is not None:
65
+ self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=pad_token_string)
64
66
 
65
67
  if isinstance(pre_tokenizers, str):
66
68
  pre_tokenizers = [pre_tokenizers]
@@ -114,79 +116,76 @@ class TokenizerConstructor:
114
116
 
115
117
 
116
118
  def train(self, training_paths:list[str]|str) -> None:
119
+ '''
120
+ points to text files to be used for training the tokenizer instance
121
+ '''
117
122
  if isinstance(training_paths, str):
118
123
  training_paths = [training_paths]
119
124
  self.tokenizer_type.train(training_paths, trainer=self.trainer)
120
125
  self.vocab_size = self.tokenizer_type.get_vocab_size()
121
126
 
122
127
  def encode(self, inp:str) -> list[int]:
128
+ '''
129
+ encodes string using trained tokenizer instance
130
+ '''
123
131
  return self.tokenizer_type.encode(inp).ids
124
132
 
133
+ def encode_batch(self, inp:list[str], max_length:int=None) -> list[list[int]]:
134
+ '''
135
+ encodes strings in parallel and truncates entries with length > max_length
136
+ '''
137
+ if max_length is not None:
138
+ self.tokenizer_type.enable_truncation(max_length=max_length)
139
+ self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=self.pad_token_string, length=max_length)
140
+ out = [row.ids for row in self.tokenizer_type.encode_batch(inp)]
141
+ self.tokenizer_type.no_truncation()
142
+ self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=self.pad_token_string)
143
+ return out
144
+
125
145
  def decode(self, inp:list[int]) -> str:
146
+ '''
147
+ decodes list of tokenz using trained tokenizer instance
148
+ '''
126
149
  return self.tokenizer_type.decode(inp)
127
150
 
128
151
 
129
152
 
130
153
  def create_mask(row:list, block_size:int) -> list[bool]:
131
154
  '''
132
-
133
155
  creates a mask list of length block_size for row, asuming mask does cover the entire row input
134
-
135
156
  '''
136
157
  mask = [1]*len(row) + [0]*(block_size - len(row))
137
158
  return mask
138
159
 
139
- def pad(row:list, block_size:int, pad_token:int) -> list[int]:
140
- '''
141
-
142
- returns padded row. Row is padded until length block_size with specified pad_token value
143
-
144
- '''
145
- row.extend([pad_token]*(block_size - len(row)))
146
- return row
147
-
148
- def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
160
+ def pre_process_data(data:str, start_token_string:str, end_token_string:str) -> list[int]:
149
161
  '''
150
-
151
- returns tokenized row using specified tokenizer, and adds the tokenizer's start and end tokens if they exist
152
-
162
+ returns data with the tokenizer's start and end tokens added to each row if they exist
153
163
  '''
154
- processed_row = tokenizer.encode(row)
155
- if tokenizer.start_token != None:
156
- processed_row.insert(0, tokenizer.start_token)
157
- if tokenizer.end_token != None:
158
- processed_row.append(tokenizer.end_token)
159
-
160
- return processed_row
164
+ if start_token_string is None and end_token_string is None:
165
+ return data
166
+ else:
167
+ for i in range(len(data)):
168
+ if start_token_string is not None:
169
+ data[i] = start_token_string + data[i]
170
+ if end_token_string is not None:
171
+ data[i] = data[i] + end_token_string
172
+
173
+ return data
161
174
 
162
- def scan_max_block_size(data:list[str], tokenizer:TokenizerConstructor) -> int:
175
+ def safe_stack(tensor_list:list[torch.tensor]) -> torch.tensor:
163
176
  '''
164
-
165
- returns max_block_size of given list of strings by taking the length of the longest process_row(row) in data
166
-
177
+ torch stack with check to ensure tensors are valid in input list
178
+ returns torch.stack(out_list) for all valid torch tensors in tensor_list. raises error if no valid tensors
167
179
  '''
168
- lengths = [len(process_row(p, tokenizer)) for p in data]
169
- max_block_size_scanner = max(lengths)
170
- return max_block_size_scanner
180
+ out_list = [row for row in tensor_list if isinstance(row, torch.Tensor)]
181
+ if len(out_list) == 0:
182
+ raise ValueError("no valid tensors in list.")
183
+ return torch.stack(out_list)
171
184
 
172
185
 
173
186
  class DataProcessor:
174
187
  '''
175
-
176
188
  data processor can be instantiated by specifying the tokenizer(s) for decoder and encoder data
177
-
178
- process_list() function processes raw data in the form of list[str] or str for decoder and encoder simultaneously and
179
- saves them to save_path as .pt files.
180
- - encoder and decoder input data should have matching input and outputs so enc_data[n] should have its corresponding
181
- decoder data at dec_data[n].
182
- - max block size can be specified for both input and output, default takes the max
183
- block size provided in the data respectively.
184
- - if enc/dec_block_size is specified and enc/dec_block_size_exceeded_policy is not, an error will occur if a piece
185
- of data larger than enc/dec_block_size is encountered. enc/dec_block_size_exceeded_policy can be set to "skip" or
186
- "trim" to skip rows larger than enc/dec_block_size or truncate the row to specified enc/dec_block_size respectively.
187
- - enc/dec_create_masks saves masks tensors to save_path as .pt files.
188
-
189
-
190
189
  '''
191
190
  def __init__(self,
192
191
  dec_tokenizer:TokenizerConstructor,
@@ -196,165 +195,153 @@ class DataProcessor:
196
195
  self.enc_tokenizer = enc_tokenizer
197
196
 
198
197
  def process_list(self,
199
- save_path:str,
200
198
  dec_data:list[str]|str,
201
199
  dec_max_block_size:int=None,
202
200
  dec_create_masks:bool=True,
203
- dec_block_size_exceeded_policy:str=None,
204
201
  enc_data:list[str]=None,
205
202
  enc_max_block_size:int=None,
206
203
  enc_create_masks:bool=True,
207
- enc_block_size_exceeded_policy:str=None
204
+ save_path:str = "."
208
205
  ) -> None:
206
+ '''
207
+ processes raw data in the form of list[str] or str for decoder and encoder simultaneously and
208
+ saves them to save_path as .pt files.
209
+ - encoder and decoder input data should have matching input and outputs so enc_data[n] should have its corresponding
210
+ decoder data at dec_data[n].
211
+ - max block size can be specified for both input and output, default takes the max
212
+ block size provided in the data respectively. If data length > max_length, the data is trimmed.
213
+ '''
209
214
 
210
215
  if isinstance(dec_data, str):
211
216
  dec_data = [dec_data]
212
217
  dec_data_length = len(dec_data)
213
- save_path = save_path.replace(".pt", "")
214
-
215
- if dec_max_block_size == None:
216
- dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
217
218
 
218
- if enc_data != None:
219
- self.enc_tokenizer = self.dec_tokenizer if self.enc_tokenizer == None else self.enc_tokenizer
219
+ if enc_data is not None:
220
+ if self.enc_tokenizer is None:
221
+ self.enc_tokenizer = self.dec_tokenizer
220
222
 
221
223
  enc_data_length = len(enc_data)
222
224
  if dec_data_length != enc_data_length:
223
- raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
225
+ raise Exception(f"decoder data and encoder data lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
224
226
 
225
- if enc_max_block_size == None:
226
- enc_max_block_size = scan_max_block_size(enc_data, self.enc_tokenizer)
227
-
228
- enc_out_list = [[]]*enc_data_length
229
- enc_mask_list = [[]]*enc_data_length if enc_create_masks else []
230
- else:
231
- enc_out_list = []
232
- enc_mask_list = []
233
-
234
- dec_out_list = [[]]*dec_data_length
235
- dec_mask_list = [[]]*dec_data_length if dec_create_masks else []
236
- for index in range(len(dec_out_list)):
237
- dec_processed_item = process_row(dec_data[index], self.dec_tokenizer)
238
- if dec_max_block_size != None and len(dec_processed_item) > dec_max_block_size:
239
- if dec_block_size_exceeded_policy == "trim":
240
- dec_processed_item = dec_processed_item[:dec_max_block_size]
241
- elif dec_block_size_exceeded_policy == "skip":
242
- continue
243
- elif dec_block_size_exceeded_policy == None:
244
- raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
245
- if dec_create_masks:
246
- dec_mask = create_mask(dec_processed_item, dec_max_block_size)
247
- dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
248
-
249
- if enc_data != None:
250
- enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
251
- if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
252
- if enc_block_size_exceeded_policy == "trim":
253
- enc_processed_item = enc_processed_item[:enc_max_block_size]
254
- elif enc_block_size_exceeded_policy == "skip":
255
- continue
256
- elif enc_block_size_exceeded_policy == None:
257
- raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
258
- if enc_create_masks:
259
- enc_mask = create_mask(enc_processed_item, enc_max_block_size)
260
- enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
261
-
262
- dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
263
- if dec_create_masks:
264
- dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
265
-
266
- if enc_data != None:
267
- enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
268
- if enc_create_masks:
269
- enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
270
-
271
- dec_out_list = torch.stack([row for row in dec_out_list if row != []])
272
- torch.save(dec_out_list, save_path + "_decoder_data.pt")
227
+ print("processing data")
228
+ dec_out_list = self.dec_tokenizer.encode_batch(dec_data, max_length=dec_max_block_size)
273
229
  if dec_create_masks:
274
- dec_mask_list = torch.stack([row for row in dec_mask_list if row != []])
275
- torch.save(dec_mask_list, save_path + "_decoder_mask_data.pt")
276
- if enc_data != None:
277
- enc_out_list = torch.stack([row for row in enc_out_list if row != []])
278
- torch.save(enc_out_list, save_path + "_encoder_data.pt")
230
+ mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.dec_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
231
+ dec_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in dec_out_list])
232
+
233
+ if enc_data is not None:
234
+ enc_out_list = self.enc_tokenizer.encode_batch(enc_data, max_length=enc_max_block_size)
279
235
  if enc_create_masks:
280
- enc_mask_list = torch.stack([row for row in enc_mask_list if row != []])
281
- torch.save(enc_mask_list, save_path + "_encoder_mask_data.pt")
236
+ mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.enc_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
237
+ enc_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in enc_out_list])
282
238
 
239
+ dec_out_list = torch.tensor(dec_out_list, dtype=torch.long)
240
+ Path(save_path).mkdir(parents=True, exist_ok=True)
241
+ torch.save(dec_out_list, os.path.join(save_path, "decoder_data.pt"))
242
+ if dec_create_masks:
243
+ dec_mask_list = torch.tensor(dec_mask_list, dtype=torch.long)
244
+ torch.save(dec_mask_list, os.path.join(save_path, "decoder_mask_data.pt"))
245
+ if enc_data is not None:
246
+ enc_out_list = torch.tensor(enc_out_list, dtype=torch.long)
247
+ torch.save(enc_out_list, os.path.join(save_path, "encoder_data.pt"))
248
+ if enc_create_masks:
249
+ enc_mask_list = torch.tensor(enc_mask_list, dtype=torch.long)
250
+ torch.save(enc_mask_list, os.path.join(save_path, "encoder_mask_data.pt"))
283
251
 
284
- def get_valid_samples(random_samples:torch.tensor,
285
- masks:torch.tensor,
252
+
253
+ def get_valid_samples(random_samples:torch.Tensor,
254
+ masks:torch.Tensor,
286
255
  block_size:int
287
256
  ) -> list[int]:
288
257
  '''
289
-
290
258
  returns list of len(random_samples) with values corresponding to index values of masks that ensure minimum masked
291
259
  values when taking sample of length block_size
292
-
293
260
  '''
294
261
  valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
295
262
  return valid_samples
296
263
 
297
- def get_batch(data:torch.tensor,
298
- random_samples:torch.tensor,
299
- masks:torch.tensor=None,
264
+ def get_batch(data:torch.Tensor,
265
+ random_samples:torch.Tensor,
266
+ masks:torch.Tensor=None,
300
267
  block_size:int=None,
301
268
  get_offset:bool=True
302
269
  ) -> tuple[torch.tensor]:
303
270
  '''
304
-
305
271
  returns random batches from data tensor using random sample for data selection.
306
272
  - returns corresponding batch offset by 1 unless get_offset=False
307
273
  - returns corresponding masks batch if masks data is specified
308
-
309
274
  '''
310
275
  batch_size = len(random_samples)
311
- if block_size != None and block_size != data.shape[1]:
276
+ if block_size is not None and block_size != data.shape[1]:
312
277
  if block_size >= data.shape[1]:
313
278
  raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
314
279
 
315
- if masks != None:
280
+ if masks is not None:
316
281
  random_point = get_valid_samples(random_samples, masks, block_size)
317
282
  else:
318
283
  random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
319
284
  batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
320
- masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks != None else None
285
+ masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks is not None else None
321
286
  batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
322
287
  else:
323
288
  block_size = data.shape[1]
324
289
  batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
325
- masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks != None else None
290
+ masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks is not None else None
326
291
  batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
327
292
 
328
293
  return batch_in, batch_out, masks_in
329
294
 
330
- def top_kp_filter(logits:torch.tensor,
331
- top_k:int,
332
- top_p:float=None
333
- ) -> torch.tensor:
295
+ def top_kp_filter(logits: torch.Tensor,
296
+ top_k: int = None,
297
+ top_p: float = None
298
+ ) -> torch.Tensor:
334
299
  '''
300
+ Returns predicted token by filtering output logits using top_k and/or top_p (nucleus) filtering.
335
301
 
336
- returns predicted token by filtering output logits using top_k and top_p
337
-
302
+ Args:
303
+ logits: (batch_size, vocab_size) tensor of raw logits.
304
+ top_k: keep only top_k tokens with highest logits.
305
+ top_p: keep the smallest set of tokens with cumulative probability >= top_p.
338
306
  '''
339
- if top_p != None:
340
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
341
- cumulative_probs = torch.cumsum(sorted_logits, dim=-1)
342
-
343
- filter = cumulative_probs > top_p
344
- filter[..., 1:] = filter[..., :-1].clone()
345
- filter[..., 0] = 0
346
- indices_to_remove = filter.scatter(1, sorted_indices, filter)
347
- logits[indices_to_remove] = float("-inf")
307
+ logits = logits.clone() # avoid modifying input logits in-place
308
+
309
+ # Apply top-p filtering if specified
310
+ if top_p is not None:
311
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
312
+ probs = F.softmax(sorted_logits, dim=-1)
313
+ cumulative_probs = torch.cumsum(probs, dim=-1)
348
314
 
349
- if top_k != None:
350
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
315
+ # Remove tokens with cumulative probability above threshold (except first token)
316
+ sorted_mask = cumulative_probs > top_p
317
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
318
+ sorted_mask[..., 0] = False
351
319
 
352
- sorted_logits = F.softmax(sorted_logits[:, :top_k], dim=-1)
353
- sorted_indices = sorted_indices[:, :top_k].detach().cpu()
354
- sorted_logits = sorted_logits.detach().cpu().numpy()
355
- sorted_logits[0][0] += 1 - sum(sorted_logits[0])
320
+ # Mask tokens to remove by setting logits to -inf
321
+ indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
322
+ logits[indices_to_remove] = float('-inf')
356
323
 
357
- selected = torch.tensor(np.random.choice(sorted_indices[0], 1, p=sorted_logits[0]), dtype=torch.long)
324
+ # Apply top-k filtering if specified
325
+ if top_k is not None:
326
+ top_k = min(top_k, logits.size(-1)) # safety check
327
+ topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
328
+ topk_probs = F.softmax(topk_logits, dim=-1).cpu().numpy()
329
+
330
+ # For each batch, sample from top_k candidates
331
+ selected = []
332
+ for i in range(topk_probs.shape[0]):
333
+ candidate = np.random.choice(topk_indices[i].cpu().numpy(), 1, p=topk_probs[i])
334
+ selected.append(candidate[0])
335
+ selected = torch.tensor(selected, dtype=torch.long)
336
+
337
+ else:
338
+ # If only top_p is specified, sample from entire filtered logits
339
+ probs = F.softmax(logits, dim=-1).cpu().numpy()
340
+ selected = []
341
+ for i in range(probs.shape[0]):
342
+ candidate = np.random.choice(len(probs[i]), 1, p=probs[i])
343
+ selected.append(candidate[0])
344
+ selected = torch.tensor(selected, dtype=torch.long)
358
345
 
359
346
  return selected
360
347
 
@@ -362,10 +349,8 @@ def top_kp_filter(logits:torch.tensor,
362
349
 
363
350
  class SelfAttention(nn.Module):
364
351
  '''
365
-
366
352
  single self attention block of size head_size.
367
353
  triangle_mask=True to apply look-ahead mask of size block_size.
368
-
369
354
  '''
370
355
  def __init__(self,
371
356
  head_size:int,
@@ -387,16 +372,14 @@ class SelfAttention(nn.Module):
387
372
  self.dropout = nn.Dropout(dropout)
388
373
 
389
374
  def forward(self,
390
- k:torch.tensor,
391
- q:torch.tensor,
392
- v:torch.tensor,
393
- mask:torch.tensor=None
375
+ k:torch.Tensor,
376
+ q:torch.Tensor,
377
+ v:torch.Tensor,
378
+ mask:torch.Tensor=None
394
379
  ) -> torch.tensor:
395
380
  '''
396
-
397
381
  k, q and v are key, tensors to get key, query and value tensors.
398
382
  custom mask tensor can be applied.
399
-
400
383
  '''
401
384
  _,T,_ = k.shape
402
385
 
@@ -406,7 +389,7 @@ class SelfAttention(nn.Module):
406
389
  wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
407
390
  if self.triangle_mask and self.block_size >= 0:
408
391
  wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
409
- if mask != None:
392
+ if mask is not None:
410
393
  wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
411
394
  wei = F.softmax(wei, dim=-1)
412
395
  wei = self.dropout(wei)
@@ -417,12 +400,10 @@ class SelfAttention(nn.Module):
417
400
 
418
401
  class MultiHeadAttention(nn.Module):
419
402
  '''
420
-
421
403
  multi-head attention block consisting of num_heads SelfAttention blocks and a linear layer to
422
404
  rejoin outputs.
423
405
  specified head_size, n_embed, dropout, block_size and triangle_mask values are passed through to
424
406
  SelfAttention blocks
425
-
426
407
  '''
427
408
  def __init__(self,
428
409
  num_heads:int,
@@ -438,16 +419,14 @@ class MultiHeadAttention(nn.Module):
438
419
  self.dropout = nn.Dropout(dropout)
439
420
 
440
421
  def forward(self,
441
- k:torch.tensor,
442
- q:torch.tensor,
443
- v:torch.tensor,
444
- mask:torch.tensor=None
422
+ k:torch.Tensor,
423
+ q:torch.Tensor,
424
+ v:torch.Tensor,
425
+ mask:torch.Tensor=None
445
426
  ) -> torch.tensor:
446
427
  '''
447
-
448
428
  k, q and v are key, tensors to get key, query and value tensors.
449
429
  custom mask tensor can be applied.
450
-
451
430
  '''
452
431
  out = torch.cat([h(k, q, v, mask=mask) for h in self.heads], dim=-1)
453
432
  out = self.dropout(self.proj(out))
@@ -455,11 +434,9 @@ class MultiHeadAttention(nn.Module):
455
434
 
456
435
  class FeedForward(nn.Module):
457
436
  '''
458
-
459
437
  feed forward layer used after multi-head attention consisting of 2 lieanr layers with
460
438
  a ReLU in between. Linear layers expand from n_embed to n_embed * expansion_factor and
461
439
  back to n_embed.
462
-
463
440
  '''
464
441
  def __init__(self,
465
442
  n_embed:int,
@@ -475,16 +452,14 @@ class FeedForward(nn.Module):
475
452
  )
476
453
 
477
454
  def forward(self,
478
- x:torch.tensor
455
+ x:torch.Tensor
479
456
  ) -> torch.tensor:
480
457
  return self.net(x)
481
458
 
482
459
  class EncoderBlock(nn.Module):
483
460
  '''
484
-
485
461
  encoder block consists of a sequence of multi-head attention, LayerNorm, feed-forward, LayerNorm
486
462
  head_size is calculated from n_embed // n_head
487
-
488
463
  '''
489
464
  def __init__(self,
490
465
  n_embed:int,
@@ -500,8 +475,8 @@ class EncoderBlock(nn.Module):
500
475
  self.ln2 = nn.LayerNorm(n_embed)
501
476
 
502
477
  def forward(self,
503
- x:torch.tensor,
504
- mask:torch.tensor=None
478
+ x:torch.Tensor,
479
+ mask:torch.Tensor=None
505
480
  ) -> tuple[torch.tensor]:
506
481
  att = self.sa(x, x, x, mask=mask)
507
482
  x = self.ln1(att + x)
@@ -512,13 +487,11 @@ class EncoderBlock(nn.Module):
512
487
 
513
488
  class DecoderBlock(nn.Module):
514
489
  '''
515
-
516
490
  decoder block consists of a sequence of multi-head attention, LayerNorm, feed-forward, LayerNorm
517
491
  if cross-attention is True, a multi-head attention block and layerNorm is added before feed-forward
518
492
  taking specified enc_k and enc_v tensors as value and key tensors. These values should be the output
519
493
  of an encoder block.
520
494
  head_size is calculated from n_embed // n_head
521
-
522
495
  '''
523
496
  def __init__(self,
524
497
  n_embed:int,
@@ -541,15 +514,15 @@ class DecoderBlock(nn.Module):
541
514
  self.ca = None
542
515
 
543
516
  def forward(self,
544
- x:torch.tensor,
545
- enc_k:torch.tensor,
546
- enc_v:torch.tensor,
517
+ x:torch.Tensor,
518
+ enc_k:torch.Tensor,
519
+ enc_v:torch.Tensor,
547
520
  mask_out:bool=None,
548
- mask_in:torch.tensor=None
521
+ mask_in:torch.Tensor=None
549
522
  ) -> tuple[torch.tensor]:
550
523
  att = self.sa(x, x, x, mask=mask_out)
551
524
  x = self.ln1(att + x)
552
- if self.ca != None:
525
+ if self.ca is not None:
553
526
  catt = self.ca(enc_k, x, enc_v, mask=mask_in)
554
527
  x = self.ln3(catt + x)
555
528
  ff = self.ffwd(x)
@@ -558,9 +531,7 @@ class DecoderBlock(nn.Module):
558
531
 
559
532
  class MySequential(nn.Sequential):
560
533
  '''
561
-
562
534
  MySequential serves the same purpose as nn.Sequential but allows for multiple inputs and outputs
563
-
564
535
  '''
565
536
  def forward(self, *input):
566
537
  for module in self._modules.values():
@@ -569,39 +540,12 @@ class MySequential(nn.Sequential):
569
540
 
570
541
  class RoboConstructor(nn.Module):
571
542
  '''
572
-
573
543
  RoboConstructor assembles an encoder-decoder or decoder-only transformer.
574
544
  if the enc_* variables are not specified, or enc_n_blocks==0, the transformer will be decoder-only.
575
545
  - if any of the dec_* variables are not specified (except dec_expansion_factor) an error will occur.
576
546
  - if enc_n_blocks > 0 and any of the enc_* variables are not specified (except enc_expansion_factor and enc_block_size) an error will occur.
577
547
  dropout can be specified, default=0.1.
578
548
  if device is not specified, device will default to first available among ("cuda", "mps", "cpu")
579
-
580
- prep_data() function returns a batch of specified batch_size, from dec_data (and dec_masks, enc_data and enc_masks if specified)
581
- - if encoder is configured in this instance, enc_data must be specified.
582
- - dec_block_size must be specified.
583
- - if enc_block_size is not specified, the entire block_size of enc_data will be used.
584
- this function is for use in train_robo()
585
-
586
- train_robo() function trains the RoboConstructor instance transformer.
587
- - training parameters can be specified such as max_iters, eval_interval, batch_size, eval_iters, learning_rate, label_smoothing.
588
- - paths must be specified for decoder training data (and encoder training data if encoder-decoder transformer)
589
- - optional paths to specify: decoder and encoder masks, decoder and encoder validation data, decoder and encoder validation masks data
590
- - if neither pad_token or tokenizer is specified (or tokenizer has no pad_token), any padding in labels will contribute towards the loss
591
- which may cause unwanted results. Specifying pad_token and/or tokenizer allows loss to be calculated while ignoring any padding in labels
592
- - specify save_path to save the model as a .pkl file every eval_interval iterations using the save_component function.
593
-
594
- generate() function uses the tranformer model from the RoboConstructor instance to generate an output from an input.
595
- - input can be in the form of a string if input tokenizer is specified (enc_tokenizer for encoder-decoder and dec_tokenizder for decoder-only),
596
- otherwise, it must be in the form of a list of tokens.
597
- - if dec_tokenizer is specified, output will be a string.
598
- - new tokens are generated until the dec_end_token (or dec_tokenizer.end_token) is generated, or the number of tokens generated == max_new_tokens.
599
- - if input tokenizer is not specified, or input tokenizer.start_token is None, enc_start_token must be specified for an encoder-decoder model.
600
- - separator_token is used to separate the input and generated tokens for a decoder-only model. If this value is not specified, there
601
- will be no distinction between input tokens and generated tokens to the transformer, even if dec_tokenizer is specified.
602
- - if new_line_token is not specified, output will be returned in one line, without any "\n" line separators.
603
- - temperature, top_k and top_p can be specified to adjust the output.
604
-
605
549
  '''
606
550
  def __init__(self,
607
551
  n_embed:int,
@@ -628,7 +572,7 @@ class RoboConstructor(nn.Module):
628
572
  self.dec_expansion_factor = dec_expansion_factor
629
573
  self.dropout = dropout
630
574
 
631
- if device == None:
575
+ if device is None:
632
576
  self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
633
577
  else:
634
578
  self.device = device
@@ -673,13 +617,13 @@ class RoboConstructor(nn.Module):
673
617
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
674
618
 
675
619
  def forward(self,
676
- dec_in:torch.tensor,
677
- dec_mask:torch.tensor=None,
678
- enc_in:torch.tensor=None,
679
- enc_mask:torch.tensor=None
620
+ dec_in:torch.Tensor,
621
+ dec_mask:torch.Tensor=None,
622
+ enc_in:torch.Tensor=None,
623
+ enc_mask:torch.Tensor=None
680
624
  ) -> torch.tensor:
681
625
  _, dec_T = dec_in.shape
682
- if enc_in != None:
626
+ if enc_in is not None:
683
627
  _, enc_T = enc_in.shape
684
628
 
685
629
  dec_tok_emb = self.dec_token_embedding_table(dec_in)
@@ -714,17 +658,24 @@ class RoboConstructor(nn.Module):
714
658
  enc_block_size:int=None,
715
659
  enc_masks:str=None
716
660
  ) -> tuple[torch.tensor]:
661
+ '''
662
+ returns a batch of specified batch_size, from dec_data (and dec_masks, enc_data and enc_masks if specified)
663
+ - if encoder is configured in this instance, enc_data must be specified.
664
+ - dec_block_size must be specified.
665
+ - if enc_block_size is not specified, the entire block_size of enc_data will be used.
666
+ this method is for use in train_robo()
667
+ '''
717
668
  random_samples = torch.randint(dec_data.shape[0], (batch_size,))
718
669
 
719
670
  dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
720
671
  dec_train_batch_in = dec_train_batch_in.to(self.device)
721
- dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out != None else None
722
- dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in != None else None
672
+ dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out is not None else None
673
+ dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in is not None else None
723
674
 
724
675
  if self.cross_attention:
725
676
  enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
726
677
  enc_train_batch_in = enc_train_batch_in.to(self.device)
727
- enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in != None else None
678
+ enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in is not None else None
728
679
  else:
729
680
  enc_train_batch_in = None
730
681
  enc_train_masks_in = None
@@ -736,14 +687,8 @@ class RoboConstructor(nn.Module):
736
687
  max_iters:int,
737
688
  eval_interval:int,
738
689
  batch_size:int,
739
- dec_training_path:str,
740
- dec_eval_path:str=None,
741
- dec_training_masks_path:str=None,
742
- dec_eval_masks_path:str=None,
743
- enc_training_path:str=None,
744
- enc_eval_path:str=None,
745
- enc_training_masks_path:str=None,
746
- enc_eval_masks_path:str=None,
690
+ training_dir_path:str,
691
+ eval_dir_path:str=None,
747
692
  eval_iters:int=3,
748
693
  learning_rate:float=1e-4,
749
694
  pad_token:int=None,
@@ -751,22 +696,46 @@ class RoboConstructor(nn.Module):
751
696
  save_path:str=None,
752
697
  label_smoothing:float=0.1
753
698
  ) -> None:
699
+ '''
700
+ trains the RoboConstructor instance transformer.
701
+ - training parameters can be specified such as max_iters, eval_interval, batch_size, eval_iters, learning_rate, label_smoothing.
702
+ - paths must be specified for decoder training data (and encoder training data if encoder-decoder transformer)
703
+ - optional paths to specify: decoder and encoder masks, decoder and encoder validation data, decoder and encoder validation masks data
704
+ - if neither pad_token or tokenizer is specified (or tokenizer has no pad_token), any padding in labels will contribute towards the loss
705
+ which may cause unwanted results. Specifying pad_token and/or tokenizer allows loss to be calculated while ignoring any padding in labels
706
+ - specify save_path to save the model as a .pkl file every eval_interval iterations using the save_component function.
707
+ '''
754
708
 
709
+ dec_training_path = os.path.join(training_dir_path, "decoder_data.pt")
755
710
  dec_training_data = torch.load(dec_training_path, weights_only=True)
756
- dec_eval_data = torch.load(dec_eval_path, weights_only=True) if dec_eval_path != None else None
757
- dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if dec_training_masks_path != None else None
758
- dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if dec_eval_masks_path != None else None
759
- enc_training_data = torch.load(enc_training_path, weights_only=True) if enc_training_path != None else None
760
- enc_eval_data = torch.load(enc_eval_path, weights_only=True) if enc_eval_path != None else None
761
- enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if enc_training_masks_path != None else None
762
- enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if enc_eval_masks_path != None else None
763
-
764
- if pad_token == None and dec_tokenizer != None:
711
+
712
+ dec_eval_path = os.path.join(eval_dir_path, "decoder_data.pt")
713
+ dec_eval_data = torch.load(dec_eval_path, weights_only=True) if os.path.isfile(dec_eval_path) else None
714
+
715
+ dec_training_masks_path = os.path.join(training_dir_path, "decoder_mask_data.pt")
716
+ dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if os.path.isfile(dec_training_masks_path) else None
717
+
718
+ dec_eval_masks_path = os.path.join(eval_dir_path, "decoder_mask_data.pt")
719
+ dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if os.path.isfile(dec_eval_masks_path) else None
720
+
721
+ enc_training_path = os.path.join(training_dir_path, "encoder_data.pt")
722
+ enc_training_data = torch.load(enc_training_path, weights_only=True) if os.path.isfile(enc_training_path) else None
723
+
724
+ enc_eval_path = os.path.join(eval_dir_path, "encoder_data.pt")
725
+ enc_eval_data = torch.load(enc_eval_path, weights_only=True) if os.path.isfile(enc_eval_path) else None
726
+
727
+ enc_training_masks_path = os.path.join(training_dir_path, "encoder_mask_data.pt")
728
+ enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if os.path.isfile(enc_training_masks_path) else None
729
+
730
+ enc_eval_masks_path = os.path.join(eval_dir_path, "encoder_mask_data.pt")
731
+ enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if os.path.isfile(enc_eval_masks_path) else None
732
+
733
+ if pad_token is None and dec_tokenizer is not None:
765
734
  pad_token = dec_tokenizer.pad_token
766
735
 
767
736
  self.to(self.device)
768
737
 
769
- if pad_token != None:
738
+ if pad_token is not None:
770
739
  loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
771
740
  else:
772
741
  loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
@@ -782,7 +751,7 @@ class RoboConstructor(nn.Module):
782
751
  proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
783
752
  losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
784
753
  out["train"] = losses.mean()
785
- if dec_eval_data != None:
754
+ if dec_eval_data is not None:
786
755
  for k in range(eval_iters):
787
756
  dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
788
757
  proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
@@ -798,7 +767,7 @@ class RoboConstructor(nn.Module):
798
767
  if iter % eval_interval == 0 or iter == max_iters-1:
799
768
  losses = estimate_loss()
800
769
  print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
801
- if save_path != None:
770
+ if save_path is not None:
802
771
  save_component(self, save_path=save_path)
803
772
 
804
773
  dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
@@ -825,25 +794,37 @@ class RoboConstructor(nn.Module):
825
794
  top_k:int=None,
826
795
  top_p:float=None
827
796
  ) -> list[int]|str:
828
- max_new_tokens = self.dec_block_size if max_new_tokens == None else max_new_tokens
797
+ '''
798
+ uses the tranformer model from the RoboConstructor instance to generate an output from an input.
799
+ - input can be in the form of a string if input tokenizer is specified (enc_tokenizer for encoder-decoder and dec_tokenizder for decoder-only),
800
+ otherwise, it must be in the form of a list of tokens.
801
+ - if dec_tokenizer is specified, output will be a string.
802
+ - new tokens are generated until the dec_end_token (or dec_tokenizer.end_token) is generated, or the number of tokens generated == max_new_tokens.
803
+ - if input tokenizer is not specified, or input tokenizer.start_token is None, enc_start_token must be specified for an encoder-decoder model.
804
+ - separator_token is used to separate the input and generated tokens for a decoder-only model. If this value is not specified, there
805
+ will be no distinction between input tokens and generated tokens to the transformer, even if dec_tokenizer is specified.
806
+ - if new_line_token is not specified, output will be returned in one line, without any "\n" line separators.
807
+ - temperature, top_k and top_p can be specified to adjust the output.
808
+ '''
809
+ max_new_tokens = self.dec_block_size if max_new_tokens is None else max_new_tokens
829
810
 
830
811
  if self.cross_attention:
831
- if enc_tokenizer != None:
832
- if enc_start_token == None:
812
+ if enc_tokenizer is not None:
813
+ if enc_start_token is None:
833
814
  enc_start_token = enc_tokenizer.start_token
834
- if enc_end_token == None:
815
+ if enc_end_token is None:
835
816
  enc_end_token = enc_tokenizer.end_token
836
817
  if isinstance(inputs, str):
837
818
  inputs = enc_tokenizer.encode(inputs)
838
819
 
839
- if dec_tokenizer != None:
840
- if dec_start_token == None:
820
+ if dec_tokenizer is not None:
821
+ if dec_start_token is None:
841
822
  dec_start_token = dec_tokenizer.start_token
842
- if dec_end_token == None:
823
+ if dec_end_token is None:
843
824
  dec_end_token = dec_tokenizer.end_token
844
- if new_line_token == None:
825
+ if new_line_token is None:
845
826
  new_line_token = dec_tokenizer.new_line_token
846
- if self.cross_attention == False and isinstance(inputs, str):
827
+ if not self.cross_attention and isinstance(inputs, str):
847
828
  inputs = dec_tokenizer.encode(inputs)
848
829
 
849
830
 
@@ -852,7 +833,7 @@ class RoboConstructor(nn.Module):
852
833
  idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
853
834
  else:
854
835
  enc_input = None
855
- if separator_token != None:
836
+ if separator_token is not None:
856
837
  idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
857
838
  else:
858
839
  idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
@@ -866,7 +847,7 @@ class RoboConstructor(nn.Module):
866
847
  logits = proj_output[:, -1, :]
867
848
  probabilities = F.log_softmax(logits/temperature, dim=-1)
868
849
 
869
- if top_k == None and top_p == None:
850
+ if top_k is None and top_p is None:
870
851
  idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
871
852
  else:
872
853
  idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
@@ -874,10 +855,10 @@ class RoboConstructor(nn.Module):
874
855
  if idx_next[0] == dec_end_token:
875
856
  break
876
857
 
877
- if dec_tokenizer == None:
858
+ if dec_tokenizer is None:
878
859
  return idx[0].tolist()
879
860
  else:
880
- if new_line_token != None:
861
+ if new_line_token is not None:
881
862
  return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
882
863
  else:
883
864
  return dec_tokenizer.decode(idx[0].tolist())
@@ -885,9 +866,7 @@ class RoboConstructor(nn.Module):
885
866
 
886
867
  def save_component(component, save_path:str) -> None:
887
868
  '''
888
-
889
869
  saves component (such as TokenizerConstructor or RoboConstructor) as .pkl file.
890
-
891
870
  '''
892
871
  save_path = save_path + ".pkl" if save_path[-4:] != ".pkl" else save_path
893
872
  with open(save_path, "wb") as comp:
@@ -895,9 +874,7 @@ def save_component(component, save_path:str) -> None:
895
874
 
896
875
  def load_component(load_path:str):
897
876
  '''
898
-
899
877
  loads saved .pkl file into variable.
900
-
901
878
  '''
902
879
  load_path = load_path + ".pkl" if load_path[-4:] != ".pkl" else load_path
903
880
  with open(load_path, "rb") as comp:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: robo_lib
3
- Version: 0.0.11
3
+ Version: 1.0.1
4
4
  Summary: A package to create, configure, and train transformer models.
5
5
  Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
6
6
  Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
@@ -13,6 +13,7 @@ Requires-Python: >=3.8
13
13
  Requires-Dist: numpy
14
14
  Requires-Dist: tokenizers
15
15
  Requires-Dist: torch
16
+ Requires-Dist: typing
16
17
  Description-Content-Type: text/markdown
17
18
 
18
19
  # robo-lib
@@ -83,10 +84,8 @@ proc.process_list(
83
84
  save_path="data/training",
84
85
  dec_data=french_train,
85
86
  dec_max_block_size=100,
86
- dec_block_size_exceeded_policy="skip",
87
87
  enc_data=english_train,
88
- enc_max_block_size=100,
89
- enc_block_size_exceeded_policy="skip"
88
+ enc_max_block_size=100
90
89
  )
91
90
 
92
91
  # process and save validation data as data/validation*.pt
@@ -94,10 +93,8 @@ proc.process_list(
94
93
  save_path="data/validation",
95
94
  dec_data=french_val,
96
95
  dec_max_block_size=100,
97
- dec_block_size_exceeded_policy="skip",
98
96
  enc_data=english_val,
99
- enc_max_block_size=100,
100
- enc_block_size_exceeded_policy="skip"
97
+ enc_max_block_size=100
101
98
  )
102
99
  ```
103
100
  - The `RoboConstructor` class is used to create and configure transformer models before trainin.
@@ -128,14 +125,8 @@ robo.train_robo(
128
125
  max_iters=20000,
129
126
  eval_interval=200,
130
127
  batch_size=128,
131
- dec_training_path="data/training_decoder_data.pt",
132
- dec_eval_path="data/validation_decoder_data.pt",
133
- dec_training_masks_path="data/training_decoder_mask_data.pt",
134
- dec_eval_masks_path="data/validation_decoder_mask_data.pt",
135
- enc_training_path="data/training_encoder_data.pt",
136
- enc_eval_path="data/validation_encoder_data.pt",
137
- enc_training_masks_path="data/training_encoder_mask_data.pt",
138
- enc_eval_masks_path="data/validation_encoder_mask_data.pt",
128
+ training_dir_path="data/training",
129
+ eval_dir_path="data/validation",
139
130
  dec_tokenizer=decoder_tok,
140
131
  save_path="models/eng_to_fr_robo.pkl"
141
132
  )
@@ -223,8 +214,8 @@ robo.train(
223
214
  max_iters=20000,
224
215
  eval_interval=200,
225
216
  batch_size=64,
226
- dec_training_path="data/shakespeare_train_decoder_data.pt",
227
- dec_eval_path="data/shakespeare_valid_decoder_data.pt",
217
+ training_dir_path="data/shakespeare_train",
218
+ eval_dir_path="data/shakespeare_valid",
228
219
  dec_tokenizer=tok,
229
220
  save_path="models/shakespeare_robo.pkl"
230
221
  )
@@ -0,0 +1,6 @@
1
+ robo_lib/__init__.py,sha256=NnzWHWwpFcSJD_XRMWKKPQFAIrRBFYiCFN0pgUGPygc,968
2
+ robo_lib/components.py,sha256=mfvdNC77d1k1vmlNwG3ri2MbfmEn3haACAnRf56b_c4,43164
3
+ robo_lib-1.0.1.dist-info/METADATA,sha256=4CG07VLULgAcGlfNeNXS9Pjzs7SXP5gNf95ddgGbWqc,9051
4
+ robo_lib-1.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ robo_lib-1.0.1.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
6
+ robo_lib-1.0.1.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- robo_lib/__init__.py,sha256=iVOAsANj0lScVW9KKMxCULYmpp0cv4sv1k3sHjBSlE0,1012
2
- robo_lib/components.py,sha256=L_GUEHdKC_-Xn56ObQ9-DH8T1ywaz0M8jlWv227gZBs,42591
3
- robo_lib-0.0.11.dist-info/METADATA,sha256=ePF06l2FXzo0qjK8v9Vob4WnOQ61KVd0mUqd7JVG7j4,9634
4
- robo_lib-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- robo_lib-0.0.11.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
6
- robo_lib-0.0.11.dist-info/RECORD,,