titans-pytorch 0.0.49__py3-none-any.whl → 0.0.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -162,7 +162,8 @@ class NeuralMemory(Module):
162
162
  heads = 1,
163
163
  model: Module | None = None,
164
164
  store_memory_loss_fn: Callable = default_loss_fn,
165
- adaptive_step_transform: Callable = default_adaptive_step_transform,
165
+ adaptive_step_transform: Callable | None = None,
166
+ default_step_transform_max_lr = 1e-2,
166
167
  pre_rmsnorm = True,
167
168
  post_rmsnorm = True,
168
169
  max_grad_norm: float | None = None,
@@ -250,6 +251,9 @@ class NeuralMemory(Module):
250
251
  Rearrange('b n h -> (b h) n')
251
252
  )
252
253
 
254
+ if not exists(adaptive_step_transform):
255
+ adaptive_step_transform = partial(default_adaptive_step_transform, max_lr = default_step_transform_max_lr)
256
+
253
257
  self.adaptive_step_transform = adaptive_step_transform
254
258
 
255
259
  # allow for softclamp the gradient norms for storing memories
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.49
3
+ Version: 0.0.50
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=EMhxPt86Vr6LFvPm0OLMFYLaIY19khU9yIHkIhl2EMA,10316
4
+ titans_pytorch/titans.py,sha256=TklMAxNDxgFBpJZFJa8hEhqA_DITmT6EM0p0ueE1jo8,15712
5
+ titans_pytorch-0.0.50.dist-info/METADATA,sha256=KU7TTrH89eNVPP10NKKTDKnW-ik344_kVQkAXW7NRL8,4210
6
+ titans_pytorch-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.50.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=EMhxPt86Vr6LFvPm0OLMFYLaIY19khU9yIHkIhl2EMA,10316
4
- titans_pytorch/titans.py,sha256=tV2ej2PGUhMjSmDFV_wowX5q9hyp4SM4Jv3eJNu7cy8,15518
5
- titans_pytorch-0.0.49.dist-info/METADATA,sha256=hEpYHDqm_gffXybcotEmsK6o-siKrE7HwT_UgbOd-4o,4210
6
- titans_pytorch-0.0.49.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.49.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.49.dist-info/RECORD,,