egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,454 @@
1
+ import logging
2
+ from enum import Enum
3
+ from itertools import chain
4
+ from typing import Dict, Optional, Sequence, Tuple, Union
5
+
6
+ import accelerate
7
+ import einops
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from baselines.rum.models.bet.bet import FocalLoss, KMeansDiscretizer
13
+ from baselines.rum.models.bet.gpt import GPT
14
+ from baselines.rum.models.bet.utils import MLP
15
+
16
+ GENERATOR_SEED_FIXED = 123456789
17
+
18
+
19
+ class TokenizedBehaviorTransformer(nn.Module):
20
+ GOAL_SPEC = Enum("GOAL_SPEC", "concat stack unconditional")
21
+
22
+ def __init__(
23
+ self,
24
+ obs_dim: int,
25
+ act_dim: int,
26
+ goal_dim: int,
27
+ gpt_model: GPT,
28
+ action_spec: Optional[Sequence[int]] = None,
29
+ start_and_ends: Optional[Sequence[Tuple[int, int]]] = None,
30
+ num_extra_predicted_actions: Optional[int] = None,
31
+ n_clusters: Union[int, Sequence[int]] = 32,
32
+ kmeans_fit_steps: int = 500,
33
+ kmeans_iters: int = 50,
34
+ offset_loss_multiplier: float = 1.0e3,
35
+ offset_distance_metric: str = "L2",
36
+ representation_height: int = 7,
37
+ representation_width: int = 7,
38
+ gamma: float = 2.0,
39
+ sampling_temperature: float = 1.0,
40
+ **kwargs,
41
+ ):
42
+ super().__init__()
43
+ self._obs_dim = obs_dim
44
+ self._act_dim = act_dim
45
+ self.sampling_temperature = sampling_temperature
46
+ # First ensure either action spec is given or start and ends are given.
47
+ assert (action_spec is not None) != (start_and_ends is not None), (
48
+ "EITHER action_spec OR start_and_ends must be given."
49
+ )
50
+ if action_spec is not None:
51
+ self._action_spec = action_spec
52
+ assert act_dim == sum(action_spec)
53
+ cumsum = [sum(action_spec[:i]) for i in range(len(action_spec) + 1)]
54
+ self._starts = cumsum[:-1]
55
+ self._ends = cumsum[1:]
56
+ self._start_ends = zip(self._starts, self._ends)
57
+ else:
58
+ self._start_ends = start_and_ends
59
+ self._action_spec = [end - start for (start, end) in start_and_ends]
60
+ self._start_and_ends = list(enumerate(self._start_ends))
61
+ self._n_subactions = len(self._start_and_ends) # Number of sub-actions.
62
+
63
+ self._goal_dim = goal_dim
64
+ self._num_extra_predicted_actions = num_extra_predicted_actions
65
+
66
+ # Decide goal conditioning style.
67
+ if goal_dim <= 0:
68
+ self._cbet_method = self.GOAL_SPEC.unconditional
69
+ elif obs_dim == goal_dim:
70
+ self._cbet_method = self.GOAL_SPEC.concat
71
+ else:
72
+ self._goal_encoder = nn.Linear(goal_dim, obs_dim, bias=False)
73
+ self._cbet_method = self.GOAL_SPEC.stack
74
+
75
+ self._gpt_model = gpt_model
76
+ if isinstance(n_clusters, int):
77
+ n_clusters = [n_clusters] * len(self._start_and_ends)
78
+ assert len(n_clusters) == len(self._start_and_ends)
79
+ # For now, we assume the number of clusters is given.
80
+ assert all(k > 0 for k in n_clusters)
81
+ self._K = n_clusters
82
+ self._kmeans_fit_steps = kmeans_fit_steps
83
+ self._clustering_algos = [
84
+ KMeansDiscretizer(num_bins=k, kmeans_iters=kmeans_iters) for k in n_clusters
85
+ ]
86
+ self._current_steps = 0
87
+ self._map_to_cbet_preds = nn.ModuleList(
88
+ [
89
+ MLP(
90
+ in_channels=gpt_model.config.output_dim,
91
+ hidden_channels=[(a + 1) * k],
92
+ )
93
+ for (a, k) in zip(self._action_spec, n_clusters)
94
+ ]
95
+ )
96
+ self._collected_actions = []
97
+ self._have_fit_kmeans = False
98
+ # Placeholder for the cluster centers.
99
+ generator = torch.Generator()
100
+ generator.manual_seed(GENERATOR_SEED_FIXED)
101
+ self._cluster_centers = nn.ParameterList(
102
+ [
103
+ nn.Parameter(
104
+ torch.randn((k, a), generator=generator, dtype=torch.float32),
105
+ requires_grad=False,
106
+ )
107
+ for (k, a) in zip(n_clusters, self._action_spec)
108
+ ]
109
+ )
110
+ self._criterion = FocalLoss(gamma=gamma, reduction="none")
111
+ self._offset_criterion = (
112
+ nn.MSELoss(reduction="none")
113
+ if offset_distance_metric == "L2"
114
+ else nn.L1Loss(reduction="none")
115
+ )
116
+ self._offset_loss_multiplier = offset_loss_multiplier
117
+
118
+ self._action_tokenizer = nn.ModuleList(
119
+ [nn.Linear(a, obs_dim) for a in self._action_spec]
120
+ )
121
+ # Figure out the embedding tokens.
122
+ if self._cbet_method == self.GOAL_SPEC.unconditional:
123
+ self._goal_embedding_token = None
124
+ else:
125
+ self._goal_embedding_token = nn.Parameter(torch.randn(goal_dim))
126
+ self._h, self._w = representation_height, representation_width
127
+ self._obs_embedding_token = nn.Parameter(
128
+ torch.randn([obs_dim, self._h, self._w])
129
+ )
130
+ self._action_embedding_token = nn.Parameter(
131
+ torch.randn([self._n_subactions, obs_dim])
132
+ )
133
+ self._extra_action_token = nn.Parameter(
134
+ torch.randn([self._n_subactions, obs_dim])
135
+ )
136
+ self._end_of_obs_token = nn.Parameter(torch.randn(obs_dim))
137
+ self._accelerator = accelerate.Accelerator()
138
+
139
+ def _load_from_state_dict(self, *args, **kwargs):
140
+ # Don't fit kmeans if we are loading from a state dict.
141
+ self._current_steps = self._kmeans_fit_steps
142
+ self._have_fit_kmeans = True
143
+ return super()._load_from_state_dict(*args, **kwargs)
144
+
145
+ def get_start_and_ends(self) -> Sequence[Tuple[int, int]]:
146
+ return self._start_and_ends
147
+
148
+ def _tokenize_actions(self, actions: torch.Tensor) -> torch.Tensor:
149
+ action_stack = [
150
+ self._action_tokenizer[i](actions[..., start:end])
151
+ for i, (start, end) in self._start_and_ends
152
+ ]
153
+ return torch.stack(action_stack, dim=2)
154
+
155
+ def _detokenize_actions(
156
+ self, tokenized_action_output: torch.Tensor
157
+ ) -> torch.Tensor:
158
+ action_bin_logits = []
159
+ all_action_offsets = []
160
+ for i, _ in self._start_and_ends:
161
+ tokenized_action_i = tokenized_action_output[..., i, :]
162
+ action_cbet_preds = self._map_to_cbet_preds[i](tokenized_action_i)
163
+ action_center_logits, action_offsets = torch.split(
164
+ action_cbet_preds,
165
+ [self._K[i], self._K[i] * self._action_spec[i]],
166
+ dim=-1,
167
+ )
168
+ action_bin_logits.append(action_center_logits)
169
+ all_action_offsets.append(action_offsets)
170
+ return action_bin_logits, all_action_offsets
171
+
172
+ def _begin_epoch(self, optimizer, **kwargs):
173
+ # log learning rate for debugging
174
+ lr_0 = optimizer.param_groups[0]["lr"]
175
+ lr_neg1 = optimizer.param_groups[-1]["lr"]
176
+ return {"lr_0": lr_0, "lr_neg1": lr_neg1}
177
+
178
+ def _calculate_cbet_preds_and_loss(
179
+ self,
180
+ cluster_centers: torch.Tensor,
181
+ bin_logits: torch.Tensor,
182
+ action_offsets: torch.Tensor,
183
+ true_actions: torch.Tensor,
184
+ is_padded_action_seq: Optional[torch.Tensor],
185
+ predict_with_offset: bool = True,
186
+ return_loss: bool = True,
187
+ sampling_temperature: float = 1.0,
188
+ ) -> torch.Tensor:
189
+ bin_probs = torch.softmax(bin_logits / sampling_temperature, dim=-1)
190
+ N, T, choices = bin_probs.shape
191
+ # Sample from the multinomial distribution, one per row.
192
+ sampled_centers = einops.rearrange(
193
+ torch.multinomial(bin_probs.view(-1, choices), num_samples=1),
194
+ "(N T) 1 -> N T 1",
195
+ N=N,
196
+ )
197
+ flattened_action_offsets = einops.rearrange(
198
+ action_offsets, "N T (K A) -> (N T) K A", K=choices
199
+ )
200
+ sampled_offsets = flattened_action_offsets[
201
+ torch.arange(flattened_action_offsets.shape[0]), sampled_centers.flatten()
202
+ ].view(N, T, -1)
203
+ centers = cluster_centers[sampled_centers.flatten()].view(N, T, -1)
204
+ a_hat = centers + (sampled_offsets if predict_with_offset else 0.0)
205
+ if not return_loss:
206
+ return a_hat, None, {}
207
+ # We are in training, so figure out the loss for the actions.
208
+ # First, we need to find the closest cluster center for each action.
209
+ true_action_bins = self._find_closest_cluster(true_actions, cluster_centers)
210
+ true_offsets = true_actions - cluster_centers[true_action_bins]
211
+ predicted_offsets = flattened_action_offsets[
212
+ torch.arange(flattened_action_offsets.shape[0]), true_action_bins.flatten()
213
+ ].view(N, T, -1)
214
+ offset_loss = self._offset_criterion(predicted_offsets, true_offsets)
215
+ cbet_loss = self._criterion(
216
+ einops.rearrange(bin_logits, "N T D -> (N T) D"),
217
+ einops.rearrange(true_action_bins, "N T -> (N T)"),
218
+ )
219
+ if is_padded_action_seq is not None:
220
+ cbet_loss *= ~is_padded_action_seq.view(-1)
221
+ offset_loss *= ~is_padded_action_seq.unsqueeze(-1)
222
+ cbet_loss, offset_loss = cbet_loss.mean(), offset_loss.mean()
223
+ action_mse = F.mse_loss(a_hat, true_actions, reduction="none")
224
+ action_l1 = F.l1_loss(a_hat, true_actions, reduction="none")
225
+ norm = torch.norm(true_actions, p=2, dim=-1, keepdim=True) + 1e-9
226
+ normalized_mse = (action_mse / norm).mean()
227
+ loss = cbet_loss + self._offset_loss_multiplier * offset_loss
228
+ if self._current_steps < self._kmeans_fit_steps:
229
+ loss = loss.detach() + (loss * 0.0)
230
+ loss_dict = {
231
+ "classification_loss": cbet_loss.detach().cpu().item(),
232
+ "offset_loss": offset_loss.detach().cpu().item(),
233
+ "loss": loss.detach().cpu().item(),
234
+ "L2_loss": action_mse.mean().detach().cpu().item(),
235
+ "L2_loss_normalized": normalized_mse.mean().detach().cpu().item(),
236
+ "L1_loss": action_l1.mean().detach().cpu().item(),
237
+ }
238
+ return a_hat, loss, loss_dict
239
+
240
+ def forward(
241
+ self,
242
+ obs_seq: torch.Tensor,
243
+ goal_seq: Optional[torch.Tensor],
244
+ action_seq: Optional[torch.Tensor],
245
+ padding_seq: Optional[torch.Tensor],
246
+ predict_with_offset: bool = True,
247
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
248
+ if self._current_steps == 0:
249
+ self._cluster_centers = self._cluster_centers.to(obs_seq.device)
250
+ if self._current_steps < self._kmeans_fit_steps and (
251
+ action_seq is not None and padding_seq is not None
252
+ ):
253
+ self._current_steps += 1
254
+ self._fit_kmeans(action_seq, padding_seq)
255
+ return self._predict(
256
+ obs_seq,
257
+ goal_seq,
258
+ action_seq,
259
+ padding_seq,
260
+ predict_with_offset=predict_with_offset,
261
+ )
262
+
263
+ def _fit_kmeans(
264
+ self,
265
+ action_seq: Optional[torch.Tensor],
266
+ padding_seq: Optional[torch.Tensor],
267
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
268
+ assert self._current_steps <= self._kmeans_fit_steps
269
+ if self._current_steps == 1:
270
+ self._cluster_centers = self._cluster_centers.to(action_seq.device)
271
+
272
+ all_action_seq = self._accelerator.gather(action_seq)
273
+ all_padding_seq = self._accelerator.gather(padding_seq)
274
+ self._collected_actions.append(
275
+ all_action_seq[torch.logical_not(all_padding_seq)]
276
+ )
277
+ if self._current_steps == self._kmeans_fit_steps:
278
+ logging.info("Fitting KMeans")
279
+ self._collected_actions = torch.cat(self._collected_actions, dim=0)
280
+ for i, (start, end) in self._start_and_ends:
281
+ clustering_algo = self._clustering_algos[i]
282
+ logging.info(f"Fitting KMeans for action {i}")
283
+ clustering_algo.fit(
284
+ self._collected_actions[:, start:end].view(-1, end - start)
285
+ )
286
+ self._cluster_centers[i] = clustering_algo.bin_centers.float().to(
287
+ action_seq.device
288
+ )
289
+ self._have_fit_kmeans = True
290
+
291
+ def _predict(
292
+ self,
293
+ obs_seq: torch.Tensor,
294
+ goal_seq: Optional[torch.Tensor],
295
+ action_seq: torch.Tensor,
296
+ is_padded_action_seq: Optional[torch.Tensor],
297
+ predict_with_offset: bool = True,
298
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
299
+ batch_size, obs_T, c, h, w = obs_seq.shape
300
+ obs_seq_and_encoding = obs_seq + self._obs_embedding_token[None, None, ...]
301
+ obs_seq_and_encoding = einops.rearrange(
302
+ obs_seq_and_encoding, "N T C H W -> N T (H W) C"
303
+ )
304
+ assert action_seq is not None
305
+ tokenized_action_seq = self._tokenize_actions(action_seq)
306
+ tokenized_action_seq_and_encoding = (
307
+ tokenized_action_seq + self._action_embedding_token[None, None, ...]
308
+ )
309
+ interspersed_seq = torch.cat(
310
+ [obs_seq_and_encoding, tokenized_action_seq_and_encoding[:, :obs_T, ...]],
311
+ dim=2,
312
+ ) # N T (H W + A) C
313
+ interspersed_seq_flattened = einops.rearrange(
314
+ interspersed_seq, "N T (HWA) C -> N (T HWA) C"
315
+ )
316
+ extra_actions_tokenized_plus_embedding = (
317
+ tokenized_action_seq_and_encoding[:, obs_T:, ...]
318
+ + self._extra_action_token[None, None, ...]
319
+ )
320
+ extra_actions_flattened = einops.rearrange(
321
+ extra_actions_tokenized_plus_embedding, "N T A D -> N (T A) D"
322
+ )
323
+ interspersed_seq_and_extra_actions = torch.cat(
324
+ [
325
+ interspersed_seq_flattened,
326
+ einops.repeat(self._end_of_obs_token, "D -> N 1 D", N=batch_size),
327
+ extra_actions_flattened,
328
+ ],
329
+ dim=1,
330
+ )
331
+
332
+ # Assume dimensions are N T D for N sequences of T timesteps with dimension D.
333
+ if self._cbet_method == self.GOAL_SPEC.unconditional:
334
+ gpt_input = interspersed_seq_and_extra_actions
335
+ elif self._cbet_method == self.GOAL_SPEC.concat:
336
+ goal_seq_encoded = goal_seq + self._goal_embedding_token[None, None, ...]
337
+ gpt_input = torch.cat(
338
+ [goal_seq_encoded, interspersed_seq_and_extra_actions], dim=1
339
+ )
340
+ elif self._cbet_method == self.GOAL_SPEC.stack:
341
+ goal_seq_embedded = (
342
+ goal_seq[:, None, :] + self._goal_embedding_token[None, None, ...]
343
+ )
344
+ goal_seq_encoded = self._goal_encoder(goal_seq_embedded)
345
+ gpt_input = torch.cat(
346
+ [goal_seq_encoded, interspersed_seq_and_extra_actions], dim=1
347
+ )
348
+ else:
349
+ raise NotImplementedError
350
+
351
+ gpt_output = self._gpt_model(gpt_input)
352
+ if self._cbet_method == self.GOAL_SPEC.concat:
353
+ # Chop off the goal encodings.
354
+ gpt_output = gpt_output[:, goal_seq.size(1) :, :]
355
+ elif self._cbet_method == self.GOAL_SPEC.stack:
356
+ gpt_output = gpt_output[:, 1::]
357
+
358
+ # Here we have a sequence of shape (N, (T*(H*W + A) + 1 + T'*A), D)
359
+ # where T' is the number of extra actions we want to predict.
360
+ # Separate out the original and the extra actions.
361
+ # We have H*W + A tokens per obs, + 1 for the end of obs token.
362
+ extra_action_output_flat = gpt_output[
363
+ :, obs_T * (self._h * self._w + self._n_subactions) : -1, :
364
+ ] # N (T' A) D
365
+ extra_action_output = einops.rearrange(
366
+ extra_action_output_flat, "N (T A) D -> N T A D", A=self._n_subactions
367
+ ) # N T' A D
368
+ original_output_tokens = gpt_output[
369
+ :, : obs_T * (self._h * self._w + self._n_subactions), :
370
+ ]
371
+ original_output_tokens_reshaped = einops.rearrange(
372
+ original_output_tokens, "N (T HWA) D -> N T HWA D", T=obs_T
373
+ )
374
+ original_action_tokens = original_output_tokens_reshaped[
375
+ :, :, -self._n_subactions - 1 : -1, :
376
+ ] # N T A D
377
+ output_action_tokens = torch.cat(
378
+ [original_action_tokens, extra_action_output], dim=1
379
+ ) # N (T + T') A D
380
+ action_bin_logits, action_offsets = self._detokenize_actions(
381
+ output_action_tokens
382
+ )
383
+ # Now calculate the predicted actions and the losses.
384
+ a_hat, loss, loss_dict = torch.zeros_like(action_seq), None, {}
385
+ for i, (start, end) in self._start_and_ends:
386
+ a_hat_i, loss_i, loss_dict_i = self._calculate_cbet_preds_and_loss(
387
+ self._cluster_centers[i],
388
+ action_bin_logits[i],
389
+ action_offsets[i],
390
+ action_seq[..., start:end],
391
+ # TODO make is_padded_action_seq more flexible,
392
+ # For example, in reality part of the action could be "padded"
393
+ is_padded_action_seq,
394
+ predict_with_offset=predict_with_offset,
395
+ return_loss=True,
396
+ sampling_temperature=self.sampling_temperature,
397
+ )
398
+ a_hat[..., start:end] = a_hat_i
399
+ if i == 0:
400
+ # a_hat = a_hat_i
401
+ loss = loss_i
402
+ for k, v in loss_dict_i.items():
403
+ loss_dict[k] = v
404
+ else:
405
+ # a_hat = torch.cat([a_hat, a_hat_i], dim=-1)
406
+ loss += loss_i
407
+ for k, v in loss_dict_i.items():
408
+ loss_dict[k] += v
409
+ for k, v in loss_dict_i.items():
410
+ loss_dict[f"{k}_{i}"] = v
411
+ return a_hat, loss, loss_dict
412
+
413
+ def _find_closest_cluster(
414
+ self, action_seq: torch.Tensor, cluster_centers: torch.Tensor
415
+ ) -> torch.Tensor:
416
+ N, T, _ = action_seq.shape
417
+ flattened_actions = einops.rearrange(action_seq, "N T A -> (N T) A")
418
+ cluster_center_distance = torch.sum(
419
+ (flattened_actions[:, None, :] - cluster_centers[None, :, :]) ** 2,
420
+ dim=2,
421
+ ) # (N T) K A -> (N T) K
422
+ closest_cluster_center = torch.argmin(cluster_center_distance, dim=1) # (N T)
423
+ discretized_action = einops.rearrange(
424
+ closest_cluster_center, "(N T) -> N T", N=N, T=T
425
+ )
426
+ return discretized_action
427
+
428
+ def configure_optimizers(self, weight_decay, learning_rate, betas):
429
+ optimizer = self._gpt_model.configure_optimizers(
430
+ weight_decay=weight_decay,
431
+ learning_rate=learning_rate,
432
+ betas=betas,
433
+ )
434
+ optimizer.add_param_group(
435
+ {
436
+ "params": chain(
437
+ self._map_to_cbet_preds.parameters(),
438
+ self._action_tokenizer.parameters(),
439
+ )
440
+ }
441
+ )
442
+ optimizer.add_param_group(
443
+ {
444
+ "params": [
445
+ self._obs_embedding_token,
446
+ self._goal_embedding_token,
447
+ self._action_embedding_token,
448
+ self._extra_action_token,
449
+ self._end_of_obs_token,
450
+ ],
451
+ "weight_decay": 0.0,
452
+ }
453
+ )
454
+ return optimizer
@@ -0,0 +1,124 @@
1
+ import torch
2
+ from typing import Callable, List, Optional
3
+
4
+
5
+ class MLP(torch.nn.Sequential):
6
+ """This block implements the multi-layer perceptron (MLP) module.
7
+ Adapted for backward compatibility from the torchvision library:
8
+ https://pytorch.org/vision/0.14/generated/torchvision.ops.MLP.html
9
+
10
+ LICENSE:
11
+
12
+ From PyTorch:
13
+
14
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
15
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
16
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
17
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
18
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
19
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
20
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
21
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
22
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
23
+
24
+ From Caffe2:
25
+
26
+ Copyright (c) 2016-present, Facebook Inc. All rights reserved.
27
+
28
+ All contributions by Facebook:
29
+ Copyright (c) 2016 Facebook Inc.
30
+
31
+ All contributions by Google:
32
+ Copyright (c) 2015 Google Inc.
33
+ All rights reserved.
34
+
35
+ All contributions by Yangqing Jia:
36
+ Copyright (c) 2015 Yangqing Jia
37
+ All rights reserved.
38
+
39
+ All contributions by Kakao Brain:
40
+ Copyright 2019-2020 Kakao Brain
41
+
42
+ All contributions by Cruise LLC:
43
+ Copyright (c) 2022 Cruise LLC.
44
+ All rights reserved.
45
+
46
+ All contributions from Caffe:
47
+ Copyright(c) 2013, 2014, 2015, the respective contributors
48
+ All rights reserved.
49
+
50
+ All other contributions:
51
+ Copyright(c) 2015, 2016 the respective contributors
52
+ All rights reserved.
53
+
54
+ Caffe2 uses a copyright model similar to Caffe: each contributor holds
55
+ copyright over their contributions to Caffe2. The project versioning records
56
+ all such contribution and copyright details. If a contributor wants to further
57
+ mark their specific copyright on a particular contribution, they should
58
+ indicate their copyright solely in the commit message of the change when it is
59
+ committed.
60
+
61
+ All rights reserved.
62
+
63
+ Redistribution and use in source and binary forms, with or without
64
+ modification, are permitted provided that the following conditions are met:
65
+
66
+ 1. Redistributions of source code must retain the above copyright
67
+ notice, this list of conditions and the following disclaimer.
68
+
69
+ 2. Redistributions in binary form must reproduce the above copyright
70
+ notice, this list of conditions and the following disclaimer in the
71
+ documentation and/or other materials provided with the distribution.
72
+
73
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
74
+ and IDIAP Research Institute nor the names of its contributors may be
75
+ used to endorse or promote products derived from this software without
76
+ specific prior written permission.
77
+
78
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
79
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
80
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
81
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
82
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
83
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
84
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
85
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
86
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
87
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
88
+ POSSIBILITY OF SUCH DAMAGE.
89
+
90
+
91
+ Args:
92
+ in_channels (int): Number of channels of the input
93
+ hidden_channels (List[int]): List of the hidden channel dimensions
94
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
95
+ activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
96
+ inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
97
+ Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
98
+ bias (bool): Whether to use bias in the linear layer. Default ``True``
99
+ dropout (float): The probability for the dropout layer. Default: 0.0
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ in_channels: int,
105
+ hidden_channels: List[int],
106
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
107
+ inplace: Optional[bool] = None,
108
+ bias: bool = True,
109
+ dropout: float = 0.0,
110
+ ):
111
+ params = {} if inplace is None else {"inplace": inplace}
112
+
113
+ layers = []
114
+ in_dim = in_channels
115
+ for hidden_dim in hidden_channels[:-1]:
116
+ layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
117
+ layers.append(activation_layer(**params))
118
+ layers.append(torch.nn.Dropout(dropout, **params))
119
+ in_dim = hidden_dim
120
+
121
+ layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
122
+ layers.append(torch.nn.Dropout(dropout, **params))
123
+
124
+ super().__init__(*layers)