hippoformer 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +104 -5
- {hippoformer-0.0.2.dist-info → hippoformer-0.0.4.dist-info}/METADATA +1 -1
- hippoformer-0.0.4.dist-info/RECORD +6 -0
- hippoformer-0.0.2.dist-info/RECORD +0 -6
- {hippoformer-0.0.2.dist-info → hippoformer-0.0.4.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.2.dist-info → hippoformer-0.0.4.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import nn, Tensor, stack, einsum
|
|
4
|
+
from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
|
|
5
5
|
import torch.nn.functional as F
|
|
6
6
|
from torch.nn import Module
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
|
+
from torch.func import vmap, grad, functional_call
|
|
8
9
|
|
|
9
10
|
from einops import repeat, rearrange
|
|
10
11
|
from einops.layers.torch import Rearrange
|
|
@@ -120,11 +121,21 @@ class mmTEM(Module):
|
|
|
120
121
|
loss_weight_inference = 1.,
|
|
121
122
|
loss_weight_consistency = 1.,
|
|
122
123
|
loss_weight_relational = 1.,
|
|
124
|
+
integration_ratio_learned = True
|
|
123
125
|
):
|
|
124
126
|
super().__init__()
|
|
125
127
|
|
|
128
|
+
# sensory
|
|
129
|
+
|
|
130
|
+
self.sensory_encoder = sensory_encoder
|
|
131
|
+
self.sensory_decoder = sensory_decoder
|
|
132
|
+
|
|
126
133
|
dim_joint_rep = dim_encoded_sensory + dim_structure
|
|
127
134
|
|
|
135
|
+
self.dim_encoded_sensory = dim_encoded_sensory
|
|
136
|
+
self.dim_structure = dim_structure
|
|
137
|
+
self.joint_dims = (dim_structure, dim_encoded_sensory)
|
|
138
|
+
|
|
128
139
|
# path integrator
|
|
129
140
|
|
|
130
141
|
self.path_integrator = PathIntegration(
|
|
@@ -139,7 +150,7 @@ class mmTEM(Module):
|
|
|
139
150
|
self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
140
151
|
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
141
152
|
|
|
142
|
-
self.
|
|
153
|
+
self.meta_memory_mlp = create_mlp(
|
|
143
154
|
dim = dim * 2,
|
|
144
155
|
depth = meta_mlp_depth,
|
|
145
156
|
dim_in = dim,
|
|
@@ -149,7 +160,7 @@ class mmTEM(Module):
|
|
|
149
160
|
|
|
150
161
|
# mlp decoder (from meta mlp output to joint)
|
|
151
162
|
|
|
152
|
-
self.
|
|
163
|
+
self.memory_output_decoder = create_mlp(
|
|
153
164
|
dim = dim * 2,
|
|
154
165
|
dim_in = dim,
|
|
155
166
|
dim_out = dim_joint_rep,
|
|
@@ -160,17 +171,105 @@ class mmTEM(Module):
|
|
|
160
171
|
# the mlp that predicts the variance for the structural code
|
|
161
172
|
# for correcting the generated structural code modeling the feedback from HC to MEC
|
|
162
173
|
|
|
163
|
-
self.
|
|
174
|
+
self.structure_variance_pred_mlp = create_mlp(
|
|
164
175
|
dim = dim_structure * 2,
|
|
165
176
|
dim_in = dim_structure * 2 + 1,
|
|
166
177
|
dim_out = dim_structure,
|
|
167
178
|
depth = structure_variance_pred_mlp_depth
|
|
168
179
|
)
|
|
169
180
|
|
|
181
|
+
# loss related
|
|
182
|
+
|
|
183
|
+
self.loss_weight_generative = loss_weight_generative
|
|
184
|
+
self.loss_weight_inference = loss_weight_inference
|
|
185
|
+
self.loss_weight_relational = loss_weight_relational
|
|
186
|
+
self.loss_weight_consistency = loss_weight_consistency
|
|
187
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
188
|
+
|
|
189
|
+
# there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
|
|
190
|
+
|
|
191
|
+
self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
|
|
192
|
+
|
|
193
|
+
def retrieve(
|
|
194
|
+
self,
|
|
195
|
+
structural_codes,
|
|
196
|
+
encoded_sensory
|
|
197
|
+
):
|
|
198
|
+
joint = cat((structural_codes, encoded_sensory), dim = -1)
|
|
199
|
+
|
|
200
|
+
queries = self.to_queries(joint)
|
|
201
|
+
|
|
202
|
+
retrieved = self.meta_memory_mlp(queries)
|
|
203
|
+
|
|
204
|
+
return self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
|
|
205
|
+
|
|
170
206
|
def forward(
|
|
171
207
|
self,
|
|
172
208
|
sensory,
|
|
173
209
|
actions
|
|
174
210
|
):
|
|
175
211
|
structural_codes = self.path_integrator(actions)
|
|
176
|
-
|
|
212
|
+
|
|
213
|
+
encoded_sensory = self.sensory_encoder(sensory)
|
|
214
|
+
|
|
215
|
+
# 1. first have the structure code be able to fetch from the meta memory mlp
|
|
216
|
+
|
|
217
|
+
decoded_gen_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
|
|
218
|
+
|
|
219
|
+
decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
|
|
220
|
+
|
|
221
|
+
generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
|
|
222
|
+
|
|
223
|
+
# 2. relational
|
|
224
|
+
|
|
225
|
+
# 2a. structure from content
|
|
226
|
+
|
|
227
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
|
|
228
|
+
|
|
229
|
+
structure_from_content_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
230
|
+
|
|
231
|
+
# 2b. structure from structure
|
|
232
|
+
|
|
233
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
|
|
234
|
+
|
|
235
|
+
structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
236
|
+
|
|
237
|
+
relational_loss = structure_from_content_loss + structure_from_structure_loss
|
|
238
|
+
|
|
239
|
+
# 3. consistency - modeling a feedback system from hippocampus to path integration
|
|
240
|
+
|
|
241
|
+
corrected_structural_code, corrected_encoded_sensory = self.retrieve(decoded_gen_structure, encoded_sensory)
|
|
242
|
+
|
|
243
|
+
sensory_sse = (corrected_encoded_sensory - encoded_sensory).norm(dim = -1, keepdim = True).pow(2)
|
|
244
|
+
|
|
245
|
+
pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
|
|
246
|
+
|
|
247
|
+
inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio * pred_variance
|
|
248
|
+
|
|
249
|
+
consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
|
|
250
|
+
|
|
251
|
+
# 4. final inference loss
|
|
252
|
+
|
|
253
|
+
_, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
|
|
254
|
+
|
|
255
|
+
decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
|
|
256
|
+
|
|
257
|
+
inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
|
|
258
|
+
|
|
259
|
+
# losses
|
|
260
|
+
|
|
261
|
+
total_loss = (
|
|
262
|
+
generative_pred_loss * self.loss_weight_generative +
|
|
263
|
+
relational_loss * self.loss_weight_relational +
|
|
264
|
+
consistency_loss * self.loss_weight_consistency +
|
|
265
|
+
inference_pred_loss * self.loss_weight_inference
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
losses = (
|
|
269
|
+
generative_pred_loss,
|
|
270
|
+
relational_loss,
|
|
271
|
+
consistency_loss,
|
|
272
|
+
inference_pred_loss
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return total_loss, losses
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=r_kn8kQ8js_Fd5wufj-I8EbE3w8b7SUZm47rUJtt4aY,8329
|
|
3
|
+
hippoformer-0.0.4.dist-info/METADATA,sha256=8geT7mVp0r4WHw3uf860xwWGMpYYX43Rum_PNDUMfmw,2773
|
|
4
|
+
hippoformer-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
hippoformer-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.4.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=yywVJJrrB1IilD_hGALRblBlBhoYGDPIYpwjNCvL3u8,4616
|
|
3
|
-
hippoformer-0.0.2.dist-info/METADATA,sha256=5E0PLeUouF-6iq8Zrw5sSuFP5xdHm0NYGlE7lyo2-ls,2773
|
|
4
|
-
hippoformer-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|