opentau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. opentau/__init__.py +179 -0
  2. opentau/__version__.py +24 -0
  3. opentau/configs/__init__.py +19 -0
  4. opentau/configs/default.py +297 -0
  5. opentau/configs/libero.py +113 -0
  6. opentau/configs/parser.py +393 -0
  7. opentau/configs/policies.py +297 -0
  8. opentau/configs/reward.py +42 -0
  9. opentau/configs/train.py +370 -0
  10. opentau/configs/types.py +76 -0
  11. opentau/constants.py +52 -0
  12. opentau/datasets/__init__.py +84 -0
  13. opentau/datasets/backward_compatibility.py +78 -0
  14. opentau/datasets/compute_stats.py +333 -0
  15. opentau/datasets/dataset_mixture.py +460 -0
  16. opentau/datasets/factory.py +232 -0
  17. opentau/datasets/grounding/__init__.py +67 -0
  18. opentau/datasets/grounding/base.py +154 -0
  19. opentau/datasets/grounding/clevr.py +110 -0
  20. opentau/datasets/grounding/cocoqa.py +130 -0
  21. opentau/datasets/grounding/dummy.py +101 -0
  22. opentau/datasets/grounding/pixmo.py +177 -0
  23. opentau/datasets/grounding/vsr.py +141 -0
  24. opentau/datasets/image_writer.py +304 -0
  25. opentau/datasets/lerobot_dataset.py +1910 -0
  26. opentau/datasets/online_buffer.py +442 -0
  27. opentau/datasets/push_dataset_to_hub/utils.py +132 -0
  28. opentau/datasets/sampler.py +99 -0
  29. opentau/datasets/standard_data_format_mapping.py +278 -0
  30. opentau/datasets/transforms.py +330 -0
  31. opentau/datasets/utils.py +1243 -0
  32. opentau/datasets/v2/batch_convert_dataset_v1_to_v2.py +887 -0
  33. opentau/datasets/v2/convert_dataset_v1_to_v2.py +829 -0
  34. opentau/datasets/v21/_remove_language_instruction.py +109 -0
  35. opentau/datasets/v21/batch_convert_dataset_v20_to_v21.py +60 -0
  36. opentau/datasets/v21/convert_dataset_v20_to_v21.py +183 -0
  37. opentau/datasets/v21/convert_stats.py +150 -0
  38. opentau/datasets/video_utils.py +597 -0
  39. opentau/envs/__init__.py +18 -0
  40. opentau/envs/configs.py +178 -0
  41. opentau/envs/factory.py +99 -0
  42. opentau/envs/libero.py +439 -0
  43. opentau/envs/utils.py +204 -0
  44. opentau/optim/__init__.py +16 -0
  45. opentau/optim/factory.py +43 -0
  46. opentau/optim/optimizers.py +121 -0
  47. opentau/optim/schedulers.py +140 -0
  48. opentau/planner/__init__.py +82 -0
  49. opentau/planner/high_level_planner.py +366 -0
  50. opentau/planner/utils/memory.py +64 -0
  51. opentau/planner/utils/utils.py +65 -0
  52. opentau/policies/__init__.py +24 -0
  53. opentau/policies/factory.py +172 -0
  54. opentau/policies/normalize.py +315 -0
  55. opentau/policies/pi0/__init__.py +19 -0
  56. opentau/policies/pi0/configuration_pi0.py +250 -0
  57. opentau/policies/pi0/modeling_pi0.py +994 -0
  58. opentau/policies/pi0/paligemma_with_expert.py +516 -0
  59. opentau/policies/pi05/__init__.py +20 -0
  60. opentau/policies/pi05/configuration_pi05.py +231 -0
  61. opentau/policies/pi05/modeling_pi05.py +1257 -0
  62. opentau/policies/pi05/paligemma_with_expert.py +572 -0
  63. opentau/policies/pretrained.py +315 -0
  64. opentau/policies/utils.py +123 -0
  65. opentau/policies/value/__init__.py +18 -0
  66. opentau/policies/value/configuration_value.py +170 -0
  67. opentau/policies/value/modeling_value.py +512 -0
  68. opentau/policies/value/reward.py +87 -0
  69. opentau/policies/value/siglip_gemma.py +221 -0
  70. opentau/scripts/actions_mse_loss.py +89 -0
  71. opentau/scripts/bin_to_safetensors.py +116 -0
  72. opentau/scripts/compute_max_token_length.py +111 -0
  73. opentau/scripts/display_sys_info.py +90 -0
  74. opentau/scripts/download_libero_benchmarks.py +54 -0
  75. opentau/scripts/eval.py +877 -0
  76. opentau/scripts/export_to_onnx.py +180 -0
  77. opentau/scripts/fake_tensor_training.py +87 -0
  78. opentau/scripts/get_advantage_and_percentiles.py +220 -0
  79. opentau/scripts/high_level_planner_inference.py +114 -0
  80. opentau/scripts/inference.py +70 -0
  81. opentau/scripts/launch_train.py +63 -0
  82. opentau/scripts/libero_simulation_parallel.py +356 -0
  83. opentau/scripts/libero_simulation_sequential.py +122 -0
  84. opentau/scripts/nav_high_level_planner_inference.py +61 -0
  85. opentau/scripts/train.py +379 -0
  86. opentau/scripts/visualize_dataset.py +294 -0
  87. opentau/scripts/visualize_dataset_html.py +507 -0
  88. opentau/scripts/zero_to_fp32.py +760 -0
  89. opentau/utils/__init__.py +20 -0
  90. opentau/utils/accelerate_utils.py +79 -0
  91. opentau/utils/benchmark.py +98 -0
  92. opentau/utils/fake_tensor.py +81 -0
  93. opentau/utils/hub.py +209 -0
  94. opentau/utils/import_utils.py +79 -0
  95. opentau/utils/io_utils.py +137 -0
  96. opentau/utils/libero.py +214 -0
  97. opentau/utils/libero_dataset_recorder.py +460 -0
  98. opentau/utils/logging_utils.py +180 -0
  99. opentau/utils/monkey_patch.py +278 -0
  100. opentau/utils/random_utils.py +244 -0
  101. opentau/utils/train_utils.py +198 -0
  102. opentau/utils/utils.py +471 -0
  103. opentau-0.1.0.dist-info/METADATA +161 -0
  104. opentau-0.1.0.dist-info/RECORD +108 -0
  105. opentau-0.1.0.dist-info/WHEEL +5 -0
  106. opentau-0.1.0.dist-info/entry_points.txt +2 -0
  107. opentau-0.1.0.dist-info/licenses/LICENSE +508 -0
  108. opentau-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,278 @@
