hippoformer 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,8 @@ from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
8
  from torch.func import vmap, grad, functional_call
9
9
 
10
- from einops import repeat, rearrange
10
+ from einx import multiply
11
+ from einops import repeat, rearrange, pack, unpack
11
12
  from einops.layers.torch import Rearrange
12
13
 
13
14
  from x_mlps_pytorch import create_mlp
@@ -22,6 +23,16 @@ def exists(v):
22
23
  def default(v, d):
23
24
  return v if exists(v) else d
24
25
 
26
+ def pack_with_inverse(t, pattern):
27
+ packed, packed_shape = pack([t], pattern)
28
+
29
+ def inverse(out, inv_pattern = None):
30
+ inv_pattern = default(inv_pattern, pattern)
31
+ unpacked, = unpack(out, packed_shape, inv_pattern)
32
+ return unpacked
33
+
34
+ return packed, inverse
35
+
25
36
  def l2norm(t):
26
37
  return F.normalize(t, dim = -1)
27
38
 
@@ -121,7 +132,8 @@ class mmTEM(Module):
121
132
  loss_weight_inference = 1.,
122
133
  loss_weight_consistency = 1.,
123
134
  loss_weight_relational = 1.,
124
- integration_ratio_learned = True
135
+ integration_ratio_learned = True,
136
+ assoc_scan_kwargs: dict = dict()
125
137
  ):
126
138
  super().__init__()
127
139
 
@@ -150,6 +162,9 @@ class mmTEM(Module):
150
162
  self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
151
163
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
152
164
 
165
+ self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
166
+ self.assoc_scan = AssocScan(*assoc_scan_kwargs)
167
+
153
168
  self.meta_memory_mlp = create_mlp(
154
169
  dim = dim * 2,
155
170
  depth = meta_mlp_depth,
@@ -164,7 +179,7 @@ class mmTEM(Module):
164
179
 
165
180
  grad_fn = grad(forward_with_mse_loss)
166
181
 
167
- self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
182
+ self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (0, 0, 0))
168
183
 
169
184
  # mlp decoder (from meta mlp output to joint)
170
185
 
@@ -198,6 +213,19 @@ class mmTEM(Module):
198
213
 
199
214
  self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
200
215
 
216
+ def init_params_and_momentum(
217
+ self,
218
+ batch_size
219
+ ):
220
+
221
+ params_dict = dict(self.meta_memory_mlp.named_parameters())
222
+
223
+ params = {name: repeat(param, '... -> b ...', b = batch_size) for name, param in params_dict.items()}
224
+
225
+ momentums = {name: zeros_like(param) for name, param in params.items()}
226
+
227
+ return params, momentums
228
+
201
229
  def retrieve(
202
230
  self,
203
231
  structural_codes,
@@ -215,8 +243,12 @@ class mmTEM(Module):
215
243
  self,
216
244
  sensory,
217
245
  actions,
218
- return_losses = False
246
+ memory_mlp_params = None,
247
+ return_losses = False,
248
+ return_memory_mlp_params = False
219
249
  ):
250
+ batch = actions.shape[0]
251
+
220
252
  structural_codes = self.path_integrator(actions)
221
253
 
222
254
  encoded_sensory = self.sensory_encoder(sensory)
@@ -272,7 +304,50 @@ class mmTEM(Module):
272
304
  keys = self.to_keys(joint_code_to_store)
273
305
  values = self.to_values(joint_code_to_store)
274
306
 
275
- grads = self.per_sample_grad_fn(dict(self.meta_memory_mlp.named_parameters()), keys, values)
307
+ lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
308
+
309
+ if exists(memory_mlp_params):
310
+ params, momentums = memory_mlp_params
311
+ else:
312
+ params, momentums = self.init_params_and_momentum(batch)
313
+
314
+ # store by getting gradients of mse loss of keys and values
315
+
316
+ grads = self.per_sample_grad_fn(params, keys, values)
317
+
318
+ # update the meta mlp parameters and momentums
319
+
320
+ next_params = dict()
321
+ next_momentum = dict()
322
+
323
+ for (
324
+ (key, param),
325
+ (_, grad),
326
+ (_, momentum)
327
+ ) in zip(
328
+ params.items(),
329
+ grads.items(),
330
+ momentums.items()
331
+ ):
332
+
333
+ grad, inverse_pack = pack_with_inverse(grad, 'b t *')
334
+
335
+ grad = multiply('b t ..., b t', grad, lr)
336
+
337
+ expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
338
+
339
+ update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
340
+
341
+ expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
342
+
343
+ acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
344
+
345
+ acc_update = inverse_pack(acc_update)
346
+
347
+ # set the next params and momentum, which can be passed back in
348
+
349
+ next_params[key] = param - acc_update[:, -1]
350
+ next_momentum[key] = update[:, -1]
276
351
 
277
352
  # losses
278
353
 
@@ -290,6 +365,9 @@ class mmTEM(Module):
290
365
  inference_pred_loss
291
366
  )
292
367
 
368
+ if return_memory_mlp_params:
369
+ return next_params, next_momentum
370
+
293
371
  if not return_losses:
294
372
  return total_loss
295
373
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -36,6 +36,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
38
  Requires-Dist: einops>=0.8.1
39
+ Requires-Dist: einx>=0.3.0
39
40
  Requires-Dist: torch>=2.4
40
41
  Requires-Dist: x-mlps-pytorch
41
42
  Provides-Extra: examples
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=yYoJ5XO0YVAyp3LcRxpunU-0HA97mpCBeQFyi-NSkF0,11549
3
+ hippoformer-0.0.7.dist-info/METADATA,sha256=Xg6NZ6VAQGmuiOo8mMwIAM39Gf6TpVOpyn7o4PMq7JE,2800
4
+ hippoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.7.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=PP2KmTygOP6MyYuhmr_8iEBbywIaTW4TpoIycYRugMo,9142
3
- hippoformer-0.0.5.dist-info/METADATA,sha256=83iG4F_6ibQy6XSCWht-aF2ZVYmiEq-KSF4XR9YaBtY,2773
4
- hippoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.5.dist-info/RECORD,,