titans-pytorch 0.0.56__py3-none-any.whl → 0.0.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +65 -23
- {titans_pytorch-0.0.56.dist-info → titans_pytorch-0.0.58.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.58.dist-info/RECORD +8 -0
- titans_pytorch-0.0.56.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.56.dist-info → titans_pytorch-0.0.58.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.56.dist-info → titans_pytorch-0.0.58.dist-info}/licenses/LICENSE +0 -0
|
@@ -285,6 +285,49 @@ class SegmentedAttention(Module):
|
|
|
285
285
|
|
|
286
286
|
return out, orig_v
|
|
287
287
|
|
|
288
|
+
# Attention + Neural Memory gating configuration, as depicted in Figure 2
|
|
289
|
+
|
|
290
|
+
class NeuralMemoryGatingWrapper(Module):
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
dim,
|
|
294
|
+
attn: SegmentedAttention,
|
|
295
|
+
neural_mem: NeuralMemory | None = None,
|
|
296
|
+
gate_attn_output = True
|
|
297
|
+
):
|
|
298
|
+
super().__init__()
|
|
299
|
+
self.attn = attn
|
|
300
|
+
self.neural_mem = neural_mem
|
|
301
|
+
self.gate_attn_output = gate_attn_output
|
|
302
|
+
|
|
303
|
+
def forward(
|
|
304
|
+
self,
|
|
305
|
+
seq,
|
|
306
|
+
*args,
|
|
307
|
+
**kwargs
|
|
308
|
+
):
|
|
309
|
+
batch, seq_len = seq.shape[:2]
|
|
310
|
+
mem = self.neural_mem
|
|
311
|
+
|
|
312
|
+
if not exists(mem):
|
|
313
|
+
return self.attn(seq, *args, **kwargs), 0.
|
|
314
|
+
|
|
315
|
+
# initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
|
|
316
|
+
|
|
317
|
+
retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
|
|
318
|
+
|
|
319
|
+
if not self.gate_attn_output:
|
|
320
|
+
seq = seq + retrieved
|
|
321
|
+
|
|
322
|
+
# attention
|
|
323
|
+
|
|
324
|
+
attn_out, values = self.attn(seq, *args, **kwargs)
|
|
325
|
+
|
|
326
|
+
if self.gate_attn_output:
|
|
327
|
+
attn_out = attn_out * retrieved.sigmoid()
|
|
328
|
+
|
|
329
|
+
return (attn_out, values), kv_aux_loss
|
|
330
|
+
|
|
288
331
|
# MAC transformer
|
|
289
332
|
|
|
290
333
|
class MemoryAsContextTransformer(Module):
|
|
@@ -296,6 +339,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
296
339
|
depth,
|
|
297
340
|
segment_len,
|
|
298
341
|
neural_memory_segment_len = None,
|
|
342
|
+
neural_mem_gate_attn_output = True,
|
|
299
343
|
num_longterm_mem_tokens = 0,
|
|
300
344
|
num_persist_mem_tokens = 0,
|
|
301
345
|
dim_head = 64,
|
|
@@ -328,7 +372,6 @@ class MemoryAsContextTransformer(Module):
|
|
|
328
372
|
|
|
329
373
|
self.layers = ModuleList([])
|
|
330
374
|
|
|
331
|
-
self.neural_mem_layers = ModuleList([])
|
|
332
375
|
self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
|
|
333
376
|
|
|
334
377
|
layers = tuple(range(1, depth + 1))
|
|
@@ -343,7 +386,18 @@ class MemoryAsContextTransformer(Module):
|
|
|
343
386
|
for layer in layers:
|
|
344
387
|
is_first = layer == 1
|
|
345
388
|
|
|
346
|
-
#
|
|
389
|
+
# attention and feedforward
|
|
390
|
+
|
|
391
|
+
attn = SegmentedAttention(
|
|
392
|
+
dim = dim,
|
|
393
|
+
dim_head = dim_head,
|
|
394
|
+
heads = heads,
|
|
395
|
+
segment_len = segment_len,
|
|
396
|
+
use_flex_attn = use_flex_attn,
|
|
397
|
+
accept_value_residual = not is_first,
|
|
398
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
399
|
+
num_persist_mem_tokens = num_persist_mem_tokens
|
|
400
|
+
)
|
|
347
401
|
|
|
348
402
|
mem = None
|
|
349
403
|
|
|
@@ -356,21 +410,11 @@ class MemoryAsContextTransformer(Module):
|
|
|
356
410
|
**neural_memory_kwargs
|
|
357
411
|
)
|
|
358
412
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
attn = SegmentedAttention(
|
|
366
|
-
dim = dim,
|
|
367
|
-
dim_head = dim_head,
|
|
368
|
-
heads = heads,
|
|
369
|
-
segment_len = segment_len,
|
|
370
|
-
use_flex_attn = use_flex_attn,
|
|
371
|
-
accept_value_residual = not is_first,
|
|
372
|
-
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
373
|
-
num_persist_mem_tokens = num_persist_mem_tokens
|
|
413
|
+
attn = NeuralMemoryGatingWrapper(
|
|
414
|
+
dim,
|
|
415
|
+
attn = attn,
|
|
416
|
+
neural_mem = mem,
|
|
417
|
+
gate_attn_output = neural_mem_gate_attn_output
|
|
374
418
|
)
|
|
375
419
|
|
|
376
420
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
|
@@ -460,19 +504,17 @@ class MemoryAsContextTransformer(Module):
|
|
|
460
504
|
|
|
461
505
|
x = self.expand_streams(x)
|
|
462
506
|
|
|
463
|
-
for
|
|
464
|
-
|
|
465
|
-
if exists(maybe_neural_mem):
|
|
466
|
-
x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
|
|
467
|
-
kv_recon_losses = kv_recon_losses + aux_kv_loss
|
|
507
|
+
for attn, ff in self.layers:
|
|
468
508
|
|
|
469
|
-
x, values = attn(
|
|
509
|
+
(x, values), maybe_mem_kv_aux_loss = attn(
|
|
470
510
|
x,
|
|
471
511
|
value_residual = value_residual,
|
|
472
512
|
disable_flex_attn = disable_flex_attn,
|
|
473
513
|
flex_attn_fn = flex_attn_fn
|
|
474
514
|
)
|
|
475
515
|
|
|
516
|
+
kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
|
|
517
|
+
|
|
476
518
|
value_residual = default(value_residual, values)
|
|
477
519
|
|
|
478
520
|
x = ff(x)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=kk8s8Q2WmbJxCVi8PcqSUyJBc8-CDAHrVjt6M0d_kFs,15323
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
+
titans_pytorch-0.0.58.dist-info/METADATA,sha256=a-Y6MV_89D44HlB7eKpurh-sw5DDiS-pIVei3Uw_uGE,4457
|
|
6
|
+
titans_pytorch-0.0.58.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.58.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.58.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=_Vsco5YuR6uxouWcjFj-s-zPhrBcaapIzqoyi7qqY0Q,14245
|
|
4
|
-
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
-
titans_pytorch-0.0.56.dist-info/METADATA,sha256=QlCmHqajHiaZTps0W9gKXIHE6dShZER3PqPoYi2zRe4,4457
|
|
6
|
-
titans_pytorch-0.0.56.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.56.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.56.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|