titans-pytorch 0.0.54__py3-none-any.whl → 0.0.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +34 -5
- {titans_pytorch-0.0.54.dist-info → titans_pytorch-0.0.56.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.56.dist-info/RECORD +8 -0
- titans_pytorch-0.0.54.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.54.dist-info → titans_pytorch-0.0.56.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.54.dist-info → titans_pytorch-0.0.56.dist-info}/licenses/LICENSE +0 -0
|
@@ -227,9 +227,10 @@ class SegmentedAttention(Module):
|
|
|
227
227
|
self,
|
|
228
228
|
seq,
|
|
229
229
|
value_residual = None,
|
|
230
|
-
flex_attn_fn: Callable | None = None
|
|
230
|
+
flex_attn_fn: Callable | None = None,
|
|
231
|
+
disable_flex_attn = False
|
|
231
232
|
):
|
|
232
|
-
if seq.is_cuda and self.use_flex_attn:
|
|
233
|
+
if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
|
|
233
234
|
return self.forward_flex(seq, value_residual, flex_attn_fn)
|
|
234
235
|
|
|
235
236
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
@@ -303,7 +304,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
303
304
|
num_residual_streams = 4,
|
|
304
305
|
neural_memory_kwargs: dict = dict(),
|
|
305
306
|
neural_memory_layers: tuple[int, ...] | None = None,
|
|
306
|
-
aux_kv_recon_loss_weight = 0
|
|
307
|
+
aux_kv_recon_loss_weight = 0.,
|
|
308
|
+
use_flex_attn = False
|
|
307
309
|
):
|
|
308
310
|
super().__init__()
|
|
309
311
|
|
|
@@ -336,6 +338,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
336
338
|
|
|
337
339
|
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
|
338
340
|
|
|
341
|
+
# mem, attn, and feedforward layers
|
|
342
|
+
|
|
339
343
|
for layer in layers:
|
|
340
344
|
is_first = layer == 1
|
|
341
345
|
|
|
@@ -363,6 +367,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
363
367
|
dim_head = dim_head,
|
|
364
368
|
heads = heads,
|
|
365
369
|
segment_len = segment_len,
|
|
370
|
+
use_flex_attn = use_flex_attn,
|
|
366
371
|
accept_value_residual = not is_first,
|
|
367
372
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
368
373
|
num_persist_mem_tokens = num_persist_mem_tokens
|
|
@@ -386,11 +391,20 @@ class MemoryAsContextTransformer(Module):
|
|
|
386
391
|
|
|
387
392
|
self.register_buffer('zero', torch.tensor(0.), persistent = False)
|
|
388
393
|
|
|
394
|
+
# flex attn related
|
|
395
|
+
|
|
396
|
+
assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
|
|
397
|
+
self.use_flex_attn = use_flex_attn
|
|
398
|
+
|
|
399
|
+
self.segment_len = segment_len
|
|
400
|
+
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
401
|
+
|
|
389
402
|
def forward(
|
|
390
403
|
self,
|
|
391
404
|
x,
|
|
392
405
|
return_loss = False,
|
|
393
|
-
return_loss_breakdown = False
|
|
406
|
+
return_loss_breakdown = False,
|
|
407
|
+
disable_flex_attn = False
|
|
394
408
|
):
|
|
395
409
|
|
|
396
410
|
if return_loss:
|
|
@@ -424,6 +438,16 @@ class MemoryAsContextTransformer(Module):
|
|
|
424
438
|
|
|
425
439
|
x = x + pos_emb[:seq_len_with_mem]
|
|
426
440
|
|
|
441
|
+
# prep flex attention
|
|
442
|
+
|
|
443
|
+
use_flex_attn = x.is_cuda and self.use_flex_attn and not disable_flex_attn
|
|
444
|
+
|
|
445
|
+
flex_attn_fn = None
|
|
446
|
+
|
|
447
|
+
if use_flex_attn:
|
|
448
|
+
block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
|
|
449
|
+
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
450
|
+
|
|
427
451
|
# value residual
|
|
428
452
|
|
|
429
453
|
value_residual = None
|
|
@@ -442,7 +466,12 @@ class MemoryAsContextTransformer(Module):
|
|
|
442
466
|
x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
|
|
443
467
|
kv_recon_losses = kv_recon_losses + aux_kv_loss
|
|
444
468
|
|
|
445
|
-
x, values = attn(
|
|
469
|
+
x, values = attn(
|
|
470
|
+
x,
|
|
471
|
+
value_residual = value_residual,
|
|
472
|
+
disable_flex_attn = disable_flex_attn,
|
|
473
|
+
flex_attn_fn = flex_attn_fn
|
|
474
|
+
)
|
|
446
475
|
|
|
447
476
|
value_residual = default(value_residual, values)
|
|
448
477
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=_Vsco5YuR6uxouWcjFj-s-zPhrBcaapIzqoyi7qqY0Q,14245
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
+
titans_pytorch-0.0.56.dist-info/METADATA,sha256=QlCmHqajHiaZTps0W9gKXIHE6dShZER3PqPoYi2zRe4,4457
|
|
6
|
+
titans_pytorch-0.0.56.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.56.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.56.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=GdUAYq6MDRGY0l2ESBH_kM01AEzVztiKmWfblSKxBEM,13212
|
|
4
|
-
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
-
titans_pytorch-0.0.54.dist-info/METADATA,sha256=bxbC3NBO4Sjii7DpFPcmNsO9M1kX76vj947H2DeUceg,4457
|
|
6
|
-
titans_pytorch-0.0.54.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.54.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.54.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|