egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
baselines/pi_policy.py ADDED
@@ -0,0 +1,110 @@
1
+ import numpy as np
2
+ import time
3
+ from egogym.policies.base_policy import BasePolicy
4
+ from egogym.utils import resize_with_pad
5
+
6
+ class PIPolicy(BasePolicy):
7
+
8
+ def __init__(self, config=None, **kwargs):
9
+ super().__init__()
10
+ if config is None:
11
+ config = get_config()
12
+
13
+ for key, value in kwargs.items():
14
+ config[key] = value
15
+
16
+ self.config = config
17
+ self.name = "pi"
18
+ self.host = self.config["host"]
19
+ self.port = self.config["port"]
20
+ self.grasping_style = self.config["grasping_style"]
21
+ self.buffer_length = self.config["buffer_length"]
22
+ self.batch_size = 0
23
+
24
+ try:
25
+ from openpi_client import websocket_client_policy
26
+ except ImportError:
27
+ raise ImportError("Please install the openpi-client.")
28
+
29
+ max_retries = 5
30
+ for attempt in range(max_retries):
31
+ try:
32
+ self.model = websocket_client_policy.WebsocketClientPolicy(
33
+ host=self.host,
34
+ port=self.port,
35
+ )
36
+ print(f"Connected to PI model at {self.host}:{self.port}")
37
+ break
38
+ except Exception as e:
39
+ if attempt < max_retries - 1:
40
+ print(f"Connection attempt {attempt + 1} failed: {e}. Retrying...")
41
+ time.sleep(1)
42
+ else:
43
+ print(f"Failed to connect to remote model after {max_retries} attempts")
44
+ raise
45
+
46
+ def get_action(self, obs):
47
+
48
+ if self.batch_size == 0:
49
+ if obs["rgb_exo"].ndim == 3:
50
+ self.batch_size = 1
51
+ else:
52
+ self.batch_size = obs["rgb_exo"].shape[0]
53
+ self.action_buffer = [None] * self.batch_size
54
+ self.current_buffer_index = [0] * self.batch_size
55
+
56
+ if self.batch_size == 1:
57
+ obs = {key: np.expand_dims(value, axis=0) if isinstance(value, np.ndarray) else [value] for key, value in obs.items()}
58
+
59
+
60
+ actions = []
61
+ for i in range(self.batch_size):
62
+
63
+ if self.action_buffer[i] is None or self.current_buffer_index[i] >= self.buffer_length:
64
+ model_input = {
65
+ "observation/exterior_image_1_left": resize_with_pad(obs["rgb_exo"][i], 224, 224),
66
+ "observation/wrist_image_left": resize_with_pad(obs["rgb_ego"][i], 224, 224),
67
+ "observation/joint_position": np.array(obs["joint_positions"][i][:7]).reshape(7,),
68
+ "observation/gripper_position": np.array(obs["grasp"][i]).reshape(1,),
69
+ "prompt": f'pick up the {obs["object_name"][i]} and place it on a plate.',
70
+ }
71
+ self.action_buffer[i] = self.model.infer(model_input)["actions"]
72
+ self.current_buffer_index[i] = 0
73
+ current_action = self.action_buffer[i][self.current_buffer_index[i]]
74
+ self.current_buffer_index[i] += 1
75
+
76
+ gripper_pos = np.array([np.clip(current_action[7], 0.0, 1.0)])
77
+
78
+ if self.grasping_style == "binary":
79
+ gripper_pos = np.array([1.0]) if gripper_pos >= self.config["gripper_threshold"] else np.array([0.0])
80
+ elif self.grasping_style == "semi_binary":
81
+ gripper_pos = gripper_pos if gripper_pos <= self.config["gripper_threshold"] else np.array([1.0])
82
+ elif not self.grasping_style == "continuous":
83
+ raise ValueError(f"Invalid grasping style: {self.grasping_style}")
84
+
85
+
86
+ arm_output = current_action[:7].reshape(7)
87
+ actions.append(np.concatenate([arm_output, gripper_pos], axis=0).reshape(1, -1))
88
+
89
+ return np.array(actions).reshape(self.batch_size, 8)
90
+
91
+ def reset(self, indicies=None):
92
+ if indicies is not None:
93
+ for i in indicies:
94
+ self.action_buffer[i] = None
95
+ self.current_buffer_index[i] = 0
96
+ else:
97
+ self.batch_size = 0
98
+ self.action_buffer = None
99
+ self.current_buffer_index = None
100
+
101
+
102
+ def get_config():
103
+ config = {
104
+ "host": "localhost",
105
+ "port": 8000,
106
+ "gripper_threshold": 0.5,
107
+ "grasping_style": "binary",
108
+ "buffer_length": 8,
109
+ }
110
+ return config
@@ -0,0 +1 @@
1
+ from .policy import Policy
@@ -0,0 +1,37 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ class BinarizeGripper(ABC):
7
+ def __init__(
8
+ self,
9
+ model: Optional[torch.nn.Module],
10
+ binarize_gripper: bool = False,
11
+ threshold: float = 0.5,
12
+ upper_value: float = 1.0,
13
+ lower_value: float = 0.0,
14
+ *args,
15
+ **kwargs,
16
+ ):
17
+ super().__init__()
18
+ assert (
19
+ not binarize_gripper or not model.relative_gripper
20
+ ), "Binarize gripper and relative gripper cannot be used together"
21
+
22
+ self.binarize_gripper = binarize_gripper
23
+ self.threshold = threshold
24
+ self.upper_value = upper_value
25
+ self.lower_value = lower_value
26
+
27
+ @abstractmethod
28
+ def binarize_gripper(self, actions):
29
+ # here the actions will be of shape (batch_size, seq_len, action_space)
30
+ # last element in action space is gripper values between 0 and 1
31
+
32
+ # binarize the gripper values
33
+ if self.binarize_gripper:
34
+ actions[:, :, -1] = torch.where(
35
+ actions[:, :, -1] > self.threshold, self.upper_value, self.lower_value
36
+ )
37
+ return actions
@@ -0,0 +1,13 @@
1
+ import abc
2
+
3
+ import torch
4
+
5
+
6
+ class AbstractLossFn(torch.nn.Module, abc.ABC):
7
+ def __init__(self, model, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ pass
10
+
11
+ @abc.abstractmethod
12
+ def forward(self, data, output, *args, **kwargs):
13
+ raise NotImplementedError
@@ -0,0 +1,114 @@
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import einops
4
+ import torch
5
+ import torch.nn as nn
6
+ from itertools import chain
7
+
8
+ from baselines.rum.loss_fns.abstract_loss_fn import AbstractLossFn
9
+ from baselines.rum.models.policies.diffusion_policy import DiffusionPolicy
10
+
11
+
12
+ class DiffusionPolicyLossFn(AbstractLossFn):
13
+ def __init__(
14
+ self,
15
+ tokenized_bet: bool,
16
+ action_dim: int,
17
+ obs_dim: int,
18
+ xyz_only: bool,
19
+ mask_last_min: int = 0,
20
+ mask_last_max: int = 0,
21
+ learned_mask: bool = True,
22
+ use_depth: bool = False,
23
+ model: Optional[torch.nn.Module] = None,
24
+ obs_window_size: int = 10,
25
+ action_sequence_length: int = 1,
26
+ data_act_scale: float = 1.0,
27
+ data_obs_scale: float = 1.0,
28
+ policy_type: str = "cnn",
29
+ device: str = "cuda",
30
+ ):
31
+ super().__init__(model)
32
+ assert mask_last_max >= mask_last_min
33
+ assert mask_last_min >= 0
34
+ obs_dim = model.feature_dim if not use_depth else model.feature_dim * 2
35
+ if use_depth:
36
+ self._depth_net = DepthNet(model.feature_dim)
37
+ self._use_depth = use_depth
38
+ self._true_action_dim = action_dim
39
+ action_dim = 3 if xyz_only else action_dim
40
+
41
+ self._diffusionpolicy = DiffusionPolicy(
42
+ obs_dim=obs_dim,
43
+ act_dim=action_dim,
44
+ obs_horizon=obs_window_size,
45
+ pred_horizon=(obs_window_size + action_sequence_length - 1),
46
+ action_horizon=action_sequence_length,
47
+ data_act_scale=data_act_scale,
48
+ data_obs_scale=data_obs_scale,
49
+ policy_type=policy_type,
50
+ device=device,
51
+ )
52
+
53
+ self._obs_mask_token = (
54
+ nn.Parameter(torch.ones(obs_dim), requires_grad=False)
55
+ if not learned_mask or mask_last_max == 0
56
+ else nn.Parameter(torch.randn(obs_dim), requires_grad=True)
57
+ )
58
+ self._obs_adapter = nn.Linear(obs_dim, obs_dim, bias=False)
59
+ self._adapt_obs = (
60
+ self._adapt_obs_linear if not tokenized_bet else self._adapt_obs_tokenized
61
+ )
62
+ self._mask_last = (mask_last_min, mask_last_max)
63
+ self._action_dim = action_dim
64
+
65
+ self.step = self._step if not tokenized_bet else self._step_tokenized
66
+
67
+ self._seen_action_stack = []
68
+ self._start_and_ends = None
69
+
70
+ def ema_step(self):
71
+ self._diffusionpolicy.ema_step()
72
+
73
+ def _adapt_obs_tokenized(self, obs):
74
+ return einops.rearrange(
75
+ self._obs_adapter(einops.rearrange(obs, "... c h w -> ... h w c")),
76
+ "... h w c -> ... c h w",
77
+ )
78
+
79
+ def _adapt_obs_linear(self, obs):
80
+ return self._obs_adapter(obs)
81
+
82
+ def _begin_epoch(self, *args, **kwargs):
83
+ return self._diffusionpolicy._begin_epoch(*args, **kwargs)
84
+
85
+ def forward(self, data, output, eval=False, *args, **kwargs):
86
+ *_, padding, actions = data
87
+ if self._use_depth:
88
+ *_, depths, padding, actions = data
89
+ output = torch.cat([output, self._depth_net(depths)], dim=-1)
90
+ adapted_obs = self._adapt_obs(output)
91
+ if actions is not None:
92
+ action_seq = actions[..., : self._action_dim].contiguous()
93
+ else:
94
+ action_seq = None
95
+ _, loss, loss_dict = self._diffusionpolicy(
96
+ adapted_obs, action_seq=action_seq, eval=eval
97
+ )
98
+ return loss, loss_dict
99
+
100
+ @torch.no_grad()
101
+ def _step(self, data, output, *args, **kwargs):
102
+ if self._use_depth:
103
+ _, depths, *_ = data
104
+ output = torch.cat([output, self._depth_net(depths)], dim=-1)
105
+ adapted_obs = self._adapt_obs(output)
106
+ a_hat, _, _ = self._diffusionpolicy(
107
+ adapted_obs,
108
+ action_seq=None,
109
+ eval=True,
110
+ )
111
+
112
+ return a_hat[0], {}
113
+
114
+ # return a_hat, {}
@@ -0,0 +1,104 @@
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import einops
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from baselines.rum.loss_fns.abstract_loss_fn import AbstractLossFn
8
+ from baselines.rum.models.bet import (
9
+ GPT,
10
+ BehaviorTransformer,
11
+ TokenizedBehaviorTransformer,
12
+ )
13
+ from baselines.rum.models.bet.vqvae.vqvae import VqVae
14
+
15
+
16
+ class RVQLossFn(AbstractLossFn):
17
+ def __init__(
18
+ self,
19
+ tokenized_bet: bool,
20
+ action_dim: int,
21
+ xyz_only: bool,
22
+ mask_last_min: int = 0,
23
+ mask_last_max: int = 0,
24
+ gpt_input_dim: int = 0,
25
+ learned_mask: bool = True,
26
+ model: Optional[torch.nn.Module] = None,
27
+ predict_with_offsets: bool = True,
28
+ sampling_temperature: float = 1.0,
29
+ action_sequence_length: int = 1,
30
+ vqvae_n_latent_dims: int = 512,
31
+ vqvae_n_embed: int = 16,
32
+ vqvae_groups: int = 2,
33
+ obs_cond: bool = False,
34
+ ):
35
+ super().__init__(model)
36
+ assert mask_last_max >= mask_last_min
37
+ assert mask_last_min >= 0
38
+ obs_dim = model.feature_dim
39
+ self._true_action_dim = action_dim
40
+ action_dim = 3 if xyz_only else action_dim
41
+ self.obs_cond = obs_cond
42
+ self._model = model
43
+ self._rvq = VqVae(
44
+ obs_dim=gpt_input_dim,
45
+ input_dim_h=action_sequence_length,
46
+ input_dim_w=action_dim,
47
+ n_latent_dims=vqvae_n_latent_dims,
48
+ vqvae_n_embed=vqvae_n_embed,
49
+ vqvae_groups=vqvae_groups,
50
+ eval=False,
51
+ device=model.device,
52
+ enc_loss_type="through_vqlayer",
53
+ obs_cond=obs_cond,
54
+ )
55
+
56
+ if self.obs_cond:
57
+ self._obs_mask_token = (
58
+ nn.Parameter(torch.ones(gpt_input_dim), requires_grad=False)
59
+ if not learned_mask or mask_last_max == 0
60
+ else nn.Parameter(torch.randn(gpt_input_dim), requires_grad=True)
61
+ )
62
+ self._obs_adapter = nn.Linear(obs_dim, gpt_input_dim, bias=False)
63
+ self._adapt_obs = (
64
+ self._adapt_obs_linear
65
+ if not tokenized_bet
66
+ else self._adapt_obs_tokenized
67
+ )
68
+ self._mask_last = (mask_last_min, mask_last_max)
69
+ self._action_dim = action_dim
70
+
71
+ self.step = self._step if not tokenized_bet else self._step_tokenized
72
+
73
+ self._seen_action_stack = []
74
+ self._start_and_ends = None
75
+
76
+ def _adapt_obs_tokenized(self, obs):
77
+ return einops.rearrange(
78
+ self._obs_adapter(einops.rearrange(obs, "... c h w -> ... h w c")),
79
+ "... h w c -> ... c h w",
80
+ )
81
+
82
+ def _adapt_obs_linear(self, obs):
83
+ return self._obs_adapter(obs)
84
+
85
+ def _begin_epoch(self, *args, **kwargs):
86
+ return self._rvq._begin_epoch(*args, **kwargs)
87
+
88
+ def forward(self, data, output, *args, **kwargs):
89
+ *_, padding, actions = data
90
+ action_seq = actions[..., : self._action_dim].contiguous()
91
+ if self.obs_cond:
92
+ adapted_obs = self._adapt_obs(output)
93
+ loss, loss_dict = self._rvq.vqvae_update(action_seq, adapted_obs)
94
+ else:
95
+ loss, loss_dict = self._rvq.vqvae_update(action_seq, None)
96
+ return loss, loss_dict
97
+
98
+ @torch.no_grad()
99
+ def _step(self, data, output, *args, **kwargs):
100
+ pass
101
+
102
+ @torch.no_grad()
103
+ def _step_tokenized(self, data, output, *args, **kwargs):
104
+ pass
@@ -0,0 +1,202 @@
1
+ from typing import Optional
2
+
3
+ import einops
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops.layers.torch import Rearrange
7
+ from einops import rearrange
8
+
9
+ from baselines.rum.loss_fns.abstract_loss_fn import AbstractLossFn
10
+ from baselines.rum.models.bet import GPT
11
+ from baselines.rum.models.bet.vqvae.vqvae import VqVae
12
+ from baselines.rum.models.bet.vqbet import VQBehaviorTransformer
13
+
14
+
15
+ class GoalAdapter(nn.Module):
16
+ def __init__(self, goal_dim, gpt_input_dim):
17
+ super().__init__()
18
+ self.linear = nn.Linear(goal_dim, gpt_input_dim, bias=False)
19
+
20
+ def forward(self, x):
21
+ # x: [b, t, g]
22
+ b, t, g = x.shape
23
+ x = rearrange(x, "b t g -> (b t) g")
24
+ x = self.linear(x)
25
+ x = rearrange(x, "(b t) g -> b t g", b=b, t=t)
26
+ return x
27
+
28
+
29
+ class VQBeTLossFn(AbstractLossFn):
30
+ def __init__(
31
+ self,
32
+ tokenized_bet: bool,
33
+ action_dim: int,
34
+ xyz_only: bool,
35
+ vqvae_load_dir: str,
36
+ gpt_model: GPT,
37
+ goal_dim: int = 0,
38
+ mask_last_min: int = 0,
39
+ mask_last_max: int = 0,
40
+ learned_mask: bool = True,
41
+ use_depth: bool = False,
42
+ model: Optional[torch.nn.Module] = None,
43
+ action_sequence_length: int = 1,
44
+ vqvae_n_latent_dims: int = 512,
45
+ vqvae_n_embed: int = 16,
46
+ vqvae_groups: int = 2,
47
+ obs_cond: bool = False,
48
+ offset_loss_multiplier: float = 100.0,
49
+ secondary_code_multiplier: float = 0.5,
50
+ gamma: float = 2.0,
51
+ obs_window_size: int = 10,
52
+ sequentially_select: bool = False,
53
+ temperature: float = 1.0,
54
+ device: str = "cuda",
55
+ ):
56
+ super().__init__(model)
57
+ assert mask_last_max >= mask_last_min
58
+ assert mask_last_min >= 0
59
+ obs_dim = model.feature_dim if not use_depth else model.feature_dim * 2
60
+ gpt_input_dim = gpt_model.config.input_dim
61
+ if use_depth:
62
+ self._depth_net = DepthNet(model.feature_dim)
63
+ self._use_depth = use_depth
64
+ self.goal_dim = goal_dim
65
+
66
+ # TODO (mahi): currently, we are casting everything to a concat style goal
67
+ # but we should be able to handle different types of goals like concat or stack
68
+ self._use_goals = goal_dim > 0
69
+ if self._use_goals:
70
+ gpt_input_dim //= 2
71
+ self._goal_adapter = GoalAdapter(goal_dim, gpt_input_dim)
72
+ goal_dim = gpt_input_dim
73
+ else:
74
+ self._goal_adapter = Rearrange("b g -> b 1 g")
75
+ self._true_action_dim = action_dim
76
+ action_dim = 3 if xyz_only else action_dim
77
+ self._rvq = VqVae(
78
+ obs_dim=gpt_input_dim,
79
+ input_dim_h=action_sequence_length,
80
+ input_dim_w=action_dim,
81
+ n_latent_dims=vqvae_n_latent_dims,
82
+ vqvae_n_embed=vqvae_n_embed,
83
+ vqvae_groups=vqvae_groups,
84
+ device=device,
85
+ eval=True,
86
+ enc_loss_type="through_vqlayer",
87
+ obs_cond=obs_cond,
88
+ load_dir=vqvae_load_dir,
89
+ )
90
+
91
+ for param in self._rvq.parameters():
92
+ param.requires_grad = False
93
+ self._vqbet = VQBehaviorTransformer(
94
+ obs_dim=gpt_input_dim,
95
+ act_dim=action_dim,
96
+ goal_dim=goal_dim,
97
+ gpt_model=gpt_model,
98
+ vqvae_model=self._rvq,
99
+ offset_loss_multiplier=offset_loss_multiplier,
100
+ secondary_code_multiplier=secondary_code_multiplier,
101
+ gamma=gamma,
102
+ obs_window_size=obs_window_size,
103
+ act_window_size=action_sequence_length,
104
+ sequentially_select=sequentially_select,
105
+ temperature=temperature,
106
+ device=device,
107
+ )
108
+
109
+ self._obs_mask_token = (
110
+ nn.Parameter(torch.ones(gpt_input_dim), requires_grad=False)
111
+ if not learned_mask or mask_last_max == 0
112
+ else nn.Parameter(torch.randn(gpt_input_dim), requires_grad=True)
113
+ )
114
+ self._obs_adapter = nn.Linear(obs_dim, gpt_input_dim, bias=False)
115
+ # self._adapt_obs = nn.Linear(obs_dim, gpt_input_dim, bias=False)
116
+ self._mask_last = (mask_last_min, mask_last_max)
117
+ self._action_dim = action_dim
118
+
119
+ self._adapt_obs = (
120
+ self._adapt_obs_linear if not tokenized_bet else self._adapt_obs_tokenized
121
+ )
122
+ self._mask_last = (mask_last_min, mask_last_max)
123
+ self._action_dim = action_dim
124
+
125
+ self.step = self._step if not tokenized_bet else self._step_tokenized
126
+
127
+ self._seen_action_stack = []
128
+ self._start_and_ends = None
129
+
130
+ def _adapt_obs_tokenized(self, obs):
131
+ return einops.rearrange(
132
+ self._obs_adapter(einops.rearrange(obs, "... c h w -> ... h w c")),
133
+ "... h w c -> ... c h w",
134
+ )
135
+
136
+ def _adapt_obs_linear(self, obs):
137
+ return self._obs_adapter(obs)
138
+
139
+ def _begin_epoch(self, *args, **kwargs):
140
+ return self._vqbet._begin_epoch(*args, **kwargs)
141
+
142
+ def forward(self, data, output, *args, **kwargs):
143
+ # TODO Mahi fix the order of depth and goals.
144
+ if self._use_goals:
145
+ _, goals, *_, padding, actions = data
146
+ if self._use_depth:
147
+ _, goals, *_, depths, padding, actions = data
148
+ output = torch.cat([output, self._depth_net(depths)], dim=-1)
149
+ goals = self._goal_adapter(goals)
150
+ else:
151
+ *_, padding, actions = data
152
+ goals = None
153
+ if self._use_depth:
154
+ *_, depths, padding, actions = data
155
+ output = torch.cat([output, self._depth_net(depths)], dim=-1)
156
+ adapted_obs = self._adapt_obs(output)
157
+ if "second_half" in kwargs:
158
+ second_half = kwargs["second_half"]
159
+ else:
160
+ second_half = False
161
+ action_seq = actions[..., : self._action_dim].contiguous()
162
+ _, loss, loss_dict = self._vqbet(
163
+ adapted_obs,
164
+ goal_seq=goals,
165
+ action_seq=action_seq,
166
+ second_half=second_half,
167
+ )
168
+ return loss, loss_dict
169
+
170
+ @torch.no_grad()
171
+ def _step(self, data, output, return_all=False, *args, **kwargs):
172
+ if self._use_depth:
173
+ *_, depths, padding, actions = data
174
+ output = torch.cat([output, self._depth_net(depths)], dim=-1)
175
+ goals = data[1] if self._use_goals else None
176
+ adapted_obs = self._adapt_obs(output)
177
+ goals = self._goal_adapter(goals) if self._use_goals else None
178
+ a_hat, _, _ = self._vqbet(
179
+ adapted_obs,
180
+ goal_seq=goals,
181
+ action_seq=None,
182
+ )
183
+ if a_hat.shape[2] != self._true_action_dim:
184
+ # append n zeros and a 1 to the end of y_hat
185
+ a_hat = torch.cat(
186
+ [
187
+ a_hat,
188
+ torch.zeros(
189
+ a_hat.shape[0],
190
+ a_hat.shape[1],
191
+ self._true_action_dim - a_hat.shape[2],
192
+ ).to(a_hat.device),
193
+ ],
194
+ dim=2,
195
+ )
196
+ a_hat[:, :, -1] = 1.0
197
+ if return_all:
198
+ return a_hat.reshape(
199
+ adapted_obs.shape[0], adapted_obs.shape[1], self._true_action_dim
200
+ )[:, -1, :], {}
201
+ # Finally, return the final action prediction only.
202
+ return a_hat[-1, -1, :], {}
@@ -0,0 +1 @@
1
+ from . import *
@@ -0,0 +1,3 @@
1
+ from baselines.rum.models.bet.bet import BehaviorTransformer
2
+ from baselines.rum.models.bet.gpt import GPT
3
+ from baselines.rum.models.bet.tokenized_bet import TokenizedBehaviorTransformer