titans-pytorch 0.0.47__py3-none-any.whl → 0.0.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -95,7 +95,7 @@ class SegmentedAttention(Module):
95
95
  dim_head = 64,
96
96
  heads = 8,
97
97
  accept_value_residual = False,
98
- attend_kwargs: dict = dict()
98
+ attend_kwargs: dict = dict(),
99
99
  ):
100
100
  super().__init__()
101
101
  self.norm = nn.RMSNorm(dim)
@@ -201,6 +201,7 @@ class MemoryAsContextTransformer(Module):
201
201
  num_residual_streams = 4,
202
202
  neural_memory_kwargs: dict = dict(),
203
203
  neural_memory_layers: tuple[int, ...] | None = None,
204
+ aux_kv_recon_loss_weight = 0.
204
205
  ):
205
206
  super().__init__()
206
207
 
@@ -276,10 +277,18 @@ class MemoryAsContextTransformer(Module):
276
277
 
277
278
  self.to_logits = LinearNoBias(dim, num_tokens)
278
279
 
280
+ # auxiliary loss on kv recon
281
+
282
+ self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
283
+ self.aux_kv_recon_loss_weight = aux_kv_recon_loss_weight
284
+
285
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
286
+
279
287
  def forward(
280
288
  self,
281
289
  x,
282
- return_loss = False
290
+ return_loss = False,
291
+ return_loss_breakdown = False
283
292
  ):
284
293
 
285
294
  if return_loss:
@@ -317,6 +326,10 @@ class MemoryAsContextTransformer(Module):
317
326
 
318
327
  value_residual = None
319
328
 
329
+ # aux losses
330
+
331
+ kv_recon_losses = self.zero
332
+
320
333
  # expand and reduce streams for hyper connections
321
334
 
322
335
  x = self.expand_streams(x)
@@ -324,7 +337,8 @@ class MemoryAsContextTransformer(Module):
324
337
  for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
325
338
 
326
339
  if exists(maybe_neural_mem):
327
- x = maybe_neural_mem(x)
340
+ x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
341
+ kv_recon_losses = kv_recon_losses + aux_kv_loss
328
342
 
329
343
  x, values = attn(x, value_residual = value_residual)
330
344
 
@@ -351,4 +365,14 @@ class MemoryAsContextTransformer(Module):
351
365
  if not return_loss:
352
366
  return logits
353
367
 
354
- return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
368
+ ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
369
+
370
+ losses = ar_loss
371
+
372
+ if self.has_aux_kv_recon_loss:
373
+ losses = losses + kv_recon_losses * self.aux_kv_recon_loss_weight
374
+
375
+ if not return_loss_breakdown:
376
+ return losses
377
+
378
+ return losses, (ar_loss, kv_recon_losses)
titans_pytorch/titans.py CHANGED
@@ -162,7 +162,8 @@ class NeuralMemory(Module):
162
162
  heads = 1,
163
163
  model: Module | None = None,
164
164
  store_memory_loss_fn: Callable = default_loss_fn,
165
- adaptive_step_transform: Callable = default_adaptive_step_transform,
165
+ adaptive_step_transform: Callable | None = None,
166
+ default_step_transform_max_lr = 1e-2,
166
167
  pre_rmsnorm = True,
167
168
  post_rmsnorm = True,
168
169
  max_grad_norm: float | None = None,
@@ -218,7 +219,7 @@ class NeuralMemory(Module):
218
219
  pred = functional_call(self.memory_model, params, inputs)
219
220
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
220
221
  weighted_loss = loss * loss_weights
221
- return weighted_loss.sum(), loss.mean()
222
+ return weighted_loss.sum(), weighted_loss.mean()
222
223
 
223
224
  self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
224
225
 
@@ -250,6 +251,9 @@ class NeuralMemory(Module):
250
251
  Rearrange('b n h -> (b h) n')
251
252
  )
252
253
 
254
+ if not exists(adaptive_step_transform):
255
+ adaptive_step_transform = partial(default_adaptive_step_transform, max_lr = default_step_transform_max_lr)
256
+
253
257
  self.adaptive_step_transform = adaptive_step_transform
254
258
 
255
259
  # allow for softclamp the gradient norms for storing memories
@@ -409,7 +413,7 @@ class NeuralMemory(Module):
409
413
  if not return_aux_kv_loss:
410
414
  return updates, next_state
411
415
 
412
- return updates, next_state, aux_kv_recon_loss
416
+ return updates, next_state, aux_kv_recon_loss.mean()
413
417
 
414
418
  def retrieve_memories(
415
419
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.47
3
+ Version: 0.0.50
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=EMhxPt86Vr6LFvPm0OLMFYLaIY19khU9yIHkIhl2EMA,10316
4
+ titans_pytorch/titans.py,sha256=TklMAxNDxgFBpJZFJa8hEhqA_DITmT6EM0p0ueE1jo8,15712
5
+ titans_pytorch-0.0.50.dist-info/METADATA,sha256=KU7TTrH89eNVPP10NKKTDKnW-ik344_kVQkAXW7NRL8,4210
6
+ titans_pytorch-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.50.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=mF8PYAjeAjLas1gkYybgzZX1AVK82A_ps_LY00ofYYs,9565
4
- titans_pytorch/titans.py,sha256=ZcWxx6n-f8ttojRnK9fExavmT1bS-QSCRHQn7ldv7J0,15502
5
- titans_pytorch-0.0.47.dist-info/METADATA,sha256=HjZxbJlnqsSgbioQz6KHWJb--8n18WDdL2T-jz-CFKc,4210
6
- titans_pytorch-0.0.47.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.47.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.47.dist-info/RECORD,,