torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/_utils.py
ADDED
|
@@ -0,0 +1,1431 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import collections
|
|
8
|
+
import functools
|
|
9
|
+
import inspect
|
|
10
|
+
import logging
|
|
11
|
+
import math
|
|
12
|
+
import os
|
|
13
|
+
import pickle
|
|
14
|
+
import sys
|
|
15
|
+
import threading
|
|
16
|
+
import time
|
|
17
|
+
import traceback
|
|
18
|
+
import warnings
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from contextlib import nullcontext
|
|
21
|
+
from functools import wraps
|
|
22
|
+
from textwrap import indent
|
|
23
|
+
from typing import Any, cast, TypeVar
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
from pyvers import implement_for # noqa: F401
|
|
29
|
+
from tensordict import unravel_key
|
|
30
|
+
from tensordict.utils import NestedKey
|
|
31
|
+
from torch import multiprocessing as mp, Tensor
|
|
32
|
+
from torch.autograd.profiler import record_function
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from torch.compiler import is_compiling
|
|
36
|
+
except ImportError:
|
|
37
|
+
from torch._dynamo import is_compiling
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_default_mp_start_method() -> str:
|
|
41
|
+
"""Returns TorchRL's preferred multiprocessing start method.
|
|
42
|
+
|
|
43
|
+
If the user has explicitly set a global start method via ``mp.set_start_method()``,
|
|
44
|
+
that method is returned. Otherwise, defaults to ``"spawn"`` for improved safety
|
|
45
|
+
across backends and to avoid known issues with ``fork`` in multi-threaded programs.
|
|
46
|
+
"""
|
|
47
|
+
# Check if user has explicitly set a global start method
|
|
48
|
+
try:
|
|
49
|
+
current = mp.get_start_method(allow_none=True)
|
|
50
|
+
if current is not None:
|
|
51
|
+
return current
|
|
52
|
+
except (TypeError, RuntimeError):
|
|
53
|
+
pass
|
|
54
|
+
return "spawn"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_mp_ctx(start_method: str | None = None):
|
|
58
|
+
"""Return a multiprocessing context with TorchRL's preferred start method.
|
|
59
|
+
|
|
60
|
+
This is intentionally context-based (instead of relying on global
|
|
61
|
+
``mp.set_start_method``) so that TorchRL components can consistently allocate
|
|
62
|
+
primitives (Queue/Pipe/Lock/Process) with a matching context.
|
|
63
|
+
"""
|
|
64
|
+
if start_method is None:
|
|
65
|
+
start_method = _get_default_mp_start_method()
|
|
66
|
+
try:
|
|
67
|
+
return mp.get_context(start_method)
|
|
68
|
+
except ValueError:
|
|
69
|
+
# Best effort fallback if a start method isn't supported on this platform.
|
|
70
|
+
return mp.get_context("spawn")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _set_mp_start_method_if_unset(start_method: str | None = None) -> str | None:
|
|
74
|
+
"""Set the global start method only if it hasn't been set yet.
|
|
75
|
+
|
|
76
|
+
Returns the (possibly pre-existing) start method, or ``None`` if it cannot be
|
|
77
|
+
determined.
|
|
78
|
+
"""
|
|
79
|
+
if start_method is None:
|
|
80
|
+
start_method = _get_default_mp_start_method()
|
|
81
|
+
|
|
82
|
+
current = None
|
|
83
|
+
try:
|
|
84
|
+
current = mp.get_start_method(allow_none=True)
|
|
85
|
+
except TypeError:
|
|
86
|
+
# Older python/torch wrappers may not accept allow_none.
|
|
87
|
+
try:
|
|
88
|
+
current = mp.get_start_method()
|
|
89
|
+
except Exception:
|
|
90
|
+
current = None
|
|
91
|
+
except Exception:
|
|
92
|
+
current = None
|
|
93
|
+
|
|
94
|
+
if current is None:
|
|
95
|
+
try:
|
|
96
|
+
mp.set_start_method(start_method, force=False)
|
|
97
|
+
current = start_method
|
|
98
|
+
except Exception:
|
|
99
|
+
# If another library already touched the context, we should not
|
|
100
|
+
# override it here.
|
|
101
|
+
pass
|
|
102
|
+
return current
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@implement_for("torch", None, "2.8")
|
|
106
|
+
def _mp_sharing_strategy_for_spawn() -> str | None:
|
|
107
|
+
# On older torch stacks, pickling Process objects for "spawn" can end up
|
|
108
|
+
# passing file descriptors for shared storages; using "file_system" reduces
|
|
109
|
+
# FD passing and avoids spawn-time failures on some old Python versions.
|
|
110
|
+
return "file_system"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@implement_for("torch", "2.8")
|
|
114
|
+
def _mp_sharing_strategy_for_spawn() -> str | None: # noqa: F811
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def strtobool(val: Any) -> bool:
|
|
119
|
+
"""Convert a string representation of truth to a boolean.
|
|
120
|
+
|
|
121
|
+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
|
|
122
|
+
Raises ValueError if 'val' is anything else.
|
|
123
|
+
"""
|
|
124
|
+
val = val.lower()
|
|
125
|
+
if val in ("y", "yes", "t", "true", "on", "1"):
|
|
126
|
+
return True
|
|
127
|
+
if val in ("n", "no", "f", "false", "off", "0"):
|
|
128
|
+
return False
|
|
129
|
+
raise ValueError(f"Invalid truth value {val!r}")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
|
|
133
|
+
logger = logging.getLogger("torchrl")
|
|
134
|
+
logger.setLevel(LOGGING_LEVEL)
|
|
135
|
+
logger.propagate = False
|
|
136
|
+
# Clear existing handlers
|
|
137
|
+
while logger.hasHandlers():
|
|
138
|
+
logger.removeHandler(logger.handlers[0])
|
|
139
|
+
stream_handlers = {
|
|
140
|
+
"stdout": sys.stdout,
|
|
141
|
+
"stderr": sys.stderr,
|
|
142
|
+
}
|
|
143
|
+
TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM")
|
|
144
|
+
stream_handler = stream_handlers.get(TORCHRL_CONSOLE_STREAM, sys.stdout)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# Create colored handler
|
|
148
|
+
class _CustomFormatter(logging.Formatter):
|
|
149
|
+
def format(self, record):
|
|
150
|
+
# Format the initial part in green
|
|
151
|
+
green_format = "\033[92m%(asctime)s [%(name)s][%(levelname)s]\033[0m"
|
|
152
|
+
# Format the message part
|
|
153
|
+
message_format = "%(message)s"
|
|
154
|
+
# End marker in green
|
|
155
|
+
end_marker = "\033[92m [END]\033[0m"
|
|
156
|
+
# Combine all parts
|
|
157
|
+
formatted_message = logging.Formatter(
|
|
158
|
+
green_format + indent(message_format, " " * 4) + end_marker
|
|
159
|
+
).format(record)
|
|
160
|
+
|
|
161
|
+
return formatted_message
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
console_handler = logging.StreamHandler(stream=stream_handler)
|
|
165
|
+
console_handler.setFormatter(_CustomFormatter())
|
|
166
|
+
logger.addHandler(console_handler)
|
|
167
|
+
|
|
168
|
+
console_handler.setLevel(LOGGING_LEVEL)
|
|
169
|
+
logger.debug(f"Logging level: {logger.getEffectiveLevel()}")
|
|
170
|
+
|
|
171
|
+
VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))
|
|
172
|
+
_os_is_windows = sys.platform == "win32"
|
|
173
|
+
RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1"))
|
|
174
|
+
if RL_WARNINGS:
|
|
175
|
+
warnings.filterwarnings("once", category=DeprecationWarning, module="torchrl")
|
|
176
|
+
|
|
177
|
+
BATCHED_PIPE_TIMEOUT = float(os.environ.get("BATCHED_PIPE_TIMEOUT", "10000.0"))
|
|
178
|
+
WEIGHT_SYNC_TIMEOUT = float(os.environ.get("WEIGHT_SYNC_TIMEOUT", "60.0"))
|
|
179
|
+
|
|
180
|
+
_TORCH_DTYPES = (
|
|
181
|
+
torch.bfloat16,
|
|
182
|
+
torch.bool,
|
|
183
|
+
torch.complex128,
|
|
184
|
+
torch.complex32,
|
|
185
|
+
torch.complex64,
|
|
186
|
+
torch.float16,
|
|
187
|
+
torch.float32,
|
|
188
|
+
torch.float64,
|
|
189
|
+
torch.int16,
|
|
190
|
+
torch.int32,
|
|
191
|
+
torch.int64,
|
|
192
|
+
torch.int8,
|
|
193
|
+
torch.qint32,
|
|
194
|
+
torch.qint8,
|
|
195
|
+
torch.quint4x2,
|
|
196
|
+
torch.quint8,
|
|
197
|
+
torch.uint8,
|
|
198
|
+
)
|
|
199
|
+
if hasattr(torch, "uint16"):
|
|
200
|
+
_TORCH_DTYPES = _TORCH_DTYPES + (torch.uint16,)
|
|
201
|
+
if hasattr(torch, "uint32"):
|
|
202
|
+
_TORCH_DTYPES = _TORCH_DTYPES + (torch.uint32,)
|
|
203
|
+
if hasattr(torch, "uint64"):
|
|
204
|
+
_TORCH_DTYPES = _TORCH_DTYPES + (torch.uint64,)
|
|
205
|
+
_STR_DTYPE_TO_DTYPE = {str(dtype): dtype for dtype in _TORCH_DTYPES}
|
|
206
|
+
_STRDTYPE2DTYPE = _STR_DTYPE_TO_DTYPE
|
|
207
|
+
_DTYPE_TO_STR_DTYPE = {
|
|
208
|
+
dtype: str_dtype for str_dtype, dtype in _STR_DTYPE_TO_DTYPE.items()
|
|
209
|
+
}
|
|
210
|
+
_DTYPE2STRDTYPE = _STR_DTYPE_TO_DTYPE
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class timeit:
|
|
214
|
+
"""A dirty but easy to use decorator for profiling code.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
name (str): The name of the timer.
|
|
218
|
+
|
|
219
|
+
Examples:
|
|
220
|
+
>>> from torchrl import timeit
|
|
221
|
+
>>> @timeit("my_function")
|
|
222
|
+
>>> def my_function():
|
|
223
|
+
...
|
|
224
|
+
>>> my_function()
|
|
225
|
+
>>> with timeit("my_other_function"):
|
|
226
|
+
... my_other_function()
|
|
227
|
+
>>> timeit.print() # prints the state of the timer for each function
|
|
228
|
+
|
|
229
|
+
The timer can also be queried mid-execution using the :meth:`elapsed` method:
|
|
230
|
+
|
|
231
|
+
>>> with timeit("my_function") as timer:
|
|
232
|
+
... # do some work
|
|
233
|
+
... print(f"Elapsed so far: {timer.elapsed():.3f}s")
|
|
234
|
+
... # do more work
|
|
235
|
+
|
|
236
|
+
For long-running processes where a context manager isn't practical,
|
|
237
|
+
use the :meth:`start` method:
|
|
238
|
+
|
|
239
|
+
>>> timer = timeit("long_process").start()
|
|
240
|
+
>>> for i in range(100):
|
|
241
|
+
... # do work
|
|
242
|
+
... print(f"Elapsed: {timer.elapsed():.3f}s")
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
_REG = {}
|
|
246
|
+
|
|
247
|
+
def __init__(self, name):
|
|
248
|
+
self.name = name
|
|
249
|
+
|
|
250
|
+
def __call__(self, fn: Callable) -> Callable:
|
|
251
|
+
@wraps(fn)
|
|
252
|
+
def decorated_fn(*args, **kwargs):
|
|
253
|
+
with self:
|
|
254
|
+
out = fn(*args, **kwargs)
|
|
255
|
+
return out
|
|
256
|
+
|
|
257
|
+
return decorated_fn
|
|
258
|
+
|
|
259
|
+
def __enter__(self) -> timeit:
|
|
260
|
+
self.t0 = time.time()
|
|
261
|
+
return self
|
|
262
|
+
|
|
263
|
+
def start(self) -> timeit:
|
|
264
|
+
"""Starts the timer without using a context manager.
|
|
265
|
+
|
|
266
|
+
This is useful when you need to track elapsed time over a long-running
|
|
267
|
+
loop or process where a context manager isn't practical.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
timeit: Returns self for method chaining.
|
|
271
|
+
|
|
272
|
+
Examples:
|
|
273
|
+
>>> timer = timeit("my_long_process").start()
|
|
274
|
+
>>> for i in range(100):
|
|
275
|
+
... # do work
|
|
276
|
+
... if i % 10 == 0:
|
|
277
|
+
... print(f"Elapsed: {timer.elapsed():.3f}s")
|
|
278
|
+
"""
|
|
279
|
+
self.t0 = time.time()
|
|
280
|
+
return self
|
|
281
|
+
|
|
282
|
+
def elapsed(self) -> float:
|
|
283
|
+
"""Returns the elapsed time in seconds since the timer was started.
|
|
284
|
+
|
|
285
|
+
This can be called during execution to query the current elapsed time.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
float: Elapsed time in seconds.
|
|
289
|
+
|
|
290
|
+
Examples:
|
|
291
|
+
>>> with timeit("my_function") as timer:
|
|
292
|
+
... # do some work
|
|
293
|
+
... print(f"Elapsed so far: {timer.elapsed():.3f}s")
|
|
294
|
+
... # do more work
|
|
295
|
+
"""
|
|
296
|
+
return time.time() - self.t0
|
|
297
|
+
|
|
298
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
299
|
+
t = self.elapsed()
|
|
300
|
+
val = self._REG.setdefault(self.name, [0.0, 0.0, 0])
|
|
301
|
+
|
|
302
|
+
count = val[2]
|
|
303
|
+
N = count + 1
|
|
304
|
+
val[0] = val[0] * (count / N) + t / N
|
|
305
|
+
val[1] += t
|
|
306
|
+
val[2] = N
|
|
307
|
+
|
|
308
|
+
@staticmethod
|
|
309
|
+
def print(prefix: str | None = None) -> str: # noqa: T202
|
|
310
|
+
"""Prints the state of the timer.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
prefix (str): The prefix to add to the keys. If `None`, no prefix is added.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
the string printed using the logger.
|
|
317
|
+
"""
|
|
318
|
+
keys = list(timeit._REG)
|
|
319
|
+
keys.sort()
|
|
320
|
+
string = []
|
|
321
|
+
for name in keys:
|
|
322
|
+
strings = []
|
|
323
|
+
if prefix:
|
|
324
|
+
strings.append(prefix)
|
|
325
|
+
strings.append(
|
|
326
|
+
f"{name} took {timeit._REG[name][0] * 1000:4.4f} msec (total = {timeit._REG[name][1]: 4.4f} sec since last reset)."
|
|
327
|
+
)
|
|
328
|
+
string.append(" -- ".join(strings))
|
|
329
|
+
logger.info(string[-1])
|
|
330
|
+
return "\n".join(string)
|
|
331
|
+
|
|
332
|
+
_printevery_count = 0
|
|
333
|
+
|
|
334
|
+
@classmethod
|
|
335
|
+
def printevery(
|
|
336
|
+
cls,
|
|
337
|
+
num_prints: int,
|
|
338
|
+
total_count: int,
|
|
339
|
+
*,
|
|
340
|
+
prefix: str | None = None,
|
|
341
|
+
erase: bool = False,
|
|
342
|
+
) -> None:
|
|
343
|
+
"""Prints the state of the timer at regular intervals.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
num_prints (int): The number of times to print the state of the timer, given the total_count.
|
|
347
|
+
total_count (int): The total number of times to print the state of the timer.
|
|
348
|
+
prefix (str): The prefix to add to the keys. If `None`, no prefix is added.
|
|
349
|
+
erase (bool): If True, erase the timer after printing. Default is `False`.
|
|
350
|
+
|
|
351
|
+
"""
|
|
352
|
+
interval = max(1, total_count // num_prints)
|
|
353
|
+
if cls._printevery_count % interval == 0:
|
|
354
|
+
cls.print(prefix=prefix)
|
|
355
|
+
if erase:
|
|
356
|
+
cls.erase()
|
|
357
|
+
cls._printevery_count += 1
|
|
358
|
+
|
|
359
|
+
@classmethod
|
|
360
|
+
def todict(
|
|
361
|
+
cls, percall: bool = True, prefix: str | None = None
|
|
362
|
+
) -> dict[str, float]:
|
|
363
|
+
"""Convert the timer to a dictionary.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
percall (bool): If True, return the average time per call.
|
|
367
|
+
prefix (str): The prefix to add to the keys.
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def _make_key(key):
|
|
371
|
+
if prefix:
|
|
372
|
+
return f"{prefix}/{key}"
|
|
373
|
+
return key
|
|
374
|
+
|
|
375
|
+
if percall:
|
|
376
|
+
return {_make_key(key): val[0] for key, val in cls._REG.items()}
|
|
377
|
+
return {_make_key(key): val[1] for key, val in cls._REG.items()}
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
def erase():
|
|
381
|
+
"""Erase the timer.
|
|
382
|
+
|
|
383
|
+
.. seealso:: :meth:`reset`
|
|
384
|
+
"""
|
|
385
|
+
for k in timeit._REG:
|
|
386
|
+
timeit._REG[k] = [0.0, 0.0, 0]
|
|
387
|
+
|
|
388
|
+
@classmethod
|
|
389
|
+
def reset(cls):
|
|
390
|
+
"""Reset the timer.
|
|
391
|
+
|
|
392
|
+
.. seealso:: :meth:`erase`
|
|
393
|
+
"""
|
|
394
|
+
cls.erase()
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
# Global flag to enable detailed profiling instrumentation.
|
|
398
|
+
# When False (default), _maybe_record_function returns nullcontext() immediately
|
|
399
|
+
# to avoid overhead in hot code paths.
|
|
400
|
+
_PROFILING_ENABLED = False
|
|
401
|
+
|
|
402
|
+
# Singleton nullcontext to avoid repeated object creation
|
|
403
|
+
_NULL_CONTEXT = nullcontext()
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def set_profiling_enabled(enabled: bool) -> None:
|
|
407
|
+
"""Enable or disable detailed profiling instrumentation.
|
|
408
|
+
|
|
409
|
+
When disabled (default), `_maybe_record_function` and `_maybe_timeit`
|
|
410
|
+
return immediately with minimal overhead. Enable only when actively
|
|
411
|
+
profiling to avoid performance regression.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
enabled: If True, enable profiling instrumentation.
|
|
415
|
+
"""
|
|
416
|
+
global _PROFILING_ENABLED
|
|
417
|
+
_PROFILING_ENABLED = enabled
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def _maybe_timeit(name):
|
|
421
|
+
"""Return timeit context if not compiling, nullcontext otherwise.
|
|
422
|
+
|
|
423
|
+
torch.compiler.is_compiling() returns True when inside a compiled region,
|
|
424
|
+
and timeit uses time.time() which dynamo cannot trace.
|
|
425
|
+
"""
|
|
426
|
+
if is_compiling():
|
|
427
|
+
return _NULL_CONTEXT
|
|
428
|
+
return timeit(name)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def _maybe_record_function(name):
|
|
432
|
+
"""Return record_function context if profiling enabled and not compiling.
|
|
433
|
+
|
|
434
|
+
When _PROFILING_ENABLED is False (default), returns immediately with
|
|
435
|
+
minimal overhead to avoid performance regression in hot code paths.
|
|
436
|
+
"""
|
|
437
|
+
if not _PROFILING_ENABLED:
|
|
438
|
+
return _NULL_CONTEXT
|
|
439
|
+
if is_compiling():
|
|
440
|
+
return _NULL_CONTEXT
|
|
441
|
+
|
|
442
|
+
return record_function(name)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _maybe_record_function_decorator(name: str) -> Callable[[Callable], Callable]:
|
|
446
|
+
"""Decorator version of :func:`_maybe_record_function`.
|
|
447
|
+
|
|
448
|
+
This is preferred over sprinkling many context managers in hot code paths,
|
|
449
|
+
as it reduces Python overhead while keeping a useful profiler structure.
|
|
450
|
+
|
|
451
|
+
When _PROFILING_ENABLED is False (default), the decorator is a no-op.
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
def decorator(fn: Callable) -> Callable:
|
|
455
|
+
@wraps(fn)
|
|
456
|
+
def wrapped(*args, **kwargs):
|
|
457
|
+
if not _PROFILING_ENABLED:
|
|
458
|
+
return fn(*args, **kwargs)
|
|
459
|
+
with _maybe_record_function(name):
|
|
460
|
+
return fn(*args, **kwargs)
|
|
461
|
+
|
|
462
|
+
return wrapped
|
|
463
|
+
|
|
464
|
+
return decorator
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def _check_for_faulty_process(processes):
|
|
468
|
+
terminate = False
|
|
469
|
+
for p in processes:
|
|
470
|
+
if not p._closed and not p.is_alive():
|
|
471
|
+
terminate = True
|
|
472
|
+
for _p in processes:
|
|
473
|
+
_p: mp.Process
|
|
474
|
+
if not _p._closed and _p.is_alive():
|
|
475
|
+
try:
|
|
476
|
+
_p.terminate()
|
|
477
|
+
except Exception:
|
|
478
|
+
_p.kill()
|
|
479
|
+
finally:
|
|
480
|
+
time.sleep(0.1)
|
|
481
|
+
_p.close()
|
|
482
|
+
if terminate:
|
|
483
|
+
break
|
|
484
|
+
if terminate:
|
|
485
|
+
raise RuntimeError(
|
|
486
|
+
"At least one process failed. Check for more infos in the log."
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def seed_generator(seed):
|
|
491
|
+
"""A seed generator function.
|
|
492
|
+
|
|
493
|
+
Given a seeding integer, generates a deterministic next seed to be used in a
|
|
494
|
+
seeding sequence.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
seed (int): initial seed.
|
|
498
|
+
|
|
499
|
+
Returns: Next seed of the chain.
|
|
500
|
+
|
|
501
|
+
"""
|
|
502
|
+
max_seed_val = (
|
|
503
|
+
2**32 - 1
|
|
504
|
+
) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688
|
|
505
|
+
rng = np.random.default_rng(seed)
|
|
506
|
+
seed = int.from_bytes(rng.bytes(8), "big")
|
|
507
|
+
return seed % max_seed_val
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class KeyDependentDefaultDict(collections.defaultdict):
|
|
511
|
+
"""A key-dependent default dict.
|
|
512
|
+
|
|
513
|
+
Examples:
|
|
514
|
+
>>> my_dict = KeyDependentDefaultDict(lambda key: "foo_" + key)
|
|
515
|
+
>>> print(my_dict["bar"])
|
|
516
|
+
foo_bar
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
def __init__(self, fun):
|
|
520
|
+
self.fun = fun
|
|
521
|
+
super().__init__()
|
|
522
|
+
|
|
523
|
+
def __missing__(self, key):
|
|
524
|
+
value = self.fun(key)
|
|
525
|
+
self[key] = value
|
|
526
|
+
return value
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def prod(sequence):
|
|
530
|
+
"""General prod function, that generalised usage across math and np.
|
|
531
|
+
|
|
532
|
+
Created for multiple python versions compatibility).
|
|
533
|
+
|
|
534
|
+
"""
|
|
535
|
+
if hasattr(math, "prod"):
|
|
536
|
+
return math.prod(sequence)
|
|
537
|
+
else:
|
|
538
|
+
return int(np.prod(sequence))
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def get_binary_env_var(key):
|
|
542
|
+
"""Parses and returns the binary environment variable value.
|
|
543
|
+
|
|
544
|
+
If not present in environment, it is considered `False`.
|
|
545
|
+
|
|
546
|
+
Args:
|
|
547
|
+
key (str): name of the environment variable.
|
|
548
|
+
"""
|
|
549
|
+
val = os.environ.get(key, "False")
|
|
550
|
+
if val in ("0", "False", "false"):
|
|
551
|
+
val = False
|
|
552
|
+
elif val in ("1", "True", "true"):
|
|
553
|
+
val = True
|
|
554
|
+
else:
|
|
555
|
+
raise ValueError(
|
|
556
|
+
f"Environment variable {key} should be in 'True', 'False', '0' or '1'. "
|
|
557
|
+
f"Got {val} instead."
|
|
558
|
+
)
|
|
559
|
+
return val
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
class _Dynamic_CKPT_BACKEND:
|
|
563
|
+
"""Allows CKPT_BACKEND to be changed on-the-fly."""
|
|
564
|
+
|
|
565
|
+
backends = ["torch", "torchsnapshot"]
|
|
566
|
+
|
|
567
|
+
def _get_backend(self):
|
|
568
|
+
backend = os.environ.get("CKPT_BACKEND", "torch")
|
|
569
|
+
if backend == "torchsnapshot":
|
|
570
|
+
try:
|
|
571
|
+
import torchsnapshot # noqa: F401
|
|
572
|
+
except ImportError as err:
|
|
573
|
+
raise ImportError(
|
|
574
|
+
f"torchsnapshot not found, but the backend points to this library. "
|
|
575
|
+
f"Consider installing torchsnapshot or choose another backend (available backends: {self.backends})"
|
|
576
|
+
) from err
|
|
577
|
+
return backend
|
|
578
|
+
|
|
579
|
+
def __getattr__(self, item):
|
|
580
|
+
return getattr(self._get_backend(), item)
|
|
581
|
+
|
|
582
|
+
def __eq__(self, other):
|
|
583
|
+
return self._get_backend() == other
|
|
584
|
+
|
|
585
|
+
def __ne__(self, other):
|
|
586
|
+
return self._get_backend() != other
|
|
587
|
+
|
|
588
|
+
def __repr__(self):
|
|
589
|
+
return self._get_backend()
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
_CKPT_BACKEND = _Dynamic_CKPT_BACKEND()
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def accept_remote_rref_invocation(func):
|
|
596
|
+
"""Decorator that allows a method to be invoked remotely.
|
|
597
|
+
|
|
598
|
+
Passes the `rpc.RRef` associated with the remote object construction as first argument in place of the object reference.
|
|
599
|
+
|
|
600
|
+
"""
|
|
601
|
+
|
|
602
|
+
@wraps(func)
|
|
603
|
+
def unpack_rref_and_invoke_function(self, *args, **kwargs):
|
|
604
|
+
# windows does not know torch._C._distributed_rpc.PyRRef
|
|
605
|
+
if not _os_is_windows and isinstance(self, torch._C._distributed_rpc.PyRRef):
|
|
606
|
+
self = self.local_value()
|
|
607
|
+
return func(self, *args, **kwargs)
|
|
608
|
+
|
|
609
|
+
return unpack_rref_and_invoke_function
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def accept_remote_rref_udf_invocation(decorated_class):
|
|
613
|
+
"""Class decorator that applies `accept_remote_rref_invocation` to all public methods."""
|
|
614
|
+
# ignores private methods
|
|
615
|
+
for name in dir(decorated_class):
|
|
616
|
+
method = getattr(decorated_class, name, None)
|
|
617
|
+
if method is None:
|
|
618
|
+
continue
|
|
619
|
+
if callable(method) and not name.startswith("_"):
|
|
620
|
+
setattr(decorated_class, name, accept_remote_rref_invocation(method))
|
|
621
|
+
return decorated_class
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
# We copy this from torch as older versions do not have it
|
|
625
|
+
# see torch.utils._contextlib
|
|
626
|
+
|
|
627
|
+
# Extra utilities for working with context managers that should have been
|
|
628
|
+
# in the standard library but are not
|
|
629
|
+
|
|
630
|
+
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
|
|
631
|
+
# 'no_grad' and 'enable_grad').
|
|
632
|
+
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
|
|
633
|
+
FuncType = Callable[..., Any]
|
|
634
|
+
F = TypeVar("F", bound=FuncType)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def _wrap_generator(ctx_factory, func):
|
|
638
|
+
"""Wrap each generator invocation with the context manager factory.
|
|
639
|
+
|
|
640
|
+
The input should be a function that returns a context manager,
|
|
641
|
+
not a context manager itself, to handle one-shot context managers.
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
@functools.wraps(func)
|
|
645
|
+
def generator_context(*args, **kwargs):
|
|
646
|
+
gen = func(*args, **kwargs)
|
|
647
|
+
|
|
648
|
+
# Generators are suspended and unsuspended at `yield`, hence we
|
|
649
|
+
# make sure the grad mode is properly set every time the execution
|
|
650
|
+
# flow returns into the wrapped generator and restored when it
|
|
651
|
+
# returns through our `yield` to our caller (see PR #49017).
|
|
652
|
+
try:
|
|
653
|
+
# Issuing `None` to a generator fires it up
|
|
654
|
+
with ctx_factory():
|
|
655
|
+
response = gen.send(None)
|
|
656
|
+
|
|
657
|
+
while True:
|
|
658
|
+
try:
|
|
659
|
+
# Forward the response to our caller and get its next request
|
|
660
|
+
request = yield response
|
|
661
|
+
|
|
662
|
+
except GeneratorExit:
|
|
663
|
+
# Inform the still active generator about its imminent closure
|
|
664
|
+
with ctx_factory():
|
|
665
|
+
gen.close()
|
|
666
|
+
raise
|
|
667
|
+
|
|
668
|
+
except BaseException:
|
|
669
|
+
# Propagate the exception thrown at us by the caller
|
|
670
|
+
with ctx_factory():
|
|
671
|
+
response = gen.throw(*sys.exc_info())
|
|
672
|
+
|
|
673
|
+
else:
|
|
674
|
+
# Pass the last request to the generator and get its response
|
|
675
|
+
with ctx_factory():
|
|
676
|
+
response = gen.send(request)
|
|
677
|
+
|
|
678
|
+
# We let the exceptions raised above by the generator's `.throw` or
|
|
679
|
+
# `.send` methods bubble up to our caller, except for StopIteration
|
|
680
|
+
except StopIteration as e:
|
|
681
|
+
# The generator informed us that it is done: take whatever its
|
|
682
|
+
# returned value (if any) was and indicate that we're done too
|
|
683
|
+
# by returning it (see docs for python's return-statement).
|
|
684
|
+
return e.value
|
|
685
|
+
|
|
686
|
+
return generator_context
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def context_decorator(ctx, func):
|
|
690
|
+
"""Context decorator.
|
|
691
|
+
|
|
692
|
+
Like contextlib.ContextDecorator, but:
|
|
693
|
+
|
|
694
|
+
1. Is done by wrapping, rather than inheritance, so it works with context
|
|
695
|
+
managers that are implemented from C and thus cannot easily inherit from
|
|
696
|
+
Python classes
|
|
697
|
+
2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
|
|
698
|
+
3. Errors out if you try to wrap a class, because it is ambiguous whether
|
|
699
|
+
or not you intended to wrap only the constructor
|
|
700
|
+
|
|
701
|
+
The input argument can either be a context manager (in which case it must
|
|
702
|
+
be a multi-shot context manager that can be directly invoked multiple times)
|
|
703
|
+
or a callable that produces a context manager.
|
|
704
|
+
"""
|
|
705
|
+
if callable(ctx) and hasattr(ctx, "__enter__"):
|
|
706
|
+
raise RuntimeError(
|
|
707
|
+
f"Passed in {ctx} is both callable and also a valid context manager "
|
|
708
|
+
"(has __enter__), making it ambiguous which interface to use. If you "
|
|
709
|
+
"intended to pass a context manager factory, rewrite your call as "
|
|
710
|
+
"context_decorator(lambda: ctx()); if you intended to pass a context "
|
|
711
|
+
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
if not callable(ctx):
|
|
715
|
+
|
|
716
|
+
def ctx_factory():
|
|
717
|
+
return ctx
|
|
718
|
+
|
|
719
|
+
else:
|
|
720
|
+
ctx_factory = ctx
|
|
721
|
+
|
|
722
|
+
if inspect.isclass(func):
|
|
723
|
+
raise RuntimeError(
|
|
724
|
+
"Cannot decorate classes; it is ambiguous whether only the "
|
|
725
|
+
"constructor or all methods should have the context manager applied; "
|
|
726
|
+
"additionally, decorating a class at definition-site will prevent "
|
|
727
|
+
"use of the identifier as a conventional type. "
|
|
728
|
+
"To specify which methods to decorate, decorate each of them "
|
|
729
|
+
"individually."
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
if inspect.isgeneratorfunction(func):
|
|
733
|
+
return _wrap_generator(ctx_factory, func)
|
|
734
|
+
|
|
735
|
+
@functools.wraps(func)
|
|
736
|
+
def decorate_context(*args, **kwargs):
|
|
737
|
+
with ctx_factory():
|
|
738
|
+
return func(*args, **kwargs)
|
|
739
|
+
|
|
740
|
+
return decorate_context
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
class _DecoratorContextManager:
|
|
744
|
+
"""Allow a context manager to be used as a decorator."""
|
|
745
|
+
|
|
746
|
+
def __call__(self, orig_func: F) -> F:
|
|
747
|
+
if inspect.isclass(orig_func):
|
|
748
|
+
warnings.warn(
|
|
749
|
+
"Decorating classes is deprecated and will be disabled in "
|
|
750
|
+
"future versions. You should only decorate functions or methods. "
|
|
751
|
+
"To preserve the current behavior of class decoration, you can "
|
|
752
|
+
"directly decorate the `__init__` method and nothing else."
|
|
753
|
+
)
|
|
754
|
+
func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
|
|
755
|
+
else:
|
|
756
|
+
func = orig_func
|
|
757
|
+
|
|
758
|
+
return cast(F, context_decorator(self.clone, func))
|
|
759
|
+
|
|
760
|
+
def __enter__(self) -> None:
|
|
761
|
+
raise NotImplementedError
|
|
762
|
+
|
|
763
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
764
|
+
raise NotImplementedError
|
|
765
|
+
|
|
766
|
+
def clone(self):
|
|
767
|
+
# override this method if your children class takes __init__ parameters
|
|
768
|
+
return self.__class__()
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def get_trace():
|
|
772
|
+
"""A simple debugging util to spot where a function is being called."""
|
|
773
|
+
traceback.print_stack()
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def _make_process_no_warn_cls(ctx=None):
|
|
777
|
+
"""Create a _ProcessNoWarn class that inherits from the appropriate Process class.
|
|
778
|
+
|
|
779
|
+
When using multiprocessing contexts (e.g., fork or spawn), the Process class
|
|
780
|
+
used must match the context to ensure synchronization primitives like locks
|
|
781
|
+
work correctly. This factory function creates a _ProcessNoWarn class that
|
|
782
|
+
inherits from the context's Process class.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
ctx: A multiprocessing context (e.g., from mp.get_context('fork')).
|
|
786
|
+
If None, uses the default mp.Process.
|
|
787
|
+
|
|
788
|
+
Returns:
|
|
789
|
+
A _ProcessNoWarn class that inherits from the appropriate Process base.
|
|
790
|
+
|
|
791
|
+
.. note::
|
|
792
|
+
For the "spawn" start method, this returns pre-defined module-level classes
|
|
793
|
+
to ensure they can be pickled correctly.
|
|
794
|
+
"""
|
|
795
|
+
if ctx is None:
|
|
796
|
+
return _ProcessNoWarn
|
|
797
|
+
|
|
798
|
+
start_method = ctx.get_start_method()
|
|
799
|
+
if start_method == "fork":
|
|
800
|
+
return _ProcessNoWarnFork
|
|
801
|
+
elif start_method == "spawn":
|
|
802
|
+
return _ProcessNoWarnSpawn
|
|
803
|
+
elif start_method == "forkserver":
|
|
804
|
+
return _ProcessNoWarnForkserver
|
|
805
|
+
else:
|
|
806
|
+
# For unknown start methods, fall back to default
|
|
807
|
+
return _ProcessNoWarn
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
# Keep the old class name as a default for backwards compatibility
|
|
811
|
+
class _ProcessNoWarn(mp.Process):
|
|
812
|
+
"""A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess.
|
|
813
|
+
|
|
814
|
+
.. note::
|
|
815
|
+
When using multiprocessing contexts with synchronization primitives (locks, etc.),
|
|
816
|
+
use :func:`_make_process_no_warn_cls` with the context to ensure compatibility.
|
|
817
|
+
"""
|
|
818
|
+
|
|
819
|
+
@wraps(mp.Process.__init__)
|
|
820
|
+
def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
|
|
821
|
+
import torchrl
|
|
822
|
+
|
|
823
|
+
self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
|
|
824
|
+
self.num_threads = num_threads
|
|
825
|
+
if _start_method is not None:
|
|
826
|
+
self._start_method = _start_method
|
|
827
|
+
super().__init__(*args, **kwargs)
|
|
828
|
+
|
|
829
|
+
def run(self, *args, **kwargs):
|
|
830
|
+
if self.num_threads is not None:
|
|
831
|
+
torch.set_num_threads(self.num_threads)
|
|
832
|
+
if self.filter_warnings_subprocess:
|
|
833
|
+
import warnings
|
|
834
|
+
|
|
835
|
+
with warnings.catch_warnings():
|
|
836
|
+
warnings.simplefilter("ignore")
|
|
837
|
+
return mp.Process.run(self, *args, **kwargs)
|
|
838
|
+
return mp.Process.run(self, *args, **kwargs)
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
# Pre-defined _ProcessNoWarn classes for different multiprocessing start methods.
|
|
842
|
+
# These must be defined at module level to be picklable with the "spawn" start method.
|
|
843
|
+
#
|
|
844
|
+
# We use a mixin pattern to avoid code duplication while still having
|
|
845
|
+
# distinct module-level classes that can be pickled.
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
class _ProcessNoWarnMixin:
|
|
849
|
+
"""Mixin class providing the common functionality for _ProcessNoWarn variants."""
|
|
850
|
+
|
|
851
|
+
def _init_process_no_warn(self, num_threads=None, _start_method=None):
|
|
852
|
+
import torchrl
|
|
853
|
+
|
|
854
|
+
self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
|
|
855
|
+
self.num_threads = num_threads
|
|
856
|
+
if _start_method is not None:
|
|
857
|
+
self._start_method = _start_method
|
|
858
|
+
|
|
859
|
+
def run(self, *args, **kwargs):
|
|
860
|
+
if self.num_threads is not None:
|
|
861
|
+
torch.set_num_threads(self.num_threads)
|
|
862
|
+
if self.filter_warnings_subprocess:
|
|
863
|
+
import warnings
|
|
864
|
+
|
|
865
|
+
with warnings.catch_warnings():
|
|
866
|
+
warnings.simplefilter("ignore")
|
|
867
|
+
return super().run(*args, **kwargs)
|
|
868
|
+
return super().run(*args, **kwargs)
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
# Spawn-specific class (for macOS default and Windows)
|
|
872
|
+
try:
|
|
873
|
+
_spawn_ctx = mp.get_context("spawn")
|
|
874
|
+
|
|
875
|
+
class _ProcessNoWarnSpawn(_ProcessNoWarnMixin, _spawn_ctx.Process):
|
|
876
|
+
"""_ProcessNoWarn for the 'spawn' multiprocessing context."""
|
|
877
|
+
|
|
878
|
+
def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
|
|
879
|
+
self._init_process_no_warn(num_threads, _start_method)
|
|
880
|
+
super().__init__(*args, **kwargs)
|
|
881
|
+
|
|
882
|
+
except ValueError:
|
|
883
|
+
_ProcessNoWarnSpawn = _ProcessNoWarn
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
# Fork-specific class (for Linux default, not available on Windows)
|
|
887
|
+
try:
|
|
888
|
+
_fork_ctx = mp.get_context("fork")
|
|
889
|
+
|
|
890
|
+
class _ProcessNoWarnFork(_ProcessNoWarnMixin, _fork_ctx.Process):
|
|
891
|
+
"""_ProcessNoWarn for the 'fork' multiprocessing context."""
|
|
892
|
+
|
|
893
|
+
def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
|
|
894
|
+
self._init_process_no_warn(num_threads, _start_method)
|
|
895
|
+
super().__init__(*args, **kwargs)
|
|
896
|
+
|
|
897
|
+
except ValueError:
|
|
898
|
+
_ProcessNoWarnFork = _ProcessNoWarn
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
# Forkserver-specific class (not available on Windows)
|
|
902
|
+
try:
|
|
903
|
+
_forkserver_ctx = mp.get_context("forkserver")
|
|
904
|
+
|
|
905
|
+
class _ProcessNoWarnForkserver(_ProcessNoWarnMixin, _forkserver_ctx.Process):
|
|
906
|
+
"""_ProcessNoWarn for the 'forkserver' multiprocessing context."""
|
|
907
|
+
|
|
908
|
+
def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
|
|
909
|
+
self._init_process_no_warn(num_threads, _start_method)
|
|
910
|
+
super().__init__(*args, **kwargs)
|
|
911
|
+
|
|
912
|
+
except ValueError:
|
|
913
|
+
_ProcessNoWarnForkserver = _ProcessNoWarn
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
def print_directory_tree(path, indent="", display_metadata=True):
|
|
917
|
+
"""Prints the directory tree starting from the specified path.
|
|
918
|
+
|
|
919
|
+
Args:
|
|
920
|
+
path (str): The path of the directory to print.
|
|
921
|
+
indent (str): The current indentation level for formatting.
|
|
922
|
+
display_metadata (bool): if ``True``, metadata of the dir will be
|
|
923
|
+
displayed too.
|
|
924
|
+
|
|
925
|
+
"""
|
|
926
|
+
if display_metadata:
|
|
927
|
+
|
|
928
|
+
def get_directory_size(path="."):
|
|
929
|
+
total_size = 0
|
|
930
|
+
|
|
931
|
+
for dirpath, _, filenames in os.walk(path):
|
|
932
|
+
for filename in filenames:
|
|
933
|
+
file_path = os.path.join(dirpath, filename)
|
|
934
|
+
total_size += os.path.getsize(file_path)
|
|
935
|
+
|
|
936
|
+
return total_size
|
|
937
|
+
|
|
938
|
+
def format_size(size):
|
|
939
|
+
# Convert size to a human-readable format
|
|
940
|
+
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
|
941
|
+
if size < 1024.0:
|
|
942
|
+
return f"{size:.2f} {unit}"
|
|
943
|
+
size /= 1024.0
|
|
944
|
+
|
|
945
|
+
total_size_bytes = get_directory_size(path)
|
|
946
|
+
formatted_size = format_size(total_size_bytes)
|
|
947
|
+
logger.info(f"Directory size: {formatted_size}")
|
|
948
|
+
|
|
949
|
+
if os.path.isdir(path):
|
|
950
|
+
logger.info(indent + os.path.basename(path) + "/")
|
|
951
|
+
indent += " "
|
|
952
|
+
for item in os.listdir(path):
|
|
953
|
+
print_directory_tree(
|
|
954
|
+
os.path.join(path, item), indent=indent, display_metadata=False
|
|
955
|
+
)
|
|
956
|
+
else:
|
|
957
|
+
logger.info(indent + os.path.basename(path))
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def _ends_with(key, match):
|
|
961
|
+
if isinstance(key, str):
|
|
962
|
+
return key == match
|
|
963
|
+
return key[-1] == match
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
|
|
967
|
+
if isinstance(key, str):
|
|
968
|
+
return new_ending
|
|
969
|
+
else:
|
|
970
|
+
return key[:-1] + (new_ending,)
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
def _append_last(key: NestedKey, new_suffix: str) -> NestedKey:
|
|
974
|
+
key = unravel_key(key)
|
|
975
|
+
if isinstance(key, str):
|
|
976
|
+
return key + new_suffix
|
|
977
|
+
else:
|
|
978
|
+
return key[:-1] + (key[-1] + new_suffix,)
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
class _rng_decorator(_DecoratorContextManager):
|
|
982
|
+
"""Temporarily sets the seed and sets back the rng state when exiting."""
|
|
983
|
+
|
|
984
|
+
def __init__(self, seed, device=None):
|
|
985
|
+
self.seed = seed
|
|
986
|
+
self.device = device
|
|
987
|
+
self.has_cuda = torch.cuda.is_available()
|
|
988
|
+
|
|
989
|
+
def __enter__(self):
|
|
990
|
+
self._get_state()
|
|
991
|
+
torch.manual_seed(self.seed)
|
|
992
|
+
|
|
993
|
+
def _get_state(self):
|
|
994
|
+
if self.has_cuda:
|
|
995
|
+
if self.device is None:
|
|
996
|
+
self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state())
|
|
997
|
+
else:
|
|
998
|
+
self._state = (
|
|
999
|
+
torch.random.get_rng_state(),
|
|
1000
|
+
torch.cuda.get_rng_state(self.device),
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
else:
|
|
1004
|
+
self._state = torch.random.get_rng_state()
|
|
1005
|
+
|
|
1006
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
1007
|
+
if self.has_cuda:
|
|
1008
|
+
torch.random.set_rng_state(self._state[0])
|
|
1009
|
+
if self.device is not None:
|
|
1010
|
+
torch.cuda.set_rng_state(self._state[1], device=self.device)
|
|
1011
|
+
else:
|
|
1012
|
+
torch.cuda.set_rng_state(self._state[1])
|
|
1013
|
+
|
|
1014
|
+
else:
|
|
1015
|
+
torch.random.set_rng_state(self._state)
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
def _can_be_pickled(obj):
|
|
1019
|
+
try:
|
|
1020
|
+
pickle.dumps(obj)
|
|
1021
|
+
return True
|
|
1022
|
+
except (pickle.PickleError, AttributeError, TypeError):
|
|
1023
|
+
return False
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def _make_ordinal_device(device: torch.device):
|
|
1027
|
+
if device is None:
|
|
1028
|
+
return device
|
|
1029
|
+
device = torch.device(device)
|
|
1030
|
+
if device.type == "cuda" and device.index is None:
|
|
1031
|
+
return torch.device("cuda", index=torch.cuda.current_device())
|
|
1032
|
+
if device.type == "mps" and device.index is None:
|
|
1033
|
+
return torch.device("mps", index=0)
|
|
1034
|
+
return device
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def get_available_device(return_str: bool = False) -> torch.device | str:
|
|
1038
|
+
"""Return the available accelerator device, or CPU if none is found.
|
|
1039
|
+
|
|
1040
|
+
Checks for accelerator availability in the following order: CUDA, NPU, MPS.
|
|
1041
|
+
Returns the first available accelerator, or CPU if none are present.
|
|
1042
|
+
|
|
1043
|
+
.. note::
|
|
1044
|
+
PyTorch generally assumes a single accelerator type per system.
|
|
1045
|
+
Running with multiple accelerator types (e.g., both CUDA and NPU)
|
|
1046
|
+
is not officially supported. This function simply returns the first
|
|
1047
|
+
available accelerator it finds.
|
|
1048
|
+
|
|
1049
|
+
Args:
|
|
1050
|
+
return_str: If ``True``, returns a string representation of the device
|
|
1051
|
+
instead of a :class:`~torch.device` object. Defaults to ``False``.
|
|
1052
|
+
|
|
1053
|
+
Returns:
|
|
1054
|
+
The available accelerator device as a :class:`~torch.device` object,
|
|
1055
|
+
or as a string if ``return_str`` is ``True``. Falls back to CPU if
|
|
1056
|
+
no accelerator is available.
|
|
1057
|
+
|
|
1058
|
+
Examples:
|
|
1059
|
+
>>> from torchrl._utils import get_available_device
|
|
1060
|
+
>>> device = get_available_device()
|
|
1061
|
+
>>> # Use with config fallback:
|
|
1062
|
+
>>> device = cfg.device or get_available_device()
|
|
1063
|
+
"""
|
|
1064
|
+
if torch.cuda.is_available():
|
|
1065
|
+
device = "cuda:0"
|
|
1066
|
+
elif hasattr(torch, "npu") and torch.npu.is_available():
|
|
1067
|
+
device = "npu:0"
|
|
1068
|
+
elif torch.backends.mps.is_available():
|
|
1069
|
+
device = "mps:0"
|
|
1070
|
+
else:
|
|
1071
|
+
device = "cpu"
|
|
1072
|
+
if return_str:
|
|
1073
|
+
return device
|
|
1074
|
+
return torch.device(device)
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
class _ContextManager:
|
|
1078
|
+
def __init__(self):
|
|
1079
|
+
self._mode: Any | None = None
|
|
1080
|
+
self._lock = threading.Lock()
|
|
1081
|
+
|
|
1082
|
+
def get_mode(self) -> Any | None:
|
|
1083
|
+
cm = self._lock if not is_compiling() else nullcontext()
|
|
1084
|
+
with cm:
|
|
1085
|
+
return self._mode
|
|
1086
|
+
|
|
1087
|
+
def set_mode(self, type: Any | None) -> None:
|
|
1088
|
+
cm = self._lock if not is_compiling() else nullcontext()
|
|
1089
|
+
with cm:
|
|
1090
|
+
self._mode = type
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
def _standardize(
|
|
1094
|
+
input: Tensor,
|
|
1095
|
+
exclude_dims: tuple[int] = (),
|
|
1096
|
+
mean: Tensor | None = None,
|
|
1097
|
+
std: Tensor | None = None,
|
|
1098
|
+
eps: float | None = None,
|
|
1099
|
+
):
|
|
1100
|
+
"""Standardizes the input tensor with the possibility of excluding specific dims from the statistics.
|
|
1101
|
+
|
|
1102
|
+
Useful when processing multi-agent data to keep the agent dimensions independent.
|
|
1103
|
+
|
|
1104
|
+
Args:
|
|
1105
|
+
input (Tensor): the input tensor to be standardized.
|
|
1106
|
+
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
|
|
1107
|
+
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
|
|
1108
|
+
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
|
|
1109
|
+
eps (:obj:`float`): epsilon to be used for numerical stability. Default: float32 resolution.
|
|
1110
|
+
|
|
1111
|
+
"""
|
|
1112
|
+
if eps is None:
|
|
1113
|
+
if input.dtype.is_floating_point:
|
|
1114
|
+
eps = torch.finfo(torch.float).resolution
|
|
1115
|
+
else:
|
|
1116
|
+
eps = 1e-6
|
|
1117
|
+
|
|
1118
|
+
len_exclude_dims = len(exclude_dims)
|
|
1119
|
+
if not len_exclude_dims:
|
|
1120
|
+
if mean is None:
|
|
1121
|
+
mean = input.mean()
|
|
1122
|
+
else:
|
|
1123
|
+
# Assume dtypes are compatible
|
|
1124
|
+
mean = torch.as_tensor(mean, device=input.device)
|
|
1125
|
+
if std is None:
|
|
1126
|
+
std = input.std()
|
|
1127
|
+
else:
|
|
1128
|
+
# Assume dtypes are compatible
|
|
1129
|
+
std = torch.as_tensor(std, device=input.device)
|
|
1130
|
+
return (input - mean) / std.clamp_min(eps)
|
|
1131
|
+
|
|
1132
|
+
input_shape = input.shape
|
|
1133
|
+
exclude_dims = [
|
|
1134
|
+
d if d >= 0 else d + len(input_shape) for d in exclude_dims
|
|
1135
|
+
] # Make negative dims positive
|
|
1136
|
+
|
|
1137
|
+
if len(set(exclude_dims)) != len_exclude_dims:
|
|
1138
|
+
raise ValueError("Exclude dims has repeating elements")
|
|
1139
|
+
if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
|
|
1140
|
+
raise ValueError(
|
|
1141
|
+
f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
|
|
1142
|
+
)
|
|
1143
|
+
if len_exclude_dims == len(input_shape):
|
|
1144
|
+
warnings.warn(
|
|
1145
|
+
"_standardize called but all dims were excluded from the statistics, returning unprocessed input"
|
|
1146
|
+
)
|
|
1147
|
+
return input
|
|
1148
|
+
|
|
1149
|
+
included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims)
|
|
1150
|
+
if mean is None:
|
|
1151
|
+
mean = torch.mean(input, keepdim=True, dim=included_dims)
|
|
1152
|
+
if std is None:
|
|
1153
|
+
std = torch.std(input, keepdim=True, dim=included_dims)
|
|
1154
|
+
return (input - mean) / std.clamp_min(eps)
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
@wraps(torch.compile)
|
|
1158
|
+
def compile_with_warmup(*args, warmup: int = 1, **kwargs):
|
|
1159
|
+
"""Compile a model with warm-up.
|
|
1160
|
+
|
|
1161
|
+
This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase,
|
|
1162
|
+
the original model is used. After the warm-up phase, the model is compiled using
|
|
1163
|
+
`torch.compile`.
|
|
1164
|
+
|
|
1165
|
+
Args:
|
|
1166
|
+
*args: Arguments to be passed to `torch.compile`.
|
|
1167
|
+
warmup (int): Number of calls to the model before compiling it. Defaults to 1.
|
|
1168
|
+
**kwargs: Keyword arguments to be passed to `torch.compile`.
|
|
1169
|
+
|
|
1170
|
+
Returns:
|
|
1171
|
+
A callable that wraps the original model. If no model is provided, returns a
|
|
1172
|
+
lambda function that takes a model as input and returns the wrapped model.
|
|
1173
|
+
|
|
1174
|
+
Notes:
|
|
1175
|
+
If no model is provided, this function returns a lambda function that can be
|
|
1176
|
+
used to wrap a model later. This allows for delayed compilation of the model.
|
|
1177
|
+
|
|
1178
|
+
Example:
|
|
1179
|
+
>>> model = torch.nn.Linear(5, 3)
|
|
1180
|
+
>>> compiled_model = compile_with_warmup(model, warmup=10)
|
|
1181
|
+
>>> # First 10 calls use the original model
|
|
1182
|
+
>>> # After 10 calls, the model is compiled and used
|
|
1183
|
+
"""
|
|
1184
|
+
if len(args):
|
|
1185
|
+
model = args[0]
|
|
1186
|
+
args = ()
|
|
1187
|
+
else:
|
|
1188
|
+
model = kwargs.pop("model", None)
|
|
1189
|
+
if model is None:
|
|
1190
|
+
return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs)
|
|
1191
|
+
else:
|
|
1192
|
+
count = -1
|
|
1193
|
+
compiled_model = model
|
|
1194
|
+
|
|
1195
|
+
@wraps(model)
|
|
1196
|
+
def count_and_compile(*model_args, **model_kwargs):
|
|
1197
|
+
nonlocal count
|
|
1198
|
+
nonlocal compiled_model
|
|
1199
|
+
count += 1
|
|
1200
|
+
if count == warmup:
|
|
1201
|
+
compiled_model = torch.compile(model, *args, **kwargs)
|
|
1202
|
+
return compiled_model(*model_args, **model_kwargs)
|
|
1203
|
+
|
|
1204
|
+
return count_and_compile
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
# auto unwrap control
|
|
1208
|
+
_DEFAULT_AUTO_UNWRAP = True
|
|
1209
|
+
_AUTO_UNWRAP = os.environ.get("AUTO_UNWRAP_TRANSFORMED_ENV")
|
|
1210
|
+
|
|
1211
|
+
|
|
1212
|
+
class set_auto_unwrap_transformed_env(_DecoratorContextManager):
|
|
1213
|
+
"""A context manager or decorator to control whether TransformedEnv should automatically unwrap nested TransformedEnv instances.
|
|
1214
|
+
|
|
1215
|
+
Args:
|
|
1216
|
+
mode (bool): Whether to automatically unwrap nested :class:`~torchrl.envs.TransformedEnv`
|
|
1217
|
+
instances. If ``False``, :class:`~torchrl.envs.TransformedEnv` will not unwrap nested instances.
|
|
1218
|
+
Defaults to ``True``.
|
|
1219
|
+
|
|
1220
|
+
.. note:: Until v0.9, this will raise a warning if :class:`~torchrl.envs.TransformedEnv` are nested
|
|
1221
|
+
and the value is not set explicitly (`auto_unwrap=True` default behavior).
|
|
1222
|
+
You can set the value of :func:`~torchrl.envs.auto_unwrap_transformed_env`
|
|
1223
|
+
through:
|
|
1224
|
+
|
|
1225
|
+
- The ``AUTO_UNWRAP_TRANSFORMED_ENV`` environment variable;
|
|
1226
|
+
- By setting ``torchrl.set_auto_unwrap_transformed_env(val: bool).set()`` at the
|
|
1227
|
+
beginning of your script;
|
|
1228
|
+
- By using ``torchrl.set_auto_unwrap_transformed_env(val: bool)`` as a context
|
|
1229
|
+
manager or a decorator.
|
|
1230
|
+
|
|
1231
|
+
.. seealso:: :class:`~torchrl.envs.TransformedEnv`
|
|
1232
|
+
|
|
1233
|
+
Examples:
|
|
1234
|
+
>>> with set_auto_unwrap_transformed_env(False):
|
|
1235
|
+
... env = TransformedEnv(TransformedEnv(env))
|
|
1236
|
+
... assert not isinstance(env.base_env, TransformedEnv)
|
|
1237
|
+
>>> @set_auto_unwrap_transformed_env(False)
|
|
1238
|
+
... def my_function():
|
|
1239
|
+
... env = TransformedEnv(TransformedEnv(env))
|
|
1240
|
+
... assert not isinstance(env.base_env, TransformedEnv)
|
|
1241
|
+
... return env
|
|
1242
|
+
|
|
1243
|
+
"""
|
|
1244
|
+
|
|
1245
|
+
def __init__(self, mode: bool) -> None:
|
|
1246
|
+
super().__init__()
|
|
1247
|
+
self.mode = mode
|
|
1248
|
+
|
|
1249
|
+
def clone(self) -> set_auto_unwrap_transformed_env:
|
|
1250
|
+
# override this method if your children class takes __init__ parameters
|
|
1251
|
+
return type(self)(self.mode)
|
|
1252
|
+
|
|
1253
|
+
def __enter__(self) -> None:
|
|
1254
|
+
self.set()
|
|
1255
|
+
|
|
1256
|
+
def set(self) -> None:
|
|
1257
|
+
global _AUTO_UNWRAP
|
|
1258
|
+
self._old_mode = _AUTO_UNWRAP
|
|
1259
|
+
_AUTO_UNWRAP = bool(self.mode)
|
|
1260
|
+
# we do this such that sub-processes see the same lazy op than the main one
|
|
1261
|
+
os.environ["AUTO_UNWRAP_TRANSFORMED_ENV"] = str(_AUTO_UNWRAP)
|
|
1262
|
+
|
|
1263
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
1264
|
+
global _AUTO_UNWRAP
|
|
1265
|
+
_AUTO_UNWRAP = self._old_mode
|
|
1266
|
+
os.environ["AUTO_UNWRAP_TRANSFORMED_ENV"] = str(_AUTO_UNWRAP)
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
def auto_unwrap_transformed_env(allow_none=False):
|
|
1270
|
+
"""Get the current setting for automatically unwrapping TransformedEnv instances.
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
allow_none (bool, optional): If True, returns ``None`` if no setting has been
|
|
1274
|
+
specified. Otherwise, returns the default setting. Defaults to ``False``.
|
|
1275
|
+
|
|
1276
|
+
seealso: :func:`~torchrl.set_auto_unwrap_transformed_env`
|
|
1277
|
+
|
|
1278
|
+
Returns:
|
|
1279
|
+
bool or None: The current setting for automatically unwrapping TransformedEnv
|
|
1280
|
+
instances.
|
|
1281
|
+
"""
|
|
1282
|
+
global _AUTO_UNWRAP # noqa: F824
|
|
1283
|
+
if _AUTO_UNWRAP is None and allow_none:
|
|
1284
|
+
return None
|
|
1285
|
+
elif _AUTO_UNWRAP is None:
|
|
1286
|
+
return _DEFAULT_AUTO_UNWRAP
|
|
1287
|
+
return strtobool(_AUTO_UNWRAP) if isinstance(_AUTO_UNWRAP, str) else _AUTO_UNWRAP
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
def safe_is_current_stream_capturing():
|
|
1291
|
+
"""A safe proxy to torch.cuda.is_current_stream_capturing."""
|
|
1292
|
+
if not torch.cuda.is_available():
|
|
1293
|
+
return False
|
|
1294
|
+
try:
|
|
1295
|
+
return torch.cuda.is_current_stream_capturing()
|
|
1296
|
+
except Exception as error:
|
|
1297
|
+
warnings.warn(
|
|
1298
|
+
f"torch.cuda.is_current_stream_capturing() exited unexpectedly with the error message {error=}. "
|
|
1299
|
+
f"Returning False by default."
|
|
1300
|
+
)
|
|
1301
|
+
return False
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
@classmethod
|
|
1305
|
+
def as_remote(cls, remote_config: dict[str, Any] | None = None):
|
|
1306
|
+
"""Creates an instance of a remote ray class.
|
|
1307
|
+
|
|
1308
|
+
Args:
|
|
1309
|
+
cls (Python Class): class to be remotely instantiated.
|
|
1310
|
+
remote_config (dict): the quantity of CPU cores to reserve for this class.
|
|
1311
|
+
|
|
1312
|
+
Returns:
|
|
1313
|
+
A function that creates ray remote class instances.
|
|
1314
|
+
"""
|
|
1315
|
+
import ray
|
|
1316
|
+
|
|
1317
|
+
if remote_config is None:
|
|
1318
|
+
remote_config = {}
|
|
1319
|
+
|
|
1320
|
+
remote_collector = ray.remote(**remote_config)(cls)
|
|
1321
|
+
remote_collector.is_remote = True
|
|
1322
|
+
return remote_collector
|
|
1323
|
+
|
|
1324
|
+
|
|
1325
|
+
def get_ray_default_runtime_env() -> dict[str, Any]:
|
|
1326
|
+
"""Get the default Ray runtime environment configuration for TorchRL.
|
|
1327
|
+
|
|
1328
|
+
This function returns a runtime environment configuration that excludes
|
|
1329
|
+
large directories and files that should not be uploaded to Ray workers.
|
|
1330
|
+
This helps prevent issues with Ray's working_dir size limits (512MB default).
|
|
1331
|
+
|
|
1332
|
+
Returns:
|
|
1333
|
+
dict: A dictionary containing the default runtime_env configuration with
|
|
1334
|
+
excludes for common large directories.
|
|
1335
|
+
|
|
1336
|
+
Examples:
|
|
1337
|
+
>>> import ray
|
|
1338
|
+
>>> from torchrl._utils import get_ray_default_runtime_env
|
|
1339
|
+
>>> ray_init_config = {"num_cpus": 4}
|
|
1340
|
+
>>> ray_init_config["runtime_env"] = get_ray_default_runtime_env()
|
|
1341
|
+
>>> ray.init(**ray_init_config)
|
|
1342
|
+
|
|
1343
|
+
Note:
|
|
1344
|
+
The excludes list includes:
|
|
1345
|
+
- Virtual environments (.venv/, venv/, etc.)
|
|
1346
|
+
- Test files and caches
|
|
1347
|
+
- Documentation builds
|
|
1348
|
+
- Benchmarks
|
|
1349
|
+
- Examples and tutorials
|
|
1350
|
+
- CI/CD configurations
|
|
1351
|
+
- IDE configurations
|
|
1352
|
+
|
|
1353
|
+
"""
|
|
1354
|
+
return {
|
|
1355
|
+
"excludes": [
|
|
1356
|
+
".venv/",
|
|
1357
|
+
"venv/",
|
|
1358
|
+
"env/",
|
|
1359
|
+
"ENV/",
|
|
1360
|
+
"env.bak/",
|
|
1361
|
+
"venv.bak/",
|
|
1362
|
+
"test/",
|
|
1363
|
+
"tests/",
|
|
1364
|
+
"docs/",
|
|
1365
|
+
"benchmarks/",
|
|
1366
|
+
"tutorials/",
|
|
1367
|
+
"examples/",
|
|
1368
|
+
".github/",
|
|
1369
|
+
".pytest_cache/",
|
|
1370
|
+
".mypy_cache/",
|
|
1371
|
+
".ruff_cache/",
|
|
1372
|
+
"__pycache__/",
|
|
1373
|
+
"*.pyc",
|
|
1374
|
+
"*.pyo",
|
|
1375
|
+
"*.egg-info/",
|
|
1376
|
+
".idea/",
|
|
1377
|
+
".vscode/",
|
|
1378
|
+
"dev/",
|
|
1379
|
+
"main/",
|
|
1380
|
+
"*.html",
|
|
1381
|
+
]
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def merge_ray_runtime_env(ray_init_config: dict[str, Any]) -> dict[str, Any]:
|
|
1386
|
+
"""Merge user-provided ray_init_config with default runtime_env excludes.
|
|
1387
|
+
|
|
1388
|
+
This function ensures that the default TorchRL runtime_env excludes are applied
|
|
1389
|
+
to prevent large directories from being uploaded to Ray workers, while preserving
|
|
1390
|
+
any user-provided configuration.
|
|
1391
|
+
|
|
1392
|
+
Args:
|
|
1393
|
+
ray_init_config (dict): The ray init configuration dictionary to merge.
|
|
1394
|
+
|
|
1395
|
+
Returns:
|
|
1396
|
+
dict: The merged configuration with default runtime_env excludes applied.
|
|
1397
|
+
|
|
1398
|
+
Examples:
|
|
1399
|
+
>>> from torchrl._utils import merge_ray_runtime_env
|
|
1400
|
+
>>> ray_init_config = {"num_cpus": 4}
|
|
1401
|
+
>>> ray_init_config = merge_ray_runtime_env(ray_init_config)
|
|
1402
|
+
>>> ray.init(**ray_init_config)
|
|
1403
|
+
|
|
1404
|
+
"""
|
|
1405
|
+
default_runtime_env = get_ray_default_runtime_env()
|
|
1406
|
+
runtime_env = ray_init_config.get("runtime_env")
|
|
1407
|
+
|
|
1408
|
+
# Handle None or missing runtime_env
|
|
1409
|
+
if runtime_env is None:
|
|
1410
|
+
runtime_env = {}
|
|
1411
|
+
ray_init_config["runtime_env"] = runtime_env
|
|
1412
|
+
elif not isinstance(runtime_env, dict):
|
|
1413
|
+
runtime_env = dict(runtime_env)
|
|
1414
|
+
ray_init_config["runtime_env"] = runtime_env
|
|
1415
|
+
|
|
1416
|
+
# Merge excludes lists
|
|
1417
|
+
excludes = runtime_env.get("excludes", [])
|
|
1418
|
+
runtime_env["excludes"] = list(set(default_runtime_env["excludes"] + excludes))
|
|
1419
|
+
|
|
1420
|
+
# Ensure env_vars exists
|
|
1421
|
+
if "env_vars" not in runtime_env:
|
|
1422
|
+
runtime_env["env_vars"] = {}
|
|
1423
|
+
elif not isinstance(runtime_env["env_vars"], dict):
|
|
1424
|
+
runtime_env["env_vars"] = dict(runtime_env["env_vars"])
|
|
1425
|
+
|
|
1426
|
+
return ray_init_config
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
def rl_warnings():
|
|
1430
|
+
"""Checks the status of the RL_WARNINGS env varioble."""
|
|
1431
|
+
return RL_WARNINGS
|