hippoformer 0.0.6__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
+ Requires-Dist: beartype
38
39
  Requires-Dist: einops>=0.8.1
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: torch>=2.4
@@ -50,8 +51,6 @@ Description-Content-Type: text/markdown
50
51
 
51
52
  Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Integrating Hippocampus-inspired Spatial Memory with Transformers
52
53
 
53
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
54
-
55
54
  ## Citations
56
55
 
57
56
  ```bibtex
@@ -4,8 +4,6 @@
4
4
 
5
5
  Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Integrating Hippocampus-inspired Spatial Memory with Transformers
6
6
 
7
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
8
-
9
7
  ## Citations
10
8
 
11
9
  ```bibtex
@@ -7,6 +7,8 @@ from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
8
  from torch.func import vmap, grad, functional_call
9
9
 
10
+ from beartype import beartype
11
+
10
12
  from einx import multiply
11
13
  from einops import repeat, rearrange, pack, unpack
12
14
  from einops.layers.torch import Rearrange
@@ -36,6 +38,80 @@ def pack_with_inverse(t, pattern):
36
38
  def l2norm(t):
37
39
  return F.normalize(t, dim = -1)
38
40
 
41
+ # sensory encoder decoder for 2d
42
+
43
+ grid_sensory_enc_dec = (
44
+ create_mlp(
45
+ dim = 32 * 2,
46
+ dim_in = 9,
47
+ dim_out = 32,
48
+ depth = 3,
49
+ ),
50
+ create_mlp(
51
+ dim = 32 * 2,
52
+ dim_in = 32,
53
+ dim_out = 9,
54
+ depth = 3,
55
+ ),
56
+ )
57
+
58
+ # sensory encoder decoder for 3d maze
59
+
60
+ class EncoderPackTime(Module):
61
+ def __init__(self, fn: Module):
62
+ super().__init__()
63
+ self.fn = fn
64
+
65
+ def forward(self, x):
66
+ x = rearrange(x, 'b c t h w -> b t c h w')
67
+ x, packed_shape = pack([x], '* c h w')
68
+
69
+ x = self.fn(x)
70
+
71
+ x, = unpack(x, packed_shape, '* d')
72
+ print(x.shape)
73
+ return x
74
+
75
+ class DecoderPackTime(Module):
76
+ def __init__(self, fn: Module):
77
+ super().__init__()
78
+ self.fn = fn
79
+
80
+ def forward(self, x):
81
+ x, packed_shape = pack(x, '* d')
82
+
83
+ x = self.fn(x)
84
+
85
+ x = unpack(x, packed_shape, '* c h w')
86
+ x = rearrange(x, 'b t c h w -> b c t h w')
87
+ return x
88
+
89
+ maze_sensory_enc_dec = (
90
+ EncoderPackTime(nn.Sequential(
91
+ nn.Conv2d(3, 16, 7, 2, padding = 3),
92
+ nn.ReLU(),
93
+ nn.Conv2d(16, 32, 3, 2, 1),
94
+ nn.ReLU(),
95
+ nn.Conv2d(32, 64, 3, 2, 1),
96
+ nn.ReLU(),
97
+ nn.Conv2d(64, 128, 3, 2, 1),
98
+ nn.ReLU(),
99
+ Rearrange('b ... -> b (...)'),
100
+ nn.Linear(2048, 32)
101
+ )),
102
+ DecoderPackTime(nn.Sequential(
103
+ nn.Linear(32, 2048),
104
+ Rearrange('b (c h w) -> b c h w', c = 128, h = 4),
105
+ nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding = (1, 1)),
106
+ nn.ReLU(),
107
+ nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding = (1, 1)),
108
+ nn.ReLU(),
109
+ nn.ConvTranspose2d(32, 16, 3, 2, 1, output_padding = (1, 1)),
110
+ nn.ReLU(),
111
+ nn.ConvTranspose2d(16, 3, 3, 2, 1, output_padding = (1, 1))
112
+ ))
113
+ )
114
+
39
115
  # path integration
40
116
 
41
117
  class RNN(ScriptModule):
@@ -114,12 +190,12 @@ class PathIntegration(Module):
114
190
  # proposed mmTEM
115
191
 
116
192
  class mmTEM(Module):
