hippoformer 0.0.1__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.1 → hippoformer-0.0.3}/PKG-INFO +1 -1
- hippoformer-0.0.3/hippoformer/__init__.py +4 -0
- hippoformer-0.0.3/hippoformer/hippoformer.py +219 -0
- {hippoformer-0.0.1 → hippoformer-0.0.3}/pyproject.toml +1 -1
- hippoformer-0.0.3/tests/test_hippoformer.py +37 -0
- hippoformer-0.0.1/hippoformer/__init__.py +0 -0
- hippoformer-0.0.1/hippoformer/hippoformer.py +0 -116
- hippoformer-0.0.1/tests/test_hippoformer.py +0 -15
- {hippoformer-0.0.1 → hippoformer-0.0.3}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.1 → hippoformer-0.0.3}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.1 → hippoformer-0.0.3}/.gitignore +0 -0
- {hippoformer-0.0.1 → hippoformer-0.0.3}/LICENSE +0 -0
- {hippoformer-0.0.1 → hippoformer-0.0.3}/README.md +0 -0
- {hippoformer-0.0.1 → hippoformer-0.0.3}/hippoformer-fig6.png +0 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, Tensor, stack, einsum, tensor
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch.nn import Module
|
|
7
|
+
from torch.jit import ScriptModule, script_method
|
|
8
|
+
from torch.func import vmap, grad, functional_call
|
|
9
|
+
|
|
10
|
+
from einops import repeat, rearrange
|
|
11
|
+
from einops.layers.torch import Rearrange
|
|
12
|
+
|
|
13
|
+
from x_mlps_pytorch import create_mlp
|
|
14
|
+
|
|
15
|
+
from assoc_scan import AssocScan
|
|
16
|
+
|
|
17
|
+
# helpers
|
|
18
|
+
|
|
19
|
+
def exists(v):
|
|
20
|
+
return v is not None
|
|
21
|
+
|
|
22
|
+
def default(v, d):
|
|
23
|
+
return v if exists(v) else d
|
|
24
|
+
|
|
25
|
+
def l2norm(t):
|
|
26
|
+
return F.normalize(t, dim = -1)
|
|
27
|
+
|
|
28
|
+
# path integration
|
|
29
|
+
|
|
30
|
+
class RNN(ScriptModule):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
dim,
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.init_hidden = nn.Parameter(torch.randn(1, dim) * 1e-2)
|
|
37
|
+
|
|
38
|
+
@script_method
|
|
39
|
+
def forward(
|
|
40
|
+
self,
|
|
41
|
+
transitions: Tensor,
|
|
42
|
+
hidden: Tensor | None = None
|
|
43
|
+
) -> Tensor:
|
|
44
|
+
|
|
45
|
+
batch, seq_len = transitions.shape[:2]
|
|
46
|
+
|
|
47
|
+
if hidden is None:
|
|
48
|
+
hidden = l2norm(self.init_hidden)
|
|
49
|
+
hidden = hidden.expand(batch, -1)
|
|
50
|
+
|
|
51
|
+
hiddens: list[Tensor] = []
|
|
52
|
+
|
|
53
|
+
for i in range(seq_len):
|
|
54
|
+
transition = transitions[:, i]
|
|
55
|
+
|
|
56
|
+
hidden = einsum('b i, b i j -> b j', hidden, transition)
|
|
57
|
+
hidden = F.relu(hidden)
|
|
58
|
+
hidden = l2norm(hidden)
|
|
59
|
+
|
|
60
|
+
hiddens.append(hidden)
|
|
61
|
+
|
|
62
|
+
return stack(hiddens, dim = 1)
|
|
63
|
+
|
|
64
|
+
class PathIntegration(Module):
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
dim_action,
|
|
68
|
+
dim_structure,
|
|
69
|
+
mlp_hidden_dim = None,
|
|
70
|
+
mlp_depth = 2
|
|
71
|
+
):
|
|
72
|
+
# they use the same approach from Ruiqi Gao's paper from 2021
|
|
73
|
+
super().__init__()
|
|
74
|
+
|
|
75
|
+
self.init_structure = nn.Parameter(torch.randn(dim_structure))
|
|
76
|
+
|
|
77
|
+
self.to_transitions = create_mlp(
|
|
78
|
+
default(mlp_hidden_dim, dim_action * 4),
|
|
79
|
+
dim_in = dim_action,
|
|
80
|
+
dim_out = dim_structure * dim_structure,
|
|
81
|
+
depth = mlp_depth
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
self.mlp_out_to_weights = Rearrange('... (i j) -> ... i j', j = dim_structure)
|
|
85
|
+
|
|
86
|
+
self.rnn = RNN(dim_structure)
|
|
87
|
+
|
|
88
|
+
def forward(
|
|
89
|
+
self,
|
|
90
|
+
actions, # (b n d)
|
|
91
|
+
prev_structural = None # (b n d) | (b d)
|
|
92
|
+
):
|
|
93
|
+
batch = actions.shape[0]
|
|
94
|
+
|
|
95
|
+
transitions = self.to_transitions(actions)
|
|
96
|
+
transitions = self.mlp_out_to_weights(transitions)
|
|
97
|
+
|
|
98
|
+
if exists(prev_structural) and prev_structural.ndim == 3:
|
|
99
|
+
prev_structural = prev_structural[:, -1]
|
|
100
|
+
|
|
101
|
+
return self.rnn(transitions, prev_structural)
|
|
102
|
+
|
|
103
|
+
# proposed mmTEM
|
|
104
|
+
|
|
105
|
+
class mmTEM(Module):
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
dim,
|
|
109
|
+
*,
|
|
110
|
+
sensory_encoder: Module,
|
|
111
|
+
sensory_decoder: Module,
|
|
112
|
+
dim_sensory,
|
|
113
|
+
dim_action,
|
|
114
|
+
dim_encoded_sensory,
|
|
115
|
+
dim_structure,
|
|
116
|
+
meta_mlp_depth = 2,
|
|
117
|
+
decoder_mlp_depth = 2,
|
|
118
|
+
structure_variance_pred_mlp_depth = 2,
|
|
119
|
+
path_integrate_kwargs: dict = dict(),
|
|
120
|
+
loss_weight_generative = 1.,
|
|
121
|
+
loss_weight_inference = 1.,
|
|
122
|
+
loss_weight_consistency = 1.,
|
|
123
|
+
loss_weight_relational = 1.,
|
|
124
|
+
):
|
|
125
|
+
super().__init__()
|
|
126
|
+
|
|
127
|
+
# sensory
|
|
128
|
+
|
|
129
|
+
self.sensory_encoder = sensory_encoder
|
|
130
|
+
self.sensory_decoder = sensory_decoder
|
|
131
|
+
|
|
132
|
+
dim_joint_rep = dim_encoded_sensory + dim_structure
|
|
133
|
+
|
|
134
|
+
self.dim_encoded_sensory = dim_encoded_sensory
|
|
135
|
+
self.dim_structure = dim_structure
|
|
136
|
+
self.joint_dims = (dim_structure, dim_encoded_sensory)
|
|
137
|
+
|
|
138
|
+
# path integrator
|
|
139
|
+
|
|
140
|
+
self.path_integrator = PathIntegration(
|
|
141
|
+
dim_action = dim_action,
|
|
142
|
+
dim_structure = dim_structure,
|
|
143
|
+
**path_integrate_kwargs
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# meta mlp related
|
|
147
|
+
|
|
148
|
+
self.to_queries = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
149
|
+
self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
150
|
+
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
151
|
+
|
|
152
|
+
self.meta_memory_mlp = create_mlp(
|
|
153
|
+
dim = dim * 2,
|
|
154
|
+
depth = meta_mlp_depth,
|
|
155
|
+
dim_in = dim,
|
|
156
|
+
dim_out = dim,
|
|
157
|
+
activation = nn.ReLU()
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# mlp decoder (from meta mlp output to joint)
|
|
161
|
+
|
|
162
|
+
self.memory_output_decoder = create_mlp(
|
|
163
|
+
dim = dim * 2,
|
|
164
|
+
dim_in = dim,
|
|
165
|
+
dim_out = dim_joint_rep,
|
|
166
|
+
depth = decoder_mlp_depth,
|
|
167
|
+
activation = nn.ReLU()
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# the mlp that predicts the variance for the structural code
|
|
171
|
+
# for correcting the generated structural code modeling the feedback from HC to MEC
|
|
172
|
+
|
|
173
|
+
self.structure_variance_pred_mlp_depth = create_mlp(
|
|
174
|
+
dim = dim_structure * 2,
|
|
175
|
+
dim_in = dim_structure * 2 + 1,
|
|
176
|
+
dim_out = dim_structure,
|
|
177
|
+
depth = structure_variance_pred_mlp_depth
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# loss related
|
|
181
|
+
|
|
182
|
+
self.loss_weight_generative = loss_weight_generative
|
|
183
|
+
self.loss_weight_inference = loss_weight_inference
|
|
184
|
+
self.loss_weight_relational = loss_weight_relational
|
|
185
|
+
self.loss_weight_consistency = loss_weight_consistency
|
|
186
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
187
|
+
|
|
188
|
+
def forward(
|
|
189
|
+
self,
|
|
190
|
+
sensory,
|
|
191
|
+
actions
|
|
192
|
+
):
|
|
193
|
+
structural_codes = self.path_integrator(actions)
|
|
194
|
+
|
|
195
|
+
# first have the structure code be able to fetch from the meta memory mlp
|
|
196
|
+
|
|
197
|
+
structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
|
|
198
|
+
|
|
199
|
+
queries = self.to_queries(structure_codes_with_zero_sensory)
|
|
200
|
+
|
|
201
|
+
retrieved = self.meta_memory_mlp(queries)
|
|
202
|
+
|
|
203
|
+
decoded_structure, decoded_encoded_sensory = self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
|
|
204
|
+
|
|
205
|
+
decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
|
|
206
|
+
|
|
207
|
+
generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
|
|
208
|
+
|
|
209
|
+
# losses
|
|
210
|
+
|
|
211
|
+
total_loss = (
|
|
212
|
+
generative_pred_loss * self.loss_weight_generative
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
losses = (
|
|
216
|
+
generative_pred_loss,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return total_loss, losses
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
def test_path_integrate():
|
|
6
|
+
from hippoformer.hippoformer import PathIntegration
|
|
7
|
+
|
|
8
|
+
path_integrator = PathIntegration(32, 64)
|
|
9
|
+
|
|
10
|
+
actions = torch.randn(2, 16, 32)
|
|
11
|
+
|
|
12
|
+
structure_codes = path_integrator(actions)
|
|
13
|
+
structure_codes = path_integrator(actions, structure_codes) # pass in previous structure codes, it will auto use the last one as hidden and pass it to the RNN
|
|
14
|
+
|
|
15
|
+
assert structure_codes.shape == (2, 16, 64)
|
|
16
|
+
|
|
17
|
+
def test_mm_tem():
|
|
18
|
+
import torch
|
|
19
|
+
from hippoformer.hippoformer import mmTEM
|
|
20
|
+
|
|
21
|
+
from torch.nn import Linear
|
|
22
|
+
|
|
23
|
+
model = mmTEM(
|
|
24
|
+
dim = 32,
|
|
25
|
+
sensory_encoder = Linear(11, 32),
|
|
26
|
+
sensory_decoder = Linear(32, 11),
|
|
27
|
+
dim_sensory = 11,
|
|
28
|
+
dim_action = 7,
|
|
29
|
+
dim_structure = 32,
|
|
30
|
+
dim_encoded_sensory = 32
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
actions = torch.randn(2, 16, 7)
|
|
34
|
+
sensory = torch.randn(2, 16, 11)
|
|
35
|
+
|
|
36
|
+
loss, losses = model(sensory, actions)
|
|
37
|
+
loss.backward()
|
|
File without changes
|
|
@@ -1,116 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn, Tensor, stack, einsum
|
|
5
|
-
import torch.nn.functional as F
|
|
6
|
-
from torch.nn import Module
|
|
7
|
-
from torch.jit import ScriptModule, script_method
|
|
8
|
-
|
|
9
|
-
from einops import repeat, rearrange
|
|
10
|
-
from einops.layers.torch import Rearrange
|
|
11
|
-
|
|
12
|
-
from x_mlps_pytorch import create_mlp
|
|
13
|
-
|
|
14
|
-
from assoc_scan import AssocScan
|
|
15
|
-
|
|
16
|
-
# helpers
|
|
17
|
-
|
|
18
|
-
def exists(v):
|
|
19
|
-
return v is not None
|
|
20
|
-
|
|
21
|
-
def default(v, d):
|
|
22
|
-
return v if exists(v) else d
|
|
23
|
-
|
|
24
|
-
def l2norm(t):
|
|
25
|
-
return F.normalize(t, dim = -1)
|
|
26
|
-
|
|
27
|
-
# path integration
|
|
28
|
-
|
|
29
|
-
class RNN(ScriptModule):
|
|
30
|
-
def __init__(
|
|
31
|
-
self,
|
|
32
|
-
dim,
|
|
33
|
-
):
|
|
34
|
-
super().__init__()
|
|
35
|
-
self.init_hidden = nn.Parameter(torch.randn(1, dim) * 1e-2)
|
|
36
|
-
|
|
37
|
-
@script_method
|
|
38
|
-
def forward(
|
|
39
|
-
self,
|
|
40
|
-
transitions: Tensor,
|
|
41
|
-
hidden: Tensor | None = None
|
|
42
|
-
) -> Tensor:
|
|
43
|
-
|
|
44
|
-
batch, seq_len = transitions.shape[:2]
|
|
45
|
-
|
|
46
|
-
if hidden is None:
|
|
47
|
-
hidden = l2norm(self.init_hidden)
|
|
48
|
-
hidden = hidden.expand(batch, -1)
|
|
49
|
-
|
|
50
|
-
hiddens: list[Tensor] = []
|
|
51
|
-
|
|
52
|
-
for i in range(seq_len):
|
|
53
|
-
transition = transitions[:, i]
|
|
54
|
-
|
|
55
|
-
hidden = einsum('b i, b i j -> b j', hidden, transition)
|
|
56
|
-
hidden = F.relu(hidden)
|
|
57
|
-
hidden = l2norm(hidden)
|
|
58
|
-
|
|
59
|
-
hiddens.append(hidden)
|
|
60
|
-
|
|
61
|
-
return stack(hiddens, dim = 1)
|
|
62
|
-
|
|
63
|
-
class PathIntegration(Module):
|
|
64
|
-
def __init__(
|
|
65
|
-
self,
|
|
66
|
-
dim_action,
|
|
67
|
-
dim_structure,
|
|
68
|
-
mlp_hidden_dim = None,
|
|
69
|
-
mlp_depth = 2
|
|
70
|
-
):
|
|
71
|
-
# they use the same approach from Ruiqi Gao's paper from 2021
|
|
72
|
-
super().__init__()
|
|
73
|
-
|
|
74
|
-
self.init_structure = nn.Parameter(torch.randn(dim_structure))
|
|
75
|
-
|
|
76
|
-
self.to_transitions = create_mlp(
|
|
77
|
-
default(mlp_hidden_dim, dim_action * 4),
|
|
78
|
-
dim_in = dim_action,
|
|
79
|
-
dim_out = dim_structure * dim_structure,
|
|
80
|
-
depth = mlp_depth
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
self.mlp_out_to_weights = Rearrange('... (i j) -> ... i j', j = dim_structure)
|
|
84
|
-
|
|
85
|
-
self.rnn = RNN(dim_structure)
|
|
86
|
-
|
|
87
|
-
def forward(
|
|
88
|
-
self,
|
|
89
|
-
actions, # (b n d)
|
|
90
|
-
prev_structural = None # (b n d) | (b d)
|
|
91
|
-
):
|
|
92
|
-
batch = actions.shape[0]
|
|
93
|
-
|
|
94
|
-
transitions = self.to_transitions(actions)
|
|
95
|
-
transitions = self.mlp_out_to_weights(transitions)
|
|
96
|
-
|
|
97
|
-
if exists(prev_structural) and prev_structural.ndim == 3:
|
|
98
|
-
prev_structural = prev_structural[:, -1]
|
|
99
|
-
|
|
100
|
-
return self.rnn(transitions, prev_structural)
|
|
101
|
-
|
|
102
|
-
# proposed mmTEM
|
|
103
|
-
|
|
104
|
-
class mmTEM(Module):
|
|
105
|
-
def __init__(
|
|
106
|
-
self,
|
|
107
|
-
dim
|
|
108
|
-
):
|
|
109
|
-
super().__init__()
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def forward(
|
|
113
|
-
self,
|
|
114
|
-
data
|
|
115
|
-
):
|
|
116
|
-
raise NotImplementedError
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
def test_path_integrate():
|
|
6
|
-
from hippoformer.hippoformer import PathIntegration
|
|
7
|
-
|
|
8
|
-
path_integrator = PathIntegration(32, 64)
|
|
9
|
-
|
|
10
|
-
actions = torch.randn(2, 16, 32)
|
|
11
|
-
|
|
12
|
-
structure_codes = path_integrator(actions)
|
|
13
|
-
structure_codes = path_integrator(actions, structure_codes) # pass in previous structure codes, it will auto use the last one as hidden and pass it to the RNN
|
|
14
|
-
|
|
15
|
-
assert structure_codes.shape == (2, 16, 64)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|