hippoformer 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -179,7 +179,7 @@ class mmTEM(Module):
179
179
 
180
180
  grad_fn = grad(forward_with_mse_loss)
181
181
 
182
- self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
182
+ self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (0, 0, 0))
183
183
 
184
184
  # mlp decoder (from meta mlp output to joint)
185
185
 
@@ -213,6 +213,19 @@ class mmTEM(Module):
213
213
 
214
214
  self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
215
215
 
216
+ def init_params_and_momentum(
217
+ self,
218
+ batch_size
219
+ ):
220
+
221
+ params_dict = dict(self.meta_memory_mlp.named_parameters())
222
+
223
+ params = {name: repeat(param, '... -> b ...', b = batch_size) for name, param in params_dict.items()}
224
+
225
+ momentums = {name: zeros_like(param) for name, param in params.items()}
226
+
227
+ return params, momentums
228
+
216
229
  def retrieve(
217
230
  self,
218
231
  structural_codes,
@@ -230,7 +243,9 @@ class mmTEM(Module):
230
243
  self,
231
244
  sensory,
232
245
  actions,
233
- return_losses = False
246
+ memory_mlp_params = None,
247
+ return_losses = False,
248
+ return_memory_mlp_params = False
234
249
  ):
235
250
  batch = actions.shape[0]
236
251
 
@@ -291,22 +306,28 @@ class mmTEM(Module):
291
306
 
292
307
  lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
293
308
 
294
- params = dict(self.meta_memory_mlp.named_parameters())
309
+ if exists(memory_mlp_params):
310
+ params, momentums = memory_mlp_params
311
+ else:
312
+ params, momentums = self.init_params_and_momentum(batch)
313
+
314
+ # store by getting gradients of mse loss of keys and values
315
+
295
316
  grads = self.per_sample_grad_fn(params, keys, values)
296
317
 
297
- # update the meta mlp parameters
318
+ # update the meta mlp parameters and momentums
298
319
 
299
- init_momentums = {k: zeros_like(v) for k, v in params.items()}
300
320
  next_params = dict()
321
+ next_momentum = dict()
301
322
 
302
323
  for (
303
324
  (key, param),
304
325
  (_, grad),
305
- (_, init_momentum)
326
+ (_, momentum)
306
327
  ) in zip(
307
328
  params.items(),
308
329
  grads.items(),
309
- init_momentums.items()
330
+ momentums.items()
310
331
  ):
311
332
 
312
333
  grad, inverse_pack = pack_with_inverse(grad, 'b t *')
@@ -315,9 +336,7 @@ class mmTEM(Module):
315
336
 
316
337
  expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
317
338
 
318
- init_momentum = repeat(init_momentum, '... -> b ...', b = batch)
319
-
320
- update = self.assoc_scan(grad, expanded_beta.sigmoid(), init_momentum)
339
+ update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
321
340
 
322
341
  expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
323
342
 
@@ -325,7 +344,10 @@ class mmTEM(Module):
325
344
 
326
345
  acc_update = inverse_pack(acc_update)
327
346
 
347
+ # set the next params and momentum, which can be passed back in
348
+
328
349
  next_params[key] = param - acc_update[:, -1]
350
+ next_momentum[key] = update[:, -1]
329
351
 
330
352
  # losses
331
353
 
@@ -343,6 +365,9 @@ class mmTEM(Module):
343
365
  inference_pred_loss
344
366
  )
345
367
 
368
+ if return_memory_mlp_params:
369
+ return next_params, next_momentum
370
+
346
371
  if not return_losses:
347
372
  return total_loss
348
373
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=yYoJ5XO0YVAyp3LcRxpunU-0HA97mpCBeQFyi-NSkF0,11549
3
+ hippoformer-0.0.7.dist-info/METADATA,sha256=Xg6NZ6VAQGmuiOo8mMwIAM39Gf6TpVOpyn7o4PMq7JE,2800
4
+ hippoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.7.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=b6EXU2VXh_ZD7brxpCVNuU-m7cE-zXRR-sOmqfofPCg,10839
3
- hippoformer-0.0.6.dist-info/METADATA,sha256=ufTBdu8ZGggxwfgzphYV56jjaGdI5sLCE_iZF5Bku6s,2800
4
- hippoformer-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.6.dist-info/RECORD,,