hippoformer 0.0.8__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.8 → hippoformer-0.0.9}/PKG-INFO +1 -1
- {hippoformer-0.0.8 → hippoformer-0.0.9}/hippoformer/hippoformer.py +46 -0
- {hippoformer-0.0.8 → hippoformer-0.0.9}/pyproject.toml +1 -1
- {hippoformer-0.0.8 → hippoformer-0.0.9}/tests/test_hippoformer.py +5 -2
- {hippoformer-0.0.8 → hippoformer-0.0.9}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.8 → hippoformer-0.0.9}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.8 → hippoformer-0.0.9}/.gitignore +0 -0
- {hippoformer-0.0.8 → hippoformer-0.0.9}/LICENSE +0 -0
- {hippoformer-0.0.8 → hippoformer-0.0.9}/README.md +0 -0
- {hippoformer-0.0.8 → hippoformer-0.0.9}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.8 → hippoformer-0.0.9}/hippoformer-fig6.png +0 -0
|
@@ -38,6 +38,40 @@ def pack_with_inverse(t, pattern):
|
|
|
38
38
|
def l2norm(t):
|
|
39
39
|
return F.normalize(t, dim = -1)
|
|
40
40
|
|
|
41
|
+
# Muon - Jordan et al from oss community - applied to the latest version of titans
|
|
42
|
+
|
|
43
|
+
def newtonschulz5(
|
|
44
|
+
t,
|
|
45
|
+
steps = 5,
|
|
46
|
+
eps = 1e-7,
|
|
47
|
+
coefs = (3.4445, -4.7750, 2.0315)
|
|
48
|
+
):
|
|
49
|
+
not_weights = t.ndim <= 3
|
|
50
|
+
|
|
51
|
+
if not_weights:
|
|
52
|
+
return t
|
|
53
|
+
|
|
54
|
+
shape = t.shape
|
|
55
|
+
should_transpose = shape[-2] > shape[-1]
|
|
56
|
+
|
|
57
|
+
if should_transpose:
|
|
58
|
+
t = t.transpose(-1, -2)
|
|
59
|
+
|
|
60
|
+
t, inv_pack = pack_with_inverse(t, '* i j')
|
|
61
|
+
t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
|
|
62
|
+
|
|
63
|
+
a, b, c = coefs
|
|
64
|
+
|
|
65
|
+
for _ in range(steps):
|
|
66
|
+
A = t @ t.transpose(-1, -2)
|
|
67
|
+
B = b * A + c * A @ A
|
|
68
|
+
t = a * t + B @ t
|
|
69
|
+
|
|
70
|
+
if should_transpose:
|
|
71
|
+
t = t.transpose(-1, -2)
|
|
72
|
+
|
|
73
|
+
return inv_pack(t)
|
|
74
|
+
|
|
41
75
|
# sensory encoder decoder for 2d
|
|
42
76
|
|
|
43
77
|
grid_sensory_enc_dec = (
|
|
@@ -209,6 +243,7 @@ class mmTEM(Module):
|
|
|
209
243
|
loss_weight_consistency = 1.,
|
|
210
244
|
loss_weight_relational = 1.,
|
|
211
245
|
integration_ratio_learned = True,
|
|
246
|
+
muon_update = False,
|
|
212
247
|
assoc_scan_kwargs: dict = dict()
|
|
213
248
|
):
|
|
214
249
|
super().__init__()
|
|
@@ -287,6 +322,10 @@ class mmTEM(Module):
|
|
|
287
322
|
self.loss_weight_consistency = loss_weight_consistency
|
|
288
323
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
289
324
|
|
|
325
|
+
# update with muon
|
|
326
|
+
|
|
327
|
+
self.muon_update = muon_update
|
|
328
|
+
|
|
290
329
|
# there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
|
|
291
330
|
|
|
292
331
|
self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
|
|
@@ -416,6 +455,13 @@ class mmTEM(Module):
|
|
|
416
455
|
|
|
417
456
|
update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
|
|
418
457
|
|
|
458
|
+
# maybe muon
|
|
459
|
+
|
|
460
|
+
if self.muon_update:
|
|
461
|
+
update = newtonschulz5(update)
|
|
462
|
+
|
|
463
|
+
# with forget gating
|
|
464
|
+
|
|
419
465
|
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
420
466
|
|
|
421
467
|
acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
|
|
@@ -16,8 +16,10 @@ def test_path_integrate():
|
|
|
16
16
|
assert structure_codes.shape == (2, 16, 64)
|
|
17
17
|
|
|
18
18
|
@param('sensory_type', ('naive', '2d', '3d'))
|
|
19
|
+
@param('muon_update', (True, False))
|
|
19
20
|
def test_mm_tem(
|
|
20
|
-
sensory_type
|
|
21
|
+
sensory_type,
|
|
22
|
+
muon_update
|
|
21
23
|
):
|
|
22
24
|
import torch
|
|
23
25
|
from hippoformer.hippoformer import mmTEM
|
|
@@ -52,7 +54,8 @@ def test_mm_tem(
|
|
|
52
54
|
dim_sensory = 11,
|
|
53
55
|
dim_action = 7,
|
|
54
56
|
dim_structure = 32,
|
|
55
|
-
dim_encoded_sensory = 32
|
|
57
|
+
dim_encoded_sensory = 32,
|
|
58
|
+
muon_update = muon_update
|
|
56
59
|
)
|
|
57
60
|
|
|
58
61
|
actions = torch.randn(2, 16, 7)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|