hippoformer 0.0.11__tar.gz → 0.0.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.11 → hippoformer-0.0.14}/PKG-INFO +12 -1
- {hippoformer-0.0.11 → hippoformer-0.0.14}/README.md +11 -0
- {hippoformer-0.0.11 → hippoformer-0.0.14}/hippoformer/hippoformer.py +118 -26
- {hippoformer-0.0.11 → hippoformer-0.0.14}/pyproject.toml +1 -1
- {hippoformer-0.0.11 → hippoformer-0.0.14}/tests/test_hippoformer.py +1 -1
- {hippoformer-0.0.11 → hippoformer-0.0.14}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.11 → hippoformer-0.0.14}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.11 → hippoformer-0.0.14}/.gitignore +0 -0
- {hippoformer-0.0.11 → hippoformer-0.0.14}/LICENSE +0 -0
- {hippoformer-0.0.11 → hippoformer-0.0.14}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.11 → hippoformer-0.0.14}/hippoformer-fig6.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hippoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.14
|
|
4
4
|
Summary: hippoformer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hippoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hippoformer
|
|
@@ -63,3 +63,14 @@ Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Int
|
|
|
63
63
|
note = {under review}
|
|
64
64
|
}
|
|
65
65
|
```
|
|
66
|
+
|
|
67
|
+
```bibtex
|
|
68
|
+
@article{Li2020GridCA,
|
|
69
|
+
title = {Grid Cells Are Ubiquitous in Neural Networks},
|
|
70
|
+
author = {Songlin Li and Yangdong Deng and Zhihua Wang},
|
|
71
|
+
journal = {ArXiv},
|
|
72
|
+
year = {2020},
|
|
73
|
+
volume = {abs/2003.03482},
|
|
74
|
+
url = {https://api.semanticscholar.org/CorpusID:212634300}
|
|
75
|
+
}
|
|
76
|
+
```
|
|
@@ -16,3 +16,14 @@ Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Int
|
|
|
16
16
|
note = {under review}
|
|
17
17
|
}
|
|
18
18
|
```
|
|
19
|
+
|
|
20
|
+
```bibtex
|
|
21
|
+
@article{Li2020GridCA,
|
|
22
|
+
title = {Grid Cells Are Ubiquitous in Neural Networks},
|
|
23
|
+
author = {Songlin Li and Yangdong Deng and Zhihua Wang},
|
|
24
|
+
journal = {ArXiv},
|
|
25
|
+
year = {2020},
|
|
26
|
+
volume = {abs/2003.03482},
|
|
27
|
+
url = {https://api.semanticscholar.org/CorpusID:212634300}
|
|
28
|
+
}
|
|
29
|
+
```
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
|
|
4
|
+
from torch import nn, Tensor, cat, stack, arange, zeros_like, einsum, tensor
|
|
5
5
|
import torch.nn.functional as F
|
|
6
|
-
from torch.nn import Module
|
|
6
|
+
from torch.nn import Module, ModuleList
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
8
|
from torch.func import vmap, grad, functional_call
|
|
9
9
|
|
|
@@ -221,8 +221,11 @@ class PathIntegration(Module):
|
|
|
221
221
|
|
|
222
222
|
return self.rnn(transitions, prev_structural)
|
|
223
223
|
|
|
224
|
-
# custom transformer
|
|
225
|
-
|
|
224
|
+
# custom transformer proposed by James Whittington that bridges to hippocampal models with a few twists
|
|
225
|
+
|
|
226
|
+
# the mmTEM can be seen as a linear attention / TTT variant of what he proposed
|
|
227
|
+
# needed for the baseline as well as the parallel block to bolster local time prediction
|
|
228
|
+
|
|
226
229
|
# https://arxiv.org/abs/2112.04035
|
|
227
230
|
|
|
228
231
|
def FeedForward(dim, mult = 4.):
|
|
@@ -238,19 +241,32 @@ class Attention(Module):
|
|
|
238
241
|
self,
|
|
239
242
|
dim_q,
|
|
240
243
|
dim_kv,
|
|
244
|
+
window_size,
|
|
241
245
|
dim_head = 64,
|
|
242
|
-
heads = 8
|
|
246
|
+
heads = 8,
|
|
247
|
+
implicit_mlp_expansion = 2 # for fair comparison, the attention should have an implicit mlp of 2 layers with a non-linearity, just like the meta-memory mlp in titans (linear attention)
|
|
243
248
|
):
|
|
244
249
|
super().__init__()
|
|
245
250
|
dim_inner = dim_head * heads
|
|
251
|
+
dim_mlp_inner = dim_head * heads * implicit_mlp_expansion
|
|
252
|
+
|
|
246
253
|
self.scale = dim_head ** -0.5
|
|
247
254
|
|
|
248
255
|
self.to_queries = nn.Linear(dim_q, dim_inner, bias = False)
|
|
249
|
-
|
|
256
|
+
|
|
257
|
+
self.to_w1_keys = nn.Linear(dim_kv, dim_inner, bias = False)
|
|
258
|
+
self.to_w1_values = nn.Linear(dim_kv, dim_mlp_inner, bias = False)
|
|
259
|
+
|
|
260
|
+
self.implicit_mlp_activation = nn.SiLU()
|
|
261
|
+
|
|
262
|
+
self.to_w2_keys = nn.Linear(dim_kv, dim_mlp_inner, bias = False)
|
|
263
|
+
self.to_w2_values = nn.Linear(dim_kv, dim_inner, bias = False)
|
|
250
264
|
|
|
251
265
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
252
266
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
253
267
|
|
|
268
|
+
self.window_size = window_size
|
|
269
|
+
|
|
254
270
|
self.to_out = nn.Linear(dim_inner, dim_q, bias = False)
|
|
255
271
|
self.attn_head_sink = nn.Parameter(torch.randn(heads) * 1e-2) # needed as the diagonal is masked out, and for attention sink
|
|
256
272
|
|
|
@@ -264,43 +280,59 @@ class Attention(Module):
|
|
|
264
280
|
|
|
265
281
|
q = self.to_queries(queries_input)
|
|
266
282
|
|
|
267
|
-
|
|
283
|
+
k1, v1, k2, v2 = [fn(key_values_input) for fn in (self.to_w1_keys, self.to_w1_values, self.to_w2_keys, self.to_w2_values)]
|
|
268
284
|
|
|
269
|
-
q,
|
|
285
|
+
q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
|
|
270
286
|
|
|
271
287
|
if exists(kv_cache):
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
288
|
+
ck1, cv1, vk2, cv2 = kv_cache
|
|
289
|
+
k1 = cat((ck1, k1), dim = -2)
|
|
290
|
+
v1 = cat((cv1, v1), dim = -2)
|
|
291
|
+
k2 = cat((ck2, k2), dim = -2)
|
|
292
|
+
v2 = cat((cv2, v2), dim = -2)
|
|
293
|
+
|
|
294
|
+
def attend(q, k, v):
|
|
295
|
+
q = q * self.scale
|
|
296
|
+
|
|
297
|
+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
|
275
298
|
|
|
276
|
-
|
|
299
|
+
# the diagonal is masked out
|
|
277
300
|
|
|
278
|
-
|
|
301
|
+
i, j = sim.shape[-2:]
|
|
279
302
|
|
|
280
|
-
|
|
303
|
+
j_seq = arange(j, device = device)[:, None]
|
|
304
|
+
i_seq = arange(i, device = device)[None, :] + (j - i)
|
|
281
305
|
|
|
282
|
-
|
|
283
|
-
causal_mask_without_diagonal = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i)
|
|
306
|
+
windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
|
|
284
307
|
|
|
285
|
-
|
|
308
|
+
sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
286
309
|
|
|
287
|
-
|
|
310
|
+
# attention sink, for token as well as for attention sinking - from gpt-oss
|
|
288
311
|
|
|
289
|
-
|
|
312
|
+
attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
|
|
290
313
|
|
|
291
|
-
|
|
314
|
+
sim = cat((attn_sink, sim), dim = -1)
|
|
292
315
|
|
|
293
|
-
|
|
316
|
+
attn = sim.softmax(dim = -1)
|
|
294
317
|
|
|
295
|
-
|
|
318
|
+
attn = attn[..., 1:] # remove sink
|
|
296
319
|
|
|
297
|
-
|
|
320
|
+
# aggregate
|
|
298
321
|
|
|
299
|
-
|
|
322
|
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
323
|
+
return out
|
|
324
|
+
|
|
325
|
+
# implicit memory mlp w1
|
|
326
|
+
|
|
327
|
+
hiddens = attend(q, k1, v1)
|
|
328
|
+
hiddens = self.implicit_mlp_activation(hiddens)
|
|
329
|
+
out = attend(hiddens, k2, v2)
|
|
330
|
+
|
|
331
|
+
# merge heads
|
|
300
332
|
|
|
301
333
|
out = self.merge_heads(out)
|
|
302
334
|
|
|
303
|
-
return self.to_out(out),
|
|
335
|
+
return self.to_out(out), (k1, v1, k2, v2)
|
|
304
336
|
|
|
305
337
|
class TEMTransformerBlock(Module):
|
|
306
338
|
def __init__(
|
|
@@ -314,7 +346,7 @@ class TEMTransformerBlock(Module):
|
|
|
314
346
|
):
|
|
315
347
|
super().__init__()
|
|
316
348
|
|
|
317
|
-
self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, dim_head = dim_head, heads = heads)
|
|
349
|
+
self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, window_size, dim_head = dim_head, heads = heads)
|
|
318
350
|
self.ff = FeedForward(dim_structure, ff_expansion_factor)
|
|
319
351
|
|
|
320
352
|
self.window_size = window_size
|
|
@@ -337,6 +369,66 @@ class TEMTransformerBlock(Module):
|
|
|
337
369
|
|
|
338
370
|
return x, next_kv_cache
|
|
339
371
|
|
|
372
|
+
class TEMTransformer(Module):
|
|
373
|
+
def __init__(
|
|
374
|
+
self,
|
|
375
|
+
sensory_encoder_decoder: tuple[Module, Module],
|
|
376
|
+
dim_sensory,
|
|
377
|
+
dim_action,
|
|
378
|
+
dim_encoded_sensory,
|
|
379
|
+
dim_structure,
|
|
380
|
+
depth = 4,
|
|
381
|
+
transformer_kwargs: dict = dict(
|
|
382
|
+
dim_head = 64,
|
|
383
|
+
heads = 8,
|
|
384
|
+
ff_expansion_factor = 4,
|
|
385
|
+
window_size = 32
|
|
386
|
+
),
|
|
387
|
+
):
|
|
388
|
+
super().__init__()
|
|
389
|
+
|
|
390
|
+
self.sensory_encoder, self.sensory_decoder = sensory_encoder_decoder
|
|
391
|
+
|
|
392
|
+
self.path_integrator = nn.GRU(dim_action, dim_structure)
|
|
393
|
+
|
|
394
|
+
self.layers = ModuleList([])
|
|
395
|
+
|
|
396
|
+
for _ in range(depth):
|
|
397
|
+
|
|
398
|
+
block = TEMTransformerBlock(
|
|
399
|
+
dim_structure,
|
|
400
|
+
dim_encoded_sensory,
|
|
401
|
+
**transformer_kwargs
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
layers.append(block)
|
|
405
|
+
|
|
406
|
+
def forward(
|
|
407
|
+
self,
|
|
408
|
+
sensory,
|
|
409
|
+
actions,
|
|
410
|
+
prev_hiddens = None, # for the GRU based path integrator
|
|
411
|
+
prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
|
|
412
|
+
):
|
|
413
|
+
|
|
414
|
+
structure, next_hiddens = self.gru_path_integrator(actions, prev_hiddens)
|
|
415
|
+
|
|
416
|
+
encoded_sensory = self.sensory_encoder(sensory)
|
|
417
|
+
|
|
418
|
+
next_kv_cache = []
|
|
419
|
+
|
|
420
|
+
for layer in self.layers:
|
|
421
|
+
structure, layer_next_cache = layer(structure, encoded_sensory)
|
|
422
|
+
next_kv_cache.append(layer_next_cache)
|
|
423
|
+
|
|
424
|
+
decoded_sensory = self.sensory_decoder(structure)
|
|
425
|
+
|
|
426
|
+
next_memories = (next_hiddens, stack(next_kv_cache))
|
|
427
|
+
|
|
428
|
+
pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
|
|
429
|
+
|
|
430
|
+
return pred_loss
|
|
431
|
+
|
|
340
432
|
# proposed mmTEM
|
|
341
433
|
|
|
342
434
|
class mmTEM(Module):
|
|
@@ -69,7 +69,7 @@ def test_mm_tem(
|
|
|
69
69
|
def test_tem_t():
|
|
70
70
|
from hippoformer.hippoformer import TEMTransformerBlock
|
|
71
71
|
|
|
72
|
-
block = TEMTransformerBlock(32, 16)
|
|
72
|
+
block = TEMTransformerBlock(32, 16, window_size = 3)
|
|
73
73
|
|
|
74
74
|
structural_codes = torch.randn(1, 7, 32)
|
|
75
75
|
encoded_sensory = torch.randn(1, 7, 16)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|