hippoformer 0.0.9__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.9 → hippoformer-0.0.10}/PKG-INFO +1 -1
- {hippoformer-0.0.9 → hippoformer-0.0.10}/hippoformer/hippoformer.py +6 -3
- {hippoformer-0.0.9 → hippoformer-0.0.10}/pyproject.toml +1 -1
- {hippoformer-0.0.9 → hippoformer-0.0.10}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.10}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.10}/.gitignore +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.10}/LICENSE +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.10}/README.md +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.10}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.10}/hippoformer-fig6.png +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.10}/tests/test_hippoformer.py +0 -0
|
@@ -455,6 +455,10 @@ class mmTEM(Module):
|
|
|
455
455
|
|
|
456
456
|
update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
|
|
457
457
|
|
|
458
|
+
# store next momentum
|
|
459
|
+
|
|
460
|
+
next_momentum[key] = update[:, -1]
|
|
461
|
+
|
|
458
462
|
# maybe muon
|
|
459
463
|
|
|
460
464
|
if self.muon_update:
|
|
@@ -464,14 +468,13 @@ class mmTEM(Module):
|
|
|
464
468
|
|
|
465
469
|
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
466
470
|
|
|
467
|
-
acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
|
|
471
|
+
acc_update = self.assoc_scan(-update, expanded_forget.sigmoid(), param)
|
|
468
472
|
|
|
469
473
|
acc_update = inverse_pack(acc_update)
|
|
470
474
|
|
|
471
475
|
# set the next params and momentum, which can be passed back in
|
|
472
476
|
|
|
473
|
-
next_params[key] =
|
|
474
|
-
next_momentum[key] = update[:, -1]
|
|
477
|
+
next_params[key] = acc_update[:, -1]
|
|
475
478
|
|
|
476
479
|
# losses
|
|
477
480
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|