opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,198 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Utilities for training checkpoint management and state persistence.
18
+
19
+ This module provides functions for saving and loading training checkpoints,
20
+ managing checkpoint directories, and pruning old checkpoints.
21
+ """
22
+
23
+ import logging
24
+ import shutil
25
+ from pathlib import Path
26
+
27
+ from termcolor import colored
28
+ from torch.optim import Optimizer
29
+ from torch.optim.lr_scheduler import LRScheduler
30
+
31
+ from opentau.configs.train import TrainPipelineConfig
32
+ from opentau.constants import (
33
+ CHECKPOINTS_DIR,
34
+ LAST_CHECKPOINT_LINK,
35
+ TRAINING_STEP,
36
+ )
37
+ from opentau.datasets.utils import load_json, write_json
38
+ from opentau.utils.random_utils import load_rng_state, save_rng_state
39
+
40
+
41
+ def log_output_dir(out_dir):
42
+ """Log the output directory path with colored formatting.
43
+
44
+ Args:
45
+ out_dir: Output directory path to log.
46
+ """
47
+ logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
48
+
49
+
50
+ def get_step_identifier(step: int, total_steps: int) -> str:
51
+ """Generate a zero-padded step identifier string.
52
+
53
+ Args:
54
+ step: Current step number.
55
+ total_steps: Total number of steps (used to determine padding width).
56
+
57
+ Returns:
58
+ Zero-padded string representation of the step number.
59
+ """
60
+ num_digits = max(6, len(str(total_steps)))
61
+ return f"{step:0{num_digits}d}"
62
+
63
+
64
+ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
65
+ """Returns the checkpoint sub-directory corresponding to the step number."""
66
+ step_identifier = get_step_identifier(step, total_steps)
67
+ return output_dir / CHECKPOINTS_DIR / step_identifier
68
+
69
+
70
+ def save_training_step(step: int, save_dir: Path) -> None:
71
+ """Save the current training step number to a file.
72
+
73
+ Args:
74
+ step: Current training step number.
75
+ save_dir: Directory where the step file will be saved.
76
+ """
77
+ write_json({"step": step}, save_dir / TRAINING_STEP)
78
+
79
+
80
+ def load_training_step(save_dir: Path) -> int:
81
+ """Load the training step number from a file.
82
+
83
+ Args:
84
+ save_dir: Directory containing the step file.
85
+
86
+ Returns:
87
+ Training step number.
88
+ """
89
+ training_step = load_json(save_dir / TRAINING_STEP)
90
+ return training_step["step"]
91
+
92
+
93
+ def update_last_checkpoint(checkpoint_dir: Path) -> Path:
94
+ """Update the symlink pointing to the last checkpoint.
95
+
96
+ Args:
97
+ checkpoint_dir: Path to the checkpoint directory to link.
98
+
99
+ Returns:
100
+ Path to the symlink that was created or updated.
101
+ """
102
+ last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
103
+ if last_checkpoint_dir.is_symlink():
104
+ last_checkpoint_dir.unlink()
105
+ relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
106
+ last_checkpoint_dir.symlink_to(relative_target)
107
+
108
+
109
+ def save_checkpoint(
110
+ checkpoint_dir: Path,
111
+ step: int,
112
+ cfg: TrainPipelineConfig,
113
+ ) -> None:
114
+ """Save training checkpoint including config and RNG state.
115
+
116
+ Note: accelerate saves the model and training run. This method saves all
117
+ other auxiliary objects such as configs and RNG state.
118
+
119
+ Args:
120
+ checkpoint_dir: Directory where the checkpoint will be saved.
121
+ step: Current training step number.
122
+ cfg: The training config used for this run.
123
+ """
124
+ cfg.save_pretrained(checkpoint_dir)
125
+ save_training_step(step, checkpoint_dir)
126
+ save_rng_state(checkpoint_dir)
127
+
128
+
129
+ def load_training_state(checkpoint_dir: Path) -> tuple[int, Optimizer, LRScheduler | None]:
130
+ """Load training state including step, optimizer, scheduler, and RNG state.
131
+
132
+ This is used to resume a training run. Note: optimizer and scheduler states
133
+ are loaded by accelerate, not by this function.
134
+
135
+ Args:
136
+ checkpoint_dir: The checkpoint directory. Should contain a 'training_state' dir.
137
+
138
+ Returns:
139
+ Tuple containing the training step number. Note: optimizer and scheduler
140
+ are loaded separately by accelerate.
141
+
142
+ Raises:
143
+ NotADirectoryError: If checkpoint_dir doesn't exist or is not a directory.
144
+ """
145
+ if not checkpoint_dir.is_dir():
146
+ raise NotADirectoryError(checkpoint_dir)
147
+
148
+ load_rng_state(checkpoint_dir)
149
+ step = load_training_step(checkpoint_dir)
150
+
151
+ return step
152
+
153
+
154
+ def prune_old_checkpoints(latest_checkpoint_path: str) -> None:
155
+ """Delete all checkpoint directories except the specified one.
156
+
157
+ Recursively deletes all checkpoint directories in a parent folder except
158
+ for the specified one. This function is designed to clean up old model
159
+ checkpoints, preserving only the most recent one. It includes safety checks
160
+ to ensure it only deletes directories and handles potential filesystem errors.
161
+
162
+ Args:
163
+ latest_checkpoint_path: The full path to the checkpoint directory
164
+ that should be kept.
165
+ """
166
+ try:
167
+ latest_checkpoint = Path(latest_checkpoint_path).resolve()
168
+ parent_dir = latest_checkpoint.parent
169
+
170
+ if not parent_dir.is_dir():
171
+ logging.error(f"Parent directory '{parent_dir.resolve()}' does not exist. Aborting cleanup.")
172
+ return
173
+
174
+ if not latest_checkpoint.is_dir():
175
+ logging.warning(
176
+ f"Checkpoint '{latest_checkpoint.resolve()}' is not a valid directory. Aborting cleanup."
177
+ )
178
+ return
179
+
180
+ logging.info(
181
+ f"Starting cleanup in '{parent_dir.resolve()}'. Keeping checkpoint: '{latest_checkpoint.name}'"
182
+ )
183
+
184
+ # Iterate and delete other directories
185
+ for item in parent_dir.iterdir():
186
+ # Skip the checkpoint we want to keep and any files
187
+ if item.resolve() == latest_checkpoint.resolve() or not item.is_dir():
188
+ continue
189
+
190
+ try:
191
+ logging.info(f"Deleting old checkpoint directory: {item.name}")
192
+ shutil.rmtree(item)
193
+ logging.info(f"Successfully deleted {item.name}")
194
+ except OSError as e:
195
+ logging.error(f"Failed to delete '{item.name}'. Error: {e}")
196
+
197
+ except Exception as e:
198
+ logging.critical(f"An unexpected error occurred during checkpoint pruning setup: {e}")
opentau/utils/utils.py ADDED
@@ -0,0 +1,471 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """General utility functions for device management, logging, and common operations.
18
+
19
+ This module provides utilities for device selection, logging initialization,
20
+ number formatting, platform-specific operations, and various helper functions
21
+ used throughout the OpenTau codebase.
22
+ """
23
+
24
+ import enum
25
+ import inspect
26
+ import logging
27
+ import os
28
+ import platform
29
+ import warnings
30
+ from copy import copy
31
+ from dataclasses import fields, is_dataclass
32
+ from datetime import datetime, timezone
33
+ from functools import wraps
34
+ from typing import Any, Callable
35
+
36
+ import accelerate
37
+ import numpy as np
38
+ import torch
39
+
40
+
41
+ def inside_slurm() -> bool:
42
+ """Check whether the Python process was launched through SLURM.
43
+
44
+ Returns:
45
+ True if the process is running in a SLURM environment, False otherwise.
46
+ """
47
+ # TODO(rcadene): return False for interactive mode `--pty bash`
48
+ return "SLURM_JOB_ID" in os.environ
49
+
50
+
51
+ def auto_torch_device() -> torch.device:
52
+ """Automatically select the best available torch device.
53
+
54
+ Returns:
55
+ torch.device instance. Priority: CUDA > MPS > CPU.
56
+ """
57
+ if torch.cuda.is_available():
58
+ return torch.device("cuda")
59
+ if torch.backends.mps.is_available():
60
+ return torch.device("mps")
61
+ return torch.device("cpu")
62
+
63
+
64
+ def get_safe_torch_device(try_device: str, log: bool = False, accelerator: Callable = None) -> torch.device:
65
+ """Given a string, return a torch.device with checks on whether the device is available."""
66
+ match try_device:
67
+ case "cuda":
68
+ assert torch.cuda.is_available()
69
+ device = accelerator.device if accelerator else torch.device("cuda")
70
+ case "mps":
71
+ assert torch.backends.mps.is_available()
72
+ device = torch.device("mps")
73
+ case "cpu":
74
+ device = torch.device("cpu")
75
+ if log:
76
+ logging.warning("Using CPU, this will be slow.")
77
+ case _:
78
+ device = torch.device(try_device)
79
+ if log:
80
+ logging.warning(f"Using custom {try_device} device.")
81
+
82
+ return device
83
+
84
+
85
+ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device) -> torch.dtype:
86
+ """Get a dtype that is compatible with the given device.
87
+
88
+ MPS is currently not compatible with float64, so this function converts
89
+ float64 to float32 for MPS devices.
90
+
91
+ Args:
92
+ dtype: Desired dtype.
93
+ device: Target device (string or torch.device).
94
+
95
+ Returns:
96
+ Compatible dtype for the device.
97
+ """
98
+ if isinstance(device, torch.device):
99
+ device = device.type
100
+ if device == "mps" and dtype == torch.float64:
101
+ return torch.float32
102
+ else:
103
+ return dtype
104
+
105
+
106
+ def is_torch_device_available(try_device: str) -> bool:
107
+ """Check if a torch device is available.
108
+
109
+ Args:
110
+ try_device: Device name to check ('cuda', 'mps', or 'cpu').
111
+
112
+ Returns:
113
+ True if the device is available, False otherwise.
114
+
115
+ Raises:
116
+ ValueError: If try_device is not one of the recognized device types.
117
+ """
118
+ if try_device == "cuda":
119
+ return torch.cuda.is_available()
120
+ elif try_device == "mps":
121
+ return torch.backends.mps.is_available()
122
+ elif try_device == "cpu":
123
+ return True
124
+ else:
125
+ raise ValueError(f"Unknown device '{try_device}.")
126
+
127
+
128
+ def is_amp_available(device: str) -> bool:
129
+ """Check if automatic mixed precision (AMP) is available for a device.
130
+
131
+ Args:
132
+ device: Device name to check ('cuda', 'mps', or 'cpu').
133
+
134
+ Returns:
135
+ True if AMP is available for the device, False otherwise.
136
+
137
+ Raises:
138
+ ValueError: If device is not one of the recognized device types.
139
+ """
140
+ if device in ["cuda", "cpu"]:
141
+ return True
142
+ elif device == "mps":
143
+ return False
144
+ else:
145
+ raise ValueError(f"Unknown device '{device}.")
146
+
147
+
148
+ # Global variable to ensure logging is initialized only once
149
+ _logging_init_stack = ""
150
+
151
+
152
+ def _format_stack(stack: list[inspect.FrameInfo]) -> str:
153
+ return "\n".join(
154
+ f" File '{frame.filename}', line {frame.lineno}, in {frame.function}"
155
+ for frame in stack[1:] # skip the current frame
156
+ )
157
+
158
+
159
+ def init_logging(accelerator: accelerate.Accelerator | None = None, level=logging.INFO) -> None:
160
+ """Initialize logging configuration with custom formatter.
161
+
162
+ This function sets up logging with a custom formatter that includes
163
+ timestamp, filename, and line number. It can only be initialized once
164
+ per process and will warn if called multiple times.
165
+
166
+ Args:
167
+ accelerator: Optional Accelerator instance. If provided, logging level
168
+ is set to WARNING on non-main processes to avoid duplicate logs.
169
+ level: Logging level to use. Defaults to logging.INFO.
170
+ """
171
+ global _logging_init_stack
172
+ stack = inspect.stack()
173
+
174
+ if _logging_init_stack:
175
+ warnings.warn(
176
+ f"""Logging was already initialized through the following stack:
177
+ {_logging_init_stack}
178
+ Not initializing again through the following stack:
179
+ {_format_stack(stack)}""",
180
+ stacklevel=2,
181
+ )
182
+ else:
183
+ _logging_init_stack = _format_stack(stack)
184
+
185
+ class CustomFormatter(logging.Formatter):
186
+ def format(self, record):
187
+ dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
188
+ fnameline = f"{record.pathname}:{record.lineno}"
189
+ return f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
190
+
191
+ console_handler = logging.StreamHandler()
192
+ console_handler.setFormatter(CustomFormatter())
193
+
194
+ logging.basicConfig(level=level, force=True, handlers=[console_handler])
195
+
196
+ if accelerator and not accelerator.is_main_process:
197
+ # Disable duplicate logging on non-main processes
198
+ logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
199
+ logging.getLogger().setLevel(logging.WARNING)
200
+
201
+
202
+ def format_big_number(num: float | int, precision: int = 0) -> str:
203
+ """Format a large number with appropriate suffix (K, M, B, T, Q).
204
+
205
+ Args:
206
+ num: Number to format.
207
+ precision: Number of decimal places. Defaults to 0.
208
+
209
+ Returns:
210
+ Formatted string with suffix (e.g., "1.5K", "2.3M").
211
+ """
212
+ suffixes = ["", "K", "M", "B", "T", "Q"]
213
+ divisor = 1000.0
214
+
215
+ for suffix in suffixes:
216
+ if abs(num) < divisor:
217
+ return f"{num:.{precision}f}{suffix}"
218
+ num /= divisor
219
+
220
+ return num
221
+
222
+
223
+ def capture_timestamp_utc() -> datetime:
224
+ """Capture the current UTC timestamp.
225
+
226
+ Returns:
227
+ datetime object representing the current UTC time.
228
+ """
229
+ return datetime.now(timezone.utc)
230
+
231
+
232
+ def say(text: str, blocking: bool = False) -> None:
233
+ """Use text-to-speech to speak text (platform-specific).
234
+
235
+ Args:
236
+ text: Text to speak.
237
+ blocking: If True, wait for speech to complete before returning.
238
+ Defaults to False.
239
+ """
240
+ # Check if mac, linux, or windows.
241
+ if platform.system() == "Darwin":
242
+ cmd = f'say "{text}"'
243
+ if not blocking:
244
+ cmd += " &"
245
+ elif platform.system() == "Linux":
246
+ cmd = f'spd-say "{text}"'
247
+ if blocking:
248
+ cmd += " --wait"
249
+ elif platform.system() == "Windows":
250
+ # TODO(rcadene): Make blocking option work for Windows
251
+ cmd = (
252
+ 'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
253
+ f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
254
+ )
255
+
256
+ os.system(cmd) # nosec: B605
257
+
258
+
259
+ def log_say(text: str, play_sounds: bool, blocking: bool = False) -> None:
260
+ """Log text and optionally speak it using text-to-speech.
261
+
262
+ Args:
263
+ text: Text to log and optionally speak.
264
+ play_sounds: If True, also speak the text using text-to-speech.
265
+ blocking: If True, wait for speech to complete. Defaults to False.
266
+ """
267
+ logging.info(text)
268
+
269
+ if play_sounds:
270
+ say(text, blocking)
271
+
272
+
273
+ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
274
+ """Convert image shape from HWC to CHW format if needed.
275
+
276
+ Args:
277
+ image_shape: Image shape tuple, either (H, W, C) or (C, H, W).
278
+
279
+ Returns:
280
+ Image shape in CHW format (C, H, W).
281
+
282
+ Raises:
283
+ ValueError: If the input shape is not in a recognized format.
284
+ """
285
+ shape = copy(image_shape)
286
+ if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
287
+ shape = (shape[2], shape[0], shape[1])
288
+ elif not (shape[0] < shape[1] and shape[0] < shape[2]):
289
+ raise ValueError(image_shape)
290
+
291
+ return shape
292
+
293
+
294
+ def has_method(cls: object, method_name: str) -> bool:
295
+ """Check if a class or object has a specific method.
296
+
297
+ Args:
298
+ cls: Class or object to check.
299
+ method_name: Name of the method to check for.
300
+
301
+ Returns:
302
+ True if the method exists and is callable, False otherwise.
303
+ """
304
+ return hasattr(cls, method_name) and callable(getattr(cls, method_name))
305
+
306
+
307
+ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
308
+ """Check if a string can be converted to a numpy dtype.
309
+
310
+ Args:
311
+ dtype_str: String to check.
312
+
313
+ Returns:
314
+ True if the string is a valid numpy dtype, False otherwise.
315
+ """
316
+ try:
317
+ # Attempt to convert the string to a numpy dtype
318
+ np.dtype(dtype_str)
319
+ return True
320
+ except TypeError:
321
+ # If a TypeError is raised, the string is not a valid dtype
322
+ return False
323
+
324
+
325
+ def is_launched_with_accelerate() -> bool:
326
+ """Check if the process was launched with accelerate.
327
+
328
+ Returns:
329
+ True if ACCELERATE_MIXED_PRECISION is in the environment, False otherwise.
330
+ """
331
+ return "ACCELERATE_MIXED_PRECISION" in os.environ
332
+
333
+
334
+ def attempt_torch_compile(fn: callable, device_hint=None) -> callable:
335
+ """Attempt to compile a PyTorch function using torch.compile.
336
+
337
+ The argument device_hint is used to check if torch.compile works reliably
338
+ on the device. Compilation is skipped if the device is MPS (Metal Performance
339
+ Shaders) as it is experimental.
340
+
341
+ Args:
342
+ fn: Function to compile.
343
+ device_hint: Optional device hint to check compatibility.
344
+
345
+ Returns:
346
+ Compiled function if compilation succeeds, otherwise the original function.
347
+ """
348
+ if device_hint and "mps" in str(device_hint):
349
+ logging.warning("torch.compile is experimental on MPS devices. Compilation skipped.")
350
+ return fn
351
+
352
+ if hasattr(torch, "compile"):
353
+ logging.info("Attempting to compile the policy with torch.compile()...")
354
+ try:
355
+ # Other options: "default", "max-autotune" (longer compile time)
356
+ fn = torch.compile(fn)
357
+ logging.info("Policy compiled successfully.")
358
+ except Exception as e:
359
+ logging.warning(f"torch.compile failed with error: {e}. Proceeding without compilation.")
360
+ else:
361
+ logging.warning(
362
+ "torch.compile is not available. Requires PyTorch 2.0+. Proceeding without compilation."
363
+ )
364
+
365
+ return fn
366
+
367
+
368
+ def create_dummy_observation(cfg, device, dtype=torch.bfloat16) -> dict:
369
+ """Create a dummy observation dictionary for testing or initialization.
370
+
371
+ Args:
372
+ cfg: Configuration object with num_cams, resolution, max_state_dim,
373
+ and action_chunk attributes.
374
+ device: Device to create tensors on.
375
+ dtype: Data type for tensors. Defaults to torch.bfloat16.
376
+
377
+ Returns:
378
+ Dictionary containing dummy camera observations, state, prompt, and
379
+ padding flags.
380
+ """
381
+ camera_observations = {
382
+ f"camera{i}": torch.zeros((1, 3, *cfg.resolution), dtype=dtype, device=device)
383
+ for i in range(cfg.num_cams)
384
+ }
385
+ return {
386
+ **camera_observations,
387
+ "state": torch.zeros((1, cfg.max_state_dim), dtype=dtype, device=device),
388
+ "prompt": ["Pick up yellow lego block and put it in the bin"],
389
+ "img_is_pad": torch.zeros((1, cfg.num_cams), dtype=torch.bool, device=device),
390
+ "action_is_pad": torch.zeros((1, cfg.action_chunk), dtype=torch.bool, device=device),
391
+ }
392
+
393
+
394
+ def encode_accelerator_state_dict(obj) -> Any:
395
+ """Encode an object into a JSON/YAML-compatible primitive type.
396
+
397
+ Args:
398
+ obj: Object to encode (can be Enum, dict, list, tuple, dataclass, etc.).
399
+
400
+ Returns:
401
+ Encoded object with all nested structures converted to primitives.
402
+ """
403
+ if isinstance(obj, enum.Enum):
404
+ return encode_accelerator_state_dict(obj.value)
405
+ elif isinstance(obj, (str, int, float, bool)) or obj is None:
406
+ return obj
407
+ elif isinstance(obj, (list, tuple)):
408
+ return [encode_accelerator_state_dict(item) for item in obj]
409
+ elif isinstance(obj, dict):
410
+ return {key.replace(".", "_"): encode_accelerator_state_dict(value) for key, value in obj.items()}
411
+ elif is_dataclass(obj):
412
+ return {f.name: encode_accelerator_state_dict(getattr(obj, f.name)) for f in fields(obj)}
413
+ else:
414
+ return str(obj) # Fallback to string representation for unsupported types
415
+
416
+
417
+ def on_accelerate_main_proc(*, local=False, _sync=False):
418
+ r"""Returns a decorator to run a function only on the main process when using `accelerate`.
419
+
420
+ If `local` is True (defaults to False), the function will run on the main process of each node
421
+ (useful for multi-node setups).
422
+ If `_sync` is True (defaults to False), the output of the function will be broadcasted to all processes.
423
+ If `_sync` is True, you must ensure that all processes call the decorated function, otherwise it will deadlock.
424
+
425
+ YOU SHOULD BE EXTREMELY CAREFUL WHEN USING THIS DECORATOR with _sync=True. Consider the following example::
426
+
427
+ @on_accelerate_main_proc()
428
+ def f():
429
+ return g()
430
+
431
+ @on_accelerate_main_proc(_sync=True)
432
+ def g():
433
+ return 42
434
+
435
+ In this case, if f() is called on all processes, they will deadlock at g() because child processes don't even
436
+ enter f(), hence never call g(), and thus won't reach the broadcast.
437
+
438
+ Another example::
439
+
440
+ @on_accelerate_main_proc(_sync=cond())
441
+ def f():
442
+ print("hi")
443
+ """
444
+
445
+ def decorator(func):
446
+ @wraps(func)
447
+ def wrapper(*args, **kwargs):
448
+ state = accelerate.state.PartialState()
449
+ if not is_launched_with_accelerate() or not state.use_distributed:
450
+ return func(*args, **kwargs)
451
+
452
+ output, exception = None, None
453
+ flag = state.is_local_main_process if local else state.is_main_process
454
+ if flag:
455
+ try:
456
+ output = func(*args, **kwargs)
457
+ except Exception as e:
458
+ exception = e
459
+
460
+ if _sync:
461
+ payload = [output, exception]
462
+ accelerate.utils.broadcast_object_list(payload, from_process=0)
463
+ output, exception = payload
464
+
465
+ if exception is not None:
466
+ raise RuntimeError("An exception occurred in the main process.") from exception
467
+ return output
468
+
469
+ return wrapper
470
+
471
+ return decorator