opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
opentau/optim/factory.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
from torch.optim import Optimizer
|
|
20
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
21
|
+
|
|
22
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
23
|
+
from opentau.policies.pretrained import PreTrainedPolicy
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def make_optimizer_and_scheduler(
|
|
27
|
+
cfg: TrainPipelineConfig, policy: PreTrainedPolicy
|
|
28
|
+
) -> tuple[Optimizer, LRScheduler | None]:
|
|
29
|
+
"""Generates the optimizer and scheduler based on configs.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
cfg (TrainPipelineConfig): The training config that contains optimizer and scheduler configs
|
|
33
|
+
policy (PreTrainedPolicy): The policy config from which parameters and presets must be taken from.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`.
|
|
37
|
+
"""
|
|
38
|
+
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
|
39
|
+
# When using `accelerate`, unused parameters that require grad can result in a RuntimeError("Expected to have
|
|
40
|
+
# finished reduction in the prior iteration before starting a new one.")
|
|
41
|
+
optimizer = cfg.optimizer.build(p for p in params if p.requires_grad)
|
|
42
|
+
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
|
43
|
+
return optimizer, lr_scheduler
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
import abc
|
|
18
|
+
from dataclasses import asdict, dataclass
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Iterable
|
|
21
|
+
|
|
22
|
+
import draccus
|
|
23
|
+
import torch
|
|
24
|
+
from safetensors.torch import load_file, save_file
|
|
25
|
+
from torch.nn.parameter import Parameter
|
|
26
|
+
|
|
27
|
+
from opentau.constants import (
|
|
28
|
+
OPTIMIZER_PARAM_GROUPS,
|
|
29
|
+
OPTIMIZER_STATE,
|
|
30
|
+
)
|
|
31
|
+
from opentau.datasets.utils import flatten_dict, unflatten_dict, write_json
|
|
32
|
+
from opentau.utils.io_utils import deserialize_json_into_object
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
37
|
+
lr: float
|
|
38
|
+
weight_decay: float
|
|
39
|
+
grad_clip_norm: float
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def type(self) -> str:
|
|
43
|
+
return self.get_choice_name(self.__class__)
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def default_choice_name(cls) -> str | None:
|
|
47
|
+
return "adam"
|
|
48
|
+
|
|
49
|
+
@abc.abstractmethod
|
|
50
|
+
def build(self, params: Iterable[Parameter]) -> torch.optim.Optimizer:
|
|
51
|
+
raise NotImplementedError
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@OptimizerConfig.register_subclass("adam")
|
|
55
|
+
@dataclass
|
|
56
|
+
class AdamConfig(OptimizerConfig):
|
|
57
|
+
lr: float = 1e-3
|
|
58
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
59
|
+
eps: float = 1e-8
|
|
60
|
+
weight_decay: float = 0.0
|
|
61
|
+
grad_clip_norm: float = 10.0
|
|
62
|
+
|
|
63
|
+
def build(self, params: Iterable[Parameter]) -> torch.optim.Optimizer:
|
|
64
|
+
kwargs = asdict(self)
|
|
65
|
+
kwargs.pop("grad_clip_norm")
|
|
66
|
+
return torch.optim.Adam(params, **kwargs)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@OptimizerConfig.register_subclass("adamw")
|
|
70
|
+
@dataclass
|
|
71
|
+
class AdamWConfig(OptimizerConfig):
|
|
72
|
+
lr: float = 1e-3
|
|
73
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
74
|
+
eps: float = 1e-8
|
|
75
|
+
weight_decay: float = 1e-2
|
|
76
|
+
grad_clip_norm: float = 10.0
|
|
77
|
+
|
|
78
|
+
def build(self, params: Iterable[Parameter]) -> torch.optim.Optimizer:
|
|
79
|
+
kwargs = asdict(self)
|
|
80
|
+
kwargs.pop("grad_clip_norm")
|
|
81
|
+
return torch.optim.AdamW(params, **kwargs)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@OptimizerConfig.register_subclass("sgd")
|
|
85
|
+
@dataclass
|
|
86
|
+
class SGDConfig(OptimizerConfig):
|
|
87
|
+
lr: float = 1e-3
|
|
88
|
+
momentum: float = 0.0
|
|
89
|
+
dampening: float = 0.0
|
|
90
|
+
nesterov: bool = False
|
|
91
|
+
weight_decay: float = 0.0
|
|
92
|
+
grad_clip_norm: float = 10.0
|
|
93
|
+
|
|
94
|
+
def build(self, params: Iterable[Parameter]) -> torch.optim.Optimizer:
|
|
95
|
+
kwargs = asdict(self)
|
|
96
|
+
kwargs.pop("grad_clip_norm")
|
|
97
|
+
return torch.optim.SGD(params, **kwargs)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
|
101
|
+
state = optimizer.state_dict()
|
|
102
|
+
param_groups = state.pop("param_groups")
|
|
103
|
+
flat_state = flatten_dict(state)
|
|
104
|
+
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
|
105
|
+
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
|
109
|
+
current_state_dict = optimizer.state_dict()
|
|
110
|
+
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
|
111
|
+
state = unflatten_dict(flat_state)
|
|
112
|
+
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
|
113
|
+
|
|
114
|
+
if "param_groups" in current_state_dict:
|
|
115
|
+
param_groups = deserialize_json_into_object(
|
|
116
|
+
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
|
117
|
+
)
|
|
118
|
+
loaded_state_dict["param_groups"] = param_groups
|
|
119
|
+
|
|
120
|
+
optimizer.load_state_dict(loaded_state_dict)
|
|
121
|
+
return optimizer
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
import abc
|
|
18
|
+
import math
|
|
19
|
+
from dataclasses import asdict, dataclass
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import draccus
|
|
23
|
+
from torch.optim import Optimizer
|
|
24
|
+
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
|
25
|
+
|
|
26
|
+
from opentau.constants import SCHEDULER_STATE
|
|
27
|
+
from opentau.datasets.utils import write_json
|
|
28
|
+
from opentau.utils.io_utils import deserialize_json_into_object
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
33
|
+
num_warmup_steps: int
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def type(self) -> str:
|
|
37
|
+
return self.get_choice_name(self.__class__)
|
|
38
|
+
|
|
39
|
+
@abc.abstractmethod
|
|
40
|
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LRScheduler | None:
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@LRSchedulerConfig.register_subclass("diffuser")
|
|
45
|
+
@dataclass
|
|
46
|
+
class DiffuserSchedulerConfig(LRSchedulerConfig):
|
|
47
|
+
name: str = "cosine"
|
|
48
|
+
num_warmup_steps: int | None = None
|
|
49
|
+
|
|
50
|
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
|
51
|
+
from diffusers.optimization import get_scheduler
|
|
52
|
+
|
|
53
|
+
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
|
54
|
+
return get_scheduler(**kwargs)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@LRSchedulerConfig.register_subclass("vqbet")
|
|
58
|
+
@dataclass
|
|
59
|
+
class VQBeTSchedulerConfig(LRSchedulerConfig):
|
|
60
|
+
num_warmup_steps: int
|
|
61
|
+
num_vqvae_training_steps: int
|
|
62
|
+
num_cycles: float = 0.5
|
|
63
|
+
|
|
64
|
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
|
65
|
+
def lr_lambda(current_step):
|
|
66
|
+
if current_step < self.num_vqvae_training_steps:
|
|
67
|
+
return float(1)
|
|
68
|
+
else:
|
|
69
|
+
adjusted_step = current_step - self.num_vqvae_training_steps
|
|
70
|
+
if adjusted_step < self.num_warmup_steps:
|
|
71
|
+
return float(adjusted_step) / float(max(1, self.num_warmup_steps))
|
|
72
|
+
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
|
73
|
+
max(1, num_training_steps - self.num_warmup_steps)
|
|
74
|
+
)
|
|
75
|
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
|
76
|
+
|
|
77
|
+
return LambdaLR(optimizer, lr_lambda, -1)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
|
81
|
+
@dataclass
|
|
82
|
+
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|
83
|
+
"""Used by Physical Intelligence to train Pi0"""
|
|
84
|
+
|
|
85
|
+
num_warmup_steps: int
|
|
86
|
+
num_decay_steps: int
|
|
87
|
+
peak_lr: float
|
|
88
|
+
decay_lr: float
|
|
89
|
+
|
|
90
|
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
|
91
|
+
del num_training_steps
|
|
92
|
+
|
|
93
|
+
def lr_lambda(current_step):
|
|
94
|
+
def linear_warmup_schedule(current_step):
|
|
95
|
+
if current_step <= 0:
|
|
96
|
+
return 1 / (self.num_warmup_steps + 1)
|
|
97
|
+
frac = 1 - current_step / self.num_warmup_steps
|
|
98
|
+
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
|
99
|
+
|
|
100
|
+
def cosine_decay_schedule(current_step):
|
|
101
|
+
step = min(current_step, self.num_decay_steps)
|
|
102
|
+
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
|
103
|
+
alpha = self.decay_lr / self.peak_lr
|
|
104
|
+
decayed = (1 - alpha) * cosine_decay + alpha
|
|
105
|
+
return decayed
|
|
106
|
+
|
|
107
|
+
if current_step < self.num_warmup_steps:
|
|
108
|
+
return linear_warmup_schedule(current_step)
|
|
109
|
+
|
|
110
|
+
return cosine_decay_schedule(current_step)
|
|
111
|
+
|
|
112
|
+
return LambdaLR(optimizer, lr_lambda, -1)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@LRSchedulerConfig.register_subclass("constant")
|
|
116
|
+
@dataclass
|
|
117
|
+
class ConstantSchedulerConfig(LRSchedulerConfig):
|
|
118
|
+
"""Constant learning rate scheduler that doesn't change the learning rate over time"""
|
|
119
|
+
|
|
120
|
+
num_warmup_steps: int = 0
|
|
121
|
+
|
|
122
|
+
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
|
123
|
+
del num_training_steps
|
|
124
|
+
|
|
125
|
+
def lr_lambda(current_step):
|
|
126
|
+
# Always return 1.0 to keep the learning rate constant
|
|
127
|
+
return 1.0
|
|
128
|
+
|
|
129
|
+
return LambdaLR(optimizer, lr_lambda, -1)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
|
133
|
+
state_dict = scheduler.state_dict()
|
|
134
|
+
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
|
138
|
+
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
|
139
|
+
scheduler.load_state_dict(state_dict)
|
|
140
|
+
return scheduler
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""High-level planning for robots using vision-language models.
|
|
15
|
+
|
|
16
|
+
This module provides high-level planning capabilities that convert natural
|
|
17
|
+
language task descriptions into low-level action plans using vision-language
|
|
18
|
+
models (VLMs). It supports both manipulation and navigation tasks, with
|
|
19
|
+
integration for both open-source and closed-source models.
|
|
20
|
+
|
|
21
|
+
The planner acts as a bridge between high-level language commands (e.g., "Pick
|
|
22
|
+
up the red block and place it on the table") and low-level action sequences
|
|
23
|
+
that can be executed by robot policies. It processes visual observations
|
|
24
|
+
(camera images) along with task descriptions to generate structured plans.
|
|
25
|
+
|
|
26
|
+
Key Features:
|
|
27
|
+
|
|
28
|
+
- **Multi-model Support**: Works with both open-source models (CogVLM,
|
|
29
|
+
SmolVLM variants) and closed-source models (GPT-4o via OpenAI API).
|
|
30
|
+
- **Task-specific Planners**: Specialized planners for manipulation and
|
|
31
|
+
navigation tasks with task-appropriate prompts and image processing.
|
|
32
|
+
- **Conversation Memory**: Maintains conversation history for multi-turn
|
|
33
|
+
planning and context-aware plan generation.
|
|
34
|
+
- **Cost Tracking**: Automatic cost calculation for GPT-4o API usage.
|
|
35
|
+
- **Prompt Library**: YAML-based prompt templates for different task types
|
|
36
|
+
and scenarios.
|
|
37
|
+
- **Image Processing**: Automatic conversion of camera tensors to base64
|
|
38
|
+
format for API-based models.
|
|
39
|
+
|
|
40
|
+
Main Classes:
|
|
41
|
+
|
|
42
|
+
- **BaseHighLevelPlanner**: Abstract base class defining the planner
|
|
43
|
+
interface with inference and cost calculation methods.
|
|
44
|
+
- **HighLevelPlanner**: Planner for manipulation tasks, supporting both
|
|
45
|
+
GPT-4o and open-source vision-language models (CogVLM, SmolVLM variants).
|
|
46
|
+
- **NavHighLevelPlanner**: Specialized planner for navigation tasks with
|
|
47
|
+
support for processing multiple camera views.
|
|
48
|
+
- **Memory**: Conversation history manager that stores and retrieves
|
|
49
|
+
multi-turn conversations between user and LLM assistant.
|
|
50
|
+
|
|
51
|
+
Supported Models:
|
|
52
|
+
|
|
53
|
+
- **Open-source**: CogVLM-Chat-HF, SmolVLM-256M-Instruct,
|
|
54
|
+
SmolVLM-500M-Instruct, SmolVLM2-2.2B-Instruct
|
|
55
|
+
- **Closed-source**: GPT-4o (via OpenAI API)
|
|
56
|
+
|
|
57
|
+
Modules:
|
|
58
|
+
|
|
59
|
+
- **high_level_planner**: Core planner implementations for manipulation
|
|
60
|
+
and navigation tasks.
|
|
61
|
+
- **utils.memory**: Conversation memory management for maintaining context.
|
|
62
|
+
- **utils.utils**: Utility functions for image encoding and prompt loading.
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
Create a planner and generate a plan:
|
|
66
|
+
|
|
67
|
+
>>> from opentau.planner import HighLevelPlanner, Memory
|
|
68
|
+
>>> planner = HighLevelPlanner()
|
|
69
|
+
>>> memory = Memory()
|
|
70
|
+
>>> image_dict = {"camera0": camera_tensor}
|
|
71
|
+
>>> task = "Pick up the red block and place it on the table"
|
|
72
|
+
>>> plan = planner.inference(
|
|
73
|
+
... image_dict=image_dict,
|
|
74
|
+
... model_name="gpt4o",
|
|
75
|
+
... task=task,
|
|
76
|
+
... mem=memory
|
|
77
|
+
... )
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
from .high_level_planner import HighLevelPlanner as HighLevelPlanner
|
|
81
|
+
from .high_level_planner import NavHighLevelPlanner as NavHighLevelPlanner
|
|
82
|
+
from .utils.memory import Memory as Memory
|