titans-pytorch 0.0.56__py3-none-any.whl → 0.0.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +64 -23
- {titans_pytorch-0.0.56.dist-info → titans_pytorch-0.0.57.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.57.dist-info/RECORD +8 -0
- titans_pytorch-0.0.56.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.56.dist-info → titans_pytorch-0.0.57.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.56.dist-info → titans_pytorch-0.0.57.dist-info}/licenses/LICENSE +0 -0
|
@@ -285,6 +285,50 @@ class SegmentedAttention(Module):
|
|
|
285
285
|
|
|
286
286
|
return out, orig_v
|
|
287
287
|
|
|
288
|
+
# Attention + Neural Memory gating configuration, as depicted in Figure 2
|
|
289
|
+
|
|
290
|
+
class NeuralMemoryGatingWrapper(Module):
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
dim,
|
|
294
|
+
attn: SegmentedAttention,
|
|
295
|
+
neural_mem: NeuralMemory | None = None
|
|
296
|
+
):
|
|
297
|
+
super().__init__()
|
|
298
|
+
self.attn = attn
|
|
299
|
+
self.neural_mem = neural_mem
|
|
300
|
+
self.to_gates = nn.Linear(dim, dim) if exists(neural_mem) else None
|
|
301
|
+
|
|
302
|
+
def forward(
|
|
303
|
+
self,
|
|
304
|
+
seq,
|
|
305
|
+
*args,
|
|
306
|
+
**kwargs
|
|
307
|
+
):
|
|
308
|
+
batch, seq_len = seq.shape[:2]
|
|
309
|
+
mem = self.neural_mem
|
|
310
|
+
|
|
311
|
+
if not exists(mem):
|
|
312
|
+
return self.attn(seq, *args, **kwargs), 0.
|
|
313
|
+
|
|
314
|
+
# initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
|
|
315
|
+
|
|
316
|
+
retrieved, first_kv_aux_loss = mem(seq, return_aux_kv_loss = True)
|
|
317
|
+
|
|
318
|
+
seq = seq + retrieved
|
|
319
|
+
|
|
320
|
+
# attention
|
|
321
|
+
|
|
322
|
+
attn_out, values = self.attn(seq, *args, **kwargs)
|
|
323
|
+
|
|
324
|
+
# another retrieve, but this time gate the attention output
|
|
325
|
+
|
|
326
|
+
retrieved, second_kv_aux_loss = mem(attn_out, return_aux_kv_loss = True)
|
|
327
|
+
|
|
328
|
+
attn_out = attn_out * self.to_gates(retrieved).sigmoid()
|
|
329
|
+
|
|
330
|
+
return (attn_out, values), (first_kv_aux_loss + second_kv_aux_loss)
|
|
331
|
+
|
|
288
332
|
# MAC transformer
|
|
289
333
|
|
|
290
334
|
class MemoryAsContextTransformer(Module):
|
|
@@ -328,7 +372,6 @@ class MemoryAsContextTransformer(Module):
|
|
|
328
372
|
|
|
329
373
|
self.layers = ModuleList([])
|
|
330
374
|
|
|
331
|
-
self.neural_mem_layers = ModuleList([])
|
|
332
375
|
self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
|
|
333
376
|
|
|
334
377
|
layers = tuple(range(1, depth + 1))
|
|
@@ -343,7 +386,18 @@ class MemoryAsContextTransformer(Module):
|
|
|
343
386
|
for layer in layers:
|
|
344
387
|
is_first = layer == 1
|
|
345
388
|
|
|
346
|
-
#
|
|
389
|
+
# attention and feedforward
|
|
390
|
+
|
|
391
|
+
attn = SegmentedAttention(
|
|
392
|
+
dim = dim,
|
|
393
|
+
dim_head = dim_head,
|
|
394
|
+
heads = heads,
|
|
395
|
+
segment_len = segment_len,
|
|
396
|
+
use_flex_attn = use_flex_attn,
|
|
397
|
+
accept_value_residual = not is_first,
|
|
398
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
399
|
+
num_persist_mem_tokens = num_persist_mem_tokens
|
|
400
|
+
)
|
|
347
401
|
|
|
348
402
|
mem = None
|
|
349
403
|
|
|
@@ -356,21 +410,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
356
410
|
**neural_memory_kwargs
|
|
357
411
|
)
|
|
358
412
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
# attention and feedforward
|
|
364
|
-
|
|
365
|
-
attn = SegmentedAttention(
|
|
366
|
-
dim = dim,
|
|
367
|
-
dim_head = dim_head,
|
|
368
|
-
heads = heads,
|
|
369
|
-
segment_len = segment_len,
|
|
370
|
-
use_flex_attn = use_flex_attn,
|
|
371
|
-
accept_value_residual = not is_first,
|
|
372
|
-
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
373
|
-
num_persist_mem_tokens = num_persist_mem_tokens
|
|
413
|
+
attn = NeuralMemoryGatingWrapper(
|
|
414
|
+
dim,
|
|
415
|
+
attn = attn,
|
|
416
|
+
neural_mem = mem,
|
|
374
417
|
)
|
|
375
418
|
|
|
376
419
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
|
@@ -460,19 +503,17 @@ class MemoryAsContextTransformer(Module):
|
|
|
460
503
|
|
|
461
504
|
x = self.expand_streams(x)
|
|
462
505
|
|
|
463
|
-
for
|
|
464
|
-
|
|
465
|
-
if exists(maybe_neural_mem):
|
|
466
|
-
x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
|
|
467
|
-
kv_recon_losses = kv_recon_losses + aux_kv_loss
|
|
506
|
+
for attn, ff in self.layers:
|
|
468
507
|
|
|
469
|
-
x, values = attn(
|
|
508
|
+
(x, values), maybe_mem_kv_aux_loss = attn(
|
|
470
509
|
x,
|
|
471
510
|
value_residual = value_residual,
|
|
472
511
|
disable_flex_attn = disable_flex_attn,
|
|
473
512
|
flex_attn_fn = flex_attn_fn
|
|
474
513
|
)
|
|
475
514
|
|
|
515
|
+
kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
|
|
516
|
+
|
|
476
517
|
value_residual = default(value_residual, values)
|
|
477
518
|
|
|
478
519
|
x = ff(x)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=-WS_SI--_5f1whMaJOH-mYCU37EjYU_iZTurGfs8zgI,15331
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
+
titans_pytorch-0.0.57.dist-info/METADATA,sha256=rwLIRndtBo22oJt0Xm9xK9zqOYV50Jfo6g7oVrKq7CQ,4457
|
|
6
|
+
titans_pytorch-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.57.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=_Vsco5YuR6uxouWcjFj-s-zPhrBcaapIzqoyi7qqY0Q,14245
|
|
4
|
-
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
-
titans_pytorch-0.0.56.dist-info/METADATA,sha256=QlCmHqajHiaZTps0W9gKXIHE6dShZER3PqPoYi2zRe4,4457
|
|
6
|
-
titans_pytorch-0.0.56.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.56.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.56.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|