robo-lib 0.0.11__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- robo_lib/__init__.py +2 -3
- robo_lib/components.py +246 -269
- {robo_lib-0.0.11.dist-info → robo_lib-1.0.1.dist-info}/METADATA +8 -17
- robo_lib-1.0.1.dist-info/RECORD +6 -0
- robo_lib-0.0.11.dist-info/RECORD +0 -6
- {robo_lib-0.0.11.dist-info → robo_lib-1.0.1.dist-info}/WHEEL +0 -0
- {robo_lib-0.0.11.dist-info → robo_lib-1.0.1.dist-info}/licenses/LICENSE +0 -0
robo_lib/__init__.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
1
|
from .components import TokenizerConstructor as TokenizerConstructor
|
2
2
|
from .components import create_mask as create_mask
|
3
|
-
from .components import
|
4
|
-
from .components import
|
5
|
-
from .components import scan_max_block_size as scan_max_block_size
|
3
|
+
from .components import pre_process_data as pre_process_data
|
4
|
+
from .components import safe_stack as safe_stack
|
6
5
|
from .components import DataProcessor as DataProcessor
|
7
6
|
from .components import get_valid_samples as get_valid_samples
|
8
7
|
from .components import get_batch as get_batch
|
robo_lib/components.py
CHANGED
@@ -6,30 +6,28 @@ import numpy as np
|
|
6
6
|
import random
|
7
7
|
import pickle
|
8
8
|
import itertools
|
9
|
+
from pathlib import Path
|
10
|
+
import os
|
11
|
+
from typing import List, Literal
|
12
|
+
|
13
|
+
pre_tokenizers = Literal["Whitespace", "IndividualDigit", "Digits", "BertPreTokenizer", "ByteLevel", "Metaspace", "Punctuation", "UnicodeScripts", "WhitespaceSplit"]
|
9
14
|
|
10
15
|
class TokenizerConstructor:
|
11
16
|
'''
|
12
|
-
|
13
17
|
simple assembler for tokenizer using the tokenizers library
|
14
|
-
tokenizer parameters can be set using strings and list[
|
18
|
+
tokenizer parameters can be set using strings and list[str]s
|
15
19
|
strings used for tokenizer_type, pre_tokenizers, normalizers arguments are the names of those present in the
|
16
20
|
tokenizers library. Additionally "IndividualDigits" can be used in normalizers for tokenizers.pre_tokenizers.Digits(individual_digits=True)
|
17
21
|
|
18
|
-
train([paths]) function points to text files to be used for training the tokenizer instance
|
19
|
-
|
20
|
-
encode(string) function encodes string using trained tokenizer instance
|
21
|
-
|
22
|
-
decode(list[int]) function decodes list of tokenz using trained tokenizer instance
|
23
|
-
|
24
22
|
vocab_size attribute returns the tokenizer instance's vocab_size (untrained tokenizer will have vocab_size=None)
|
25
23
|
|
26
|
-
|
27
24
|
'''
|
28
25
|
def __init__(self,
|
29
26
|
min_frequency:int=2,
|
30
|
-
tokenizer_type:
|
31
|
-
pre_tokenizers:
|
27
|
+
tokenizer_type:Literal["BPE", "WordLevel", "WordPiece", "Unigram"] = "BPE",
|
28
|
+
pre_tokenizers: pre_tokenizers|List[pre_tokenizers]=["Whitespace"],
|
32
29
|
normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
|
30
|
+
vocab:dict[str,int] = {},
|
33
31
|
special_tokens:list[str]|str=[],
|
34
32
|
unknown_token_string:str="<unk>",
|
35
33
|
start_token_string:str="<sos>",
|
@@ -42,25 +40,29 @@ class TokenizerConstructor:
|
|
42
40
|
|
43
41
|
if isinstance(special_tokens, str):
|
44
42
|
special_tokens = [special_tokens]
|
45
|
-
self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token
|
46
|
-
self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string
|
47
|
-
self.start_token = self.special_tokens.index(start_token_string) if start_token_string
|
48
|
-
self.end_token = self.special_tokens.index(end_token_string) if end_token_string
|
49
|
-
self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string
|
50
|
-
self.
|
43
|
+
self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token is not None]
|
44
|
+
self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string is not None else None
|
45
|
+
self.start_token = self.special_tokens.index(start_token_string) if start_token_string is not None else None
|
46
|
+
self.end_token = self.special_tokens.index(end_token_string) if end_token_string is not None else None
|
47
|
+
self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string is not None else None
|
48
|
+
self.pad_token_string = pad_token_string
|
49
|
+
self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string is not None else None
|
51
50
|
|
52
51
|
if tokenizer_type == "BPE":
|
53
52
|
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
|
54
53
|
self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
55
54
|
elif tokenizer_type == "WordLevel":
|
56
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
|
55
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(vocab = vocab, unk_token=unknown_token_string))
|
57
56
|
self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
58
57
|
elif tokenizer_type == "WordPiece":
|
59
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
|
58
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(vocab = vocab, unk_token=unknown_token_string))
|
60
59
|
self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
61
60
|
elif tokenizer_type == "Unigram":
|
62
|
-
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(
|
61
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram())
|
63
62
|
self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
63
|
+
|
64
|
+
if self.pad_token is not None:
|
65
|
+
self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=pad_token_string)
|
64
66
|
|
65
67
|
if isinstance(pre_tokenizers, str):
|
66
68
|
pre_tokenizers = [pre_tokenizers]
|
@@ -114,79 +116,76 @@ class TokenizerConstructor:
|
|
114
116
|
|
115
117
|
|
116
118
|
def train(self, training_paths:list[str]|str) -> None:
|
119
|
+
'''
|
120
|
+
points to text files to be used for training the tokenizer instance
|
121
|
+
'''
|
117
122
|
if isinstance(training_paths, str):
|
118
123
|
training_paths = [training_paths]
|
119
124
|
self.tokenizer_type.train(training_paths, trainer=self.trainer)
|
120
125
|
self.vocab_size = self.tokenizer_type.get_vocab_size()
|
121
126
|
|
122
127
|
def encode(self, inp:str) -> list[int]:
|
128
|
+
'''
|
129
|
+
encodes string using trained tokenizer instance
|
130
|
+
'''
|
123
131
|
return self.tokenizer_type.encode(inp).ids
|
124
132
|
|
133
|
+
def encode_batch(self, inp:list[str], max_length:int=None) -> list[list[int]]:
|
134
|
+
'''
|
135
|
+
encodes strings in parallel and truncates entries with length > max_length
|
136
|
+
'''
|
137
|
+
if max_length is not None:
|
138
|
+
self.tokenizer_type.enable_truncation(max_length=max_length)
|
139
|
+
self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=self.pad_token_string, length=max_length)
|
140
|
+
out = [row.ids for row in self.tokenizer_type.encode_batch(inp)]
|
141
|
+
self.tokenizer_type.no_truncation()
|
142
|
+
self.tokenizer_type.enable_padding(pad_id=self.pad_token, pad_token=self.pad_token_string)
|
143
|
+
return out
|
144
|
+
|
125
145
|
def decode(self, inp:list[int]) -> str:
|
146
|
+
'''
|
147
|
+
decodes list of tokenz using trained tokenizer instance
|
148
|
+
'''
|
126
149
|
return self.tokenizer_type.decode(inp)
|
127
150
|
|
128
151
|
|
129
152
|
|
130
153
|
def create_mask(row:list, block_size:int) -> list[bool]:
|
131
154
|
'''
|
132
|
-
|
133
155
|
creates a mask list of length block_size for row, asuming mask does cover the entire row input
|
134
|
-
|
135
156
|
'''
|
136
157
|
mask = [1]*len(row) + [0]*(block_size - len(row))
|
137
158
|
return mask
|
138
159
|
|
139
|
-
def
|
140
|
-
'''
|
141
|
-
|
142
|
-
returns padded row. Row is padded until length block_size with specified pad_token value
|
143
|
-
|
144
|
-
'''
|
145
|
-
row.extend([pad_token]*(block_size - len(row)))
|
146
|
-
return row
|
147
|
-
|
148
|
-
def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
|
160
|
+
def pre_process_data(data:str, start_token_string:str, end_token_string:str) -> list[int]:
|
149
161
|
'''
|
150
|
-
|
151
|
-
returns tokenized row using specified tokenizer, and adds the tokenizer's start and end tokens if they exist
|
152
|
-
|
162
|
+
returns data with the tokenizer's start and end tokens added to each row if they exist
|
153
163
|
'''
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
164
|
+
if start_token_string is None and end_token_string is None:
|
165
|
+
return data
|
166
|
+
else:
|
167
|
+
for i in range(len(data)):
|
168
|
+
if start_token_string is not None:
|
169
|
+
data[i] = start_token_string + data[i]
|
170
|
+
if end_token_string is not None:
|
171
|
+
data[i] = data[i] + end_token_string
|
172
|
+
|
173
|
+
return data
|
161
174
|
|
162
|
-
def
|
175
|
+
def safe_stack(tensor_list:list[torch.tensor]) -> torch.tensor:
|
163
176
|
'''
|
164
|
-
|
165
|
-
returns
|
166
|
-
|
177
|
+
torch stack with check to ensure tensors are valid in input list
|
178
|
+
returns torch.stack(out_list) for all valid torch tensors in tensor_list. raises error if no valid tensors
|
167
179
|
'''
|
168
|
-
|
169
|
-
|
170
|
-
|
180
|
+
out_list = [row for row in tensor_list if isinstance(row, torch.Tensor)]
|
181
|
+
if len(out_list) == 0:
|
182
|
+
raise ValueError("no valid tensors in list.")
|
183
|
+
return torch.stack(out_list)
|
171
184
|
|
172
185
|
|
173
186
|
class DataProcessor:
|
174
187
|
'''
|
175
|
-
|
176
188
|
data processor can be instantiated by specifying the tokenizer(s) for decoder and encoder data
|
177
|
-
|
178
|
-
process_list() function processes raw data in the form of list[str] or str for decoder and encoder simultaneously and
|
179
|
-
saves them to save_path as .pt files.
|
180
|
-
- encoder and decoder input data should have matching input and outputs so enc_data[n] should have its corresponding
|
181
|
-
decoder data at dec_data[n].
|
182
|
-
- max block size can be specified for both input and output, default takes the max
|
183
|
-
block size provided in the data respectively.
|
184
|
-
- if enc/dec_block_size is specified and enc/dec_block_size_exceeded_policy is not, an error will occur if a piece
|
185
|
-
of data larger than enc/dec_block_size is encountered. enc/dec_block_size_exceeded_policy can be set to "skip" or
|
186
|
-
"trim" to skip rows larger than enc/dec_block_size or truncate the row to specified enc/dec_block_size respectively.
|
187
|
-
- enc/dec_create_masks saves masks tensors to save_path as .pt files.
|
188
|
-
|
189
|
-
|
190
189
|
'''
|
191
190
|
def __init__(self,
|
192
191
|
dec_tokenizer:TokenizerConstructor,
|
@@ -196,165 +195,153 @@ class DataProcessor:
|
|
196
195
|
self.enc_tokenizer = enc_tokenizer
|
197
196
|
|
198
197
|
def process_list(self,
|
199
|
-
save_path:str,
|
200
198
|
dec_data:list[str]|str,
|
201
199
|
dec_max_block_size:int=None,
|
202
200
|
dec_create_masks:bool=True,
|
203
|
-
dec_block_size_exceeded_policy:str=None,
|
204
201
|
enc_data:list[str]=None,
|
205
202
|
enc_max_block_size:int=None,
|
206
203
|
enc_create_masks:bool=True,
|
207
|
-
|
204
|
+
save_path:str = "."
|
208
205
|
) -> None:
|
206
|
+
'''
|
207
|
+
processes raw data in the form of list[str] or str for decoder and encoder simultaneously and
|
208
|
+
saves them to save_path as .pt files.
|
209
|
+
- encoder and decoder input data should have matching input and outputs so enc_data[n] should have its corresponding
|
210
|
+
decoder data at dec_data[n].
|
211
|
+
- max block size can be specified for both input and output, default takes the max
|
212
|
+
block size provided in the data respectively. If data length > max_length, the data is trimmed.
|
213
|
+
'''
|
209
214
|
|
210
215
|
if isinstance(dec_data, str):
|
211
216
|
dec_data = [dec_data]
|
212
217
|
dec_data_length = len(dec_data)
|
213
|
-
save_path = save_path.replace(".pt", "")
|
214
|
-
|
215
|
-
if dec_max_block_size == None:
|
216
|
-
dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
|
217
218
|
|
218
|
-
if enc_data
|
219
|
-
|
219
|
+
if enc_data is not None:
|
220
|
+
if self.enc_tokenizer is None:
|
221
|
+
self.enc_tokenizer = self.dec_tokenizer
|
220
222
|
|
221
223
|
enc_data_length = len(enc_data)
|
222
224
|
if dec_data_length != enc_data_length:
|
223
|
-
raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
|
225
|
+
raise Exception(f"decoder data and encoder data lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
|
224
226
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
enc_out_list = [[]]*enc_data_length
|
229
|
-
enc_mask_list = [[]]*enc_data_length if enc_create_masks else []
|
230
|
-
else:
|
231
|
-
enc_out_list = []
|
232
|
-
enc_mask_list = []
|
233
|
-
|
234
|
-
dec_out_list = [[]]*dec_data_length
|
235
|
-
dec_mask_list = [[]]*dec_data_length if dec_create_masks else []
|
236
|
-
for index in range(len(dec_out_list)):
|
237
|
-
dec_processed_item = process_row(dec_data[index], self.dec_tokenizer)
|
238
|
-
if dec_max_block_size != None and len(dec_processed_item) > dec_max_block_size:
|
239
|
-
if dec_block_size_exceeded_policy == "trim":
|
240
|
-
dec_processed_item = dec_processed_item[:dec_max_block_size]
|
241
|
-
elif dec_block_size_exceeded_policy == "skip":
|
242
|
-
continue
|
243
|
-
elif dec_block_size_exceeded_policy == None:
|
244
|
-
raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
|
245
|
-
if dec_create_masks:
|
246
|
-
dec_mask = create_mask(dec_processed_item, dec_max_block_size)
|
247
|
-
dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
|
248
|
-
|
249
|
-
if enc_data != None:
|
250
|
-
enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
|
251
|
-
if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
|
252
|
-
if enc_block_size_exceeded_policy == "trim":
|
253
|
-
enc_processed_item = enc_processed_item[:enc_max_block_size]
|
254
|
-
elif enc_block_size_exceeded_policy == "skip":
|
255
|
-
continue
|
256
|
-
elif enc_block_size_exceeded_policy == None:
|
257
|
-
raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
|
258
|
-
if enc_create_masks:
|
259
|
-
enc_mask = create_mask(enc_processed_item, enc_max_block_size)
|
260
|
-
enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
|
261
|
-
|
262
|
-
dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
|
263
|
-
if dec_create_masks:
|
264
|
-
dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
|
265
|
-
|
266
|
-
if enc_data != None:
|
267
|
-
enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
|
268
|
-
if enc_create_masks:
|
269
|
-
enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
|
270
|
-
|
271
|
-
dec_out_list = torch.stack([row for row in dec_out_list if row != []])
|
272
|
-
torch.save(dec_out_list, save_path + "_decoder_data.pt")
|
227
|
+
print("processing data")
|
228
|
+
dec_out_list = self.dec_tokenizer.encode_batch(dec_data, max_length=dec_max_block_size)
|
273
229
|
if dec_create_masks:
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
230
|
+
mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.dec_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
|
231
|
+
dec_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in dec_out_list])
|
232
|
+
|
233
|
+
if enc_data is not None:
|
234
|
+
enc_out_list = self.enc_tokenizer.encode_batch(enc_data, max_length=enc_max_block_size)
|
279
235
|
if enc_create_masks:
|
280
|
-
|
281
|
-
|
236
|
+
mask_tokenizer = TokenizerConstructor(min_frequency=1, tokenizer_type="WordLevel", vocab={str(self.enc_tokenizer.pad_token): 0, "<unk>": 1}, special_tokens=["<pad>", "<unk>"], unknown_token_string="<unk>", start_token_string=None, end_token_string=None, pad_token_string=None)
|
237
|
+
enc_mask_list = mask_tokenizer.encode_batch([str(i).replace("[", "").replace("]", "").replace(",", "") for i in enc_out_list])
|
282
238
|
|
239
|
+
dec_out_list = torch.tensor(dec_out_list, dtype=torch.long)
|
240
|
+
Path(save_path).mkdir(parents=True, exist_ok=True)
|
241
|
+
torch.save(dec_out_list, os.path.join(save_path, "decoder_data.pt"))
|
242
|
+
if dec_create_masks:
|
243
|
+
dec_mask_list = torch.tensor(dec_mask_list, dtype=torch.long)
|
244
|
+
torch.save(dec_mask_list, os.path.join(save_path, "decoder_mask_data.pt"))
|
245
|
+
if enc_data is not None:
|
246
|
+
enc_out_list = torch.tensor(enc_out_list, dtype=torch.long)
|
247
|
+
torch.save(enc_out_list, os.path.join(save_path, "encoder_data.pt"))
|
248
|
+
if enc_create_masks:
|
249
|
+
enc_mask_list = torch.tensor(enc_mask_list, dtype=torch.long)
|
250
|
+
torch.save(enc_mask_list, os.path.join(save_path, "encoder_mask_data.pt"))
|
283
251
|
|
284
|
-
|
285
|
-
|
252
|
+
|
253
|
+
def get_valid_samples(random_samples:torch.Tensor,
|
254
|
+
masks:torch.Tensor,
|
286
255
|
block_size:int
|
287
256
|
) -> list[int]:
|
288
257
|
'''
|
289
|
-
|
290
258
|
returns list of len(random_samples) with values corresponding to index values of masks that ensure minimum masked
|
291
259
|
values when taking sample of length block_size
|
292
|
-
|
293
260
|
'''
|
294
261
|
valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
|
295
262
|
return valid_samples
|
296
263
|
|
297
|
-
def get_batch(data:torch.
|
298
|
-
random_samples:torch.
|
299
|
-
masks:torch.
|
264
|
+
def get_batch(data:torch.Tensor,
|
265
|
+
random_samples:torch.Tensor,
|
266
|
+
masks:torch.Tensor=None,
|
300
267
|
block_size:int=None,
|
301
268
|
get_offset:bool=True
|
302
269
|
) -> tuple[torch.tensor]:
|
303
270
|
'''
|
304
|
-
|
305
271
|
returns random batches from data tensor using random sample for data selection.
|
306
272
|
- returns corresponding batch offset by 1 unless get_offset=False
|
307
273
|
- returns corresponding masks batch if masks data is specified
|
308
|
-
|
309
274
|
'''
|
310
275
|
batch_size = len(random_samples)
|
311
|
-
if block_size
|
276
|
+
if block_size is not None and block_size != data.shape[1]:
|
312
277
|
if block_size >= data.shape[1]:
|
313
278
|
raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
|
314
279
|
|
315
|
-
if masks
|
280
|
+
if masks is not None:
|
316
281
|
random_point = get_valid_samples(random_samples, masks, block_size)
|
317
282
|
else:
|
318
283
|
random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
|
319
284
|
batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
|
320
|
-
masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks
|
285
|
+
masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks is not None else None
|
321
286
|
batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
|
322
287
|
else:
|
323
288
|
block_size = data.shape[1]
|
324
289
|
batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
|
325
|
-
masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks
|
290
|
+
masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks is not None else None
|
326
291
|
batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
|
327
292
|
|
328
293
|
return batch_in, batch_out, masks_in
|
329
294
|
|
330
|
-
def top_kp_filter(logits:torch.
|
331
|
-
top_k:int,
|
332
|
-
top_p:float=None
|
333
|
-
) -> torch.
|
295
|
+
def top_kp_filter(logits: torch.Tensor,
|
296
|
+
top_k: int = None,
|
297
|
+
top_p: float = None
|
298
|
+
) -> torch.Tensor:
|
334
299
|
'''
|
300
|
+
Returns predicted token by filtering output logits using top_k and/or top_p (nucleus) filtering.
|
335
301
|
|
336
|
-
|
337
|
-
|
302
|
+
Args:
|
303
|
+
logits: (batch_size, vocab_size) tensor of raw logits.
|
304
|
+
top_k: keep only top_k tokens with highest logits.
|
305
|
+
top_p: keep the smallest set of tokens with cumulative probability >= top_p.
|
338
306
|
'''
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
indices_to_remove = filter.scatter(1, sorted_indices, filter)
|
347
|
-
logits[indices_to_remove] = float("-inf")
|
307
|
+
logits = logits.clone() # avoid modifying input logits in-place
|
308
|
+
|
309
|
+
# Apply top-p filtering if specified
|
310
|
+
if top_p is not None:
|
311
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
312
|
+
probs = F.softmax(sorted_logits, dim=-1)
|
313
|
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
348
314
|
|
349
|
-
|
350
|
-
|
315
|
+
# Remove tokens with cumulative probability above threshold (except first token)
|
316
|
+
sorted_mask = cumulative_probs > top_p
|
317
|
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
318
|
+
sorted_mask[..., 0] = False
|
351
319
|
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
sorted_logits[0][0] += 1 - sum(sorted_logits[0])
|
320
|
+
# Mask tokens to remove by setting logits to -inf
|
321
|
+
indices_to_remove = sorted_mask.scatter(1, sorted_indices, sorted_mask)
|
322
|
+
logits[indices_to_remove] = float('-inf')
|
356
323
|
|
357
|
-
|
324
|
+
# Apply top-k filtering if specified
|
325
|
+
if top_k is not None:
|
326
|
+
top_k = min(top_k, logits.size(-1)) # safety check
|
327
|
+
topk_logits, topk_indices = torch.topk(logits, top_k, dim=-1)
|
328
|
+
topk_probs = F.softmax(topk_logits, dim=-1).cpu().numpy()
|
329
|
+
|
330
|
+
# For each batch, sample from top_k candidates
|
331
|
+
selected = []
|
332
|
+
for i in range(topk_probs.shape[0]):
|
333
|
+
candidate = np.random.choice(topk_indices[i].cpu().numpy(), 1, p=topk_probs[i])
|
334
|
+
selected.append(candidate[0])
|
335
|
+
selected = torch.tensor(selected, dtype=torch.long)
|
336
|
+
|
337
|
+
else:
|
338
|
+
# If only top_p is specified, sample from entire filtered logits
|
339
|
+
probs = F.softmax(logits, dim=-1).cpu().numpy()
|
340
|
+
selected = []
|
341
|
+
for i in range(probs.shape[0]):
|
342
|
+
candidate = np.random.choice(len(probs[i]), 1, p=probs[i])
|
343
|
+
selected.append(candidate[0])
|
344
|
+
selected = torch.tensor(selected, dtype=torch.long)
|
358
345
|
|
359
346
|
return selected
|
360
347
|
|
@@ -362,10 +349,8 @@ def top_kp_filter(logits:torch.tensor,
|
|
362
349
|
|
363
350
|
class SelfAttention(nn.Module):
|
364
351
|
'''
|
365
|
-
|
366
352
|
single self attention block of size head_size.
|
367
353
|
triangle_mask=True to apply look-ahead mask of size block_size.
|
368
|
-
|
369
354
|
'''
|
370
355
|
def __init__(self,
|
371
356
|
head_size:int,
|
@@ -387,16 +372,14 @@ class SelfAttention(nn.Module):
|
|
387
372
|
self.dropout = nn.Dropout(dropout)
|
388
373
|
|
389
374
|
def forward(self,
|
390
|
-
k:torch.
|
391
|
-
q:torch.
|
392
|
-
v:torch.
|
393
|
-
mask:torch.
|
375
|
+
k:torch.Tensor,
|
376
|
+
q:torch.Tensor,
|
377
|
+
v:torch.Tensor,
|
378
|
+
mask:torch.Tensor=None
|
394
379
|
) -> torch.tensor:
|
395
380
|
'''
|
396
|
-
|
397
381
|
k, q and v are key, tensors to get key, query and value tensors.
|
398
382
|
custom mask tensor can be applied.
|
399
|
-
|
400
383
|
'''
|
401
384
|
_,T,_ = k.shape
|
402
385
|
|
@@ -406,7 +389,7 @@ class SelfAttention(nn.Module):
|
|
406
389
|
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
|
407
390
|
if self.triangle_mask and self.block_size >= 0:
|
408
391
|
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
409
|
-
if mask
|
392
|
+
if mask is not None:
|
410
393
|
wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
|
411
394
|
wei = F.softmax(wei, dim=-1)
|
412
395
|
wei = self.dropout(wei)
|
@@ -417,12 +400,10 @@ class SelfAttention(nn.Module):
|
|
417
400
|
|
418
401
|
class MultiHeadAttention(nn.Module):
|
419
402
|
'''
|
420
|
-
|
421
403
|
multi-head attention block consisting of num_heads SelfAttention blocks and a linear layer to
|
422
404
|
rejoin outputs.
|
423
405
|
specified head_size, n_embed, dropout, block_size and triangle_mask values are passed through to
|
424
406
|
SelfAttention blocks
|
425
|
-
|
426
407
|
'''
|
427
408
|
def __init__(self,
|
428
409
|
num_heads:int,
|
@@ -438,16 +419,14 @@ class MultiHeadAttention(nn.Module):
|
|
438
419
|
self.dropout = nn.Dropout(dropout)
|
439
420
|
|
440
421
|
def forward(self,
|
441
|
-
k:torch.
|
442
|
-
q:torch.
|
443
|
-
v:torch.
|
444
|
-
mask:torch.
|
422
|
+
k:torch.Tensor,
|
423
|
+
q:torch.Tensor,
|
424
|
+
v:torch.Tensor,
|
425
|
+
mask:torch.Tensor=None
|
445
426
|
) -> torch.tensor:
|
446
427
|
'''
|
447
|
-
|
448
428
|
k, q and v are key, tensors to get key, query and value tensors.
|
449
429
|
custom mask tensor can be applied.
|
450
|
-
|
451
430
|
'''
|
452
431
|
out = torch.cat([h(k, q, v, mask=mask) for h in self.heads], dim=-1)
|
453
432
|
out = self.dropout(self.proj(out))
|
@@ -455,11 +434,9 @@ class MultiHeadAttention(nn.Module):
|
|
455
434
|
|
456
435
|
class FeedForward(nn.Module):
|
457
436
|
'''
|
458
|
-
|
459
437
|
feed forward layer used after multi-head attention consisting of 2 lieanr layers with
|
460
438
|
a ReLU in between. Linear layers expand from n_embed to n_embed * expansion_factor and
|
461
439
|
back to n_embed.
|
462
|
-
|
463
440
|
'''
|
464
441
|
def __init__(self,
|
465
442
|
n_embed:int,
|
@@ -475,16 +452,14 @@ class FeedForward(nn.Module):
|
|
475
452
|
)
|
476
453
|
|
477
454
|
def forward(self,
|
478
|
-
x:torch.
|
455
|
+
x:torch.Tensor
|
479
456
|
) -> torch.tensor:
|
480
457
|
return self.net(x)
|
481
458
|
|
482
459
|
class EncoderBlock(nn.Module):
|
483
460
|
'''
|
484
|
-
|
485
461
|
encoder block consists of a sequence of multi-head attention, LayerNorm, feed-forward, LayerNorm
|
486
462
|
head_size is calculated from n_embed // n_head
|
487
|
-
|
488
463
|
'''
|
489
464
|
def __init__(self,
|
490
465
|
n_embed:int,
|
@@ -500,8 +475,8 @@ class EncoderBlock(nn.Module):
|
|
500
475
|
self.ln2 = nn.LayerNorm(n_embed)
|
501
476
|
|
502
477
|
def forward(self,
|
503
|
-
x:torch.
|
504
|
-
mask:torch.
|
478
|
+
x:torch.Tensor,
|
479
|
+
mask:torch.Tensor=None
|
505
480
|
) -> tuple[torch.tensor]:
|
506
481
|
att = self.sa(x, x, x, mask=mask)
|
507
482
|
x = self.ln1(att + x)
|
@@ -512,13 +487,11 @@ class EncoderBlock(nn.Module):
|
|
512
487
|
|
513
488
|
class DecoderBlock(nn.Module):
|
514
489
|
'''
|
515
|
-
|
516
490
|
decoder block consists of a sequence of multi-head attention, LayerNorm, feed-forward, LayerNorm
|
517
491
|
if cross-attention is True, a multi-head attention block and layerNorm is added before feed-forward
|
518
492
|
taking specified enc_k and enc_v tensors as value and key tensors. These values should be the output
|
519
493
|
of an encoder block.
|
520
494
|
head_size is calculated from n_embed // n_head
|
521
|
-
|
522
495
|
'''
|
523
496
|
def __init__(self,
|
524
497
|
n_embed:int,
|
@@ -541,15 +514,15 @@ class DecoderBlock(nn.Module):
|
|
541
514
|
self.ca = None
|
542
515
|
|
543
516
|
def forward(self,
|
544
|
-
x:torch.
|
545
|
-
enc_k:torch.
|
546
|
-
enc_v:torch.
|
517
|
+
x:torch.Tensor,
|
518
|
+
enc_k:torch.Tensor,
|
519
|
+
enc_v:torch.Tensor,
|
547
520
|
mask_out:bool=None,
|
548
|
-
mask_in:torch.
|
521
|
+
mask_in:torch.Tensor=None
|
549
522
|
) -> tuple[torch.tensor]:
|
550
523
|
att = self.sa(x, x, x, mask=mask_out)
|
551
524
|
x = self.ln1(att + x)
|
552
|
-
if self.ca
|
525
|
+
if self.ca is not None:
|
553
526
|
catt = self.ca(enc_k, x, enc_v, mask=mask_in)
|
554
527
|
x = self.ln3(catt + x)
|
555
528
|
ff = self.ffwd(x)
|
@@ -558,9 +531,7 @@ class DecoderBlock(nn.Module):
|
|
558
531
|
|
559
532
|
class MySequential(nn.Sequential):
|
560
533
|
'''
|
561
|
-
|
562
534
|
MySequential serves the same purpose as nn.Sequential but allows for multiple inputs and outputs
|
563
|
-
|
564
535
|
'''
|
565
536
|
def forward(self, *input):
|
566
537
|
for module in self._modules.values():
|
@@ -569,39 +540,12 @@ class MySequential(nn.Sequential):
|
|
569
540
|
|
570
541
|
class RoboConstructor(nn.Module):
|
571
542
|
'''
|
572
|
-
|
573
543
|
RoboConstructor assembles an encoder-decoder or decoder-only transformer.
|
574
544
|
if the enc_* variables are not specified, or enc_n_blocks==0, the transformer will be decoder-only.
|
575
545
|
- if any of the dec_* variables are not specified (except dec_expansion_factor) an error will occur.
|
576
546
|
- if enc_n_blocks > 0 and any of the enc_* variables are not specified (except enc_expansion_factor and enc_block_size) an error will occur.
|
577
547
|
dropout can be specified, default=0.1.
|
578
548
|
if device is not specified, device will default to first available among ("cuda", "mps", "cpu")
|
579
|
-
|
580
|
-
prep_data() function returns a batch of specified batch_size, from dec_data (and dec_masks, enc_data and enc_masks if specified)
|
581
|
-
- if encoder is configured in this instance, enc_data must be specified.
|
582
|
-
- dec_block_size must be specified.
|
583
|
-
- if enc_block_size is not specified, the entire block_size of enc_data will be used.
|
584
|
-
this function is for use in train_robo()
|
585
|
-
|
586
|
-
train_robo() function trains the RoboConstructor instance transformer.
|
587
|
-
- training parameters can be specified such as max_iters, eval_interval, batch_size, eval_iters, learning_rate, label_smoothing.
|
588
|
-
- paths must be specified for decoder training data (and encoder training data if encoder-decoder transformer)
|
589
|
-
- optional paths to specify: decoder and encoder masks, decoder and encoder validation data, decoder and encoder validation masks data
|
590
|
-
- if neither pad_token or tokenizer is specified (or tokenizer has no pad_token), any padding in labels will contribute towards the loss
|
591
|
-
which may cause unwanted results. Specifying pad_token and/or tokenizer allows loss to be calculated while ignoring any padding in labels
|
592
|
-
- specify save_path to save the model as a .pkl file every eval_interval iterations using the save_component function.
|
593
|
-
|
594
|
-
generate() function uses the tranformer model from the RoboConstructor instance to generate an output from an input.
|
595
|
-
- input can be in the form of a string if input tokenizer is specified (enc_tokenizer for encoder-decoder and dec_tokenizder for decoder-only),
|
596
|
-
otherwise, it must be in the form of a list of tokens.
|
597
|
-
- if dec_tokenizer is specified, output will be a string.
|
598
|
-
- new tokens are generated until the dec_end_token (or dec_tokenizer.end_token) is generated, or the number of tokens generated == max_new_tokens.
|
599
|
-
- if input tokenizer is not specified, or input tokenizer.start_token is None, enc_start_token must be specified for an encoder-decoder model.
|
600
|
-
- separator_token is used to separate the input and generated tokens for a decoder-only model. If this value is not specified, there
|
601
|
-
will be no distinction between input tokens and generated tokens to the transformer, even if dec_tokenizer is specified.
|
602
|
-
- if new_line_token is not specified, output will be returned in one line, without any "\n" line separators.
|
603
|
-
- temperature, top_k and top_p can be specified to adjust the output.
|
604
|
-
|
605
549
|
'''
|
606
550
|
def __init__(self,
|
607
551
|
n_embed:int,
|
@@ -628,7 +572,7 @@ class RoboConstructor(nn.Module):
|
|
628
572
|
self.dec_expansion_factor = dec_expansion_factor
|
629
573
|
self.dropout = dropout
|
630
574
|
|
631
|
-
if device
|
575
|
+
if device is None:
|
632
576
|
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
633
577
|
else:
|
634
578
|
self.device = device
|
@@ -673,13 +617,13 @@ class RoboConstructor(nn.Module):
|
|
673
617
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
674
618
|
|
675
619
|
def forward(self,
|
676
|
-
dec_in:torch.
|
677
|
-
dec_mask:torch.
|
678
|
-
enc_in:torch.
|
679
|
-
enc_mask:torch.
|
620
|
+
dec_in:torch.Tensor,
|
621
|
+
dec_mask:torch.Tensor=None,
|
622
|
+
enc_in:torch.Tensor=None,
|
623
|
+
enc_mask:torch.Tensor=None
|
680
624
|
) -> torch.tensor:
|
681
625
|
_, dec_T = dec_in.shape
|
682
|
-
if enc_in
|
626
|
+
if enc_in is not None:
|
683
627
|
_, enc_T = enc_in.shape
|
684
628
|
|
685
629
|
dec_tok_emb = self.dec_token_embedding_table(dec_in)
|
@@ -714,17 +658,24 @@ class RoboConstructor(nn.Module):
|
|
714
658
|
enc_block_size:int=None,
|
715
659
|
enc_masks:str=None
|
716
660
|
) -> tuple[torch.tensor]:
|
661
|
+
'''
|
662
|
+
returns a batch of specified batch_size, from dec_data (and dec_masks, enc_data and enc_masks if specified)
|
663
|
+
- if encoder is configured in this instance, enc_data must be specified.
|
664
|
+
- dec_block_size must be specified.
|
665
|
+
- if enc_block_size is not specified, the entire block_size of enc_data will be used.
|
666
|
+
this method is for use in train_robo()
|
667
|
+
'''
|
717
668
|
random_samples = torch.randint(dec_data.shape[0], (batch_size,))
|
718
669
|
|
719
670
|
dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
|
720
671
|
dec_train_batch_in = dec_train_batch_in.to(self.device)
|
721
|
-
dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out
|
722
|
-
dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in
|
672
|
+
dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out is not None else None
|
673
|
+
dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in is not None else None
|
723
674
|
|
724
675
|
if self.cross_attention:
|
725
676
|
enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
|
726
677
|
enc_train_batch_in = enc_train_batch_in.to(self.device)
|
727
|
-
enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in
|
678
|
+
enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in is not None else None
|
728
679
|
else:
|
729
680
|
enc_train_batch_in = None
|
730
681
|
enc_train_masks_in = None
|
@@ -736,14 +687,8 @@ class RoboConstructor(nn.Module):
|
|
736
687
|
max_iters:int,
|
737
688
|
eval_interval:int,
|
738
689
|
batch_size:int,
|
739
|
-
|
740
|
-
|
741
|
-
dec_training_masks_path:str=None,
|
742
|
-
dec_eval_masks_path:str=None,
|
743
|
-
enc_training_path:str=None,
|
744
|
-
enc_eval_path:str=None,
|
745
|
-
enc_training_masks_path:str=None,
|
746
|
-
enc_eval_masks_path:str=None,
|
690
|
+
training_dir_path:str,
|
691
|
+
eval_dir_path:str=None,
|
747
692
|
eval_iters:int=3,
|
748
693
|
learning_rate:float=1e-4,
|
749
694
|
pad_token:int=None,
|
@@ -751,22 +696,46 @@ class RoboConstructor(nn.Module):
|
|
751
696
|
save_path:str=None,
|
752
697
|
label_smoothing:float=0.1
|
753
698
|
) -> None:
|
699
|
+
'''
|
700
|
+
trains the RoboConstructor instance transformer.
|
701
|
+
- training parameters can be specified such as max_iters, eval_interval, batch_size, eval_iters, learning_rate, label_smoothing.
|
702
|
+
- paths must be specified for decoder training data (and encoder training data if encoder-decoder transformer)
|
703
|
+
- optional paths to specify: decoder and encoder masks, decoder and encoder validation data, decoder and encoder validation masks data
|
704
|
+
- if neither pad_token or tokenizer is specified (or tokenizer has no pad_token), any padding in labels will contribute towards the loss
|
705
|
+
which may cause unwanted results. Specifying pad_token and/or tokenizer allows loss to be calculated while ignoring any padding in labels
|
706
|
+
- specify save_path to save the model as a .pkl file every eval_interval iterations using the save_component function.
|
707
|
+
'''
|
754
708
|
|
709
|
+
dec_training_path = os.path.join(training_dir_path, "decoder_data.pt")
|
755
710
|
dec_training_data = torch.load(dec_training_path, weights_only=True)
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
711
|
+
|
712
|
+
dec_eval_path = os.path.join(eval_dir_path, "decoder_data.pt")
|
713
|
+
dec_eval_data = torch.load(dec_eval_path, weights_only=True) if os.path.isfile(dec_eval_path) else None
|
714
|
+
|
715
|
+
dec_training_masks_path = os.path.join(training_dir_path, "decoder_mask_data.pt")
|
716
|
+
dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if os.path.isfile(dec_training_masks_path) else None
|
717
|
+
|
718
|
+
dec_eval_masks_path = os.path.join(eval_dir_path, "decoder_mask_data.pt")
|
719
|
+
dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if os.path.isfile(dec_eval_masks_path) else None
|
720
|
+
|
721
|
+
enc_training_path = os.path.join(training_dir_path, "encoder_data.pt")
|
722
|
+
enc_training_data = torch.load(enc_training_path, weights_only=True) if os.path.isfile(enc_training_path) else None
|
723
|
+
|
724
|
+
enc_eval_path = os.path.join(eval_dir_path, "encoder_data.pt")
|
725
|
+
enc_eval_data = torch.load(enc_eval_path, weights_only=True) if os.path.isfile(enc_eval_path) else None
|
726
|
+
|
727
|
+
enc_training_masks_path = os.path.join(training_dir_path, "encoder_mask_data.pt")
|
728
|
+
enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if os.path.isfile(enc_training_masks_path) else None
|
729
|
+
|
730
|
+
enc_eval_masks_path = os.path.join(eval_dir_path, "encoder_mask_data.pt")
|
731
|
+
enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if os.path.isfile(enc_eval_masks_path) else None
|
732
|
+
|
733
|
+
if pad_token is None and dec_tokenizer is not None:
|
765
734
|
pad_token = dec_tokenizer.pad_token
|
766
735
|
|
767
736
|
self.to(self.device)
|
768
737
|
|
769
|
-
if pad_token
|
738
|
+
if pad_token is not None:
|
770
739
|
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
|
771
740
|
else:
|
772
741
|
loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
|
@@ -782,7 +751,7 @@ class RoboConstructor(nn.Module):
|
|
782
751
|
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
783
752
|
losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
|
784
753
|
out["train"] = losses.mean()
|
785
|
-
if dec_eval_data
|
754
|
+
if dec_eval_data is not None:
|
786
755
|
for k in range(eval_iters):
|
787
756
|
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
|
788
757
|
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
@@ -798,7 +767,7 @@ class RoboConstructor(nn.Module):
|
|
798
767
|
if iter % eval_interval == 0 or iter == max_iters-1:
|
799
768
|
losses = estimate_loss()
|
800
769
|
print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
|
801
|
-
if save_path
|
770
|
+
if save_path is not None:
|
802
771
|
save_component(self, save_path=save_path)
|
803
772
|
|
804
773
|
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
|
@@ -825,25 +794,37 @@ class RoboConstructor(nn.Module):
|
|
825
794
|
top_k:int=None,
|
826
795
|
top_p:float=None
|
827
796
|
) -> list[int]|str:
|
828
|
-
|
797
|
+
'''
|
798
|
+
uses the tranformer model from the RoboConstructor instance to generate an output from an input.
|
799
|
+
- input can be in the form of a string if input tokenizer is specified (enc_tokenizer for encoder-decoder and dec_tokenizder for decoder-only),
|
800
|
+
otherwise, it must be in the form of a list of tokens.
|
801
|
+
- if dec_tokenizer is specified, output will be a string.
|
802
|
+
- new tokens are generated until the dec_end_token (or dec_tokenizer.end_token) is generated, or the number of tokens generated == max_new_tokens.
|
803
|
+
- if input tokenizer is not specified, or input tokenizer.start_token is None, enc_start_token must be specified for an encoder-decoder model.
|
804
|
+
- separator_token is used to separate the input and generated tokens for a decoder-only model. If this value is not specified, there
|
805
|
+
will be no distinction between input tokens and generated tokens to the transformer, even if dec_tokenizer is specified.
|
806
|
+
- if new_line_token is not specified, output will be returned in one line, without any "\n" line separators.
|
807
|
+
- temperature, top_k and top_p can be specified to adjust the output.
|
808
|
+
'''
|
809
|
+
max_new_tokens = self.dec_block_size if max_new_tokens is None else max_new_tokens
|
829
810
|
|
830
811
|
if self.cross_attention:
|
831
|
-
if enc_tokenizer
|
832
|
-
if enc_start_token
|
812
|
+
if enc_tokenizer is not None:
|
813
|
+
if enc_start_token is None:
|
833
814
|
enc_start_token = enc_tokenizer.start_token
|
834
|
-
if enc_end_token
|
815
|
+
if enc_end_token is None:
|
835
816
|
enc_end_token = enc_tokenizer.end_token
|
836
817
|
if isinstance(inputs, str):
|
837
818
|
inputs = enc_tokenizer.encode(inputs)
|
838
819
|
|
839
|
-
if dec_tokenizer
|
840
|
-
if dec_start_token
|
820
|
+
if dec_tokenizer is not None:
|
821
|
+
if dec_start_token is None:
|
841
822
|
dec_start_token = dec_tokenizer.start_token
|
842
|
-
if dec_end_token
|
823
|
+
if dec_end_token is None:
|
843
824
|
dec_end_token = dec_tokenizer.end_token
|
844
|
-
if new_line_token
|
825
|
+
if new_line_token is None:
|
845
826
|
new_line_token = dec_tokenizer.new_line_token
|
846
|
-
if self.cross_attention
|
827
|
+
if not self.cross_attention and isinstance(inputs, str):
|
847
828
|
inputs = dec_tokenizer.encode(inputs)
|
848
829
|
|
849
830
|
|
@@ -852,7 +833,7 @@ class RoboConstructor(nn.Module):
|
|
852
833
|
idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
|
853
834
|
else:
|
854
835
|
enc_input = None
|
855
|
-
if separator_token
|
836
|
+
if separator_token is not None:
|
856
837
|
idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
|
857
838
|
else:
|
858
839
|
idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
|
@@ -866,7 +847,7 @@ class RoboConstructor(nn.Module):
|
|
866
847
|
logits = proj_output[:, -1, :]
|
867
848
|
probabilities = F.log_softmax(logits/temperature, dim=-1)
|
868
849
|
|
869
|
-
if top_k
|
850
|
+
if top_k is None and top_p is None:
|
870
851
|
idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
|
871
852
|
else:
|
872
853
|
idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
|
@@ -874,10 +855,10 @@ class RoboConstructor(nn.Module):
|
|
874
855
|
if idx_next[0] == dec_end_token:
|
875
856
|
break
|
876
857
|
|
877
|
-
if dec_tokenizer
|
858
|
+
if dec_tokenizer is None:
|
878
859
|
return idx[0].tolist()
|
879
860
|
else:
|
880
|
-
if new_line_token
|
861
|
+
if new_line_token is not None:
|
881
862
|
return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
|
882
863
|
else:
|
883
864
|
return dec_tokenizer.decode(idx[0].tolist())
|
@@ -885,9 +866,7 @@ class RoboConstructor(nn.Module):
|
|
885
866
|
|
886
867
|
def save_component(component, save_path:str) -> None:
|
887
868
|
'''
|
888
|
-
|
889
869
|
saves component (such as TokenizerConstructor or RoboConstructor) as .pkl file.
|
890
|
-
|
891
870
|
'''
|
892
871
|
save_path = save_path + ".pkl" if save_path[-4:] != ".pkl" else save_path
|
893
872
|
with open(save_path, "wb") as comp:
|
@@ -895,9 +874,7 @@ def save_component(component, save_path:str) -> None:
|
|
895
874
|
|
896
875
|
def load_component(load_path:str):
|
897
876
|
'''
|
898
|
-
|
899
877
|
loads saved .pkl file into variable.
|
900
|
-
|
901
878
|
'''
|
902
879
|
load_path = load_path + ".pkl" if load_path[-4:] != ".pkl" else load_path
|
903
880
|
with open(load_path, "rb") as comp:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: robo_lib
|
3
|
-
Version:
|
3
|
+
Version: 1.0.1
|
4
4
|
Summary: A package to create, configure, and train transformer models.
|
5
5
|
Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
|
6
6
|
Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
|
@@ -13,6 +13,7 @@ Requires-Python: >=3.8
|
|
13
13
|
Requires-Dist: numpy
|
14
14
|
Requires-Dist: tokenizers
|
15
15
|
Requires-Dist: torch
|
16
|
+
Requires-Dist: typing
|
16
17
|
Description-Content-Type: text/markdown
|
17
18
|
|
18
19
|
# robo-lib
|
@@ -83,10 +84,8 @@ proc.process_list(
|
|
83
84
|
save_path="data/training",
|
84
85
|
dec_data=french_train,
|
85
86
|
dec_max_block_size=100,
|
86
|
-
dec_block_size_exceeded_policy="skip",
|
87
87
|
enc_data=english_train,
|
88
|
-
enc_max_block_size=100
|
89
|
-
enc_block_size_exceeded_policy="skip"
|
88
|
+
enc_max_block_size=100
|
90
89
|
)
|
91
90
|
|
92
91
|
# process and save validation data as data/validation*.pt
|
@@ -94,10 +93,8 @@ proc.process_list(
|
|
94
93
|
save_path="data/validation",
|
95
94
|
dec_data=french_val,
|
96
95
|
dec_max_block_size=100,
|
97
|
-
dec_block_size_exceeded_policy="skip",
|
98
96
|
enc_data=english_val,
|
99
|
-
enc_max_block_size=100
|
100
|
-
enc_block_size_exceeded_policy="skip"
|
97
|
+
enc_max_block_size=100
|
101
98
|
)
|
102
99
|
```
|
103
100
|
- The `RoboConstructor` class is used to create and configure transformer models before trainin.
|
@@ -128,14 +125,8 @@ robo.train_robo(
|
|
128
125
|
max_iters=20000,
|
129
126
|
eval_interval=200,
|
130
127
|
batch_size=128,
|
131
|
-
|
132
|
-
|
133
|
-
dec_training_masks_path="data/training_decoder_mask_data.pt",
|
134
|
-
dec_eval_masks_path="data/validation_decoder_mask_data.pt",
|
135
|
-
enc_training_path="data/training_encoder_data.pt",
|
136
|
-
enc_eval_path="data/validation_encoder_data.pt",
|
137
|
-
enc_training_masks_path="data/training_encoder_mask_data.pt",
|
138
|
-
enc_eval_masks_path="data/validation_encoder_mask_data.pt",
|
128
|
+
training_dir_path="data/training",
|
129
|
+
eval_dir_path="data/validation",
|
139
130
|
dec_tokenizer=decoder_tok,
|
140
131
|
save_path="models/eng_to_fr_robo.pkl"
|
141
132
|
)
|
@@ -223,8 +214,8 @@ robo.train(
|
|
223
214
|
max_iters=20000,
|
224
215
|
eval_interval=200,
|
225
216
|
batch_size=64,
|
226
|
-
|
227
|
-
|
217
|
+
training_dir_path="data/shakespeare_train",
|
218
|
+
eval_dir_path="data/shakespeare_valid",
|
228
219
|
dec_tokenizer=tok,
|
229
220
|
save_path="models/shakespeare_robo.pkl"
|
230
221
|
)
|
@@ -0,0 +1,6 @@
|
|
1
|
+
robo_lib/__init__.py,sha256=NnzWHWwpFcSJD_XRMWKKPQFAIrRBFYiCFN0pgUGPygc,968
|
2
|
+
robo_lib/components.py,sha256=mfvdNC77d1k1vmlNwG3ri2MbfmEn3haACAnRf56b_c4,43164
|
3
|
+
robo_lib-1.0.1.dist-info/METADATA,sha256=4CG07VLULgAcGlfNeNXS9Pjzs7SXP5gNf95ddgGbWqc,9051
|
4
|
+
robo_lib-1.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
robo_lib-1.0.1.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
|
6
|
+
robo_lib-1.0.1.dist-info/RECORD,,
|
robo_lib-0.0.11.dist-info/RECORD
DELETED
@@ -1,6 +0,0 @@
|
|
1
|
-
robo_lib/__init__.py,sha256=iVOAsANj0lScVW9KKMxCULYmpp0cv4sv1k3sHjBSlE0,1012
|
2
|
-
robo_lib/components.py,sha256=L_GUEHdKC_-Xn56ObQ9-DH8T1ywaz0M8jlWv227gZBs,42591
|
3
|
-
robo_lib-0.0.11.dist-info/METADATA,sha256=ePF06l2FXzo0qjK8v9Vob4WnOQ61KVd0mUqd7JVG7j4,9634
|
4
|
-
robo_lib-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
robo_lib-0.0.11.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
|
6
|
-
robo_lib-0.0.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|