hippoformer 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,8 @@ from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
8
  from torch.func import vmap, grad, functional_call
9
9
 
10
- from einops import repeat, rearrange
10
+ from einx import multiply
11
+ from einops import repeat, rearrange, pack, unpack
11
12
  from einops.layers.torch import Rearrange
12
13
 
13
14
  from x_mlps_pytorch import create_mlp
@@ -22,6 +23,16 @@ def exists(v):
22
23
  def default(v, d):
23
24
  return v if exists(v) else d
24
25
 
26
+ def pack_with_inverse(t, pattern):
27
+ packed, packed_shape = pack([t], pattern)
28
+
29
+ def inverse(out, inv_pattern = None):
30
+ inv_pattern = default(inv_pattern, pattern)
31
+ unpacked, = unpack(out, packed_shape, inv_pattern)
32
+ return unpacked
33
+
34
+ return packed, inverse
35
+
25
36
  def l2norm(t):
26
37
  return F.normalize(t, dim = -1)
27
38
 
@@ -121,7 +132,8 @@ class mmTEM(Module):
121
132
  loss_weight_inference = 1.,
122
133
  loss_weight_consistency = 1.,
123
134
  loss_weight_relational = 1.,
124
- integration_ratio_learned = True
135
+ integration_ratio_learned = True,
136
+ assoc_scan_kwargs: dict = dict()
125
137
  ):
126
138
  super().__init__()
127
139
 
@@ -150,6 +162,9 @@ class mmTEM(Module):
150
162
  self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
151
163
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
152
164
 
165
+ self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
166
+ self.assoc_scan = AssocScan(*assoc_scan_kwargs)
167
+
153
168
  self.meta_memory_mlp = create_mlp(
154
169
  dim = dim * 2,
155
170
  depth = meta_mlp_depth,
@@ -217,6 +232,8 @@ class mmTEM(Module):
217
232
  actions,
218
233
  return_losses = False
219
234
  ):
235
+ batch = actions.shape[0]
236
+
220
237
  structural_codes = self.path_integrator(actions)
221
238
 
222
239
  encoded_sensory = self.sensory_encoder(sensory)
@@ -272,7 +289,43 @@ class mmTEM(Module):
272
289
  keys = self.to_keys(joint_code_to_store)
273
290
  values = self.to_values(joint_code_to_store)
274
291
 
275
- grads = self.per_sample_grad_fn(dict(self.meta_memory_mlp.named_parameters()), keys, values)
292
+ lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
293
+
294
+ params = dict(self.meta_memory_mlp.named_parameters())
295
+ grads = self.per_sample_grad_fn(params, keys, values)
296
+
297
+ # update the meta mlp parameters
298
+
299
+ init_momentums = {k: zeros_like(v) for k, v in params.items()}
300
+ next_params = dict()
301
+
302
+ for (
303
+ (key, param),
304
+ (_, grad),
305
+ (_, init_momentum)
306
+ ) in zip(
307
+ params.items(),
308
+ grads.items(),
309
+ init_momentums.items()
310
+ ):
311
+
312
+ grad, inverse_pack = pack_with_inverse(grad, 'b t *')
313
+
314
+ grad = multiply('b t ..., b t', grad, lr)
315
+
316
+ expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
317
+
318
+ init_momentum = repeat(init_momentum, '... -> b ...', b = batch)
319
+
320
+ update = self.assoc_scan(grad, expanded_beta.sigmoid(), init_momentum)
321
+
322
+ expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
323
+
324
+ acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
325
+
326
+ acc_update = inverse_pack(acc_update)
327
+
328
+ next_params[key] = param - acc_update[:, -1]
276
329
 
277
330
  # losses
278
331
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -36,6 +36,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
38
  Requires-Dist: einops>=0.8.1
39
+ Requires-Dist: einx>=0.3.0
39
40
  Requires-Dist: torch>=2.4
40
41
  Requires-Dist: x-mlps-pytorch
41
42
  Provides-Extra: examples
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=b6EXU2VXh_ZD7brxpCVNuU-m7cE-zXRR-sOmqfofPCg,10839
3
+ hippoformer-0.0.6.dist-info/METADATA,sha256=ufTBdu8ZGggxwfgzphYV56jjaGdI5sLCE_iZF5Bku6s,2800
4
+ hippoformer-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.6.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=PP2KmTygOP6MyYuhmr_8iEBbywIaTW4TpoIycYRugMo,9142
3
- hippoformer-0.0.5.dist-info/METADATA,sha256=83iG4F_6ibQy6XSCWht-aF2ZVYmiEq-KSF4XR9YaBtY,2773
4
- hippoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.5.dist-info/RECORD,,