opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Utility functions and helpers for the OpenTau package.
|
|
16
|
+
|
|
17
|
+
This module provides various utility functions for logging, random number generation,
|
|
18
|
+
training utilities, device management, and other common operations used throughout
|
|
19
|
+
the OpenTau codebase.
|
|
20
|
+
"""
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Utilities for managing Accelerate accelerator instances.
|
|
15
|
+
|
|
16
|
+
This module provides functions for setting and getting a global accelerator
|
|
17
|
+
instance, which is useful for accessing accelerator state throughout the codebase.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import warnings
|
|
21
|
+
|
|
22
|
+
from accelerate import Accelerator
|
|
23
|
+
|
|
24
|
+
_acc: Accelerator | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def set_proc_accelerator(accelerator: Accelerator, allow_reset: bool = False) -> None:
|
|
28
|
+
"""Set the global accelerator instance for the current process.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
accelerator: Accelerator instance to set.
|
|
32
|
+
allow_reset: If True, allow resetting an already-set accelerator.
|
|
33
|
+
Defaults to False.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
AssertionError: If accelerator is not an Accelerator instance.
|
|
37
|
+
RuntimeError: If accelerator is already set and allow_reset is False.
|
|
38
|
+
"""
|
|
39
|
+
global _acc
|
|
40
|
+
|
|
41
|
+
assert isinstance(accelerator, Accelerator), (
|
|
42
|
+
f"Expected an `Accelerator` got {type(accelerator)} with value {accelerator}."
|
|
43
|
+
)
|
|
44
|
+
if _acc is not None:
|
|
45
|
+
if allow_reset:
|
|
46
|
+
warnings.warn(
|
|
47
|
+
"Resetting the accelerator. This could have unintended side effects.",
|
|
48
|
+
UserWarning,
|
|
49
|
+
stacklevel=2,
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
raise RuntimeError("Accelerator has already been set.")
|
|
53
|
+
_acc = accelerator
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_proc_accelerator() -> Accelerator:
|
|
57
|
+
"""Get the global accelerator instance for the current process.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The accelerator instance, or None if not set.
|
|
61
|
+
"""
|
|
62
|
+
return _acc
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def acc_print(*args, **kwargs) -> None:
|
|
66
|
+
"""Print with process index prefix when using accelerate.
|
|
67
|
+
|
|
68
|
+
If an accelerator is set, prints with a prefix showing the process index.
|
|
69
|
+
Otherwise, prints normally.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
*args: Positional arguments to pass to print.
|
|
73
|
+
**kwargs: Keyword arguments to pass to print.
|
|
74
|
+
"""
|
|
75
|
+
acc = get_proc_accelerator()
|
|
76
|
+
if acc is None:
|
|
77
|
+
print(*args, **kwargs)
|
|
78
|
+
else:
|
|
79
|
+
print(f"Acc[{acc.process_index} of {acc.num_processes}]", *args, **kwargs)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Utilities for benchmarking and timing code execution.
|
|
18
|
+
|
|
19
|
+
This module provides the TimeBenchmark class for measuring execution time
|
|
20
|
+
using context managers or decorators in a thread-safe manner.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import threading
|
|
24
|
+
import time
|
|
25
|
+
from contextlib import ContextDecorator
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TimeBenchmark(ContextDecorator):
|
|
29
|
+
"""
|
|
30
|
+
Measures execution time using a context manager or decorator.
|
|
31
|
+
|
|
32
|
+
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
|
|
33
|
+
environments.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
print: If True, prints the elapsed time upon exiting the context or completing the function.
|
|
37
|
+
Defaults to False.
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
|
|
41
|
+
Using as a context manager::
|
|
42
|
+
|
|
43
|
+
>>> benchmark = TimeBenchmark()
|
|
44
|
+
>>> with benchmark:
|
|
45
|
+
... time.sleep(1)
|
|
46
|
+
>>> print(f"Block took {benchmark.result:.4f} seconds")
|
|
47
|
+
Block took approximately 1.0000 seconds
|
|
48
|
+
|
|
49
|
+
Using with multithreading::
|
|
50
|
+
|
|
51
|
+
import threading
|
|
52
|
+
|
|
53
|
+
benchmark = TimeBenchmark()
|
|
54
|
+
|
|
55
|
+
def context_manager_example():
|
|
56
|
+
with benchmark:
|
|
57
|
+
time.sleep(0.01)
|
|
58
|
+
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
|
|
59
|
+
|
|
60
|
+
threads = []
|
|
61
|
+
for _ in range(3):
|
|
62
|
+
t1 = threading.Thread(target=context_manager_example)
|
|
63
|
+
threads.append(t1)
|
|
64
|
+
|
|
65
|
+
for t in threads:
|
|
66
|
+
t.start()
|
|
67
|
+
|
|
68
|
+
for t in threads:
|
|
69
|
+
t.join()
|
|
70
|
+
|
|
71
|
+
# Expected output:
|
|
72
|
+
# Block took approximately 10.00 milliseconds
|
|
73
|
+
# Block took approximately 10.00 milliseconds
|
|
74
|
+
# Block took approximately 10.00 milliseconds
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, print=False):
|
|
78
|
+
self.local = threading.local()
|
|
79
|
+
self.print_time = print
|
|
80
|
+
|
|
81
|
+
def __enter__(self):
|
|
82
|
+
self.local.start_time = time.perf_counter()
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
def __exit__(self, *exc):
|
|
86
|
+
self.local.end_time = time.perf_counter()
|
|
87
|
+
self.local.elapsed_time = self.local.end_time - self.local.start_time
|
|
88
|
+
if self.print_time:
|
|
89
|
+
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def result(self):
|
|
94
|
+
return getattr(self.local, "elapsed_time", None)
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def result_ms(self):
|
|
98
|
+
return self.result * 1e3
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Utilities for working with PyTorch FakeTensor.
|
|
15
|
+
|
|
16
|
+
This module provides a FakeTensorContext class and decorator for running code
|
|
17
|
+
with FakeTensor mode enabled, which is useful for shape inference and testing
|
|
18
|
+
without actual tensor computations.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import functools
|
|
22
|
+
|
|
23
|
+
from torch._subclasses import FakeTensorMode
|
|
24
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
25
|
+
|
|
26
|
+
from opentau.utils.monkey_patch import (
|
|
27
|
+
torch_fake_tensor_beta_validate_args_patch,
|
|
28
|
+
torch_fake_tensor_is_inf_patch,
|
|
29
|
+
torch_fake_tensor_module_to_patch,
|
|
30
|
+
torch_fake_tensor_to_numpy_patch,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Share the ShapeEnv instance across all FakeTensorContext instances
|
|
34
|
+
# Without this, each FakeTensor.item() call would start numbering from 0, which is wrong.
|
|
35
|
+
_shared_shape_env = ShapeEnv()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class FakeTensorContext:
|
|
39
|
+
"""Context manager for enabling FakeTensor mode with necessary patches.
|
|
40
|
+
|
|
41
|
+
This context manager applies all necessary monkey patches for FakeTensor
|
|
42
|
+
compatibility and manages the FakeTensorMode lifecycle.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
allow_non_fake_inputs: If True, allow non-fake tensors as inputs.
|
|
46
|
+
Defaults to True.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, allow_non_fake_inputs: bool = True):
|
|
50
|
+
self.mode = FakeTensorMode(
|
|
51
|
+
shape_env=_shared_shape_env,
|
|
52
|
+
allow_non_fake_inputs=allow_non_fake_inputs,
|
|
53
|
+
)
|
|
54
|
+
torch_fake_tensor_module_to_patch()
|
|
55
|
+
torch_fake_tensor_to_numpy_patch()
|
|
56
|
+
torch_fake_tensor_beta_validate_args_patch()
|
|
57
|
+
torch_fake_tensor_is_inf_patch()
|
|
58
|
+
|
|
59
|
+
def __enter__(self):
|
|
60
|
+
return self.mode.__enter__()
|
|
61
|
+
|
|
62
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
63
|
+
return self.mode.__exit__(exc_type, exc_val, exc_tb)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def run_with_fake_tensor(fn):
|
|
67
|
+
"""Decorator to run a function with FakeTensor enabled.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
fn: Function to wrap.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Wrapped function that runs with FakeTensorContext enabled.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
@functools.wraps(fn)
|
|
77
|
+
def wrapper(*args, **kwargs):
|
|
78
|
+
with FakeTensorContext():
|
|
79
|
+
return fn(*args, **kwargs)
|
|
80
|
+
|
|
81
|
+
return wrapper
|
opentau/utils/hub.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""Utilities for interacting with the Hugging Face Hub.
|
|
16
|
+
|
|
17
|
+
This module provides the HubMixin class which enables objects to be saved to
|
|
18
|
+
and loaded from the Hugging Face Hub, similar to ModelHubMixin but with fewer
|
|
19
|
+
assumptions about the object type.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from tempfile import TemporaryDirectory
|
|
24
|
+
from typing import Any, Type, TypeVar
|
|
25
|
+
|
|
26
|
+
from huggingface_hub import HfApi
|
|
27
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
|
28
|
+
|
|
29
|
+
T = TypeVar("T", bound="HubMixin")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class HubMixin:
|
|
33
|
+
"""
|
|
34
|
+
A Mixin containing the functionality to push an object to the hub.
|
|
35
|
+
|
|
36
|
+
This is similar to huggingface_hub.ModelHubMixin but is lighter and makes less assumptions about its
|
|
37
|
+
subclasses (in particular, the fact that it's not necessarily a model).
|
|
38
|
+
|
|
39
|
+
The inheriting classes must implement '_save_pretrained' and 'from_pretrained'.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def save_pretrained(
|
|
43
|
+
self,
|
|
44
|
+
save_directory: str | Path,
|
|
45
|
+
*,
|
|
46
|
+
repo_id: str | None = None,
|
|
47
|
+
push_to_hub: bool = False,
|
|
48
|
+
card_kwargs: dict[str, Any] | None = None,
|
|
49
|
+
**push_to_hub_kwargs,
|
|
50
|
+
) -> str | None:
|
|
51
|
+
"""
|
|
52
|
+
Save object in local directory.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
save_directory (`str` or `Path`):
|
|
56
|
+
Path to directory in which the object will be saved.
|
|
57
|
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
58
|
+
Whether or not to push your object to the Huggingface Hub after saving it.
|
|
59
|
+
repo_id (`str`, *optional*):
|
|
60
|
+
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
|
61
|
+
not provided.
|
|
62
|
+
card_kwargs (`Dict[str, Any]`, *optional*):
|
|
63
|
+
Additional arguments passed to the card template to customize the card.
|
|
64
|
+
push_to_hub_kwargs:
|
|
65
|
+
Additional key word arguments passed along to the [`~HubMixin.push_to_hub`] method.
|
|
66
|
+
Returns:
|
|
67
|
+
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
|
|
68
|
+
"""
|
|
69
|
+
save_directory = Path(save_directory)
|
|
70
|
+
save_directory.mkdir(parents=True, exist_ok=True)
|
|
71
|
+
|
|
72
|
+
# save object (weights, files, etc.)
|
|
73
|
+
self._save_pretrained(save_directory)
|
|
74
|
+
|
|
75
|
+
# push to the Hub if required
|
|
76
|
+
if push_to_hub:
|
|
77
|
+
if repo_id is None:
|
|
78
|
+
repo_id = save_directory.name # Defaults to `save_directory` name
|
|
79
|
+
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
def _save_pretrained(self, save_directory: Path) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Overwrite this method in subclass to define how to save your object.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
save_directory (`str` or `Path`):
|
|
88
|
+
Path to directory in which the object files will be saved.
|
|
89
|
+
"""
|
|
90
|
+
raise NotImplementedError
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
@validate_hf_hub_args
|
|
94
|
+
def from_pretrained(
|
|
95
|
+
cls: Type[T],
|
|
96
|
+
pretrained_name_or_path: str | Path,
|
|
97
|
+
*,
|
|
98
|
+
force_download: bool = False,
|
|
99
|
+
resume_download: bool | None = None,
|
|
100
|
+
proxies: dict | None = None,
|
|
101
|
+
token: str | bool | None = None,
|
|
102
|
+
cache_dir: str | Path | None = None,
|
|
103
|
+
local_files_only: bool = False,
|
|
104
|
+
revision: str | None = None,
|
|
105
|
+
**kwargs,
|
|
106
|
+
) -> T:
|
|
107
|
+
"""
|
|
108
|
+
Download the object from the Huggingface Hub and instantiate it.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
pretrained_name_or_path (`str`, `Path`):
|
|
112
|
+
- Either the `repo_id` (string) of the object hosted on the Hub, e.g. `lerobot/diffusion_pusht`.
|
|
113
|
+
- Or a path to a `directory` containing the object files saved using `.save_pretrained`,
|
|
114
|
+
e.g., `../path/to/my_model_directory/`.
|
|
115
|
+
revision (`str`, *optional*):
|
|
116
|
+
Revision on the Hub. Can be a branch name, a git tag or any commit id.
|
|
117
|
+
Defaults to the latest commit on `main` branch.
|
|
118
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
|
119
|
+
Whether to force (re-)downloading the files from the Hub, overriding the existing cache.
|
|
120
|
+
proxies (`Dict[str, str]`, *optional*):
|
|
121
|
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
122
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
|
123
|
+
token (`str` or `bool`, *optional*):
|
|
124
|
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
125
|
+
cached when running `huggingface-cli login`.
|
|
126
|
+
cache_dir (`str`, `Path`, *optional*):
|
|
127
|
+
Path to the folder where cached files are stored.
|
|
128
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
129
|
+
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
|
130
|
+
kwargs (`Dict`, *optional*):
|
|
131
|
+
Additional kwargs to pass to the object during initialization.
|
|
132
|
+
"""
|
|
133
|
+
raise NotImplementedError
|
|
134
|
+
|
|
135
|
+
@validate_hf_hub_args
|
|
136
|
+
def push_to_hub(
|
|
137
|
+
self,
|
|
138
|
+
repo_id: str,
|
|
139
|
+
*,
|
|
140
|
+
commit_message: str | None = None,
|
|
141
|
+
private: bool | None = None,
|
|
142
|
+
token: str | None = None,
|
|
143
|
+
branch: str | None = None,
|
|
144
|
+
create_pr: bool | None = None,
|
|
145
|
+
allow_patterns: list[str] | str | None = None,
|
|
146
|
+
ignore_patterns: list[str] | str | None = None,
|
|
147
|
+
delete_patterns: list[str] | str | None = None,
|
|
148
|
+
card_kwargs: dict[str, Any] | None = None,
|
|
149
|
+
) -> str:
|
|
150
|
+
"""
|
|
151
|
+
Upload model checkpoint to the Hub.
|
|
152
|
+
|
|
153
|
+
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
|
|
154
|
+
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
|
|
155
|
+
details.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
repo_id (`str`):
|
|
159
|
+
ID of the repository to push to (example: `"username/my-model"`).
|
|
160
|
+
commit_message (`str`, *optional*):
|
|
161
|
+
Message to commit while pushing.
|
|
162
|
+
private (`bool`, *optional*):
|
|
163
|
+
Whether the repository created should be private.
|
|
164
|
+
If `None` (default), the repo will be public unless the organization's default is private.
|
|
165
|
+
token (`str`, *optional*):
|
|
166
|
+
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
167
|
+
cached when running `huggingface-cli login`.
|
|
168
|
+
branch (`str`, *optional*):
|
|
169
|
+
The git branch on which to push the model. This defaults to `"main"`.
|
|
170
|
+
create_pr (`boolean`, *optional*):
|
|
171
|
+
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
|
172
|
+
allow_patterns (`List[str]` or `str`, *optional*):
|
|
173
|
+
If provided, only files matching at least one pattern are pushed.
|
|
174
|
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
|
175
|
+
If provided, files matching any of the patterns are not pushed.
|
|
176
|
+
delete_patterns (`List[str]` or `str`, *optional*):
|
|
177
|
+
If provided, remote files matching any of the patterns will be deleted from the repo.
|
|
178
|
+
card_kwargs (`Dict[str, Any]`, *optional*):
|
|
179
|
+
Additional arguments passed to the card template to customize the card.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
The url of the commit of your object in the given repository.
|
|
183
|
+
"""
|
|
184
|
+
api = HfApi(token=token)
|
|
185
|
+
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
|
186
|
+
|
|
187
|
+
if commit_message is None:
|
|
188
|
+
if "Policy" in self.__class__.__name__:
|
|
189
|
+
commit_message = "Upload policy"
|
|
190
|
+
elif "Config" in self.__class__.__name__:
|
|
191
|
+
commit_message = "Upload config"
|
|
192
|
+
else:
|
|
193
|
+
commit_message = f"Upload {self.__class__.__name__}"
|
|
194
|
+
|
|
195
|
+
# Push the files to the repo in a single commit
|
|
196
|
+
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
|
197
|
+
saved_path = Path(tmp) / repo_id
|
|
198
|
+
self.save_pretrained(saved_path, card_kwargs=card_kwargs)
|
|
199
|
+
return api.upload_folder(
|
|
200
|
+
repo_id=repo_id,
|
|
201
|
+
repo_type="model",
|
|
202
|
+
folder_path=saved_path,
|
|
203
|
+
commit_message=commit_message,
|
|
204
|
+
revision=branch,
|
|
205
|
+
create_pr=create_pr,
|
|
206
|
+
allow_patterns=allow_patterns,
|
|
207
|
+
ignore_patterns=ignore_patterns,
|
|
208
|
+
delete_patterns=delete_patterns,
|
|
209
|
+
)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Utilities for checking package availability and versions.
|
|
18
|
+
|
|
19
|
+
This module provides functions to check if packages are installed and optionally
|
|
20
|
+
retrieve their versions without importing them, which is useful for conditional
|
|
21
|
+
imports and dependency checking.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import importlib
|
|
25
|
+
import logging
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
|
29
|
+
"""Check if a package is available and optionally return its version.
|
|
30
|
+
|
|
31
|
+
Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
|
32
|
+
|
|
33
|
+
This function checks if the package spec exists and grabs its version to
|
|
34
|
+
avoid importing a local directory. Note: this doesn't work for all packages.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
pkg_name: Name of the package to check.
|
|
38
|
+
return_version: If True, return a tuple of (available, version).
|
|
39
|
+
If False, return only the availability boolean. Defaults to False.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
If return_version is False, returns a boolean indicating availability.
|
|
43
|
+
If return_version is True, returns a tuple of (available, version).
|
|
44
|
+
"""
|
|
45
|
+
package_exists = importlib.util.find_spec(pkg_name) is not None
|
|
46
|
+
package_version = "N/A"
|
|
47
|
+
if package_exists:
|
|
48
|
+
try:
|
|
49
|
+
# Primary method to get the package version
|
|
50
|
+
package_version = importlib.metadata.version(pkg_name)
|
|
51
|
+
except importlib.metadata.PackageNotFoundError:
|
|
52
|
+
# Fallback method: Only for "torch" and versions containing "dev"
|
|
53
|
+
if pkg_name == "torch":
|
|
54
|
+
try:
|
|
55
|
+
package = importlib.import_module(pkg_name)
|
|
56
|
+
temp_version = getattr(package, "__version__", "N/A")
|
|
57
|
+
# Check if the version contains "dev"
|
|
58
|
+
if "dev" in temp_version:
|
|
59
|
+
package_version = temp_version
|
|
60
|
+
package_exists = True
|
|
61
|
+
else:
|
|
62
|
+
package_exists = False
|
|
63
|
+
except ImportError:
|
|
64
|
+
# If the package can't be imported, it's not available
|
|
65
|
+
package_exists = False
|
|
66
|
+
else:
|
|
67
|
+
# For packages other than "torch", don't attempt the fallback and set as not available
|
|
68
|
+
package_exists = False
|
|
69
|
+
logging.debug(f"Detected {pkg_name} version: {package_version}")
|
|
70
|
+
if return_version:
|
|
71
|
+
return package_exists, package_version
|
|
72
|
+
else:
|
|
73
|
+
return package_exists
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
|
77
|
+
_gym_xarm_available = is_package_available("gym_xarm")
|
|
78
|
+
_gym_aloha_available = is_package_available("gym_aloha")
|
|
79
|
+
_gym_pusht_available = is_package_available("gym_pusht")
|