1
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""Monkey patches to inject behaviour into computational graph construction. This eliminates the need to trace and
16
+ modify source code of 3rd party libraries, such as transformers or even PyTorch itself.
17
+
18
+ Where necessary, we can do
19
+
20
+ >>> from opentau.utils.monkey_patch import torch_cumsum_patch, torch_pow_patch
21
+ >>> torch_cumsum_patch() # Apply the patch to handle bool tensors in cumsum
22
+ >>> torch_pow_patch() # Apply the patch to handle mixed number-tensor exponentiation in pow
23
+
24
+ to apply the patches.
25
+
26
+ Running the same monkey patch twice will not have any effect, as the 2nd call will return immediately.
27
+
28
+ Note: currently, there is no way to undo a monkey patch once it has been applied, which can be a future to-do.
29
+ For now, only apply the patch when its implications are understood.
30
+ """
31
+
32
+ import importlib
33
+ import logging
34
+ import sys
35
+ from functools import wraps
36
+
37
+ import numpy as np
38
+ import torch
39
+
40
+ # global singleton to track which patches have been applied
41
+ __patches_applied = set()
42
+
43
+
44
+ def _run_once_only(func):
45
+ """Decorator that ensures the function is run once only.
46
+
47
+ Subsequent calls to the function will return immediately without executing
48
+ the function body again.
49
+
50
+ Args:
51
+ func: Function to be wrapped.
52
+
53
+ Returns:
54
+ Wrapped function that executes only once.
55
+ """
56
+
57
+ @wraps(func)
58
+ def inner(*args, **kwargs):
59
+ if func in __patches_applied:
60
+ return
61
+
62
+ __patches_applied.add(func)
63
+ logging.debug(f"Applying monkey patch: {func.__name__} in {func.__module__}")
64
+ return func(*args, **kwargs)
65
+
66
+ return inner
67
+
68
+
69
+ def patches_applied():
70
+ """Get a list of all patches that have been applied.
71
+
72
+ Returns:
73
+ List of patch function names that have been applied.
74
+ """
75
+ return [func.__name__ for func in __patches_applied]
76
+
77
+
78
+ @_run_once_only
79
+ def torch_cumsum_patch():
80
+ """Override torch.cumsum to handle bool tensors correctly.
81
+
82
+ PyTorch allows cumsum on bool tensors, but ONNX Runtime does not.
83
+ This patch converts bool tensors to int64 before calling cumsum.
84
+ """
85
+ original_cumsum = torch.cumsum
86
+
87
+ def _patched_cumsum(tensor, *args, **kwargs):
88
+ if tensor.dtype == torch.bool:
89
+ tensor = tensor.to(torch.int64)
90
+ return original_cumsum(tensor, *args, **kwargs)
91
+
92
+ torch.cumsum = _patched_cumsum
93
+
94
+
95
+ @_run_once_only
96
+ def torch_pow_patch():
97
+ """Override torch.pow to ensure both base and exponent are tensors.
98
+
99
+ This patch converts scalar arguments to tensors before calling pow,
100
+ ensuring compatibility with ONNX export.
101
+ """
102
+ original_pow = torch.pow
103
+
104
+ def _patched_pow(base, exponent, *args, **kwargs):
105
+ if not isinstance(base, torch.Tensor):
106
+ base = torch.tensor(base, dtype=exponent.dtype, device=exponent.device)
107
+ if not isinstance(exponent, torch.Tensor):
108
+ exponent = torch.tensor(exponent, dtype=base.dtype, device=base.device)
109
+ return original_pow(base, exponent, *args, **kwargs)
110
+
111
+ torch.pow = _patched_pow
112
+ torch.Tensor.pow = _patched_pow
113
+ torch.Tensor.__pow__ = _patched_pow
114
+ # At least for torch 2.7, `__rpow__` is already defined like this, but we ensure it for future compatibility
115
+ torch.Tensor.__rpow__ = lambda x, y: _patched_pow(y, x)
116
+
117
+
118
+ @_run_once_only
119
+ def torch_full_patch():
120
+ """Override torch.full to convert bool fill values to int.
121
+
122
+ This patch ensures that True/False are converted to 1/0 before reaching
123
+ the C++ level, improving compatibility with certain backends.
124
+ """
125
+ original_full = torch.full
126
+
127
+ def _patched_full(size, fill_value, *args, **kwargs):
128
+ if isinstance(fill_value, bool):
129
+ fill_value = int(fill_value)
130
+ return original_full(size, fill_value, *args, **kwargs)
131
+
132
+ torch.full = _patched_full
133
+
134
+
135
+ @_run_once_only
136
+ def torch_fake_tensor_module_to_patch():
137
+ """Fix torch.nn.Module.to(device) behavior in FakeTensorMode.
138
+
139
+ Without this patch, Module.to(device) is a no-op in FakeTensorMode, leading
140
+ to device mismatch errors. This patch enables proper device conversion.
141
+
142
+ See https://github.com/pytorch/pytorch/issues/119665 for more details.
143
+ """
144
+ torch.__future__.set_overwrite_module_params_on_conversion(True)
145
+
146
+
147
+ @_run_once_only
148
+ def torch_fake_tensor_to_numpy_patch():
149
+ """Enable .numpy() calls on FakeTensor to return random numpy arrays.
150
+
151
+ This patch allows .numpy() to be called on FakeTensor instances, returning
152
+ numpy arrays with random values. Note that calling .numpy() multiple times
153
+ on the same FakeTensor may return different values.
154
+ """
155
+ _torch2np = {
156
+ torch.float32: np.float32,
157
+ torch.float64: np.float64,
158
+ torch.float16: np.float16,
159
+ # torch.bfloat16 is intentionally excluded as it is not supported by numpy
160
+ torch.int64: np.int64,
161
+ torch.int32: np.int32,
162
+ torch.int16: np.int16,
163
+ torch.int8: np.int8,
164
+ torch.uint8: np.uint8,
165
+ torch.bool: np.bool_,
166
+ }
167
+
168
+ def _patched_numpy(self: torch._subclasses.fake_tensor.FakeTensor, /):
169
+ if self.device.type != "cpu":
170
+ raise RuntimeError(
171
+ f"FakeTensor.numpy() can only be called on CPU tensors. This tensor is on {self.device}"
172
+ )
173
+ if self.requires_grad:
174
+ raise RuntimeError(
175
+ ".numpy() cannot be called on tensors that require gradients. Call tensor.detach().numpy() instead."
176
+ )
177
+ if self.dtype not in _torch2np:
178
+ raise RuntimeError(f"Unsupported dtype {self.dtype} for FakeTensor.numpy()")
179
+
180
+ # `np.random.rand()` returns a float instead of a nil-dim array
181
+ # So we wrap it in np.array() to ensure the shape is preserved
182
+ return np.array(np.random.rand(*self.shape)).astype(_torch2np[self.dtype])
183
+
184
+ torch._subclasses.fake_tensor.FakeTensor.numpy = _patched_numpy
185
+
186
+
187
+ @_run_once_only
188
+ def torch_fake_tensor_beta_validate_args_patch():
189
+ """Fix torch.distributions.Beta to work in FakeTensorMode.
190
+
191
+ This patch sets validate_args=False by default for Beta distributions,
192
+ which is required for FakeTensorMode compatibility.
193
+ """
194
+ original_beta_init = torch.distributions.Beta.__init__
195
+
196
+ def _patched_beta_init(self, *args, **kwargs):
197
+ kwargs.setdefault("validate_args", False)
198
+ original_beta_init(self, *args, **kwargs)
199
+
200
+ torch.distributions.Beta.__init__ = _patched_beta_init
201
+
202
+
203
+ @_run_once_only
204
+ def torch_fake_tensor_is_inf_patch():
205
+ """Patch torch.isinf to work with FakeTensor.
206
+
207
+ This patch provides a mock implementation of torch.isinf that returns
208
+ a mock object compatible with FakeTensor operations.
209
+ """
210
+ from unittest.mock import Mock
211
+
212
+ def _patched_isinf(x):
213
+ obj = Mock()
214
+ obj.dtype = torch.bool
215
+ obj.shape = x.shape
216
+ obj.any.return_value = False
217
+ obj.all.return_value = False
218
+ return obj
219
+
220
+ torch.isinf = _patched_isinf
221
+
222
+
223
+ @_run_once_only
224
+ def gym_is_gymnasium_patch():
225
+ """Monkey patch to make `import gym` equivalent to `import gymnasium as gym`.
226
+
227
+ This patch is necessary because the original gym package is incompatible
228
+ with numpy >= 2.0. It redirects gym imports to use gymnasium instead.
229
+ """
230
+ _g = importlib.import_module("gymnasium")
231
+ sys.modules.setdefault("gym", _g)
232
+
233
+ # This is a non-exhaustive list. More submodules may be added in the future as needed.
234
+ # A more compressive solution would involve a lower-level approach using `finder`s and `loader`s.
235
+ # See https://docs.python.org/3/reference/import.html#finders-and-loaders
236
+ subpackages = [
237
+ "spaces",
238
+ "envs",
239
+ "envs.classic_control",
240
+ "envs.mujoco",
241
+ "envs.toy_text",
242
+ "wrappers",
243
+ "vector",
244
+ "vector.utils",
245
+ "utils",
246
+ ]
247
+
248
+ for sub in subpackages:
249
+ try:
250
+ old_name = f"gym.{sub}"
251
+ new_name = f"gymnasium.{sub}"
252
+ if old_name in sys.modules:
253
+ print(f"Module {old_name} already exists in sys.modules, skipping import of {new_name}")
254
+ else:
255
+ # Assuming importing the submodule has no side effects, which should be true for gymnasium
256
+ sys.modules[old_name] = importlib.import_module(new_name)
257
+ except (ImportError, ModuleNotFoundError):
258
+ print("Failed to import gymnasium submodule:", sub, file=sys.stderr)
259
+
260
+
261
+ @_run_once_only
262
+ def torch_load_patch():
263
+ """Override torch.load to handle weights_only argument.
264
+
265
+ This patch ensures that torch.load properly handles the weights_only
266
+ argument for PyTorch versions >= 2.6, setting it to False by default
267
+ if not explicitly provided.
268
+ """
269
+ if torch.__version__ < "2.6":
270
+ return
271
+
272
+ original_load = torch.load
273
+
274
+ def _patched_load(*args, weights_only=..., **kwargs):
275
+ kwargs["weights_only"] = False if weights_only is ... else weights_only
276
+ return original_load(*args, **kwargs)
277
+
278
+ torch.load = _patched_load
@@ -0,0 +1,244 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ # Copyright 2026 Tensor Auto Inc. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Utilities for managing random number generator states.
18
+
19
+ This module provides functions for serializing, deserializing, saving, and loading
20
+ random number generator states for Python's random module, NumPy, and PyTorch.
21
+ This is essential for reproducibility in training and evaluation.
22
+ """
23
+
24
+ import random
25
+ from contextlib import contextmanager
26
+ from pathlib import Path
27
+ from typing import Any, Generator
28
+
29
+ import accelerate
30
+ import numpy as np
31
+ import torch
32
+ from safetensors.torch import load_file, save_file
33
+
34
+ from opentau.constants import RNG_STATE
35
+ from opentau.datasets.utils import flatten_dict, unflatten_dict
36
+
37
+
38
+ def serialize_python_rng_state() -> dict[str, torch.Tensor]:
39
+ """Serialize Python's random module RNG state to a dictionary.
40
+
41
+ Returns:
42
+ Dictionary containing the RNG state as torch.Tensor values, suitable
43
+ for saving with safetensors.save_file() or torch.save().
44
+ """
45
+ py_state = random.getstate()
46
+ return {
47
+ "py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
48
+ "py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
49
+ }
50
+
51
+
52
+ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
53
+ """Restore Python's random module RNG state from a dictionary.
54
+
55
+ Args:
56
+ rng_state_dict: Dictionary produced by serialize_python_rng_state().
57
+ """
58
+ py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
59
+ random.setstate(py_state)
60
+
61
+
62
+ def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
63
+ """Serialize NumPy's RNG state to a dictionary.
64
+
65
+ Returns:
66
+ Dictionary containing the RNG state as torch.Tensor values, suitable
67
+ for saving with safetensors.save_file() or torch.save().
68
+ """
69
+ np_state = np.random.get_state()
70
+ # Ensure no breaking changes from numpy
71
+ assert np_state[0] == "MT19937"
72
+ return {
73
+ "np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
74
+ "np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
75
+ "np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
76
+ "np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
77
+ }
78
+
79
+
80
+ def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
81
+ """Restore NumPy's RNG state from a dictionary.
82
+
83
+ Args:
84
+ rng_state_dict: Dictionary produced by serialize_numpy_rng_state().
85
+ """
86
+ np_state = (
87
+ "MT19937",
88
+ rng_state_dict["np_rng_state_values"].numpy(),
89
+ rng_state_dict["np_rng_state_index"].item(),
90
+ rng_state_dict["np_rng_has_gauss"].item(),
91
+ rng_state_dict["np_rng_cached_gaussian"].item(),
92
+ )
93
+ np.random.set_state(np_state)
94
+
95
+
96
+ def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
97
+ """Serialize PyTorch's RNG state to a dictionary.
98
+
99
+ Returns:
100
+ Dictionary containing the RNG state as torch.Tensor values, including
101
+ CUDA RNG state if available. Suitable for saving with safetensors.save_file()
102
+ or torch.save().
103
+ """
104
+ torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
105
+ if torch.cuda.is_available():
106
+ torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
107
+ return torch_rng_state_dict
108
+
109
+
110
+ def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
111
+ """Restore PyTorch's RNG state from a dictionary.
112
+
113
+ Args:
114
+ rng_state_dict: Dictionary produced by serialize_torch_rng_state().
115
+ """
116
+ torch.set_rng_state(rng_state_dict["torch_rng_state"])
117
+ if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
118
+ torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
119
+
120
+
121
+ def serialize_rng_state() -> dict[str, torch.Tensor]:
122
+ """Serialize RNG states for random, numpy, and torch.
123
+
124
+ Returns:
125
+ Dictionary containing all RNG states as torch.Tensor values, suitable
126
+ for saving with safetensors.save_file() or torch.save().
127
+ """
128
+ py_rng_state_dict = serialize_python_rng_state()
129
+ np_rng_state_dict = serialize_numpy_rng_state()
130
+ torch_rng_state_dict = serialize_torch_rng_state()
131
+
132
+ return {
133
+ **py_rng_state_dict,
134
+ **np_rng_state_dict,
135
+ **torch_rng_state_dict,
136
+ }
137
+
138
+
139
+ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
140
+ """Restore RNG states for random, numpy, and torch from a dictionary.
141
+
142
+ Args:
143
+ rng_state_dict: Dictionary produced by serialize_rng_state().
144
+ """
145
+ py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
146
+ np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
147
+ torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
148
+
149
+ deserialize_python_rng_state(py_rng_state_dict)
150
+ deserialize_numpy_rng_state(np_rng_state_dict)
151
+ deserialize_torch_rng_state(torch_rng_state_dict)
152
+
153
+
154
+ def save_rng_state(save_dir: Path) -> None:
155
+ """Save RNG state to a file in the specified directory.
156
+
157
+ Args:
158
+ save_dir: Directory path where the RNG state file will be saved.
159
+ """
160
+ rng_state_dict = serialize_rng_state()
161
+ flat_rng_state_dict = flatten_dict(rng_state_dict)
162
+ save_file(flat_rng_state_dict, save_dir / RNG_STATE)
163
+
164
+
165
+ def load_rng_state(save_dir: Path) -> None:
166
+ """Load RNG state from a file in the specified directory.
167
+
168
+ Args:
169
+ save_dir: Directory path containing the RNG state file.
170
+ """
171
+ flat_rng_state_dict = load_file(save_dir / RNG_STATE)
172
+ rng_state_dict = unflatten_dict(flat_rng_state_dict)
173
+ deserialize_rng_state(rng_state_dict)
174
+
175
+
176
+ def get_rng_state() -> dict[str, Any]:
177
+ """Get the current random state for random, numpy, and torch.
178
+
179
+ Returns:
180
+ Dictionary containing the current RNG states for all three generators.
181
+ """
182
+ random_state_dict = {
183
+ "random_state": random.getstate(),
184
+ "numpy_random_state": np.random.get_state(),
185
+ "torch_random_state": torch.random.get_rng_state(),
186
+ }
187
+ if torch.cuda.is_available():
188
+ random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
189
+ return random_state_dict
190
+
191
+
192
+ def set_rng_state(random_state_dict: dict[str, Any]) -> None:
193
+ """Set the random state for random, numpy, and torch.
194
+
195
+ Args:
196
+ random_state_dict: Dictionary of the form returned by get_rng_state().
197
+ """
198
+ random.setstate(random_state_dict["random_state"])
199
+ np.random.set_state(random_state_dict["numpy_random_state"])
200
+ torch.random.set_rng_state(random_state_dict["torch_random_state"])
201
+ if torch.cuda.is_available():
202
+ torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
203
+
204
+
205
+ # TODO: only use accelerate set_seed instead of this function. accelerate set_seed already handles the random, numpy, and torch seeds.
206
+ def set_seed(seed, accelerator: accelerate.Accelerator = None) -> None:
207
+ """Set seed for reproducibility across random, numpy, and torch.
208
+
209
+ Args:
210
+ seed: Seed value to use. If None, no seeding is performed.
211
+ accelerator: Optional Accelerator instance. If provided, each process
212
+ gets a different seed offset to ensure reproducibility in distributed
213
+ settings.
214
+ """
215
+ # before setting the seed, we check if we are using an accelerator and ensure every process gets a different seed
216
+ if seed is not None and accelerator is not None:
217
+ magic_number = 12345 # arbitrary constant to offset the seed per process
218
+ seed += accelerator.process_index * magic_number
219
+ random.seed(seed)
220
+ np.random.seed(seed)
221
+ torch.manual_seed(seed)
222
+ if torch.cuda.is_available():
223
+ torch.cuda.manual_seed_all(seed)
224
+ if accelerator:
225
+ from accelerate.utils import set_seed
226
+
227
+ set_seed(seed)
228
+
229
+
230
+ @contextmanager
231
+ def seeded_context(seed: int) -> Generator[None, None, None]:
232
+ """Set the seed when entering a context, and restore the prior random state at exit.
233
+
234
+ Example usage::
235
+
236
+ a = random.random() # produces some random number
237
+ with seeded_context(1337):
238
+ b = random.random() # produces some other random number
239
+ c = random.random() # produces yet another random number, but the same it would have if we never made `b`
240
+ """
241
+ random_state_dict = get_rng_state()
242
+ set_seed(seed)
243
+ yield None
244
+ set_rng_state(random_state_dict)