opentau 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. {opentau-0.1.0/src/opentau.egg-info → opentau-0.1.2}/PKG-INFO +8 -5
  2. {opentau-0.1.0 → opentau-0.1.2}/README.md +3 -3
  3. {opentau-0.1.0 → opentau-0.1.2}/pyproject.toml +9 -3
  4. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/__init__.py +1 -0
  5. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/lerobot_dataset.py +7 -9
  6. opentau-0.1.2/src/opentau/scripts/launch.py +84 -0
  7. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/train.py +8 -8
  8. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/visualize_dataset.py +46 -30
  9. opentau-0.1.2/src/opentau/utils/transformers_patch.py +276 -0
  10. {opentau-0.1.0 → opentau-0.1.2/src/opentau.egg-info}/PKG-INFO +8 -5
  11. {opentau-0.1.0 → opentau-0.1.2}/src/opentau.egg-info/SOURCES.txt +2 -2
  12. opentau-0.1.2/src/opentau.egg-info/entry_points.txt +5 -0
  13. {opentau-0.1.0 → opentau-0.1.2}/src/opentau.egg-info/requires.txt +4 -1
  14. opentau-0.1.0/src/opentau/scripts/launch_train.py +0 -63
  15. opentau-0.1.0/src/opentau/scripts/visualize_dataset_html.py +0 -507
  16. opentau-0.1.0/src/opentau.egg-info/entry_points.txt +0 -2
  17. {opentau-0.1.0 → opentau-0.1.2}/LICENSE +0 -0
  18. {opentau-0.1.0 → opentau-0.1.2}/setup.cfg +0 -0
  19. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/__version__.py +0 -0
  20. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/__init__.py +0 -0
  21. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/default.py +0 -0
  22. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/libero.py +0 -0
  23. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/parser.py +0 -0
  24. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/policies.py +0 -0
  25. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/reward.py +0 -0
  26. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/train.py +0 -0
  27. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/types.py +0 -0
  28. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/constants.py +0 -0
  29. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/__init__.py +0 -0
  30. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/backward_compatibility.py +0 -0
  31. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/compute_stats.py +0 -0
  32. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/dataset_mixture.py +0 -0
  33. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/factory.py +0 -0
  34. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/__init__.py +0 -0
  35. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/base.py +0 -0
  36. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/clevr.py +0 -0
  37. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/cocoqa.py +0 -0
  38. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/dummy.py +0 -0
  39. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/pixmo.py +0 -0
  40. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/vsr.py +0 -0
  41. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/image_writer.py +0 -0
  42. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/online_buffer.py +0 -0
  43. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/push_dataset_to_hub/utils.py +0 -0
  44. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/sampler.py +0 -0
  45. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/standard_data_format_mapping.py +0 -0
  46. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/transforms.py +0 -0
  47. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/utils.py +0 -0
  48. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +0 -0
  49. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v2/convert_dataset_v1_to_v2.py +0 -0
  50. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v21/_remove_language_instruction.py +0 -0
  51. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +0 -0
  52. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v21/convert_dataset_v20_to_v21.py +0 -0
  53. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v21/convert_stats.py +0 -0
  54. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/video_utils.py +0 -0
  55. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/__init__.py +0 -0
  56. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/configs.py +0 -0
  57. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/factory.py +0 -0
  58. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/libero.py +0 -0
  59. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/utils.py +0 -0
  60. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/optim/__init__.py +0 -0
  61. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/optim/factory.py +0 -0
  62. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/optim/optimizers.py +0 -0
  63. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/optim/schedulers.py +0 -0
  64. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/planner/__init__.py +0 -0
  65. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/planner/high_level_planner.py +0 -0
  66. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/planner/utils/memory.py +0 -0
  67. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/planner/utils/utils.py +0 -0
  68. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/__init__.py +0 -0
  69. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/factory.py +0 -0
  70. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/normalize.py +0 -0
  71. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi0/__init__.py +0 -0
  72. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi0/configuration_pi0.py +0 -0
  73. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi0/modeling_pi0.py +0 -0
  74. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi0/paligemma_with_expert.py +0 -0
  75. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi05/__init__.py +0 -0
  76. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi05/configuration_pi05.py +0 -0
  77. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi05/modeling_pi05.py +0 -0
  78. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi05/paligemma_with_expert.py +0 -0
  79. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pretrained.py +0 -0
  80. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/utils.py +0 -0
  81. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/__init__.py +0 -0
  82. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/configuration_value.py +0 -0
  83. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/modeling_value.py +0 -0
  84. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/reward.py +0 -0
  85. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/siglip_gemma.py +0 -0
  86. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/actions_mse_loss.py +0 -0
  87. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/bin_to_safetensors.py +0 -0
  88. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/compute_max_token_length.py +0 -0
  89. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/display_sys_info.py +0 -0
  90. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/download_libero_benchmarks.py +0 -0
  91. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/eval.py +0 -0
  92. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/export_to_onnx.py +0 -0
  93. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/fake_tensor_training.py +0 -0
  94. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/get_advantage_and_percentiles.py +0 -0
  95. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/high_level_planner_inference.py +0 -0
  96. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/inference.py +0 -0
  97. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/libero_simulation_parallel.py +0 -0
  98. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/libero_simulation_sequential.py +0 -0
  99. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/nav_high_level_planner_inference.py +0 -0
  100. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/zero_to_fp32.py +0 -0
  101. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/__init__.py +0 -0
  102. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/accelerate_utils.py +0 -0
  103. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/benchmark.py +0 -0
  104. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/fake_tensor.py +0 -0
  105. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/hub.py +0 -0
  106. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/import_utils.py +0 -0
  107. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/io_utils.py +0 -0
  108. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/libero.py +0 -0
  109. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/libero_dataset_recorder.py +0 -0
  110. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/logging_utils.py +0 -0
  111. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/monkey_patch.py +0 -0
  112. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/random_utils.py +0 -0
  113. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/train_utils.py +0 -0
  114. {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/utils.py +0 -0
  115. {opentau-0.1.0 → opentau-0.1.2}/src/opentau.egg-info/dependency_links.txt +0 -0
  116. {opentau-0.1.0 → opentau-0.1.2}/src/opentau.egg-info/top_level.txt +0 -0
  117. {opentau-0.1.0 → opentau-0.1.2}/tests/test_available.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opentau
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch
5
5
  Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>, Xingrui Gu <xingrui_gu@berkeley.edu>
6
6
  License: Apache-2.0
@@ -52,7 +52,7 @@ Requires-Dist: onnxruntime>=1.22.1; sys_platform == "darwin" or platform_machine
52
52
  Requires-Dist: onnxruntime-gpu>=1.22.0; (sys_platform == "linux" and platform_machine == "x86_64") or (sys_platform == "win32" and (platform_machine == "AMD64" or platform_machine == "x86_64"))
53
53
  Requires-Dist: onnxscript>=0.3.1
54
54
  Requires-Dist: onnx-ir>=0.1.4
55
- Requires-Dist: opentau-transformers==4.53.3
55
+ Requires-Dist: transformers==4.53.3
56
56
  Requires-Dist: scipy>=1.15.2
57
57
  Requires-Dist: pytest>=8.1.0
58
58
  Requires-Dist: pytest-cov>=5.0.0
@@ -94,6 +94,9 @@ Requires-Dist: numpy<2; extra == "libero"
94
94
  Requires-Dist: gym<0.27,>=0.25; extra == "libero"
95
95
  Requires-Dist: pyopengl-accelerate==3.1.7; sys_platform == "linux" and extra == "libero"
96
96
  Requires-Dist: gymnasium[other]>=0.29; extra == "libero"
97
+ Requires-Dist: mujoco>=3.1.6; sys_platform == "linux" and extra == "libero"
98
+ Requires-Dist: pyopengl==3.1.7; sys_platform == "linux" and extra == "libero"
99
+ Requires-Dist: numpy==1.26.4; sys_platform == "linux" and extra == "libero"
97
100
  Dynamic: license-file
98
101
 
99
102
  <p align="center">
@@ -134,10 +137,10 @@ OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we
134
137
  ## Quick Start
135
138
  If you are familiar with LeRobot, getting started with OpenTau is very easy.
136
139
  Because OpenTau is a fork of the popular LeRobot repository, any LeRobot-compliant policy and dataset can be used directly with OpenTau.
137
- Check out our documentation to get started quickly.
138
- We provide a quick start guide to help you get started with OpenTau.
140
+ Check out our [documentation](https://opentau.readthedocs.io/) to get started quickly.
141
+ We provide a [quick start guide](https://opentau.readthedocs.io/en/latest/getting_started.html) to help you get started with OpenTau.
139
142
 
140
- For using local notebooks to train and evaluate models, find the notebooks at `notebooks/pi05_training.ipynb` and `notebooks/pi05_evaluation_only.ipynb`.
143
+ For using local notebooks to train and evaluate models, find the notebooks at [notebooks/pi05_training.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_training.ipynb) and [notebooks/pi05_evaluation_only.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_evaluation_only.ipynb).
141
144
 
142
145
  For using the Google Colab notebooks to train and evaluate models, find the colab notebooks here: [pi05_training](https://colab.research.google.com/drive/1DeU0lNnEzs1KHo0Nkgh4YKBr-xu9moBM?usp=sharing) and [pi05_evaluation_only](https://colab.research.google.com/drive/1U_AyuH9WYMT4anEWvsOtIT7g01jA0WGm?usp=sharing) respectively.
143
146
 
@@ -36,10 +36,10 @@ OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we
36
36
  ## Quick Start
37
37
  If you are familiar with LeRobot, getting started with OpenTau is very easy.
38
38
  Because OpenTau is a fork of the popular LeRobot repository, any LeRobot-compliant policy and dataset can be used directly with OpenTau.
39
- Check out our documentation to get started quickly.
40
- We provide a quick start guide to help you get started with OpenTau.
39
+ Check out our [documentation](https://opentau.readthedocs.io/) to get started quickly.
40
+ We provide a [quick start guide](https://opentau.readthedocs.io/en/latest/getting_started.html) to help you get started with OpenTau.
41
41
 
42
- For using local notebooks to train and evaluate models, find the notebooks at `notebooks/pi05_training.ipynb` and `notebooks/pi05_evaluation_only.ipynb`.
42
+ For using local notebooks to train and evaluate models, find the notebooks at [notebooks/pi05_training.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_training.ipynb) and [notebooks/pi05_evaluation_only.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_evaluation_only.ipynb).
43
43
 
44
44
  For using the Google Colab notebooks to train and evaluate models, find the colab notebooks here: [pi05_training](https://colab.research.google.com/drive/1DeU0lNnEzs1KHo0Nkgh4YKBr-xu9moBM?usp=sharing) and [pi05_evaluation_only](https://colab.research.google.com/drive/1U_AyuH9WYMT4anEWvsOtIT7g01jA0WGm?usp=sharing) respectively.
45
45
 
@@ -20,7 +20,7 @@ huggingface = "https://huggingface.co/TensorAuto"
20
20
 
21
21
  [project]
22
22
  name = "opentau"
23
- version = "0.1.0"
23
+ version = "0.1.2"
24
24
  description = "OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch"
25
25
  authors = [
26
26
  { name = "Shuheng Liu", email = "wish1104@icloud.com" },
@@ -76,7 +76,7 @@ dependencies = [
76
76
  "onnxruntime-gpu>=1.22.0 ; ((sys_platform == 'linux' and platform_machine == 'x86_64') or (sys_platform == 'win32' and (platform_machine == 'AMD64' or platform_machine == 'x86_64'))) ",
77
77
  "onnxscript>=0.3.1",
78
78
  "onnx-ir>=0.1.4",
79
- "opentau-transformers==4.53.3",
79
+ "transformers==4.53.3",
80
80
  "scipy>=1.15.2",
81
81
  "pytest>=8.1.0",
82
82
  "pytest-cov>=5.0.0",
@@ -89,7 +89,10 @@ dependencies = [
89
89
  ]
90
90
 
91
91
  [project.scripts]
92
- opentau-train = "opentau.scripts.launch_train:main"
92
+ opentau-train = "opentau.scripts.launch:train"
93
+ opentau-eval = "opentau.scripts.launch:eval"
94
+ opentau-export = "opentau.scripts.launch:export"
95
+ opentau-dataset-viz = "opentau.scripts.launch:visualize"
93
96
 
94
97
  [project.optional-dependencies]
95
98
  dev = ["pre-commit>=3.7.0",
@@ -122,6 +125,9 @@ libero = [
122
125
  "gym>=0.25,<0.27",
123
126
  "pyopengl-accelerate==3.1.7 ; sys_platform == 'linux'",
124
127
  "gymnasium[other]>=0.29",
128
+ "mujoco>=3.1.6 ; sys_platform == 'linux'",
129
+ "pyopengl==3.1.7 ; sys_platform == 'linux'",
130
+ "numpy==1.26.4 ; sys_platform == 'linux'",
125
131
  ]
126
132
 
127
133
  [tool.uv.sources]
@@ -56,6 +56,7 @@ When implementing a new policy class (e.g., `DiffusionPolicy`), follow these ste
56
56
  import itertools
57
57
 
58
58
  from opentau.__version__ import __version__ # noqa: F401
59
+ from opentau.utils import transformers_patch # noqa: F401
59
60
 
60
61
  # TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
61
62
  # refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
@@ -633,7 +633,9 @@ class BaseDataset(torch.utils.data.Dataset):
633
633
  For example, {"image_key": torch.zeros(2, 3, 224, 224), "image_key_is_pad": [False, True] } will become
634
634
  {
635
635
  "image_key": torch.zeros(3, 224, 224),
636
+ "image_key_local": torch.zeros(3, 224, 224),
636
637
  "image_key_is_pad: False,
638
+ "image_key_local_is_pad": True,
637
639
  }.
638
640
  """
639
641
  raise NotImplementedError
@@ -1787,16 +1789,12 @@ class LeRobotDataset(BaseDataset):
1787
1789
  cam_keys = {v for k, v in name_map.items() if k.startswith("camera")}
1788
1790
  for k in cam_keys:
1789
1791
  images = item.pop(k)
1790
- assert len(images) == 2, (
1791
- f"{k} in {self.__class__} is expected to have length 2, got shape={images.shape}"
1792
- )
1793
- item[k + "_local"], item[k] = images
1792
+ if len(images) == 2:
1793
+ item[k + "_local"], item[k] = images
1794
1794
 
1795
- pads = item.pop(k + "_is_pad")
1796
- assert len(pads) == 2, (
1797
- f"{k} in {self.__class__} is expected to have length 2, got shape={pads.shape}"
1798
- )
1799
- item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
1795
+ pads = item.get(k + "_is_pad")
1796
+ if hasattr(pads, "__len__") and len(pads) == 2:
1797
+ item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
1800
1798
 
1801
1799
  @staticmethod
1802
1800
  def compute_delta_params(
@@ -0,0 +1,84 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import subprocess
17
+ import sys
18
+ from pathlib import Path
19
+ from types import ModuleType
20
+
21
+ import opentau.scripts.eval as eval_script
22
+ import opentau.scripts.export_to_onnx as export_script
23
+ import opentau.scripts.train as train_script
24
+ import opentau.scripts.visualize_dataset as visualize_script
25
+
26
+
27
+ def launch(script_module: ModuleType, description: str, use_accelerate: bool = True):
28
+ """Generic launcher for OpenTau scripts using Accelerate or Python."""
29
+ parser = argparse.ArgumentParser(
30
+ description=description,
31
+ usage=f"{Path(sys.argv[0]).name} {'[--accelerate-config CONFIG] ' if use_accelerate else ''}[ARGS]",
32
+ )
33
+ if use_accelerate:
34
+ parser.add_argument(
35
+ "--accelerate-config", type=str, help="Path to accelerate config file (yaml)", default=None
36
+ )
37
+
38
+ # We use parse_known_args so that all other arguments are collected
39
+ # These will be passed to the target script
40
+ args, unknown_args = parser.parse_known_args()
41
+
42
+ # Base command
43
+ if use_accelerate:
44
+ cmd = ["accelerate", "launch"]
45
+ # Add accelerate config if provided
46
+ if args.accelerate_config:
47
+ cmd.extend(["--config_file", args.accelerate_config])
48
+ else:
49
+ cmd = [sys.executable]
50
+
51
+ # Add the path to the target script
52
+ # We resolve the path to ensure it's absolute
53
+ script_path = Path(script_module.__file__).resolve()
54
+ cmd.append(str(script_path))
55
+
56
+ # Add all other arguments (passed to the target script)
57
+ cmd.extend(unknown_args)
58
+
59
+ # Print the command for transparency
60
+ print(f"Executing: {' '.join(cmd)}")
61
+
62
+ # Replace the current process with the accelerate launch command
63
+ try:
64
+ subprocess.run(cmd, check=True)
65
+ except subprocess.CalledProcessError as e:
66
+ sys.exit(e.returncode)
67
+ except KeyboardInterrupt:
68
+ sys.exit(130)
69
+
70
+
71
+ def train():
72
+ launch(train_script, "Launch OpenTau training with Accelerate")
73
+
74
+
75
+ def eval():
76
+ launch(eval_script, "Launch OpenTau evaluation with Accelerate")
77
+
78
+
79
+ def export():
80
+ launch(export_script, "Launch OpenTau ONNX export", use_accelerate=False)
81
+
82
+
83
+ def visualize():
84
+ launch(visualize_script, "Launch OpenTau visualization", use_accelerate=False)
@@ -73,16 +73,16 @@ def update_policy(
73
73
  train_config.loss_weighting["MSE"] * losses["MSE"] + train_config.loss_weighting["CE"] * losses["CE"]
74
74
  )
75
75
 
76
- # accelerator.backward(loss)
77
- # accelerator.unscale_gradients(optimizer=optimizer)
76
+ accelerator.backward(loss)
77
+ accelerator.unscale_gradients(optimizer=optimizer)
78
78
 
79
- # if accelerator.sync_gradients:
80
- # grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
81
- # if accelerator.is_main_process:
82
- # train_metrics.grad_norm = grad_norm
79
+ if accelerator.sync_gradients:
80
+ grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
81
+ if accelerator.is_main_process:
82
+ train_metrics.grad_norm = grad_norm
83
83
 
84
- # optimizer.step()
85
- # optimizer.zero_grad()
84
+ optimizer.step()
85
+ optimizer.zero_grad()
86
86
 
87
87
  # Step through pytorch scheduler at every batch instead of epoch
88
88
  if lr_scheduler is not None:
@@ -14,7 +14,7 @@
14
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
- """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
17
+ """Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
18
18
 
19
19
  Note: The last frame of the episode doesn't always correspond to a final state.
20
20
  That's because our datasets are composed of transition from state to state up to
@@ -30,34 +30,21 @@ Examples:
30
30
 
31
31
  - Visualize data stored on a local machine:
32
32
  ```
33
- local$ python src/opentau/scripts/visualize_dataset.py \
34
- --repo-id lerobot/pusht \
35
- --episode-index 0
33
+ local$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0
36
34
  ```
37
35
 
38
36
  - Visualize data stored on a distant machine with a local viewer:
39
37
  ```
40
- distant$ python src/opentau/scripts/visualize_dataset.py \
41
- --repo-id lerobot/pusht \
42
- --episode-index 0 \
43
- --save 1 \
44
- --output-dir path/to/directory
38
+ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --save 1 --output-dir path/to/directory
45
39
 
46
40
  local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
47
41
  local$ rerun lerobot_pusht_episode_0.rrd
48
42
  ```
49
43
 
50
44
  - Visualize data stored on a distant machine through streaming:
51
- (You need to forward the websocket port to the distant machine, with
52
- `ssh -L 9087:localhost:9087 username@remote-host`)
53
45
  ```
54
- distant$ python src/opentau/scripts/visualize_dataset.py \
55
- --repo-id lerobot/pusht \
56
- --episode-index 0 \
57
- --mode distant \
58
- --ws-port 9087
59
46
 
60
- local$ rerun ws://localhost:9087
47
+ distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --mode distant --web-port 9090
61
48
  ```
62
49
 
63
50
  """
@@ -75,8 +62,34 @@ import torch
75
62
  import torch.utils.data
76
63
  import tqdm
77
64
 
65
+ from opentau.configs.default import DatasetMixtureConfig, WandBConfig
66
+ from opentau.configs.train import TrainPipelineConfig
78
67
  from opentau.datasets.lerobot_dataset import LeRobotDataset
79
- from opentau.scripts.visualize_dataset_html import create_mock_train_config
68
+
69
+
70
+ def create_mock_train_config() -> TrainPipelineConfig:
71
+ """Create a mock TrainPipelineConfig for dataset visualization.
72
+
73
+ Returns:
74
+ TrainPipelineConfig: A mock config with default values.
75
+ """
76
+ return TrainPipelineConfig(
77
+ dataset_mixture=DatasetMixtureConfig(), # Will be set by the dataset
78
+ resolution=(224, 224),
79
+ num_cams=2,
80
+ max_state_dim=32,
81
+ max_action_dim=32,
82
+ action_chunk=50,
83
+ loss_weighting={"MSE": 1, "CE": 1},
84
+ num_workers=4,
85
+ batch_size=8,
86
+ steps=100_000,
87
+ log_freq=200,
88
+ save_checkpoint=True,
89
+ save_freq=20_000,
90
+ use_policy_training_preset=True,
91
+ wandb=WandBConfig(),
92
+ )
80
93
 
81
94
 
82
95
  class EpisodeSampler(torch.utils.data.Sampler):
@@ -108,7 +121,6 @@ def visualize_dataset(
108
121
  num_workers: int = 0,
109
122
  mode: str = "local",
110
123
  web_port: int = 9090,
111
- ws_port: int = 9087,
112
124
  save: bool = False,
113
125
  output_dir: Path | None = None,
114
126
  ) -> Path | None:
@@ -142,7 +154,7 @@ def visualize_dataset(
142
154
  gc.collect()
143
155
 
144
156
  if mode == "distant":
145
- rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
157
+ rr.serve_web_viewer(open_browser=False, web_port=web_port)
146
158
 
147
159
  logging.info("Logging to Rerun")
148
160
 
@@ -194,7 +206,7 @@ def visualize_dataset(
194
206
  print("Ctrl-C received. Exiting.")
195
207
 
196
208
 
197
- def main():
209
+ def parse_args() -> dict:
198
210
  parser = argparse.ArgumentParser()
199
211
 
200
212
  parser.add_argument(
@@ -250,12 +262,6 @@ def main():
250
262
  default=9090,
251
263
  help="Web port for rerun.io when `--mode distant` is set.",
252
264
  )
253
- parser.add_argument(
254
- "--ws-port",
255
- type=int,
256
- default=9087,
257
- help="Web socket port for rerun.io when `--mode distant` is set.",
258
- )
259
265
  parser.add_argument(
260
266
  "--save",
261
267
  type=int,
@@ -279,15 +285,25 @@ def main():
279
285
  )
280
286
 
281
287
  args = parser.parse_args()
282
- kwargs = vars(args)
288
+ return vars(args)
289
+
290
+
291
+ def main():
292
+ kwargs = parse_args()
283
293
  repo_id = kwargs.pop("repo_id")
284
294
  root = kwargs.pop("root")
285
295
  tolerance_s = kwargs.pop("tolerance_s")
286
296
 
287
297
  logging.info("Loading dataset")
288
- dataset = LeRobotDataset(create_mock_train_config(), repo_id, root=root, tolerance_s=tolerance_s)
298
+ dataset = LeRobotDataset(
299
+ create_mock_train_config(),
300
+ repo_id,
301
+ root=root,
302
+ tolerance_s=tolerance_s,
303
+ standardize=False,
304
+ )
289
305
 
290
- visualize_dataset(dataset, **vars(args))
306
+ visualize_dataset(dataset, **kwargs)
291
307
 
292
308
 
293
309
  if __name__ == "__main__":
@@ -0,0 +1,276 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for patching transformers
16
+
17
+ Most patches come from the branch fix/lerobot-openpi
18
+ """
19
+
20
+ from typing import Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers.models.gemma import modeling_gemma
25
+ from transformers.models.gemma.configuration_gemma import GemmaConfig
26
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaModel
27
+
28
+ # Monkey patch __init__ of GemmaConfig to fix or modify its behavior as needed.
29
+
30
+ _original_gemma_config_init = GemmaConfig.__init__
31
+
32
+
33
+ def patched_gemma_config_init(
34
+ self, *args, use_adarms: bool = False, adarms_cond_dim: Optional[int] = None, **kwargs
35
+ ):
36
+ """Initializes the GemmaConfig with added ADARMS support.
37
+
38
+ Args:
39
+ self: The GemmaConfig instance.
40
+ *args: Variable length argument list.
41
+ use_adarms: Whether to use Adaptive RMS normalization.
42
+ adarms_cond_dim: The dimension of the conditioning vector for ADARMS.
43
+ **kwargs: Arbitrary keyword arguments.
44
+ """
45
+ # Call the original init with all other arguments
46
+ _original_gemma_config_init(self, *args, **kwargs)
47
+
48
+ # Initialize custom attributes
49
+ self.use_adarms = use_adarms
50
+ self.adarms_cond_dim = adarms_cond_dim
51
+
52
+ # Set default for adarms_cond_dim if use_adarms is True
53
+ if self.use_adarms and self.adarms_cond_dim is None:
54
+ # hidden_size is set by _original_gemma_config_init
55
+ self.adarms_cond_dim = self.hidden_size
56
+
57
+
58
+ GemmaConfig.__init__ = patched_gemma_config_init
59
+
60
+
61
+ # --- Modeling Patches ---
62
+
63
+
64
+ def _gated_residual(x, y, gate):
65
+ """
66
+ Applies gated residual connection with optional gate parameter.
67
+
68
+ Args:
69
+ x: Input tensor (residual)
70
+ y: Output tensor to be added
71
+ gate: Optional gate tensor to modulate the addition
72
+
73
+ Returns:
74
+ x + y if gate is None, otherwise x + y * gate
75
+ """
76
+ if x is None and y is None:
77
+ return None
78
+ if x is None or y is None:
79
+ return x if x is not None else y
80
+ if gate is None:
81
+ return x + y
82
+ return x + y * gate
83
+
84
+
85
+ modeling_gemma._gated_residual = _gated_residual
86
+
87
+
88
+ class PatchedGemmaRMSNorm(nn.Module):
89
+ """RMS normalization with optional adaptive support (ADARMS)."""
90
+
91
+ def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None):
92
+ """Initializes the PatchedGemmaRMSNorm.
93
+
94
+ Args:
95
+ dim: The dimension of the input tensor.
96
+ eps: The epsilon value for numerical stability.
97
+ cond_dim: The dimension of the conditioning vector (if using ADARMS).
98
+ """
99
+ super().__init__()
100
+ self.eps = eps
101
+ self.dim = dim
102
+ self.cond_dim = cond_dim
103
+
104
+ # Dense layer for adaptive normalization (if cond_dim is provided)
105
+ if cond_dim is not None:
106
+ self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
107
+ # Initialize with zeros (matches source implementation)
108
+ nn.init.zeros_(self.dense.weight)
109
+ else:
110
+ self.weight = nn.Parameter(torch.zeros(dim))
111
+ self.dense = None
112
+
113
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
114
+ """Applies RMS normalization.
115
+
116
+ Args:
117
+ x: The input tensor.
118
+
119
+ Returns:
120
+ The normalized tensor.
121
+ """
122
+ # Compute variance in float32 (like the source implementation)
123
+ var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
124
+ # Compute normalization in float32
125
+ normed_inputs = x * torch.rsqrt(var + self.eps)
126
+ return normed_inputs
127
+
128
+ def forward(
129
+ self, x: torch.Tensor, cond: Optional[torch.Tensor] = None
130
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
131
+ """Forward pass of the normalization layer.
132
+
133
+ Args:
134
+ x: The input tensor.
135
+ cond: The conditioning tensor for adaptive normalization.
136
+
137
+ Returns:
138
+ A tuple containing the normalized tensor and the gate tensor (if applicable).
139
+ If cond is None, the gate tensor will be None.
140
+
141
+ Raises:
142
+ ValueError: If cond dimension does not match the configured cond_dim.
143
+ """
144
+ dtype = x.dtype # original dtype, could be half-precision
145
+ normed_inputs = self._norm(x)
146
+
147
+ if cond is None or self.dense is None:
148
+ # regular RMSNorm
149
+ # scale by learned parameter in float32 (matches source implementation)
150
+ normed_inputs = normed_inputs * (1.0 + self.weight.float())
151
+ return normed_inputs.to(dtype), None # return in original dtype with None gate
152
+
153
+ # adaptive RMSNorm (if cond is provided and dense layer exists)
154
+ if cond.shape[-1] != self.cond_dim:
155
+ raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
156
+
157
+ modulation = self.dense(cond)
158
+ # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
159
+ if len(x.shape) == 3: # [batch, seq, features]
160
+ modulation = modulation.unsqueeze(1)
161
+
162
+ scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
163
+
164
+ normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
165
+
166
+ return normed_inputs.to(dtype), gate.to(dtype)
167
+
168
+ def extra_repr(self) -> str:
169
+ """Returns the extra representation of the module."""
170
+ repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
171
+ if self.dense is not None:
172
+ repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
173
+ return repr_str
174
+
175
+
176
+ # Apply patches
177
+ modeling_gemma.GemmaRMSNorm = PatchedGemmaRMSNorm
178
+
179
+
180
+ def patched_gemma_decoder_layer_init(self, config: GemmaConfig, layer_idx: int):
181
+ """Initializes a GemmaDecoderLayer with potential ADARMS support.
182
+
183
+ Args:
184
+ self: The GemmaDecoderLayer instance.
185
+ config: The configuration object.
186
+ layer_idx: The index of the layer.
187
+ """
188
+ modeling_gemma.GradientCheckpointingLayer.__init__(self)
189
+ self.hidden_size = config.hidden_size
190
+
191
+ self.self_attn = modeling_gemma.GemmaAttention(config=config, layer_idx=layer_idx)
192
+
193
+ self.mlp = modeling_gemma.GemmaMLP(config)
194
+
195
+ cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
196
+ self.input_layernorm = modeling_gemma.GemmaRMSNorm(
197
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
198
+ )
199
+ self.post_attention_layernorm = modeling_gemma.GemmaRMSNorm(
200
+ config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
201
+ )
202
+
203
+
204
+ modeling_gemma.GemmaDecoderLayer.__init__ = patched_gemma_decoder_layer_init
205
+
206
+
207
+ def patched_gemma_model_init(self, config: GemmaConfig):
208
+ """Initializes the GemmaModel with potential ADARMS support.
209
+
210
+ Args:
211
+ self: The GemmaModel instance.
212
+ config: The configuration object.
213
+ """
214
+ modeling_gemma.GemmaPreTrainedModel.__init__(self, config)
215
+ self.padding_idx = config.pad_token_id
216
+ self.vocab_size = config.vocab_size
217
+
218
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
219
+ self.layers = nn.ModuleList(
220
+ [modeling_gemma.GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
221
+ )
222
+
223
+ cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
224
+ self.norm = modeling_gemma.GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
225
+ self.rotary_emb = modeling_gemma.GemmaRotaryEmbedding(config=config)
226
+ self.gradient_checkpointing = False
227
+
228
+ # Initialize weights and apply final processing
229
+ self.post_init()
230
+
231
+
232
+ modeling_gemma.GemmaModel.__init__ = patched_gemma_model_init
233
+
234
+
235
+ def patched_gemma_pretrained_model_init_weights(self, module: nn.Module):
236
+ """Initializes the weights of the GemmaPreTrainedModel.
237
+
238
+ Args:
239
+ self: The GemmaPreTrainedModel instance.
240
+ module: The module to initialize.
241
+ """
242
+ std = self.config.initializer_range
243
+ if isinstance(module, nn.Linear):
244
+ module.weight.data.normal_(mean=0.0, std=std)
245
+ if module.bias is not None:
246
+ module.bias.data.zero_()
247
+ elif isinstance(module, nn.Embedding):
248
+ module.weight.data.normal_(mean=0.0, std=std)
249
+ if module.padding_idx is not None:
250
+ module.weight.data[module.padding_idx].zero_()
251
+ elif isinstance(module, modeling_gemma.GemmaRMSNorm):
252
+ if hasattr(module, "weight"):
253
+ module.weight.data.fill_(1.0)
254
+
255
+
256
+ modeling_gemma.GemmaPreTrainedModel._init_weights = patched_gemma_pretrained_model_init_weights
257
+
258
+
259
+ def patched_paligemma_model_get_image_features(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
260
+ """Obtains image last hidden states from the vision tower and apply multimodal projection.
261
+
262
+ Args:
263
+ self: The PaliGemmaModel instance.
264
+ pixel_values: The tensors corresponding to the input images.
265
+ Shape: (batch_size, channels, height, width).
266
+
267
+ Returns:
268
+ Image feature tensor of shape (num_images, image_length, embed_dim).
269
+ """
270
+ image_outputs = self.vision_tower(pixel_values)
271
+ selected_image_feature = image_outputs.last_hidden_state
272
+ image_features = self.multi_modal_projector(selected_image_feature)
273
+ return image_features
274
+
275
+
276
+ PaliGemmaModel.get_image_features = patched_paligemma_model_get_image_features