titans-pytorch 0.0.46__tar.gz → 0.0.49__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/PKG-INFO +1 -1
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/pyproject.toml +1 -1
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/titans_pytorch/mac_transformer.py +28 -4
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/titans_pytorch/titans.py +14 -10
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/train_mac.py +9 -7
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/.gitignore +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/LICENSE +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/README.md +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/data/README.md +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/fig1.png +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/fig2.png +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.46 → titans_pytorch-0.0.49}/titans_pytorch/associative_scan.py +0 -0
|
@@ -95,7 +95,7 @@ class SegmentedAttention(Module):
|
|
|
95
95
|
dim_head = 64,
|
|
96
96
|
heads = 8,
|
|
97
97
|
accept_value_residual = False,
|
|
98
|
-
attend_kwargs: dict = dict()
|
|
98
|
+
attend_kwargs: dict = dict(),
|
|
99
99
|
):
|
|
100
100
|
super().__init__()
|
|
101
101
|
self.norm = nn.RMSNorm(dim)
|
|
@@ -201,6 +201,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
201
201
|
num_residual_streams = 4,
|
|
202
202
|
neural_memory_kwargs: dict = dict(),
|
|
203
203
|
neural_memory_layers: tuple[int, ...] | None = None,
|
|
204
|
+
aux_kv_recon_loss_weight = 0.
|
|
204
205
|
):
|
|
205
206
|
super().__init__()
|
|
206
207
|
|
|
@@ -276,10 +277,18 @@ class MemoryAsContextTransformer(Module):
|
|
|
276
277
|
|
|
277
278
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
|
278
279
|
|
|
280
|
+
# auxiliary loss on kv recon
|
|
281
|
+
|
|
282
|
+
self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
|
|
283
|
+
self.aux_kv_recon_loss_weight = aux_kv_recon_loss_weight
|
|
284
|
+
|
|
285
|
+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
286
|
+
|
|
279
287
|
def forward(
|
|
280
288
|
self,
|
|
281
289
|
x,
|
|
282
|
-
return_loss = False
|
|
290
|
+
return_loss = False,
|
|
291
|
+
return_loss_breakdown = False
|
|
283
292
|
):
|
|
284
293
|
|
|
285
294
|
if return_loss:
|
|
@@ -317,6 +326,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
317
326
|
|
|
318
327
|
value_residual = None
|
|
319
328
|
|
|
329
|
+
# aux losses
|
|
330
|
+
|
|
331
|
+
kv_recon_losses = self.zero
|
|
332
|
+
|
|
320
333
|
# expand and reduce streams for hyper connections
|
|
321
334
|
|
|
322
335
|
x = self.expand_streams(x)
|
|
@@ -324,7 +337,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
324
337
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
|
325
338
|
|
|
326
339
|
if exists(maybe_neural_mem):
|
|
327
|
-
x = maybe_neural_mem(x)
|
|
340
|
+
x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
|
|
341
|
+
kv_recon_losses = kv_recon_losses + aux_kv_loss
|
|
328
342
|
|
|
329
343
|
x, values = attn(x, value_residual = value_residual)
|
|
330
344
|
|
|
@@ -351,4 +365,14 @@ class MemoryAsContextTransformer(Module):
|
|
|
351
365
|
if not return_loss:
|
|
352
366
|
return logits
|
|
353
367
|
|
|
354
|
-
|
|
368
|
+
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
|
369
|
+
|
|
370
|
+
losses = ar_loss
|
|
371
|
+
|
|
372
|
+
if self.has_aux_kv_recon_loss:
|
|
373
|
+
losses = losses + kv_recon_losses * self.aux_kv_recon_loss_weight
|
|
374
|
+
|
|
375
|
+
if not return_loss_breakdown:
|
|
376
|
+
return losses
|
|
377
|
+
|
|
378
|
+
return losses, (ar_loss, kv_recon_losses)
|
|
@@ -217,10 +217,10 @@ class NeuralMemory(Module):
|
|
|
217
217
|
def forward_and_loss(params, inputs, loss_weights, target):
|
|
218
218
|
pred = functional_call(self.memory_model, params, inputs)
|
|
219
219
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
|
220
|
-
|
|
221
|
-
return
|
|
220
|
+
weighted_loss = loss * loss_weights
|
|
221
|
+
return weighted_loss.sum(), weighted_loss.mean()
|
|
222
222
|
|
|
223
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
|
|
223
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
|
|
224
224
|
|
|
225
225
|
# queries for retrieving from the model
|
|
226
226
|
|
|
@@ -282,7 +282,8 @@ class NeuralMemory(Module):
|
|
|
282
282
|
def store_memories(
|
|
283
283
|
self,
|
|
284
284
|
seq,
|
|
285
|
-
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
285
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
|
286
|
+
return_aux_kv_loss = False
|
|
286
287
|
):
|
|
287
288
|
|
|
288
289
|
seq = self.store_norm(seq)
|
|
@@ -330,7 +331,7 @@ class NeuralMemory(Module):
|
|
|
330
331
|
|
|
331
332
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
332
333
|
|
|
333
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
|
334
|
+
grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
|
334
335
|
|
|
335
336
|
grads = TensorDict(grads)
|
|
336
337
|
|
|
@@ -405,7 +406,10 @@ class NeuralMemory(Module):
|
|
|
405
406
|
|
|
406
407
|
next_state = (curr_weights + last_update, next_momentum)
|
|
407
408
|
|
|
408
|
-
|
|
409
|
+
if not return_aux_kv_loss:
|
|
410
|
+
return updates, next_state
|
|
411
|
+
|
|
412
|
+
return updates, next_state, aux_kv_recon_loss.mean()
|
|
409
413
|
|
|
410
414
|
def retrieve_memories(
|
|
411
415
|
self,
|
|
@@ -484,7 +488,7 @@ class NeuralMemory(Module):
|
|
|
484
488
|
seq,
|
|
485
489
|
store_seq = None,
|
|
486
490
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
|
487
|
-
|
|
491
|
+
return_aux_kv_loss = False
|
|
488
492
|
):
|
|
489
493
|
batch, seq_len = seq.shape[:2]
|
|
490
494
|
|
|
@@ -499,13 +503,13 @@ class NeuralMemory(Module):
|
|
|
499
503
|
|
|
500
504
|
store_seq = default(store_seq, seq)
|
|
501
505
|
|
|
502
|
-
updates, next_memories = self.store_memories(store_seq, past_state)
|
|
506
|
+
updates, next_memories, aux_kv_recon_loss = self.store_memories(store_seq, past_state, return_aux_kv_loss = True)
|
|
503
507
|
|
|
504
508
|
past_weights, _ = past_state
|
|
505
509
|
|
|
506
510
|
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
507
511
|
|
|
508
|
-
if not
|
|
512
|
+
if not return_aux_kv_loss:
|
|
509
513
|
return retrieved
|
|
510
514
|
|
|
511
|
-
return retrieved,
|
|
515
|
+
return retrieved, aux_kv_recon_loss
|
|
@@ -24,13 +24,14 @@ SHOULD_GENERATE = False
|
|
|
24
24
|
SEQ_LEN = 512
|
|
25
25
|
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
|
27
|
-
WANDB_ONLINE =
|
|
27
|
+
WANDB_ONLINE = True # turn this on to pipe experiment to cloud
|
|
28
28
|
NEURAL_MEMORY_DEPTH = 2
|
|
29
29
|
NUM_PERSIST_MEM = 4
|
|
30
30
|
NUM_LONGTERM_MEM = 4
|
|
31
31
|
NEURAL_MEM_LAYERS = (2, 4)
|
|
32
32
|
WINDOW_SIZE = 32
|
|
33
|
-
|
|
33
|
+
KV_RECON_LOSS_WEIGHT = 0.1
|
|
34
|
+
RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS} - kv recon loss 0.1'
|
|
34
35
|
|
|
35
36
|
# wandb experiment tracker
|
|
36
37
|
|
|
@@ -63,6 +64,7 @@ model = MemoryAsContextTransformer(
|
|
|
63
64
|
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
|
64
65
|
neural_memory_layers = NEURAL_MEM_LAYERS,
|
|
65
66
|
neural_memory_segment_len = WINDOW_SIZE // 2,
|
|
67
|
+
aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
|
|
66
68
|
neural_memory_kwargs = dict(
|
|
67
69
|
dim_head = 64,
|
|
68
70
|
heads = 4,
|
|
@@ -108,20 +110,20 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
108
110
|
model.train()
|
|
109
111
|
|
|
110
112
|
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
|
111
|
-
loss = model(next(train_loader), return_loss = True)
|
|
113
|
+
loss, (ar_loss, kv_recon_losses) = model(next(train_loader), return_loss = True, return_loss_breakdown = True)
|
|
112
114
|
loss.backward()
|
|
113
115
|
|
|
114
|
-
print(f'training loss: {
|
|
116
|
+
print(f'training loss: {ar_loss.item()}')
|
|
115
117
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
116
118
|
optim.step()
|
|
117
119
|
optim.zero_grad()
|
|
118
|
-
wandb.log(dict(loss =
|
|
120
|
+
wandb.log(dict(loss = ar_loss.item()))
|
|
119
121
|
|
|
120
122
|
if i % VALIDATE_EVERY == 0:
|
|
121
123
|
model.eval()
|
|
122
124
|
with torch.no_grad():
|
|
123
|
-
loss = model(next(val_loader), return_loss = True)
|
|
124
|
-
print(f'validation loss: {
|
|
125
|
+
loss, (ar_loss, _) = model(next(val_loader), return_loss = True, return_loss_breakdown = True)
|
|
126
|
+
print(f'validation loss: {ar_loss.item()}')
|
|
125
127
|
|
|
126
128
|
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
|
127
129
|
model.eval()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|