egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,313 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import einops
4
+ from baselines.rum.models.bet.vqvae.residual_vq import ResidualVQ
5
+ from baselines.rum.models.bet.vqvae.vqvae_utils import get_tensor, weights_init_encoder
6
+
7
+
8
+ class EncoderMLP(nn.Module):
9
+ def __init__(
10
+ self,
11
+ input_dim,
12
+ output_dim=16,
13
+ hidden_dim=128,
14
+ layer_num=1,
15
+ last_activation=None,
16
+ ):
17
+ super(EncoderMLP, self).__init__()
18
+ layers = []
19
+
20
+ layers.append(nn.Linear(input_dim, hidden_dim))
21
+ layers.append(nn.ReLU())
22
+ for _ in range(layer_num):
23
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
24
+ layers.append(nn.ReLU())
25
+
26
+ self.encoder = nn.Sequential(*layers)
27
+ self.fc = nn.Linear(hidden_dim, output_dim)
28
+
29
+ if last_activation is not None:
30
+ self.last_layer = last_activation
31
+ else:
32
+ self.last_layer = None
33
+ self.apply(weights_init_encoder)
34
+
35
+ def forward(self, x):
36
+ h = self.encoder(x)
37
+ state = self.fc(h)
38
+ if self.last_layer:
39
+ state = self.last_layer(state)
40
+ return state
41
+
42
+
43
+ class CondiitonalEncoderMLP(nn.Module):
44
+ def __init__(
45
+ self,
46
+ input_dim,
47
+ output_dim=16,
48
+ hidden_dim=512,
49
+ layer_num=2,
50
+ last_activation=None,
51
+ obs_dim=None,
52
+ ):
53
+ super(CondiitonalEncoderMLP, self).__init__()
54
+ layers = []
55
+
56
+ layers.append(nn.Linear(input_dim + obs_dim, hidden_dim))
57
+ layers.append(nn.ReLU())
58
+ for _ in range(layer_num):
59
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
60
+ layers.append(nn.ReLU())
61
+
62
+ self.encoder = nn.Sequential(*layers)
63
+ self.fc = nn.Linear(hidden_dim, output_dim)
64
+
65
+ if last_activation is not None:
66
+ self.last_layer = last_activation
67
+ else:
68
+ self.last_layer = None
69
+ self.apply(weights_init_encoder)
70
+
71
+ def forward(self, x, obs):
72
+ x = torch.cat((x, obs), dim=1)
73
+ h = self.encoder(x)
74
+ state = self.fc(h)
75
+ if self.last_layer:
76
+ state = self.last_layer(state)
77
+ return state
78
+
79
+
80
+ class VqVae(nn.Module):
81
+ def __init__(
82
+ self,
83
+ obs_dim=60,
84
+ input_dim_h=10, # length of action chunk
85
+ input_dim_w=9, # action dim
86
+ n_latent_dims=512,
87
+ vqvae_n_embed=32,
88
+ vqvae_groups=4,
89
+ eval=True,
90
+ device="cuda",
91
+ load_dir=None,
92
+ enc_loss_type="skip_vqlayer",
93
+ obs_cond=False,
94
+ encoder_loss_multiplier=1.0,
95
+ act_scale=1.0,
96
+ ):
97
+ super(VqVae, self).__init__()
98
+ self.n_latent_dims = n_latent_dims # 64
99
+ self.input_dim_h = input_dim_h
100
+ self.input_dim_w = input_dim_w
101
+ self.rep_dim = self.n_latent_dims
102
+ self.vqvae_n_embed = vqvae_n_embed # 120
103
+ self.vqvae_lr = 1e-3
104
+ self.vqvae_groups = vqvae_groups
105
+ self.device = device
106
+ self.enc_loss_type = enc_loss_type
107
+ self.obs_cond = obs_cond
108
+ self.encoder_loss_multiplier = encoder_loss_multiplier
109
+ self.act_scale = act_scale
110
+
111
+ discrete_cfg = {"groups": self.vqvae_groups, "n_embed": self.vqvae_n_embed}
112
+ self.vq_layer = ResidualVQ(
113
+ dim=self.n_latent_dims,
114
+ num_quantizers=discrete_cfg["groups"],
115
+ codebook_size=self.vqvae_n_embed,
116
+ eval=eval,
117
+ ).to(self.device)
118
+ self.embedding_dim = self.n_latent_dims
119
+ self.vq_layer.device = device
120
+
121
+ if self.input_dim_h == 1:
122
+ if self.obs_cond:
123
+ self.encoder = CondiitonalEncoderMLP(
124
+ input_dim=input_dim_w, output_dim=n_latent_dims, obs_dim=obs_dim
125
+ ).to(self.device)
126
+ self.decoder = CondiitonalEncoderMLP(
127
+ input_dim=n_latent_dims, output_dim=input_dim_w, obs_dim=obs_dim
128
+ ).to(self.device)
129
+ else:
130
+ self.encoder = EncoderMLP(
131
+ input_dim=input_dim_w, output_dim=n_latent_dims
132
+ ).to(self.device)
133
+ self.decoder = EncoderMLP(
134
+ input_dim=n_latent_dims, output_dim=input_dim_w
135
+ ).to(self.device)
136
+ else:
137
+ if self.obs_cond:
138
+ self.encoder = CondiitonalEncoderMLP(
139
+ input_dim=input_dim_w * self.input_dim_h,
140
+ output_dim=n_latent_dims,
141
+ obs_dim=obs_dim,
142
+ ).to(self.device)
143
+ self.decoder = CondiitonalEncoderMLP(
144
+ input_dim=n_latent_dims,
145
+ output_dim=input_dim_w * self.input_dim_h,
146
+ obs_dim=obs_dim,
147
+ ).to(self.device)
148
+ else:
149
+ self.encoder = EncoderMLP(
150
+ input_dim=input_dim_w * self.input_dim_h, output_dim=n_latent_dims
151
+ ).to(self.device)
152
+ self.decoder = EncoderMLP(
153
+ input_dim=n_latent_dims, output_dim=input_dim_w * self.input_dim_h
154
+ ).to(self.device)
155
+
156
+ # params = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.vq_layer.parameters())
157
+ # self.vqvae_optimizer = torch.optim.Adam(params, lr=self.vqvae_lr, weight_decay=0.0001)
158
+
159
+ if load_dir is not None:
160
+ try:
161
+ state_dict = torch.load(load_dir)
162
+ except RuntimeError:
163
+ state_dict = torch.load(load_dir, map_location=torch.device("cpu"))
164
+
165
+ new_dict = {}
166
+
167
+ prefix_to_remove = "_rvq."
168
+
169
+ for key, value in state_dict["loss_fn"].items():
170
+ if key.startswith(prefix_to_remove):
171
+ new_key = key[len(prefix_to_remove) :] # Remove the prefix
172
+ new_dict[new_key] = value
173
+ else:
174
+ new_dict[key] = value
175
+
176
+ self.load_state_dict(new_dict, strict=False)
177
+
178
+ if eval:
179
+ self.vq_layer.eval()
180
+ else:
181
+ self.vq_layer.train()
182
+
183
+ def draw_logits_forward(self, encoding_logits):
184
+ z_embed = self.vq_layer.draw_logits_forward(encoding_logits)
185
+ return z_embed
186
+
187
+ def draw_code_forward(self, encoding_indices):
188
+ with torch.no_grad():
189
+ z_embed = self.vq_layer.get_codes_from_indices(encoding_indices)
190
+ z_embed = z_embed.sum(dim=0)
191
+ return z_embed
192
+
193
+ def get_action_from_latent(self, latent, obs=None):
194
+ if self.obs_cond:
195
+ output = self.decoder(latent, obs[:, -1]) * self.act_scale
196
+ else:
197
+ output = self.decoder(latent) * self.act_scale
198
+ if self.input_dim_h == 1:
199
+ return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w)
200
+ else:
201
+ return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w)
202
+
203
+ def preprocess(self, state):
204
+ if not torch.is_tensor(state):
205
+ state = get_tensor(state, self.device)
206
+ if self.input_dim_h == 1:
207
+ state = state.squeeze(-2) # state.squeeze(-1)
208
+ else:
209
+ state = einops.rearrange(state, "N T A -> N (T A)")
210
+ return state.to(self.device)
211
+
212
+ def get_code(self, state, obs=None, required_recon=False):
213
+ state = state / self.act_scale
214
+ state = self.preprocess(state)
215
+ with torch.no_grad():
216
+ if self.obs_cond:
217
+ state_rep = self.encoder(state, obs[:, -1])
218
+ else:
219
+ state_rep = self.encoder(state)
220
+ state_rep_shape = state_rep.shape[:-1]
221
+ state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
222
+ state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
223
+ state_vq = state_rep_flat.view(*state_rep_shape, -1)
224
+ vq_code = vq_code.view(*state_rep_shape, -1)
225
+ vq_loss_state = torch.sum(vq_loss_state)
226
+ if required_recon:
227
+ if self.obs_cond:
228
+ recon_state = self.decoder(state_vq, obs[:, -1]) * self.act_scale
229
+ recon_state_ae = (
230
+ self.decoder(state_rep, obs[:, -1]) * self.act_scale
231
+ )
232
+ else:
233
+ recon_state = self.decoder(state_vq) * self.act_scale
234
+ recon_state_ae = self.decoder(state_rep) * self.act_scale
235
+ if self.input_dim_h == 1:
236
+ return state_vq, vq_code, recon_state, recon_state_ae
237
+ else:
238
+ return (
239
+ state_vq,
240
+ vq_code,
241
+ torch.swapaxes(recon_state, -2, -1),
242
+ torch.swapaxes(recon_state_ae, -2, -1),
243
+ )
244
+ else:
245
+ return state_vq, vq_code
246
+
247
+ def vqvae_update(self, state, obs=None):
248
+ state = state / self.act_scale
249
+ state = self.preprocess(state)
250
+ if self.obs_cond:
251
+ state_rep = self.encoder(state, obs[:, -1])
252
+ else:
253
+ state_rep = self.encoder(state)
254
+ state_rep_shape = state_rep.shape[:-1]
255
+ state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
256
+ state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
257
+ state_vq = state_rep_flat.view(*state_rep_shape, -1)
258
+ vq_code = vq_code.view(*state_rep_shape, -1)
259
+ vq_loss_state = torch.sum(vq_loss_state)
260
+
261
+ if self.obs_cond:
262
+ dec_out = self.decoder(state_vq, obs[:, -1])
263
+ else:
264
+ dec_out = self.decoder(state_vq)
265
+ encoder_loss = (state - dec_out).abs().mean()
266
+
267
+ rep_loss = encoder_loss * self.encoder_loss_multiplier + (vq_loss_state * 5)
268
+
269
+ # self.vqvae_optimizer.zero_grad()
270
+ # rep_loss.backward()
271
+ # self.vqvae_optimizer.step()
272
+ vqvae_recon_loss = torch.nn.MSELoss()(state, dec_out)
273
+ loss_dict = {
274
+ "rep_loss": rep_loss.detach().cpu().item(),
275
+ "vq_loss_state": vq_loss_state.detach().cpu().item(),
276
+ "vqvae_recon_loss_l1": encoder_loss.detach().cpu().item(),
277
+ "vqvae_recon_loss_l2": vqvae_recon_loss.detach().cpu().item(),
278
+ "n_different_codes": len(torch.unique(vq_code)),
279
+ "n_different_combinations": len(torch.unique(vq_code, dim=0)),
280
+ }
281
+ return rep_loss, loss_dict
282
+
283
+ def configure_optimizers(self, weight_decay, learning_rate, betas):
284
+ params = (
285
+ list(self.encoder.parameters())
286
+ + list(self.decoder.parameters())
287
+ + list(self.vq_layer.parameters())
288
+ )
289
+ optimizer = torch.optim.AdamW(
290
+ params,
291
+ weight_decay=weight_decay,
292
+ learning_rate=learning_rate,
293
+ betas=betas,
294
+ )
295
+ return optimizer
296
+
297
+ def _begin_epoch(self, optimizer, **kwargs):
298
+ # log codebook usage rate for debugging
299
+ # lr_0 = optimizer.param_groups[0]["lr"]
300
+ # lr_neg1 = optimizer.param_groups[-1]["lr"]
301
+ # return {"lr_0": lr_0, "lr_neg1": lr_neg1}
302
+ return None
303
+
304
+ # def state_dict(self):
305
+ # return {'encoder': self.encoder.state_dict(),
306
+ # 'decoder': self.decoder.state_dict(),
307
+ # 'vq_embedding': self.vq_layer.state_dict()}
308
+
309
+ # def load_state_dict(self, state_dict):
310
+ # self.encoder.load_state_dict(state_dict['encoder'])
311
+ # self.decoder.load_state_dict(state_dict['decoder'])
312
+ # self.vq_layer.load_state_dict(state_dict['vq_embedding'])
313
+ # self.vq_layer.eval()
@@ -0,0 +1,30 @@
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import os.path as osp
5
+
6
+
7
+ def weights_init_encoder(m):
8
+ if isinstance(m, nn.Linear):
9
+ nn.init.orthogonal_(m.weight.data)
10
+ m.bias.data.fill_(0.0)
11
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
12
+ assert m.weight.size(2) == m.weight.size(3)
13
+ m.weight.data.fill_(0.0)
14
+ m.bias.data.fill_(0.0)
15
+ mid = m.weight.size(2) // 2
16
+ gain = nn.init.calculate_gain("relu")
17
+ nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
18
+
19
+
20
+ def get_tensor(z, device):
21
+ if z is None:
22
+ return None
23
+ if z[0].dtype == np.dtype("O"):
24
+ return None
25
+ if len(z.shape) == 1:
26
+ return torch.FloatTensor(z.copy()).to(device).unsqueeze(0)
27
+ # return torch.from_numpy(z.copy()).float().to(device).unsqueeze(0)
28
+ else:
29
+ return torch.FloatTensor(z.copy()).to(device)
30
+ # return torch.from_numpy(z.copy()).float().to(device)
@@ -0,0 +1,33 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CustomModel(nn.Module):
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__()
7
+ # TODO: set model
8
+ self.model = None
9
+
10
+ def to(self, device):
11
+ # TODO: move model to device
12
+ if self.model is not None:
13
+ self.model.to(device)
14
+ return self
15
+
16
+ def forward(self, x):
17
+ # TODO: forward pass of model
18
+ pass
19
+
20
+ def step(self, data, *args, **kwargs):
21
+ images, actions = data
22
+ # images is a tensor of shape (1, image_buffer_size, 3, 256, 256), where the images are in chronological order
23
+ # actions is a tensor of shape (1, image_buffer_size, 7), where the actions are in chronological order, and the final action is padding
24
+
25
+ # TODO: implement pass that takes in observations as described above out outputs a 7-dimensional action
26
+ action = torch.zeros(7)
27
+ logs = {}
28
+
29
+ return action, logs
30
+
31
+ def reset(self):
32
+ # TODO: optional; this method is called once the robot is homed
33
+ pass
File without changes
@@ -0,0 +1,70 @@
1
+ import pathlib
2
+ import warnings
3
+ from abc import ABC, abstractmethod
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ ENCODER_LOADING_ERROR_MSG = (
10
+ "Could not load encoder weights: defaulting to pretrained weights"
11
+ )
12
+
13
+
14
+ class AbstractEncoder(nn.Module, ABC):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ @property
19
+ @abstractmethod
20
+ def feature_dim(self):
21
+ pass
22
+
23
+ def transform(self, x):
24
+ return x
25
+
26
+ def to(self, device):
27
+ self.device = device
28
+ return super().to(device)
29
+
30
+ @abstractmethod
31
+ def forward(self, x):
32
+ pass
33
+
34
+ def load_weights(
35
+ self, weight_path: Union[pathlib.Path, str], strict: bool = True
36
+ ) -> None:
37
+ try:
38
+ checkpoint = torch.load(weight_path, map_location="cpu")
39
+ weight_dict = checkpoint.get("model", checkpoint)
40
+
41
+ self.model.load_state_dict(weight_dict, strict=strict)
42
+ return
43
+
44
+ except RuntimeError:
45
+ warnings.warn("Couldn't load full model weights. Trying fallback options...")
46
+
47
+ fallback_prefixes = [
48
+ "model.",
49
+ "encoder.",
50
+ "module.base_encoder._orig_mod.",
51
+ ]
52
+
53
+ for prefix in fallback_prefixes:
54
+ try:
55
+ filtered_weights = {
56
+ k.replace(prefix, ""): v
57
+ for k, v in weight_dict.items()
58
+ if k.startswith(prefix)
59
+ }
60
+
61
+ filtered_weights.pop("fc.weight", None)
62
+ filtered_weights.pop("fc.bias", None)
63
+
64
+ self.model.load_state_dict(filtered_weights, strict=strict)
65
+ return
66
+
67
+ except RuntimeError:
68
+ warnings.warn(f"Failed loading with prefix: '{prefix}'")
69
+
70
+ warnings.warn(ENCODER_LOADING_ERROR_MSG)
@@ -0,0 +1,45 @@
1
+ """
2
+ A generic model wrapper for dummy identity encoders.
3
+ """
4
+
5
+ import pathlib
6
+ from typing import Union
7
+ from torch import nn
8
+ import torch
9
+ from .abstract_base_encoder import AbstractEncoder
10
+
11
+
12
+ class IdentityEncoder(AbstractEncoder):
13
+ def __init__(
14
+ self,
15
+ model_name,
16
+ pretrained: bool = True,
17
+ weight_path: Union[None, str, pathlib.Path] = None,
18
+ ):
19
+ super().__init__()
20
+ self._model_name = model_name
21
+ self.device = None
22
+ self.model = nn.Identity()
23
+
24
+ # Use a placeholder parameter if the model has no parameters
25
+ if len(list(self.model.parameters())) == 0:
26
+ self.placeholder_param = nn.Parameter(torch.zeros(1, requires_grad=True))
27
+ else:
28
+ self.placeholder_param = None
29
+
30
+ def transform(self, x):
31
+ return x
32
+
33
+ @property
34
+ def feature_dim(self):
35
+ return 0
36
+
37
+ def to(self, device):
38
+ self.device = device
39
+ self.model.to(device)
40
+ if self.placeholder_param is not None:
41
+ self.placeholder_param.to(device)
42
+ return self
43
+
44
+ def forward(self, x):
45
+ return self.transform(x)
@@ -0,0 +1,82 @@
1
+ """
2
+ A generic model wrapper for timm encoders.
3
+ """
4
+
5
+ import pathlib
6
+ from typing import Iterable, Optional, Union
7
+ import warnings
8
+
9
+ import einops
10
+ import timm
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from baselines.rum.models.encoders.abstract_base_encoder import AbstractEncoder
15
+ from baselines.rum.utils.decord_transforms import create_transform
16
+
17
+
18
+ class TimmModel(AbstractEncoder):
19
+ def __init__(
20
+ self,
21
+ model_name: str = "hf-hub:notmahi/dobb-e",
22
+ pretrained: bool = True,
23
+ weight_path: Union[None, str, pathlib.Path] = None,
24
+ ):
25
+ super().__init__()
26
+ self._model_name = model_name
27
+
28
+ self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
29
+ if weight_path:
30
+ self.load_weights(weight_path, strict=False)
31
+
32
+ def transform(self, x):
33
+ return x
34
+
35
+ @property
36
+ def feature_dim(self):
37
+ return self.model.num_features
38
+
39
+ def to(self, device):
40
+ self.model.to(device)
41
+ return self
42
+
43
+ def forward(self, x):
44
+ return self.model(self.transform(x))
45
+
46
+
47
+ class TimmSSL(TimmModel):
48
+ def __init__(
49
+ self,
50
+ model_name: str = "hf-hub:notmahi/dobb-e",
51
+ pretrained=True,
52
+ override_aug_kwargs=dict(
53
+ hflip=0.0, vflip=0.0, scale=(1.0, 1.0), crop_pct=0.875
54
+ ),
55
+ weight_path: Union[None, str, pathlib.Path] = None,
56
+ ):
57
+ super().__init__(model_name, pretrained=pretrained)
58
+ data_cfg = timm.data.resolve_data_config(self.model.pretrained_cfg)
59
+ # Now define the transforms.
60
+ data_cfg.update(override_aug_kwargs)
61
+ data_cfg["is_training"] = True
62
+ self._train_transform = create_transform(**data_cfg)
63
+ data_cfg["is_training"] = False
64
+ self._test_transform = create_transform(**data_cfg)
65
+ if weight_path is not None:
66
+ self.load_weights(weight_path, strict=True)
67
+
68
+ def transform(self, x):
69
+ return (
70
+ self._train_transform(x) if self.model.training else self._test_transform(x)
71
+ )
72
+
73
+ def forward(self, x):
74
+ # Split the input into frames and labels.
75
+ images, *labels = x
76
+ # Flatten the frames into a single batch.
77
+ flattened_images = einops.rearrange(images, "b t c h w -> (b t) c h w")
78
+ # Transform and pass through the model.
79
+ result = self.model(self.transform(flattened_images))
80
+ # Unflatten the result.
81
+ result = einops.rearrange(result, "(b t) c -> b t c", b=images.shape[0])
82
+ return result