titans-pytorch 0.0.50__py3-none-any.whl → 0.0.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,16 +7,48 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ # flex attention
11
+ # https://pytorch.org/blog/flexattention/
12
+
13
+ flex_attention = None
14
+
15
+ try:
16
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
+ if torch.cuda.is_available():
18
+ flex_attention = torch.compile(flex_attention)
19
+ except ImportError:
20
+ pass
21
+
22
+ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
23
+
24
+ def create_mac_mask(b, h, q_idx, kv_idx):
25
+ is_persist_mem = kv_idx < persist_mem_len
26
+ causal_mask = q_idx >= (kv_idx - is_persist_mem)
27
+ block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
28
+ return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
29
+
30
+ block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
31
+ return block_mask
32
+
33
+ # einstein notation related
34
+
10
35
  from einops import einsum, repeat, rearrange, pack, unpack
11
36
  from einops.layers.torch import Rearrange
12
37
 
13
- from hyper_connections import get_init_and_expand_reduce_stream_functions
38
+ # b - batch
39
+ # n - sequence
40
+ # h - heads
41
+ # d - feature dimension
14
42
 
15
43
  # absolute and relative positions
16
44
 
17
45
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
46
  from rotary_embedding_torch import RotaryEmbedding
47
+
48
+ # hyper connections / attend from x-transformers, which handles different queries and key lengths better
49
+
19
50
  from x_transformers.attend import Attend
51
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
20
52
 
21
53
  # proposed neural memory
22
54
 
titans_pytorch/titans.py CHANGED
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
56
56
 
57
57
  return packed, inverse
58
58
 
59
+ def Sequential(*modules):
60
+ modules = [*filter(exists, modules)]
61
+
62
+ if len(modules) == 0:
63
+ return nn.Identity()
64
+
65
+ if len(modules) == 1:
66
+ return modules[0]
67
+
68
+ return nn.Sequential(*modules)
69
+
59
70
  # softclamping gradients
60
71
 
61
72
  def softclamp_max(t, max_value):
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
124
135
  ])
125
136
 
126
137
  def forward(self, x):
127
-
128
- assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
129
-
130
138
  wq, wk, wv, ffw1, ffw2 = self.weights
131
139
 
132
140
  q = F.normalize(x @ wq, dim = -1)
@@ -168,6 +176,7 @@ class NeuralMemory(Module):
168
176
  post_rmsnorm = True,
169
177
  max_grad_norm: float | None = None,
170
178
  use_accelerated_scan = False,
179
+ activation: Module | None = None,
171
180
  default_model_kwargs: dict = dict(
172
181
  depth = 2
173
182
  )
@@ -225,11 +234,11 @@ class NeuralMemory(Module):
225
234
 
226
235
  # queries for retrieving from the model
227
236
 
228
- self.to_queries = LinearNoBias(dim, dim_inner)
237
+ self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
229
238
 
230
239
  # keys and values for storing to the model
231
240
 
232
- self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
241
+ self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
233
242
  self.store_memory_loss_fn = store_memory_loss_fn
234
243
 
235
244
  # empty memory embed
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.50
3
+ Version: 0.0.51
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -135,3 +135,12 @@ $ python train_mac.py
135
135
  year = {2024}
136
136
  }
137
137
  ```
138
+
139
+ ```bibtex
140
+ @inproceedings{Yang2024GatedDN,
141
+ title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
142
+ author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
143
+ year = {2024},
144
+ url = {https://api.semanticscholar.org/CorpusID:274598177}
145
+ }
146
+ ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=R_lWRe_iBkiWZ1iZAt1tNyjaTUyB5mb80mcYZHUKww0,11369
4
+ titans_pytorch/titans.py,sha256=T04onF0xhcrosS-Qkx7fcx-Cqgh0TdU5JLdq9l8ayGg,15911
5
+ titans_pytorch-0.0.51.dist-info/METADATA,sha256=OKqRYWucpjsLKKOksBrdHnvCAqgHgv2FUtL6l8hN2-Y,4484
6
+ titans_pytorch-0.0.51.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.51.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.51.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=EMhxPt86Vr6LFvPm0OLMFYLaIY19khU9yIHkIhl2EMA,10316
4
- titans_pytorch/titans.py,sha256=TklMAxNDxgFBpJZFJa8hEhqA_DITmT6EM0p0ueE1jo8,15712
5
- titans_pytorch-0.0.50.dist-info/METADATA,sha256=KU7TTrH89eNVPP10NKKTDKnW-ik344_kVQkAXW7NRL8,4210
6
- titans_pytorch-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.50.dist-info/RECORD,,