hippoformer 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.4 → hippoformer-0.0.6}/PKG-INFO +2 -1
- {hippoformer-0.0.4 → hippoformer-0.0.6}/hippoformer/hippoformer.py +79 -5
- {hippoformer-0.0.4 → hippoformer-0.0.6}/pyproject.toml +2 -1
- {hippoformer-0.0.4 → hippoformer-0.0.6}/tests/test_hippoformer.py +1 -1
- {hippoformer-0.0.4 → hippoformer-0.0.6}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.4 → hippoformer-0.0.6}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.4 → hippoformer-0.0.6}/.gitignore +0 -0
- {hippoformer-0.0.4 → hippoformer-0.0.6}/LICENSE +0 -0
- {hippoformer-0.0.4 → hippoformer-0.0.6}/README.md +0 -0
- {hippoformer-0.0.4 → hippoformer-0.0.6}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.4 → hippoformer-0.0.6}/hippoformer-fig6.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hippoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: hippoformer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hippoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hippoformer
|
|
@@ -36,6 +36,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
38
|
Requires-Dist: einops>=0.8.1
|
|
39
|
+
Requires-Dist: einx>=0.3.0
|
|
39
40
|
Requires-Dist: torch>=2.4
|
|
40
41
|
Requires-Dist: x-mlps-pytorch
|
|
41
42
|
Provides-Extra: examples
|
|
@@ -7,7 +7,8 @@ from torch.nn import Module
|
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
8
|
from torch.func import vmap, grad, functional_call
|
|
9
9
|
|
|
10
|
-
from
|
|
10
|
+
from einx import multiply
|
|
11
|
+
from einops import repeat, rearrange, pack, unpack
|
|
11
12
|
from einops.layers.torch import Rearrange
|
|
12
13
|
|
|
13
14
|
from x_mlps_pytorch import create_mlp
|
|
@@ -22,6 +23,16 @@ def exists(v):
|
|
|
22
23
|
def default(v, d):
|
|
23
24
|
return v if exists(v) else d
|
|
24
25
|
|
|
26
|
+
def pack_with_inverse(t, pattern):
|
|
27
|
+
packed, packed_shape = pack([t], pattern)
|
|
28
|
+
|
|
29
|
+
def inverse(out, inv_pattern = None):
|
|
30
|
+
inv_pattern = default(inv_pattern, pattern)
|
|
31
|
+
unpacked, = unpack(out, packed_shape, inv_pattern)
|
|
32
|
+
return unpacked
|
|
33
|
+
|
|
34
|
+
return packed, inverse
|
|
35
|
+
|
|
25
36
|
def l2norm(t):
|
|
26
37
|
return F.normalize(t, dim = -1)
|
|
27
38
|
|
|
@@ -121,7 +132,8 @@ class mmTEM(Module):
|
|
|
121
132
|
loss_weight_inference = 1.,
|
|
122
133
|
loss_weight_consistency = 1.,
|
|
123
134
|
loss_weight_relational = 1.,
|
|
124
|
-
integration_ratio_learned = True
|
|
135
|
+
integration_ratio_learned = True,
|
|
136
|
+
assoc_scan_kwargs: dict = dict()
|
|
125
137
|
):
|
|
126
138
|
super().__init__()
|
|
127
139
|
|
|
@@ -150,6 +162,9 @@ class mmTEM(Module):
|
|
|
150
162
|
self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
151
163
|
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
152
164
|
|
|
165
|
+
self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
|
|
166
|
+
self.assoc_scan = AssocScan(*assoc_scan_kwargs)
|
|
167
|
+
|
|
153
168
|
self.meta_memory_mlp = create_mlp(
|
|
154
169
|
dim = dim * 2,
|
|
155
170
|
depth = meta_mlp_depth,
|
|
@@ -158,6 +173,14 @@ class mmTEM(Module):
|
|
|
158
173
|
activation = nn.ReLU()
|
|
159
174
|
)
|
|
160
175
|
|
|
176
|
+
def forward_with_mse_loss(params, keys, values):
|
|
177
|
+
pred = functional_call(self.meta_memory_mlp, params, keys)
|
|
178
|
+
return F.mse_loss(pred, values)
|
|
179
|
+
|
|
180
|
+
grad_fn = grad(forward_with_mse_loss)
|
|
181
|
+
|
|
182
|
+
self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
|
|
183
|
+
|
|
161
184
|
# mlp decoder (from meta mlp output to joint)
|
|
162
185
|
|
|
163
186
|
self.memory_output_decoder = create_mlp(
|
|
@@ -206,8 +229,11 @@ class mmTEM(Module):
|
|
|
206
229
|
def forward(
|
|
207
230
|
self,
|
|
208
231
|
sensory,
|
|
209
|
-
actions
|
|
232
|
+
actions,
|
|
233
|
+
return_losses = False
|
|
210
234
|
):
|
|
235
|
+
batch = actions.shape[0]
|
|
236
|
+
|
|
211
237
|
structural_codes = self.path_integrator(actions)
|
|
212
238
|
|
|
213
239
|
encoded_sensory = self.sensory_encoder(sensory)
|
|
@@ -244,18 +270,63 @@ class mmTEM(Module):
|
|
|
244
270
|
|
|
245
271
|
pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
|
|
246
272
|
|
|
247
|
-
inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio * pred_variance
|
|
273
|
+
inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio.sigmoid() * pred_variance
|
|
248
274
|
|
|
249
275
|
consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
|
|
250
276
|
|
|
251
277
|
# 4. final inference loss
|
|
252
278
|
|
|
253
|
-
|
|
279
|
+
final_structural_code, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
|
|
254
280
|
|
|
255
281
|
decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
|
|
256
282
|
|
|
257
283
|
inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
|
|
258
284
|
|
|
285
|
+
# 5. store the final structural code from step 4 + encoded sensory
|
|
286
|
+
|
|
287
|
+
joint_code_to_store = cat((final_structural_code, encoded_sensory), dim = -1)
|
|
288
|
+
|
|
289
|
+
keys = self.to_keys(joint_code_to_store)
|
|
290
|
+
values = self.to_values(joint_code_to_store)
|
|
291
|
+
|
|
292
|
+
lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
|
|
293
|
+
|
|
294
|
+
params = dict(self.meta_memory_mlp.named_parameters())
|
|
295
|
+
grads = self.per_sample_grad_fn(params, keys, values)
|
|
296
|
+
|
|
297
|
+
# update the meta mlp parameters
|
|
298
|
+
|
|
299
|
+
init_momentums = {k: zeros_like(v) for k, v in params.items()}
|
|
300
|
+
next_params = dict()
|
|
301
|
+
|
|
302
|
+
for (
|
|
303
|
+
(key, param),
|
|
304
|
+
(_, grad),
|
|
305
|
+
(_, init_momentum)
|
|
306
|
+
) in zip(
|
|
307
|
+
params.items(),
|
|
308
|
+
grads.items(),
|
|
309
|
+
init_momentums.items()
|
|
310
|
+
):
|
|
311
|
+
|
|
312
|
+
grad, inverse_pack = pack_with_inverse(grad, 'b t *')
|
|
313
|
+
|
|
314
|
+
grad = multiply('b t ..., b t', grad, lr)
|
|
315
|
+
|
|
316
|
+
expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
|
|
317
|
+
|
|
318
|
+
init_momentum = repeat(init_momentum, '... -> b ...', b = batch)
|
|
319
|
+
|
|
320
|
+
update = self.assoc_scan(grad, expanded_beta.sigmoid(), init_momentum)
|
|
321
|
+
|
|
322
|
+
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
323
|
+
|
|
324
|
+
acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
|
|
325
|
+
|
|
326
|
+
acc_update = inverse_pack(acc_update)
|
|
327
|
+
|
|
328
|
+
next_params[key] = param - acc_update[:, -1]
|
|
329
|
+
|
|
259
330
|
# losses
|
|
260
331
|
|
|
261
332
|
total_loss = (
|
|
@@ -272,4 +343,7 @@ class mmTEM(Module):
|
|
|
272
343
|
inference_pred_loss
|
|
273
344
|
)
|
|
274
345
|
|
|
346
|
+
if not return_losses:
|
|
347
|
+
return total_loss
|
|
348
|
+
|
|
275
349
|
return total_loss, losses
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "hippoformer"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.6"
|
|
4
4
|
description = "hippoformer"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -25,6 +25,7 @@ classifiers=[
|
|
|
25
25
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"assoc-scan",
|
|
28
|
+
"einx>=0.3.0",
|
|
28
29
|
"einops>=0.8.1",
|
|
29
30
|
"torch>=2.4",
|
|
30
31
|
"x-mlps-pytorch",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|