hippoformer 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +126 -2
- {hippoformer-0.0.7.dist-info → hippoformer-0.0.9.dist-info}/METADATA +2 -3
- hippoformer-0.0.9.dist-info/RECORD +6 -0
- hippoformer-0.0.7.dist-info/RECORD +0 -6
- {hippoformer-0.0.7.dist-info → hippoformer-0.0.9.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.7.dist-info → hippoformer-0.0.9.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -7,6 +7,8 @@ from torch.nn import Module
|
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
8
|
from torch.func import vmap, grad, functional_call
|
|
9
9
|
|
|
10
|
+
from beartype import beartype
|
|
11
|
+
|
|
10
12
|
from einx import multiply
|
|
11
13
|
from einops import repeat, rearrange, pack, unpack
|
|
12
14
|
from einops.layers.torch import Rearrange
|
|
@@ -36,6 +38,114 @@ def pack_with_inverse(t, pattern):
|
|
|
36
38
|
def l2norm(t):
|
|
37
39
|
return F.normalize(t, dim = -1)
|
|
38
40
|
|
|
41
|
+
# Muon - Jordan et al from oss community - applied to the latest version of titans
|
|
42
|
+
|
|
43
|
+
def newtonschulz5(
|
|
44
|
+
t,
|
|
45
|
+
steps = 5,
|
|
46
|
+
eps = 1e-7,
|
|
47
|
+
coefs = (3.4445, -4.7750, 2.0315)
|
|
48
|
+
):
|
|
49
|
+
not_weights = t.ndim <= 3
|
|
50
|
+
|
|
51
|
+
if not_weights:
|
|
52
|
+
return t
|
|
53
|
+
|
|
54
|
+
shape = t.shape
|
|
55
|
+
should_transpose = shape[-2] > shape[-1]
|
|
56
|
+
|
|
57
|
+
if should_transpose:
|
|
58
|
+
t = t.transpose(-1, -2)
|
|
59
|
+
|
|
60
|
+
t, inv_pack = pack_with_inverse(t, '* i j')
|
|
61
|
+
t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
|
|
62
|
+
|
|
63
|
+
a, b, c = coefs
|
|
64
|
+
|
|
65
|
+
for _ in range(steps):
|
|
66
|
+
A = t @ t.transpose(-1, -2)
|
|
67
|
+
B = b * A + c * A @ A
|
|
68
|
+
t = a * t + B @ t
|
|
69
|
+
|
|
70
|
+
if should_transpose:
|
|
71
|
+
t = t.transpose(-1, -2)
|
|
72
|
+
|
|
73
|
+
return inv_pack(t)
|
|
74
|
+
|
|
75
|
+
# sensory encoder decoder for 2d
|
|
76
|
+
|
|
77
|
+
grid_sensory_enc_dec = (
|
|
78
|
+
create_mlp(
|
|
79
|
+
dim = 32 * 2,
|
|
80
|
+
dim_in = 9,
|
|
81
|
+
dim_out = 32,
|
|
82
|
+
depth = 3,
|
|
83
|
+
),
|
|
84
|
+
create_mlp(
|
|
85
|
+
dim = 32 * 2,
|
|
86
|
+
dim_in = 32,
|
|
87
|
+
dim_out = 9,
|
|
88
|
+
depth = 3,
|
|
89
|
+
),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# sensory encoder decoder for 3d maze
|
|
93
|
+
|
|
94
|
+
class EncoderPackTime(Module):
|
|
95
|
+
def __init__(self, fn: Module):
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.fn = fn
|
|
98
|
+
|
|
99
|
+
def forward(self, x):
|
|
100
|
+
x = rearrange(x, 'b c t h w -> b t c h w')
|
|
101
|
+
x, packed_shape = pack([x], '* c h w')
|
|
102
|
+
|
|
103
|
+
x = self.fn(x)
|
|
104
|
+
|
|
105
|
+
x, = unpack(x, packed_shape, '* d')
|
|
106
|
+
print(x.shape)
|
|
107
|
+
return x
|
|
108
|
+
|
|
109
|
+
class DecoderPackTime(Module):
|
|
110
|
+
def __init__(self, fn: Module):
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.fn = fn
|
|
113
|
+
|
|
114
|
+
def forward(self, x):
|
|
115
|
+
x, packed_shape = pack(x, '* d')
|
|
116
|
+
|
|
117
|
+
x = self.fn(x)
|
|
118
|
+
|
|
119
|
+
x = unpack(x, packed_shape, '* c h w')
|
|
120
|
+
x = rearrange(x, 'b t c h w -> b c t h w')
|
|
121
|
+
return x
|
|
122
|
+
|
|
123
|
+
maze_sensory_enc_dec = (
|
|
124
|
+
EncoderPackTime(nn.Sequential(
|
|
125
|
+
nn.Conv2d(3, 16, 7, 2, padding = 3),
|
|
126
|
+
nn.ReLU(),
|
|
127
|
+
nn.Conv2d(16, 32, 3, 2, 1),
|
|
128
|
+
nn.ReLU(),
|
|
129
|
+
nn.Conv2d(32, 64, 3, 2, 1),
|
|
130
|
+
nn.ReLU(),
|
|
131
|
+
nn.Conv2d(64, 128, 3, 2, 1),
|
|
132
|
+
nn.ReLU(),
|
|
133
|
+
Rearrange('b ... -> b (...)'),
|
|
134
|
+
nn.Linear(2048, 32)
|
|
135
|
+
)),
|
|
136
|
+
DecoderPackTime(nn.Sequential(
|
|
137
|
+
nn.Linear(32, 2048),
|
|
138
|
+
Rearrange('b (c h w) -> b c h w', c = 128, h = 4),
|
|
139
|
+
nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding = (1, 1)),
|
|
140
|
+
nn.ReLU(),
|
|
141
|
+
nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding = (1, 1)),
|
|
142
|
+
nn.ReLU(),
|
|
143
|
+
nn.ConvTranspose2d(32, 16, 3, 2, 1, output_padding = (1, 1)),
|
|
144
|
+
nn.ReLU(),
|
|
145
|
+
nn.ConvTranspose2d(16, 3, 3, 2, 1, output_padding = (1, 1))
|
|
146
|
+
))
|
|
147
|
+
)
|
|
148
|
+
|
|
39
149
|
# path integration
|
|
40
150
|
|
|
41
151
|
class RNN(ScriptModule):
|
|
@@ -114,12 +224,12 @@ class PathIntegration(Module):
|
|
|
114
224
|
# proposed mmTEM
|
|
115
225
|
|
|
116
226
|
class mmTEM(Module):
|
|
227
|
+
@beartype
|
|
117
228
|
def __init__(
|
|
118
229
|
self,
|
|
119
230
|
dim,
|
|
120
231
|
*,
|
|
121
|
-
|
|
122
|
-
sensory_decoder: Module,
|
|
232
|
+
sensory_encoder_decoder: tuple[Module, Module],
|
|
123
233
|
dim_sensory,
|
|
124
234
|
dim_action,
|
|
125
235
|
dim_encoded_sensory,
|
|
@@ -133,12 +243,15 @@ class mmTEM(Module):
|
|
|
133
243
|
loss_weight_consistency = 1.,
|
|
134
244
|
loss_weight_relational = 1.,
|
|
135
245
|
integration_ratio_learned = True,
|
|
246
|
+
muon_update = False,
|
|
136
247
|
assoc_scan_kwargs: dict = dict()
|
|
137
248
|
):
|
|
138
249
|
super().__init__()
|
|
139
250
|
|
|
140
251
|
# sensory
|
|
141
252
|
|
|
253
|
+
sensory_encoder, sensory_decoder = sensory_encoder_decoder
|
|
254
|
+
|
|
142
255
|
self.sensory_encoder = sensory_encoder
|
|
143
256
|
self.sensory_decoder = sensory_decoder
|
|
144
257
|
|
|
@@ -209,6 +322,10 @@ class mmTEM(Module):
|
|
|
209
322
|
self.loss_weight_consistency = loss_weight_consistency
|
|
210
323
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
211
324
|
|
|
325
|
+
# update with muon
|
|
326
|
+
|
|
327
|
+
self.muon_update = muon_update
|
|
328
|
+
|
|
212
329
|
# there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
|
|
213
330
|
|
|
214
331
|
self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
|
|
@@ -338,6 +455,13 @@ class mmTEM(Module):
|
|
|
338
455
|
|
|
339
456
|
update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
|
|
340
457
|
|
|
458
|
+
# maybe muon
|
|
459
|
+
|
|
460
|
+
if self.muon_update:
|
|
461
|
+
update = newtonschulz5(update)
|
|
462
|
+
|
|
463
|
+
# with forget gating
|
|
464
|
+
|
|
341
465
|
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
342
466
|
|
|
343
467
|
acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hippoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: hippoformer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hippoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hippoformer
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
|
+
Requires-Dist: beartype
|
|
38
39
|
Requires-Dist: einops>=0.8.1
|
|
39
40
|
Requires-Dist: einx>=0.3.0
|
|
40
41
|
Requires-Dist: torch>=2.4
|
|
@@ -50,8 +51,6 @@ Description-Content-Type: text/markdown
|
|
|
50
51
|
|
|
51
52
|
Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Integrating Hippocampus-inspired Spatial Memory with Transformers
|
|
52
53
|
|
|
53
|
-
[Temporary Discord](https://discord.gg/MkACrrkrYR)
|
|
54
|
-
|
|
55
54
|
## Citations
|
|
56
55
|
|
|
57
56
|
```bibtex
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=m7luQGFdMWOkZUorjd5v34hx_vjOQqpJOAGCL0njHUE,14426
|
|
3
|
+
hippoformer-0.0.9.dist-info/METADATA,sha256=owgkDcdTf0_N5IbUr3e_yt7u5sIWfOMha-hA5LQWnus,2772
|
|
4
|
+
hippoformer-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hippoformer-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.9.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=yYoJ5XO0YVAyp3LcRxpunU-0HA97mpCBeQFyi-NSkF0,11549
|
|
3
|
-
hippoformer-0.0.7.dist-info/METADATA,sha256=Xg6NZ6VAQGmuiOo8mMwIAM39Gf6TpVOpyn7o4PMq7JE,2800
|
|
4
|
-
hippoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|