opentau 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opentau-0.1.0/src/opentau.egg-info → opentau-0.1.2}/PKG-INFO +8 -5
- {opentau-0.1.0 → opentau-0.1.2}/README.md +3 -3
- {opentau-0.1.0 → opentau-0.1.2}/pyproject.toml +9 -3
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/__init__.py +1 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/lerobot_dataset.py +7 -9
- opentau-0.1.2/src/opentau/scripts/launch.py +84 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/train.py +8 -8
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/visualize_dataset.py +46 -30
- opentau-0.1.2/src/opentau/utils/transformers_patch.py +276 -0
- {opentau-0.1.0 → opentau-0.1.2/src/opentau.egg-info}/PKG-INFO +8 -5
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau.egg-info/SOURCES.txt +2 -2
- opentau-0.1.2/src/opentau.egg-info/entry_points.txt +5 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau.egg-info/requires.txt +4 -1
- opentau-0.1.0/src/opentau/scripts/launch_train.py +0 -63
- opentau-0.1.0/src/opentau/scripts/visualize_dataset_html.py +0 -507
- opentau-0.1.0/src/opentau.egg-info/entry_points.txt +0 -2
- {opentau-0.1.0 → opentau-0.1.2}/LICENSE +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/setup.cfg +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/__version__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/default.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/libero.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/parser.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/policies.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/reward.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/train.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/configs/types.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/constants.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/backward_compatibility.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/compute_stats.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/dataset_mixture.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/factory.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/base.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/clevr.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/cocoqa.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/dummy.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/pixmo.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/grounding/vsr.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/image_writer.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/online_buffer.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/push_dataset_to_hub/utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/sampler.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/standard_data_format_mapping.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/transforms.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v2/convert_dataset_v1_to_v2.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v21/_remove_language_instruction.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v21/convert_dataset_v20_to_v21.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/v21/convert_stats.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/datasets/video_utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/configs.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/factory.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/libero.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/envs/utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/optim/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/optim/factory.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/optim/optimizers.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/optim/schedulers.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/planner/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/planner/high_level_planner.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/planner/utils/memory.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/planner/utils/utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/factory.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/normalize.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi0/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi0/configuration_pi0.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi0/modeling_pi0.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi0/paligemma_with_expert.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi05/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi05/configuration_pi05.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi05/modeling_pi05.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pi05/paligemma_with_expert.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/pretrained.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/configuration_value.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/modeling_value.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/reward.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/policies/value/siglip_gemma.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/actions_mse_loss.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/bin_to_safetensors.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/compute_max_token_length.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/display_sys_info.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/download_libero_benchmarks.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/eval.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/export_to_onnx.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/fake_tensor_training.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/get_advantage_and_percentiles.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/high_level_planner_inference.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/inference.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/libero_simulation_parallel.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/libero_simulation_sequential.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/nav_high_level_planner_inference.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/scripts/zero_to_fp32.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/__init__.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/accelerate_utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/benchmark.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/fake_tensor.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/hub.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/import_utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/io_utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/libero.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/libero_dataset_recorder.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/logging_utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/monkey_patch.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/random_utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/train_utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau/utils/utils.py +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau.egg-info/dependency_links.txt +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/src/opentau.egg-info/top_level.txt +0 -0
- {opentau-0.1.0 → opentau-0.1.2}/tests/test_available.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opentau
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch
|
|
5
5
|
Author-email: Shuheng Liu <wish1104@icloud.com>, William Yue <williamyue37@gmail.com>, Akshay Shah <akshayhitendrashah@gmail.com>, Xingrui Gu <xingrui_gu@berkeley.edu>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -52,7 +52,7 @@ Requires-Dist: onnxruntime>=1.22.1; sys_platform == "darwin" or platform_machine
|
|
|
52
52
|
Requires-Dist: onnxruntime-gpu>=1.22.0; (sys_platform == "linux" and platform_machine == "x86_64") or (sys_platform == "win32" and (platform_machine == "AMD64" or platform_machine == "x86_64"))
|
|
53
53
|
Requires-Dist: onnxscript>=0.3.1
|
|
54
54
|
Requires-Dist: onnx-ir>=0.1.4
|
|
55
|
-
Requires-Dist:
|
|
55
|
+
Requires-Dist: transformers==4.53.3
|
|
56
56
|
Requires-Dist: scipy>=1.15.2
|
|
57
57
|
Requires-Dist: pytest>=8.1.0
|
|
58
58
|
Requires-Dist: pytest-cov>=5.0.0
|
|
@@ -94,6 +94,9 @@ Requires-Dist: numpy<2; extra == "libero"
|
|
|
94
94
|
Requires-Dist: gym<0.27,>=0.25; extra == "libero"
|
|
95
95
|
Requires-Dist: pyopengl-accelerate==3.1.7; sys_platform == "linux" and extra == "libero"
|
|
96
96
|
Requires-Dist: gymnasium[other]>=0.29; extra == "libero"
|
|
97
|
+
Requires-Dist: mujoco>=3.1.6; sys_platform == "linux" and extra == "libero"
|
|
98
|
+
Requires-Dist: pyopengl==3.1.7; sys_platform == "linux" and extra == "libero"
|
|
99
|
+
Requires-Dist: numpy==1.26.4; sys_platform == "linux" and extra == "libero"
|
|
97
100
|
Dynamic: license-file
|
|
98
101
|
|
|
99
102
|
<p align="center">
|
|
@@ -134,10 +137,10 @@ OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we
|
|
|
134
137
|
## Quick Start
|
|
135
138
|
If you are familiar with LeRobot, getting started with OpenTau is very easy.
|
|
136
139
|
Because OpenTau is a fork of the popular LeRobot repository, any LeRobot-compliant policy and dataset can be used directly with OpenTau.
|
|
137
|
-
Check out our documentation to get started quickly.
|
|
138
|
-
We provide a quick start guide to help you get started with OpenTau.
|
|
140
|
+
Check out our [documentation](https://opentau.readthedocs.io/) to get started quickly.
|
|
141
|
+
We provide a [quick start guide](https://opentau.readthedocs.io/en/latest/getting_started.html) to help you get started with OpenTau.
|
|
139
142
|
|
|
140
|
-
For using local notebooks to train and evaluate models, find the notebooks at
|
|
143
|
+
For using local notebooks to train and evaluate models, find the notebooks at [notebooks/pi05_training.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_training.ipynb) and [notebooks/pi05_evaluation_only.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_evaluation_only.ipynb).
|
|
141
144
|
|
|
142
145
|
For using the Google Colab notebooks to train and evaluate models, find the colab notebooks here: [pi05_training](https://colab.research.google.com/drive/1DeU0lNnEzs1KHo0Nkgh4YKBr-xu9moBM?usp=sharing) and [pi05_evaluation_only](https://colab.research.google.com/drive/1U_AyuH9WYMT4anEWvsOtIT7g01jA0WGm?usp=sharing) respectively.
|
|
143
146
|
|
|
@@ -36,10 +36,10 @@ OpenTau ($\tau$) is a tool developed by *[Tensor][1]* to bridge this gap, and we
|
|
|
36
36
|
## Quick Start
|
|
37
37
|
If you are familiar with LeRobot, getting started with OpenTau is very easy.
|
|
38
38
|
Because OpenTau is a fork of the popular LeRobot repository, any LeRobot-compliant policy and dataset can be used directly with OpenTau.
|
|
39
|
-
Check out our documentation to get started quickly.
|
|
40
|
-
We provide a quick start guide to help you get started with OpenTau.
|
|
39
|
+
Check out our [documentation](https://opentau.readthedocs.io/) to get started quickly.
|
|
40
|
+
We provide a [quick start guide](https://opentau.readthedocs.io/en/latest/getting_started.html) to help you get started with OpenTau.
|
|
41
41
|
|
|
42
|
-
For using local notebooks to train and evaluate models, find the notebooks at
|
|
42
|
+
For using local notebooks to train and evaluate models, find the notebooks at [notebooks/pi05_training.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_training.ipynb) and [notebooks/pi05_evaluation_only.ipynb](https://github.com/TensorAuto/OpenTau/blob/main/notebooks/pi05_evaluation_only.ipynb).
|
|
43
43
|
|
|
44
44
|
For using the Google Colab notebooks to train and evaluate models, find the colab notebooks here: [pi05_training](https://colab.research.google.com/drive/1DeU0lNnEzs1KHo0Nkgh4YKBr-xu9moBM?usp=sharing) and [pi05_evaluation_only](https://colab.research.google.com/drive/1U_AyuH9WYMT4anEWvsOtIT7g01jA0WGm?usp=sharing) respectively.
|
|
45
45
|
|
|
@@ -20,7 +20,7 @@ huggingface = "https://huggingface.co/TensorAuto"
|
|
|
20
20
|
|
|
21
21
|
[project]
|
|
22
22
|
name = "opentau"
|
|
23
|
-
version = "0.1.
|
|
23
|
+
version = "0.1.2"
|
|
24
24
|
description = "OpenTau: Tensor's VLA Training Infrastructure for Real-World Robotics in Pytorch"
|
|
25
25
|
authors = [
|
|
26
26
|
{ name = "Shuheng Liu", email = "wish1104@icloud.com" },
|
|
@@ -76,7 +76,7 @@ dependencies = [
|
|
|
76
76
|
"onnxruntime-gpu>=1.22.0 ; ((sys_platform == 'linux' and platform_machine == 'x86_64') or (sys_platform == 'win32' and (platform_machine == 'AMD64' or platform_machine == 'x86_64'))) ",
|
|
77
77
|
"onnxscript>=0.3.1",
|
|
78
78
|
"onnx-ir>=0.1.4",
|
|
79
|
-
"
|
|
79
|
+
"transformers==4.53.3",
|
|
80
80
|
"scipy>=1.15.2",
|
|
81
81
|
"pytest>=8.1.0",
|
|
82
82
|
"pytest-cov>=5.0.0",
|
|
@@ -89,7 +89,10 @@ dependencies = [
|
|
|
89
89
|
]
|
|
90
90
|
|
|
91
91
|
[project.scripts]
|
|
92
|
-
opentau-train = "opentau.scripts.
|
|
92
|
+
opentau-train = "opentau.scripts.launch:train"
|
|
93
|
+
opentau-eval = "opentau.scripts.launch:eval"
|
|
94
|
+
opentau-export = "opentau.scripts.launch:export"
|
|
95
|
+
opentau-dataset-viz = "opentau.scripts.launch:visualize"
|
|
93
96
|
|
|
94
97
|
[project.optional-dependencies]
|
|
95
98
|
dev = ["pre-commit>=3.7.0",
|
|
@@ -122,6 +125,9 @@ libero = [
|
|
|
122
125
|
"gym>=0.25,<0.27",
|
|
123
126
|
"pyopengl-accelerate==3.1.7 ; sys_platform == 'linux'",
|
|
124
127
|
"gymnasium[other]>=0.29",
|
|
128
|
+
"mujoco>=3.1.6 ; sys_platform == 'linux'",
|
|
129
|
+
"pyopengl==3.1.7 ; sys_platform == 'linux'",
|
|
130
|
+
"numpy==1.26.4 ; sys_platform == 'linux'",
|
|
125
131
|
]
|
|
126
132
|
|
|
127
133
|
[tool.uv.sources]
|
|
@@ -56,6 +56,7 @@ When implementing a new policy class (e.g., `DiffusionPolicy`), follow these ste
|
|
|
56
56
|
import itertools
|
|
57
57
|
|
|
58
58
|
from opentau.__version__ import __version__ # noqa: F401
|
|
59
|
+
from opentau.utils import transformers_patch # noqa: F401
|
|
59
60
|
|
|
60
61
|
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
|
|
61
62
|
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
|
|
@@ -633,7 +633,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|
|
633
633
|
For example, {"image_key": torch.zeros(2, 3, 224, 224), "image_key_is_pad": [False, True] } will become
|
|
634
634
|
{
|
|
635
635
|
"image_key": torch.zeros(3, 224, 224),
|
|
636
|
+
"image_key_local": torch.zeros(3, 224, 224),
|
|
636
637
|
"image_key_is_pad: False,
|
|
638
|
+
"image_key_local_is_pad": True,
|
|
637
639
|
}.
|
|
638
640
|
"""
|
|
639
641
|
raise NotImplementedError
|
|
@@ -1787,16 +1789,12 @@ class LeRobotDataset(BaseDataset):
|
|
|
1787
1789
|
cam_keys = {v for k, v in name_map.items() if k.startswith("camera")}
|
|
1788
1790
|
for k in cam_keys:
|
|
1789
1791
|
images = item.pop(k)
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
)
|
|
1793
|
-
item[k + "_local"], item[k] = images
|
|
1792
|
+
if len(images) == 2:
|
|
1793
|
+
item[k + "_local"], item[k] = images
|
|
1794
1794
|
|
|
1795
|
-
pads = item.
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
)
|
|
1799
|
-
item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
|
|
1795
|
+
pads = item.get(k + "_is_pad")
|
|
1796
|
+
if hasattr(pads, "__len__") and len(pads) == 2:
|
|
1797
|
+
item[k + "_local_is_pad"], item[k + "_is_pad"] = pads
|
|
1800
1798
|
|
|
1801
1799
|
@staticmethod
|
|
1802
1800
|
def compute_delta_params(
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import subprocess
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from types import ModuleType
|
|
20
|
+
|
|
21
|
+
import opentau.scripts.eval as eval_script
|
|
22
|
+
import opentau.scripts.export_to_onnx as export_script
|
|
23
|
+
import opentau.scripts.train as train_script
|
|
24
|
+
import opentau.scripts.visualize_dataset as visualize_script
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def launch(script_module: ModuleType, description: str, use_accelerate: bool = True):
|
|
28
|
+
"""Generic launcher for OpenTau scripts using Accelerate or Python."""
|
|
29
|
+
parser = argparse.ArgumentParser(
|
|
30
|
+
description=description,
|
|
31
|
+
usage=f"{Path(sys.argv[0]).name} {'[--accelerate-config CONFIG] ' if use_accelerate else ''}[ARGS]",
|
|
32
|
+
)
|
|
33
|
+
if use_accelerate:
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"--accelerate-config", type=str, help="Path to accelerate config file (yaml)", default=None
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# We use parse_known_args so that all other arguments are collected
|
|
39
|
+
# These will be passed to the target script
|
|
40
|
+
args, unknown_args = parser.parse_known_args()
|
|
41
|
+
|
|
42
|
+
# Base command
|
|
43
|
+
if use_accelerate:
|
|
44
|
+
cmd = ["accelerate", "launch"]
|
|
45
|
+
# Add accelerate config if provided
|
|
46
|
+
if args.accelerate_config:
|
|
47
|
+
cmd.extend(["--config_file", args.accelerate_config])
|
|
48
|
+
else:
|
|
49
|
+
cmd = [sys.executable]
|
|
50
|
+
|
|
51
|
+
# Add the path to the target script
|
|
52
|
+
# We resolve the path to ensure it's absolute
|
|
53
|
+
script_path = Path(script_module.__file__).resolve()
|
|
54
|
+
cmd.append(str(script_path))
|
|
55
|
+
|
|
56
|
+
# Add all other arguments (passed to the target script)
|
|
57
|
+
cmd.extend(unknown_args)
|
|
58
|
+
|
|
59
|
+
# Print the command for transparency
|
|
60
|
+
print(f"Executing: {' '.join(cmd)}")
|
|
61
|
+
|
|
62
|
+
# Replace the current process with the accelerate launch command
|
|
63
|
+
try:
|
|
64
|
+
subprocess.run(cmd, check=True)
|
|
65
|
+
except subprocess.CalledProcessError as e:
|
|
66
|
+
sys.exit(e.returncode)
|
|
67
|
+
except KeyboardInterrupt:
|
|
68
|
+
sys.exit(130)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def train():
|
|
72
|
+
launch(train_script, "Launch OpenTau training with Accelerate")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def eval():
|
|
76
|
+
launch(eval_script, "Launch OpenTau evaluation with Accelerate")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def export():
|
|
80
|
+
launch(export_script, "Launch OpenTau ONNX export", use_accelerate=False)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def visualize():
|
|
84
|
+
launch(visualize_script, "Launch OpenTau visualization", use_accelerate=False)
|
|
@@ -73,16 +73,16 @@ def update_policy(
|
|
|
73
73
|
train_config.loss_weighting["MSE"] * losses["MSE"] + train_config.loss_weighting["CE"] * losses["CE"]
|
|
74
74
|
)
|
|
75
75
|
|
|
76
|
-
|
|
77
|
-
|
|
76
|
+
accelerator.backward(loss)
|
|
77
|
+
accelerator.unscale_gradients(optimizer=optimizer)
|
|
78
78
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
79
|
+
if accelerator.sync_gradients:
|
|
80
|
+
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
|
81
|
+
if accelerator.is_main_process:
|
|
82
|
+
train_metrics.grad_norm = grad_norm
|
|
83
83
|
|
|
84
|
-
|
|
85
|
-
|
|
84
|
+
optimizer.step()
|
|
85
|
+
optimizer.zero_grad()
|
|
86
86
|
|
|
87
87
|
# Step through pytorch scheduler at every batch instead of epoch
|
|
88
88
|
if lr_scheduler is not None:
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
15
|
# See the License for the specific language governing permissions and
|
|
16
16
|
# limitations under the License.
|
|
17
|
-
"""
|
|
17
|
+
"""Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
|
18
18
|
|
|
19
19
|
Note: The last frame of the episode doesn't always correspond to a final state.
|
|
20
20
|
That's because our datasets are composed of transition from state to state up to
|
|
@@ -30,34 +30,21 @@ Examples:
|
|
|
30
30
|
|
|
31
31
|
- Visualize data stored on a local machine:
|
|
32
32
|
```
|
|
33
|
-
local$
|
|
34
|
-
--repo-id lerobot/pusht \
|
|
35
|
-
--episode-index 0
|
|
33
|
+
local$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0
|
|
36
34
|
```
|
|
37
35
|
|
|
38
36
|
- Visualize data stored on a distant machine with a local viewer:
|
|
39
37
|
```
|
|
40
|
-
distant$
|
|
41
|
-
--repo-id lerobot/pusht \
|
|
42
|
-
--episode-index 0 \
|
|
43
|
-
--save 1 \
|
|
44
|
-
--output-dir path/to/directory
|
|
38
|
+
distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --save 1 --output-dir path/to/directory
|
|
45
39
|
|
|
46
40
|
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
|
|
47
41
|
local$ rerun lerobot_pusht_episode_0.rrd
|
|
48
42
|
```
|
|
49
43
|
|
|
50
44
|
- Visualize data stored on a distant machine through streaming:
|
|
51
|
-
(You need to forward the websocket port to the distant machine, with
|
|
52
|
-
`ssh -L 9087:localhost:9087 username@remote-host`)
|
|
53
45
|
```
|
|
54
|
-
distant$ python src/opentau/scripts/visualize_dataset.py \
|
|
55
|
-
--repo-id lerobot/pusht \
|
|
56
|
-
--episode-index 0 \
|
|
57
|
-
--mode distant \
|
|
58
|
-
--ws-port 9087
|
|
59
46
|
|
|
60
|
-
|
|
47
|
+
distant$ opentau-dataset-viz --repo-id lerobot/pusht --episode-index 0 --mode distant --web-port 9090
|
|
61
48
|
```
|
|
62
49
|
|
|
63
50
|
"""
|
|
@@ -75,8 +62,34 @@ import torch
|
|
|
75
62
|
import torch.utils.data
|
|
76
63
|
import tqdm
|
|
77
64
|
|
|
65
|
+
from opentau.configs.default import DatasetMixtureConfig, WandBConfig
|
|
66
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
78
67
|
from opentau.datasets.lerobot_dataset import LeRobotDataset
|
|
79
|
-
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def create_mock_train_config() -> TrainPipelineConfig:
|
|
71
|
+
"""Create a mock TrainPipelineConfig for dataset visualization.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
TrainPipelineConfig: A mock config with default values.
|
|
75
|
+
"""
|
|
76
|
+
return TrainPipelineConfig(
|
|
77
|
+
dataset_mixture=DatasetMixtureConfig(), # Will be set by the dataset
|
|
78
|
+
resolution=(224, 224),
|
|
79
|
+
num_cams=2,
|
|
80
|
+
max_state_dim=32,
|
|
81
|
+
max_action_dim=32,
|
|
82
|
+
action_chunk=50,
|
|
83
|
+
loss_weighting={"MSE": 1, "CE": 1},
|
|
84
|
+
num_workers=4,
|
|
85
|
+
batch_size=8,
|
|
86
|
+
steps=100_000,
|
|
87
|
+
log_freq=200,
|
|
88
|
+
save_checkpoint=True,
|
|
89
|
+
save_freq=20_000,
|
|
90
|
+
use_policy_training_preset=True,
|
|
91
|
+
wandb=WandBConfig(),
|
|
92
|
+
)
|
|
80
93
|
|
|
81
94
|
|
|
82
95
|
class EpisodeSampler(torch.utils.data.Sampler):
|
|
@@ -108,7 +121,6 @@ def visualize_dataset(
|
|
|
108
121
|
num_workers: int = 0,
|
|
109
122
|
mode: str = "local",
|
|
110
123
|
web_port: int = 9090,
|
|
111
|
-
ws_port: int = 9087,
|
|
112
124
|
save: bool = False,
|
|
113
125
|
output_dir: Path | None = None,
|
|
114
126
|
) -> Path | None:
|
|
@@ -142,7 +154,7 @@ def visualize_dataset(
|
|
|
142
154
|
gc.collect()
|
|
143
155
|
|
|
144
156
|
if mode == "distant":
|
|
145
|
-
rr.
|
|
157
|
+
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
|
146
158
|
|
|
147
159
|
logging.info("Logging to Rerun")
|
|
148
160
|
|
|
@@ -194,7 +206,7 @@ def visualize_dataset(
|
|
|
194
206
|
print("Ctrl-C received. Exiting.")
|
|
195
207
|
|
|
196
208
|
|
|
197
|
-
def
|
|
209
|
+
def parse_args() -> dict:
|
|
198
210
|
parser = argparse.ArgumentParser()
|
|
199
211
|
|
|
200
212
|
parser.add_argument(
|
|
@@ -250,12 +262,6 @@ def main():
|
|
|
250
262
|
default=9090,
|
|
251
263
|
help="Web port for rerun.io when `--mode distant` is set.",
|
|
252
264
|
)
|
|
253
|
-
parser.add_argument(
|
|
254
|
-
"--ws-port",
|
|
255
|
-
type=int,
|
|
256
|
-
default=9087,
|
|
257
|
-
help="Web socket port for rerun.io when `--mode distant` is set.",
|
|
258
|
-
)
|
|
259
265
|
parser.add_argument(
|
|
260
266
|
"--save",
|
|
261
267
|
type=int,
|
|
@@ -279,15 +285,25 @@ def main():
|
|
|
279
285
|
)
|
|
280
286
|
|
|
281
287
|
args = parser.parse_args()
|
|
282
|
-
|
|
288
|
+
return vars(args)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def main():
|
|
292
|
+
kwargs = parse_args()
|
|
283
293
|
repo_id = kwargs.pop("repo_id")
|
|
284
294
|
root = kwargs.pop("root")
|
|
285
295
|
tolerance_s = kwargs.pop("tolerance_s")
|
|
286
296
|
|
|
287
297
|
logging.info("Loading dataset")
|
|
288
|
-
dataset = LeRobotDataset(
|
|
298
|
+
dataset = LeRobotDataset(
|
|
299
|
+
create_mock_train_config(),
|
|
300
|
+
repo_id,
|
|
301
|
+
root=root,
|
|
302
|
+
tolerance_s=tolerance_s,
|
|
303
|
+
standardize=False,
|
|
304
|
+
)
|
|
289
305
|
|
|
290
|
-
visualize_dataset(dataset, **
|
|
306
|
+
visualize_dataset(dataset, **kwargs)
|
|
291
307
|
|
|
292
308
|
|
|
293
309
|
if __name__ == "__main__":
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module for patching transformers
|
|
16
|
+
|
|
17
|
+
Most patches come from the branch fix/lerobot-openpi
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from typing import Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch import nn
|
|
24
|
+
from transformers.models.gemma import modeling_gemma
|
|
25
|
+
from transformers.models.gemma.configuration_gemma import GemmaConfig
|
|
26
|
+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaModel
|
|
27
|
+
|
|
28
|
+
# Monkey patch __init__ of GemmaConfig to fix or modify its behavior as needed.
|
|
29
|
+
|
|
30
|
+
_original_gemma_config_init = GemmaConfig.__init__
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def patched_gemma_config_init(
|
|
34
|
+
self, *args, use_adarms: bool = False, adarms_cond_dim: Optional[int] = None, **kwargs
|
|
35
|
+
):
|
|
36
|
+
"""Initializes the GemmaConfig with added ADARMS support.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
self: The GemmaConfig instance.
|
|
40
|
+
*args: Variable length argument list.
|
|
41
|
+
use_adarms: Whether to use Adaptive RMS normalization.
|
|
42
|
+
adarms_cond_dim: The dimension of the conditioning vector for ADARMS.
|
|
43
|
+
**kwargs: Arbitrary keyword arguments.
|
|
44
|
+
"""
|
|
45
|
+
# Call the original init with all other arguments
|
|
46
|
+
_original_gemma_config_init(self, *args, **kwargs)
|
|
47
|
+
|
|
48
|
+
# Initialize custom attributes
|
|
49
|
+
self.use_adarms = use_adarms
|
|
50
|
+
self.adarms_cond_dim = adarms_cond_dim
|
|
51
|
+
|
|
52
|
+
# Set default for adarms_cond_dim if use_adarms is True
|
|
53
|
+
if self.use_adarms and self.adarms_cond_dim is None:
|
|
54
|
+
# hidden_size is set by _original_gemma_config_init
|
|
55
|
+
self.adarms_cond_dim = self.hidden_size
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
GemmaConfig.__init__ = patched_gemma_config_init
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# --- Modeling Patches ---
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _gated_residual(x, y, gate):
|
|
65
|
+
"""
|
|
66
|
+
Applies gated residual connection with optional gate parameter.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
x: Input tensor (residual)
|
|
70
|
+
y: Output tensor to be added
|
|
71
|
+
gate: Optional gate tensor to modulate the addition
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
x + y if gate is None, otherwise x + y * gate
|
|
75
|
+
"""
|
|
76
|
+
if x is None and y is None:
|
|
77
|
+
return None
|
|
78
|
+
if x is None or y is None:
|
|
79
|
+
return x if x is not None else y
|
|
80
|
+
if gate is None:
|
|
81
|
+
return x + y
|
|
82
|
+
return x + y * gate
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
modeling_gemma._gated_residual = _gated_residual
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class PatchedGemmaRMSNorm(nn.Module):
|
|
89
|
+
"""RMS normalization with optional adaptive support (ADARMS)."""
|
|
90
|
+
|
|
91
|
+
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None):
|
|
92
|
+
"""Initializes the PatchedGemmaRMSNorm.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
dim: The dimension of the input tensor.
|
|
96
|
+
eps: The epsilon value for numerical stability.
|
|
97
|
+
cond_dim: The dimension of the conditioning vector (if using ADARMS).
|
|
98
|
+
"""
|
|
99
|
+
super().__init__()
|
|
100
|
+
self.eps = eps
|
|
101
|
+
self.dim = dim
|
|
102
|
+
self.cond_dim = cond_dim
|
|
103
|
+
|
|
104
|
+
# Dense layer for adaptive normalization (if cond_dim is provided)
|
|
105
|
+
if cond_dim is not None:
|
|
106
|
+
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
|
107
|
+
# Initialize with zeros (matches source implementation)
|
|
108
|
+
nn.init.zeros_(self.dense.weight)
|
|
109
|
+
else:
|
|
110
|
+
self.weight = nn.Parameter(torch.zeros(dim))
|
|
111
|
+
self.dense = None
|
|
112
|
+
|
|
113
|
+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
|
114
|
+
"""Applies RMS normalization.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
x: The input tensor.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
The normalized tensor.
|
|
121
|
+
"""
|
|
122
|
+
# Compute variance in float32 (like the source implementation)
|
|
123
|
+
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
|
124
|
+
# Compute normalization in float32
|
|
125
|
+
normed_inputs = x * torch.rsqrt(var + self.eps)
|
|
126
|
+
return normed_inputs
|
|
127
|
+
|
|
128
|
+
def forward(
|
|
129
|
+
self, x: torch.Tensor, cond: Optional[torch.Tensor] = None
|
|
130
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
131
|
+
"""Forward pass of the normalization layer.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
x: The input tensor.
|
|
135
|
+
cond: The conditioning tensor for adaptive normalization.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
A tuple containing the normalized tensor and the gate tensor (if applicable).
|
|
139
|
+
If cond is None, the gate tensor will be None.
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
ValueError: If cond dimension does not match the configured cond_dim.
|
|
143
|
+
"""
|
|
144
|
+
dtype = x.dtype # original dtype, could be half-precision
|
|
145
|
+
normed_inputs = self._norm(x)
|
|
146
|
+
|
|
147
|
+
if cond is None or self.dense is None:
|
|
148
|
+
# regular RMSNorm
|
|
149
|
+
# scale by learned parameter in float32 (matches source implementation)
|
|
150
|
+
normed_inputs = normed_inputs * (1.0 + self.weight.float())
|
|
151
|
+
return normed_inputs.to(dtype), None # return in original dtype with None gate
|
|
152
|
+
|
|
153
|
+
# adaptive RMSNorm (if cond is provided and dense layer exists)
|
|
154
|
+
if cond.shape[-1] != self.cond_dim:
|
|
155
|
+
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
|
|
156
|
+
|
|
157
|
+
modulation = self.dense(cond)
|
|
158
|
+
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
|
|
159
|
+
if len(x.shape) == 3: # [batch, seq, features]
|
|
160
|
+
modulation = modulation.unsqueeze(1)
|
|
161
|
+
|
|
162
|
+
scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
|
|
163
|
+
|
|
164
|
+
normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
|
|
165
|
+
|
|
166
|
+
return normed_inputs.to(dtype), gate.to(dtype)
|
|
167
|
+
|
|
168
|
+
def extra_repr(self) -> str:
|
|
169
|
+
"""Returns the extra representation of the module."""
|
|
170
|
+
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
|
|
171
|
+
if self.dense is not None:
|
|
172
|
+
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
|
|
173
|
+
return repr_str
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# Apply patches
|
|
177
|
+
modeling_gemma.GemmaRMSNorm = PatchedGemmaRMSNorm
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def patched_gemma_decoder_layer_init(self, config: GemmaConfig, layer_idx: int):
|
|
181
|
+
"""Initializes a GemmaDecoderLayer with potential ADARMS support.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
self: The GemmaDecoderLayer instance.
|
|
185
|
+
config: The configuration object.
|
|
186
|
+
layer_idx: The index of the layer.
|
|
187
|
+
"""
|
|
188
|
+
modeling_gemma.GradientCheckpointingLayer.__init__(self)
|
|
189
|
+
self.hidden_size = config.hidden_size
|
|
190
|
+
|
|
191
|
+
self.self_attn = modeling_gemma.GemmaAttention(config=config, layer_idx=layer_idx)
|
|
192
|
+
|
|
193
|
+
self.mlp = modeling_gemma.GemmaMLP(config)
|
|
194
|
+
|
|
195
|
+
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
|
196
|
+
self.input_layernorm = modeling_gemma.GemmaRMSNorm(
|
|
197
|
+
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
|
198
|
+
)
|
|
199
|
+
self.post_attention_layernorm = modeling_gemma.GemmaRMSNorm(
|
|
200
|
+
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
modeling_gemma.GemmaDecoderLayer.__init__ = patched_gemma_decoder_layer_init
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def patched_gemma_model_init(self, config: GemmaConfig):
|
|
208
|
+
"""Initializes the GemmaModel with potential ADARMS support.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
self: The GemmaModel instance.
|
|
212
|
+
config: The configuration object.
|
|
213
|
+
"""
|
|
214
|
+
modeling_gemma.GemmaPreTrainedModel.__init__(self, config)
|
|
215
|
+
self.padding_idx = config.pad_token_id
|
|
216
|
+
self.vocab_size = config.vocab_size
|
|
217
|
+
|
|
218
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
219
|
+
self.layers = nn.ModuleList(
|
|
220
|
+
[modeling_gemma.GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
|
224
|
+
self.norm = modeling_gemma.GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
|
225
|
+
self.rotary_emb = modeling_gemma.GemmaRotaryEmbedding(config=config)
|
|
226
|
+
self.gradient_checkpointing = False
|
|
227
|
+
|
|
228
|
+
# Initialize weights and apply final processing
|
|
229
|
+
self.post_init()
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
modeling_gemma.GemmaModel.__init__ = patched_gemma_model_init
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def patched_gemma_pretrained_model_init_weights(self, module: nn.Module):
|
|
236
|
+
"""Initializes the weights of the GemmaPreTrainedModel.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
self: The GemmaPreTrainedModel instance.
|
|
240
|
+
module: The module to initialize.
|
|
241
|
+
"""
|
|
242
|
+
std = self.config.initializer_range
|
|
243
|
+
if isinstance(module, nn.Linear):
|
|
244
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
|
245
|
+
if module.bias is not None:
|
|
246
|
+
module.bias.data.zero_()
|
|
247
|
+
elif isinstance(module, nn.Embedding):
|
|
248
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
|
249
|
+
if module.padding_idx is not None:
|
|
250
|
+
module.weight.data[module.padding_idx].zero_()
|
|
251
|
+
elif isinstance(module, modeling_gemma.GemmaRMSNorm):
|
|
252
|
+
if hasattr(module, "weight"):
|
|
253
|
+
module.weight.data.fill_(1.0)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
modeling_gemma.GemmaPreTrainedModel._init_weights = patched_gemma_pretrained_model_init_weights
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def patched_paligemma_model_get_image_features(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
260
|
+
"""Obtains image last hidden states from the vision tower and apply multimodal projection.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
self: The PaliGemmaModel instance.
|
|
264
|
+
pixel_values: The tensors corresponding to the input images.
|
|
265
|
+
Shape: (batch_size, channels, height, width).
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Image feature tensor of shape (num_images, image_length, embed_dim).
|
|
269
|
+
"""
|
|
270
|
+
image_outputs = self.vision_tower(pixel_values)
|
|
271
|
+
selected_image_feature = image_outputs.last_hidden_state
|
|
272
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
|
273
|
+
return image_features
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
PaliGemmaModel.get_image_features = patched_paligemma_model_get_image_features
|