titans-pytorch 0.0.30__py3-none-any.whl → 0.0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,9 +29,33 @@ def exists(v):
29
29
  def default(v, d):
30
30
  return v if exists(v) else d
31
31
 
32
+ def identity(t):
33
+ return t
34
+
32
35
  def round_up_multiple(seq, mult):
33
36
  return ceil(seq / mult) * mult
34
37
 
38
+ def pad_and_segment_with_inverse(seq, segment_len):
39
+ batch, seq_len = seq.shape[:2]
40
+
41
+ need_segment = seq_len >= segment_len
42
+
43
+ if not need_segment:
44
+ return seq, identity
45
+
46
+ next_seq_len_mult = round_up_multiple(seq_len, segment_len)
47
+
48
+ padding = next_seq_len_mult - seq_len
49
+ seq = F.pad(seq, (0, 0, 0, padding))
50
+
51
+ seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
52
+
53
+ def inverse(out):
54
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
55
+ return out[:, :-padding]
56
+
57
+ return seq, inverse
58
+
35
59
  # feedforward and attention
36
60
 
37
61
  class GEGLU(Module):
@@ -77,9 +101,6 @@ class SegmentedAttention(Module):
77
101
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
78
102
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
79
103
 
80
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
81
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
82
-
83
104
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
84
105
 
85
106
  def forward(self, seq):
@@ -91,16 +112,7 @@ class SegmentedAttention(Module):
91
112
  # auto pad to multiple
92
113
  # todo - get rid of logic with flex attention
93
114
 
94
- need_segment = seq_len >= total_segment_len
95
-
96
- if need_segment:
97
- next_seq_len = round_up_multiple(seq_len, total_segment_len)
98
- padding = next_seq_len - seq_len
99
-
100
- if padding > 0:
101
- seq = F.pad(seq, (0, 0, 0, padding))
102
-
103
- seq = self.segment_seq(seq)
115
+ seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
104
116
 
105
117
  # attention
106
118
 
@@ -130,10 +142,9 @@ class SegmentedAttention(Module):
130
142
 
131
143
  out = self.to_out(out)
132
144
 
133
- if need_segment:
134
- out = self.merge_seq_back(out)
145
+ out = inverse_segment(out)
135
146
 
136
- return out[:, :seq_len]
147
+ return out
137
148
 
138
149
  # MAC transformer
139
150
 
@@ -207,29 +218,18 @@ class MemoryAsContextTransformer(Module):
207
218
 
208
219
  # intersperse longterm memory
209
220
 
210
- need_segment = seq_len >= segment_len
211
-
212
- if need_segment:
213
- next_seq_len = round_up_multiple(seq_len, segment_len)
214
- padding = next_seq_len - seq_len
215
-
216
- if padding > 0:
217
- x = F.pad(x, (0, 0, 0, padding))
218
-
219
- x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
221
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
220
222
 
221
223
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
222
224
  x = torch.cat((mems, x), dim = -2)
223
225
 
224
- if need_segment:
225
- x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
226
- x = x[:, :seq_len]
226
+ x = inverse_segment(x)
227
227
 
228
228
  # apply axial positional embedding
229
229
  # so intra and inter segment can be more easily discerned by the network
230
230
 
231
231
  pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
232
- x = x + pos_emb[:seq_len]
232
+ x = x + pos_emb[:x.shape[-2]]
233
233
 
234
234
  # expand and reduce streams for hyper connections
235
235
 
@@ -241,6 +241,14 @@ class MemoryAsContextTransformer(Module):
241
241
 
242
242
  x = self.reduce_streams(x)
243
243
 
244
+ # excise out the memories
245
+
246
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
247
+
248
+ x = x[:, self.num_longterm_mem_tokens:]
249
+
250
+ x = inverse_segment(x)
251
+
244
252
  # to logits
245
253
 
246
254
  x = self.norm(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.30
3
+ Version: 0.0.31
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=XDLc9NWXpVMza03XjU0lkw5lRvtJ25ReTPKNoGslOOk,6773
3
+ titans_pytorch/mac_transformer.py,sha256=bq5RbCgA0GWLFHTrDTIKUSQhkkuCkdjEykOwjfHDs0M,6747
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.30.dist-info/METADATA,sha256=7H6WPsgfBE9ByEUP7r6C-cfBX6K13yNCaVhZMzwUvf8,3938
7
- titans_pytorch-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.30.dist-info/RECORD,,
6
+ titans_pytorch-0.0.31.dist-info/METADATA,sha256=bN1fVL2S_vML1oqLIA92tvBhkVvnpQN11fU4e1QVI4s,3938
7
+ titans_pytorch-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.31.dist-info/RECORD,,