titans-pytorch 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,18 +3,39 @@ from typing import Callable
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
+ from torch.nn import Module
6
7
  import torch.nn.functional as F
7
8
 
9
+ from einops import rearrange, repeat, reduce, pack, unpack
10
+
8
11
  # taken from S5-pytorch repository
9
12
  # https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
10
13
 
11
14
  # helper functions
12
15
 
16
+ def exists(v):
17
+ return v is not None
18
+
19
+ def default(*args):
20
+ for arg in args:
21
+ if exists(arg):
22
+ return arg
23
+ return None
24
+
13
25
  def pad_at_dim(t, pad, dim = -1, value = 0.):
14
26
  dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
15
27
  zeros = ((0, 0) * dims_from_right)
16
28
  return F.pad(t, (*zeros, *pad), value = value)
17
29
 
30
+ def pack_one_with_inverse(t, pattern):
31
+ packed, packed_shape = pack([t], pattern)
32
+
33
+ def inverse(out, inv_pattern = None):
34
+ inv_pattern = default(inv_pattern, pattern)
35
+ return unpack(out, packed_shape, inv_pattern)[0]
36
+
37
+ return packed, inverse
38
+
18
39
  # the operator that is needed
19
40
 
20
41
  @torch.jit.script
@@ -88,3 +109,69 @@ def _interleave(a, b):
88
109
  interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
89
110
 
90
111
  return interleaved[:, :output_axis_len]
112
+
113
+ # associative scan wrapper around naive and accelerated version
114
+
115
+ class AssocScan(Module):
116
+ def __init__(
117
+ self,
118
+ use_accelerated = False
119
+ ):
120
+ super().__init__()
121
+ self.use_accelerated = use_accelerated
122
+
123
+ def forward(
124
+ self,
125
+ gates,
126
+ inputs,
127
+ prev = None,
128
+ remove_prev = None
129
+ ):
130
+ remove_prev = default(remove_prev, exists(prev))
131
+
132
+ inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
133
+ gates, _ = pack_one_with_inverse(gates, 'b n *')
134
+
135
+ if exists(prev):
136
+ prev, _ = pack_one_with_inverse(prev, 'b *')
137
+
138
+ if exists(prev):
139
+ inputs, _ = pack([prev, inputs], 'b * d')
140
+ gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
141
+
142
+ if not self.use_accelerated:
143
+ _, out = associative_scan(binary_operator, (gates, inputs))
144
+
145
+ if remove_prev:
146
+ out = out[:, 1:]
147
+
148
+ return inverse_pack_weight_shape(out)
149
+
150
+ from accelerated_scan.triton import scan as triton_scan
151
+ from accelerated_scan.warp import scan as warp_scan
152
+
153
+ scan = triton_scan if gates.is_cuda else warp_scan
154
+
155
+ def accelerate_scan_fn(gates, inputs):
156
+ gates = gates.expand_as(inputs)
157
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
158
+
159
+ seq_len = gates.shape[-1]
160
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
161
+
162
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
163
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
164
+
165
+ outputs = scan(gates.contiguous(), inputs.contiguous())
166
+
167
+ outputs = outputs[..., :seq_len]
168
+ outputs = rearrange(outputs, 'b d n -> b n d')
169
+
170
+ return outputs
171
+
172
+ out = accelerate_scan_fn(gates, inputs)
173
+
174
+ if remove_prev:
175
+ out = out[:, 1:]
176
+
177
+ return inverse_pack_weight_shape(out)
@@ -30,6 +30,25 @@ class LayerNorm(Module):
30
30
 
31
31
  return self.ln(x) * (gamma + 1.)
32
32
 
33
+ # norm + residual wrapper, as used in original TTT paper
34
+ # but could be removed
35
+
36
+ class ResidualNorm(Module):
37
+ def __init__(
38
+ self,
39
+ dim,
40
+ model: Module
41
+ ):
42
+ super().__init__()
43
+ self.norm = LayerNorm(dim)
44
+ self.model = model
45
+
46
+ def forward(self, x):
47
+
48
+ out = self.model(x)
49
+
50
+ return self.norm(out) + x
51
+
33
52
  # memory mlp proposed in TTT
