opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Utilities for training checkpoint management and state persistence.
|
|
18
|
+
|
|
19
|
+
This module provides functions for saving and loading training checkpoints,
|
|
20
|
+
managing checkpoint directories, and pruning old checkpoints.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import logging
|
|
24
|
+
import shutil
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
|
|
27
|
+
from termcolor import colored
|
|
28
|
+
from torch.optim import Optimizer
|
|
29
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
30
|
+
|
|
31
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
32
|
+
from opentau.constants import (
|
|
33
|
+
CHECKPOINTS_DIR,
|
|
34
|
+
LAST_CHECKPOINT_LINK,
|
|
35
|
+
TRAINING_STEP,
|
|
36
|
+
)
|
|
37
|
+
from opentau.datasets.utils import load_json, write_json
|
|
38
|
+
from opentau.utils.random_utils import load_rng_state, save_rng_state
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def log_output_dir(out_dir):
|
|
42
|
+
"""Log the output directory path with colored formatting.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
out_dir: Output directory path to log.
|
|
46
|
+
"""
|
|
47
|
+
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_step_identifier(step: int, total_steps: int) -> str:
|
|
51
|
+
"""Generate a zero-padded step identifier string.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
step: Current step number.
|
|
55
|
+
total_steps: Total number of steps (used to determine padding width).
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Zero-padded string representation of the step number.
|
|
59
|
+
"""
|
|
60
|
+
num_digits = max(6, len(str(total_steps)))
|
|
61
|
+
return f"{step:0{num_digits}d}"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
|
|
65
|
+
"""Returns the checkpoint sub-directory corresponding to the step number."""
|
|
66
|
+
step_identifier = get_step_identifier(step, total_steps)
|
|
67
|
+
return output_dir / CHECKPOINTS_DIR / step_identifier
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def save_training_step(step: int, save_dir: Path) -> None:
|
|
71
|
+
"""Save the current training step number to a file.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
step: Current training step number.
|
|
75
|
+
save_dir: Directory where the step file will be saved.
|
|
76
|
+
"""
|
|
77
|
+
write_json({"step": step}, save_dir / TRAINING_STEP)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def load_training_step(save_dir: Path) -> int:
|
|
81
|
+
"""Load the training step number from a file.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
save_dir: Directory containing the step file.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Training step number.
|
|
88
|
+
"""
|
|
89
|
+
training_step = load_json(save_dir / TRAINING_STEP)
|
|
90
|
+
return training_step["step"]
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
|
94
|
+
"""Update the symlink pointing to the last checkpoint.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
checkpoint_dir: Path to the checkpoint directory to link.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Path to the symlink that was created or updated.
|
|
101
|
+
"""
|
|
102
|
+
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
|
103
|
+
if last_checkpoint_dir.is_symlink():
|
|
104
|
+
last_checkpoint_dir.unlink()
|
|
105
|
+
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
|
|
106
|
+
last_checkpoint_dir.symlink_to(relative_target)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def save_checkpoint(
|
|
110
|
+
checkpoint_dir: Path,
|
|
111
|
+
step: int,
|
|
112
|
+
cfg: TrainPipelineConfig,
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Save training checkpoint including config and RNG state.
|
|
115
|
+
|
|
116
|
+
Note: accelerate saves the model and training run. This method saves all
|
|
117
|
+
other auxiliary objects such as configs and RNG state.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
checkpoint_dir: Directory where the checkpoint will be saved.
|
|
121
|
+
step: Current training step number.
|
|
122
|
+
cfg: The training config used for this run.
|
|
123
|
+
"""
|
|
124
|
+
cfg.save_pretrained(checkpoint_dir)
|
|
125
|
+
save_training_step(step, checkpoint_dir)
|
|
126
|
+
save_rng_state(checkpoint_dir)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def load_training_state(checkpoint_dir: Path) -> tuple[int, Optimizer, LRScheduler | None]:
|
|
130
|
+
"""Load training state including step, optimizer, scheduler, and RNG state.
|
|
131
|
+
|
|
132
|
+
This is used to resume a training run. Note: optimizer and scheduler states
|
|
133
|
+
are loaded by accelerate, not by this function.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
checkpoint_dir: The checkpoint directory. Should contain a 'training_state' dir.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Tuple containing the training step number. Note: optimizer and scheduler
|
|
140
|
+
are loaded separately by accelerate.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
NotADirectoryError: If checkpoint_dir doesn't exist or is not a directory.
|
|
144
|
+
"""
|
|
145
|
+
if not checkpoint_dir.is_dir():
|
|
146
|
+
raise NotADirectoryError(checkpoint_dir)
|
|
147
|
+
|
|
148
|
+
load_rng_state(checkpoint_dir)
|
|
149
|
+
step = load_training_step(checkpoint_dir)
|
|
150
|
+
|
|
151
|
+
return step
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def prune_old_checkpoints(latest_checkpoint_path: str) -> None:
|
|
155
|
+
"""Delete all checkpoint directories except the specified one.
|
|
156
|
+
|
|
157
|
+
Recursively deletes all checkpoint directories in a parent folder except
|
|
158
|
+
for the specified one. This function is designed to clean up old model
|
|
159
|
+
checkpoints, preserving only the most recent one. It includes safety checks
|
|
160
|
+
to ensure it only deletes directories and handles potential filesystem errors.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
latest_checkpoint_path: The full path to the checkpoint directory
|
|
164
|
+
that should be kept.
|
|
165
|
+
"""
|
|
166
|
+
try:
|
|
167
|
+
latest_checkpoint = Path(latest_checkpoint_path).resolve()
|
|
168
|
+
parent_dir = latest_checkpoint.parent
|
|
169
|
+
|
|
170
|
+
if not parent_dir.is_dir():
|
|
171
|
+
logging.error(f"Parent directory '{parent_dir.resolve()}' does not exist. Aborting cleanup.")
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
if not latest_checkpoint.is_dir():
|
|
175
|
+
logging.warning(
|
|
176
|
+
f"Checkpoint '{latest_checkpoint.resolve()}' is not a valid directory. Aborting cleanup."
|
|
177
|
+
)
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
logging.info(
|
|
181
|
+
f"Starting cleanup in '{parent_dir.resolve()}'. Keeping checkpoint: '{latest_checkpoint.name}'"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Iterate and delete other directories
|
|
185
|
+
for item in parent_dir.iterdir():
|
|
186
|
+
# Skip the checkpoint we want to keep and any files
|
|
187
|
+
if item.resolve() == latest_checkpoint.resolve() or not item.is_dir():
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
logging.info(f"Deleting old checkpoint directory: {item.name}")
|
|
192
|
+
shutil.rmtree(item)
|
|
193
|
+
logging.info(f"Successfully deleted {item.name}")
|
|
194
|
+
except OSError as e:
|
|
195
|
+
logging.error(f"Failed to delete '{item.name}'. Error: {e}")
|
|
196
|
+
|
|
197
|
+
except Exception as e:
|
|
198
|
+
logging.critical(f"An unexpected error occurred during checkpoint pruning setup: {e}")
|
opentau/utils/utils.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""General utility functions for device management, logging, and common operations.
|
|
18
|
+
|
|
19
|
+
This module provides utilities for device selection, logging initialization,
|
|
20
|
+
number formatting, platform-specific operations, and various helper functions
|
|
21
|
+
used throughout the OpenTau codebase.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import enum
|
|
25
|
+
import inspect
|
|
26
|
+
import logging
|
|
27
|
+
import os
|
|
28
|
+
import platform
|
|
29
|
+
import warnings
|
|
30
|
+
from copy import copy
|
|
31
|
+
from dataclasses import fields, is_dataclass
|
|
32
|
+
from datetime import datetime, timezone
|
|
33
|
+
from functools import wraps
|
|
34
|
+
from typing import Any, Callable
|
|
35
|
+
|
|
36
|
+
import accelerate
|
|
37
|
+
import numpy as np
|
|
38
|
+
import torch
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def inside_slurm() -> bool:
|
|
42
|
+
"""Check whether the Python process was launched through SLURM.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
True if the process is running in a SLURM environment, False otherwise.
|
|
46
|
+
"""
|
|
47
|
+
# TODO(rcadene): return False for interactive mode `--pty bash`
|
|
48
|
+
return "SLURM_JOB_ID" in os.environ
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def auto_torch_device() -> torch.device:
|
|
52
|
+
"""Automatically select the best available torch device.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
torch.device instance. Priority: CUDA > MPS > CPU.
|
|
56
|
+
"""
|
|
57
|
+
if torch.cuda.is_available():
|
|
58
|
+
return torch.device("cuda")
|
|
59
|
+
if torch.backends.mps.is_available():
|
|
60
|
+
return torch.device("mps")
|
|
61
|
+
return torch.device("cpu")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_safe_torch_device(try_device: str, log: bool = False, accelerator: Callable = None) -> torch.device:
|
|
65
|
+
"""Given a string, return a torch.device with checks on whether the device is available."""
|
|
66
|
+
match try_device:
|
|
67
|
+
case "cuda":
|
|
68
|
+
assert torch.cuda.is_available()
|
|
69
|
+
device = accelerator.device if accelerator else torch.device("cuda")
|
|
70
|
+
case "mps":
|
|
71
|
+
assert torch.backends.mps.is_available()
|
|
72
|
+
device = torch.device("mps")
|
|
73
|
+
case "cpu":
|
|
74
|
+
device = torch.device("cpu")
|
|
75
|
+
if log:
|
|
76
|
+
logging.warning("Using CPU, this will be slow.")
|
|
77
|
+
case _:
|
|
78
|
+
device = torch.device(try_device)
|
|
79
|
+
if log:
|
|
80
|
+
logging.warning(f"Using custom {try_device} device.")
|
|
81
|
+
|
|
82
|
+
return device
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device) -> torch.dtype:
|
|
86
|
+
"""Get a dtype that is compatible with the given device.
|
|
87
|
+
|
|
88
|
+
MPS is currently not compatible with float64, so this function converts
|
|
89
|
+
float64 to float32 for MPS devices.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
dtype: Desired dtype.
|
|
93
|
+
device: Target device (string or torch.device).
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Compatible dtype for the device.
|
|
97
|
+
"""
|
|
98
|
+
if isinstance(device, torch.device):
|
|
99
|
+
device = device.type
|
|
100
|
+
if device == "mps" and dtype == torch.float64:
|
|
101
|
+
return torch.float32
|
|
102
|
+
else:
|
|
103
|
+
return dtype
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def is_torch_device_available(try_device: str) -> bool:
|
|
107
|
+
"""Check if a torch device is available.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
try_device: Device name to check ('cuda', 'mps', or 'cpu').
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
True if the device is available, False otherwise.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ValueError: If try_device is not one of the recognized device types.
|
|
117
|
+
"""
|
|
118
|
+
if try_device == "cuda":
|
|
119
|
+
return torch.cuda.is_available()
|
|
120
|
+
elif try_device == "mps":
|
|
121
|
+
return torch.backends.mps.is_available()
|
|
122
|
+
elif try_device == "cpu":
|
|
123
|
+
return True
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Unknown device '{try_device}.")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def is_amp_available(device: str) -> bool:
|
|
129
|
+
"""Check if automatic mixed precision (AMP) is available for a device.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
device: Device name to check ('cuda', 'mps', or 'cpu').
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
True if AMP is available for the device, False otherwise.
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If device is not one of the recognized device types.
|
|
139
|
+
"""
|
|
140
|
+
if device in ["cuda", "cpu"]:
|
|
141
|
+
return True
|
|
142
|
+
elif device == "mps":
|
|
143
|
+
return False
|
|
144
|
+
else:
|
|
145
|
+
raise ValueError(f"Unknown device '{device}.")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# Global variable to ensure logging is initialized only once
|
|
149
|
+
_logging_init_stack = ""
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _format_stack(stack: list[inspect.FrameInfo]) -> str:
|
|
153
|
+
return "\n".join(
|
|
154
|
+
f" File '{frame.filename}', line {frame.lineno}, in {frame.function}"
|
|
155
|
+
for frame in stack[1:] # skip the current frame
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def init_logging(accelerator: accelerate.Accelerator | None = None, level=logging.INFO) -> None:
|
|
160
|
+
"""Initialize logging configuration with custom formatter.
|
|
161
|
+
|
|
162
|
+
This function sets up logging with a custom formatter that includes
|
|
163
|
+
timestamp, filename, and line number. It can only be initialized once
|
|
164
|
+
per process and will warn if called multiple times.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
accelerator: Optional Accelerator instance. If provided, logging level
|
|
168
|
+
is set to WARNING on non-main processes to avoid duplicate logs.
|
|
169
|
+
level: Logging level to use. Defaults to logging.INFO.
|
|
170
|
+
"""
|
|
171
|
+
global _logging_init_stack
|
|
172
|
+
stack = inspect.stack()
|
|
173
|
+
|
|
174
|
+
if _logging_init_stack:
|
|
175
|
+
warnings.warn(
|
|
176
|
+
f"""Logging was already initialized through the following stack:
|
|
177
|
+
{_logging_init_stack}
|
|
178
|
+
Not initializing again through the following stack:
|
|
179
|
+
{_format_stack(stack)}""",
|
|
180
|
+
stacklevel=2,
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
_logging_init_stack = _format_stack(stack)
|
|
184
|
+
|
|
185
|
+
class CustomFormatter(logging.Formatter):
|
|
186
|
+
def format(self, record):
|
|
187
|
+
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
188
|
+
fnameline = f"{record.pathname}:{record.lineno}"
|
|
189
|
+
return f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
|
190
|
+
|
|
191
|
+
console_handler = logging.StreamHandler()
|
|
192
|
+
console_handler.setFormatter(CustomFormatter())
|
|
193
|
+
|
|
194
|
+
logging.basicConfig(level=level, force=True, handlers=[console_handler])
|
|
195
|
+
|
|
196
|
+
if accelerator and not accelerator.is_main_process:
|
|
197
|
+
# Disable duplicate logging on non-main processes
|
|
198
|
+
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
|
199
|
+
logging.getLogger().setLevel(logging.WARNING)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def format_big_number(num: float | int, precision: int = 0) -> str:
|
|
203
|
+
"""Format a large number with appropriate suffix (K, M, B, T, Q).
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
num: Number to format.
|
|
207
|
+
precision: Number of decimal places. Defaults to 0.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Formatted string with suffix (e.g., "1.5K", "2.3M").
|
|
211
|
+
"""
|
|
212
|
+
suffixes = ["", "K", "M", "B", "T", "Q"]
|
|
213
|
+
divisor = 1000.0
|
|
214
|
+
|
|
215
|
+
for suffix in suffixes:
|
|
216
|
+
if abs(num) < divisor:
|
|
217
|
+
return f"{num:.{precision}f}{suffix}"
|
|
218
|
+
num /= divisor
|
|
219
|
+
|
|
220
|
+
return num
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def capture_timestamp_utc() -> datetime:
|
|
224
|
+
"""Capture the current UTC timestamp.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
datetime object representing the current UTC time.
|
|
228
|
+
"""
|
|
229
|
+
return datetime.now(timezone.utc)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def say(text: str, blocking: bool = False) -> None:
|
|
233
|
+
"""Use text-to-speech to speak text (platform-specific).
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
text: Text to speak.
|
|
237
|
+
blocking: If True, wait for speech to complete before returning.
|
|
238
|
+
Defaults to False.
|
|
239
|
+
"""
|
|
240
|
+
# Check if mac, linux, or windows.
|
|
241
|
+
if platform.system() == "Darwin":
|
|
242
|
+
cmd = f'say "{text}"'
|
|
243
|
+
if not blocking:
|
|
244
|
+
cmd += " &"
|
|
245
|
+
elif platform.system() == "Linux":
|
|
246
|
+
cmd = f'spd-say "{text}"'
|
|
247
|
+
if blocking:
|
|
248
|
+
cmd += " --wait"
|
|
249
|
+
elif platform.system() == "Windows":
|
|
250
|
+
# TODO(rcadene): Make blocking option work for Windows
|
|
251
|
+
cmd = (
|
|
252
|
+
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
|
253
|
+
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
os.system(cmd) # nosec: B605
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def log_say(text: str, play_sounds: bool, blocking: bool = False) -> None:
|
|
260
|
+
"""Log text and optionally speak it using text-to-speech.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
text: Text to log and optionally speak.
|
|
264
|
+
play_sounds: If True, also speak the text using text-to-speech.
|
|
265
|
+
blocking: If True, wait for speech to complete. Defaults to False.
|
|
266
|
+
"""
|
|
267
|
+
logging.info(text)
|
|
268
|
+
|
|
269
|
+
if play_sounds:
|
|
270
|
+
say(text, blocking)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
|
|
274
|
+
"""Convert image shape from HWC to CHW format if needed.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
image_shape: Image shape tuple, either (H, W, C) or (C, H, W).
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Image shape in CHW format (C, H, W).
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
ValueError: If the input shape is not in a recognized format.
|
|
284
|
+
"""
|
|
285
|
+
shape = copy(image_shape)
|
|
286
|
+
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
|
|
287
|
+
shape = (shape[2], shape[0], shape[1])
|
|
288
|
+
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
|
|
289
|
+
raise ValueError(image_shape)
|
|
290
|
+
|
|
291
|
+
return shape
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def has_method(cls: object, method_name: str) -> bool:
|
|
295
|
+
"""Check if a class or object has a specific method.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
cls: Class or object to check.
|
|
299
|
+
method_name: Name of the method to check for.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
True if the method exists and is callable, False otherwise.
|
|
303
|
+
"""
|
|
304
|
+
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
|
308
|
+
"""Check if a string can be converted to a numpy dtype.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
dtype_str: String to check.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
True if the string is a valid numpy dtype, False otherwise.
|
|
315
|
+
"""
|
|
316
|
+
try:
|
|
317
|
+
# Attempt to convert the string to a numpy dtype
|
|
318
|
+
np.dtype(dtype_str)
|
|
319
|
+
return True
|
|
320
|
+
except TypeError:
|
|
321
|
+
# If a TypeError is raised, the string is not a valid dtype
|
|
322
|
+
return False
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def is_launched_with_accelerate() -> bool:
|
|
326
|
+
"""Check if the process was launched with accelerate.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
True if ACCELERATE_MIXED_PRECISION is in the environment, False otherwise.
|
|
330
|
+
"""
|
|
331
|
+
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def attempt_torch_compile(fn: callable, device_hint=None) -> callable:
|
|
335
|
+
"""Attempt to compile a PyTorch function using torch.compile.
|
|
336
|
+
|
|
337
|
+
The argument device_hint is used to check if torch.compile works reliably
|
|
338
|
+
on the device. Compilation is skipped if the device is MPS (Metal Performance
|
|
339
|
+
Shaders) as it is experimental.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
fn: Function to compile.
|
|
343
|
+
device_hint: Optional device hint to check compatibility.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
Compiled function if compilation succeeds, otherwise the original function.
|
|
347
|
+
"""
|
|
348
|
+
if device_hint and "mps" in str(device_hint):
|
|
349
|
+
logging.warning("torch.compile is experimental on MPS devices. Compilation skipped.")
|
|
350
|
+
return fn
|
|
351
|
+
|
|
352
|
+
if hasattr(torch, "compile"):
|
|
353
|
+
logging.info("Attempting to compile the policy with torch.compile()...")
|
|
354
|
+
try:
|
|
355
|
+
# Other options: "default", "max-autotune" (longer compile time)
|
|
356
|
+
fn = torch.compile(fn)
|
|
357
|
+
logging.info("Policy compiled successfully.")
|
|
358
|
+
except Exception as e:
|
|
359
|
+
logging.warning(f"torch.compile failed with error: {e}. Proceeding without compilation.")
|
|
360
|
+
else:
|
|
361
|
+
logging.warning(
|
|
362
|
+
"torch.compile is not available. Requires PyTorch 2.0+. Proceeding without compilation."
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
return fn
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def create_dummy_observation(cfg, device, dtype=torch.bfloat16) -> dict:
|
|
369
|
+
"""Create a dummy observation dictionary for testing or initialization.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
cfg: Configuration object with num_cams, resolution, max_state_dim,
|
|
373
|
+
and action_chunk attributes.
|
|
374
|
+
device: Device to create tensors on.
|
|
375
|
+
dtype: Data type for tensors. Defaults to torch.bfloat16.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Dictionary containing dummy camera observations, state, prompt, and
|
|
379
|
+
padding flags.
|
|
380
|
+
"""
|
|
381
|
+
camera_observations = {
|
|
382
|
+
f"camera{i}": torch.zeros((1, 3, *cfg.resolution), dtype=dtype, device=device)
|
|
383
|
+
for i in range(cfg.num_cams)
|
|
384
|
+
}
|
|
385
|
+
return {
|
|
386
|
+
**camera_observations,
|
|
387
|
+
"state": torch.zeros((1, cfg.max_state_dim), dtype=dtype, device=device),
|
|
388
|
+
"prompt": ["Pick up yellow lego block and put it in the bin"],
|
|
389
|
+
"img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device),
|
|
390
|
+
"action_is_pad": torch.zeros((1, cfg.action_chunk), dtype=torch.bool, device=device),
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def encode_accelerator_state_dict(obj) -> Any:
|
|
395
|
+
"""Encode an object into a JSON/YAML-compatible primitive type.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
obj: Object to encode (can be Enum, dict, list, tuple, dataclass, etc.).
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
Encoded object with all nested structures converted to primitives.
|
|
402
|
+
"""
|
|
403
|
+
if isinstance(obj, enum.Enum):
|
|
404
|
+
return encode_accelerator_state_dict(obj.value)
|
|
405
|
+
elif isinstance(obj, (str, int, float, bool)) or obj is None:
|
|
406
|
+
return obj
|
|
407
|
+
elif isinstance(obj, (list, tuple)):
|
|
408
|
+
return [encode_accelerator_state_dict(item) for item in obj]
|
|
409
|
+
elif isinstance(obj, dict):
|
|
410
|
+
return {key.replace(".", "_"): encode_accelerator_state_dict(value) for key, value in obj.items()}
|
|
411
|
+
elif is_dataclass(obj):
|
|
412
|
+
return {f.name: encode_accelerator_state_dict(getattr(obj, f.name)) for f in fields(obj)}
|
|
413
|
+
else:
|
|
414
|
+
return str(obj) # Fallback to string representation for unsupported types
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def on_accelerate_main_proc(*, local=False, _sync=False):
|
|
418
|
+
r"""Returns a decorator to run a function only on the main process when using `accelerate`.
|
|
419
|
+
|
|
420
|
+
If `local` is True (defaults to False), the function will run on the main process of each node
|
|
421
|
+
(useful for multi-node setups).
|
|
422
|
+
If `_sync` is True (defaults to False), the output of the function will be broadcasted to all processes.
|
|
423
|
+
If `_sync` is True, you must ensure that all processes call the decorated function, otherwise it will deadlock.
|
|
424
|
+
|
|
425
|
+
YOU SHOULD BE EXTREMELY CAREFUL WHEN USING THIS DECORATOR with _sync=True. Consider the following example::
|
|
426
|
+
|
|
427
|
+
@on_accelerate_main_proc()
|
|
428
|
+
def f():
|
|
429
|
+
return g()
|
|
430
|
+
|
|
431
|
+
@on_accelerate_main_proc(_sync=True)
|
|
432
|
+
def g():
|
|
433
|
+
return 42
|
|
434
|
+
|
|
435
|
+
In this case, if f() is called on all processes, they will deadlock at g() because child processes don't even
|
|
436
|
+
enter f(), hence never call g(), and thus won't reach the broadcast.
|
|
437
|
+
|
|
438
|
+
Another example::
|
|
439
|
+
|
|
440
|
+
@on_accelerate_main_proc(_sync=cond())
|
|
441
|
+
def f():
|
|
442
|
+
print("hi")
|
|
443
|
+
"""
|
|
444
|
+
|
|
445
|
+
def decorator(func):
|
|
446
|
+
@wraps(func)
|
|
447
|
+
def wrapper(*args, **kwargs):
|
|
448
|
+
state = accelerate.state.PartialState()
|
|
449
|
+
if not is_launched_with_accelerate() or not state.use_distributed:
|
|
450
|
+
return func(*args, **kwargs)
|
|
451
|
+
|
|
452
|
+
output, exception = None, None
|
|
453
|
+
flag = state.is_local_main_process if local else state.is_main_process
|
|
454
|
+
if flag:
|
|
455
|
+
try:
|
|
456
|
+
output = func(*args, **kwargs)
|
|
457
|
+
except Exception as e:
|
|
458
|
+
exception = e
|
|
459
|
+
|
|
460
|
+
if _sync:
|
|
461
|
+
payload = [output, exception]
|
|
462
|
+
accelerate.utils.broadcast_object_list(payload, from_process=0)
|
|
463
|
+
output, exception = payload
|
|
464
|
+
|
|
465
|
+
if exception is not None:
|
|
466
|
+
raise RuntimeError("An exception occurred in the main process.") from exception
|
|
467
|
+
return output
|
|
468
|
+
|
|
469
|
+
return wrapper
|
|
470
|
+
|
|
471
|
+
return decorator
|