hippoformer 0.0.7__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.7 → hippoformer-0.0.8}/PKG-INFO +2 -3
- {hippoformer-0.0.7 → hippoformer-0.0.8}/README.md +0 -2
- {hippoformer-0.0.7 → hippoformer-0.0.8}/hippoformer/hippoformer.py +80 -2
- {hippoformer-0.0.7 → hippoformer-0.0.8}/pyproject.toml +2 -1
- {hippoformer-0.0.7 → hippoformer-0.0.8}/tests/test_hippoformer.py +28 -7
- {hippoformer-0.0.7 → hippoformer-0.0.8}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.7 → hippoformer-0.0.8}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.7 → hippoformer-0.0.8}/.gitignore +0 -0
- {hippoformer-0.0.7 → hippoformer-0.0.8}/LICENSE +0 -0
- {hippoformer-0.0.7 → hippoformer-0.0.8}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.7 → hippoformer-0.0.8}/hippoformer-fig6.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hippoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: hippoformer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hippoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hippoformer
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
|
+
Requires-Dist: beartype
|
|
38
39
|
Requires-Dist: einops>=0.8.1
|
|
39
40
|
Requires-Dist: einx>=0.3.0
|
|
40
41
|
Requires-Dist: torch>=2.4
|
|
@@ -50,8 +51,6 @@ Description-Content-Type: text/markdown
|
|
|
50
51
|
|
|
51
52
|
Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Integrating Hippocampus-inspired Spatial Memory with Transformers
|
|
52
53
|
|
|
53
|
-
[Temporary Discord](https://discord.gg/MkACrrkrYR)
|
|
54
|
-
|
|
55
54
|
## Citations
|
|
56
55
|
|
|
57
56
|
```bibtex
|
|
@@ -7,6 +7,8 @@ from torch.nn import Module
|
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
8
|
from torch.func import vmap, grad, functional_call
|
|
9
9
|
|
|
10
|
+
from beartype import beartype
|
|
11
|
+
|
|
10
12
|
from einx import multiply
|
|
11
13
|
from einops import repeat, rearrange, pack, unpack
|
|
12
14
|
from einops.layers.torch import Rearrange
|
|
@@ -36,6 +38,80 @@ def pack_with_inverse(t, pattern):
|
|
|
36
38
|
def l2norm(t):
|
|
37
39
|
return F.normalize(t, dim = -1)
|
|
38
40
|
|
|
41
|
+
# sensory encoder decoder for 2d
|
|
42
|
+
|
|
43
|
+
grid_sensory_enc_dec = (
|
|
44
|
+
create_mlp(
|
|
45
|
+
dim = 32 * 2,
|
|
46
|
+
dim_in = 9,
|
|
47
|
+
dim_out = 32,
|
|
48
|
+
depth = 3,
|
|
49
|
+
),
|
|
50
|
+
create_mlp(
|
|
51
|
+
dim = 32 * 2,
|
|
52
|
+
dim_in = 32,
|
|
53
|
+
dim_out = 9,
|
|
54
|
+
depth = 3,
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# sensory encoder decoder for 3d maze
|
|
59
|
+
|
|
60
|
+
class EncoderPackTime(Module):
|
|
61
|
+
def __init__(self, fn: Module):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.fn = fn
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
x = rearrange(x, 'b c t h w -> b t c h w')
|
|
67
|
+
x, packed_shape = pack([x], '* c h w')
|
|
68
|
+
|
|
69
|
+
x = self.fn(x)
|
|
70
|
+
|
|
71
|
+
x, = unpack(x, packed_shape, '* d')
|
|
72
|
+
print(x.shape)
|
|
73
|
+
return x
|
|
74
|
+
|
|
75
|
+
class DecoderPackTime(Module):
|
|
76
|
+
def __init__(self, fn: Module):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.fn = fn
|
|
79
|
+
|
|
80
|
+
def forward(self, x):
|
|
81
|
+
x, packed_shape = pack(x, '* d')
|
|
82
|
+
|
|
83
|
+
x = self.fn(x)
|
|
84
|
+
|
|
85
|
+
x = unpack(x, packed_shape, '* c h w')
|
|
86
|
+
x = rearrange(x, 'b t c h w -> b c t h w')
|
|
87
|
+
return x
|
|
88
|
+
|
|
89
|
+
maze_sensory_enc_dec = (
|
|
90
|
+
EncoderPackTime(nn.Sequential(
|
|
91
|
+
nn.Conv2d(3, 16, 7, 2, padding = 3),
|
|
92
|
+
nn.ReLU(),
|
|
93
|
+
nn.Conv2d(16, 32, 3, 2, 1),
|
|
94
|
+
nn.ReLU(),
|
|
95
|
+
nn.Conv2d(32, 64, 3, 2, 1),
|
|
96
|
+
nn.ReLU(),
|
|
97
|
+
nn.Conv2d(64, 128, 3, 2, 1),
|
|
98
|
+
nn.ReLU(),
|
|
99
|
+
Rearrange('b ... -> b (...)'),
|
|
100
|
+
nn.Linear(2048, 32)
|
|
101
|
+
)),
|
|
102
|
+
DecoderPackTime(nn.Sequential(
|
|
103
|
+
nn.Linear(32, 2048),
|
|
104
|
+
Rearrange('b (c h w) -> b c h w', c = 128, h = 4),
|
|
105
|
+
nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding = (1, 1)),
|
|
106
|
+
nn.ReLU(),
|
|
107
|
+
nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding = (1, 1)),
|
|
108
|
+
nn.ReLU(),
|
|
109
|
+
nn.ConvTranspose2d(32, 16, 3, 2, 1, output_padding = (1, 1)),
|
|
110
|
+
nn.ReLU(),
|
|
111
|
+
nn.ConvTranspose2d(16, 3, 3, 2, 1, output_padding = (1, 1))
|
|
112
|
+
))
|
|
113
|
+
)
|
|
114
|
+
|
|
39
115
|
# path integration
|
|
40
116
|
|
|
41
117
|
class RNN(ScriptModule):
|
|
@@ -114,12 +190,12 @@ class PathIntegration(Module):
|
|
|
114
190
|
# proposed mmTEM
|
|
115
191
|
|
|
116
192
|
class mmTEM(Module):
|
|
193
|
+
@beartype
|
|
117
194
|
def __init__(
|
|
118
195
|
self,
|
|
119
196
|
dim,
|
|
120
197
|
*,
|
|
121
|
-
|
|
122
|
-
sensory_decoder: Module,
|
|
198
|
+
sensory_encoder_decoder: tuple[Module, Module],
|
|
123
199
|
dim_sensory,
|
|
124
200
|
dim_action,
|
|
125
201
|
dim_encoded_sensory,
|
|
@@ -139,6 +215,8 @@ class mmTEM(Module):
|
|
|
139
215
|
|
|
140
216
|
# sensory
|
|
141
217
|
|
|
218
|
+
sensory_encoder, sensory_decoder = sensory_encoder_decoder
|
|
219
|
+
|
|
142
220
|
self.sensory_encoder = sensory_encoder
|
|
143
221
|
self.sensory_decoder = sensory_decoder
|
|
144
222
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hippoformer"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.8"
|
|
4
4
|
description = "hippoformer"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -25,6 +25,7 @@ classifiers=[
|
|
|
25
25
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"assoc-scan",
|
|
28
|
+
"beartype",
|
|
28
29
|
"einx>=0.3.0",
|
|
29
30
|
"einops>=0.8.1",
|
|
30
31
|
"torch>=2.4",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
@@ -14,16 +15,40 @@ def test_path_integrate():
|
|
|
14
15
|
|
|
15
16
|
assert structure_codes.shape == (2, 16, 64)
|
|
16
17
|
|
|
17
|
-
|
|
18
|
+
@param('sensory_type', ('naive', '2d', '3d'))
|
|
19
|
+
def test_mm_tem(
|
|
20
|
+
sensory_type
|
|
21
|
+
):
|
|
18
22
|
import torch
|
|
19
23
|
from hippoformer.hippoformer import mmTEM
|
|
20
24
|
|
|
21
25
|
from torch.nn import Linear
|
|
22
26
|
|
|
27
|
+
if sensory_type == 'naive':
|
|
28
|
+
enc_dec = (
|
|
29
|
+
Linear(11, 32),
|
|
30
|
+
Linear(32, 11)
|
|
31
|
+
)
|
|
32
|
+
sensory = torch.randn(2, 16, 11)
|
|
33
|
+
|
|
34
|
+
elif sensory_type == '2d':
|
|
35
|
+
|
|
36
|
+
from hippoformer.hippoformer import grid_sensory_enc_dec
|
|
37
|
+
|
|
38
|
+
enc_dec = grid_sensory_enc_dec
|
|
39
|
+
sensory = torch.randn(2, 16, 9)
|
|
40
|
+
|
|
41
|
+
elif sensory_type == '3d':
|
|
42
|
+
|
|
43
|
+
from hippoformer.hippoformer import maze_sensory_enc_dec
|
|
44
|
+
|
|
45
|
+
enc_dec = maze_sensory_enc_dec
|
|
46
|
+
|
|
47
|
+
sensory = torch.randn(2, 3, 16, 64, 64)
|
|
48
|
+
|
|
23
49
|
model = mmTEM(
|
|
24
50
|
dim = 32,
|
|
25
|
-
|
|
26
|
-
sensory_decoder = Linear(32, 11),
|
|
51
|
+
sensory_encoder_decoder = enc_dec,
|
|
27
52
|
dim_sensory = 11,
|
|
28
53
|
dim_action = 7,
|
|
29
54
|
dim_structure = 32,
|
|
@@ -31,10 +56,6 @@ def test_mm_tem():
|
|
|
31
56
|
)
|
|
32
57
|
|
|
33
58
|
actions = torch.randn(2, 16, 7)
|
|
34
|
-
sensory = torch.randn(2, 16, 11)
|
|
35
|
-
|
|
36
|
-
actions = torch.randn(2, 16, 7)
|
|
37
|
-
sensory = torch.randn(2, 16, 11)
|
|
38
59
|
|
|
39
60
|
next_params = model(sensory, actions, return_memory_mlp_params = True)
|
|
40
61
|
next_params = model(sensory, actions, memory_mlp_params = next_params, return_memory_mlp_params = True)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|