titans-pytorch 0.3.21__tar.gz → 0.3.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.21
3
+ Version: 0.3.22
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.21"
3
+ version = "0.3.22"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -30,6 +30,7 @@ def torch_default_dtype(dtype):
30
30
  @pytest.mark.parametrize('momentum', (False, True))
31
31
  @pytest.mark.parametrize('qk_rmsnorm', (False, True))
32
32
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
33
+ @pytest.mark.parametrize('num_kv_per_token', (1, 2))
33
34
  @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
34
35
  @pytest.mark.parametrize('per_head_learned_parameters', (False, True))
35
36
  def test_titans(
@@ -40,6 +41,7 @@ def test_titans(
40
41
  momentum,
41
42
  qk_rmsnorm,
42
43
  max_grad_norm,
44
+ num_kv_per_token,
43
45
  per_parameter_lr_modulation,
44
46
  per_head_learned_parameters
45
47
  ):
@@ -49,6 +51,7 @@ def test_titans(
49
51
  activation = nn.SiLU() if silu else None,
50
52
  attn_pool_chunks = attn_pool_chunks,
51
53
  max_grad_norm = max_grad_norm,
54
+ num_kv_per_token = num_kv_per_token,
52
55
  momentum = momentum,
53
56
  qk_rmsnorm = qk_rmsnorm,
54
57
  per_parameter_lr_modulation = per_parameter_lr_modulation,
@@ -35,6 +35,7 @@ d - feature dimension
35
35
  c - intra-chunk
36
36
  w - num memory network weight parameters
37
37
  o - momentum orders
38
+ u - key / value updates - allowing a token to emit multiple key / values
38
39
  """
39
40
 
40
41
  LinearNoBias = partial(Linear, bias = False)
@@ -231,6 +232,7 @@ class NeuralMemory(Module):
231
232
  momentum_order = 1,
232
233
  learned_momentum_combine = False,
233
234
  learned_combine_include_zeroth = False,
235
+ num_kv_per_token = 1, # whether a single token can do multiple updates to the memory model
234
236
  qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
235
237
  pre_rmsnorm = True,
236
238
  post_rmsnorm = False,
@@ -363,11 +365,22 @@ class NeuralMemory(Module):
363
365
 
364
366
  # keys and values for storing to the model
365
367
 
366
- self.to_keys = Sequential(LinearNoBias(dim, dim_inner), activation)
367
- self.to_values = Sequential(LinearNoBias(dim, dim_inner), activation)
368
+ assert num_kv_per_token > 0
369
+
370
+ self.to_keys = Sequential(
371
+ LinearNoBias(dim, dim_inner * num_kv_per_token),
372
+ activation,
373
+ )
374
+
375
+ self.to_values = Sequential(
376
+ LinearNoBias(dim, dim_inner * num_kv_per_token),
377
+ activation,
378
+ )
368
379
 
369
380
  self.store_memory_loss_fn = store_memory_loss_fn
370
381
 
382
+ self.num_kv_per_token = num_kv_per_token
383
+
371
384
  # `chunk_size` refers to chunk size used for storing to memory model weights
372
385
 
373
386
  chunk_size = self.store_chunk_size
@@ -384,8 +397,8 @@ class NeuralMemory(Module):
384
397
  # learned adaptive learning rate
385
398
 
386
399
  self.to_adaptive_step = Sequential(
387
- nn.Linear(dim, heads),
388
- Rearrange('b n h -> (b h) n')
400
+ nn.Linear(dim, heads * num_kv_per_token),
401
+ Rearrange('b n (h u) -> (b h) (n u)', u = num_kv_per_token)
389
402
  )
390
403
 
391
404
  if not exists(adaptive_step_transform):
@@ -518,7 +531,7 @@ class NeuralMemory(Module):
518
531
 
519
532
  # shapes and variables
520
533
 
521
- heads, chunk_size = self.heads, self.store_chunk_size
534
+ heads, chunk_size, num_updates = self.heads, self.store_chunk_size, self.num_kv_per_token
522
535
 
523
536
  # curtail sequence by multiple of the chunk size
524
537
  # only a complete chunk of the sequence provides the memory for the next chunk
@@ -587,15 +600,17 @@ class NeuralMemory(Module):
587
600
 
588
601
  batch = keys.shape[0]
589
602
 
603
+ # take care of chunking
604
+
605
+ keys, values = tuple(rearrange(t, 'b h (n c) (u d) -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
606
+
590
607
  # maybe qk rmsnorm
591
608
 
592
609
  keys = self.k_norm(keys)
593
610
 
594
- # take care of chunking
595
-
596
- keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
611
+ # adaptive lr
597
612
 
598
- adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
613
+ adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates)
599
614
 
600
615
  # maybe add previous layer weight
601
616
 
File without changes