hippoformer 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, Tensor, stack, einsum
4
+ from torch import nn, Tensor, stack, einsum, tensor
5
5
  import torch.nn.functional as F
6
6
  from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
+ from torch.func import vmap, grad, functional_call
8
9
 
9
10
  from einops import repeat, rearrange
10
11
  from einops.layers.torch import Rearrange
@@ -123,8 +124,17 @@ class mmTEM(Module):
123
124
  ):
124
125
  super().__init__()
125
126
 
127
+ # sensory
128
+
129
+ self.sensory_encoder = sensory_encoder
130
+ self.sensory_decoder = sensory_decoder
131
+
126
132
  dim_joint_rep = dim_encoded_sensory + dim_structure
127
133
 
134
+ self.dim_encoded_sensory = dim_encoded_sensory
135
+ self.dim_structure = dim_structure
136
+ self.joint_dims = (dim_structure, dim_encoded_sensory)
137
+
128
138
  # path integrator
129
139
 
130
140
  self.path_integrator = PathIntegration(
@@ -139,7 +149,7 @@ class mmTEM(Module):
139
149
  self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
140
150
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
141
151
 
142
- self.meta_mlp = create_mlp(
152
+ self.meta_memory_mlp = create_mlp(
143
153
  dim = dim * 2,
144
154
  depth = meta_mlp_depth,
145
155
  dim_in = dim,
@@ -149,7 +159,7 @@ class mmTEM(Module):
149
159
 
150
160
  # mlp decoder (from meta mlp output to joint)
151
161
 
152
- self.meta_mlp_output_decoder = create_mlp(
162
+ self.memory_output_decoder = create_mlp(
153
163
  dim = dim * 2,
154
164
  dim_in = dim,
155
165
  dim_out = dim_joint_rep,
@@ -167,10 +177,43 @@ class mmTEM(Module):
167
177
  depth = structure_variance_pred_mlp_depth
168
178
  )
169
179
 
180
+ # loss related
181
+
182
+ self.loss_weight_generative = loss_weight_generative
183
+ self.loss_weight_inference = loss_weight_inference
184
+ self.loss_weight_relational = loss_weight_relational
185
+ self.loss_weight_consistency = loss_weight_consistency
186
+ self.register_buffer('zero', tensor(0.), persistent = False)
187
+
170
188
  def forward(
171
189
  self,
172
190
  sensory,
173
191
  actions
174
192
  ):
175
193
  structural_codes = self.path_integrator(actions)
176
- return structural_codes.sum()
194
+
195
+ # first have the structure code be able to fetch from the meta memory mlp
196
+
197
+ structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
198
+
199
+ queries = self.to_queries(structure_codes_with_zero_sensory)
200
+
201
+ retrieved = self.meta_memory_mlp(queries)
202
+
203
+ decoded_structure, decoded_encoded_sensory = self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
204
+
205
+ decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
206
+
207
+ generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
208
+
209
+ # losses
210
+
211
+ total_loss = (
212
+ generative_pred_loss * self.loss_weight_generative
213
+ )
214
+
215
+ losses = (
216
+ generative_pred_loss,
217
+ )
218
+
219
+ return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=XnLDm12BHuJSOcb1Y8EKCSqkvdw20anuGT2O4kFsWHg,6041
3
+ hippoformer-0.0.3.dist-info/METADATA,sha256=L1SJ76ffExU4DQen0ppCv3gzt7waXXaE6znPuon60_g,2773
4
+ hippoformer-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.3.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=yywVJJrrB1IilD_hGALRblBlBhoYGDPIYpwjNCvL3u8,4616
3
- hippoformer-0.0.2.dist-info/METADATA,sha256=5E0PLeUouF-6iq8Zrw5sSuFP5xdHm0NYGlE7lyo2-ls,2773
4
- hippoformer-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.2.dist-info/RECORD,,