titans-pytorch 0.1.14__tar.gz → 0.1.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.14
3
+ Version: 0.1.15
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
- Requires-Dist: axial-positional-embedding>=0.3.7
38
+ Requires-Dist: axial-positional-embedding>=0.3.9
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hyper-connections>=0.1.8
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.14"
3
+ version = "0.1.15"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -26,7 +26,7 @@ classifiers=[
26
26
 
27
27
  dependencies = [
28
28
  "accelerated-scan>=0.2.0",
29
- "axial_positional_embedding>=0.3.7",
29
+ "axial_positional_embedding>=0.3.9",
30
30
  "einops>=0.8.0",
31
31
  "einx>=0.3.0",
32
32
  "hyper-connections>=0.1.8",
@@ -301,6 +301,45 @@ class MemoryAttention(Module):
301
301
 
302
302
  return out
303
303
 
304
+ # associative scan wrapper
305
+
306
+ class AssocScan(Module):
307
+ def __init__(
308
+ self,
309
+ use_accelerated = False
310
+ ):
311
+ super().__init__()
312
+ self.use_accelerated = use_accelerated
313
+
314
+ def forward(self, gates, inputs):
315
+
316
+ if not self.use_accelerated:
317
+ _, outputs = associative_scan(binary_operator, (gates, inputs))
318
+ return outputs
319
+
320
+ from accelerated_scan.triton import scan as triton_scan
321
+ from accelerated_scan.warp import scan as warp_scan
322
+
323
+ scan = triton_scan if gates.is_cuda else warp_scan
324
+
325
+ def accelerate_scan_fn(gates, inputs):
326
+ gates = gates.expand_as(inputs)
327
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
328
+
329
+ seq_len = gates.shape[-1]
330
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
331
+
332
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
333
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
334
+
335
+ outputs = scan(gates.contiguous(), inputs.contiguous())
336
+
337
+ outputs = outputs[..., :seq_len]
338
+ outputs = rearrange(outputs, 'b d n -> b n d')
339
+ return outputs
340
+
341
+ return accelerate_scan_fn(gates, inputs)
342
+
304
343
  # main neural memory
305
344
 
306
345
  def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
@@ -339,6 +378,10 @@ class NeuralMemory(Module):
339
378
 
340
379
  self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
341
380
 
381
+ # associative scan
382
+
383
+ self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
384
+
342
385
  # norms
343
386
 
344
387
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -564,38 +607,6 @@ class NeuralMemory(Module):
564
607
 
565
608
  surprises = grads.apply(lambda t: -t)
566
609
 
567
- # determine scan function
568
-
569
- def default_associative_scan(gates, inputs):
570
- _, outputs = associative_scan(binary_operator, (gates, inputs))
571
- return outputs
572
-
573
- if self.use_accelerated_scan:
574
- from accelerated_scan.triton import scan as triton_scan
575
- from accelerated_scan.warp import scan as warp_scan
576
-
577
- scan = triton_scan if seq.is_cuda else warp_scan
578
-
579
- def accelerate_scan_fn(gates, inputs):
580
- gates = gates.expand_as(inputs)
581
- gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
582
-
583
- seq_len = gates.shape[-1]
584
- next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
585
-
586
- gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
587
- inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
588
-
589
- outputs = scan(gates.contiguous(), inputs.contiguous())
590
-
591
- outputs = outputs[..., :seq_len]
592
- outputs = rearrange(outputs, 'b d n -> b n d')
593
- return outputs
594
-
595
- scan_fn = accelerate_scan_fn
596
- else:
597
- scan_fn = default_associative_scan
598
-
599
610
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
600
611
 
601
612
  next_momentum = TensorDict() if has_momentum else None
@@ -610,12 +621,12 @@ class NeuralMemory(Module):
610
621
  # derive momentum with associative scan - eq (10)
611
622
 
612
623
  if has_momentum:
613
- update = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
624
+ update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
614
625
  momentum = update
615
626
 
616
627
  # use associative scan again for learned forgetting (weight decay) - eq (13)
617
628
 
618
- update = scan_fn(1. - decay_factor, update)
629
+ update = self.assoc_scan(1. - decay_factor, update)
619
630
 
620
631
  updates[param_name] = inverse_pack(update)
621
632
 
File without changes