titans-pytorch 0.0.19__tar.gz → 0.0.21__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.19
3
+ Version: 0.0.21
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.19"
3
+ version = "0.0.21"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -40,6 +40,9 @@ def exists(v):
40
40
  def default(v, d):
41
41
  return v if exists(v) else d
42
42
 
43
+ def identity(t):
44
+ return t
45
+
43
46
  def round_down_multiple(seq, mult):
44
47
  return seq // mult * mult
45
48
 
@@ -55,6 +58,21 @@ def pack_one_with_inverse(t, pattern):
55
58
 
56
59
  return packed, inverse
57
60
 
61
+ # softclamping gradients
62
+
63
+ def softclamp_max(t, max_value):
64
+ half_max_value = max_value / 2
65
+ return ((t / half_max_value).tanh() * half_max_value) + half_max_value
66
+
67
+ def softclamp_grad_norm(t, max_value):
68
+ t, inverse = pack_one_with_inverse(t, 'bn *')
69
+
70
+ norm = t.norm(dim = -1, keepdim = True)
71
+ clamped_norm = softclamp_max(norm, max_value)
72
+
73
+ t = t * (clamped_norm / norm)
74
+ return inverse(t)
75
+
58
76
  # classes
59
77
 
60
78
  class MemoryMLP(Module):
@@ -82,6 +100,9 @@ class MemoryMLP(Module):
82
100
 
83
101
  # main neural memory
84
102
 
103
+ def default_adaptive_step_transform(adaptive_step):
104
+ return torch.exp(adaptive_step.sigmoid() * -15) # from 1. - 1e-7
105
+
85
106
  def default_loss_fn(pred, target):
86
107
  return (pred - target).pow(2).mean(dim = -1).sum()
87
108
 
@@ -94,8 +115,10 @@ class NeuralMemory(Module):
94
115
  heads = 1,
95
116
  model: Module | None = None,
96
117
  store_memory_loss_fn: Callable = default_loss_fn,
118
+ adaptive_step_transform: Callable = default_adaptive_step_transform,
97
119
  pre_rmsnorm = True,
98
120
  post_rmsnorm = True,
121
+ max_grad_norm: float | None = None,
99
122
  use_accelerated_scan = False,
100
123
  default_mlp_kwargs: dict = dict(
101
124
  depth = 4
@@ -172,6 +195,12 @@ class NeuralMemory(Module):
172
195
  Rearrange('b n h -> (b h) n')
173
196
  )
174
197
 
198
+ self.adaptive_step_transform = adaptive_step_transform
199
+
200
+ # allow for softclamp the gradient norms for storing memories
201
+
202
+ self.max_grad_norm = max_grad_norm
203
+
175
204
  # weight decay factor
176
205
 
177
206
  self.to_decay_factor = nn.Sequential(
@@ -222,7 +251,8 @@ class NeuralMemory(Module):
222
251
 
223
252
  # pack batch and sequence dimension
224
253
 
225
- adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
254
+ adaptive_lr = self.to_adaptive_step(seq)
255
+ adaptive_lr = self.adaptive_step_transform(adaptive_lr)
226
256
 
227
257
  adaptive_momentum = self.to_momentum(seq).sigmoid()
228
258
  decay_factor = self.to_decay_factor(seq).sigmoid()
@@ -247,6 +277,11 @@ class NeuralMemory(Module):
247
277
 
248
278
  grads = TensorDict(grads)
249
279
 
280
+ # maybe softclamp grad norm
281
+
282
+ if exists(self.max_grad_norm):
283
+ grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
284
+
250
285
  # restore batch and sequence dimension
251
286
 
252
287
  grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
@@ -62,11 +62,11 @@ def decode_tokens(tokens):
62
62
 
63
63
  titans_neural_memory = NeuralMemory(
64
64
  dim = 384,
65
- chunk_size = WINDOW_SIZE,
65
+ chunk_size = 4,
66
66
  pre_rmsnorm = True,
67
67
  post_rmsnorm = True,
68
- dim_head = 32,
69
- heads = 8,
68
+ dim_head = 64,
69
+ heads = 4,
70
70
  use_accelerated_scan = True,
71
71
  default_mlp_kwargs = dict(
72
72
  depth = NEURAL_MEMORY_DEPTH
File without changes