opentau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/__init__.py +179 -0
- opentau/__version__.py +24 -0
- opentau/configs/__init__.py +19 -0
- opentau/configs/default.py +297 -0
- opentau/configs/libero.py +113 -0
- opentau/configs/parser.py +393 -0
- opentau/configs/policies.py +297 -0
- opentau/configs/reward.py +42 -0
- opentau/configs/train.py +370 -0
- opentau/configs/types.py +76 -0
- opentau/constants.py +52 -0
- opentau/datasets/__init__.py +84 -0
- opentau/datasets/backward_compatibility.py +78 -0
- opentau/datasets/compute_stats.py +333 -0
- opentau/datasets/dataset_mixture.py +460 -0
- opentau/datasets/factory.py +232 -0
- opentau/datasets/grounding/__init__.py +67 -0
- opentau/datasets/grounding/base.py +154 -0
- opentau/datasets/grounding/clevr.py +110 -0
- opentau/datasets/grounding/cocoqa.py +130 -0
- opentau/datasets/grounding/dummy.py +101 -0
- opentau/datasets/grounding/pixmo.py +177 -0
- opentau/datasets/grounding/vsr.py +141 -0
- opentau/datasets/image_writer.py +304 -0
- opentau/datasets/lerobot_dataset.py +1910 -0
- opentau/datasets/online_buffer.py +442 -0
- opentau/datasets/push_dataset_to_hub/utils.py +132 -0
- opentau/datasets/sampler.py +99 -0
- opentau/datasets/standard_data_format_mapping.py +278 -0
- opentau/datasets/transforms.py +330 -0
- opentau/datasets/utils.py +1243 -0
- opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
- opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
- opentau/datasets/v21/_remove_language_instruction.py +109 -0
- opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
- opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
- opentau/datasets/v21/convert_stats.py +150 -0
- opentau/datasets/video_utils.py +597 -0
- opentau/envs/__init__.py +18 -0
- opentau/envs/configs.py +178 -0
- opentau/envs/factory.py +99 -0
- opentau/envs/libero.py +439 -0
- opentau/envs/utils.py +204 -0
- opentau/optim/__init__.py +16 -0
- opentau/optim/factory.py +43 -0
- opentau/optim/optimizers.py +121 -0
- opentau/optim/schedulers.py +140 -0
- opentau/planner/__init__.py +82 -0
- opentau/planner/high_level_planner.py +366 -0
- opentau/planner/utils/memory.py +64 -0
- opentau/planner/utils/utils.py +65 -0
- opentau/policies/__init__.py +24 -0
- opentau/policies/factory.py +172 -0
- opentau/policies/normalize.py +315 -0
- opentau/policies/pi0/__init__.py +19 -0
- opentau/policies/pi0/configuration_pi0.py +250 -0
- opentau/policies/pi0/modeling_pi0.py +994 -0
- opentau/policies/pi0/paligemma_with_expert.py +516 -0
- opentau/policies/pi05/__init__.py +20 -0
- opentau/policies/pi05/configuration_pi05.py +231 -0
- opentau/policies/pi05/modeling_pi05.py +1257 -0
- opentau/policies/pi05/paligemma_with_expert.py +572 -0
- opentau/policies/pretrained.py +315 -0
- opentau/policies/utils.py +123 -0
- opentau/policies/value/__init__.py +18 -0
- opentau/policies/value/configuration_value.py +170 -0
- opentau/policies/value/modeling_value.py +512 -0
- opentau/policies/value/reward.py +87 -0
- opentau/policies/value/siglip_gemma.py +221 -0
- opentau/scripts/actions_mse_loss.py +89 -0
- opentau/scripts/bin_to_safetensors.py +116 -0
- opentau/scripts/compute_max_token_length.py +111 -0
- opentau/scripts/display_sys_info.py +90 -0
- opentau/scripts/download_libero_benchmarks.py +54 -0
- opentau/scripts/eval.py +877 -0
- opentau/scripts/export_to_onnx.py +180 -0
- opentau/scripts/fake_tensor_training.py +87 -0
- opentau/scripts/get_advantage_and_percentiles.py +220 -0
- opentau/scripts/high_level_planner_inference.py +114 -0
- opentau/scripts/inference.py +70 -0
- opentau/scripts/launch_train.py +63 -0
- opentau/scripts/libero_simulation_parallel.py +356 -0
- opentau/scripts/libero_simulation_sequential.py +122 -0
- opentau/scripts/nav_high_level_planner_inference.py +61 -0
- opentau/scripts/train.py +379 -0
- opentau/scripts/visualize_dataset.py +294 -0
- opentau/scripts/visualize_dataset_html.py +507 -0
- opentau/scripts/zero_to_fp32.py +760 -0
- opentau/utils/__init__.py +20 -0
- opentau/utils/accelerate_utils.py +79 -0
- opentau/utils/benchmark.py +98 -0
- opentau/utils/fake_tensor.py +81 -0
- opentau/utils/hub.py +209 -0
- opentau/utils/import_utils.py +79 -0
- opentau/utils/io_utils.py +137 -0
- opentau/utils/libero.py +214 -0
- opentau/utils/libero_dataset_recorder.py +460 -0
- opentau/utils/logging_utils.py +180 -0
- opentau/utils/monkey_patch.py +278 -0
- opentau/utils/random_utils.py +244 -0
- opentau/utils/train_utils.py +198 -0
- opentau/utils/utils.py +471 -0
- opentau-0.1.0.dist-info/METADATA +161 -0
- opentau-0.1.0.dist-info/RECORD +108 -0
- opentau-0.1.0.dist-info/WHEEL +5 -0
- opentau-0.1.0.dist-info/entry_points.txt +2 -0
- opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
- opentau-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
r"""Monkey patches to inject behaviour into computational graph construction. This eliminates the need to trace and
|
|
16
|
+
modify source code of 3rd party libraries, such as transformers or even PyTorch itself.
|
|
17
|
+
|
|
18
|
+
Where necessary, we can do
|
|
19
|
+
|
|
20
|
+
>>> from opentau.utils.monkey_patch import torch_cumsum_patch, torch_pow_patch
|
|
21
|
+
>>> torch_cumsum_patch() # Apply the patch to handle bool tensors in cumsum
|
|
22
|
+
>>> torch_pow_patch() # Apply the patch to handle mixed number-tensor exponentiation in pow
|
|
23
|
+
|
|
24
|
+
to apply the patches.
|
|
25
|
+
|
|
26
|
+
Running the same monkey patch twice will not have any effect, as the 2nd call will return immediately.
|
|
27
|
+
|
|
28
|
+
Note: currently, there is no way to undo a monkey patch once it has been applied, which can be a future to-do.
|
|
29
|
+
For now, only apply the patch when its implications are understood.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
import importlib
|
|
33
|
+
import logging
|
|
34
|
+
import sys
|
|
35
|
+
from functools import wraps
|
|
36
|
+
|
|
37
|
+
import numpy as np
|
|
38
|
+
import torch
|
|
39
|
+
|
|
40
|
+
# global singleton to track which patches have been applied
|
|
41
|
+
__patches_applied = set()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _run_once_only(func):
|
|
45
|
+
"""Decorator that ensures the function is run once only.
|
|
46
|
+
|
|
47
|
+
Subsequent calls to the function will return immediately without executing
|
|
48
|
+
the function body again.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
func: Function to be wrapped.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Wrapped function that executes only once.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
@wraps(func)
|
|
58
|
+
def inner(*args, **kwargs):
|
|
59
|
+
if func in __patches_applied:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
__patches_applied.add(func)
|
|
63
|
+
logging.debug(f"Applying monkey patch: {func.__name__} in {func.__module__}")
|
|
64
|
+
return func(*args, **kwargs)
|
|
65
|
+
|
|
66
|
+
return inner
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def patches_applied():
|
|
70
|
+
"""Get a list of all patches that have been applied.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
List of patch function names that have been applied.
|
|
74
|
+
"""
|
|
75
|
+
return [func.__name__ for func in __patches_applied]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@_run_once_only
|
|
79
|
+
def torch_cumsum_patch():
|
|
80
|
+
"""Override torch.cumsum to handle bool tensors correctly.
|
|
81
|
+
|
|
82
|
+
PyTorch allows cumsum on bool tensors, but ONNX Runtime does not.
|
|
83
|
+
This patch converts bool tensors to int64 before calling cumsum.
|
|
84
|
+
"""
|
|
85
|
+
original_cumsum = torch.cumsum
|
|
86
|
+
|
|
87
|
+
def _patched_cumsum(tensor, *args, **kwargs):
|
|
88
|
+
if tensor.dtype == torch.bool:
|
|
89
|
+
tensor = tensor.to(torch.int64)
|
|
90
|
+
return original_cumsum(tensor, *args, **kwargs)
|
|
91
|
+
|
|
92
|
+
torch.cumsum = _patched_cumsum
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@_run_once_only
|
|
96
|
+
def torch_pow_patch():
|
|
97
|
+
"""Override torch.pow to ensure both base and exponent are tensors.
|
|
98
|
+
|
|
99
|
+
This patch converts scalar arguments to tensors before calling pow,
|
|
100
|
+
ensuring compatibility with ONNX export.
|
|
101
|
+
"""
|
|
102
|
+
original_pow = torch.pow
|
|
103
|
+
|
|
104
|
+
def _patched_pow(base, exponent, *args, **kwargs):
|
|
105
|
+
if not isinstance(base, torch.Tensor):
|
|
106
|
+
base = torch.tensor(base, dtype=exponent.dtype, device=exponent.device)
|
|
107
|
+
if not isinstance(exponent, torch.Tensor):
|
|
108
|
+
exponent = torch.tensor(exponent, dtype=base.dtype, device=base.device)
|
|
109
|
+
return original_pow(base, exponent, *args, **kwargs)
|
|
110
|
+
|
|
111
|
+
torch.pow = _patched_pow
|
|
112
|
+
torch.Tensor.pow = _patched_pow
|
|
113
|
+
torch.Tensor.__pow__ = _patched_pow
|
|
114
|
+
# At least for torch 2.7, `__rpow__` is already defined like this, but we ensure it for future compatibility
|
|
115
|
+
torch.Tensor.__rpow__ = lambda x, y: _patched_pow(y, x)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@_run_once_only
|
|
119
|
+
def torch_full_patch():
|
|
120
|
+
"""Override torch.full to convert bool fill values to int.
|
|
121
|
+
|
|
122
|
+
This patch ensures that True/False are converted to 1/0 before reaching
|
|
123
|
+
the C++ level, improving compatibility with certain backends.
|
|
124
|
+
"""
|
|
125
|
+
original_full = torch.full
|
|
126
|
+
|
|
127
|
+
def _patched_full(size, fill_value, *args, **kwargs):
|
|
128
|
+
if isinstance(fill_value, bool):
|
|
129
|
+
fill_value = int(fill_value)
|
|
130
|
+
return original_full(size, fill_value, *args, **kwargs)
|
|
131
|
+
|
|
132
|
+
torch.full = _patched_full
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@_run_once_only
|
|
136
|
+
def torch_fake_tensor_module_to_patch():
|
|
137
|
+
"""Fix torch.nn.Module.to(device) behavior in FakeTensorMode.
|
|
138
|
+
|
|
139
|
+
Without this patch, Module.to(device) is a no-op in FakeTensorMode, leading
|
|
140
|
+
to device mismatch errors. This patch enables proper device conversion.
|
|
141
|
+
|
|
142
|
+
See https://github.com/pytorch/pytorch/issues/119665 for more details.
|
|
143
|
+
"""
|
|
144
|
+
torch.__future__.set_overwrite_module_params_on_conversion(True)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@_run_once_only
|
|
148
|
+
def torch_fake_tensor_to_numpy_patch():
|
|
149
|
+
"""Enable .numpy() calls on FakeTensor to return random numpy arrays.
|
|
150
|
+
|
|
151
|
+
This patch allows .numpy() to be called on FakeTensor instances, returning
|
|
152
|
+
numpy arrays with random values. Note that calling .numpy() multiple times
|
|
153
|
+
on the same FakeTensor may return different values.
|
|
154
|
+
"""
|
|
155
|
+
_torch2np = {
|
|
156
|
+
torch.float32: np.float32,
|
|
157
|
+
torch.float64: np.float64,
|
|
158
|
+
torch.float16: np.float16,
|
|
159
|
+
# torch.bfloat16 is intentionally excluded as it is not supported by numpy
|
|
160
|
+
torch.int64: np.int64,
|
|
161
|
+
torch.int32: np.int32,
|
|
162
|
+
torch.int16: np.int16,
|
|
163
|
+
torch.int8: np.int8,
|
|
164
|
+
torch.uint8: np.uint8,
|
|
165
|
+
torch.bool: np.bool_,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def _patched_numpy(self: torch._subclasses.fake_tensor.FakeTensor, /):
|
|
169
|
+
if self.device.type != "cpu":
|
|
170
|
+
raise RuntimeError(
|
|
171
|
+
f"FakeTensor.numpy() can only be called on CPU tensors. This tensor is on {self.device}"
|
|
172
|
+
)
|
|
173
|
+
if self.requires_grad:
|
|
174
|
+
raise RuntimeError(
|
|
175
|
+
".numpy() cannot be called on tensors that require gradients. Call tensor.detach().numpy() instead."
|
|
176
|
+
)
|
|
177
|
+
if self.dtype not in _torch2np:
|
|
178
|
+
raise RuntimeError(f"Unsupported dtype {self.dtype} for FakeTensor.numpy()")
|
|
179
|
+
|
|
180
|
+
# `np.random.rand()` returns a float instead of a nil-dim array
|
|
181
|
+
# So we wrap it in np.array() to ensure the shape is preserved
|
|
182
|
+
return np.array(np.random.rand(*self.shape)).astype(_torch2np[self.dtype])
|
|
183
|
+
|
|
184
|
+
torch._subclasses.fake_tensor.FakeTensor.numpy = _patched_numpy
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@_run_once_only
|
|
188
|
+
def torch_fake_tensor_beta_validate_args_patch():
|
|
189
|
+
"""Fix torch.distributions.Beta to work in FakeTensorMode.
|
|
190
|
+
|
|
191
|
+
This patch sets validate_args=False by default for Beta distributions,
|
|
192
|
+
which is required for FakeTensorMode compatibility.
|
|
193
|
+
"""
|
|
194
|
+
original_beta_init = torch.distributions.Beta.__init__
|
|
195
|
+
|
|
196
|
+
def _patched_beta_init(self, *args, **kwargs):
|
|
197
|
+
kwargs.setdefault("validate_args", False)
|
|
198
|
+
original_beta_init(self, *args, **kwargs)
|
|
199
|
+
|
|
200
|
+
torch.distributions.Beta.__init__ = _patched_beta_init
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@_run_once_only
|
|
204
|
+
def torch_fake_tensor_is_inf_patch():
|
|
205
|
+
"""Patch torch.isinf to work with FakeTensor.
|
|
206
|
+
|
|
207
|
+
This patch provides a mock implementation of torch.isinf that returns
|
|
208
|
+
a mock object compatible with FakeTensor operations.
|
|
209
|
+
"""
|
|
210
|
+
from unittest.mock import Mock
|
|
211
|
+
|
|
212
|
+
def _patched_isinf(x):
|
|
213
|
+
obj = Mock()
|
|
214
|
+
obj.dtype = torch.bool
|
|
215
|
+
obj.shape = x.shape
|
|
216
|
+
obj.any.return_value = False
|
|
217
|
+
obj.all.return_value = False
|
|
218
|
+
return obj
|
|
219
|
+
|
|
220
|
+
torch.isinf = _patched_isinf
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@_run_once_only
|
|
224
|
+
def gym_is_gymnasium_patch():
|
|
225
|
+
"""Monkey patch to make `import gym` equivalent to `import gymnasium as gym`.
|
|
226
|
+
|
|
227
|
+
This patch is necessary because the original gym package is incompatible
|
|
228
|
+
with numpy >= 2.0. It redirects gym imports to use gymnasium instead.
|
|
229
|
+
"""
|
|
230
|
+
_g = importlib.import_module("gymnasium")
|
|
231
|
+
sys.modules.setdefault("gym", _g)
|
|
232
|
+
|
|
233
|
+
# This is a non-exhaustive list. More submodules may be added in the future as needed.
|
|
234
|
+
# A more compressive solution would involve a lower-level approach using `finder`s and `loader`s.
|
|
235
|
+
# See https://docs.python.org/3/reference/import.html#finders-and-loaders
|
|
236
|
+
subpackages = [
|
|
237
|
+
"spaces",
|
|
238
|
+
"envs",
|
|
239
|
+
"envs.classic_control",
|
|
240
|
+
"envs.mujoco",
|
|
241
|
+
"envs.toy_text",
|
|
242
|
+
"wrappers",
|
|
243
|
+
"vector",
|
|
244
|
+
"vector.utils",
|
|
245
|
+
"utils",
|
|
246
|
+
]
|
|
247
|
+
|
|
248
|
+
for sub in subpackages:
|
|
249
|
+
try:
|
|
250
|
+
old_name = f"gym.{sub}"
|
|
251
|
+
new_name = f"gymnasium.{sub}"
|
|
252
|
+
if old_name in sys.modules:
|
|
253
|
+
print(f"Module {old_name} already exists in sys.modules, skipping import of {new_name}")
|
|
254
|
+
else:
|
|
255
|
+
# Assuming importing the submodule has no side effects, which should be true for gymnasium
|
|
256
|
+
sys.modules[old_name] = importlib.import_module(new_name)
|
|
257
|
+
except (ImportError, ModuleNotFoundError):
|
|
258
|
+
print("Failed to import gymnasium submodule:", sub, file=sys.stderr)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@_run_once_only
|
|
262
|
+
def torch_load_patch():
|
|
263
|
+
"""Override torch.load to handle weights_only argument.
|
|
264
|
+
|
|
265
|
+
This patch ensures that torch.load properly handles the weights_only
|
|
266
|
+
argument for PyTorch versions >= 2.6, setting it to False by default
|
|
267
|
+
if not explicitly provided.
|
|
268
|
+
"""
|
|
269
|
+
if torch.__version__ < "2.6":
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
original_load = torch.load
|
|
273
|
+
|
|
274
|
+
def _patched_load(*args, weights_only=..., **kwargs):
|
|
275
|
+
kwargs["weights_only"] = False if weights_only is ... else weights_only
|
|
276
|
+
return original_load(*args, **kwargs)
|
|
277
|
+
|
|
278
|
+
torch.load = _patched_load
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
4
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Utilities for managing random number generator states.
|
|
18
|
+
|
|
19
|
+
This module provides functions for serializing, deserializing, saving, and loading
|
|
20
|
+
random number generator states for Python's random module, NumPy, and PyTorch.
|
|
21
|
+
This is essential for reproducibility in training and evaluation.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import random
|
|
25
|
+
from contextlib import contextmanager
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any, Generator
|
|
28
|
+
|
|
29
|
+
import accelerate
|
|
30
|
+
import numpy as np
|
|
31
|
+
import torch
|
|
32
|
+
from safetensors.torch import load_file, save_file
|
|
33
|
+
|
|
34
|
+
from opentau.constants import RNG_STATE
|
|
35
|
+
from opentau.datasets.utils import flatten_dict, unflatten_dict
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
|
|
39
|
+
"""Serialize Python's random module RNG state to a dictionary.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Dictionary containing the RNG state as torch.Tensor values, suitable
|
|
43
|
+
for saving with safetensors.save_file() or torch.save().
|
|
44
|
+
"""
|
|
45
|
+
py_state = random.getstate()
|
|
46
|
+
return {
|
|
47
|
+
"py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
|
|
48
|
+
"py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|
53
|
+
"""Restore Python's random module RNG state from a dictionary.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
rng_state_dict: Dictionary produced by serialize_python_rng_state().
|
|
57
|
+
"""
|
|
58
|
+
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
|
59
|
+
random.setstate(py_state)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
|
|
63
|
+
"""Serialize NumPy's RNG state to a dictionary.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Dictionary containing the RNG state as torch.Tensor values, suitable
|
|
67
|
+
for saving with safetensors.save_file() or torch.save().
|
|
68
|
+
"""
|
|
69
|
+
np_state = np.random.get_state()
|
|
70
|
+
# Ensure no breaking changes from numpy
|
|
71
|
+
assert np_state[0] == "MT19937"
|
|
72
|
+
return {
|
|
73
|
+
"np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
|
|
74
|
+
"np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
|
|
75
|
+
"np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
|
|
76
|
+
"np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|
81
|
+
"""Restore NumPy's RNG state from a dictionary.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
rng_state_dict: Dictionary produced by serialize_numpy_rng_state().
|
|
85
|
+
"""
|
|
86
|
+
np_state = (
|
|
87
|
+
"MT19937",
|
|
88
|
+
rng_state_dict["np_rng_state_values"].numpy(),
|
|
89
|
+
rng_state_dict["np_rng_state_index"].item(),
|
|
90
|
+
rng_state_dict["np_rng_has_gauss"].item(),
|
|
91
|
+
rng_state_dict["np_rng_cached_gaussian"].item(),
|
|
92
|
+
)
|
|
93
|
+
np.random.set_state(np_state)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
|
|
97
|
+
"""Serialize PyTorch's RNG state to a dictionary.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Dictionary containing the RNG state as torch.Tensor values, including
|
|
101
|
+
CUDA RNG state if available. Suitable for saving with safetensors.save_file()
|
|
102
|
+
or torch.save().
|
|
103
|
+
"""
|
|
104
|
+
torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
|
|
105
|
+
if torch.cuda.is_available():
|
|
106
|
+
torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
|
|
107
|
+
return torch_rng_state_dict
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|
111
|
+
"""Restore PyTorch's RNG state from a dictionary.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
rng_state_dict: Dictionary produced by serialize_torch_rng_state().
|
|
115
|
+
"""
|
|
116
|
+
torch.set_rng_state(rng_state_dict["torch_rng_state"])
|
|
117
|
+
if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
|
|
118
|
+
torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def serialize_rng_state() -> dict[str, torch.Tensor]:
|
|
122
|
+
"""Serialize RNG states for random, numpy, and torch.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Dictionary containing all RNG states as torch.Tensor values, suitable
|
|
126
|
+
for saving with safetensors.save_file() or torch.save().
|
|
127
|
+
"""
|
|
128
|
+
py_rng_state_dict = serialize_python_rng_state()
|
|
129
|
+
np_rng_state_dict = serialize_numpy_rng_state()
|
|
130
|
+
torch_rng_state_dict = serialize_torch_rng_state()
|
|
131
|
+
|
|
132
|
+
return {
|
|
133
|
+
**py_rng_state_dict,
|
|
134
|
+
**np_rng_state_dict,
|
|
135
|
+
**torch_rng_state_dict,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
|
140
|
+
"""Restore RNG states for random, numpy, and torch from a dictionary.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
rng_state_dict: Dictionary produced by serialize_rng_state().
|
|
144
|
+
"""
|
|
145
|
+
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
|
146
|
+
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
|
147
|
+
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
|
148
|
+
|
|
149
|
+
deserialize_python_rng_state(py_rng_state_dict)
|
|
150
|
+
deserialize_numpy_rng_state(np_rng_state_dict)
|
|
151
|
+
deserialize_torch_rng_state(torch_rng_state_dict)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def save_rng_state(save_dir: Path) -> None:
|
|
155
|
+
"""Save RNG state to a file in the specified directory.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
save_dir: Directory path where the RNG state file will be saved.
|
|
159
|
+
"""
|
|
160
|
+
rng_state_dict = serialize_rng_state()
|
|
161
|
+
flat_rng_state_dict = flatten_dict(rng_state_dict)
|
|
162
|
+
save_file(flat_rng_state_dict, save_dir / RNG_STATE)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def load_rng_state(save_dir: Path) -> None:
|
|
166
|
+
"""Load RNG state from a file in the specified directory.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
save_dir: Directory path containing the RNG state file.
|
|
170
|
+
"""
|
|
171
|
+
flat_rng_state_dict = load_file(save_dir / RNG_STATE)
|
|
172
|
+
rng_state_dict = unflatten_dict(flat_rng_state_dict)
|
|
173
|
+
deserialize_rng_state(rng_state_dict)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_rng_state() -> dict[str, Any]:
|
|
177
|
+
"""Get the current random state for random, numpy, and torch.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dictionary containing the current RNG states for all three generators.
|
|
181
|
+
"""
|
|
182
|
+
random_state_dict = {
|
|
183
|
+
"random_state": random.getstate(),
|
|
184
|
+
"numpy_random_state": np.random.get_state(),
|
|
185
|
+
"torch_random_state": torch.random.get_rng_state(),
|
|
186
|
+
}
|
|
187
|
+
if torch.cuda.is_available():
|
|
188
|
+
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
|
189
|
+
return random_state_dict
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def set_rng_state(random_state_dict: dict[str, Any]) -> None:
|
|
193
|
+
"""Set the random state for random, numpy, and torch.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
random_state_dict: Dictionary of the form returned by get_rng_state().
|
|
197
|
+
"""
|
|
198
|
+
random.setstate(random_state_dict["random_state"])
|
|
199
|
+
np.random.set_state(random_state_dict["numpy_random_state"])
|
|
200
|
+
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
|
201
|
+
if torch.cuda.is_available():
|
|
202
|
+
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
# TODO: only use accelerate set_seed instead of this function. accelerate set_seed already handles the random, numpy, and torch seeds.
|
|
206
|
+
def set_seed(seed, accelerator: accelerate.Accelerator = None) -> None:
|
|
207
|
+
"""Set seed for reproducibility across random, numpy, and torch.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
seed: Seed value to use. If None, no seeding is performed.
|
|
211
|
+
accelerator: Optional Accelerator instance. If provided, each process
|
|
212
|
+
gets a different seed offset to ensure reproducibility in distributed
|
|
213
|
+
settings.
|
|
214
|
+
"""
|
|
215
|
+
# before setting the seed, we check if we are using an accelerator and ensure every process gets a different seed
|
|
216
|
+
if seed is not None and accelerator is not None:
|
|
217
|
+
magic_number = 12345 # arbitrary constant to offset the seed per process
|
|
218
|
+
seed += accelerator.process_index * magic_number
|
|
219
|
+
random.seed(seed)
|
|
220
|
+
np.random.seed(seed)
|
|
221
|
+
torch.manual_seed(seed)
|
|
222
|
+
if torch.cuda.is_available():
|
|
223
|
+
torch.cuda.manual_seed_all(seed)
|
|
224
|
+
if accelerator:
|
|
225
|
+
from accelerate.utils import set_seed
|
|
226
|
+
|
|
227
|
+
set_seed(seed)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@contextmanager
|
|
231
|
+
def seeded_context(seed: int) -> Generator[None, None, None]:
|
|
232
|
+
"""Set the seed when entering a context, and restore the prior random state at exit.
|
|
233
|
+
|
|
234
|
+
Example usage::
|
|
235
|
+
|
|
236
|
+
a = random.random() # produces some random number
|
|
237
|
+
with seeded_context(1337):
|
|
238
|
+
b = random.random() # produces some other random number
|
|
239
|
+
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
|
240
|
+
"""
|
|
241
|
+
random_state_dict = get_rng_state()
|
|
242
|
+
set_seed(seed)
|
|
243
|
+
yield None
|
|
244
|
+
set_rng_state(random_state_dict)
|