titans-pytorch 0.0.62__py3-none-any.whl → 0.0.64__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -153,6 +153,7 @@ class SegmentedAttention(Module):
153
153
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
154
154
 
155
155
  total_segment_len = segment_len + num_longterm_mem_tokens
156
+ self.total_segment_len = total_segment_len
156
157
 
157
158
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
158
159
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
@@ -209,7 +210,7 @@ class SegmentedAttention(Module):
209
210
  # prep flex attention
210
211
 
211
212
  if not exists(flex_attn_fn):
212
- block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
213
+ block_mask = create_mac_block_mask(seq_len, self.total_segment_len, self.num_persist_mem_tokens)
213
214
 
214
215
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
215
216
 
@@ -241,7 +242,6 @@ class SegmentedAttention(Module):
241
242
  batch, seq_len = seq.shape[:2]
242
243
 
243
244
  # auto pad to multiple
244
- # todo - get rid of logic with flex attention
245
245
 
246
246
  seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
247
247
 
titans_pytorch/titans.py CHANGED
@@ -168,13 +168,14 @@ class MemoryAttention(Module):
168
168
  ):
169
169
  super().__init__()
170
170
  self.scale = scale
171
+ dim_ff_hidden = int(dim * expansion_factor)
171
172
 
172
173
  self.weights = nn.ParameterList([
173
174
  nn.Parameter(torch.randn(dim, dim)), # queries
174
175
  nn.Parameter(torch.randn(dim, dim)), # keys
175
176
  nn.Parameter(torch.randn(dim, dim)), # values
176
- nn.Parameter(torch.randn(dim, dim * expansion_factor)), # ff w1
177
- nn.Parameter(torch.randn(dim * expansion_factor, dim)), # ff w2
177
+ nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
178
+ nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
178
179
  ])
179
180
 
180
181
  for weight in self.weights:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.62
3
+ Version: 0.0.64
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=7voYtbD_ErCD_JjvwhAiunUWtSIsIxGJAaf2aRB3c2s,15349
4
+ titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
5
+ titans_pytorch-0.0.64.dist-info/METADATA,sha256=K63jobSfTdn-aFpEpZgolu4zSIvgUzF2rDuoCHGXkgE,4457
6
+ titans_pytorch-0.0.64.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.64.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.64.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
4
- titans_pytorch/titans.py,sha256=95J6UL44lOrdZSXdm7p36xC9tDeSmRBwdjig9T82PzI,17452
5
- titans_pytorch-0.0.62.dist-info/METADATA,sha256=08Blaa9Ehyv09rSA9uWguxbhKpbrd7C53Ya13E1VbpU,4457
6
- titans_pytorch-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.62.dist-info/RECORD,,