hippoformer 0.0.6__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.6 → hippoformer-0.0.8}/PKG-INFO +2 -3
- {hippoformer-0.0.6 → hippoformer-0.0.8}/README.md +0 -2
- {hippoformer-0.0.6 → hippoformer-0.0.8}/hippoformer/hippoformer.py +115 -12
- {hippoformer-0.0.6 → hippoformer-0.0.8}/pyproject.toml +2 -1
- hippoformer-0.0.8/tests/test_hippoformer.py +64 -0
- hippoformer-0.0.6/tests/test_hippoformer.py +0 -37
- {hippoformer-0.0.6 → hippoformer-0.0.8}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.6 → hippoformer-0.0.8}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.6 → hippoformer-0.0.8}/.gitignore +0 -0
- {hippoformer-0.0.6 → hippoformer-0.0.8}/LICENSE +0 -0
- {hippoformer-0.0.6 → hippoformer-0.0.8}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.6 → hippoformer-0.0.8}/hippoformer-fig6.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hippoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: hippoformer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hippoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hippoformer
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
|
+
Requires-Dist: beartype
|
|
38
39
|
Requires-Dist: einops>=0.8.1
|
|
39
40
|
Requires-Dist: einx>=0.3.0
|
|
40
41
|
Requires-Dist: torch>=2.4
|
|
@@ -50,8 +51,6 @@ Description-Content-Type: text/markdown
|
|
|
50
51
|
|
|
51
52
|
Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Integrating Hippocampus-inspired Spatial Memory with Transformers
|
|
52
53
|
|
|
53
|
-
[Temporary Discord](https://discord.gg/MkACrrkrYR)
|
|
54
|
-
|
|
55
54
|
## Citations
|
|
56
55
|
|
|
57
56
|
```bibtex
|
|
@@ -7,6 +7,8 @@ from torch.nn import Module
|
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
8
|
from torch.func import vmap, grad, functional_call
|
|
9
9
|
|
|
10
|
+
from beartype import beartype
|
|
11
|
+
|
|
10
12
|
from einx import multiply
|
|
11
13
|
from einops import repeat, rearrange, pack, unpack
|
|
12
14
|
from einops.layers.torch import Rearrange
|
|
@@ -36,6 +38,80 @@ def pack_with_inverse(t, pattern):
|
|
|
36
38
|
def l2norm(t):
|
|
37
39
|
return F.normalize(t, dim = -1)
|
|
38
40
|
|
|
41
|
+
# sensory encoder decoder for 2d
|
|
42
|
+
|
|
43
|
+
grid_sensory_enc_dec = (
|
|
44
|
+
create_mlp(
|
|
45
|
+
dim = 32 * 2,
|
|
46
|
+
dim_in = 9,
|
|
47
|
+
dim_out = 32,
|
|
48
|
+
depth = 3,
|
|
49
|
+
),
|
|
50
|
+
create_mlp(
|
|
51
|
+
dim = 32 * 2,
|
|
52
|
+
dim_in = 32,
|
|
53
|
+
dim_out = 9,
|
|
54
|
+
depth = 3,
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# sensory encoder decoder for 3d maze
|
|
59
|
+
|
|
60
|
+
class EncoderPackTime(Module):
|
|
61
|
+
def __init__(self, fn: Module):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.fn = fn
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
x = rearrange(x, 'b c t h w -> b t c h w')
|
|
67
|
+
x, packed_shape = pack([x], '* c h w')
|
|
68
|
+
|
|
69
|
+
x = self.fn(x)
|
|
70
|
+
|
|
71
|
+
x, = unpack(x, packed_shape, '* d')
|
|
72
|
+
print(x.shape)
|
|
73
|
+
return x
|
|
74
|
+
|
|
75
|
+
class DecoderPackTime(Module):
|
|
76
|
+
def __init__(self, fn: Module):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.fn = fn
|
|
79
|
+
|
|
80
|
+
def forward(self, x):
|
|
81
|
+
x, packed_shape = pack(x, '* d')
|
|
82
|
+
|
|
83
|
+
x = self.fn(x)
|
|
84
|
+
|
|
85
|
+
x = unpack(x, packed_shape, '* c h w')
|
|
86
|
+
x = rearrange(x, 'b t c h w -> b c t h w')
|
|
87
|
+
return x
|
|
88
|
+
|
|
89
|
+
maze_sensory_enc_dec = (
|
|
90
|
+
EncoderPackTime(nn.Sequential(
|
|
91
|
+
nn.Conv2d(3, 16, 7, 2, padding = 3),
|
|
92
|
+
nn.ReLU(),
|
|
93
|
+
nn.Conv2d(16, 32, 3, 2, 1),
|
|
94
|
+
nn.ReLU(),
|
|
95
|
+
nn.Conv2d(32, 64, 3, 2, 1),
|
|
96
|
+
nn.ReLU(),
|
|
97
|
+
nn.Conv2d(64, 128, 3, 2, 1),
|
|
98
|
+
nn.ReLU(),
|
|
99
|
+
Rearrange('b ... -> b (...)'),
|
|
100
|
+
nn.Linear(2048, 32)
|
|
101
|
+
)),
|
|
102
|
+
DecoderPackTime(nn.Sequential(
|
|
103
|
+
nn.Linear(32, 2048),
|
|
104
|
+
Rearrange('b (c h w) -> b c h w', c = 128, h = 4),
|
|
105
|
+
nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding = (1, 1)),
|
|
106
|
+
nn.ReLU(),
|
|
107
|
+
nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding = (1, 1)),
|
|
108
|
+
nn.ReLU(),
|
|
109
|
+
nn.ConvTranspose2d(32, 16, 3, 2, 1, output_padding = (1, 1)),
|
|
110
|
+
nn.ReLU(),
|
|
111
|
+
nn.ConvTranspose2d(16, 3, 3, 2, 1, output_padding = (1, 1))
|
|
112
|
+
))
|
|
113
|
+
)
|
|
114
|
+
|
|
39
115
|
# path integration
|
|
40
116
|
|
|
41
117
|
class RNN(ScriptModule):
|
|
@@ -114,12 +190,12 @@ class PathIntegration(Module):
|
|
|
114
190
|
# proposed mmTEM
|
|
115
191
|
|
|
116
192
|
class mmTEM(Module):
|
|
193
|
+
@beartype
|
|
117
194
|
def __init__(
|
|
118
195
|
self,
|
|
119
196
|
dim,
|
|
120
197
|
*,
|
|
121
|
-
|
|
122
|
-
sensory_decoder: Module,
|
|
198
|
+
sensory_encoder_decoder: tuple[Module, Module],
|
|
123
199
|
dim_sensory,
|
|
124
200
|
dim_action,
|
|
125
201
|
dim_encoded_sensory,
|
|
@@ -139,6 +215,8 @@ class mmTEM(Module):
|
|
|
139
215
|
|
|
140
216
|
# sensory
|
|
141
217
|
|
|
218
|
+
sensory_encoder, sensory_decoder = sensory_encoder_decoder
|
|
219
|
+
|
|
142
220
|
self.sensory_encoder = sensory_encoder
|
|
143
221
|
self.sensory_decoder = sensory_decoder
|
|
144
222
|
|
|
@@ -179,7 +257,7 @@ class mmTEM(Module):
|
|
|
179
257
|
|
|
180
258
|
grad_fn = grad(forward_with_mse_loss)
|
|
181
259
|
|
|
182
|
-
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (
|
|
260
|
+
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (0, 0, 0))
|
|
183
261
|
|
|
184
262
|
# mlp decoder (from meta mlp output to joint)
|
|
185
263
|
|
|
@@ -213,6 +291,19 @@ class mmTEM(Module):
|
|
|
213
291
|
|
|
214
292
|
self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
|
|
215
293
|
|
|
294
|
+
def init_params_and_momentum(
|
|
295
|
+
self,
|
|
296
|
+
batch_size
|
|
297
|
+
):
|
|
298
|
+
|
|
299
|
+
params_dict = dict(self.meta_memory_mlp.named_parameters())
|
|
300
|
+
|
|
301
|
+
params = {name: repeat(param, '... -> b ...', b = batch_size) for name, param in params_dict.items()}
|
|
302
|
+
|
|
303
|
+
momentums = {name: zeros_like(param) for name, param in params.items()}
|
|
304
|
+
|
|
305
|
+
return params, momentums
|
|
306
|
+
|
|
216
307
|
def retrieve(
|
|
217
308
|
self,
|
|
218
309
|
structural_codes,
|
|
@@ -230,7 +321,9 @@ class mmTEM(Module):
|
|
|
230
321
|
self,
|
|
231
322
|
sensory,
|
|
232
323
|
actions,
|
|
233
|
-
|
|
324
|
+
memory_mlp_params = None,
|
|
325
|
+
return_losses = False,
|
|
326
|
+
return_memory_mlp_params = False
|
|
234
327
|
):
|
|
235
328
|
batch = actions.shape[0]
|
|
236
329
|
|
|
@@ -291,22 +384,28 @@ class mmTEM(Module):
|
|
|
291
384
|
|
|
292
385
|
lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
|
|
293
386
|
|
|
294
|
-
|
|
387
|
+
if exists(memory_mlp_params):
|
|
388
|
+
params, momentums = memory_mlp_params
|
|
389
|
+
else:
|
|
390
|
+
params, momentums = self.init_params_and_momentum(batch)
|
|
391
|
+
|
|
392
|
+
# store by getting gradients of mse loss of keys and values
|
|
393
|
+
|
|
295
394
|
grads = self.per_sample_grad_fn(params, keys, values)
|
|
296
395
|
|
|
297
|
-
# update the meta mlp parameters
|
|
396
|
+
# update the meta mlp parameters and momentums
|
|
298
397
|
|
|
299
|
-
init_momentums = {k: zeros_like(v) for k, v in params.items()}
|
|
300
398
|
next_params = dict()
|
|
399
|
+
next_momentum = dict()
|
|
301
400
|
|
|
302
401
|
for (
|
|
303
402
|
(key, param),
|
|
304
403
|
(_, grad),
|
|
305
|
-
(_,
|
|
404
|
+
(_, momentum)
|
|
306
405
|
) in zip(
|
|
307
406
|
params.items(),
|
|
308
407
|
grads.items(),
|
|
309
|
-
|
|
408
|
+
momentums.items()
|
|
310
409
|
):
|
|
311
410
|
|
|
312
411
|
grad, inverse_pack = pack_with_inverse(grad, 'b t *')
|
|
@@ -315,9 +414,7 @@ class mmTEM(Module):
|
|
|
315
414
|
|
|
316
415
|
expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
|
|
317
416
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
update = self.assoc_scan(grad, expanded_beta.sigmoid(), init_momentum)
|
|
417
|
+
update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
|
|
321
418
|
|
|
322
419
|
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
323
420
|
|
|
@@ -325,7 +422,10 @@ class mmTEM(Module):
|
|
|
325
422
|
|
|
326
423
|
acc_update = inverse_pack(acc_update)
|
|
327
424
|
|
|
425
|
+
# set the next params and momentum, which can be passed back in
|
|
426
|
+
|
|
328
427
|
next_params[key] = param - acc_update[:, -1]
|
|
428
|
+
next_momentum[key] = update[:, -1]
|
|
329
429
|
|
|
330
430
|
# losses
|
|
331
431
|
|
|
@@ -343,6 +443,9 @@ class mmTEM(Module):
|
|
|
343
443
|
inference_pred_loss
|
|
344
444
|
)
|
|
345
445
|
|
|
446
|
+
if return_memory_mlp_params:
|
|
447
|
+
return next_params, next_momentum
|
|
448
|
+
|
|
346
449
|
if not return_losses:
|
|
347
450
|
return total_loss
|
|
348
451
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hippoformer"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.8"
|
|
4
4
|
description = "hippoformer"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -25,6 +25,7 @@ classifiers=[
|
|
|
25
25
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"assoc-scan",
|
|
28
|
+
"beartype",
|
|
28
29
|
"einx>=0.3.0",
|
|
29
30
|
"einops>=0.8.1",
|
|
30
31
|
"torch>=2.4",
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
def test_path_integrate():
|
|
7
|
+
from hippoformer.hippoformer import PathIntegration
|
|
8
|
+
|
|
9
|
+
path_integrator = PathIntegration(32, 64)
|
|
10
|
+
|
|
11
|
+
actions = torch.randn(2, 16, 32)
|
|
12
|
+
|
|
13
|
+
structure_codes = path_integrator(actions)
|
|
14
|
+
structure_codes = path_integrator(actions, structure_codes) # pass in previous structure codes, it will auto use the last one as hidden and pass it to the RNN
|
|
15
|
+
|
|
16
|
+
assert structure_codes.shape == (2, 16, 64)
|
|
17
|
+
|
|
18
|
+
@param('sensory_type', ('naive', '2d', '3d'))
|
|
19
|
+
def test_mm_tem(
|
|
20
|
+
sensory_type
|
|
21
|
+
):
|
|
22
|
+
import torch
|
|
23
|
+
from hippoformer.hippoformer import mmTEM
|
|
24
|
+
|
|
25
|
+
from torch.nn import Linear
|
|
26
|
+
|
|
27
|
+
if sensory_type == 'naive':
|
|
28
|
+
enc_dec = (
|
|
29
|
+
Linear(11, 32),
|
|
30
|
+
Linear(32, 11)
|
|
31
|
+
)
|
|
32
|
+
sensory = torch.randn(2, 16, 11)
|
|
33
|
+
|
|
34
|
+
elif sensory_type == '2d':
|
|
35
|
+
|
|
36
|
+
from hippoformer.hippoformer import grid_sensory_enc_dec
|
|
37
|
+
|
|
38
|
+
enc_dec = grid_sensory_enc_dec
|
|
39
|
+
sensory = torch.randn(2, 16, 9)
|
|
40
|
+
|
|
41
|
+
elif sensory_type == '3d':
|
|
42
|
+
|
|
43
|
+
from hippoformer.hippoformer import maze_sensory_enc_dec
|
|
44
|
+
|
|
45
|
+
enc_dec = maze_sensory_enc_dec
|
|
46
|
+
|
|
47
|
+
sensory = torch.randn(2, 3, 16, 64, 64)
|
|
48
|
+
|
|
49
|
+
model = mmTEM(
|
|
50
|
+
dim = 32,
|
|
51
|
+
sensory_encoder_decoder = enc_dec,
|
|
52
|
+
dim_sensory = 11,
|
|
53
|
+
dim_action = 7,
|
|
54
|
+
dim_structure = 32,
|
|
55
|
+
dim_encoded_sensory = 32
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
actions = torch.randn(2, 16, 7)
|
|
59
|
+
|
|
60
|
+
next_params = model(sensory, actions, return_memory_mlp_params = True)
|
|
61
|
+
next_params = model(sensory, actions, memory_mlp_params = next_params, return_memory_mlp_params = True)
|
|
62
|
+
|
|
63
|
+
loss = model(sensory, actions, memory_mlp_params = next_params)
|
|
64
|
+
loss.backward()
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
def test_path_integrate():
|
|
6
|
-
from hippoformer.hippoformer import PathIntegration
|
|
7
|
-
|
|
8
|
-
path_integrator = PathIntegration(32, 64)
|
|
9
|
-
|
|
10
|
-
actions = torch.randn(2, 16, 32)
|
|
11
|
-
|
|
12
|
-
structure_codes = path_integrator(actions)
|
|
13
|
-
structure_codes = path_integrator(actions, structure_codes) # pass in previous structure codes, it will auto use the last one as hidden and pass it to the RNN
|
|
14
|
-
|
|
15
|
-
assert structure_codes.shape == (2, 16, 64)
|
|
16
|
-
|
|
17
|
-
def test_mm_tem():
|
|
18
|
-
import torch
|
|
19
|
-
from hippoformer.hippoformer import mmTEM
|
|
20
|
-
|
|
21
|
-
from torch.nn import Linear
|
|
22
|
-
|
|
23
|
-
model = mmTEM(
|
|
24
|
-
dim = 32,
|
|
25
|
-
sensory_encoder = Linear(11, 32),
|
|
26
|
-
sensory_decoder = Linear(32, 11),
|
|
27
|
-
dim_sensory = 11,
|
|
28
|
-
dim_action = 7,
|
|
29
|
-
dim_structure = 32,
|
|
30
|
-
dim_encoded_sensory = 32
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
actions = torch.randn(2, 16, 7)
|
|
34
|
-
sensory = torch.randn(2, 16, 11)
|
|
35
|
-
|
|
36
|
-
loss = model(sensory, actions)
|
|
37
|
-
loss.backward()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|