hippoformer 0.0.3__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.3 → hippoformer-0.0.4}/PKG-INFO +1 -1
- {hippoformer-0.0.3 → hippoformer-0.0.4}/hippoformer/hippoformer.py +66 -10
- {hippoformer-0.0.3 → hippoformer-0.0.4}/pyproject.toml +1 -1
- {hippoformer-0.0.3 → hippoformer-0.0.4}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.4}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.4}/.gitignore +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.4}/LICENSE +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.4}/README.md +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.4}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.4}/hippoformer-fig6.png +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.4}/tests/test_hippoformer.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import nn, Tensor, stack, einsum, tensor
|
|
4
|
+
from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
|
|
5
5
|
import torch.nn.functional as F
|
|
6
6
|
from torch.nn import Module
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
@@ -121,6 +121,7 @@ class mmTEM(Module):
|
|
|
121
121
|
loss_weight_inference = 1.,
|
|
122
122
|
loss_weight_consistency = 1.,
|
|
123
123
|
loss_weight_relational = 1.,
|
|
124
|
+
integration_ratio_learned = True
|
|
124
125
|
):
|
|
125
126
|
super().__init__()
|
|
126
127
|
|
|
@@ -170,7 +171,7 @@ class mmTEM(Module):
|
|
|
170
171
|
# the mlp that predicts the variance for the structural code
|
|
171
172
|
# for correcting the generated structural code modeling the feedback from HC to MEC
|
|
172
173
|
|
|
173
|
-
self.
|
|
174
|
+
self.structure_variance_pred_mlp = create_mlp(
|
|
174
175
|
dim = dim_structure * 2,
|
|
175
176
|
dim_in = dim_structure * 2 + 1,
|
|
176
177
|
dim_out = dim_structure,
|
|
@@ -185,6 +186,23 @@ class mmTEM(Module):
|
|
|
185
186
|
self.loss_weight_consistency = loss_weight_consistency
|
|
186
187
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
187
188
|
|
|
189
|
+
# there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
|
|
190
|
+
|
|
191
|
+
self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
|
|
192
|
+
|
|
193
|
+
def retrieve(
|
|
194
|
+
self,
|
|
195
|
+
structural_codes,
|
|
196
|
+
encoded_sensory
|
|
197
|
+
):
|
|
198
|
+
joint = cat((structural_codes, encoded_sensory), dim = -1)
|
|
199
|
+
|
|
200
|
+
queries = self.to_queries(joint)
|
|
201
|
+
|
|
202
|
+
retrieved = self.meta_memory_mlp(queries)
|
|
203
|
+
|
|
204
|
+
return self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
|
|
205
|
+
|
|
188
206
|
def forward(
|
|
189
207
|
self,
|
|
190
208
|
sensory,
|
|
@@ -192,28 +210,66 @@ class mmTEM(Module):
|
|
|
192
210
|
):
|
|
193
211
|
structural_codes = self.path_integrator(actions)
|
|
194
212
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
|
|
213
|
+
encoded_sensory = self.sensory_encoder(sensory)
|
|
198
214
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
retrieved = self.meta_memory_mlp(queries)
|
|
215
|
+
# 1. first have the structure code be able to fetch from the meta memory mlp
|
|
202
216
|
|
|
203
|
-
|
|
217
|
+
decoded_gen_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
|
|
204
218
|
|
|
205
219
|
decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
|
|
206
220
|
|
|
207
221
|
generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
|
|
208
222
|
|
|
223
|
+
# 2. relational
|
|
224
|
+
|
|
225
|
+
# 2a. structure from content
|
|
226
|
+
|
|
227
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
|
|
228
|
+
|
|
229
|
+
structure_from_content_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
230
|
+
|
|
231
|
+
# 2b. structure from structure
|
|
232
|
+
|
|
233
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
|
|
234
|
+
|
|
235
|
+
structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
236
|
+
|
|
237
|
+
relational_loss = structure_from_content_loss + structure_from_structure_loss
|
|
238
|
+
|
|
239
|
+
# 3. consistency - modeling a feedback system from hippocampus to path integration
|
|
240
|
+
|
|
241
|
+
corrected_structural_code, corrected_encoded_sensory = self.retrieve(decoded_gen_structure, encoded_sensory)
|
|
242
|
+
|
|
243
|
+
sensory_sse = (corrected_encoded_sensory - encoded_sensory).norm(dim = -1, keepdim = True).pow(2)
|
|
244
|
+
|
|
245
|
+
pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
|
|
246
|
+
|
|
247
|
+
inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio * pred_variance
|
|
248
|
+
|
|
249
|
+
consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
|
|
250
|
+
|
|
251
|
+
# 4. final inference loss
|
|
252
|
+
|
|
253
|
+
_, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
|
|
254
|
+
|
|
255
|
+
decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
|
|
256
|
+
|
|
257
|
+
inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
|
|
258
|
+
|
|
209
259
|
# losses
|
|
210
260
|
|
|
211
261
|
total_loss = (
|
|
212
|
-
generative_pred_loss * self.loss_weight_generative
|
|
262
|
+
generative_pred_loss * self.loss_weight_generative +
|
|
263
|
+
relational_loss * self.loss_weight_relational +
|
|
264
|
+
consistency_loss * self.loss_weight_consistency +
|
|
265
|
+
inference_pred_loss * self.loss_weight_inference
|
|
213
266
|
)
|
|
214
267
|
|
|
215
268
|
losses = (
|
|
216
269
|
generative_pred_loss,
|
|
270
|
+
relational_loss,
|
|
271
|
+
consistency_loss,
|
|
272
|
+
inference_pred_loss
|
|
217
273
|
)
|
|
218
274
|
|
|
219
275
|
return total_loss, losses
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|