opentau 0.1.1__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. {opentau-0.1.1/src/opentau.egg-info → opentau-0.1.2}/PKG-INFO +2 -2
  2. {opentau-0.1.1 → opentau-0.1.2}/pyproject.toml +3 -2
  3. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/lerobot_dataset.py +7 -9
  4. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/launch.py +5 -0
  5. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/visualize_dataset.py +46 -30
  6. opentau-0.1.2/src/opentau/utils/transformers_patch.py +276 -0
  7. {opentau-0.1.1 → opentau-0.1.2/src/opentau.egg-info}/PKG-INFO +2 -2
  8. {opentau-0.1.1 → opentau-0.1.2}/src/opentau.egg-info/SOURCES.txt +0 -1
  9. {opentau-0.1.1 → opentau-0.1.2}/src/opentau.egg-info/entry_points.txt +1 -0
  10. {opentau-0.1.1 → opentau-0.1.2}/src/opentau.egg-info/requires.txt +1 -1
  11. opentau-0.1.1/src/opentau/scripts/visualize_dataset_html.py +0 -507
  12. opentau-0.1.1/src/opentau/utils/transformers_patch.py +0 -48
  13. {opentau-0.1.1 → opentau-0.1.2}/LICENSE +0 -0
  14. {opentau-0.1.1 → opentau-0.1.2}/README.md +0 -0
  15. {opentau-0.1.1 → opentau-0.1.2}/setup.cfg +0 -0
  16. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/__init__.py +0 -0
  17. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/__version__.py +0 -0
  18. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/configs/__init__.py +0 -0
  19. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/configs/default.py +0 -0
  20. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/configs/libero.py +0 -0
  21. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/configs/parser.py +0 -0
  22. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/configs/policies.py +0 -0
  23. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/configs/reward.py +0 -0
  24. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/configs/train.py +0 -0
  25. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/configs/types.py +0 -0
  26. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/constants.py +0 -0
  27. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/__init__.py +0 -0
  28. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/backward_compatibility.py +0 -0
  29. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/compute_stats.py +0 -0
  30. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/dataset_mixture.py +0 -0
  31. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/factory.py +0 -0
  32. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/grounding/__init__.py +0 -0
  33. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/grounding/base.py +0 -0
  34. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/grounding/clevr.py +0 -0
  35. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/grounding/cocoqa.py +0 -0
  36. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/grounding/dummy.py +0 -0
  37. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/grounding/pixmo.py +0 -0
  38. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/grounding/vsr.py +0 -0
  39. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/image_writer.py +0 -0
  40. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/online_buffer.py +0 -0
  41. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/push_dataset_to_hub/utils.py +0 -0
  42. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/sampler.py +0 -0
  43. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/standard_data_format_mapping.py +0 -0
  44. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/transforms.py +0 -0
  45. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/utils.py +0 -0
  46. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +0 -0
  47. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/v2/convert_dataset_v1_to_v2.py +0 -0
  48. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/v21/_remove_language_instruction.py +0 -0
  49. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +0 -0
  50. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/v21/convert_dataset_v20_to_v21.py +0 -0
  51. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/v21/convert_stats.py +0 -0
  52. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/datasets/video_utils.py +0 -0
  53. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/envs/__init__.py +0 -0
  54. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/envs/configs.py +0 -0
  55. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/envs/factory.py +0 -0
  56. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/envs/libero.py +0 -0
  57. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/envs/utils.py +0 -0
  58. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/optim/__init__.py +0 -0
  59. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/optim/factory.py +0 -0
  60. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/optim/optimizers.py +0 -0
  61. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/optim/schedulers.py +0 -0
  62. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/planner/__init__.py +0 -0
  63. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/planner/high_level_planner.py +0 -0
  64. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/planner/utils/memory.py +0 -0
  65. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/planner/utils/utils.py +0 -0
  66. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/__init__.py +0 -0
  67. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/factory.py +0 -0
  68. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/normalize.py +0 -0
  69. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pi0/__init__.py +0 -0
  70. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pi0/configuration_pi0.py +0 -0
  71. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pi0/modeling_pi0.py +0 -0
  72. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pi0/paligemma_with_expert.py +0 -0
  73. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pi05/__init__.py +0 -0
  74. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pi05/configuration_pi05.py +0 -0
  75. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pi05/modeling_pi05.py +0 -0
  76. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pi05/paligemma_with_expert.py +0 -0
  77. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/pretrained.py +0 -0
  78. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/utils.py +0 -0
  79. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/value/__init__.py +0 -0
  80. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/value/configuration_value.py +0 -0
  81. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/value/modeling_value.py +0 -0
  82. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/value/reward.py +0 -0
  83. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/policies/value/siglip_gemma.py +0 -0
  84. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/actions_mse_loss.py +0 -0
  85. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/bin_to_safetensors.py +0 -0
  86. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/compute_max_token_length.py +0 -0
  87. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/display_sys_info.py +0 -0
  88. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/download_libero_benchmarks.py +0 -0
  89. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/eval.py +0 -0
  90. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/export_to_onnx.py +0 -0
  91. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/fake_tensor_training.py +0 -0
  92. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/get_advantage_and_percentiles.py +0 -0
  93. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/high_level_planner_inference.py +0 -0
  94. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/inference.py +0 -0
  95. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/libero_simulation_parallel.py +0 -0
  96. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/libero_simulation_sequential.py +0 -0
  97. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/nav_high_level_planner_inference.py +0 -0
  98. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/train.py +0 -0
  99. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/scripts/zero_to_fp32.py +0 -0
  100. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/__init__.py +0 -0
  101. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/accelerate_utils.py +0 -0
  102. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/benchmark.py +0 -0
  103. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/fake_tensor.py +0 -0
  104. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/hub.py +0 -0
  105. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/import_utils.py +0 -0
  106. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/io_utils.py +0 -0
  107. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/libero.py +0 -0
  108. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/libero_dataset_recorder.py +0 -0
  109. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/logging_utils.py +0 -0
  110. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/monkey_patch.py +0 -0
  111. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/random_utils.py +0 -0
  112. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/train_utils.py +0 -0
  113. {opentau-0.1.1 → opentau-0.1.2}/src/opentau/utils/utils.py +0 -0
  114. {opentau-0.1.1 → opentau-0.1.2}/src/opentau.egg-info/dependency_links.txt +0 -0
  115. {opentau-0.1.1 → opentau-0.1.2}/src/opentau.egg-info/top_level.txt +0 -0
  116. {opentau-0.1.1 → opentau-0.1.2}/tests/test_available.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opentau
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch
5
5
  Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>, Xingrui Gu <xingrui_gu@berkeley.edu>
