hippoformer 0.0.3__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, Tensor, stack, einsum, tensor
4
+ from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
5
5
  import torch.nn.functional as F
6
6
  from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
@@ -121,6 +121,7 @@ class mmTEM(Module):
121
121
  loss_weight_inference = 1.,
122
122
  loss_weight_consistency = 1.,
123
123
  loss_weight_relational = 1.,
124
+ integration_ratio_learned = True
124
125
  ):
125
126
  super().__init__()
126
127
 
@@ -157,6 +158,14 @@ class mmTEM(Module):
157
158
  activation = nn.ReLU()
158
159
  )
159
160
 
161
+ def forward_with_mse_loss(params, keys, values):
162
+ pred = functional_call(self.meta_memory_mlp, params, keys)
163
+ return F.mse_loss(pred, values)
164
+
165
+ grad_fn = grad(forward_with_mse_loss)
166
+
167
+ self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
168
+
160
169
  # mlp decoder (from meta mlp output to joint)
161
170
 
162
171
  self.memory_output_decoder = create_mlp(
@@ -170,7 +179,7 @@ class mmTEM(Module):
170
179
  # the mlp that predicts the variance for the structural code
171
180
  # for correcting the generated structural code modeling the feedback from HC to MEC
172
181
 
173
- self.structure_variance_pred_mlp_depth = create_mlp(
182
+ self.structure_variance_pred_mlp = create_mlp(
174
183
  dim = dim_structure * 2,
175
184
  dim_in = dim_structure * 2 + 1,
176
185
  dim_out = dim_structure,
@@ -185,35 +194,103 @@ class mmTEM(Module):
185
194
  self.loss_weight_consistency = loss_weight_consistency
186
195
  self.register_buffer('zero', tensor(0.), persistent = False)
187
196
 
197
+ # there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
198
+
199
+ self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
200
+
201
+ def retrieve(
202
+ self,
203
+ structural_codes,
204
+ encoded_sensory
205
+ ):
206
+ joint = cat((structural_codes, encoded_sensory), dim = -1)
207
+
208
+ queries = self.to_queries(joint)
209
+
210
+ retrieved = self.meta_memory_mlp(queries)
211
+
212
+ return self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
213
+
188
214
  def forward(
189
215
  self,
190
216
  sensory,
191
- actions
217
+ actions,
218
+ return_losses = False
192
219
  ):
193
220
  structural_codes = self.path_integrator(actions)
194
221
 
195
- # first have the structure code be able to fetch from the meta memory mlp
196
-
197
- structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
222
+ encoded_sensory = self.sensory_encoder(sensory)
198
223
 
199
- queries = self.to_queries(structure_codes_with_zero_sensory)
200
-
201
- retrieved = self.meta_memory_mlp(queries)
224
+ # 1. first have the structure code be able to fetch from the meta memory mlp
202
225
 
203
- decoded_structure, decoded_encoded_sensory = self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
226
+ decoded_gen_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
204
227
 
205
228
  decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
206
229
 
207
230
  generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
208
231
 
232
+ # 2. relational
233
+
234
+ # 2a. structure from content
235
+
236
+ decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
237
+
238
+ structure_from_content_loss = F.mse_loss(decoded_structure, structural_codes)
239
+
240
+ # 2b. structure from structure
241
+
242
+ decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
243
+
244
+ structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
245
+
246
+ relational_loss = structure_from_content_loss + structure_from_structure_loss
247
+
248
+ # 3. consistency - modeling a feedback system from hippocampus to path integration
249
+
250
+ corrected_structural_code, corrected_encoded_sensory = self.retrieve(decoded_gen_structure, encoded_sensory)
251
+
252
+ sensory_sse = (corrected_encoded_sensory - encoded_sensory).norm(dim = -1, keepdim = True).pow(2)
253
+
254
+ pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
255
+
256
+ inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio.sigmoid() * pred_variance
257
+
258
+ consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
259
+
260
+ # 4. final inference loss
261
+
262
+ final_structural_code, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
263
+
264
+ decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
265
+
266
+ inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
267
+
268
+ # 5. store the final structural code from step 4 + encoded sensory
269
+
270
+ joint_code_to_store = cat((final_structural_code, encoded_sensory), dim = -1)
271
+
272
+ keys = self.to_keys(joint_code_to_store)
273
+ values = self.to_values(joint_code_to_store)
274
+
275
+ grads = self.per_sample_grad_fn(dict(self.meta_memory_mlp.named_parameters()), keys, values)
276
+
209
277
  # losses
210
278
 
211
279
  total_loss = (
212
- generative_pred_loss * self.loss_weight_generative
280
+ generative_pred_loss * self.loss_weight_generative +
281
+ relational_loss * self.loss_weight_relational +
282
+ consistency_loss * self.loss_weight_consistency +
283
+ inference_pred_loss * self.loss_weight_inference
213
284
  )
214
285
 
215
286
  losses = (
216
287
  generative_pred_loss,
288
+ relational_loss,
289
+ consistency_loss,
290
+ inference_pred_loss
217
291
  )
218
292
 
293
+ if not return_losses:
294
+ return total_loss
295
+
219
296
  return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.3"
3
+ version = "0.0.5"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -33,5 +33,5 @@ def test_mm_tem():
33
33
  actions = torch.randn(2, 16, 7)
34
34
  sensory = torch.randn(2, 16, 11)
35
35
 
36
- loss, losses = model(sensory, actions)
36
+ loss = model(sensory, actions)
37
37
  loss.backward()
File without changes
File without changes
File without changes