titans-pytorch 0.0.57__py3-none-any.whl → 0.0.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +11 -10
- {titans_pytorch-0.0.57.dist-info → titans_pytorch-0.0.58.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.58.dist-info/RECORD +8 -0
- titans_pytorch-0.0.57.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.57.dist-info → titans_pytorch-0.0.58.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.57.dist-info → titans_pytorch-0.0.58.dist-info}/licenses/LICENSE +0 -0
|
@@ -292,12 +292,13 @@ class NeuralMemoryGatingWrapper(Module):
|
|
|
292
292
|
self,
|
|
293
293
|
dim,
|
|
294
294
|
attn: SegmentedAttention,
|
|
295
|
-
neural_mem: NeuralMemory | None = None
|
|
295
|
+
neural_mem: NeuralMemory | None = None,
|
|
296
|
+
gate_attn_output = True
|
|
296
297
|
):
|
|
297
298
|
super().__init__()
|
|
298
299
|
self.attn = attn
|
|
299
300
|
self.neural_mem = neural_mem
|
|
300
|
-
self.
|
|
301
|
+
self.gate_attn_output = gate_attn_output
|
|
301
302
|
|
|
302
303
|
def forward(
|
|
303
304
|
self,
|
|
@@ -313,21 +314,19 @@ class NeuralMemoryGatingWrapper(Module):
|
|
|
313
314
|
|
|
314
315
|
# initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
|
|
315
316
|
|
|
316
|
-
retrieved,
|
|
317
|
+
retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
|
|
317
318
|
|
|
318
|
-
|
|
319
|
+
if not self.gate_attn_output:
|
|
320
|
+
seq = seq + retrieved
|
|
319
321
|
|
|
320
322
|
# attention
|
|
321
323
|
|
|
322
324
|
attn_out, values = self.attn(seq, *args, **kwargs)
|
|
323
325
|
|
|
324
|
-
|
|
326
|
+
if self.gate_attn_output:
|
|
327
|
+
attn_out = attn_out * retrieved.sigmoid()
|
|
325
328
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
attn_out = attn_out * self.to_gates(retrieved).sigmoid()
|
|
329
|
-
|
|
330
|
-
return (attn_out, values), (first_kv_aux_loss + second_kv_aux_loss)
|
|
329
|
+
return (attn_out, values), kv_aux_loss
|
|
331
330
|
|
|
332
331
|
# MAC transformer
|
|
333
332
|
|
|
@@ -340,6 +339,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
340
339
|
depth,
|
|
341
340
|
segment_len,
|
|
342
341
|
neural_memory_segment_len = None,
|
|
342
|
+
neural_mem_gate_attn_output = True,
|
|
343
343
|
num_longterm_mem_tokens = 0,
|
|
344
344
|
num_persist_mem_tokens = 0,
|
|
345
345
|
dim_head = 64,
|
|
@@ -414,6 +414,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
414
414
|
dim,
|
|
415
415
|
attn = attn,
|
|
416
416
|
neural_mem = mem,
|
|
417
|
+
gate_attn_output = neural_mem_gate_attn_output
|
|
417
418
|
)
|
|
418
419
|
|
|
419
420
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=kk8s8Q2WmbJxCVi8PcqSUyJBc8-CDAHrVjt6M0d_kFs,15323
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
+
titans_pytorch-0.0.58.dist-info/METADATA,sha256=a-Y6MV_89D44HlB7eKpurh-sw5DDiS-pIVei3Uw_uGE,4457
|
|
6
|
+
titans_pytorch-0.0.58.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.58.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.58.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=-WS_SI--_5f1whMaJOH-mYCU37EjYU_iZTurGfs8zgI,15331
|
|
4
|
-
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
-
titans_pytorch-0.0.57.dist-info/METADATA,sha256=rwLIRndtBo22oJt0Xm9xK9zqOYV50Jfo6g7oVrKq7CQ,4457
|
|
6
|
-
titans_pytorch-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.57.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|