torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import contextlib
|
|
8
|
+
import importlib
|
|
9
|
+
from collections.abc import Callable, Iterator
|
|
10
|
+
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from torchrl._utils import logger as torchrl_logger
|
|
15
|
+
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
|
|
16
|
+
from torchrl.envs.transforms.transforms import Transform
|
|
17
|
+
|
|
18
|
+
RAY_ERR = None
|
|
19
|
+
_has_ray = importlib.util.find_spec("ray") is not None
|
|
20
|
+
if _has_ray:
|
|
21
|
+
import ray
|
|
22
|
+
else:
|
|
23
|
+
|
|
24
|
+
def ray(): # noqa: D103
|
|
25
|
+
raise ImportError(
|
|
26
|
+
"ray is not installed. Please install it with `pip install ray`."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def as_remote(cls, remote_config=None):
|
|
32
|
+
"""Creates an instance of a remote ray class.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
cls (Python Class): class to be remotely instantiated.
|
|
36
|
+
remote_config (dict): the quantity of CPU cores to reserve for this class.
|
|
37
|
+
Defaults to `torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG`.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A function that creates ray remote class instances.
|
|
41
|
+
"""
|
|
42
|
+
if remote_config is None:
|
|
43
|
+
from torchrl.collectors.distributed.ray import DEFAULT_REMOTE_CLASS_CONFIG
|
|
44
|
+
|
|
45
|
+
remote_config = DEFAULT_REMOTE_CLASS_CONFIG
|
|
46
|
+
remote_collector = ray.remote(**remote_config)(cls)
|
|
47
|
+
remote_collector.is_remote = True
|
|
48
|
+
return remote_collector
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
ReplayBuffer.as_remote = as_remote
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RayReplayBuffer(ReplayBuffer):
|
|
55
|
+
"""A Ray implementation of the Replay Buffer that can be extended and sampled remotely.
|
|
56
|
+
|
|
57
|
+
Keyword Args:
|
|
58
|
+
replay_buffer_cls (type[ReplayBuffer], optional): the class to use for the replay buffer.
|
|
59
|
+
Defaults to :class:`~torchrl.data.ReplayBuffer`.
|
|
60
|
+
ray_init_config (dict[str, Any], optiona): keyword arguments to pass to `ray.init()`.
|
|
61
|
+
remote_config (dict[str, Any], optiona): keyword arguments to pass to `cls.as_remote()`.
|
|
62
|
+
Defaults to `torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG`.
|
|
63
|
+
**kwargs: keyword arguments to pass to the replay buffer class.
|
|
64
|
+
|
|
65
|
+
.. seealso:: :class:`~torchrl.data.ReplayBuffer` for a list of other keyword arguments.
|
|
66
|
+
|
|
67
|
+
The writer, sampler and storage should be passed as constructors to prevent serialization issues.
|
|
68
|
+
Transforms constructors should be passed through the `transform_factory` argument.
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
>>> import asyncio
|
|
72
|
+
>>> from tensordict.nn import TensorDictModule
|
|
73
|
+
>>> from torch import nn
|
|
74
|
+
>>> from torchrl.collectors.distributed.ray import RayCollector
|
|
75
|
+
>>> from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
|
|
76
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
77
|
+
>>>
|
|
78
|
+
>>> async def main():
|
|
79
|
+
... # 1. Create environment factory
|
|
80
|
+
... def env_maker():
|
|
81
|
+
... return GymEnv("Pendulum-v1", device="cpu")
|
|
82
|
+
...
|
|
83
|
+
... policy = TensorDictModule(
|
|
84
|
+
... nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
|
|
85
|
+
... )
|
|
86
|
+
...
|
|
87
|
+
... buffer = RayReplayBuffer()
|
|
88
|
+
...
|
|
89
|
+
... # 2. Define distributed collector
|
|
90
|
+
... remote_config = {
|
|
91
|
+
... "num_cpus": 1,
|
|
92
|
+
... "num_gpus": 0,
|
|
93
|
+
... "memory": 5 * 1024**3,
|
|
94
|
+
... "object_store_memory": 2 * 1024**3,
|
|
95
|
+
... }
|
|
96
|
+
... distributed_collector = RayCollector(
|
|
97
|
+
... [env_maker],
|
|
98
|
+
... policy,
|
|
99
|
+
... total_frames=600,
|
|
100
|
+
... frames_per_batch=200,
|
|
101
|
+
... remote_configs=remote_config,
|
|
102
|
+
... replay_buffer=buffer,
|
|
103
|
+
... )
|
|
104
|
+
...
|
|
105
|
+
... print("start")
|
|
106
|
+
... distributed_collector.start()
|
|
107
|
+
...
|
|
108
|
+
... while True:
|
|
109
|
+
... while not len(buffer):
|
|
110
|
+
... print("waiting")
|
|
111
|
+
... await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep
|
|
112
|
+
... print("sample", buffer.sample(32))
|
|
113
|
+
... # break at some point
|
|
114
|
+
... break
|
|
115
|
+
...
|
|
116
|
+
... await distributed_collector.async_shutdown(shutdown_ray=False)
|
|
117
|
+
... buffer.close() # Close buffer after collector
|
|
118
|
+
>>>
|
|
119
|
+
>>> if __name__ == "__main__":
|
|
120
|
+
... asyncio.run(main())
|
|
121
|
+
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
*args,
|
|
127
|
+
replay_buffer_cls: type[ReplayBuffer] | None = ReplayBuffer,
|
|
128
|
+
ray_init_config: dict[str, Any] | None = None,
|
|
129
|
+
remote_config: dict[str, Any] | None = None,
|
|
130
|
+
delayed_init: bool = False,
|
|
131
|
+
**kwargs,
|
|
132
|
+
) -> None:
|
|
133
|
+
if not _has_ray:
|
|
134
|
+
raise RuntimeError(
|
|
135
|
+
"ray library not found, unable to create a RayReplayBuffer. "
|
|
136
|
+
) from RAY_ERR
|
|
137
|
+
if not ray.is_initialized():
|
|
138
|
+
if ray_init_config is None:
|
|
139
|
+
from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG
|
|
140
|
+
|
|
141
|
+
ray_init_config = DEFAULT_RAY_INIT_CONFIG
|
|
142
|
+
ray.init(**ray_init_config)
|
|
143
|
+
|
|
144
|
+
remote_cls = replay_buffer_cls.as_remote(remote_config).remote
|
|
145
|
+
# We can detect if the buffer has a GPU allocated, if not
|
|
146
|
+
# we'll make sure that the data is sent to CPU when needed.
|
|
147
|
+
if remote_config is not None:
|
|
148
|
+
self.has_gpu = remote_config.get("num_gpus", 0) > 0
|
|
149
|
+
else:
|
|
150
|
+
self.has_gpu = False
|
|
151
|
+
self._rb = remote_cls(*args, delayed_init=delayed_init, **kwargs)
|
|
152
|
+
self._delayed_init = False
|
|
153
|
+
|
|
154
|
+
def close(self):
|
|
155
|
+
"""Terminates the Ray actor associated with this replay buffer."""
|
|
156
|
+
if hasattr(self, "_rb"):
|
|
157
|
+
try:
|
|
158
|
+
torchrl_logger.info("Killing Ray actor.")
|
|
159
|
+
ray.kill(self._rb) # Forcefully terminate the actor
|
|
160
|
+
torchrl_logger.info("Ray actor killed.")
|
|
161
|
+
except (ValueError, RuntimeError) as e:
|
|
162
|
+
# Actor may already be dead if ray.shutdown() was called
|
|
163
|
+
torchrl_logger.debug(
|
|
164
|
+
f"Failed to kill Ray actor (may already be terminated): {e}"
|
|
165
|
+
)
|
|
166
|
+
finally:
|
|
167
|
+
delattr(self, "_rb") # Remove the reference to the terminated actor
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def _replay_lock(self):
|
|
171
|
+
"""Placeholder for the replay lock.
|
|
172
|
+
|
|
173
|
+
Replay-lock is not supported yet by RayReplayBuffer.
|
|
174
|
+
"""
|
|
175
|
+
return contextlib.nullcontext()
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def batch_size(self):
|
|
179
|
+
return ray.get(self._rb._getattr.remote("_batch_size"))
|
|
180
|
+
|
|
181
|
+
def sample(self, *args, **kwargs):
|
|
182
|
+
pending_task = self._rb.sample.remote(*args, **kwargs)
|
|
183
|
+
return ray.get(pending_task)
|
|
184
|
+
|
|
185
|
+
def extend(self, *args, **kwargs):
|
|
186
|
+
if not self.has_gpu:
|
|
187
|
+
# Move the data to GPU
|
|
188
|
+
args = [arg.to("cpu") for arg in args if hasattr(arg, "to")]
|
|
189
|
+
kwargs = {k: v.to("cpu") for k, v in kwargs.items() if hasattr(v, "to")}
|
|
190
|
+
pending_task = self._rb.extend.remote(*args, **kwargs)
|
|
191
|
+
return ray.get(pending_task)
|
|
192
|
+
|
|
193
|
+
def add(self, *args, **kwargs):
|
|
194
|
+
return ray.get(self._rb.add.remote(*args, **kwargs))
|
|
195
|
+
|
|
196
|
+
def update_priority(self, *args, **kwargs):
|
|
197
|
+
return ray.get(self._rb.update_priority.remote(*args, **kwargs))
|
|
198
|
+
|
|
199
|
+
def append_transform(self, *args, **kwargs):
|
|
200
|
+
return ray.get(self._rb.append_transform.remote(*args, **kwargs))
|
|
201
|
+
|
|
202
|
+
def dumps(self, path):
|
|
203
|
+
return ray.get(self._rb.dumps.remote(path))
|
|
204
|
+
|
|
205
|
+
def dump(self, path):
|
|
206
|
+
return ray.get(self._rb.dump.remote(path))
|
|
207
|
+
|
|
208
|
+
def loads(self, path):
|
|
209
|
+
return ray.get(self._rb.loads.remote(path))
|
|
210
|
+
|
|
211
|
+
def load(self, *args, **kwargs):
|
|
212
|
+
return ray.get(self._rb.load.remote(*args, **kwargs))
|
|
213
|
+
|
|
214
|
+
def empty(self, empty_write_count: bool = True):
|
|
215
|
+
return ray.get(self._rb.empty.remote(empty_write_count=empty_write_count))
|
|
216
|
+
|
|
217
|
+
def __getitem__(self, index):
|
|
218
|
+
return ray.get(self._rb.__getitem__.remote(index))
|
|
219
|
+
|
|
220
|
+
def next(self):
|
|
221
|
+
return ray.get(self._rb.next.remote())
|
|
222
|
+
|
|
223
|
+
def __iter__(self) -> Iterator[Any]:
|
|
224
|
+
"""Returns an iterator that yields None as the collector writes directly to the replay buffer."""
|
|
225
|
+
while True:
|
|
226
|
+
data = self.next()
|
|
227
|
+
if data is not None:
|
|
228
|
+
yield data
|
|
229
|
+
else:
|
|
230
|
+
break
|
|
231
|
+
|
|
232
|
+
def insert_transform(
|
|
233
|
+
self,
|
|
234
|
+
index: int,
|
|
235
|
+
transform: Transform, # noqa-F821
|
|
236
|
+
*,
|
|
237
|
+
invert: bool = False,
|
|
238
|
+
) -> ReplayBuffer:
|
|
239
|
+
return ray.get(
|
|
240
|
+
self._rb.insert_transform.remote(index, transform, invert=invert)
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def mark_update(self, index: int | torch.Tensor) -> None:
|
|
244
|
+
return ray.get(self._rb.mark_update.remote(index))
|
|
245
|
+
|
|
246
|
+
def register_load_hook(self, hook: Callable[[Any], Any]):
|
|
247
|
+
return ray.get(self._rb.register_load_hook.remote(hook))
|
|
248
|
+
|
|
249
|
+
def register_save_hook(self, hook: Callable[[Any], Any]):
|
|
250
|
+
return ray.get(self._rb.register_save_hook.remote(hook))
|
|
251
|
+
|
|
252
|
+
def save(self, path: str):
|
|
253
|
+
return ray.get(self._rb.save.remote(path))
|
|
254
|
+
|
|
255
|
+
def set_rng(self, generator):
|
|
256
|
+
return ray.get(self._rb.set_rng.remote(generator))
|
|
257
|
+
|
|
258
|
+
def set_sampler(self, sampler):
|
|
259
|
+
return ray.get(self._rb.set_sampler.remote(sampler))
|
|
260
|
+
|
|
261
|
+
def set_storage(self, storage):
|
|
262
|
+
return ray.get(self._rb.set_storage.remote(storage))
|
|
263
|
+
|
|
264
|
+
def set_writer(self, writer):
|
|
265
|
+
return ray.get(self._rb.set_writer.remote(writer))
|
|
266
|
+
|
|
267
|
+
def share(self, shared: bool = True):
|
|
268
|
+
return ray.get(self._rb.share.remote(shared))
|
|
269
|
+
|
|
270
|
+
def state_dict(self):
|
|
271
|
+
return ray.get(self._rb.state_dict.remote())
|
|
272
|
+
|
|
273
|
+
def __len__(self):
|
|
274
|
+
return ray.get(self._rb.__len__.remote())
|
|
275
|
+
|
|
276
|
+
@property
|
|
277
|
+
def write_count(self):
|
|
278
|
+
return ray.get(self._rb._getattr.remote("write_count"))
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def dim_extend(self):
|
|
282
|
+
return ray.get(self._rb._getattr.remote("dim_extend"))
|
|
283
|
+
|
|
284
|
+
@dim_extend.setter
|
|
285
|
+
def dim_extend(self, value):
|
|
286
|
+
return ray.get(self._rb._setattr.remote("dim_extend", value))
|
|
287
|
+
|
|
288
|
+
def __setitem__(self, index, value) -> None:
|
|
289
|
+
return ray.get(self._rb.__setitem__.remote(index, value))
|
|
290
|
+
|
|
291
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
292
|
+
return ray.get(self._rb.load_state_dict.remote(state_dict))
|