egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,881 @@
|
|
|
1
|
+
from typing import Tuple, Sequence, Dict, Union, Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
7
|
+
from diffusers.training_utils import EMAModel
|
|
8
|
+
from diffusers.optimization import get_scheduler
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import einops
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
GENERATOR_SEED_FIXED = 123456789
|
|
15
|
+
|
|
16
|
+
class SinusoidalPosEmb(nn.Module):
|
|
17
|
+
def __init__(self, dim):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.dim = dim
|
|
20
|
+
|
|
21
|
+
def forward(self, x):
|
|
22
|
+
device = x.device
|
|
23
|
+
half_dim = self.dim // 2
|
|
24
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
25
|
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
|
26
|
+
emb = x[:, None] * emb[None, :]
|
|
27
|
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
|
28
|
+
return emb
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Downsample1d(nn.Module):
|
|
32
|
+
def __init__(self, dim):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
|
35
|
+
|
|
36
|
+
def forward(self, x):
|
|
37
|
+
return self.conv(x)
|
|
38
|
+
|
|
39
|
+
class Upsample1d(nn.Module):
|
|
40
|
+
def __init__(self, dim):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
return self.conv(x)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Conv1dBlock(nn.Module):
|
|
49
|
+
'''
|
|
50
|
+
Conv1d --> GroupNorm --> Mish
|
|
51
|
+
'''
|
|
52
|
+
|
|
53
|
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
|
54
|
+
super().__init__()
|
|
55
|
+
|
|
56
|
+
self.block = nn.Sequential(
|
|
57
|
+
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
|
58
|
+
nn.GroupNorm(n_groups, out_channels),
|
|
59
|
+
nn.Mish(),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def forward(self, x):
|
|
63
|
+
return self.block(x)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ConditionalResidualBlock1D(nn.Module):
|
|
67
|
+
def __init__(self,
|
|
68
|
+
in_channels,
|
|
69
|
+
out_channels,
|
|
70
|
+
cond_dim,
|
|
71
|
+
kernel_size=3,
|
|
72
|
+
n_groups=8):
|
|
73
|
+
super().__init__()
|
|
74
|
+
|
|
75
|
+
self.blocks = nn.ModuleList([
|
|
76
|
+
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
|
77
|
+
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
|
78
|
+
])
|
|
79
|
+
|
|
80
|
+
# FiLM modulation https://arxiv.org/abs/1709.07871
|
|
81
|
+
# predicts per-channel scale and bias
|
|
82
|
+
cond_channels = out_channels * 2
|
|
83
|
+
self.out_channels = out_channels
|
|
84
|
+
self.cond_encoder = nn.Sequential(
|
|
85
|
+
nn.Mish(),
|
|
86
|
+
nn.Linear(cond_dim, cond_channels),
|
|
87
|
+
nn.Unflatten(-1, (-1, 1))
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# make sure dimensions compatible
|
|
91
|
+
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
|
92
|
+
if in_channels != out_channels else nn.Identity()
|
|
93
|
+
|
|
94
|
+
def forward(self, x, cond):
|
|
95
|
+
'''
|
|
96
|
+
x : [ batch_size x in_channels x horizon ]
|
|
97
|
+
cond : [ batch_size x cond_dim]
|
|
98
|
+
|
|
99
|
+
returns:
|
|
100
|
+
out : [ batch_size x out_channels x horizon ]
|
|
101
|
+
'''
|
|
102
|
+
out = self.blocks[0](x)
|
|
103
|
+
embed = self.cond_encoder(cond)
|
|
104
|
+
|
|
105
|
+
embed = embed.reshape(
|
|
106
|
+
embed.shape[0], 2, self.out_channels, 1)
|
|
107
|
+
scale = embed[:,0,...]
|
|
108
|
+
bias = embed[:,1,...]
|
|
109
|
+
out = scale * out + bias
|
|
110
|
+
|
|
111
|
+
out = self.blocks[1](out)
|
|
112
|
+
out = out + self.residual_conv(x)
|
|
113
|
+
return out
|
|
114
|
+
|
|
115
|
+
class ModuleAttrMixin(nn.Module):
|
|
116
|
+
def __init__(self):
|
|
117
|
+
super().__init__()
|
|
118
|
+
self._dummy_variable = nn.Parameter()
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def device(self):
|
|
122
|
+
return next(iter(self.parameters())).device
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def dtype(self):
|
|
126
|
+
return next(iter(self.parameters())).dtype
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class TransformerForDiffusion(ModuleAttrMixin):
|
|
131
|
+
def __init__(self,
|
|
132
|
+
input_dim: int,
|
|
133
|
+
output_dim: int,
|
|
134
|
+
horizon: int,
|
|
135
|
+
n_obs_steps: int = None,
|
|
136
|
+
cond_dim: int = 0,
|
|
137
|
+
n_layer: int = 12,
|
|
138
|
+
n_head: int = 12,
|
|
139
|
+
n_emb: int = 768,
|
|
140
|
+
p_drop_emb: float = 0.1,
|
|
141
|
+
p_drop_attn: float = 0.1,
|
|
142
|
+
causal_attn: bool=False,
|
|
143
|
+
time_as_cond: bool=True,
|
|
144
|
+
obs_as_cond: bool=False,
|
|
145
|
+
n_cond_layers: int = 0
|
|
146
|
+
) -> None:
|
|
147
|
+
super().__init__()
|
|
148
|
+
|
|
149
|
+
# compute number of tokens for main trunk and condition encoder
|
|
150
|
+
if n_obs_steps is None:
|
|
151
|
+
n_obs_steps = horizon
|
|
152
|
+
|
|
153
|
+
T = horizon
|
|
154
|
+
T_cond = 1
|
|
155
|
+
if not time_as_cond:
|
|
156
|
+
T += 1
|
|
157
|
+
T_cond -= 1
|
|
158
|
+
obs_as_cond = cond_dim > 0
|
|
159
|
+
if obs_as_cond:
|
|
160
|
+
assert time_as_cond
|
|
161
|
+
T_cond += n_obs_steps
|
|
162
|
+
|
|
163
|
+
# input embedding stem
|
|
164
|
+
self.input_emb = nn.Linear(input_dim, n_emb)
|
|
165
|
+
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
|
166
|
+
self.drop = nn.Dropout(p_drop_emb)
|
|
167
|
+
|
|
168
|
+
# cond encoder
|
|
169
|
+
self.time_emb = SinusoidalPosEmb(n_emb)
|
|
170
|
+
self.cond_obs_emb = None
|
|
171
|
+
|
|
172
|
+
if obs_as_cond:
|
|
173
|
+
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
|
174
|
+
|
|
175
|
+
self.cond_pos_emb = None
|
|
176
|
+
self.encoder = None
|
|
177
|
+
self.decoder = None
|
|
178
|
+
encoder_only = False
|
|
179
|
+
if T_cond > 0:
|
|
180
|
+
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
|
181
|
+
if n_cond_layers > 0:
|
|
182
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
183
|
+
d_model=n_emb,
|
|
184
|
+
nhead=n_head,
|
|
185
|
+
dim_feedforward=4*n_emb,
|
|
186
|
+
dropout=p_drop_attn,
|
|
187
|
+
activation='gelu',
|
|
188
|
+
batch_first=True,
|
|
189
|
+
norm_first=True
|
|
190
|
+
)
|
|
191
|
+
self.encoder = nn.TransformerEncoder(
|
|
192
|
+
encoder_layer=encoder_layer,
|
|
193
|
+
num_layers=n_cond_layers
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
self.encoder = nn.Sequential(
|
|
197
|
+
nn.Linear(n_emb, 4 * n_emb),
|
|
198
|
+
nn.Mish(),
|
|
199
|
+
nn.Linear(4 * n_emb, n_emb)
|
|
200
|
+
)
|
|
201
|
+
# decoder
|
|
202
|
+
decoder_layer = nn.TransformerDecoderLayer(
|
|
203
|
+
d_model=n_emb,
|
|
204
|
+
nhead=n_head,
|
|
205
|
+
dim_feedforward=4*n_emb,
|
|
206
|
+
dropout=p_drop_attn,
|
|
207
|
+
activation='gelu',
|
|
208
|
+
batch_first=True,
|
|
209
|
+
norm_first=True # important for stability
|
|
210
|
+
)
|
|
211
|
+
self.decoder = nn.TransformerDecoder(
|
|
212
|
+
decoder_layer=decoder_layer,
|
|
213
|
+
num_layers=n_layer
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
# encoder only BERT
|
|
217
|
+
encoder_only = True
|
|
218
|
+
|
|
219
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
220
|
+
d_model=n_emb,
|
|
221
|
+
nhead=n_head,
|
|
222
|
+
dim_feedforward=4*n_emb,
|
|
223
|
+
dropout=p_drop_attn,
|
|
224
|
+
activation='gelu',
|
|
225
|
+
batch_first=True,
|
|
226
|
+
norm_first=True
|
|
227
|
+
)
|
|
228
|
+
self.encoder = nn.TransformerEncoder(
|
|
229
|
+
encoder_layer=encoder_layer,
|
|
230
|
+
num_layers=n_layer
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# attention mask
|
|
234
|
+
if causal_attn:
|
|
235
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
236
|
+
# torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT
|
|
237
|
+
# therefore, the upper triangle should be -inf and others (including diag) should be 0.
|
|
238
|
+
sz = T
|
|
239
|
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
240
|
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
241
|
+
self.register_buffer("mask", mask)
|
|
242
|
+
|
|
243
|
+
if time_as_cond and obs_as_cond:
|
|
244
|
+
S = T_cond
|
|
245
|
+
t, s = torch.meshgrid(
|
|
246
|
+
torch.arange(T),
|
|
247
|
+
torch.arange(S),
|
|
248
|
+
indexing='ij'
|
|
249
|
+
)
|
|
250
|
+
mask = t >= (s-1) # add one dimension since time is the first token in cond
|
|
251
|
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
252
|
+
self.register_buffer('memory_mask', mask)
|
|
253
|
+
else:
|
|
254
|
+
self.memory_mask = None
|
|
255
|
+
else:
|
|
256
|
+
self.mask = None
|
|
257
|
+
self.memory_mask = None
|
|
258
|
+
|
|
259
|
+
# decoder head
|
|
260
|
+
self.ln_f = nn.LayerNorm(n_emb)
|
|
261
|
+
self.head = nn.Linear(n_emb, output_dim)
|
|
262
|
+
|
|
263
|
+
# constants
|
|
264
|
+
self.T = T
|
|
265
|
+
self.T_cond = T_cond
|
|
266
|
+
self.horizon = horizon
|
|
267
|
+
self.time_as_cond = time_as_cond
|
|
268
|
+
self.obs_as_cond = obs_as_cond
|
|
269
|
+
self.encoder_only = encoder_only
|
|
270
|
+
|
|
271
|
+
# init
|
|
272
|
+
self.apply(self._init_weights)
|
|
273
|
+
# logger.info(
|
|
274
|
+
# "number of parameters: %e", sum(p.numel() for p in self.parameters())
|
|
275
|
+
# )
|
|
276
|
+
|
|
277
|
+
def _init_weights(self, module):
|
|
278
|
+
ignore_types = (nn.Dropout,
|
|
279
|
+
SinusoidalPosEmb,
|
|
280
|
+
nn.TransformerEncoderLayer,
|
|
281
|
+
nn.TransformerDecoderLayer,
|
|
282
|
+
nn.TransformerEncoder,
|
|
283
|
+
nn.TransformerDecoder,
|
|
284
|
+
nn.ModuleList,
|
|
285
|
+
nn.Mish,
|
|
286
|
+
nn.Sequential)
|
|
287
|
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
288
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
289
|
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
290
|
+
torch.nn.init.zeros_(module.bias)
|
|
291
|
+
elif isinstance(module, nn.MultiheadAttention):
|
|
292
|
+
weight_names = [
|
|
293
|
+
'in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
|
|
294
|
+
for name in weight_names:
|
|
295
|
+
weight = getattr(module, name)
|
|
296
|
+
if weight is not None:
|
|
297
|
+
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
|
298
|
+
|
|
299
|
+
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
|
|
300
|
+
for name in bias_names:
|
|
301
|
+
bias = getattr(module, name)
|
|
302
|
+
if bias is not None:
|
|
303
|
+
torch.nn.init.zeros_(bias)
|
|
304
|
+
elif isinstance(module, nn.LayerNorm):
|
|
305
|
+
torch.nn.init.zeros_(module.bias)
|
|
306
|
+
torch.nn.init.ones_(module.weight)
|
|
307
|
+
elif isinstance(module, TransformerForDiffusion):
|
|
308
|
+
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
|
309
|
+
if module.cond_obs_emb is not None:
|
|
310
|
+
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
|
311
|
+
elif isinstance(module, ignore_types):
|
|
312
|
+
# no param
|
|
313
|
+
pass
|
|
314
|
+
else:
|
|
315
|
+
raise RuntimeError("Unaccounted module {}".format(module))
|
|
316
|
+
|
|
317
|
+
def get_optim_groups(self, weight_decay: float=1e-3):
|
|
318
|
+
"""
|
|
319
|
+
This long function is unfortunately doing something very simple and is being very defensive:
|
|
320
|
+
We are separating out all parameters of the model into two buckets: those that will experience
|
|
321
|
+
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
|
322
|
+
We are then returning the PyTorch optimizer object.
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
326
|
+
decay = set()
|
|
327
|
+
no_decay = set()
|
|
328
|
+
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
|
329
|
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
330
|
+
for mn, m in self.named_modules():
|
|
331
|
+
for pn, p in m.named_parameters():
|
|
332
|
+
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
|
|
333
|
+
|
|
334
|
+
if pn.endswith("bias"):
|
|
335
|
+
# all biases will not be decayed
|
|
336
|
+
no_decay.add(fpn)
|
|
337
|
+
elif pn.startswith("bias"):
|
|
338
|
+
# MultiheadAttention bias starts with "bias"
|
|
339
|
+
no_decay.add(fpn)
|
|
340
|
+
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
|
341
|
+
# weights of whitelist modules will be weight decayed
|
|
342
|
+
decay.add(fpn)
|
|
343
|
+
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
|
344
|
+
# weights of blacklist modules will NOT be weight decayed
|
|
345
|
+
no_decay.add(fpn)
|
|
346
|
+
|
|
347
|
+
# special case the position embedding parameter in the root GPT module as not decayed
|
|
348
|
+
no_decay.add("pos_emb")
|
|
349
|
+
no_decay.add("_dummy_variable")
|
|
350
|
+
if self.cond_pos_emb is not None:
|
|
351
|
+
no_decay.add("cond_pos_emb")
|
|
352
|
+
|
|
353
|
+
# validate that we considered every parameter
|
|
354
|
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
355
|
+
inter_params = decay & no_decay
|
|
356
|
+
union_params = decay | no_decay
|
|
357
|
+
assert (
|
|
358
|
+
len(inter_params) == 0
|
|
359
|
+
), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
|
|
360
|
+
assert (
|
|
361
|
+
len(param_dict.keys() - union_params) == 0
|
|
362
|
+
), "parameters %s were not separated into either decay/no_decay set!" % (
|
|
363
|
+
str(param_dict.keys() - union_params),
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# create the pytorch optimizer object
|
|
367
|
+
optim_groups = [
|
|
368
|
+
{
|
|
369
|
+
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
|
370
|
+
"weight_decay": weight_decay,
|
|
371
|
+
},
|
|
372
|
+
{
|
|
373
|
+
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
|
374
|
+
"weight_decay": 0.0,
|
|
375
|
+
},
|
|
376
|
+
]
|
|
377
|
+
return optim_groups
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def configure_optimizers(self,
|
|
381
|
+
learning_rate: float=1e-4,
|
|
382
|
+
weight_decay: float=1e-3,
|
|
383
|
+
betas: Tuple[float, float]=(0.9,0.95)):
|
|
384
|
+
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
|
385
|
+
optimizer = torch.optim.AdamW(
|
|
386
|
+
optim_groups, lr=learning_rate, betas=betas
|
|
387
|
+
)
|
|
388
|
+
return optimizer
|
|
389
|
+
|
|
390
|
+
def forward(self,
|
|
391
|
+
sample: torch.Tensor,
|
|
392
|
+
timestep: Union[torch.Tensor, float, int],
|
|
393
|
+
cond: Optional[torch.Tensor]=None, **kwargs):
|
|
394
|
+
"""
|
|
395
|
+
x: (B,T,input_dim)
|
|
396
|
+
timestep: (B,) or int, diffusion step
|
|
397
|
+
cond: (B,T',cond_dim)
|
|
398
|
+
output: (B,T,input_dim)
|
|
399
|
+
"""
|
|
400
|
+
# 1. time
|
|
401
|
+
timesteps = timestep
|
|
402
|
+
if not torch.is_tensor(timesteps):
|
|
403
|
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
|
404
|
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
|
405
|
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
|
406
|
+
timesteps = timesteps[None].to(sample.device)
|
|
407
|
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
408
|
+
timesteps = timesteps.expand(sample.shape[0])
|
|
409
|
+
time_emb = self.time_emb(timesteps).unsqueeze(1)
|
|
410
|
+
# (B,1,n_emb)
|
|
411
|
+
|
|
412
|
+
# process input
|
|
413
|
+
input_emb = self.input_emb(sample)
|
|
414
|
+
|
|
415
|
+
if self.encoder_only:
|
|
416
|
+
# BERT
|
|
417
|
+
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
|
418
|
+
t = token_embeddings.shape[1]
|
|
419
|
+
position_embeddings = self.pos_emb[
|
|
420
|
+
:, :t, :
|
|
421
|
+
] # each position maps to a (learnable) vector
|
|
422
|
+
x = self.drop(token_embeddings + position_embeddings)
|
|
423
|
+
# (B,T+1,n_emb)
|
|
424
|
+
x = self.encoder(src=x, mask=self.mask)
|
|
425
|
+
# (B,T+1,n_emb)
|
|
426
|
+
x = x[:,1:,:]
|
|
427
|
+
# (B,T,n_emb)
|
|
428
|
+
else:
|
|
429
|
+
# encoder
|
|
430
|
+
cond_embeddings = time_emb
|
|
431
|
+
if self.obs_as_cond:
|
|
432
|
+
cond_obs_emb = self.cond_obs_emb(cond)
|
|
433
|
+
# (B,To,n_emb)
|
|
434
|
+
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
|
435
|
+
tc = cond_embeddings.shape[1]
|
|
436
|
+
position_embeddings = self.cond_pos_emb[
|
|
437
|
+
:, :tc, :
|
|
438
|
+
] # each position maps to a (learnable) vector
|
|
439
|
+
x = self.drop(cond_embeddings + position_embeddings)
|
|
440
|
+
x = self.encoder(x)
|
|
441
|
+
memory = x
|
|
442
|
+
# (B,T_cond,n_emb)
|
|
443
|
+
|
|
444
|
+
# decoder
|
|
445
|
+
token_embeddings = input_emb
|
|
446
|
+
t = token_embeddings.shape[1]
|
|
447
|
+
position_embeddings = self.pos_emb[
|
|
448
|
+
:, :t, :
|
|
449
|
+
] # each position maps to a (learnable) vector
|
|
450
|
+
x = self.drop(token_embeddings + position_embeddings)
|
|
451
|
+
# (B,T,n_emb)
|
|
452
|
+
x = self.decoder(
|
|
453
|
+
tgt=x,
|
|
454
|
+
memory=memory,
|
|
455
|
+
tgt_mask=self.mask,
|
|
456
|
+
memory_mask=self.memory_mask
|
|
457
|
+
)
|
|
458
|
+
# (B,T,n_emb)
|
|
459
|
+
|
|
460
|
+
# head
|
|
461
|
+
x = self.ln_f(x)
|
|
462
|
+
x = self.head(x)
|
|
463
|
+
# (B,T,n_out)
|
|
464
|
+
return x
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class ConditionalUnet1D(nn.Module):
|
|
469
|
+
def __init__(self,
|
|
470
|
+
input_dim,
|
|
471
|
+
global_cond_dim,
|
|
472
|
+
diffusion_step_embed_dim=256,
|
|
473
|
+
down_dims=[256,512,1024],
|
|
474
|
+
kernel_size=5,
|
|
475
|
+
n_groups=8
|
|
476
|
+
):
|
|
477
|
+
"""
|
|
478
|
+
input_dim: Dim of actions.
|
|
479
|
+
global_cond_dim: Dim of global conditioning applied with FiLM
|
|
480
|
+
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
|
|
481
|
+
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
|
|
482
|
+
down_dims: Channel size for each UNet level.
|
|
483
|
+
The length of this array determines numebr of levels.
|
|
484
|
+
kernel_size: Conv kernel size
|
|
485
|
+
n_groups: Number of groups for GroupNorm
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
super().__init__()
|
|
489
|
+
all_dims = [input_dim] + list(down_dims)
|
|
490
|
+
start_dim = down_dims[0]
|
|
491
|
+
|
|
492
|
+
dsed = diffusion_step_embed_dim
|
|
493
|
+
diffusion_step_encoder = nn.Sequential(
|
|
494
|
+
SinusoidalPosEmb(dsed),
|
|
495
|
+
nn.Linear(dsed, dsed * 4),
|
|
496
|
+
nn.Mish(),
|
|
497
|
+
nn.Linear(dsed * 4, dsed),
|
|
498
|
+
)
|
|
499
|
+
cond_dim = dsed + global_cond_dim
|
|
500
|
+
|
|
501
|
+
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
|
502
|
+
mid_dim = all_dims[-1]
|
|
503
|
+
self.mid_modules = nn.ModuleList([
|
|
504
|
+
ConditionalResidualBlock1D(
|
|
505
|
+
mid_dim, mid_dim, cond_dim=cond_dim,
|
|
506
|
+
kernel_size=kernel_size, n_groups=n_groups
|
|
507
|
+
),
|
|
508
|
+
ConditionalResidualBlock1D(
|
|
509
|
+
mid_dim, mid_dim, cond_dim=cond_dim,
|
|
510
|
+
kernel_size=kernel_size, n_groups=n_groups
|
|
511
|
+
),
|
|
512
|
+
])
|
|
513
|
+
|
|
514
|
+
down_modules = nn.ModuleList([])
|
|
515
|
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
516
|
+
is_last = ind >= (len(in_out) - 1)
|
|
517
|
+
down_modules.append(nn.ModuleList([
|
|
518
|
+
ConditionalResidualBlock1D(
|
|
519
|
+
dim_in, dim_out, cond_dim=cond_dim,
|
|
520
|
+
kernel_size=kernel_size, n_groups=n_groups),
|
|
521
|
+
ConditionalResidualBlock1D(
|
|
522
|
+
dim_out, dim_out, cond_dim=cond_dim,
|
|
523
|
+
kernel_size=kernel_size, n_groups=n_groups),
|
|
524
|
+
Downsample1d(dim_out) if not is_last else nn.Identity()
|
|
525
|
+
]))
|
|
526
|
+
|
|
527
|
+
up_modules = nn.ModuleList([])
|
|
528
|
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
|
529
|
+
is_last = ind >= (len(in_out) - 1)
|
|
530
|
+
up_modules.append(nn.ModuleList([
|
|
531
|
+
ConditionalResidualBlock1D(
|
|
532
|
+
dim_out*2, dim_in, cond_dim=cond_dim,
|
|
533
|
+
kernel_size=kernel_size, n_groups=n_groups),
|
|
534
|
+
ConditionalResidualBlock1D(
|
|
535
|
+
dim_in, dim_in, cond_dim=cond_dim,
|
|
536
|
+
kernel_size=kernel_size, n_groups=n_groups),
|
|
537
|
+
Upsample1d(dim_in) if not is_last else nn.Identity()
|
|
538
|
+
]))
|
|
539
|
+
|
|
540
|
+
final_conv = nn.Sequential(
|
|
541
|
+
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
|
542
|
+
nn.Conv1d(start_dim, input_dim, 1),
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
self.diffusion_step_encoder = diffusion_step_encoder
|
|
546
|
+
self.up_modules = up_modules
|
|
547
|
+
self.down_modules = down_modules
|
|
548
|
+
self.final_conv = final_conv
|
|
549
|
+
|
|
550
|
+
print("number of parameters: {:e}".format(
|
|
551
|
+
sum(p.numel() for p in self.parameters()))
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
def forward(self,
|
|
555
|
+
sample: torch.Tensor,
|
|
556
|
+
timestep: Union[torch.Tensor, float, int],
|
|
557
|
+
global_cond=None):
|
|
558
|
+
"""
|
|
559
|
+
x: (B,T,input_dim)
|
|
560
|
+
timestep: (B,) or int, diffusion step
|
|
561
|
+
global_cond: (B,global_cond_dim)
|
|
562
|
+
output: (B,T,input_dim)
|
|
563
|
+
"""
|
|
564
|
+
# (B,T,C)
|
|
565
|
+
sample = sample.moveaxis(-1,-2)
|
|
566
|
+
# (B,C,T)
|
|
567
|
+
|
|
568
|
+
# 1. time
|
|
569
|
+
timesteps = timestep
|
|
570
|
+
if not torch.is_tensor(timesteps):
|
|
571
|
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
|
572
|
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
|
573
|
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
|
574
|
+
timesteps = timesteps[None].to(sample.device)
|
|
575
|
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
576
|
+
timesteps = timesteps.expand(sample.shape[0])
|
|
577
|
+
|
|
578
|
+
global_feature = self.diffusion_step_encoder(timesteps)
|
|
579
|
+
|
|
580
|
+
if global_cond is not None:
|
|
581
|
+
global_feature = torch.cat([
|
|
582
|
+
global_feature, global_cond
|
|
583
|
+
], axis=-1)
|
|
584
|
+
|
|
585
|
+
x = sample
|
|
586
|
+
h = []
|
|
587
|
+
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
|
588
|
+
x = resnet(x, global_feature)
|
|
589
|
+
x = resnet2(x, global_feature)
|
|
590
|
+
h.append(x)
|
|
591
|
+
x = downsample(x)
|
|
592
|
+
|
|
593
|
+
for mid_module in self.mid_modules:
|
|
594
|
+
x = mid_module(x, global_feature)
|
|
595
|
+
|
|
596
|
+
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
|
597
|
+
x = torch.cat((x, h.pop()), dim=1)
|
|
598
|
+
x = resnet(x, global_feature)
|
|
599
|
+
x = resnet2(x, global_feature)
|
|
600
|
+
x = upsample(x)
|
|
601
|
+
|
|
602
|
+
x = self.final_conv(x)
|
|
603
|
+
|
|
604
|
+
# (B,C,T)
|
|
605
|
+
x = x.moveaxis(-1,-2)
|
|
606
|
+
# (B,T,C)
|
|
607
|
+
return x
|
|
608
|
+
|
|
609
|
+
class DiffusionPolicy(nn.Module):
|
|
610
|
+
# observation and action dimensions corrsponding to
|
|
611
|
+
# the output of PushTEnv
|
|
612
|
+
def __init__(
|
|
613
|
+
self,
|
|
614
|
+
obs_dim: int,
|
|
615
|
+
act_dim: int,
|
|
616
|
+
obs_horizon: int,
|
|
617
|
+
pred_horizon: int,
|
|
618
|
+
action_horizon: int,
|
|
619
|
+
data_act_scale = 1.0,
|
|
620
|
+
data_obs_scale = 1.0,
|
|
621
|
+
policy_type = 'cnn',
|
|
622
|
+
device = 'cuda',
|
|
623
|
+
):
|
|
624
|
+
super().__init__()
|
|
625
|
+
self.obs_dim = obs_dim
|
|
626
|
+
self.action_dim = act_dim
|
|
627
|
+
self.obs_horizon = obs_horizon
|
|
628
|
+
self.pred_horizon = pred_horizon
|
|
629
|
+
self.action_horizon = action_horizon
|
|
630
|
+
self.data_act_scale = data_act_scale
|
|
631
|
+
self.data_obs_scale = data_obs_scale
|
|
632
|
+
self.policy_type = policy_type
|
|
633
|
+
self.device = device
|
|
634
|
+
if self.policy_type == "cnn":
|
|
635
|
+
if self.action_horizon == 4:
|
|
636
|
+
self.pad_before = 1
|
|
637
|
+
self.pad_after = 2
|
|
638
|
+
self.pred_horizon = pred_horizon + self.pad_before + self.pad_after
|
|
639
|
+
if self.action_horizon == 6:
|
|
640
|
+
self.pad_before = 0
|
|
641
|
+
self.pad_after = 1
|
|
642
|
+
self.pred_horizon = pred_horizon + self.pad_before + self.pad_after
|
|
643
|
+
# create network object
|
|
644
|
+
if self.policy_type == "cnn":
|
|
645
|
+
self.noise_pred_net = ConditionalUnet1D(
|
|
646
|
+
input_dim=self.action_dim,
|
|
647
|
+
global_cond_dim=self.obs_dim*self.obs_horizon
|
|
648
|
+
).to(self.device)
|
|
649
|
+
elif self.policy_type == "transformer":
|
|
650
|
+
self.noise_pred_net = TransformerForDiffusion(
|
|
651
|
+
input_dim=self.action_dim,
|
|
652
|
+
output_dim=self.action_dim,
|
|
653
|
+
horizon=pred_horizon,
|
|
654
|
+
n_obs_steps=obs_horizon,
|
|
655
|
+
cond_dim=self.obs_dim,
|
|
656
|
+
n_layer=8,
|
|
657
|
+
n_head=4,
|
|
658
|
+
n_emb=768,
|
|
659
|
+
p_drop_emb=0.0,
|
|
660
|
+
p_drop_attn=0.1,
|
|
661
|
+
causal_attn=True,
|
|
662
|
+
time_as_cond=True,
|
|
663
|
+
obs_as_cond=True,
|
|
664
|
+
n_cond_layers=0
|
|
665
|
+
).to(self.device)
|
|
666
|
+
else:
|
|
667
|
+
raise NotImplementedError
|
|
668
|
+
|
|
669
|
+
# for this demo, we use DDPMScheduler with 100 diffusion iterations
|
|
670
|
+
self.num_diffusion_iters = 100
|
|
671
|
+
self.noise_scheduler = DDPMScheduler(
|
|
672
|
+
num_train_timesteps=self.num_diffusion_iters,
|
|
673
|
+
# the choise of beta schedule has big impact on performance
|
|
674
|
+
# we found squared cosine works the best
|
|
675
|
+
beta_schedule='squaredcos_cap_v2',
|
|
676
|
+
# clip output to [-1,1] to improve stability
|
|
677
|
+
clip_sample=True,
|
|
678
|
+
# our network predicts noise (instead of denoised action)
|
|
679
|
+
prediction_type='epsilon'
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
self.ema = EMAModel(
|
|
683
|
+
model=self.noise_pred_net,
|
|
684
|
+
inv_gamma= 1.0,
|
|
685
|
+
max_value= 0.9999,
|
|
686
|
+
min_value= 0.0,
|
|
687
|
+
power= 0.75,
|
|
688
|
+
update_after_step= 0,
|
|
689
|
+
)
|
|
690
|
+
self.ema_noise_pred_net = self.get_ema_average()
|
|
691
|
+
# self.ema = EMAModel(
|
|
692
|
+
# parameters=self.noise_pred_net.parameters(),
|
|
693
|
+
# power=0.75)
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def forward(
|
|
697
|
+
self,
|
|
698
|
+
obs_seq: torch.Tensor,
|
|
699
|
+
action_seq: Optional[torch.Tensor],
|
|
700
|
+
eval = False
|
|
701
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
702
|
+
if eval:
|
|
703
|
+
return self._predict(obs_seq, None, action_seq)
|
|
704
|
+
else:
|
|
705
|
+
return self._update(obs_seq, None, action_seq)
|
|
706
|
+
|
|
707
|
+
def _update(
|
|
708
|
+
self,
|
|
709
|
+
obs_seq: torch.Tensor,
|
|
710
|
+
goal_seq: Optional[torch.Tensor],
|
|
711
|
+
action_seq: Optional[torch.Tensor],
|
|
712
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
|
|
713
|
+
# Assume dimensions are N T D for N sequences of T timesteps with dimension D.
|
|
714
|
+
if obs_seq.shape[1] < self.obs_horizon:
|
|
715
|
+
obs_seq = torch.cat((torch.tile(obs_seq[:, 0, :], (1, self.obs_horizon-obs_seq.shape[1], 1)), obs_seq), dim=-2)
|
|
716
|
+
if self.policy_type == "cnn":
|
|
717
|
+
action_seq = torch.cat((torch.zeros_like(action_seq[:, :self.pad_before]), action_seq, torch.zeros_like(action_seq[:, :self.pad_after])), dim=1)
|
|
718
|
+
naction = self.normalize_act_data(action_seq).to(self.device)
|
|
719
|
+
nobs = self.normalize_obs_data(obs_seq).to(self.device)
|
|
720
|
+
# nobs = obs_seq.to(self.device)
|
|
721
|
+
# naction = action_seq.to(self.device)
|
|
722
|
+
B = nobs.shape[0]
|
|
723
|
+
|
|
724
|
+
# observation as FiLM conditioning
|
|
725
|
+
# (B, obs_horizon, obs_dim)
|
|
726
|
+
obs_cond = nobs[:,:self.obs_horizon,:]
|
|
727
|
+
# (B, obs_horizon * obs_dim)
|
|
728
|
+
# obs_cond = obs_cond.flatten(start_dim=1)
|
|
729
|
+
|
|
730
|
+
# sample noise to add to actions
|
|
731
|
+
noise = torch.randn(naction.shape, device="cuda")
|
|
732
|
+
|
|
733
|
+
# sample a diffusion iteration for each data point
|
|
734
|
+
timesteps = torch.randint(
|
|
735
|
+
0, self.noise_scheduler.config.num_train_timesteps,
|
|
736
|
+
(B,), device="cuda"
|
|
737
|
+
).long()
|
|
738
|
+
|
|
739
|
+
# add noise to the clean images according to the noise magnitude at each diffusion iteration
|
|
740
|
+
# (this is the forward diffusion process)
|
|
741
|
+
noisy_actions = self.noise_scheduler.add_noise(
|
|
742
|
+
naction, noise, timesteps)
|
|
743
|
+
|
|
744
|
+
# predict the noise residual
|
|
745
|
+
if self.policy_type == "cnn":
|
|
746
|
+
obs_cond = obs_cond.flatten(start_dim=1)
|
|
747
|
+
noise_pred = self.noise_pred_net(
|
|
748
|
+
noisy_actions, timesteps, global_cond=obs_cond)
|
|
749
|
+
elif self.policy_type == "transformer":
|
|
750
|
+
noise_pred = self.noise_pred_net(
|
|
751
|
+
noisy_actions, timesteps, cond=obs_cond)
|
|
752
|
+
else:
|
|
753
|
+
raise NotImplementedError
|
|
754
|
+
# L2 loss
|
|
755
|
+
loss = nn.functional.mse_loss(noise_pred, noise)
|
|
756
|
+
loss_dict = {
|
|
757
|
+
"total_loss": loss.detach().cpu().item(),
|
|
758
|
+
}
|
|
759
|
+
return None, loss, loss_dict
|
|
760
|
+
|
|
761
|
+
def normalize_obs_data(self, data):
|
|
762
|
+
return data / self.data_obs_scale
|
|
763
|
+
|
|
764
|
+
def unnormalize_obs_data(self, data):
|
|
765
|
+
return data * self.data_obs_scale
|
|
766
|
+
|
|
767
|
+
def normalize_act_data(self, data):
|
|
768
|
+
return data / self.data_act_scale
|
|
769
|
+
|
|
770
|
+
def unnormalize_act_data(self, data):
|
|
771
|
+
return data * self.data_act_scale
|
|
772
|
+
|
|
773
|
+
def _predict(
|
|
774
|
+
self,
|
|
775
|
+
obs_seq: torch.Tensor,
|
|
776
|
+
goal_seq: Optional[torch.Tensor],
|
|
777
|
+
action_seq: Optional[torch.Tensor],
|
|
778
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
|
|
779
|
+
self.ema_noise_pred_net = self.get_ema_average()
|
|
780
|
+
B = obs_seq.shape[0]
|
|
781
|
+
# stack the last obs_horizon (2) number of observations
|
|
782
|
+
if obs_seq.shape[1] < self.obs_horizon:
|
|
783
|
+
obs_seq = torch.cat((torch.tile(obs_seq[:, 0, :], (1, self.obs_horizon-obs_seq.shape[1], 1)), obs_seq), dim=-2)
|
|
784
|
+
# normalize observation
|
|
785
|
+
nobs = self.normalize_obs_data(obs_seq)
|
|
786
|
+
# device transfer
|
|
787
|
+
# nobs = torch.from_numpy(nobs).to("cuda", dtype=torch.float32)
|
|
788
|
+
|
|
789
|
+
# infer action
|
|
790
|
+
with torch.no_grad():
|
|
791
|
+
# reshape observation to (B,obs_horizon*obs_dim)
|
|
792
|
+
# obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)
|
|
793
|
+
|
|
794
|
+
# initialize action from Guassian noise
|
|
795
|
+
noisy_action = torch.randn(
|
|
796
|
+
(B, self.pred_horizon, self.action_dim), device=self.device)
|
|
797
|
+
naction = noisy_action
|
|
798
|
+
|
|
799
|
+
# init scheduler
|
|
800
|
+
self.noise_scheduler.set_timesteps(self.num_diffusion_iters)
|
|
801
|
+
|
|
802
|
+
for k in self.noise_scheduler.timesteps:
|
|
803
|
+
# predict noise
|
|
804
|
+
if self.policy_type == "cnn":
|
|
805
|
+
# (B, obs_horizon, obs_dim)
|
|
806
|
+
#################################
|
|
807
|
+
obs_cond = nobs.flatten(start_dim=1)
|
|
808
|
+
# obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)
|
|
809
|
+
#################################
|
|
810
|
+
noise_pred = self.ema_noise_pred_net(
|
|
811
|
+
sample=naction,
|
|
812
|
+
timestep=k,
|
|
813
|
+
global_cond=obs_cond
|
|
814
|
+
)
|
|
815
|
+
elif self.policy_type == "transformer":
|
|
816
|
+
obs_cond = nobs
|
|
817
|
+
noise_pred = self.ema_noise_pred_net(
|
|
818
|
+
sample=naction,
|
|
819
|
+
timestep=k,
|
|
820
|
+
cond=obs_cond
|
|
821
|
+
)
|
|
822
|
+
else:
|
|
823
|
+
raise NotImplementedError
|
|
824
|
+
|
|
825
|
+
# inverse diffusion step (remove noise)
|
|
826
|
+
naction = self.noise_scheduler.step(
|
|
827
|
+
model_output=noise_pred,
|
|
828
|
+
timestep=k,
|
|
829
|
+
sample=naction
|
|
830
|
+
).prev_sample
|
|
831
|
+
|
|
832
|
+
# unnormalize action
|
|
833
|
+
if self.policy_type == "cnn":
|
|
834
|
+
naction = naction[:, self.pad_before : -self.pad_after]
|
|
835
|
+
naction = self.unnormalize_act_data(naction)
|
|
836
|
+
action_pred = naction.detach().to(self.device)
|
|
837
|
+
# (B, pred_horizon, action_dim)
|
|
838
|
+
action_pred = action_pred[0]
|
|
839
|
+
start = self.obs_horizon - 1
|
|
840
|
+
end = start + self.action_horizon
|
|
841
|
+
a_hat = action_pred[start:end,:]
|
|
842
|
+
|
|
843
|
+
if action_seq is None:
|
|
844
|
+
return a_hat, None, {}
|
|
845
|
+
action_mse = F.mse_loss(naction, action_seq, reduction="none")
|
|
846
|
+
action_l1 = F.l1_loss(naction, action_seq, reduction="none")
|
|
847
|
+
norm = torch.norm(action_seq, p=2, dim=-1, keepdim=True) + 1e-9
|
|
848
|
+
normalized_mse = (action_mse / norm).mean()
|
|
849
|
+
|
|
850
|
+
translation_loss = F.mse_loss(
|
|
851
|
+
naction[:, :, :3], action_seq[:, :, :3]
|
|
852
|
+
).detach()
|
|
853
|
+
rotation_loss = F.mse_loss(
|
|
854
|
+
naction[:, :, 3:6], action_seq[:, :, 3:6]
|
|
855
|
+
).detach()
|
|
856
|
+
gripper_loss = F.mse_loss(
|
|
857
|
+
naction[:, :, 6:], action_seq[:, :, 6:]
|
|
858
|
+
).detach()
|
|
859
|
+
|
|
860
|
+
loss_dict = {
|
|
861
|
+
"L2_loss": action_mse.mean().detach().cpu().item(),
|
|
862
|
+
"L2_loss_normalized": normalized_mse.mean().detach().cpu().item(),
|
|
863
|
+
"L1_loss": action_l1.mean().detach().cpu().item(),
|
|
864
|
+
"translation_loss": translation_loss,
|
|
865
|
+
"rotation_loss": rotation_loss,
|
|
866
|
+
"gripper_loss": gripper_loss,
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
return a_hat, action_mse.mean(), loss_dict
|
|
870
|
+
|
|
871
|
+
def ema_step(self):
|
|
872
|
+
self.ema.step(self.noise_pred_net)
|
|
873
|
+
|
|
874
|
+
def get_ema_average(self):
|
|
875
|
+
return self.ema.averaged_model
|
|
876
|
+
|
|
877
|
+
def _begin_epoch(self, optimizer, **kwargs):
|
|
878
|
+
return None
|
|
879
|
+
|
|
880
|
+
def _load_from_state_dict(self, *args, **kwargs):
|
|
881
|
+
return super()._load_from_state_dict(*args, **kwargs)
|