opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,67 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Vision-language grounding datasets for multimodal learning.
16
+
17
+ This module provides datasets for training vision-language-action models on
18
+ image-text grounding tasks without requiring robot actions. Grounding datasets
19
+ are designed to help models learn visual understanding, spatial reasoning,
20
+ and language grounding capabilities that can be transferred to robotic tasks.
21
+
22
+ Grounding datasets differ from standard robot learning datasets in that they:
23
+ - Provide images, prompts, and responses but no robot actions or states
24
+ - Use zero-padding for state and action features to maintain compatibility
25
+ - Focus on visual question answering, spatial reasoning, and object grounding
26
+ - Enable training on large-scale vision-language data without robot hardware
27
+
28
+ The module uses a registration system where datasets are registered via the
29
+ `@register_grounding_dataset` decorator, making them available through the
30
+ `available_grounding_datasets` registry.
31
+
32
+ Available Datasets:
33
+ - CLEVR: Compositional Language and Elementary Visual Reasoning dataset
34
+ for visual question answering with synthetic scenes.
35
+ - COCO-QA: Visual question answering dataset based on COCO images,
36
+ filtered for spatial reasoning tasks.
37
+ - PIXMO: Pixel-level manipulation grounding dataset for object
38
+ localization and manipulation tasks.
39
+ - VSR: Visual Spatial Reasoning dataset for true/false statement
40
+ grounding about spatial relationships in images.
41
+ - dummy: Synthetic test dataset with simple black, white, and gray
42
+ images for testing infrastructure.
43
+
44
+ Classes:
45
+ GroundingDataset: Base class for all grounding datasets, providing
46
+ common functionality for metadata creation, data format conversion,
47
+ and zero-padding of state/action features.
48
+
49
+ Modules:
50
+ base: Base class and common functionality for grounding datasets.
51
+ clevr: CLEVR dataset implementation.
52
+ cocoqa: COCO-QA dataset implementation.
53
+ dummy: Dummy test dataset implementation.
54
+ pixmo: PIXMO dataset implementation.
55
+ vsr: VSR dataset implementation.
56
+
57
+ Example:
58
+ Use a grounding dataset in training configuration:
59
+ >>> from opentau.configs.default import DatasetConfig
60
+ >>> cfg = DatasetConfig(grounding="cocoqa")
61
+ >>> dataset = make_dataset(cfg, train_cfg)
62
+
63
+ Access available grounding datasets:
64
+ >>> from opentau import available_grounding_datasets
65
+ >>> print(list(available_grounding_datasets.keys()))
66
+ ['clevr', 'cocoqa', 'dummy', 'pixmo', 'vsr']
67
+ """
@@ -0,0 +1,154 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Base class for vision-language grounding datasets.
16
+
17
+ This module provides the base class for all grounding datasets, which are used
18
+ for training vision-language-action models on image-text tasks without robot
19
+ actions. Grounding datasets provide images, prompts, and responses for tasks
20
+ like visual question answering, spatial reasoning, and object grounding.
21
+
22
+ The base class handles common functionality including:
23
+ - Metadata creation with ImageNet statistics for images
24
+ - Zero-padding of state and action features for compatibility
25
+ - Standard data format conversion
26
+ - Integration with the dataset mixture system
27
+
28
+ Classes:
29
+ GroundingDataset: Abstract base class that all grounding datasets inherit
30
+ from. Provides common functionality for metadata creation, data format
31
+ conversion, and zero-padding of missing features.
32
+
33
+ Example:
34
+ Create a custom grounding dataset:
35
+ >>> from opentau import register_grounding_dataset
36
+ >>> @register_grounding_dataset("my_dataset")
37
+ >>> class MyGroundingDataset(GroundingDataset):
38
+ ... def __getitem_helper__(self, item):
39
+ ... return {"image": ..., "task": ..., "postfix": ...}
40
+ """
41
+
42
+ from abc import abstractmethod
43
+ from copy import deepcopy
44
+ from typing import final
45
+
46
+ import torch
47
+
48
+ from opentau.configs.train import TrainPipelineConfig
49
+ from opentau.datasets.lerobot_dataset import CODEBASE_VERSION, BaseDataset, GroundingDatasetMetadata
50
+
51
+
52
+ class GroundingDataset(BaseDataset):
53
+ """Base class for vision-language grounding datasets.
54
+
55
+ Grounding datasets are used for training vision-language-action models on
56
+ image-text tasks without robot actions. They provide images, prompts, and
57
+ responses for grounding tasks.
58
+
59
+ Attributes:
60
+ num_frames: Number of frames in the dataset.
61
+ num_episodes: Number of episodes (always 1 for grounding datasets).
62
+ meta: Dataset metadata containing features and statistics.
63
+ """
64
+
65
+ def __init__(self, cfg: TrainPipelineConfig, num_frames: int = 1, num_episodes: int = 1):
66
+ super().__init__(cfg)
67
+ self.num_frames = num_frames
68
+ self.num_episodes = num_episodes
69
+ self.meta = self.create_meta()
70
+
71
+ def create_meta(self) -> GroundingDatasetMetadata:
72
+ """Create metadata for the grounding dataset.
73
+
74
+ Initializes metadata with ImageNet statistics for images and zero
75
+ statistics for state and actions (since grounding datasets don't have them).
76
+
77
+ Returns:
78
+ GroundingDatasetMetadata object with initialized info and stats.
79
+ """
80
+ from opentau.datasets.factory import IMAGENET_STATS
81
+
82
+ info = {
83
+ "codebase_version": CODEBASE_VERSION,
84
+ "features": {
85
+ "camera0": {
86
+ "dtype": "image",
87
+ "shape": [3, 224, 224],
88
+ "names": ["channel", "height", "width"],
89
+ },
90
+ },
91
+ }
92
+ stats = {
93
+ "image": {
94
+ "min": [[[0.0]], [[0.0]], [[0.0]]],
95
+ "max": [[[1.0]], [[1.0]], [[1.0]]],
96
+ "count": [len(self)],
97
+ **deepcopy(IMAGENET_STATS), # mean and std
98
+ },
99
+ "state": {
100
+ "min": [0.0],
101
+ "max": [0.0],
102
+ "mean": [0.0],
103
+ "std": [0.0],
104
+ "count": [len(self)],
105
+ },
106
+ "actions": {
107
+ "min": [0.0],
108
+ "max": [0.0],
109
+ "mean": [0.0],
110
+ "std": [0.0],
111
+ "count": [len(self)],
112
+ },
113
+ }
114
+ metadata = GroundingDatasetMetadata(info=info, stats=stats)
115
+ metadata.repo_id = self._get_feature_mapping_key()
116
+ return metadata
117
+
118
+ @abstractmethod
119
+ def __getitem_helper__(self, item) -> dict:
120
+ """Helper method to get a dataset item (to be implemented by subclasses).
121
+
122
+ Args:
123
+ item: Index of the item to retrieve.
124
+
125
+ Returns:
126
+ Dictionary containing the raw item data with keys like 'image',
127
+ 'task', 'postfix', 'task_type', 'prompt'.
128
+ """
129
+ pass
130
+
131
+ @final
132
+ def __getitem__(self, item):
133
+ item = self.__getitem_helper__(item)
134
+
135
+ # Grounding datasets don't have states or actions. 0-padding is used.
136
+ item["state"] = torch.zeros(self.max_state_dim)
137
+ item["actions"] = torch.zeros(self.action_chunk, self.max_action_dim)
138
+ item["actions_is_pad"] = torch.ones(self.action_chunk, dtype=torch.bool)
139
+ item = self._to_standard_data_format(item)
140
+ item["return_bin_idx"] = torch.tensor(0, dtype=torch.long)
141
+ item["return_continuous"] = torch.tensor(0, dtype=torch.float32)
142
+ item["advantage"] = torch.tensor(0, dtype=torch.bfloat16)
143
+ return item
144
+
145
+ def _separate_image_in_time(self, item: dict) -> None:
146
+ """Separate images in time (no-op for grounding datasets).
147
+
148
+ Grounding datasets don't have temporal image sequences, so this is a no-op.
149
+
150
+ Args:
151
+ item: Item dictionary (unmodified).
152
+ """
153
+ # Grounding datasets has nothing to separate.
154
+ pass
@@ -0,0 +1,110 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """CLEVR dataset for visual reasoning and grounding tasks.
16
+
17
+ This module provides the CLEVR (Compositional Language and Elementary Visual
18
+ Reasoning) dataset implementation for training vision-language models on
19
+ compositional visual reasoning tasks. The dataset contains synthetic scenes
20
+ with geometric objects and questions requiring compositional reasoning.
21
+
22
+ The dataset is loaded from HuggingFace and formatted for grounding tasks,
23
+ providing images, questions, and answers for visual reasoning.
24
+
25
+ Classes:
26
+ CLEVRDataset: Dataset class that loads and formats CLEVR data from
27
+ MMInstruction/Clevr_CoGenT_TrainA_70K_Complex on HuggingFace.
28
+
29
+ Functions:
30
+ _img_to_normalized_tensor: Convert PIL Image to normalized torch tensor
31
+ with channel-first format and [0, 1] normalization.
32
+
33
+ Example:
34
+ Use CLEVR dataset in training:
35
+ >>> from opentau.configs.default import DatasetConfig
36
+ >>> cfg = DatasetConfig(grounding="clevr")
37
+ >>> dataset = make_dataset(cfg, train_cfg)
38
+ """
39
+
40
+ import logging
41
+
42
+ import numpy as np
43
+ import torch
44
+ from datasets import load_dataset
45
+ from PIL import Image
46
+
47
+ from opentau import register_grounding_dataset
48
+ from opentau.configs.train import TrainPipelineConfig
49
+ from opentau.datasets.grounding.base import GroundingDataset
50
+
51
+ logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
52
+
53
+
54
+ def _img_to_normalized_tensor(img: Image.Image, img_shape: tuple) -> torch.Tensor:
55
+ """Convert a PIL Image to a normalized torch tensor.
56
+
57
+ Resizes the image and converts it from (H, W, C) to (C, H, W) format,
58
+ normalizing pixel values to [0, 1].
59
+
60
+ Args:
61
+ img: PIL Image to convert.
62
+ img_shape: Target image shape (height, width).
63
+
64
+ Returns:
65
+ Normalized tensor of shape (C, H, W) with values in [0, 1].
66
+ """
67
+ img = img.resize(img_shape, Image.BILINEAR)
68
+
69
+ # pytorch uses (C, H, W) while PIL uses (H, W, C)
70
+ return torch.from_numpy(np.array(img))[:, :, :3].permute(2, 0, 1).float() / 255.0
71
+
72
+
73
+ @register_grounding_dataset("clevr")
74
+ class CLEVRDataset(GroundingDataset):
75
+ """CLEVR dataset for visual reasoning and grounding tasks.
76
+
77
+ Loads the MMInstruction/Clevr_CoGenT_TrainA_70K_Complex dataset from
78
+ HuggingFace and formats it for grounding tasks.
79
+ """
80
+
81
+ def __init__(self, cfg: TrainPipelineConfig, consecutive_bad_tolerance=100):
82
+ self.dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_70K_Complex", split="train")
83
+ super().__init__(cfg)
84
+
85
+ def __len__(self):
86
+ return len(self.dataset)
87
+
88
+ def _get_feature_mapping_key(self) -> str:
89
+ return "clevr"
90
+
91
+ def __getitem_helper__(self, item) -> dict:
92
+ """Get a CLEVR dataset item.
93
+
94
+ Args:
95
+ item: Index of the item to retrieve.
96
+
97
+ Returns:
98
+ Dictionary with image, task, postfix, task_type, and prompt
99
+ extracted from the CLEVR dataset sample.
100
+ """
101
+ sample = self.dataset[item]
102
+ img = sample["image"]
103
+
104
+ return {
105
+ "image": _img_to_normalized_tensor(img, self.resolution),
106
+ "task": "grounding",
107
+ "postfix": f"The answer is {sample['solution'].split('<answer>')[1].split('</answer>')[0]}",
108
+ "task_type": "grounding",
109
+ "prompt": f'{{"task": "grounding", "description": "Using the Image, Answer the following question. \n {sample["problem"]}"}}',
110
+ }
@@ -0,0 +1,130 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """COCO-QA dataset for visual question answering and grounding tasks.
16
+
17
+ This module provides the COCO-QA dataset implementation for training
18
+ vision-language models on visual question answering tasks. The dataset is
19
+ filtered to only include 'where' questions, focusing on spatial reasoning
20
+ tasks that are relevant for robotic manipulation.
21
+
22
+ The dataset is loaded from HuggingFace (ThucPD/coco-qa-vi) and automatically
23
+ filtered to retain only spatial reasoning questions.
24
+
25
+ Classes:
26
+ COCODataset: Dataset class that loads, filters, and formats COCO-QA data
27
+ for grounding tasks.
28
+
29
+ Functions:
30
+ _img_to_normalized_tensor: Convert PIL Image to normalized torch tensor
31
+ with channel-first format and [0, 1] normalization.
32
+ _filter_dataset: Filter dataset samples to only include 'where' questions
33
+ for spatial reasoning tasks.
34
+
35
+ Example:
36
+ Use COCO-QA dataset in training:
37
+ >>> from opentau.configs.default import DatasetConfig
38
+ >>> cfg = DatasetConfig(grounding="cocoqa")
39
+ >>> dataset = make_dataset(cfg, train_cfg)
40
+ """
41
+
42
+ import logging
43
+ from typing import List
44
+
45
+ import numpy as np
46
+ import torch
47
+ from datasets import load_dataset
48
+ from PIL import Image
49
+
50
+ from opentau import register_grounding_dataset
51
+ from opentau.configs.train import TrainPipelineConfig
52
+ from opentau.datasets.grounding.base import GroundingDataset
53
+
54
+ logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
55
+
56
+
57
+ def _img_to_normalized_tensor(img: Image.Image, img_shape: tuple) -> torch.Tensor:
58
+ """Convert a PIL Image to a normalized torch tensor.
59
+
60
+ Resizes the image and converts it from (H, W, C) to (C, H, W) format,
61
+ normalizing pixel values to [0, 1].
62
+
63
+ Args:
64
+ img: PIL Image to convert.
65
+ img_shape: Target image shape (height, width).
66
+
67
+ Returns:
68
+ Normalized tensor of shape (C, H, W) with values in [0, 1].
69
+ """
70
+ img = img.resize(img_shape, Image.BILINEAR)
71
+ return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
72
+
73
+
74
+ def _filter_dataset(dataset: List) -> List:
75
+ """Filter dataset to only include samples with 'where' questions.
76
+
77
+ Args:
78
+ dataset: List of dataset samples.
79
+
80
+ Returns:
81
+ Filtered list containing only samples with 'where' in the question.
82
+ """
83
+ filtered_dataset = []
84
+ for sd in dataset:
85
+ if "where" in sd["question"]:
86
+ filtered_dataset.append(sd)
87
+
88
+ return filtered_dataset
89
+
90
+
91
+ @register_grounding_dataset("cocoqa")
92
+ class COCODataset(GroundingDataset):
93
+ """COCO-QA dataset for visual question answering and grounding tasks.
94
+
95
+ Loads the ThucPD/coco-qa-vi dataset from HuggingFace and filters it to
96
+ only include 'where' questions for spatial reasoning tasks.
97
+ """
98
+
99
+ def __init__(self, cfg: TrainPipelineConfig):
100
+ self.dataset = load_dataset("ThucPD/coco-qa-vi", split="train")
101
+
102
+ self.filtered_dataset = _filter_dataset(self.dataset)
103
+ super().__init__(cfg)
104
+
105
+ def __len__(self):
106
+ return len(self.filtered_dataset)
107
+
108
+ def _get_feature_mapping_key(self) -> str:
109
+ return "cocoqa"
110
+
111
+ def __getitem_helper__(self, item) -> dict:
112
+ """Get a COCO-QA dataset item.
113
+
114
+ Args:
115
+ item: Index of the item to retrieve.
116
+
117
+ Returns:
118
+ Dictionary with image, task, postfix, task_type, and prompt
119
+ extracted from the COCO-QA dataset sample.
120
+ """
121
+ sample = self.filtered_dataset[item]
122
+ img = sample["image"]
123
+
124
+ return {
125
+ "image": _img_to_normalized_tensor(img, self.resolution),
126
+ "task": "grounding",
127
+ "postfix": f"The answer is {sample['answer']}",
128
+ "task_type": "grounding",
129
+ "prompt": f'{{"task": "grounding", "description": "Using the Image, Answer the following question. \n {sample["question"]}"}}',
130
+ }
@@ -0,0 +1,101 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dummy grounding dataset for testing and development.
16
+
17
+ This module provides a simple synthetic grounding dataset for testing the
18
+ dataset infrastructure without requiring external data sources or network
19
+ access. The dataset contains three predefined items: black, white, and gray
20
+ images with corresponding question-answer pairs.
21
+
22
+ The dataset cycles through the predefined items, making it useful for
23
+ testing data loading pipelines, training loops, and debugging.
24
+
25
+ Classes:
26
+ DummyGroundingDataset: Synthetic dataset class that provides simple test
27
+ data with configurable length.
28
+
29
+ Constants:
30
+ _data: List of three predefined dataset items (black, white, gray images)
31
+ with corresponding tasks, postfixes, and prompts.
32
+
33
+ Example:
34
+ Use dummy dataset for testing:
35
+ >>> from opentau.configs.default import DatasetConfig
36
+ >>> cfg = DatasetConfig(grounding="dummy")
37
+ >>> dataset = make_dataset(cfg, train_cfg)
38
+ >>> len(dataset) # Returns 1000 by default
39
+ """
40
+
41
+ import torch
42
+
43
+ from opentau import register_grounding_dataset
44
+ from opentau.configs.train import TrainPipelineConfig
45
+ from opentau.datasets.grounding.base import GroundingDataset
46
+
47
+ _data = [
48
+ {
49
+ "image": torch.zeros(3, 224, 224),
50
+ "task": "What do you see in the image?",
51
+ "postfix": "This is a black image",
52
+ "task_type": "qa",
53
+ "prompt": '{"task": "qa", "description": "What do you see in the image?"}',
54
+ },
55
+ {
56
+ "image": torch.ones(3, 224, 224),
57
+ "task": "What do you see in the image?",
58
+ "postfix": "This is a white image",
59
+ "task_type": "qa",
60
+ "prompt": '{"task": "qa", "description": "What do you see in the image?"}',
61
+ },
62
+ {
63
+ "image": torch.ones(3, 224, 224) * 0.5,
64
+ "task": "What do you see in the image?",
65
+ "postfix": "This is a gray image",
66
+ "task_type": "qa",
67
+ "prompt": '{"task": "qa", "description": "What do you see in the image?"}',
68
+ },
69
+ ]
70
+
71
+
72
+ @register_grounding_dataset("dummy")
73
+ class DummyGroundingDataset(GroundingDataset):
74
+ """Dummy grounding dataset for testing purposes.
75
+
76
+ Provides simple synthetic data with black, white, and gray images
77
+ for testing the dataset infrastructure.
78
+ """
79
+
80
+ def __init__(self, cfg: TrainPipelineConfig, length: int = 1000):
81
+ self.length = length
82
+ super().__init__(cfg)
83
+
84
+ def __len__(self):
85
+ return self.length
86
+
87
+ def __getitem_helper__(self, item) -> dict:
88
+ """Get a dummy dataset item.
89
+
90
+ Cycles through a small set of predefined items (black, white, gray images).
91
+
92
+ Args:
93
+ item: Index of the item to retrieve.
94
+
95
+ Returns:
96
+ Dictionary with image, task, postfix, task_type, and prompt.
97
+ """
98
+ return _data[item % len(_data)]
99
+
100
+ def _get_feature_mapping_key(self) -> str:
101
+ return "dummy"