titans-pytorch 0.0.52__py3-none-any.whl → 0.0.54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
+ MemoryAttention
4
5
  )
5
6
 
6
7
  from titans_pytorch.mac_transformer import (
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
  from math import ceil
3
4
  from functools import partial
4
5
 
@@ -32,7 +33,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
32
33
 
33
34
  # einstein notation related
34
35
 
35
- from einops import einsum, repeat, rearrange, pack, unpack
36
+ from einops import repeat, rearrange, pack, unpack
36
37
  from einops.layers.torch import Rearrange
37
38
 
38
39
  # b - batch
titans_pytorch/titans.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
  import math
3
4
  from functools import partial
4
5
 
@@ -16,7 +17,6 @@ from titans_pytorch.associative_scan import (
16
17
  pad_at_dim
17
18
  )
18
19
 
19
- import einx
20
20
  from einops import rearrange, repeat, pack, unpack
21
21
  from einops.layers.torch import Rearrange, Reduce
22
22
 
@@ -123,9 +123,12 @@ class MemoryMLP(Module):
123
123
  class MemoryAttention(Module):
124
124
  def __init__(
125
125
  self,
126
- dim
126
+ dim,
127
+ scale = 8.
127
128
  ):
128
129
  super().__init__()
130
+ self.scale = scale
131
+
129
132
  self.weights = nn.ParameterList([
130
133
  nn.Parameter(torch.randn(dim, dim)), # queries
131
134
  nn.Parameter(torch.randn(dim, dim)), # keys
@@ -143,6 +146,7 @@ class MemoryAttention(Module):
143
146
 
144
147
  attn_out = F.scaled_dot_product_attention(
145
148
  q, k, v,
149
+ scale = self.scale,
146
150
  is_causal = True
147
151
  )
148
152
 
@@ -174,6 +178,7 @@ class NeuralMemory(Module):
174
178
  default_step_transform_max_lr = 1e-2,
175
179
  pre_rmsnorm = True,
176
180
  post_rmsnorm = True,
181
+ learned_mem_model_weights = True,
177
182
  max_grad_norm: float | None = None,
178
183
  use_accelerated_scan = False,
179
184
  activation: Module | None = None,
@@ -212,6 +217,9 @@ class NeuralMemory(Module):
212
217
  if not exists(model):
213
218
  model = MemoryMLP(dim_head, **default_model_kwargs)
214
219
 
220
+ if not learned_mem_model_weights:
221
+ model.requires_grad_(False)
222
+
215
223
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
216
224
 
217
225
  # the memory is the weights of the model
@@ -338,9 +346,9 @@ class NeuralMemory(Module):
338
346
 
339
347
  # take care of chunking
340
348
 
341
- keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
349
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
342
350
 
343
- adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
351
+ adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
344
352
 
345
353
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
346
354
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.52
3
+ Version: 0.0.54
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -37,7 +37,6 @@ Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.5
39
39
  Requires-Dist: einops>=0.8.0
40
- Requires-Dist: einx>=0.3.0
41
40
  Requires-Dist: hyper-connections>=0.1.8
42
41
  Requires-Dist: ninja
43
42
  Requires-Dist: rotary-embedding-torch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=GdUAYq6MDRGY0l2ESBH_kM01AEzVztiKmWfblSKxBEM,13212
4
+ titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
+ titans_pytorch-0.0.54.dist-info/METADATA,sha256=bxbC3NBO4Sjii7DpFPcmNsO9M1kX76vj947H2DeUceg,4457
6
+ titans_pytorch-0.0.54.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.54.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.54.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=khfjpbsy-uT9NIG3dZLsLOG_XSEi7EqcyfbPr7EQc2Q,13192
4
- titans_pytorch/titans.py,sha256=T04onF0xhcrosS-Qkx7fcx-Cqgh0TdU5JLdq9l8ayGg,15911
5
- titans_pytorch-0.0.52.dist-info/METADATA,sha256=coC9ExIuNvmab0BktSE1NwUgxRaBUV7h_cTHeoJkRJo,4484
6
- titans_pytorch-0.0.52.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.52.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.52.dist-info/RECORD,,