titans-pytorch 0.0.20__tar.gz → 0.0.21__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.20"
3
+ version = "0.0.21"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -40,6 +40,9 @@ def exists(v):
40
40
  def default(v, d):
41
41
  return v if exists(v) else d
42
42
 
43
+ def identity(t):
44
+ return t
45
+
43
46
  def round_down_multiple(seq, mult):
44
47
  return seq // mult * mult
45
48
 
@@ -97,6 +100,9 @@ class MemoryMLP(Module):
97
100
 
98
101
  # main neural memory
99
102
 
103
+ def default_adaptive_step_transform(adaptive_step):
104
+ return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
105
+
100
106
  def default_loss_fn(pred, target):
101
107
  return (pred - target).pow(2).mean(dim = -1).sum()
102
108
 
@@ -109,6 +115,7 @@ class NeuralMemory(Module):
109
115
  heads = 1,
110
116
  model: Module | None = None,
111
117
  store_memory_loss_fn: Callable = default_loss_fn,
118
+ adaptive_step_transform: Callable = default_adaptive_step_transform,
112
119
  pre_rmsnorm = True,
113
120
  post_rmsnorm = True,
114
121
  max_grad_norm: float | None = None,
@@ -188,6 +195,8 @@ class NeuralMemory(Module):
188
195
  Rearrange('b n h -> (b h) n')
189
196
  )
190
197
 
198
+ self.adaptive_step_transform = adaptive_step_transform
199
+
191
200
  # allow for softclamp the gradient norms for storing memories
192
201
 
193
202
  self.max_grad_norm = max_grad_norm
@@ -242,7 +251,8 @@ class NeuralMemory(Module):
242
251
 
243
252
  # pack batch and sequence dimension
244
253
 
245
- adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
254
+ adaptive_lr = self.to_adaptive_step(seq)
255
+ adaptive_lr = self.adaptive_step_transform(adaptive_lr)
246
256
 
247
257
  adaptive_momentum = self.to_momentum(seq).sigmoid()
248
258
  decay_factor = self.to_decay_factor(seq).sigmoid()
@@ -62,11 +62,11 @@ def decode_tokens(tokens):
62
62
 
63
63
  titans_neural_memory = NeuralMemory(
64
64
  dim = 384,
65
- chunk_size = WINDOW_SIZE,
65
+ chunk_size = 4,
66
66
  pre_rmsnorm = True,
67
67
  post_rmsnorm = True,
68
- dim_head = 32,
69
- heads = 8,
68
+ dim_head = 64,
69
+ heads = 4,
70
70
  use_accelerated_scan = True,
71
71
  default_mlp_kwargs = dict(
72
72
  depth = NEURAL_MEMORY_DEPTH
File without changes