hippoformer 0.0.7__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.7
3
+ Version: 0.0.8
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
+ Requires-Dist: beartype
38
39
  Requires-Dist: einops>=0.8.1
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: torch>=2.4
@@ -50,8 +51,6 @@ Description-Content-Type: text/markdown
50
51
 
51
52
  Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Integrating Hippocampus-inspired Spatial Memory with Transformers
52
53
 
53
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
54
-
55
54
  ## Citations
56
55
 
57
56
  ```bibtex
@@ -4,8 +4,6 @@
4
4
 
5
5
  Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Integrating Hippocampus-inspired Spatial Memory with Transformers
6
6
 
7
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
8
-
9
7
  ## Citations
10
8
 
11
9
  ```bibtex
@@ -7,6 +7,8 @@ from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
8
  from torch.func import vmap, grad, functional_call
9
9
 
10
+ from beartype import beartype
11
+
10
12
  from einx import multiply
11
13
  from einops import repeat, rearrange, pack, unpack
12
14
  from einops.layers.torch import Rearrange
@@ -36,6 +38,80 @@ def pack_with_inverse(t, pattern):
36
38
  def l2norm(t):
37
39
  return F.normalize(t, dim = -1)
38
40
 
41
+ # sensory encoder decoder for 2d
42
+
43
+ grid_sensory_enc_dec = (
44
+ create_mlp(
45
+ dim = 32 * 2,
46
+ dim_in = 9,
47
+ dim_out = 32,
48
+ depth = 3,
49
+ ),
50
+ create_mlp(
51
+ dim = 32 * 2,
52
+ dim_in = 32,
53
+ dim_out = 9,
54
+ depth = 3,
55
+ ),
56
+ )
57
+
58
+ # sensory encoder decoder for 3d maze
59
+
60
+ class EncoderPackTime(Module):
61
+ def __init__(self, fn: Module):
62
+ super().__init__()
63
+ self.fn = fn
64
+
65
+ def forward(self, x):
66
+ x = rearrange(x, 'b c t h w -> b t c h w')
67
+ x, packed_shape = pack([x], '* c h w')
68
+
69
+ x = self.fn(x)
70
+
71
+ x, = unpack(x, packed_shape, '* d')
72
+ print(x.shape)
73
+ return x
74
+
75
+ class DecoderPackTime(Module):
76
+ def __init__(self, fn: Module):
77
+ super().__init__()
78
+ self.fn = fn
79
+
80
+ def forward(self, x):
81
+ x, packed_shape = pack(x, '* d')
82
+
83
+ x = self.fn(x)
84
+
85
+ x = unpack(x, packed_shape, '* c h w')
86
+ x = rearrange(x, 'b t c h w -> b c t h w')
87
+ return x
88
+
89
+ maze_sensory_enc_dec = (
90
+ EncoderPackTime(nn.Sequential(
91
+ nn.Conv2d(3, 16, 7, 2, padding = 3),
92
+ nn.ReLU(),
93
+ nn.Conv2d(16, 32, 3, 2, 1),
94
+ nn.ReLU(),
95
+ nn.Conv2d(32, 64, 3, 2, 1),
96
+ nn.ReLU(),
97
+ nn.Conv2d(64, 128, 3, 2, 1),
98
+ nn.ReLU(),
99
+ Rearrange('b ... -> b (...)'),
100
+ nn.Linear(2048, 32)
101
+ )),
102
+ DecoderPackTime(nn.Sequential(
103
+ nn.Linear(32, 2048),
104
+ Rearrange('b (c h w) -> b c h w', c = 128, h = 4),
105
+ nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding = (1, 1)),
106
+ nn.ReLU(),
107
+ nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding = (1, 1)),
108
+ nn.ReLU(),
109
+ nn.ConvTranspose2d(32, 16, 3, 2, 1, output_padding = (1, 1)),
110
+ nn.ReLU(),
111
+ nn.ConvTranspose2d(16, 3, 3, 2, 1, output_padding = (1, 1))
112
+ ))
113
+ )
114
+
39
115
  # path integration
40
116
 
41
117
  class RNN(ScriptModule):
@@ -114,12 +190,12 @@ class PathIntegration(Module):
114
190
  # proposed mmTEM
115
191
 
116
192
  class mmTEM(Module):
193
+ @beartype
117
194
  def __init__(
118
195
  self,
119
196
  dim,
120
197
  *,
121
- sensory_encoder: Module,
122
- sensory_decoder: Module,
198
+ sensory_encoder_decoder: tuple[Module, Module],
123
199
  dim_sensory,
124
200
  dim_action,
125
201
  dim_encoded_sensory,
@@ -139,6 +215,8 @@ class mmTEM(Module):
139
215
 
140
216
  # sensory
141
217
 
218
+ sensory_encoder, sensory_decoder = sensory_encoder_decoder
219
+
142
220
  self.sensory_encoder = sensory_encoder
143
221
  self.sensory_decoder = sensory_decoder
144
222
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.7"
3
+ version = "0.0.8"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -25,6 +25,7 @@ classifiers=[
25
25
 
26
26
  dependencies = [
27
27
  "assoc-scan",
28
+ "beartype",
28
29
  "einx>=0.3.0",
29
30
  "einops>=0.8.1",
30
31
  "torch>=2.4",
@@ -1,4 +1,5 @@
1
1
  import pytest
2
+ param = pytest.mark.parametrize
2
3
 
3
4
  import torch
4
5
 
@@ -14,16 +15,40 @@ def test_path_integrate():
14
15
 
15
16
  assert structure_codes.shape == (2, 16, 64)
16
17
 
17
- def test_mm_tem():
18
+ @param('sensory_type', ('naive', '2d', '3d'))
19
+ def test_mm_tem(
20
+ sensory_type
21
+ ):
18
22
  import torch
19
23
  from hippoformer.hippoformer import mmTEM
20
24
 
21
25
  from torch.nn import Linear
22
26
 
27
+ if sensory_type == 'naive':
28
+ enc_dec = (
29
+ Linear(11, 32),
30
+ Linear(32, 11)
31
+ )
32
+ sensory = torch.randn(2, 16, 11)
33
+
34
+ elif sensory_type == '2d':
35
+
36
+ from hippoformer.hippoformer import grid_sensory_enc_dec
37
+
38
+ enc_dec = grid_sensory_enc_dec
39
+ sensory = torch.randn(2, 16, 9)
40
+
41
+ elif sensory_type == '3d':
42
+
43
+ from hippoformer.hippoformer import maze_sensory_enc_dec
44
+
45
+ enc_dec = maze_sensory_enc_dec
46
+
47
+ sensory = torch.randn(2, 3, 16, 64, 64)
48
+
23
49
  model = mmTEM(
24
50
  dim = 32,
25
- sensory_encoder = Linear(11, 32),
26
- sensory_decoder = Linear(32, 11),
51
+ sensory_encoder_decoder = enc_dec,
27
52
  dim_sensory = 11,
28
53
  dim_action = 7,
29
54
  dim_structure = 32,
@@ -31,10 +56,6 @@ def test_mm_tem():
31
56
  )
32
57
 
33
58
  actions = torch.randn(2, 16, 7)
34
- sensory = torch.randn(2, 16, 11)
35
-
36
- actions = torch.randn(2, 16, 7)
37
- sensory = torch.randn(2, 16, 11)
38
59
 
39
60
  next_params = model(sensory, actions, return_memory_mlp_params = True)
40
61
  next_params = model(sensory, actions, memory_mlp_params = next_params, return_memory_mlp_params = True)
File without changes
File without changes