hippoformer 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +17 -7
- {hippoformer-0.0.11.dist-info → hippoformer-0.0.12.dist-info}/METADATA +1 -1
- hippoformer-0.0.12.dist-info/RECORD +6 -0
- hippoformer-0.0.11.dist-info/RECORD +0 -6
- {hippoformer-0.0.11.dist-info → hippoformer-0.0.12.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.11.dist-info → hippoformer-0.0.12.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
|
|
4
|
+
from torch import nn, Tensor, cat, stack, arange, zeros_like, einsum, tensor
|
|
5
5
|
import torch.nn.functional as F
|
|
6
6
|
from torch.nn import Module
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
@@ -221,8 +221,11 @@ class PathIntegration(Module):
|
|
|
221
221
|
|
|
222
222
|
return self.rnn(transitions, prev_structural)
|
|
223
223
|
|
|
224
|
-
# custom transformer
|
|
225
|
-
|
|
224
|
+
# custom transformer proposed by James Whittington that bridges to hippocampal models with a few twists
|
|
225
|
+
|
|
226
|
+
# the mmTEM can be seen as a linear attention / TTT variant of what he proposed
|
|
227
|
+
# needed for the baseline as well as the parallel block to bolster local time prediction
|
|
228
|
+
|
|
226
229
|
# https://arxiv.org/abs/2112.04035
|
|
227
230
|
|
|
228
231
|
def FeedForward(dim, mult = 4.):
|
|
@@ -238,8 +241,9 @@ class Attention(Module):
|
|
|
238
241
|
self,
|
|
239
242
|
dim_q,
|
|
240
243
|
dim_kv,
|
|
244
|
+
window_size,
|
|
241
245
|
dim_head = 64,
|
|
242
|
-
heads = 8
|
|
246
|
+
heads = 8,
|
|
243
247
|
):
|
|
244
248
|
super().__init__()
|
|
245
249
|
dim_inner = dim_head * heads
|
|
@@ -251,6 +255,8 @@ class Attention(Module):
|
|
|
251
255
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
252
256
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
253
257
|
|
|
258
|
+
self.window_size = window_size
|
|
259
|
+
|
|
254
260
|
self.to_out = nn.Linear(dim_inner, dim_q, bias = False)
|
|
255
261
|
self.attn_head_sink = nn.Parameter(torch.randn(heads) * 1e-2) # needed as the diagonal is masked out, and for attention sink
|
|
256
262
|
|
|
@@ -280,9 +286,13 @@ class Attention(Module):
|
|
|
280
286
|
# the diagonal is masked out
|
|
281
287
|
|
|
282
288
|
i, j = sim.shape[-2:]
|
|
283
|
-
causal_mask_without_diagonal = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i)
|
|
284
289
|
|
|
285
|
-
|
|
290
|
+
j_seq = arange(j, device = device)[:, None]
|
|
291
|
+
i_seq = arange(i, device = device)[None, :] + (j - i)
|
|
292
|
+
|
|
293
|
+
windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
|
|
294
|
+
|
|
295
|
+
sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
286
296
|
|
|
287
297
|
# attention sink, for token as well as for attention sinking - from gpt-oss
|
|
288
298
|
|
|
@@ -314,7 +324,7 @@ class TEMTransformerBlock(Module):
|
|
|
314
324
|
):
|
|
315
325
|
super().__init__()
|
|
316
326
|
|
|
317
|
-
self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, dim_head = dim_head, heads = heads)
|
|
327
|
+
self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, window_size, dim_head = dim_head, heads = heads)
|
|
318
328
|
self.ff = FeedForward(dim_structure, ff_expansion_factor)
|
|
319
329
|
|
|
320
330
|
self.window_size = window_size
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=DjKAUfMpcoaAOqyuWauKp8n8e2YTzGVSOctNXxagkiA,18166
|
|
3
|
+
hippoformer-0.0.12.dist-info/METADATA,sha256=bH3GaJniFX2zCgxFFg8v0amGEPdWlBbG401Ml3_hDCs,2773
|
|
4
|
+
hippoformer-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hippoformer-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.12.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=PLMfdype8AMwlVWrtItDBkE3gU_BCUaL42NMjB4vhAY,17795
|
|
3
|
-
hippoformer-0.0.11.dist-info/METADATA,sha256=6NlqhZSEApQkUKsncBxmDIE03x_xZktHH-JCeYlYfcg,2773
|
|
4
|
-
hippoformer-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|