titans-pytorch 0.2.11__tar.gz → 0.2.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.11
3
+ Version: 0.2.14
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.11"
3
+ version = "0.2.14"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -284,15 +284,18 @@ class NeuralMemory(Module):
284
284
  adaptive_step_transform: Callable | None = None,
285
285
  default_step_transform_max_lr = 1.,
286
286
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
287
- max_mem_layer_modulation = 1e1, # max of 10.
287
+ max_mem_layer_modulation = 1., # max of 10.
288
288
  attn_pool_chunks = False,
289
289
  momentum = True,
290
290
  pre_rmsnorm = True,
291
- post_rmsnorm = True,
291
+ post_rmsnorm = False,
292
292
  qk_rmsnorm = False,
293
293
  max_grad_norm: float | None = None,
294
294
  use_accelerated_scan = False,
295
295
  activation: Module | None = None,
296
+ init_adaptive_step_bias = None,
297
+ init_momentum_bias = None,
298
+ init_decay_bias = None,
296
299
  default_model_kwargs: dict = dict(
297
300
  depth = 2
298
301
  )
@@ -411,12 +414,12 @@ class NeuralMemory(Module):
411
414
  # learned adaptive learning rate and momentum
412
415
 
413
416
  self.to_momentum = Sequential(
414
- LinearNoBias(dim, heads),
417
+ nn.Linear(dim, heads),
415
418
  Rearrange('b n h -> (b h) n 1')
416
419
  ) if momentum else None
417
420
 
418
421
  self.to_adaptive_step = Sequential(
419
- LinearNoBias(dim, heads),
422
+ nn.Linear(dim, heads),
420
423
  Rearrange('b n h -> (b h) n')
421
424
  )
422
425
 
@@ -428,7 +431,7 @@ class NeuralMemory(Module):
428
431
  # per layer learning rate modulation
429
432
 
430
433
  self.to_layer_modulation = Sequential(
431
- LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
434
+ nn.Linear(dim, heads * self.num_memory_parameter_tensors),
432
435
  Rearrange('b n (h w) -> w (b h) n', h = heads),
433
436
  nn.Sigmoid()
434
437
  ) if per_parameter_lr_modulation else None
@@ -442,10 +445,27 @@ class NeuralMemory(Module):
442
445
  # weight decay factor
443
446
 
444
447
  self.to_decay_factor = Sequential(
445
- LinearNoBias(dim, heads),
448
+ nn.Linear(dim, heads),
446
449
  Rearrange('b n h -> (b h) n 1')
447
450
  )
448
451
 
452
+ # inits
453
+
454
+ if exists(init_adaptive_step_bias):
455
+ linear = self.to_adaptive_step[0]
456
+ nn.init.zeros_(linear.weight)
457
+ nn.init.constant_(linear.bias, init_adaptive_step_bias)
458
+
459
+ if exists(init_momentum_bias):
460
+ linear = self.to_momentum[0]
461
+ nn.init.zeros_(linear.weight)
462
+ nn.init.constant_(linear.bias, init_momentum_bias)
463
+
464
+ if exists(init_decay_bias):
465
+ linear = self.to_decay_factor[0]
466
+ nn.init.zeros_(linear.weight)
467
+ nn.init.constant_(linear.bias, init_decay_bias)
468
+
449
469
  # maybe use accelerated scan
450
470
 
451
471
  self.use_accelerated_scan = use_accelerated_scan
@@ -34,6 +34,7 @@ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural mem
34
34
  NEURAL_MEM_GATE_ATTN_OUTPUT = False
35
35
  NEURAL_MEM_MOMENTUM = True
36
36
  NEURAL_MEM_QK_NORM = True
37
+ NEURAL_MEM_MAX_LR = 1e-1
37
38
  WINDOW_SIZE = 32
38
39
  NEURAL_MEM_SEGMENT_LEN = 2 # set smaller for more granularity for learning rate / momentum etc
39
40
  NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
@@ -86,6 +87,7 @@ model = MemoryAsContextTransformer(
86
87
  neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
87
88
  neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
88
89
  neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
90
+ default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
89
91
  use_flex_attn = USE_FLEX_ATTN,
90
92
  sliding_window_attn = SLIDING_WINDOWS,
91
93
  neural_memory_model = MemoryMLP(
File without changes