torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
import time
|
|
5
|
+
import warnings
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from collections.abc import Iterator, Sequence
|
|
8
|
+
from queue import Empty
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from tensordict import TensorDict, TensorDictBase
|
|
13
|
+
from tensordict.nn import TensorDictModuleBase
|
|
14
|
+
from torchrl import logger as torchrl_logger
|
|
15
|
+
from torchrl._utils import (
|
|
16
|
+
_check_for_faulty_process,
|
|
17
|
+
accept_remote_rref_udf_invocation,
|
|
18
|
+
RL_WARNINGS,
|
|
19
|
+
)
|
|
20
|
+
from torchrl.collectors._base import _make_legacy_metaclass
|
|
21
|
+
from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT
|
|
22
|
+
from torchrl.collectors._multi_base import _MultiCollectorMeta, MultiCollector
|
|
23
|
+
from torchrl.collectors.utils import split_trajectories
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@accept_remote_rref_udf_invocation
|
|
27
|
+
class MultiSyncCollector(MultiCollector):
|
|
28
|
+
"""Runs a given number of DataCollectors on separate processes synchronously.
|
|
29
|
+
|
|
30
|
+
.. aafig::
|
|
31
|
+
|
|
32
|
+
+----------------------------------------------------------------------+
|
|
33
|
+
| "MultiSyncCollector" | |
|
|
34
|
+
|~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| |
|
|
35
|
+
| "Collector 1" | "Collector 2" | "Collector 3" | Main |
|
|
36
|
+
|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~|
|
|
37
|
+
| "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | |
|
|
38
|
+
|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~|
|
|
39
|
+
|"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | |
|
|
40
|
+
| | | | | | | |
|
|
41
|
+
| "actor" | | | "actor" | |
|
|
42
|
+
| | | | | |
|
|
43
|
+
| "step" | "step" | "actor" | | |
|
|
44
|
+
| | | | | |
|
|
45
|
+
| | | | "step" | "step" | |
|
|
46
|
+
| | | | | | |
|
|
47
|
+
| "actor" | "step" | "step" | "actor" | |
|
|
48
|
+
| | | | | |
|
|
49
|
+
| | "actor" | | |
|
|
50
|
+
| | | | |
|
|
51
|
+
| "yield batch of traj 1"------->"collect, train"|
|
|
52
|
+
| | |
|
|
53
|
+
| "step" | "step" | "step" | "step" | "step" | "step" | |
|
|
54
|
+
| | | | | | | |
|
|
55
|
+
| "actor" | "actor" | | | |
|
|
56
|
+
| | "step" | "step" | "actor" | |
|
|
57
|
+
| | | | | |
|
|
58
|
+
| "step" | "step" | "actor" | "step" | "step" | |
|
|
59
|
+
| | | | | | |
|
|
60
|
+
| "actor" | | "actor" | |
|
|
61
|
+
| "yield batch of traj 2"------->"collect, train"|
|
|
62
|
+
| | |
|
|
63
|
+
+----------------------------------------------------------------------+
|
|
64
|
+
|
|
65
|
+
Envs can be identical or different.
|
|
66
|
+
|
|
67
|
+
The collection starts when the next item of the collector is queried,
|
|
68
|
+
and no environment step is computed in between the reception of a batch of
|
|
69
|
+
trajectory and the start of the next collection.
|
|
70
|
+
This class can be safely used with online RL sota-implementations.
|
|
71
|
+
|
|
72
|
+
.. note::
|
|
73
|
+
Python requires multiprocessed code to be instantiated within a main guard:
|
|
74
|
+
|
|
75
|
+
>>> from torchrl.collectors import MultiSyncCollector
|
|
76
|
+
>>> if __name__ == "__main__":
|
|
77
|
+
... # Create your collector here
|
|
78
|
+
... collector = MultiSyncCollector(...)
|
|
79
|
+
|
|
80
|
+
See https://docs.python.org/3/library/multiprocessing.html for more info.
|
|
81
|
+
|
|
82
|
+
Examples:
|
|
83
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
84
|
+
>>> from tensordict.nn import TensorDictModule
|
|
85
|
+
>>> from torch import nn
|
|
86
|
+
>>> from torchrl.collectors import MultiSyncCollector
|
|
87
|
+
>>> if __name__ == "__main__":
|
|
88
|
+
... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
|
|
89
|
+
... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
|
|
90
|
+
... collector = MultiSyncCollector(
|
|
91
|
+
... create_env_fn=[env_maker, env_maker],
|
|
92
|
+
... policy=policy,
|
|
93
|
+
... total_frames=2000,
|
|
94
|
+
... max_frames_per_traj=50,
|
|
95
|
+
... frames_per_batch=200,
|
|
96
|
+
... init_random_frames=-1,
|
|
97
|
+
... reset_at_each_iter=False,
|
|
98
|
+
... device="cpu",
|
|
99
|
+
... storing_device="cpu",
|
|
100
|
+
... cat_results="stack",
|
|
101
|
+
... )
|
|
102
|
+
... for i, data in enumerate(collector):
|
|
103
|
+
... if i == 2:
|
|
104
|
+
... print(data)
|
|
105
|
+
... break
|
|
106
|
+
... collector.shutdown()
|
|
107
|
+
... del collector
|
|
108
|
+
TensorDict(
|
|
109
|
+
fields={
|
|
110
|
+
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
111
|
+
collector: TensorDict(
|
|
112
|
+
fields={
|
|
113
|
+
traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
114
|
+
batch_size=torch.Size([200]),
|
|
115
|
+
device=cpu,
|
|
116
|
+
is_shared=False),
|
|
117
|
+
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
118
|
+
next: TensorDict(
|
|
119
|
+
fields={
|
|
120
|
+
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
121
|
+
observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
122
|
+
reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
123
|
+
step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
124
|
+
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
125
|
+
batch_size=torch.Size([200]),
|
|
126
|
+
device=cpu,
|
|
127
|
+
is_shared=False),
|
|
128
|
+
observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
129
|
+
step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
130
|
+
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
131
|
+
batch_size=torch.Size([200]),
|
|
132
|
+
device=cpu,
|
|
133
|
+
is_shared=False)
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
__doc__ += MultiCollector.__doc__
|
|
138
|
+
|
|
139
|
+
# for RPC
|
|
140
|
+
def next(self):
|
|
141
|
+
return super().next()
|
|
142
|
+
|
|
143
|
+
# for RPC
|
|
144
|
+
def shutdown(
|
|
145
|
+
self,
|
|
146
|
+
timeout: float | None = None,
|
|
147
|
+
close_env: bool = True,
|
|
148
|
+
raise_on_error: bool = True,
|
|
149
|
+
) -> None:
|
|
150
|
+
if not close_env:
|
|
151
|
+
raise RuntimeError(
|
|
152
|
+
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
|
|
153
|
+
)
|
|
154
|
+
if hasattr(self, "out_buffer"):
|
|
155
|
+
del self.out_buffer
|
|
156
|
+
if hasattr(self, "buffers"):
|
|
157
|
+
del self.buffers
|
|
158
|
+
try:
|
|
159
|
+
return super().shutdown(timeout=timeout)
|
|
160
|
+
except Exception as e:
|
|
161
|
+
if raise_on_error:
|
|
162
|
+
raise e
|
|
163
|
+
else:
|
|
164
|
+
pass
|
|
165
|
+
|
|
166
|
+
# for RPC
|
|
167
|
+
def set_seed(self, seed: int, static_seed: bool = False) -> int:
|
|
168
|
+
return super().set_seed(seed, static_seed)
|
|
169
|
+
|
|
170
|
+
# for RPC
|
|
171
|
+
def state_dict(self) -> OrderedDict:
|
|
172
|
+
return super().state_dict()
|
|
173
|
+
|
|
174
|
+
# for RPC
|
|
175
|
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
|
176
|
+
return super().load_state_dict(state_dict)
|
|
177
|
+
|
|
178
|
+
# for RPC
|
|
179
|
+
def update_policy_weights_(
|
|
180
|
+
self,
|
|
181
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
182
|
+
*,
|
|
183
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
184
|
+
**kwargs,
|
|
185
|
+
) -> None:
|
|
186
|
+
if "policy_weights" in kwargs:
|
|
187
|
+
warnings.warn(
|
|
188
|
+
"`policy_weights` is deprecated. Use `policy_or_weights` instead.",
|
|
189
|
+
DeprecationWarning,
|
|
190
|
+
)
|
|
191
|
+
policy_or_weights = kwargs.pop("policy_weights")
|
|
192
|
+
|
|
193
|
+
super().update_policy_weights_(
|
|
194
|
+
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int:
|
|
198
|
+
if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
|
|
199
|
+
return self._frames_per_batch[worker_idx]
|
|
200
|
+
if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS:
|
|
201
|
+
warnings.warn(
|
|
202
|
+
f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"
|
|
203
|
+
f" this results in more frames_per_batch per iteration that requested."
|
|
204
|
+
"To silence this message, set the environment variable RL_WARNINGS to False."
|
|
205
|
+
)
|
|
206
|
+
frames_per_batch_worker = -(
|
|
207
|
+
-self.requested_frames_per_batch // self.num_workers
|
|
208
|
+
)
|
|
209
|
+
return frames_per_batch_worker
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def _queue_len(self) -> int:
|
|
213
|
+
return self.num_workers
|
|
214
|
+
|
|
215
|
+
def iterator(self) -> Iterator[TensorDictBase]:
|
|
216
|
+
cat_results = self.cat_results
|
|
217
|
+
if cat_results is None:
|
|
218
|
+
cat_results = "stack"
|
|
219
|
+
|
|
220
|
+
self.buffers = [None for _ in range(self.num_workers)]
|
|
221
|
+
dones = [False for _ in range(self.num_workers)]
|
|
222
|
+
workers_frames = [0 for _ in range(self.num_workers)]
|
|
223
|
+
same_device = None
|
|
224
|
+
self.out_buffer = None
|
|
225
|
+
preempt = self.interruptor is not None and self.preemptive_threshold < 1.0
|
|
226
|
+
|
|
227
|
+
while not all(dones) and self._frames < self.total_frames:
|
|
228
|
+
_check_for_faulty_process(self.procs)
|
|
229
|
+
if self.update_at_each_batch:
|
|
230
|
+
self.update_policy_weights_()
|
|
231
|
+
|
|
232
|
+
for idx in range(self.num_workers):
|
|
233
|
+
if self._should_use_random_frames():
|
|
234
|
+
msg = "continue_random"
|
|
235
|
+
else:
|
|
236
|
+
msg = "continue"
|
|
237
|
+
self.pipes[idx].send((None, msg))
|
|
238
|
+
|
|
239
|
+
self._iter += 1
|
|
240
|
+
|
|
241
|
+
if preempt:
|
|
242
|
+
self.interruptor.start_collection()
|
|
243
|
+
while self.queue_out.qsize() < int(
|
|
244
|
+
self.num_workers * self.preemptive_threshold
|
|
245
|
+
):
|
|
246
|
+
continue
|
|
247
|
+
self.interruptor.stop_collection()
|
|
248
|
+
# Now wait for stragglers to return
|
|
249
|
+
while self.queue_out.qsize() < int(self.num_workers):
|
|
250
|
+
continue
|
|
251
|
+
|
|
252
|
+
recv = collections.deque()
|
|
253
|
+
t0 = time.time()
|
|
254
|
+
while len(recv) < self.num_workers and (
|
|
255
|
+
(time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT)
|
|
256
|
+
):
|
|
257
|
+
for _ in range(self.num_workers):
|
|
258
|
+
try:
|
|
259
|
+
new_data, j = self.queue_out.get(timeout=_TIMEOUT)
|
|
260
|
+
recv.append((new_data, j))
|
|
261
|
+
except (TimeoutError, Empty):
|
|
262
|
+
_check_for_faulty_process(self.procs)
|
|
263
|
+
if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT):
|
|
264
|
+
try:
|
|
265
|
+
self.shutdown()
|
|
266
|
+
finally:
|
|
267
|
+
raise RuntimeError(
|
|
268
|
+
f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. "
|
|
269
|
+
f"Increase the MAX_IDLE_COUNT environment variable to bypass this error."
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
for _ in range(self.num_workers):
|
|
273
|
+
new_data, j = recv.popleft()
|
|
274
|
+
use_buffers = self._use_buffers
|
|
275
|
+
if self.replay_buffer is not None:
|
|
276
|
+
idx = new_data
|
|
277
|
+
workers_frames[idx] = workers_frames[
|
|
278
|
+
idx
|
|
279
|
+
] + self.frames_per_batch_worker(worker_idx=idx)
|
|
280
|
+
continue
|
|
281
|
+
elif j == 0 or not use_buffers:
|
|
282
|
+
try:
|
|
283
|
+
data, idx = new_data
|
|
284
|
+
self.buffers[idx] = data
|
|
285
|
+
if use_buffers is None and j > 0:
|
|
286
|
+
self._use_buffers = False
|
|
287
|
+
except TypeError:
|
|
288
|
+
if use_buffers is None:
|
|
289
|
+
self._use_buffers = True
|
|
290
|
+
idx = new_data
|
|
291
|
+
else:
|
|
292
|
+
raise
|
|
293
|
+
else:
|
|
294
|
+
idx = new_data
|
|
295
|
+
|
|
296
|
+
if preempt:
|
|
297
|
+
# mask buffers if cat, and create a mask if stack
|
|
298
|
+
if cat_results != "stack":
|
|
299
|
+
buffers = [None] * self.num_workers
|
|
300
|
+
for worker_idx, buffer in enumerate(self.buffers):
|
|
301
|
+
# Skip pre-empted envs:
|
|
302
|
+
if buffer is None:
|
|
303
|
+
continue
|
|
304
|
+
valid = buffer.get(("collector", "traj_ids")) != -1
|
|
305
|
+
if valid.ndim > 2:
|
|
306
|
+
valid = valid.flatten(0, -2)
|
|
307
|
+
if valid.ndim == 2:
|
|
308
|
+
valid = valid.any(0)
|
|
309
|
+
buffers[worker_idx] = buffer[..., valid]
|
|
310
|
+
else:
|
|
311
|
+
for buffer in filter(lambda x: x is not None, self.buffers):
|
|
312
|
+
with buffer.unlock_():
|
|
313
|
+
buffer.set(
|
|
314
|
+
("collector", "mask"),
|
|
315
|
+
buffer.get(("collector", "traj_ids")) != -1,
|
|
316
|
+
)
|
|
317
|
+
buffers = self.buffers
|
|
318
|
+
else:
|
|
319
|
+
buffers = self.buffers
|
|
320
|
+
|
|
321
|
+
# Skip frame counting if this worker didn't send data this iteration
|
|
322
|
+
# (happens when reusing buffers or on first iteration with some workers)
|
|
323
|
+
if self.buffers[idx] is None:
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()
|
|
327
|
+
|
|
328
|
+
if workers_frames[idx] >= self.total_frames:
|
|
329
|
+
dones[idx] = True
|
|
330
|
+
|
|
331
|
+
if self.replay_buffer is not None:
|
|
332
|
+
yield
|
|
333
|
+
self._frames += sum(
|
|
334
|
+
self.frames_per_batch_worker(worker_idx=worker_idx)
|
|
335
|
+
for worker_idx in range(self.num_workers)
|
|
336
|
+
)
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
# we have to correct the traj_ids to make sure that they don't overlap
|
|
340
|
+
# We can count the number of frames collected for free in this loop
|
|
341
|
+
n_collected = 0
|
|
342
|
+
for idx in range(self.num_workers):
|
|
343
|
+
buffer = buffers[idx]
|
|
344
|
+
if buffer is None:
|
|
345
|
+
continue
|
|
346
|
+
traj_ids = buffer.get(("collector", "traj_ids"))
|
|
347
|
+
if preempt:
|
|
348
|
+
if cat_results == "stack":
|
|
349
|
+
mask_frames = buffer.get(("collector", "traj_ids")) != -1
|
|
350
|
+
n_collected += mask_frames.sum().cpu()
|
|
351
|
+
else:
|
|
352
|
+
n_collected += traj_ids.numel()
|
|
353
|
+
else:
|
|
354
|
+
n_collected += traj_ids.numel()
|
|
355
|
+
|
|
356
|
+
if same_device is None:
|
|
357
|
+
prev_device = None
|
|
358
|
+
same_device = True
|
|
359
|
+
for item in filter(lambda x: x is not None, self.buffers):
|
|
360
|
+
if prev_device is None:
|
|
361
|
+
prev_device = item.device
|
|
362
|
+
else:
|
|
363
|
+
same_device = same_device and (item.device == prev_device)
|
|
364
|
+
|
|
365
|
+
if self.split_trajs:
|
|
366
|
+
max_traj_id = -1
|
|
367
|
+
for idx in range(self.num_workers):
|
|
368
|
+
if buffers[idx] is not None:
|
|
369
|
+
traj_ids = buffers[idx].get(("collector", "traj_ids"))
|
|
370
|
+
if traj_ids is not None:
|
|
371
|
+
buffers[idx].set_(
|
|
372
|
+
("collector", "traj_ids"), traj_ids + max_traj_id + 1
|
|
373
|
+
)
|
|
374
|
+
max_traj_id = (
|
|
375
|
+
buffers[idx].get(("collector", "traj_ids")).max()
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if cat_results == "stack":
|
|
379
|
+
stack = (
|
|
380
|
+
torch.stack if self._use_buffers else TensorDict.maybe_dense_stack
|
|
381
|
+
)
|
|
382
|
+
if same_device:
|
|
383
|
+
self.out_buffer = stack(
|
|
384
|
+
[item for item in buffers if item is not None], 0
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
self.out_buffer = stack(
|
|
388
|
+
[item.cpu() for item in buffers if item is not None], 0
|
|
389
|
+
)
|
|
390
|
+
else:
|
|
391
|
+
if self._use_buffers is None:
|
|
392
|
+
torchrl_logger.warning(
|
|
393
|
+
"use_buffer not specified and not yet inferred from data, assuming `True`."
|
|
394
|
+
)
|
|
395
|
+
elif not self._use_buffers:
|
|
396
|
+
raise RuntimeError(
|
|
397
|
+
"Cannot concatenate results with use_buffers=False"
|
|
398
|
+
)
|
|
399
|
+
try:
|
|
400
|
+
if same_device:
|
|
401
|
+
self.out_buffer = torch.cat(
|
|
402
|
+
[item for item in buffers if item is not None], cat_results
|
|
403
|
+
)
|
|
404
|
+
else:
|
|
405
|
+
self.out_buffer = torch.cat(
|
|
406
|
+
[item.cpu() for item in buffers if item is not None],
|
|
407
|
+
cat_results,
|
|
408
|
+
)
|
|
409
|
+
except RuntimeError as err:
|
|
410
|
+
if (
|
|
411
|
+
preempt
|
|
412
|
+
and cat_results != -1
|
|
413
|
+
and "Sizes of tensors must match" in str(err)
|
|
414
|
+
):
|
|
415
|
+
raise RuntimeError(
|
|
416
|
+
"The value provided to cat_results isn't compatible with the collectors outputs. "
|
|
417
|
+
"Consider using `cat_results=-1`."
|
|
418
|
+
)
|
|
419
|
+
raise
|
|
420
|
+
|
|
421
|
+
# TODO: why do we need to do cat inplace and clone?
|
|
422
|
+
if self.split_trajs:
|
|
423
|
+
out = split_trajectories(self.out_buffer, prefix="collector")
|
|
424
|
+
else:
|
|
425
|
+
out = self.out_buffer
|
|
426
|
+
if cat_results in (-1, "stack"):
|
|
427
|
+
out.refine_names(*[None] * (out.ndim - 1) + ["time"])
|
|
428
|
+
|
|
429
|
+
self._frames += n_collected
|
|
430
|
+
|
|
431
|
+
if self.postprocs:
|
|
432
|
+
self.postprocs = (
|
|
433
|
+
self.postprocs.to(out.device)
|
|
434
|
+
if hasattr(self.postprocs, "to")
|
|
435
|
+
else self.postprocs
|
|
436
|
+
)
|
|
437
|
+
out = self.postprocs(out)
|
|
438
|
+
if self._exclude_private_keys:
|
|
439
|
+
excluded_keys = [key for key in out.keys() if key.startswith("_")]
|
|
440
|
+
if excluded_keys:
|
|
441
|
+
out = out.exclude(*excluded_keys)
|
|
442
|
+
yield out
|
|
443
|
+
|
|
444
|
+
del self.buffers
|
|
445
|
+
self.out_buffer = None
|
|
446
|
+
# We shall not call shutdown just yet as user may want to retrieve state_dict
|
|
447
|
+
# self._shutdown_main()
|
|
448
|
+
|
|
449
|
+
# for RPC
|
|
450
|
+
def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
|
|
451
|
+
return super().receive_weights(policy_or_weights)
|
|
452
|
+
|
|
453
|
+
# for RPC
|
|
454
|
+
def _receive_weights_scheme(self):
|
|
455
|
+
return super()._receive_weights_scheme()
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
_LegacyMultiSyncMeta = _make_legacy_metaclass(_MultiCollectorMeta)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
class MultiSyncDataCollector(MultiSyncCollector, metaclass=_LegacyMultiSyncMeta):
|
|
462
|
+
"""Deprecated version of :class:`~torchrl.collectors.MultiSyncCollector`."""
|
|
463
|
+
|
|
464
|
+
...
|