titans-pytorch 0.0.55__py3-none-any.whl → 0.0.57__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -285,6 +285,50 @@ class SegmentedAttention(Module):
285
285
 
286
286
  return out, orig_v
287
287
 
288
+ # Attention + Neural Memory gating configuration, as depicted in Figure 2
289
+
290
+ class NeuralMemoryGatingWrapper(Module):
291
+ def __init__(
292
+ self,
293
+ dim,
294
+ attn: SegmentedAttention,
295
+ neural_mem: NeuralMemory | None = None
296
+ ):
297
+ super().__init__()
298
+ self.attn = attn
299
+ self.neural_mem = neural_mem
300
+ self.to_gates = nn.Linear(dim, dim) if exists(neural_mem) else None
301
+
302
+ def forward(
303
+ self,
304
+ seq,
305
+ *args,
306
+ **kwargs
307
+ ):
308
+ batch, seq_len = seq.shape[:2]
309
+ mem = self.neural_mem
310
+
311
+ if not exists(mem):
312
+ return self.attn(seq, *args, **kwargs), 0.
313
+
314
+ # initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
315
+
316
+ retrieved, first_kv_aux_loss = mem(seq, return_aux_kv_loss = True)
317
+
318
+ seq = seq + retrieved
319
+
320
+ # attention
321
+
322
+ attn_out, values = self.attn(seq, *args, **kwargs)
323
+
324
+ # another retrieve, but this time gate the attention output
325
+
326
+ retrieved, second_kv_aux_loss = mem(attn_out, return_aux_kv_loss = True)
327
+
328
+ attn_out = attn_out * self.to_gates(retrieved).sigmoid()
329
+
330
+ return (attn_out, values), (first_kv_aux_loss + second_kv_aux_loss)
331
+
288
332
  # MAC transformer
289
333
 
290
334
  class MemoryAsContextTransformer(Module):
@@ -328,7 +372,6 @@ class MemoryAsContextTransformer(Module):
328
372
 
329
373
  self.layers = ModuleList([])
330
374
 
331
- self.neural_mem_layers = ModuleList([])
332
375
  self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
333
376
 
334
377
  layers = tuple(range(1, depth + 1))
@@ -343,7 +386,18 @@ class MemoryAsContextTransformer(Module):
343
386
  for layer in layers:
344
387
  is_first = layer == 1
345
388
 
346
- # neural memory
389
+ # attention and feedforward
390
+
391
+ attn = SegmentedAttention(
392
+ dim = dim,
393
+ dim_head = dim_head,
394
+ heads = heads,
395
+ segment_len = segment_len,
396
+ use_flex_attn = use_flex_attn,
397
+ accept_value_residual = not is_first,
398
+ num_longterm_mem_tokens = num_longterm_mem_tokens,
399
+ num_persist_mem_tokens = num_persist_mem_tokens
400
+ )
347
401
 
348
402
  mem = None
349
403
 
@@ -356,21 +410,10 @@ class MemoryAsContextTransformer(Module):
356
410
  **neural_memory_kwargs
357
411
  )
358
412
 
359
- mem = init_hyper_conn(dim = dim, branch = mem)
360
-
361
- self.neural_mem_layers.append(mem)
362
-
363
- # attention and feedforward
364
-
365
- attn = SegmentedAttention(
366
- dim = dim,
367
- dim_head = dim_head,
368
- heads = heads,
369
- segment_len = segment_len,
370
- use_flex_attn = use_flex_attn,
371
- accept_value_residual = not is_first,
372
- num_longterm_mem_tokens = num_longterm_mem_tokens,
373
- num_persist_mem_tokens = num_persist_mem_tokens
413
+ attn = NeuralMemoryGatingWrapper(
414
+ dim,
415
+ attn = attn,
416
+ neural_mem = mem,
374
417
  )
375
418
 
376
419
  ff = FeedForward(dim = dim, mult = ff_mult)
@@ -445,8 +488,7 @@ class MemoryAsContextTransformer(Module):
445
488
  flex_attn_fn = None
446
489
 
447
490
  if use_flex_attn:
448
- block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
449
-
491
+ block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
450
492
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
451
493
 
452
494
  # value residual
@@ -461,13 +503,16 @@ class MemoryAsContextTransformer(Module):
461
503
 
462
504
  x = self.expand_streams(x)
463
505
 
464
- for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
506
+ for attn, ff in self.layers:
465
507
 
466
- if exists(maybe_neural_mem):
467
- x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
468
- kv_recon_losses = kv_recon_losses + aux_kv_loss
508
+ (x, values), maybe_mem_kv_aux_loss = attn(
509
+ x,
510
+ value_residual = value_residual,
511
+ disable_flex_attn = disable_flex_attn,
512
+ flex_attn_fn = flex_attn_fn
513
+ )
469
514
 
470
- x, values = attn(x, value_residual = value_residual, disable_flex_attn = disable_flex_attn, flex_attn_fn = flex_attn_fn)
515
+ kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
471
516
 
472
517
  value_residual = default(value_residual, values)
473
518
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.55
3
+ Version: 0.0.57
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=-WS_SI--_5f1whMaJOH-mYCU37EjYU_iZTurGfs8zgI,15331
4
+ titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
+ titans_pytorch-0.0.57.dist-info/METADATA,sha256=rwLIRndtBo22oJt0Xm9xK9zqOYV50Jfo6g7oVrKq7CQ,4457
6
+ titans_pytorch-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.57.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=-VN8bURUaqHXH_96UqGYDhWcfgCaFdHGdM6faVuYDgQ,14159
4
- titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
- titans_pytorch-0.0.55.dist-info/METADATA,sha256=VYP1B5d9tejIXr7u6ML4cSjvgIDlWYyp5KTyydlUqV8,4457
6
- titans_pytorch-0.0.55.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.55.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.55.dist-info/RECORD,,