hippoformer 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -158,6 +158,14 @@ class mmTEM(Module):
158
158
  activation = nn.ReLU()
159
159
  )
160
160
 
161
+ def forward_with_mse_loss(params, keys, values):
162
+ pred = functional_call(self.meta_memory_mlp, params, keys)
163
+ return F.mse_loss(pred, values)
164
+
165
+ grad_fn = grad(forward_with_mse_loss)
166
+
167
+ self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
168
+
161
169
  # mlp decoder (from meta mlp output to joint)
162
170
 
163
171
  self.memory_output_decoder = create_mlp(
@@ -206,7 +214,8 @@ class mmTEM(Module):
206
214
  def forward(
207
215
  self,
208
216
  sensory,
209
- actions
217
+ actions,
218
+ return_losses = False
210
219
  ):
211
220
  structural_codes = self.path_integrator(actions)
212
221
 
@@ -244,18 +253,27 @@ class mmTEM(Module):
244
253
 
245
254
  pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
246
255
 
247
- inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio * pred_variance
256
+ inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio.sigmoid() * pred_variance
248
257
 
249
258
  consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
250
259
 
251
260
  # 4. final inference loss
252
261
 
253
- _, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
262
+ final_structural_code, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
254
263
 
255
264
  decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
256
265
 
257
266
  inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
258
267
 
268
+ # 5. store the final structural code from step 4 + encoded sensory
269
+
270
+ joint_code_to_store = cat((final_structural_code, encoded_sensory), dim = -1)
271
+
272
+ keys = self.to_keys(joint_code_to_store)
273
+ values = self.to_values(joint_code_to_store)
274
+
275
+ grads = self.per_sample_grad_fn(dict(self.meta_memory_mlp.named_parameters()), keys, values)
276
+
259
277
  # losses
260
278
 
261
279
  total_loss = (
@@ -272,4 +290,7 @@ class mmTEM(Module):
272
290
  inference_pred_loss
273
291
  )
274
292
 
293
+ if not return_losses:
294
+ return total_loss
295
+
275
296
  return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -33,5 +33,5 @@ def test_mm_tem():
33
33
  actions = torch.randn(2, 16, 7)
34
34
  sensory = torch.randn(2, 16, 11)
35
35
 
36
- loss, losses = model(sensory, actions)
36
+ loss = model(sensory, actions)
37
37
  loss.backward()
File without changes
File without changes
File without changes