34
53
 
35
54
  class MemoryMLP(Module):
@@ -45,8 +64,6 @@ class MemoryMLP(Module):
45
64
 
46
65
  self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
47
66
 
48
- self.ln = LayerNorm(dim)
49
-
50
67
  for weight in self.weights:
51
68
  nn.init.xavier_uniform_(weight)
52
69
 
@@ -54,8 +71,6 @@ class MemoryMLP(Module):
54
71
  self,
55
72
  x
56
73
  ):
57
- residual = x
58
-
59
74
  for ind, weight in enumerate(self.weights):
60
75
  is_first = ind == 0
61
76
 
@@ -64,7 +79,7 @@ class MemoryMLP(Module):
64
79
 
65
80
  x = x @ weight
66
81
 
67
- return self.ln(x) + residual
82
+ return x
68
83
 
69
84
  # memory mlp, but with gated residual + final projection
70
85
 
@@ -97,7 +112,6 @@ class GatedResidualMemoryMLP(Module):
97
112
  self,
98
113
  x
99
114
  ):
100
- residual = x
101
115
 
102
116
  for weight1, weight2, to_gates in self.weights:
103
117
  res = x
@@ -111,9 +125,7 @@ class GatedResidualMemoryMLP(Module):
111
125
  gates = cat((branch_out, res), dim = -1) @ to_gates
112
126
  x = res.lerp(branch_out, gates.sigmoid())
113
127
 
114
- out = x @ self.final_proj
115
-
116
- return self.ln(out) + residual
128
+ return x @ self.final_proj
117
129
 
118
130
  # memory mlp with factorized weights
119
131
  # so can tradeoff capacity for smaller chunk sizes
@@ -143,7 +155,6 @@ class FactorizedMemoryMLP(Module):
143
155
  self,
144
156
  x
145
157
  ):
146
- residual = x
147
158
 
148
159
  for ind, (weight1, weight2) in enumerate(self.weights):
149
160
  is_first = ind == 0
@@ -153,7 +164,7 @@ class FactorizedMemoryMLP(Module):
153
164
 
154
165
  x = x @ weight1 @ weight2
155
166
 
156
- return self.ln(x) + residual
167
+ return x
157
168
 
158
169
  # improvised attention as memory module
159
170
 
@@ -182,7 +193,6 @@ class MemoryAttention(Module):
182
193
  nn.init.xavier_uniform_(weight)
183
194
 
184
195
  def forward(self, x):
185
- residual = x
186
196
 
187
197
  wq, wk, wv, ffw1, ffw2 = self.weights
188
198
 
@@ -202,4 +212,4 @@ class MemoryAttention(Module):
202
212
  h = F.gelu(x @ ffw1)
203
213
  ff_out = h @ ffw2
204
214
 
205
- return self.ln(attn_out + ff_out) + residual
215
+ return attn_out + ff_out
@@ -8,19 +8,16 @@ from collections import namedtuple
8
8
  import torch
9
9
  from torch import nn, cat, tensor, Tensor
10
10
  import torch.nn.functional as F
11
- from torch.nn import Linear, Module, Parameter, ParameterList
11
+ from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
12
12
  from torch.func import functional_call, vmap, grad
13
13
 
14
14
  from tensordict import TensorDict
15
15
 
16
- from titans_pytorch.associative_scan import (
17
- associative_scan,
18
- binary_operator,
19
- pad_at_dim
20
- )
16
+ from titans_pytorch.associative_scan import AssocScan
21
17
 
22
18
  from titans_pytorch.memory_models import(
23
- MemoryMLP
19
+ MemoryMLP,
20
+ ResidualNorm
24
21
  )
25
22
 
26
23
  import einx
@@ -79,8 +76,8 @@ def safe_cat(inputs, dim = -2):
79
76
  def is_empty_tensor(t):
80
77
  return t.numel() == 0
81
78
 
