hippoformer 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +56 -3
- {hippoformer-0.0.5.dist-info → hippoformer-0.0.6.dist-info}/METADATA +2 -1
- hippoformer-0.0.6.dist-info/RECORD +6 -0
- hippoformer-0.0.5.dist-info/RECORD +0 -6
- {hippoformer-0.0.5.dist-info → hippoformer-0.0.6.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.5.dist-info → hippoformer-0.0.6.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -7,7 +7,8 @@ from torch.nn import Module
|
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
8
|
from torch.func import vmap, grad, functional_call
|
|
9
9
|
|
|
10
|
-
from
|
|
10
|
+
from einx import multiply
|
|
11
|
+
from einops import repeat, rearrange, pack, unpack
|
|
11
12
|
from einops.layers.torch import Rearrange
|
|
12
13
|
|
|
13
14
|
from x_mlps_pytorch import create_mlp
|
|
@@ -22,6 +23,16 @@ def exists(v):
|
|
|
22
23
|
def default(v, d):
|
|
23
24
|
return v if exists(v) else d
|
|
24
25
|
|
|
26
|
+
def pack_with_inverse(t, pattern):
|
|
27
|
+
packed, packed_shape = pack([t], pattern)
|
|
28
|
+
|
|
29
|
+
def inverse(out, inv_pattern = None):
|
|
30
|
+
inv_pattern = default(inv_pattern, pattern)
|
|
31
|
+
unpacked, = unpack(out, packed_shape, inv_pattern)
|
|
32
|
+
return unpacked
|
|
33
|
+
|
|
34
|
+
return packed, inverse
|
|
35
|
+
|
|
25
36
|
def l2norm(t):
|
|
26
37
|
return F.normalize(t, dim = -1)
|
|
27
38
|
|
|
@@ -121,7 +132,8 @@ class mmTEM(Module):
|
|
|
121
132
|
loss_weight_inference = 1.,
|
|
122
133
|
loss_weight_consistency = 1.,
|
|
123
134
|
loss_weight_relational = 1.,
|
|
124
|
-
integration_ratio_learned = True
|
|
135
|
+
integration_ratio_learned = True,
|
|
136
|
+
assoc_scan_kwargs: dict = dict()
|
|
125
137
|
):
|
|
126
138
|
super().__init__()
|
|
127
139
|
|
|
@@ -150,6 +162,9 @@ class mmTEM(Module):
|
|
|
150
162
|
self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
151
163
|
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
152
164
|
|
|
165
|
+
self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
|
|
166
|
+
self.assoc_scan = AssocScan(*assoc_scan_kwargs)
|
|
167
|
+
|
|
153
168
|
self.meta_memory_mlp = create_mlp(
|
|
154
169
|
dim = dim * 2,
|
|
155
170
|
depth = meta_mlp_depth,
|
|
@@ -217,6 +232,8 @@ class mmTEM(Module):
|
|
|
217
232
|
actions,
|
|
218
233
|
return_losses = False
|
|
219
234
|
):
|
|
235
|
+
batch = actions.shape[0]
|
|
236
|
+
|
|
220
237
|
structural_codes = self.path_integrator(actions)
|
|
221
238
|
|
|
222
239
|
encoded_sensory = self.sensory_encoder(sensory)
|
|
@@ -272,7 +289,43 @@ class mmTEM(Module):
|
|
|
272
289
|
keys = self.to_keys(joint_code_to_store)
|
|
273
290
|
values = self.to_values(joint_code_to_store)
|
|
274
291
|
|
|
275
|
-
|
|
292
|
+
lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
|
|
293
|
+
|
|
294
|
+
params = dict(self.meta_memory_mlp.named_parameters())
|
|
295
|
+
grads = self.per_sample_grad_fn(params, keys, values)
|
|
296
|
+
|
|
297
|
+
# update the meta mlp parameters
|
|
298
|
+
|
|
299
|
+
init_momentums = {k: zeros_like(v) for k, v in params.items()}
|
|
300
|
+
next_params = dict()
|
|
301
|
+
|
|
302
|
+
for (
|
|
303
|
+
(key, param),
|
|
304
|
+
(_, grad),
|
|
305
|
+
(_, init_momentum)
|
|
306
|
+
) in zip(
|
|
307
|
+
params.items(),
|
|
308
|
+
grads.items(),
|
|
309
|
+
init_momentums.items()
|
|
310
|
+
):
|
|
311
|
+
|
|
312
|
+
grad, inverse_pack = pack_with_inverse(grad, 'b t *')
|
|
313
|
+
|
|
314
|
+
grad = multiply('b t ..., b t', grad, lr)
|
|
315
|
+
|
|
316
|
+
expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
|
|
317
|
+
|
|
318
|
+
init_momentum = repeat(init_momentum, '... -> b ...', b = batch)
|
|
319
|
+
|
|
320
|
+
update = self.assoc_scan(grad, expanded_beta.sigmoid(), init_momentum)
|
|
321
|
+
|
|
322
|
+
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
323
|
+
|
|
324
|
+
acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
|
|
325
|
+
|
|
326
|
+
acc_update = inverse_pack(acc_update)
|
|
327
|
+
|
|
328
|
+
next_params[key] = param - acc_update[:, -1]
|
|
276
329
|
|
|
277
330
|
# losses
|
|
278
331
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hippoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: hippoformer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hippoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hippoformer
|
|
@@ -36,6 +36,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
38
|
Requires-Dist: einops>=0.8.1
|
|
39
|
+
Requires-Dist: einx>=0.3.0
|
|
39
40
|
Requires-Dist: torch>=2.4
|
|
40
41
|
Requires-Dist: x-mlps-pytorch
|
|
41
42
|
Provides-Extra: examples
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=b6EXU2VXh_ZD7brxpCVNuU-m7cE-zXRR-sOmqfofPCg,10839
|
|
3
|
+
hippoformer-0.0.6.dist-info/METADATA,sha256=ufTBdu8ZGggxwfgzphYV56jjaGdI5sLCE_iZF5Bku6s,2800
|
|
4
|
+
hippoformer-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hippoformer-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.6.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=PP2KmTygOP6MyYuhmr_8iEBbywIaTW4TpoIycYRugMo,9142
|
|
3
|
-
hippoformer-0.0.5.dist-info/METADATA,sha256=83iG4F_6ibQy6XSCWht-aF2ZVYmiEq-KSF4XR9YaBtY,2773
|
|
4
|
-
hippoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|