hippoformer 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +24 -3
- {hippoformer-0.0.4.dist-info → hippoformer-0.0.5.dist-info}/METADATA +1 -1
- hippoformer-0.0.5.dist-info/RECORD +6 -0
- hippoformer-0.0.4.dist-info/RECORD +0 -6
- {hippoformer-0.0.4.dist-info → hippoformer-0.0.5.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.4.dist-info → hippoformer-0.0.5.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -158,6 +158,14 @@ class mmTEM(Module):
|
|
|
158
158
|
activation = nn.ReLU()
|
|
159
159
|
)
|
|
160
160
|
|
|
161
|
+
def forward_with_mse_loss(params, keys, values):
|
|
162
|
+
pred = functional_call(self.meta_memory_mlp, params, keys)
|
|
163
|
+
return F.mse_loss(pred, values)
|
|
164
|
+
|
|
165
|
+
grad_fn = grad(forward_with_mse_loss)
|
|
166
|
+
|
|
167
|
+
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
|
|
168
|
+
|
|
161
169
|
# mlp decoder (from meta mlp output to joint)
|
|
162
170
|
|
|
163
171
|
self.memory_output_decoder = create_mlp(
|
|
@@ -206,7 +214,8 @@ class mmTEM(Module):
|
|
|
206
214
|
def forward(
|
|
207
215
|
self,
|
|
208
216
|
sensory,
|
|
209
|
-
actions
|
|
217
|
+
actions,
|
|
218
|
+
return_losses = False
|
|
210
219
|
):
|
|
211
220
|
structural_codes = self.path_integrator(actions)
|
|
212
221
|
|
|
@@ -244,18 +253,27 @@ class mmTEM(Module):
|
|
|
244
253
|
|
|
245
254
|
pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
|
|
246
255
|
|
|
247
|
-
inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio * pred_variance
|
|
256
|
+
inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio.sigmoid() * pred_variance
|
|
248
257
|
|
|
249
258
|
consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
|
|
250
259
|
|
|
251
260
|
# 4. final inference loss
|
|
252
261
|
|
|
253
|
-
|
|
262
|
+
final_structural_code, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
|
|
254
263
|
|
|
255
264
|
decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
|
|
256
265
|
|
|
257
266
|
inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
|
|
258
267
|
|
|
268
|
+
# 5. store the final structural code from step 4 + encoded sensory
|
|
269
|
+
|
|
270
|
+
joint_code_to_store = cat((final_structural_code, encoded_sensory), dim = -1)
|
|
271
|
+
|
|
272
|
+
keys = self.to_keys(joint_code_to_store)
|
|
273
|
+
values = self.to_values(joint_code_to_store)
|
|
274
|
+
|
|
275
|
+
grads = self.per_sample_grad_fn(dict(self.meta_memory_mlp.named_parameters()), keys, values)
|
|
276
|
+
|
|
259
277
|
# losses
|
|
260
278
|
|
|
261
279
|
total_loss = (
|
|
@@ -272,4 +290,7 @@ class mmTEM(Module):
|
|
|
272
290
|
inference_pred_loss
|
|
273
291
|
)
|
|
274
292
|
|
|
293
|
+
if not return_losses:
|
|
294
|
+
return total_loss
|
|
295
|
+
|
|
275
296
|
return total_loss, losses
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=PP2KmTygOP6MyYuhmr_8iEBbywIaTW4TpoIycYRugMo,9142
|
|
3
|
+
hippoformer-0.0.5.dist-info/METADATA,sha256=83iG4F_6ibQy6XSCWht-aF2ZVYmiEq-KSF4XR9YaBtY,2773
|
|
4
|
+
hippoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hippoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.5.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=r_kn8kQ8js_Fd5wufj-I8EbE3w8b7SUZm47rUJtt4aY,8329
|
|
3
|
-
hippoformer-0.0.4.dist-info/METADATA,sha256=8geT7mVp0r4WHw3uf860xwWGMpYYX43Rum_PNDUMfmw,2773
|
|
4
|
-
hippoformer-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|