hippoformer 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, Tensor, stack, einsum
4
+ from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
5
5
  import torch.nn.functional as F
6
6
  from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
+ from torch.func import vmap, grad, functional_call
8
9
 
9
10
  from einops import repeat, rearrange
10
11
  from einops.layers.torch import Rearrange
@@ -120,11 +121,21 @@ class mmTEM(Module):
120
121
  loss_weight_inference = 1.,
121
122
  loss_weight_consistency = 1.,
122
123
  loss_weight_relational = 1.,
124
+ integration_ratio_learned = True
123
125
  ):
124
126
  super().__init__()
125
127
 
128
+ # sensory
129
+
130
+ self.sensory_encoder = sensory_encoder
131
+ self.sensory_decoder = sensory_decoder
132
+
126
133
  dim_joint_rep = dim_encoded_sensory + dim_structure
127
134
 
135
+ self.dim_encoded_sensory = dim_encoded_sensory
136
+ self.dim_structure = dim_structure
137
+ self.joint_dims = (dim_structure, dim_encoded_sensory)
138
+
128
139
  # path integrator
129
140
 
130
141
  self.path_integrator = PathIntegration(
@@ -139,7 +150,7 @@ class mmTEM(Module):
139
150
  self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
140
151
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
141
152
 
142
- self.meta_mlp = create_mlp(
153
+ self.meta_memory_mlp = create_mlp(
143
154
  dim = dim * 2,
144
155
  depth = meta_mlp_depth,
145
156
  dim_in = dim,
@@ -149,7 +160,7 @@ class mmTEM(Module):
149
160
 
150
161
  # mlp decoder (from meta mlp output to joint)
151
162
 
152
- self.meta_mlp_output_decoder = create_mlp(
163
+ self.memory_output_decoder = create_mlp(
153
164
  dim = dim * 2,
154
165
  dim_in = dim,
155
166
  dim_out = dim_joint_rep,
@@ -160,17 +171,105 @@ class mmTEM(Module):
160
171
  # the mlp that predicts the variance for the structural code
161
172
  # for correcting the generated structural code modeling the feedback from HC to MEC
162
173
 
163
- self.structure_variance_pred_mlp_depth = create_mlp(
174
+ self.structure_variance_pred_mlp = create_mlp(
164
175
  dim = dim_structure * 2,
165
176
  dim_in = dim_structure * 2 + 1,
166
177
  dim_out = dim_structure,
167
178
  depth = structure_variance_pred_mlp_depth
168
179
  )
169
180
 
181
+ # loss related
182
+
183
+ self.loss_weight_generative = loss_weight_generative
184
+ self.loss_weight_inference = loss_weight_inference
185
+ self.loss_weight_relational = loss_weight_relational
186
+ self.loss_weight_consistency = loss_weight_consistency
187
+ self.register_buffer('zero', tensor(0.), persistent = False)
188
+
189
+ # there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
190
+
191
+ self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
192
+
193
+ def retrieve(
194
+ self,
195
+ structural_codes,
196
+ encoded_sensory
197
+ ):
198
+ joint = cat((structural_codes, encoded_sensory), dim = -1)
199
+
200
+ queries = self.to_queries(joint)
201
+
202
+ retrieved = self.meta_memory_mlp(queries)
203
+
204
+ return self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
205
+
170
206
  def forward(
171
207
  self,
172
208
  sensory,
173
209
  actions
174
210
  ):
175
211
  structural_codes = self.path_integrator(actions)
176
- return structural_codes.sum()
212
+
213
+ encoded_sensory = self.sensory_encoder(sensory)
214
+
215
+ # 1. first have the structure code be able to fetch from the meta memory mlp
216
+
217
+ decoded_gen_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
218
+
219
+ decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
220
+
221
+ generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
222
+
223
+ # 2. relational
224
+
225
+ # 2a. structure from content
226
+
227
+ decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
228
+
229
+ structure_from_content_loss = F.mse_loss(decoded_structure, structural_codes)
230
+
231
+ # 2b. structure from structure
232
+
233
+ decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
234
+
235
+ structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
236
+
237
+ relational_loss = structure_from_content_loss + structure_from_structure_loss
238
+
239
+ # 3. consistency - modeling a feedback system from hippocampus to path integration
240
+
241
+ corrected_structural_code, corrected_encoded_sensory = self.retrieve(decoded_gen_structure, encoded_sensory)
242
+
243
+ sensory_sse = (corrected_encoded_sensory - encoded_sensory).norm(dim = -1, keepdim = True).pow(2)
244
+
245
+ pred_variance = self.structure_variance_pred_mlp(cat((corrected_structural_code, decoded_gen_structure, sensory_sse), dim = -1))
246
+
247
+ inf_structural_code = decoded_gen_structure + (corrected_structural_code - decoded_gen_structure) * self.integration_ratio * pred_variance
248
+
249
+ consistency_loss = F.mse_loss(decoded_gen_structure, inf_structural_code)
250
+
251
+ # 4. final inference loss
252
+
253
+ _, inf_encoded_sensory = self.retrieve(inf_structural_code, zeros_like(encoded_sensory))
254
+
255
+ decoded_inf_sensory = self.sensory_decoder(inf_encoded_sensory)
256
+
257
+ inference_pred_loss = F.mse_loss(sensory, decoded_inf_sensory)
258
+
259
+ # losses
260
+
261
+ total_loss = (
262
+ generative_pred_loss * self.loss_weight_generative +
263
+ relational_loss * self.loss_weight_relational +
264
+ consistency_loss * self.loss_weight_consistency +
265
+ inference_pred_loss * self.loss_weight_inference
266
+ )
267
+
268
+ losses = (
269
+ generative_pred_loss,
270
+ relational_loss,
271
+ consistency_loss,
272
+ inference_pred_loss
273
+ )
274
+
275
+ return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.2"
3
+ version = "0.0.4"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -33,5 +33,5 @@ def test_mm_tem():
33
33
  actions = torch.randn(2, 16, 7)
34
34
  sensory = torch.randn(2, 16, 11)
35
35
 
36
- loss = model(sensory, actions)
36
+ loss, losses = model(sensory, actions)
37
37
  loss.backward()
File without changes
File without changes
File without changes