hippoformer 0.0.1__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,4 @@
1
+ from hippoformer.hippoformer import (
2
+ PathIntegration,
3
+ mmTEM
4
+ )
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn, Tensor, stack, einsum, tensor
5
+ import torch.nn.functional as F
6
+ from torch.nn import Module
7
+ from torch.jit import ScriptModule, script_method
8
+ from torch.func import vmap, grad, functional_call
9
+
10
+ from einops import repeat, rearrange
11
+ from einops.layers.torch import Rearrange
12
+
13
+ from x_mlps_pytorch import create_mlp
14
+
15
+ from assoc_scan import AssocScan
16
+
17
+ # helpers
18
+
19
+ def exists(v):
20
+ return v is not None
21
+
22
+ def default(v, d):
23
+ return v if exists(v) else d
24
+
25
+ def l2norm(t):
26
+ return F.normalize(t, dim = -1)
27
+
28
+ # path integration
29
+
30
+ class RNN(ScriptModule):
31
+ def __init__(
32
+ self,
33
+ dim,
34
+ ):
35
+ super().__init__()
36
+ self.init_hidden = nn.Parameter(torch.randn(1, dim) * 1e-2)
37
+
38
+ @script_method
39
+ def forward(
40
+ self,
41
+ transitions: Tensor,
42
+ hidden: Tensor | None = None
43
+ ) -> Tensor:
44
+
45
+ batch, seq_len = transitions.shape[:2]
46
+
47
+ if hidden is None:
48
+ hidden = l2norm(self.init_hidden)
49
+ hidden = hidden.expand(batch, -1)
50
+
51
+ hiddens: list[Tensor] = []
52
+
53
+ for i in range(seq_len):
54
+ transition = transitions[:, i]
55
+
56
+ hidden = einsum('b i, b i j -> b j', hidden, transition)
57
+ hidden = F.relu(hidden)
58
+ hidden = l2norm(hidden)
59
+
60
+ hiddens.append(hidden)
61
+
62
+ return stack(hiddens, dim = 1)
63
+
64
+ class PathIntegration(Module):
65
+ def __init__(
66
+ self,
67
+ dim_action,
68
+ dim_structure,
69
+ mlp_hidden_dim = None,
70
+ mlp_depth = 2
71
+ ):
72
+ # they use the same approach from Ruiqi Gao's paper from 2021
73
+ super().__init__()
74
+
75
+ self.init_structure = nn.Parameter(torch.randn(dim_structure))
76
+
77
+ self.to_transitions = create_mlp(
78
+ default(mlp_hidden_dim, dim_action * 4),
79
+ dim_in = dim_action,
80
+ dim_out = dim_structure * dim_structure,
81
+ depth = mlp_depth
82
+ )
83
+
84
+ self.mlp_out_to_weights = Rearrange('... (i j) -> ... i j', j = dim_structure)
85
+
86
+ self.rnn = RNN(dim_structure)
87
+
88
+ def forward(
89
+ self,
90
+ actions, # (b n d)
91
+ prev_structural = None # (b n d) | (b d)
92
+ ):
93
+ batch = actions.shape[0]
94
+
95
+ transitions = self.to_transitions(actions)
96
+ transitions = self.mlp_out_to_weights(transitions)
97
+
98
+ if exists(prev_structural) and prev_structural.ndim == 3:
99
+ prev_structural = prev_structural[:, -1]
100
+
101
+ return self.rnn(transitions, prev_structural)
102
+
103
+ # proposed mmTEM
104
+
105
+ class mmTEM(Module):
106
+ def __init__(
107
+ self,
108
+ dim,
109
+ *,
110
+ sensory_encoder: Module,
111
+ sensory_decoder: Module,
112
+ dim_sensory,
113
+ dim_action,
114
+ dim_encoded_sensory,
115
+ dim_structure,
116
+ meta_mlp_depth = 2,
117
+ decoder_mlp_depth = 2,
118
+ structure_variance_pred_mlp_depth = 2,
119
+ path_integrate_kwargs: dict = dict(),
120
+ loss_weight_generative = 1.,
121
+ loss_weight_inference = 1.,
122
+ loss_weight_consistency = 1.,
123
+ loss_weight_relational = 1.,
124
+ ):
125
+ super().__init__()
126
+
127
+ # sensory
128
+
129
+ self.sensory_encoder = sensory_encoder
130
+ self.sensory_decoder = sensory_decoder
131
+
132
+ dim_joint_rep = dim_encoded_sensory + dim_structure
133
+
134
+ self.dim_encoded_sensory = dim_encoded_sensory
135
+ self.dim_structure = dim_structure
136
+ self.joint_dims = (dim_structure, dim_encoded_sensory)
137
+
138
+ # path integrator
139
+
140
+ self.path_integrator = PathIntegration(
141
+ dim_action = dim_action,
142
+ dim_structure = dim_structure,
143
+ **path_integrate_kwargs
144
+ )
145
+
146
+ # meta mlp related
147
+
148
+ self.to_queries = nn.Linear(dim_joint_rep, dim, bias = False)
149
+ self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
150
+ self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
151
+
152
+ self.meta_memory_mlp = create_mlp(
153
+ dim = dim * 2,
154
+ depth = meta_mlp_depth,
155
+ dim_in = dim,
156
+ dim_out = dim,
157
+ activation = nn.ReLU()
158
+ )
159
+
160
+ # mlp decoder (from meta mlp output to joint)
161
+
162
+ self.memory_output_decoder = create_mlp(
163
+ dim = dim * 2,
164
+ dim_in = dim,
165
+ dim_out = dim_joint_rep,
166
+ depth = decoder_mlp_depth,
167
+ activation = nn.ReLU()
168
+ )
169
+
170
+ # the mlp that predicts the variance for the structural code
171
+ # for correcting the generated structural code modeling the feedback from HC to MEC
172
+
173
+ self.structure_variance_pred_mlp_depth = create_mlp(
174
+ dim = dim_structure * 2,
175
+ dim_in = dim_structure * 2 + 1,
176
+ dim_out = dim_structure,
177
+ depth = structure_variance_pred_mlp_depth
178
+ )
179
+
180
+ # loss related
181
+
182
+ self.loss_weight_generative = loss_weight_generative
183
+ self.loss_weight_inference = loss_weight_inference
184
+ self.loss_weight_relational = loss_weight_relational
185
+ self.loss_weight_consistency = loss_weight_consistency
186
+ self.register_buffer('zero', tensor(0.), persistent = False)
187
+
188
+ def forward(
189
+ self,
190
+ sensory,
191
+ actions
192
+ ):
193
+ structural_codes = self.path_integrator(actions)
194
+
195
+ # first have the structure code be able to fetch from the meta memory mlp
196
+
197
+ structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
198
+
199
+ queries = self.to_queries(structure_codes_with_zero_sensory)
200
+
201
+ retrieved = self.meta_memory_mlp(queries)
202
+
203
+ decoded_structure, decoded_encoded_sensory = self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
204
+
205
+ decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
206
+
207
+ generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
208
+
209
+ # losses
210
+
211
+ total_loss = (
212
+ generative_pred_loss * self.loss_weight_generative
213
+ )
214
+
215
+ losses = (
216
+ generative_pred_loss,
217
+ )
218
+
219
+ return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.1"
3
+ version = "0.0.3"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,37 @@
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ def test_path_integrate():
6
+ from hippoformer.hippoformer import PathIntegration
7
+
8
+ path_integrator = PathIntegration(32, 64)
9
+
10
+ actions = torch.randn(2, 16, 32)
11
+
12
+ structure_codes = path_integrator(actions)
13
+ structure_codes = path_integrator(actions, structure_codes) # pass in previous structure codes, it will auto use the last one as hidden and pass it to the RNN
14
+
15
+ assert structure_codes.shape == (2, 16, 64)
16
+
17
+ def test_mm_tem():
18
+ import torch
19
+ from hippoformer.hippoformer import mmTEM
20
+
21
+ from torch.nn import Linear
22
+
23
+ model = mmTEM(
24
+ dim = 32,
25
+ sensory_encoder = Linear(11, 32),
26
+ sensory_decoder = Linear(32, 11),
27
+ dim_sensory = 11,
28
+ dim_action = 7,
29
+ dim_structure = 32,
30
+ dim_encoded_sensory = 32
31
+ )
32
+
33
+ actions = torch.randn(2, 16, 7)
34
+ sensory = torch.randn(2, 16, 11)
35
+
36
+ loss, losses = model(sensory, actions)
37
+ loss.backward()
File without changes
@@ -1,116 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- from torch import nn, Tensor, stack, einsum
5
- import torch.nn.functional as F
6
- from torch.nn import Module
7
- from torch.jit import ScriptModule, script_method
8
-
9
- from einops import repeat, rearrange
10
- from einops.layers.torch import Rearrange
11
-
12
- from x_mlps_pytorch import create_mlp
13
-
14
- from assoc_scan import AssocScan
15
-
16
- # helpers
17
-
18
- def exists(v):
19
- return v is not None
20
-
21
- def default(v, d):
22
- return v if exists(v) else d
23
-
24
- def l2norm(t):
25
- return F.normalize(t, dim = -1)
26
-
27
- # path integration
28
-
29
- class RNN(ScriptModule):
30
- def __init__(
31
- self,
32
- dim,
33
- ):
34
- super().__init__()
35
- self.init_hidden = nn.Parameter(torch.randn(1, dim) * 1e-2)
36
-
37
- @script_method
38
- def forward(
39
- self,
40
- transitions: Tensor,
41
- hidden: Tensor | None = None
42
- ) -> Tensor:
43
-
44
- batch, seq_len = transitions.shape[:2]
45
-
46
- if hidden is None:
47
- hidden = l2norm(self.init_hidden)
48
- hidden = hidden.expand(batch, -1)
49
-
50
- hiddens: list[Tensor] = []
51
-
52
- for i in range(seq_len):
53
- transition = transitions[:, i]
54
-
55
- hidden = einsum('b i, b i j -> b j', hidden, transition)
56
- hidden = F.relu(hidden)
57
- hidden = l2norm(hidden)
58
-
59
- hiddens.append(hidden)
60
-
61
- return stack(hiddens, dim = 1)
62
-
63
- class PathIntegration(Module):
64
- def __init__(
65
- self,
66
- dim_action,
67
- dim_structure,
68
- mlp_hidden_dim = None,
69
- mlp_depth = 2
70
- ):
71
- # they use the same approach from Ruiqi Gao's paper from 2021
72
- super().__init__()
73
-
74
- self.init_structure = nn.Parameter(torch.randn(dim_structure))
75
-
76
- self.to_transitions = create_mlp(
77
- default(mlp_hidden_dim, dim_action * 4),
78
- dim_in = dim_action,
79
- dim_out = dim_structure * dim_structure,
80
- depth = mlp_depth
81
- )
82
-
83
- self.mlp_out_to_weights = Rearrange('... (i j) -> ... i j', j = dim_structure)
84
-
85
- self.rnn = RNN(dim_structure)
86
-
87
- def forward(
88
- self,
89
- actions, # (b n d)
90
- prev_structural = None # (b n d) | (b d)
91
- ):
92
- batch = actions.shape[0]
93
-
94
- transitions = self.to_transitions(actions)
95
- transitions = self.mlp_out_to_weights(transitions)
96
-
97
- if exists(prev_structural) and prev_structural.ndim == 3:
98
- prev_structural = prev_structural[:, -1]
99
-
100
- return self.rnn(transitions, prev_structural)
101
-
102
- # proposed mmTEM
103
-
104
- class mmTEM(Module):
105
- def __init__(
106
- self,
107
- dim
108
- ):
109
- super().__init__()
110
-
111
-
112
- def forward(
113
- self,
114
- data
115
- ):
116
- raise NotImplementedError
@@ -1,15 +0,0 @@
1
- import pytest
2
-
3
- import torch
4
-
5
- def test_path_integrate():
6
- from hippoformer.hippoformer import PathIntegration
7
-
8
- path_integrator = PathIntegration(32, 64)
9
-
10
- actions = torch.randn(2, 16, 32)
11
-
12
- structure_codes = path_integrator(actions)
13
- structure_codes = path_integrator(actions, structure_codes) # pass in previous structure codes, it will auto use the last one as hidden and pass it to the RNN
14
-
15
- assert structure_codes.shape == (2, 16, 64)
File without changes
File without changes
File without changes