titans-pytorch 0.0.57__py3-none-any.whl → 0.0.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -292,12 +292,13 @@ class NeuralMemoryGatingWrapper(Module):
292
292
  self,
293
293
  dim,
294
294
  attn: SegmentedAttention,
295
- neural_mem: NeuralMemory | None = None
295
+ neural_mem: NeuralMemory | None = None,
296
+ gate_attn_output = True
296
297
  ):
297
298
  super().__init__()
298
299
  self.attn = attn
299
300
  self.neural_mem = neural_mem
300
- self.to_gates = nn.Linear(dim, dim) if exists(neural_mem) else None
301
+ self.gate_attn_output = gate_attn_output
301
302
 
302
303
  def forward(
303
304
  self,
@@ -313,21 +314,19 @@ class NeuralMemoryGatingWrapper(Module):
313
314
 
314
315
  # initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
315
316
 
316
- retrieved, first_kv_aux_loss = mem(seq, return_aux_kv_loss = True)
317
+ retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
317
318
 
318
- seq = seq + retrieved
319
+ if not self.gate_attn_output:
320
+ seq = seq + retrieved
319
321
 
320
322
  # attention
321
323
 
322
324
  attn_out, values = self.attn(seq, *args, **kwargs)
323
325
 
324
- # another retrieve, but this time gate the attention output
326
+ if self.gate_attn_output:
327
+ attn_out = attn_out * retrieved.sigmoid()
325
328
 
326
- retrieved, second_kv_aux_loss = mem(attn_out, return_aux_kv_loss = True)
327
-
328
- attn_out = attn_out * self.to_gates(retrieved).sigmoid()
329
-
330
- return (attn_out, values), (first_kv_aux_loss + second_kv_aux_loss)
329
+ return (attn_out, values), kv_aux_loss
331
330
 
332
331
  # MAC transformer
333
332
 
@@ -340,6 +339,7 @@ class MemoryAsContextTransformer(Module):
340
339
  depth,
341
340
  segment_len,
342
341
  neural_memory_segment_len = None,
342
+ neural_mem_gate_attn_output = True,
343
343
  num_longterm_mem_tokens = 0,
344
344
  num_persist_mem_tokens = 0,
345
345
  dim_head = 64,
@@ -414,6 +414,7 @@ class MemoryAsContextTransformer(Module):
414
414
  dim,
415
415
  attn = attn,
416
416
  neural_mem = mem,
417
+ gate_attn_output = neural_mem_gate_attn_output
417
418
  )
418
419
 
419
420
  ff = FeedForward(dim = dim, mult = ff_mult)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.57
3
+ Version: 0.0.58
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=kk8s8Q2WmbJxCVi8PcqSUyJBc8-CDAHrVjt6M0d_kFs,15323
4
+ titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
+ titans_pytorch-0.0.58.dist-info/METADATA,sha256=a-Y6MV_89D44HlB7eKpurh-sw5DDiS-pIVei3Uw_uGE,4457
6
+ titans_pytorch-0.0.58.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.58.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.58.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=-WS_SI--_5f1whMaJOH-mYCU37EjYU_iZTurGfs8zgI,15331
4
- titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
- titans_pytorch-0.0.57.dist-info/METADATA,sha256=rwLIRndtBo22oJt0Xm9xK9zqOYV50Jfo6g7oVrKq7CQ,4457
6
- titans_pytorch-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.57.dist-info/RECORD,,