6
6
  License: Apache-2.0
@@ -52,7 +52,7 @@ Requires-Dist: onnxruntime>=1.22.1; sys_platform == "darwin" or platform_machine
52
52
  Requires-Dist: onnxruntime-gpu>=1.22.0; (sys_platform == "linux" and platform_machine == "x86_64") or (sys_platform == "win32" and (platform_machine == "AMD64" or platform_machine == "x86_64"))
53
53
  Requires-Dist: onnxscript>=0.3.1
54
54
  Requires-Dist: onnx-ir>=0.1.4
55
- Requires-Dist: opentau-transformers==4.53.3
55
+ Requires-Dist: transformers==4.53.3
56
56
  Requires-Dist: scipy>=1.15.2
57
57
  Requires-Dist: pytest>=8.1.0
58
58
  Requires-Dist: pytest-cov>=5.0.0
@@ -20,7 +20,7 @@ huggingface = "https://huggingface.co/TensorAuto"
20
20
 
21
21
  [project]
22
22
  name = "opentau"
23
- version = "0.1.1"
23
+ version = "0.1.2"
24
24
  description = "OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch"
25
25
  authors = [
26
26
  { name = "Shuheng Liu", email = "wish1104@icloud.com" },
@@ -76,7 +76,7 @@ dependencies = [
76
76
  "onnxruntime-gpu>=1.22.0 ; ((sys_platform == 'linux' and platform_machine == 'x86_64') or (sys_platform == 'win32' and (platform_machine == 'AMD64' or platform_machine == 'x86_64'))) ",
77
77
  "onnxscript>=0.3.1",
78
78
  "onnx-ir>=0.1.4",
79
- "opentau-transformers==4.53.3",
79
+ "transformers==4.53.3",
80
80
  "scipy>=1.15.2",
81
81
  "pytest>=8.1.0",
82
82
  "pytest-cov>=5.0.0",
@@ -92,6 +92,7 @@ dependencies = [
92
92
  opentau-train = "opentau.scripts.launch:train"
93
93
  opentau-eval = "opentau.scripts.launch:eval"
94
94
  opentau-export = "opentau.scripts.launch:export"
95
+ opentau-dataset-viz = "opentau.scripts.launch:visualize"
95
96
 
96
97
  [project.optional-dependencies]
97
98
  dev = ["pre-commit>=3.7.0",
@@ -633,7 +633,9 @@ class BaseDataset(torch.utils.data.Dataset):
633
633
  For example, {"image_key": torch.zeros(2, 3, 224, 224), "image_key_is_pad": [False, True] } will become
634
634
  {
635
635
  "image_key": torch.zeros(3, 224, 224),
636
+ "image_key_local": torch.zeros(3, 224, 224),
636
637
  "image_key_is_pad: False,
638
+ "image_key_local_is_pad": True,
637
639
  }.
638
640
  """
639
641
  raise NotImplementedError
@@ -1787,16 +1789,12 @@ class LeRobotDataset(BaseDataset):
1787
1789
  cam_keys = {v for k, v in name_map.items() if k.startswith("camera")}
1788
1790
  for k in cam_keys:
1789
1791
  images = item.pop(k)
1790
- assert len(images) == 2, (
1791
- f"{k} in {self.__class__} is expected to have length 2, got shape={images.shape}"
1792
- )
1793
- item[k + "_local"], item[k] = images
1792
+ if len(images) == 2:
1793
+ item[k + "_local"], item[k] = images
1794
1794
 
1795
- pads = item.pop(k + "_is_pad")
1796
- assert len(pads) == 2, (
1797
- f"{k} in {self.__class__} is expected to have length 2, got shape={pads.shape}"
1798
- )
1799
- item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
1795
+ pads = item.get(k + "_is_pad")
1796
+ if hasattr(pads, "__len__") and len(pads) == 2:
1797
+ item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
1800
1798
 
1801
1799
  @staticmethod
1802
1800
  def compute_delta_params(
@@ -21,6 +21,7 @@ from types import ModuleType
21
21
  import opentau.scripts.eval as eval_script
22
22
  import opentau.scripts.export_to_onnx as export_script
23
23
  import opentau.scripts.train as train_script
24
+ import opentau.scripts.visualize_dataset as visualize_script
24
25
 
25
26
 
26
27
  def launch(script_module: ModuleType, description: str, use_accelerate: bool = True):
@@ -77,3 +78,7 @@ def eval():
77
78
 
78
79
  def export():
79
80
  launch(export_script, "Launch OpenTau ONNX export", use_accelerate=False)
81
+
82
+
83
+ def visualize():
84
+ launch(visualize_script, "Launch OpenTau visualization", use_accelerate=False)
@@ -14,7 +14,7 @@
14
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
- """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
17
+ """Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
18
18
 
19
19
  Note: The last frame of the episode doesn't always correspond to a final state.
20
20
  That's because our datasets are composed of transition from state to state up to
@@ -30,34 +30,21 @@ Examples:
30
30
 
31
31
  - Visualize data stored on a local machine:
32
32
  ```
33
- local$ python src/opentau/scripts/visualize_dataset.py \
34
- --repo-id lerobot/pusht \
35
- --episode-index 0
33
+ local$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0
36
34
  ```
37
35
 
38
36
  - Visualize data stored on a distant machine with a local viewer:
39
37
  ```
40
- distant$ python src/opentau/scripts/visualize_dataset.py \
41
- --repo-id lerobot/pusht \
42
- --episode-index 0 \
43
- --save 1 \
44
- --output-dir path/to/directory
38
+ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --save 1 --output-dir path/to/directory
45
39
 
46
40
  local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
47
41
  local$ rerun lerobot_pusht_episode_0.rrd
48
42
  ```
49
43
 
50
44
  - Visualize data stored on a distant machine through streaming:
51
- (You need to forward the websocket port to the distant machine, with
52
- `ssh -L 9087:localhost:9087 username@remote-host`)
53
45
  ```
54
- distant$ python src/opentau/scripts/visualize_dataset.py \
55
- --repo-id lerobot/pusht \
56
- --episode-index 0 \
57
- --mode distant \
58
- --ws-port 9087
59
46
 
60
- local$ rerun ws://localhost:9087
47
+ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --mode distant --web-port 9090
61
48
  ```
62
49
 
63
50
  """
@@ -75,8 +62,34 @@ import torch
75
62
  import torch.utils.data
76
63
  import tqdm
77
64
 
65
+ from opentau.configs.default import DatasetMixtureConfig, WandBConfig
66
+ from opentau.configs.train import TrainPipelineConfig
78
67
  from opentau.datasets.lerobot_dataset import LeRobotDataset
79
- from opentau.scripts.visualize_dataset_html import create_mock_train_config
68
+
69
+
70
+ def create_mock_train_config() -> TrainPipelineConfig:
71
+ """Create a mock TrainPipelineConfig for dataset visualization.
72
+
73
+ Returns:
74
+ TrainPipelineConfig: A mock config with default values.
75
+ """
76
+ return TrainPipelineConfig(
77
+ dataset_mixture=DatasetMixtureConfig(), # Will be set by the dataset
78
+ resolution=(224, 224),
79
+ num_cams=2,
80
+ max_state_dim=32,
81
+ max_action_dim=32,
82
+ action_chunk=50,
83
+ loss_weighting={"MSE": 1, "CE": 1},
84
+ num_workers=4,
85
+ batch_size=8,
86
+ steps=100_000,
87
+ log_freq=200,
88
+ save_checkpoint=True,
89
+ save_freq=20_000,
90
+ use_policy_training_preset=True,
91
+ wandb=WandBConfig(),
92
+ )
80
93
 
81
94
 
82
95
  class EpisodeSampler(torch.utils.data.Sampler):
@@ -108,7 +121,6 @@ def visualize_dataset(
108
121
  num_workers: int = 0,
109
122
  mode: str = "local",
110
123
  web_port: int = 9090,
111
- ws_port: int = 9087,
112
124
  save: bool = False,
113
125
  output_dir: Path | None = None,
114
126
  ) -> Path | None:
@@ -142,7 +154,7 @@ def visualize_dataset(
142
154
  gc.collect()
143
155
 
144
156
  if mode == "distant":
145
- rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
157
+ rr.serve_web_viewer(open_browser=False, web_port=web_port)
146
158
 
147
159
  logging.info("Logging to Rerun")
148
160
 
@@ -194,7 +206,7 @@ def visualize_dataset(
194
206
  print("Ctrl-C received. Exiting.")
195
207
 
196
208
 
197
- def main():
209
+ def parse_args() -> dict:
198
210
  parser = argparse.ArgumentParser()
199
211
 
200
212
  parser.add_argument(
@@ -250,12 +262,6 @@ def main():
250
262
  default=9090,
251
263
  help="Web port for rerun.io when `--mode distant` is set.",
252
264
  )
253
- parser.add_argument(
254
- "--ws-port",
255
- type=int,
256
- default=9087,
257
- help="Web socket port for rerun.io when `--mode distant` is set.",
258
- )
259
265
  parser.add_argument(
260
266
  "--save",
261
267
  type=int,
@@ -279,15 +285,25 @@ def main():
279
285
  )
280
286
 
281
287
  args = parser.parse_args()
282
- kwargs = vars(args)
288
+ return vars(args)
289
+
290
+
291
+ def main():
292
+ kwargs = parse_args()
283
293
  repo_id = kwargs.pop("repo_id")
284
294
  root = kwargs.pop("root")
285
295
  tolerance_s = kwargs.pop("tolerance_s")
286
296
 
287
297
  logging.info("Loading dataset")
288
- dataset = LeRobotDataset(create_mock_train_config(), repo_id, root=root, tolerance_s=tolerance_s)
298
+ dataset = LeRobotDataset(
299
+ create_mock_train_config(),
300
+ repo_id,
301
+ root=root,
302
+ tolerance_s=tolerance_s,
303
+ standardize=False,
304
+ )
289
305
 
290
- visualize_dataset(dataset, **vars(args))
306
+ visualize_dataset(dataset, **kwargs)
291
307
 
292
308
 
293
309
  if __name__ == "__main__":
@@ -0,0 +1,276 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for patching transformers
16
+
17
+ Most patches come from the branch fix/lerobot-openpi
18
+ """
19
+
20
+ from typing import Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers.models.gemma import modeling_gemma
25
+ from transformers.models.gemma.configuration_gemma import GemmaConfig
26
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaModel
27
+
28
+ # Monkey patch __init__ of GemmaConfig to fix or modify its behavior as needed.
29
+
30
+ _original_gemma_config_init = GemmaConfig.__init__
31
+
32
+
33
+ def patched_gemma_config_init(
34
+ self, *args, use_adarms: bool = False, adarms_cond_dim: Optional[int] = None, **kwargs
35
+ ):
36
+ """Initializes the GemmaConfig with added ADARMS support.
37
+
38
+ Args:
39
+ self: The GemmaConfig instance.
40
+ *args: Variable length argument list.
41
+ use_adarms: Whether to use Adaptive RMS normalization.
42
+ adarms_cond_dim: The dimension of the conditioning vector for ADARMS.
43
+ **kwargs: Arbitrary keyword arguments.
44
+ """
45
+ # Call the original init with all other arguments
46
+ _original_gemma_config_init(self, *args, **kwargs)
47
+
48
+ # Initialize custom attributes
49
+ self.use_adarms = use_adarms
50
+ self.adarms_cond_dim = adarms_cond_dim
51
+
52
+ # Set default for adarms_cond_dim if use_adarms is True
53
+ if self.use_adarms and self.adarms_cond_dim is None:
54
+ # hidden_size is set by _original_gemma_config_init
55
+ self.adarms_cond_dim = self.hidden_size
56
+
57
+
58
+ GemmaConfig.__init__ = patched_gemma_config_init
59
+
60
+
61
+ # --- Modeling Patches ---
62
+
63
+
64
+ def _gated_residual(x, y, gate):
65
+ """
66
+ Applies gated residual connection with optional gate parameter.
67
+
68
+ Args:
69
+ x: Input tensor (residual)
70
+ y: Output tensor to be added
71
+ gate: Optional gate tensor to modulate the addition
72
+
73
+ Returns:
74
+ x + y if gate is None, otherwise x + y * gate
75
+ """
76
+ if x is None and y is None:
77
+ return None
78
+ if x is None or y is None:
79
+ return x if x is not None else y
80
+ if gate is None:
81
+ return x + y
82
+ return x + y * gate
83
+
84
+
85
+ modeling_gemma._gated_residual = _gated_residual
86
+
87
+
88
+ class PatchedGemmaRMSNorm(nn.Module):
89
+ """RMS normalization with optional adaptive support (ADARMS)."""
90
+
91
+ def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None):
92
+ """Initializes the PatchedGemmaRMSNorm.
93
+
94
+ Args:
95
+ dim: The dimension of the input tensor.
96
+ eps: The epsilon value for numerical stability.
97
+ cond_dim: The dimension of the conditioning vector (if using ADARMS).
98
+ """
99
+ super().__init__()
100
+ self.eps = eps
101
+ self.dim = dim
102
+ self.cond_dim = cond_dim
103
+
104
+ # Dense layer for adaptive normalization (if cond_dim is provided)
105
+ if cond_dim is not None:
106
+ self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
107
+ # Initialize with zeros (matches source implementation)
108
+ nn.init.zeros_(self.dense.weight)
109
+ else:
110
+ self.weight = nn.Parameter(torch.zeros(dim))
111
+ self.dense = None
112
+
113
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
114
+ """Applies RMS normalization.
115
+
116
+ Args:
117
+ x: The input tensor.
118
+
119
+ Returns:
120
+ The normalized tensor.
121
+ """
122
+ # Compute variance in float32 (like the source implementation)
123
+ var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
124
+ # Compute normalization in float32
125
+ normed_inputs = x * torch.rsqrt(var + self.eps)
126
+ return normed_inputs
127
+
128
+ def forward(
129
+ self, x: torch.Tensor, cond: Optional[torch.Tensor] = None
130
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
131
+ """Forward pass of the normalization layer.
132
+
133
+ Args:
134
+ x: The input tensor.
135
+ cond: The conditioning tensor for adaptive normalization.
136
+
137
+ Returns:
138
+ A tuple containing the normalized tensor and the gate tensor (if applicable).
139
+ If cond is None, the gate tensor will be None.
140
+
141
+ Raises:
142
+ ValueError: If cond dimension does not match the configured cond_dim.
143
+ """
144
+ dtype = x.dtype # original dtype, could be half-precision
145
+ normed_inputs = self._norm(x)
146
+
147
+ if cond is None or self.dense is None:
148
+ # regular RMSNorm
149
+ # scale by learned parameter in float32 (matches source implementation)
150
+ normed_inputs = normed_inputs * (1.0 + self.weight.float())
151
+ return normed_inputs.to(dtype), None # return in original dtype with None gate
152
+
153
+ # adaptive RMSNorm (if cond is provided and dense layer exists)
154
+ if cond.shape[-1] != self.cond_dim:
155
+ raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
156
+
157
+ modulation = self.dense(cond)
158
+ # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
159
+ if len(x.shape) == 3: # [batch, seq, features]
160
+ modulation = modulation.unsqueeze(1)
161
+
162
+ scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
163
+
164
+ normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
165
+
166
+ return normed_inputs.to(dtype), gate.to(dtype)
167
+
168
+ def extra_repr(self) -> str:
169
+ """Returns the extra representation of the module."""
170
+ repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
171
+ if self.dense is not None:
172
+ repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
173
+ return repr_str
174
+
175
+
176
+ # Apply patches
177
+ modeling_gemma.GemmaRMSNorm = PatchedGemmaRMSNorm
178
+
179
+
180
+ def patched_gemma_decoder_layer_init(self, config: GemmaConfig, layer_idx: int):
181
+ """Initializes a GemmaDecoderLayer with potential ADARMS support.
182
+
183
+ Args:
184
+ self: The GemmaDecoderLayer instance.
185
+ config: The configuration object.
186
+ layer_idx: The index of the layer.
187
+ """
188
+ modeling_gemma.GradientCheckpointingLayer.__init__(self)
189
+ self.hidden_size = config.hidden_size
190
+
191
+ self.self_attn = modeling_gemma.GemmaAttention(config=config, layer_idx=layer_idx)
192
+
193
+ self.mlp = modeling_gemma.GemmaMLP(config)
194
+
195
+ cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
196
+ self.input_layernorm = modeling_gemma.GemmaRMSNorm(
197
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
198
+ )
199
+ self.post_attention_layernorm = modeling_gemma.GemmaRMSNorm(
200
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
201
+ )
202
+
203
+
204
+ modeling_gemma.GemmaDecoderLayer.__init__ = patched_gemma_decoder_layer_init
205
+
206
+
207
+ def patched_gemma_model_init(self, config: GemmaConfig):
208
+ """Initializes the GemmaModel with potential ADARMS support.
209
+
210
+ Args:
211
+ self: The GemmaModel instance.
212
+ config: The configuration object.
213
+ """
214
+ modeling_gemma.GemmaPreTrainedModel.__init__(self, config)
215
+ self.padding_idx = config.pad_token_id
216
+ self.vocab_size = config.vocab_size
217
+
218
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
219
+ self.layers = nn.ModuleList(
220
+ [modeling_gemma.GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
221
+ )
222
+
223
+ cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
224
+ self.norm = modeling_gemma.GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
225
+ self.rotary_emb = modeling_gemma.GemmaRotaryEmbedding(config=config)
226
+ self.gradient_checkpointing = False
227
+
228
+ # Initialize weights and apply final processing
229
+ self.post_init()
230
+
231
+
232
+ modeling_gemma.GemmaModel.__init__ = patched_gemma_model_init
233
+
234
+
235
+ def patched_gemma_pretrained_model_init_weights(self, module: nn.Module):
236
+ """Initializes the weights of the GemmaPreTrainedModel.
237
+
238
+ Args:
239
+ self: The GemmaPreTrainedModel instance.
240
+ module: The module to initialize.
241
+ """
242
+ std = self.config.initializer_range
243
+ if isinstance(module, nn.Linear):
244
+ module.weight.data.normal_(mean=0.0, std=std)
245
+ if module.bias is not None:
246
+ module.bias.data.zero_()
247
+ elif isinstance(module, nn.Embedding):
248
+ module.weight.data.normal_(mean=0.0, std=std)
249
+ if module.padding_idx is not None:
250
+ module.weight.data[module.padding_idx].zero_()
251
+ elif isinstance(module, modeling_gemma.GemmaRMSNorm):
252
+ if hasattr(module, "weight"):
253
+ module.weight.data.fill_(1.0)
254
+
255
+
256
+ modeling_gemma.GemmaPreTrainedModel._init_weights = patched_gemma_pretrained_model_init_weights
257
+
258
+
259
+ def patched_paligemma_model_get_image_features(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
260
+ """Obtains image last hidden states from the vision tower and apply multimodal projection.
261
+
262
+ Args:
263
+ self: The PaliGemmaModel instance.
264
+ pixel_values: The tensors corresponding to the input images.
265
+ Shape: (batch_size, channels, height, width).
266
+
267
+ Returns:
268
+ Image feature tensor of shape (num_images, image_length, embed_dim).
269
+ """
270
+ image_outputs = self.vision_tower(pixel_values)
271
+ selected_image_feature = image_outputs.last_hidden_state
272
+ image_features = self.multi_modal_projector(selected_image_feature)
273
+ return image_features
274
+
275
+
276
+ PaliGemmaModel.get_image_features = patched_paligemma_model_get_image_features
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opentau
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch
5
5
  Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>, Xingrui Gu <xingrui_gu@berkeley.edu>
6
6
  License: Apache-2.0
@@ -52,7 +52,7 @@ Requires-Dist: onnxruntime>=1.22.1; sys_platform == "darwin" or platform_machine
52
52
  Requires-Dist: onnxruntime-gpu>=1.22.0; (sys_platform == "linux" and platform_machine == "x86_64") or (sys_platform == "win32" and (platform_machine == "AMD64" or platform_machine == "x86_64"))
53
53
  Requires-Dist: onnxscript>=0.3.1
54
54
  Requires-Dist: onnx-ir>=0.1.4
55
- Requires-Dist: opentau-transformers==4.53.3
55
+ Requires-Dist: transformers==4.53.3
56
56
  Requires-Dist: scipy>=1.15.2
57
57
  Requires-Dist: pytest>=8.1.0
58
58
  Requires-Dist: pytest-cov>=5.0.0
@@ -93,7 +93,6 @@ src/opentau/scripts/libero_simulation_sequential.py
93
93
  src/opentau/scripts/nav_high_level_planner_inference.py
94
94
  src/opentau/scripts/train.py
95
95
  src/opentau/scripts/visualize_dataset.py
96
- src/opentau/scripts/visualize_dataset_html.py
97
96
  src/opentau/scripts/zero_to_fp32.py
98
97
  src/opentau/utils/__init__.py
99
98
  src/opentau/utils/accelerate_utils.py
@@ -1,4 +1,5 @@
1
1
  [console_scripts]
2
+ opentau-dataset-viz = opentau.scripts.launch:visualize
2
3
  opentau-eval = opentau.scripts.launch:eval
3
4
  opentau-export = opentau.scripts.launch:export
4
5
  opentau-train = opentau.scripts.launch:train
@@ -27,7 +27,7 @@ scikit-learn>=1.7.1
27
27
  onnx>=1.18.0
28
28
  onnxscript>=0.3.1
29
29
  onnx-ir>=0.1.4
30
- opentau-transformers==4.53.3
30
+ transformers==4.53.3
31
31
  scipy>=1.15.2
32
32
  pytest>=8.1.0
33
33
  pytest-cov>=5.0.0