egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from itertools import chain
|
|
4
|
+
from typing import Dict, Optional, Sequence, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import accelerate
|
|
7
|
+
import einops
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
|
|
12
|
+
from baselines.rum.models.bet.bet import FocalLoss, KMeansDiscretizer
|
|
13
|
+
from baselines.rum.models.bet.gpt import GPT
|
|
14
|
+
from baselines.rum.models.bet.utils import MLP
|
|
15
|
+
|
|
16
|
+
GENERATOR_SEED_FIXED = 123456789
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TokenizedBehaviorTransformer(nn.Module):
|
|
20
|
+
GOAL_SPEC = Enum("GOAL_SPEC", "concat stack unconditional")
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
obs_dim: int,
|
|
25
|
+
act_dim: int,
|
|
26
|
+
goal_dim: int,
|
|
27
|
+
gpt_model: GPT,
|
|
28
|
+
action_spec: Optional[Sequence[int]] = None,
|
|
29
|
+
start_and_ends: Optional[Sequence[Tuple[int, int]]] = None,
|
|
30
|
+
num_extra_predicted_actions: Optional[int] = None,
|
|
31
|
+
n_clusters: Union[int, Sequence[int]] = 32,
|
|
32
|
+
kmeans_fit_steps: int = 500,
|
|
33
|
+
kmeans_iters: int = 50,
|
|
34
|
+
offset_loss_multiplier: float = 1.0e3,
|
|
35
|
+
offset_distance_metric: str = "L2",
|
|
36
|
+
representation_height: int = 7,
|
|
37
|
+
representation_width: int = 7,
|
|
38
|
+
gamma: float = 2.0,
|
|
39
|
+
sampling_temperature: float = 1.0,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self._obs_dim = obs_dim
|
|
44
|
+
self._act_dim = act_dim
|
|
45
|
+
self.sampling_temperature = sampling_temperature
|
|
46
|
+
# First ensure either action spec is given or start and ends are given.
|
|
47
|
+
assert (action_spec is not None) != (start_and_ends is not None), (
|
|
48
|
+
"EITHER action_spec OR start_and_ends must be given."
|
|
49
|
+
)
|
|
50
|
+
if action_spec is not None:
|
|
51
|
+
self._action_spec = action_spec
|
|
52
|
+
assert act_dim == sum(action_spec)
|
|
53
|
+
cumsum = [sum(action_spec[:i]) for i in range(len(action_spec) + 1)]
|
|
54
|
+
self._starts = cumsum[:-1]
|
|
55
|
+
self._ends = cumsum[1:]
|
|
56
|
+
self._start_ends = zip(self._starts, self._ends)
|
|
57
|
+
else:
|
|
58
|
+
self._start_ends = start_and_ends
|
|
59
|
+
self._action_spec = [end - start for (start, end) in start_and_ends]
|
|
60
|
+
self._start_and_ends = list(enumerate(self._start_ends))
|
|
61
|
+
self._n_subactions = len(self._start_and_ends) # Number of sub-actions.
|
|
62
|
+
|
|
63
|
+
self._goal_dim = goal_dim
|
|
64
|
+
self._num_extra_predicted_actions = num_extra_predicted_actions
|
|
65
|
+
|
|
66
|
+
# Decide goal conditioning style.
|
|
67
|
+
if goal_dim <= 0:
|
|
68
|
+
self._cbet_method = self.GOAL_SPEC.unconditional
|
|
69
|
+
elif obs_dim == goal_dim:
|
|
70
|
+
self._cbet_method = self.GOAL_SPEC.concat
|
|
71
|
+
else:
|
|
72
|
+
self._goal_encoder = nn.Linear(goal_dim, obs_dim, bias=False)
|
|
73
|
+
self._cbet_method = self.GOAL_SPEC.stack
|
|
74
|
+
|
|
75
|
+
self._gpt_model = gpt_model
|
|
76
|
+
if isinstance(n_clusters, int):
|
|
77
|
+
n_clusters = [n_clusters] * len(self._start_and_ends)
|
|
78
|
+
assert len(n_clusters) == len(self._start_and_ends)
|
|
79
|
+
# For now, we assume the number of clusters is given.
|
|
80
|
+
assert all(k > 0 for k in n_clusters)
|
|
81
|
+
self._K = n_clusters
|
|
82
|
+
self._kmeans_fit_steps = kmeans_fit_steps
|
|
83
|
+
self._clustering_algos = [
|
|
84
|
+
KMeansDiscretizer(num_bins=k, kmeans_iters=kmeans_iters) for k in n_clusters
|
|
85
|
+
]
|
|
86
|
+
self._current_steps = 0
|
|
87
|
+
self._map_to_cbet_preds = nn.ModuleList(
|
|
88
|
+
[
|
|
89
|
+
MLP(
|
|
90
|
+
in_channels=gpt_model.config.output_dim,
|
|
91
|
+
hidden_channels=[(a + 1) * k],
|
|
92
|
+
)
|
|
93
|
+
for (a, k) in zip(self._action_spec, n_clusters)
|
|
94
|
+
]
|
|
95
|
+
)
|
|
96
|
+
self._collected_actions = []
|
|
97
|
+
self._have_fit_kmeans = False
|
|
98
|
+
# Placeholder for the cluster centers.
|
|
99
|
+
generator = torch.Generator()
|
|
100
|
+
generator.manual_seed(GENERATOR_SEED_FIXED)
|
|
101
|
+
self._cluster_centers = nn.ParameterList(
|
|
102
|
+
[
|
|
103
|
+
nn.Parameter(
|
|
104
|
+
torch.randn((k, a), generator=generator, dtype=torch.float32),
|
|
105
|
+
requires_grad=False,
|
|
106
|
+
)
|
|
107
|
+
for (k, a) in zip(n_clusters, self._action_spec)
|
|
108
|
+
]
|
|
109
|
+
)
|
|
110
|
+
self._criterion = FocalLoss(gamma=gamma, reduction="none")
|
|
111
|
+
self._offset_criterion = (
|
|
112
|
+
nn.MSELoss(reduction="none")
|
|
113
|
+
if offset_distance_metric == "L2"
|
|
114
|
+
else nn.L1Loss(reduction="none")
|
|
115
|
+
)
|
|
116
|
+
self._offset_loss_multiplier = offset_loss_multiplier
|
|
117
|
+
|
|
118
|
+
self._action_tokenizer = nn.ModuleList(
|
|
119
|
+
[nn.Linear(a, obs_dim) for a in self._action_spec]
|
|
120
|
+
)
|
|
121
|
+
# Figure out the embedding tokens.
|
|
122
|
+
if self._cbet_method == self.GOAL_SPEC.unconditional:
|
|
123
|
+
self._goal_embedding_token = None
|
|
124
|
+
else:
|
|
125
|
+
self._goal_embedding_token = nn.Parameter(torch.randn(goal_dim))
|
|
126
|
+
self._h, self._w = representation_height, representation_width
|
|
127
|
+
self._obs_embedding_token = nn.Parameter(
|
|
128
|
+
torch.randn([obs_dim, self._h, self._w])
|
|
129
|
+
)
|
|
130
|
+
self._action_embedding_token = nn.Parameter(
|
|
131
|
+
torch.randn([self._n_subactions, obs_dim])
|
|
132
|
+
)
|
|
133
|
+
self._extra_action_token = nn.Parameter(
|
|
134
|
+
torch.randn([self._n_subactions, obs_dim])
|
|
135
|
+
)
|
|
136
|
+
self._end_of_obs_token = nn.Parameter(torch.randn(obs_dim))
|
|
137
|
+
self._accelerator = accelerate.Accelerator()
|
|
138
|
+
|
|
139
|
+
def _load_from_state_dict(self, *args, **kwargs):
|
|
140
|
+
# Don't fit kmeans if we are loading from a state dict.
|
|
141
|
+
self._current_steps = self._kmeans_fit_steps
|
|
142
|
+
self._have_fit_kmeans = True
|
|
143
|
+
return super()._load_from_state_dict(*args, **kwargs)
|
|
144
|
+
|
|
145
|
+
def get_start_and_ends(self) -> Sequence[Tuple[int, int]]:
|
|
146
|
+
return self._start_and_ends
|
|
147
|
+
|
|
148
|
+
def _tokenize_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
|
149
|
+
action_stack = [
|
|
150
|
+
self._action_tokenizer[i](actions[..., start:end])
|
|
151
|
+
for i, (start, end) in self._start_and_ends
|
|
152
|
+
]
|
|
153
|
+
return torch.stack(action_stack, dim=2)
|
|
154
|
+
|
|
155
|
+
def _detokenize_actions(
|
|
156
|
+
self, tokenized_action_output: torch.Tensor
|
|
157
|
+
) -> torch.Tensor:
|
|
158
|
+
action_bin_logits = []
|
|
159
|
+
all_action_offsets = []
|
|
160
|
+
for i, _ in self._start_and_ends:
|
|
161
|
+
tokenized_action_i = tokenized_action_output[..., i, :]
|
|
162
|
+
action_cbet_preds = self._map_to_cbet_preds[i](tokenized_action_i)
|
|
163
|
+
action_center_logits, action_offsets = torch.split(
|
|
164
|
+
action_cbet_preds,
|
|
165
|
+
[self._K[i], self._K[i] * self._action_spec[i]],
|
|
166
|
+
dim=-1,
|
|
167
|
+
)
|
|
168
|
+
action_bin_logits.append(action_center_logits)
|
|
169
|
+
all_action_offsets.append(action_offsets)
|
|
170
|
+
return action_bin_logits, all_action_offsets
|
|
171
|
+
|
|
172
|
+
def _begin_epoch(self, optimizer, **kwargs):
|
|
173
|
+
# log learning rate for debugging
|
|
174
|
+
lr_0 = optimizer.param_groups[0]["lr"]
|
|
175
|
+
lr_neg1 = optimizer.param_groups[-1]["lr"]
|
|
176
|
+
return {"lr_0": lr_0, "lr_neg1": lr_neg1}
|
|
177
|
+
|
|
178
|
+
def _calculate_cbet_preds_and_loss(
|
|
179
|
+
self,
|
|
180
|
+
cluster_centers: torch.Tensor,
|
|
181
|
+
bin_logits: torch.Tensor,
|
|
182
|
+
action_offsets: torch.Tensor,
|
|
183
|
+
true_actions: torch.Tensor,
|
|
184
|
+
is_padded_action_seq: Optional[torch.Tensor],
|
|
185
|
+
predict_with_offset: bool = True,
|
|
186
|
+
return_loss: bool = True,
|
|
187
|
+
sampling_temperature: float = 1.0,
|
|
188
|
+
) -> torch.Tensor:
|
|
189
|
+
bin_probs = torch.softmax(bin_logits / sampling_temperature, dim=-1)
|
|
190
|
+
N, T, choices = bin_probs.shape
|
|
191
|
+
# Sample from the multinomial distribution, one per row.
|
|
192
|
+
sampled_centers = einops.rearrange(
|
|
193
|
+
torch.multinomial(bin_probs.view(-1, choices), num_samples=1),
|
|
194
|
+
"(N T) 1 -> N T 1",
|
|
195
|
+
N=N,
|
|
196
|
+
)
|
|
197
|
+
flattened_action_offsets = einops.rearrange(
|
|
198
|
+
action_offsets, "N T (K A) -> (N T) K A", K=choices
|
|
199
|
+
)
|
|
200
|
+
sampled_offsets = flattened_action_offsets[
|
|
201
|
+
torch.arange(flattened_action_offsets.shape[0]), sampled_centers.flatten()
|
|
202
|
+
].view(N, T, -1)
|
|
203
|
+
centers = cluster_centers[sampled_centers.flatten()].view(N, T, -1)
|
|
204
|
+
a_hat = centers + (sampled_offsets if predict_with_offset else 0.0)
|
|
205
|
+
if not return_loss:
|
|
206
|
+
return a_hat, None, {}
|
|
207
|
+
# We are in training, so figure out the loss for the actions.
|
|
208
|
+
# First, we need to find the closest cluster center for each action.
|
|
209
|
+
true_action_bins = self._find_closest_cluster(true_actions, cluster_centers)
|
|
210
|
+
true_offsets = true_actions - cluster_centers[true_action_bins]
|
|
211
|
+
predicted_offsets = flattened_action_offsets[
|
|
212
|
+
torch.arange(flattened_action_offsets.shape[0]), true_action_bins.flatten()
|
|
213
|
+
].view(N, T, -1)
|
|
214
|
+
offset_loss = self._offset_criterion(predicted_offsets, true_offsets)
|
|
215
|
+
cbet_loss = self._criterion(
|
|
216
|
+
einops.rearrange(bin_logits, "N T D -> (N T) D"),
|
|
217
|
+
einops.rearrange(true_action_bins, "N T -> (N T)"),
|
|
218
|
+
)
|
|
219
|
+
if is_padded_action_seq is not None:
|
|
220
|
+
cbet_loss *= ~is_padded_action_seq.view(-1)
|
|
221
|
+
offset_loss *= ~is_padded_action_seq.unsqueeze(-1)
|
|
222
|
+
cbet_loss, offset_loss = cbet_loss.mean(), offset_loss.mean()
|
|
223
|
+
action_mse = F.mse_loss(a_hat, true_actions, reduction="none")
|
|
224
|
+
action_l1 = F.l1_loss(a_hat, true_actions, reduction="none")
|
|
225
|
+
norm = torch.norm(true_actions, p=2, dim=-1, keepdim=True) + 1e-9
|
|
226
|
+
normalized_mse = (action_mse / norm).mean()
|
|
227
|
+
loss = cbet_loss + self._offset_loss_multiplier * offset_loss
|
|
228
|
+
if self._current_steps < self._kmeans_fit_steps:
|
|
229
|
+
loss = loss.detach() + (loss * 0.0)
|
|
230
|
+
loss_dict = {
|
|
231
|
+
"classification_loss": cbet_loss.detach().cpu().item(),
|
|
232
|
+
"offset_loss": offset_loss.detach().cpu().item(),
|
|
233
|
+
"loss": loss.detach().cpu().item(),
|
|
234
|
+
"L2_loss": action_mse.mean().detach().cpu().item(),
|
|
235
|
+
"L2_loss_normalized": normalized_mse.mean().detach().cpu().item(),
|
|
236
|
+
"L1_loss": action_l1.mean().detach().cpu().item(),
|
|
237
|
+
}
|
|
238
|
+
return a_hat, loss, loss_dict
|
|
239
|
+
|
|
240
|
+
def forward(
|
|
241
|
+
self,
|
|
242
|
+
obs_seq: torch.Tensor,
|
|
243
|
+
goal_seq: Optional[torch.Tensor],
|
|
244
|
+
action_seq: Optional[torch.Tensor],
|
|
245
|
+
padding_seq: Optional[torch.Tensor],
|
|
246
|
+
predict_with_offset: bool = True,
|
|
247
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
248
|
+
if self._current_steps == 0:
|
|
249
|
+
self._cluster_centers = self._cluster_centers.to(obs_seq.device)
|
|
250
|
+
if self._current_steps < self._kmeans_fit_steps and (
|
|
251
|
+
action_seq is not None and padding_seq is not None
|
|
252
|
+
):
|
|
253
|
+
self._current_steps += 1
|
|
254
|
+
self._fit_kmeans(action_seq, padding_seq)
|
|
255
|
+
return self._predict(
|
|
256
|
+
obs_seq,
|
|
257
|
+
goal_seq,
|
|
258
|
+
action_seq,
|
|
259
|
+
padding_seq,
|
|
260
|
+
predict_with_offset=predict_with_offset,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def _fit_kmeans(
|
|
264
|
+
self,
|
|
265
|
+
action_seq: Optional[torch.Tensor],
|
|
266
|
+
padding_seq: Optional[torch.Tensor],
|
|
267
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
268
|
+
assert self._current_steps <= self._kmeans_fit_steps
|
|
269
|
+
if self._current_steps == 1:
|
|
270
|
+
self._cluster_centers = self._cluster_centers.to(action_seq.device)
|
|
271
|
+
|
|
272
|
+
all_action_seq = self._accelerator.gather(action_seq)
|
|
273
|
+
all_padding_seq = self._accelerator.gather(padding_seq)
|
|
274
|
+
self._collected_actions.append(
|
|
275
|
+
all_action_seq[torch.logical_not(all_padding_seq)]
|
|
276
|
+
)
|
|
277
|
+
if self._current_steps == self._kmeans_fit_steps:
|
|
278
|
+
logging.info("Fitting KMeans")
|
|
279
|
+
self._collected_actions = torch.cat(self._collected_actions, dim=0)
|
|
280
|
+
for i, (start, end) in self._start_and_ends:
|
|
281
|
+
clustering_algo = self._clustering_algos[i]
|
|
282
|
+
logging.info(f"Fitting KMeans for action {i}")
|
|
283
|
+
clustering_algo.fit(
|
|
284
|
+
self._collected_actions[:, start:end].view(-1, end - start)
|
|
285
|
+
)
|
|
286
|
+
self._cluster_centers[i] = clustering_algo.bin_centers.float().to(
|
|
287
|
+
action_seq.device
|
|
288
|
+
)
|
|
289
|
+
self._have_fit_kmeans = True
|
|
290
|
+
|
|
291
|
+
def _predict(
|
|
292
|
+
self,
|
|
293
|
+
obs_seq: torch.Tensor,
|
|
294
|
+
goal_seq: Optional[torch.Tensor],
|
|
295
|
+
action_seq: torch.Tensor,
|
|
296
|
+
is_padded_action_seq: Optional[torch.Tensor],
|
|
297
|
+
predict_with_offset: bool = True,
|
|
298
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
|
|
299
|
+
batch_size, obs_T, c, h, w = obs_seq.shape
|
|
300
|
+
obs_seq_and_encoding = obs_seq + self._obs_embedding_token[None, None, ...]
|
|
301
|
+
obs_seq_and_encoding = einops.rearrange(
|
|
302
|
+
obs_seq_and_encoding, "N T C H W -> N T (H W) C"
|
|
303
|
+
)
|
|
304
|
+
assert action_seq is not None
|
|
305
|
+
tokenized_action_seq = self._tokenize_actions(action_seq)
|
|
306
|
+
tokenized_action_seq_and_encoding = (
|
|
307
|
+
tokenized_action_seq + self._action_embedding_token[None, None, ...]
|
|
308
|
+
)
|
|
309
|
+
interspersed_seq = torch.cat(
|
|
310
|
+
[obs_seq_and_encoding, tokenized_action_seq_and_encoding[:, :obs_T, ...]],
|
|
311
|
+
dim=2,
|
|
312
|
+
) # N T (H W + A) C
|
|
313
|
+
interspersed_seq_flattened = einops.rearrange(
|
|
314
|
+
interspersed_seq, "N T (HWA) C -> N (T HWA) C"
|
|
315
|
+
)
|
|
316
|
+
extra_actions_tokenized_plus_embedding = (
|
|
317
|
+
tokenized_action_seq_and_encoding[:, obs_T:, ...]
|
|
318
|
+
+ self._extra_action_token[None, None, ...]
|
|
319
|
+
)
|
|
320
|
+
extra_actions_flattened = einops.rearrange(
|
|
321
|
+
extra_actions_tokenized_plus_embedding, "N T A D -> N (T A) D"
|
|
322
|
+
)
|
|
323
|
+
interspersed_seq_and_extra_actions = torch.cat(
|
|
324
|
+
[
|
|
325
|
+
interspersed_seq_flattened,
|
|
326
|
+
einops.repeat(self._end_of_obs_token, "D -> N 1 D", N=batch_size),
|
|
327
|
+
extra_actions_flattened,
|
|
328
|
+
],
|
|
329
|
+
dim=1,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Assume dimensions are N T D for N sequences of T timesteps with dimension D.
|
|
333
|
+
if self._cbet_method == self.GOAL_SPEC.unconditional:
|
|
334
|
+
gpt_input = interspersed_seq_and_extra_actions
|
|
335
|
+
elif self._cbet_method == self.GOAL_SPEC.concat:
|
|
336
|
+
goal_seq_encoded = goal_seq + self._goal_embedding_token[None, None, ...]
|
|
337
|
+
gpt_input = torch.cat(
|
|
338
|
+
[goal_seq_encoded, interspersed_seq_and_extra_actions], dim=1
|
|
339
|
+
)
|
|
340
|
+
elif self._cbet_method == self.GOAL_SPEC.stack:
|
|
341
|
+
goal_seq_embedded = (
|
|
342
|
+
goal_seq[:, None, :] + self._goal_embedding_token[None, None, ...]
|
|
343
|
+
)
|
|
344
|
+
goal_seq_encoded = self._goal_encoder(goal_seq_embedded)
|
|
345
|
+
gpt_input = torch.cat(
|
|
346
|
+
[goal_seq_encoded, interspersed_seq_and_extra_actions], dim=1
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
raise NotImplementedError
|
|
350
|
+
|
|
351
|
+
gpt_output = self._gpt_model(gpt_input)
|
|
352
|
+
if self._cbet_method == self.GOAL_SPEC.concat:
|
|
353
|
+
# Chop off the goal encodings.
|
|
354
|
+
gpt_output = gpt_output[:, goal_seq.size(1) :, :]
|
|
355
|
+
elif self._cbet_method == self.GOAL_SPEC.stack:
|
|
356
|
+
gpt_output = gpt_output[:, 1::]
|
|
357
|
+
|
|
358
|
+
# Here we have a sequence of shape (N, (T*(H*W + A) + 1 + T'*A), D)
|
|
359
|
+
# where T' is the number of extra actions we want to predict.
|
|
360
|
+
# Separate out the original and the extra actions.
|
|
361
|
+
# We have H*W + A tokens per obs, + 1 for the end of obs token.
|
|
362
|
+
extra_action_output_flat = gpt_output[
|
|
363
|
+
:, obs_T * (self._h * self._w + self._n_subactions) : -1, :
|
|
364
|
+
] # N (T' A) D
|
|
365
|
+
extra_action_output = einops.rearrange(
|
|
366
|
+
extra_action_output_flat, "N (T A) D -> N T A D", A=self._n_subactions
|
|
367
|
+
) # N T' A D
|
|
368
|
+
original_output_tokens = gpt_output[
|
|
369
|
+
:, : obs_T * (self._h * self._w + self._n_subactions), :
|
|
370
|
+
]
|
|
371
|
+
original_output_tokens_reshaped = einops.rearrange(
|
|
372
|
+
original_output_tokens, "N (T HWA) D -> N T HWA D", T=obs_T
|
|
373
|
+
)
|
|
374
|
+
original_action_tokens = original_output_tokens_reshaped[
|
|
375
|
+
:, :, -self._n_subactions - 1 : -1, :
|
|
376
|
+
] # N T A D
|
|
377
|
+
output_action_tokens = torch.cat(
|
|
378
|
+
[original_action_tokens, extra_action_output], dim=1
|
|
379
|
+
) # N (T + T') A D
|
|
380
|
+
action_bin_logits, action_offsets = self._detokenize_actions(
|
|
381
|
+
output_action_tokens
|
|
382
|
+
)
|
|
383
|
+
# Now calculate the predicted actions and the losses.
|
|
384
|
+
a_hat, loss, loss_dict = torch.zeros_like(action_seq), None, {}
|
|
385
|
+
for i, (start, end) in self._start_and_ends:
|
|
386
|
+
a_hat_i, loss_i, loss_dict_i = self._calculate_cbet_preds_and_loss(
|
|
387
|
+
self._cluster_centers[i],
|
|
388
|
+
action_bin_logits[i],
|
|
389
|
+
action_offsets[i],
|
|
390
|
+
action_seq[..., start:end],
|
|
391
|
+
# TODO make is_padded_action_seq more flexible,
|
|
392
|
+
# For example, in reality part of the action could be "padded"
|
|
393
|
+
is_padded_action_seq,
|
|
394
|
+
predict_with_offset=predict_with_offset,
|
|
395
|
+
return_loss=True,
|
|
396
|
+
sampling_temperature=self.sampling_temperature,
|
|
397
|
+
)
|
|
398
|
+
a_hat[..., start:end] = a_hat_i
|
|
399
|
+
if i == 0:
|
|
400
|
+
# a_hat = a_hat_i
|
|
401
|
+
loss = loss_i
|
|
402
|
+
for k, v in loss_dict_i.items():
|
|
403
|
+
loss_dict[k] = v
|
|
404
|
+
else:
|
|
405
|
+
# a_hat = torch.cat([a_hat, a_hat_i], dim=-1)
|
|
406
|
+
loss += loss_i
|
|
407
|
+
for k, v in loss_dict_i.items():
|
|
408
|
+
loss_dict[k] += v
|
|
409
|
+
for k, v in loss_dict_i.items():
|
|
410
|
+
loss_dict[f"{k}_{i}"] = v
|
|
411
|
+
return a_hat, loss, loss_dict
|
|
412
|
+
|
|
413
|
+
def _find_closest_cluster(
|
|
414
|
+
self, action_seq: torch.Tensor, cluster_centers: torch.Tensor
|
|
415
|
+
) -> torch.Tensor:
|
|
416
|
+
N, T, _ = action_seq.shape
|
|
417
|
+
flattened_actions = einops.rearrange(action_seq, "N T A -> (N T) A")
|
|
418
|
+
cluster_center_distance = torch.sum(
|
|
419
|
+
(flattened_actions[:, None, :] - cluster_centers[None, :, :]) ** 2,
|
|
420
|
+
dim=2,
|
|
421
|
+
) # (N T) K A -> (N T) K
|
|
422
|
+
closest_cluster_center = torch.argmin(cluster_center_distance, dim=1) # (N T)
|
|
423
|
+
discretized_action = einops.rearrange(
|
|
424
|
+
closest_cluster_center, "(N T) -> N T", N=N, T=T
|
|
425
|
+
)
|
|
426
|
+
return discretized_action
|
|
427
|
+
|
|
428
|
+
def configure_optimizers(self, weight_decay, learning_rate, betas):
|
|
429
|
+
optimizer = self._gpt_model.configure_optimizers(
|
|
430
|
+
weight_decay=weight_decay,
|
|
431
|
+
learning_rate=learning_rate,
|
|
432
|
+
betas=betas,
|
|
433
|
+
)
|
|
434
|
+
optimizer.add_param_group(
|
|
435
|
+
{
|
|
436
|
+
"params": chain(
|
|
437
|
+
self._map_to_cbet_preds.parameters(),
|
|
438
|
+
self._action_tokenizer.parameters(),
|
|
439
|
+
)
|
|
440
|
+
}
|
|
441
|
+
)
|
|
442
|
+
optimizer.add_param_group(
|
|
443
|
+
{
|
|
444
|
+
"params": [
|
|
445
|
+
self._obs_embedding_token,
|
|
446
|
+
self._goal_embedding_token,
|
|
447
|
+
self._action_embedding_token,
|
|
448
|
+
self._extra_action_token,
|
|
449
|
+
self._end_of_obs_token,
|
|
450
|
+
],
|
|
451
|
+
"weight_decay": 0.0,
|
|
452
|
+
}
|
|
453
|
+
)
|
|
454
|
+
return optimizer
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Callable, List, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MLP(torch.nn.Sequential):
|
|
6
|
+
"""This block implements the multi-layer perceptron (MLP) module.
|
|
7
|
+
Adapted for backward compatibility from the torchvision library:
|
|
8
|
+
https://pytorch.org/vision/0.14/generated/torchvision.ops.MLP.html
|
|
9
|
+
|
|
10
|
+
LICENSE:
|
|
11
|
+
|
|
12
|
+
From PyTorch:
|
|
13
|
+
|
|
14
|
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
|
15
|
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
|
16
|
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
|
17
|
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
|
18
|
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
|
19
|
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
|
20
|
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
|
21
|
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
|
22
|
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
|
23
|
+
|
|
24
|
+
From Caffe2:
|
|
25
|
+
|
|
26
|
+
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
|
27
|
+
|
|
28
|
+
All contributions by Facebook:
|
|
29
|
+
Copyright (c) 2016 Facebook Inc.
|
|
30
|
+
|
|
31
|
+
All contributions by Google:
|
|
32
|
+
Copyright (c) 2015 Google Inc.
|
|
33
|
+
All rights reserved.
|
|
34
|
+
|
|
35
|
+
All contributions by Yangqing Jia:
|
|
36
|
+
Copyright (c) 2015 Yangqing Jia
|
|
37
|
+
All rights reserved.
|
|
38
|
+
|
|
39
|
+
All contributions by Kakao Brain:
|
|
40
|
+
Copyright 2019-2020 Kakao Brain
|
|
41
|
+
|
|
42
|
+
All contributions by Cruise LLC:
|
|
43
|
+
Copyright (c) 2022 Cruise LLC.
|
|
44
|
+
All rights reserved.
|
|
45
|
+
|
|
46
|
+
All contributions from Caffe:
|
|
47
|
+
Copyright(c) 2013, 2014, 2015, the respective contributors
|
|
48
|
+
All rights reserved.
|
|
49
|
+
|
|
50
|
+
All other contributions:
|
|
51
|
+
Copyright(c) 2015, 2016 the respective contributors
|
|
52
|
+
All rights reserved.
|
|
53
|
+
|
|
54
|
+
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
|
55
|
+
copyright over their contributions to Caffe2. The project versioning records
|
|
56
|
+
all such contribution and copyright details. If a contributor wants to further
|
|
57
|
+
mark their specific copyright on a particular contribution, they should
|
|
58
|
+
indicate their copyright solely in the commit message of the change when it is
|
|
59
|
+
committed.
|
|
60
|
+
|
|
61
|
+
All rights reserved.
|
|
62
|
+
|
|
63
|
+
Redistribution and use in source and binary forms, with or without
|
|
64
|
+
modification, are permitted provided that the following conditions are met:
|
|
65
|
+
|
|
66
|
+
1. Redistributions of source code must retain the above copyright
|
|
67
|
+
notice, this list of conditions and the following disclaimer.
|
|
68
|
+
|
|
69
|
+
2. Redistributions in binary form must reproduce the above copyright
|
|
70
|
+
notice, this list of conditions and the following disclaimer in the
|
|
71
|
+
documentation and/or other materials provided with the distribution.
|
|
72
|
+
|
|
73
|
+
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
|
74
|
+
and IDIAP Research Institute nor the names of its contributors may be
|
|
75
|
+
used to endorse or promote products derived from this software without
|
|
76
|
+
specific prior written permission.
|
|
77
|
+
|
|
78
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
79
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
80
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
81
|
+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
|
82
|
+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
|
83
|
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
84
|
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
85
|
+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
|
86
|
+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
|
87
|
+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
|
88
|
+
POSSIBILITY OF SUCH DAMAGE.
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
in_channels (int): Number of channels of the input
|
|
93
|
+
hidden_channels (List[int]): List of the hidden channel dimensions
|
|
94
|
+
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
|
|
95
|
+
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
|
|
96
|
+
inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
|
|
97
|
+
Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
|
|
98
|
+
bias (bool): Whether to use bias in the linear layer. Default ``True``
|
|
99
|
+
dropout (float): The probability for the dropout layer. Default: 0.0
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
in_channels: int,
|
|
105
|
+
hidden_channels: List[int],
|
|
106
|
+
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
|
107
|
+
inplace: Optional[bool] = None,
|
|
108
|
+
bias: bool = True,
|
|
109
|
+
dropout: float = 0.0,
|
|
110
|
+
):
|
|
111
|
+
params = {} if inplace is None else {"inplace": inplace}
|
|
112
|
+
|
|
113
|
+
layers = []
|
|
114
|
+
in_dim = in_channels
|
|
115
|
+
for hidden_dim in hidden_channels[:-1]:
|
|
116
|
+
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
|
|
117
|
+
layers.append(activation_layer(**params))
|
|
118
|
+
layers.append(torch.nn.Dropout(dropout, **params))
|
|
119
|
+
in_dim = hidden_dim
|
|
120
|
+
|
|
121
|
+
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
|
|
122
|
+
layers.append(torch.nn.Dropout(dropout, **params))
|
|
123
|
+
|
|
124
|
+
super().__init__(*layers)
|