opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Vision-language grounding datasets for multimodal learning.
|
|
16
|
+
|
|
17
|
+
This module provides datasets for training vision-language-action models on
|
|
18
|
+
image-text grounding tasks without requiring robot actions. Grounding datasets
|
|
19
|
+
are designed to help models learn visual understanding, spatial reasoning,
|
|
20
|
+
and language grounding capabilities that can be transferred to robotic tasks.
|
|
21
|
+
|
|
22
|
+
Grounding datasets differ from standard robot learning datasets in that they:
|
|
23
|
+
- Provide images, prompts, and responses but no robot actions or states
|
|
24
|
+
- Use zero-padding for state and action features to maintain compatibility
|
|
25
|
+
- Focus on visual question answering, spatial reasoning, and object grounding
|
|
26
|
+
- Enable training on large-scale vision-language data without robot hardware
|
|
27
|
+
|
|
28
|
+
The module uses a registration system where datasets are registered via the
|
|
29
|
+
`@register_grounding_dataset` decorator, making them available through the
|
|
30
|
+
`available_grounding_datasets` registry.
|
|
31
|
+
|
|
32
|
+
Available Datasets:
|
|
33
|
+
- CLEVR: Compositional Language and Elementary Visual Reasoning dataset
|
|
34
|
+
for visual question answering with synthetic scenes.
|
|
35
|
+
- COCO-QA: Visual question answering dataset based on COCO images,
|
|
36
|
+
filtered for spatial reasoning tasks.
|
|
37
|
+
- PIXMO: Pixel-level manipulation grounding dataset for object
|
|
38
|
+
localization and manipulation tasks.
|
|
39
|
+
- VSR: Visual Spatial Reasoning dataset for true/false statement
|
|
40
|
+
grounding about spatial relationships in images.
|
|
41
|
+
- dummy: Synthetic test dataset with simple black, white, and gray
|
|
42
|
+
images for testing infrastructure.
|
|
43
|
+
|
|
44
|
+
Classes:
|
|
45
|
+
GroundingDataset: Base class for all grounding datasets, providing
|
|
46
|
+
common functionality for metadata creation, data format conversion,
|
|
47
|
+
and zero-padding of state/action features.
|
|
48
|
+
|
|
49
|
+
Modules:
|
|
50
|
+
base: Base class and common functionality for grounding datasets.
|
|
51
|
+
clevr: CLEVR dataset implementation.
|
|
52
|
+
cocoqa: COCO-QA dataset implementation.
|
|
53
|
+
dummy: Dummy test dataset implementation.
|
|
54
|
+
pixmo: PIXMO dataset implementation.
|
|
55
|
+
vsr: VSR dataset implementation.
|
|
56
|
+
|
|
57
|
+
Example:
|
|
58
|
+
Use a grounding dataset in training configuration:
|
|
59
|
+
>>> from opentau.configs.default import DatasetConfig
|
|
60
|
+
>>> cfg = DatasetConfig(grounding="cocoqa")
|
|
61
|
+
>>> dataset = make_dataset(cfg, train_cfg)
|
|
62
|
+
|
|
63
|
+
Access available grounding datasets:
|
|
64
|
+
>>> from opentau import available_grounding_datasets
|
|
65
|
+
>>> print(list(available_grounding_datasets.keys()))
|
|
66
|
+
['clevr', 'cocoqa', 'dummy', 'pixmo', 'vsr']
|
|
67
|
+
"""
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Base class for vision-language grounding datasets.
|
|
16
|
+
|
|
17
|
+
This module provides the base class for all grounding datasets, which are used
|
|
18
|
+
for training vision-language-action models on image-text tasks without robot
|
|
19
|
+
actions. Grounding datasets provide images, prompts, and responses for tasks
|
|
20
|
+
like visual question answering, spatial reasoning, and object grounding.
|
|
21
|
+
|
|
22
|
+
The base class handles common functionality including:
|
|
23
|
+
- Metadata creation with ImageNet statistics for images
|
|
24
|
+
- Zero-padding of state and action features for compatibility
|
|
25
|
+
- Standard data format conversion
|
|
26
|
+
- Integration with the dataset mixture system
|
|
27
|
+
|
|
28
|
+
Classes:
|
|
29
|
+
GroundingDataset: Abstract base class that all grounding datasets inherit
|
|
30
|
+
from. Provides common functionality for metadata creation, data format
|
|
31
|
+
conversion, and zero-padding of missing features.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
Create a custom grounding dataset:
|
|
35
|
+
>>> from opentau import register_grounding_dataset
|
|
36
|
+
>>> @register_grounding_dataset("my_dataset")
|
|
37
|
+
>>> class MyGroundingDataset(GroundingDataset):
|
|
38
|
+
... def __getitem_helper__(self, item):
|
|
39
|
+
... return {"image": ..., "task": ..., "postfix": ...}
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
from abc import abstractmethod
|
|
43
|
+
from copy import deepcopy
|
|
44
|
+
from typing import final
|
|
45
|
+
|
|
46
|
+
import torch
|
|
47
|
+
|
|
48
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
49
|
+
from opentau.datasets.lerobot_dataset import CODEBASE_VERSION, BaseDataset, GroundingDatasetMetadata
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class GroundingDataset(BaseDataset):
|
|
53
|
+
"""Base class for vision-language grounding datasets.
|
|
54
|
+
|
|
55
|
+
Grounding datasets are used for training vision-language-action models on
|
|
56
|
+
image-text tasks without robot actions. They provide images, prompts, and
|
|
57
|
+
responses for grounding tasks.
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
num_frames: Number of frames in the dataset.
|
|
61
|
+
num_episodes: Number of episodes (always 1 for grounding datasets).
|
|
62
|
+
meta: Dataset metadata containing features and statistics.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, cfg: TrainPipelineConfig, num_frames: int = 1, num_episodes: int = 1):
|
|
66
|
+
super().__init__(cfg)
|
|
67
|
+
self.num_frames = num_frames
|
|
68
|
+
self.num_episodes = num_episodes
|
|
69
|
+
self.meta = self.create_meta()
|
|
70
|
+
|
|
71
|
+
def create_meta(self) -> GroundingDatasetMetadata:
|
|
72
|
+
"""Create metadata for the grounding dataset.
|
|
73
|
+
|
|
74
|
+
Initializes metadata with ImageNet statistics for images and zero
|
|
75
|
+
statistics for state and actions (since grounding datasets don't have them).
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
GroundingDatasetMetadata object with initialized info and stats.
|
|
79
|
+
"""
|
|
80
|
+
from opentau.datasets.factory import IMAGENET_STATS
|
|
81
|
+
|
|
82
|
+
info = {
|
|
83
|
+
"codebase_version": CODEBASE_VERSION,
|
|
84
|
+
"features": {
|
|
85
|
+
"camera0": {
|
|
86
|
+
"dtype": "image",
|
|
87
|
+
"shape": [3, 224, 224],
|
|
88
|
+
"names": ["channel", "height", "width"],
|
|
89
|
+
},
|
|
90
|
+
},
|
|
91
|
+
}
|
|
92
|
+
stats = {
|
|
93
|
+
"image": {
|
|
94
|
+
"min": [[[0.0]], [[0.0]], [[0.0]]],
|
|
95
|
+
"max": [[[1.0]], [[1.0]], [[1.0]]],
|
|
96
|
+
"count": [len(self)],
|
|
97
|
+
**deepcopy(IMAGENET_STATS), # mean and std
|
|
98
|
+
},
|
|
99
|
+
"state": {
|
|
100
|
+
"min": [0.0],
|
|
101
|
+
"max": [0.0],
|
|
102
|
+
"mean": [0.0],
|
|
103
|
+
"std": [0.0],
|
|
104
|
+
"count": [len(self)],
|
|
105
|
+
},
|
|
106
|
+
"actions": {
|
|
107
|
+
"min": [0.0],
|
|
108
|
+
"max": [0.0],
|
|
109
|
+
"mean": [0.0],
|
|
110
|
+
"std": [0.0],
|
|
111
|
+
"count": [len(self)],
|
|
112
|
+
},
|
|
113
|
+
}
|
|
114
|
+
metadata = GroundingDatasetMetadata(info=info, stats=stats)
|
|
115
|
+
metadata.repo_id = self._get_feature_mapping_key()
|
|
116
|
+
return metadata
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
def __getitem_helper__(self, item) -> dict:
|
|
120
|
+
"""Helper method to get a dataset item (to be implemented by subclasses).
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
item: Index of the item to retrieve.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Dictionary containing the raw item data with keys like 'image',
|
|
127
|
+
'task', 'postfix', 'task_type', 'prompt'.
|
|
128
|
+
"""
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
@final
|
|
132
|
+
def __getitem__(self, item):
|
|
133
|
+
item = self.__getitem_helper__(item)
|
|
134
|
+
|
|
135
|
+
# Grounding datasets don't have states or actions. 0-padding is used.
|
|
136
|
+
item["state"] = torch.zeros(self.max_state_dim)
|
|
137
|
+
item["actions"] = torch.zeros(self.action_chunk, self.max_action_dim)
|
|
138
|
+
item["actions_is_pad"] = torch.ones(self.action_chunk, dtype=torch.bool)
|
|
139
|
+
item = self._to_standard_data_format(item)
|
|
140
|
+
item["return_bin_idx"] = torch.tensor(0, dtype=torch.long)
|
|
141
|
+
item["return_continuous"] = torch.tensor(0, dtype=torch.float32)
|
|
142
|
+
item["advantage"] = torch.tensor(0, dtype=torch.bfloat16)
|
|
143
|
+
return item
|
|
144
|
+
|
|
145
|
+
def _separate_image_in_time(self, item: dict) -> None:
|
|
146
|
+
"""Separate images in time (no-op for grounding datasets).
|
|
147
|
+
|
|
148
|
+
Grounding datasets don't have temporal image sequences, so this is a no-op.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
item: Item dictionary (unmodified).
|
|
152
|
+
"""
|
|
153
|
+
# Grounding datasets has nothing to separate.
|
|
154
|
+
pass
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""CLEVR dataset for visual reasoning and grounding tasks.
|
|
16
|
+
|
|
17
|
+
This module provides the CLEVR (Compositional Language and Elementary Visual
|
|
18
|
+
Reasoning) dataset implementation for training vision-language models on
|
|
19
|
+
compositional visual reasoning tasks. The dataset contains synthetic scenes
|
|
20
|
+
with geometric objects and questions requiring compositional reasoning.
|
|
21
|
+
|
|
22
|
+
The dataset is loaded from HuggingFace and formatted for grounding tasks,
|
|
23
|
+
providing images, questions, and answers for visual reasoning.
|
|
24
|
+
|
|
25
|
+
Classes:
|
|
26
|
+
CLEVRDataset: Dataset class that loads and formats CLEVR data from
|
|
27
|
+
MMInstruction/Clevr_CoGenT_TrainA_70K_Complex on HuggingFace.
|
|
28
|
+
|
|
29
|
+
Functions:
|
|
30
|
+
_img_to_normalized_tensor: Convert PIL Image to normalized torch tensor
|
|
31
|
+
with channel-first format and [0, 1] normalization.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
Use CLEVR dataset in training:
|
|
35
|
+
>>> from opentau.configs.default import DatasetConfig
|
|
36
|
+
>>> cfg = DatasetConfig(grounding="clevr")
|
|
37
|
+
>>> dataset = make_dataset(cfg, train_cfg)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
import logging
|
|
41
|
+
|
|
42
|
+
import numpy as np
|
|
43
|
+
import torch
|
|
44
|
+
from datasets import load_dataset
|
|
45
|
+
from PIL import Image
|
|
46
|
+
|
|
47
|
+
from opentau import register_grounding_dataset
|
|
48
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
49
|
+
from opentau.datasets.grounding.base import GroundingDataset
|
|
50
|
+
|
|
51
|
+
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _img_to_normalized_tensor(img: Image.Image, img_shape: tuple) -> torch.Tensor:
|
|
55
|
+
"""Convert a PIL Image to a normalized torch tensor.
|
|
56
|
+
|
|
57
|
+
Resizes the image and converts it from (H, W, C) to (C, H, W) format,
|
|
58
|
+
normalizing pixel values to [0, 1].
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
img: PIL Image to convert.
|
|
62
|
+
img_shape: Target image shape (height, width).
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Normalized tensor of shape (C, H, W) with values in [0, 1].
|
|
66
|
+
"""
|
|
67
|
+
img = img.resize(img_shape, Image.BILINEAR)
|
|
68
|
+
|
|
69
|
+
# pytorch uses (C, H, W) while PIL uses (H, W, C)
|
|
70
|
+
return torch.from_numpy(np.array(img))[:, :, :3].permute(2, 0, 1).float() / 255.0
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@register_grounding_dataset("clevr")
|
|
74
|
+
class CLEVRDataset(GroundingDataset):
|
|
75
|
+
"""CLEVR dataset for visual reasoning and grounding tasks.
|
|
76
|
+
|
|
77
|
+
Loads the MMInstruction/Clevr_CoGenT_TrainA_70K_Complex dataset from
|
|
78
|
+
HuggingFace and formats it for grounding tasks.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self, cfg: TrainPipelineConfig, consecutive_bad_tolerance=100):
|
|
82
|
+
self.dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex", split="train")
|
|
83
|
+
super().__init__(cfg)
|
|
84
|
+
|
|
85
|
+
def __len__(self):
|
|
86
|
+
return len(self.dataset)
|
|
87
|
+
|
|
88
|
+
def _get_feature_mapping_key(self) -> str:
|
|
89
|
+
return "clevr"
|
|
90
|
+
|
|
91
|
+
def __getitem_helper__(self, item) -> dict:
|
|
92
|
+
"""Get a CLEVR dataset item.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
item: Index of the item to retrieve.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Dictionary with image, task, postfix, task_type, and prompt
|
|
99
|
+
extracted from the CLEVR dataset sample.
|
|
100
|
+
"""
|
|
101
|
+
sample = self.dataset[item]
|
|
102
|
+
img = sample["image"]
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
"image": _img_to_normalized_tensor(img, self.resolution),
|
|
106
|
+
"task": "grounding",
|
|
107
|
+
"postfix": f"The answer is {sample['solution'].split('<answer>')[1].split('</answer>')[0]}",
|
|
108
|
+
"task_type": "grounding",
|
|
109
|
+
"prompt": f'{{"task": "grounding", "description": "Using the Image, Answer the following question. \n {sample["problem"]}"}}',
|
|
110
|
+
}
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""COCO-QA dataset for visual question answering and grounding tasks.
|
|
16
|
+
|
|
17
|
+
This module provides the COCO-QA dataset implementation for training
|
|
18
|
+
vision-language models on visual question answering tasks. The dataset is
|
|
19
|
+
filtered to only include 'where' questions, focusing on spatial reasoning
|
|
20
|
+
tasks that are relevant for robotic manipulation.
|
|
21
|
+
|
|
22
|
+
The dataset is loaded from HuggingFace (ThucPD/coco-qa-vi) and automatically
|
|
23
|
+
filtered to retain only spatial reasoning questions.
|
|
24
|
+
|
|
25
|
+
Classes:
|
|
26
|
+
COCODataset: Dataset class that loads, filters, and formats COCO-QA data
|
|
27
|
+
for grounding tasks.
|
|
28
|
+
|
|
29
|
+
Functions:
|
|
30
|
+
_img_to_normalized_tensor: Convert PIL Image to normalized torch tensor
|
|
31
|
+
with channel-first format and [0, 1] normalization.
|
|
32
|
+
_filter_dataset: Filter dataset samples to only include 'where' questions
|
|
33
|
+
for spatial reasoning tasks.
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
Use COCO-QA dataset in training:
|
|
37
|
+
>>> from opentau.configs.default import DatasetConfig
|
|
38
|
+
>>> cfg = DatasetConfig(grounding="cocoqa")
|
|
39
|
+
>>> dataset = make_dataset(cfg, train_cfg)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
import logging
|
|
43
|
+
from typing import List
|
|
44
|
+
|
|
45
|
+
import numpy as np
|
|
46
|
+
import torch
|
|
47
|
+
from datasets import load_dataset
|
|
48
|
+
from PIL import Image
|
|
49
|
+
|
|
50
|
+
from opentau import register_grounding_dataset
|
|
51
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
52
|
+
from opentau.datasets.grounding.base import GroundingDataset
|
|
53
|
+
|
|
54
|
+
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _img_to_normalized_tensor(img: Image.Image, img_shape: tuple) -> torch.Tensor:
|
|
58
|
+
"""Convert a PIL Image to a normalized torch tensor.
|
|
59
|
+
|
|
60
|
+
Resizes the image and converts it from (H, W, C) to (C, H, W) format,
|
|
61
|
+
normalizing pixel values to [0, 1].
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
img: PIL Image to convert.
|
|
65
|
+
img_shape: Target image shape (height, width).
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Normalized tensor of shape (C, H, W) with values in [0, 1].
|
|
69
|
+
"""
|
|
70
|
+
img = img.resize(img_shape, Image.BILINEAR)
|
|
71
|
+
return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _filter_dataset(dataset: List) -> List:
|
|
75
|
+
"""Filter dataset to only include samples with 'where' questions.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
dataset: List of dataset samples.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Filtered list containing only samples with 'where' in the question.
|
|
82
|
+
"""
|
|
83
|
+
filtered_dataset = []
|
|
84
|
+
for sd in dataset:
|
|
85
|
+
if "where" in sd["question"]:
|
|
86
|
+
filtered_dataset.append(sd)
|
|
87
|
+
|
|
88
|
+
return filtered_dataset
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@register_grounding_dataset("cocoqa")
|
|
92
|
+
class COCODataset(GroundingDataset):
|
|
93
|
+
"""COCO-QA dataset for visual question answering and grounding tasks.
|
|
94
|
+
|
|
95
|
+
Loads the ThucPD/coco-qa-vi dataset from HuggingFace and filters it to
|
|
96
|
+
only include 'where' questions for spatial reasoning tasks.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(self, cfg: TrainPipelineConfig):
|
|
100
|
+
self.dataset = load_dataset("ThucPD/coco-qa-vi", split="train")
|
|
101
|
+
|
|
102
|
+
self.filtered_dataset = _filter_dataset(self.dataset)
|
|
103
|
+
super().__init__(cfg)
|
|
104
|
+
|
|
105
|
+
def __len__(self):
|
|
106
|
+
return len(self.filtered_dataset)
|
|
107
|
+
|
|
108
|
+
def _get_feature_mapping_key(self) -> str:
|
|
109
|
+
return "cocoqa"
|
|
110
|
+
|
|
111
|
+
def __getitem_helper__(self, item) -> dict:
|
|
112
|
+
"""Get a COCO-QA dataset item.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
item: Index of the item to retrieve.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Dictionary with image, task, postfix, task_type, and prompt
|
|
119
|
+
extracted from the COCO-QA dataset sample.
|
|
120
|
+
"""
|
|
121
|
+
sample = self.filtered_dataset[item]
|
|
122
|
+
img = sample["image"]
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"image": _img_to_normalized_tensor(img, self.resolution),
|
|
126
|
+
"task": "grounding",
|
|
127
|
+
"postfix": f"The answer is {sample['answer']}",
|
|
128
|
+
"task_type": "grounding",
|
|
129
|
+
"prompt": f'{{"task": "grounding", "description": "Using the Image, Answer the following question. \n {sample["question"]}"}}',
|
|
130
|
+
}
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Dummy grounding dataset for testing and development.
|
|
16
|
+
|
|
17
|
+
This module provides a simple synthetic grounding dataset for testing the
|
|
18
|
+
dataset infrastructure without requiring external data sources or network
|
|
19
|
+
access. The dataset contains three predefined items: black, white, and gray
|
|
20
|
+
images with corresponding question-answer pairs.
|
|
21
|
+
|
|
22
|
+
The dataset cycles through the predefined items, making it useful for
|
|
23
|
+
testing data loading pipelines, training loops, and debugging.
|
|
24
|
+
|
|
25
|
+
Classes:
|
|
26
|
+
DummyGroundingDataset: Synthetic dataset class that provides simple test
|
|
27
|
+
data with configurable length.
|
|
28
|
+
|
|
29
|
+
Constants:
|
|
30
|
+
_data: List of three predefined dataset items (black, white, gray images)
|
|
31
|
+
with corresponding tasks, postfixes, and prompts.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
Use dummy dataset for testing:
|
|
35
|
+
>>> from opentau.configs.default import DatasetConfig
|
|
36
|
+
>>> cfg = DatasetConfig(grounding="dummy")
|
|
37
|
+
>>> dataset = make_dataset(cfg, train_cfg)
|
|
38
|
+
>>> len(dataset) # Returns 1000 by default
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
import torch
|
|
42
|
+
|
|
43
|
+
from opentau import register_grounding_dataset
|
|
44
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
45
|
+
from opentau.datasets.grounding.base import GroundingDataset
|
|
46
|
+
|
|
47
|
+
_data = [
|
|
48
|
+
{
|
|
49
|
+
"image": torch.zeros(3, 224, 224),
|
|
50
|
+
"task": "What do you see in the image?",
|
|
51
|
+
"postfix": "This is a black image",
|
|
52
|
+
"task_type": "qa",
|
|
53
|
+
"prompt": '{"task": "qa", "description": "What do you see in the image?"}',
|
|
54
|
+
},
|
|
55
|
+
{
|
|
56
|
+
"image": torch.ones(3, 224, 224),
|
|
57
|
+
"task": "What do you see in the image?",
|
|
58
|
+
"postfix": "This is a white image",
|
|
59
|
+
"task_type": "qa",
|
|
60
|
+
"prompt": '{"task": "qa", "description": "What do you see in the image?"}',
|
|
61
|
+
},
|
|
62
|
+
{
|
|
63
|
+
"image": torch.ones(3, 224, 224) * 0.5,
|
|
64
|
+
"task": "What do you see in the image?",
|
|
65
|
+
"postfix": "This is a gray image",
|
|
66
|
+
"task_type": "qa",
|
|
67
|
+
"prompt": '{"task": "qa", "description": "What do you see in the image?"}',
|
|
68
|
+
},
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@register_grounding_dataset("dummy")
|
|
73
|
+
class DummyGroundingDataset(GroundingDataset):
|
|
74
|
+
"""Dummy grounding dataset for testing purposes.
|
|
75
|
+
|
|
76
|
+
Provides simple synthetic data with black, white, and gray images
|
|
77
|
+
for testing the dataset infrastructure.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self, cfg: TrainPipelineConfig, length: int = 1000):
|
|
81
|
+
self.length = length
|
|
82
|
+
super().__init__(cfg)
|
|
83
|
+
|
|
84
|
+
def __len__(self):
|
|
85
|
+
return self.length
|
|
86
|
+
|
|
87
|
+
def __getitem_helper__(self, item) -> dict:
|
|
88
|
+
"""Get a dummy dataset item.
|
|
89
|
+
|
|
90
|
+
Cycles through a small set of predefined items (black, white, gray images).
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
item: Index of the item to retrieve.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Dictionary with image, task, postfix, task_type, and prompt.
|
|
97
|
+
"""
|
|
98
|
+
return _data[item % len(_data)]
|
|
99
|
+
|
|
100
|
+
def _get_feature_mapping_key(self) -> str:
|
|
101
|
+
return "dummy"
|