torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1058 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import contextlib
|
|
5
|
+
import functools
|
|
6
|
+
import typing
|
|
7
|
+
import warnings
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from collections.abc import Callable, Iterator
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, overload
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from tensordict import TensorDict, TensorDictBase
|
|
17
|
+
from tensordict.base import NO_DEFAULT
|
|
18
|
+
from tensordict.nn import TensorDictModule, TensorDictModuleBase
|
|
19
|
+
from torch import nn as nn
|
|
20
|
+
from torch.utils.data import IterableDataset
|
|
21
|
+
from torchrl.collectors.utils import _map_weight
|
|
22
|
+
|
|
23
|
+
from torchrl.collectors.weight_update import WeightUpdaterBase
|
|
24
|
+
from torchrl.weight_update.utils import _resolve_attr
|
|
25
|
+
from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ProfileConfig:
|
|
30
|
+
"""Configuration for profiling collector workers.
|
|
31
|
+
|
|
32
|
+
This class holds all the settings for profiling collector rollouts
|
|
33
|
+
using PyTorch's profiler. It's designed to work across all collector types.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
workers: List of worker indices to profile. For single-process collectors
|
|
37
|
+
(like Collector), this is ignored. For multi-process collectors
|
|
38
|
+
(like MultiSyncCollector, MultiAsyncCollector), only the specified
|
|
39
|
+
workers will be profiled. Defaults to [0].
|
|
40
|
+
num_rollouts: Total number of rollouts to profile (including warmup).
|
|
41
|
+
After this many rollouts, profiling stops. Defaults to 3.
|
|
42
|
+
warmup_rollouts: Number of rollouts to skip before starting actual
|
|
43
|
+
profiling. This allows JIT/compile warmup. Defaults to 1.
|
|
44
|
+
save_path: Path to save the profiling trace. If None, traces are saved
|
|
45
|
+
to "./collector_profile_{worker_idx}.json". Supports {worker_idx}
|
|
46
|
+
placeholder for worker-specific files.
|
|
47
|
+
activities: List of profiler activities. Defaults to CPU and CUDA.
|
|
48
|
+
record_shapes: Whether to record tensor shapes. Defaults to True.
|
|
49
|
+
profile_memory: Whether to profile memory usage. Defaults to False.
|
|
50
|
+
with_stack: Whether to record stack traces. Defaults to True.
|
|
51
|
+
with_flops: Whether to compute FLOPS. Defaults to False.
|
|
52
|
+
on_trace_ready: Optional callback when trace is ready. If None,
|
|
53
|
+
traces are exported to Chrome trace format at save_path.
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
>>> from torchrl.collectors import MultiSyncCollector, ProfileConfig
|
|
57
|
+
>>> collector = MultiSyncCollector(...)
|
|
58
|
+
>>> collector.enable_profile(
|
|
59
|
+
... workers=[0],
|
|
60
|
+
... num_rollouts=5,
|
|
61
|
+
... warmup_rollouts=2,
|
|
62
|
+
... save_path="./traces/worker_{worker_idx}.json",
|
|
63
|
+
... )
|
|
64
|
+
>>> for data in collector:
|
|
65
|
+
... # First worker will be profiled for rollouts 2-4
|
|
66
|
+
... process(data)
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
workers: list[int] = field(default_factory=lambda: [0])
|
|
70
|
+
num_rollouts: int = 3
|
|
71
|
+
warmup_rollouts: int = 1
|
|
72
|
+
save_path: str | Path | None = None
|
|
73
|
+
activities: list[str] = field(default_factory=lambda: ["cpu", "cuda"])
|
|
74
|
+
record_shapes: bool = True
|
|
75
|
+
profile_memory: bool = False
|
|
76
|
+
with_stack: bool = True
|
|
77
|
+
with_flops: bool = False
|
|
78
|
+
on_trace_ready: Callable | None = None
|
|
79
|
+
|
|
80
|
+
def __post_init__(self):
|
|
81
|
+
"""Validate configuration after initialization."""
|
|
82
|
+
if self.num_rollouts <= self.warmup_rollouts:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"num_rollouts ({self.num_rollouts}) must be greater than "
|
|
85
|
+
f"warmup_rollouts ({self.warmup_rollouts})"
|
|
86
|
+
)
|
|
87
|
+
if self.warmup_rollouts < 0:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"warmup_rollouts must be >= 0, got {self.warmup_rollouts}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def get_save_path(self, worker_idx: int) -> Path:
|
|
93
|
+
"""Get the save path for a specific worker.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
worker_idx: The worker index.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Path object for the trace file.
|
|
100
|
+
"""
|
|
101
|
+
if self.save_path is None:
|
|
102
|
+
return Path(f"./collector_profile_{worker_idx}.json")
|
|
103
|
+
path_str = str(self.save_path).format(worker_idx=worker_idx)
|
|
104
|
+
return Path(path_str)
|
|
105
|
+
|
|
106
|
+
def get_activities(self) -> list:
|
|
107
|
+
"""Get PyTorch profiler activity list.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
List of torch.profiler.ProfilerActivity values.
|
|
111
|
+
"""
|
|
112
|
+
import torch.profiler
|
|
113
|
+
|
|
114
|
+
activity_map = {
|
|
115
|
+
"cpu": torch.profiler.ProfilerActivity.CPU,
|
|
116
|
+
"cuda": torch.profiler.ProfilerActivity.CUDA,
|
|
117
|
+
}
|
|
118
|
+
result = []
|
|
119
|
+
for activity in self.activities:
|
|
120
|
+
activity_lower = activity.lower()
|
|
121
|
+
if activity_lower in activity_map:
|
|
122
|
+
# Only add CUDA if CUDA is available
|
|
123
|
+
if activity_lower == "cuda" and not torch.cuda.is_available():
|
|
124
|
+
continue
|
|
125
|
+
result.append(activity_map[activity_lower])
|
|
126
|
+
return result
|
|
127
|
+
|
|
128
|
+
def should_profile_worker(self, worker_idx: int) -> bool:
|
|
129
|
+
"""Check if a specific worker should be profiled.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
worker_idx: The worker index to check.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
True if this worker should be profiled.
|
|
136
|
+
"""
|
|
137
|
+
return worker_idx in self.workers
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class BaseCollector(IterableDataset, metaclass=abc.ABCMeta):
|
|
141
|
+
"""Base class for data collectors."""
|
|
142
|
+
|
|
143
|
+
_task = None
|
|
144
|
+
_iterator = None
|
|
145
|
+
_iteration_started = False
|
|
146
|
+
total_frames: int
|
|
147
|
+
requested_frames_per_batch: int
|
|
148
|
+
frames_per_batch: int
|
|
149
|
+
trust_policy: bool
|
|
150
|
+
compiled_policy: bool
|
|
151
|
+
cudagraphed_policy: bool
|
|
152
|
+
_weight_updater: WeightUpdaterBase | None = None
|
|
153
|
+
_weight_sync_schemes: dict[str, WeightSyncScheme] | None = None
|
|
154
|
+
verbose: bool = False
|
|
155
|
+
_profile_config: ProfileConfig | None = None
|
|
156
|
+
|
|
157
|
+
def enable_profile(
|
|
158
|
+
self,
|
|
159
|
+
*,
|
|
160
|
+
workers: list[int] | None = None,
|
|
161
|
+
num_rollouts: int = 3,
|
|
162
|
+
warmup_rollouts: int = 1,
|
|
163
|
+
save_path: str | Path | None = None,
|
|
164
|
+
activities: list[str] | None = None,
|
|
165
|
+
record_shapes: bool = True,
|
|
166
|
+
profile_memory: bool = False,
|
|
167
|
+
with_stack: bool = True,
|
|
168
|
+
with_flops: bool = False,
|
|
169
|
+
on_trace_ready: Callable | None = None,
|
|
170
|
+
) -> None:
|
|
171
|
+
"""Enable profiling for collector worker rollouts.
|
|
172
|
+
|
|
173
|
+
This method configures the collector to profile rollouts using PyTorch's
|
|
174
|
+
profiler. For multi-process collectors, profiling happens in the worker
|
|
175
|
+
processes. For single-process collectors (Collector), profiling happens
|
|
176
|
+
in the main process.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
workers: List of worker indices to profile. Defaults to [0].
|
|
180
|
+
For single-process collectors, this is ignored.
|
|
181
|
+
num_rollouts: Total number of rollouts to run the profiler for
|
|
182
|
+
(including warmup). Profiling stops after this many rollouts.
|
|
183
|
+
Defaults to 3.
|
|
184
|
+
warmup_rollouts: Number of rollouts to skip before starting actual
|
|
185
|
+
profiling. Useful for JIT/compile warmup. The profiler runs
|
|
186
|
+
but discards data during warmup. Defaults to 1.
|
|
187
|
+
save_path: Path to save the profiling trace. Supports {worker_idx}
|
|
188
|
+
placeholder for worker-specific files. If None, traces are
|
|
189
|
+
saved to "./collector_profile_{worker_idx}.json".
|
|
190
|
+
activities: List of profiler activities ("cpu", "cuda").
|
|
191
|
+
Defaults to ["cpu", "cuda"].
|
|
192
|
+
record_shapes: Whether to record tensor shapes. Defaults to True.
|
|
193
|
+
profile_memory: Whether to profile memory usage. Defaults to False.
|
|
194
|
+
with_stack: Whether to record Python stack traces. Defaults to True.
|
|
195
|
+
with_flops: Whether to compute FLOPS. Defaults to False.
|
|
196
|
+
on_trace_ready: Optional callback when trace is ready. If None,
|
|
197
|
+
traces are exported to Chrome trace format at save_path.
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
RuntimeError: If called after iteration has started.
|
|
201
|
+
ValueError: If num_rollouts <= warmup_rollouts.
|
|
202
|
+
|
|
203
|
+
Example:
|
|
204
|
+
>>> from torchrl.collectors import MultiSyncCollector
|
|
205
|
+
>>> collector = MultiSyncCollector(
|
|
206
|
+
... create_env_fn=[make_env] * 4,
|
|
207
|
+
... policy=policy,
|
|
208
|
+
... frames_per_batch=1000,
|
|
209
|
+
... total_frames=100000,
|
|
210
|
+
... )
|
|
211
|
+
>>> collector.enable_profile(
|
|
212
|
+
... workers=[0],
|
|
213
|
+
... num_rollouts=5,
|
|
214
|
+
... warmup_rollouts=2,
|
|
215
|
+
... save_path="./traces/worker_{worker_idx}.json",
|
|
216
|
+
... )
|
|
217
|
+
>>> # Worker 0 will be profiled for rollouts 2, 3, 4
|
|
218
|
+
>>> for data in collector:
|
|
219
|
+
... train(data)
|
|
220
|
+
>>> collector.shutdown()
|
|
221
|
+
|
|
222
|
+
Note:
|
|
223
|
+
- Profiling adds overhead, so only profile specific workers
|
|
224
|
+
- The trace file can be viewed in Chrome's trace viewer
|
|
225
|
+
(chrome://tracing) or with PyTorch's TensorBoard plugin
|
|
226
|
+
- For multi-process collectors, this must be called BEFORE
|
|
227
|
+
iteration starts as it needs to configure workers
|
|
228
|
+
"""
|
|
229
|
+
if self._iteration_started:
|
|
230
|
+
raise RuntimeError(
|
|
231
|
+
"Cannot enable profiling after iteration has started. "
|
|
232
|
+
"Call enable_profile() before iterating over the collector."
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if workers is None:
|
|
236
|
+
workers = [0]
|
|
237
|
+
if activities is None:
|
|
238
|
+
activities = ["cpu", "cuda"]
|
|
239
|
+
|
|
240
|
+
self._profile_config = ProfileConfig(
|
|
241
|
+
workers=workers,
|
|
242
|
+
num_rollouts=num_rollouts,
|
|
243
|
+
warmup_rollouts=warmup_rollouts,
|
|
244
|
+
save_path=save_path,
|
|
245
|
+
activities=activities,
|
|
246
|
+
record_shapes=record_shapes,
|
|
247
|
+
profile_memory=profile_memory,
|
|
248
|
+
with_stack=with_stack,
|
|
249
|
+
with_flops=with_flops,
|
|
250
|
+
on_trace_ready=on_trace_ready,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def profile_config(self) -> ProfileConfig | None:
|
|
255
|
+
"""Get the profiling configuration.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
ProfileConfig if profiling is enabled, None otherwise.
|
|
259
|
+
"""
|
|
260
|
+
return self._profile_config
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def weight_updater(self) -> WeightUpdaterBase:
|
|
264
|
+
return self._weight_updater
|
|
265
|
+
|
|
266
|
+
@weight_updater.setter
|
|
267
|
+
def weight_updater(self, value: WeightUpdaterBase | None):
|
|
268
|
+
if value is not None:
|
|
269
|
+
if not isinstance(value, WeightUpdaterBase) and callable(
|
|
270
|
+
value
|
|
271
|
+
): # Fall back to default constructor
|
|
272
|
+
value = value()
|
|
273
|
+
value.register_collector(self)
|
|
274
|
+
if value.collector is not self:
|
|
275
|
+
raise RuntimeError("Failed to register collector.")
|
|
276
|
+
self._weight_updater = value
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def worker_idx(self) -> int | None:
|
|
280
|
+
"""Get the worker index for this collector.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
The worker index (0-indexed).
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
RuntimeError: If worker_idx has not been set.
|
|
287
|
+
"""
|
|
288
|
+
if not hasattr(self, "_worker_idx"):
|
|
289
|
+
raise RuntimeError(
|
|
290
|
+
"worker_idx has not been set. This collector may not have been "
|
|
291
|
+
"initialized as a worker in a distributed setup."
|
|
292
|
+
)
|
|
293
|
+
return self._worker_idx
|
|
294
|
+
|
|
295
|
+
@worker_idx.setter
|
|
296
|
+
def worker_idx(self, value: int | None) -> None:
|
|
297
|
+
"""Set the worker index for this collector.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
value: The worker index (0-indexed) or None.
|
|
301
|
+
"""
|
|
302
|
+
self._worker_idx = value
|
|
303
|
+
|
|
304
|
+
def cascade_execute(self, attr_path: str, *args, **kwargs) -> Any:
|
|
305
|
+
"""Execute a method on a nested attribute of this collector.
|
|
306
|
+
|
|
307
|
+
This method allows remote callers to invoke methods on nested attributes
|
|
308
|
+
of the collector without needing to know the full structure. It's particularly
|
|
309
|
+
useful for calling methods on weight sync schemes from the sender side.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
attr_path: Full path to the callable, e.g.,
|
|
313
|
+
"_receiver_schemes['model_id']._set_dist_connection_info"
|
|
314
|
+
*args: Positional arguments to pass to the method.
|
|
315
|
+
**kwargs: Keyword arguments to pass to the method.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
The return value of the method call.
|
|
319
|
+
|
|
320
|
+
Examples:
|
|
321
|
+
>>> collector.cascade_execute(
|
|
322
|
+
... "_receiver_schemes['policy']._set_dist_connection_info",
|
|
323
|
+
... connection_info_ref,
|
|
324
|
+
... worker_idx=0
|
|
325
|
+
... )
|
|
326
|
+
"""
|
|
327
|
+
attr = _resolve_attr(self, attr_path)
|
|
328
|
+
if callable(attr):
|
|
329
|
+
return attr(*args, **kwargs)
|
|
330
|
+
else:
|
|
331
|
+
if args or kwargs:
|
|
332
|
+
raise ValueError(
|
|
333
|
+
f"Arguments and keyword arguments are not supported for non-callable attributes. Got {args} and {kwargs} for {attr_path}"
|
|
334
|
+
)
|
|
335
|
+
return attr
|
|
336
|
+
|
|
337
|
+
def _get_policy_and_device(
|
|
338
|
+
self,
|
|
339
|
+
policy: Callable[[Any], Any] | None = None,
|
|
340
|
+
policy_device: Any = NO_DEFAULT,
|
|
341
|
+
env_maker: Any | None = None,
|
|
342
|
+
env_maker_kwargs: dict[str, Any] | None = None,
|
|
343
|
+
) -> tuple[TensorDictModule, None | Callable[[], dict]]:
|
|
344
|
+
"""Util method to get a policy and its device given the collector __init__ inputs.
|
|
345
|
+
|
|
346
|
+
We want to copy the policy and then move the data there, not call policy.to(device).
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
policy (TensorDictModule, optional): a policy to be used
|
|
350
|
+
policy_device (torch.device, optional): the device where the policy should be placed.
|
|
351
|
+
Defaults to self.policy_device
|
|
352
|
+
env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair.
|
|
353
|
+
env_maker_kwargs (a dict, optional): the env_maker function kwargs.
|
|
354
|
+
|
|
355
|
+
"""
|
|
356
|
+
if policy_device is NO_DEFAULT:
|
|
357
|
+
policy_device = self.policy_device
|
|
358
|
+
|
|
359
|
+
if not policy_device:
|
|
360
|
+
return policy, None
|
|
361
|
+
|
|
362
|
+
if isinstance(policy, nn.Module):
|
|
363
|
+
param_and_buf = TensorDict.from_module(policy, as_module=True)
|
|
364
|
+
else:
|
|
365
|
+
# Because we want to reach the warning
|
|
366
|
+
param_and_buf = TensorDict()
|
|
367
|
+
|
|
368
|
+
i = -1
|
|
369
|
+
for p in param_and_buf.values(True, True):
|
|
370
|
+
i += 1
|
|
371
|
+
if p.device != policy_device:
|
|
372
|
+
# Then we need casting
|
|
373
|
+
break
|
|
374
|
+
else:
|
|
375
|
+
if i == -1 and not self.trust_policy:
|
|
376
|
+
# We trust that the policy policy device is adequate
|
|
377
|
+
warnings.warn(
|
|
378
|
+
"A policy device was provided but no parameter/buffer could be found in "
|
|
379
|
+
"the policy. Casting to policy_device is therefore impossible. "
|
|
380
|
+
"The collector will trust that the devices match. To suppress this "
|
|
381
|
+
"warning, set `trust_policy=True` when building the collector."
|
|
382
|
+
)
|
|
383
|
+
return policy, None
|
|
384
|
+
|
|
385
|
+
# Create a stateless policy, then populate this copy with params on device
|
|
386
|
+
def get_original_weights(policy=policy):
|
|
387
|
+
td = TensorDict.from_module(policy)
|
|
388
|
+
return td.data
|
|
389
|
+
|
|
390
|
+
# We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function
|
|
391
|
+
with param_and_buf.data.to("meta").to_module(policy):
|
|
392
|
+
policy_new_device = deepcopy(policy)
|
|
393
|
+
|
|
394
|
+
param_and_buf_new_device = param_and_buf.apply(
|
|
395
|
+
functools.partial(_map_weight, policy_device=policy_device),
|
|
396
|
+
filter_empty=False,
|
|
397
|
+
)
|
|
398
|
+
param_and_buf_new_device.to_module(policy_new_device)
|
|
399
|
+
# Sanity check
|
|
400
|
+
if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set(
|
|
401
|
+
get_original_weights().keys(True, True)
|
|
402
|
+
):
|
|
403
|
+
raise RuntimeError("Failed to map weights. The weight sets mismatch.")
|
|
404
|
+
return policy_new_device, get_original_weights
|
|
405
|
+
|
|
406
|
+
def start(self):
|
|
407
|
+
"""Starts the collector for asynchronous data collection.
|
|
408
|
+
|
|
409
|
+
This method initiates the background collection of data, allowing for decoupling of data collection and training.
|
|
410
|
+
|
|
411
|
+
The collected data is typically stored in a replay buffer passed during the collector's initialization.
|
|
412
|
+
|
|
413
|
+
.. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown`
|
|
414
|
+
when you're done with it to free up resources.
|
|
415
|
+
|
|
416
|
+
.. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature.
|
|
417
|
+
Ensure you understand the implications for your specific algorithm before using this mode.
|
|
418
|
+
|
|
419
|
+
Raises:
|
|
420
|
+
NotImplementedError: If not implemented by a subclass.
|
|
421
|
+
"""
|
|
422
|
+
raise NotImplementedError(
|
|
423
|
+
f"Collector start() is not implemented for {type(self).__name__}."
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
@contextlib.contextmanager
|
|
427
|
+
def pause(self):
|
|
428
|
+
"""Context manager that pauses the collector if it is running free."""
|
|
429
|
+
raise NotImplementedError(
|
|
430
|
+
f"Collector pause() is not implemented for {type(self).__name__}."
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
def async_shutdown(
|
|
434
|
+
self, timeout: float | None = None, close_env: bool = True
|
|
435
|
+
) -> None:
|
|
436
|
+
"""Shuts down the collector when started asynchronously with the `start` method.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
timeout (float, optional): The maximum time to wait for the collector to shutdown.
|
|
440
|
+
close_env (bool, optional): If True, the collector will close the contained environment.
|
|
441
|
+
Defaults to `True`.
|
|
442
|
+
|
|
443
|
+
.. seealso:: :meth:`~.start`
|
|
444
|
+
|
|
445
|
+
"""
|
|
446
|
+
return self.shutdown(timeout=timeout, close_env=close_env)
|
|
447
|
+
|
|
448
|
+
def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
|
|
449
|
+
"""Extract weights from a model if needed.
|
|
450
|
+
|
|
451
|
+
For the new weight sync scheme system, weight preparation is handled
|
|
452
|
+
by the scheme's prepare_weights() method. This method now only handles
|
|
453
|
+
legacy weight updater cases.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
weights: Either already-extracted weights or a model to extract from.
|
|
457
|
+
model_id: The model identifier for resolving string paths.
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
Extracted weights in the appropriate format.
|
|
461
|
+
"""
|
|
462
|
+
# New weight sync schemes handle preparation themselves
|
|
463
|
+
if self._weight_sync_schemes:
|
|
464
|
+
# Just pass through - WeightSender will call scheme.prepare_weights()
|
|
465
|
+
return weights
|
|
466
|
+
|
|
467
|
+
# Legacy weight updater path
|
|
468
|
+
return self._legacy_extract_weights(weights, model_id)
|
|
469
|
+
|
|
470
|
+
def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any:
|
|
471
|
+
"""Legacy weight extraction for old weight updater system.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
weights: Either already-extracted weights or a model to extract from.
|
|
475
|
+
model_id: The model identifier.
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
Extracted weights.
|
|
479
|
+
"""
|
|
480
|
+
if weights is None:
|
|
481
|
+
if model_id == "policy" and hasattr(self, "policy_weights"):
|
|
482
|
+
return self.policy_weights
|
|
483
|
+
elif model_id == "policy" and hasattr(self, "_policy_weights_dict"):
|
|
484
|
+
policy_device = (
|
|
485
|
+
self.policy_device
|
|
486
|
+
if not isinstance(self.policy_device, (list, tuple))
|
|
487
|
+
else self.policy_device[0]
|
|
488
|
+
)
|
|
489
|
+
return self._policy_weights_dict.get(policy_device)
|
|
490
|
+
return None
|
|
491
|
+
|
|
492
|
+
return weights
|
|
493
|
+
|
|
494
|
+
@property
|
|
495
|
+
def _legacy_weight_updater(self) -> bool:
|
|
496
|
+
return self._weight_updater is not None
|
|
497
|
+
|
|
498
|
+
# Overloads for update_policy_weights_ to support multiple calling conventions
|
|
499
|
+
@overload
|
|
500
|
+
def update_policy_weights_(
|
|
501
|
+
self,
|
|
502
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
|
|
503
|
+
/,
|
|
504
|
+
) -> None:
|
|
505
|
+
...
|
|
506
|
+
|
|
507
|
+
@overload
|
|
508
|
+
def update_policy_weights_(
|
|
509
|
+
self,
|
|
510
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
|
|
511
|
+
/,
|
|
512
|
+
*,
|
|
513
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
514
|
+
model_id: str | None = None,
|
|
515
|
+
) -> None:
|
|
516
|
+
...
|
|
517
|
+
|
|
518
|
+
@overload
|
|
519
|
+
def update_policy_weights_(
|
|
520
|
+
self,
|
|
521
|
+
*,
|
|
522
|
+
weights: TensorDictBase | dict,
|
|
523
|
+
model_id: str | None = None,
|
|
524
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
525
|
+
) -> None:
|
|
526
|
+
...
|
|
527
|
+
|
|
528
|
+
@overload
|
|
529
|
+
def update_policy_weights_(
|
|
530
|
+
self,
|
|
531
|
+
*,
|
|
532
|
+
policy: TensorDictModuleBase | nn.Module,
|
|
533
|
+
model_id: str | None = None,
|
|
534
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
535
|
+
) -> None:
|
|
536
|
+
...
|
|
537
|
+
|
|
538
|
+
@overload
|
|
539
|
+
def update_policy_weights_(
|
|
540
|
+
self,
|
|
541
|
+
*,
|
|
542
|
+
weights_dict: dict[
|
|
543
|
+
str, TensorDictBase | TensorDictModuleBase | nn.Module | dict
|
|
544
|
+
],
|
|
545
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
546
|
+
) -> None:
|
|
547
|
+
...
|
|
548
|
+
|
|
549
|
+
def update_policy_weights_(
|
|
550
|
+
self,
|
|
551
|
+
policy_or_weights: TensorDictBase
|
|
552
|
+
| TensorDictModuleBase
|
|
553
|
+
| nn.Module
|
|
554
|
+
| dict
|
|
555
|
+
| None = None,
|
|
556
|
+
*,
|
|
557
|
+
weights: TensorDictBase | dict | None = None,
|
|
558
|
+
policy: TensorDictModuleBase | nn.Module | None = None,
|
|
559
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
560
|
+
model_id: str | None = None,
|
|
561
|
+
weights_dict: dict[str, Any] | None = None,
|
|
562
|
+
**kwargs,
|
|
563
|
+
) -> None:
|
|
564
|
+
"""Update policy weights for the data collector.
|
|
565
|
+
|
|
566
|
+
This method synchronizes the policy weights used by the collector with the latest
|
|
567
|
+
trained weights. It supports both local and remote weight updates, depending on
|
|
568
|
+
the collector configuration.
|
|
569
|
+
|
|
570
|
+
The method accepts weights in multiple forms for convenience:
|
|
571
|
+
|
|
572
|
+
Examples:
|
|
573
|
+
>>> # Pass policy module as positional argument
|
|
574
|
+
>>> collector.update_policy_weights_(policy_module)
|
|
575
|
+
>>>
|
|
576
|
+
>>> # Pass TensorDict weights as positional argument
|
|
577
|
+
>>> collector.update_policy_weights_(weights_tensordict)
|
|
578
|
+
>>>
|
|
579
|
+
>>> # Use keyword arguments for clarity
|
|
580
|
+
>>> collector.update_policy_weights_(weights=weights_td, model_id="actor")
|
|
581
|
+
>>> collector.update_policy_weights_(policy=actor_module, model_id="actor")
|
|
582
|
+
>>>
|
|
583
|
+
>>> # Update multiple models atomically
|
|
584
|
+
>>> collector.update_policy_weights_(weights_dict={
|
|
585
|
+
... "actor": actor_weights,
|
|
586
|
+
... "critic": critic_weights,
|
|
587
|
+
... })
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
policy_or_weights: The weights to update with. Can be:
|
|
591
|
+
|
|
592
|
+
- ``nn.Module``: A policy module whose weights will be extracted
|
|
593
|
+
- ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted
|
|
594
|
+
- ``TensorDictBase``: A TensorDict containing weights
|
|
595
|
+
- ``dict``: A regular dict containing weights
|
|
596
|
+
- ``None``: Will try to get weights from server using ``_get_server_weights()``
|
|
597
|
+
|
|
598
|
+
Keyword Args:
|
|
599
|
+
weights: Alternative to positional argument. A TensorDict or dict containing
|
|
600
|
+
weights to update. Cannot be used together with ``policy_or_weights`` or ``policy``.
|
|
601
|
+
policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase``
|
|
602
|
+
whose weights will be extracted. Cannot be used together with ``policy_or_weights``
|
|
603
|
+
or ``weights``.
|
|
604
|
+
worker_ids: Identifiers for the workers to update. Relevant when the collector
|
|
605
|
+
has multiple workers. Can be int, list of ints, device, or list of devices.
|
|
606
|
+
model_id: The model identifier to update (default: ``"policy"``).
|
|
607
|
+
Cannot be used together with ``weights_dict``.
|
|
608
|
+
weights_dict: Dictionary mapping model_id to weights for updating
|
|
609
|
+
multiple models atomically. Keys should match model_ids registered in
|
|
610
|
+
``weight_sync_schemes``. Cannot be used together with ``model_id``,
|
|
611
|
+
``policy_or_weights``, ``weights``, or ``policy``.
|
|
612
|
+
|
|
613
|
+
Raises:
|
|
614
|
+
TypeError: If ``worker_ids`` is provided but no ``weight_updater`` is configured.
|
|
615
|
+
ValueError: If conflicting parameters are provided.
|
|
616
|
+
|
|
617
|
+
.. note:: Users should extend the ``WeightUpdaterBase`` classes to customize
|
|
618
|
+
the weight update logic for specific use cases.
|
|
619
|
+
|
|
620
|
+
.. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
|
|
621
|
+
:meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`.
|
|
622
|
+
|
|
623
|
+
"""
|
|
624
|
+
# Handle the different keyword argument forms
|
|
625
|
+
if weights is not None:
|
|
626
|
+
if policy_or_weights is not None:
|
|
627
|
+
raise ValueError(
|
|
628
|
+
"Cannot specify both positional 'policy_or_weights' and keyword 'weights'"
|
|
629
|
+
)
|
|
630
|
+
if policy is not None:
|
|
631
|
+
raise ValueError("Cannot specify both 'weights' and 'policy'")
|
|
632
|
+
policy_or_weights = weights
|
|
633
|
+
|
|
634
|
+
if policy is not None:
|
|
635
|
+
if policy_or_weights is not None:
|
|
636
|
+
raise ValueError(
|
|
637
|
+
"Cannot specify both positional 'policy_or_weights' and keyword 'policy'"
|
|
638
|
+
)
|
|
639
|
+
policy_or_weights = policy
|
|
640
|
+
if self._legacy_weight_updater:
|
|
641
|
+
return self._legacy_weight_update_impl(
|
|
642
|
+
policy_or_weights=policy_or_weights,
|
|
643
|
+
worker_ids=worker_ids,
|
|
644
|
+
model_id=model_id,
|
|
645
|
+
weights_dict=weights_dict,
|
|
646
|
+
**kwargs,
|
|
647
|
+
)
|
|
648
|
+
else:
|
|
649
|
+
return self._weight_update_impl(
|
|
650
|
+
policy_or_weights=policy_or_weights,
|
|
651
|
+
worker_ids=worker_ids,
|
|
652
|
+
model_id=model_id,
|
|
653
|
+
weights_dict=weights_dict,
|
|
654
|
+
**kwargs,
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
def _legacy_weight_update_impl(
|
|
658
|
+
self,
|
|
659
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
660
|
+
*,
|
|
661
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
662
|
+
model_id: str | None = None,
|
|
663
|
+
weights_dict: dict[str, Any] | None = None,
|
|
664
|
+
**kwargs,
|
|
665
|
+
) -> None:
|
|
666
|
+
if weights_dict is not None:
|
|
667
|
+
raise ValueError("weights_dict is not supported with legacy weight updater")
|
|
668
|
+
if model_id is not None:
|
|
669
|
+
raise ValueError("model_id is not supported with legacy weight updater")
|
|
670
|
+
# Fall back to old weight updater system
|
|
671
|
+
self.weight_updater(
|
|
672
|
+
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
def _weight_update_impl(
|
|
676
|
+
self,
|
|
677
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
678
|
+
*,
|
|
679
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
680
|
+
model_id: str | None = None,
|
|
681
|
+
weights_dict: dict[str, Any] | None = None,
|
|
682
|
+
**kwargs,
|
|
683
|
+
) -> None:
|
|
684
|
+
if "policy_weights" in kwargs:
|
|
685
|
+
warnings.warn(
|
|
686
|
+
"`policy_weights` is deprecated. Use `policy_or_weights` instead.",
|
|
687
|
+
DeprecationWarning,
|
|
688
|
+
)
|
|
689
|
+
policy_or_weights = kwargs.pop("policy_weights")
|
|
690
|
+
|
|
691
|
+
if weights_dict is not None and model_id is not None:
|
|
692
|
+
raise ValueError("Cannot specify both 'weights_dict' and 'model_id'")
|
|
693
|
+
|
|
694
|
+
if weights_dict is not None and policy_or_weights is not None:
|
|
695
|
+
raise ValueError(
|
|
696
|
+
"Cannot specify both 'weights_dict' and 'policy_or_weights'"
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
if self._weight_sync_schemes:
|
|
700
|
+
if model_id is None:
|
|
701
|
+
model_id = "policy"
|
|
702
|
+
if policy_or_weights is not None and weights_dict is None:
|
|
703
|
+
# Use model_id as the key, not hardcoded "policy"
|
|
704
|
+
weights_dict = {model_id: policy_or_weights}
|
|
705
|
+
elif weights_dict is None:
|
|
706
|
+
weights_dict = {model_id: policy_or_weights}
|
|
707
|
+
for target_model_id, weights in weights_dict.items():
|
|
708
|
+
if target_model_id not in self._weight_sync_schemes:
|
|
709
|
+
raise KeyError(
|
|
710
|
+
f"Model '{target_model_id}' not found in registered weight sync schemes. "
|
|
711
|
+
f"Available models: {list(self._weight_sync_schemes.keys())}"
|
|
712
|
+
)
|
|
713
|
+
processed_weights = self._extract_weights_if_needed(
|
|
714
|
+
weights, target_model_id
|
|
715
|
+
)
|
|
716
|
+
# Use new send() API with worker_ids support
|
|
717
|
+
scheme = self._weight_sync_schemes.get(target_model_id)
|
|
718
|
+
if not isinstance(scheme, WeightSyncScheme):
|
|
719
|
+
raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}")
|
|
720
|
+
self._send_weights_scheme(
|
|
721
|
+
scheme=scheme,
|
|
722
|
+
processed_weights=processed_weights,
|
|
723
|
+
worker_ids=worker_ids,
|
|
724
|
+
model_id=target_model_id,
|
|
725
|
+
)
|
|
726
|
+
elif self._weight_updater is not None:
|
|
727
|
+
# unreachable
|
|
728
|
+
raise RuntimeError
|
|
729
|
+
else:
|
|
730
|
+
# No weight updater configured, try fallback
|
|
731
|
+
self._maybe_fallback_update(policy_or_weights, model_id=model_id)
|
|
732
|
+
|
|
733
|
+
def _maybe_fallback_update(
|
|
734
|
+
self,
|
|
735
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
736
|
+
*,
|
|
737
|
+
model_id: str | None = None,
|
|
738
|
+
) -> None:
|
|
739
|
+
"""Fallback weight update when no scheme is configured.
|
|
740
|
+
|
|
741
|
+
Override in subclasses to provide custom fallback behavior.
|
|
742
|
+
By default, this is a no-op.
|
|
743
|
+
"""
|
|
744
|
+
|
|
745
|
+
def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids):
|
|
746
|
+
# method to override if the scheme requires an RPC call to receive the weights
|
|
747
|
+
scheme.send(weights=processed_weights, worker_ids=worker_ids)
|
|
748
|
+
|
|
749
|
+
def _receive_weights_scheme(self):
|
|
750
|
+
"""Receive weights for all registered receiver schemes.
|
|
751
|
+
|
|
752
|
+
scheme.receive() handles both applying weights locally and cascading
|
|
753
|
+
to sub-collectors via context.update_policy_weights_().
|
|
754
|
+
"""
|
|
755
|
+
if not hasattr(self, "_receiver_schemes"):
|
|
756
|
+
raise RuntimeError("No receiver schemes registered.")
|
|
757
|
+
|
|
758
|
+
for scheme in self._receiver_schemes.values():
|
|
759
|
+
scheme.receive()
|
|
760
|
+
|
|
761
|
+
# Overloads for receive_weights to support multiple calling conventions
|
|
762
|
+
@overload
|
|
763
|
+
def receive_weights(self) -> None:
|
|
764
|
+
...
|
|
765
|
+
|
|
766
|
+
@overload
|
|
767
|
+
def receive_weights(
|
|
768
|
+
self,
|
|
769
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
|
|
770
|
+
/,
|
|
771
|
+
) -> None:
|
|
772
|
+
...
|
|
773
|
+
|
|
774
|
+
@overload
|
|
775
|
+
def receive_weights(
|
|
776
|
+
self,
|
|
777
|
+
*,
|
|
778
|
+
weights: TensorDictBase | dict,
|
|
779
|
+
) -> None:
|
|
780
|
+
...
|
|
781
|
+
|
|
782
|
+
@overload
|
|
783
|
+
def receive_weights(
|
|
784
|
+
self,
|
|
785
|
+
*,
|
|
786
|
+
policy: TensorDictModuleBase | nn.Module,
|
|
787
|
+
) -> None:
|
|
788
|
+
...
|
|
789
|
+
|
|
790
|
+
def receive_weights(
|
|
791
|
+
self,
|
|
792
|
+
policy_or_weights: TensorDictBase
|
|
793
|
+
| TensorDictModuleBase
|
|
794
|
+
| nn.Module
|
|
795
|
+
| dict
|
|
796
|
+
| None = None,
|
|
797
|
+
*,
|
|
798
|
+
weights: TensorDictBase | dict | None = None,
|
|
799
|
+
policy: TensorDictModuleBase | nn.Module | None = None,
|
|
800
|
+
) -> None:
|
|
801
|
+
"""Receive and apply weights to the collector's policy.
|
|
802
|
+
|
|
803
|
+
This method applies weights to the local policy. When receiver schemes are
|
|
804
|
+
registered, it delegates to those schemes. Otherwise, it directly applies
|
|
805
|
+
the provided weights.
|
|
806
|
+
|
|
807
|
+
The method accepts weights in multiple forms for convenience:
|
|
808
|
+
|
|
809
|
+
Examples:
|
|
810
|
+
>>> # Receive from registered schemes (distributed collectors)
|
|
811
|
+
>>> collector.receive_weights()
|
|
812
|
+
>>>
|
|
813
|
+
>>> # Apply weights from a policy module (positional)
|
|
814
|
+
>>> collector.receive_weights(trained_policy)
|
|
815
|
+
>>>
|
|
816
|
+
>>> # Apply weights from a TensorDict (positional)
|
|
817
|
+
>>> collector.receive_weights(weights_tensordict)
|
|
818
|
+
>>>
|
|
819
|
+
>>> # Use keyword arguments for clarity
|
|
820
|
+
>>> collector.receive_weights(weights=weights_td)
|
|
821
|
+
>>> collector.receive_weights(policy=trained_policy)
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
policy_or_weights: The weights to apply. Can be:
|
|
825
|
+
|
|
826
|
+
- ``nn.Module``: A policy module whose weights will be extracted and applied
|
|
827
|
+
- ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted
|
|
828
|
+
- ``TensorDictBase``: A TensorDict containing weights
|
|
829
|
+
- ``dict``: A regular dict containing weights
|
|
830
|
+
- ``None``: Receive from registered schemes or mirror from original policy
|
|
831
|
+
|
|
832
|
+
Keyword Args:
|
|
833
|
+
weights: Alternative to positional argument. A TensorDict or dict containing
|
|
834
|
+
weights to apply. Cannot be used together with ``policy_or_weights`` or ``policy``.
|
|
835
|
+
policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase``
|
|
836
|
+
whose weights will be extracted. Cannot be used together with ``policy_or_weights``
|
|
837
|
+
or ``weights``.
|
|
838
|
+
|
|
839
|
+
Raises:
|
|
840
|
+
ValueError: If conflicting parameters are provided or if arguments are passed
|
|
841
|
+
when receiver schemes are registered.
|
|
842
|
+
|
|
843
|
+
"""
|
|
844
|
+
# Handle the different keyword argument forms
|
|
845
|
+
if weights is not None:
|
|
846
|
+
if policy_or_weights is not None:
|
|
847
|
+
raise ValueError(
|
|
848
|
+
"Cannot specify both positional 'policy_or_weights' and keyword 'weights'"
|
|
849
|
+
)
|
|
850
|
+
if policy is not None:
|
|
851
|
+
raise ValueError("Cannot specify both 'weights' and 'policy'")
|
|
852
|
+
policy_or_weights = weights
|
|
853
|
+
|
|
854
|
+
if policy is not None:
|
|
855
|
+
if policy_or_weights is not None:
|
|
856
|
+
raise ValueError(
|
|
857
|
+
"Cannot specify both positional 'policy_or_weights' and keyword 'policy'"
|
|
858
|
+
)
|
|
859
|
+
policy_or_weights = policy
|
|
860
|
+
|
|
861
|
+
if getattr(self, "_receiver_schemes", None) is not None:
|
|
862
|
+
if policy_or_weights is not None:
|
|
863
|
+
raise ValueError(
|
|
864
|
+
"Cannot specify 'policy_or_weights' when using 'receiver_schemes'. Schemes should know how to get the weights."
|
|
865
|
+
)
|
|
866
|
+
self._receive_weights_scheme()
|
|
867
|
+
return
|
|
868
|
+
|
|
869
|
+
# No weight updater configured
|
|
870
|
+
# For single-process collectors, apply weights locally if explicitly provided
|
|
871
|
+
if policy_or_weights is not None:
|
|
872
|
+
from torchrl.weight_update.weight_sync_schemes import WeightStrategy
|
|
873
|
+
|
|
874
|
+
# Use WeightStrategy to apply weights properly
|
|
875
|
+
strategy = WeightStrategy(extract_as="tensordict")
|
|
876
|
+
|
|
877
|
+
# Extract weights if needed
|
|
878
|
+
if isinstance(policy_or_weights, nn.Module):
|
|
879
|
+
weights = strategy.extract_weights(policy_or_weights)
|
|
880
|
+
else:
|
|
881
|
+
weights = policy_or_weights
|
|
882
|
+
|
|
883
|
+
# Apply to local policy
|
|
884
|
+
if hasattr(self, "policy") and isinstance(self.policy, nn.Module):
|
|
885
|
+
strategy.apply_weights(self.policy, weights)
|
|
886
|
+
# Otherwise, no action needed - policy is local and changes are immediately visible
|
|
887
|
+
|
|
888
|
+
def register_scheme_receiver(
|
|
889
|
+
self,
|
|
890
|
+
weight_recv_schemes: dict[str, WeightSyncScheme],
|
|
891
|
+
*,
|
|
892
|
+
synchronize_weights: bool = True,
|
|
893
|
+
): # noqa: D417
|
|
894
|
+
"""Set up receiver schemes for this collector to receive weights from parent collectors.
|
|
895
|
+
|
|
896
|
+
This method initializes receiver schemes and stores them in _receiver_schemes
|
|
897
|
+
for later use by _receive_weights_scheme() and receive_weights().
|
|
898
|
+
|
|
899
|
+
Receiver schemes enable cascading weight updates across collector hierarchies:
|
|
900
|
+
- Parent collector sends weights via its weight_sync_schemes (senders)
|
|
901
|
+
- Child collector receives weights via its weight_recv_schemes (receivers)
|
|
902
|
+
- If child is also a parent (intermediate node), it can propagate to its own children
|
|
903
|
+
|
|
904
|
+
Args:
|
|
905
|
+
weight_recv_schemes (dict[str, WeightSyncScheme]): Dictionary of {model_id: WeightSyncScheme} to set up as receivers.
|
|
906
|
+
These schemes will receive weights from parent collectors.
|
|
907
|
+
|
|
908
|
+
Keyword Args:
|
|
909
|
+
synchronize_weights (bool, optional): If True, synchronize weights immediately after registering the schemes.
|
|
910
|
+
Defaults to `True`.
|
|
911
|
+
"""
|
|
912
|
+
# Initialize _receiver_schemes if not already present
|
|
913
|
+
if not hasattr(self, "_receiver_schemes"):
|
|
914
|
+
self._receiver_schemes = {}
|
|
915
|
+
|
|
916
|
+
# Initialize each scheme on the receiver side
|
|
917
|
+
for model_id, scheme in weight_recv_schemes.items():
|
|
918
|
+
if not scheme.initialized_on_receiver:
|
|
919
|
+
if scheme.initialized_on_sender:
|
|
920
|
+
raise RuntimeError(
|
|
921
|
+
"Weight sync scheme cannot be initialized on both sender and receiver."
|
|
922
|
+
)
|
|
923
|
+
scheme.init_on_receiver(
|
|
924
|
+
model_id=model_id,
|
|
925
|
+
context=self,
|
|
926
|
+
worker_idx=self.worker_idx,
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
# Store the scheme for later use in receive_weights()
|
|
930
|
+
self._receiver_schemes[model_id] = scheme
|
|
931
|
+
|
|
932
|
+
# Perform initial synchronization
|
|
933
|
+
if synchronize_weights:
|
|
934
|
+
for scheme in weight_recv_schemes.values():
|
|
935
|
+
if not scheme.synchronized_on_receiver:
|
|
936
|
+
scheme.connect(worker_idx=self.worker_idx)
|
|
937
|
+
|
|
938
|
+
def __iter__(self) -> Iterator[TensorDictBase]:
|
|
939
|
+
# Mark that iteration has started (used by enable_profile check)
|
|
940
|
+
self._iteration_started = True
|
|
941
|
+
try:
|
|
942
|
+
yield from self.iterator()
|
|
943
|
+
except Exception:
|
|
944
|
+
self.shutdown()
|
|
945
|
+
raise
|
|
946
|
+
|
|
947
|
+
def next(self):
|
|
948
|
+
try:
|
|
949
|
+
if self._iterator is None:
|
|
950
|
+
self._iterator = iter(self)
|
|
951
|
+
out = next(self._iterator)
|
|
952
|
+
# if any, we don't want the device ref to be passed in distributed settings
|
|
953
|
+
if out is not None and (out.device != "cpu"):
|
|
954
|
+
out = out.copy().clear_device_()
|
|
955
|
+
return out
|
|
956
|
+
except StopIteration:
|
|
957
|
+
return None
|
|
958
|
+
|
|
959
|
+
@abc.abstractmethod
|
|
960
|
+
def shutdown(
|
|
961
|
+
self,
|
|
962
|
+
timeout: float | None = None,
|
|
963
|
+
close_env: bool = True,
|
|
964
|
+
raise_on_error: bool = True,
|
|
965
|
+
) -> None:
|
|
966
|
+
raise NotImplementedError
|
|
967
|
+
|
|
968
|
+
@abc.abstractmethod
|
|
969
|
+
def iterator(self) -> Iterator[TensorDictBase]:
|
|
970
|
+
raise NotImplementedError
|
|
971
|
+
|
|
972
|
+
@abc.abstractmethod
|
|
973
|
+
def set_seed(self, seed: int, static_seed: bool = False) -> int:
|
|
974
|
+
raise NotImplementedError
|
|
975
|
+
|
|
976
|
+
@abc.abstractmethod
|
|
977
|
+
def state_dict(self) -> OrderedDict:
|
|
978
|
+
raise NotImplementedError
|
|
979
|
+
|
|
980
|
+
@abc.abstractmethod
|
|
981
|
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
|
982
|
+
raise NotImplementedError
|
|
983
|
+
|
|
984
|
+
def _read_compile_kwargs(self, compile_policy, cudagraph_policy):
|
|
985
|
+
self.compiled_policy = compile_policy not in (False, None)
|
|
986
|
+
self.cudagraphed_policy = cudagraph_policy not in (False, None)
|
|
987
|
+
self.compiled_policy_kwargs = (
|
|
988
|
+
{} if not isinstance(compile_policy, typing.Mapping) else compile_policy
|
|
989
|
+
)
|
|
990
|
+
self.cudagraphed_policy_kwargs = (
|
|
991
|
+
{} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
def __repr__(self) -> str:
|
|
995
|
+
string = f"{self.__class__.__name__}()"
|
|
996
|
+
return string
|
|
997
|
+
|
|
998
|
+
def __class_getitem__(self, index):
|
|
999
|
+
raise NotImplementedError
|
|
1000
|
+
|
|
1001
|
+
def __len__(self) -> int:
|
|
1002
|
+
if self.total_frames > 0:
|
|
1003
|
+
return -(self.total_frames // -self.requested_frames_per_batch)
|
|
1004
|
+
raise RuntimeError("Non-terminating collectors do not have a length")
|
|
1005
|
+
|
|
1006
|
+
def init_updater(self, *args, **kwargs):
|
|
1007
|
+
"""Initialize the weight updater with custom arguments.
|
|
1008
|
+
|
|
1009
|
+
This method passes the arguments to the weight updater's init method.
|
|
1010
|
+
If no weight updater is set, this is a no-op.
|
|
1011
|
+
|
|
1012
|
+
Args:
|
|
1013
|
+
*args: Positional arguments for weight updater initialization
|
|
1014
|
+
**kwargs: Keyword arguments for weight updater initialization
|
|
1015
|
+
"""
|
|
1016
|
+
if self.weight_updater is not None:
|
|
1017
|
+
self.weight_updater.init(*args, **kwargs)
|
|
1018
|
+
|
|
1019
|
+
|
|
1020
|
+
def _make_legacy_metaclass(parent_metaclass):
|
|
1021
|
+
"""Create a legacy metaclass for deprecated collector names.
|
|
1022
|
+
|
|
1023
|
+
This factory creates a metaclass that inherits from the given parent metaclass
|
|
1024
|
+
to avoid metaclass conflicts.
|
|
1025
|
+
"""
|
|
1026
|
+
|
|
1027
|
+
class _LegacyMeta(parent_metaclass):
|
|
1028
|
+
"""Metaclass for deprecated collector class names.
|
|
1029
|
+
|
|
1030
|
+
Raises a deprecation warning when the old class name is instantiated,
|
|
1031
|
+
and ensures isinstance() checks work for both old and new names.
|
|
1032
|
+
"""
|
|
1033
|
+
|
|
1034
|
+
def __call__(cls, *args, **kwargs):
|
|
1035
|
+
warnings.warn(
|
|
1036
|
+
f"{cls.__name__} has been deprecated and will be removed in v0.13. "
|
|
1037
|
+
f"Please use {cls.__bases__[0].__name__} instead.",
|
|
1038
|
+
category=DeprecationWarning,
|
|
1039
|
+
)
|
|
1040
|
+
return super().__call__(*args, **kwargs)
|
|
1041
|
+
|
|
1042
|
+
def __instancecheck__(cls, instance):
|
|
1043
|
+
if super().__instancecheck__(instance):
|
|
1044
|
+
return True
|
|
1045
|
+
parent_cls = cls.__bases__[0]
|
|
1046
|
+
return isinstance(instance, parent_cls)
|
|
1047
|
+
|
|
1048
|
+
return _LegacyMeta
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
# Default legacy metaclass for classes with abc.ABCMeta
|
|
1052
|
+
_LegacyCollectorMeta = _make_legacy_metaclass(abc.ABCMeta)
|
|
1053
|
+
|
|
1054
|
+
|
|
1055
|
+
class DataCollectorBase(BaseCollector, metaclass=_LegacyCollectorMeta):
|
|
1056
|
+
"""Deprecated version of :class:`~torchrl.collectors.BaseCollector`."""
|
|
1057
|
+
|
|
1058
|
+
...
|