torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
#
|
|
6
|
+
# from https://github.com/toshas/torch_truncnorm
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
from numbers import Number
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch.distributions import constraints, Distribution
|
|
14
|
+
from torch.distributions.utils import broadcast_all
|
|
15
|
+
|
|
16
|
+
CONST_SQRT_2 = math.sqrt(2)
|
|
17
|
+
CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi)
|
|
18
|
+
CONST_INV_SQRT_2 = 1 / math.sqrt(2)
|
|
19
|
+
CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
|
|
20
|
+
CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TruncatedStandardNormal(Distribution):
|
|
24
|
+
"""Truncated Standard Normal distribution.
|
|
25
|
+
|
|
26
|
+
Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
arg_constraints = {
|
|
30
|
+
"a": constraints.real,
|
|
31
|
+
"b": constraints.real,
|
|
32
|
+
}
|
|
33
|
+
has_rsample = True
|
|
34
|
+
eps = 1e-6
|
|
35
|
+
|
|
36
|
+
def __init__(self, a, b, validate_args=None, device=None):
|
|
37
|
+
self.a, self.b = broadcast_all(a, b)
|
|
38
|
+
_non_blocking = device is not None and torch.device(device).type == "cuda"
|
|
39
|
+
self.a = self.a.to(device, non_blocking=_non_blocking)
|
|
40
|
+
self.b = self.b.to(device, non_blocking=_non_blocking)
|
|
41
|
+
if isinstance(a, Number) and isinstance(b, Number):
|
|
42
|
+
batch_shape = torch.Size()
|
|
43
|
+
else:
|
|
44
|
+
batch_shape = self.a.size()
|
|
45
|
+
super().__init__(batch_shape, validate_args=validate_args)
|
|
46
|
+
if self.a.dtype != self.b.dtype:
|
|
47
|
+
raise ValueError("Truncation bounds types are different")
|
|
48
|
+
if any(
|
|
49
|
+
(self.a >= self.b)
|
|
50
|
+
.view(
|
|
51
|
+
-1,
|
|
52
|
+
)
|
|
53
|
+
.tolist()
|
|
54
|
+
):
|
|
55
|
+
raise ValueError("Incorrect truncation range")
|
|
56
|
+
eps = self.eps
|
|
57
|
+
self._dtype_min_gt_0 = eps
|
|
58
|
+
self._dtype_max_lt_1 = 1 - eps
|
|
59
|
+
self._little_phi_a = self._little_phi(self.a)
|
|
60
|
+
self._little_phi_b = self._little_phi(self.b)
|
|
61
|
+
self._big_phi_a = self._big_phi(self.a)
|
|
62
|
+
self._big_phi_b = self._big_phi(self.b)
|
|
63
|
+
self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps)
|
|
64
|
+
self._log_Z = self._Z.log()
|
|
65
|
+
little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan)
|
|
66
|
+
little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)
|
|
67
|
+
self._lpbb_m_lpaa_d_Z = (
|
|
68
|
+
self._little_phi_b * little_phi_coeff_b
|
|
69
|
+
- self._little_phi_a * little_phi_coeff_a
|
|
70
|
+
) / self._Z
|
|
71
|
+
self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z
|
|
72
|
+
self._variance = (
|
|
73
|
+
1
|
|
74
|
+
- self._lpbb_m_lpaa_d_Z
|
|
75
|
+
- ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2
|
|
76
|
+
)
|
|
77
|
+
self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
|
|
78
|
+
|
|
79
|
+
@constraints.dependent_property
|
|
80
|
+
def support(self):
|
|
81
|
+
return constraints.interval(self.a, self.b)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def mean(self):
|
|
85
|
+
return self._mean
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def deterministic_sample(self):
|
|
89
|
+
return self.mean
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def variance(self):
|
|
93
|
+
return self._variance
|
|
94
|
+
|
|
95
|
+
def entropy(self):
|
|
96
|
+
return self._entropy
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def auc(self):
|
|
100
|
+
return self._Z
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _little_phi(x):
|
|
104
|
+
return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI
|
|
105
|
+
|
|
106
|
+
def _big_phi(self, x):
|
|
107
|
+
phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf())
|
|
108
|
+
return phi.clamp(self.eps, 1 - self.eps)
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _inv_big_phi(x):
|
|
112
|
+
return CONST_SQRT_2 * (2 * x - 1).erfinv()
|
|
113
|
+
|
|
114
|
+
def cdf(self, value):
|
|
115
|
+
if self._validate_args:
|
|
116
|
+
self._validate_sample(value)
|
|
117
|
+
return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)
|
|
118
|
+
|
|
119
|
+
def icdf(self, value):
|
|
120
|
+
y = self._big_phi_a + value * self._Z
|
|
121
|
+
y = y.clamp(self.eps, 1 - self.eps)
|
|
122
|
+
return self._inv_big_phi(y)
|
|
123
|
+
|
|
124
|
+
def log_prob(self, value):
|
|
125
|
+
if self._validate_args:
|
|
126
|
+
self._validate_sample(value)
|
|
127
|
+
return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
|
|
128
|
+
|
|
129
|
+
def rsample(self, sample_shape=None):
|
|
130
|
+
if sample_shape is None:
|
|
131
|
+
sample_shape = torch.Size([])
|
|
132
|
+
shape = self._extended_shape(sample_shape)
|
|
133
|
+
p = torch.empty(shape, device=self.a.device).uniform_(
|
|
134
|
+
self._dtype_min_gt_0, self._dtype_max_lt_1
|
|
135
|
+
)
|
|
136
|
+
return self.icdf(p)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class TruncatedNormal(TruncatedStandardNormal):
|
|
140
|
+
"""Truncated Normal distribution.
|
|
141
|
+
|
|
142
|
+
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
has_rsample = True
|
|
146
|
+
|
|
147
|
+
def __init__(self, loc, scale, a, b, validate_args=None, device=None):
|
|
148
|
+
scale = scale.clamp_min(self.eps)
|
|
149
|
+
self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
|
|
150
|
+
_non_blocking = device is not None and torch.device(device).type == "cuda"
|
|
151
|
+
a = a.to(device, non_blocking=_non_blocking)
|
|
152
|
+
b = b.to(device, non_blocking=_non_blocking)
|
|
153
|
+
self._non_std_a = a
|
|
154
|
+
self._non_std_b = b
|
|
155
|
+
a = (a - self.loc) / self.scale
|
|
156
|
+
b = (b - self.loc) / self.scale
|
|
157
|
+
super().__init__(a, b, validate_args=validate_args)
|
|
158
|
+
self._log_scale = self.scale.log()
|
|
159
|
+
self._mean = self._mean * self.scale + self.loc
|
|
160
|
+
self._variance = self._variance * self.scale**2
|
|
161
|
+
self._entropy += self._log_scale
|
|
162
|
+
|
|
163
|
+
def _to_std_rv(self, value):
|
|
164
|
+
return (value - self.loc) / self.scale
|
|
165
|
+
|
|
166
|
+
def _from_std_rv(self, value):
|
|
167
|
+
return value * self.scale + self.loc
|
|
168
|
+
|
|
169
|
+
def cdf(self, value):
|
|
170
|
+
return super().cdf(self._to_std_rv(value))
|
|
171
|
+
|
|
172
|
+
def icdf(self, value):
|
|
173
|
+
sample = self._from_std_rv(super().icdf(value))
|
|
174
|
+
|
|
175
|
+
# clamp data but keep gradients
|
|
176
|
+
sample_clip = torch.stack(
|
|
177
|
+
[sample.detach(), self._non_std_a.detach().expand_as(sample)], 0
|
|
178
|
+
).max(0)[0]
|
|
179
|
+
sample_clip = torch.stack(
|
|
180
|
+
[sample_clip, self._non_std_b.detach().expand_as(sample)], 0
|
|
181
|
+
).min(0)[0]
|
|
182
|
+
sample.data.copy_(sample_clip)
|
|
183
|
+
return sample
|
|
184
|
+
|
|
185
|
+
def log_prob(self, value):
|
|
186
|
+
value = self._to_std_rv(value)
|
|
187
|
+
return super().log_prob(value) - self._log_scale
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import autograd, distributions as d
|
|
9
|
+
from torch.distributions import Independent, Transform, TransformedDistribution
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from torch.compiler import is_dynamo_compiling
|
|
13
|
+
except ImportError:
|
|
14
|
+
from torch._dynamo import is_compiling as is_dynamo_compiling
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _cast_device(elt: torch.Tensor | float, device) -> torch.Tensor | float:
|
|
18
|
+
if isinstance(elt, torch.Tensor):
|
|
19
|
+
_non_blocking = device is not None and torch.device(device).type == "cuda"
|
|
20
|
+
return elt.to(device, non_blocking=_non_blocking)
|
|
21
|
+
return elt
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _cast_transform_device(transform, device):
|
|
25
|
+
if transform is None:
|
|
26
|
+
return transform
|
|
27
|
+
_non_blocking = device is not None and torch.device(device).type == "cuda"
|
|
28
|
+
if isinstance(transform, d.ComposeTransform):
|
|
29
|
+
for i, t in enumerate(transform.parts):
|
|
30
|
+
transform.parts[i] = _cast_transform_device(t, device)
|
|
31
|
+
elif isinstance(transform, d.Transform):
|
|
32
|
+
for attribute in dir(transform):
|
|
33
|
+
value = getattr(transform, attribute)
|
|
34
|
+
if isinstance(value, torch.Tensor):
|
|
35
|
+
setattr(
|
|
36
|
+
transform, attribute, value.to(device, non_blocking=_non_blocking)
|
|
37
|
+
)
|
|
38
|
+
return transform
|
|
39
|
+
else:
|
|
40
|
+
raise TypeError(
|
|
41
|
+
f"Cannot perform device casting for transform of type {type(transform)}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class FasterTransformedDistribution(TransformedDistribution):
|
|
46
|
+
"""A faster implementation of TransformedDistribution."""
|
|
47
|
+
|
|
48
|
+
__doc__ = __doc__ + TransformedDistribution.__doc__
|
|
49
|
+
|
|
50
|
+
def __init__(self, base_distribution, transforms, validate_args=None):
|
|
51
|
+
if is_dynamo_compiling():
|
|
52
|
+
return super().__init__(
|
|
53
|
+
base_distribution, transforms, validate_args=validate_args
|
|
54
|
+
)
|
|
55
|
+
if isinstance(transforms, Transform):
|
|
56
|
+
self.transforms = [transforms]
|
|
57
|
+
elif isinstance(transforms, list):
|
|
58
|
+
raise ValueError("Make a ComposeTransform first.")
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"transforms must be a Transform or list, but was {transforms}"
|
|
62
|
+
)
|
|
63
|
+
transform = self.transforms[0]
|
|
64
|
+
# Reshape base_distribution according to transforms.
|
|
65
|
+
base_shape = base_distribution.batch_shape + base_distribution.event_shape
|
|
66
|
+
base_event_dim = len(base_distribution.event_shape)
|
|
67
|
+
# transform = ComposeTransform(self.transforms)
|
|
68
|
+
# if len(base_shape) < transform.domain.event_dim:
|
|
69
|
+
# raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
|
|
70
|
+
# .format(transform.domain.event_dim, base_shape))
|
|
71
|
+
transform_codomain_event_dim = transform.codomain.event_dim
|
|
72
|
+
transform_domain_event_dim = transform.domain.event_dim
|
|
73
|
+
|
|
74
|
+
forward_shape = transform.forward_shape(base_shape)
|
|
75
|
+
expanded_base_shape = transform.inverse_shape(forward_shape)
|
|
76
|
+
if base_shape != expanded_base_shape:
|
|
77
|
+
base_batch_shape = expanded_base_shape[
|
|
78
|
+
: len(expanded_base_shape) - base_event_dim
|
|
79
|
+
]
|
|
80
|
+
base_distribution = base_distribution.expand(base_batch_shape)
|
|
81
|
+
reinterpreted_batch_ndims = transform_domain_event_dim - base_event_dim
|
|
82
|
+
if reinterpreted_batch_ndims > 0:
|
|
83
|
+
base_distribution = Independent(
|
|
84
|
+
base_distribution, reinterpreted_batch_ndims
|
|
85
|
+
)
|
|
86
|
+
self.base_dist = base_distribution
|
|
87
|
+
|
|
88
|
+
# Compute shapes.
|
|
89
|
+
transform_change_in_event_dim = (
|
|
90
|
+
transform_codomain_event_dim - transform_domain_event_dim
|
|
91
|
+
)
|
|
92
|
+
event_dim = max(
|
|
93
|
+
transform_codomain_event_dim, # the transform is coupled
|
|
94
|
+
base_event_dim + transform_change_in_event_dim, # the base dist is coupled
|
|
95
|
+
)
|
|
96
|
+
cut = len(forward_shape) - event_dim
|
|
97
|
+
batch_shape = forward_shape[:cut]
|
|
98
|
+
event_shape = forward_shape[cut:]
|
|
99
|
+
super(TransformedDistribution, self).__init__(
|
|
100
|
+
batch_shape, event_shape, validate_args=validate_args
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _safetanh(x, eps): # noqa: D103
|
|
105
|
+
lim = 1.0 - eps
|
|
106
|
+
y = x.tanh()
|
|
107
|
+
return y.clamp(-lim, lim)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _safeatanh(y, eps): # noqa: D103
|
|
111
|
+
lim = 1.0 - eps
|
|
112
|
+
return y.clamp(-lim, lim).atanh()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class _SafeTanh(autograd.Function):
|
|
116
|
+
generate_vmap_rule = True
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def forward(input, eps):
|
|
120
|
+
output = input.tanh()
|
|
121
|
+
lim = 1.0 - eps
|
|
122
|
+
output = output.clamp(-lim, lim)
|
|
123
|
+
# ctx.save_for_backward(output)
|
|
124
|
+
return output
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def setup_context(ctx, inputs, output):
|
|
128
|
+
# input, eps = inputs
|
|
129
|
+
# ctx.mark_non_differentiable(ind, ind_inv)
|
|
130
|
+
# # Tensors must be saved via ctx.save_for_backward. Please do not
|
|
131
|
+
# # assign them directly onto the ctx object.
|
|
132
|
+
ctx.save_for_backward(output)
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def backward(ctx, *grad):
|
|
136
|
+
grad = grad[0]
|
|
137
|
+
(output,) = ctx.saved_tensors
|
|
138
|
+
return (grad * (1 - output.pow(2)), None)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class _SafeTanhNoEps(autograd.Function):
|
|
142
|
+
generate_vmap_rule = True
|
|
143
|
+
|
|
144
|
+
@staticmethod
|
|
145
|
+
def forward(input):
|
|
146
|
+
output = input.tanh()
|
|
147
|
+
eps = torch.finfo(input.dtype).resolution
|
|
148
|
+
lim = 1.0 - eps
|
|
149
|
+
output = output.clamp(-lim, lim)
|
|
150
|
+
return output
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def setup_context(ctx, inputs, output):
|
|
154
|
+
ctx.save_for_backward(output)
|
|
155
|
+
|
|
156
|
+
@staticmethod
|
|
157
|
+
def backward(ctx, *grad):
|
|
158
|
+
grad = grad[0]
|
|
159
|
+
(output,) = ctx.saved_tensors
|
|
160
|
+
return (grad * (1 - output.pow(2)),)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class _SafeaTanh(autograd.Function):
|
|
164
|
+
generate_vmap_rule = True
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def forward(tanh_val, eps):
|
|
168
|
+
if eps is None:
|
|
169
|
+
eps = torch.finfo(tanh_val.dtype).resolution
|
|
170
|
+
lim = 1.0 - eps
|
|
171
|
+
output = tanh_val.clamp(-lim, lim)
|
|
172
|
+
# ctx.save_for_backward(output)
|
|
173
|
+
output = output.atanh()
|
|
174
|
+
return output
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def setup_context(ctx, inputs, output):
|
|
178
|
+
tanh_val, eps = inputs
|
|
179
|
+
|
|
180
|
+
# ctx.mark_non_differentiable(ind, ind_inv)
|
|
181
|
+
# # Tensors must be saved via ctx.save_for_backward. Please do not
|
|
182
|
+
# # assign them directly onto the ctx object.
|
|
183
|
+
ctx.save_for_backward(tanh_val)
|
|
184
|
+
ctx.eps = eps
|
|
185
|
+
|
|
186
|
+
@staticmethod
|
|
187
|
+
def backward(ctx, *grad):
|
|
188
|
+
grad = grad[0]
|
|
189
|
+
(tanh_val,) = ctx.saved_tensors
|
|
190
|
+
eps = ctx.eps
|
|
191
|
+
lim = 1.0 - eps
|
|
192
|
+
output = tanh_val.clamp(-lim, lim)
|
|
193
|
+
return (grad / (1 - output.pow(2)), None)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class _SafeaTanhNoEps(autograd.Function):
|
|
197
|
+
generate_vmap_rule = True
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def forward(tanh_val):
|
|
201
|
+
eps = torch.finfo(tanh_val.dtype).resolution
|
|
202
|
+
lim = 1.0 - eps
|
|
203
|
+
output = tanh_val.clamp(-lim, lim)
|
|
204
|
+
# ctx.save_for_backward(output)
|
|
205
|
+
output = output.atanh()
|
|
206
|
+
return output
|
|
207
|
+
|
|
208
|
+
@staticmethod
|
|
209
|
+
def setup_context(ctx, inputs, output):
|
|
210
|
+
tanh_val = inputs[0]
|
|
211
|
+
eps = torch.finfo(tanh_val.dtype).resolution
|
|
212
|
+
|
|
213
|
+
# ctx.mark_non_differentiable(ind, ind_inv)
|
|
214
|
+
# # Tensors must be saved via ctx.save_for_backward. Please do not
|
|
215
|
+
# # assign them directly onto the ctx object.
|
|
216
|
+
ctx.save_for_backward(tanh_val)
|
|
217
|
+
ctx.eps = eps
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def backward(ctx, *grad):
|
|
221
|
+
grad = grad[0]
|
|
222
|
+
(tanh_val,) = ctx.saved_tensors
|
|
223
|
+
eps = ctx.eps
|
|
224
|
+
lim = 1.0 - eps
|
|
225
|
+
output = tanh_val.clamp(-lim, lim)
|
|
226
|
+
return (grad / (1 - output.pow(2)),)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
safetanh = _SafeTanh.apply
|
|
230
|
+
safeatanh = _SafeaTanh.apply
|
|
231
|
+
|
|
232
|
+
safetanh_noeps = _SafeTanhNoEps.apply
|
|
233
|
+
safeatanh_noeps = _SafeaTanhNoEps.apply
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""LLM utilities for TorchRL.
|
|
6
|
+
|
|
7
|
+
Note:
|
|
8
|
+
This package contains optional integrations (e.g. vLLM) that may rely on native
|
|
9
|
+
extensions. To keep `import torchrl` / `import torchrl.envs` lightweight and
|
|
10
|
+
robust, we **avoid importing optional backends at module import time** and
|
|
11
|
+
instead only import those backends on demand.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from .policies.common import ChatHistory, LLMWrapperBase, LogProbs, Masks, Text, Tokens
|
|
19
|
+
from .policies.transformers_wrapper import (
|
|
20
|
+
RemoteTransformersWrapper,
|
|
21
|
+
TransformersWrapper,
|
|
22
|
+
)
|
|
23
|
+
from .policies.vllm_wrapper import vLLMWrapper
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
# Data structures
|
|
27
|
+
"ChatHistory",
|
|
28
|
+
"LogProbs",
|
|
29
|
+
"Masks",
|
|
30
|
+
"Text",
|
|
31
|
+
"Tokens",
|
|
32
|
+
# Wrapper base class
|
|
33
|
+
"LLMWrapperBase",
|
|
34
|
+
# Local wrappers
|
|
35
|
+
"TransformersWrapper",
|
|
36
|
+
"vLLMWrapper",
|
|
37
|
+
# Remote wrappers
|
|
38
|
+
"RemoteTransformersWrapper",
|
|
39
|
+
# Async vLLM (recommended)
|
|
40
|
+
"AsyncVLLM",
|
|
41
|
+
"make_async_vllm_engine",
|
|
42
|
+
"stateless_init_process_group_async",
|
|
43
|
+
# Sync vLLM utilities
|
|
44
|
+
"make_vllm_worker",
|
|
45
|
+
"stateless_init_process_group",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def __getattr__(name: str) -> Any: # noqa: ANN401
|
|
50
|
+
# Keep backends optional and on-demand to avoid importing vLLM native extensions
|
|
51
|
+
# as a side-effect of importing torchrl.
|
|
52
|
+
if name in {
|
|
53
|
+
"AsyncVLLM",
|
|
54
|
+
"make_async_vllm_engine",
|
|
55
|
+
"make_vllm_worker",
|
|
56
|
+
"stateless_init_process_group",
|
|
57
|
+
"stateless_init_process_group_async",
|
|
58
|
+
}:
|
|
59
|
+
from . import backends # local import is intentional / required
|
|
60
|
+
|
|
61
|
+
return getattr(backends, name)
|
|
62
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""LLM backends.
|
|
6
|
+
|
|
7
|
+
These backends can be optional and may rely on native extensions. We avoid
|
|
8
|
+
importing them at module import time and lazily load on attribute access.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
# Base classes
|
|
17
|
+
"RLvLLMEngine",
|
|
18
|
+
# Sync vLLM
|
|
19
|
+
"make_vllm_worker",
|
|
20
|
+
"RayLLMWorker",
|
|
21
|
+
"LocalLLMWrapper",
|
|
22
|
+
# Async vLLM
|
|
23
|
+
"_AsyncvLLMWorker",
|
|
24
|
+
"_AsyncLLMEngine",
|
|
25
|
+
"AsyncVLLM",
|
|
26
|
+
"make_async_vllm_engine",
|
|
27
|
+
# Utilities
|
|
28
|
+
"stateless_init_process_group",
|
|
29
|
+
"stateless_init_process_group_async",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
_LAZY_ATTRS: dict[str, tuple[str, str]] = {
|
|
33
|
+
# Base classes and interfaces
|
|
34
|
+
"RLvLLMEngine": ("torchrl.modules.llm.backends.vllm", "RLvLLMEngine"),
|
|
35
|
+
# Sync vLLM
|
|
36
|
+
"make_vllm_worker": ("torchrl.modules.llm.backends.vllm", "make_vllm_worker"),
|
|
37
|
+
"RayLLMWorker": ("torchrl.modules.llm.backends.vllm", "RayLLMWorker"),
|
|
38
|
+
"LocalLLMWrapper": ("torchrl.modules.llm.backends.vllm", "LocalLLMWrapper"),
|
|
39
|
+
# Async vLLM
|
|
40
|
+
"_AsyncvLLMWorker": ("torchrl.modules.llm.backends.vllm", "_AsyncvLLMWorker"),
|
|
41
|
+
"_AsyncLLMEngine": ("torchrl.modules.llm.backends.vllm", "_AsyncLLMEngine"),
|
|
42
|
+
"AsyncVLLM": ("torchrl.modules.llm.backends.vllm", "AsyncVLLM"),
|
|
43
|
+
"make_async_vllm_engine": (
|
|
44
|
+
"torchrl.modules.llm.backends.vllm",
|
|
45
|
+
"make_async_vllm_engine",
|
|
46
|
+
),
|
|
47
|
+
# Utilities
|
|
48
|
+
"stateless_init_process_group": (
|
|
49
|
+
"torchrl.modules.llm.backends.vllm",
|
|
50
|
+
"stateless_init_process_group",
|
|
51
|
+
),
|
|
52
|
+
"stateless_init_process_group_async": (
|
|
53
|
+
"torchrl.modules.llm.backends.vllm",
|
|
54
|
+
"stateless_init_process_group_async",
|
|
55
|
+
),
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def __getattr__(name: str) -> Any: # noqa: ANN401
|
|
60
|
+
target = _LAZY_ATTRS.get(name)
|
|
61
|
+
if target is None:
|
|
62
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
63
|
+
module_name, attr_name = target
|
|
64
|
+
module = __import__(module_name, fromlist=[attr_name])
|
|
65
|
+
return getattr(module, attr_name)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""vLLM backends for TorchRL.
|
|
7
|
+
|
|
8
|
+
This module provides comprehensive vLLM integration including:
|
|
9
|
+
- Base classes and interfaces
|
|
10
|
+
- Synchronous vLLM workers
|
|
11
|
+
- Asynchronous vLLM services
|
|
12
|
+
- Shared utilities
|
|
13
|
+
|
|
14
|
+
Examples:
|
|
15
|
+
>>> # Create an async vLLM service (recommended)
|
|
16
|
+
>>> from torchrl.modules.llm.backends.vllm import AsyncVLLM
|
|
17
|
+
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B")
|
|
18
|
+
|
|
19
|
+
>>> # Create a sync Ray worker
|
|
20
|
+
>>> from torchrl.modules.llm.backends.vllm import make_vllm_worker
|
|
21
|
+
>>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", make_ray_worker=True)
|
|
22
|
+
|
|
23
|
+
>>> # All engines implement the same interface
|
|
24
|
+
>>> from torchrl.modules.llm.backends.vllm import RLvLLMEngine
|
|
25
|
+
>>> updater = vLLMUpdaterV2(any_engine) # Works with any RLvLLMEngine
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
from typing import Any
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
# Base classes and interfaces
|
|
34
|
+
"RLvLLMEngine",
|
|
35
|
+
# Synchronous vLLM
|
|
36
|
+
"make_vllm_worker",
|
|
37
|
+
"RayLLMWorker",
|
|
38
|
+
"LocalLLMWrapper",
|
|
39
|
+
# Asynchronous vLLM
|
|
40
|
+
"AsyncVLLM",
|
|
41
|
+
"make_async_vllm_engine",
|
|
42
|
+
"_AsyncLLMEngine",
|
|
43
|
+
"_AsyncvLLMWorker",
|
|
44
|
+
# Utilities
|
|
45
|
+
"stateless_init_process_group",
|
|
46
|
+
"stateless_init_process_group_async",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
_LAZY_ATTRS: dict[str, tuple[str, str]] = {
|
|
50
|
+
# Base
|
|
51
|
+
"RLvLLMEngine": ("torchrl.modules.llm.backends.vllm.base", "RLvLLMEngine"),
|
|
52
|
+
# Sync
|
|
53
|
+
"make_vllm_worker": (
|
|
54
|
+
"torchrl.modules.llm.backends.vllm.vllm_sync",
|
|
55
|
+
"make_vllm_worker",
|
|
56
|
+
),
|
|
57
|
+
"RayLLMWorker": ("torchrl.modules.llm.backends.vllm.vllm_sync", "RayLLMWorker"),
|
|
58
|
+
"LocalLLMWrapper": (
|
|
59
|
+
"torchrl.modules.llm.backends.vllm.vllm_sync",
|
|
60
|
+
"LocalLLMWrapper",
|
|
61
|
+
),
|
|
62
|
+
# Async
|
|
63
|
+
"_AsyncLLMEngine": (
|
|
64
|
+
"torchrl.modules.llm.backends.vllm.vllm_async",
|
|
65
|
+
"_AsyncLLMEngine",
|
|
66
|
+
),
|
|
67
|
+
"_AsyncvLLMWorker": (
|
|
68
|
+
"torchrl.modules.llm.backends.vllm.vllm_async",
|
|
69
|
+
"_AsyncvLLMWorker",
|
|
70
|
+
),
|
|
71
|
+
"AsyncVLLM": ("torchrl.modules.llm.backends.vllm.vllm_async", "AsyncVLLM"),
|
|
72
|
+
"make_async_vllm_engine": (
|
|
73
|
+
"torchrl.modules.llm.backends.vllm.vllm_async",
|
|
74
|
+
"make_async_vllm_engine",
|
|
75
|
+
),
|
|
76
|
+
# Utils
|
|
77
|
+
"stateless_init_process_group": (
|
|
78
|
+
"torchrl.modules.llm.backends.vllm.vllm_utils",
|
|
79
|
+
"stateless_init_process_group",
|
|
80
|
+
),
|
|
81
|
+
"stateless_init_process_group_async": (
|
|
82
|
+
"torchrl.modules.llm.backends.vllm.vllm_utils",
|
|
83
|
+
"stateless_init_process_group_async",
|
|
84
|
+
),
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def __getattr__(name: str) -> Any: # noqa: ANN401
|
|
89
|
+
target = _LAZY_ATTRS.get(name)
|
|
90
|
+
if target is None:
|
|
91
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
92
|
+
module_name, attr_name = target
|
|
93
|
+
module = __import__(module_name, fromlist=[attr_name])
|
|
94
|
+
return getattr(module, attr_name)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Override the last layers of your models here."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from vllm.config import VllmConfig
|
|
16
|
+
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
|
17
|
+
except ImportError:
|
|
18
|
+
|
|
19
|
+
class VllmConfig:
|
|
20
|
+
"""Placeholder for VllmConfig class when vLLM is not installed."""
|
|
21
|
+
|
|
22
|
+
class Qwen3ForCausalLM:
|
|
23
|
+
"""Placeholder for Qwen3ForCausalLM class when vLLM is not installed."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_fp32_output_enabled() -> bool:
|
|
27
|
+
"""Check if FP32 output is enabled."""
|
|
28
|
+
return os.getenv("VLLM_ENABLE_FP32_OUTPUT", "0") == "1"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Qwen3ForCausalLMFP32(Qwen3ForCausalLM):
|
|
32
|
+
"""Qwen3ForCausalLM with FP32 output."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
35
|
+
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
36
|
+
if is_fp32_output_enabled():
|
|
37
|
+
self.lm_head.float()
|
|
38
|
+
|
|
39
|
+
def compute_logits(
|
|
40
|
+
self,
|
|
41
|
+
hidden_states: torch.Tensor,
|
|
42
|
+
) -> torch.Tensor | None:
|
|
43
|
+
if is_fp32_output_enabled():
|
|
44
|
+
hidden_states = hidden_states.float()
|
|
45
|
+
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
46
|
+
return logits
|