titans-pytorch 0.0.30__py3-none-any.whl → 0.0.31__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

@@ -29,9 +29,33 @@ def exists(v):
29
29
  def default(v, d):
30
30
  return v if exists(v) else d
31
31
 
32
+ def identity(t):
33
+ return t
34
+
32
35
  def round_up_multiple(seq, mult):
33
36
  return ceil(seq / mult) * mult
34
37
 
38
+ def pad_and_segment_with_inverse(seq, segment_len):
39
+ batch, seq_len = seq.shape[:2]
40
+
41
+ need_segment = seq_len >= segment_len
42
+
43
+ if not need_segment:
44
+ return seq, identity
45
+
46
+ next_seq_len_mult = round_up_multiple(seq_len, segment_len)
47
+
48
+ padding = next_seq_len_mult - seq_len
49
+ seq = F.pad(seq, (0, 0, 0, padding))
50
+
51
+ seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
52
+
53
+ def inverse(out):
54
+ out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
55
+ return out[:, :-padding]
56
+
57
+ return seq, inverse
58
+
35
59
  # feedforward and attention
36
60
 
37
61
  class GEGLU(Module):
@@ -77,9 +101,6 @@ class SegmentedAttention(Module):
77
101
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
78
102
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
79
103
 
80
- self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
81
- self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
82
-
83
104
  self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
84
105
 
85
106
  def forward(self, seq):
@@ -91,16 +112,7 @@ class SegmentedAttention(Module):
91
112
  # auto pad to multiple
92
113
  # todo - get rid of logic with flex attention
93
114
 
94
- need_segment = seq_len >= total_segment_len
95
-
96
- if need_segment:
97
- next_seq_len = round_up_multiple(seq_len, total_segment_len)
98
- padding = next_seq_len - seq_len
99
-
100
- if padding > 0:
101
- seq = F.pad(seq, (0, 0, 0, padding))
102
-
103
- seq = self.segment_seq(seq)
115
+ seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
104
116
 
105
117
  # attention
106
118
 
@@ -130,10 +142,9 @@ class SegmentedAttention(Module):
130
142
 
131
143
  out = self.to_out(out)
132
144
 
133
- if need_segment:
134
- out = self.merge_seq_back(out)
145
+ out = inverse_segment(out)
135
146
 
136
- return out[:, :seq_len]
147
+ return out
137
148
 
138
149
  # MAC transformer
139
150
 
@@ -207,29 +218,18 @@ class MemoryAsContextTransformer(Module):
207
218
 
208
219
  # intersperse longterm memory
209
220
 
210
- need_segment = seq_len >= segment_len
211
-
212
- if need_segment:
213
- next_seq_len = round_up_multiple(seq_len, segment_len)
214
- padding = next_seq_len - seq_len
215
-
216
- if padding > 0:
217
- x = F.pad(x, (0, 0, 0, padding))
218
-
219
- x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
221
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
220
222
 
221
223
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
222
224
  x = torch.cat((mems, x), dim = -2)
223
225
 
224
- if need_segment:
225
- x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
226
- x = x[:, :seq_len]
226
+ x = inverse_segment(x)
227
227
 
228
228
  # apply axial positional embedding
229
229
  # so intra and inter segment can be more easily discerned by the network
230
230
 
231
231
  pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
232
- x = x + pos_emb[:seq_len]
232
+ x = x + pos_emb[:x.shape[-2]]
233
233
 
234
234
  # expand and reduce streams for hyper connections
235
235
 
@@ -241,6 +241,14 @@ class MemoryAsContextTransformer(Module):
241
241
 
242
242
  x = self.reduce_streams(x)
243
243
 
244
+ # excise out the memories
245
+
246
+ x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
247
+
248
+ x = x[:, self.num_longterm_mem_tokens:]
249
+
250
+ x = inverse_segment(x)
251
+
244
252
  # to logits
245
253
 
246
254
  x = self.norm(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.30
3
+ Version: 0.0.31
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=XDLc9NWXpVMza03XjU0lkw5lRvtJ25ReTPKNoGslOOk,6773
3
+ titans_pytorch/mac_transformer.py,sha256=bq5RbCgA0GWLFHTrDTIKUSQhkkuCkdjEykOwjfHDs0M,6747
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.30.dist-info/METADATA,sha256=7H6WPsgfBE9ByEUP7r6C-cfBX6K13yNCaVhZMzwUvf8,3938
7
- titans_pytorch-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.30.dist-info/RECORD,,
6
+ titans_pytorch-0.0.31.dist-info/METADATA,sha256=bN1fVL2S_vML1oqLIA92tvBhkVvnpQN11fU4e1QVI4s,3938
7
+ titans_pytorch-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.31.dist-info/RECORD,,