82
- def dict_get_shape(td):
83
- return {k: v.shape for k, v in td.items()}
79
+ def dict_get_value_shapes(td):
80
+ return [v.shape for k, v in td.items()]
84
81
 
85
82
  def rearrange_dict_values(td, pattern, **kwargs):
86
83
  return td.apply(lambda t: rearrange(t, pattern, **kwargs))
@@ -97,6 +94,11 @@ def round_down_multiple(seq, mult):
97
94
  def round_up_multiple(seq, mult):
98
95
  return math.ceil(seq / mult) * mult
99
96
 
97
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
98
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
99
+ zeros = ((0, 0) * dims_from_right)
100
+ return F.pad(t, (*zeros, *pad), value = value)
101
+
100
102
  def pack_one_with_inverse(t, pattern):
101
103
  packed, packed_shape = pack([t], pattern)
102
104
 
@@ -197,72 +199,6 @@ class AttentionPool(Module):
197
199
 
198
200
  return reduce(x * attn, 'b n c d -> b n d', 'sum')
199
201
 
200
- # associative scan wrapper
201
-
202
- class AssocScan(Module):
203
- def __init__(
204
- self,
205
- use_accelerated = False
206
- ):
207
- super().__init__()
208
- self.use_accelerated = use_accelerated
209
-
210
- def forward(
211
- self,
212
- gates,
213
- inputs,
214
- prev = None,
215
- remove_prev = None
216
- ):
217
- remove_prev = default(remove_prev, exists(prev))
218
-
219
- inputs, inverse_pack_weight_shape = pack_one_with_inverse(inputs, 'b n *')
220
- gates, _ = pack_one_with_inverse(gates, 'b n *')
221
-
222
- if exists(prev):
223
- prev, _ = pack_one_with_inverse(prev, 'b *')
224
-
225
- if exists(prev):
226
- inputs, _ = pack([prev, inputs], 'b * d')
227
- gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
228
-
229
- if not self.use_accelerated:
230
- _, out = associative_scan(binary_operator, (gates, inputs))
231
-
232
- if remove_prev:
233
- out = out[:, 1:]
234
-
235
- return inverse_pack_weight_shape(out)
236
-
237
- from accelerated_scan.triton import scan as triton_scan
238
- from accelerated_scan.warp import scan as warp_scan
239
-
240
- scan = triton_scan if gates.is_cuda else warp_scan
241
-
242
- def accelerate_scan_fn(gates, inputs):
243
- gates = gates.expand_as(inputs)
244
- gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
245
-
246
- seq_len = gates.shape[-1]
247
- next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
248
-
249
- gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
250
- inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
251
-
252
- outputs = scan(gates.contiguous(), inputs.contiguous())
253
-
254
- outputs = outputs[..., :seq_len]
255
- outputs = rearrange(outputs, 'b d n -> b n d')
256
-
257
- return outputs
258
-
259
- out = accelerate_scan_fn(gates, inputs)
260
-
261
- if remove_prev:
262
- out = out[:, 1:]
263
-
264
- return inverse_pack_weight_shape(out)
265
-
266
202
  # main neural memory
267
203
 
268
204
  def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
@@ -285,6 +221,7 @@ class NeuralMemory(Module):
285
221
  default_step_transform_max_lr = 1.,
286
222
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
287
223
  max_mem_layer_modulation = 1., # max of 10.
224
+ per_head_learned_parameters = True,
288
225
  attn_pool_chunks = False,
289
226
  momentum = True,
290
227
  pre_rmsnorm = True,
@@ -298,6 +235,7 @@ class NeuralMemory(Module):
298
235
  init_decay_bias = None,
299
236
  accept_weight_residual = False,
300
237
  gated_transition = False,
