hippoformer 0.0.2__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.2 → hippoformer-0.0.3}/PKG-INFO +1 -1
- {hippoformer-0.0.2 → hippoformer-0.0.3}/hippoformer/hippoformer.py +47 -4
- {hippoformer-0.0.2 → hippoformer-0.0.3}/pyproject.toml +1 -1
- {hippoformer-0.0.2 → hippoformer-0.0.3}/tests/test_hippoformer.py +1 -1
- {hippoformer-0.0.2 → hippoformer-0.0.3}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.2 → hippoformer-0.0.3}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.2 → hippoformer-0.0.3}/.gitignore +0 -0
- {hippoformer-0.0.2 → hippoformer-0.0.3}/LICENSE +0 -0
- {hippoformer-0.0.2 → hippoformer-0.0.3}/README.md +0 -0
- {hippoformer-0.0.2 → hippoformer-0.0.3}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.2 → hippoformer-0.0.3}/hippoformer-fig6.png +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import nn, Tensor, stack, einsum
|
|
4
|
+
from torch import nn, Tensor, stack, einsum, tensor
|
|
5
5
|
import torch.nn.functional as F
|
|
6
6
|
from torch.nn import Module
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
|
+
from torch.func import vmap, grad, functional_call
|
|
8
9
|
|
|
9
10
|
from einops import repeat, rearrange
|
|
10
11
|
from einops.layers.torch import Rearrange
|
|
@@ -123,8 +124,17 @@ class mmTEM(Module):
|
|
|
123
124
|
):
|
|
124
125
|
super().__init__()
|
|
125
126
|
|
|
127
|
+
# sensory
|
|
128
|
+
|
|
129
|
+
self.sensory_encoder = sensory_encoder
|
|
130
|
+
self.sensory_decoder = sensory_decoder
|
|
131
|
+
|
|
126
132
|
dim_joint_rep = dim_encoded_sensory + dim_structure
|
|
127
133
|
|
|
134
|
+
self.dim_encoded_sensory = dim_encoded_sensory
|
|
135
|
+
self.dim_structure = dim_structure
|
|
136
|
+
self.joint_dims = (dim_structure, dim_encoded_sensory)
|
|
137
|
+
|
|
128
138
|
# path integrator
|
|
129
139
|
|
|
130
140
|
self.path_integrator = PathIntegration(
|
|
@@ -139,7 +149,7 @@ class mmTEM(Module):
|
|
|
139
149
|
self.to_keys = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
140
150
|
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
141
151
|
|
|
142
|
-
self.
|
|
152
|
+
self.meta_memory_mlp = create_mlp(
|
|
143
153
|
dim = dim * 2,
|
|
144
154
|
depth = meta_mlp_depth,
|
|
145
155
|
dim_in = dim,
|
|
@@ -149,7 +159,7 @@ class mmTEM(Module):
|
|
|
149
159
|
|
|
150
160
|
# mlp decoder (from meta mlp output to joint)
|
|
151
161
|
|
|
152
|
-
self.
|
|
162
|
+
self.memory_output_decoder = create_mlp(
|
|
153
163
|
dim = dim * 2,
|
|
154
164
|
dim_in = dim,
|
|
155
165
|
dim_out = dim_joint_rep,
|
|
@@ -167,10 +177,43 @@ class mmTEM(Module):
|
|
|
167
177
|
depth = structure_variance_pred_mlp_depth
|
|
168
178
|
)
|
|
169
179
|
|
|
180
|
+
# loss related
|
|
181
|
+
|
|
182
|
+
self.loss_weight_generative = loss_weight_generative
|
|
183
|
+
self.loss_weight_inference = loss_weight_inference
|
|
184
|
+
self.loss_weight_relational = loss_weight_relational
|
|
185
|
+
self.loss_weight_consistency = loss_weight_consistency
|
|
186
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
187
|
+
|
|
170
188
|
def forward(
|
|
171
189
|
self,
|
|
172
190
|
sensory,
|
|
173
191
|
actions
|
|
174
192
|
):
|
|
175
193
|
structural_codes = self.path_integrator(actions)
|
|
176
|
-
|
|
194
|
+
|
|
195
|
+
# first have the structure code be able to fetch from the meta memory mlp
|
|
196
|
+
|
|
197
|
+
structure_codes_with_zero_sensory = F.pad(structural_codes, (0, self.dim_encoded_sensory))
|
|
198
|
+
|
|
199
|
+
queries = self.to_queries(structure_codes_with_zero_sensory)
|
|
200
|
+
|
|
201
|
+
retrieved = self.meta_memory_mlp(queries)
|
|
202
|
+
|
|
203
|
+
decoded_structure, decoded_encoded_sensory = self.memory_output_decoder(retrieved).split(self.joint_dims, dim = -1)
|
|
204
|
+
|
|
205
|
+
decoded_sensory = self.sensory_decoder(decoded_encoded_sensory)
|
|
206
|
+
|
|
207
|
+
generative_pred_loss = F.mse_loss(sensory, decoded_sensory)
|
|
208
|
+
|
|
209
|
+
# losses
|
|
210
|
+
|
|
211
|
+
total_loss = (
|
|
212
|
+
generative_pred_loss * self.loss_weight_generative
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
losses = (
|
|
216
|
+
generative_pred_loss,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return total_loss, losses
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|