titans-pytorch 0.0.21__tar.gz → 0.0.22__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.21
3
+ Version: 0.0.22
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.21"
3
+ version = "0.0.22"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -2,8 +2,10 @@ import torch
2
2
  import pytest
3
3
 
4
4
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
5
+ @pytest.mark.parametrize('max_grad_norm', (None, 2.))
5
6
  def test_titans(
6
- seq_len
7
+ seq_len,
8
+ max_grad_norm
7
9
  ):
8
10
 
9
11
  from titans_pytorch import NeuralMemory
@@ -11,6 +13,7 @@ def test_titans(
11
13
  mem = NeuralMemory(
12
14
  dim = 384,
13
15
  chunk_size = 64,
16
+ max_grad_norm = max_grad_norm
14
17
  )
15
18
 
16
19
  seq = torch.randn(2, seq_len, 384)
@@ -100,8 +100,8 @@ class MemoryMLP(Module):
100
100
 
101
101
  # main neural memory
102
102
 
103
- def default_adaptive_step_transform(adaptive_step):
104
- return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
103
+ def default_adaptive_step_transform(adaptive_step, max_lr = 1e-1):
104
+ return adaptive_step.sigmoid() * max_lr
105
105
 
106
106
  def default_loss_fn(pred, target):
107
107
  return (pred - target).pow(2).mean(dim = -1).sum()
@@ -67,6 +67,7 @@ titans_neural_memory = NeuralMemory(
67
67
  post_rmsnorm = True,
68
68
  dim_head = 64,
69
69
  heads = 4,
70
+ max_grad_norm = 1.,
70
71
  use_accelerated_scan = True,
71
72
  default_mlp_kwargs = dict(
72
73
  depth = NEURAL_MEMORY_DEPTH
File without changes