egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,881 @@
1
+ from typing import Tuple, Sequence, Dict, Union, Optional
2
+ import numpy as np
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
7
+ from diffusers.training_utils import EMAModel
8
+ from diffusers.optimization import get_scheduler
9
+ from tqdm.auto import tqdm
10
+ import torch.nn as nn
11
+ import einops
12
+ import torch.nn.functional as F
13
+
14
+ GENERATOR_SEED_FIXED = 123456789
15
+
16
+ class SinusoidalPosEmb(nn.Module):
17
+ def __init__(self, dim):
18
+ super().__init__()
19
+ self.dim = dim
20
+
21
+ def forward(self, x):
22
+ device = x.device
23
+ half_dim = self.dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26
+ emb = x[:, None] * emb[None, :]
27
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28
+ return emb
29
+
30
+
31
+ class Downsample1d(nn.Module):
32
+ def __init__(self, dim):
33
+ super().__init__()
34
+ self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
35
+
36
+ def forward(self, x):
37
+ return self.conv(x)
38
+
39
+ class Upsample1d(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
43
+
44
+ def forward(self, x):
45
+ return self.conv(x)
46
+
47
+
48
+ class Conv1dBlock(nn.Module):
49
+ '''
50
+ Conv1d --> GroupNorm --> Mish
51
+ '''
52
+
53
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
54
+ super().__init__()
55
+
56
+ self.block = nn.Sequential(
57
+ nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
58
+ nn.GroupNorm(n_groups, out_channels),
59
+ nn.Mish(),
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.block(x)
64
+
65
+
66
+ class ConditionalResidualBlock1D(nn.Module):
67
+ def __init__(self,
68
+ in_channels,
69
+ out_channels,
70
+ cond_dim,
71
+ kernel_size=3,
72
+ n_groups=8):
73
+ super().__init__()
74
+
75
+ self.blocks = nn.ModuleList([
76
+ Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
77
+ Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
78
+ ])
79
+
80
+ # FiLM modulation https://arxiv.org/abs/1709.07871
81
+ # predicts per-channel scale and bias
82
+ cond_channels = out_channels * 2
83
+ self.out_channels = out_channels
84
+ self.cond_encoder = nn.Sequential(
85
+ nn.Mish(),
86
+ nn.Linear(cond_dim, cond_channels),
87
+ nn.Unflatten(-1, (-1, 1))
88
+ )
89
+
90
+ # make sure dimensions compatible
91
+ self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
92
+ if in_channels != out_channels else nn.Identity()
93
+
94
+ def forward(self, x, cond):
95
+ '''
96
+ x : [ batch_size x in_channels x horizon ]
97
+ cond : [ batch_size x cond_dim]
98
+
99
+ returns:
100
+ out : [ batch_size x out_channels x horizon ]
101
+ '''
102
+ out = self.blocks[0](x)
103
+ embed = self.cond_encoder(cond)
104
+
105
+ embed = embed.reshape(
106
+ embed.shape[0], 2, self.out_channels, 1)
107
+ scale = embed[:,0,...]
108
+ bias = embed[:,1,...]
109
+ out = scale * out + bias
110
+
111
+ out = self.blocks[1](out)
112
+ out = out + self.residual_conv(x)
113
+ return out
114
+
115
+ class ModuleAttrMixin(nn.Module):
116
+ def __init__(self):
117
+ super().__init__()
118
+ self._dummy_variable = nn.Parameter()
119
+
120
+ @property
121
+ def device(self):
122
+ return next(iter(self.parameters())).device
123
+
124
+ @property
125
+ def dtype(self):
126
+ return next(iter(self.parameters())).dtype
127
+
128
+
129
+
130
+ class TransformerForDiffusion(ModuleAttrMixin):
131
+ def __init__(self,
132
+ input_dim: int,
133
+ output_dim: int,
134
+ horizon: int,
135
+ n_obs_steps: int = None,
136
+ cond_dim: int = 0,
137
+ n_layer: int = 12,
138
+ n_head: int = 12,
139
+ n_emb: int = 768,
140
+ p_drop_emb: float = 0.1,
141
+ p_drop_attn: float = 0.1,
142
+ causal_attn: bool=False,
143
+ time_as_cond: bool=True,
144
+ obs_as_cond: bool=False,
145
+ n_cond_layers: int = 0
146
+ ) -> None:
147
+ super().__init__()
148
+
149
+ # compute number of tokens for main trunk and condition encoder
150
+ if n_obs_steps is None:
151
+ n_obs_steps = horizon
152
+
153
+ T = horizon
154
+ T_cond = 1
155
+ if not time_as_cond:
156
+ T += 1
157
+ T_cond -= 1
158
+ obs_as_cond = cond_dim > 0
159
+ if obs_as_cond:
160
+ assert time_as_cond
161
+ T_cond += n_obs_steps
162
+
163
+ # input embedding stem
164
+ self.input_emb = nn.Linear(input_dim, n_emb)
165
+ self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
166
+ self.drop = nn.Dropout(p_drop_emb)
167
+
168
+ # cond encoder
169
+ self.time_emb = SinusoidalPosEmb(n_emb)
170
+ self.cond_obs_emb = None
171
+
172
+ if obs_as_cond:
173
+ self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
174
+
175
+ self.cond_pos_emb = None
176
+ self.encoder = None
177
+ self.decoder = None
178
+ encoder_only = False
179
+ if T_cond > 0:
180
+ self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
181
+ if n_cond_layers > 0:
182
+ encoder_layer = nn.TransformerEncoderLayer(
183
+ d_model=n_emb,
184
+ nhead=n_head,
185
+ dim_feedforward=4*n_emb,
186
+ dropout=p_drop_attn,
187
+ activation='gelu',
188
+ batch_first=True,
189
+ norm_first=True
190
+ )
191
+ self.encoder = nn.TransformerEncoder(
192
+ encoder_layer=encoder_layer,
193
+ num_layers=n_cond_layers
194
+ )
195
+ else:
196
+ self.encoder = nn.Sequential(
197
+ nn.Linear(n_emb, 4 * n_emb),
198
+ nn.Mish(),
199
+ nn.Linear(4 * n_emb, n_emb)
200
+ )
201
+ # decoder
202
+ decoder_layer = nn.TransformerDecoderLayer(
203
+ d_model=n_emb,
204
+ nhead=n_head,
205
+ dim_feedforward=4*n_emb,
206
+ dropout=p_drop_attn,
207
+ activation='gelu',
208
+ batch_first=True,
209
+ norm_first=True # important for stability
210
+ )
211
+ self.decoder = nn.TransformerDecoder(
212
+ decoder_layer=decoder_layer,
213
+ num_layers=n_layer
214
+ )
215
+ else:
216
+ # encoder only BERT
217
+ encoder_only = True
218
+
219
+ encoder_layer = nn.TransformerEncoderLayer(
220
+ d_model=n_emb,
221
+ nhead=n_head,
222
+ dim_feedforward=4*n_emb,
223
+ dropout=p_drop_attn,
224
+ activation='gelu',
225
+ batch_first=True,
226
+ norm_first=True
227
+ )
228
+ self.encoder = nn.TransformerEncoder(
229
+ encoder_layer=encoder_layer,
230
+ num_layers=n_layer
231
+ )
232
+
233
+ # attention mask
234
+ if causal_attn:
235
+ # causal mask to ensure that attention is only applied to the left in the input sequence
236
+ # torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT
237
+ # therefore, the upper triangle should be -inf and others (including diag) should be 0.
238
+ sz = T
239
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
240
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
241
+ self.register_buffer("mask", mask)
242
+
243
+ if time_as_cond and obs_as_cond:
244
+ S = T_cond
245
+ t, s = torch.meshgrid(
246
+ torch.arange(T),
247
+ torch.arange(S),
248
+ indexing='ij'
249
+ )
250
+ mask = t >= (s-1) # add one dimension since time is the first token in cond
251
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
252
+ self.register_buffer('memory_mask', mask)
253
+ else:
254
+ self.memory_mask = None
255
+ else:
256
+ self.mask = None
257
+ self.memory_mask = None
258
+
259
+ # decoder head
260
+ self.ln_f = nn.LayerNorm(n_emb)
261
+ self.head = nn.Linear(n_emb, output_dim)
262
+
263
+ # constants
264
+ self.T = T
265
+ self.T_cond = T_cond
266
+ self.horizon = horizon
267
+ self.time_as_cond = time_as_cond
268
+ self.obs_as_cond = obs_as_cond
269
+ self.encoder_only = encoder_only
270
+
271
+ # init
272
+ self.apply(self._init_weights)
273
+ # logger.info(
274
+ # "number of parameters: %e", sum(p.numel() for p in self.parameters())
275
+ # )
276
+
277
+ def _init_weights(self, module):
278
+ ignore_types = (nn.Dropout,
279
+ SinusoidalPosEmb,
280
+ nn.TransformerEncoderLayer,
281
+ nn.TransformerDecoderLayer,
282
+ nn.TransformerEncoder,
283
+ nn.TransformerDecoder,
284
+ nn.ModuleList,
285
+ nn.Mish,
286
+ nn.Sequential)
287
+ if isinstance(module, (nn.Linear, nn.Embedding)):
288
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
289
+ if isinstance(module, nn.Linear) and module.bias is not None:
290
+ torch.nn.init.zeros_(module.bias)
291
+ elif isinstance(module, nn.MultiheadAttention):
292
+ weight_names = [
293
+ 'in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
294
+ for name in weight_names:
295
+ weight = getattr(module, name)
296
+ if weight is not None:
297
+ torch.nn.init.normal_(weight, mean=0.0, std=0.02)
298
+
299
+ bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
300
+ for name in bias_names:
301
+ bias = getattr(module, name)
302
+ if bias is not None:
303
+ torch.nn.init.zeros_(bias)
304
+ elif isinstance(module, nn.LayerNorm):
305
+ torch.nn.init.zeros_(module.bias)
306
+ torch.nn.init.ones_(module.weight)
307
+ elif isinstance(module, TransformerForDiffusion):
308
+ torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
309
+ if module.cond_obs_emb is not None:
310
+ torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
311
+ elif isinstance(module, ignore_types):
312
+ # no param
313
+ pass
314
+ else:
315
+ raise RuntimeError("Unaccounted module {}".format(module))
316
+
317
+ def get_optim_groups(self, weight_decay: float=1e-3):
318
+ """
319
+ This long function is unfortunately doing something very simple and is being very defensive:
320
+ We are separating out all parameters of the model into two buckets: those that will experience
321
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
322
+ We are then returning the PyTorch optimizer object.
323
+ """
324
+
325
+ # separate out all parameters to those that will and won't experience regularizing weight decay
326
+ decay = set()
327
+ no_decay = set()
328
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
329
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
330
+ for mn, m in self.named_modules():
331
+ for pn, p in m.named_parameters():
332
+ fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
333
+
334
+ if pn.endswith("bias"):
335
+ # all biases will not be decayed
336
+ no_decay.add(fpn)
337
+ elif pn.startswith("bias"):
338
+ # MultiheadAttention bias starts with "bias"
339
+ no_decay.add(fpn)
340
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
341
+ # weights of whitelist modules will be weight decayed
342
+ decay.add(fpn)
343
+ elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
344
+ # weights of blacklist modules will NOT be weight decayed
345
+ no_decay.add(fpn)
346
+
347
+ # special case the position embedding parameter in the root GPT module as not decayed
348
+ no_decay.add("pos_emb")
349
+ no_decay.add("_dummy_variable")
350
+ if self.cond_pos_emb is not None:
351
+ no_decay.add("cond_pos_emb")
352
+
353
+ # validate that we considered every parameter
354
+ param_dict = {pn: p for pn, p in self.named_parameters()}
355
+ inter_params = decay & no_decay
356
+ union_params = decay | no_decay
357
+ assert (
358
+ len(inter_params) == 0
359
+ ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
360
+ assert (
361
+ len(param_dict.keys() - union_params) == 0
362
+ ), "parameters %s were not separated into either decay/no_decay set!" % (
363
+ str(param_dict.keys() - union_params),
364
+ )
365
+
366
+ # create the pytorch optimizer object
367
+ optim_groups = [
368
+ {
369
+ "params": [param_dict[pn] for pn in sorted(list(decay))],
370
+ "weight_decay": weight_decay,
371
+ },
372
+ {
373
+ "params": [param_dict[pn] for pn in sorted(list(no_decay))],
374
+ "weight_decay": 0.0,
375
+ },
376
+ ]
377
+ return optim_groups
378
+
379
+
380
+ def configure_optimizers(self,
381
+ learning_rate: float=1e-4,
382
+ weight_decay: float=1e-3,
383
+ betas: Tuple[float, float]=(0.9,0.95)):
384
+ optim_groups = self.get_optim_groups(weight_decay=weight_decay)
385
+ optimizer = torch.optim.AdamW(
386
+ optim_groups, lr=learning_rate, betas=betas
387
+ )
388
+ return optimizer
389
+
390
+ def forward(self,
391
+ sample: torch.Tensor,
392
+ timestep: Union[torch.Tensor, float, int],
393
+ cond: Optional[torch.Tensor]=None, **kwargs):
394
+ """
395
+ x: (B,T,input_dim)
396
+ timestep: (B,) or int, diffusion step
397
+ cond: (B,T',cond_dim)
398
+ output: (B,T,input_dim)
399
+ """
400
+ # 1. time
401
+ timesteps = timestep
402
+ if not torch.is_tensor(timesteps):
403
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
404
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
405
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
406
+ timesteps = timesteps[None].to(sample.device)
407
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
408
+ timesteps = timesteps.expand(sample.shape[0])
409
+ time_emb = self.time_emb(timesteps).unsqueeze(1)
410
+ # (B,1,n_emb)
411
+
412
+ # process input
413
+ input_emb = self.input_emb(sample)
414
+
415
+ if self.encoder_only:
416
+ # BERT
417
+ token_embeddings = torch.cat([time_emb, input_emb], dim=1)
418
+ t = token_embeddings.shape[1]
419
+ position_embeddings = self.pos_emb[
420
+ :, :t, :
421
+ ] # each position maps to a (learnable) vector
422
+ x = self.drop(token_embeddings + position_embeddings)
423
+ # (B,T+1,n_emb)
424
+ x = self.encoder(src=x, mask=self.mask)
425
+ # (B,T+1,n_emb)
426
+ x = x[:,1:,:]
427
+ # (B,T,n_emb)
428
+ else:
429
+ # encoder
430
+ cond_embeddings = time_emb
431
+ if self.obs_as_cond:
432
+ cond_obs_emb = self.cond_obs_emb(cond)
433
+ # (B,To,n_emb)
434
+ cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
435
+ tc = cond_embeddings.shape[1]
436
+ position_embeddings = self.cond_pos_emb[
437
+ :, :tc, :
438
+ ] # each position maps to a (learnable) vector
439
+ x = self.drop(cond_embeddings + position_embeddings)
440
+ x = self.encoder(x)
441
+ memory = x
442
+ # (B,T_cond,n_emb)
443
+
444
+ # decoder
445
+ token_embeddings = input_emb
446
+ t = token_embeddings.shape[1]
447
+ position_embeddings = self.pos_emb[
448
+ :, :t, :
449
+ ] # each position maps to a (learnable) vector
450
+ x = self.drop(token_embeddings + position_embeddings)
451
+ # (B,T,n_emb)
452
+ x = self.decoder(
453
+ tgt=x,
454
+ memory=memory,
455
+ tgt_mask=self.mask,
456
+ memory_mask=self.memory_mask
457
+ )
458
+ # (B,T,n_emb)
459
+
460
+ # head
461
+ x = self.ln_f(x)
462
+ x = self.head(x)
463
+ # (B,T,n_out)
464
+ return x
465
+
466
+
467
+
468
+ class ConditionalUnet1D(nn.Module):
469
+ def __init__(self,
470
+ input_dim,
471
+ global_cond_dim,
472
+ diffusion_step_embed_dim=256,
473
+ down_dims=[256,512,1024],
474
+ kernel_size=5,
475
+ n_groups=8
476
+ ):
477
+ """
478
+ input_dim: Dim of actions.
479
+ global_cond_dim: Dim of global conditioning applied with FiLM
480
+ in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
481
+ diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
482
+ down_dims: Channel size for each UNet level.
483
+ The length of this array determines numebr of levels.
484
+ kernel_size: Conv kernel size
485
+ n_groups: Number of groups for GroupNorm
486
+ """
487
+
488
+ super().__init__()
489
+ all_dims = [input_dim] + list(down_dims)
490
+ start_dim = down_dims[0]
491
+
492
+ dsed = diffusion_step_embed_dim
493
+ diffusion_step_encoder = nn.Sequential(
494
+ SinusoidalPosEmb(dsed),
495
+ nn.Linear(dsed, dsed * 4),
496
+ nn.Mish(),
497
+ nn.Linear(dsed * 4, dsed),
498
+ )
499
+ cond_dim = dsed + global_cond_dim
500
+
501
+ in_out = list(zip(all_dims[:-1], all_dims[1:]))
502
+ mid_dim = all_dims[-1]
503
+ self.mid_modules = nn.ModuleList([
504
+ ConditionalResidualBlock1D(
505
+ mid_dim, mid_dim, cond_dim=cond_dim,
506
+ kernel_size=kernel_size, n_groups=n_groups
507
+ ),
508
+ ConditionalResidualBlock1D(
509
+ mid_dim, mid_dim, cond_dim=cond_dim,
510
+ kernel_size=kernel_size, n_groups=n_groups
511
+ ),
512
+ ])
513
+
514
+ down_modules = nn.ModuleList([])
515
+ for ind, (dim_in, dim_out) in enumerate(in_out):
516
+ is_last = ind >= (len(in_out) - 1)
517
+ down_modules.append(nn.ModuleList([
518
+ ConditionalResidualBlock1D(
519
+ dim_in, dim_out, cond_dim=cond_dim,
520
+ kernel_size=kernel_size, n_groups=n_groups),
521
+ ConditionalResidualBlock1D(
522
+ dim_out, dim_out, cond_dim=cond_dim,
523
+ kernel_size=kernel_size, n_groups=n_groups),
524
+ Downsample1d(dim_out) if not is_last else nn.Identity()
525
+ ]))
526
+
527
+ up_modules = nn.ModuleList([])
528
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
529
+ is_last = ind >= (len(in_out) - 1)
530
+ up_modules.append(nn.ModuleList([
531
+ ConditionalResidualBlock1D(
532
+ dim_out*2, dim_in, cond_dim=cond_dim,
533
+ kernel_size=kernel_size, n_groups=n_groups),
534
+ ConditionalResidualBlock1D(
535
+ dim_in, dim_in, cond_dim=cond_dim,
536
+ kernel_size=kernel_size, n_groups=n_groups),
537
+ Upsample1d(dim_in) if not is_last else nn.Identity()
538
+ ]))
539
+
540
+ final_conv = nn.Sequential(
541
+ Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
542
+ nn.Conv1d(start_dim, input_dim, 1),
543
+ )
544
+
545
+ self.diffusion_step_encoder = diffusion_step_encoder
546
+ self.up_modules = up_modules
547
+ self.down_modules = down_modules
548
+ self.final_conv = final_conv
549
+
550
+ print("number of parameters: {:e}".format(
551
+ sum(p.numel() for p in self.parameters()))
552
+ )
553
+
554
+ def forward(self,
555
+ sample: torch.Tensor,
556
+ timestep: Union[torch.Tensor, float, int],
557
+ global_cond=None):
558
+ """
559
+ x: (B,T,input_dim)
560
+ timestep: (B,) or int, diffusion step
561
+ global_cond: (B,global_cond_dim)
562
+ output: (B,T,input_dim)
563
+ """
564
+ # (B,T,C)
565
+ sample = sample.moveaxis(-1,-2)
566
+ # (B,C,T)
567
+
568
+ # 1. time
569
+ timesteps = timestep
570
+ if not torch.is_tensor(timesteps):
571
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
572
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
573
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
574
+ timesteps = timesteps[None].to(sample.device)
575
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
576
+ timesteps = timesteps.expand(sample.shape[0])
577
+
578
+ global_feature = self.diffusion_step_encoder(timesteps)
579
+
580
+ if global_cond is not None:
581
+ global_feature = torch.cat([
582
+ global_feature, global_cond
583
+ ], axis=-1)
584
+
585
+ x = sample
586
+ h = []
587
+ for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
588
+ x = resnet(x, global_feature)
589
+ x = resnet2(x, global_feature)
590
+ h.append(x)
591
+ x = downsample(x)
592
+
593
+ for mid_module in self.mid_modules:
594
+ x = mid_module(x, global_feature)
595
+
596
+ for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
597
+ x = torch.cat((x, h.pop()), dim=1)
598
+ x = resnet(x, global_feature)
599
+ x = resnet2(x, global_feature)
600
+ x = upsample(x)
601
+
602
+ x = self.final_conv(x)
603
+
604
+ # (B,C,T)
605
+ x = x.moveaxis(-1,-2)
606
+ # (B,T,C)
607
+ return x
608
+
609
+ class DiffusionPolicy(nn.Module):
610
+ # observation and action dimensions corrsponding to
611
+ # the output of PushTEnv
612
+ def __init__(
613
+ self,
614
+ obs_dim: int,
615
+ act_dim: int,
616
+ obs_horizon: int,
617
+ pred_horizon: int,
618
+ action_horizon: int,
619
+ data_act_scale = 1.0,
620
+ data_obs_scale = 1.0,
621
+ policy_type = 'cnn',
622
+ device = 'cuda',
623
+ ):
624
+ super().__init__()
625
+ self.obs_dim = obs_dim
626
+ self.action_dim = act_dim
627
+ self.obs_horizon = obs_horizon
628
+ self.pred_horizon = pred_horizon
629
+ self.action_horizon = action_horizon
630
+ self.data_act_scale = data_act_scale
631
+ self.data_obs_scale = data_obs_scale
632
+ self.policy_type = policy_type
633
+ self.device = device
634
+ if self.policy_type == "cnn":
635
+ if self.action_horizon == 4:
636
+ self.pad_before = 1
637
+ self.pad_after = 2
638
+ self.pred_horizon = pred_horizon + self.pad_before + self.pad_after
639
+ if self.action_horizon == 6:
640
+ self.pad_before = 0
641
+ self.pad_after = 1
642
+ self.pred_horizon = pred_horizon + self.pad_before + self.pad_after
643
+ # create network object
644
+ if self.policy_type == "cnn":
645
+ self.noise_pred_net = ConditionalUnet1D(
646
+ input_dim=self.action_dim,
647
+ global_cond_dim=self.obs_dim*self.obs_horizon
648
+ ).to(self.device)
649
+ elif self.policy_type == "transformer":
650
+ self.noise_pred_net = TransformerForDiffusion(
651
+ input_dim=self.action_dim,
652
+ output_dim=self.action_dim,
653
+ horizon=pred_horizon,
654
+ n_obs_steps=obs_horizon,
655
+ cond_dim=self.obs_dim,
656
+ n_layer=8,
657
+ n_head=4,
658
+ n_emb=768,
659
+ p_drop_emb=0.0,
660
+ p_drop_attn=0.1,
661
+ causal_attn=True,
662
+ time_as_cond=True,
663
+ obs_as_cond=True,
664
+ n_cond_layers=0
665
+ ).to(self.device)
666
+ else:
667
+ raise NotImplementedError
668
+
669
+ # for this demo, we use DDPMScheduler with 100 diffusion iterations
670
+ self.num_diffusion_iters = 100
671
+ self.noise_scheduler = DDPMScheduler(
672
+ num_train_timesteps=self.num_diffusion_iters,
673
+ # the choise of beta schedule has big impact on performance
674
+ # we found squared cosine works the best
675
+ beta_schedule='squaredcos_cap_v2',
676
+ # clip output to [-1,1] to improve stability
677
+ clip_sample=True,
678
+ # our network predicts noise (instead of denoised action)
679
+ prediction_type='epsilon'
680
+ )
681
+
682
+ self.ema = EMAModel(
683
+ model=self.noise_pred_net,
684
+ inv_gamma= 1.0,
685
+ max_value= 0.9999,
686
+ min_value= 0.0,
687
+ power= 0.75,
688
+ update_after_step= 0,
689
+ )
690
+ self.ema_noise_pred_net = self.get_ema_average()
691
+ # self.ema = EMAModel(
692
+ # parameters=self.noise_pred_net.parameters(),
693
+ # power=0.75)
694
+
695
+
696
+ def forward(
697
+ self,
698
+ obs_seq: torch.Tensor,
699
+ action_seq: Optional[torch.Tensor],
700
+ eval = False
701
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
702
+ if eval:
703
+ return self._predict(obs_seq, None, action_seq)
704
+ else:
705
+ return self._update(obs_seq, None, action_seq)
706
+
707
+ def _update(
708
+ self,
709
+ obs_seq: torch.Tensor,
710
+ goal_seq: Optional[torch.Tensor],
711
+ action_seq: Optional[torch.Tensor],
712
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
713
+ # Assume dimensions are N T D for N sequences of T timesteps with dimension D.
714
+ if obs_seq.shape[1] < self.obs_horizon:
715
+ obs_seq = torch.cat((torch.tile(obs_seq[:, 0, :], (1, self.obs_horizon-obs_seq.shape[1], 1)), obs_seq), dim=-2)
716
+ if self.policy_type == "cnn":
717
+ action_seq = torch.cat((torch.zeros_like(action_seq[:, :self.pad_before]), action_seq, torch.zeros_like(action_seq[:, :self.pad_after])), dim=1)
718
+ naction = self.normalize_act_data(action_seq).to(self.device)
719
+ nobs = self.normalize_obs_data(obs_seq).to(self.device)
720
+ # nobs = obs_seq.to(self.device)
721
+ # naction = action_seq.to(self.device)
722
+ B = nobs.shape[0]
723
+
724
+ # observation as FiLM conditioning
725
+ # (B, obs_horizon, obs_dim)
726
+ obs_cond = nobs[:,:self.obs_horizon,:]
727
+ # (B, obs_horizon * obs_dim)
728
+ # obs_cond = obs_cond.flatten(start_dim=1)
729
+
730
+ # sample noise to add to actions
731
+ noise = torch.randn(naction.shape, device="cuda")
732
+
733
+ # sample a diffusion iteration for each data point
734
+ timesteps = torch.randint(
735
+ 0, self.noise_scheduler.config.num_train_timesteps,
736
+ (B,), device="cuda"
737
+ ).long()
738
+
739
+ # add noise to the clean images according to the noise magnitude at each diffusion iteration
740
+ # (this is the forward diffusion process)
741
+ noisy_actions = self.noise_scheduler.add_noise(
742
+ naction, noise, timesteps)
743
+
744
+ # predict the noise residual
745
+ if self.policy_type == "cnn":
746
+ obs_cond = obs_cond.flatten(start_dim=1)
747
+ noise_pred = self.noise_pred_net(
748
+ noisy_actions, timesteps, global_cond=obs_cond)
749
+ elif self.policy_type == "transformer":
750
+ noise_pred = self.noise_pred_net(
751
+ noisy_actions, timesteps, cond=obs_cond)
752
+ else:
753
+ raise NotImplementedError
754
+ # L2 loss
755
+ loss = nn.functional.mse_loss(noise_pred, noise)
756
+ loss_dict = {
757
+ "total_loss": loss.detach().cpu().item(),
758
+ }
759
+ return None, loss, loss_dict
760
+
761
+ def normalize_obs_data(self, data):
762
+ return data / self.data_obs_scale
763
+
764
+ def unnormalize_obs_data(self, data):
765
+ return data * self.data_obs_scale
766
+
767
+ def normalize_act_data(self, data):
768
+ return data / self.data_act_scale
769
+
770
+ def unnormalize_act_data(self, data):
771
+ return data * self.data_act_scale
772
+
773
+ def _predict(
774
+ self,
775
+ obs_seq: torch.Tensor,
776
+ goal_seq: Optional[torch.Tensor],
777
+ action_seq: Optional[torch.Tensor],
778
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
779
+ self.ema_noise_pred_net = self.get_ema_average()
780
+ B = obs_seq.shape[0]
781
+ # stack the last obs_horizon (2) number of observations
782
+ if obs_seq.shape[1] < self.obs_horizon:
783
+ obs_seq = torch.cat((torch.tile(obs_seq[:, 0, :], (1, self.obs_horizon-obs_seq.shape[1], 1)), obs_seq), dim=-2)
784
+ # normalize observation
785
+ nobs = self.normalize_obs_data(obs_seq)
786
+ # device transfer
787
+ # nobs = torch.from_numpy(nobs).to("cuda", dtype=torch.float32)
788
+
789
+ # infer action
790
+ with torch.no_grad():
791
+ # reshape observation to (B,obs_horizon*obs_dim)
792
+ # obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)
793
+
794
+ # initialize action from Guassian noise
795
+ noisy_action = torch.randn(
796
+ (B, self.pred_horizon, self.action_dim), device=self.device)
797
+ naction = noisy_action
798
+
799
+ # init scheduler
800
+ self.noise_scheduler.set_timesteps(self.num_diffusion_iters)
801
+
802
+ for k in self.noise_scheduler.timesteps:
803
+ # predict noise
804
+ if self.policy_type == "cnn":
805
+ # (B, obs_horizon, obs_dim)
806
+ #################################
807
+ obs_cond = nobs.flatten(start_dim=1)
808
+ # obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)
809
+ #################################
810
+ noise_pred = self.ema_noise_pred_net(
811
+ sample=naction,
812
+ timestep=k,
813
+ global_cond=obs_cond
814
+ )
815
+ elif self.policy_type == "transformer":
816
+ obs_cond = nobs
817
+ noise_pred = self.ema_noise_pred_net(
818
+ sample=naction,
819
+ timestep=k,
820
+ cond=obs_cond
821
+ )
822
+ else:
823
+ raise NotImplementedError
824
+
825
+ # inverse diffusion step (remove noise)
826
+ naction = self.noise_scheduler.step(
827
+ model_output=noise_pred,
828
+ timestep=k,
829
+ sample=naction
830
+ ).prev_sample
831
+
832
+ # unnormalize action
833
+ if self.policy_type == "cnn":
834
+ naction = naction[:, self.pad_before : -self.pad_after]
835
+ naction = self.unnormalize_act_data(naction)
836
+ action_pred = naction.detach().to(self.device)
837
+ # (B, pred_horizon, action_dim)
838
+ action_pred = action_pred[0]
839
+ start = self.obs_horizon - 1
840
+ end = start + self.action_horizon
841
+ a_hat = action_pred[start:end,:]
842
+
843
+ if action_seq is None:
844
+ return a_hat, None, {}
845
+ action_mse = F.mse_loss(naction, action_seq, reduction="none")
846
+ action_l1 = F.l1_loss(naction, action_seq, reduction="none")
847
+ norm = torch.norm(action_seq, p=2, dim=-1, keepdim=True) + 1e-9
848
+ normalized_mse = (action_mse / norm).mean()
849
+
850
+ translation_loss = F.mse_loss(
851
+ naction[:, :, :3], action_seq[:, :, :3]
852
+ ).detach()
853
+ rotation_loss = F.mse_loss(
854
+ naction[:, :, 3:6], action_seq[:, :, 3:6]
855
+ ).detach()
856
+ gripper_loss = F.mse_loss(
857
+ naction[:, :, 6:], action_seq[:, :, 6:]
858
+ ).detach()
859
+
860
+ loss_dict = {
861
+ "L2_loss": action_mse.mean().detach().cpu().item(),
862
+ "L2_loss_normalized": normalized_mse.mean().detach().cpu().item(),
863
+ "L1_loss": action_l1.mean().detach().cpu().item(),
864
+ "translation_loss": translation_loss,
865
+ "rotation_loss": rotation_loss,
866
+ "gripper_loss": gripper_loss,
867
+ }
868
+
869
+ return a_hat, action_mse.mean(), loss_dict
870
+
871
+ def ema_step(self):
872
+ self.ema.step(self.noise_pred_net)
873
+
874
+ def get_ema_average(self):
875
+ return self.ema.averaged_model
876
+
877
+ def _begin_epoch(self, optimizer, **kwargs):
878
+ return None
879
+
880
+ def _load_from_state_dict(self, *args, **kwargs):
881
+ return super()._load_from_state_dict(*args, **kwargs)