hippoformer 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hippoformer/__init__.py CHANGED
@@ -0,0 +1,4 @@
1
+ from hippoformer.hippoformer import (
2
+ PathIntegration,
3
+ mmTEM
4
+ )
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, Tensor, stack, einsum
4
+ from torch import nn, Tensor, stack, einsum, tensor
5
5
  import torch.nn.functional as F
6
6
  from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
+ from torch.func import vmap, grad, functional_call
8
9
 
9
10
  from einops import repeat, rearrange
10
11
  from einops.layers.torch import Rearrange
@@ -104,13 +105,115 @@ class PathIntegration(Module):
104
105
  class mmTEM(Module):
105
106
  def __init__(
106
107
  self,
107
- dim
108
+ dim,
109
+ *,
110
+ sensory_encoder: Module,
111
+ sensory_decoder: Module,
112
+ dim_sensory,
113
+ dim_action,
114
+ dim_encoded_sensory,
115
+ dim_structure,
116
+ meta_mlp_depth = 2,
117
+ decoder_mlp_depth = 2,
118
+ structure_variance_pred_mlp_depth = 2,
119
+ path_integrate_kwargs: dict = dict(),
120
+ loss_weight_generative = 1.,
121
+ loss_weight_inference = 1.,
122
+ loss_weight_consistency = 1.,
123
+ loss_weight_relational = 1.,
108
124
  ):
109
125
  super().__init__()
110
126
 
127
+ # sensory
128
+
129
+ self.sensory_encoder = sensory_encoder
130
+ self.sensory_decoder = sensory_decoder
131
+
132
+ dim_joint_rep = dim_encoded_sensory + dim_structure
133
+
134
+ self.dim_encoded_sensory = dim_encoded_sensory
135
+ self.dim_structure = dim_structure
136
+ self.joint_dims = (dim_structure, dim_encoded_sensory)
137
+
138
+ # path integrator
139
+
140
+ self.path_integrator = PathIntegration(
141
+ dim_action = dim_action,
142
+ dim_structure = dim_structure,
143
+ **path_integrate_kwargs
144
+ )
145
+
146
+ # meta mlp related
147
+
148
+ self.to_queries = nn.Linear(dim_joint_rep, dim, bias = False)
149
+ self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
150
+ self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
151
+
152
+ self.meta_memory_mlp = create_mlp(
153
+ dim = dim * 2,
154
+ depth = meta_mlp_depth,
155
+ dim_in = dim,
156
+ dim_out = dim,
157
+ activation = nn.ReLU()
158
+ )
159
+
160
+ # mlp decoder (from meta mlp output to joint)
161
+
162
+ self.memory_output_decoder = create_mlp(
163
+ dim = dim * 2,
164
+ dim_in = dim,
165
+ dim_out = dim_joint_rep,
166
+ depth = decoder_mlp_depth,
167
+ activation = nn.ReLU()
168
+ )
169
+
170
+ # the mlp that predicts the variance for the structural code
171
+ # for correcting the generated structural code modeling the feedback from HC to MEC
172
+
173
+ self.structure_variance_pred_mlp_depth = create_mlp(
174
+ dim = dim_structure * 2,
175
+ dim_in = dim_structure * 2 + 1,
176
+ dim_out = dim_structure,
177
+ depth = structure_variance_pred_mlp_depth
178
+ )
179
+
180
+ # loss related
181
+
182
+ self.loss_weight_generative = loss_weight_generative
183
+ self.loss_weight_inference = loss_weight_inference
184
+ self.loss_weight_relational = loss_weight_relational
185
+ self.loss_weight_consistency = loss_weight_consistency
186
+ self.register_buffer('zero', tensor(0.), persistent = False)
111
187
 
112
188
  def forward(
113
189
  self,
114
- data
190
+ sensory,
191
+ actions
115
192
  ):
116
- raise NotImplementedError
193
+ structural_codes = self.path_integrator(actions)
194
+
195
+ # first have the structure code be able to fetch from the meta memory mlp
196
+
197
+ structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
198
+
199
+ queries = self.to_queries(structure_codes_with_zero_sensory)
200
+
201
+ retrieved = self.meta_memory_mlp(queries)
202
+
203
+ decoded_structure, decoded_encoded_sensory = self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
204
+
205
+ decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
206
+
207
+ generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
208
+
209
+ # losses
210
+
211
+ total_loss = (
212
+ generative_pred_loss * self.loss_weight_generative
213
+ )
214
+
215
+ losses = (
216
+ generative_pred_loss,
217
+ )
218
+
219
+ return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=XnLDm12BHuJSOcb1Y8EKCSqkvdw20anuGT2O4kFsWHg,6041
3
+ hippoformer-0.0.3.dist-info/METADATA,sha256=L1SJ76ffExU4DQen0ppCv3gzt7waXXaE6znPuon60_g,2773
4
+ hippoformer-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.3.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- hippoformer/hippoformer.py,sha256=6tA4ZWYKbzclpeTUhJtr2OguVOyyAGFxuLf9bfnfO_M,2682
3
- hippoformer-0.0.1.dist-info/METADATA,sha256=4hnfh1oIIlcGsIQ7qD7fZHWfM5ltnHhATAPcN-4vkxQ,2773
4
- hippoformer-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.1.dist-info/RECORD,,