egogym 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- egogym-0.1.0/PKG-INFO +52 -0
- egogym-0.1.0/README.md +21 -0
- egogym-0.1.0/baselines/pi_policy.py +110 -0
- egogym-0.1.0/baselines/rum/__init__.py +1 -0
- egogym-0.1.0/baselines/rum/loss_fns/__init__.py +37 -0
- egogym-0.1.0/baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- egogym-0.1.0/baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- egogym-0.1.0/baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- egogym-0.1.0/baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- egogym-0.1.0/baselines/rum/models/__init__.py +1 -0
- egogym-0.1.0/baselines/rum/models/bet/__init__.py +3 -0
- egogym-0.1.0/baselines/rum/models/bet/bet.py +347 -0
- egogym-0.1.0/baselines/rum/models/bet/gpt.py +277 -0
- egogym-0.1.0/baselines/rum/models/bet/tokenized_bet.py +454 -0
- egogym-0.1.0/baselines/rum/models/bet/utils.py +124 -0
- egogym-0.1.0/baselines/rum/models/bet/vqbet.py +410 -0
- egogym-0.1.0/baselines/rum/models/bet/vqvae/__init__.py +3 -0
- egogym-0.1.0/baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- egogym-0.1.0/baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- egogym-0.1.0/baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- egogym-0.1.0/baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- egogym-0.1.0/baselines/rum/models/custom.py +33 -0
- egogym-0.1.0/baselines/rum/models/encoders/__init__.py +0 -0
- egogym-0.1.0/baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- egogym-0.1.0/baselines/rum/models/encoders/identity.py +45 -0
- egogym-0.1.0/baselines/rum/models/encoders/timm_encoders.py +82 -0
- egogym-0.1.0/baselines/rum/models/policies/diffusion_policy.py +881 -0
- egogym-0.1.0/baselines/rum/models/policies/open_loop.py +122 -0
- egogym-0.1.0/baselines/rum/models/policies/simple_open_loop.py +108 -0
- egogym-0.1.0/baselines/rum/molmo/server.py +144 -0
- egogym-0.1.0/baselines/rum/policy.py +293 -0
- egogym-0.1.0/baselines/rum/utils/__init__.py +212 -0
- egogym-0.1.0/baselines/rum/utils/action_transforms.py +22 -0
- egogym-0.1.0/baselines/rum/utils/decord_transforms.py +135 -0
- egogym-0.1.0/baselines/rum/utils/rpc.py +249 -0
- egogym-0.1.0/baselines/rum/utils/schedulers.py +71 -0
- egogym-0.1.0/baselines/rum/utils/trajectory_vis.py +128 -0
- egogym-0.1.0/baselines/rum/utils/zmq_utils.py +281 -0
- egogym-0.1.0/baselines/rum_policy.py +108 -0
- egogym-0.1.0/egogym/__init__.py +8 -0
- egogym-0.1.0/egogym/assets/constants.py +1804 -0
- egogym-0.1.0/egogym/components/__init__.py +1 -0
- egogym-0.1.0/egogym/components/object.py +94 -0
- egogym-0.1.0/egogym/egogym.py +106 -0
- egogym-0.1.0/egogym/embodiments/__init__.py +10 -0
- egogym-0.1.0/egogym/embodiments/arms/__init__.py +4 -0
- egogym-0.1.0/egogym/embodiments/arms/arm.py +65 -0
- egogym-0.1.0/egogym/embodiments/arms/droid.py +49 -0
- egogym-0.1.0/egogym/embodiments/grippers/__init__.py +4 -0
- egogym-0.1.0/egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym-0.1.0/egogym/embodiments/grippers/rum.py +6 -0
- egogym-0.1.0/egogym/embodiments/robot.py +95 -0
- egogym-0.1.0/egogym/evaluate.py +216 -0
- egogym-0.1.0/egogym/managers/__init__.py +2 -0
- egogym-0.1.0/egogym/managers/objects_managers.py +30 -0
- egogym-0.1.0/egogym/managers/textures_manager.py +21 -0
- egogym-0.1.0/egogym/misc/molmo_client.py +49 -0
- egogym-0.1.0/egogym/misc/molmo_server.py +197 -0
- egogym-0.1.0/egogym/policies/__init__.py +1 -0
- egogym-0.1.0/egogym/policies/base_policy.py +13 -0
- egogym-0.1.0/egogym/scripts/analayze.py +834 -0
- egogym-0.1.0/egogym/scripts/plot.py +87 -0
- egogym-0.1.0/egogym/scripts/plot_correlation.py +392 -0
- egogym-0.1.0/egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym-0.1.0/egogym/scripts/plot_failure.py +248 -0
- egogym-0.1.0/egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym-0.1.0/egogym/scripts/plot_failure_vlm.py +257 -0
- egogym-0.1.0/egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym-0.1.0/egogym/scripts/plot_line.py +303 -0
- egogym-0.1.0/egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym-0.1.0/egogym/scripts/plot_pi0_bars.py +169 -0
- egogym-0.1.0/egogym/tasks/close.py +84 -0
- egogym-0.1.0/egogym/tasks/open.py +85 -0
- egogym-0.1.0/egogym/tasks/pick.py +121 -0
- egogym-0.1.0/egogym/utils.py +969 -0
- egogym-0.1.0/egogym/wrappers/__init__.py +20 -0
- egogym-0.1.0/egogym/wrappers/episode_monitor.py +282 -0
- egogym-0.1.0/egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym-0.1.0/egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym-0.1.0/egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym-0.1.0/egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0/egogym.egg-info/PKG-INFO +52 -0
- egogym-0.1.0/egogym.egg-info/SOURCES.txt +86 -0
- egogym-0.1.0/egogym.egg-info/dependency_links.txt +1 -0
- egogym-0.1.0/egogym.egg-info/requires.txt +14 -0
- egogym-0.1.0/egogym.egg-info/top_level.txt +2 -0
- egogym-0.1.0/pyproject.toml +47 -0
- egogym-0.1.0/setup.cfg +4 -0
egogym-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: egogym
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: EgoGym: A robotics environment for egocentric control tasks
|
|
5
|
+
License: MIT
|
|
6
|
+
Project-URL: Homepage, https://github.com/omarrayyann/EgoGym
|
|
7
|
+
Project-URL: Repository, https://github.com/omarrayyann/EgoGym
|
|
8
|
+
Keywords: robotics,reinforcement-learning,mujoco,gymnasium
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Requires-Python: >=3.8
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
Requires-Dist: mujoco
|
|
19
|
+
Requires-Dist: gymnasium
|
|
20
|
+
Requires-Dist: opencv-python
|
|
21
|
+
Requires-Dist: scipy
|
|
22
|
+
Requires-Dist: gdown
|
|
23
|
+
Requires-Dist: numpy
|
|
24
|
+
Requires-Dist: torch
|
|
25
|
+
Requires-Dist: transformers
|
|
26
|
+
Requires-Dist: tqdm
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: pytest; extra == "dev"
|
|
29
|
+
Requires-Dist: black; extra == "dev"
|
|
30
|
+
Requires-Dist: flake8; extra == "dev"
|
|
31
|
+
|
|
32
|
+
# EgoGym
|
|
33
|
+
|
|
34
|
+
EgoGym is a lightweight benchmark suite for egocentric robot policies
|
|
35
|
+
|
|
36
|
+
```python
|
|
37
|
+
env = gym.make(
|
|
38
|
+
"Egogym-Pick-v0", # Options: "Egogym-Pick-v0", "Egogym-Open-v0", "Egogym-Close-v0"
|
|
39
|
+
robot="rum", # Options: "rum", "droid"
|
|
40
|
+
action_space="delta", # Options: "delta", "absolute"
|
|
41
|
+
num_objs=5, # Options: 1-5
|
|
42
|
+
)
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
**Actions (17-dim):**
|
|
46
|
+
- `[0-15]`: flattened 4x4 transformation
|
|
47
|
+
- `[16]`: continuous gripper (0=open, 1=close)
|
|
48
|
+
|
|
49
|
+
**Coordinate Frame:**
|
|
50
|
+
- **x**: Right of camera
|
|
51
|
+
- **y**: Up of camera
|
|
52
|
+
- **z**: Backward of camera
|
egogym-0.1.0/README.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# EgoGym
|
|
2
|
+
|
|
3
|
+
EgoGym is a lightweight benchmark suite for egocentric robot policies
|
|
4
|
+
|
|
5
|
+
```python
|
|
6
|
+
env = gym.make(
|
|
7
|
+
"Egogym-Pick-v0", # Options: "Egogym-Pick-v0", "Egogym-Open-v0", "Egogym-Close-v0"
|
|
8
|
+
robot="rum", # Options: "rum", "droid"
|
|
9
|
+
action_space="delta", # Options: "delta", "absolute"
|
|
10
|
+
num_objs=5, # Options: 1-5
|
|
11
|
+
)
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
**Actions (17-dim):**
|
|
15
|
+
- `[0-15]`: flattened 4x4 transformation
|
|
16
|
+
- `[16]`: continuous gripper (0=open, 1=close)
|
|
17
|
+
|
|
18
|
+
**Coordinate Frame:**
|
|
19
|
+
- **x**: Right of camera
|
|
20
|
+
- **y**: Up of camera
|
|
21
|
+
- **z**: Backward of camera
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import time
|
|
3
|
+
from egogym.policies.base_policy import BasePolicy
|
|
4
|
+
from egogym.utils import resize_with_pad
|
|
5
|
+
|
|
6
|
+
class PIPolicy(BasePolicy):
|
|
7
|
+
|
|
8
|
+
def __init__(self, config=None, **kwargs):
|
|
9
|
+
super().__init__()
|
|
10
|
+
if config is None:
|
|
11
|
+
config = get_config()
|
|
12
|
+
|
|
13
|
+
for key, value in kwargs.items():
|
|
14
|
+
config[key] = value
|
|
15
|
+
|
|
16
|
+
self.config = config
|
|
17
|
+
self.name = "pi"
|
|
18
|
+
self.host = self.config["host"]
|
|
19
|
+
self.port = self.config["port"]
|
|
20
|
+
self.grasping_style = self.config["grasping_style"]
|
|
21
|
+
self.buffer_length = self.config["buffer_length"]
|
|
22
|
+
self.batch_size = 0
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from openpi_client import websocket_client_policy
|
|
26
|
+
except ImportError:
|
|
27
|
+
raise ImportError("Please install the openpi-client.")
|
|
28
|
+
|
|
29
|
+
max_retries = 5
|
|
30
|
+
for attempt in range(max_retries):
|
|
31
|
+
try:
|
|
32
|
+
self.model = websocket_client_policy.WebsocketClientPolicy(
|
|
33
|
+
host=self.host,
|
|
34
|
+
port=self.port,
|
|
35
|
+
)
|
|
36
|
+
print(f"Connected to PI model at {self.host}:{self.port}")
|
|
37
|
+
break
|
|
38
|
+
except Exception as e:
|
|
39
|
+
if attempt < max_retries - 1:
|
|
40
|
+
print(f"Connection attempt {attempt + 1} failed: {e}. Retrying...")
|
|
41
|
+
time.sleep(1)
|
|
42
|
+
else:
|
|
43
|
+
print(f"Failed to connect to remote model after {max_retries} attempts")
|
|
44
|
+
raise
|
|
45
|
+
|
|
46
|
+
def get_action(self, obs):
|
|
47
|
+
|
|
48
|
+
if self.batch_size == 0:
|
|
49
|
+
if obs["rgb_exo"].ndim == 3:
|
|
50
|
+
self.batch_size = 1
|
|
51
|
+
else:
|
|
52
|
+
self.batch_size = obs["rgb_exo"].shape[0]
|
|
53
|
+
self.action_buffer = [None] * self.batch_size
|
|
54
|
+
self.current_buffer_index = [0] * self.batch_size
|
|
55
|
+
|
|
56
|
+
if self.batch_size == 1:
|
|
57
|
+
obs = {key: np.expand_dims(value, axis=0) if isinstance(value, np.ndarray) else [value] for key, value in obs.items()}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
actions = []
|
|
61
|
+
for i in range(self.batch_size):
|
|
62
|
+
|
|
63
|
+
if self.action_buffer[i] is None or self.current_buffer_index[i] >= self.buffer_length:
|
|
64
|
+
model_input = {
|
|
65
|
+
"observation/exterior_image_1_left": resize_with_pad(obs["rgb_exo"][i], 224, 224),
|
|
66
|
+
"observation/wrist_image_left": resize_with_pad(obs["rgb_ego"][i], 224, 224),
|
|
67
|
+
"observation/joint_position": np.array(obs["joint_positions"][i][:7]).reshape(7,),
|
|
68
|
+
"observation/gripper_position": np.array(obs["grasp"][i]).reshape(1,),
|
|
69
|
+
"prompt": f'pick up the {obs["object_name"][i]} and place it on a plate.',
|
|
70
|
+
}
|
|
71
|
+
self.action_buffer[i] = self.model.infer(model_input)["actions"]
|
|
72
|
+
self.current_buffer_index[i] = 0
|
|
73
|
+
current_action = self.action_buffer[i][self.current_buffer_index[i]]
|
|
74
|
+
self.current_buffer_index[i] += 1
|
|
75
|
+
|
|
76
|
+
gripper_pos = np.array([np.clip(current_action[7], 0.0, 1.0)])
|
|
77
|
+
|
|
78
|
+
if self.grasping_style == "binary":
|
|
79
|
+
gripper_pos = np.array([1.0]) if gripper_pos >= self.config["gripper_threshold"] else np.array([0.0])
|
|
80
|
+
elif self.grasping_style == "semi_binary":
|
|
81
|
+
gripper_pos = gripper_pos if gripper_pos <= self.config["gripper_threshold"] else np.array([1.0])
|
|
82
|
+
elif not self.grasping_style == "continuous":
|
|
83
|
+
raise ValueError(f"Invalid grasping style: {self.grasping_style}")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
arm_output = current_action[:7].reshape(7)
|
|
87
|
+
actions.append(np.concatenate([arm_output, gripper_pos], axis=0).reshape(1, -1))
|
|
88
|
+
|
|
89
|
+
return np.array(actions).reshape(self.batch_size, 8)
|
|
90
|
+
|
|
91
|
+
def reset(self, indicies=None):
|
|
92
|
+
if indicies is not None:
|
|
93
|
+
for i in indicies:
|
|
94
|
+
self.action_buffer[i] = None
|
|
95
|
+
self.current_buffer_index[i] = 0
|
|
96
|
+
else:
|
|
97
|
+
self.batch_size = 0
|
|
98
|
+
self.action_buffer = None
|
|
99
|
+
self.current_buffer_index = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_config():
|
|
103
|
+
config = {
|
|
104
|
+
"host": "localhost",
|
|
105
|
+
"port": 8000,
|
|
106
|
+
"gripper_threshold": 0.5,
|
|
107
|
+
"grasping_style": "binary",
|
|
108
|
+
"buffer_length": 8,
|
|
109
|
+
}
|
|
110
|
+
return config
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .policy import Policy
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
class BinarizeGripper(ABC):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
model: Optional[torch.nn.Module],
|
|
10
|
+
binarize_gripper: bool = False,
|
|
11
|
+
threshold: float = 0.5,
|
|
12
|
+
upper_value: float = 1.0,
|
|
13
|
+
lower_value: float = 0.0,
|
|
14
|
+
*args,
|
|
15
|
+
**kwargs,
|
|
16
|
+
):
|
|
17
|
+
super().__init__()
|
|
18
|
+
assert (
|
|
19
|
+
not binarize_gripper or not model.relative_gripper
|
|
20
|
+
), "Binarize gripper and relative gripper cannot be used together"
|
|
21
|
+
|
|
22
|
+
self.binarize_gripper = binarize_gripper
|
|
23
|
+
self.threshold = threshold
|
|
24
|
+
self.upper_value = upper_value
|
|
25
|
+
self.lower_value = lower_value
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def binarize_gripper(self, actions):
|
|
29
|
+
# here the actions will be of shape (batch_size, seq_len, action_space)
|
|
30
|
+
# last element in action space is gripper values between 0 and 1
|
|
31
|
+
|
|
32
|
+
# binarize the gripper values
|
|
33
|
+
if self.binarize_gripper:
|
|
34
|
+
actions[:, :, -1] = torch.where(
|
|
35
|
+
actions[:, :, -1] > self.threshold, self.upper_value, self.lower_value
|
|
36
|
+
)
|
|
37
|
+
return actions
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AbstractLossFn(torch.nn.Module, abc.ABC):
|
|
7
|
+
def __init__(self, model, *args, **kwargs):
|
|
8
|
+
super().__init__(*args, **kwargs)
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
@abc.abstractmethod
|
|
12
|
+
def forward(self, data, output, *args, **kwargs):
|
|
13
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import Optional, Sequence, Tuple
|
|
2
|
+
|
|
3
|
+
import einops
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from itertools import chain
|
|
7
|
+
|
|
8
|
+
from baselines.rum.loss_fns.abstract_loss_fn import AbstractLossFn
|
|
9
|
+
from baselines.rum.models.policies.diffusion_policy import DiffusionPolicy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DiffusionPolicyLossFn(AbstractLossFn):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
tokenized_bet: bool,
|
|
16
|
+
action_dim: int,
|
|
17
|
+
obs_dim: int,
|
|
18
|
+
xyz_only: bool,
|
|
19
|
+
mask_last_min: int = 0,
|
|
20
|
+
mask_last_max: int = 0,
|
|
21
|
+
learned_mask: bool = True,
|
|
22
|
+
use_depth: bool = False,
|
|
23
|
+
model: Optional[torch.nn.Module] = None,
|
|
24
|
+
obs_window_size: int = 10,
|
|
25
|
+
action_sequence_length: int = 1,
|
|
26
|
+
data_act_scale: float = 1.0,
|
|
27
|
+
data_obs_scale: float = 1.0,
|
|
28
|
+
policy_type: str = "cnn",
|
|
29
|
+
device: str = "cuda",
|
|
30
|
+
):
|
|
31
|
+
super().__init__(model)
|
|
32
|
+
assert mask_last_max >= mask_last_min
|
|
33
|
+
assert mask_last_min >= 0
|
|
34
|
+
obs_dim = model.feature_dim if not use_depth else model.feature_dim * 2
|
|
35
|
+
if use_depth:
|
|
36
|
+
self._depth_net = DepthNet(model.feature_dim)
|
|
37
|
+
self._use_depth = use_depth
|
|
38
|
+
self._true_action_dim = action_dim
|
|
39
|
+
action_dim = 3 if xyz_only else action_dim
|
|
40
|
+
|
|
41
|
+
self._diffusionpolicy = DiffusionPolicy(
|
|
42
|
+
obs_dim=obs_dim,
|
|
43
|
+
act_dim=action_dim,
|
|
44
|
+
obs_horizon=obs_window_size,
|
|
45
|
+
pred_horizon=(obs_window_size + action_sequence_length - 1),
|
|
46
|
+
action_horizon=action_sequence_length,
|
|
47
|
+
data_act_scale=data_act_scale,
|
|
48
|
+
data_obs_scale=data_obs_scale,
|
|
49
|
+
policy_type=policy_type,
|
|
50
|
+
device=device,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
self._obs_mask_token = (
|
|
54
|
+
nn.Parameter(torch.ones(obs_dim), requires_grad=False)
|
|
55
|
+
if not learned_mask or mask_last_max == 0
|
|
56
|
+
else nn.Parameter(torch.randn(obs_dim), requires_grad=True)
|
|
57
|
+
)
|
|
58
|
+
self._obs_adapter = nn.Linear(obs_dim, obs_dim, bias=False)
|
|
59
|
+
self._adapt_obs = (
|
|
60
|
+
self._adapt_obs_linear if not tokenized_bet else self._adapt_obs_tokenized
|
|
61
|
+
)
|
|
62
|
+
self._mask_last = (mask_last_min, mask_last_max)
|
|
63
|
+
self._action_dim = action_dim
|
|
64
|
+
|
|
65
|
+
self.step = self._step if not tokenized_bet else self._step_tokenized
|
|
66
|
+
|
|
67
|
+
self._seen_action_stack = []
|
|
68
|
+
self._start_and_ends = None
|
|
69
|
+
|
|
70
|
+
def ema_step(self):
|
|
71
|
+
self._diffusionpolicy.ema_step()
|
|
72
|
+
|
|
73
|
+
def _adapt_obs_tokenized(self, obs):
|
|
74
|
+
return einops.rearrange(
|
|
75
|
+
self._obs_adapter(einops.rearrange(obs, "... c h w -> ... h w c")),
|
|
76
|
+
"... h w c -> ... c h w",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _adapt_obs_linear(self, obs):
|
|
80
|
+
return self._obs_adapter(obs)
|
|
81
|
+
|
|
82
|
+
def _begin_epoch(self, *args, **kwargs):
|
|
83
|
+
return self._diffusionpolicy._begin_epoch(*args, **kwargs)
|
|
84
|
+
|
|
85
|
+
def forward(self, data, output, eval=False, *args, **kwargs):
|
|
86
|
+
*_, padding, actions = data
|
|
87
|
+
if self._use_depth:
|
|
88
|
+
*_, depths, padding, actions = data
|
|
89
|
+
output = torch.cat([output, self._depth_net(depths)], dim=-1)
|
|
90
|
+
adapted_obs = self._adapt_obs(output)
|
|
91
|
+
if actions is not None:
|
|
92
|
+
action_seq = actions[..., : self._action_dim].contiguous()
|
|
93
|
+
else:
|
|
94
|
+
action_seq = None
|
|
95
|
+
_, loss, loss_dict = self._diffusionpolicy(
|
|
96
|
+
adapted_obs, action_seq=action_seq, eval=eval
|
|
97
|
+
)
|
|
98
|
+
return loss, loss_dict
|
|
99
|
+
|
|
100
|
+
@torch.no_grad()
|
|
101
|
+
def _step(self, data, output, *args, **kwargs):
|
|
102
|
+
if self._use_depth:
|
|
103
|
+
_, depths, *_ = data
|
|
104
|
+
output = torch.cat([output, self._depth_net(depths)], dim=-1)
|
|
105
|
+
adapted_obs = self._adapt_obs(output)
|
|
106
|
+
a_hat, _, _ = self._diffusionpolicy(
|
|
107
|
+
adapted_obs,
|
|
108
|
+
action_seq=None,
|
|
109
|
+
eval=True,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return a_hat[0], {}
|
|
113
|
+
|
|
114
|
+
# return a_hat, {}
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import Optional, Sequence, Tuple
|
|
2
|
+
|
|
3
|
+
import einops
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from baselines.rum.loss_fns.abstract_loss_fn import AbstractLossFn
|
|
8
|
+
from baselines.rum.models.bet import (
|
|
9
|
+
GPT,
|
|
10
|
+
BehaviorTransformer,
|
|
11
|
+
TokenizedBehaviorTransformer,
|
|
12
|
+
)
|
|
13
|
+
from baselines.rum.models.bet.vqvae.vqvae import VqVae
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RVQLossFn(AbstractLossFn):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
tokenized_bet: bool,
|
|
20
|
+
action_dim: int,
|
|
21
|
+
xyz_only: bool,
|
|
22
|
+
mask_last_min: int = 0,
|
|
23
|
+
mask_last_max: int = 0,
|
|
24
|
+
gpt_input_dim: int = 0,
|
|
25
|
+
learned_mask: bool = True,
|
|
26
|
+
model: Optional[torch.nn.Module] = None,
|
|
27
|
+
predict_with_offsets: bool = True,
|
|
28
|
+
sampling_temperature: float = 1.0,
|
|
29
|
+
action_sequence_length: int = 1,
|
|
30
|
+
vqvae_n_latent_dims: int = 512,
|
|
31
|
+
vqvae_n_embed: int = 16,
|
|
32
|
+
vqvae_groups: int = 2,
|
|
33
|
+
obs_cond: bool = False,
|
|
34
|
+
):
|
|
35
|
+
super().__init__(model)
|
|
36
|
+
assert mask_last_max >= mask_last_min
|
|
37
|
+
assert mask_last_min >= 0
|
|
38
|
+
obs_dim = model.feature_dim
|
|
39
|
+
self._true_action_dim = action_dim
|
|
40
|
+
action_dim = 3 if xyz_only else action_dim
|
|
41
|
+
self.obs_cond = obs_cond
|
|
42
|
+
self._model = model
|
|
43
|
+
self._rvq = VqVae(
|
|
44
|
+
obs_dim=gpt_input_dim,
|
|
45
|
+
input_dim_h=action_sequence_length,
|
|
46
|
+
input_dim_w=action_dim,
|
|
47
|
+
n_latent_dims=vqvae_n_latent_dims,
|
|
48
|
+
vqvae_n_embed=vqvae_n_embed,
|
|
49
|
+
vqvae_groups=vqvae_groups,
|
|
50
|
+
eval=False,
|
|
51
|
+
device=model.device,
|
|
52
|
+
enc_loss_type="through_vqlayer",
|
|
53
|
+
obs_cond=obs_cond,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if self.obs_cond:
|
|
57
|
+
self._obs_mask_token = (
|
|
58
|
+
nn.Parameter(torch.ones(gpt_input_dim), requires_grad=False)
|
|
59
|
+
if not learned_mask or mask_last_max == 0
|
|
60
|
+
else nn.Parameter(torch.randn(gpt_input_dim), requires_grad=True)
|
|
61
|
+
)
|
|
62
|
+
self._obs_adapter = nn.Linear(obs_dim, gpt_input_dim, bias=False)
|
|
63
|
+
self._adapt_obs = (
|
|
64
|
+
self._adapt_obs_linear
|
|
65
|
+
if not tokenized_bet
|
|
66
|
+
else self._adapt_obs_tokenized
|
|
67
|
+
)
|
|
68
|
+
self._mask_last = (mask_last_min, mask_last_max)
|
|
69
|
+
self._action_dim = action_dim
|
|
70
|
+
|
|
71
|
+
self.step = self._step if not tokenized_bet else self._step_tokenized
|
|
72
|
+
|
|
73
|
+
self._seen_action_stack = []
|
|
74
|
+
self._start_and_ends = None
|
|
75
|
+
|
|
76
|
+
def _adapt_obs_tokenized(self, obs):
|
|
77
|
+
return einops.rearrange(
|
|
78
|
+
self._obs_adapter(einops.rearrange(obs, "... c h w -> ... h w c")),
|
|
79
|
+
"... h w c -> ... c h w",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def _adapt_obs_linear(self, obs):
|
|
83
|
+
return self._obs_adapter(obs)
|
|
84
|
+
|
|
85
|
+
def _begin_epoch(self, *args, **kwargs):
|
|
86
|
+
return self._rvq._begin_epoch(*args, **kwargs)
|
|
87
|
+
|
|
88
|
+
def forward(self, data, output, *args, **kwargs):
|
|
89
|
+
*_, padding, actions = data
|
|
90
|
+
action_seq = actions[..., : self._action_dim].contiguous()
|
|
91
|
+
if self.obs_cond:
|
|
92
|
+
adapted_obs = self._adapt_obs(output)
|
|
93
|
+
loss, loss_dict = self._rvq.vqvae_update(action_seq, adapted_obs)
|
|
94
|
+
else:
|
|
95
|
+
loss, loss_dict = self._rvq.vqvae_update(action_seq, None)
|
|
96
|
+
return loss, loss_dict
|
|
97
|
+
|
|
98
|
+
@torch.no_grad()
|
|
99
|
+
def _step(self, data, output, *args, **kwargs):
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
@torch.no_grad()
|
|
103
|
+
def _step_tokenized(self, data, output, *args, **kwargs):
|
|
104
|
+
pass
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import einops
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
|
|
9
|
+
from baselines.rum.loss_fns.abstract_loss_fn import AbstractLossFn
|
|
10
|
+
from baselines.rum.models.bet import GPT
|
|
11
|
+
from baselines.rum.models.bet.vqvae.vqvae import VqVae
|
|
12
|
+
from baselines.rum.models.bet.vqbet import VQBehaviorTransformer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GoalAdapter(nn.Module):
|
|
16
|
+
def __init__(self, goal_dim, gpt_input_dim):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.linear = nn.Linear(goal_dim, gpt_input_dim, bias=False)
|
|
19
|
+
|
|
20
|
+
def forward(self, x):
|
|
21
|
+
# x: [b, t, g]
|
|
22
|
+
b, t, g = x.shape
|
|
23
|
+
x = rearrange(x, "b t g -> (b t) g")
|
|
24
|
+
x = self.linear(x)
|
|
25
|
+
x = rearrange(x, "(b t) g -> b t g", b=b, t=t)
|
|
26
|
+
return x
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class VQBeTLossFn(AbstractLossFn):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
tokenized_bet: bool,
|
|
33
|
+
action_dim: int,
|
|
34
|
+
xyz_only: bool,
|
|
35
|
+
vqvae_load_dir: str,
|
|
36
|
+
gpt_model: GPT,
|
|
37
|
+
goal_dim: int = 0,
|
|
38
|
+
mask_last_min: int = 0,
|
|
39
|
+
mask_last_max: int = 0,
|
|
40
|
+
learned_mask: bool = True,
|
|
41
|
+
use_depth: bool = False,
|
|
42
|
+
model: Optional[torch.nn.Module] = None,
|
|
43
|
+
action_sequence_length: int = 1,
|
|
44
|
+
vqvae_n_latent_dims: int = 512,
|
|
45
|
+
vqvae_n_embed: int = 16,
|
|
46
|
+
vqvae_groups: int = 2,
|
|
47
|
+
obs_cond: bool = False,
|
|
48
|
+
offset_loss_multiplier: float = 100.0,
|
|
49
|
+
secondary_code_multiplier: float = 0.5,
|
|
50
|
+
gamma: float = 2.0,
|
|
51
|
+
obs_window_size: int = 10,
|
|
52
|
+
sequentially_select: bool = False,
|
|
53
|
+
temperature: float = 1.0,
|
|
54
|
+
device: str = "cuda",
|
|
55
|
+
):
|
|
56
|
+
super().__init__(model)
|
|
57
|
+
assert mask_last_max >= mask_last_min
|
|
58
|
+
assert mask_last_min >= 0
|
|
59
|
+
obs_dim = model.feature_dim if not use_depth else model.feature_dim * 2
|
|
60
|
+
gpt_input_dim = gpt_model.config.input_dim
|
|
61
|
+
if use_depth:
|
|
62
|
+
self._depth_net = DepthNet(model.feature_dim)
|
|
63
|
+
self._use_depth = use_depth
|
|
64
|
+
self.goal_dim = goal_dim
|
|
65
|
+
|
|
66
|
+
# TODO (mahi): currently, we are casting everything to a concat style goal
|
|
67
|
+
# but we should be able to handle different types of goals like concat or stack
|
|
68
|
+
self._use_goals = goal_dim > 0
|
|
69
|
+
if self._use_goals:
|
|
70
|
+
gpt_input_dim //= 2
|
|
71
|
+
self._goal_adapter = GoalAdapter(goal_dim, gpt_input_dim)
|
|
72
|
+
goal_dim = gpt_input_dim
|
|
73
|
+
else:
|
|
74
|
+
self._goal_adapter = Rearrange("b g -> b 1 g")
|
|
75
|
+
self._true_action_dim = action_dim
|
|
76
|
+
action_dim = 3 if xyz_only else action_dim
|
|
77
|
+
self._rvq = VqVae(
|
|
78
|
+
obs_dim=gpt_input_dim,
|
|
79
|
+
input_dim_h=action_sequence_length,
|
|
80
|
+
input_dim_w=action_dim,
|
|
81
|
+
n_latent_dims=vqvae_n_latent_dims,
|
|
82
|
+
vqvae_n_embed=vqvae_n_embed,
|
|
83
|
+
vqvae_groups=vqvae_groups,
|
|
84
|
+
device=device,
|
|
85
|
+
eval=True,
|
|
86
|
+
enc_loss_type="through_vqlayer",
|
|
87
|
+
obs_cond=obs_cond,
|
|
88
|
+
load_dir=vqvae_load_dir,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
for param in self._rvq.parameters():
|
|
92
|
+
param.requires_grad = False
|
|
93
|
+
self._vqbet = VQBehaviorTransformer(
|
|
94
|
+
obs_dim=gpt_input_dim,
|
|
95
|
+
act_dim=action_dim,
|
|
96
|
+
goal_dim=goal_dim,
|
|
97
|
+
gpt_model=gpt_model,
|
|
98
|
+
vqvae_model=self._rvq,
|
|
99
|
+
offset_loss_multiplier=offset_loss_multiplier,
|
|
100
|
+
secondary_code_multiplier=secondary_code_multiplier,
|
|
101
|
+
gamma=gamma,
|
|
102
|
+
obs_window_size=obs_window_size,
|
|
103
|
+
act_window_size=action_sequence_length,
|
|
104
|
+
sequentially_select=sequentially_select,
|
|
105
|
+
temperature=temperature,
|
|
106
|
+
device=device,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
self._obs_mask_token = (
|
|
110
|
+
nn.Parameter(torch.ones(gpt_input_dim), requires_grad=False)
|
|
111
|
+
if not learned_mask or mask_last_max == 0
|
|
112
|
+
else nn.Parameter(torch.randn(gpt_input_dim), requires_grad=True)
|
|
113
|
+
)
|
|
114
|
+
self._obs_adapter = nn.Linear(obs_dim, gpt_input_dim, bias=False)
|
|
115
|
+
# self._adapt_obs = nn.Linear(obs_dim, gpt_input_dim, bias=False)
|
|
116
|
+
self._mask_last = (mask_last_min, mask_last_max)
|
|
117
|
+
self._action_dim = action_dim
|
|
118
|
+
|
|
119
|
+
self._adapt_obs = (
|
|
120
|
+
self._adapt_obs_linear if not tokenized_bet else self._adapt_obs_tokenized
|
|
121
|
+
)
|
|
122
|
+
self._mask_last = (mask_last_min, mask_last_max)
|
|
123
|
+
self._action_dim = action_dim
|
|
124
|
+
|
|
125
|
+
self.step = self._step if not tokenized_bet else self._step_tokenized
|
|
126
|
+
|
|
127
|
+
self._seen_action_stack = []
|
|
128
|
+
self._start_and_ends = None
|
|
129
|
+
|
|
130
|
+
def _adapt_obs_tokenized(self, obs):
|
|
131
|
+
return einops.rearrange(
|
|
132
|
+
self._obs_adapter(einops.rearrange(obs, "... c h w -> ... h w c")),
|
|
133
|
+
"... h w c -> ... c h w",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def _adapt_obs_linear(self, obs):
|
|
137
|
+
return self._obs_adapter(obs)
|
|
138
|
+
|
|
139
|
+
def _begin_epoch(self, *args, **kwargs):
|
|
140
|
+
return self._vqbet._begin_epoch(*args, **kwargs)
|
|
141
|
+
|
|
142
|
+
def forward(self, data, output, *args, **kwargs):
|
|
143
|
+
# TODO Mahi fix the order of depth and goals.
|
|
144
|
+
if self._use_goals:
|
|
145
|
+
_, goals, *_, padding, actions = data
|
|
146
|
+
if self._use_depth:
|
|
147
|
+
_, goals, *_, depths, padding, actions = data
|
|
148
|
+
output = torch.cat([output, self._depth_net(depths)], dim=-1)
|
|
149
|
+
goals = self._goal_adapter(goals)
|
|
150
|
+
else:
|
|
151
|
+
*_, padding, actions = data
|
|
152
|
+
goals = None
|
|
153
|
+
if self._use_depth:
|
|
154
|
+
*_, depths, padding, actions = data
|
|
155
|
+
output = torch.cat([output, self._depth_net(depths)], dim=-1)
|
|
156
|
+
adapted_obs = self._adapt_obs(output)
|
|
157
|
+
if "second_half" in kwargs:
|
|
158
|
+
second_half = kwargs["second_half"]
|
|
159
|
+
else:
|
|
160
|
+
second_half = False
|
|
161
|
+
action_seq = actions[..., : self._action_dim].contiguous()
|
|
162
|
+
_, loss, loss_dict = self._vqbet(
|
|
163
|
+
adapted_obs,
|
|
164
|
+
goal_seq=goals,
|
|
165
|
+
action_seq=action_seq,
|
|
166
|
+
second_half=second_half,
|
|
167
|
+
)
|
|
168
|
+
return loss, loss_dict
|
|
169
|
+
|
|
170
|
+
@torch.no_grad()
|
|
171
|
+
def _step(self, data, output, return_all=False, *args, **kwargs):
|
|
172
|
+
if self._use_depth:
|
|
173
|
+
*_, depths, padding, actions = data
|
|
174
|
+
output = torch.cat([output, self._depth_net(depths)], dim=-1)
|
|
175
|
+
goals = data[1] if self._use_goals else None
|
|
176
|
+
adapted_obs = self._adapt_obs(output)
|
|
177
|
+
goals = self._goal_adapter(goals) if self._use_goals else None
|
|
178
|
+
a_hat, _, _ = self._vqbet(
|
|
179
|
+
adapted_obs,
|
|
180
|
+
goal_seq=goals,
|
|
181
|
+
action_seq=None,
|
|
182
|
+
)
|
|
183
|
+
if a_hat.shape[2] != self._true_action_dim:
|
|
184
|
+
# append n zeros and a 1 to the end of y_hat
|
|
185
|
+
a_hat = torch.cat(
|
|
186
|
+
[
|
|
187
|
+
a_hat,
|
|
188
|
+
torch.zeros(
|
|
189
|
+
a_hat.shape[0],
|
|
190
|
+
a_hat.shape[1],
|
|
191
|
+
self._true_action_dim - a_hat.shape[2],
|
|
192
|
+
).to(a_hat.device),
|
|
193
|
+
],
|
|
194
|
+
dim=2,
|
|
195
|
+
)
|
|
196
|
+
a_hat[:, :, -1] = 1.0
|
|
197
|
+
if return_all:
|
|
198
|
+
return a_hat.reshape(
|
|
199
|
+
adapted_obs.shape[0], adapted_obs.shape[1], self._true_action_dim
|
|
200
|
+
)[:, -1, :], {}
|
|
201
|
+
# Finally, return the final action prediction only.
|
|
202
|
+
return a_hat[-1, -1, :], {}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import *
|