opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Datasets for Image-Text Point Set grounding tasks.
|
|
15
|
+
|
|
16
|
+
This module provides the PIXMO (Pixel-level Manipulation) dataset implementation
|
|
17
|
+
for training vision-language models on part localization and object grounding tasks.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import json
|
|
21
|
+
import logging
|
|
22
|
+
import random
|
|
23
|
+
import warnings
|
|
24
|
+
from io import BytesIO
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import requests
|
|
28
|
+
import torch
|
|
29
|
+
from datasets import load_dataset
|
|
30
|
+
from PIL import Image
|
|
31
|
+
from requests.adapters import HTTPAdapter
|
|
32
|
+
from urllib3.util.retry import Retry
|
|
33
|
+
|
|
34
|
+
from opentau import register_grounding_dataset
|
|
35
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
36
|
+
from opentau.datasets.grounding.base import GroundingDataset
|
|
37
|
+
|
|
38
|
+
# TODO: add a config to filter the warnings
|
|
39
|
+
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
|
|
40
|
+
warnings.filterwarnings(
|
|
41
|
+
"ignore",
|
|
42
|
+
message=r"Palette images with Transparency expressed in bytes should be converted to RGBA images",
|
|
43
|
+
category=UserWarning,
|
|
44
|
+
module=r"PIL\.Image",
|
|
45
|
+
)
|
|
46
|
+
warnings.filterwarnings(
|
|
47
|
+
"ignore",
|
|
48
|
+
message=r"image file could not be identified because AVIF support not installed",
|
|
49
|
+
category=UserWarning,
|
|
50
|
+
module=r"PIL\.Image",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
IMG_SIZE = 224
|
|
54
|
+
POINT_GRID = 255
|
|
55
|
+
MAX_RETRIES = 1
|
|
56
|
+
HTTP_TIMEOUT = 1
|
|
57
|
+
LOG_EVERY_N_BAD = 1000
|
|
58
|
+
|
|
59
|
+
_session = requests.Session()
|
|
60
|
+
_session.mount(
|
|
61
|
+
"https://",
|
|
62
|
+
HTTPAdapter(
|
|
63
|
+
max_retries=Retry(
|
|
64
|
+
total=MAX_RETRIES,
|
|
65
|
+
backoff_factor=0.5,
|
|
66
|
+
status_forcelist=[500, 502, 503, 504],
|
|
67
|
+
)
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _pil_from_url(url: str) -> Image.Image | None:
|
|
73
|
+
"""Download, decode, and resize an image using its URL. Returns None in case of failure."""
|
|
74
|
+
try:
|
|
75
|
+
r = _session.get(url, timeout=HTTP_TIMEOUT)
|
|
76
|
+
r.raise_for_status()
|
|
77
|
+
# TODO: Check against the hash in case the image somehow changed.
|
|
78
|
+
return Image.open(BytesIO(r.content)).convert("RGB")
|
|
79
|
+
except Exception:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_post_fix(label: str, points: list, orig_w: int, orig_h: int, max_points: int = 16) -> str:
|
|
84
|
+
"""Map points from pixel space to grid space and return a JSON postfix string.
|
|
85
|
+
|
|
86
|
+
Converts pixel coordinates to a 255x255 grid, deduplicates points, and
|
|
87
|
+
limits to max_points. Returns a JSON string with point coordinates and labels.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
label: Label for the points (e.g., object class name).
|
|
91
|
+
points: List of point dictionaries with 'x' and 'y' keys.
|
|
92
|
+
orig_w: Original image width.
|
|
93
|
+
orig_h: Original image height.
|
|
94
|
+
max_points: Maximum number of points to include. Defaults to 16.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
JSON string containing point coordinates and labels.
|
|
98
|
+
"""
|
|
99
|
+
# use `dict` to deduplicate as `set` is not guaranteed to preserve order
|
|
100
|
+
deduplicated = {
|
|
101
|
+
(int(p["x"] * POINT_GRID / orig_w), int(p["y"] * POINT_GRID / orig_h)): None for p in points
|
|
102
|
+
}
|
|
103
|
+
if len(deduplicated) > max_points:
|
|
104
|
+
deduplicated = random.choices(list(deduplicated), k=max_points)
|
|
105
|
+
rows = [{"in_frame": True, "point": pair, "label": label} for pair in deduplicated]
|
|
106
|
+
return json.dumps(rows)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _img_to_normalized_tensor(img: Image.Image) -> torch.Tensor:
|
|
110
|
+
"""Convert a PIL Image to a normalized torch tensor.
|
|
111
|
+
|
|
112
|
+
Resizes the image to IMG_SIZE and converts it from (H, W, C) to (C, H, W)
|
|
113
|
+
format, normalizing pixel values to [0, 1].
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
img: PIL Image to convert.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Normalized tensor of shape (C, IMG_SIZE, IMG_SIZE) with values in [0, 1].
|
|
120
|
+
"""
|
|
121
|
+
img = img.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)
|
|
122
|
+
# pytorch uses (C, H, W) while PIL uses (H, W, C)
|
|
123
|
+
return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@register_grounding_dataset("pixmo")
|
|
127
|
+
class PixmoDataset(GroundingDataset):
|
|
128
|
+
r"""Dataset for the iterable PixMo dataset implementation, recommended to be used together with PrefetchWrapper"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, cfg: TrainPipelineConfig, consecutive_bad_tolerance=100):
|
|
131
|
+
# Self.ds is needed for metadata, which is computed in parent constructor
|
|
132
|
+
self.ds = load_dataset("allenai/pixmo-points", split="train")
|
|
133
|
+
super().__init__(cfg)
|
|
134
|
+
self.bad_ids = set()
|
|
135
|
+
self.consecutive_bad_tolerance = consecutive_bad_tolerance
|
|
136
|
+
|
|
137
|
+
def __len__(self):
|
|
138
|
+
return len(self.ds)
|
|
139
|
+
|
|
140
|
+
def _get_feature_mapping_key(self) -> str:
|
|
141
|
+
return "pixmo"
|
|
142
|
+
|
|
143
|
+
def __getitem_helper__(self, item) -> dict:
|
|
144
|
+
"""Get a PixMo dataset item.
|
|
145
|
+
|
|
146
|
+
Downloads the image from URL and formats it for part localization tasks.
|
|
147
|
+
Retries with random indices if image download fails.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
item: Index of the item to retrieve.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Dictionary with image, task, postfix, task_type, and prompt.
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
RuntimeError: If too many consecutive items fail to load.
|
|
157
|
+
"""
|
|
158
|
+
for _ in range(self.consecutive_bad_tolerance):
|
|
159
|
+
if item in self.bad_ids:
|
|
160
|
+
item = np.random.randint(0, len(self.ds))
|
|
161
|
+
continue
|
|
162
|
+
ex = self.ds[item]
|
|
163
|
+
img = _pil_from_url(ex["image_url"])
|
|
164
|
+
if img is None:
|
|
165
|
+
self.bad_ids.add(item)
|
|
166
|
+
item = np.random.randint(0, len(self.ds))
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
return {
|
|
170
|
+
"image": _img_to_normalized_tensor(img),
|
|
171
|
+
"task": ex["label"],
|
|
172
|
+
"postfix": _get_post_fix(ex["label"], ex["points"], *img.size),
|
|
173
|
+
"task_type": "part",
|
|
174
|
+
"prompt": f'{{"task": "part", "description": "Find {ex["label"]} in the image"}}',
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
raise RuntimeError("Too many consecutive bad items. Please check dataset or increase the tolerance.")
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""VSR (Visual Spatial Reasoning) dataset for true/false statement grounding.
|
|
15
|
+
|
|
16
|
+
This module provides the VSR dataset implementation for training vision-language
|
|
17
|
+
models on visual spatial reasoning tasks. The dataset contains images with
|
|
18
|
+
statements about spatial relationships, and models must determine whether each
|
|
19
|
+
statement is true or false based on the image content.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import logging
|
|
23
|
+
from io import BytesIO
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import requests
|
|
27
|
+
import torch
|
|
28
|
+
from datasets import load_dataset
|
|
29
|
+
from PIL import Image
|
|
30
|
+
from requests.adapters import HTTPAdapter
|
|
31
|
+
from urllib3.util.retry import Retry
|
|
32
|
+
|
|
33
|
+
from opentau import register_grounding_dataset
|
|
34
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
35
|
+
from opentau.datasets.grounding.base import GroundingDataset
|
|
36
|
+
|
|
37
|
+
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
|
|
38
|
+
|
|
39
|
+
MAX_RETRIES = 1
|
|
40
|
+
HTTP_TIMEOUT = 1
|
|
41
|
+
LOG_EVERY_N_BAD = 1000
|
|
42
|
+
|
|
43
|
+
_session = requests.Session()
|
|
44
|
+
_session.mount(
|
|
45
|
+
"https://",
|
|
46
|
+
HTTPAdapter(
|
|
47
|
+
max_retries=Retry(
|
|
48
|
+
total=MAX_RETRIES,
|
|
49
|
+
backoff_factor=0.5,
|
|
50
|
+
status_forcelist=[500, 502, 503, 504],
|
|
51
|
+
)
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _pil_from_url(url: str) -> Image.Image | None:
|
|
57
|
+
"""Download, decode, and resize an image using its URL. Returns None in case of failure."""
|
|
58
|
+
try:
|
|
59
|
+
r = _session.get(url, timeout=HTTP_TIMEOUT)
|
|
60
|
+
r.raise_for_status()
|
|
61
|
+
# TODO: Check against the hash in case the image somehow changed.
|
|
62
|
+
return Image.open(BytesIO(r.content)).convert("RGB")
|
|
63
|
+
except Exception:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _img_to_normalized_tensor(img: Image.Image, img_shape: tuple) -> torch.Tensor:
|
|
68
|
+
"""Convert a PIL Image to a normalized torch tensor.
|
|
69
|
+
|
|
70
|
+
Resizes the image and converts it from (H, W, C) to (C, H, W) format,
|
|
71
|
+
normalizing pixel values to [0, 1].
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
img: PIL Image to convert.
|
|
75
|
+
img_shape: Target image shape (height, width).
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Normalized tensor of shape (C, H, W) with values in [0, 1].
|
|
79
|
+
"""
|
|
80
|
+
img = img.resize(img_shape, Image.BILINEAR)
|
|
81
|
+
return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@register_grounding_dataset("vsr")
|
|
85
|
+
class VSRDataset(GroundingDataset):
|
|
86
|
+
"""Visual Spatial Reasoning (VSR) dataset for true/false statement grounding.
|
|
87
|
+
|
|
88
|
+
Loads the cambridgeltl/vsr_random dataset from HuggingFace and formats it
|
|
89
|
+
for visual reasoning tasks where models must determine if statements about
|
|
90
|
+
images are true or false.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(self, cfg: TrainPipelineConfig, consecutive_bad_tolerance=100):
|
|
94
|
+
self.dataset = load_dataset("cambridgeltl/vsr_random", split="train")
|
|
95
|
+
super().__init__(cfg)
|
|
96
|
+
self.bad_ids = set()
|
|
97
|
+
self.consecutive_bad_tolerance = consecutive_bad_tolerance
|
|
98
|
+
self.mapping = {0: "False", 1: "True"}
|
|
99
|
+
|
|
100
|
+
def __len__(self):
|
|
101
|
+
return len(self.dataset)
|
|
102
|
+
|
|
103
|
+
def _get_feature_mapping_key(self) -> str:
|
|
104
|
+
return "vsr"
|
|
105
|
+
|
|
106
|
+
def __getitem_helper__(self, item) -> dict:
|
|
107
|
+
"""Get a VSR dataset item.
|
|
108
|
+
|
|
109
|
+
Downloads the image from URL and formats it for true/false reasoning tasks.
|
|
110
|
+
Retries with random indices if image download fails.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
item: Index of the item to retrieve.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Dictionary with image, task, postfix, task_type, and prompt.
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
RuntimeError: If too many consecutive items fail to load.
|
|
120
|
+
"""
|
|
121
|
+
for _ in range(self.consecutive_bad_tolerance):
|
|
122
|
+
if item in self.bad_ids:
|
|
123
|
+
item = np.random.randint(0, len(self.dataset))
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
sample = self.dataset[item]
|
|
127
|
+
img = _pil_from_url(sample["image_link"])
|
|
128
|
+
if img is None:
|
|
129
|
+
self.bad_ids.add(item)
|
|
130
|
+
item = np.random.randint(0, len(self.dataset))
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
return {
|
|
134
|
+
"image": _img_to_normalized_tensor(img, self.resolution),
|
|
135
|
+
"task": sample["label"],
|
|
136
|
+
"postfix": f"The statement is {self.mapping[sample['label']]}",
|
|
137
|
+
"task_type": "grounding",
|
|
138
|
+
"prompt": f'{{"task": "grounding", "description": "Using the Image, Tell me if following statement is true or false. \n {sample["caption"]}"}}',
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
raise RuntimeError("Too many consecutive bad items. Please check dataset or increase the tolerance.")
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Asynchronous image writing utilities for high-frequency data recording.
|
|
18
|
+
|
|
19
|
+
This module provides functionality for writing images to disk asynchronously
|
|
20
|
+
using multithreading or multiprocessing, which is critical for controlling
|
|
21
|
+
robots and recording data at high frame rates without blocking the main process.
|
|
22
|
+
|
|
23
|
+
The module supports two execution models:
|
|
24
|
+
|
|
25
|
+
1. Threading mode (num_processes=0): Creates a pool of worker threads
|
|
26
|
+
for concurrent image writing within a single process.
|
|
27
|
+
2. Multiprocessing mode (num_processes>0): Creates multiple processes,
|
|
28
|
+
each with their own thread pool, for maximum parallelism.
|
|
29
|
+
|
|
30
|
+
Key Features:
|
|
31
|
+
- Asynchronous writing: Images are queued and written in background
|
|
32
|
+
workers, preventing I/O blocking of the main process.
|
|
33
|
+
- Multiple input formats: Supports torch Tensors, numpy arrays, and
|
|
34
|
+
PIL Images with automatic conversion.
|
|
35
|
+
- Format flexibility: Handles both channel-first (C, H, W) and
|
|
36
|
+
channel-last (H, W, C) image formats.
|
|
37
|
+
- Type conversion: Automatically converts float arrays in [0, 1] to
|
|
38
|
+
uint8 in [0, 255] for PIL Image compatibility.
|
|
39
|
+
- Safe cleanup: Decorator ensures image writers are properly stopped
|
|
40
|
+
even when exceptions occur.
|
|
41
|
+
|
|
42
|
+
Classes:
|
|
43
|
+
|
|
44
|
+
AsyncImageWriter
|
|
45
|
+
Main class for asynchronous image writing with configurable threading
|
|
46
|
+
or multiprocessing backends.
|
|
47
|
+
|
|
48
|
+
Functions:
|
|
49
|
+
image_array_to_pil_image
|
|
50
|
+
Convert numpy array to PIL Image with format and type conversion.
|
|
51
|
+
write_image
|
|
52
|
+
Write an image (numpy array or PIL Image) to disk.
|
|
53
|
+
worker_thread_loop
|
|
54
|
+
Worker thread loop for processing image write queue.
|
|
55
|
+
worker_process
|
|
56
|
+
Worker process that manages multiple threads for image writing.
|
|
57
|
+
safe_stop_image_writer
|
|
58
|
+
Decorator to safely stop image writer on exceptions.
|
|
59
|
+
|
|
60
|
+
Example:
|
|
61
|
+
Create an async image writer with threading:
|
|
62
|
+
>>> writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
|
63
|
+
>>> writer.save_image(image_array, Path("output/image.jpg"))
|
|
64
|
+
>>> writer.wait_until_done() # Wait for all images to be written
|
|
65
|
+
>>> writer.stop() # Clean up resources
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
import multiprocessing
|
|
69
|
+
import queue
|
|
70
|
+
import threading
|
|
71
|
+
from pathlib import Path
|
|
72
|
+
|
|
73
|
+
import numpy as np
|
|
74
|
+
import PIL.Image
|
|
75
|
+
import torch
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def safe_stop_image_writer(func):
|
|
79
|
+
"""Decorator to safely stop image writer on exceptions.
|
|
80
|
+
|
|
81
|
+
Ensures that the image writer is properly stopped if an exception occurs
|
|
82
|
+
during function execution.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
func: Function to wrap.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Wrapped function that stops image writer on exceptions.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def wrapper(*args, **kwargs):
|
|
92
|
+
try:
|
|
93
|
+
return func(*args, **kwargs)
|
|
94
|
+
except Exception as e:
|
|
95
|
+
dataset = kwargs.get("dataset")
|
|
96
|
+
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
|
97
|
+
if image_writer is not None:
|
|
98
|
+
print("Waiting for image writer to terminate...")
|
|
99
|
+
image_writer.stop()
|
|
100
|
+
raise e
|
|
101
|
+
|
|
102
|
+
return wrapper
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
|
106
|
+
"""Convert a numpy array to a PIL Image.
|
|
107
|
+
|
|
108
|
+
Supports channel-first (C, H, W) and channel-last (H, W, C) formats.
|
|
109
|
+
Converts float arrays in [0, 1] to uint8 in [0, 255].
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
image_array: Input image array of shape (C, H, W) or (H, W, C).
|
|
113
|
+
range_check: If True, validates that float arrays are in [0, 1] range.
|
|
114
|
+
Defaults to True.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
PIL Image object.
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
ValueError: If array has wrong number of dimensions, wrong number of
|
|
121
|
+
channels, or float values are outside [0, 1] range.
|
|
122
|
+
NotImplementedError: If image doesn't have 3 channels.
|
|
123
|
+
|
|
124
|
+
Note:
|
|
125
|
+
TODO(aliberts): handle 1 channel and 4 for depth images
|
|
126
|
+
"""
|
|
127
|
+
# TODO(aliberts): handle 1 channel and 4 for depth images
|
|
128
|
+
if image_array.ndim != 3:
|
|
129
|
+
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
|
130
|
+
|
|
131
|
+
if image_array.shape[0] == 3:
|
|
132
|
+
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
|
133
|
+
image_array = image_array.transpose(1, 2, 0)
|
|
134
|
+
|
|
135
|
+
elif image_array.shape[-1] != 3:
|
|
136
|
+
raise NotImplementedError(
|
|
137
|
+
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if image_array.dtype != np.uint8:
|
|
141
|
+
if range_check:
|
|
142
|
+
max_ = image_array.max().item()
|
|
143
|
+
min_ = image_array.min().item()
|
|
144
|
+
if max_ > 1.0 or min_ < 0.0:
|
|
145
|
+
raise ValueError(
|
|
146
|
+
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
|
147
|
+
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
|
148
|
+
"provide a uint8 image with values in the range [0, 255]."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
image_array = (image_array * 255).astype(np.uint8)
|
|
152
|
+
|
|
153
|
+
return PIL.Image.fromarray(image_array)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
|
157
|
+
"""Write an image to disk.
|
|
158
|
+
|
|
159
|
+
Converts numpy arrays to PIL Images if needed, then saves to the specified path.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
image: Image to save (numpy array or PIL Image).
|
|
163
|
+
fpath: Path where the image will be saved.
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
TypeError: If image type is not supported.
|
|
167
|
+
"""
|
|
168
|
+
try:
|
|
169
|
+
if isinstance(image, np.ndarray):
|
|
170
|
+
img = image_array_to_pil_image(image)
|
|
171
|
+
elif isinstance(image, PIL.Image.Image):
|
|
172
|
+
img = image
|
|
173
|
+
else:
|
|
174
|
+
raise TypeError(f"Unsupported image type: {type(image)}")
|
|
175
|
+
img.save(fpath)
|
|
176
|
+
except Exception as e:
|
|
177
|
+
print(f"Error writing image {fpath}: {e}")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def worker_thread_loop(queue: queue.Queue) -> None:
|
|
181
|
+
"""Worker thread loop for asynchronous image writing.
|
|
182
|
+
|
|
183
|
+
Continuously processes items from the queue until receiving None (sentinel).
|
|
184
|
+
Each item should be a tuple of (image_array, file_path).
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
queue: Queue containing (image_array, file_path) tuples or None sentinel.
|
|
188
|
+
"""
|
|
189
|
+
while True:
|
|
190
|
+
item = queue.get()
|
|
191
|
+
if item is None:
|
|
192
|
+
queue.task_done()
|
|
193
|
+
break
|
|
194
|
+
image_array, fpath = item
|
|
195
|
+
write_image(image_array, fpath)
|
|
196
|
+
queue.task_done()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def worker_process(queue: queue.Queue, num_threads: int) -> None:
|
|
200
|
+
"""Worker process that manages multiple threads for image writing.
|
|
201
|
+
|
|
202
|
+
Creates and manages a pool of worker threads that process items from the queue.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
queue: Queue containing (image_array, file_path) tuples or None sentinels.
|
|
206
|
+
num_threads: Number of worker threads to create in this process.
|
|
207
|
+
"""
|
|
208
|
+
threads = []
|
|
209
|
+
for _ in range(num_threads):
|
|
210
|
+
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
|
211
|
+
t.daemon = True
|
|
212
|
+
t.start()
|
|
213
|
+
threads.append(t)
|
|
214
|
+
for t in threads:
|
|
215
|
+
t.join()
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class AsyncImageWriter:
|
|
219
|
+
"""
|
|
220
|
+
This class abstract away the initialisation of processes or/and threads to
|
|
221
|
+
save images on disk asynchrounously, which is critical to control a robot and record data
|
|
222
|
+
at a high frame rate.
|
|
223
|
+
|
|
224
|
+
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
|
225
|
+
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
|
|
226
|
+
their own threads pool of size `num_threads`.
|
|
227
|
+
|
|
228
|
+
The optimal number of processes and threads depends on your computer capabilities.
|
|
229
|
+
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
|
230
|
+
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
def __init__(self, num_processes: int = 0, num_threads: int = 1):
|
|
234
|
+
self.num_processes = num_processes
|
|
235
|
+
self.num_threads = num_threads
|
|
236
|
+
self.queue = None
|
|
237
|
+
self.threads = []
|
|
238
|
+
self.processes = []
|
|
239
|
+
self._stopped = False
|
|
240
|
+
|
|
241
|
+
if num_threads <= 0 and num_processes <= 0:
|
|
242
|
+
raise ValueError("Number of threads and processes must be greater than zero.")
|
|
243
|
+
|
|
244
|
+
if self.num_processes == 0:
|
|
245
|
+
# Use threading
|
|
246
|
+
self.queue = queue.Queue()
|
|
247
|
+
for _ in range(self.num_threads):
|
|
248
|
+
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
|
|
249
|
+
t.daemon = True
|
|
250
|
+
t.start()
|
|
251
|
+
self.threads.append(t)
|
|
252
|
+
else:
|
|
253
|
+
# Use multiprocessing
|
|
254
|
+
self.queue = multiprocessing.JoinableQueue()
|
|
255
|
+
for _ in range(self.num_processes):
|
|
256
|
+
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
|
257
|
+
p.daemon = True
|
|
258
|
+
p.start()
|
|
259
|
+
self.processes.append(p)
|
|
260
|
+
|
|
261
|
+
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
|
262
|
+
"""Queue an image for asynchronous writing.
|
|
263
|
+
|
|
264
|
+
Converts torch tensors to numpy arrays and adds the image to the write queue.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
image: Image to save (torch Tensor, numpy array, or PIL Image).
|
|
268
|
+
fpath: Path where the image will be saved.
|
|
269
|
+
"""
|
|
270
|
+
if isinstance(image, torch.Tensor):
|
|
271
|
+
# Convert tensor to numpy array to minimize main process time
|
|
272
|
+
image = image.cpu().numpy()
|
|
273
|
+
self.queue.put((image, fpath))
|
|
274
|
+
|
|
275
|
+
def wait_until_done(self) -> None:
|
|
276
|
+
"""Wait until all queued images have been written to disk."""
|
|
277
|
+
self.queue.join()
|
|
278
|
+
|
|
279
|
+
def stop(self) -> None:
|
|
280
|
+
"""Stop all worker threads/processes and clean up resources.
|
|
281
|
+
|
|
282
|
+
Sends sentinel values to all workers and waits for them to finish.
|
|
283
|
+
Terminates processes if they don't respond.
|
|
284
|
+
"""
|
|
285
|
+
if self._stopped:
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
if self.num_processes == 0:
|
|
289
|
+
for _ in self.threads:
|
|
290
|
+
self.queue.put(None)
|
|
291
|
+
for t in self.threads:
|
|
292
|
+
t.join()
|
|
293
|
+
else:
|
|
294
|
+
num_nones = self.num_processes * self.num_threads
|
|
295
|
+
for _ in range(num_nones):
|
|
296
|
+
self.queue.put(None)
|
|
297
|
+
for p in self.processes:
|
|
298
|
+
p.join()
|
|
299
|
+
if p.is_alive():
|
|
300
|
+
p.terminate()
|
|
301
|
+
self.queue.close()
|
|
302
|
+
self.queue.join_thread()
|
|
303
|
+
|
|
304
|
+
self._stopped = True
|