titans-pytorch 0.0.30__tar.gz → 0.0.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.30
3
+ Version: 0.0.32
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.30"
3
+ version = "0.0.32"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
17
17
  from axial_positional_embedding import ContinuousAxialPositionalEmbedding
18
18
  from rotary_embedding_torch import RotaryEmbedding
19
19
 
20
+ # proposed neural memory
21
+
22
+ from titans_pytorch.titans import NeuralMemory
23
+
20
24
  # constants
21
25
 
22
26
  LinearNoBias = partial(Linear, bias = False)
@@ -29,9 +33,33 @@ def exists(v):
29
33
  def default(v, d):
30
34
  return v if exists(v) else d
31
35
 
36
+ def identity(t):
37
+ return t
38
+
32
39
  def round_up_multiple(seq, mult):
33
40
  return ceil(seq / mult) * mult
34
41
 
42
+ def pad_and_segment_with_inverse(seq, segment_len):
43
+ batch, seq_len = seq.shape[:2]
44
+
45
+ need_segment = seq_len >= segment_len
46
+
47
+ if not need_segment:
48
+ return seq, identity
49
+
50
+ next_seq_len_mult = round_up_multiple(seq_len, segment_len)
51
+
52
+ padding = next_seq_len_mult - seq_len
53
+ seq = F.pad(seq, (0, 0, 0, padding))
54
+
55
+ seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
56
+
57
+ def inverse(out):
58
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
59
+ return out[:, :-padding]
60
+
61
+ return seq, inverse
62
+
35
63
  # feedforward and attention
36
64
 
37
65
  class GEGLU(Module):
@@ -77,9 +105,6 @@ class SegmentedAttention(Module):
77
105
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
78
106
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
79
107
 
80
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
81
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
82
-
83
108
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
84
109
 
85
110
  def forward(self, seq):
@@ -91,16 +116,7 @@ class SegmentedAttention(Module):
91
116
  # auto pad to multiple
92
117
  # todo - get rid of logic with flex attention
93
118
 
94
- need_segment = seq_len >= total_segment_len
95
-
96
- if need_segment:
97
- next_seq_len = round_up_multiple(seq_len, total_segment_len)
98
- padding = next_seq_len - seq_len
99
-
100
- if padding > 0:
101
- seq = F.pad(seq, (0, 0, 0, padding))
102
-
103
- seq = self.segment_seq(seq)
119
+ seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
104
120
 
105
121
  # attention
106
122
 
@@ -130,10 +146,9 @@ class SegmentedAttention(Module):
130
146
 
131
147
  out = self.to_out(out)
132
148
 
133
- if need_segment:
134
- out = self.merge_seq_back(out)
149
+ out = inverse_segment(out)
135
150
 
136
- return out[:, :seq_len]
151
+ return out
137
152
 
138
153
  # MAC transformer
139
154
 
@@ -150,7 +165,9 @@ class MemoryAsContextTransformer(Module):
150
165
  dim_head = 64,
151
166
  heads = 8,
152
167
  ff_mult = 4,
153
- num_residual_streams = 4
168
+ num_residual_streams = 4,
169
+ neural_memory_kwargs: dict = dict(),
170
+ neural_memory_layers: tuple[int, ...] | None = None,
154
171
  ):
155
172
  super().__init__()
156
173
 
@@ -170,8 +187,25 @@ class MemoryAsContextTransformer(Module):
170
187
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
171
188
 
172
189
  self.layers = ModuleList([])
190
+ self.neural_mem_layers = ModuleList([])
191
+
192
+ layers = tuple(range(1, depth + 1))
193
+ neural_memory_layers = set(default(neural_memory_layers, layers))
194
+
195
+ for layer in layers:
196
+
197
+ # neural memory
198
+
199
+ mem = None
200
+
201
+ if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
202
+ mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
203
+ mem = init_hyper_conn(dim = dim, branch = mem)
204
+
205
+ self.neural_mem_layers.append(mem)
206
+
207
+ # attention and feedforward
173
208
 
174
- for _ in range(depth):
175
209
  attn = SegmentedAttention(
176
210
  dim = dim,
177
211
  dim_head = dim_head,
@@ -207,40 +241,54 @@ class MemoryAsContextTransformer(Module):
207
241
 
208
242
  # intersperse longterm memory
209
243
 
210
- need_segment = seq_len >= segment_len
211
-
212
- if need_segment:
213
- next_seq_len = round_up_multiple(seq_len, segment_len)
214
- padding = next_seq_len - seq_len
215
-
216
- if padding > 0:
217
- x = F.pad(x, (0, 0, 0, padding))
218
-
219
- x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
244
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
220
245
 
221
246
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
222
- x = torch.cat((mems, x), dim = -2)
247
+ x = cat((mems, x), dim = -2)
223
248
 
224
- if need_segment:
225
- x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
226
- x = x[:, :seq_len]
249
+ x = inverse_segment(x)
227
250
 
228
251
  # apply axial positional embedding
229
252
  # so intra and inter segment can be more easily discerned by the network
230
253
 
231
254
  pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
232
- x = x + pos_emb[:seq_len]
255
+ x = x + pos_emb[:x.shape[-2]]
233
256
 
234
257
  # expand and reduce streams for hyper connections
235
258
 
236
259
  x = self.expand_streams(x)
237
260
 
238
- for attn, ff in self.layers:
261
+ for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
262
+
263
+ if exists(maybe_neural_mem):
264
+ batch_streams = x.shape[0]
265
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
266
+
267
+ longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
268
+
269
+ longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
270
+
271
+ longterm_mems = maybe_neural_mem(longterm_mems)
272
+
273
+ longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
274
+
275
+ x = cat((longterm_mems, x), dim = -2)
276
+
277
+ x = inverse_segment(x)
278
+
239
279
  x = attn(x)
240
280
  x = ff(x)
241
281
 
242
282
  x = self.reduce_streams(x)
243
283
 
284
+ # excise out the memories
285
+
286
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
287
+
288
+ x = x[:, num_longterm_mem_tokens:]
289
+
290
+ x = inverse_segment(x)
291
+
244
292
  # to logits
245
293
 
246
294
  x = self.norm(x)
File without changes