mlable-torch 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlable_torch-0.2.0/.github/README.md +14 -0
- mlable_torch-0.2.0/PKG-INFO +28 -0
- mlable_torch-0.2.0/mlable/__init__.py +0 -0
- mlable_torch-0.2.0/mlable/data.py +7 -0
- mlable_torch-0.2.0/mlable/layers.py +266 -0
- mlable_torch-0.2.0/mlable/optimizers.py +64 -0
- mlable_torch-0.2.0/mlable/sampling.py +17 -0
- mlable_torch-0.2.0/pyproject.toml +18 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# MLable
|
|
2
|
+
|
|
3
|
+
PyTorch libs: layers, blocks, optimizers, etc.
|
|
4
|
+
|
|
5
|
+
## TODO
|
|
6
|
+
|
|
7
|
+
See [TODO](TODO.md).
|
|
8
|
+
|
|
9
|
+
## License
|
|
10
|
+
|
|
11
|
+
Licensed under the [aGPLv3](LICENSE.md).
|
|
12
|
+
|
|
13
|
+
[code-micrograd]: https://github.com/karpathy/micrograd
|
|
14
|
+
[video-karpathy]: https://www.youtube.com/@AndrejKarpathy/videos
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mlable-torch
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: PyTorch libs.
|
|
5
|
+
Author: apehex
|
|
6
|
+
Author-email: apehex@protonmail.com
|
|
7
|
+
Requires-Python: >=3.10,<3.12
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Requires-Dist: torch (>=2.2)
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
|
|
14
|
+
# MLable
|
|
15
|
+
|
|
16
|
+
PyTorch libs: layers, blocks, optimizers, etc.
|
|
17
|
+
|
|
18
|
+
## TODO
|
|
19
|
+
|
|
20
|
+
See [TODO](TODO.md).
|
|
21
|
+
|
|
22
|
+
## License
|
|
23
|
+
|
|
24
|
+
Licensed under the [aGPLv3](LICENSE.md).
|
|
25
|
+
|
|
26
|
+
[code-micrograd]: https://github.com/karpathy/micrograd
|
|
27
|
+
[video-karpathy]: https://www.youtube.com/@AndrejKarpathy/videos
|
|
28
|
+
|
|
File without changes
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
# NORMALIZATION ###############################################################
|
|
6
|
+
|
|
7
|
+
class BatchNorm1d(torch.nn.Module):
|
|
8
|
+
def __init__(self, dim: int, epsilon: float=1e-5, momentum: float=0.1, **kwargs) -> None:
|
|
9
|
+
super(BatchNorm1d, self).__init__(**kwargs)
|
|
10
|
+
self._epsilon = epsilon
|
|
11
|
+
self._momentum = momentum
|
|
12
|
+
# parameters (trained with backprop)
|
|
13
|
+
self._gamma = torch.nn.Parameter(torch.ones(dim), requires_grad=True)
|
|
14
|
+
self._beta = torch.nn.Parameter(torch.zeros(dim), requires_grad=True)
|
|
15
|
+
# buffers (trained with a running 'momentum update')
|
|
16
|
+
self._mean = torch.zeros(dim)
|
|
17
|
+
self._var = torch.ones(dim)
|
|
18
|
+
self.register_buffer("mean", self._mean)
|
|
19
|
+
self.register_buffer("variance", self._var)
|
|
20
|
+
|
|
21
|
+
def forward(self, x: torch.Tensor, training: bool, **kwargs) -> torch.Tensor:
|
|
22
|
+
# current mean
|
|
23
|
+
if training:
|
|
24
|
+
__axes = list(range(x.ndim - 1)) # reduce all axes except the last one
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
__mean = x.mean(__axes, keepdim=True) # batch mean
|
|
27
|
+
__var = x.var(__axes, keepdim=True) # batch variance
|
|
28
|
+
self._mean = (1. - self._momentum) * self._mean + self._momentum * __mean
|
|
29
|
+
self._var = (1. - self._momentum) * self._var + self._momentum * __var
|
|
30
|
+
# normalize x
|
|
31
|
+
__x = (x - self._mean) / torch.sqrt(self._var + self._epsilon)
|
|
32
|
+
# scale
|
|
33
|
+
return self._gamma * __x + self._beta
|
|
34
|
+
|
|
35
|
+
# ACTIVATION ##################################################################
|
|
36
|
+
|
|
37
|
+
class Tanh(torch.nn.Module):
|
|
38
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
39
|
+
return torch.tanh(x)
|
|
40
|
+
|
|
41
|
+
class NewGELU(torch.nn.Module):
|
|
42
|
+
"""Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415"""
|
|
43
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
44
|
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
|
45
|
+
|
|
46
|
+
# RESHAPING ###################################################################
|
|
47
|
+
|
|
48
|
+
def _normalize_shape(shape: list) -> list:
|
|
49
|
+
return [-1 if __d is None else __d for __d in shape]
|
|
50
|
+
|
|
51
|
+
def _normalize_dim(dim: int) -> int:
|
|
52
|
+
return -1 if (dim is None or dim < 0) else dim
|
|
53
|
+
|
|
54
|
+
def _multiply_dim(dim_l: int, dim_r: int) -> int:
|
|
55
|
+
return -1 if (dim_l == -1 or dim_r == -1) else dim_l * dim_r
|
|
56
|
+
|
|
57
|
+
def _divide_dim(dim_l: int, dim_r: int) -> int:
|
|
58
|
+
return -1 if (dim_l == -1 or dim_r == -1) else dim_l // dim_r
|
|
59
|
+
|
|
60
|
+
class Divide(torch.nn.Module):
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
input_axis: int, # relative to the NEW shape / rank
|
|
64
|
+
output_axis: int, # same
|
|
65
|
+
factor: int,
|
|
66
|
+
insert: bool=False,
|
|
67
|
+
**kwargs
|
|
68
|
+
) -> None:
|
|
69
|
+
super(Divide, self).__init__(**kwargs)
|
|
70
|
+
self._input_axis = input_axis
|
|
71
|
+
self._output_axis = output_axis
|
|
72
|
+
self._factor = factor
|
|
73
|
+
self._insert = insert
|
|
74
|
+
|
|
75
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
76
|
+
# infer the dimension of the symbolic axis
|
|
77
|
+
__shape = _normalize_shape(list(inputs.shape))
|
|
78
|
+
# rank, according to the new shape
|
|
79
|
+
__rank = len(__shape) + int(self._insert)
|
|
80
|
+
# axes, taken from the new shape
|
|
81
|
+
__axis0 = self._input_axis % __rank
|
|
82
|
+
__axis1 = self._output_axis % __rank
|
|
83
|
+
# option to group data on a new axis
|
|
84
|
+
if self._insert: __shape.insert(__axis1, 1)
|
|
85
|
+
# move data from axis 0 to axis 1
|
|
86
|
+
__shape[__axis0] = _divide_dim(__shape[__axis0], self._factor)
|
|
87
|
+
__shape[__axis1] = _multiply_dim(__shape[__axis1], self._factor)
|
|
88
|
+
return inputs.view(*__shape) #.squeeze(1)
|
|
89
|
+
|
|
90
|
+
class Merge(torch.nn.Module):
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
left_axis: int=-2,
|
|
94
|
+
right_axis: int=-1,
|
|
95
|
+
left: bool=True,
|
|
96
|
+
**kwargs
|
|
97
|
+
) -> None:
|
|
98
|
+
super(Merge, self).__init__(**kwargs)
|
|
99
|
+
self._left_axis = left_axis
|
|
100
|
+
self._right_axis = right_axis
|
|
101
|
+
self._left = left
|
|
102
|
+
|
|
103
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
104
|
+
# infer the dimension of the symbolic axis
|
|
105
|
+
__shape = _normalize_shape(list(inputs.shape))
|
|
106
|
+
__rank = len(__shape)
|
|
107
|
+
# target axes
|
|
108
|
+
__axis_l = self._left_axis % __rank
|
|
109
|
+
__axis_r = self._right_axis % __rank
|
|
110
|
+
# new axis
|
|
111
|
+
__dim = _multiply_dim(__shape[__axis_l], __shape[__axis_r])
|
|
112
|
+
__axis_k = __axis_l if self._left else __axis_r # kept axis
|
|
113
|
+
__axis_d = __axis_r if self._left else __axis_l # deleted axis
|
|
114
|
+
# new shape
|
|
115
|
+
__shape[__axis_k] = __dim
|
|
116
|
+
__shape.pop(__axis_d)
|
|
117
|
+
# actually merge the two axes
|
|
118
|
+
return inputs.view(*__shape)
|
|
119
|
+
|
|
120
|
+
# LINEAR ######################################################################
|
|
121
|
+
|
|
122
|
+
class Linear(torch.nn.Module):
|
|
123
|
+
def __init__(self, in_features: int, out_features: int, bias: bool=True, **kwargs) -> None:
|
|
124
|
+
super(Linear, self).__init__(**kwargs)
|
|
125
|
+
self._weight = torch.nn.Parameter(torch.randn((in_features, out_features)) / (in_features ** 0.5), requires_grad=True)
|
|
126
|
+
self._bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=True) if bias else None
|
|
127
|
+
|
|
128
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
129
|
+
__x = torch.matmul(x, self._weight)
|
|
130
|
+
if self._bias is not None:
|
|
131
|
+
__x += self._bias
|
|
132
|
+
return __x
|
|
133
|
+
|
|
134
|
+
class Embedding(torch.nn.Module):
|
|
135
|
+
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs) -> None:
|
|
136
|
+
super(Embedding, self).__init__(**kwargs)
|
|
137
|
+
self._depth = num_embeddings
|
|
138
|
+
self._weight = torch.nn.Parameter(torch.randn((num_embeddings, embedding_dim)), requires_grad=True)
|
|
139
|
+
|
|
140
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
141
|
+
__x = torch.nn.functional.one_hot(input=x, num_classes=self._depth)
|
|
142
|
+
return torch.matmul(__x.float(), self._weight)
|
|
143
|
+
|
|
144
|
+
class PositionalEmbedding(torch.nn.Module):
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
time_dim: int,
|
|
148
|
+
embed_dim: int,
|
|
149
|
+
input_axis: int=1, # axis of the sequence
|
|
150
|
+
output_axis: int=-1, # axis of the embedding
|
|
151
|
+
**kwargs
|
|
152
|
+
):
|
|
153
|
+
super(PositionalEmbedding, self).__init__(**kwargs)
|
|
154
|
+
# weights
|
|
155
|
+
self._input_axis = input_axis
|
|
156
|
+
self._output_axis = output_axis
|
|
157
|
+
self._kernel = torch.nn.Parameter(torch.randn((time_dim, embed_dim)), requires_grad=True)
|
|
158
|
+
|
|
159
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
160
|
+
# shape
|
|
161
|
+
__input_shape = list(inputs.shape)
|
|
162
|
+
__axes = [self._input_axis % len(__input_shape), self._output_axis % len(__input_shape)]
|
|
163
|
+
__output_shape = [(__d if __i in __axes else 1) for __i, __d in enumerate(list(__input_shape))]
|
|
164
|
+
return inputs + self._kernel.view(*__output_shape) # each index in the sequence axis has a dedicated bias (different from dense bias)
|
|
165
|
+
|
|
166
|
+
# RECURRENT ###################################################################
|
|
167
|
+
|
|
168
|
+
class RNNCell(torch.nn.Module):
|
|
169
|
+
def __init__(self, embed_dim: int, state_dim: int, **kwargs) -> None:
|
|
170
|
+
super(RNNCell, self).__init__(**kwargs)
|
|
171
|
+
self._weights = Linear(in_features=embed_dim + state_dim, out_features=state_dim)
|
|
172
|
+
|
|
173
|
+
def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
|
174
|
+
__xh = torch.cat([x, h], dim=-1)
|
|
175
|
+
return torch.nn.functional.tanh(self._weights(__xh))
|
|
176
|
+
|
|
177
|
+
class GRUCell(torch.nn.Module):
|
|
178
|
+
def __init__(self, embed_dim: int, state_dim: int, **kwargs) -> None:
|
|
179
|
+
super(GRUCell, self).__init__(**kwargs)
|
|
180
|
+
# input, forget, output, gate
|
|
181
|
+
self._xh_to_z = Linear(in_features=embed_dim + state_dim, out_features=state_dim)
|
|
182
|
+
self._xh_to_r = Linear(in_features=embed_dim + state_dim, out_features=state_dim)
|
|
183
|
+
self._xh_to_hhat = Linear(in_features=embed_dim + state_dim, out_features=state_dim)
|
|
184
|
+
|
|
185
|
+
def forward(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
|
186
|
+
# state
|
|
187
|
+
__xh = torch.cat([x, h], dim=-1)
|
|
188
|
+
# reset gate
|
|
189
|
+
__r = torch.nn.functional.sigmoid(self._xh_to_r(__xh))
|
|
190
|
+
# switch gate
|
|
191
|
+
__z = torch.nn.functional.sigmoid(self._xh_to_z(__xh))
|
|
192
|
+
# reset state
|
|
193
|
+
__xhr = torch.cat([x, __r * h], dim=-1)
|
|
194
|
+
# candidate state
|
|
195
|
+
__hhat = torch.nn.functional.tanh(self._xh_to_hhat(__xhr))
|
|
196
|
+
# combine candidate and previous states
|
|
197
|
+
return (1. - __z) * h + __z * __hhat
|
|
198
|
+
|
|
199
|
+
# ATTENTION ###################################################################
|
|
200
|
+
|
|
201
|
+
class CausalSelfAttention(torch.nn.Module):
|
|
202
|
+
def __init__(self, time_dim: int, embed_dim: int, num_heads: int, **kwargs) -> None:
|
|
203
|
+
super(CausalSelfAttention, self).__init__(**kwargs)
|
|
204
|
+
assert embed_dim % num_heads == 0
|
|
205
|
+
# key, query, value projections for all heads, but in a batch
|
|
206
|
+
self._attention = Linear(in_features=embed_dim, out_features=3 * embed_dim)
|
|
207
|
+
# output projection
|
|
208
|
+
self._projection = Linear(in_features=embed_dim, out_features=embed_dim)
|
|
209
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
210
|
+
self._mask = torch.tril(torch.ones(time_dim, time_dim)).view(1, 1, time_dim, time_dim)
|
|
211
|
+
self.register_buffer("mask", self._mask)
|
|
212
|
+
# save the shape
|
|
213
|
+
self._head_count = num_heads
|
|
214
|
+
self._head_dim = embed_dim
|
|
215
|
+
|
|
216
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
217
|
+
# insert a new axis to group by attention head
|
|
218
|
+
__shape = list(x.shape)
|
|
219
|
+
__shape.insert(2, self._head_count)
|
|
220
|
+
__shape[-1] = __shape[-1] // self._head_count
|
|
221
|
+
# calculate query, key, values for all heads in batch
|
|
222
|
+
__q, __k, __v = self._attention(x).split(self._head_dim, dim=-1)
|
|
223
|
+
# group by head rather than time
|
|
224
|
+
__k = __k.view(*__shape).transpose(1, 2) # (B, H, T, E/H)
|
|
225
|
+
__q = __q.view(*__shape).transpose(1, 2) # (B, H, T, E/H)
|
|
226
|
+
__v = __v.view(*__shape).transpose(1, 2) # (B, H, T, E/H)
|
|
227
|
+
# self-attention
|
|
228
|
+
__w = (__q @ __k.transpose(-2, -1)) * (1.0 / math.sqrt(__shape[-1])) # (B, H, T, E/H) x (B, H, E/H, T) -> (B, H, T, T)
|
|
229
|
+
# causal: only attend to past tokens
|
|
230
|
+
__w = __w.masked_fill(self._mask == 0, float('-inf'))
|
|
231
|
+
__w = torch.nn.functional.softmax(__w, dim=-1)
|
|
232
|
+
# values
|
|
233
|
+
__y = __w @ __v # (B, H, T, T) x (B, H, T, E/H) -> (B, H, T, E/H)
|
|
234
|
+
# assemble heads
|
|
235
|
+
__y = __y.transpose(1, 2).contiguous().view(*x.shape) # original shape (B, T, E)
|
|
236
|
+
# output projection
|
|
237
|
+
return self._projection(__y)
|
|
238
|
+
|
|
239
|
+
# BLOCKS ######################################################################
|
|
240
|
+
|
|
241
|
+
class Sequential(torch.nn.Module):
|
|
242
|
+
def __init__(self, layers: list, **kwargs) -> None:
|
|
243
|
+
super(Sequential, self).__init__(**kwargs)
|
|
244
|
+
self._layers = layers
|
|
245
|
+
|
|
246
|
+
def forward(self, x: torch.Tensor, training: bool=True, **kwargs) -> torch.Tensor:
|
|
247
|
+
__x = x
|
|
248
|
+
# forward
|
|
249
|
+
for __l in self._layers:
|
|
250
|
+
__x = __l(x=__x, training=training, **kwargs)
|
|
251
|
+
# conclude
|
|
252
|
+
return __x
|
|
253
|
+
|
|
254
|
+
class TransformerBlock(torch.nn.Module):
|
|
255
|
+
def __init__(self, time_dim: int, embed_dim: int, num_heads: int, **kwargs) -> None:
|
|
256
|
+
super(TransformerBlock, self).__init__(**kwargs)
|
|
257
|
+
self._block = torch.nn.Sequential(
|
|
258
|
+
torch.nn.LayerNorm(embed_dim),
|
|
259
|
+
CausalSelfAttention(time_dim=time_dim, embed_dim=embed_dim, num_heads=num_heads),
|
|
260
|
+
torch.nn.LayerNorm(embed_dim),
|
|
261
|
+
Linear(embed_dim, 4 * embed_dim),
|
|
262
|
+
Linear(4 * embed_dim, embed_dim),
|
|
263
|
+
NewGELU())
|
|
264
|
+
|
|
265
|
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
266
|
+
return self._block(x, **kwargs)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
import mlable.torch.data as _mtd
|
|
6
|
+
|
|
7
|
+
# LEARNING RATE ###############################################################
|
|
8
|
+
|
|
9
|
+
def learning_rate_waveform(step: int, lr_min: float, lr_max: float, lr_exp: float, rampup: int, sustain: int, steps_per_epoch: int=1024) -> float:
|
|
10
|
+
__lr = lr_min
|
|
11
|
+
__epoch = step // steps_per_epoch
|
|
12
|
+
if __epoch < rampup:
|
|
13
|
+
__lr = lr_min + (__epoch * (lr_max - lr_min) / rampup)
|
|
14
|
+
elif __epoch < rampup + sustain:
|
|
15
|
+
__lr = lr_max
|
|
16
|
+
else:
|
|
17
|
+
__lr = lr_min + (lr_max - lr_min) * lr_exp ** (__epoch - rampup - sustain)
|
|
18
|
+
return __lr
|
|
19
|
+
|
|
20
|
+
# SGD #########################################################################
|
|
21
|
+
|
|
22
|
+
class SGD(torch.optim.Optimizer):
|
|
23
|
+
def __init__(self, params: list, rate: callable, **kwargs) -> None:
|
|
24
|
+
__default_rate = functools.partial(learning_rate_waveform, lr_min=0.00001, lr_max=0.0001, lr_exp=0.8, rampup=4, sustain=2, steps_per_epoch=1024)
|
|
25
|
+
super(SGD, self).__init__(params, {'rate': __default_rate}, **kwargs)
|
|
26
|
+
self._parameters = list(params)
|
|
27
|
+
self._rate = rate
|
|
28
|
+
self._iteration = -1
|
|
29
|
+
|
|
30
|
+
def step(self) -> None:
|
|
31
|
+
self._iteration += 1
|
|
32
|
+
with torch.no_grad():
|
|
33
|
+
for __p in self._parameters:
|
|
34
|
+
__p += -self._rate(self._iteration) * __p.grad
|
|
35
|
+
|
|
36
|
+
# GENERIC #####################################################################
|
|
37
|
+
|
|
38
|
+
def step(model: torch.nn.Module, loss: callable, optimizer: torch.optim.Optimizer, x: torch.Tensor, y: torch.Tensor, epoch: int) -> torch.Tensor:
|
|
39
|
+
# forward
|
|
40
|
+
__output = model(x=x, training=True)
|
|
41
|
+
__loss = loss(input=__output, target=y)
|
|
42
|
+
# backward
|
|
43
|
+
model.zero_grad(set_to_none=True)
|
|
44
|
+
__loss.backward()
|
|
45
|
+
# update the parameters
|
|
46
|
+
optimizer.step()
|
|
47
|
+
return __loss
|
|
48
|
+
|
|
49
|
+
def train(model:torch.nn.Module, loss: callable, optimizer: torch.optim.Optimizer, x: torch.Tensor, y: torch.Tensor, n_epoch: int, n_batch: int) -> None:
|
|
50
|
+
# scheme
|
|
51
|
+
__steps = int(x.shape[0]) // n_batch
|
|
52
|
+
# iterate on the whole dataset
|
|
53
|
+
for __e in range(n_epoch):
|
|
54
|
+
# iterate on batchs
|
|
55
|
+
for __s in range(__steps):
|
|
56
|
+
# track the overall iteration
|
|
57
|
+
__k = __e * __steps + __s
|
|
58
|
+
# random batch
|
|
59
|
+
__x, __y = _mtd.batch(x=x, y=y, size=n_batch)
|
|
60
|
+
# step
|
|
61
|
+
__loss = step(model=model, loss=loss, optimizer=optimizer, x=__x, y=__y, epoch=__e)
|
|
62
|
+
# log the progress
|
|
63
|
+
if __s % __steps == 0:
|
|
64
|
+
print('[epoch {epoch}] train loss: {train}'.format(epoch=__e, train=__loss.item()))
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
# NGRAMS ######################################################################
|
|
4
|
+
|
|
5
|
+
def _next(model: torch.nn.Module, ngram: list) -> int:
|
|
6
|
+
__logits = model(torch.tensor([ngram]), training=False)
|
|
7
|
+
__probs = torch.nn.functional.softmax(__logits, dim=-1)
|
|
8
|
+
return torch.multinomial(__probs, num_samples=1).item()
|
|
9
|
+
|
|
10
|
+
def sample(model: torch.nn.Module, context: int, length: int) -> str:
|
|
11
|
+
__result = []
|
|
12
|
+
__ngram = context * [0]
|
|
13
|
+
for __i in range(length):
|
|
14
|
+
__n = _next(model=model, ngram=__ngram)
|
|
15
|
+
__result.append(__n)
|
|
16
|
+
__ngram = __ngram[1:] + [__n]
|
|
17
|
+
return __result
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "mlable-torch"
|
|
3
|
+
version = "0.2.0"
|
|
4
|
+
description = "PyTorch libs."
|
|
5
|
+
authors = ["apehex <apehex@protonmail.com>"]
|
|
6
|
+
readme = ".github/README.md"
|
|
7
|
+
packages = [{include = "mlable"}]
|
|
8
|
+
|
|
9
|
+
[tool.poetry.dependencies]
|
|
10
|
+
python = ">=3.10, <3.12"
|
|
11
|
+
torch = ">=2.2"
|
|
12
|
+
|
|
13
|
+
[tool.poetry.group.dev.dependencies]
|
|
14
|
+
pytest = "*"
|
|
15
|
+
|
|
16
|
+
[build-system]
|
|
17
|
+
requires = ["poetry-core"]
|
|
18
|
+
build-backend = "poetry.core.masonry.api"
|