egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
from math import ceil
|
|
2
|
+
from functools import partial
|
|
3
|
+
from itertools import zip_longest
|
|
4
|
+
from random import randrange
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from baselines.rum.models.bet.vqvae.vector_quantize_pytorch import VectorQuantize
|
|
10
|
+
|
|
11
|
+
from einops import rearrange, repeat, pack, unpack
|
|
12
|
+
|
|
13
|
+
# helper functions
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def exists(val):
|
|
17
|
+
return val is not None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def default(val, d):
|
|
21
|
+
return val if exists(val) else d
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def round_up_multiple(num, mult):
|
|
25
|
+
return ceil(num / mult) * mult
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# main class
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ResidualVQ(nn.Module):
|
|
32
|
+
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
*,
|
|
37
|
+
dim,
|
|
38
|
+
num_quantizers,
|
|
39
|
+
codebook_dim=None,
|
|
40
|
+
shared_codebook=False,
|
|
41
|
+
heads=1,
|
|
42
|
+
quantize_dropout=False,
|
|
43
|
+
quantize_dropout_cutoff_index=0,
|
|
44
|
+
quantize_dropout_multiple_of=1,
|
|
45
|
+
accept_image_fmap=False,
|
|
46
|
+
eval=False,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
super().__init__()
|
|
50
|
+
assert heads == 1, "residual vq is not compatible with multi-headed codes"
|
|
51
|
+
codebook_dim = default(codebook_dim, dim)
|
|
52
|
+
codebook_input_dim = codebook_dim * heads
|
|
53
|
+
|
|
54
|
+
requires_projection = codebook_input_dim != dim
|
|
55
|
+
self.project_in = (
|
|
56
|
+
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
|
57
|
+
)
|
|
58
|
+
self.project_out = (
|
|
59
|
+
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.num_quantizers = num_quantizers
|
|
63
|
+
|
|
64
|
+
self.accept_image_fmap = accept_image_fmap
|
|
65
|
+
self.layers = nn.ModuleList(
|
|
66
|
+
[
|
|
67
|
+
VectorQuantize(
|
|
68
|
+
dim=codebook_dim,
|
|
69
|
+
codebook_dim=codebook_dim,
|
|
70
|
+
accept_image_fmap=accept_image_fmap,
|
|
71
|
+
eval=eval,
|
|
72
|
+
**kwargs,
|
|
73
|
+
)
|
|
74
|
+
for _ in range(num_quantizers)
|
|
75
|
+
]
|
|
76
|
+
)
|
|
77
|
+
if eval:
|
|
78
|
+
self.layers.eval()
|
|
79
|
+
|
|
80
|
+
self.quantize_dropout = quantize_dropout and num_quantizers > 1
|
|
81
|
+
|
|
82
|
+
assert quantize_dropout_cutoff_index >= 0
|
|
83
|
+
|
|
84
|
+
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
|
85
|
+
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
|
86
|
+
|
|
87
|
+
if not shared_codebook:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
first_vq, *rest_vq = self.layers
|
|
91
|
+
codebook = first_vq._codebook
|
|
92
|
+
|
|
93
|
+
for vq in rest_vq:
|
|
94
|
+
vq._codebook = codebook
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def codebooks(self):
|
|
98
|
+
codebooks = [layer._codebook.embed for layer in self.layers]
|
|
99
|
+
codebooks = torch.stack(codebooks, dim=0)
|
|
100
|
+
codebooks = rearrange(codebooks, "q 1 c d -> q c d")
|
|
101
|
+
return codebooks
|
|
102
|
+
|
|
103
|
+
def get_codes_from_indices(self, indices):
|
|
104
|
+
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
|
105
|
+
|
|
106
|
+
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
|
|
107
|
+
|
|
108
|
+
indices, ps = pack([indices], "b * q")
|
|
109
|
+
|
|
110
|
+
# because of quantize dropout, one can pass in indices that are coarse
|
|
111
|
+
# and the network should be able to reconstruct
|
|
112
|
+
|
|
113
|
+
if quantize_dim < self.num_quantizers:
|
|
114
|
+
assert self.quantize_dropout > 0.0, (
|
|
115
|
+
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
|
116
|
+
)
|
|
117
|
+
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
|
118
|
+
|
|
119
|
+
# get ready for gathering
|
|
120
|
+
|
|
121
|
+
codebooks = repeat(self.codebooks, "q c d -> q b c d", b=batch)
|
|
122
|
+
gather_indices = repeat(indices, "b n q -> q b n d", d=codebooks.shape[-1])
|
|
123
|
+
|
|
124
|
+
# take care of quantizer dropout
|
|
125
|
+
|
|
126
|
+
mask = gather_indices == -1.0
|
|
127
|
+
gather_indices = gather_indices.masked_fill(
|
|
128
|
+
mask, 0
|
|
129
|
+
) # have it fetch a dummy code to be masked out later
|
|
130
|
+
|
|
131
|
+
all_codes = codebooks.gather(2, gather_indices) # gather all codes
|
|
132
|
+
|
|
133
|
+
# mask out any codes that were dropout-ed
|
|
134
|
+
|
|
135
|
+
all_codes = all_codes.masked_fill(mask, 0.0)
|
|
136
|
+
|
|
137
|
+
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
|
|
138
|
+
|
|
139
|
+
(all_codes,) = unpack(all_codes, ps, "q b * d")
|
|
140
|
+
|
|
141
|
+
return all_codes
|
|
142
|
+
|
|
143
|
+
def draw_logits_forward(self, encoding_logits):
|
|
144
|
+
# encoding_indices : dim1 = batch_size dim2 = 4 (number of groups) dim3 = vq dict size (header)
|
|
145
|
+
encoding_logits = encoding_logits.to(self.device)
|
|
146
|
+
bs = encoding_logits.shape[0]
|
|
147
|
+
quantized = torch.zeros((bs, self.codebooks.shape[-1])).to(self.device)
|
|
148
|
+
for q in range(encoding_logits.shape[1]):
|
|
149
|
+
quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q]).to(
|
|
150
|
+
self.device
|
|
151
|
+
)
|
|
152
|
+
return quantized
|
|
153
|
+
|
|
154
|
+
def forward(
|
|
155
|
+
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
|
|
156
|
+
):
|
|
157
|
+
num_quant, quant_dropout_multiple_of, return_loss, device = (
|
|
158
|
+
self.num_quantizers,
|
|
159
|
+
self.quantize_dropout_multiple_of,
|
|
160
|
+
exists(indices),
|
|
161
|
+
x.device,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
x = self.project_in(x)
|
|
165
|
+
|
|
166
|
+
assert not (self.accept_image_fmap and exists(indices))
|
|
167
|
+
|
|
168
|
+
quantized_out = 0.0
|
|
169
|
+
residual = x
|
|
170
|
+
|
|
171
|
+
all_losses = []
|
|
172
|
+
all_indices = []
|
|
173
|
+
|
|
174
|
+
if return_loss:
|
|
175
|
+
assert not torch.any(indices == -1), (
|
|
176
|
+
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
|
|
177
|
+
)
|
|
178
|
+
ce_losses = []
|
|
179
|
+
|
|
180
|
+
should_quantize_dropout = (
|
|
181
|
+
self.training and self.quantize_dropout and not return_loss
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# sample a layer index at which to dropout further residual quantization
|
|
185
|
+
# also prepare null indices and loss
|
|
186
|
+
|
|
187
|
+
if should_quantize_dropout:
|
|
188
|
+
rand_quantize_dropout_index = randrange(
|
|
189
|
+
self.quantize_dropout_cutoff_index, num_quant
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if quant_dropout_multiple_of != 1:
|
|
193
|
+
rand_quantize_dropout_index = (
|
|
194
|
+
round_up_multiple(
|
|
195
|
+
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
|
|
196
|
+
)
|
|
197
|
+
- 1
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
null_indices_shape = (
|
|
201
|
+
(x.shape[0], *x.shape[-2:])
|
|
202
|
+
if self.accept_image_fmap
|
|
203
|
+
else tuple(x.shape[:2])
|
|
204
|
+
)
|
|
205
|
+
null_indices = torch.full(
|
|
206
|
+
null_indices_shape, -1.0, device=device, dtype=torch.long
|
|
207
|
+
)
|
|
208
|
+
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
|
|
209
|
+
|
|
210
|
+
# go through the layers
|
|
211
|
+
|
|
212
|
+
for quantizer_index, layer in enumerate(self.layers):
|
|
213
|
+
if (
|
|
214
|
+
should_quantize_dropout
|
|
215
|
+
and quantizer_index > rand_quantize_dropout_index
|
|
216
|
+
):
|
|
217
|
+
all_indices.append(null_indices)
|
|
218
|
+
all_losses.append(null_loss)
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
layer_indices = None
|
|
222
|
+
if return_loss:
|
|
223
|
+
layer_indices = indices[..., quantizer_index]
|
|
224
|
+
|
|
225
|
+
quantized, *rest = layer(
|
|
226
|
+
residual,
|
|
227
|
+
indices=layer_indices,
|
|
228
|
+
sample_codebook_temp=sample_codebook_temp,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
residual = residual - quantized.detach()
|
|
232
|
+
quantized_out = quantized_out + quantized
|
|
233
|
+
|
|
234
|
+
if return_loss:
|
|
235
|
+
ce_loss = rest[0]
|
|
236
|
+
ce_losses.append(ce_loss)
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
embed_indices, loss = rest
|
|
240
|
+
|
|
241
|
+
all_indices.append(embed_indices)
|
|
242
|
+
all_losses.append(loss)
|
|
243
|
+
|
|
244
|
+
# project out, if needed
|
|
245
|
+
|
|
246
|
+
quantized_out = self.project_out(quantized_out)
|
|
247
|
+
|
|
248
|
+
# whether to early return the cross entropy loss
|
|
249
|
+
|
|
250
|
+
if return_loss:
|
|
251
|
+
return quantized_out, sum(ce_losses)
|
|
252
|
+
|
|
253
|
+
# stack all losses and indices
|
|
254
|
+
|
|
255
|
+
all_losses, all_indices = map(
|
|
256
|
+
partial(torch.stack, dim=-1), (all_losses, all_indices)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
ret = (quantized_out, all_indices, all_losses)
|
|
260
|
+
|
|
261
|
+
if return_all_codes:
|
|
262
|
+
# whether to return all codes from all codebooks across layers
|
|
263
|
+
all_codes = self.get_codes_from_indices(all_indices)
|
|
264
|
+
|
|
265
|
+
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
|
266
|
+
ret = (*ret, all_codes)
|
|
267
|
+
|
|
268
|
+
return ret
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
# grouped residual vq
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class GroupedResidualVQ(nn.Module):
|
|
275
|
+
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
|
|
276
|
+
super().__init__()
|
|
277
|
+
self.dim = dim
|
|
278
|
+
self.groups = groups
|
|
279
|
+
assert (dim % groups) == 0
|
|
280
|
+
dim_per_group = dim // groups
|
|
281
|
+
|
|
282
|
+
self.accept_image_fmap = accept_image_fmap
|
|
283
|
+
|
|
284
|
+
self.rvqs = nn.ModuleList([])
|
|
285
|
+
|
|
286
|
+
for _ in range(groups):
|
|
287
|
+
self.rvqs.append(
|
|
288
|
+
ResidualVQ(
|
|
289
|
+
dim=dim_per_group, accept_image_fmap=accept_image_fmap, **kwargs
|
|
290
|
+
)
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def codebooks(self):
|
|
295
|
+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
|
296
|
+
|
|
297
|
+
def get_codes_from_indices(self, indices):
|
|
298
|
+
codes = tuple(
|
|
299
|
+
rvq.get_codes_from_indices(chunk_indices)
|
|
300
|
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
|
301
|
+
)
|
|
302
|
+
return torch.stack(codes)
|
|
303
|
+
|
|
304
|
+
def forward(
|
|
305
|
+
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
|
|
306
|
+
):
|
|
307
|
+
shape = x.shape
|
|
308
|
+
split_dim = 1 if self.accept_image_fmap else -1
|
|
309
|
+
assert shape[split_dim] == self.dim
|
|
310
|
+
|
|
311
|
+
# split the feature dimension into groups
|
|
312
|
+
|
|
313
|
+
x = x.chunk(self.groups, dim=split_dim)
|
|
314
|
+
|
|
315
|
+
indices = default(indices, tuple())
|
|
316
|
+
return_ce_loss = len(indices) > 0
|
|
317
|
+
assert len(indices) == 0 or len(indices) == self.groups
|
|
318
|
+
|
|
319
|
+
forward_kwargs = dict(
|
|
320
|
+
return_all_codes=return_all_codes, sample_codebook_temp=sample_codebook_temp
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# invoke residual vq on each group
|
|
324
|
+
|
|
325
|
+
out = tuple(
|
|
326
|
+
rvq(chunk, indices=chunk_indices, **forward_kwargs)
|
|
327
|
+
for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices)
|
|
328
|
+
)
|
|
329
|
+
out = tuple(zip(*out))
|
|
330
|
+
|
|
331
|
+
# if returning cross entropy loss to rvq codebooks
|
|
332
|
+
|
|
333
|
+
if return_ce_loss:
|
|
334
|
+
quantized, ce_losses = out
|
|
335
|
+
return torch.cat(quantized, dim=split_dim), sum(ce_losses)
|
|
336
|
+
|
|
337
|
+
# otherwise, get all the zipped outputs and combine them
|
|
338
|
+
|
|
339
|
+
quantized, all_indices, commit_losses, *maybe_all_codes = out
|
|
340
|
+
|
|
341
|
+
quantized = torch.cat(quantized, dim=split_dim)
|
|
342
|
+
all_indices = torch.stack(all_indices)
|
|
343
|
+
commit_losses = torch.stack(commit_losses)
|
|
344
|
+
|
|
345
|
+
ret = (quantized, all_indices, commit_losses, *maybe_all_codes)
|
|
346
|
+
return ret
|