238
+ mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
301
239
  default_model_kwargs: dict = dict(
302
240
  depth = 2,
303
241
  expansion_factor = 4.
@@ -368,11 +306,26 @@ class NeuralMemory(Module):
368
306
 
369
307
  # the memory is the weights of the model
370
308
 
309
+ if mem_model_norm_add_residual:
310
+ model = ResidualNorm(dim = dim_head, model = model)
311
+
371
312
  self.memory_model = model
372
313
 
373
- self.num_memory_parameter_tensors = len(set(model.parameters()))
314
+ mem_model_params = dict(model.named_parameters())
315
+
316
+ self.num_memory_parameter_tensors = len(mem_model_params)
374
317
 
375
- self.init_weight_shape = dict_get_shape(dict(model.named_parameters()))
318
+ self.memory_model_parameter_names = [*mem_model_params.keys()]
319
+
320
+ memory_model_parameters = [*mem_model_params.values()]
321
+
322
+ if per_head_learned_parameters:
323
+ memory_model_parameters = [repeat(p, '... -> h ...', h = heads) for p in memory_model_parameters]
324
+
325
+ self.init_weight_shape = [p.shape for p in memory_model_parameters]
326
+
327
+ self.memory_model_parameters = ParameterList(memory_model_parameters)
328
+ self.per_head_learned_parameters = per_head_learned_parameters
376
329
 
377
330
  # the chunk size within the paper where adaptive step, momentum, weight decay are shared
378
331
 
@@ -488,21 +441,32 @@ class NeuralMemory(Module):
488
441
 
489
442
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
490
443
 
444
+ @property
445
+ def memory_model_parameter_dict(self):
446
+ return TensorDict(dict(zip(self.memory_model_parameter_names, self.memory_model_parameters)))
447
+
491
448
  def init_weights(
492
449
  self,
493
450
  batch,
494
451
  ):
495
- weights = TensorDict(dict(self.memory_model.named_parameters()))
496
- weights = repeat_dict_values(weights, '... -> bh ...', bh = batch * self.heads)
452
+ if self.per_head_learned_parameters:
453
+ weights = repeat_dict_values(self.memory_model_parameter_dict, 'h ... -> (b h) ...', b = batch)
454
+ else:
455
+ weights = repeat_dict_values(self.memory_model_parameter_dict, '... -> bh ...', bh = batch * self.heads)
456
+
497
457
  return weights
498
458
 
499
459
  def init_momentum(
500
460
  self,
501
461
  batch,
502
462
  ):
503
- weights = TensorDict(dict(self.memory_model.named_parameters()))
504
- zeros = weights.clone().zero_()
505
- zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
463
+ zeros = self.memory_model_parameter_dict.clone().zero_()
464
+
465
+ if self.per_head_learned_parameters:
466
+ zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
467
+ else:
468
+ zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
469
+
506
470
  return zeros
507
471
 
508
472
  def store_memories(
@@ -694,7 +658,7 @@ class NeuralMemory(Module):
694
658
  ):
695
659
  chunk_size = self.retrieve_chunk_size
696
660
 
697
- weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
661
+ weights_have_expanded_shape = dict_get_value_shapes(weights) != self.init_weight_shape
698
662
 
699
663
  batch, seq_len = seq.shape[:2]
700
664
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.3
3
+ Version: 0.3.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
3
+ titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
+ titans_pytorch/memory_models.py,sha256=2fma9u0NQmDabgbpG6CLDGBRYzX99yIDQCSYIB0etkU,4989
5
+ titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
6
+ titans_pytorch-0.3.5.dist-info/METADATA,sha256=R6EL4q-zgW7DV5OyLzqz5XP2IvLNpJkaBylwH8GsyII,6815
7
+ titans_pytorch-0.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.5.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=0KLHZN-y_7lwrhWSnFRaYJ3GiUV3tzVjxS9CxIx_eI8,4843
5
- titans_pytorch/neural_memory.py,sha256=Ff-IBv-CCQAP7IYIpokPDoGtsvpzotAJsHB1d_-xd98,27934
6
- titans_pytorch-0.3.3.dist-info/METADATA,sha256=CutjohW8xSNycd5W-uyXC4827ubmIpAJCs9xoMbfZzo,6815
7
- titans_pytorch-0.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.3.dist-info/RECORD,,