mlable-torch 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ # MLable
2
+
3
+ PyTorch libs: layers, blocks, optimizers, etc.
4
+
5
+ ## TODO
6
+
7
+ See [TODO](TODO.md).
8
+
9
+ ## License
10
+
11
+ Licensed under the [aGPLv3](LICENSE.md).
12
+
13
+ [code-micrograd]: https://github.com/karpathy/micrograd
14
+ [video-karpathy]: https://www.youtube.com/@AndrejKarpathy/videos
@@ -0,0 +1,28 @@
1
+ Metadata-Version: 2.1
2
+ Name: mlable-torch
3
+ Version: 0.2.0
4
+ Summary: PyTorch libs.
5
+ Author: apehex
6
+ Author-email: apehex@protonmail.com
7
+ Requires-Python: >=3.10,<3.12
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Requires-Dist: torch (>=2.2)
12
+ Description-Content-Type: text/markdown
13
+
14
+ # MLable
15
+
16
+ PyTorch libs: layers, blocks, optimizers, etc.
17
+
18
+ ## TODO
19
+
20
+ See [TODO](TODO.md).
21
+
22
+ ## License
23
+
24
+ Licensed under the [aGPLv3](LICENSE.md).
25
+
26
+ [code-micrograd]: https://github.com/karpathy/micrograd
27
+ [video-karpathy]: https://www.youtube.com/@AndrejKarpathy/videos
28
+
File without changes
@@ -0,0 +1,7 @@
1
+ import torch
2
+
3
+ # BATCH #######################################################################
4
+
5
+ def batch(x: torch.Tensor, y: torch.Tensor, size: int) -> tuple:
6
+ __indices = torch.randint(0, x.shape[0], (size,))
7
+ return x[__indices], y[__indices]
@@ -0,0 +1,266 @@
1
+ import math
2
+
3
+ import torch
4
+
5
+ # NORMALIZATION ###############################################################
6
+
7
+ class BatchNorm1d(torch.nn.Module):
8
+ def __init__(self, dim: int, epsilon: float=1e-5, momentum: float=0.1, **kwargs) -> None:
9
+ super(BatchNorm1d, self).__init__(**kwargs)
10
+ self._epsilon = epsilon
11
+ self._momentum = momentum
12
+ # parameters (trained with backprop)
13
+ self._gamma = torch.nn.Parameter(torch.ones(dim), requires_grad=True)
14
+ self._beta = torch.nn.Parameter(torch.zeros(dim), requires_grad=True)
15
+ # buffers (trained with a running 'momentum update')
16
+ self._mean = torch.zeros(dim)
17
+ self._var = torch.ones(dim)
18
+ self.register_buffer("mean", self._mean)
19
+ self.register_buffer("variance", self._var)
20
+
21
+ def forward(self, x: torch.Tensor, training: bool, **kwargs) -> torch.Tensor:
22
+ # current mean
23
+ if training:
24
+ __axes = list(range(x.ndim - 1)) # reduce all axes except the last one
25
+ with torch.no_grad():
26
+ __mean = x.mean(__axes, keepdim=True) # batch mean
27
+ __var = x.var(__axes, keepdim=True) # batch variance
28
+ self._mean = (1. - self._momentum) * self._mean + self._momentum * __mean
29
+ self._var = (1. - self._momentum) * self._var + self._momentum * __var
30
+ # normalize x
31
+ __x = (x - self._mean) / torch.sqrt(self._var + self._epsilon)
32
+ # scale
33
+ return self._gamma * __x + self._beta
34
+
35
+ # ACTIVATION ##################################################################
36
+
37
+ class Tanh(torch.nn.Module):
38
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
39
+ return torch.tanh(x)
40
+
41
+ class NewGELU(torch.nn.Module):
42
+ """Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415"""
43
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
44
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
45
+
46
+ # RESHAPING ###################################################################
47
+
48
+ def _normalize_shape(shape: list) -> list:
49
+ return [-1 if __d is None else __d for __d in shape]
50
+
51
+ def _normalize_dim(dim: int) -> int:
52
+ return -1 if (dim is None or dim < 0) else dim
53
+
54
+ def _multiply_dim(dim_l: int, dim_r: int) -> int:
55
+ return -1 if (dim_l == -1 or dim_r == -1) else dim_l * dim_r
56
+
57
+ def _divide_dim(dim_l: int, dim_r: int) -> int:
58
+ return -1 if (dim_l == -1 or dim_r == -1) else dim_l // dim_r
59
+
60
+ class Divide(torch.nn.Module):
61
+ def __init__(
62
+ self,
63
+ input_axis: int, # relative to the NEW shape / rank
64
+ output_axis: int, # same
65
+ factor: int,
66
+ insert: bool=False,
67
+ **kwargs
68
+ ) -> None:
69
+ super(Divide, self).__init__(**kwargs)
70
+ self._input_axis = input_axis
71
+ self._output_axis = output_axis
72
+ self._factor = factor
73
+ self._insert = insert
74
+
75
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
76
+ # infer the dimension of the symbolic axis
77
+ __shape = _normalize_shape(list(inputs.shape))
78
+ # rank, according to the new shape
79
+ __rank = len(__shape) + int(self._insert)
80
+ # axes, taken from the new shape
81
+ __axis0 = self._input_axis % __rank
82
+ __axis1 = self._output_axis % __rank
83
+ # option to group data on a new axis
84
+ if self._insert: __shape.insert(__axis1, 1)
85
+ # move data from axis 0 to axis 1
86
+ __shape[__axis0] = _divide_dim(__shape[__axis0], self._factor)
87
+ __shape[__axis1] = _multiply_dim(__shape[__axis1], self._factor)
88
+ return inputs.view(*__shape) #.squeeze(1)
89
+
90
+ class Merge(torch.nn.Module):
91
+ def __init__(
92
+ self,
93
+ left_axis: int=-2,
94
+ right_axis: int=-1,
95
+ left: bool=True,
96
+ **kwargs
97
+ ) -> None:
98
+ super(Merge, self).__init__(**kwargs)
99
+ self._left_axis = left_axis
100
+ self._right_axis = right_axis
101
+ self._left = left
102
+
103
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
104
+ # infer the dimension of the symbolic axis
105
+ __shape = _normalize_shape(list(inputs.shape))
106
+ __rank = len(__shape)
107
+ # target axes
108
+ __axis_l = self._left_axis % __rank
109
+ __axis_r = self._right_axis % __rank
110
+ # new axis
111
+ __dim = _multiply_dim(__shape[__axis_l], __shape[__axis_r])
112
+ __axis_k = __axis_l if self._left else __axis_r # kept axis
113
+ __axis_d = __axis_r if self._left else __axis_l # deleted axis
114
+ # new shape
115
+ __shape[__axis_k] = __dim
116
+ __shape.pop(__axis_d)
117
+ # actually merge the two axes
118
+ return inputs.view(*__shape)
119
+
120
+ # LINEAR ######################################################################
121
+
122
+ class Linear(torch.nn.Module):
123
+ def __init__(self, in_features: int, out_features: int, bias: bool=True, **kwargs) -> None:
124
+ super(Linear, self).__init__(**kwargs)
125
+ self._weight = torch.nn.Parameter(torch.randn((in_features, out_features)) / (in_features ** 0.5), requires_grad=True)
126
+ self._bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=True) if bias else None
127
+
128
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
129
+ __x = torch.matmul(x, self._weight)
130
+ if self._bias is not None:
131
+ __x += self._bias
132
+ return __x
133
+
134
+ class Embedding(torch.nn.Module):
135
+ def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs) -> None:
136
+ super(Embedding, self).__init__(**kwargs)
137
+ self._depth = num_embeddings
138
+ self._weight = torch.nn.Parameter(torch.randn((num_embeddings, embedding_dim)), requires_grad=True)
139
+
140
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
141
+ __x = torch.nn.functional.one_hot(input=x, num_classes=self._depth)
142
+ return torch.matmul(__x.float(), self._weight)
143
+
144
+ class PositionalEmbedding(torch.nn.Module):
145
+ def __init__(
146
+ self,
147
+ time_dim: int,
148
+ embed_dim: int,
149
+ input_axis: int=1, # axis of the sequence
150
+ output_axis: int=-1, # axis of the embedding
151
+ **kwargs
152
+ ):
153
+ super(PositionalEmbedding, self).__init__(**kwargs)
154
+ # weights
155
+ self._input_axis = input_axis
156
+ self._output_axis = output_axis
157
+ self._kernel = torch.nn.Parameter(torch.randn((time_dim, embed_dim)), requires_grad=True)
158
+
159
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
160
+ # shape
161
+ __input_shape = list(inputs.shape)
162
+ __axes = [self._input_axis % len(__input_shape), self._output_axis % len(__input_shape)]
163
+ __output_shape = [(__d if __i in __axes else 1) for __i, __d in enumerate(list(__input_shape))]
164
+ return inputs + self._kernel.view(*__output_shape) # each index in the sequence axis has a dedicated bias (different from dense bias)
165
+
166
+ # RECURRENT ###################################################################
167
+
168
+ class RNNCell(torch.nn.Module):
169
+ def __init__(self, embed_dim: int, state_dim: int, **kwargs) -> None:
170
+ super(RNNCell, self).__init__(**kwargs)
171
+ self._weights = Linear(in_features=embed_dim + state_dim, out_features=state_dim)
172
+
173
+ def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
174
+ __xh = torch.cat([x, h], dim=-1)
175
+ return torch.nn.functional.tanh(self._weights(__xh))
176
+
177
+ class GRUCell(torch.nn.Module):
178
+ def __init__(self, embed_dim: int, state_dim: int, **kwargs) -> None:
179
+ super(GRUCell, self).__init__(**kwargs)
180
+ # input, forget, output, gate
181
+ self._xh_to_z = Linear(in_features=embed_dim + state_dim, out_features=state_dim)
182
+ self._xh_to_r = Linear(in_features=embed_dim + state_dim, out_features=state_dim)
183
+ self._xh_to_hhat = Linear(in_features=embed_dim + state_dim, out_features=state_dim)
184
+
185
+ def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
186
+ # state
187
+ __xh = torch.cat([x, h], dim=-1)
188
+ # reset gate
189
+ __r = torch.nn.functional.sigmoid(self._xh_to_r(__xh))
190
+ # switch gate
191
+ __z = torch.nn.functional.sigmoid(self._xh_to_z(__xh))
192
+ # reset state
193
+ __xhr = torch.cat([x, __r * h], dim=-1)
194
+ # candidate state
195
+ __hhat = torch.nn.functional.tanh(self._xh_to_hhat(__xhr))
196
+ # combine candidate and previous states
197
+ return (1. - __z) * h + __z * __hhat
198
+
199
+ # ATTENTION ###################################################################
200
+
201
+ class CausalSelfAttention(torch.nn.Module):
202
+ def __init__(self, time_dim: int, embed_dim: int, num_heads: int, **kwargs) -> None:
203
+ super(CausalSelfAttention, self).__init__(**kwargs)
204
+ assert embed_dim % num_heads == 0
205
+ # key, query, value projections for all heads, but in a batch
206
+ self._attention = Linear(in_features=embed_dim, out_features=3 * embed_dim)
207
+ # output projection
208
+ self._projection = Linear(in_features=embed_dim, out_features=embed_dim)
209
+ # causal mask to ensure that attention is only applied to the left in the input sequence
210
+ self._mask = torch.tril(torch.ones(time_dim, time_dim)).view(1, 1, time_dim, time_dim)
211
+ self.register_buffer("mask", self._mask)
212
+ # save the shape
213
+ self._head_count = num_heads
214
+ self._head_dim = embed_dim
215
+
216
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
217
+ # insert a new axis to group by attention head
218
+ __shape = list(x.shape)
219
+ __shape.insert(2, self._head_count)
220
+ __shape[-1] = __shape[-1] // self._head_count
221
+ # calculate query, key, values for all heads in batch
222
+ __q, __k, __v = self._attention(x).split(self._head_dim, dim=-1)
223
+ # group by head rather than time
224
+ __k = __k.view(*__shape).transpose(1, 2) # (B, H, T, E/H)
225
+ __q = __q.view(*__shape).transpose(1, 2) # (B, H, T, E/H)
226
+ __v = __v.view(*__shape).transpose(1, 2) # (B, H, T, E/H)
227
+ # self-attention
228
+ __w = (__q @ __k.transpose(-2, -1)) * (1.0 / math.sqrt(__shape[-1])) # (B, H, T, E/H) x (B, H, E/H, T) -> (B, H, T, T)
229
+ # causal: only attend to past tokens
230
+ __w = __w.masked_fill(self._mask == 0, float('-inf'))
231
+ __w = torch.nn.functional.softmax(__w, dim=-1)
232
+ # values
233
+ __y = __w @ __v # (B, H, T, T) x (B, H, T, E/H) -> (B, H, T, E/H)
234
+ # assemble heads
235
+ __y = __y.transpose(1, 2).contiguous().view(*x.shape) # original shape (B, T, E)
236
+ # output projection
237
+ return self._projection(__y)
238
+
239
+ # BLOCKS ######################################################################
240
+
241
+ class Sequential(torch.nn.Module):
242
+ def __init__(self, layers: list, **kwargs) -> None:
243
+ super(Sequential, self).__init__(**kwargs)
244
+ self._layers = layers
245
+
246
+ def forward(self, x: torch.Tensor, training: bool=True, **kwargs) -> torch.Tensor:
247
+ __x = x
248
+ # forward
249
+ for __l in self._layers:
250
+ __x = __l(x=__x, training=training, **kwargs)
251
+ # conclude
252
+ return __x
253
+
254
+ class TransformerBlock(torch.nn.Module):
255
+ def __init__(self, time_dim: int, embed_dim: int, num_heads: int, **kwargs) -> None:
256
+ super(TransformerBlock, self).__init__(**kwargs)
257
+ self._block = torch.nn.Sequential(
258
+ torch.nn.LayerNorm(embed_dim),
259
+ CausalSelfAttention(time_dim=time_dim, embed_dim=embed_dim, num_heads=num_heads),
260
+ torch.nn.LayerNorm(embed_dim),
261
+ Linear(embed_dim, 4 * embed_dim),
262
+ Linear(4 * embed_dim, embed_dim),
263
+ NewGELU())
264
+
265
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
266
+ return self._block(x, **kwargs)
@@ -0,0 +1,64 @@
1
+ import functools
2
+
3
+ import torch
4
+
5
+ import mlable.torch.data as _mtd
6
+
7
+ # LEARNING RATE ###############################################################
8
+
9
+ def learning_rate_waveform(step: int, lr_min: float, lr_max: float, lr_exp: float, rampup: int, sustain: int, steps_per_epoch: int=1024) -> float:
10
+ __lr = lr_min
11
+ __epoch = step // steps_per_epoch
12
+ if __epoch < rampup:
13
+ __lr = lr_min + (__epoch * (lr_max - lr_min) / rampup)
14
+ elif __epoch < rampup + sustain:
15
+ __lr = lr_max
16
+ else:
17
+ __lr = lr_min + (lr_max - lr_min) * lr_exp ** (__epoch - rampup - sustain)
18
+ return __lr
19
+
20
+ # SGD #########################################################################
21
+
22
+ class SGD(torch.optim.Optimizer):
23
+ def __init__(self, params: list, rate: callable, **kwargs) -> None:
24
+ __default_rate = functools.partial(learning_rate_waveform, lr_min=0.00001, lr_max=0.0001, lr_exp=0.8, rampup=4, sustain=2, steps_per_epoch=1024)
25
+ super(SGD, self).__init__(params, {'rate': __default_rate}, **kwargs)
26
+ self._parameters = list(params)
27
+ self._rate = rate
28
+ self._iteration = -1
29
+
30
+ def step(self) -> None:
31
+ self._iteration += 1
32
+ with torch.no_grad():
33
+ for __p in self._parameters:
34
+ __p += -self._rate(self._iteration) * __p.grad
35
+
36
+ # GENERIC #####################################################################
37
+
38
+ def step(model: torch.nn.Module, loss: callable, optimizer: torch.optim.Optimizer, x: torch.Tensor, y: torch.Tensor, epoch: int) -> torch.Tensor:
39
+ # forward
40
+ __output = model(x=x, training=True)
41
+ __loss = loss(input=__output, target=y)
42
+ # backward
43
+ model.zero_grad(set_to_none=True)
44
+ __loss.backward()
45
+ # update the parameters
46
+ optimizer.step()
47
+ return __loss
48
+
49
+ def train(model:torch.nn.Module, loss: callable, optimizer: torch.optim.Optimizer, x: torch.Tensor, y: torch.Tensor, n_epoch: int, n_batch: int) -> None:
50
+ # scheme
51
+ __steps = int(x.shape[0]) // n_batch
52
+ # iterate on the whole dataset
53
+ for __e in range(n_epoch):
54
+ # iterate on batchs
55
+ for __s in range(__steps):
56
+ # track the overall iteration
57
+ __k = __e * __steps + __s
58
+ # random batch
59
+ __x, __y = _mtd.batch(x=x, y=y, size=n_batch)
60
+ # step
61
+ __loss = step(model=model, loss=loss, optimizer=optimizer, x=__x, y=__y, epoch=__e)
62
+ # log the progress
63
+ if __s % __steps == 0:
64
+ print('[epoch {epoch}] train loss: {train}'.format(epoch=__e, train=__loss.item()))
@@ -0,0 +1,17 @@
1
+ import torch
2
+
3
+ # NGRAMS ######################################################################
4
+
5
+ def _next(model: torch.nn.Module, ngram: list) -> int:
6
+ __logits = model(torch.tensor([ngram]), training=False)
7
+ __probs = torch.nn.functional.softmax(__logits, dim=-1)
8
+ return torch.multinomial(__probs, num_samples=1).item()
9
+
10
+ def sample(model: torch.nn.Module, context: int, length: int) -> str:
11
+ __result = []
12
+ __ngram = context * [0]
13
+ for __i in range(length):
14
+ __n = _next(model=model, ngram=__ngram)
15
+ __result.append(__n)
16
+ __ngram = __ngram[1:] + [__n]
17
+ return __result
@@ -0,0 +1,18 @@
1
+ [tool.poetry]
2
+ name = "mlable-torch"
3
+ version = "0.2.0"
4
+ description = "PyTorch libs."
5
+ authors = ["apehex <apehex@protonmail.com>"]
6
+ readme = ".github/README.md"
7
+ packages = [{include = "mlable"}]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = ">=3.10, <3.12"
11
+ torch = ">=2.2"
12
+
13
+ [tool.poetry.group.dev.dependencies]
14
+ pytest = "*"
15
+
16
+ [build-system]
17
+ requires = ["poetry-core"]
18
+ build-backend = "poetry.core.masonry.api"