egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,346 @@
1
+ from math import ceil
2
+ from functools import partial
3
+ from itertools import zip_longest
4
+ from random import randrange
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from baselines.rum.models.bet.vqvae.vector_quantize_pytorch import VectorQuantize
10
+
11
+ from einops import rearrange, repeat, pack, unpack
12
+
13
+ # helper functions
14
+
15
+
16
+ def exists(val):
17
+ return val is not None
18
+
19
+
20
+ def default(val, d):
21
+ return val if exists(val) else d
22
+
23
+
24
+ def round_up_multiple(num, mult):
25
+ return ceil(num / mult) * mult
26
+
27
+
28
+ # main class
29
+
30
+
31
+ class ResidualVQ(nn.Module):
32
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
33
+
34
+ def __init__(
35
+ self,
36
+ *,
37
+ dim,
38
+ num_quantizers,
39
+ codebook_dim=None,
40
+ shared_codebook=False,
41
+ heads=1,
42
+ quantize_dropout=False,
43
+ quantize_dropout_cutoff_index=0,
44
+ quantize_dropout_multiple_of=1,
45
+ accept_image_fmap=False,
46
+ eval=False,
47
+ **kwargs,
48
+ ):
49
+ super().__init__()
50
+ assert heads == 1, "residual vq is not compatible with multi-headed codes"
51
+ codebook_dim = default(codebook_dim, dim)
52
+ codebook_input_dim = codebook_dim * heads
53
+
54
+ requires_projection = codebook_input_dim != dim
55
+ self.project_in = (
56
+ nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
57
+ )
58
+ self.project_out = (
59
+ nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
60
+ )
61
+
62
+ self.num_quantizers = num_quantizers
63
+
64
+ self.accept_image_fmap = accept_image_fmap
65
+ self.layers = nn.ModuleList(
66
+ [
67
+ VectorQuantize(
68
+ dim=codebook_dim,
69
+ codebook_dim=codebook_dim,
70
+ accept_image_fmap=accept_image_fmap,
71
+ eval=eval,
72
+ **kwargs,
73
+ )
74
+ for _ in range(num_quantizers)
75
+ ]
76
+ )
77
+ if eval:
78
+ self.layers.eval()
79
+
80
+ self.quantize_dropout = quantize_dropout and num_quantizers > 1
81
+
82
+ assert quantize_dropout_cutoff_index >= 0
83
+
84
+ self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
85
+ self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
86
+
87
+ if not shared_codebook:
88
+ return
89
+
90
+ first_vq, *rest_vq = self.layers
91
+ codebook = first_vq._codebook
92
+
93
+ for vq in rest_vq:
94
+ vq._codebook = codebook
95
+
96
+ @property
97
+ def codebooks(self):
98
+ codebooks = [layer._codebook.embed for layer in self.layers]
99
+ codebooks = torch.stack(codebooks, dim=0)
100
+ codebooks = rearrange(codebooks, "q 1 c d -> q c d")
101
+ return codebooks
102
+
103
+ def get_codes_from_indices(self, indices):
104
+ batch, quantize_dim = indices.shape[0], indices.shape[-1]
105
+
106
+ # may also receive indices in the shape of 'b h w q' (accept_image_fmap)
107
+
108
+ indices, ps = pack([indices], "b * q")
109
+
110
+ # because of quantize dropout, one can pass in indices that are coarse
111
+ # and the network should be able to reconstruct
112
+
113
+ if quantize_dim < self.num_quantizers:
114
+ assert self.quantize_dropout > 0.0, (
115
+ "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
116
+ )
117
+ indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
118
+
119
+ # get ready for gathering
120
+
121
+ codebooks = repeat(self.codebooks, "q c d -> q b c d", b=batch)
122
+ gather_indices = repeat(indices, "b n q -> q b n d", d=codebooks.shape[-1])
123
+
124
+ # take care of quantizer dropout
125
+
126
+ mask = gather_indices == -1.0
127
+ gather_indices = gather_indices.masked_fill(
128
+ mask, 0
129
+ ) # have it fetch a dummy code to be masked out later
130
+
131
+ all_codes = codebooks.gather(2, gather_indices) # gather all codes
132
+
133
+ # mask out any codes that were dropout-ed
134
+
135
+ all_codes = all_codes.masked_fill(mask, 0.0)
136
+
137
+ # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
138
+
139
+ (all_codes,) = unpack(all_codes, ps, "q b * d")
140
+
141
+ return all_codes
142
+
143
+ def draw_logits_forward(self, encoding_logits):
144
+ # encoding_indices : dim1 = batch_size dim2 = 4 (number of groups) dim3 = vq dict size (header)
145
+ encoding_logits = encoding_logits.to(self.device)
146
+ bs = encoding_logits.shape[0]
147
+ quantized = torch.zeros((bs, self.codebooks.shape[-1])).to(self.device)
148
+ for q in range(encoding_logits.shape[1]):
149
+ quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q]).to(
150
+ self.device
151
+ )
152
+ return quantized
153
+
154
+ def forward(
155
+ self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
156
+ ):
157
+ num_quant, quant_dropout_multiple_of, return_loss, device = (
158
+ self.num_quantizers,
159
+ self.quantize_dropout_multiple_of,
160
+ exists(indices),
161
+ x.device,
162
+ )
163
+
164
+ x = self.project_in(x)
165
+
166
+ assert not (self.accept_image_fmap and exists(indices))
167
+
168
+ quantized_out = 0.0
169
+ residual = x
170
+
171
+ all_losses = []
172
+ all_indices = []
173
+
174
+ if return_loss:
175
+ assert not torch.any(indices == -1), (
176
+ "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
177
+ )
178
+ ce_losses = []
179
+
180
+ should_quantize_dropout = (
181
+ self.training and self.quantize_dropout and not return_loss
182
+ )
183
+
184
+ # sample a layer index at which to dropout further residual quantization
185
+ # also prepare null indices and loss
186
+
187
+ if should_quantize_dropout:
188
+ rand_quantize_dropout_index = randrange(
189
+ self.quantize_dropout_cutoff_index, num_quant
190
+ )
191
+
192
+ if quant_dropout_multiple_of != 1:
193
+ rand_quantize_dropout_index = (
194
+ round_up_multiple(
195
+ rand_quantize_dropout_index + 1, quant_dropout_multiple_of
196
+ )
197
+ - 1
198
+ )
199
+
200
+ null_indices_shape = (
201
+ (x.shape[0], *x.shape[-2:])
202
+ if self.accept_image_fmap
203
+ else tuple(x.shape[:2])
204
+ )
205
+ null_indices = torch.full(
206
+ null_indices_shape, -1.0, device=device, dtype=torch.long
207
+ )
208
+ null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
209
+
210
+ # go through the layers
211
+
212
+ for quantizer_index, layer in enumerate(self.layers):
213
+ if (
214
+ should_quantize_dropout
215
+ and quantizer_index > rand_quantize_dropout_index
216
+ ):
217
+ all_indices.append(null_indices)
218
+ all_losses.append(null_loss)
219
+ continue
220
+
221
+ layer_indices = None
222
+ if return_loss:
223
+ layer_indices = indices[..., quantizer_index]
224
+
225
+ quantized, *rest = layer(
226
+ residual,
227
+ indices=layer_indices,
228
+ sample_codebook_temp=sample_codebook_temp,
229
+ )
230
+
231
+ residual = residual - quantized.detach()
232
+ quantized_out = quantized_out + quantized
233
+
234
+ if return_loss:
235
+ ce_loss = rest[0]
236
+ ce_losses.append(ce_loss)
237
+ continue
238
+
239
+ embed_indices, loss = rest
240
+
241
+ all_indices.append(embed_indices)
242
+ all_losses.append(loss)
243
+
244
+ # project out, if needed
245
+
246
+ quantized_out = self.project_out(quantized_out)
247
+
248
+ # whether to early return the cross entropy loss
249
+
250
+ if return_loss:
251
+ return quantized_out, sum(ce_losses)
252
+
253
+ # stack all losses and indices
254
+
255
+ all_losses, all_indices = map(
256
+ partial(torch.stack, dim=-1), (all_losses, all_indices)
257
+ )
258
+
259
+ ret = (quantized_out, all_indices, all_losses)
260
+
261
+ if return_all_codes:
262
+ # whether to return all codes from all codebooks across layers
263
+ all_codes = self.get_codes_from_indices(all_indices)
264
+
265
+ # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
266
+ ret = (*ret, all_codes)
267
+
268
+ return ret
269
+
270
+
271
+ # grouped residual vq
272
+
273
+
274
+ class GroupedResidualVQ(nn.Module):
275
+ def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
276
+ super().__init__()
277
+ self.dim = dim
278
+ self.groups = groups
279
+ assert (dim % groups) == 0
280
+ dim_per_group = dim // groups
281
+
282
+ self.accept_image_fmap = accept_image_fmap
283
+
284
+ self.rvqs = nn.ModuleList([])
285
+
286
+ for _ in range(groups):
287
+ self.rvqs.append(
288
+ ResidualVQ(
289
+ dim=dim_per_group, accept_image_fmap=accept_image_fmap, **kwargs
290
+ )
291
+ )
292
+
293
+ @property
294
+ def codebooks(self):
295
+ return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
296
+
297
+ def get_codes_from_indices(self, indices):
298
+ codes = tuple(
299
+ rvq.get_codes_from_indices(chunk_indices)
300
+ for rvq, chunk_indices in zip(self.rvqs, indices)
301
+ )
302
+ return torch.stack(codes)
303
+
304
+ def forward(
305
+ self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
306
+ ):
307
+ shape = x.shape
308
+ split_dim = 1 if self.accept_image_fmap else -1
309
+ assert shape[split_dim] == self.dim
310
+
311
+ # split the feature dimension into groups
312
+
313
+ x = x.chunk(self.groups, dim=split_dim)
314
+
315
+ indices = default(indices, tuple())
316
+ return_ce_loss = len(indices) > 0
317
+ assert len(indices) == 0 or len(indices) == self.groups
318
+
319
+ forward_kwargs = dict(
320
+ return_all_codes=return_all_codes, sample_codebook_temp=sample_codebook_temp
321
+ )
322
+
323
+ # invoke residual vq on each group
324
+
325
+ out = tuple(
326
+ rvq(chunk, indices=chunk_indices, **forward_kwargs)
327
+ for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices)
328
+ )
329
+ out = tuple(zip(*out))
330
+
331
+ # if returning cross entropy loss to rvq codebooks
332
+
333
+ if return_ce_loss:
334
+ quantized, ce_losses = out
335
+ return torch.cat(quantized, dim=split_dim), sum(ce_losses)
336
+
337
+ # otherwise, get all the zipped outputs and combine them
338
+
339
+ quantized, all_indices, commit_losses, *maybe_all_codes = out
340
+
341
+ quantized = torch.cat(quantized, dim=split_dim)
342
+ all_indices = torch.stack(all_indices)
343
+ commit_losses = torch.stack(commit_losses)
344
+
345
+ ret = (quantized, all_indices, commit_losses, *maybe_all_codes)
346
+ return ret