egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,410 @@
1
+ import logging
2
+ from enum import Enum
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ import einops
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import tqdm
10
+
11
+ from baselines.rum.models.bet.gpt import GPT
12
+ from baselines.rum.models.bet.utils import MLP
13
+ from baselines.rum.models.bet.vqvae.vqvae import VqVae
14
+
15
+ GENERATOR_SEED_FIXED = 123456789
16
+
17
+
18
+ class VQBehaviorTransformer(nn.Module):
19
+ GOAL_SPEC = Enum("GOAL_SPEC", "concat stack unconditional")
20
+
21
+ def __init__(
22
+ self,
23
+ obs_dim: int,
24
+ act_dim: int,
25
+ goal_dim: int,
26
+ gpt_model: GPT,
27
+ vqvae_model: VqVae,
28
+ offset_loss_multiplier: float = 1.0e2,
29
+ secondary_code_multiplier: float = 0.5,
30
+ gamma: float = 2.0,
31
+ obs_window_size=10,
32
+ act_window_size=10,
33
+ sequentially_select=False,
34
+ use_og_bet_loss=False,
35
+ use_half_and_half_loss=False,
36
+ temperature=1.0,
37
+ device="cuda",
38
+ ):
39
+ super().__init__()
40
+ self._obs_dim = obs_dim
41
+ self._act_dim = act_dim
42
+ self._goal_dim = goal_dim
43
+ self.obs_window_size = obs_window_size
44
+ self.act_window_size = act_window_size
45
+ self.sequentially_select = sequentially_select
46
+ self._use_og_bet_loss = use_og_bet_loss
47
+ self._use_half_and_half_loss = use_half_and_half_loss
48
+ self.temperature = temperature
49
+ self.device = device
50
+
51
+ if goal_dim <= 0:
52
+ self._cbet_method = self.GOAL_SPEC.unconditional
53
+ # elif obs_dim == goal_dim:
54
+ # self._cbet_method = self.GOAL_SPEC.concat
55
+ # TODO (haritheja): this is a temporary fix, we should be able to handle different types of goals like concat or stack
56
+ else:
57
+ self._cbet_method = self.GOAL_SPEC.stack
58
+
59
+ self._gpt_model = gpt_model
60
+ self._vqvae_model = vqvae_model
61
+ self._G = self._vqvae_model.vqvae_groups # G(number of groups)
62
+ self._C = self._vqvae_model.vqvae_n_embed # C(number of code integers)
63
+ self._D = self._vqvae_model.embedding_dim # D(embedding dims)
64
+ self._current_steps = 0
65
+ if self.sequentially_select:
66
+ self._map_to_cbet_preds_bin1 = MLP(
67
+ in_channels=gpt_model.config.output_dim,
68
+ hidden_channels=[512, 512, self._C],
69
+ )
70
+ self._map_to_cbet_preds_bin2 = MLP(
71
+ in_channels=gpt_model.config.output_dim + self._C,
72
+ hidden_channels=[512, self._C],
73
+ )
74
+ else:
75
+ self._map_to_cbet_preds_bin = MLP(
76
+ in_channels=gpt_model.config.output_dim,
77
+ hidden_channels=[1024, 1024, self._G * self._C],
78
+ )
79
+ self._map_to_cbet_preds_offset = MLP(
80
+ in_channels=gpt_model.config.output_dim,
81
+ hidden_channels=[
82
+ 1024,
83
+ 1024,
84
+ self._G * self._C * (act_dim * self.act_window_size),
85
+ ],
86
+ )
87
+ self._offset_loss_multiplier = offset_loss_multiplier
88
+ self._secondary_code_multiplier = secondary_code_multiplier
89
+ self._criterion = FocalLoss(gamma=gamma)
90
+
91
+ def forward(
92
+ self,
93
+ obs_seq: torch.Tensor,
94
+ goal_seq: Optional[torch.Tensor],
95
+ action_seq: Optional[torch.Tensor],
96
+ second_half: bool = False,
97
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
98
+ # VQ-BeT doesn't use "padding_seq" and "predict_with_offset" input
99
+ return self._predict(obs_seq, goal_seq, action_seq, second_half)
100
+
101
+ def _predict(
102
+ self,
103
+ obs_seq: torch.Tensor,
104
+ goal_seq: Optional[torch.Tensor],
105
+ action_seq: Optional[torch.Tensor],
106
+ second_half: bool,
107
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
108
+ """
109
+ Assume dimensions are N T D for N sequences of T timesteps with dimension D.
110
+
111
+ obs_seq : (Batch size) X (obs window) X (obs dim)
112
+ goal_seq : (Batch size) X (goal window) X (obs dim)
113
+ action_seq : (Batch size) X (total window) X (action dim)
114
+
115
+ assume that
116
+ total_w (total window) of action sequence is
117
+ total_w = obs_w + act_w - 1
118
+ e.g. is observation window is 5, and antion pred window is 3,
119
+ the obs seq (o_{t-4}, o_{t-3}, o_{t-2}, o_{t-1}, o_{t}) will predict (a_{t}, a_{t+1}, a_{t+2})
120
+
121
+ However, we not only predict (a_{t}, a_{t+1}, a_{t+2}), but also all the action seq for all the actions in the sequence
122
+
123
+ o_{t-4} -> | | -> (a_{t-4}, a_{t-3}, a_{t-2})
124
+ o_{t-3} -> | G | -> (a_{t-3}, a_{t-2}, a_{t-1})
125
+ o_{t-2} -> | P | -> (a_{t-2}, a_{t-1}, a_{t})
126
+ o_{t-1} -> | T | -> (a_{t-1}, a_{t}, a_{t+1})
127
+ o_{t} -> | | -> (a_{t}, a_{t+1}, a_{t+2})
128
+ """
129
+ if obs_seq.shape[1] < self.obs_window_size:
130
+ # if input size is smaller than obs_window size (e.g. the initial steps of env eval episodes,
131
+ # VQ-BeT copy the obs and tile it to match obs_window_size
132
+ obs_seq = torch.cat(
133
+ (
134
+ torch.tile(
135
+ obs_seq[:, 0, :],
136
+ (1, self.obs_window_size - obs_seq.shape[1], 1),
137
+ ),
138
+ obs_seq,
139
+ ),
140
+ dim=-2,
141
+ )
142
+ if self._cbet_method == self.GOAL_SPEC.stack:
143
+ goal_seq = torch.cat(
144
+ (
145
+ torch.tile(
146
+ goal_seq[:, 0, :],
147
+ (1, self.obs_window_size - goal_seq.shape[1], 1),
148
+ ),
149
+ goal_seq,
150
+ ),
151
+ dim=-2,
152
+ )
153
+ if self._cbet_method == self.GOAL_SPEC.unconditional:
154
+ gpt_input = obs_seq
155
+ elif self._cbet_method == self.GOAL_SPEC.concat:
156
+ gpt_input = torch.cat([goal_seq, obs_seq], dim=1)
157
+ elif self._cbet_method == self.GOAL_SPEC.stack:
158
+ gpt_input = torch.cat([goal_seq, obs_seq], dim=-1)
159
+ else:
160
+ raise NotImplementedError
161
+
162
+ gpt_output = self._gpt_model(gpt_input)
163
+ if self._cbet_method == self.GOAL_SPEC.concat:
164
+ # Chop off the goal encodings.
165
+ gpt_output = gpt_output[:, goal_seq.size(1) :, :]
166
+ gpt_output = einops.rearrange(gpt_output, "N T (G C) -> (N T) (G C)", G=self._G)
167
+ obs = einops.rearrange(obs_seq, "N T O -> (N T) O")
168
+ obs = obs.unsqueeze(dim=1)
169
+ # note that output of offset network is G C WA,
170
+ # where G is number of 'layers' of Residual VQ-VAE
171
+ # C is number of words in each layer's dictionary
172
+ # and W, A is predicted action window, and predicted action dims
173
+ if self.sequentially_select:
174
+ cbet_logits1 = self._map_to_cbet_preds_bin1(gpt_output)
175
+ cbet_offsets = self._map_to_cbet_preds_offset(gpt_output)
176
+ cbet_offsets = einops.rearrange(
177
+ cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C
178
+ )
179
+ cbet_probs1 = torch.softmax(cbet_logits1 / self.temperature, dim=-1)
180
+ NT, choices = cbet_probs1.shape
181
+ G = self._G
182
+ sampled_centers1 = einops.rearrange(
183
+ torch.multinomial(cbet_probs1.view(-1, choices), num_samples=1),
184
+ "(NT) 1 -> NT",
185
+ NT=NT,
186
+ )
187
+ cbet_logits2 = self._map_to_cbet_preds_bin2(
188
+ torch.cat(
189
+ (gpt_output, F.one_hot(sampled_centers1, num_classes=self._C)),
190
+ axis=1,
191
+ )
192
+ )
193
+ cbet_probs2 = torch.softmax(cbet_logits2 / self.temperature, dim=-1)
194
+ sampled_centers2 = einops.rearrange(
195
+ torch.multinomial(cbet_probs2.view(-1, choices), num_samples=1),
196
+ "(NT) 1 -> NT",
197
+ NT=NT,
198
+ )
199
+ sampled_centers = torch.stack(
200
+ (sampled_centers1, sampled_centers2), axis=1
201
+ ) # NT, G
202
+ else:
203
+ cbet_logits = self._map_to_cbet_preds_bin(gpt_output)
204
+ cbet_offsets = self._map_to_cbet_preds_offset(gpt_output)
205
+ cbet_logits = einops.rearrange(
206
+ cbet_logits, "(NT) (G C) -> (NT) G C", G=self._G
207
+ )
208
+ cbet_offsets = einops.rearrange(
209
+ cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C
210
+ )
211
+ cbet_probs = torch.softmax(cbet_logits / self.temperature, dim=-1)
212
+ NT, G, choices = cbet_probs.shape
213
+ sampled_centers = einops.rearrange(
214
+ torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
215
+ "(NT G) 1 -> NT G",
216
+ NT=NT,
217
+ )
218
+ if action_seq is not None:
219
+ n, total_w, act_dim = action_seq.shape
220
+ act_w = self._vqvae_model.input_dim_h
221
+ obs_w = total_w + 1 - act_w
222
+ output_shape = (n, obs_w, act_w, act_dim)
223
+ output = torch.empty(output_shape).to(action_seq.device)
224
+ for i in range(obs_w):
225
+ output[:, i, :, :] = action_seq[:, i : i + act_w, :]
226
+ action_seq = einops.rearrange(output, "N T W A -> (N T) W A")
227
+ _, action_bins = self._vqvae_model.get_code(
228
+ action_seq, obs
229
+ ) # action_bins: NT, G
230
+
231
+ with torch.no_grad():
232
+ centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
233
+ NT, -1, self._D
234
+ )
235
+ return_decoder_input = einops.rearrange(
236
+ centers.clone().detach(), "NT G D -> NT (G D)"
237
+ )
238
+ decoded_action = (
239
+ self._vqvae_model.get_action_from_latent(return_decoder_input, obs)
240
+ .clone()
241
+ .detach()
242
+ ) # NT, A
243
+
244
+ def get_offset_from_centers(centers):
245
+ indices = (
246
+ torch.arange(NT, device=self.device).unsqueeze(1),
247
+ torch.arange(self._G, device=self.device).unsqueeze(0),
248
+ centers,
249
+ )
250
+ # Use advanced indexing to sample the values
251
+ sampled_offsets = cbet_offsets[indices] # NT, G, WA or NT, G, A
252
+ sampled_offsets = sampled_offsets.sum(dim=1)
253
+ sampled_offsets = einops.rearrange(
254
+ sampled_offsets, "NT (W A) -> NT W A", W=self._vqvae_model.input_dim_h
255
+ )
256
+ return sampled_offsets
257
+
258
+ if self._use_og_bet_loss and action_seq is not None:
259
+ sampled_offsets = get_offset_from_centers(action_bins)
260
+ else:
261
+ sampled_offsets = get_offset_from_centers(sampled_centers)
262
+
263
+ a_hat = decoded_action + sampled_offsets
264
+
265
+ if action_seq is None:
266
+ return a_hat, None, {}
267
+ # Figure out the loss for the actions.
268
+ # First, we need to find the GT VQ codes for each action.
269
+
270
+ # Now we can compute the loss.
271
+ if action_seq.ndim == 2:
272
+ action_seq = action_seq.unsqueeze(0)
273
+
274
+ offset_target = action_seq - decoded_action
275
+ if self._use_half_and_half_loss and action_seq is not None:
276
+ offset_gt = get_offset_from_centers(action_bins)
277
+ offset_sampled = get_offset_from_centers(sampled_centers)
278
+ offset_loss = 0.5 * torch.nn.L1Loss()(
279
+ offset_target, offset_gt
280
+ ) + 0.5 * torch.nn.L1Loss()(offset_target, offset_sampled)
281
+ else:
282
+ offset_loss = torch.nn.L1Loss()(offset_target, sampled_offsets)
283
+
284
+ if self.sequentially_select:
285
+ cbet_loss1 = self._criterion( # F.cross_entropy
286
+ cbet_logits1[:, :],
287
+ action_bins[:, 0],
288
+ )
289
+ cbet_logits2 = self._map_to_cbet_preds_bin2(
290
+ torch.cat(
291
+ (gpt_output, F.one_hot(action_bins[:, 0], num_classes=self._C)),
292
+ axis=1,
293
+ )
294
+ )
295
+ cbet_loss2 = self._criterion( # F.cross_entropy
296
+ cbet_logits2[:, :],
297
+ action_bins[:, 1],
298
+ )
299
+ else:
300
+ cbet_loss1 = self._criterion( # F.cross_entropy
301
+ cbet_logits[:, 0, :],
302
+ action_bins[:, 0],
303
+ )
304
+ cbet_loss2 = self._criterion( # F.cross_entropy
305
+ cbet_logits[:, 1, :],
306
+ action_bins[:, 1],
307
+ )
308
+ cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self._secondary_code_multiplier
309
+
310
+ equal_total_code_rate = (
311
+ torch.sum(
312
+ (torch.sum((action_bins == sampled_centers).int(), axis=1) == G).int()
313
+ )
314
+ / NT
315
+ )
316
+ equal_primary_code_rate = torch.sum(
317
+ (action_bins[:, 0] == sampled_centers[:, 0]).int()
318
+ ) / (NT)
319
+ equal_secondary_code_rate = torch.sum(
320
+ (action_bins[:, 1] == sampled_centers[:, 1]).int()
321
+ ) / (NT)
322
+ # if second_half:
323
+ # cbet_loss = cbet_loss * 0
324
+ loss = cbet_loss + self._offset_loss_multiplier * offset_loss
325
+ action_mse = F.mse_loss(a_hat, action_seq, reduction="none")
326
+ action_l1 = F.l1_loss(a_hat, action_seq, reduction="none")
327
+ norm = torch.norm(action_seq, p=2, dim=-1, keepdim=True) + 1e-9
328
+ normalized_mse = (action_mse / norm).mean()
329
+
330
+ translation_loss = F.mse_loss(a_hat[:, :, :3], action_seq[:, :, :3]).detach()
331
+ rotation_loss = F.mse_loss(a_hat[:, :, 3:6], action_seq[:, :, 3:6]).detach()
332
+ gripper_loss = F.mse_loss(a_hat[:, :, 6:], action_seq[:, :, 6:]).detach()
333
+
334
+ loss_dict = {
335
+ "classification_loss": cbet_loss.detach().cpu().item(),
336
+ "offset_loss": offset_loss.detach().cpu().item(),
337
+ "loss": loss.detach().cpu().item(),
338
+ "equal_total_code_rate": equal_total_code_rate,
339
+ "equal_primary_code_rate": equal_primary_code_rate,
340
+ "equal_secondary_code_rate": equal_secondary_code_rate,
341
+ "L2_loss": action_mse.mean().detach().cpu().item(),
342
+ "L2_loss_normalized": normalized_mse.mean().detach().cpu().item(),
343
+ "L1_loss": action_l1.mean().detach().cpu().item(),
344
+ "translation_loss": translation_loss,
345
+ "rotation_loss": rotation_loss,
346
+ "gripper_loss": gripper_loss,
347
+ }
348
+ return a_hat, loss, loss_dict
349
+
350
+ # def configure_optimizers(self, weight_decay, learning_rate, betas):
351
+
352
+ # optimizer1 = self._gpt_model.configure_optimizers(
353
+ # weight_decay=weight_decay,
354
+ # learning_rate=learning_rate,
355
+ # betas=betas,
356
+ # )
357
+ # if self.sequentially_select:
358
+ # optimizer1.add_param_group({"params": self._map_to_cbet_preds_bin1.parameters()})
359
+ # optimizer1.add_param_group({"params": self._map_to_cbet_preds_bin2.parameters()})
360
+ # else:
361
+ # optimizer1.add_param_group({"params": self._map_to_cbet_preds_bin.parameters()})
362
+ # optimizer2 = torch.optim.AdamW(self._map_to_cbet_preds_offset.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=betas)
363
+ # return {"optimizer1": optimizer1, "optimizer2": optimizer2}
364
+
365
+ def _begin_epoch(self, optimizer, **kwargs):
366
+ # log codebook usage rate for debugging
367
+ # lr_0 = optimizer.param_groups[0]["lr"]
368
+ # lr_neg1 = optimizer.param_groups[-1]["lr"]
369
+ # return {"lr_0": lr_0, "lr_neg1": lr_neg1}
370
+ return None
371
+
372
+ def _load_from_state_dict(self, *args, **kwargs):
373
+ # Don't fit kmeans if we are loading from a state dict.
374
+ # if (path / "cbet_model.pt").exists():
375
+ # self.load_state_dict(torch.load(path / "cbet_model.pt"))
376
+ # elif (path / "gpt_model.pt").exists():
377
+ # self._gpt_model.load_state_dict(torch.load(path / "gpt_model.pt"))
378
+ # else:
379
+ # logging.warning("No model found at %s", path)
380
+ return super()._load_from_state_dict(*args, **kwargs)
381
+
382
+ # def load_model(self, path):
383
+ # if (path / "cbet_model.pt").exists():
384
+ # self.load_state_dict(torch.load(path / "cbet_model.pt"))
385
+ # elif (path / "gpt_model.pt").exists():
386
+ # self._gpt_model.load_state_dict(torch.load(path / "gpt_model.pt"))
387
+ # else:
388
+ # logging.warning("No model found at %s", path)
389
+
390
+
391
+ class FocalLoss(nn.Module):
392
+ def __init__(self, gamma: float = 0, reduction: str = "mean"):
393
+ super(FocalLoss, self).__init__()
394
+ self.gamma = gamma
395
+ if reduction not in ("mean", "sum", "none"):
396
+ raise NotImplementedError
397
+ self.reduction = reduction
398
+
399
+ def forward(self, input, target):
400
+ logpt = F.log_softmax(input, dim=-1)
401
+ logpt = logpt.gather(1, target.view(-1, 1)).view(-1)
402
+ pt = logpt.exp()
403
+
404
+ loss = -1 * (1 - pt) ** self.gamma * logpt
405
+ if self.reduction == "mean":
406
+ return loss.mean()
407
+ elif self.reduction == "sum":
408
+ return loss.sum()
409
+ else:
410
+ return loss
@@ -0,0 +1,3 @@
1
+ from baselines.rum.models.bet.vqvae.vector_quantize_pytorch import VectorQuantize
2
+ from baselines.rum.models.bet.vqvae.residual_vq import ResidualVQ
3
+ from baselines.rum.models.bet.vqvae.vqvae import VqVae