hippoformer 0.0.3__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.3 → hippoformer-0.0.5}/PKG-INFO +1 -1
- {hippoformer-0.0.3 → hippoformer-0.0.5}/hippoformer/hippoformer.py +88 -11
- {hippoformer-0.0.3 → hippoformer-0.0.5}/pyproject.toml +1 -1
- {hippoformer-0.0.3 → hippoformer-0.0.5}/tests/test_hippoformer.py +1 -1
- {hippoformer-0.0.3 → hippoformer-0.0.5}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.5}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.5}/.gitignore +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.5}/LICENSE +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.5}/README.md +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.5}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.3 → hippoformer-0.0.5}/hippoformer-fig6.png +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import nn, Tensor, stack, einsum, tensor
|
|
4
|
+
from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
|
|
5
5
|
import torch.nn.functional as F
|
|
6
6
|
from torch.nn import Module
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
@@ -121,6 +121,7 @@ class mmTEM(Module):
|
|
|
121
121
|
loss_weight_inference = 1.,
|
|
122
122
|
loss_weight_consistency = 1.,
|
|
123
123
|
loss_weight_relational = 1.,
|
|
124
|
+
integration_ratio_learned = True
|
|
124
125
|
):
|
|
125
126
|
super().__init__()
|
|
126
127
|
|
|
@@ -157,6 +158,14 @@ class mmTEM(Module):
|
|
|
157
158
|
activation = nn.ReLU()
|
|
158
159
|
)
|
|
159
160
|
|
|
161
|
+
def forward_with_mse_loss(params, keys, values):
|
|
162
|
+
pred = functional_call(self.meta_memory_mlp, params, keys)
|
|
163
|
+
return F.mse_loss(pred, values)
|
|
164
|
+
|
|
165
|
+
grad_fn = grad(forward_with_mse_loss)
|
|
166
|
+
|
|
167
|
+
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
|
|
168
|
+
|
|
160
169
|
# mlp decoder (from meta mlp output to joint)
|
|
161
170
|
|
|
162
171
|
self.memory_output_decoder = create_mlp(
|
|
@@ -170,7 +179,7 @@ class mmTEM(Module):
|
|
|
170
179
|
# the mlp that predicts the variance for the structural code
|
|
171
180
|
# for correcting the generated structural code modeling the feedback from HC to MEC
|
|
172
181
|
|
|
173
|
-
self.
|
|
182
|
+
self.structure_variance_pred_mlp = create_mlp(
|
|
174
183
|
dim = dim_structure * 2,
|
|
175
184
|
dim_in = dim_structure * 2 + 1,
|
|
176
185
|
dim_out = dim_structure,
|
|
@@ -185,35 +194,103 @@ class mmTEM(Module):
|
|
|
185
194
|
self.loss_weight_consistency = loss_weight_consistency
|
|
186
195
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
187
196
|
|
|
197
|
+
# there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
|
|
198
|
+
|
|
199
|
+
self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
|
|
200
|
+
|
|
201
|
+
def retrieve(
|
|
202
|
+
self,
|
|
203
|
+
structural_codes,
|
|
204
|
+
encoded_sensory
|
|
205
|
+
):
|
|
206
|
+
joint = cat((structural_codes, encoded_sensory), dim = -1)
|
|
207
|
+
|
|
208
|
+
queries = self.to_queries(joint)
|
|
209
|
+
|
|
210
|
+
retrieved = self.meta_memory_mlp(queries)
|
|
211
|
+
|
|
212
|
+
return self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
|
|
213
|
+
|
|
188
214
|
def forward(
|
|
189
215
|
self,
|
|
190
216
|
sensory,
|
|
191
|
-
actions
|
|
217
|
+
actions,
|
|
218
|
+
return_losses = False
|
|
192
219
|
):
|
|
193
220
|
structural_codes = self.path_integrator(actions)
|
|
194
221
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
|
|
222
|
+
encoded_sensory = self.sensory_encoder(sensory)
|
|
198
223
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
retrieved = self.meta_memory_mlp(queries)
|
|
224
|
+
# 1. first have the structure code be able to fetch from the meta memory mlp
|
|
202
225
|
|
|
203
|
-
|
|
226
|
+
decoded_gen_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
|
|
204
227
|
|
|
205
228
|
decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
|
|
206
229
|
|
|
207
230
|
generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
|
|
208
231
|
|
|
232
|
+
# 2. relational
|
|
233
|
+
|
|
234
|
+
# 2a. structure from content
|
|
235
|
+
|
|
236
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
|
|
237
|
+
|
|
238
|
+
structure_from_content_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
239
|
+
|
|
240
|
+
# 2b. structure from structure
|
|
241
|
+
|
|
242
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
|
|
243
|
+
|
|
244
|
+
structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
245
|
+
|
|
246
|
+
relational_loss = structure_from_content_loss + structure_from_structure_loss
|
|
247
|
+
|
|
248
|
+
# 3. consistency - modeling a feedback system from hippocampus to path integration
|
|
249
|
+
|
|
250
|
+
corrected_structural_code, corrected_encoded_sensory = self.retrieve(decoded_gen_structure, encoded_sensory)
|
|
251
|
+
|
|
252
|
+
sensory_sse = (corrected_encoded_sensory - encoded_sensory).norm(dim = -1, keepdim = True).pow(2)
|
|
253
|
+
|
|
254
|
+
pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
|
|
255
|
+
|
|
256
|
+
inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio.sigmoid() * pred_variance
|
|
257
|
+
|
|
258
|
+
consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
|
|
259
|
+
|
|
260
|
+
# 4. final inference loss
|
|
261
|
+
|
|
262
|
+
final_structural_code, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
|
|
263
|
+
|
|
264
|
+
decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
|
|
265
|
+
|
|
266
|
+
inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
|
|
267
|
+
|
|
268
|
+
# 5. store the final structural code from step 4 + encoded sensory
|
|
269
|
+
|
|
270
|
+
joint_code_to_store = cat((final_structural_code, encoded_sensory), dim = -1)
|
|
271
|
+
|
|
272
|
+
keys = self.to_keys(joint_code_to_store)
|
|
273
|
+
values = self.to_values(joint_code_to_store)
|
|
274
|
+
|
|
275
|
+
grads = self.per_sample_grad_fn(dict(self.meta_memory_mlp.named_parameters()), keys, values)
|
|
276
|
+
|
|
209
277
|
# losses
|
|
210
278
|
|
|
211
279
|
total_loss = (
|
|
212
|
-
generative_pred_loss * self.loss_weight_generative
|
|
280
|
+
generative_pred_loss * self.loss_weight_generative +
|
|
281
|
+
relational_loss * self.loss_weight_relational +
|
|
282
|
+
consistency_loss * self.loss_weight_consistency +
|
|
283
|
+
inference_pred_loss * self.loss_weight_inference
|
|
213
284
|
)
|
|
214
285
|
|
|
215
286
|
losses = (
|
|
216
287
|
generative_pred_loss,
|
|
288
|
+
relational_loss,
|
|
289
|
+
consistency_loss,
|
|
290
|
+
inference_pred_loss
|
|
217
291
|
)
|
|
218
292
|
|
|
293
|
+
if not return_losses:
|
|
294
|
+
return total_loss
|
|
295
|
+
|
|
219
296
|
return total_loss, losses
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|