titans-pytorch 0.2.11__tar.gz → 0.2.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/PKG-INFO +1 -1
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/pyproject.toml +1 -1
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/titans_pytorch/neural_memory.py +26 -6
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/train_mac.py +2 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/.gitignore +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/LICENSE +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/README.md +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/data/README.md +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/fig1.png +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/fig2.png +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/tests/test_titans.py +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.2.11 → titans_pytorch-0.2.14}/titans_pytorch/memory_models.py +0 -0
@@ -284,15 +284,18 @@ class NeuralMemory(Module):
|
|
284
284
|
adaptive_step_transform: Callable | None = None,
|
285
285
|
default_step_transform_max_lr = 1.,
|
286
286
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
287
|
-
max_mem_layer_modulation =
|
287
|
+
max_mem_layer_modulation = 1., # max of 10.
|
288
288
|
attn_pool_chunks = False,
|
289
289
|
momentum = True,
|
290
290
|
pre_rmsnorm = True,
|
291
|
-
post_rmsnorm =
|
291
|
+
post_rmsnorm = False,
|
292
292
|
qk_rmsnorm = False,
|
293
293
|
max_grad_norm: float | None = None,
|
294
294
|
use_accelerated_scan = False,
|
295
295
|
activation: Module | None = None,
|
296
|
+
init_adaptive_step_bias = None,
|
297
|
+
init_momentum_bias = None,
|
298
|
+
init_decay_bias = None,
|
296
299
|
default_model_kwargs: dict = dict(
|
297
300
|
depth = 2
|
298
301
|
)
|
@@ -411,12 +414,12 @@ class NeuralMemory(Module):
|
|
411
414
|
# learned adaptive learning rate and momentum
|
412
415
|
|
413
416
|
self.to_momentum = Sequential(
|
414
|
-
|
417
|
+
nn.Linear(dim, heads),
|
415
418
|
Rearrange('b n h -> (b h) n 1')
|
416
419
|
) if momentum else None
|
417
420
|
|
418
421
|
self.to_adaptive_step = Sequential(
|
419
|
-
|
422
|
+
nn.Linear(dim, heads),
|
420
423
|
Rearrange('b n h -> (b h) n')
|
421
424
|
)
|
422
425
|
|
@@ -428,7 +431,7 @@ class NeuralMemory(Module):
|
|
428
431
|
# per layer learning rate modulation
|
429
432
|
|
430
433
|
self.to_layer_modulation = Sequential(
|
431
|
-
|
434
|
+
nn.Linear(dim, heads * self.num_memory_parameter_tensors),
|
432
435
|
Rearrange('b n (h w) -> w (b h) n', h = heads),
|
433
436
|
nn.Sigmoid()
|
434
437
|
) if per_parameter_lr_modulation else None
|
@@ -442,10 +445,27 @@ class NeuralMemory(Module):
|
|
442
445
|
# weight decay factor
|
443
446
|
|
444
447
|
self.to_decay_factor = Sequential(
|
445
|
-
|
448
|
+
nn.Linear(dim, heads),
|
446
449
|
Rearrange('b n h -> (b h) n 1')
|
447
450
|
)
|
448
451
|
|
452
|
+
# inits
|
453
|
+
|
454
|
+
if exists(init_adaptive_step_bias):
|
455
|
+
linear = self.to_adaptive_step[0]
|
456
|
+
nn.init.zeros_(linear.weight)
|
457
|
+
nn.init.constant_(linear.bias, init_adaptive_step_bias)
|
458
|
+
|
459
|
+
if exists(init_momentum_bias):
|
460
|
+
linear = self.to_momentum[0]
|
461
|
+
nn.init.zeros_(linear.weight)
|
462
|
+
nn.init.constant_(linear.bias, init_momentum_bias)
|
463
|
+
|
464
|
+
if exists(init_decay_bias):
|
465
|
+
linear = self.to_decay_factor[0]
|
466
|
+
nn.init.zeros_(linear.weight)
|
467
|
+
nn.init.constant_(linear.bias, init_decay_bias)
|
468
|
+
|
449
469
|
# maybe use accelerated scan
|
450
470
|
|
451
471
|
self.use_accelerated_scan = use_accelerated_scan
|
@@ -34,6 +34,7 @@ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural mem
|
|
34
34
|
NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
35
35
|
NEURAL_MEM_MOMENTUM = True
|
36
36
|
NEURAL_MEM_QK_NORM = True
|
37
|
+
NEURAL_MEM_MAX_LR = 1e-1
|
37
38
|
WINDOW_SIZE = 32
|
38
39
|
NEURAL_MEM_SEGMENT_LEN = 2 # set smaller for more granularity for learning rate / momentum etc
|
39
40
|
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
@@ -86,6 +87,7 @@ model = MemoryAsContextTransformer(
|
|
86
87
|
neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
|
87
88
|
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
|
88
89
|
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
|
90
|
+
default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
|
89
91
|
use_flex_attn = USE_FLEX_ATTN,
|
90
92
|
sliding_window_attn = SLIDING_WINDOWS,
|
91
93
|
neural_memory_model = MemoryMLP(
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|