titans-pytorch 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +23 -14
- {titans_pytorch-0.1.12.dist-info → titans_pytorch-0.1.14.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.14.dist-info/RECORD +8 -0
- titans_pytorch-0.1.12.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.12.dist-info → titans_pytorch-0.1.14.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.12.dist-info → titans_pytorch-0.1.14.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -323,6 +323,7 @@ class NeuralMemory(Module):
|
|
|
323
323
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
|
324
324
|
max_mem_layer_modulation = 1e1, # max of 10.
|
|
325
325
|
attn_pool_chunks = False,
|
|
326
|
+
momentum = True,
|
|
326
327
|
pre_rmsnorm = True,
|
|
327
328
|
post_rmsnorm = True,
|
|
328
329
|
learned_mem_model_weights = True,
|
|
@@ -423,7 +424,7 @@ class NeuralMemory(Module):
|
|
|
423
424
|
self.to_momentum = Sequential(
|
|
424
425
|
LinearNoBias(dim, heads),
|
|
425
426
|
Rearrange('b n h -> (b h) n 1')
|
|
426
|
-
)
|
|
427
|
+
) if momentum else None
|
|
427
428
|
|
|
428
429
|
self.to_adaptive_step = Sequential(
|
|
429
430
|
LinearNoBias(dim, heads),
|
|
@@ -462,12 +463,15 @@ class NeuralMemory(Module):
|
|
|
462
463
|
|
|
463
464
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
464
465
|
|
|
465
|
-
def init_weights_and_momentum(self):
|
|
466
|
+
def init_weights_and_momentum(self, zero_weights = False):
|
|
466
467
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
467
468
|
|
|
468
|
-
init_weights = params
|
|
469
|
+
init_weights = params
|
|
469
470
|
init_momentum = params.clone().zero_()
|
|
470
471
|
|
|
472
|
+
if zero_weights:
|
|
473
|
+
init_weights = params.clone().zero_()
|
|
474
|
+
|
|
471
475
|
return init_weights, init_momentum
|
|
472
476
|
|
|
473
477
|
def init_empty_memory_embed(self, batch, seq_len):
|
|
@@ -497,14 +501,10 @@ class NeuralMemory(Module):
|
|
|
497
501
|
|
|
498
502
|
seq = seq[:, :round_down_seq_len]
|
|
499
503
|
|
|
500
|
-
#
|
|
501
|
-
|
|
502
|
-
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
|
504
|
+
# get the weights of the memory network
|
|
503
505
|
|
|
504
506
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
curr_weights = curr_weights + past_weights
|
|
507
|
+
curr_weights, past_momentum = past_state
|
|
508
508
|
|
|
509
509
|
# derive learned hparams for optimization of memory network
|
|
510
510
|
|
|
@@ -513,10 +513,13 @@ class NeuralMemory(Module):
|
|
|
513
513
|
|
|
514
514
|
chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
|
|
515
515
|
|
|
516
|
-
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
|
517
516
|
decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
|
|
518
517
|
|
|
519
518
|
need_layer_lr_mod = exists(self.to_layer_modulation)
|
|
519
|
+
has_momentum = exists(self.to_momentum)
|
|
520
|
+
|
|
521
|
+
if has_momentum:
|
|
522
|
+
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
|
520
523
|
|
|
521
524
|
if need_layer_lr_mod:
|
|
522
525
|
layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
|
|
@@ -595,23 +598,29 @@ class NeuralMemory(Module):
|
|
|
595
598
|
|
|
596
599
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
597
600
|
|
|
598
|
-
next_momentum = TensorDict()
|
|
601
|
+
next_momentum = TensorDict() if has_momentum else None
|
|
599
602
|
updates = TensorDict()
|
|
600
603
|
|
|
601
604
|
for param_name, surprise in surprises.items():
|
|
602
605
|
|
|
603
606
|
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
|
604
607
|
|
|
608
|
+
update = surprise
|
|
609
|
+
|
|
605
610
|
# derive momentum with associative scan - eq (10)
|
|
606
611
|
|
|
607
|
-
|
|
612
|
+
if has_momentum:
|
|
613
|
+
update = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
|
614
|
+
momentum = update
|
|
608
615
|
|
|
609
616
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
610
617
|
|
|
611
|
-
update = scan_fn(1. - decay_factor,
|
|
618
|
+
update = scan_fn(1. - decay_factor, update)
|
|
612
619
|
|
|
613
620
|
updates[param_name] = inverse_pack(update)
|
|
614
|
-
|
|
621
|
+
|
|
622
|
+
if has_momentum:
|
|
623
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
|
615
624
|
|
|
616
625
|
# compute the next weight per batch
|
|
617
626
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
|
|
4
|
+
titans_pytorch/titans.py,sha256=J7UbhhL0YKoQjrKPCgwUYwJd8-YfCZrHTSVjIFvvRRw,22077
|
|
5
|
+
titans_pytorch-0.1.14.dist-info/METADATA,sha256=lM_x9obTibKKMXXdsW0U5GcQLTR4aXCtemEKPzYPqYo,6340
|
|
6
|
+
titans_pytorch-0.1.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.14.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=zxknstaI_Uz47Y8WvZ3S7geJ-TNdqKV5Rvj0Jlw8njs,19271
|
|
4
|
-
titans_pytorch/titans.py,sha256=eDTqAIDZjSLd34t8M-dCaqVf_s0wZ9jhVIOfXF7E9ts,21887
|
|
5
|
-
titans_pytorch-0.1.12.dist-info/METADATA,sha256=dL8HpHt6V5gN8p8px7sc2IgJGqXthE7rULKIrRFCwF8,6340
|
|
6
|
-
titans_pytorch-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|