hippoformer 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -38,6 +38,40 @@ def pack_with_inverse(t, pattern):
38
38
  def l2norm(t):
39
39
  return F.normalize(t, dim = -1)
40
40
 
41
+ # Muon - Jordan et al from oss community - applied to the latest version of titans
42
+
43
+ def newtonschulz5(
44
+ t,
45
+ steps = 5,
46
+ eps = 1e-7,
47
+ coefs = (3.4445, -4.7750, 2.0315)
48
+ ):
49
+ not_weights = t.ndim <= 3
50
+
51
+ if not_weights:
52
+ return t
53
+
54
+ shape = t.shape
55
+ should_transpose = shape[-2] > shape[-1]
56
+
57
+ if should_transpose:
58
+ t = t.transpose(-1, -2)
59
+
60
+ t, inv_pack = pack_with_inverse(t, '* i j')
61
+ t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
62
+
63
+ a, b, c = coefs
64
+
65
+ for _ in range(steps):
66
+ A = t @ t.transpose(-1, -2)
67
+ B = b * A + c * A @ A
68
+ t = a * t + B @ t
69
+
70
+ if should_transpose:
71
+ t = t.transpose(-1, -2)
72
+
73
+ return inv_pack(t)
74
+
41
75
  # sensory encoder decoder for 2d
42
76
 
43
77
  grid_sensory_enc_dec = (
@@ -209,6 +243,7 @@ class mmTEM(Module):
209
243
  loss_weight_consistency = 1.,
210
244
  loss_weight_relational = 1.,
211
245
  integration_ratio_learned = True,
246
+ muon_update = False,
212
247
  assoc_scan_kwargs: dict = dict()
213
248
  ):
214
249
  super().__init__()
@@ -287,6 +322,10 @@ class mmTEM(Module):
287
322
  self.loss_weight_consistency = loss_weight_consistency
288
323
  self.register_buffer('zero', tensor(0.), persistent = False)
289
324
 
325
+ # update with muon
326
+
327
+ self.muon_update = muon_update
328
+
290
329
  # there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
291
330
 
292
331
  self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
@@ -416,16 +455,26 @@ class mmTEM(Module):
416
455
 
417
456
  update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
418
457
 
458
+ # store next momentum
459
+
460
+ next_momentum[key] = update[:, -1]
461
+
462
+ # maybe muon
463
+
464
+ if self.muon_update:
465
+ update = newtonschulz5(update)
466
+
467
+ # with forget gating
468
+
419
469
  expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
420
470
 
421
- acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
471
+ acc_update = self.assoc_scan(-update, expanded_forget.sigmoid(), param)
422
472
 
423
473
  acc_update = inverse_pack(acc_update)
424
474
 
425
475
  # set the next params and momentum, which can be passed back in
426
476
 
427
- next_params[key] = param - acc_update[:, -1]
428
- next_momentum[key] = update[:, -1]
477
+ next_params[key] = acc_update[:, -1]
429
478
 
430
479
  # losses
431
480
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=GWFJy2idp0FWBoVFw8T_6inTXYtY4i47hfhKj88_I0A,14463
3
+ hippoformer-0.0.10.dist-info/METADATA,sha256=IB7iybYMwOkee3Q5ji-B_dnOB62LyK_6t1FPM_UT-FM,2773
4
+ hippoformer-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.10.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=jTR7XRB4bgP6i2VkzOg_Z20Xx7IDg5ifBXYjs6tmrFs,13473
3
- hippoformer-0.0.8.dist-info/METADATA,sha256=DCgss4I14pexPOjxk319tiqQeBpz5aVRGRQDNvskL9g,2772
4
- hippoformer-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.8.dist-info/RECORD,,