egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Dict, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import einops
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
import tqdm
|
|
10
|
+
|
|
11
|
+
from baselines.rum.models.bet.gpt import GPT
|
|
12
|
+
from baselines.rum.models.bet.utils import MLP
|
|
13
|
+
from baselines.rum.models.bet.vqvae.vqvae import VqVae
|
|
14
|
+
|
|
15
|
+
GENERATOR_SEED_FIXED = 123456789
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class VQBehaviorTransformer(nn.Module):
|
|
19
|
+
GOAL_SPEC = Enum("GOAL_SPEC", "concat stack unconditional")
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
obs_dim: int,
|
|
24
|
+
act_dim: int,
|
|
25
|
+
goal_dim: int,
|
|
26
|
+
gpt_model: GPT,
|
|
27
|
+
vqvae_model: VqVae,
|
|
28
|
+
offset_loss_multiplier: float = 1.0e2,
|
|
29
|
+
secondary_code_multiplier: float = 0.5,
|
|
30
|
+
gamma: float = 2.0,
|
|
31
|
+
obs_window_size=10,
|
|
32
|
+
act_window_size=10,
|
|
33
|
+
sequentially_select=False,
|
|
34
|
+
use_og_bet_loss=False,
|
|
35
|
+
use_half_and_half_loss=False,
|
|
36
|
+
temperature=1.0,
|
|
37
|
+
device="cuda",
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self._obs_dim = obs_dim
|
|
41
|
+
self._act_dim = act_dim
|
|
42
|
+
self._goal_dim = goal_dim
|
|
43
|
+
self.obs_window_size = obs_window_size
|
|
44
|
+
self.act_window_size = act_window_size
|
|
45
|
+
self.sequentially_select = sequentially_select
|
|
46
|
+
self._use_og_bet_loss = use_og_bet_loss
|
|
47
|
+
self._use_half_and_half_loss = use_half_and_half_loss
|
|
48
|
+
self.temperature = temperature
|
|
49
|
+
self.device = device
|
|
50
|
+
|
|
51
|
+
if goal_dim <= 0:
|
|
52
|
+
self._cbet_method = self.GOAL_SPEC.unconditional
|
|
53
|
+
# elif obs_dim == goal_dim:
|
|
54
|
+
# self._cbet_method = self.GOAL_SPEC.concat
|
|
55
|
+
# TODO (haritheja): this is a temporary fix, we should be able to handle different types of goals like concat or stack
|
|
56
|
+
else:
|
|
57
|
+
self._cbet_method = self.GOAL_SPEC.stack
|
|
58
|
+
|
|
59
|
+
self._gpt_model = gpt_model
|
|
60
|
+
self._vqvae_model = vqvae_model
|
|
61
|
+
self._G = self._vqvae_model.vqvae_groups # G(number of groups)
|
|
62
|
+
self._C = self._vqvae_model.vqvae_n_embed # C(number of code integers)
|
|
63
|
+
self._D = self._vqvae_model.embedding_dim # D(embedding dims)
|
|
64
|
+
self._current_steps = 0
|
|
65
|
+
if self.sequentially_select:
|
|
66
|
+
self._map_to_cbet_preds_bin1 = MLP(
|
|
67
|
+
in_channels=gpt_model.config.output_dim,
|
|
68
|
+
hidden_channels=[512, 512, self._C],
|
|
69
|
+
)
|
|
70
|
+
self._map_to_cbet_preds_bin2 = MLP(
|
|
71
|
+
in_channels=gpt_model.config.output_dim + self._C,
|
|
72
|
+
hidden_channels=[512, self._C],
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
self._map_to_cbet_preds_bin = MLP(
|
|
76
|
+
in_channels=gpt_model.config.output_dim,
|
|
77
|
+
hidden_channels=[1024, 1024, self._G * self._C],
|
|
78
|
+
)
|
|
79
|
+
self._map_to_cbet_preds_offset = MLP(
|
|
80
|
+
in_channels=gpt_model.config.output_dim,
|
|
81
|
+
hidden_channels=[
|
|
82
|
+
1024,
|
|
83
|
+
1024,
|
|
84
|
+
self._G * self._C * (act_dim * self.act_window_size),
|
|
85
|
+
],
|
|
86
|
+
)
|
|
87
|
+
self._offset_loss_multiplier = offset_loss_multiplier
|
|
88
|
+
self._secondary_code_multiplier = secondary_code_multiplier
|
|
89
|
+
self._criterion = FocalLoss(gamma=gamma)
|
|
90
|
+
|
|
91
|
+
def forward(
|
|
92
|
+
self,
|
|
93
|
+
obs_seq: torch.Tensor,
|
|
94
|
+
goal_seq: Optional[torch.Tensor],
|
|
95
|
+
action_seq: Optional[torch.Tensor],
|
|
96
|
+
second_half: bool = False,
|
|
97
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
98
|
+
# VQ-BeT doesn't use "padding_seq" and "predict_with_offset" input
|
|
99
|
+
return self._predict(obs_seq, goal_seq, action_seq, second_half)
|
|
100
|
+
|
|
101
|
+
def _predict(
|
|
102
|
+
self,
|
|
103
|
+
obs_seq: torch.Tensor,
|
|
104
|
+
goal_seq: Optional[torch.Tensor],
|
|
105
|
+
action_seq: Optional[torch.Tensor],
|
|
106
|
+
second_half: bool,
|
|
107
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
|
|
108
|
+
"""
|
|
109
|
+
Assume dimensions are N T D for N sequences of T timesteps with dimension D.
|
|
110
|
+
|
|
111
|
+
obs_seq : (Batch size) X (obs window) X (obs dim)
|
|
112
|
+
goal_seq : (Batch size) X (goal window) X (obs dim)
|
|
113
|
+
action_seq : (Batch size) X (total window) X (action dim)
|
|
114
|
+
|
|
115
|
+
assume that
|
|
116
|
+
total_w (total window) of action sequence is
|
|
117
|
+
total_w = obs_w + act_w - 1
|
|
118
|
+
e.g. is observation window is 5, and antion pred window is 3,
|
|
119
|
+
the obs seq (o_{t-4}, o_{t-3}, o_{t-2}, o_{t-1}, o_{t}) will predict (a_{t}, a_{t+1}, a_{t+2})
|
|
120
|
+
|
|
121
|
+
However, we not only predict (a_{t}, a_{t+1}, a_{t+2}), but also all the action seq for all the actions in the sequence
|
|
122
|
+
|
|
123
|
+
o_{t-4} -> | | -> (a_{t-4}, a_{t-3}, a_{t-2})
|
|
124
|
+
o_{t-3} -> | G | -> (a_{t-3}, a_{t-2}, a_{t-1})
|
|
125
|
+
o_{t-2} -> | P | -> (a_{t-2}, a_{t-1}, a_{t})
|
|
126
|
+
o_{t-1} -> | T | -> (a_{t-1}, a_{t}, a_{t+1})
|
|
127
|
+
o_{t} -> | | -> (a_{t}, a_{t+1}, a_{t+2})
|
|
128
|
+
"""
|
|
129
|
+
if obs_seq.shape[1] < self.obs_window_size:
|
|
130
|
+
# if input size is smaller than obs_window size (e.g. the initial steps of env eval episodes,
|
|
131
|
+
# VQ-BeT copy the obs and tile it to match obs_window_size
|
|
132
|
+
obs_seq = torch.cat(
|
|
133
|
+
(
|
|
134
|
+
torch.tile(
|
|
135
|
+
obs_seq[:, 0, :],
|
|
136
|
+
(1, self.obs_window_size - obs_seq.shape[1], 1),
|
|
137
|
+
),
|
|
138
|
+
obs_seq,
|
|
139
|
+
),
|
|
140
|
+
dim=-2,
|
|
141
|
+
)
|
|
142
|
+
if self._cbet_method == self.GOAL_SPEC.stack:
|
|
143
|
+
goal_seq = torch.cat(
|
|
144
|
+
(
|
|
145
|
+
torch.tile(
|
|
146
|
+
goal_seq[:, 0, :],
|
|
147
|
+
(1, self.obs_window_size - goal_seq.shape[1], 1),
|
|
148
|
+
),
|
|
149
|
+
goal_seq,
|
|
150
|
+
),
|
|
151
|
+
dim=-2,
|
|
152
|
+
)
|
|
153
|
+
if self._cbet_method == self.GOAL_SPEC.unconditional:
|
|
154
|
+
gpt_input = obs_seq
|
|
155
|
+
elif self._cbet_method == self.GOAL_SPEC.concat:
|
|
156
|
+
gpt_input = torch.cat([goal_seq, obs_seq], dim=1)
|
|
157
|
+
elif self._cbet_method == self.GOAL_SPEC.stack:
|
|
158
|
+
gpt_input = torch.cat([goal_seq, obs_seq], dim=-1)
|
|
159
|
+
else:
|
|
160
|
+
raise NotImplementedError
|
|
161
|
+
|
|
162
|
+
gpt_output = self._gpt_model(gpt_input)
|
|
163
|
+
if self._cbet_method == self.GOAL_SPEC.concat:
|
|
164
|
+
# Chop off the goal encodings.
|
|
165
|
+
gpt_output = gpt_output[:, goal_seq.size(1) :, :]
|
|
166
|
+
gpt_output = einops.rearrange(gpt_output, "N T (G C) -> (N T) (G C)", G=self._G)
|
|
167
|
+
obs = einops.rearrange(obs_seq, "N T O -> (N T) O")
|
|
168
|
+
obs = obs.unsqueeze(dim=1)
|
|
169
|
+
# note that output of offset network is G C WA,
|
|
170
|
+
# where G is number of 'layers' of Residual VQ-VAE
|
|
171
|
+
# C is number of words in each layer's dictionary
|
|
172
|
+
# and W, A is predicted action window, and predicted action dims
|
|
173
|
+
if self.sequentially_select:
|
|
174
|
+
cbet_logits1 = self._map_to_cbet_preds_bin1(gpt_output)
|
|
175
|
+
cbet_offsets = self._map_to_cbet_preds_offset(gpt_output)
|
|
176
|
+
cbet_offsets = einops.rearrange(
|
|
177
|
+
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C
|
|
178
|
+
)
|
|
179
|
+
cbet_probs1 = torch.softmax(cbet_logits1 / self.temperature, dim=-1)
|
|
180
|
+
NT, choices = cbet_probs1.shape
|
|
181
|
+
G = self._G
|
|
182
|
+
sampled_centers1 = einops.rearrange(
|
|
183
|
+
torch.multinomial(cbet_probs1.view(-1, choices), num_samples=1),
|
|
184
|
+
"(NT) 1 -> NT",
|
|
185
|
+
NT=NT,
|
|
186
|
+
)
|
|
187
|
+
cbet_logits2 = self._map_to_cbet_preds_bin2(
|
|
188
|
+
torch.cat(
|
|
189
|
+
(gpt_output, F.one_hot(sampled_centers1, num_classes=self._C)),
|
|
190
|
+
axis=1,
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
cbet_probs2 = torch.softmax(cbet_logits2 / self.temperature, dim=-1)
|
|
194
|
+
sampled_centers2 = einops.rearrange(
|
|
195
|
+
torch.multinomial(cbet_probs2.view(-1, choices), num_samples=1),
|
|
196
|
+
"(NT) 1 -> NT",
|
|
197
|
+
NT=NT,
|
|
198
|
+
)
|
|
199
|
+
sampled_centers = torch.stack(
|
|
200
|
+
(sampled_centers1, sampled_centers2), axis=1
|
|
201
|
+
) # NT, G
|
|
202
|
+
else:
|
|
203
|
+
cbet_logits = self._map_to_cbet_preds_bin(gpt_output)
|
|
204
|
+
cbet_offsets = self._map_to_cbet_preds_offset(gpt_output)
|
|
205
|
+
cbet_logits = einops.rearrange(
|
|
206
|
+
cbet_logits, "(NT) (G C) -> (NT) G C", G=self._G
|
|
207
|
+
)
|
|
208
|
+
cbet_offsets = einops.rearrange(
|
|
209
|
+
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C
|
|
210
|
+
)
|
|
211
|
+
cbet_probs = torch.softmax(cbet_logits / self.temperature, dim=-1)
|
|
212
|
+
NT, G, choices = cbet_probs.shape
|
|
213
|
+
sampled_centers = einops.rearrange(
|
|
214
|
+
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
|
215
|
+
"(NT G) 1 -> NT G",
|
|
216
|
+
NT=NT,
|
|
217
|
+
)
|
|
218
|
+
if action_seq is not None:
|
|
219
|
+
n, total_w, act_dim = action_seq.shape
|
|
220
|
+
act_w = self._vqvae_model.input_dim_h
|
|
221
|
+
obs_w = total_w + 1 - act_w
|
|
222
|
+
output_shape = (n, obs_w, act_w, act_dim)
|
|
223
|
+
output = torch.empty(output_shape).to(action_seq.device)
|
|
224
|
+
for i in range(obs_w):
|
|
225
|
+
output[:, i, :, :] = action_seq[:, i : i + act_w, :]
|
|
226
|
+
action_seq = einops.rearrange(output, "N T W A -> (N T) W A")
|
|
227
|
+
_, action_bins = self._vqvae_model.get_code(
|
|
228
|
+
action_seq, obs
|
|
229
|
+
) # action_bins: NT, G
|
|
230
|
+
|
|
231
|
+
with torch.no_grad():
|
|
232
|
+
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
|
|
233
|
+
NT, -1, self._D
|
|
234
|
+
)
|
|
235
|
+
return_decoder_input = einops.rearrange(
|
|
236
|
+
centers.clone().detach(), "NT G D -> NT (G D)"
|
|
237
|
+
)
|
|
238
|
+
decoded_action = (
|
|
239
|
+
self._vqvae_model.get_action_from_latent(return_decoder_input, obs)
|
|
240
|
+
.clone()
|
|
241
|
+
.detach()
|
|
242
|
+
) # NT, A
|
|
243
|
+
|
|
244
|
+
def get_offset_from_centers(centers):
|
|
245
|
+
indices = (
|
|
246
|
+
torch.arange(NT, device=self.device).unsqueeze(1),
|
|
247
|
+
torch.arange(self._G, device=self.device).unsqueeze(0),
|
|
248
|
+
centers,
|
|
249
|
+
)
|
|
250
|
+
# Use advanced indexing to sample the values
|
|
251
|
+
sampled_offsets = cbet_offsets[indices] # NT, G, WA or NT, G, A
|
|
252
|
+
sampled_offsets = sampled_offsets.sum(dim=1)
|
|
253
|
+
sampled_offsets = einops.rearrange(
|
|
254
|
+
sampled_offsets, "NT (W A) -> NT W A", W=self._vqvae_model.input_dim_h
|
|
255
|
+
)
|
|
256
|
+
return sampled_offsets
|
|
257
|
+
|
|
258
|
+
if self._use_og_bet_loss and action_seq is not None:
|
|
259
|
+
sampled_offsets = get_offset_from_centers(action_bins)
|
|
260
|
+
else:
|
|
261
|
+
sampled_offsets = get_offset_from_centers(sampled_centers)
|
|
262
|
+
|
|
263
|
+
a_hat = decoded_action + sampled_offsets
|
|
264
|
+
|
|
265
|
+
if action_seq is None:
|
|
266
|
+
return a_hat, None, {}
|
|
267
|
+
# Figure out the loss for the actions.
|
|
268
|
+
# First, we need to find the GT VQ codes for each action.
|
|
269
|
+
|
|
270
|
+
# Now we can compute the loss.
|
|
271
|
+
if action_seq.ndim == 2:
|
|
272
|
+
action_seq = action_seq.unsqueeze(0)
|
|
273
|
+
|
|
274
|
+
offset_target = action_seq - decoded_action
|
|
275
|
+
if self._use_half_and_half_loss and action_seq is not None:
|
|
276
|
+
offset_gt = get_offset_from_centers(action_bins)
|
|
277
|
+
offset_sampled = get_offset_from_centers(sampled_centers)
|
|
278
|
+
offset_loss = 0.5 * torch.nn.L1Loss()(
|
|
279
|
+
offset_target, offset_gt
|
|
280
|
+
) + 0.5 * torch.nn.L1Loss()(offset_target, offset_sampled)
|
|
281
|
+
else:
|
|
282
|
+
offset_loss = torch.nn.L1Loss()(offset_target, sampled_offsets)
|
|
283
|
+
|
|
284
|
+
if self.sequentially_select:
|
|
285
|
+
cbet_loss1 = self._criterion( # F.cross_entropy
|
|
286
|
+
cbet_logits1[:, :],
|
|
287
|
+
action_bins[:, 0],
|
|
288
|
+
)
|
|
289
|
+
cbet_logits2 = self._map_to_cbet_preds_bin2(
|
|
290
|
+
torch.cat(
|
|
291
|
+
(gpt_output, F.one_hot(action_bins[:, 0], num_classes=self._C)),
|
|
292
|
+
axis=1,
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
cbet_loss2 = self._criterion( # F.cross_entropy
|
|
296
|
+
cbet_logits2[:, :],
|
|
297
|
+
action_bins[:, 1],
|
|
298
|
+
)
|
|
299
|
+
else:
|
|
300
|
+
cbet_loss1 = self._criterion( # F.cross_entropy
|
|
301
|
+
cbet_logits[:, 0, :],
|
|
302
|
+
action_bins[:, 0],
|
|
303
|
+
)
|
|
304
|
+
cbet_loss2 = self._criterion( # F.cross_entropy
|
|
305
|
+
cbet_logits[:, 1, :],
|
|
306
|
+
action_bins[:, 1],
|
|
307
|
+
)
|
|
308
|
+
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self._secondary_code_multiplier
|
|
309
|
+
|
|
310
|
+
equal_total_code_rate = (
|
|
311
|
+
torch.sum(
|
|
312
|
+
(torch.sum((action_bins == sampled_centers).int(), axis=1) == G).int()
|
|
313
|
+
)
|
|
314
|
+
/ NT
|
|
315
|
+
)
|
|
316
|
+
equal_primary_code_rate = torch.sum(
|
|
317
|
+
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
|
318
|
+
) / (NT)
|
|
319
|
+
equal_secondary_code_rate = torch.sum(
|
|
320
|
+
(action_bins[:, 1] == sampled_centers[:, 1]).int()
|
|
321
|
+
) / (NT)
|
|
322
|
+
# if second_half:
|
|
323
|
+
# cbet_loss = cbet_loss * 0
|
|
324
|
+
loss = cbet_loss + self._offset_loss_multiplier * offset_loss
|
|
325
|
+
action_mse = F.mse_loss(a_hat, action_seq, reduction="none")
|
|
326
|
+
action_l1 = F.l1_loss(a_hat, action_seq, reduction="none")
|
|
327
|
+
norm = torch.norm(action_seq, p=2, dim=-1, keepdim=True) + 1e-9
|
|
328
|
+
normalized_mse = (action_mse / norm).mean()
|
|
329
|
+
|
|
330
|
+
translation_loss = F.mse_loss(a_hat[:, :, :3], action_seq[:, :, :3]).detach()
|
|
331
|
+
rotation_loss = F.mse_loss(a_hat[:, :, 3:6], action_seq[:, :, 3:6]).detach()
|
|
332
|
+
gripper_loss = F.mse_loss(a_hat[:, :, 6:], action_seq[:, :, 6:]).detach()
|
|
333
|
+
|
|
334
|
+
loss_dict = {
|
|
335
|
+
"classification_loss": cbet_loss.detach().cpu().item(),
|
|
336
|
+
"offset_loss": offset_loss.detach().cpu().item(),
|
|
337
|
+
"loss": loss.detach().cpu().item(),
|
|
338
|
+
"equal_total_code_rate": equal_total_code_rate,
|
|
339
|
+
"equal_primary_code_rate": equal_primary_code_rate,
|
|
340
|
+
"equal_secondary_code_rate": equal_secondary_code_rate,
|
|
341
|
+
"L2_loss": action_mse.mean().detach().cpu().item(),
|
|
342
|
+
"L2_loss_normalized": normalized_mse.mean().detach().cpu().item(),
|
|
343
|
+
"L1_loss": action_l1.mean().detach().cpu().item(),
|
|
344
|
+
"translation_loss": translation_loss,
|
|
345
|
+
"rotation_loss": rotation_loss,
|
|
346
|
+
"gripper_loss": gripper_loss,
|
|
347
|
+
}
|
|
348
|
+
return a_hat, loss, loss_dict
|
|
349
|
+
|
|
350
|
+
# def configure_optimizers(self, weight_decay, learning_rate, betas):
|
|
351
|
+
|
|
352
|
+
# optimizer1 = self._gpt_model.configure_optimizers(
|
|
353
|
+
# weight_decay=weight_decay,
|
|
354
|
+
# learning_rate=learning_rate,
|
|
355
|
+
# betas=betas,
|
|
356
|
+
# )
|
|
357
|
+
# if self.sequentially_select:
|
|
358
|
+
# optimizer1.add_param_group({"params": self._map_to_cbet_preds_bin1.parameters()})
|
|
359
|
+
# optimizer1.add_param_group({"params": self._map_to_cbet_preds_bin2.parameters()})
|
|
360
|
+
# else:
|
|
361
|
+
# optimizer1.add_param_group({"params": self._map_to_cbet_preds_bin.parameters()})
|
|
362
|
+
# optimizer2 = torch.optim.AdamW(self._map_to_cbet_preds_offset.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=betas)
|
|
363
|
+
# return {"optimizer1": optimizer1, "optimizer2": optimizer2}
|
|
364
|
+
|
|
365
|
+
def _begin_epoch(self, optimizer, **kwargs):
|
|
366
|
+
# log codebook usage rate for debugging
|
|
367
|
+
# lr_0 = optimizer.param_groups[0]["lr"]
|
|
368
|
+
# lr_neg1 = optimizer.param_groups[-1]["lr"]
|
|
369
|
+
# return {"lr_0": lr_0, "lr_neg1": lr_neg1}
|
|
370
|
+
return None
|
|
371
|
+
|
|
372
|
+
def _load_from_state_dict(self, *args, **kwargs):
|
|
373
|
+
# Don't fit kmeans if we are loading from a state dict.
|
|
374
|
+
# if (path / "cbet_model.pt").exists():
|
|
375
|
+
# self.load_state_dict(torch.load(path / "cbet_model.pt"))
|
|
376
|
+
# elif (path / "gpt_model.pt").exists():
|
|
377
|
+
# self._gpt_model.load_state_dict(torch.load(path / "gpt_model.pt"))
|
|
378
|
+
# else:
|
|
379
|
+
# logging.warning("No model found at %s", path)
|
|
380
|
+
return super()._load_from_state_dict(*args, **kwargs)
|
|
381
|
+
|
|
382
|
+
# def load_model(self, path):
|
|
383
|
+
# if (path / "cbet_model.pt").exists():
|
|
384
|
+
# self.load_state_dict(torch.load(path / "cbet_model.pt"))
|
|
385
|
+
# elif (path / "gpt_model.pt").exists():
|
|
386
|
+
# self._gpt_model.load_state_dict(torch.load(path / "gpt_model.pt"))
|
|
387
|
+
# else:
|
|
388
|
+
# logging.warning("No model found at %s", path)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class FocalLoss(nn.Module):
|
|
392
|
+
def __init__(self, gamma: float = 0, reduction: str = "mean"):
|
|
393
|
+
super(FocalLoss, self).__init__()
|
|
394
|
+
self.gamma = gamma
|
|
395
|
+
if reduction not in ("mean", "sum", "none"):
|
|
396
|
+
raise NotImplementedError
|
|
397
|
+
self.reduction = reduction
|
|
398
|
+
|
|
399
|
+
def forward(self, input, target):
|
|
400
|
+
logpt = F.log_softmax(input, dim=-1)
|
|
401
|
+
logpt = logpt.gather(1, target.view(-1, 1)).view(-1)
|
|
402
|
+
pt = logpt.exp()
|
|
403
|
+
|
|
404
|
+
loss = -1 * (1 - pt) ** self.gamma * logpt
|
|
405
|
+
if self.reduction == "mean":
|
|
406
|
+
return loss.mean()
|
|
407
|
+
elif self.reduction == "sum":
|
|
408
|
+
return loss.sum()
|
|
409
|
+
else:
|
|
410
|
+
return loss
|