titans-pytorch 0.0.55__py3-none-any.whl → 0.0.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +69 -24
- {titans_pytorch-0.0.55.dist-info → titans_pytorch-0.0.57.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.57.dist-info/RECORD +8 -0
- titans_pytorch-0.0.55.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.55.dist-info → titans_pytorch-0.0.57.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.55.dist-info → titans_pytorch-0.0.57.dist-info}/licenses/LICENSE +0 -0
|
@@ -285,6 +285,50 @@ class SegmentedAttention(Module):
|
|
|
285
285
|
|
|
286
286
|
return out, orig_v
|
|
287
287
|
|
|
288
|
+
# Attention + Neural Memory gating configuration, as depicted in Figure 2
|
|
289
|
+
|
|
290
|
+
class NeuralMemoryGatingWrapper(Module):
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
dim,
|
|
294
|
+
attn: SegmentedAttention,
|
|
295
|
+
neural_mem: NeuralMemory | None = None
|
|
296
|
+
):
|
|
297
|
+
super().__init__()
|
|
298
|
+
self.attn = attn
|
|
299
|
+
self.neural_mem = neural_mem
|
|
300
|
+
self.to_gates = nn.Linear(dim, dim) if exists(neural_mem) else None
|
|
301
|
+
|
|
302
|
+
def forward(
|
|
303
|
+
self,
|
|
304
|
+
seq,
|
|
305
|
+
*args,
|
|
306
|
+
**kwargs
|
|
307
|
+
):
|
|
308
|
+
batch, seq_len = seq.shape[:2]
|
|
309
|
+
mem = self.neural_mem
|
|
310
|
+
|
|
311
|
+
if not exists(mem):
|
|
312
|
+
return self.attn(seq, *args, **kwargs), 0.
|
|
313
|
+
|
|
314
|
+
# initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
|
|
315
|
+
|
|
316
|
+
retrieved, first_kv_aux_loss = mem(seq, return_aux_kv_loss = True)
|
|
317
|
+
|
|
318
|
+
seq = seq + retrieved
|
|
319
|
+
|
|
320
|
+
# attention
|
|
321
|
+
|
|
322
|
+
attn_out, values = self.attn(seq, *args, **kwargs)
|
|
323
|
+
|
|
324
|
+
# another retrieve, but this time gate the attention output
|
|
325
|
+
|
|
326
|
+
retrieved, second_kv_aux_loss = mem(attn_out, return_aux_kv_loss = True)
|
|
327
|
+
|
|
328
|
+
attn_out = attn_out * self.to_gates(retrieved).sigmoid()
|
|
329
|
+
|
|
330
|
+
return (attn_out, values), (first_kv_aux_loss + second_kv_aux_loss)
|
|
331
|
+
|
|
288
332
|
# MAC transformer
|
|
289
333
|
|
|
290
334
|
class MemoryAsContextTransformer(Module):
|
|
@@ -328,7 +372,6 @@ class MemoryAsContextTransformer(Module):
|
|
|
328
372
|
|
|
329
373
|
self.layers = ModuleList([])
|
|
330
374
|
|
|
331
|
-
self.neural_mem_layers = ModuleList([])
|
|
332
375
|
self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
|
|
333
376
|
|
|
334
377
|
layers = tuple(range(1, depth + 1))
|
|
@@ -343,7 +386,18 @@ class MemoryAsContextTransformer(Module):
|
|
|
343
386
|
for layer in layers:
|
|
344
387
|
is_first = layer == 1
|
|
345
388
|
|
|
346
|
-
#
|
|
389
|
+
# attention and feedforward
|
|
390
|
+
|
|
391
|
+
attn = SegmentedAttention(
|
|
392
|
+
dim = dim,
|
|
393
|
+
dim_head = dim_head,
|
|
394
|
+
heads = heads,
|
|
395
|
+
segment_len = segment_len,
|
|
396
|
+
use_flex_attn = use_flex_attn,
|
|
397
|
+
accept_value_residual = not is_first,
|
|
398
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
399
|
+
num_persist_mem_tokens = num_persist_mem_tokens
|
|
400
|
+
)
|
|
347
401
|
|
|
348
402
|
mem = None
|
|
349
403
|
|
|
@@ -356,21 +410,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
356
410
|
**neural_memory_kwargs
|
|
357
411
|
)
|
|
358
412
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
# attention and feedforward
|
|
364
|
-
|
|
365
|
-
attn = SegmentedAttention(
|
|
366
|
-
dim = dim,
|
|
367
|
-
dim_head = dim_head,
|
|
368
|
-
heads = heads,
|
|
369
|
-
segment_len = segment_len,
|
|
370
|
-
use_flex_attn = use_flex_attn,
|
|
371
|
-
accept_value_residual = not is_first,
|
|
372
|
-
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
373
|
-
num_persist_mem_tokens = num_persist_mem_tokens
|
|
413
|
+
attn = NeuralMemoryGatingWrapper(
|
|
414
|
+
dim,
|
|
415
|
+
attn = attn,
|
|
416
|
+
neural_mem = mem,
|
|
374
417
|
)
|
|
375
418
|
|
|
376
419
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
|
@@ -445,8 +488,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
445
488
|
flex_attn_fn = None
|
|
446
489
|
|
|
447
490
|
if use_flex_attn:
|
|
448
|
-
block_mask = create_mac_block_mask(
|
|
449
|
-
|
|
491
|
+
block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
|
|
450
492
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
451
493
|
|
|
452
494
|
# value residual
|
|
@@ -461,13 +503,16 @@ class MemoryAsContextTransformer(Module):
|
|
|
461
503
|
|
|
462
504
|
x = self.expand_streams(x)
|
|
463
505
|
|
|
464
|
-
for
|
|
506
|
+
for attn, ff in self.layers:
|
|
465
507
|
|
|
466
|
-
|
|
467
|
-
x,
|
|
468
|
-
|
|
508
|
+
(x, values), maybe_mem_kv_aux_loss = attn(
|
|
509
|
+
x,
|
|
510
|
+
value_residual = value_residual,
|
|
511
|
+
disable_flex_attn = disable_flex_attn,
|
|
512
|
+
flex_attn_fn = flex_attn_fn
|
|
513
|
+
)
|
|
469
514
|
|
|
470
|
-
|
|
515
|
+
kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
|
|
471
516
|
|
|
472
517
|
value_residual = default(value_residual, values)
|
|
473
518
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=-WS_SI--_5f1whMaJOH-mYCU37EjYU_iZTurGfs8zgI,15331
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
+
titans_pytorch-0.0.57.dist-info/METADATA,sha256=rwLIRndtBo22oJt0Xm9xK9zqOYV50Jfo6g7oVrKq7CQ,4457
|
|
6
|
+
titans_pytorch-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.57.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=-VN8bURUaqHXH_96UqGYDhWcfgCaFdHGdM6faVuYDgQ,14159
|
|
4
|
-
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
-
titans_pytorch-0.0.55.dist-info/METADATA,sha256=VYP1B5d9tejIXr7u6ML4cSjvgIDlWYyp5KTyydlUqV8,4457
|
|
6
|
-
titans_pytorch-0.0.55.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.55.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.55.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|