193
+ @beartype
117
194
  def __init__(
118
195
  self,
119
196
  dim,
120
197
  *,
121
- sensory_encoder: Module,
122
- sensory_decoder: Module,
198
+ sensory_encoder_decoder: tuple[Module, Module],
123
199
  dim_sensory,
124
200
  dim_action,
125
201
  dim_encoded_sensory,
@@ -139,6 +215,8 @@ class mmTEM(Module):
139
215
 
140
216
  # sensory
141
217
 
218
+ sensory_encoder, sensory_decoder = sensory_encoder_decoder
219
+
142
220
  self.sensory_encoder = sensory_encoder
143
221
  self.sensory_decoder = sensory_decoder
144
222
 
@@ -179,7 +257,7 @@ class mmTEM(Module):
179
257
 
180
258
  grad_fn = grad(forward_with_mse_loss)
181
259
 
182
- self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (None, 0, 0))
260
+ self.per_sample_grad_fn = vmap(vmap(grad_fn, in_dims = (None, 0, 0)), in_dims = (0, 0, 0))
183
261
 
184
262
  # mlp decoder (from meta mlp output to joint)
185
263
 
@@ -213,6 +291,19 @@ class mmTEM(Module):
213
291
 
214
292
  self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
215
293
 
294
+ def init_params_and_momentum(
295
+ self,
296
+ batch_size
297
+ ):
298
+
299
+ params_dict = dict(self.meta_memory_mlp.named_parameters())
300
+
301
+ params = {name: repeat(param, '... -> b ...', b = batch_size) for name, param in params_dict.items()}
302
+
303
+ momentums = {name: zeros_like(param) for name, param in params.items()}
304
+
305
+ return params, momentums
306
+
216
307
  def retrieve(
217
308
  self,
218
309
  structural_codes,
@@ -230,7 +321,9 @@ class mmTEM(Module):
230
321
  self,
231
322
  sensory,
232
323
  actions,
233
- return_losses = False
324
+ memory_mlp_params = None,
325
+ return_losses = False,
326
+ return_memory_mlp_params = False
234
327
  ):
235
328
  batch = actions.shape[0]
236
329
 
@@ -291,22 +384,28 @@ class mmTEM(Module):
291
384
 
292
385
  lr, forget, beta = self.to_learned_optim_hparams(joint_code_to_store).unbind(dim = -1)
293
386
 
294
- params = dict(self.meta_memory_mlp.named_parameters())
387
+ if exists(memory_mlp_params):
388
+ params, momentums = memory_mlp_params
389
+ else:
390
+ params, momentums = self.init_params_and_momentum(batch)
391
+
392
+ # store by getting gradients of mse loss of keys and values
393
+
295
394
  grads = self.per_sample_grad_fn(params, keys, values)
296
395
 
297
- # update the meta mlp parameters
396
+ # update the meta mlp parameters and momentums
298
397
 
299
- init_momentums = {k: zeros_like(v) for k, v in params.items()}
300
398
  next_params = dict()
399
+ next_momentum = dict()
301
400
 
302
401
  for (
303
402
  (key, param),
304
403
  (_, grad),
305
- (_, init_momentum)
404
+ (_, momentum)
306
405
  ) in zip(
307
406
  params.items(),
308
407
  grads.items(),
309
- init_momentums.items()
408
+ momentums.items()
310
409
  ):
311
410
 
312
411
  grad, inverse_pack = pack_with_inverse(grad, 'b t *')
@@ -315,9 +414,7 @@ class mmTEM(Module):
315
414
 
316
415
  expanded_beta = repeat(beta, 'b t -> b t w', w = grad.shape[-1])
317
416
 
318
- init_momentum = repeat(init_momentum, '... -> b ...', b = batch)
319
-
320
- update = self.assoc_scan(grad, expanded_beta.sigmoid(), init_momentum)
417
+ update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
321
418
 
322
419
  expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
323
420
 
@@ -325,7 +422,10 @@ class mmTEM(Module):
325
422
 
326
423
  acc_update = inverse_pack(acc_update)
327
424
 
425
+ # set the next params and momentum, which can be passed back in
426
+
328
427
  next_params[key] = param - acc_update[:, -1]
428
+ next_momentum[key] = update[:, -1]
329
429
 
330
430
  # losses
