hippoformer 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/__init__.py +4 -0
- hippoformer/hippoformer.py +107 -4
- {hippoformer-0.0.1.dist-info → hippoformer-0.0.3.dist-info}/METADATA +1 -1
- hippoformer-0.0.3.dist-info/RECORD +6 -0
- hippoformer-0.0.1.dist-info/RECORD +0 -6
- {hippoformer-0.0.1.dist-info → hippoformer-0.0.3.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.1.dist-info → hippoformer-0.0.3.dist-info}/licenses/LICENSE +0 -0
hippoformer/__init__.py
CHANGED
hippoformer/hippoformer.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import nn, Tensor, stack, einsum
|
|
4
|
+
from torch import nn, Tensor, stack, einsum, tensor
|
|
5
5
|
import torch.nn.functional as F
|
|
6
6
|
from torch.nn import Module
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
|
+
from torch.func import vmap, grad, functional_call
|
|
8
9
|
|
|
9
10
|
from einops import repeat, rearrange
|
|
10
11
|
from einops.layers.torch import Rearrange
|
|
@@ -104,13 +105,115 @@ class PathIntegration(Module):
|
|
|
104
105
|
class mmTEM(Module):
|
|
105
106
|
def __init__(
|
|
106
107
|
self,
|
|
107
|
-
dim
|
|
108
|
+
dim,
|
|
109
|
+
*,
|
|
110
|
+
sensory_encoder: Module,
|
|
111
|
+
sensory_decoder: Module,
|
|
112
|
+
dim_sensory,
|
|
113
|
+
dim_action,
|
|
114
|
+
dim_encoded_sensory,
|
|
115
|
+
dim_structure,
|
|
116
|
+
meta_mlp_depth = 2,
|
|
117
|
+
decoder_mlp_depth = 2,
|
|
118
|
+
structure_variance_pred_mlp_depth = 2,
|
|
119
|
+
path_integrate_kwargs: dict = dict(),
|
|
120
|
+
loss_weight_generative = 1.,
|
|
121
|
+
loss_weight_inference = 1.,
|
|
122
|
+
loss_weight_consistency = 1.,
|
|
123
|
+
loss_weight_relational = 1.,
|
|
108
124
|
):
|
|
109
125
|
super().__init__()
|
|
110
126
|
|
|
127
|
+
# sensory
|
|
128
|
+
|
|
129
|
+
self.sensory_encoder = sensory_encoder
|
|
130
|
+
self.sensory_decoder = sensory_decoder
|
|
131
|
+
|
|
132
|
+
dim_joint_rep = dim_encoded_sensory + dim_structure
|
|
133
|
+
|
|
134
|
+
self.dim_encoded_sensory = dim_encoded_sensory
|
|
135
|
+
self.dim_structure = dim_structure
|
|
136
|
+
self.joint_dims = (dim_structure, dim_encoded_sensory)
|
|
137
|
+
|
|
138
|
+
# path integrator
|
|
139
|
+
|
|
140
|
+
self.path_integrator = PathIntegration(
|
|
141
|
+
dim_action = dim_action,
|
|
142
|
+
dim_structure = dim_structure,
|
|
143
|
+
**path_integrate_kwargs
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# meta mlp related
|
|
147
|
+
|
|
148
|
+
self.to_queries = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
149
|
+
self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
150
|
+
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
151
|
+
|
|
152
|
+
self.meta_memory_mlp = create_mlp(
|
|
153
|
+
dim = dim * 2,
|
|
154
|
+
depth = meta_mlp_depth,
|
|
155
|
+
dim_in = dim,
|
|
156
|
+
dim_out = dim,
|
|
157
|
+
activation = nn.ReLU()
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# mlp decoder (from meta mlp output to joint)
|
|
161
|
+
|
|
162
|
+
self.memory_output_decoder = create_mlp(
|
|
163
|
+
dim = dim * 2,
|
|
164
|
+
dim_in = dim,
|
|
165
|
+
dim_out = dim_joint_rep,
|
|
166
|
+
depth = decoder_mlp_depth,
|
|
167
|
+
activation = nn.ReLU()
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# the mlp that predicts the variance for the structural code
|
|
171
|
+
# for correcting the generated structural code modeling the feedback from HC to MEC
|
|
172
|
+
|
|
173
|
+
self.structure_variance_pred_mlp_depth = create_mlp(
|
|
174
|
+
dim = dim_structure * 2,
|
|
175
|
+
dim_in = dim_structure * 2 + 1,
|
|
176
|
+
dim_out = dim_structure,
|
|
177
|
+
depth = structure_variance_pred_mlp_depth
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# loss related
|
|
181
|
+
|
|
182
|
+
self.loss_weight_generative = loss_weight_generative
|
|
183
|
+
self.loss_weight_inference = loss_weight_inference
|
|
184
|
+
self.loss_weight_relational = loss_weight_relational
|
|
185
|
+
self.loss_weight_consistency = loss_weight_consistency
|
|
186
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
111
187
|
|
|
112
188
|
def forward(
|
|
113
189
|
self,
|
|
114
|
-
|
|
190
|
+
sensory,
|
|
191
|
+
actions
|
|
115
192
|
):
|
|
116
|
-
|
|
193
|
+
structural_codes = self.path_integrator(actions)
|
|
194
|
+
|
|
195
|
+
# first have the structure code be able to fetch from the meta memory mlp
|
|
196
|
+
|
|
197
|
+
structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
|
|
198
|
+
|
|
199
|
+
queries = self.to_queries(structure_codes_with_zero_sensory)
|
|
200
|
+
|
|
201
|
+
retrieved = self.meta_memory_mlp(queries)
|
|
202
|
+
|
|
203
|
+
decoded_structure, decoded_encoded_sensory = self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
|
|
204
|
+
|
|
205
|
+
decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
|
|
206
|
+
|
|
207
|
+
generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
|
|
208
|
+
|
|
209
|
+
# losses
|
|
210
|
+
|
|
211
|
+
total_loss = (
|
|
212
|
+
generative_pred_loss * self.loss_weight_generative
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
losses = (
|
|
216
|
+
generative_pred_loss,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return total_loss, losses
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=XnLDm12BHuJSOcb1Y8EKCSqkvdw20anuGT2O4kFsWHg,6041
|
|
3
|
+
hippoformer-0.0.3.dist-info/METADATA,sha256=L1SJ76ffExU4DQen0ppCv3gzt7waXXaE6znPuon60_g,2773
|
|
4
|
+
hippoformer-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hippoformer-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.3.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
hippoformer/hippoformer.py,sha256=6tA4ZWYKbzclpeTUhJtr2OguVOyyAGFxuLf9bfnfO_M,2682
|
|
3
|
-
hippoformer-0.0.1.dist-info/METADATA,sha256=4hnfh1oIIlcGsIQ7qD7fZHWfM5ltnHhATAPcN-4vkxQ,2773
|
|
4
|
-
hippoformer-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|