opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,177 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Datasets for Image-Text Point Set grounding tasks.
15
+
16
+ This module provides the PIXMO (Pixel-level Manipulation) dataset implementation
17
+ for training vision-language models on part localization and object grounding tasks.
18
+ """
19
+
20
+ import json
21
+ import logging
22
+ import random
23
+ import warnings
24
+ from io import BytesIO
25
+
26
+ import numpy as np
27
+ import requests
28
+ import torch
29
+ from datasets import load_dataset
30
+ from PIL import Image
31
+ from requests.adapters import HTTPAdapter
32
+ from urllib3.util.retry import Retry
33
+
34
+ from opentau import register_grounding_dataset
35
+ from opentau.configs.train import TrainPipelineConfig
36
+ from opentau.datasets.grounding.base import GroundingDataset
37
+
38
+ # TODO: add a config to filter the warnings
39
+ logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
40
+ warnings.filterwarnings(
41
+ "ignore",
42
+ message=r"Palette images with Transparency expressed in bytes should be converted to RGBA images",
43
+ category=UserWarning,
44
+ module=r"PIL\.Image",
45
+ )
46
+ warnings.filterwarnings(
47
+ "ignore",
48
+ message=r"image file could not be identified because AVIF support not installed",
49
+ category=UserWarning,
50
+ module=r"PIL\.Image",
51
+ )
52
+
53
+ IMG_SIZE = 224
54
+ POINT_GRID = 255
55
+ MAX_RETRIES = 1
56
+ HTTP_TIMEOUT = 1
57
+ LOG_EVERY_N_BAD = 1000
58
+
59
+ _session = requests.Session()
60
+ _session.mount(
61
+ "https://",
62
+ HTTPAdapter(
63
+ max_retries=Retry(
64
+ total=MAX_RETRIES,
65
+ backoff_factor=0.5,
66
+ status_forcelist=[500, 502, 503, 504],
67
+ )
68
+ ),
69
+ )
70
+
71
+
72
+ def _pil_from_url(url: str) -> Image.Image | None:
73
+ """Download, decode, and resize an image using its URL. Returns None in case of failure."""
74
+ try:
75
+ r = _session.get(url, timeout=HTTP_TIMEOUT)
76
+ r.raise_for_status()
77
+ # TODO: Check against the hash in case the image somehow changed.
78
+ return Image.open(BytesIO(r.content)).convert("RGB")
79
+ except Exception:
80
+ return None
81
+
82
+
83
+ def _get_post_fix(label: str, points: list, orig_w: int, orig_h: int, max_points: int = 16) -> str:
84
+ """Map points from pixel space to grid space and return a JSON postfix string.
85
+
86
+ Converts pixel coordinates to a 255x255 grid, deduplicates points, and
87
+ limits to max_points. Returns a JSON string with point coordinates and labels.
88
+
89
+ Args:
90
+ label: Label for the points (e.g., object class name).
91
+ points: List of point dictionaries with 'x' and 'y' keys.
92
+ orig_w: Original image width.
93
+ orig_h: Original image height.
94
+ max_points: Maximum number of points to include. Defaults to 16.
95
+
96
+ Returns:
97
+ JSON string containing point coordinates and labels.
98
+ """
99
+ # use `dict` to deduplicate as `set` is not guaranteed to preserve order
100
+ deduplicated = {
101
+ (int(p["x"] * POINT_GRID / orig_w), int(p["y"] * POINT_GRID / orig_h)): None for p in points
102
+ }
103
+ if len(deduplicated) > max_points:
104
+ deduplicated = random.choices(list(deduplicated), k=max_points)
105
+ rows = [{"in_frame": True, "point": pair, "label": label} for pair in deduplicated]
106
+ return json.dumps(rows)
107
+
108
+
109
+ def _img_to_normalized_tensor(img: Image.Image) -> torch.Tensor:
110
+ """Convert a PIL Image to a normalized torch tensor.
111
+
112
+ Resizes the image to IMG_SIZE and converts it from (H, W, C) to (C, H, W)
113
+ format, normalizing pixel values to [0, 1].
114
+
115
+ Args:
116
+ img: PIL Image to convert.
117
+
118
+ Returns:
119
+ Normalized tensor of shape (C, IMG_SIZE, IMG_SIZE) with values in [0, 1].
120
+ """
121
+ img = img.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)
122
+ # pytorch uses (C, H, W) while PIL uses (H, W, C)
123
+ return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
124
+
125
+
126
+ @register_grounding_dataset("pixmo")
127
+ class PixmoDataset(GroundingDataset):
128
+ r"""Dataset for the iterable PixMo dataset implementation, recommended to be used together with PrefetchWrapper"""
129
+
130
+ def __init__(self, cfg: TrainPipelineConfig, consecutive_bad_tolerance=100):
131
+ # Self.ds is needed for metadata, which is computed in parent constructor
132
+ self.ds = load_dataset("allenai/pixmo-points", split="train")
133
+ super().__init__(cfg)
134
+ self.bad_ids = set()
135
+ self.consecutive_bad_tolerance = consecutive_bad_tolerance
136
+
137
+ def __len__(self):
138
+ return len(self.ds)
139
+
140
+ def _get_feature_mapping_key(self) -> str:
141
+ return "pixmo"
142
+
143
+ def __getitem_helper__(self, item) -> dict:
144
+ """Get a PixMo dataset item.
145
+
146
+ Downloads the image from URL and formats it for part localization tasks.
147
+ Retries with random indices if image download fails.
148
+
149
+ Args:
150
+ item: Index of the item to retrieve.
151
+
152
+ Returns:
153
+ Dictionary with image, task, postfix, task_type, and prompt.
154
+
155
+ Raises:
156
+ RuntimeError: If too many consecutive items fail to load.
157
+ """
158
+ for _ in range(self.consecutive_bad_tolerance):
159
+ if item in self.bad_ids:
160
+ item = np.random.randint(0, len(self.ds))
161
+ continue
162
+ ex = self.ds[item]
163
+ img = _pil_from_url(ex["image_url"])
164
+ if img is None:
165
+ self.bad_ids.add(item)
166
+ item = np.random.randint(0, len(self.ds))
167
+ continue
168
+
169
+ return {
170
+ "image": _img_to_normalized_tensor(img),
171
+ "task": ex["label"],
172
+ "postfix": _get_post_fix(ex["label"], ex["points"], *img.size),
173
+ "task_type": "part",
174
+ "prompt": f'{{"task": "part", "description": "Find {ex["label"]} in the image"}}',
175
+ }
176
+
177
+ raise RuntimeError("Too many consecutive bad items. Please check dataset or increase the tolerance.")
@@ -0,0 +1,141 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """VSR (Visual Spatial Reasoning) dataset for true/false statement grounding.
15
+
16
+ This module provides the VSR dataset implementation for training vision-language
17
+ models on visual spatial reasoning tasks. The dataset contains images with
18
+ statements about spatial relationships, and models must determine whether each
19
+ statement is true or false based on the image content.
20
+ """
21
+
22
+ import logging
23
+ from io import BytesIO
24
+
25
+ import numpy as np
26
+ import requests
27
+ import torch
28
+ from datasets import load_dataset
29
+ from PIL import Image
30
+ from requests.adapters import HTTPAdapter
31
+ from urllib3.util.retry import Retry
32
+
33
+ from opentau import register_grounding_dataset
34
+ from opentau.configs.train import TrainPipelineConfig
35
+ from opentau.datasets.grounding.base import GroundingDataset
36
+
37
+ logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
38
+
39
+ MAX_RETRIES = 1
40
+ HTTP_TIMEOUT = 1
41
+ LOG_EVERY_N_BAD = 1000
42
+
43
+ _session = requests.Session()
44
+ _session.mount(
45
+ "https://",
46
+ HTTPAdapter(
47
+ max_retries=Retry(
48
+ total=MAX_RETRIES,
49
+ backoff_factor=0.5,
50
+ status_forcelist=[500, 502, 503, 504],
51
+ )
52
+ ),
53
+ )
54
+
55
+
56
+ def _pil_from_url(url: str) -> Image.Image | None:
57
+ """Download, decode, and resize an image using its URL. Returns None in case of failure."""
58
+ try:
59
+ r = _session.get(url, timeout=HTTP_TIMEOUT)
60
+ r.raise_for_status()
61
+ # TODO: Check against the hash in case the image somehow changed.
62
+ return Image.open(BytesIO(r.content)).convert("RGB")
63
+ except Exception:
64
+ return None
65
+
66
+
67
+ def _img_to_normalized_tensor(img: Image.Image, img_shape: tuple) -> torch.Tensor:
68
+ """Convert a PIL Image to a normalized torch tensor.
69
+
70
+ Resizes the image and converts it from (H, W, C) to (C, H, W) format,
71
+ normalizing pixel values to [0, 1].
72
+
73
+ Args:
74
+ img: PIL Image to convert.
75
+ img_shape: Target image shape (height, width).
76
+
77
+ Returns:
78
+ Normalized tensor of shape (C, H, W) with values in [0, 1].
79
+ """
80
+ img = img.resize(img_shape, Image.BILINEAR)
81
+ return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
82
+
83
+
84
+ @register_grounding_dataset("vsr")
85
+ class VSRDataset(GroundingDataset):
86
+ """Visual Spatial Reasoning (VSR) dataset for true/false statement grounding.
87
+
88
+ Loads the cambridgeltl/vsr_random dataset from HuggingFace and formats it
89
+ for visual reasoning tasks where models must determine if statements about
90
+ images are true or false.
91
+ """
92
+
93
+ def __init__(self, cfg: TrainPipelineConfig, consecutive_bad_tolerance=100):
94
+ self.dataset = load_dataset("cambridgeltl/vsr_random", split="train")
95
+ super().__init__(cfg)
96
+ self.bad_ids = set()
97
+ self.consecutive_bad_tolerance = consecutive_bad_tolerance
98
+ self.mapping = {0: "False", 1: "True"}
99
+
100
+ def __len__(self):
101
+ return len(self.dataset)
102
+
103
+ def _get_feature_mapping_key(self) -> str:
104
+ return "vsr"
105
+
106
+ def __getitem_helper__(self, item) -> dict:
107
+ """Get a VSR dataset item.
108
+
109
+ Downloads the image from URL and formats it for true/false reasoning tasks.
110
+ Retries with random indices if image download fails.
111
+
112
+ Args:
113
+ item: Index of the item to retrieve.
114
+
115
+ Returns:
116
+ Dictionary with image, task, postfix, task_type, and prompt.
117
+
118
+ Raises:
119
+ RuntimeError: If too many consecutive items fail to load.
120
+ """
121
+ for _ in range(self.consecutive_bad_tolerance):
122
+ if item in self.bad_ids:
123
+ item = np.random.randint(0, len(self.dataset))
124
+ continue
125
+
126
+ sample = self.dataset[item]
127
+ img = _pil_from_url(sample["image_link"])
128
+ if img is None:
129
+ self.bad_ids.add(item)
130
+ item = np.random.randint(0, len(self.dataset))
131
+ continue
132
+
133
+ return {
134
+ "image": _img_to_normalized_tensor(img, self.resolution),
135
+ "task": sample["label"],
136
+ "postfix": f"The statement is {self.mapping[sample['label']]}",
137
+ "task_type": "grounding",
138
+ "prompt": f'{{"task": "grounding", "description": "Using the Image, Tell me if following statement is true or false. \n {sample["caption"]}"}}',
139
+ }
140
+
141
+ raise RuntimeError("Too many consecutive bad items. Please check dataset or increase the tolerance.")
@@ -0,0 +1,304 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Asynchronous image writing utilities for high-frequency data recording.
18
+
19
+ This module provides functionality for writing images to disk asynchronously
20
+ using multithreading or multiprocessing, which is critical for controlling
21
+ robots and recording data at high frame rates without blocking the main process.
22
+
23
+ The module supports two execution models:
24
+
25
+ 1. Threading mode (num_processes=0): Creates a pool of worker threads
26
+ for concurrent image writing within a single process.
27
+ 2. Multiprocessing mode (num_processes>0): Creates multiple processes,
28
+ each with their own thread pool, for maximum parallelism.
29
+
30
+ Key Features:
31
+ - Asynchronous writing: Images are queued and written in background
32
+ workers, preventing I/O blocking of the main process.
33
+ - Multiple input formats: Supports torch Tensors, numpy arrays, and
34
+ PIL Images with automatic conversion.
35
+ - Format flexibility: Handles both channel-first (C, H, W) and
36
+ channel-last (H, W, C) image formats.
37
+ - Type conversion: Automatically converts float arrays in [0, 1] to
38
+ uint8 in [0, 255] for PIL Image compatibility.
39
+ - Safe cleanup: Decorator ensures image writers are properly stopped
40
+ even when exceptions occur.
41
+
42
+ Classes:
43
+
44
+ AsyncImageWriter
45
+ Main class for asynchronous image writing with configurable threading
46
+ or multiprocessing backends.
47
+
48
+ Functions:
49
+ image_array_to_pil_image
50
+ Convert numpy array to PIL Image with format and type conversion.
51
+ write_image
52
+ Write an image (numpy array or PIL Image) to disk.
53
+ worker_thread_loop
54
+ Worker thread loop for processing image write queue.
55
+ worker_process
56
+ Worker process that manages multiple threads for image writing.
57
+ safe_stop_image_writer
58
+ Decorator to safely stop image writer on exceptions.
59
+
60
+ Example:
61
+ Create an async image writer with threading:
62
+ >>> writer = AsyncImageWriter(num_processes=0, num_threads=4)
63
+ >>> writer.save_image(image_array, Path("output/image.jpg"))
64
+ >>> writer.wait_until_done() # Wait for all images to be written
65
+ >>> writer.stop() # Clean up resources
66
+ """
67
+
68
+ import multiprocessing
69
+ import queue
70
+ import threading
71
+ from pathlib import Path
72
+
73
+ import numpy as np
74
+ import PIL.Image
75
+ import torch
76
+
77
+
78
+ def safe_stop_image_writer(func):
79
+ """Decorator to safely stop image writer on exceptions.
80
+
81
+ Ensures that the image writer is properly stopped if an exception occurs
82
+ during function execution.
83
+
84
+ Args:
85
+ func: Function to wrap.
86
+
87
+ Returns:
88
+ Wrapped function that stops image writer on exceptions.
89
+ """
90
+
91
+ def wrapper(*args, **kwargs):
92
+ try:
93
+ return func(*args, **kwargs)
94
+ except Exception as e:
95
+ dataset = kwargs.get("dataset")
96
+ image_writer = getattr(dataset, "image_writer", None) if dataset else None
97
+ if image_writer is not None:
98
+ print("Waiting for image writer to terminate...")
99
+ image_writer.stop()
100
+ raise e
101
+
102
+ return wrapper
103
+
104
+
105
+ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
106
+ """Convert a numpy array to a PIL Image.
107
+
108
+ Supports channel-first (C, H, W) and channel-last (H, W, C) formats.
109
+ Converts float arrays in [0, 1] to uint8 in [0, 255].
110
+
111
+ Args:
112
+ image_array: Input image array of shape (C, H, W) or (H, W, C).
113
+ range_check: If True, validates that float arrays are in [0, 1] range.
114
+ Defaults to True.
115
+
116
+ Returns:
117
+ PIL Image object.
118
+
119
+ Raises:
120
+ ValueError: If array has wrong number of dimensions, wrong number of
121
+ channels, or float values are outside [0, 1] range.
122
+ NotImplementedError: If image doesn't have 3 channels.
123
+
124
+ Note:
125
+ TODO(aliberts): handle 1 channel and 4 for depth images
126
+ """
127
+ # TODO(aliberts): handle 1 channel and 4 for depth images
128
+ if image_array.ndim != 3:
129
+ raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
130
+
131
+ if image_array.shape[0] == 3:
132
+ # Transpose from pytorch convention (C, H, W) to (H, W, C)
133
+ image_array = image_array.transpose(1, 2, 0)
134
+
135
+ elif image_array.shape[-1] != 3:
136
+ raise NotImplementedError(
137
+ f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
138
+ )
139
+
140
+ if image_array.dtype != np.uint8:
141
+ if range_check:
142
+ max_ = image_array.max().item()
143
+ min_ = image_array.min().item()
144
+ if max_ > 1.0 or min_ < 0.0:
145
+ raise ValueError(
146
+ "The image data type is float, which requires values in the range [0.0, 1.0]. "
147
+ f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
148
+ "provide a uint8 image with values in the range [0, 255]."
149
+ )
150
+
151
+ image_array = (image_array * 255).astype(np.uint8)
152
+
153
+ return PIL.Image.fromarray(image_array)
154
+
155
+
156
+ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path) -> None:
157
+ """Write an image to disk.
158
+
159
+ Converts numpy arrays to PIL Images if needed, then saves to the specified path.
160
+
161
+ Args:
162
+ image: Image to save (numpy array or PIL Image).
163
+ fpath: Path where the image will be saved.
164
+
165
+ Raises:
166
+ TypeError: If image type is not supported.
167
+ """
168
+ try:
169
+ if isinstance(image, np.ndarray):
170
+ img = image_array_to_pil_image(image)
171
+ elif isinstance(image, PIL.Image.Image):
172
+ img = image
173
+ else:
174
+ raise TypeError(f"Unsupported image type: {type(image)}")
175
+ img.save(fpath)
176
+ except Exception as e:
177
+ print(f"Error writing image {fpath}: {e}")
178
+
179
+
180
+ def worker_thread_loop(queue: queue.Queue) -> None:
181
+ """Worker thread loop for asynchronous image writing.
182
+
183
+ Continuously processes items from the queue until receiving None (sentinel).
184
+ Each item should be a tuple of (image_array, file_path).
185
+
186
+ Args:
187
+ queue: Queue containing (image_array, file_path) tuples or None sentinel.
188
+ """
189
+ while True:
190
+ item = queue.get()
191
+ if item is None:
192
+ queue.task_done()
193
+ break
194
+ image_array, fpath = item
195
+ write_image(image_array, fpath)
196
+ queue.task_done()
197
+
198
+
199
+ def worker_process(queue: queue.Queue, num_threads: int) -> None:
200
+ """Worker process that manages multiple threads for image writing.
201
+
202
+ Creates and manages a pool of worker threads that process items from the queue.
203
+
204
+ Args:
205
+ queue: Queue containing (image_array, file_path) tuples or None sentinels.
206
+ num_threads: Number of worker threads to create in this process.
207
+ """
208
+ threads = []
209
+ for _ in range(num_threads):
210
+ t = threading.Thread(target=worker_thread_loop, args=(queue,))
211
+ t.daemon = True
212
+ t.start()
213
+ threads.append(t)
214
+ for t in threads:
215
+ t.join()
216
+
217
+
218
+ class AsyncImageWriter:
219
+ """
220
+ This class abstract away the initialisation of processes or/and threads to
221
+ save images on disk asynchrounously, which is critical to control a robot and record data
222
+ at a high frame rate.
223
+
224
+ When `num_processes=0`, it creates a threads pool of size `num_threads`.
225
+ When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
226
+ their own threads pool of size `num_threads`.
227
+
228
+ The optimal number of processes and threads depends on your computer capabilities.
229
+ We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
230
+ the number of threads. If it is still not stable, try to use 1 subprocess, or more.
231
+ """
232
+
233
+ def __init__(self, num_processes: int = 0, num_threads: int = 1):
234
+ self.num_processes = num_processes
235
+ self.num_threads = num_threads
236
+ self.queue = None
237
+ self.threads = []
238
+ self.processes = []
239
+ self._stopped = False
240
+
241
+ if num_threads <= 0 and num_processes <= 0:
242
+ raise ValueError("Number of threads and processes must be greater than zero.")
243
+
244
+ if self.num_processes == 0:
245
+ # Use threading
246
+ self.queue = queue.Queue()
247
+ for _ in range(self.num_threads):
248
+ t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
249
+ t.daemon = True
250
+ t.start()
251
+ self.threads.append(t)
252
+ else:
253
+ # Use multiprocessing
254
+ self.queue = multiprocessing.JoinableQueue()
255
+ for _ in range(self.num_processes):
256
+ p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
257
+ p.daemon = True
258
+ p.start()
259
+ self.processes.append(p)
260
+
261
+ def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
262
+ """Queue an image for asynchronous writing.
263
+
264
+ Converts torch tensors to numpy arrays and adds the image to the write queue.
265
+
266
+ Args:
267
+ image: Image to save (torch Tensor, numpy array, or PIL Image).
268
+ fpath: Path where the image will be saved.
269
+ """
270
+ if isinstance(image, torch.Tensor):
271
+ # Convert tensor to numpy array to minimize main process time
272
+ image = image.cpu().numpy()
273
+ self.queue.put((image, fpath))
274
+
275
+ def wait_until_done(self) -> None:
276
+ """Wait until all queued images have been written to disk."""
277
+ self.queue.join()
278
+
279
+ def stop(self) -> None:
280
+ """Stop all worker threads/processes and clean up resources.
281
+
282
+ Sends sentinel values to all workers and waits for them to finish.
283
+ Terminates processes if they don't respond.
284
+ """
285
+ if self._stopped:
286
+ return
287
+
288
+ if self.num_processes == 0:
289
+ for _ in self.threads:
290
+ self.queue.put(None)
291
+ for t in self.threads:
292
+ t.join()
293
+ else:
294
+ num_nones = self.num_processes * self.num_threads
295
+ for _ in range(num_nones):
296
+ self.queue.put(None)
297
+ for p in self.processes:
298
+ p.join()
299
+ if p.is_alive():
300
+ p.terminate()
301
+ self.queue.close()
302
+ self.queue.join_thread()
303
+
304
+ self._stopped = True