titans-pytorch 0.2.4__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.4
3
+ Version: 0.2.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.4"
3
+ version = "0.2.5"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -488,7 +488,7 @@ class MemoryAsContextTransformer(Module):
488
488
  neural_memory_model: Module | None = None,
489
489
  neural_memory_kwargs: dict = dict(),
490
490
  neural_memory_layers: tuple[int, ...] | None = None,
491
- aux_kv_recon_loss_weight = 0.,
491
+ aux_kv_recon_loss_weight = 1.,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
494
  weight_tie_memory_model = False,
@@ -1,5 +1,5 @@
1
1
  import torch
2
- from torch import nn
2
+ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
@@ -304,13 +304,26 @@ class NeuralMemory(Module):
304
304
  nn.Sigmoid()
305
305
  ) if heads > 1 else None
306
306
 
307
- # memory mlp
307
+ # memory model
308
308
 
309
309
  if not exists(model):
310
310
  model = MemoryMLP(dim_head, **default_model_kwargs)
311
311
 
312
+ # validate memory model
313
+
312
314
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
313
315
 
316
+ test_shape = (3, 2, dim_head)
317
+
318
+ with torch.no_grad():
319
+ try:
320
+ test_input = torch.randn(test_shape)
321
+ mem_model_output = model(test_input)
322
+ except:
323
+ raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
324
+
325
+ assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
326
+
314
327
  # the memory is the weights of the model
315
328
 
316
329
  self.memory_model = model
@@ -30,17 +30,18 @@ SEQ_LEN = 512
30
30
  NEURAL_MEMORY_DEPTH = 2
31
31
  NUM_PERSIST_MEM = 4
32
32
  NUM_LONGTERM_MEM = 4
33
- NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
33
+ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
34
34
  NEURAL_MEM_GATE_ATTN_OUTPUT = False
35
35
  NEURAL_MEM_MOMENTUM = True
36
- NEURAL_MEM_QK_NORM = False
36
+ NEURAL_MEM_QK_NORM = True
37
37
  WINDOW_SIZE = 32
38
- NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
38
+ NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
39
39
  SLIDING_WINDOWS = True
40
- WEIGHT_TIE_MEMORY_MODEL = True # set to have memory MLP shared across layers
41
- STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
40
+ WEIGHT_TIE_MEMORY_MODEL = False # set to have memory MLP shared across layers
41
+ PREV_MEM_UPDATE_FOR_WEIGHTS = True,
42
+ STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
42
43
  MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
43
- KV_RECON_LOSS_WEIGHT = 0.
44
+ KV_RECON_LOSS_WEIGHT = 1.
44
45
 
45
46
  # experiment related
46
47
 
@@ -90,6 +91,7 @@ model = MemoryAsContextTransformer(
90
91
  use_flex_attn = USE_FLEX_ATTN,
91
92
  sliding_window_attn = SLIDING_WINDOWS,
92
93
  weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
94
+ prev_neural_mem_update_for_weights = PREV_MEM_UPDATE_FOR_WEIGHTS,
93
95
  neural_memory_model = MemoryMLP(
94
96
  dim = 64,
95
97
  depth = NEURAL_MEMORY_DEPTH
File without changes
File without changes
File without changes
File without changes