hippoformer 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +35 -10
- {hippoformer-0.0.6.dist-info → hippoformer-0.0.7.dist-info}/METADATA +1 -1
- hippoformer-0.0.7.dist-info/RECORD +6 -0
- hippoformer-0.0.6.dist-info/RECORD +0 -6
- {hippoformer-0.0.6.dist-info → hippoformer-0.0.7.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.6.dist-info → hippoformer-0.0.7.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -179,7 +179,7 @@ class mmTEM(Module):
|
|
|
179
179
|
|
|
180
180
|
grad_fn = grad(forward_with_mse_loss)
|
|
181
181
|
|
|
182
|
-
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (
|
|
182
|
+
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (0, 0, 0))
|
|
183
183
|
|
|
184
184
|
# mlp decoder (from meta mlp output to joint)
|
|
185
185
|
|
|
@@ -213,6 +213,19 @@ class mmTEM(Module):
|
|
|
213
213
|
|
|
214
214
|
self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
|
|
215
215
|
|
|
216
|
+
def init_params_and_momentum(
|
|
217
|
+
self,
|
|
218
|
+
batch_size
|
|
219
|
+
):
|
|
220
|
+
|
|
221
|
+
params_dict = dict(self.meta_memory_mlp.named_parameters())
|
|
222
|
+
|
|
223
|
+
params = {name: repeat(param, '... -> b ...', b = batch_size) for name, param in params_dict.items()}
|
|
224
|
+
|
|
225
|
+
momentums = {name: zeros_like(param) for name, param in params.items()}
|
|
226
|
+
|
|
227
|
+
return params, momentums
|
|
228
|
+
|
|
216
229
|
def retrieve(
|
|
217
230
|
self,
|
|
218
231
|
structural_codes,
|
|
@@ -230,7 +243,9 @@ class mmTEM(Module):
|
|
|
230
243
|
self,
|
|
231
244
|
sensory,
|
|
232
245
|
actions,
|
|
233
|
-
|
|
246
|
+
memory_mlp_params = None,
|
|
247
|
+
return_losses = False,
|
|
248
|
+
return_memory_mlp_params = False
|
|
234
249
|
):
|
|
235
250
|
batch = actions.shape[0]
|
|
236
251
|
|
|
@@ -291,22 +306,28 @@ class mmTEM(Module):
|
|
|
291
306
|
|
|
292
307
|
lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
|
|
293
308
|
|
|
294
|
-
|
|
309
|
+
if exists(memory_mlp_params):
|
|
310
|
+
params, momentums = memory_mlp_params
|
|
311
|
+
else:
|
|
312
|
+
params, momentums = self.init_params_and_momentum(batch)
|
|
313
|
+
|
|
314
|
+
# store by getting gradients of mse loss of keys and values
|
|
315
|
+
|
|
295
316
|
grads = self.per_sample_grad_fn(params, keys, values)
|
|
296
317
|
|
|
297
|
-
# update the meta mlp parameters
|
|
318
|
+
# update the meta mlp parameters and momentums
|
|
298
319
|
|
|
299
|
-
init_momentums = {k: zeros_like(v) for k, v in params.items()}
|
|
300
320
|
next_params = dict()
|
|
321
|
+
next_momentum = dict()
|
|
301
322
|
|
|
302
323
|
for (
|
|
303
324
|
(key, param),
|
|
304
325
|
(_, grad),
|
|
305
|
-
(_,
|
|
326
|
+
(_, momentum)
|
|
306
327
|
) in zip(
|
|
307
328
|
params.items(),
|
|
308
329
|
grads.items(),
|
|
309
|
-
|
|
330
|
+
momentums.items()
|
|
310
331
|
):
|
|
311
332
|
|
|
312
333
|
grad, inverse_pack = pack_with_inverse(grad, 'b t *')
|
|
@@ -315,9 +336,7 @@ class mmTEM(Module):
|
|
|
315
336
|
|
|
316
337
|
expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
|
|
317
338
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
update = self.assoc_scan(grad, expanded_beta.sigmoid(), init_momentum)
|
|
339
|
+
update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
|
|
321
340
|
|
|
322
341
|
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
323
342
|
|
|
@@ -325,7 +344,10 @@ class mmTEM(Module):
|
|
|
325
344
|
|
|
326
345
|
acc_update = inverse_pack(acc_update)
|
|
327
346
|
|
|
347
|
+
# set the next params and momentum, which can be passed back in
|
|
348
|
+
|
|
328
349
|
next_params[key] = param - acc_update[:, -1]
|
|
350
|
+
next_momentum[key] = update[:, -1]
|
|
329
351
|
|
|
330
352
|
# losses
|
|
331
353
|
|
|
@@ -343,6 +365,9 @@ class mmTEM(Module):
|
|
|
343
365
|
inference_pred_loss
|
|
344
366
|
)
|
|
345
367
|
|
|
368
|
+
if return_memory_mlp_params:
|
|
369
|
+
return next_params, next_momentum
|
|
370
|
+
|
|
346
371
|
if not return_losses:
|
|
347
372
|
return total_loss
|
|
348
373
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=yYoJ5XO0YVAyp3LcRxpunU-0HA97mpCBeQFyi-NSkF0,11549
|
|
3
|
+
hippoformer-0.0.7.dist-info/METADATA,sha256=Xg6NZ6VAQGmuiOo8mMwIAM39Gf6TpVOpyn7o4PMq7JE,2800
|
|
4
|
+
hippoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hippoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.7.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=b6EXU2VXh_ZD7brxpCVNuU-m7cE-zXRR-sOmqfofPCg,10839
|
|
3
|
-
hippoformer-0.0.6.dist-info/METADATA,sha256=ufTBdu8ZGggxwfgzphYV56jjaGdI5sLCE_iZF5Bku6s,2800
|
|
4
|
-
hippoformer-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|