opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,20 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utility functions and helpers for the OpenTau package.
16
+
17
+ This module provides various utility functions for logging, random number generation,
18
+ training utilities, device management, and other common operations used throughout
19
+ the OpenTau codebase.
20
+ """
@@ -0,0 +1,79 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utilities for managing Accelerate accelerator instances.
15
+
16
+ This module provides functions for setting and getting a global accelerator
17
+ instance, which is useful for accessing accelerator state throughout the codebase.
18
+ """
19
+
20
+ import warnings
21
+
22
+ from accelerate import Accelerator
23
+
24
+ _acc: Accelerator | None = None
25
+
26
+
27
+ def set_proc_accelerator(accelerator: Accelerator, allow_reset: bool = False) -> None:
28
+ """Set the global accelerator instance for the current process.
29
+
30
+ Args:
31
+ accelerator: Accelerator instance to set.
32
+ allow_reset: If True, allow resetting an already-set accelerator.
33
+ Defaults to False.
34
+
35
+ Raises:
36
+ AssertionError: If accelerator is not an Accelerator instance.
37
+ RuntimeError: If accelerator is already set and allow_reset is False.
38
+ """
39
+ global _acc
40
+
41
+ assert isinstance(accelerator, Accelerator), (
42
+ f"Expected an `Accelerator` got {type(accelerator)} with value {accelerator}."
43
+ )
44
+ if _acc is not None:
45
+ if allow_reset:
46
+ warnings.warn(
47
+ "Resetting the accelerator. This could have unintended side effects.",
48
+ UserWarning,
49
+ stacklevel=2,
50
+ )
51
+ else:
52
+ raise RuntimeError("Accelerator has already been set.")
53
+ _acc = accelerator
54
+
55
+
56
+ def get_proc_accelerator() -> Accelerator:
57
+ """Get the global accelerator instance for the current process.
58
+
59
+ Returns:
60
+ The accelerator instance, or None if not set.
61
+ """
62
+ return _acc
63
+
64
+
65
+ def acc_print(*args, **kwargs) -> None:
66
+ """Print with process index prefix when using accelerate.
67
+
68
+ If an accelerator is set, prints with a prefix showing the process index.
69
+ Otherwise, prints normally.
70
+
71
+ Args:
72
+ *args: Positional arguments to pass to print.
73
+ **kwargs: Keyword arguments to pass to print.
74
+ """
75
+ acc = get_proc_accelerator()
76
+ if acc is None:
77
+ print(*args, **kwargs)
78
+ else:
79
+ print(f"Acc[{acc.process_index} of {acc.num_processes}]", *args, **kwargs)
@@ -0,0 +1,98 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Utilities for benchmarking and timing code execution.
18
+
19
+ This module provides the TimeBenchmark class for measuring execution time
20
+ using context managers or decorators in a thread-safe manner.
21
+ """
22
+
23
+ import threading
24
+ import time
25
+ from contextlib import ContextDecorator
26
+
27
+
28
+ class TimeBenchmark(ContextDecorator):
29
+ """
30
+ Measures execution time using a context manager or decorator.
31
+
32
+ This class supports both context manager and decorator usage, and is thread-safe for multithreaded
33
+ environments.
34
+
35
+ Args:
36
+ print: If True, prints the elapsed time upon exiting the context or completing the function.
37
+ Defaults to False.
38
+
39
+ Examples:
40
+
41
+ Using as a context manager::
42
+
43
+ >>> benchmark = TimeBenchmark()
44
+ >>> with benchmark:
45
+ ... time.sleep(1)
46
+ >>> print(f"Block took {benchmark.result:.4f} seconds")
47
+ Block took approximately 1.0000 seconds
48
+
49
+ Using with multithreading::
50
+
51
+ import threading
52
+
53
+ benchmark = TimeBenchmark()
54
+
55
+ def context_manager_example():
56
+ with benchmark:
57
+ time.sleep(0.01)
58
+ print(f"Block took {benchmark.result_ms:.2f} milliseconds")
59
+
60
+ threads = []
61
+ for _ in range(3):
62
+ t1 = threading.Thread(target=context_manager_example)
63
+ threads.append(t1)
64
+
65
+ for t in threads:
66
+ t.start()
67
+
68
+ for t in threads:
69
+ t.join()
70
+
71
+ # Expected output:
72
+ # Block took approximately 10.00 milliseconds
73
+ # Block took approximately 10.00 milliseconds
74
+ # Block took approximately 10.00 milliseconds
75
+ """
76
+
77
+ def __init__(self, print=False):
78
+ self.local = threading.local()
79
+ self.print_time = print
80
+
81
+ def __enter__(self):
82
+ self.local.start_time = time.perf_counter()
83
+ return self
84
+
85
+ def __exit__(self, *exc):
86
+ self.local.end_time = time.perf_counter()
87
+ self.local.elapsed_time = self.local.end_time - self.local.start_time
88
+ if self.print_time:
89
+ print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
90
+ return False
91
+
92
+ @property
93
+ def result(self):
94
+ return getattr(self.local, "elapsed_time", None)
95
+
96
+ @property
97
+ def result_ms(self):
98
+ return self.result * 1e3
@@ -0,0 +1,81 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utilities for working with PyTorch FakeTensor.
15
+
16
+ This module provides a FakeTensorContext class and decorator for running code
17
+ with FakeTensor mode enabled, which is useful for shape inference and testing
18
+ without actual tensor computations.
19
+ """
20
+
21
+ import functools
22
+
23
+ from torch._subclasses import FakeTensorMode
24
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
25
+
26
+ from opentau.utils.monkey_patch import (
27
+ torch_fake_tensor_beta_validate_args_patch,
28
+ torch_fake_tensor_is_inf_patch,
29
+ torch_fake_tensor_module_to_patch,
30
+ torch_fake_tensor_to_numpy_patch,
31
+ )
32
+
33
+ # Share the ShapeEnv instance across all FakeTensorContext instances
34
+ # Without this, each FakeTensor.item() call would start numbering from 0, which is wrong.
35
+ _shared_shape_env = ShapeEnv()
36
+
37
+
38
+ class FakeTensorContext:
39
+ """Context manager for enabling FakeTensor mode with necessary patches.
40
+
41
+ This context manager applies all necessary monkey patches for FakeTensor
42
+ compatibility and manages the FakeTensorMode lifecycle.
43
+
44
+ Args:
45
+ allow_non_fake_inputs: If True, allow non-fake tensors as inputs.
46
+ Defaults to True.
47
+ """
48
+
49
+ def __init__(self, allow_non_fake_inputs: bool = True):
50
+ self.mode = FakeTensorMode(
51
+ shape_env=_shared_shape_env,
52
+ allow_non_fake_inputs=allow_non_fake_inputs,
53
+ )
54
+ torch_fake_tensor_module_to_patch()
55
+ torch_fake_tensor_to_numpy_patch()
56
+ torch_fake_tensor_beta_validate_args_patch()
57
+ torch_fake_tensor_is_inf_patch()
58
+
59
+ def __enter__(self):
60
+ return self.mode.__enter__()
61
+
62
+ def __exit__(self, exc_type, exc_val, exc_tb):
63
+ return self.mode.__exit__(exc_type, exc_val, exc_tb)
64
+
65
+
66
+ def run_with_fake_tensor(fn):
67
+ """Decorator to run a function with FakeTensor enabled.
68
+
69
+ Args:
70
+ fn: Function to wrap.
71
+
72
+ Returns:
73
+ Wrapped function that runs with FakeTensorContext enabled.
74
+ """
75
+
76
+ @functools.wraps(fn)
77
+ def wrapper(*args, **kwargs):
78
+ with FakeTensorContext():
79
+ return fn(*args, **kwargs)
80
+
81
+ return wrapper
opentau/utils/hub.py ADDED
@@ -0,0 +1,209 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities for interacting with the Hugging Face Hub.
16
+
17
+ This module provides the HubMixin class which enables objects to be saved to
18
+ and loaded from the Hugging Face Hub, similar to ModelHubMixin but with fewer
19
+ assumptions about the object type.
20
+ """
21
+
22
+ from pathlib import Path
23
+ from tempfile import TemporaryDirectory
24
+ from typing import Any, Type, TypeVar
25
+
26
+ from huggingface_hub import HfApi
27
+ from huggingface_hub.utils import validate_hf_hub_args
28
+
29
+ T = TypeVar("T", bound="HubMixin")
30
+
31
+
32
+ class HubMixin:
33
+ """
34
+ A Mixin containing the functionality to push an object to the hub.
35
+
36
+ This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its
37
+ subclasses (in particular, the fact that it's not necessarily a model).
38
+
39
+ The inheriting classes must implement '_save_pretrained' and 'from_pretrained'.
40
+ """
41
+
42
+ def save_pretrained(
43
+ self,
44
+ save_directory: str | Path,
45
+ *,
46
+ repo_id: str | None = None,
47
+ push_to_hub: bool = False,
48
+ card_kwargs: dict[str, Any] | None = None,
49
+ **push_to_hub_kwargs,
50
+ ) -> str | None:
51
+ """
52
+ Save object in local directory.
53
+
54
+ Args:
55
+ save_directory (`str` or `Path`):
56
+ Path to directory in which the object will be saved.
57
+ push_to_hub (`bool`, *optional*, defaults to `False`):
58
+ Whether or not to push your object to the Huggingface Hub after saving it.
59
+ repo_id (`str`, *optional*):
60
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
61
+ not provided.
62
+ card_kwargs (`Dict[str, Any]`, *optional*):
63
+ Additional arguments passed to the card template to customize the card.
64
+ push_to_hub_kwargs:
65
+ Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method.
66
+ Returns:
67
+ `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
68
+ """
69
+ save_directory = Path(save_directory)
70
+ save_directory.mkdir(parents=True, exist_ok=True)
71
+
72
+ # save object (weights, files, etc.)
73
+ self._save_pretrained(save_directory)
74
+
75
+ # push to the Hub if required
76
+ if push_to_hub:
77
+ if repo_id is None:
78
+ repo_id = save_directory.name # Defaults to `save_directory` name
79
+ return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
80
+ return None
81
+
82
+ def _save_pretrained(self, save_directory: Path) -> None:
83
+ """
84
+ Overwrite this method in subclass to define how to save your object.
85
+
86
+ Args:
87
+ save_directory (`str` or `Path`):
88
+ Path to directory in which the object files will be saved.
89
+ """
90
+ raise NotImplementedError
91
+
92
+ @classmethod
93
+ @validate_hf_hub_args
94
+ def from_pretrained(
95
+ cls: Type[T],
96
+ pretrained_name_or_path: str | Path,
97
+ *,
98
+ force_download: bool = False,
99
+ resume_download: bool | None = None,
100
+ proxies: dict | None = None,
101
+ token: str | bool | None = None,
102
+ cache_dir: str | Path | None = None,
103
+ local_files_only: bool = False,
104
+ revision: str | None = None,
105
+ **kwargs,
106
+ ) -> T:
107
+ """
108
+ Download the object from the Huggingface Hub and instantiate it.
109
+
110
+ Args:
111
+ pretrained_name_or_path (`str`, `Path`):
112
+ - Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`.
113
+ - Or a path to a `directory` containing the object files saved using `.save_pretrained`,
114
+ e.g., `../path/to/my_model_directory/`.
115
+ revision (`str`, *optional*):
116
+ Revision on the Hub. Can be a branch name, a git tag or any commit id.
117
+ Defaults to the latest commit on `main` branch.
118
+ force_download (`bool`, *optional*, defaults to `False`):
119
+ Whether to force (re-)downloading the files from the Hub, overriding the existing cache.
120
+ proxies (`Dict[str, str]`, *optional*):
121
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
122
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
123
+ token (`str` or `bool`, *optional*):
124
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
125
+ cached when running `huggingface-cli login`.
126
+ cache_dir (`str`, `Path`, *optional*):
127
+ Path to the folder where cached files are stored.
128
+ local_files_only (`bool`, *optional*, defaults to `False`):
129
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
130
+ kwargs (`Dict`, *optional*):
131
+ Additional kwargs to pass to the object during initialization.
132
+ """
133
+ raise NotImplementedError
134
+
135
+ @validate_hf_hub_args
136
+ def push_to_hub(
137
+ self,
138
+ repo_id: str,
139
+ *,
140
+ commit_message: str | None = None,
141
+ private: bool | None = None,
142
+ token: str | None = None,
143
+ branch: str | None = None,
144
+ create_pr: bool | None = None,
145
+ allow_patterns: list[str] | str | None = None,
146
+ ignore_patterns: list[str] | str | None = None,
147
+ delete_patterns: list[str] | str | None = None,
148
+ card_kwargs: dict[str, Any] | None = None,
149
+ ) -> str:
150
+ """
151
+ Upload model checkpoint to the Hub.
152
+
153
+ Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
154
+ `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
155
+ details.
156
+
157
+ Args:
158
+ repo_id (`str`):
159
+ ID of the repository to push to (example: `"username/my-model"`).
160
+ commit_message (`str`, *optional*):
161
+ Message to commit while pushing.
162
+ private (`bool`, *optional*):
163
+ Whether the repository created should be private.
164
+ If `None` (default), the repo will be public unless the organization's default is private.
165
+ token (`str`, *optional*):
166
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
167
+ cached when running `huggingface-cli login`.
168
+ branch (`str`, *optional*):
169
+ The git branch on which to push the model. This defaults to `"main"`.
170
+ create_pr (`boolean`, *optional*):
171
+ Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
172
+ allow_patterns (`List[str]` or `str`, *optional*):
173
+ If provided, only files matching at least one pattern are pushed.
174
+ ignore_patterns (`List[str]` or `str`, *optional*):
175
+ If provided, files matching any of the patterns are not pushed.
176
+ delete_patterns (`List[str]` or `str`, *optional*):
177
+ If provided, remote files matching any of the patterns will be deleted from the repo.
178
+ card_kwargs (`Dict[str, Any]`, *optional*):
179
+ Additional arguments passed to the card template to customize the card.
180
+
181
+ Returns:
182
+ The url of the commit of your object in the given repository.
183
+ """
184
+ api = HfApi(token=token)
185
+ repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
186
+
187
+ if commit_message is None:
188
+ if "Policy" in self.__class__.__name__:
189
+ commit_message = "Upload policy"
190
+ elif "Config" in self.__class__.__name__:
191
+ commit_message = "Upload config"
192
+ else:
193
+ commit_message = f"Upload {self.__class__.__name__}"
194
+
195
+ # Push the files to the repo in a single commit
196
+ with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
197
+ saved_path = Path(tmp) / repo_id
198
+ self.save_pretrained(saved_path, card_kwargs=card_kwargs)
199
+ return api.upload_folder(
200
+ repo_id=repo_id,
201
+ repo_type="model",
202
+ folder_path=saved_path,
203
+ commit_message=commit_message,
204
+ revision=branch,
205
+ create_pr=create_pr,
206
+ allow_patterns=allow_patterns,
207
+ ignore_patterns=ignore_patterns,
208
+ delete_patterns=delete_patterns,
209
+ )
@@ -0,0 +1,79 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Utilities for checking package availability and versions.
18
+
19
+ This module provides functions to check if packages are installed and optionally
20
+ retrieve their versions without importing them, which is useful for conditional
21
+ imports and dependency checking.
22
+ """
23
+
24
+ import importlib
25
+ import logging
26
+
27
+
28
+ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
29
+ """Check if a package is available and optionally return its version.
30
+
31
+ Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
32
+
33
+ This function checks if the package spec exists and grabs its version to
34
+ avoid importing a local directory. Note: this doesn't work for all packages.
35
+
36
+ Args:
37
+ pkg_name: Name of the package to check.
38
+ return_version: If True, return a tuple of (available, version).
39
+ If False, return only the availability boolean. Defaults to False.
40
+
41
+ Returns:
42
+ If return_version is False, returns a boolean indicating availability.
43
+ If return_version is True, returns a tuple of (available, version).
44
+ """
45
+ package_exists = importlib.util.find_spec(pkg_name) is not None
46
+ package_version = "N/A"
47
+ if package_exists:
48
+ try:
49
+ # Primary method to get the package version
50
+ package_version = importlib.metadata.version(pkg_name)
51
+ except importlib.metadata.PackageNotFoundError:
52
+ # Fallback method: Only for "torch" and versions containing "dev"
53
+ if pkg_name == "torch":
54
+ try:
55
+ package = importlib.import_module(pkg_name)
56
+ temp_version = getattr(package, "__version__", "N/A")
57
+ # Check if the version contains "dev"
58
+ if "dev" in temp_version:
59
+ package_version = temp_version
60
+ package_exists = True
61
+ else:
62
+ package_exists = False
63
+ except ImportError:
64
+ # If the package can't be imported, it's not available
65
+ package_exists = False
66
+ else:
67
+ # For packages other than "torch", don't attempt the fallback and set as not available
68
+ package_exists = False
69
+ logging.debug(f"Detected {pkg_name} version: {package_version}")
70
+ if return_version:
71
+ return package_exists, package_version
72
+ else:
73
+ return package_exists
74
+
75
+
76
+ _torch_available, _torch_version = is_package_available("torch", return_version=True)
77
+ _gym_xarm_available = is_package_available("gym_xarm")
78
+ _gym_aloha_available = is_package_available("gym_aloha")
79
+ _gym_pusht_available = is_package_available("gym_pusht")