robo-lib 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- robo_lib/__init__.py +18 -0
- robo_lib/components.py +893 -0
- robo_lib-0.0.4.dist-info/METADATA +18 -0
- robo_lib-0.0.4.dist-info/RECORD +6 -0
- robo_lib-0.0.4.dist-info/WHEEL +4 -0
- robo_lib-0.0.4.dist-info/licenses/LICENSE +21 -0
robo_lib/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
from .components import TokenizerConstructor as TokenizerConstructor
|
2
|
+
from .components import create_mask as create_mask
|
3
|
+
from .components import pad as pad
|
4
|
+
from .components import process_row as process_row
|
5
|
+
from .components import scan_max_block_size as scan_max_block_size
|
6
|
+
from .components import DataProcessor as DataProcessor
|
7
|
+
from .components import get_valid_samples as get_valid_samples
|
8
|
+
from .components import get_batch as get_batch
|
9
|
+
from .components import top_kp_filter as top_kp_filter
|
10
|
+
from .components import SelfAttention as SelfAttention
|
11
|
+
from .components import MultiHeadAttention as MultiHeadAttention
|
12
|
+
from .components import FeedForward as FeedForward
|
13
|
+
from .components import EncoderBlock as EncoderBlock
|
14
|
+
from .components import DecoderBlock as DecoderBlock
|
15
|
+
from .components import MySequential as MySequential
|
16
|
+
from .components import RoboConstructor as RoboConstructor
|
17
|
+
from .components import save_component as save_component
|
18
|
+
from .components import load_component as load_component
|
robo_lib/components.py
ADDED
@@ -0,0 +1,893 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import tokenizers
|
5
|
+
import numpy as np
|
6
|
+
import random
|
7
|
+
import pickle
|
8
|
+
import itertools
|
9
|
+
|
10
|
+
class TokenizerConstructor:
|
11
|
+
'''
|
12
|
+
|
13
|
+
simple assembler for tokenizer using the tokenizers library
|
14
|
+
tokenizer parameters can be set using strings and list[string]s
|
15
|
+
strings used for tokenizer_type, pre_tokenizers, normalizers arguments are the names of those present in the
|
16
|
+
tokenizers library. Additionally "IndividualDigits" can be used in normalizers for tokenizers.pre_tokenizers.Digits(individual_digits=True)
|
17
|
+
|
18
|
+
train([paths]) function points to text files to be used for training the tokenizer instance
|
19
|
+
|
20
|
+
encode(string) function encodes string using trained tokenizer instance
|
21
|
+
|
22
|
+
decode(list[int]) function decodes list of tokenz using trained tokenizer instance
|
23
|
+
|
24
|
+
vocab_size attribute returns the tokenizer instance's vocab_size (untrained tokenizer will have vocab_size=None)
|
25
|
+
|
26
|
+
|
27
|
+
'''
|
28
|
+
def __init__(self,
|
29
|
+
min_frequency:int=2,
|
30
|
+
tokenizer_type:str="BPE",
|
31
|
+
pre_tokenizers:list[str]|str=["Whitespace"],
|
32
|
+
normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
|
33
|
+
special_tokens:list[str]|str=[],
|
34
|
+
unknown_token_string:str="<unk>",
|
35
|
+
start_token_string:str="<sos>",
|
36
|
+
end_token_string:str="<eos>",
|
37
|
+
pad_token_string:str="<pad>",
|
38
|
+
new_line_token_string:str="\n",
|
39
|
+
vocab_size:int=30000
|
40
|
+
) -> None:
|
41
|
+
self.vocab_size = None
|
42
|
+
|
43
|
+
if isinstance(special_tokens, str):
|
44
|
+
special_tokens = [special_tokens]
|
45
|
+
self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token != None]
|
46
|
+
self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string != None else None
|
47
|
+
self.start_token = self.special_tokens.index(start_token_string) if start_token_string != None else None
|
48
|
+
self.end_token = self.special_tokens.index(end_token_string) if end_token_string != None else None
|
49
|
+
self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string != None else None
|
50
|
+
self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string != None else None
|
51
|
+
|
52
|
+
if tokenizer_type == "BPE":
|
53
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
|
54
|
+
self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
55
|
+
elif tokenizer_type == "WordLevel":
|
56
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
|
57
|
+
self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
58
|
+
elif tokenizer_type == "WordPiece":
|
59
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
|
60
|
+
self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
61
|
+
elif tokenizer_type == "Unigram":
|
62
|
+
self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(unk_token=unknown_token_string))
|
63
|
+
self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
|
64
|
+
|
65
|
+
if isinstance(pre_tokenizers, str):
|
66
|
+
pre_tokenizers = [pre_tokenizers]
|
67
|
+
sequence = []
|
68
|
+
for pre_tok in pre_tokenizers:
|
69
|
+
if pre_tok == "Whitespace":
|
70
|
+
sequence.append(tokenizers.pre_tokenizers.Whitespace())
|
71
|
+
elif pre_tok == "IndividualDigits":
|
72
|
+
sequence.append(tokenizers.pre_tokenizers.Digits(individual_digits=True))
|
73
|
+
elif pre_tok == "Digits":
|
74
|
+
sequence.append(tokenizers.pre_tokenizers.Digits(individual_digits=False))
|
75
|
+
elif pre_tok == "BertPreTokenizer":
|
76
|
+
sequence.append(tokenizers.pre_tokenizers.BertPreTokenizer())
|
77
|
+
elif pre_tok == "ByteLevel":
|
78
|
+
sequence.append(tokenizers.pre_tokenizers.ByteLevel())
|
79
|
+
elif pre_tok == "Metaspace":
|
80
|
+
sequence.append(tokenizers.pre_tokenizers.Metaspace())
|
81
|
+
elif pre_tok == "Punctuation":
|
82
|
+
sequence.append(tokenizers.pre_tokenizers.Punctuation())
|
83
|
+
elif pre_tok == "UnicodeScripts":
|
84
|
+
sequence.append(tokenizers.pre_tokenizers.UnicodeScripts())
|
85
|
+
elif pre_tok == "WhitespaceSplit":
|
86
|
+
sequence.append(tokenizers.pre_tokenizers.WhitespaceSplit())
|
87
|
+
self.tokenizer_type.pre_tokenizer = tokenizers.pre_tokenizers.Sequence(sequence)
|
88
|
+
|
89
|
+
if isinstance(normalizers, str):
|
90
|
+
normalizers = [normalizers]
|
91
|
+
sequence = []
|
92
|
+
for norm in normalizers:
|
93
|
+
if norm == "Lowercase":
|
94
|
+
sequence.append(tokenizers.normalizers.Lowercase())
|
95
|
+
elif norm == "NFC":
|
96
|
+
sequence.append(tokenizers.normalizers.NFC())
|
97
|
+
elif norm == "NFD":
|
98
|
+
sequence.append(tokenizers.normalizers.NFD())
|
99
|
+
elif norm == "NFKC":
|
100
|
+
sequence.append(tokenizers.normalizers.NFKC())
|
101
|
+
elif norm == "NFKD":
|
102
|
+
sequence.append(tokenizers.normalizers.NFKD())
|
103
|
+
elif norm == "Nmt":
|
104
|
+
sequence.append(tokenizers.normalizers.Nmt())
|
105
|
+
elif norm == "BertNormalizer":
|
106
|
+
sequence.append(tokenizers.normalizers.BertNormalizer())
|
107
|
+
elif norm == "StripAccents":
|
108
|
+
sequence.append(tokenizers.normalizers.StripAccents())
|
109
|
+
elif norm == "Strip":
|
110
|
+
sequence.append(tokenizers.normalizers.Strip())
|
111
|
+
elif norm == "BertNormalizer":
|
112
|
+
sequence.append(tokenizers.normalizers.BertNormalizer())
|
113
|
+
self.tokenizer_type.normalizer = tokenizers.normalizers.Sequence(sequence)
|
114
|
+
|
115
|
+
|
116
|
+
def train(self, training_paths:list[str]|str) -> None:
|
117
|
+
if isinstance(training_paths, str):
|
118
|
+
training_paths = [training_paths]
|
119
|
+
self.tokenizer_type.train(training_paths, trainer=self.trainer)
|
120
|
+
self.vocab_size = self.tokenizer_type.get_vocab_size()
|
121
|
+
|
122
|
+
def encode(self, inp:str) -> list[int]:
|
123
|
+
return self.tokenizer_type.encode(inp).ids
|
124
|
+
|
125
|
+
def decode(self, inp:list[int]) -> str:
|
126
|
+
return self.tokenizer_type.decode(inp)
|
127
|
+
|
128
|
+
|
129
|
+
|
130
|
+
def create_mask(row:list, block_size:int) -> list[bool]:
|
131
|
+
'''
|
132
|
+
|
133
|
+
creates a mask list of length block_size for row, asuming mask does cover the entire row input
|
134
|
+
|
135
|
+
'''
|
136
|
+
mask = [1]*len(row) + [0]*(block_size - len(row))
|
137
|
+
return mask
|
138
|
+
|
139
|
+
def pad(row:list, block_size:int, pad_token:int) -> list[int]:
|
140
|
+
'''
|
141
|
+
|
142
|
+
returns padded row. Row is padded until length block_size with specified pad_token value
|
143
|
+
|
144
|
+
'''
|
145
|
+
row.extend([pad_token]*(block_size - len(row)))
|
146
|
+
return row
|
147
|
+
|
148
|
+
def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
|
149
|
+
'''
|
150
|
+
|
151
|
+
returns tokenized row using specified tokenizer, and adds the tokenizer's start and end tokens if they exist
|
152
|
+
|
153
|
+
'''
|
154
|
+
processed_row = tokenizer.encode(row)
|
155
|
+
if tokenizer.start_token != None:
|
156
|
+
processed_row.insert(0, tokenizer.start_token)
|
157
|
+
if tokenizer.end_token != None:
|
158
|
+
processed_row.append(tokenizer.end_token)
|
159
|
+
|
160
|
+
return processed_row
|
161
|
+
|
162
|
+
def scan_max_block_size(data:list[str], tokenizer:TokenizerConstructor) -> int:
|
163
|
+
'''
|
164
|
+
|
165
|
+
returns max_block_size of given list of strings by taking the length of the longest process_row(row) in data
|
166
|
+
|
167
|
+
'''
|
168
|
+
lengths = [len(process_row(p, tokenizer)) for p in data]
|
169
|
+
max_block_size_scanner = max(lengths)
|
170
|
+
return max_block_size_scanner
|
171
|
+
|
172
|
+
|
173
|
+
class DataProcessor:
|
174
|
+
'''
|
175
|
+
|
176
|
+
data processor can be instantiated by specifying the tokenizer(s) for decoder and encoder data
|
177
|
+
|
178
|
+
process_list() function processes raw data in the form of list[str] or str for decoder and encoder simultaneously and
|
179
|
+
saves them to save_path as .pt files.
|
180
|
+
- encoder and decoder input data should have matching input and outputs so enc_data[n] should have its corresponding
|
181
|
+
decoder data at dec_data[n].
|
182
|
+
- max block size can be specified for both input and output, default takes the max
|
183
|
+
block size provided in the data respectively.
|
184
|
+
- if enc/dec_block_size is specified and enc/dec_block_size_exceeded_policy is not, an error will occur if a piece
|
185
|
+
of data larger than enc/dec_block_size is encountered. enc/dec_block_size_exceeded_policy can be set to "skip" or
|
186
|
+
"trim" to skip rows larger than enc/dec_block_size or truncate the row to specified enc/dec_block_size respectively.
|
187
|
+
- enc/dec_create_masks saves masks tensors to save_path as .pt files.
|
188
|
+
|
189
|
+
|
190
|
+
'''
|
191
|
+
def __init__(self,
|
192
|
+
dec_tokenizer:TokenizerConstructor,
|
193
|
+
enc_tokenizer:TokenizerConstructor=None
|
194
|
+
) -> None:
|
195
|
+
self.dec_tokenizer = dec_tokenizer
|
196
|
+
self.enc_tokenizer = enc_tokenizer
|
197
|
+
|
198
|
+
def process_list(self,
|
199
|
+
save_path:str,
|
200
|
+
dec_data:list[str]|str,
|
201
|
+
dec_max_block_size:int=None,
|
202
|
+
dec_create_masks:bool=True,
|
203
|
+
dec_block_size_exceeded_policy:str=None,
|
204
|
+
enc_data:list[str]=None,
|
205
|
+
enc_create_masks=True,
|
206
|
+
enc_max_block_size:int=None,
|
207
|
+
enc_block_size_exceeded_policy:str=None
|
208
|
+
) -> None:
|
209
|
+
|
210
|
+
if isinstance(dec_data, str):
|
211
|
+
dec_data = [dec_data]
|
212
|
+
dec_data_length = len(dec_data)
|
213
|
+
save_path = save_path.replace(".pt", "")
|
214
|
+
|
215
|
+
if dec_max_block_size == None:
|
216
|
+
dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
|
217
|
+
|
218
|
+
if enc_data != None:
|
219
|
+
self.enc_tokenizer = self.dec_tokenizer if self.enc_tokenizer == None else self.enc_tokenizer
|
220
|
+
|
221
|
+
enc_data_length = len(enc_data)
|
222
|
+
if dec_data_length != enc_data_length:
|
223
|
+
raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
|
224
|
+
|
225
|
+
if enc_max_block_size == None:
|
226
|
+
enc_max_block_size = scan_max_block_size(enc_data, self.enc_tokenizer)
|
227
|
+
|
228
|
+
enc_out_list = [[]]*enc_data_length
|
229
|
+
enc_mask_list = [[]]*enc_data_length if enc_create_masks else []
|
230
|
+
else:
|
231
|
+
enc_out_list = []
|
232
|
+
enc_mask_list = []
|
233
|
+
|
234
|
+
dec_out_list = [[]]*dec_data_length
|
235
|
+
dec_mask_list = [[]]*dec_data_length if dec_create_masks else []
|
236
|
+
for index in range(len(dec_out_list)):
|
237
|
+
dec_processed_item = process_row(dec_data[index], self.dec_tokenizer)
|
238
|
+
if dec_max_block_size != None and len(dec_processed_item) > dec_max_block_size:
|
239
|
+
if dec_block_size_exceeded_policy == "trim":
|
240
|
+
dec_processed_item = dec_processed_item[:dec_max_block_size]
|
241
|
+
elif dec_block_size_exceeded_policy == "skip":
|
242
|
+
continue
|
243
|
+
elif dec_block_size_exceeded_policy == None:
|
244
|
+
raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
|
245
|
+
if dec_create_masks:
|
246
|
+
dec_mask = create_mask(dec_processed_item, dec_max_block_size)
|
247
|
+
dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
|
248
|
+
|
249
|
+
if enc_data != None:
|
250
|
+
enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
|
251
|
+
if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
|
252
|
+
if enc_block_size_exceeded_policy == "trim":
|
253
|
+
enc_processed_item = enc_processed_item[:enc_max_block_size]
|
254
|
+
elif enc_block_size_exceeded_policy == "skip":
|
255
|
+
continue
|
256
|
+
elif enc_block_size_exceeded_policy == None:
|
257
|
+
raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
|
258
|
+
if enc_create_masks:
|
259
|
+
enc_mask = create_mask(enc_processed_item, enc_max_block_size)
|
260
|
+
enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
|
261
|
+
|
262
|
+
dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
|
263
|
+
if dec_create_masks:
|
264
|
+
dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
|
265
|
+
|
266
|
+
if enc_data != None:
|
267
|
+
enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
|
268
|
+
if enc_create_masks:
|
269
|
+
enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
|
270
|
+
|
271
|
+
dec_out_list = torch.stack([row for row in dec_out_list if row != []])
|
272
|
+
torch.save(dec_out_list, save_path + "_decoder_data.pt")
|
273
|
+
if dec_create_masks:
|
274
|
+
dec_mask_list = torch.stack([row for row in dec_mask_list if row != []])
|
275
|
+
torch.save(dec_mask_list, save_path + "_decoder_mask_data.pt")
|
276
|
+
if enc_data != None:
|
277
|
+
enc_out_list = torch.stack([row for row in enc_out_list if row != []])
|
278
|
+
torch.save(enc_out_list, save_path + "_encoder_data.pt")
|
279
|
+
if enc_create_masks:
|
280
|
+
enc_mask_list = torch.stack([row for row in enc_mask_list if row != []])
|
281
|
+
torch.save(enc_mask_list, save_path + "_encoder_mask_data.pt")
|
282
|
+
|
283
|
+
|
284
|
+
def get_valid_samples(random_samples:torch.tensor,
|
285
|
+
masks:torch.tensor,
|
286
|
+
block_size:int
|
287
|
+
) -> list[int]:
|
288
|
+
'''
|
289
|
+
|
290
|
+
returns list of len(random_samples) with values corresponding to index values of masks that ensure minimum masked
|
291
|
+
values when taking sample of length block_size
|
292
|
+
|
293
|
+
'''
|
294
|
+
valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
|
295
|
+
return valid_samples
|
296
|
+
|
297
|
+
def get_batch(data:torch.tensor,
|
298
|
+
random_samples:torch.tensor,
|
299
|
+
masks:torch.tensor=None,
|
300
|
+
block_size:int=None,
|
301
|
+
get_offset:bool=True
|
302
|
+
) -> tuple[torch.tensor]:
|
303
|
+
'''
|
304
|
+
|
305
|
+
returns random batches from data tensor using random sample for data selection.
|
306
|
+
- returns corresponding batch offset by 1 unless get_offset=False
|
307
|
+
- returns corresponding masks batch if masks data is specified
|
308
|
+
|
309
|
+
'''
|
310
|
+
batch_size = len(random_samples)
|
311
|
+
if block_size != None and block_size != data.shape[1]:
|
312
|
+
if block_size >= data.shape[1]:
|
313
|
+
raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
|
314
|
+
|
315
|
+
if masks != None:
|
316
|
+
random_point = get_valid_samples(random_samples, masks, block_size)
|
317
|
+
else:
|
318
|
+
random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
|
319
|
+
batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
|
320
|
+
masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks != None else None
|
321
|
+
batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
|
322
|
+
else:
|
323
|
+
block_size = data.shape[1]
|
324
|
+
batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
|
325
|
+
masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks != None else None
|
326
|
+
batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
|
327
|
+
|
328
|
+
return batch_in, batch_out, masks_in
|
329
|
+
|
330
|
+
def top_kp_filter(logits:torch.tensor,
|
331
|
+
top_k:int,
|
332
|
+
top_p:float=None
|
333
|
+
) -> torch.tensor:
|
334
|
+
'''
|
335
|
+
|
336
|
+
returns predicted token by filtering output logits using top_k and top_p
|
337
|
+
|
338
|
+
'''
|
339
|
+
if top_p != None:
|
340
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
341
|
+
cumulative_probs = torch.cumsum(sorted_logits, dim=-1)
|
342
|
+
|
343
|
+
filter = cumulative_probs > top_p
|
344
|
+
filter[..., 1:] = filter[..., :-1].clone()
|
345
|
+
filter[..., 0] = 0
|
346
|
+
indices_to_remove = filter.scatter(1, sorted_indices, filter)
|
347
|
+
logits[indices_to_remove] = float("-inf")
|
348
|
+
|
349
|
+
if top_k != None:
|
350
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
351
|
+
|
352
|
+
sorted_logits = F.softmax(sorted_logits[:, :top_k], dim=-1)
|
353
|
+
sorted_indices = sorted_indices[:, :top_k].detach().cpu()
|
354
|
+
sorted_logits = sorted_logits.detach().cpu().numpy()
|
355
|
+
sorted_logits[0][0] += 1 - sum(sorted_logits[0])
|
356
|
+
|
357
|
+
selected = torch.tensor(np.random.choice(sorted_indices[0], 1, p=sorted_logits[0]), dtype=torch.long)
|
358
|
+
|
359
|
+
return selected
|
360
|
+
|
361
|
+
|
362
|
+
|
363
|
+
class SelfAttention(nn.Module):
|
364
|
+
'''
|
365
|
+
|
366
|
+
single self attention block of size head_size.
|
367
|
+
triangle_mask=True to apply look-ahead mask of size block_size.
|
368
|
+
|
369
|
+
'''
|
370
|
+
def __init__(self,
|
371
|
+
head_size:int,
|
372
|
+
n_embed:int,
|
373
|
+
dropout:float,
|
374
|
+
block_size:int=0,
|
375
|
+
triangle_mask:bool=True
|
376
|
+
) -> None:
|
377
|
+
super().__init__()
|
378
|
+
self.key = nn.Linear(n_embed, head_size, bias=False)
|
379
|
+
self.query = nn.Linear(n_embed, head_size, bias=False)
|
380
|
+
self.value = nn.Linear(n_embed, head_size, bias=False)
|
381
|
+
self.triangle_mask = triangle_mask
|
382
|
+
self.block_size = block_size
|
383
|
+
|
384
|
+
if self.triangle_mask and self.block_size >= 0:
|
385
|
+
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
|
386
|
+
|
387
|
+
self.dropout = nn.Dropout(dropout)
|
388
|
+
|
389
|
+
def forward(self,
|
390
|
+
k:torch.tensor,
|
391
|
+
q:torch.tensor,
|
392
|
+
v:torch.tensor,
|
393
|
+
mask:torch.tensor=None
|
394
|
+
) -> torch.tensor:
|
395
|
+
'''
|
396
|
+
|
397
|
+
k, q and v are key, tensors to get key, query and value tensors.
|
398
|
+
custom mask tensor can be applied.
|
399
|
+
|
400
|
+
'''
|
401
|
+
_,T,_ = k.shape
|
402
|
+
|
403
|
+
k = self.key(k)
|
404
|
+
q = self.query(q)
|
405
|
+
|
406
|
+
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
|
407
|
+
if self.triangle_mask and self.block_size >= 0:
|
408
|
+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
409
|
+
if mask != None:
|
410
|
+
wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
|
411
|
+
wei = F.softmax(wei, dim=-1)
|
412
|
+
wei = self.dropout(wei)
|
413
|
+
|
414
|
+
v = self.value(v)
|
415
|
+
out = wei @ v
|
416
|
+
return out
|
417
|
+
|
418
|
+
class MultiHeadAttention(nn.Module):
|
419
|
+
'''
|
420
|
+
|
421
|
+
multi-head attention block consisting of num_heads SelfAttention blocks and a linear layer to
|
422
|
+
rejoin outputs.
|
423
|
+
specified head_size, n_embed, dropout, block_size and triangle_mask values are passed through to
|
424
|
+
SelfAttention blocks
|
425
|
+
|
426
|
+
'''
|
427
|
+
def __init__(self,
|
428
|
+
num_heads:int,
|
429
|
+
head_size:int,
|
430
|
+
n_embed:int,
|
431
|
+
dropout:float=0.1,
|
432
|
+
block_size:int=0,
|
433
|
+
triangle_mask:bool=True
|
434
|
+
) -> None:
|
435
|
+
super().__init__()
|
436
|
+
self.heads = nn.ModuleList([SelfAttention(head_size, n_embed, dropout, block_size=block_size, triangle_mask=triangle_mask) for _ in range(num_heads)])
|
437
|
+
self.proj = nn.Linear(head_size * num_heads, n_embed)
|
438
|
+
self.dropout = nn.Dropout(dropout)
|
439
|
+
|
440
|
+
def forward(self,
|
441
|
+
k:torch.tensor,
|
442
|
+
q:torch.tensor,
|
443
|
+
v:torch.tensor,
|
444
|
+
mask:torch.tensor=None
|
445
|
+
) -> torch.tensor:
|
446
|
+
'''
|
447
|
+
|
448
|
+
k, q and v are key, tensors to get key, query and value tensors.
|
449
|
+
custom mask tensor can be applied.
|
450
|
+
|
451
|
+
'''
|
452
|
+
out = torch.cat([h(k, q, v, mask=mask) for h in self.heads], dim=-1)
|
453
|
+
out = self.dropout(self.proj(out))
|
454
|
+
return out
|
455
|
+
|
456
|
+
class FeedForward(nn.Module):
|
457
|
+
'''
|
458
|
+
|
459
|
+
feed forward layer used after multi-head attention consisting of 2 lieanr layers with
|
460
|
+
a ReLU in between. Linear layers expand from n_embed to n_embed * expansion_factor and
|
461
|
+
back to n_embed.
|
462
|
+
|
463
|
+
'''
|
464
|
+
def __init__(self,
|
465
|
+
n_embed:int,
|
466
|
+
expansion_factor:int,
|
467
|
+
dropout:float=0.1
|
468
|
+
) -> None:
|
469
|
+
super().__init__()
|
470
|
+
self.net = nn.Sequential(
|
471
|
+
nn.Linear(n_embed, expansion_factor * n_embed),
|
472
|
+
nn.ReLU(),
|
473
|
+
nn.Linear(expansion_factor * n_embed, n_embed),
|
474
|
+
nn.Dropout(dropout),
|
475
|
+
)
|
476
|
+
|
477
|
+
def forward(self,
|
478
|
+
x:torch.tensor
|
479
|
+
) -> torch.tensor:
|
480
|
+
return self.net(x)
|
481
|
+
|
482
|
+
class EncoderBlock(nn.Module):
|
483
|
+
'''
|
484
|
+
|
485
|
+
encoder block consists of a sequence of multi-head attention, LayerNorm, feed-forward, LayerNorm
|
486
|
+
head_size is calculated from n_embed // n_head
|
487
|
+
|
488
|
+
'''
|
489
|
+
def __init__(self,
|
490
|
+
n_embed:int,
|
491
|
+
n_head:int,
|
492
|
+
expansion_factor:int,
|
493
|
+
dropout:float=0.1
|
494
|
+
) -> None:
|
495
|
+
super().__init__()
|
496
|
+
head_size = n_embed // n_head
|
497
|
+
self.sa = MultiHeadAttention(n_head, head_size, n_embed, dropout, triangle_mask=False)
|
498
|
+
self.ffwd = FeedForward(n_embed, expansion_factor, dropout)
|
499
|
+
self.ln1 = nn.LayerNorm(n_embed)
|
500
|
+
self.ln2 = nn.LayerNorm(n_embed)
|
501
|
+
|
502
|
+
def forward(self,
|
503
|
+
x:torch.tensor,
|
504
|
+
mask:torch.tensor=None
|
505
|
+
) -> tuple[torch.tensor]:
|
506
|
+
att = self.sa(x, x, x, mask=mask)
|
507
|
+
x = self.ln1(att + x)
|
508
|
+
ff = self.ffwd(x)
|
509
|
+
out = self.ln2(ff + x)
|
510
|
+
return out, mask
|
511
|
+
|
512
|
+
|
513
|
+
class DecoderBlock(nn.Module):
|
514
|
+
'''
|
515
|
+
|
516
|
+
decoder block consists of a sequence of multi-head attention, LayerNorm, feed-forward, LayerNorm
|
517
|
+
if cross-attention is True, a multi-head attention block and layerNorm is added before feed-forward
|
518
|
+
taking specified enc_k and enc_v tensors as value and key tensors. These values should be the output
|
519
|
+
of an encoder block.
|
520
|
+
head_size is calculated from n_embed // n_head
|
521
|
+
|
522
|
+
'''
|
523
|
+
def __init__(self,
|
524
|
+
n_embed:int,
|
525
|
+
n_head:int,
|
526
|
+
expansion_factor:int,
|
527
|
+
cross_attention:bool=False,
|
528
|
+
block_size:int=0,
|
529
|
+
dropout:float=0.1
|
530
|
+
) -> None:
|
531
|
+
super().__init__()
|
532
|
+
head_size = n_embed // n_head
|
533
|
+
self.sa = MultiHeadAttention(n_head, head_size, n_embed, dropout, block_size=block_size, triangle_mask=True)
|
534
|
+
self.ffwd = FeedForward(n_embed, expansion_factor, dropout)
|
535
|
+
self.ln1 = nn.LayerNorm(n_embed)
|
536
|
+
self.ln2 = nn.LayerNorm(n_embed)
|
537
|
+
if cross_attention:
|
538
|
+
self.ca = MultiHeadAttention(n_head, head_size, n_embed, dropout, triangle_mask=False)
|
539
|
+
self.ln3 = nn.LayerNorm(n_embed)
|
540
|
+
else:
|
541
|
+
self.ca = None
|
542
|
+
|
543
|
+
def forward(self,
|
544
|
+
x:torch.tensor,
|
545
|
+
enc_k:torch.tensor,
|
546
|
+
enc_v:torch.tensor,
|
547
|
+
mask_out:bool=None,
|
548
|
+
mask_in:torch.tensor=None
|
549
|
+
) -> tuple[torch.tensor]:
|
550
|
+
att = self.sa(x, x, x, mask=mask_out)
|
551
|
+
x = self.ln1(att + x)
|
552
|
+
if self.ca != None:
|
553
|
+
catt = self.ca(enc_k, x, enc_v, mask=mask_in)
|
554
|
+
x = self.ln3(catt + x)
|
555
|
+
ff = self.ffwd(x)
|
556
|
+
out = self.ln2(ff + x)
|
557
|
+
return out, enc_k, enc_v, mask_out, mask_in
|
558
|
+
|
559
|
+
class MySequential(nn.Sequential):
|
560
|
+
'''
|
561
|
+
|
562
|
+
MySequential serves the same purpose as nn.Sequential but allows for multiple inputs and outputs
|
563
|
+
|
564
|
+
'''
|
565
|
+
def forward(self, *input):
|
566
|
+
for module in self._modules.values():
|
567
|
+
input = module(*input)
|
568
|
+
return input
|
569
|
+
|
570
|
+
class RoboConstructor(nn.Module):
|
571
|
+
'''
|
572
|
+
|
573
|
+
RoboConstructor assembles an encoder-decoder or decoder-only transformer.
|
574
|
+
if the enc_* variables are not specified, or enc_n_blocks==0, the transformer will be decoder-only.
|
575
|
+
- if any of the dec_* variables are not specified (except dec_expansion_factor) an error will occur.
|
576
|
+
- if enc_n_blocks > 0 and any of the enc_* variables are not specified (except enc_expansion_factor and enc_block_size) an error will occur.
|
577
|
+
dropout can be specified, default=0.1.
|
578
|
+
if device is not specified, device will default to first available among ("cuda", "mps", "cpu")
|
579
|
+
|
580
|
+
prep_data() function returns a batch of specified batch_size, from dec_data (and dec_masks, enc_data and enc_masks if specified)
|
581
|
+
- if encoder is configured in this instance, enc_data must be specified.
|
582
|
+
- dec_block_size must be specified.
|
583
|
+
- if enc_block_size is not specified, the entire block_size of enc_data will be used.
|
584
|
+
this function is for use in train_robo()
|
585
|
+
|
586
|
+
train_robo() function trains the RoboConstructor instance transformer.
|
587
|
+
- training parameters can be specified such as max_iters, eval_interval, batch_size, eval_iters, learning_rate, label_smoothing.
|
588
|
+
- paths must be specified for decoder training data (and encoder training data if encoder-decoder transformer)
|
589
|
+
- optional paths to specify: decoder and encoder masks, decoder and encoder validation data, decoder and encoder validation masks data
|
590
|
+
- if neither pad_token or tokenizer is specified (or tokenizer has no pad_token), any padding in labels will contribute towards the loss
|
591
|
+
which may cause unwanted results. Specifying pad_token and/or tokenizer allows loss to be calculated while ignoring any padding in labels
|
592
|
+
- specify save_path to save the model as a .pkl file every eval_interval iterations using the save_component function.
|
593
|
+
|
594
|
+
generate() function uses the tranformer model from the RoboConstructor instance to generate an output from an input.
|
595
|
+
- input can be in the form of a string if input tokenizer is specified (enc_tokenizer for encoder-decoder and dec_tokenizder for decoder-only),
|
596
|
+
otherwise, it must be in the form of a list of tokens.
|
597
|
+
- if dec_tokenizer is specified, output will be a string.
|
598
|
+
- new tokens are generated until the dec_end_token (or dec_tokenizer.end_token) is generated, or the number of tokens generated == max_new_tokens.
|
599
|
+
- if input tokenizer is not specified, or input tokenizer.start_token is None, enc_start_token must be specified for an encoder-decoder model.
|
600
|
+
- separator_token is used to separate the input and generated tokens for a decoder-only model. If this value is not specified, there
|
601
|
+
will be no distinction between input tokens and generated tokens to the transformer, even if dec_tokenizer is specified.
|
602
|
+
- if new_line_token is not specified, output will be returned in one line, without any "\n" line separators.
|
603
|
+
- temperature, top_k and top_p can be specified to adjust the output.
|
604
|
+
|
605
|
+
'''
|
606
|
+
def __init__(self,
|
607
|
+
n_embed:int,
|
608
|
+
dec_n_blocks:int,
|
609
|
+
dec_n_head:int,
|
610
|
+
dec_vocab_size:int,
|
611
|
+
dec_block_size:int,
|
612
|
+
dec_expansion_factor:int=4,
|
613
|
+
enc_n_blocks:int=0,
|
614
|
+
enc_n_head:int=None,
|
615
|
+
enc_vocab_size:int=None,
|
616
|
+
enc_block_size:int=None,
|
617
|
+
enc_expansion_factor:int=4,
|
618
|
+
dropout:float=0.1,
|
619
|
+
device:str=None
|
620
|
+
) -> None:
|
621
|
+
super().__init__()
|
622
|
+
self.n_embed = n_embed
|
623
|
+
self.dec_n_blocks = dec_n_blocks
|
624
|
+
self.dec_n_head = dec_n_head
|
625
|
+
self.dec_vocab_size = dec_vocab_size
|
626
|
+
self.dec_block_size = dec_block_size
|
627
|
+
self.dec_expansion_factor = dec_expansion_factor
|
628
|
+
self.dropout = dropout
|
629
|
+
|
630
|
+
if device == None:
|
631
|
+
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
632
|
+
else:
|
633
|
+
self.device = device
|
634
|
+
self.dec_token_embedding_table = nn.Embedding(dec_vocab_size, n_embed)
|
635
|
+
self.dec_positional_embedding_table = nn.Embedding(dec_block_size, n_embed)
|
636
|
+
|
637
|
+
if enc_n_blocks != 0:
|
638
|
+
self.enc_n_blocks = enc_n_blocks
|
639
|
+
self.enc_n_head = enc_n_head
|
640
|
+
self.enc_expansion_factor = enc_expansion_factor
|
641
|
+
self.enc_vocab_size = enc_vocab_size
|
642
|
+
self.enc_block_size = enc_block_size
|
643
|
+
self.cross_attention = True
|
644
|
+
self.enc_token_embedding_table = nn.Embedding(enc_vocab_size, n_embed)
|
645
|
+
self.enc_positional_embedding_table = nn.Embedding(enc_block_size, n_embed)
|
646
|
+
self.encoder_blocks = MySequential(*[EncoderBlock(n_embed, enc_n_head, enc_expansion_factor, dropout=dropout) for _ in range(enc_n_blocks)])
|
647
|
+
else:
|
648
|
+
self.cross_attention = False
|
649
|
+
self.enc_block_size = None
|
650
|
+
|
651
|
+
self.decoder_blocks = MySequential(*[DecoderBlock(n_embed, dec_n_head, dec_expansion_factor, cross_attention=self.cross_attention, block_size=self.dec_block_size, dropout=dropout) for _ in range(dec_n_blocks)])
|
652
|
+
self.ln = nn.LayerNorm(n_embed)
|
653
|
+
self.lid = nn.Linear(n_embed, dec_vocab_size)
|
654
|
+
|
655
|
+
self.apply(self._init_weights)
|
656
|
+
|
657
|
+
def _init_weights(self, module) -> None:
|
658
|
+
if isinstance(module, nn.Linear):
|
659
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
660
|
+
if module.bias is not None:
|
661
|
+
torch.nn.init.zeros_(module.bias)
|
662
|
+
elif isinstance(module, nn.Embedding):
|
663
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
664
|
+
|
665
|
+
def forward(self,
|
666
|
+
dec_in:torch.tensor,
|
667
|
+
dec_mask:torch.tensor=None,
|
668
|
+
enc_in:torch.tensor=None,
|
669
|
+
enc_mask:torch.tensor=None
|
670
|
+
) -> torch.tensor:
|
671
|
+
_, dec_T = dec_in.shape
|
672
|
+
if enc_in != None:
|
673
|
+
_, enc_T = enc_in.shape
|
674
|
+
|
675
|
+
dec_tok_emb = self.dec_token_embedding_table(dec_in)
|
676
|
+
dec_pos_emb = self.dec_positional_embedding_table(torch.arange(dec_T, device=self.device))
|
677
|
+
dec_x = dec_tok_emb + dec_pos_emb
|
678
|
+
|
679
|
+
if self.cross_attention:
|
680
|
+
enc_tok_emb = self.enc_token_embedding_table(enc_in)
|
681
|
+
enc_pos_emb = self.enc_positional_embedding_table(torch.arange(enc_T, device=self.device))
|
682
|
+
enc_x = enc_tok_emb + enc_pos_emb
|
683
|
+
|
684
|
+
enc_out, enc_mask = self.encoder_blocks(enc_x, enc_mask)
|
685
|
+
else:
|
686
|
+
enc_out = None
|
687
|
+
|
688
|
+
x, _, _, _, _ = self.decoder_blocks(dec_x, enc_out, enc_out, dec_mask, enc_mask)
|
689
|
+
x = self.ln(x)
|
690
|
+
proj_output = self.lid(x)
|
691
|
+
|
692
|
+
return proj_output
|
693
|
+
|
694
|
+
|
695
|
+
def prep_data(self,
|
696
|
+
batch_size:int,
|
697
|
+
dec_data:str,
|
698
|
+
dec_block_size:int,
|
699
|
+
dec_masks:str=None,
|
700
|
+
enc_data:str=None,
|
701
|
+
enc_block_size:int=None,
|
702
|
+
enc_masks:str=None
|
703
|
+
) -> tuple[torch.tensor]:
|
704
|
+
random_samples = torch.randint(dec_data.shape[0], (batch_size,))
|
705
|
+
|
706
|
+
dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
|
707
|
+
dec_train_batch_in = dec_train_batch_in.to(self.device)
|
708
|
+
dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out != None else None
|
709
|
+
dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in != None else None
|
710
|
+
|
711
|
+
if self.cross_attention:
|
712
|
+
enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
|
713
|
+
enc_train_batch_in = enc_train_batch_in.to(self.device)
|
714
|
+
enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in != None else None
|
715
|
+
else:
|
716
|
+
enc_train_batch_in = None
|
717
|
+
enc_train_masks_in = None
|
718
|
+
|
719
|
+
return dec_train_batch_in, dec_train_batch_out, dec_train_masks_in, enc_train_batch_in, enc_train_masks_in
|
720
|
+
|
721
|
+
|
722
|
+
def train_robo(self,
|
723
|
+
max_iters:int,
|
724
|
+
eval_interval:int,
|
725
|
+
batch_size:int,
|
726
|
+
dec_training_path:str,
|
727
|
+
dec_eval_path:str=None,
|
728
|
+
dec_training_masks_path:str=None,
|
729
|
+
dec_eval_masks_path:str=None,
|
730
|
+
enc_training_path:str=None,
|
731
|
+
enc_eval_path:str=None,
|
732
|
+
enc_training_masks_path:str=None,
|
733
|
+
enc_eval_masks_path:str=None,
|
734
|
+
eval_iters:int=3,
|
735
|
+
learning_rate:float=1e-4,
|
736
|
+
pad_token:int=None,
|
737
|
+
tokenizer:TokenizerConstructor=None,
|
738
|
+
save_path:str=None,
|
739
|
+
label_smoothing:float=0.1
|
740
|
+
) -> None:
|
741
|
+
|
742
|
+
dec_training_data = torch.load(dec_training_path, weights_only=True)
|
743
|
+
dec_eval_data = torch.load(dec_eval_path, weights_only=True) if dec_eval_path != None else None
|
744
|
+
dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if dec_training_masks_path != None else None
|
745
|
+
dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if dec_eval_masks_path != None else None
|
746
|
+
enc_training_data = torch.load(enc_training_path, weights_only=True) if enc_training_path != None else None
|
747
|
+
enc_eval_data = torch.load(enc_eval_path, weights_only=True) if enc_eval_path != None else None
|
748
|
+
enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if enc_training_masks_path != None else None
|
749
|
+
enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if enc_eval_masks_path != None else None
|
750
|
+
|
751
|
+
if pad_token == None and tokenizer != None:
|
752
|
+
pad_token = tokenizer.pad_token
|
753
|
+
|
754
|
+
self.to(self.device)
|
755
|
+
|
756
|
+
if pad_token != None:
|
757
|
+
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
|
758
|
+
else:
|
759
|
+
loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
|
760
|
+
print(sum(p.numel() for p in self.parameters())/1e6, "M parameters")
|
761
|
+
optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
|
762
|
+
@torch.no_grad()
|
763
|
+
def estimate_loss() -> dict:
|
764
|
+
out = {}
|
765
|
+
self.eval()
|
766
|
+
losses = torch.zeros(eval_iters)
|
767
|
+
for k in range(eval_iters):
|
768
|
+
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
|
769
|
+
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
770
|
+
losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
|
771
|
+
out["train"] = losses.mean()
|
772
|
+
if dec_eval_data != None:
|
773
|
+
for k in range(eval_iters):
|
774
|
+
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
|
775
|
+
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
776
|
+
losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
|
777
|
+
out["eval"] = losses.mean()
|
778
|
+
else:
|
779
|
+
out["eval"] = np.nan
|
780
|
+
self.train()
|
781
|
+
return out
|
782
|
+
|
783
|
+
self.train()
|
784
|
+
for iter in range(max_iters):
|
785
|
+
if iter % eval_interval == 0 or iter == max_iters-1:
|
786
|
+
losses = estimate_loss()
|
787
|
+
print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
|
788
|
+
if save_path != None:
|
789
|
+
save_component(self, save_path=save_path)
|
790
|
+
|
791
|
+
dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
|
792
|
+
proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
|
793
|
+
loss = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
|
794
|
+
loss.backward()
|
795
|
+
optimizer.step()
|
796
|
+
optimizer.zero_grad()
|
797
|
+
|
798
|
+
self.eval()
|
799
|
+
|
800
|
+
# use dec and enc tokenizers
|
801
|
+
def generate(self,
|
802
|
+
inputs:list[int]|str,
|
803
|
+
max_new_tokens:int=None,
|
804
|
+
dec_tokenizer:TokenizerConstructor=None,
|
805
|
+
enc_tokenizer:TokenizerConstructor=None,
|
806
|
+
dec_start_token:int=None,
|
807
|
+
enc_start_token:int=None,
|
808
|
+
enc_end_token:int=None,
|
809
|
+
dec_end_token:int=None,
|
810
|
+
separator_token:int=None,
|
811
|
+
new_line_token:int=None,
|
812
|
+
temperature:float=1,
|
813
|
+
top_k:int=None,
|
814
|
+
top_p:float=None
|
815
|
+
) -> list[int]|str:
|
816
|
+
max_new_tokens = self.dec_block_size if max_new_tokens == None else max_new_tokens
|
817
|
+
|
818
|
+
if self.cross_attention:
|
819
|
+
if enc_tokenizer != None:
|
820
|
+
if enc_start_token == None:
|
821
|
+
enc_start_token = enc_tokenizer.start_token
|
822
|
+
if enc_end_token == None:
|
823
|
+
enc_end_token = enc_tokenizer.end_token
|
824
|
+
if isinstance(inputs, str):
|
825
|
+
inputs = enc_tokenizer.encode(inputs)
|
826
|
+
|
827
|
+
if dec_tokenizer != None:
|
828
|
+
if dec_start_token == None:
|
829
|
+
dec_start_token = dec_tokenizer.start_token
|
830
|
+
if dec_end_token == None:
|
831
|
+
dec_end_token = dec_tokenizer.end_token
|
832
|
+
if new_line_token == None:
|
833
|
+
new_line_token = dec_tokenizer.new_line_token
|
834
|
+
if self.cross_attention == False and isinstance(inputs, str):
|
835
|
+
inputs = dec_tokenizer.encode(inputs)
|
836
|
+
|
837
|
+
|
838
|
+
if self.cross_attention:
|
839
|
+
enc_input = torch.tensor([[enc_start_token] + inputs + [enc_end_token]], dtype=torch.long, device=self.device)
|
840
|
+
idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
|
841
|
+
else:
|
842
|
+
enc_input = None
|
843
|
+
if separator_token != None:
|
844
|
+
idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
|
845
|
+
else:
|
846
|
+
idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
|
847
|
+
|
848
|
+
self.eval()
|
849
|
+
for _ in range(1, max_new_tokens):
|
850
|
+
idx_cond = idx[:, -self.dec_block_size:] if idx.shape[1] > self.dec_block_size else idx
|
851
|
+
|
852
|
+
proj_output = self(idx_cond, enc_in=enc_input)
|
853
|
+
|
854
|
+
logits = proj_output[:, -1, :]
|
855
|
+
probabilities = F.log_softmax(logits/temperature, dim=-1)
|
856
|
+
|
857
|
+
if top_k == None and top_p == None:
|
858
|
+
idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
|
859
|
+
else:
|
860
|
+
idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
|
861
|
+
idx = torch.cat((idx, idx_next), dim=-1)
|
862
|
+
if idx_next[0] == dec_end_token:
|
863
|
+
break
|
864
|
+
|
865
|
+
if dec_tokenizer == None:
|
866
|
+
return idx[0].tolist()
|
867
|
+
else:
|
868
|
+
if new_line_token != None:
|
869
|
+
return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
|
870
|
+
else:
|
871
|
+
return dec_tokenizer.decode(idx[0].tolist())
|
872
|
+
|
873
|
+
|
874
|
+
def save_component(component, save_path:str) -> None:
|
875
|
+
'''
|
876
|
+
|
877
|
+
saves component (such as TokenizerConstructor or RoboConstructor) as .pkl file.
|
878
|
+
|
879
|
+
'''
|
880
|
+
save_path = save_path + ".pkl" if save_path[-4:] != ".pkl" else save_path
|
881
|
+
with open(save_path, "wb") as comp:
|
882
|
+
pickle.dump(component, comp, pickle.HIGHEST_PROTOCOL)
|
883
|
+
|
884
|
+
def load_component(load_path:str):
|
885
|
+
'''
|
886
|
+
|
887
|
+
loads saved .pkl file into variable.
|
888
|
+
|
889
|
+
'''
|
890
|
+
load_path = load_path + ".pkl" if load_path[-4:] != ".pkl" else load_path
|
891
|
+
with open(load_path, "rb") as comp:
|
892
|
+
loaded_component = pickle.load(comp)
|
893
|
+
return loaded_component
|
@@ -0,0 +1,18 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: robo_lib
|
3
|
+
Version: 0.0.4
|
4
|
+
Summary: A package to configure, create and train transformer models.
|
5
|
+
Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
|
6
|
+
Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
|
7
|
+
Author-email: Erik Papp <erik3papp@gmail.com>
|
8
|
+
License-File: LICENSE
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
10
|
+
Classifier: Operating System :: OS Independent
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
12
|
+
Requires-Python: >=3.8
|
13
|
+
Requires-Dist: numpy
|
14
|
+
Requires-Dist: tokenizers
|
15
|
+
Requires-Dist: torch
|
16
|
+
Description-Content-Type: text/markdown
|
17
|
+
|
18
|
+
# robo_pack
|
@@ -0,0 +1,6 @@
|
|
1
|
+
robo_lib/__init__.py,sha256=iVOAsANj0lScVW9KKMxCULYmpp0cv4sv1k3sHjBSlE0,1012
|
2
|
+
robo_lib/components.py,sha256=kNtDslsSfjV4b9mKxGB7ZjbjdPvk-o_1i0AKyA6c4Mk,42025
|
3
|
+
robo_lib-0.0.4.dist-info/METADATA,sha256=mAADQzCKvZLsTrGwxUGFj2SIgfTQLKOhutm4tX9CuJw,616
|
4
|
+
robo_lib-0.0.4.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
5
|
+
robo_lib-0.0.4.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
|
6
|
+
robo_lib-0.0.4.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) [2024] [Erik Papp]
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|