robo-lib 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
robo_lib/__init__.py ADDED
@@ -0,0 +1,18 @@
1
+ from .components import TokenizerConstructor as TokenizerConstructor
2
+ from .components import create_mask as create_mask
3
+ from .components import pad as pad
4
+ from .components import process_row as process_row
5
+ from .components import scan_max_block_size as scan_max_block_size
6
+ from .components import DataProcessor as DataProcessor
7
+ from .components import get_valid_samples as get_valid_samples
8
+ from .components import get_batch as get_batch
9
+ from .components import top_kp_filter as top_kp_filter
10
+ from .components import SelfAttention as SelfAttention
11
+ from .components import MultiHeadAttention as MultiHeadAttention
12
+ from .components import FeedForward as FeedForward
13
+ from .components import EncoderBlock as EncoderBlock
14
+ from .components import DecoderBlock as DecoderBlock
15
+ from .components import MySequential as MySequential
16
+ from .components import RoboConstructor as RoboConstructor
17
+ from .components import save_component as save_component
18
+ from .components import load_component as load_component
robo_lib/components.py ADDED
@@ -0,0 +1,893 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import tokenizers
5
+ import numpy as np
6
+ import random
7
+ import pickle
8
+ import itertools
9
+
10
+ class TokenizerConstructor:
11
+ '''
12
+
13
+ simple assembler for tokenizer using the tokenizers library
14
+ tokenizer parameters can be set using strings and list[string]s
15
+ strings used for tokenizer_type, pre_tokenizers, normalizers arguments are the names of those present in the
16
+ tokenizers library. Additionally "IndividualDigits" can be used in normalizers for tokenizers.pre_tokenizers.Digits(individual_digits=True)
17
+
18
+ train([paths]) function points to text files to be used for training the tokenizer instance
19
+
20
+ encode(string) function encodes string using trained tokenizer instance
21
+
22
+ decode(list[int]) function decodes list of tokenz using trained tokenizer instance
23
+
24
+ vocab_size attribute returns the tokenizer instance's vocab_size (untrained tokenizer will have vocab_size=None)
25
+
26
+
27
+ '''
28
+ def __init__(self,
29
+ min_frequency:int=2,
30
+ tokenizer_type:str="BPE",
31
+ pre_tokenizers:list[str]|str=["Whitespace"],
32
+ normalizers:list[str]|str=["Lowercase", "NFD", "StripAccents", "Strip"],
33
+ special_tokens:list[str]|str=[],
34
+ unknown_token_string:str="<unk>",
35
+ start_token_string:str="<sos>",
36
+ end_token_string:str="<eos>",
37
+ pad_token_string:str="<pad>",
38
+ new_line_token_string:str="\n",
39
+ vocab_size:int=30000
40
+ ) -> None:
41
+ self.vocab_size = None
42
+
43
+ if isinstance(special_tokens, str):
44
+ special_tokens = [special_tokens]
45
+ self.special_tokens = special_tokens + [token for token in [unknown_token_string, start_token_string, end_token_string, pad_token_string, new_line_token_string] if token not in special_tokens and token != None]
46
+ self.unknown_token = self.special_tokens.index(unknown_token_string) if unknown_token_string != None else None
47
+ self.start_token = self.special_tokens.index(start_token_string) if start_token_string != None else None
48
+ self.end_token = self.special_tokens.index(end_token_string) if end_token_string != None else None
49
+ self.pad_token = self.special_tokens.index(pad_token_string) if pad_token_string != None else None
50
+ self.new_line_token = self.special_tokens.index(new_line_token_string) if new_line_token_string != None else None
51
+
52
+ if tokenizer_type == "BPE":
53
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token=unknown_token_string))
54
+ self.trainer = tokenizers.trainers.BpeTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
55
+ elif tokenizer_type == "WordLevel":
56
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordLevel(unk_token=unknown_token_string))
57
+ self.trainer = tokenizers.trainers.WordLevelTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
58
+ elif tokenizer_type == "WordPiece":
59
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.WordPiece(unk_token=unknown_token_string))
60
+ self.trainer = tokenizers.trainers.WordPieceTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
61
+ elif tokenizer_type == "Unigram":
62
+ self.tokenizer_type = tokenizers.Tokenizer(tokenizers.models.Unigram(unk_token=unknown_token_string))
63
+ self.trainer = tokenizers.trainers.UnigramTrainer(special_tokens=self.special_tokens, min_frequency=min_frequency, vocab_size=vocab_size)
64
+
65
+ if isinstance(pre_tokenizers, str):
66
+ pre_tokenizers = [pre_tokenizers]
67
+ sequence = []
68
+ for pre_tok in pre_tokenizers:
69
+ if pre_tok == "Whitespace":
70
+ sequence.append(tokenizers.pre_tokenizers.Whitespace())
71
+ elif pre_tok == "IndividualDigits":
72
+ sequence.append(tokenizers.pre_tokenizers.Digits(individual_digits=True))
73
+ elif pre_tok == "Digits":
74
+ sequence.append(tokenizers.pre_tokenizers.Digits(individual_digits=False))
75
+ elif pre_tok == "BertPreTokenizer":
76
+ sequence.append(tokenizers.pre_tokenizers.BertPreTokenizer())
77
+ elif pre_tok == "ByteLevel":
78
+ sequence.append(tokenizers.pre_tokenizers.ByteLevel())
79
+ elif pre_tok == "Metaspace":
80
+ sequence.append(tokenizers.pre_tokenizers.Metaspace())
81
+ elif pre_tok == "Punctuation":
82
+ sequence.append(tokenizers.pre_tokenizers.Punctuation())
83
+ elif pre_tok == "UnicodeScripts":
84
+ sequence.append(tokenizers.pre_tokenizers.UnicodeScripts())
85
+ elif pre_tok == "WhitespaceSplit":
86
+ sequence.append(tokenizers.pre_tokenizers.WhitespaceSplit())
87
+ self.tokenizer_type.pre_tokenizer = tokenizers.pre_tokenizers.Sequence(sequence)
88
+
89
+ if isinstance(normalizers, str):
90
+ normalizers = [normalizers]
91
+ sequence = []
92
+ for norm in normalizers:
93
+ if norm == "Lowercase":
94
+ sequence.append(tokenizers.normalizers.Lowercase())
95
+ elif norm == "NFC":
96
+ sequence.append(tokenizers.normalizers.NFC())
97
+ elif norm == "NFD":
98
+ sequence.append(tokenizers.normalizers.NFD())
99
+ elif norm == "NFKC":
100
+ sequence.append(tokenizers.normalizers.NFKC())
101
+ elif norm == "NFKD":
102
+ sequence.append(tokenizers.normalizers.NFKD())
103
+ elif norm == "Nmt":
104
+ sequence.append(tokenizers.normalizers.Nmt())
105
+ elif norm == "BertNormalizer":
106
+ sequence.append(tokenizers.normalizers.BertNormalizer())
107
+ elif norm == "StripAccents":
108
+ sequence.append(tokenizers.normalizers.StripAccents())
109
+ elif norm == "Strip":
110
+ sequence.append(tokenizers.normalizers.Strip())
111
+ elif norm == "BertNormalizer":
112
+ sequence.append(tokenizers.normalizers.BertNormalizer())
113
+ self.tokenizer_type.normalizer = tokenizers.normalizers.Sequence(sequence)
114
+
115
+
116
+ def train(self, training_paths:list[str]|str) -> None:
117
+ if isinstance(training_paths, str):
118
+ training_paths = [training_paths]
119
+ self.tokenizer_type.train(training_paths, trainer=self.trainer)
120
+ self.vocab_size = self.tokenizer_type.get_vocab_size()
121
+
122
+ def encode(self, inp:str) -> list[int]:
123
+ return self.tokenizer_type.encode(inp).ids
124
+
125
+ def decode(self, inp:list[int]) -> str:
126
+ return self.tokenizer_type.decode(inp)
127
+
128
+
129
+
130
+ def create_mask(row:list, block_size:int) -> list[bool]:
131
+ '''
132
+
133
+ creates a mask list of length block_size for row, asuming mask does cover the entire row input
134
+
135
+ '''
136
+ mask = [1]*len(row) + [0]*(block_size - len(row))
137
+ return mask
138
+
139
+ def pad(row:list, block_size:int, pad_token:int) -> list[int]:
140
+ '''
141
+
142
+ returns padded row. Row is padded until length block_size with specified pad_token value
143
+
144
+ '''
145
+ row.extend([pad_token]*(block_size - len(row)))
146
+ return row
147
+
148
+ def process_row(row:str, tokenizer:TokenizerConstructor) -> list[int]:
149
+ '''
150
+
151
+ returns tokenized row using specified tokenizer, and adds the tokenizer's start and end tokens if they exist
152
+
153
+ '''
154
+ processed_row = tokenizer.encode(row)
155
+ if tokenizer.start_token != None:
156
+ processed_row.insert(0, tokenizer.start_token)
157
+ if tokenizer.end_token != None:
158
+ processed_row.append(tokenizer.end_token)
159
+
160
+ return processed_row
161
+
162
+ def scan_max_block_size(data:list[str], tokenizer:TokenizerConstructor) -> int:
163
+ '''
164
+
165
+ returns max_block_size of given list of strings by taking the length of the longest process_row(row) in data
166
+
167
+ '''
168
+ lengths = [len(process_row(p, tokenizer)) for p in data]
169
+ max_block_size_scanner = max(lengths)
170
+ return max_block_size_scanner
171
+
172
+
173
+ class DataProcessor:
174
+ '''
175
+
176
+ data processor can be instantiated by specifying the tokenizer(s) for decoder and encoder data
177
+
178
+ process_list() function processes raw data in the form of list[str] or str for decoder and encoder simultaneously and
179
+ saves them to save_path as .pt files.
180
+ - encoder and decoder input data should have matching input and outputs so enc_data[n] should have its corresponding
181
+ decoder data at dec_data[n].
182
+ - max block size can be specified for both input and output, default takes the max
183
+ block size provided in the data respectively.
184
+ - if enc/dec_block_size is specified and enc/dec_block_size_exceeded_policy is not, an error will occur if a piece
185
+ of data larger than enc/dec_block_size is encountered. enc/dec_block_size_exceeded_policy can be set to "skip" or
186
+ "trim" to skip rows larger than enc/dec_block_size or truncate the row to specified enc/dec_block_size respectively.
187
+ - enc/dec_create_masks saves masks tensors to save_path as .pt files.
188
+
189
+
190
+ '''
191
+ def __init__(self,
192
+ dec_tokenizer:TokenizerConstructor,
193
+ enc_tokenizer:TokenizerConstructor=None
194
+ ) -> None:
195
+ self.dec_tokenizer = dec_tokenizer
196
+ self.enc_tokenizer = enc_tokenizer
197
+
198
+ def process_list(self,
199
+ save_path:str,
200
+ dec_data:list[str]|str,
201
+ dec_max_block_size:int=None,
202
+ dec_create_masks:bool=True,
203
+ dec_block_size_exceeded_policy:str=None,
204
+ enc_data:list[str]=None,
205
+ enc_create_masks=True,
206
+ enc_max_block_size:int=None,
207
+ enc_block_size_exceeded_policy:str=None
208
+ ) -> None:
209
+
210
+ if isinstance(dec_data, str):
211
+ dec_data = [dec_data]
212
+ dec_data_length = len(dec_data)
213
+ save_path = save_path.replace(".pt", "")
214
+
215
+ if dec_max_block_size == None:
216
+ dec_max_block_size = scan_max_block_size(dec_data, self.dec_tokenizer)
217
+
218
+ if enc_data != None:
219
+ self.enc_tokenizer = self.dec_tokenizer if self.enc_tokenizer == None else self.enc_tokenizer
220
+
221
+ enc_data_length = len(enc_data)
222
+ if dec_data_length != enc_data_length:
223
+ raise Exception(f"decoder and encoder lengths do not match. decoder_data_length is {dec_data_length}, encoder_data_length is {enc_data_length}")
224
+
225
+ if enc_max_block_size == None:
226
+ enc_max_block_size = scan_max_block_size(enc_data, self.enc_tokenizer)
227
+
228
+ enc_out_list = [[]]*enc_data_length
229
+ enc_mask_list = [[]]*enc_data_length if enc_create_masks else []
230
+ else:
231
+ enc_out_list = []
232
+ enc_mask_list = []
233
+
234
+ dec_out_list = [[]]*dec_data_length
235
+ dec_mask_list = [[]]*dec_data_length if dec_create_masks else []
236
+ for index in range(len(dec_out_list)):
237
+ dec_processed_item = process_row(dec_data[index], self.dec_tokenizer)
238
+ if dec_max_block_size != None and len(dec_processed_item) > dec_max_block_size:
239
+ if dec_block_size_exceeded_policy == "trim":
240
+ dec_processed_item = dec_processed_item[:dec_max_block_size]
241
+ elif dec_block_size_exceeded_policy == "skip":
242
+ continue
243
+ elif dec_block_size_exceeded_policy == None:
244
+ raise Exception(f"encountered item in dec_data larger than maximum block size ({dec_max_block_size})")
245
+ if dec_create_masks:
246
+ dec_mask = create_mask(dec_processed_item, dec_max_block_size)
247
+ dec_processed_item = pad(dec_processed_item, dec_max_block_size, self.dec_tokenizer.pad_token)
248
+
249
+ if enc_data != None:
250
+ enc_processed_item = process_row(enc_data[index], self.enc_tokenizer)
251
+ if enc_max_block_size != None and len(enc_processed_item) > enc_max_block_size:
252
+ if enc_block_size_exceeded_policy == "trim":
253
+ enc_processed_item = enc_processed_item[:enc_max_block_size]
254
+ elif enc_block_size_exceeded_policy == "skip":
255
+ continue
256
+ elif enc_block_size_exceeded_policy == None:
257
+ raise Exception(f"encountered item in enc_data larger than maximum block size ({enc_max_block_size})")
258
+ if enc_create_masks:
259
+ enc_mask = create_mask(enc_processed_item, enc_max_block_size)
260
+ enc_processed_item = pad(enc_processed_item, enc_max_block_size, self.enc_tokenizer.pad_token)
261
+
262
+ dec_out_list[index] = torch.tensor(dec_processed_item, dtype=torch.long)
263
+ if dec_create_masks:
264
+ dec_mask_list[index] = torch.tensor(dec_mask, dtype=torch.bool)
265
+
266
+ if enc_data != None:
267
+ enc_out_list[index] = torch.tensor(enc_processed_item, dtype=torch.long)
268
+ if enc_create_masks:
269
+ enc_mask_list[index] = torch.tensor(enc_mask, dtype=torch.bool)
270
+
271
+ dec_out_list = torch.stack([row for row in dec_out_list if row != []])
272
+ torch.save(dec_out_list, save_path + "_decoder_data.pt")
273
+ if dec_create_masks:
274
+ dec_mask_list = torch.stack([row for row in dec_mask_list if row != []])
275
+ torch.save(dec_mask_list, save_path + "_decoder_mask_data.pt")
276
+ if enc_data != None:
277
+ enc_out_list = torch.stack([row for row in enc_out_list if row != []])
278
+ torch.save(enc_out_list, save_path + "_encoder_data.pt")
279
+ if enc_create_masks:
280
+ enc_mask_list = torch.stack([row for row in enc_mask_list if row != []])
281
+ torch.save(enc_mask_list, save_path + "_encoder_mask_data.pt")
282
+
283
+
284
+ def get_valid_samples(random_samples:torch.tensor,
285
+ masks:torch.tensor,
286
+ block_size:int
287
+ ) -> list[int]:
288
+ '''
289
+
290
+ returns list of len(random_samples) with values corresponding to index values of masks that ensure minimum masked
291
+ values when taking sample of length block_size
292
+
293
+ '''
294
+ valid_samples = [0 if sum(masks[row_num]) <= block_size else random.randint(0, sum(masks[row_num]) - block_size) for row_num in random_samples]
295
+ return valid_samples
296
+
297
+ def get_batch(data:torch.tensor,
298
+ random_samples:torch.tensor,
299
+ masks:torch.tensor=None,
300
+ block_size:int=None,
301
+ get_offset:bool=True
302
+ ) -> tuple[torch.tensor]:
303
+ '''
304
+
305
+ returns random batches from data tensor using random sample for data selection.
306
+ - returns corresponding batch offset by 1 unless get_offset=False
307
+ - returns corresponding masks batch if masks data is specified
308
+
309
+ '''
310
+ batch_size = len(random_samples)
311
+ if block_size != None and block_size != data.shape[1]:
312
+ if block_size >= data.shape[1]:
313
+ raise Exception(f"specified block size ({block_size}) is larger than input tensor length ({data.shape[1]})")
314
+
315
+ if masks != None:
316
+ random_point = get_valid_samples(random_samples, masks, block_size)
317
+ else:
318
+ random_point = torch.randint(data.shape[1] - block_size, (batch_size,))
319
+ batch_in = torch.stack([data[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)])
320
+ masks_in = torch.stack([masks[random_samples[i]][random_point[i]:random_point[i]+block_size-int(get_offset)] for i in range(batch_size)]) if masks != None else None
321
+ batch_out = torch.stack([data[random_samples[i]][1+random_point[i]:random_point[i]+block_size] for i in range(batch_size)]) if get_offset else None
322
+ else:
323
+ block_size = data.shape[1]
324
+ batch_in = torch.stack([data[row_num][:block_size-int(get_offset)] for row_num in random_samples])
325
+ masks_in = torch.stack([masks[row_num][:block_size-int(get_offset)] for row_num in random_samples]) if masks != None else None
326
+ batch_out = torch.stack([data[row_num][1:block_size] for row_num in random_samples]) if get_offset else None
327
+
328
+ return batch_in, batch_out, masks_in
329
+
330
+ def top_kp_filter(logits:torch.tensor,
331
+ top_k:int,
332
+ top_p:float=None
333
+ ) -> torch.tensor:
334
+ '''
335
+
336
+ returns predicted token by filtering output logits using top_k and top_p
337
+
338
+ '''
339
+ if top_p != None:
340
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
341
+ cumulative_probs = torch.cumsum(sorted_logits, dim=-1)
342
+
343
+ filter = cumulative_probs > top_p
344
+ filter[..., 1:] = filter[..., :-1].clone()
345
+ filter[..., 0] = 0
346
+ indices_to_remove = filter.scatter(1, sorted_indices, filter)
347
+ logits[indices_to_remove] = float("-inf")
348
+
349
+ if top_k != None:
350
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
351
+
352
+ sorted_logits = F.softmax(sorted_logits[:, :top_k], dim=-1)
353
+ sorted_indices = sorted_indices[:, :top_k].detach().cpu()
354
+ sorted_logits = sorted_logits.detach().cpu().numpy()
355
+ sorted_logits[0][0] += 1 - sum(sorted_logits[0])
356
+
357
+ selected = torch.tensor(np.random.choice(sorted_indices[0], 1, p=sorted_logits[0]), dtype=torch.long)
358
+
359
+ return selected
360
+
361
+
362
+
363
+ class SelfAttention(nn.Module):
364
+ '''
365
+
366
+ single self attention block of size head_size.
367
+ triangle_mask=True to apply look-ahead mask of size block_size.
368
+
369
+ '''
370
+ def __init__(self,
371
+ head_size:int,
372
+ n_embed:int,
373
+ dropout:float,
374
+ block_size:int=0,
375
+ triangle_mask:bool=True
376
+ ) -> None:
377
+ super().__init__()
378
+ self.key = nn.Linear(n_embed, head_size, bias=False)
379
+ self.query = nn.Linear(n_embed, head_size, bias=False)
380
+ self.value = nn.Linear(n_embed, head_size, bias=False)
381
+ self.triangle_mask = triangle_mask
382
+ self.block_size = block_size
383
+
384
+ if self.triangle_mask and self.block_size >= 0:
385
+ self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
386
+
387
+ self.dropout = nn.Dropout(dropout)
388
+
389
+ def forward(self,
390
+ k:torch.tensor,
391
+ q:torch.tensor,
392
+ v:torch.tensor,
393
+ mask:torch.tensor=None
394
+ ) -> torch.tensor:
395
+ '''
396
+
397
+ k, q and v are key, tensors to get key, query and value tensors.
398
+ custom mask tensor can be applied.
399
+
400
+ '''
401
+ _,T,_ = k.shape
402
+
403
+ k = self.key(k)
404
+ q = self.query(q)
405
+
406
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
407
+ if self.triangle_mask and self.block_size >= 0:
408
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
409
+ if mask != None:
410
+ wei = wei.masked_fill(mask.unsqueeze(1)==0, float("-inf"))
411
+ wei = F.softmax(wei, dim=-1)
412
+ wei = self.dropout(wei)
413
+
414
+ v = self.value(v)
415
+ out = wei @ v
416
+ return out
417
+
418
+ class MultiHeadAttention(nn.Module):
419
+ '''
420
+
421
+ multi-head attention block consisting of num_heads SelfAttention blocks and a linear layer to
422
+ rejoin outputs.
423
+ specified head_size, n_embed, dropout, block_size and triangle_mask values are passed through to
424
+ SelfAttention blocks
425
+
426
+ '''
427
+ def __init__(self,
428
+ num_heads:int,
429
+ head_size:int,
430
+ n_embed:int,
431
+ dropout:float=0.1,
432
+ block_size:int=0,
433
+ triangle_mask:bool=True
434
+ ) -> None:
435
+ super().__init__()
436
+ self.heads = nn.ModuleList([SelfAttention(head_size, n_embed, dropout, block_size=block_size, triangle_mask=triangle_mask) for _ in range(num_heads)])
437
+ self.proj = nn.Linear(head_size * num_heads, n_embed)
438
+ self.dropout = nn.Dropout(dropout)
439
+
440
+ def forward(self,
441
+ k:torch.tensor,
442
+ q:torch.tensor,
443
+ v:torch.tensor,
444
+ mask:torch.tensor=None
445
+ ) -> torch.tensor:
446
+ '''
447
+
448
+ k, q and v are key, tensors to get key, query and value tensors.
449
+ custom mask tensor can be applied.
450
+
451
+ '''
452
+ out = torch.cat([h(k, q, v, mask=mask) for h in self.heads], dim=-1)
453
+ out = self.dropout(self.proj(out))
454
+ return out
455
+
456
+ class FeedForward(nn.Module):
457
+ '''
458
+
459
+ feed forward layer used after multi-head attention consisting of 2 lieanr layers with
460
+ a ReLU in between. Linear layers expand from n_embed to n_embed * expansion_factor and
461
+ back to n_embed.
462
+
463
+ '''
464
+ def __init__(self,
465
+ n_embed:int,
466
+ expansion_factor:int,
467
+ dropout:float=0.1
468
+ ) -> None:
469
+ super().__init__()
470
+ self.net = nn.Sequential(
471
+ nn.Linear(n_embed, expansion_factor * n_embed),
472
+ nn.ReLU(),
473
+ nn.Linear(expansion_factor * n_embed, n_embed),
474
+ nn.Dropout(dropout),
475
+ )
476
+
477
+ def forward(self,
478
+ x:torch.tensor
479
+ ) -> torch.tensor:
480
+ return self.net(x)
481
+
482
+ class EncoderBlock(nn.Module):
483
+ '''
484
+
485
+ encoder block consists of a sequence of multi-head attention, LayerNorm, feed-forward, LayerNorm
486
+ head_size is calculated from n_embed // n_head
487
+
488
+ '''
489
+ def __init__(self,
490
+ n_embed:int,
491
+ n_head:int,
492
+ expansion_factor:int,
493
+ dropout:float=0.1
494
+ ) -> None:
495
+ super().__init__()
496
+ head_size = n_embed // n_head
497
+ self.sa = MultiHeadAttention(n_head, head_size, n_embed, dropout, triangle_mask=False)
498
+ self.ffwd = FeedForward(n_embed, expansion_factor, dropout)
499
+ self.ln1 = nn.LayerNorm(n_embed)
500
+ self.ln2 = nn.LayerNorm(n_embed)
501
+
502
+ def forward(self,
503
+ x:torch.tensor,
504
+ mask:torch.tensor=None
505
+ ) -> tuple[torch.tensor]:
506
+ att = self.sa(x, x, x, mask=mask)
507
+ x = self.ln1(att + x)
508
+ ff = self.ffwd(x)
509
+ out = self.ln2(ff + x)
510
+ return out, mask
511
+
512
+
513
+ class DecoderBlock(nn.Module):
514
+ '''
515
+
516
+ decoder block consists of a sequence of multi-head attention, LayerNorm, feed-forward, LayerNorm
517
+ if cross-attention is True, a multi-head attention block and layerNorm is added before feed-forward
518
+ taking specified enc_k and enc_v tensors as value and key tensors. These values should be the output
519
+ of an encoder block.
520
+ head_size is calculated from n_embed // n_head
521
+
522
+ '''
523
+ def __init__(self,
524
+ n_embed:int,
525
+ n_head:int,
526
+ expansion_factor:int,
527
+ cross_attention:bool=False,
528
+ block_size:int=0,
529
+ dropout:float=0.1
530
+ ) -> None:
531
+ super().__init__()
532
+ head_size = n_embed // n_head
533
+ self.sa = MultiHeadAttention(n_head, head_size, n_embed, dropout, block_size=block_size, triangle_mask=True)
534
+ self.ffwd = FeedForward(n_embed, expansion_factor, dropout)
535
+ self.ln1 = nn.LayerNorm(n_embed)
536
+ self.ln2 = nn.LayerNorm(n_embed)
537
+ if cross_attention:
538
+ self.ca = MultiHeadAttention(n_head, head_size, n_embed, dropout, triangle_mask=False)
539
+ self.ln3 = nn.LayerNorm(n_embed)
540
+ else:
541
+ self.ca = None
542
+
543
+ def forward(self,
544
+ x:torch.tensor,
545
+ enc_k:torch.tensor,
546
+ enc_v:torch.tensor,
547
+ mask_out:bool=None,
548
+ mask_in:torch.tensor=None
549
+ ) -> tuple[torch.tensor]:
550
+ att = self.sa(x, x, x, mask=mask_out)
551
+ x = self.ln1(att + x)
552
+ if self.ca != None:
553
+ catt = self.ca(enc_k, x, enc_v, mask=mask_in)
554
+ x = self.ln3(catt + x)
555
+ ff = self.ffwd(x)
556
+ out = self.ln2(ff + x)
557
+ return out, enc_k, enc_v, mask_out, mask_in
558
+
559
+ class MySequential(nn.Sequential):
560
+ '''
561
+
562
+ MySequential serves the same purpose as nn.Sequential but allows for multiple inputs and outputs
563
+
564
+ '''
565
+ def forward(self, *input):
566
+ for module in self._modules.values():
567
+ input = module(*input)
568
+ return input
569
+
570
+ class RoboConstructor(nn.Module):
571
+ '''
572
+
573
+ RoboConstructor assembles an encoder-decoder or decoder-only transformer.
574
+ if the enc_* variables are not specified, or enc_n_blocks==0, the transformer will be decoder-only.
575
+ - if any of the dec_* variables are not specified (except dec_expansion_factor) an error will occur.
576
+ - if enc_n_blocks > 0 and any of the enc_* variables are not specified (except enc_expansion_factor and enc_block_size) an error will occur.
577
+ dropout can be specified, default=0.1.
578
+ if device is not specified, device will default to first available among ("cuda", "mps", "cpu")
579
+
580
+ prep_data() function returns a batch of specified batch_size, from dec_data (and dec_masks, enc_data and enc_masks if specified)
581
+ - if encoder is configured in this instance, enc_data must be specified.
582
+ - dec_block_size must be specified.
583
+ - if enc_block_size is not specified, the entire block_size of enc_data will be used.
584
+ this function is for use in train_robo()
585
+
586
+ train_robo() function trains the RoboConstructor instance transformer.
587
+ - training parameters can be specified such as max_iters, eval_interval, batch_size, eval_iters, learning_rate, label_smoothing.
588
+ - paths must be specified for decoder training data (and encoder training data if encoder-decoder transformer)
589
+ - optional paths to specify: decoder and encoder masks, decoder and encoder validation data, decoder and encoder validation masks data
590
+ - if neither pad_token or tokenizer is specified (or tokenizer has no pad_token), any padding in labels will contribute towards the loss
591
+ which may cause unwanted results. Specifying pad_token and/or tokenizer allows loss to be calculated while ignoring any padding in labels
592
+ - specify save_path to save the model as a .pkl file every eval_interval iterations using the save_component function.
593
+
594
+ generate() function uses the tranformer model from the RoboConstructor instance to generate an output from an input.
595
+ - input can be in the form of a string if input tokenizer is specified (enc_tokenizer for encoder-decoder and dec_tokenizder for decoder-only),
596
+ otherwise, it must be in the form of a list of tokens.
597
+ - if dec_tokenizer is specified, output will be a string.
598
+ - new tokens are generated until the dec_end_token (or dec_tokenizer.end_token) is generated, or the number of tokens generated == max_new_tokens.
599
+ - if input tokenizer is not specified, or input tokenizer.start_token is None, enc_start_token must be specified for an encoder-decoder model.
600
+ - separator_token is used to separate the input and generated tokens for a decoder-only model. If this value is not specified, there
601
+ will be no distinction between input tokens and generated tokens to the transformer, even if dec_tokenizer is specified.
602
+ - if new_line_token is not specified, output will be returned in one line, without any "\n" line separators.
603
+ - temperature, top_k and top_p can be specified to adjust the output.
604
+
605
+ '''
606
+ def __init__(self,
607
+ n_embed:int,
608
+ dec_n_blocks:int,
609
+ dec_n_head:int,
610
+ dec_vocab_size:int,
611
+ dec_block_size:int,
612
+ dec_expansion_factor:int=4,
613
+ enc_n_blocks:int=0,
614
+ enc_n_head:int=None,
615
+ enc_vocab_size:int=None,
616
+ enc_block_size:int=None,
617
+ enc_expansion_factor:int=4,
618
+ dropout:float=0.1,
619
+ device:str=None
620
+ ) -> None:
621
+ super().__init__()
622
+ self.n_embed = n_embed
623
+ self.dec_n_blocks = dec_n_blocks
624
+ self.dec_n_head = dec_n_head
625
+ self.dec_vocab_size = dec_vocab_size
626
+ self.dec_block_size = dec_block_size
627
+ self.dec_expansion_factor = dec_expansion_factor
628
+ self.dropout = dropout
629
+
630
+ if device == None:
631
+ self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
632
+ else:
633
+ self.device = device
634
+ self.dec_token_embedding_table = nn.Embedding(dec_vocab_size, n_embed)
635
+ self.dec_positional_embedding_table = nn.Embedding(dec_block_size, n_embed)
636
+
637
+ if enc_n_blocks != 0:
638
+ self.enc_n_blocks = enc_n_blocks
639
+ self.enc_n_head = enc_n_head
640
+ self.enc_expansion_factor = enc_expansion_factor
641
+ self.enc_vocab_size = enc_vocab_size
642
+ self.enc_block_size = enc_block_size
643
+ self.cross_attention = True
644
+ self.enc_token_embedding_table = nn.Embedding(enc_vocab_size, n_embed)
645
+ self.enc_positional_embedding_table = nn.Embedding(enc_block_size, n_embed)
646
+ self.encoder_blocks = MySequential(*[EncoderBlock(n_embed, enc_n_head, enc_expansion_factor, dropout=dropout) for _ in range(enc_n_blocks)])
647
+ else:
648
+ self.cross_attention = False
649
+ self.enc_block_size = None
650
+
651
+ self.decoder_blocks = MySequential(*[DecoderBlock(n_embed, dec_n_head, dec_expansion_factor, cross_attention=self.cross_attention, block_size=self.dec_block_size, dropout=dropout) for _ in range(dec_n_blocks)])
652
+ self.ln = nn.LayerNorm(n_embed)
653
+ self.lid = nn.Linear(n_embed, dec_vocab_size)
654
+
655
+ self.apply(self._init_weights)
656
+
657
+ def _init_weights(self, module) -> None:
658
+ if isinstance(module, nn.Linear):
659
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
660
+ if module.bias is not None:
661
+ torch.nn.init.zeros_(module.bias)
662
+ elif isinstance(module, nn.Embedding):
663
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
664
+
665
+ def forward(self,
666
+ dec_in:torch.tensor,
667
+ dec_mask:torch.tensor=None,
668
+ enc_in:torch.tensor=None,
669
+ enc_mask:torch.tensor=None
670
+ ) -> torch.tensor:
671
+ _, dec_T = dec_in.shape
672
+ if enc_in != None:
673
+ _, enc_T = enc_in.shape
674
+
675
+ dec_tok_emb = self.dec_token_embedding_table(dec_in)
676
+ dec_pos_emb = self.dec_positional_embedding_table(torch.arange(dec_T, device=self.device))
677
+ dec_x = dec_tok_emb + dec_pos_emb
678
+
679
+ if self.cross_attention:
680
+ enc_tok_emb = self.enc_token_embedding_table(enc_in)
681
+ enc_pos_emb = self.enc_positional_embedding_table(torch.arange(enc_T, device=self.device))
682
+ enc_x = enc_tok_emb + enc_pos_emb
683
+
684
+ enc_out, enc_mask = self.encoder_blocks(enc_x, enc_mask)
685
+ else:
686
+ enc_out = None
687
+
688
+ x, _, _, _, _ = self.decoder_blocks(dec_x, enc_out, enc_out, dec_mask, enc_mask)
689
+ x = self.ln(x)
690
+ proj_output = self.lid(x)
691
+
692
+ return proj_output
693
+
694
+
695
+ def prep_data(self,
696
+ batch_size:int,
697
+ dec_data:str,
698
+ dec_block_size:int,
699
+ dec_masks:str=None,
700
+ enc_data:str=None,
701
+ enc_block_size:int=None,
702
+ enc_masks:str=None
703
+ ) -> tuple[torch.tensor]:
704
+ random_samples = torch.randint(dec_data.shape[0], (batch_size,))
705
+
706
+ dec_train_batch_in, dec_train_batch_out, dec_train_masks_in = get_batch(dec_data, random_samples, masks=dec_masks, block_size=dec_block_size, get_offset=True)
707
+ dec_train_batch_in = dec_train_batch_in.to(self.device)
708
+ dec_train_batch_out = dec_train_batch_out.to(self.device) if dec_train_batch_out != None else None
709
+ dec_train_masks_in = dec_train_masks_in.to(self.device) if dec_train_masks_in != None else None
710
+
711
+ if self.cross_attention:
712
+ enc_train_batch_in, _, enc_train_masks_in = get_batch(enc_data, random_samples, masks=enc_masks, block_size=enc_block_size, get_offset=False)
713
+ enc_train_batch_in = enc_train_batch_in.to(self.device)
714
+ enc_train_masks_in = enc_train_masks_in.to(self.device) if enc_train_masks_in != None else None
715
+ else:
716
+ enc_train_batch_in = None
717
+ enc_train_masks_in = None
718
+
719
+ return dec_train_batch_in, dec_train_batch_out, dec_train_masks_in, enc_train_batch_in, enc_train_masks_in
720
+
721
+
722
+ def train_robo(self,
723
+ max_iters:int,
724
+ eval_interval:int,
725
+ batch_size:int,
726
+ dec_training_path:str,
727
+ dec_eval_path:str=None,
728
+ dec_training_masks_path:str=None,
729
+ dec_eval_masks_path:str=None,
730
+ enc_training_path:str=None,
731
+ enc_eval_path:str=None,
732
+ enc_training_masks_path:str=None,
733
+ enc_eval_masks_path:str=None,
734
+ eval_iters:int=3,
735
+ learning_rate:float=1e-4,
736
+ pad_token:int=None,
737
+ tokenizer:TokenizerConstructor=None,
738
+ save_path:str=None,
739
+ label_smoothing:float=0.1
740
+ ) -> None:
741
+
742
+ dec_training_data = torch.load(dec_training_path, weights_only=True)
743
+ dec_eval_data = torch.load(dec_eval_path, weights_only=True) if dec_eval_path != None else None
744
+ dec_training_masks_data = torch.load(dec_training_masks_path, weights_only=True) if dec_training_masks_path != None else None
745
+ dec_eval_masks_data = torch.load(dec_eval_masks_path, weights_only=True) if dec_eval_masks_path != None else None
746
+ enc_training_data = torch.load(enc_training_path, weights_only=True) if enc_training_path != None else None
747
+ enc_eval_data = torch.load(enc_eval_path, weights_only=True) if enc_eval_path != None else None
748
+ enc_training_masks_data = torch.load(enc_training_masks_path, weights_only=True) if enc_training_masks_path != None else None
749
+ enc_eval_masks_data = torch.load(enc_eval_masks_path, weights_only=True) if enc_eval_masks_path != None else None
750
+
751
+ if pad_token == None and tokenizer != None:
752
+ pad_token = tokenizer.pad_token
753
+
754
+ self.to(self.device)
755
+
756
+ if pad_token != None:
757
+ loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=label_smoothing).to(self.device)
758
+ else:
759
+ loss_fn = nn.CrossEntropyLoss(label_smoothing=label_smoothing).to(self.device)
760
+ print(sum(p.numel() for p in self.parameters())/1e6, "M parameters")
761
+ optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
762
+ @torch.no_grad()
763
+ def estimate_loss() -> dict:
764
+ out = {}
765
+ self.eval()
766
+ losses = torch.zeros(eval_iters)
767
+ for k in range(eval_iters):
768
+ dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
769
+ proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
770
+ losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
771
+ out["train"] = losses.mean()
772
+ if dec_eval_data != None:
773
+ for k in range(eval_iters):
774
+ dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_eval_data, dec_masks=dec_eval_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_eval_data, enc_masks=enc_eval_masks_data, enc_block_size=self.enc_block_size)
775
+ proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
776
+ losses[k] = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
777
+ out["eval"] = losses.mean()
778
+ else:
779
+ out["eval"] = np.nan
780
+ self.train()
781
+ return out
782
+
783
+ self.train()
784
+ for iter in range(max_iters):
785
+ if iter % eval_interval == 0 or iter == max_iters-1:
786
+ losses = estimate_loss()
787
+ print(f"step {iter}: train loss {losses['train']:.4f}, eval loss {losses['eval']:.4f}")
788
+ if save_path != None:
789
+ save_component(self, save_path=save_path)
790
+
791
+ dec_x, dec_y, dec_mask, enc_x, enc_mask = self.prep_data(batch_size, dec_training_data, dec_masks=dec_training_masks_data, dec_block_size=self.dec_block_size, enc_data=enc_training_data, enc_masks=enc_training_masks_data, enc_block_size=self.enc_block_size)
792
+ proj_output = self.forward(dec_x, dec_mask, enc_x, enc_mask)
793
+ loss = loss_fn(proj_output.view(-1, self.dec_vocab_size), dec_y.view(-1))
794
+ loss.backward()
795
+ optimizer.step()
796
+ optimizer.zero_grad()
797
+
798
+ self.eval()
799
+
800
+ # use dec and enc tokenizers
801
+ def generate(self,
802
+ inputs:list[int]|str,
803
+ max_new_tokens:int=None,
804
+ dec_tokenizer:TokenizerConstructor=None,
805
+ enc_tokenizer:TokenizerConstructor=None,
806
+ dec_start_token:int=None,
807
+ enc_start_token:int=None,
808
+ enc_end_token:int=None,
809
+ dec_end_token:int=None,
810
+ separator_token:int=None,
811
+ new_line_token:int=None,
812
+ temperature:float=1,
813
+ top_k:int=None,
814
+ top_p:float=None
815
+ ) -> list[int]|str:
816
+ max_new_tokens = self.dec_block_size if max_new_tokens == None else max_new_tokens
817
+
818
+ if self.cross_attention:
819
+ if enc_tokenizer != None:
820
+ if enc_start_token == None:
821
+ enc_start_token = enc_tokenizer.start_token
822
+ if enc_end_token == None:
823
+ enc_end_token = enc_tokenizer.end_token
824
+ if isinstance(inputs, str):
825
+ inputs = enc_tokenizer.encode(inputs)
826
+
827
+ if dec_tokenizer != None:
828
+ if dec_start_token == None:
829
+ dec_start_token = dec_tokenizer.start_token
830
+ if dec_end_token == None:
831
+ dec_end_token = dec_tokenizer.end_token
832
+ if new_line_token == None:
833
+ new_line_token = dec_tokenizer.new_line_token
834
+ if self.cross_attention == False and isinstance(inputs, str):
835
+ inputs = dec_tokenizer.encode(inputs)
836
+
837
+
838
+ if self.cross_attention:
839
+ enc_input = torch.tensor([[enc_start_token] + inputs + [enc_end_token]], dtype=torch.long, device=self.device)
840
+ idx = torch.tensor([[dec_start_token]], dtype=torch.long, device=self.device)
841
+ else:
842
+ enc_input = None
843
+ if separator_token != None:
844
+ idx = torch.tensor([[dec_start_token] + inputs + [separator_token]], dtype=torch.long, device=self.device)
845
+ else:
846
+ idx = torch.tensor([[dec_start_token] + inputs], dtype=torch.long, device=self.device)
847
+
848
+ self.eval()
849
+ for _ in range(1, max_new_tokens):
850
+ idx_cond = idx[:, -self.dec_block_size:] if idx.shape[1] > self.dec_block_size else idx
851
+
852
+ proj_output = self(idx_cond, enc_in=enc_input)
853
+
854
+ logits = proj_output[:, -1, :]
855
+ probabilities = F.log_softmax(logits/temperature, dim=-1)
856
+
857
+ if top_k == None and top_p == None:
858
+ idx_next = torch.max(probabilities, dim=-1).indices.unsqueeze(0)
859
+ else:
860
+ idx_next = top_kp_filter(probabilities, top_k=top_k, top_p=top_p).unsqueeze(0).to(self.device)
861
+ idx = torch.cat((idx, idx_next), dim=-1)
862
+ if idx_next[0] == dec_end_token:
863
+ break
864
+
865
+ if dec_tokenizer == None:
866
+ return idx[0].tolist()
867
+ else:
868
+ if new_line_token != None:
869
+ return "\n".join([dec_tokenizer.decode(list(y)) for x, y in itertools.groupby(idx[0].tolist(), lambda z: z == 0) if not x])
870
+ else:
871
+ return dec_tokenizer.decode(idx[0].tolist())
872
+
873
+
874
+ def save_component(component, save_path:str) -> None:
875
+ '''
876
+
877
+ saves component (such as TokenizerConstructor or RoboConstructor) as .pkl file.
878
+
879
+ '''
880
+ save_path = save_path + ".pkl" if save_path[-4:] != ".pkl" else save_path
881
+ with open(save_path, "wb") as comp:
882
+ pickle.dump(component, comp, pickle.HIGHEST_PROTOCOL)
883
+
884
+ def load_component(load_path:str):
885
+ '''
886
+
887
+ loads saved .pkl file into variable.
888
+
889
+ '''
890
+ load_path = load_path + ".pkl" if load_path[-4:] != ".pkl" else load_path
891
+ with open(load_path, "rb") as comp:
892
+ loaded_component = pickle.load(comp)
893
+ return loaded_component
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.3
2
+ Name: robo_lib
3
+ Version: 0.0.4
4
+ Summary: A package to configure, create and train transformer models.
5
+ Project-URL: Homepage, https://github.com/hamburgerfish/robo_pack
6
+ Project-URL: Issues, https://github.com/hamburgerfish/robo_pack/issues
7
+ Author-email: Erik Papp <erik3papp@gmail.com>
8
+ License-File: LICENSE
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.8
13
+ Requires-Dist: numpy
14
+ Requires-Dist: tokenizers
15
+ Requires-Dist: torch
16
+ Description-Content-Type: text/markdown
17
+
18
+ # robo_pack
@@ -0,0 +1,6 @@
1
+ robo_lib/__init__.py,sha256=iVOAsANj0lScVW9KKMxCULYmpp0cv4sv1k3sHjBSlE0,1012
2
+ robo_lib/components.py,sha256=kNtDslsSfjV4b9mKxGB7ZjbjdPvk-o_1i0AKyA6c4Mk,42025
3
+ robo_lib-0.0.4.dist-info/METADATA,sha256=mAADQzCKvZLsTrGwxUGFj2SIgfTQLKOhutm4tX9CuJw,616
4
+ robo_lib-0.0.4.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
5
+ robo_lib-0.0.4.dist-info/licenses/LICENSE,sha256=4XzkkpFqPzH0GH3zxOqRTqc7xUKSEe7dWPOuJYW95ac,1089
6
+ robo_lib-0.0.4.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.25.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) [2024] [Erik Papp]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.