hippoformer 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +83 -5
- {hippoformer-0.0.5.dist-info → hippoformer-0.0.7.dist-info}/METADATA +2 -1
- hippoformer-0.0.7.dist-info/RECORD +6 -0
- hippoformer-0.0.5.dist-info/RECORD +0 -6
- {hippoformer-0.0.5.dist-info → hippoformer-0.0.7.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.5.dist-info → hippoformer-0.0.7.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -7,7 +7,8 @@ from torch.nn import Module
|
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
8
|
from torch.func import vmap, grad, functional_call
|
|
9
9
|
|
|
10
|
-
from
|
|
10
|
+
from einx import multiply
|
|
11
|
+
from einops import repeat, rearrange, pack, unpack
|
|
11
12
|
from einops.layers.torch import Rearrange
|
|
12
13
|
|
|
13
14
|
from x_mlps_pytorch import create_mlp
|
|
@@ -22,6 +23,16 @@ def exists(v):
|
|
|
22
23
|
def default(v, d):
|
|
23
24
|
return v if exists(v) else d
|
|
24
25
|
|
|
26
|
+
def pack_with_inverse(t, pattern):
|
|
27
|
+
packed, packed_shape = pack([t], pattern)
|
|
28
|
+
|
|
29
|
+
def inverse(out, inv_pattern = None):
|
|
30
|
+
inv_pattern = default(inv_pattern, pattern)
|
|
31
|
+
unpacked, = unpack(out, packed_shape, inv_pattern)
|
|
32
|
+
return unpacked
|
|
33
|
+
|
|
34
|
+
return packed, inverse
|
|
35
|
+
|
|
25
36
|
def l2norm(t):
|
|
26
37
|
return F.normalize(t, dim = -1)
|
|
27
38
|
|
|
@@ -121,7 +132,8 @@ class mmTEM(Module):
|
|
|
121
132
|
loss_weight_inference = 1.,
|
|
122
133
|
loss_weight_consistency = 1.,
|
|
123
134
|
loss_weight_relational = 1.,
|
|
124
|
-
integration_ratio_learned = True
|
|
135
|
+
integration_ratio_learned = True,
|
|
136
|
+
assoc_scan_kwargs: dict = dict()
|
|
125
137
|
):
|
|
126
138
|
super().__init__()
|
|
127
139
|
|
|
@@ -150,6 +162,9 @@ class mmTEM(Module):
|
|
|
150
162
|
self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
151
163
|
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
152
164
|
|
|
165
|
+
self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
|
|
166
|
+
self.assoc_scan = AssocScan(*assoc_scan_kwargs)
|
|
167
|
+
|
|
153
168
|
self.meta_memory_mlp = create_mlp(
|
|
154
169
|
dim = dim * 2,
|
|
155
170
|
depth = meta_mlp_depth,
|
|
@@ -164,7 +179,7 @@ class mmTEM(Module):
|
|
|
164
179
|
|
|
165
180
|
grad_fn = grad(forward_with_mse_loss)
|
|
166
181
|
|
|
167
|
-
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (
|
|
182
|
+
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (0, 0, 0))
|
|
168
183
|
|
|
169
184
|
# mlp decoder (from meta mlp output to joint)
|
|
170
185
|
|
|
@@ -198,6 +213,19 @@ class mmTEM(Module):
|
|
|
198
213
|
|
|
199
214
|
self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
|
|
200
215
|
|
|
216
|
+
def init_params_and_momentum(
|
|
217
|
+
self,
|
|
218
|
+
batch_size
|
|
219
|
+
):
|
|
220
|
+
|
|
221
|
+
params_dict = dict(self.meta_memory_mlp.named_parameters())
|
|
222
|
+
|
|
223
|
+
params = {name: repeat(param, '... -> b ...', b = batch_size) for name, param in params_dict.items()}
|
|
224
|
+
|
|
225
|
+
momentums = {name: zeros_like(param) for name, param in params.items()}
|
|
226
|
+
|
|
227
|
+
return params, momentums
|
|
228
|
+
|
|
201
229
|
def retrieve(
|
|
202
230
|
self,
|
|
203
231
|
structural_codes,
|
|
@@ -215,8 +243,12 @@ class mmTEM(Module):
|
|
|
215
243
|
self,
|
|
216
244
|
sensory,
|
|
217
245
|
actions,
|
|
218
|
-
|
|
246
|
+
memory_mlp_params = None,
|
|
247
|
+
return_losses = False,
|
|
248
|
+
return_memory_mlp_params = False
|
|
219
249
|
):
|
|
250
|
+
batch = actions.shape[0]
|
|
251
|
+
|
|
220
252
|
structural_codes = self.path_integrator(actions)
|
|
221
253
|
|
|
222
254
|
encoded_sensory = self.sensory_encoder(sensory)
|
|
@@ -272,7 +304,50 @@ class mmTEM(Module):
|
|
|
272
304
|
keys = self.to_keys(joint_code_to_store)
|
|
273
305
|
values = self.to_values(joint_code_to_store)
|
|
274
306
|
|
|
275
|
-
|
|
307
|
+
lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
|
|
308
|
+
|
|
309
|
+
if exists(memory_mlp_params):
|
|
310
|
+
params, momentums = memory_mlp_params
|
|
311
|
+
else:
|
|
312
|
+
params, momentums = self.init_params_and_momentum(batch)
|
|
313
|
+
|
|
314
|
+
# store by getting gradients of mse loss of keys and values
|
|
315
|
+
|
|
316
|
+
grads = self.per_sample_grad_fn(params, keys, values)
|
|
317
|
+
|
|
318
|
+
# update the meta mlp parameters and momentums
|
|
319
|
+
|
|
320
|
+
next_params = dict()
|
|
321
|
+
next_momentum = dict()
|
|
322
|
+
|
|
323
|
+
for (
|
|
324
|
+
(key, param),
|
|
325
|
+
(_, grad),
|
|
326
|
+
(_, momentum)
|
|
327
|
+
) in zip(
|
|
328
|
+
params.items(),
|
|
329
|
+
grads.items(),
|
|
330
|
+
momentums.items()
|
|
331
|
+
):
|
|
332
|
+
|
|
333
|
+
grad, inverse_pack = pack_with_inverse(grad, 'b t *')
|
|
334
|
+
|
|
335
|
+
grad = multiply('b t ..., b t', grad, lr)
|
|
336
|
+
|
|
337
|
+
expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
|
|
338
|
+
|
|
339
|
+
update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
|
|
340
|
+
|
|
341
|
+
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
342
|
+
|
|
343
|
+
acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
|
|
344
|
+
|
|
345
|
+
acc_update = inverse_pack(acc_update)
|
|
346
|
+
|
|
347
|
+
# set the next params and momentum, which can be passed back in
|
|
348
|
+
|
|
349
|
+
next_params[key] = param - acc_update[:, -1]
|
|
350
|
+
next_momentum[key] = update[:, -1]
|
|
276
351
|
|
|
277
352
|
# losses
|
|
278
353
|
|
|
@@ -290,6 +365,9 @@ class mmTEM(Module):
|
|
|
290
365
|
inference_pred_loss
|
|
291
366
|
)
|
|
292
367
|
|
|
368
|
+
if return_memory_mlp_params:
|
|
369
|
+
return next_params, next_momentum
|
|
370
|
+
|
|
293
371
|
if not return_losses:
|
|
294
372
|
return total_loss
|
|
295
373
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hippoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: hippoformer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hippoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hippoformer
|
|
@@ -36,6 +36,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
38
|
Requires-Dist: einops>=0.8.1
|
|
39
|
+
Requires-Dist: einx>=0.3.0
|
|
39
40
|
Requires-Dist: torch>=2.4
|
|
40
41
|
Requires-Dist: x-mlps-pytorch
|
|
41
42
|
Provides-Extra: examples
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=yYoJ5XO0YVAyp3LcRxpunU-0HA97mpCBeQFyi-NSkF0,11549
|
|
3
|
+
hippoformer-0.0.7.dist-info/METADATA,sha256=Xg6NZ6VAQGmuiOo8mMwIAM39Gf6TpVOpyn7o4PMq7JE,2800
|
|
4
|
+
hippoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hippoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.7.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=PP2KmTygOP6MyYuhmr_8iEBbywIaTW4TpoIycYRugMo,9142
|
|
3
|
-
hippoformer-0.0.5.dist-info/METADATA,sha256=83iG4F_6ibQy6XSCWht-aF2ZVYmiEq-KSF4XR9YaBtY,2773
|
|
4
|
-
hippoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|