titans-pytorch 0.0.1__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/PKG-INFO +12 -7
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/README.md +9 -6
- titans_pytorch-0.0.8/data/README.md +3 -0
- titans_pytorch-0.0.8/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/pyproject.toml +5 -2
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/titans_pytorch/titans.py +91 -29
- titans_pytorch-0.0.8/train.py +132 -0
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/.gitignore +0 -0
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/LICENSE +0 -0
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/fig1.png +0 -0
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/fig2.png +0 -0
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.1 → titans_pytorch-0.0.8}/titans_pytorch/associative_scan.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -39,6 +39,8 @@ Requires-Dist: einx>=0.3.0
|
|
|
39
39
|
Requires-Dist: tensordict>=0.6.2
|
|
40
40
|
Requires-Dist: torch>=2.3
|
|
41
41
|
Provides-Extra: examples
|
|
42
|
+
Requires-Dist: local-attention>=1.10.0; extra == 'examples'
|
|
43
|
+
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
|
42
44
|
Provides-Extra: test
|
|
43
45
|
Requires-Dist: pytest; extra == 'test'
|
|
44
46
|
Description-Content-Type: text/markdown
|
|
@@ -49,7 +51,7 @@ Description-Content-Type: text/markdown
|
|
|
49
51
|
|
|
50
52
|
## Titans - Pytorch (wip)
|
|
51
53
|
|
|
52
|
-
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
|
|
54
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
|
53
55
|
|
|
54
56
|
## Install
|
|
55
57
|
|
|
@@ -63,13 +65,16 @@ $ pip install titans-pytorch
|
|
|
63
65
|
import torch
|
|
64
66
|
from titans_pytorch import NeuralMemory
|
|
65
67
|
|
|
66
|
-
|
|
68
|
+
mem = NeuralMemory(
|
|
69
|
+
dim = 384,
|
|
70
|
+
chunk_size = 64,
|
|
71
|
+
pre_rmsnorm = True
|
|
72
|
+
).cuda()
|
|
67
73
|
|
|
68
|
-
|
|
74
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
|
75
|
+
retrieved = mem(seq)
|
|
69
76
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
assert x.shape == out.shape
|
|
77
|
+
assert seq.shape == retrieved.shape
|
|
73
78
|
```
|
|
74
79
|
|
|
75
80
|
## Citations
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
## Titans - Pytorch (wip)
|
|
6
6
|
|
|
7
|
-
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
|
|
7
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
|
8
8
|
|
|
9
9
|
## Install
|
|
10
10
|
|
|
@@ -18,13 +18,16 @@ $ pip install titans-pytorch
|
|
|
18
18
|
import torch
|
|
19
19
|
from titans_pytorch import NeuralMemory
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
mem = NeuralMemory(
|
|
22
|
+
dim = 384,
|
|
23
|
+
chunk_size = 64,
|
|
24
|
+
pre_rmsnorm = True
|
|
25
|
+
).cuda()
|
|
22
26
|
|
|
23
|
-
|
|
27
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
|
28
|
+
retrieved = mem(seq)
|
|
24
29
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
assert x.shape == out.shape
|
|
30
|
+
assert seq.shape == retrieved.shape
|
|
28
31
|
```
|
|
29
32
|
|
|
30
33
|
## Citations
|
|
Binary file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.8"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -36,7 +36,10 @@ Homepage = "https://pypi.org/project/titans-pytorch/"
|
|
|
36
36
|
Repository = "https://github.com/lucidrains/titans-pytorch"
|
|
37
37
|
|
|
38
38
|
[project.optional-dependencies]
|
|
39
|
-
examples = [
|
|
39
|
+
examples = [
|
|
40
|
+
"local-attention>=1.10.0",
|
|
41
|
+
"taylor-series-linear-attention"
|
|
42
|
+
]
|
|
40
43
|
test = [
|
|
41
44
|
"pytest"
|
|
42
45
|
]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import math
|
|
2
3
|
from functools import partial
|
|
3
4
|
|
|
4
5
|
import torch
|
|
@@ -11,12 +12,13 @@ from tensordict import TensorDict
|
|
|
11
12
|
|
|
12
13
|
from titans_pytorch.associative_scan import (
|
|
13
14
|
associative_scan,
|
|
14
|
-
binary_operator
|
|
15
|
+
binary_operator,
|
|
16
|
+
pad_at_dim
|
|
15
17
|
)
|
|
16
18
|
|
|
17
19
|
import einx
|
|
18
20
|
from einops import rearrange, pack, unpack
|
|
19
|
-
from einops.layers.torch import Rearrange
|
|
21
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
20
22
|
|
|
21
23
|
"""
|
|
22
24
|
ein notation:
|
|
@@ -41,6 +43,9 @@ def default(v, d):
|
|
|
41
43
|
def round_down_multiple(seq, mult):
|
|
42
44
|
return seq // mult * mult
|
|
43
45
|
|
|
46
|
+
def round_up_multiple(seq, mult):
|
|
47
|
+
return math.ceil(seq / mult) * mult
|
|
48
|
+
|
|
44
49
|
def pack_one_with_inverse(t, pattern):
|
|
45
50
|
packed, packed_shape = pack([t], pattern)
|
|
46
51
|
|
|
@@ -50,6 +55,10 @@ def pack_one_with_inverse(t, pattern):
|
|
|
50
55
|
|
|
51
56
|
return packed, inverse
|
|
52
57
|
|
|
58
|
+
def softclamp_max(t, max_value):
|
|
59
|
+
range_value = max_value / 2
|
|
60
|
+
return ((t / range_value).tanh() * range_value) + range_value
|
|
61
|
+
|
|
53
62
|
# classes
|
|
54
63
|
|
|
55
64
|
class MLP(Module):
|
|
@@ -84,11 +93,17 @@ class NeuralMemory(Module):
|
|
|
84
93
|
def __init__(
|
|
85
94
|
self,
|
|
86
95
|
dim,
|
|
96
|
+
chunk_size = 1,
|
|
87
97
|
model: Module | None = None,
|
|
88
|
-
store_memory_loss_fn: Callable = default_loss_fn
|
|
98
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
|
99
|
+
pre_rmsnorm = False,
|
|
100
|
+
max_adaptive_step_size = 1e-5
|
|
89
101
|
):
|
|
90
102
|
super().__init__()
|
|
91
103
|
|
|
104
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
105
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
106
|
+
|
|
92
107
|
if not exists(model):
|
|
93
108
|
model = MLP(dim, depth = 4)
|
|
94
109
|
|
|
@@ -98,11 +113,15 @@ class NeuralMemory(Module):
|
|
|
98
113
|
|
|
99
114
|
self.memory_model = model
|
|
100
115
|
|
|
116
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
|
117
|
+
|
|
118
|
+
self.chunk_size = chunk_size
|
|
119
|
+
|
|
101
120
|
# prepare function for per sample gradients from model above, using torch.func
|
|
102
121
|
|
|
103
122
|
def forward_and_loss(params, inputs, target):
|
|
104
123
|
pred = functional_call(self.memory_model, params, inputs)
|
|
105
|
-
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k)
|
|
124
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
|
106
125
|
return loss
|
|
107
126
|
|
|
108
127
|
self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
|
|
@@ -119,9 +138,25 @@ class NeuralMemory(Module):
|
|
|
119
138
|
# learned adaptive learning rate and momentum
|
|
120
139
|
# todo - explore mlp layerwise learned lr / momentum
|
|
121
140
|
|
|
122
|
-
self.to_momentum =
|
|
123
|
-
|
|
124
|
-
|
|
141
|
+
self.to_momentum = nn.Sequential(
|
|
142
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
143
|
+
LinearNoBias(dim, 1)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self.to_adaptive_step = nn.Sequential(
|
|
147
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
148
|
+
LinearNoBias(dim, 1),
|
|
149
|
+
Rearrange('... 1 -> ...')
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.max_adaptive_step_size = max_adaptive_step_size
|
|
153
|
+
|
|
154
|
+
# weight decay factor
|
|
155
|
+
|
|
156
|
+
self.to_decay_factor = nn.Sequential(
|
|
157
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
158
|
+
LinearNoBias(dim, 1)
|
|
159
|
+
)
|
|
125
160
|
|
|
126
161
|
def init_weights_and_momentum(self):
|
|
127
162
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
@@ -137,6 +172,18 @@ class NeuralMemory(Module):
|
|
|
137
172
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
138
173
|
):
|
|
139
174
|
|
|
175
|
+
seq = self.store_norm(seq)
|
|
176
|
+
|
|
177
|
+
# curtail sequence by multiple of the chunk size
|
|
178
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
179
|
+
|
|
180
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
|
181
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
182
|
+
|
|
183
|
+
seq = seq[:, :round_down_seq_len]
|
|
184
|
+
|
|
185
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
|
186
|
+
|
|
140
187
|
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
|
141
188
|
|
|
142
189
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
@@ -148,16 +195,19 @@ class NeuralMemory(Module):
|
|
|
148
195
|
|
|
149
196
|
batch = seq.shape[0]
|
|
150
197
|
|
|
151
|
-
adaptive_lr = self.to_adaptive_step(seq)
|
|
152
|
-
adaptive_momentum = self.to_momentum(seq)
|
|
198
|
+
adaptive_lr = softclamp_max(self.to_adaptive_step(seq), self.max_adaptive_step_size)
|
|
153
199
|
|
|
154
|
-
|
|
200
|
+
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
201
|
+
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
155
202
|
|
|
156
203
|
# keys and values
|
|
157
204
|
|
|
158
|
-
seq = rearrange(seq, 'b n d -> (b n) d')
|
|
159
205
|
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
|
160
206
|
|
|
207
|
+
# take care of chunking
|
|
208
|
+
|
|
209
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
|
210
|
+
|
|
161
211
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
162
212
|
|
|
163
213
|
grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
|
|
@@ -172,31 +222,24 @@ class NeuralMemory(Module):
|
|
|
172
222
|
|
|
173
223
|
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
|
174
224
|
|
|
175
|
-
#
|
|
225
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
176
226
|
|
|
177
227
|
next_momentum = TensorDict()
|
|
228
|
+
updates = TensorDict()
|
|
178
229
|
|
|
179
230
|
for param_name, surprise in surprises.items():
|
|
180
231
|
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
|
181
232
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
momentum = inverse_pack(momentum)
|
|
185
|
-
|
|
186
|
-
next_momentum[param_name] = momentum
|
|
187
|
-
|
|
188
|
-
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
233
|
+
# derive momentum with associative scan - eq (10)
|
|
189
234
|
|
|
190
|
-
|
|
235
|
+
_, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
|
|
191
236
|
|
|
192
|
-
|
|
193
|
-
momentum, inverse_pack = pack_one_with_inverse(momentum, 'b n *')
|
|
237
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
194
238
|
|
|
195
239
|
_, update = associative_scan(binary_operator, (1. - decay_factor, momentum)) # momentum is S / surprise in the paper
|
|
196
240
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
updates[param_name] = update
|
|
241
|
+
updates[param_name] = inverse_pack(update)
|
|
242
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
|
200
243
|
|
|
201
244
|
# compute the next weight per batch
|
|
202
245
|
|
|
@@ -204,14 +247,28 @@ class NeuralMemory(Module):
|
|
|
204
247
|
|
|
205
248
|
next_state = (curr_weights + last_update, next_momentum)
|
|
206
249
|
|
|
207
|
-
return updates, next_state, aux_store_loss.mean()
|
|
250
|
+
return updates, next_state, aux_store_loss.mean() / chunk_size
|
|
208
251
|
|
|
209
252
|
def retrieve_memories(
|
|
210
253
|
self,
|
|
211
254
|
seq,
|
|
212
255
|
past_weights: dict[str, Tensor] | None = None,
|
|
213
256
|
):
|
|
214
|
-
|
|
257
|
+
chunk_size = self.chunk_size
|
|
258
|
+
batch, seq_len = seq.shape[:2]
|
|
259
|
+
|
|
260
|
+
seq = self.retrieve_norm(seq)
|
|
261
|
+
|
|
262
|
+
assert seq_len >= chunk_size
|
|
263
|
+
|
|
264
|
+
seq = seq[:, (chunk_size - 1):]
|
|
265
|
+
curtailed_seq_len = seq.shape[-2]
|
|
266
|
+
|
|
267
|
+
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
|
268
|
+
|
|
269
|
+
padding = next_seq_len - curtailed_seq_len
|
|
270
|
+
|
|
271
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
215
272
|
|
|
216
273
|
# the parameters of the memory model stores the memories of the key / values
|
|
217
274
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
|
@@ -231,7 +288,7 @@ class NeuralMemory(Module):
|
|
|
231
288
|
# fetch values from memory model
|
|
232
289
|
|
|
233
290
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
|
234
|
-
queries = rearrange(queries, 'b n d -> (b n)
|
|
291
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
|
235
292
|
|
|
236
293
|
# forward functional call
|
|
237
294
|
|
|
@@ -239,7 +296,12 @@ class NeuralMemory(Module):
|
|
|
239
296
|
|
|
240
297
|
# reconstitute batch dimension
|
|
241
298
|
|
|
242
|
-
values = rearrange(values, '(b n)
|
|
299
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
|
300
|
+
|
|
301
|
+
# restore
|
|
302
|
+
|
|
303
|
+
values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
|
304
|
+
values = values[:, :-padding]
|
|
243
305
|
|
|
244
306
|
return values
|
|
245
307
|
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import tqdm
|
|
3
|
+
import gzip
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.optim import Adam
|
|
9
|
+
from torch.nn import functional as F
|
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
|
11
|
+
|
|
12
|
+
from local_attention import LocalTransformer
|
|
13
|
+
|
|
14
|
+
from taylor_series_linear_attention import TaylorSeriesLinearAttn
|
|
15
|
+
|
|
16
|
+
from titans_pytorch.titans import NeuralMemory
|
|
17
|
+
|
|
18
|
+
# constants
|
|
19
|
+
|
|
20
|
+
NUM_BATCHES = int(1e5)
|
|
21
|
+
BATCH_SIZE = 4
|
|
22
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
|
23
|
+
LEARNING_RATE = 2e-4
|
|
24
|
+
VALIDATE_EVERY = 100
|
|
25
|
+
GENERATE_EVERY = 500
|
|
26
|
+
GENERATE_LENGTH = 512
|
|
27
|
+
SHOULD_GENERATE = False
|
|
28
|
+
SEQ_LEN = 512
|
|
29
|
+
|
|
30
|
+
# helpers
|
|
31
|
+
|
|
32
|
+
def cycle(loader):
|
|
33
|
+
while True:
|
|
34
|
+
for data in loader:
|
|
35
|
+
yield data
|
|
36
|
+
|
|
37
|
+
def decode_token(token):
|
|
38
|
+
return str(chr(max(32, token)))
|
|
39
|
+
|
|
40
|
+
def decode_tokens(tokens):
|
|
41
|
+
return ''.join(list(map(decode_token, tokens)))
|
|
42
|
+
|
|
43
|
+
# instantiate GPT-like decoder model
|
|
44
|
+
|
|
45
|
+
titans_neural_memory = NeuralMemory(
|
|
46
|
+
dim = 384,
|
|
47
|
+
chunk_size = 64,
|
|
48
|
+
pre_rmsnorm = True
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
titans_neural_memory = nn.Sequential(
|
|
52
|
+
titans_neural_memory,
|
|
53
|
+
nn.RMSNorm(384)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
linear_attn = TaylorSeriesLinearAttn(
|
|
57
|
+
dim = 384,
|
|
58
|
+
dim_head = 16,
|
|
59
|
+
heads = 16,
|
|
60
|
+
causal = True
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
model = LocalTransformer(
|
|
64
|
+
num_tokens = 256,
|
|
65
|
+
dim = 384,
|
|
66
|
+
depth = 8,
|
|
67
|
+
causal = True,
|
|
68
|
+
local_attn_window_size = 64,
|
|
69
|
+
max_seq_len = SEQ_LEN,
|
|
70
|
+
global_attn_layer = titans_neural_memory,
|
|
71
|
+
layers_insert_global_attn = (4,)
|
|
72
|
+
).cuda()
|
|
73
|
+
|
|
74
|
+
# prepare enwik8 data
|
|
75
|
+
|
|
76
|
+
with gzip.open('./data/enwik8.gz') as file:
|
|
77
|
+
data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
|
|
78
|
+
data_train, data_val = np.split(data, [int(90e6)])
|
|
79
|
+
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
|
|
80
|
+
|
|
81
|
+
class TextSamplerDataset(Dataset):
|
|
82
|
+
def __init__(self, data, seq_len):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.data = data
|
|
85
|
+
self.seq_len = seq_len
|
|
86
|
+
|
|
87
|
+
def __getitem__(self, index):
|
|
88
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
|
89
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
|
90
|
+
return full_seq.cuda()
|
|
91
|
+
|
|
92
|
+
def __len__(self):
|
|
93
|
+
return self.data.size(0) // self.seq_len
|
|
94
|
+
|
|
95
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
|
96
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
|
97
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
|
98
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
|
99
|
+
|
|
100
|
+
# optimizer
|
|
101
|
+
|
|
102
|
+
optim = Adam(model.parameters(), lr=LEARNING_RATE)
|
|
103
|
+
|
|
104
|
+
# training
|
|
105
|
+
|
|
106
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
107
|
+
model.train()
|
|
108
|
+
|
|
109
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
|
110
|
+
loss = model(next(train_loader), return_loss = True)
|
|
111
|
+
loss.backward()
|
|
112
|
+
|
|
113
|
+
print(f'training loss: {loss.item()}')
|
|
114
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
115
|
+
optim.step()
|
|
116
|
+
optim.zero_grad()
|
|
117
|
+
|
|
118
|
+
if i % VALIDATE_EVERY == 0:
|
|
119
|
+
model.eval()
|
|
120
|
+
with torch.no_grad():
|
|
121
|
+
loss = model(next(val_loader), return_loss = True)
|
|
122
|
+
print(f'validation loss: {loss.item()}')
|
|
123
|
+
|
|
124
|
+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
|
125
|
+
model.eval()
|
|
126
|
+
inp = random.choice(val_dataset)[:-1]
|
|
127
|
+
prime = decode_tokens(inp)
|
|
128
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
|
129
|
+
|
|
130
|
+
sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
|
|
131
|
+
output_str = decode_tokens(sample[0])
|
|
132
|
+
print(output_str)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|