hippoformer 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
4
+ from torch import nn, Tensor, cat, stack, arange, zeros_like, einsum, tensor
5
5
  import torch.nn.functional as F
6
6
  from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
@@ -221,8 +221,11 @@ class PathIntegration(Module):
221
221
 
222
222
  return self.rnn(transitions, prev_structural)
223
223
 
224
- # custom transformer
225
- # with the connections proposed by James Whittington that bridges to hippocampal models
224
+ # custom transformer proposed by James Whittington that bridges to hippocampal models with a few twists
225
+
226
+ # the mmTEM can be seen as a linear attention / TTT variant of what he proposed
227
+ # needed for the baseline as well as the parallel block to bolster local time prediction
228
+
226
229
  # https://arxiv.org/abs/2112.04035
227
230
 
228
231
  def FeedForward(dim, mult = 4.):
@@ -238,8 +241,9 @@ class Attention(Module):
238
241
  self,
239
242
  dim_q,
240
243
  dim_kv,
244
+ window_size,
241
245
  dim_head = 64,
242
- heads = 8
246
+ heads = 8,
243
247
  ):
244
248
  super().__init__()
245
249
  dim_inner = dim_head * heads
@@ -251,6 +255,8 @@ class Attention(Module):
251
255
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
252
256
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
253
257
 
258
+ self.window_size = window_size
259
+
254
260
  self.to_out = nn.Linear(dim_inner, dim_q, bias = False)
255
261
  self.attn_head_sink = nn.Parameter(torch.randn(heads) * 1e-2) # needed as the diagonal is masked out, and for attention sink
256
262
 
@@ -280,9 +286,13 @@ class Attention(Module):
280
286
  # the diagonal is masked out
281
287
 
282
288
  i, j = sim.shape[-2:]
283
- causal_mask_without_diagonal = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i)
284
289
 
285
- sim = sim.masked_fill(causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
290
+ j_seq = arange(j, device = device)[:, None]
291
+ i_seq = arange(i, device = device)[None, :] + (j - i)
292
+
293
+ windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
294
+
295
+ sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
286
296
 
287
297
  # attention sink, for token as well as for attention sinking - from gpt-oss
288
298
 
@@ -314,7 +324,7 @@ class TEMTransformerBlock(Module):
314
324
  ):
315
325
  super().__init__()
316
326
 
317
- self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, dim_head = dim_head, heads = heads)
327
+ self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, window_size, dim_head = dim_head, heads = heads)
318
328
  self.ff = FeedForward(dim_structure, ff_expansion_factor)
319
329
 
320
330
  self.window_size = window_size
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.11
3
+ Version: 0.0.12
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=DjKAUfMpcoaAOqyuWauKp8n8e2YTzGVSOctNXxagkiA,18166
3
+ hippoformer-0.0.12.dist-info/METADATA,sha256=bH3GaJniFX2zCgxFFg8v0amGEPdWlBbG401Ml3_hDCs,2773
4
+ hippoformer-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.12.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=PLMfdype8AMwlVWrtItDBkE3gU_BCUaL42NMjB4vhAY,17795
3
- hippoformer-0.0.11.dist-info/METADATA,sha256=6NlqhZSEApQkUKsncBxmDIE03x_xZktHH-JCeYlYfcg,2773
4
- hippoformer-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.11.dist-info/RECORD,,