titans-pytorch 0.0.49__py3-none-any.whl → 0.0.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,16 +7,48 @@ from torch import nn, cat
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module, ModuleList, Linear
9
9
 
10
+ # flex attention
11
+ # https://pytorch.org/blog/flexattention/
12
+
13
+ flex_attention = None
14
+
15
+ try:
16
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
17
+ if torch.cuda.is_available():
18
+ flex_attention = torch.compile(flex_attention)
19
+ except ImportError:
20
+ pass
21
+
22
+ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
23
+
24
+ def create_mac_mask(b, h, q_idx, kv_idx):
25
+ is_persist_mem = kv_idx < persist_mem_len
26
+ causal_mask = q_idx >= (kv_idx - is_persist_mem)
27
+ block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
28
+ return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
29
+
30
+ block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
31
+ return block_mask
32
+
33
+ # einstein notation related
34
+
10
35
  from einops import einsum, repeat, rearrange, pack, unpack
11
36
  from einops.layers.torch import Rearrange
12
37
 
13
- from hyper_connections import get_init_and_expand_reduce_stream_functions
38
+ # b - batch
39
+ # n - sequence
40
+ # h - heads
41
+ # d - feature dimension
14
42
 
15
43
  # absolute and relative positions
16
44
 
17
45
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
46
  from rotary_embedding_torch import RotaryEmbedding
47
+
48
+ # hyper connections / attend from x-transformers, which handles different queries and key lengths better
49
+
19
50
  from x_transformers.attend import Attend
51
+ from hyper_connections import get_init_and_expand_reduce_stream_functions
20
52
 
21
53
  # proposed neural memory
22
54
 
titans_pytorch/titans.py CHANGED
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
56
56
 
57
57
  return packed, inverse
58
58
 
59
+ def Sequential(*modules):
60
+ modules = [*filter(exists, modules)]
61
+
62
+ if len(modules) == 0:
63
+ return nn.Identity()
64
+
65
+ if len(modules) == 1:
66
+ return modules[0]
67
+
68
+ return nn.Sequential(*modules)
69
+
59
70
  # softclamping gradients
60
71
 
61
72
  def softclamp_max(t, max_value):
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
124
135
  ])
125
136
 
126
137
  def forward(self, x):
127
-
128
- assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
129
-
130
138
  wq, wk, wv, ffw1, ffw2 = self.weights
131
139
 
132
140
  q = F.normalize(x @ wq, dim = -1)
@@ -162,11 +170,13 @@ class NeuralMemory(Module):
162
170
  heads = 1,
163
171
  model: Module | None = None,
164
172
  store_memory_loss_fn: Callable = default_loss_fn,
165
- adaptive_step_transform: Callable = default_adaptive_step_transform,
173
+ adaptive_step_transform: Callable | None = None,
174
+ default_step_transform_max_lr = 1e-2,
166
175
  pre_rmsnorm = True,
167
176
  post_rmsnorm = True,
168
177
  max_grad_norm: float | None = None,
169
178
  use_accelerated_scan = False,
179
+ activation: Module | None = None,
170
180
  default_model_kwargs: dict = dict(
171
181
  depth = 2
172
182
  )
@@ -224,11 +234,11 @@ class NeuralMemory(Module):
224
234
 
225
235
  # queries for retrieving from the model
226
236
 
227
- self.to_queries = LinearNoBias(dim, dim_inner)
237
+ self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
228
238
 
229
239
  # keys and values for storing to the model
230
240
 
231
- self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
241
+ self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
232
242
  self.store_memory_loss_fn = store_memory_loss_fn
233
243
 
234
244
  # empty memory embed
@@ -250,6 +260,9 @@ class NeuralMemory(Module):
250
260
  Rearrange('b n h -> (b h) n')
251
261
  )
252
262
 
263
+ if not exists(adaptive_step_transform):
264
+ adaptive_step_transform = partial(default_adaptive_step_transform, max_lr = default_step_transform_max_lr)
265
+
253
266
  self.adaptive_step_transform = adaptive_step_transform
254
267
 
255
268
  # allow for softclamp the gradient norms for storing memories
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.49
3
+ Version: 0.0.51
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -135,3 +135,12 @@ $ python train_mac.py
135
135
  year = {2024}
136
136
  }
137
137
  ```
138
+
139
+ ```bibtex
140
+ @inproceedings{Yang2024GatedDN,
141
+ title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
142
+ author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
143
+ year = {2024},
144
+ url = {https://api.semanticscholar.org/CorpusID:274598177}
145
+ }
146
+ ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=R_lWRe_iBkiWZ1iZAt1tNyjaTUyB5mb80mcYZHUKww0,11369
4
+ titans_pytorch/titans.py,sha256=T04onF0xhcrosS-Qkx7fcx-Cqgh0TdU5JLdq9l8ayGg,15911
5
+ titans_pytorch-0.0.51.dist-info/METADATA,sha256=OKqRYWucpjsLKKOksBrdHnvCAqgHgv2FUtL6l8hN2-Y,4484
6
+ titans_pytorch-0.0.51.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.51.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.51.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=EMhxPt86Vr6LFvPm0OLMFYLaIY19khU9yIHkIhl2EMA,10316
4
- titans_pytorch/titans.py,sha256=tV2ej2PGUhMjSmDFV_wowX5q9hyp4SM4Jv3eJNu7cy8,15518
5
- titans_pytorch-0.0.49.dist-info/METADATA,sha256=hEpYHDqm_gffXybcotEmsK6o-siKrE7HwT_UgbOd-4o,4210
6
- titans_pytorch-0.0.49.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.49.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.49.dist-info/RECORD,,