torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import subprocess
|
|
9
|
+
import time
|
|
10
|
+
|
|
11
|
+
from torchrl._utils import logger as torchrl_logger, VERBOSE
|
|
12
|
+
from torchrl.collectors.distributed.default_configs import (
|
|
13
|
+
DEFAULT_SLURM_CONF,
|
|
14
|
+
DEFAULT_SLURM_CONF_MAIN,
|
|
15
|
+
TCP_PORT,
|
|
16
|
+
)
|
|
17
|
+
from torchrl.collectors.distributed.generic import _distributed_init_delayed
|
|
18
|
+
from torchrl.collectors.distributed.rpc import _rpc_init_collection_node
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import submitit
|
|
22
|
+
|
|
23
|
+
_has_submitit = True
|
|
24
|
+
except ModuleNotFoundError as err:
|
|
25
|
+
_has_submitit = False
|
|
26
|
+
SUBMITIT_ERR = err
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class submitit_delayed_launcher:
|
|
30
|
+
"""Delayed launcher for submitit.
|
|
31
|
+
|
|
32
|
+
In some cases, launched jobs cannot spawn other jobs on their own and this
|
|
33
|
+
can only be done at the jump-host level.
|
|
34
|
+
|
|
35
|
+
In these cases, the :func:`submitit_delayed_launcher` can be used to
|
|
36
|
+
pre-launch collector nodes that will wait for the main worker to provide
|
|
37
|
+
the launching instruction.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
num_jobs (int): the number of collection jobs to be launched.
|
|
41
|
+
framework (str, optional): the framework to use. Can be either ``"distributed"``
|
|
42
|
+
or ``"rpc"``. ``"distributed"`` requires a :class:`~.DistributedDataCollector`
|
|
43
|
+
collector whereas ``"rpc"`` requires a :class:`RPCDataCollector`.
|
|
44
|
+
Defaults to ``"distributed"``.
|
|
45
|
+
backend (str, optional): torch.distributed backend in case ``framework``
|
|
46
|
+
points to ``"distributed"``. This value must match the one passed to
|
|
47
|
+
the collector, otherwise main and satellite nodes will fail to
|
|
48
|
+
reach the rendezvous and hang forever (ie no exception will be raised!)
|
|
49
|
+
Defaults to ``'gloo'``.
|
|
50
|
+
tcpport (int or str, optional): the TCP port to use.
|
|
51
|
+
Defaults to :obj:`torchrl.collectors.distributed.default_configs.TCP_PORT`
|
|
52
|
+
submitit_main_conf (dict, optional): the main node configuration to be passed to submitit.
|
|
53
|
+
Defaults to :obj:`torchrl.collectors.distributed.default_configs.DEFAULT_SLURM_CONF_MAIN`
|
|
54
|
+
submitit_collection_conf (dict, optional): the configuration to be passed to submitit.
|
|
55
|
+
Defaults to :obj:`torchrl.collectors.distributed.default_configs.DEFAULT_SLURM_CONF`
|
|
56
|
+
|
|
57
|
+
Examples:
|
|
58
|
+
>>> num_jobs=2
|
|
59
|
+
>>> @submitit_delayed_launcher(num_jobs=num_jobs)
|
|
60
|
+
... def main():
|
|
61
|
+
... from torchrl.modules.utils.utils import RandomPolicyfrom torchrl.envs.libs.gym import GymEnv
|
|
62
|
+
... from torchrl.data import BoundedContinuous
|
|
63
|
+
... collector = DistributedDataCollector(
|
|
64
|
+
... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs,
|
|
65
|
+
... policy=RandomPolicy(BoundedContinuous(-1, 1, shape=(1,))),
|
|
66
|
+
... launcher="submitit_delayed",
|
|
67
|
+
... )
|
|
68
|
+
... for data in collector:
|
|
69
|
+
... print(data)
|
|
70
|
+
...
|
|
71
|
+
>>> if __name__ == "__main__":
|
|
72
|
+
... main()
|
|
73
|
+
...
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
_VERBOSE = VERBOSE # for debugging
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
num_jobs,
|
|
81
|
+
framework="distributed",
|
|
82
|
+
backend="gloo",
|
|
83
|
+
tcpport=TCP_PORT,
|
|
84
|
+
submitit_main_conf: dict = DEFAULT_SLURM_CONF_MAIN,
|
|
85
|
+
submitit_collection_conf: dict = DEFAULT_SLURM_CONF,
|
|
86
|
+
):
|
|
87
|
+
self.num_jobs = num_jobs
|
|
88
|
+
self.backend = backend
|
|
89
|
+
self.framework = framework
|
|
90
|
+
self.submitit_collection_conf = submitit_collection_conf
|
|
91
|
+
self.submitit_main_conf = submitit_main_conf
|
|
92
|
+
self.tcpport = tcpport
|
|
93
|
+
|
|
94
|
+
def __call__(self, main_func):
|
|
95
|
+
def exec_fun():
|
|
96
|
+
if not _has_submitit:
|
|
97
|
+
raise ModuleNotFoundError(
|
|
98
|
+
"Failed to import submitit. Check installation of the library."
|
|
99
|
+
) from SUBMITIT_ERR
|
|
100
|
+
# submit main
|
|
101
|
+
executor = submitit.AutoExecutor(folder="log_test")
|
|
102
|
+
executor.update_parameters(**self.submitit_main_conf)
|
|
103
|
+
main_job = executor.submit(main_func)
|
|
104
|
+
# listen to output file looking for IP address
|
|
105
|
+
torchrl_logger.debug(f"job id: {main_job.job_id}")
|
|
106
|
+
time.sleep(2.0)
|
|
107
|
+
node = None
|
|
108
|
+
while not node:
|
|
109
|
+
cmd = f"squeue -j {main_job.job_id} -o %N | tail -1"
|
|
110
|
+
node = subprocess.check_output(cmd, shell=True, text=True).strip()
|
|
111
|
+
try:
|
|
112
|
+
node = int(node)
|
|
113
|
+
except ValueError:
|
|
114
|
+
time.sleep(0.5)
|
|
115
|
+
continue
|
|
116
|
+
torchrl_logger.debug(f"node: {node}")
|
|
117
|
+
# by default, sinfo will truncate the node name at char 20, we increase this to 200
|
|
118
|
+
cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1"
|
|
119
|
+
rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip()
|
|
120
|
+
torchrl_logger.debug(f"IP: {rank0_ip}")
|
|
121
|
+
world_size = self.num_jobs + 1
|
|
122
|
+
|
|
123
|
+
# submit jobs
|
|
124
|
+
executor = submitit.AutoExecutor(folder="log_test")
|
|
125
|
+
executor.update_parameters(**self.submitit_collection_conf)
|
|
126
|
+
jobs = []
|
|
127
|
+
if self.framework == "rpc":
|
|
128
|
+
from .rpc import DEFAULT_TENSORPIPE_OPTIONS
|
|
129
|
+
|
|
130
|
+
tensorpipe_options = DEFAULT_TENSORPIPE_OPTIONS
|
|
131
|
+
for i in range(self.num_jobs):
|
|
132
|
+
rank = i + 1
|
|
133
|
+
if self.framework == "distributed":
|
|
134
|
+
job = executor.submit(
|
|
135
|
+
_distributed_init_delayed,
|
|
136
|
+
rank,
|
|
137
|
+
self.backend,
|
|
138
|
+
rank0_ip,
|
|
139
|
+
self.tcpport,
|
|
140
|
+
world_size,
|
|
141
|
+
self._VERBOSE,
|
|
142
|
+
)
|
|
143
|
+
elif self.framework == "rpc":
|
|
144
|
+
job = executor.submit(
|
|
145
|
+
_rpc_init_collection_node,
|
|
146
|
+
rank,
|
|
147
|
+
rank0_ip,
|
|
148
|
+
self.tcpport,
|
|
149
|
+
world_size,
|
|
150
|
+
None,
|
|
151
|
+
tensorpipe_options,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
raise NotImplementedError(f"Unknown framework {self.framework}.")
|
|
155
|
+
jobs.append(job)
|
|
156
|
+
for job in jobs:
|
|
157
|
+
job.result()
|
|
158
|
+
main_job.result()
|
|
159
|
+
|
|
160
|
+
return exec_fun
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .base import LLMCollector
|
|
7
|
+
from .ray_collector import RayLLMCollector
|
|
8
|
+
from .weight_update import vLLMUpdater, vLLMUpdaterV2
|
|
9
|
+
|
|
10
|
+
__all__ = ["vLLMUpdater", "vLLMUpdaterV2", "LLMCollector", "RayLLMCollector"]
|
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections import deque
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from tensordict import lazy_stack, TensorDictBase
|
|
14
|
+
|
|
15
|
+
from torchrl._utils import as_remote, logger as torchrl_logger
|
|
16
|
+
|
|
17
|
+
from torchrl.collectors._single import Collector
|
|
18
|
+
from torchrl.collectors.llm.utils import _QueueAsRB
|
|
19
|
+
from torchrl.collectors.weight_update import WeightUpdaterBase
|
|
20
|
+
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
|
|
21
|
+
from torchrl.envs import AsyncEnvPool
|
|
22
|
+
from torchrl.envs.common import EnvBase
|
|
23
|
+
from torchrl.envs.llm.transforms.policy_version import PolicyVersion
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LLMCollector(Collector):
|
|
27
|
+
"""A simplified version of Collector for LLM inference.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
env (EnvBase or EnvBase constructor): the environment to be used for data collection.
|
|
31
|
+
|
|
32
|
+
Keyword Args:
|
|
33
|
+
policy (Callable[[TensorDictBase], TensorDictBase]): the policy to be used for data collection.
|
|
34
|
+
policy_factory (Callable[[], Callable], optional): a callable that returns
|
|
35
|
+
a policy instance. This is exclusive with the `policy` argument.
|
|
36
|
+
|
|
37
|
+
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
|
|
38
|
+
|
|
39
|
+
dialog_turns_per_batch (int, optional): A keyword-only argument representing the total
|
|
40
|
+
number of elements in a batch. It is always required except when `yield_completed_trajectories=True`.
|
|
41
|
+
total_dialog_turns (int): A keyword-only argument representing the total
|
|
42
|
+
number of steps returned by the collector during its lifespan. -1 is never ending (until shutdown).
|
|
43
|
+
Defaults to -1.
|
|
44
|
+
yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
|
|
45
|
+
(`yield_completed_trajectories=False`, default) or single, completed trajectories
|
|
46
|
+
(`yield_completed_trajectories=True`).
|
|
47
|
+
Defaults to `False` unless `yield_only_last_steps=True`, where it cannot be `False`.
|
|
48
|
+
|
|
49
|
+
.. warning:: If the `done` state of the environment is not properly set, this may lead to a collector
|
|
50
|
+
that never leads any data.
|
|
51
|
+
|
|
52
|
+
yield_only_last_steps (bool, optional): whether to yield every step of a trajectory, or only the
|
|
53
|
+
last (done) steps.
|
|
54
|
+
If `True`, a single trajectory is yielded (or written in the buffer) at a time.
|
|
55
|
+
|
|
56
|
+
.. warning:: If the `done` state of the environment is not properly set, this may lead to a collector
|
|
57
|
+
that never leads any data.
|
|
58
|
+
|
|
59
|
+
postproc (Callable, optional): A post-processing transform, such as
|
|
60
|
+
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
|
|
61
|
+
instance.
|
|
62
|
+
Defaults to ``None``.
|
|
63
|
+
async_envs (bool, optional): if ``True``, the environment will be run asynchronously. Defaults to `True` if the
|
|
64
|
+
environment is a :class:`~torchrl.envs.AsyncEnvPool` instance.
|
|
65
|
+
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
|
|
66
|
+
but populate the buffer instead. Defaults to ``None``.
|
|
67
|
+
reset_at_each_iter (bool, optional): if ``True``, the environment will be reset at each iteration.
|
|
68
|
+
flatten_data (bool, optional): if ``True``, the collector will flatten the collected data
|
|
69
|
+
before returning it. In practice, this means that if an environment of batch-size `(B,)` is used
|
|
70
|
+
and run for `T` steps, `flatten_data=True` will present data of shape `(B*T,)`, whereas
|
|
71
|
+
`flatten_data=False` will not present data of shape `(B, T)`.
|
|
72
|
+
Defaults to `True` when `replay_buffer` is provided, `False` otherwise.
|
|
73
|
+
weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
|
|
74
|
+
or its subclass, responsible for updating the policy weights on remote inference workers.
|
|
75
|
+
This is typically not used in :class:`~torchrl.collectors.Collector` as it operates in a single-process environment.
|
|
76
|
+
Consider using a constructor if the updater needs to be serialized.
|
|
77
|
+
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
|
|
78
|
+
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
|
|
79
|
+
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
|
|
80
|
+
the policy version.
|
|
81
|
+
Defaults to `False`.
|
|
82
|
+
verbose (bool, optional): if ``True``, the collector will print progress information.
|
|
83
|
+
Defaults to `False`.
|
|
84
|
+
|
|
85
|
+
Examples:
|
|
86
|
+
>>> import vllm
|
|
87
|
+
>>> from torchrl.modules import vLLMWrapper
|
|
88
|
+
>>> from torchrl.testing.mocking_classes import DummyStrDataLoader
|
|
89
|
+
>>> from torchrl.envs import LLMEnv
|
|
90
|
+
>>> llm_model = vllm.LLM("gpt2")
|
|
91
|
+
>>> tokenizer = llm_model.get_tokenizer()
|
|
92
|
+
>>> tokenizer.pad_token = tokenizer.eos_token
|
|
93
|
+
>>> policy = vLLMWrapper(llm_model)
|
|
94
|
+
>>> dataloader = DummyStrDataLoader(1)
|
|
95
|
+
>>> env = LLMEnv.from_dataloader(
|
|
96
|
+
... dataloader=dataloader,
|
|
97
|
+
... tokenizer=tokenizer,
|
|
98
|
+
... from_text=True,
|
|
99
|
+
... batch_size=1,
|
|
100
|
+
... group_repeats=True,
|
|
101
|
+
... )
|
|
102
|
+
>>> collector = LLMCollector(
|
|
103
|
+
... env=env,
|
|
104
|
+
... policy_factory=lambda: policy,
|
|
105
|
+
... dialog_turns_per_batch=env.batch_size[0],
|
|
106
|
+
... total_dialog_turns=3,
|
|
107
|
+
... )
|
|
108
|
+
>>> for i, data in enumerate(collector):
|
|
109
|
+
... if i == 2:
|
|
110
|
+
... print(data)
|
|
111
|
+
... break
|
|
112
|
+
LazyStackedTensorDict(
|
|
113
|
+
fields={
|
|
114
|
+
attention_mask: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
115
|
+
collector: LazyStackedTensorDict(
|
|
116
|
+
fields={
|
|
117
|
+
traj_ids: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
118
|
+
exclusive_fields={
|
|
119
|
+
},
|
|
120
|
+
batch_size=torch.Size([1, 1]),
|
|
121
|
+
device=None,
|
|
122
|
+
is_shared=False,
|
|
123
|
+
stack_dim=1),
|
|
124
|
+
done: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
125
|
+
terminated: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
126
|
+
text: NonTensorStack(
|
|
127
|
+
[['plsgqejeyd']],
|
|
128
|
+
batch_size=torch.Size([1, 1]),
|
|
129
|
+
device=None),
|
|
130
|
+
text_response: NonTensorStack(
|
|
131
|
+
[['ec.n.n.n.tjbjz3perwhz']],
|
|
132
|
+
batch_size=torch.Size([1, 1]),
|
|
133
|
+
device=None),
|
|
134
|
+
tokens: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
135
|
+
tokens_response: Tensor(shape=torch.Size([1, 1, 16]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
136
|
+
exclusive_fields={
|
|
137
|
+
},
|
|
138
|
+
batch_size=torch.Size([1, 1]),
|
|
139
|
+
device=None,
|
|
140
|
+
is_shared=False,
|
|
141
|
+
stack_dim=1)
|
|
142
|
+
>>> del collector
|
|
143
|
+
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
env: EnvBase | Callable[[], EnvBase],
|
|
149
|
+
*,
|
|
150
|
+
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
151
|
+
policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]]
|
|
152
|
+
| None = None,
|
|
153
|
+
dialog_turns_per_batch: int | None = None,
|
|
154
|
+
yield_only_last_steps: bool | None = None,
|
|
155
|
+
yield_completed_trajectories: bool | None = None,
|
|
156
|
+
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
157
|
+
total_dialog_turns: int = -1,
|
|
158
|
+
async_envs: bool | None = None,
|
|
159
|
+
replay_buffer: ReplayBuffer | None = None,
|
|
160
|
+
reset_at_each_iter: bool = False,
|
|
161
|
+
flatten_data: bool | None = None,
|
|
162
|
+
weight_updater: WeightUpdaterBase
|
|
163
|
+
| Callable[[], WeightUpdaterBase]
|
|
164
|
+
| None = None,
|
|
165
|
+
queue: Any | None = None,
|
|
166
|
+
track_policy_version: bool | PolicyVersion = False,
|
|
167
|
+
verbose: bool = False,
|
|
168
|
+
):
|
|
169
|
+
if queue is not None and replay_buffer is not None:
|
|
170
|
+
raise RuntimeError(
|
|
171
|
+
"Handling both a buffer and a queue is not possible at the moment."
|
|
172
|
+
)
|
|
173
|
+
elif queue is not None:
|
|
174
|
+
# disguise the queue as a replay buffer
|
|
175
|
+
replay_buffer = _QueueAsRB(queue)
|
|
176
|
+
if dialog_turns_per_batch is None and yield_completed_trajectories:
|
|
177
|
+
dialog_turns_per_batch = 1
|
|
178
|
+
super().__init__(
|
|
179
|
+
create_env_fn=env,
|
|
180
|
+
policy=policy,
|
|
181
|
+
policy_factory=policy_factory,
|
|
182
|
+
frames_per_batch=dialog_turns_per_batch,
|
|
183
|
+
replay_buffer=replay_buffer,
|
|
184
|
+
total_frames=total_dialog_turns,
|
|
185
|
+
weight_updater=weight_updater,
|
|
186
|
+
reset_at_each_iter=reset_at_each_iter,
|
|
187
|
+
trust_policy=True,
|
|
188
|
+
use_buffers=False,
|
|
189
|
+
no_cuda_sync=True,
|
|
190
|
+
extend_buffer=True,
|
|
191
|
+
postproc=postproc,
|
|
192
|
+
)
|
|
193
|
+
if hasattr(self.policy, "register_collector"):
|
|
194
|
+
self.policy.register_collector(self)
|
|
195
|
+
|
|
196
|
+
if yield_only_last_steps is None:
|
|
197
|
+
yield_only_last_steps = False
|
|
198
|
+
|
|
199
|
+
if yield_completed_trajectories is None:
|
|
200
|
+
yield_completed_trajectories = yield_only_last_steps
|
|
201
|
+
elif yield_only_last_steps and not yield_completed_trajectories:
|
|
202
|
+
raise TypeError(
|
|
203
|
+
"yield_only_last_steps=True requires yield_completed_trajectories=True (or None)"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if yield_only_last_steps:
|
|
207
|
+
if flatten_data is not None:
|
|
208
|
+
raise TypeError(
|
|
209
|
+
"`yield_only_last_steps` cannot be `True` when `flatten_data` is passed."
|
|
210
|
+
)
|
|
211
|
+
if self.reset_at_each_iter:
|
|
212
|
+
raise TypeError(
|
|
213
|
+
"`yield_only_last_steps` cannot be `True` when `reset_at_each_iter=True`."
|
|
214
|
+
)
|
|
215
|
+
if flatten_data is None:
|
|
216
|
+
flatten_data = replay_buffer is not None
|
|
217
|
+
self.flatten_data = flatten_data
|
|
218
|
+
self.yield_completed_trajectories = yield_completed_trajectories
|
|
219
|
+
self.yield_only_last_steps = yield_only_last_steps
|
|
220
|
+
self.verbose = verbose
|
|
221
|
+
self._shuttle = None # Initialize shuttle for rollout
|
|
222
|
+
if self.yield_completed_trajectories:
|
|
223
|
+
# For async envs, we route by env_id so we only care about batch_size[0].
|
|
224
|
+
# For non-async envs, we need exactly one batch dimension.
|
|
225
|
+
if not isinstance(self.env, AsyncEnvPool) and len(self.env.batch_size) != 1:
|
|
226
|
+
raise ValueError(
|
|
227
|
+
"`yield_completed_trajectories` only works with envs that have a single batch dimension. Got "
|
|
228
|
+
f"env.batch_size={self.env.batch_size}."
|
|
229
|
+
)
|
|
230
|
+
self._yield_queues = [deque() for _ in range(self.env.batch_size[0])]
|
|
231
|
+
self._trajectory_queue = deque()
|
|
232
|
+
self.async_envs = bool(async_envs) | isinstance(self.env, AsyncEnvPool)
|
|
233
|
+
if self.async_envs and not isinstance(self.env, AsyncEnvPool):
|
|
234
|
+
# This basically means that `async_envs` is automatically set and passing is it useless as of today,
|
|
235
|
+
# except for the following error.
|
|
236
|
+
raise RuntimeError(
|
|
237
|
+
"async_envs requires the environment to be an AsyncEnvPool instance."
|
|
238
|
+
)
|
|
239
|
+
self.policy_version_tracker = track_policy_version
|
|
240
|
+
if isinstance(track_policy_version, bool) and track_policy_version:
|
|
241
|
+
if isinstance(self.env, AsyncEnvPool):
|
|
242
|
+
raise RuntimeError(
|
|
243
|
+
"AsyncEnvPool is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, "
|
|
244
|
+
"and pass that transform to the collector."
|
|
245
|
+
)
|
|
246
|
+
self.policy_version_tracker = PolicyVersion()
|
|
247
|
+
self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
|
|
248
|
+
elif isinstance(track_policy_version, PolicyVersion):
|
|
249
|
+
self.policy_version_tracker = track_policy_version
|
|
250
|
+
self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
|
|
251
|
+
else:
|
|
252
|
+
self.policy_version_tracker = None
|
|
253
|
+
|
|
254
|
+
def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
|
|
255
|
+
if self.postproc is not None:
|
|
256
|
+
raise RuntimeError("Postproc already set")
|
|
257
|
+
self.postproc = postproc
|
|
258
|
+
|
|
259
|
+
def increment_version(self):
|
|
260
|
+
"""Increment the policy version."""
|
|
261
|
+
if self.policy_version_tracker is not None:
|
|
262
|
+
if not isinstance(self.policy_version_tracker, PolicyVersion):
|
|
263
|
+
raise RuntimeError(
|
|
264
|
+
"Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
|
|
265
|
+
)
|
|
266
|
+
self.policy_version_tracker.increment_version()
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def policy_version(self) -> str | int | None:
|
|
270
|
+
"""The current policy version."""
|
|
271
|
+
if not isinstance(self.policy_version_tracker, PolicyVersion):
|
|
272
|
+
return None
|
|
273
|
+
return self.policy_version_tracker.version
|
|
274
|
+
|
|
275
|
+
def get_policy_version(self) -> str | int | None:
|
|
276
|
+
"""Get the current policy version.
|
|
277
|
+
|
|
278
|
+
This method exists to support remote calls in Ray actors, since properties
|
|
279
|
+
cannot be accessed directly through Ray's RPC mechanism.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
The current version number (int) or UUID (str), or None if version tracking is disabled.
|
|
283
|
+
"""
|
|
284
|
+
return self.policy_version
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def total_dialog_turns(self):
|
|
288
|
+
return self.total_frames
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def dialog_turns_per_batch(self) -> int:
|
|
292
|
+
"""Alias to `frames_per_batch`."""
|
|
293
|
+
return self.requested_frames_per_batch
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def rollout(self) -> Callable[[], TensorDictBase]:
|
|
297
|
+
if self.yield_completed_trajectories:
|
|
298
|
+
if self.async_envs:
|
|
299
|
+
return self._rollout_yield_trajs_async
|
|
300
|
+
else:
|
|
301
|
+
return self._rollout_yield_trajs
|
|
302
|
+
else:
|
|
303
|
+
return self._rollout_all
|
|
304
|
+
|
|
305
|
+
def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout
|
|
306
|
+
if self.reset_at_each_iter or self._shuttle is None:
|
|
307
|
+
self._shuttle = self.env.reset()
|
|
308
|
+
|
|
309
|
+
trajectory = []
|
|
310
|
+
collected_steps = 0
|
|
311
|
+
policy_input = self._shuttle
|
|
312
|
+
while collected_steps < self.dialog_turns_per_batch:
|
|
313
|
+
if self.verbose:
|
|
314
|
+
torchrl_logger.debug(
|
|
315
|
+
f"LLMCollector: Collected {collected_steps} steps over {self.dialog_turns_per_batch} requested."
|
|
316
|
+
)
|
|
317
|
+
env_input = self.policy(policy_input)
|
|
318
|
+
env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
|
|
319
|
+
|
|
320
|
+
# carry over collector data without messing up devices
|
|
321
|
+
collector_data = env_output.get("collector", default=None)
|
|
322
|
+
if collector_data is not None:
|
|
323
|
+
env_next_output.set("collector", collector_data.copy())
|
|
324
|
+
self._update_traj_ids(env_output)
|
|
325
|
+
trajectory.append(env_output.clone())
|
|
326
|
+
collected_steps += env_output.numel()
|
|
327
|
+
policy_input = self._shuttle = env_next_output
|
|
328
|
+
trajectory = lazy_stack(trajectory, -1)
|
|
329
|
+
if self.flatten_data:
|
|
330
|
+
return trajectory.view(-1)
|
|
331
|
+
return trajectory
|
|
332
|
+
|
|
333
|
+
_result_numel = 0
|
|
334
|
+
|
|
335
|
+
def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rollout
|
|
336
|
+
if self._shuttle is None:
|
|
337
|
+
self._shuttle = self.env.reset()
|
|
338
|
+
next_output = self._shuttle
|
|
339
|
+
|
|
340
|
+
collected_steps = 0
|
|
341
|
+
dones = torch.zeros(self.env.batch_size, dtype=torch.bool)
|
|
342
|
+
while True:
|
|
343
|
+
if self._result_numel >= self.dialog_turns_per_batch:
|
|
344
|
+
break
|
|
345
|
+
elif self.verbose:
|
|
346
|
+
torchrl_logger.debug(
|
|
347
|
+
f"LLMCollector: Collected {collected_steps} steps with {self._result_numel} elements in the resulting batch, over {self.dialog_turns_per_batch} requested."
|
|
348
|
+
)
|
|
349
|
+
env_input = self.policy(next_output)
|
|
350
|
+
cur_output, next_output = self.env.step_and_maybe_reset(env_input)
|
|
351
|
+
# for i in range(cur_output.numel()):
|
|
352
|
+
# print(len(cur_output[i]["text"]) < len(cur_output[i]["next", "text"]))
|
|
353
|
+
|
|
354
|
+
# carry over collector data without messing up devices
|
|
355
|
+
collector_data = cur_output.get("collector", default=None)
|
|
356
|
+
if collector_data is not None:
|
|
357
|
+
self._update_traj_ids(cur_output)
|
|
358
|
+
next_output.set("collector", collector_data.copy())
|
|
359
|
+
|
|
360
|
+
# if the loop is interrupted
|
|
361
|
+
self._shuttle = next_output
|
|
362
|
+
collected_steps += next_output.numel()
|
|
363
|
+
for i, (_data, queue) in enumerate(
|
|
364
|
+
zip(cur_output.unbind(0), self._yield_queues)
|
|
365
|
+
):
|
|
366
|
+
queue.append(_data)
|
|
367
|
+
dones[i] = _data["next", "done"].any()
|
|
368
|
+
if dones.any():
|
|
369
|
+
for idx in dones.nonzero(as_tuple=True)[0].tolist():
|
|
370
|
+
if not self.yield_only_last_steps:
|
|
371
|
+
_result = lazy_stack(self._yield_queues[idx], -1)
|
|
372
|
+
self._trajectory_queue.append(_result)
|
|
373
|
+
else:
|
|
374
|
+
# FIXME: We need to increment the step count here because iterator() won't
|
|
375
|
+
# see the extra steps
|
|
376
|
+
# We use lazy-stack because unsqueeze doesn't nest the strings in lists
|
|
377
|
+
_result = lazy_stack([self._yield_queues[idx][-1]])
|
|
378
|
+
self._trajectory_queue.append(_result)
|
|
379
|
+
self._result_numel += _result.numel()
|
|
380
|
+
self._yield_queues[idx].clear()
|
|
381
|
+
result = [self._trajectory_queue.popleft()]
|
|
382
|
+
elt = result[0].numel()
|
|
383
|
+
self._result_numel -= result[0].numel()
|
|
384
|
+
while elt < self.dialog_turns_per_batch:
|
|
385
|
+
result.append(self._trajectory_queue.popleft())
|
|
386
|
+
elt += result[-1].numel()
|
|
387
|
+
self._result_numel -= result[-1].numel()
|
|
388
|
+
result = torch.cat(result, -1)
|
|
389
|
+
if self.verbose:
|
|
390
|
+
torchrl_logger.debug(
|
|
391
|
+
f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
|
|
392
|
+
)
|
|
393
|
+
return result
|
|
394
|
+
|
|
395
|
+
started = False
|
|
396
|
+
|
|
397
|
+
def _rollout_yield_trajs_async(
|
|
398
|
+
self,
|
|
399
|
+
) -> TensorDictBase: # A simplified version of rollout
|
|
400
|
+
if not self.started:
|
|
401
|
+
if self._shuttle is None:
|
|
402
|
+
self._shuttle = self.env.reset()
|
|
403
|
+
next_output = self._shuttle
|
|
404
|
+
env_input = self.policy(next_output)
|
|
405
|
+
self.env.async_step_and_maybe_reset_send(env_input)
|
|
406
|
+
self.started = True
|
|
407
|
+
|
|
408
|
+
collected_steps = 0
|
|
409
|
+
# Use only the first dimension (num_envs) for done tracking, since we route by env_id
|
|
410
|
+
dones = torch.zeros(self.env.batch_size[0], dtype=torch.bool)
|
|
411
|
+
while True:
|
|
412
|
+
if self._trajectory_queue:
|
|
413
|
+
break
|
|
414
|
+
|
|
415
|
+
cur_output, next_output = self.env.async_step_and_maybe_reset_recv()
|
|
416
|
+
|
|
417
|
+
# Get the env ids - flatten to handle multi-dimensional batch sizes
|
|
418
|
+
# (e.g., AsyncEnvPool with batch_size=[4, 1] gives [[0], [1], [2], [3]])
|
|
419
|
+
env_ids_raw = cur_output.get(self.env._env_idx_key).tolist()
|
|
420
|
+
# Flatten nested lists to get scalar env indices
|
|
421
|
+
env_ids = []
|
|
422
|
+
for eid in env_ids_raw:
|
|
423
|
+
while isinstance(eid, list) and len(eid) == 1:
|
|
424
|
+
eid = eid[0]
|
|
425
|
+
env_ids.append(eid)
|
|
426
|
+
|
|
427
|
+
# carry over collector data without messing up devices
|
|
428
|
+
collector_data = cur_output.get("collector", default=None)
|
|
429
|
+
if collector_data is not None:
|
|
430
|
+
self._update_traj_ids(cur_output)
|
|
431
|
+
next_output.set("collector", collector_data.copy())
|
|
432
|
+
|
|
433
|
+
collected_steps += next_output.numel()
|
|
434
|
+
dones.fill_(False)
|
|
435
|
+
for i, _data in zip(env_ids, cur_output.unbind(0)):
|
|
436
|
+
queue = self._yield_queues[i]
|
|
437
|
+
queue.append(_data)
|
|
438
|
+
dones[i] = _data["next", "done"].any()
|
|
439
|
+
if dones.any():
|
|
440
|
+
for idx in dones.nonzero(as_tuple=True)[0].tolist():
|
|
441
|
+
if not self.yield_only_last_steps:
|
|
442
|
+
self._trajectory_queue.append(
|
|
443
|
+
lazy_stack(self._yield_queues[idx], -1)
|
|
444
|
+
)
|
|
445
|
+
else:
|
|
446
|
+
# FIXME: We need to increment the step count here because iterator() won't
|
|
447
|
+
# see the extra steps
|
|
448
|
+
# We use lazy-stack because unsqueeze doesn't nest the strings in lists
|
|
449
|
+
self._trajectory_queue.append(
|
|
450
|
+
lazy_stack([self._yield_queues[idx][-1]])
|
|
451
|
+
)
|
|
452
|
+
self._yield_queues[idx].clear()
|
|
453
|
+
|
|
454
|
+
# Launch the next batch:
|
|
455
|
+
# FIXME: Add a condition RE number of frames here
|
|
456
|
+
if True:
|
|
457
|
+
env_input = self.policy(next_output)
|
|
458
|
+
self.env.async_step_and_maybe_reset_send(env_input)
|
|
459
|
+
|
|
460
|
+
result = self._trajectory_queue.popleft()
|
|
461
|
+
# Flatten the result - AsyncEnvPool child envs with batch_size=(1,) produce
|
|
462
|
+
# trajectories with shape [1, T] but we want [T] for consistency
|
|
463
|
+
result = result.view(-1)
|
|
464
|
+
if self.verbose:
|
|
465
|
+
torchrl_logger.debug(
|
|
466
|
+
f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
|
|
467
|
+
)
|
|
468
|
+
return result
|
|
469
|
+
|
|
470
|
+
as_remote = as_remote
|
|
471
|
+
|
|
472
|
+
def get_policy_model(self):
|
|
473
|
+
"""Get the policy model.
|
|
474
|
+
|
|
475
|
+
This method is used by RayLLMCollector to get the remote LLM instance
|
|
476
|
+
for weight updates.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
The policy model instance
|
|
480
|
+
"""
|
|
481
|
+
return self.policy.model
|
|
482
|
+
|
|
483
|
+
def is_initialized(self) -> bool:
|
|
484
|
+
"""Check if the collector is initialized and ready.
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
bool: True if the collector is initialized and ready to collect data.
|
|
488
|
+
"""
|
|
489
|
+
# The collector is initialized if it has a valid environment and policy
|
|
490
|
+
return hasattr(self, "_env") and hasattr(self, "_policy")
|
|
491
|
+
|
|
492
|
+
def set_weight_updater(self, weight_updater: WeightUpdaterBase):
|
|
493
|
+
self.weight_updater = weight_updater
|
|
494
|
+
return True
|