331
431
 
@@ -343,6 +443,9 @@ class mmTEM(Module):
343
443
  inference_pred_loss
344
444
  )
345
445
 
446
+ if return_memory_mlp_params:
447
+ return next_params, next_momentum
448
+
346
449
  if not return_losses:
347
450
  return total_loss
348
451
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.6"
3
+ version = "0.0.8"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -25,6 +25,7 @@ classifiers=[
25
25
 
26
26
  dependencies = [
27
27
  "assoc-scan",
28
+ "beartype",
28
29
  "einx>=0.3.0",
29
30
  "einops>=0.8.1",
30
31
  "torch>=2.4",
@@ -0,0 +1,64 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
4
+ import torch
5
+
6
+ def test_path_integrate():
7
+ from hippoformer.hippoformer import PathIntegration
8
+
9
+ path_integrator = PathIntegration(32, 64)
10
+
11
+ actions = torch.randn(2, 16, 32)
12
+
13
+ structure_codes = path_integrator(actions)
14
+ structure_codes = path_integrator(actions, structure_codes) # pass in previous structure codes, it will auto use the last one as hidden and pass it to the RNN
15
+
16
+ assert structure_codes.shape == (2, 16, 64)
17
+
18
+ @param('sensory_type', ('naive', '2d', '3d'))
19
+ def test_mm_tem(
20
+ sensory_type
21
+ ):
22
+ import torch
23
+ from hippoformer.hippoformer import mmTEM
24
+
25
+ from torch.nn import Linear
26
+
27
+ if sensory_type == 'naive':
28
+ enc_dec = (
29
+ Linear(11, 32),
30
+ Linear(32, 11)
31
+ )
32
+ sensory = torch.randn(2, 16, 11)
33
+
34
+ elif sensory_type == '2d':
35
+
36
+ from hippoformer.hippoformer import grid_sensory_enc_dec
37
+
38
+ enc_dec = grid_sensory_enc_dec
39
+ sensory = torch.randn(2, 16, 9)
40
+
41
+ elif sensory_type == '3d':
42
+
43
+ from hippoformer.hippoformer import maze_sensory_enc_dec
44
+
45
+ enc_dec = maze_sensory_enc_dec
46
+
47
+ sensory = torch.randn(2, 3, 16, 64, 64)
48
+
49
+ model = mmTEM(
50
+ dim = 32,
51
+ sensory_encoder_decoder = enc_dec,
52
+ dim_sensory = 11,
53
+ dim_action = 7,
54
+ dim_structure = 32,
55
+ dim_encoded_sensory = 32
56
+ )
57
+
58
+ actions = torch.randn(2, 16, 7)
59
+
60
+ next_params = model(sensory, actions, return_memory_mlp_params = True)
61
+ next_params = model(sensory, actions, memory_mlp_params = next_params, return_memory_mlp_params = True)
62
+
63
+ loss = model(sensory, actions, memory_mlp_params = next_params)
64
+ loss.backward()
@@ -1,37 +0,0 @@
1
- import pytest
2
-
3
- import torch
4
-
5
- def test_path_integrate():
6
- from hippoformer.hippoformer import PathIntegration
7
-
8
- path_integrator = PathIntegration(32, 64)
9
-
10
- actions = torch.randn(2, 16, 32)
11
-
12
- structure_codes = path_integrator(actions)
13
- structure_codes = path_integrator(actions, structure_codes) # pass in previous structure codes, it will auto use the last one as hidden and pass it to the RNN
14
-
15
- assert structure_codes.shape == (2, 16, 64)
16
-
17
- def test_mm_tem():
18
- import torch
19
- from hippoformer.hippoformer import mmTEM
20
-
21
- from torch.nn import Linear
22
-
23
- model = mmTEM(
24
- dim = 32,
25
- sensory_encoder = Linear(11, 32),
26
- sensory_decoder = Linear(32, 11),
27
- dim_sensory = 11,
28
- dim_action = 7,
29
- dim_structure = 32,
30
- dim_encoded_sensory = 32
31
- )
32
-
33
- actions = torch.randn(2, 16, 7)
34
- sensory = torch.randn(2, 16, 11)
35
-
36
- loss = model(sensory, actions)
37
- loss.backward()
File without changes
File without changes