egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import einops
|
|
4
|
+
from baselines.rum.models.bet.vqvae.residual_vq import ResidualVQ
|
|
5
|
+
from baselines.rum.models.bet.vqvae.vqvae_utils import get_tensor, weights_init_encoder
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EncoderMLP(nn.Module):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
input_dim,
|
|
12
|
+
output_dim=16,
|
|
13
|
+
hidden_dim=128,
|
|
14
|
+
layer_num=1,
|
|
15
|
+
last_activation=None,
|
|
16
|
+
):
|
|
17
|
+
super(EncoderMLP, self).__init__()
|
|
18
|
+
layers = []
|
|
19
|
+
|
|
20
|
+
layers.append(nn.Linear(input_dim, hidden_dim))
|
|
21
|
+
layers.append(nn.ReLU())
|
|
22
|
+
for _ in range(layer_num):
|
|
23
|
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
|
24
|
+
layers.append(nn.ReLU())
|
|
25
|
+
|
|
26
|
+
self.encoder = nn.Sequential(*layers)
|
|
27
|
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
|
28
|
+
|
|
29
|
+
if last_activation is not None:
|
|
30
|
+
self.last_layer = last_activation
|
|
31
|
+
else:
|
|
32
|
+
self.last_layer = None
|
|
33
|
+
self.apply(weights_init_encoder)
|
|
34
|
+
|
|
35
|
+
def forward(self, x):
|
|
36
|
+
h = self.encoder(x)
|
|
37
|
+
state = self.fc(h)
|
|
38
|
+
if self.last_layer:
|
|
39
|
+
state = self.last_layer(state)
|
|
40
|
+
return state
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CondiitonalEncoderMLP(nn.Module):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
input_dim,
|
|
47
|
+
output_dim=16,
|
|
48
|
+
hidden_dim=512,
|
|
49
|
+
layer_num=2,
|
|
50
|
+
last_activation=None,
|
|
51
|
+
obs_dim=None,
|
|
52
|
+
):
|
|
53
|
+
super(CondiitonalEncoderMLP, self).__init__()
|
|
54
|
+
layers = []
|
|
55
|
+
|
|
56
|
+
layers.append(nn.Linear(input_dim + obs_dim, hidden_dim))
|
|
57
|
+
layers.append(nn.ReLU())
|
|
58
|
+
for _ in range(layer_num):
|
|
59
|
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
|
60
|
+
layers.append(nn.ReLU())
|
|
61
|
+
|
|
62
|
+
self.encoder = nn.Sequential(*layers)
|
|
63
|
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
|
64
|
+
|
|
65
|
+
if last_activation is not None:
|
|
66
|
+
self.last_layer = last_activation
|
|
67
|
+
else:
|
|
68
|
+
self.last_layer = None
|
|
69
|
+
self.apply(weights_init_encoder)
|
|
70
|
+
|
|
71
|
+
def forward(self, x, obs):
|
|
72
|
+
x = torch.cat((x, obs), dim=1)
|
|
73
|
+
h = self.encoder(x)
|
|
74
|
+
state = self.fc(h)
|
|
75
|
+
if self.last_layer:
|
|
76
|
+
state = self.last_layer(state)
|
|
77
|
+
return state
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class VqVae(nn.Module):
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
obs_dim=60,
|
|
84
|
+
input_dim_h=10, # length of action chunk
|
|
85
|
+
input_dim_w=9, # action dim
|
|
86
|
+
n_latent_dims=512,
|
|
87
|
+
vqvae_n_embed=32,
|
|
88
|
+
vqvae_groups=4,
|
|
89
|
+
eval=True,
|
|
90
|
+
device="cuda",
|
|
91
|
+
load_dir=None,
|
|
92
|
+
enc_loss_type="skip_vqlayer",
|
|
93
|
+
obs_cond=False,
|
|
94
|
+
encoder_loss_multiplier=1.0,
|
|
95
|
+
act_scale=1.0,
|
|
96
|
+
):
|
|
97
|
+
super(VqVae, self).__init__()
|
|
98
|
+
self.n_latent_dims = n_latent_dims # 64
|
|
99
|
+
self.input_dim_h = input_dim_h
|
|
100
|
+
self.input_dim_w = input_dim_w
|
|
101
|
+
self.rep_dim = self.n_latent_dims
|
|
102
|
+
self.vqvae_n_embed = vqvae_n_embed # 120
|
|
103
|
+
self.vqvae_lr = 1e-3
|
|
104
|
+
self.vqvae_groups = vqvae_groups
|
|
105
|
+
self.device = device
|
|
106
|
+
self.enc_loss_type = enc_loss_type
|
|
107
|
+
self.obs_cond = obs_cond
|
|
108
|
+
self.encoder_loss_multiplier = encoder_loss_multiplier
|
|
109
|
+
self.act_scale = act_scale
|
|
110
|
+
|
|
111
|
+
discrete_cfg = {"groups": self.vqvae_groups, "n_embed": self.vqvae_n_embed}
|
|
112
|
+
self.vq_layer = ResidualVQ(
|
|
113
|
+
dim=self.n_latent_dims,
|
|
114
|
+
num_quantizers=discrete_cfg["groups"],
|
|
115
|
+
codebook_size=self.vqvae_n_embed,
|
|
116
|
+
eval=eval,
|
|
117
|
+
).to(self.device)
|
|
118
|
+
self.embedding_dim = self.n_latent_dims
|
|
119
|
+
self.vq_layer.device = device
|
|
120
|
+
|
|
121
|
+
if self.input_dim_h == 1:
|
|
122
|
+
if self.obs_cond:
|
|
123
|
+
self.encoder = CondiitonalEncoderMLP(
|
|
124
|
+
input_dim=input_dim_w, output_dim=n_latent_dims, obs_dim=obs_dim
|
|
125
|
+
).to(self.device)
|
|
126
|
+
self.decoder = CondiitonalEncoderMLP(
|
|
127
|
+
input_dim=n_latent_dims, output_dim=input_dim_w, obs_dim=obs_dim
|
|
128
|
+
).to(self.device)
|
|
129
|
+
else:
|
|
130
|
+
self.encoder = EncoderMLP(
|
|
131
|
+
input_dim=input_dim_w, output_dim=n_latent_dims
|
|
132
|
+
).to(self.device)
|
|
133
|
+
self.decoder = EncoderMLP(
|
|
134
|
+
input_dim=n_latent_dims, output_dim=input_dim_w
|
|
135
|
+
).to(self.device)
|
|
136
|
+
else:
|
|
137
|
+
if self.obs_cond:
|
|
138
|
+
self.encoder = CondiitonalEncoderMLP(
|
|
139
|
+
input_dim=input_dim_w * self.input_dim_h,
|
|
140
|
+
output_dim=n_latent_dims,
|
|
141
|
+
obs_dim=obs_dim,
|
|
142
|
+
).to(self.device)
|
|
143
|
+
self.decoder = CondiitonalEncoderMLP(
|
|
144
|
+
input_dim=n_latent_dims,
|
|
145
|
+
output_dim=input_dim_w * self.input_dim_h,
|
|
146
|
+
obs_dim=obs_dim,
|
|
147
|
+
).to(self.device)
|
|
148
|
+
else:
|
|
149
|
+
self.encoder = EncoderMLP(
|
|
150
|
+
input_dim=input_dim_w * self.input_dim_h, output_dim=n_latent_dims
|
|
151
|
+
).to(self.device)
|
|
152
|
+
self.decoder = EncoderMLP(
|
|
153
|
+
input_dim=n_latent_dims, output_dim=input_dim_w * self.input_dim_h
|
|
154
|
+
).to(self.device)
|
|
155
|
+
|
|
156
|
+
# params = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.vq_layer.parameters())
|
|
157
|
+
# self.vqvae_optimizer = torch.optim.Adam(params, lr=self.vqvae_lr, weight_decay=0.0001)
|
|
158
|
+
|
|
159
|
+
if load_dir is not None:
|
|
160
|
+
try:
|
|
161
|
+
state_dict = torch.load(load_dir)
|
|
162
|
+
except RuntimeError:
|
|
163
|
+
state_dict = torch.load(load_dir, map_location=torch.device("cpu"))
|
|
164
|
+
|
|
165
|
+
new_dict = {}
|
|
166
|
+
|
|
167
|
+
prefix_to_remove = "_rvq."
|
|
168
|
+
|
|
169
|
+
for key, value in state_dict["loss_fn"].items():
|
|
170
|
+
if key.startswith(prefix_to_remove):
|
|
171
|
+
new_key = key[len(prefix_to_remove) :] # Remove the prefix
|
|
172
|
+
new_dict[new_key] = value
|
|
173
|
+
else:
|
|
174
|
+
new_dict[key] = value
|
|
175
|
+
|
|
176
|
+
self.load_state_dict(new_dict, strict=False)
|
|
177
|
+
|
|
178
|
+
if eval:
|
|
179
|
+
self.vq_layer.eval()
|
|
180
|
+
else:
|
|
181
|
+
self.vq_layer.train()
|
|
182
|
+
|
|
183
|
+
def draw_logits_forward(self, encoding_logits):
|
|
184
|
+
z_embed = self.vq_layer.draw_logits_forward(encoding_logits)
|
|
185
|
+
return z_embed
|
|
186
|
+
|
|
187
|
+
def draw_code_forward(self, encoding_indices):
|
|
188
|
+
with torch.no_grad():
|
|
189
|
+
z_embed = self.vq_layer.get_codes_from_indices(encoding_indices)
|
|
190
|
+
z_embed = z_embed.sum(dim=0)
|
|
191
|
+
return z_embed
|
|
192
|
+
|
|
193
|
+
def get_action_from_latent(self, latent, obs=None):
|
|
194
|
+
if self.obs_cond:
|
|
195
|
+
output = self.decoder(latent, obs[:, -1]) * self.act_scale
|
|
196
|
+
else:
|
|
197
|
+
output = self.decoder(latent) * self.act_scale
|
|
198
|
+
if self.input_dim_h == 1:
|
|
199
|
+
return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w)
|
|
200
|
+
else:
|
|
201
|
+
return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w)
|
|
202
|
+
|
|
203
|
+
def preprocess(self, state):
|
|
204
|
+
if not torch.is_tensor(state):
|
|
205
|
+
state = get_tensor(state, self.device)
|
|
206
|
+
if self.input_dim_h == 1:
|
|
207
|
+
state = state.squeeze(-2) # state.squeeze(-1)
|
|
208
|
+
else:
|
|
209
|
+
state = einops.rearrange(state, "N T A -> N (T A)")
|
|
210
|
+
return state.to(self.device)
|
|
211
|
+
|
|
212
|
+
def get_code(self, state, obs=None, required_recon=False):
|
|
213
|
+
state = state / self.act_scale
|
|
214
|
+
state = self.preprocess(state)
|
|
215
|
+
with torch.no_grad():
|
|
216
|
+
if self.obs_cond:
|
|
217
|
+
state_rep = self.encoder(state, obs[:, -1])
|
|
218
|
+
else:
|
|
219
|
+
state_rep = self.encoder(state)
|
|
220
|
+
state_rep_shape = state_rep.shape[:-1]
|
|
221
|
+
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
|
|
222
|
+
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
|
|
223
|
+
state_vq = state_rep_flat.view(*state_rep_shape, -1)
|
|
224
|
+
vq_code = vq_code.view(*state_rep_shape, -1)
|
|
225
|
+
vq_loss_state = torch.sum(vq_loss_state)
|
|
226
|
+
if required_recon:
|
|
227
|
+
if self.obs_cond:
|
|
228
|
+
recon_state = self.decoder(state_vq, obs[:, -1]) * self.act_scale
|
|
229
|
+
recon_state_ae = (
|
|
230
|
+
self.decoder(state_rep, obs[:, -1]) * self.act_scale
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
recon_state = self.decoder(state_vq) * self.act_scale
|
|
234
|
+
recon_state_ae = self.decoder(state_rep) * self.act_scale
|
|
235
|
+
if self.input_dim_h == 1:
|
|
236
|
+
return state_vq, vq_code, recon_state, recon_state_ae
|
|
237
|
+
else:
|
|
238
|
+
return (
|
|
239
|
+
state_vq,
|
|
240
|
+
vq_code,
|
|
241
|
+
torch.swapaxes(recon_state, -2, -1),
|
|
242
|
+
torch.swapaxes(recon_state_ae, -2, -1),
|
|
243
|
+
)
|
|
244
|
+
else:
|
|
245
|
+
return state_vq, vq_code
|
|
246
|
+
|
|
247
|
+
def vqvae_update(self, state, obs=None):
|
|
248
|
+
state = state / self.act_scale
|
|
249
|
+
state = self.preprocess(state)
|
|
250
|
+
if self.obs_cond:
|
|
251
|
+
state_rep = self.encoder(state, obs[:, -1])
|
|
252
|
+
else:
|
|
253
|
+
state_rep = self.encoder(state)
|
|
254
|
+
state_rep_shape = state_rep.shape[:-1]
|
|
255
|
+
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
|
|
256
|
+
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
|
|
257
|
+
state_vq = state_rep_flat.view(*state_rep_shape, -1)
|
|
258
|
+
vq_code = vq_code.view(*state_rep_shape, -1)
|
|
259
|
+
vq_loss_state = torch.sum(vq_loss_state)
|
|
260
|
+
|
|
261
|
+
if self.obs_cond:
|
|
262
|
+
dec_out = self.decoder(state_vq, obs[:, -1])
|
|
263
|
+
else:
|
|
264
|
+
dec_out = self.decoder(state_vq)
|
|
265
|
+
encoder_loss = (state - dec_out).abs().mean()
|
|
266
|
+
|
|
267
|
+
rep_loss = encoder_loss * self.encoder_loss_multiplier + (vq_loss_state * 5)
|
|
268
|
+
|
|
269
|
+
# self.vqvae_optimizer.zero_grad()
|
|
270
|
+
# rep_loss.backward()
|
|
271
|
+
# self.vqvae_optimizer.step()
|
|
272
|
+
vqvae_recon_loss = torch.nn.MSELoss()(state, dec_out)
|
|
273
|
+
loss_dict = {
|
|
274
|
+
"rep_loss": rep_loss.detach().cpu().item(),
|
|
275
|
+
"vq_loss_state": vq_loss_state.detach().cpu().item(),
|
|
276
|
+
"vqvae_recon_loss_l1": encoder_loss.detach().cpu().item(),
|
|
277
|
+
"vqvae_recon_loss_l2": vqvae_recon_loss.detach().cpu().item(),
|
|
278
|
+
"n_different_codes": len(torch.unique(vq_code)),
|
|
279
|
+
"n_different_combinations": len(torch.unique(vq_code, dim=0)),
|
|
280
|
+
}
|
|
281
|
+
return rep_loss, loss_dict
|
|
282
|
+
|
|
283
|
+
def configure_optimizers(self, weight_decay, learning_rate, betas):
|
|
284
|
+
params = (
|
|
285
|
+
list(self.encoder.parameters())
|
|
286
|
+
+ list(self.decoder.parameters())
|
|
287
|
+
+ list(self.vq_layer.parameters())
|
|
288
|
+
)
|
|
289
|
+
optimizer = torch.optim.AdamW(
|
|
290
|
+
params,
|
|
291
|
+
weight_decay=weight_decay,
|
|
292
|
+
learning_rate=learning_rate,
|
|
293
|
+
betas=betas,
|
|
294
|
+
)
|
|
295
|
+
return optimizer
|
|
296
|
+
|
|
297
|
+
def _begin_epoch(self, optimizer, **kwargs):
|
|
298
|
+
# log codebook usage rate for debugging
|
|
299
|
+
# lr_0 = optimizer.param_groups[0]["lr"]
|
|
300
|
+
# lr_neg1 = optimizer.param_groups[-1]["lr"]
|
|
301
|
+
# return {"lr_0": lr_0, "lr_neg1": lr_neg1}
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
# def state_dict(self):
|
|
305
|
+
# return {'encoder': self.encoder.state_dict(),
|
|
306
|
+
# 'decoder': self.decoder.state_dict(),
|
|
307
|
+
# 'vq_embedding': self.vq_layer.state_dict()}
|
|
308
|
+
|
|
309
|
+
# def load_state_dict(self, state_dict):
|
|
310
|
+
# self.encoder.load_state_dict(state_dict['encoder'])
|
|
311
|
+
# self.decoder.load_state_dict(state_dict['decoder'])
|
|
312
|
+
# self.vq_layer.load_state_dict(state_dict['vq_embedding'])
|
|
313
|
+
# self.vq_layer.eval()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import os.path as osp
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def weights_init_encoder(m):
|
|
8
|
+
if isinstance(m, nn.Linear):
|
|
9
|
+
nn.init.orthogonal_(m.weight.data)
|
|
10
|
+
m.bias.data.fill_(0.0)
|
|
11
|
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
|
12
|
+
assert m.weight.size(2) == m.weight.size(3)
|
|
13
|
+
m.weight.data.fill_(0.0)
|
|
14
|
+
m.bias.data.fill_(0.0)
|
|
15
|
+
mid = m.weight.size(2) // 2
|
|
16
|
+
gain = nn.init.calculate_gain("relu")
|
|
17
|
+
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_tensor(z, device):
|
|
21
|
+
if z is None:
|
|
22
|
+
return None
|
|
23
|
+
if z[0].dtype == np.dtype("O"):
|
|
24
|
+
return None
|
|
25
|
+
if len(z.shape) == 1:
|
|
26
|
+
return torch.FloatTensor(z.copy()).to(device).unsqueeze(0)
|
|
27
|
+
# return torch.from_numpy(z.copy()).float().to(device).unsqueeze(0)
|
|
28
|
+
else:
|
|
29
|
+
return torch.FloatTensor(z.copy()).to(device)
|
|
30
|
+
# return torch.from_numpy(z.copy()).float().to(device)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
class CustomModel(nn.Module):
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__()
|
|
7
|
+
# TODO: set model
|
|
8
|
+
self.model = None
|
|
9
|
+
|
|
10
|
+
def to(self, device):
|
|
11
|
+
# TODO: move model to device
|
|
12
|
+
if self.model is not None:
|
|
13
|
+
self.model.to(device)
|
|
14
|
+
return self
|
|
15
|
+
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
# TODO: forward pass of model
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
def step(self, data, *args, **kwargs):
|
|
21
|
+
images, actions = data
|
|
22
|
+
# images is a tensor of shape (1, image_buffer_size, 3, 256, 256), where the images are in chronological order
|
|
23
|
+
# actions is a tensor of shape (1, image_buffer_size, 7), where the actions are in chronological order, and the final action is padding
|
|
24
|
+
|
|
25
|
+
# TODO: implement pass that takes in observations as described above out outputs a 7-dimensional action
|
|
26
|
+
action = torch.zeros(7)
|
|
27
|
+
logs = {}
|
|
28
|
+
|
|
29
|
+
return action, logs
|
|
30
|
+
|
|
31
|
+
def reset(self):
|
|
32
|
+
# TODO: optional; this method is called once the robot is homed
|
|
33
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
import warnings
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
ENCODER_LOADING_ERROR_MSG = (
|
|
10
|
+
"Could not load encoder weights: defaulting to pretrained weights"
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AbstractEncoder(nn.Module, ABC):
|
|
15
|
+
def __init__(self):
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def feature_dim(self):
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def transform(self, x):
|
|
24
|
+
return x
|
|
25
|
+
|
|
26
|
+
def to(self, device):
|
|
27
|
+
self.device = device
|
|
28
|
+
return super().to(device)
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def load_weights(
|
|
35
|
+
self, weight_path: Union[pathlib.Path, str], strict: bool = True
|
|
36
|
+
) -> None:
|
|
37
|
+
try:
|
|
38
|
+
checkpoint = torch.load(weight_path, map_location="cpu")
|
|
39
|
+
weight_dict = checkpoint.get("model", checkpoint)
|
|
40
|
+
|
|
41
|
+
self.model.load_state_dict(weight_dict, strict=strict)
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
except RuntimeError:
|
|
45
|
+
warnings.warn("Couldn't load full model weights. Trying fallback options...")
|
|
46
|
+
|
|
47
|
+
fallback_prefixes = [
|
|
48
|
+
"model.",
|
|
49
|
+
"encoder.",
|
|
50
|
+
"module.base_encoder._orig_mod.",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
for prefix in fallback_prefixes:
|
|
54
|
+
try:
|
|
55
|
+
filtered_weights = {
|
|
56
|
+
k.replace(prefix, ""): v
|
|
57
|
+
for k, v in weight_dict.items()
|
|
58
|
+
if k.startswith(prefix)
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
filtered_weights.pop("fc.weight", None)
|
|
62
|
+
filtered_weights.pop("fc.bias", None)
|
|
63
|
+
|
|
64
|
+
self.model.load_state_dict(filtered_weights, strict=strict)
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
except RuntimeError:
|
|
68
|
+
warnings.warn(f"Failed loading with prefix: '{prefix}'")
|
|
69
|
+
|
|
70
|
+
warnings.warn(ENCODER_LOADING_ERROR_MSG)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A generic model wrapper for dummy identity encoders.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pathlib
|
|
6
|
+
from typing import Union
|
|
7
|
+
from torch import nn
|
|
8
|
+
import torch
|
|
9
|
+
from .abstract_base_encoder import AbstractEncoder
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class IdentityEncoder(AbstractEncoder):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
model_name,
|
|
16
|
+
pretrained: bool = True,
|
|
17
|
+
weight_path: Union[None, str, pathlib.Path] = None,
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self._model_name = model_name
|
|
21
|
+
self.device = None
|
|
22
|
+
self.model = nn.Identity()
|
|
23
|
+
|
|
24
|
+
# Use a placeholder parameter if the model has no parameters
|
|
25
|
+
if len(list(self.model.parameters())) == 0:
|
|
26
|
+
self.placeholder_param = nn.Parameter(torch.zeros(1, requires_grad=True))
|
|
27
|
+
else:
|
|
28
|
+
self.placeholder_param = None
|
|
29
|
+
|
|
30
|
+
def transform(self, x):
|
|
31
|
+
return x
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def feature_dim(self):
|
|
35
|
+
return 0
|
|
36
|
+
|
|
37
|
+
def to(self, device):
|
|
38
|
+
self.device = device
|
|
39
|
+
self.model.to(device)
|
|
40
|
+
if self.placeholder_param is not None:
|
|
41
|
+
self.placeholder_param.to(device)
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
return self.transform(x)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A generic model wrapper for timm encoders.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pathlib
|
|
6
|
+
from typing import Iterable, Optional, Union
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
import einops
|
|
10
|
+
import timm
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
|
|
14
|
+
from baselines.rum.models.encoders.abstract_base_encoder import AbstractEncoder
|
|
15
|
+
from baselines.rum.utils.decord_transforms import create_transform
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TimmModel(AbstractEncoder):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model_name: str = "hf-hub:notmahi/dobb-e",
|
|
22
|
+
pretrained: bool = True,
|
|
23
|
+
weight_path: Union[None, str, pathlib.Path] = None,
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self._model_name = model_name
|
|
27
|
+
|
|
28
|
+
self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
|
|
29
|
+
if weight_path:
|
|
30
|
+
self.load_weights(weight_path, strict=False)
|
|
31
|
+
|
|
32
|
+
def transform(self, x):
|
|
33
|
+
return x
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def feature_dim(self):
|
|
37
|
+
return self.model.num_features
|
|
38
|
+
|
|
39
|
+
def to(self, device):
|
|
40
|
+
self.model.to(device)
|
|
41
|
+
return self
|
|
42
|
+
|
|
43
|
+
def forward(self, x):
|
|
44
|
+
return self.model(self.transform(x))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TimmSSL(TimmModel):
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
model_name: str = "hf-hub:notmahi/dobb-e",
|
|
51
|
+
pretrained=True,
|
|
52
|
+
override_aug_kwargs=dict(
|
|
53
|
+
hflip=0.0, vflip=0.0, scale=(1.0, 1.0), crop_pct=0.875
|
|
54
|
+
),
|
|
55
|
+
weight_path: Union[None, str, pathlib.Path] = None,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(model_name, pretrained=pretrained)
|
|
58
|
+
data_cfg = timm.data.resolve_data_config(self.model.pretrained_cfg)
|
|
59
|
+
# Now define the transforms.
|
|
60
|
+
data_cfg.update(override_aug_kwargs)
|
|
61
|
+
data_cfg["is_training"] = True
|
|
62
|
+
self._train_transform = create_transform(**data_cfg)
|
|
63
|
+
data_cfg["is_training"] = False
|
|
64
|
+
self._test_transform = create_transform(**data_cfg)
|
|
65
|
+
if weight_path is not None:
|
|
66
|
+
self.load_weights(weight_path, strict=True)
|
|
67
|
+
|
|
68
|
+
def transform(self, x):
|
|
69
|
+
return (
|
|
70
|
+
self._train_transform(x) if self.model.training else self._test_transform(x)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def forward(self, x):
|
|
74
|
+
# Split the input into frames and labels.
|
|
75
|
+
images, *labels = x
|
|
76
|
+
# Flatten the frames into a single batch.
|
|
77
|
+
flattened_images = einops.rearrange(images, "b t c h w -> (b t) c h w")
|
|
78
|
+
# Transform and pass through the model.
|
|
79
|
+
result = self.model(self.transform(flattened_images))
|
|
80
|
+
# Unflatten the result.
|
|
81
|
+
result = einops.rearrange(result, "(b t) c -> b t c", b=images.shape[0])
|
|
82
|
+
return result
|