hippoformer 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, Tensor, stack, einsum, tensor
4
+ from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
5
5
  import torch.nn.functional as F
6
6
  from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
@@ -121,6 +121,7 @@ class mmTEM(Module):
121
121
  loss_weight_inference = 1.,
122
122
  loss_weight_consistency = 1.,
123
123
  loss_weight_relational = 1.,
124
+ integration_ratio_learned = True
124
125
  ):
125
126
  super().__init__()
126
127
 
@@ -170,7 +171,7 @@ class mmTEM(Module):
170
171
  # the mlp that predicts the variance for the structural code
171
172
  # for correcting the generated structural code modeling the feedback from HC to MEC
172
173
 
173
- self.structure_variance_pred_mlp_depth = create_mlp(
174
+ self.structure_variance_pred_mlp = create_mlp(
174
175
  dim = dim_structure * 2,
175
176
  dim_in = dim_structure * 2 + 1,
176
177
  dim_out = dim_structure,
@@ -185,6 +186,23 @@ class mmTEM(Module):
185
186
  self.loss_weight_consistency = loss_weight_consistency
186
187
  self.register_buffer('zero', tensor(0.), persistent = False)
187
188
 
189
+ # there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
190
+
191
+ self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
192
+
193
+ def retrieve(
194
+ self,
195
+ structural_codes,
196
+ encoded_sensory
197
+ ):
198
+ joint = cat((structural_codes, encoded_sensory), dim = -1)
199
+
200
+ queries = self.to_queries(joint)
201
+
202
+ retrieved = self.meta_memory_mlp(queries)
203
+
204
+ return self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
205
+
188
206
  def forward(
189
207
  self,
190
208
  sensory,
@@ -192,28 +210,66 @@ class mmTEM(Module):
192
210
  ):
193
211
  structural_codes = self.path_integrator(actions)
194
212
 
195
- # first have the structure code be able to fetch from the meta memory mlp
196
-
197
- structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
213
+ encoded_sensory = self.sensory_encoder(sensory)
198
214
 
199
- queries = self.to_queries(structure_codes_with_zero_sensory)
200
-
201
- retrieved = self.meta_memory_mlp(queries)
215
+ # 1. first have the structure code be able to fetch from the meta memory mlp
202
216
 
203
- decoded_structure, decoded_encoded_sensory = self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
217
+ decoded_gen_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
204
218
 
205
219
  decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
206
220
 
207
221
  generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
208
222
 
223
+ # 2. relational
224
+
225
+ # 2a. structure from content
226
+
227
+ decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
228
+
229
+ structure_from_content_loss = F.mse_loss(decoded_structure, structural_codes)
230
+
231
+ # 2b. structure from structure
232
+
233
+ decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
234
+
235
+ structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
236
+
237
+ relational_loss = structure_from_content_loss + structure_from_structure_loss
238
+
239
+ # 3. consistency - modeling a feedback system from hippocampus to path integration
240
+
241
+ corrected_structural_code, corrected_encoded_sensory = self.retrieve(decoded_gen_structure, encoded_sensory)
242
+
243
+ sensory_sse = (corrected_encoded_sensory - encoded_sensory).norm(dim = -1, keepdim = True).pow(2)
244
+
245
+ pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
246
+
247
+ inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio * pred_variance
248
+
249
+ consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
250
+
251
+ # 4. final inference loss
252
+
253
+ _, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
254
+
255
+ decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
256
+
257
+ inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
258
+
209
259
  # losses
210
260
 
211
261
  total_loss = (
212
- generative_pred_loss * self.loss_weight_generative
262
+ generative_pred_loss * self.loss_weight_generative +
263
+ relational_loss * self.loss_weight_relational +
264
+ consistency_loss * self.loss_weight_consistency +
265
+ inference_pred_loss * self.loss_weight_inference
213
266
  )
214
267
 
215
268
  losses = (
216
269
  generative_pred_loss,
270
+ relational_loss,
271
+ consistency_loss,
272
+ inference_pred_loss
217
273
  )
218
274
 
219
275
  return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.3"
3
+ version = "0.0.4"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes