hippoformer 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -455,6 +455,10 @@ class mmTEM(Module):
455
455
 
456
456
  update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
457
457
 
458
+ # store next momentum
459
+
460
+ next_momentum[key] = update[:, -1]
461
+
458
462
  # maybe muon
459
463
 
460
464
  if self.muon_update:
@@ -464,14 +468,13 @@ class mmTEM(Module):
464
468
 
465
469
  expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
466
470
 
467
- acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
471
+ acc_update = self.assoc_scan(-update, expanded_forget.sigmoid(), param)
468
472
 
469
473
  acc_update = inverse_pack(acc_update)
470
474
 
471
475
  # set the next params and momentum, which can be passed back in
472
476
 
473
- next_params[key] = param - acc_update[:, -1]
474
- next_momentum[key] = update[:, -1]
477
+ next_params[key] = acc_update[:, -1]
475
478
 
476
479
  # losses
477
480
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.9
3
+ Version: 0.0.10
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=GWFJy2idp0FWBoVFw8T_6inTXYtY4i47hfhKj88_I0A,14463
3
+ hippoformer-0.0.10.dist-info/METADATA,sha256=IB7iybYMwOkee3Q5ji-B_dnOB62LyK_6t1FPM_UT-FM,2773
4
+ hippoformer-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.10.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=m7luQGFdMWOkZUorjd5v34hx_vjOQqpJOAGCL0njHUE,14426
3
- hippoformer-0.0.9.dist-info/METADATA,sha256=owgkDcdTf0_N5IbUr3e_yt7u5sIWfOMha-hA5LQWnus,2772
4
- hippoformer-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.9.dist-info/RECORD,,