hippoformer 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,8 @@ from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
8
  from torch.func import vmap, grad, functional_call
9
9
 
10
- from einops import repeat, rearrange
10
+ from einx import multiply
11
+ from einops import repeat, rearrange, pack, unpack
11
12
  from einops.layers.torch import Rearrange
12
13
 
13
14
  from x_mlps_pytorch import create_mlp
@@ -22,6 +23,16 @@ def exists(v):
22
23
  def default(v, d):
23
24
  return v if exists(v) else d
24
25
 
26
+ def pack_with_inverse(t, pattern):
27
+ packed, packed_shape = pack([t], pattern)
28
+
29
+ def inverse(out, inv_pattern = None):
30
+ inv_pattern = default(inv_pattern, pattern)
31
+ unpacked, = unpack(out, packed_shape, inv_pattern)
32
+ return unpacked
33
+
34
+ return packed, inverse
35
+
25
36
  def l2norm(t):
26
37
  return F.normalize(t, dim = -1)
27
38
 
@@ -121,7 +132,8 @@ class mmTEM(Module):
121
132
  loss_weight_inference = 1.,
122
133
  loss_weight_consistency = 1.,
123
134
  loss_weight_relational = 1.,
124
- integration_ratio_learned = True
135
+ integration_ratio_learned = True,
136
+ assoc_scan_kwargs: dict = dict()
125
137
  ):
126
138
  super().__init__()
127
139
 
@@ -150,6 +162,9 @@ class mmTEM(Module):
150
162
  self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
151
163
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
152
164
 
165
+ self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
166
+ self.assoc_scan = AssocScan(*assoc_scan_kwargs)
167
+
153
168
  self.meta_memory_mlp = create_mlp(
154
169
  dim = dim * 2,
155
170
  depth = meta_mlp_depth,
@@ -158,6 +173,14 @@ class mmTEM(Module):
158
173
  activation = nn.ReLU()
159
174
  )
160
175
 
176
+ def forward_with_mse_loss(params, keys, values):
177
+ pred = functional_call(self.meta_memory_mlp, params, keys)
178
+ return F.mse_loss(pred, values)
179
+
180
+ grad_fn = grad(forward_with_mse_loss)
181
+
182
+ self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
183
+
161
184
  # mlp decoder (from meta mlp output to joint)
162
185
 
163
186
  self.memory_output_decoder = create_mlp(
@@ -206,8 +229,11 @@ class mmTEM(Module):
206
229
  def forward(
207
230
  self,
208
231
  sensory,
209
- actions
232
+ actions,
233
+ return_losses = False
210
234
  ):
235
+ batch = actions.shape[0]
236
+
211
237
  structural_codes = self.path_integrator(actions)
212
238
 
213
239
  encoded_sensory = self.sensory_encoder(sensory)
@@ -244,18 +270,63 @@ class mmTEM(Module):
244
270
 
245
271
  pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
246
272
 
247
- inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio * pred_variance
273
+ inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio.sigmoid() * pred_variance
248
274
 
249
275
  consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
250
276
 
251
277
  # 4. final inference loss
252
278
 
253
- _, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
279
+ final_structural_code, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
254
280
 
255
281
  decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
256
282
 
257
283
  inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
258
284
 
285
+ # 5. store the final structural code from step 4 + encoded sensory
286
+
287
+ joint_code_to_store = cat((final_structural_code, encoded_sensory), dim = -1)
288
+
289
+ keys = self.to_keys(joint_code_to_store)
290
+ values = self.to_values(joint_code_to_store)
291
+
292
+ lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
293
+
294
+ params = dict(self.meta_memory_mlp.named_parameters())
295
+ grads = self.per_sample_grad_fn(params, keys, values)
296
+
297
+ # update the meta mlp parameters
298
+
299
+ init_momentums = {k: zeros_like(v) for k, v in params.items()}
300
+ next_params = dict()
301
+
302
+ for (
303
+ (key, param),
304
+ (_, grad),
305
+ (_, init_momentum)
306
+ ) in zip(
307
+ params.items(),
308
+ grads.items(),
309
+ init_momentums.items()
310
+ ):
311
+
312
+ grad, inverse_pack = pack_with_inverse(grad, 'b t *')
313
+
314
+ grad = multiply('b t ..., b t', grad, lr)
315
+
316
+ expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
317
+
318
+ init_momentum = repeat(init_momentum, '... -> b ...', b = batch)
319
+
320
+ update = self.assoc_scan(grad, expanded_beta.sigmoid(), init_momentum)
321
+
322
+ expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
323
+
324
+ acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
325
+
326
+ acc_update = inverse_pack(acc_update)
327
+
328
+ next_params[key] = param - acc_update[:, -1]
329
+
259
330
  # losses
260
331
 
261
332
  total_loss = (
@@ -272,4 +343,7 @@ class mmTEM(Module):
272
343
  inference_pred_loss
273
344
  )
274
345
 
346
+ if not return_losses:
347
+ return total_loss
348
+
275
349
  return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -36,6 +36,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
38
  Requires-Dist: einops>=0.8.1
39
+ Requires-Dist: einx>=0.3.0
39
40
  Requires-Dist: torch>=2.4
40
41
  Requires-Dist: x-mlps-pytorch
41
42
  Provides-Extra: examples
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=b6EXU2VXh_ZD7brxpCVNuU-m7cE-zXRR-sOmqfofPCg,10839
3
+ hippoformer-0.0.6.dist-info/METADATA,sha256=ufTBdu8ZGggxwfgzphYV56jjaGdI5sLCE_iZF5Bku6s,2800
4
+ hippoformer-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.6.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=r_kn8kQ8js_Fd5wufj-I8EbE3w8b7SUZm47rUJtt4aY,8329
3
- hippoformer-0.0.4.dist-info/METADATA,sha256=8geT7mVp0r4WHw3uf860xwWGMpYYX43Rum_PNDUMfmw,2773
4
- hippoformer-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.4.dist-info/RECORD,,