torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
import io
|
|
6
|
+
import pickle
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from safetensors.torch import save
|
|
14
|
+
except ImportError:
|
|
15
|
+
save = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestCompressedStorageBenchmark:
|
|
19
|
+
"""Benchmark tests for CompressedListStorage."""
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def make_compressible_mock_data(num_experiences: int, device=None) -> dict:
|
|
23
|
+
"""Easily compressible data for testing."""
|
|
24
|
+
if device is None:
|
|
25
|
+
device = torch.device("cpu")
|
|
26
|
+
|
|
27
|
+
return {
|
|
28
|
+
"observations": torch.zeros(
|
|
29
|
+
(num_experiences, 4, 84, 84),
|
|
30
|
+
dtype=torch.uint8,
|
|
31
|
+
device=device,
|
|
32
|
+
),
|
|
33
|
+
"actions": torch.zeros((num_experiences,), device=device),
|
|
34
|
+
"rewards": torch.zeros((num_experiences,), device=device),
|
|
35
|
+
"next_observations": torch.zeros(
|
|
36
|
+
(num_experiences, 4, 84, 84),
|
|
37
|
+
dtype=torch.uint8,
|
|
38
|
+
device=device,
|
|
39
|
+
),
|
|
40
|
+
"terminations": torch.zeros(
|
|
41
|
+
(num_experiences,), dtype=torch.bool, device=device
|
|
42
|
+
),
|
|
43
|
+
"truncations": torch.zeros(
|
|
44
|
+
(num_experiences,), dtype=torch.bool, device=device
|
|
45
|
+
),
|
|
46
|
+
"batch_size": [num_experiences],
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict:
|
|
51
|
+
"""Uncompressible data for testing."""
|
|
52
|
+
if device is None:
|
|
53
|
+
device = torch.device("cpu")
|
|
54
|
+
return {
|
|
55
|
+
"observations": torch.randn(
|
|
56
|
+
(num_experiences, 4, 84, 84),
|
|
57
|
+
dtype=torch.float32,
|
|
58
|
+
device=device,
|
|
59
|
+
),
|
|
60
|
+
"actions": torch.randint(0, 10, (num_experiences,), device=device),
|
|
61
|
+
"rewards": torch.randn(
|
|
62
|
+
(num_experiences,), dtype=torch.float32, device=device
|
|
63
|
+
),
|
|
64
|
+
"next_observations": torch.randn(
|
|
65
|
+
(num_experiences, 4, 84, 84),
|
|
66
|
+
dtype=torch.float32,
|
|
67
|
+
device=device,
|
|
68
|
+
),
|
|
69
|
+
"terminations": torch.rand((num_experiences,), device=device)
|
|
70
|
+
< 0.2, # ~20% True
|
|
71
|
+
"truncations": torch.rand((num_experiences,), device=device)
|
|
72
|
+
< 0.1, # ~10% True
|
|
73
|
+
"batch_size": [num_experiences],
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
@pytest.mark.benchmark(
|
|
77
|
+
group="tensor_serialization_speed",
|
|
78
|
+
min_time=0.1,
|
|
79
|
+
max_time=0.5,
|
|
80
|
+
min_rounds=5,
|
|
81
|
+
disable_gc=True,
|
|
82
|
+
warmup=False,
|
|
83
|
+
)
|
|
84
|
+
@pytest.mark.parametrize(
|
|
85
|
+
"serialization_method",
|
|
86
|
+
["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"],
|
|
87
|
+
)
|
|
88
|
+
def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str):
|
|
89
|
+
"""Benchmark the speed of different tensor serialization methods.
|
|
90
|
+
|
|
91
|
+
TODO: we might need to also test which methods work on the gpu.
|
|
92
|
+
pytest benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'
|
|
93
|
+
|
|
94
|
+
------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
|
|
95
|
+
Name (time in us) Mean (smaller is better) OPS (bigger is better)
|
|
96
|
+
--------------------------------------------------------------------------------------------------
|
|
97
|
+
test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
|
|
98
|
+
test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
|
|
99
|
+
test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
|
|
100
|
+
test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
|
|
101
|
+
test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
|
|
102
|
+
--------------------------------------------------------------------------------------------------
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def serialize_with_pickle(data: torch.Tensor) -> bytes:
|
|
106
|
+
"""Serialize tensor using pickle."""
|
|
107
|
+
buffer = io.BytesIO()
|
|
108
|
+
pickle.dump(data, buffer)
|
|
109
|
+
return buffer.getvalue()
|
|
110
|
+
|
|
111
|
+
def serialize_with_untyped_storage(data: torch.Tensor) -> bytes:
|
|
112
|
+
"""Serialize tensor using torch's built-in method."""
|
|
113
|
+
return bytes(data.untyped_storage())
|
|
114
|
+
|
|
115
|
+
def serialize_with_numpy(data: torch.Tensor) -> bytes:
|
|
116
|
+
"""Serialize tensor using numpy."""
|
|
117
|
+
return data.numpy().tobytes()
|
|
118
|
+
|
|
119
|
+
def serialize_with_safetensors(data: torch.Tensor) -> bytes:
|
|
120
|
+
return save({"0": data})
|
|
121
|
+
|
|
122
|
+
def serialize_with_torch(data: torch.Tensor) -> bytes:
|
|
123
|
+
"""Serialize tensor using torch's built-in method."""
|
|
124
|
+
buffer = io.BytesIO()
|
|
125
|
+
torch.save(data, buffer)
|
|
126
|
+
return buffer.getvalue()
|
|
127
|
+
|
|
128
|
+
# Benchmark each serialization method
|
|
129
|
+
if serialization_method == "pickle":
|
|
130
|
+
serialize_fn = serialize_with_pickle
|
|
131
|
+
elif serialization_method == "torch.save":
|
|
132
|
+
serialize_fn = serialize_with_torch
|
|
133
|
+
elif serialization_method == "untyped_storage":
|
|
134
|
+
serialize_fn = serialize_with_untyped_storage
|
|
135
|
+
elif serialization_method == "numpy":
|
|
136
|
+
serialize_fn = serialize_with_numpy
|
|
137
|
+
elif serialization_method == "safetensors":
|
|
138
|
+
serialize_fn = serialize_with_safetensors
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(f"Unknown serialization method: {serialization_method}")
|
|
141
|
+
|
|
142
|
+
data = self.make_compressible_mock_data(1).get("observations")
|
|
143
|
+
|
|
144
|
+
# Run the actual benchmark
|
|
145
|
+
benchmark(serialize_fn, data)
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
import argparse
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from tensordict import TensorDict
|
|
11
|
+
from torchrl.envs import ParallelEnv, SerialEnv, step_mdp, StepCounter, TransformedEnv
|
|
12
|
+
from torchrl.envs.libs.dm_control import DMControlEnv
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def make_simple_env():
|
|
16
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
17
|
+
env = DMControlEnv("cheetah", "run", device=device)
|
|
18
|
+
env.rollout(3)
|
|
19
|
+
return ((env,), {})
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def make_transformed_env():
|
|
23
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
24
|
+
env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50))
|
|
25
|
+
env.rollout(3)
|
|
26
|
+
return ((env,), {})
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def make_serial_env():
|
|
30
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
31
|
+
env = SerialEnv(3, lambda: DMControlEnv("cheetah", "run", device=device))
|
|
32
|
+
env.rollout(3)
|
|
33
|
+
return ((env,), {})
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def make_parallel_env():
|
|
37
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
38
|
+
env = ParallelEnv(3, lambda: DMControlEnv("cheetah", "run", device=device))
|
|
39
|
+
env.rollout(3)
|
|
40
|
+
return ((env,), {})
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def make_nested_td():
|
|
44
|
+
return TensorDict(
|
|
45
|
+
{
|
|
46
|
+
("agent", "action"): 0,
|
|
47
|
+
("agent", "done"): 0,
|
|
48
|
+
("agent", "obs"): 0,
|
|
49
|
+
("agent", "other"): 0,
|
|
50
|
+
("next", "agent", "action"): 1,
|
|
51
|
+
("next", "agent", "reward"): 1,
|
|
52
|
+
("next", "agent", "done"): 1,
|
|
53
|
+
("next", "agent", "obs"): 1,
|
|
54
|
+
},
|
|
55
|
+
[],
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def make_flat_td():
|
|
60
|
+
return TensorDict(
|
|
61
|
+
{
|
|
62
|
+
"action": 0,
|
|
63
|
+
"done": 0,
|
|
64
|
+
"obs": 0,
|
|
65
|
+
"other": 0,
|
|
66
|
+
("next", "action"): 1,
|
|
67
|
+
("next", "reward"): 1,
|
|
68
|
+
("next", "done"): 1,
|
|
69
|
+
("next", "obs"): 1,
|
|
70
|
+
},
|
|
71
|
+
[],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def execute_env(env):
|
|
76
|
+
env.rollout(1000, break_when_any_done=False)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_simple(benchmark):
|
|
80
|
+
(c,), _ = make_simple_env()
|
|
81
|
+
benchmark(execute_env, c)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_transformed(benchmark):
|
|
85
|
+
(c,), _ = make_transformed_env()
|
|
86
|
+
benchmark(execute_env, c)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def test_serial(benchmark):
|
|
90
|
+
(c,), _ = make_serial_env()
|
|
91
|
+
benchmark(execute_env, c)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_parallel(benchmark):
|
|
95
|
+
(c,), _ = make_parallel_env()
|
|
96
|
+
benchmark(execute_env, c)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@pytest.mark.parametrize("nested", [True, False])
|
|
100
|
+
@pytest.mark.parametrize("keep_other", [True, False])
|
|
101
|
+
@pytest.mark.parametrize("exclude_reward", [True, False])
|
|
102
|
+
@pytest.mark.parametrize("exclude_done", [True, False])
|
|
103
|
+
@pytest.mark.parametrize("exclude_action", [True, False])
|
|
104
|
+
def test_step_mdp_speed(
|
|
105
|
+
benchmark, nested, keep_other, exclude_reward, exclude_done, exclude_action
|
|
106
|
+
):
|
|
107
|
+
if nested:
|
|
108
|
+
td = make_nested_td()
|
|
109
|
+
reward_key = ("agent", "reward")
|
|
110
|
+
done_key = ("agent", "done")
|
|
111
|
+
action_key = ("agent", "action")
|
|
112
|
+
else:
|
|
113
|
+
td = make_flat_td()
|
|
114
|
+
reward_key = "reward"
|
|
115
|
+
done_key = "done"
|
|
116
|
+
action_key = "action"
|
|
117
|
+
|
|
118
|
+
benchmark(
|
|
119
|
+
step_mdp,
|
|
120
|
+
td,
|
|
121
|
+
action_keys=action_key,
|
|
122
|
+
reward_keys=reward_key,
|
|
123
|
+
done_keys=done_key,
|
|
124
|
+
keep_other=keep_other,
|
|
125
|
+
exclude_reward=exclude_reward,
|
|
126
|
+
exclude_done=exclude_done,
|
|
127
|
+
exclude_action=exclude_action,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
if __name__ == "__main__":
|
|
132
|
+
args, unknown = argparse.ArgumentParser().parse_known_args()
|
|
133
|
+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
benchmarks/test_llm.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import importlib.util
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
import torch
|
|
12
|
+
from tensordict import set_list_to_stack, TensorDict
|
|
13
|
+
from torchrl.data.llm import History
|
|
14
|
+
from torchrl.modules.llm.policies.common import ChatHistory
|
|
15
|
+
from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
|
|
16
|
+
|
|
17
|
+
_has_transformers = importlib.import_module("transformers") is not None
|
|
18
|
+
|
|
19
|
+
# Skip all these tests if gpu is not available
|
|
20
|
+
pytestmark = pytest.mark.skipif(
|
|
21
|
+
not torch.cuda.is_available(), reason="GPU not available"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture(scope="module")
|
|
26
|
+
def transformers_wrapper():
|
|
27
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
28
|
+
with torch.device(device):
|
|
29
|
+
model = TransformersWrapper(
|
|
30
|
+
model="Qwen/Qwen2.5-0.5B",
|
|
31
|
+
tokenizer="Qwen/Qwen2.5-0.5B",
|
|
32
|
+
pad_model_input=False,
|
|
33
|
+
generate=False,
|
|
34
|
+
)
|
|
35
|
+
return model
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@pytest.mark.skipif(not _has_transformers, reason="transformers not installed")
|
|
39
|
+
class TestWrappers:
|
|
40
|
+
@pytest.mark.parametrize("packing", [True, False])
|
|
41
|
+
@set_list_to_stack(True)
|
|
42
|
+
def test_packing(self, benchmark, transformers_wrapper, packing: bool):
|
|
43
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
44
|
+
with torch.device(device):
|
|
45
|
+
transformers_wrapper = TransformersWrapper(
|
|
46
|
+
model=transformers_wrapper.model,
|
|
47
|
+
tokenizer=transformers_wrapper.tokenizer,
|
|
48
|
+
pad_model_input=not packing,
|
|
49
|
+
generate=False,
|
|
50
|
+
pad_output=False,
|
|
51
|
+
)
|
|
52
|
+
data = TensorDict(
|
|
53
|
+
{
|
|
54
|
+
"history": ChatHistory(
|
|
55
|
+
full=History(
|
|
56
|
+
role=[
|
|
57
|
+
["user", "assistant"],
|
|
58
|
+
["user", "assistant"],
|
|
59
|
+
["user", "assistant"],
|
|
60
|
+
["user", "assistant"],
|
|
61
|
+
],
|
|
62
|
+
content=[
|
|
63
|
+
[
|
|
64
|
+
"Lorem ipsum dolor sit amet",
|
|
65
|
+
"consectetur adipiscing elit",
|
|
66
|
+
],
|
|
67
|
+
[
|
|
68
|
+
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
|
|
69
|
+
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
|
|
70
|
+
],
|
|
71
|
+
[
|
|
72
|
+
"Lorem ipsum dolor sit amet",
|
|
73
|
+
"consectetur adipiscing elit",
|
|
74
|
+
],
|
|
75
|
+
[
|
|
76
|
+
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
|
|
77
|
+
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
|
|
78
|
+
],
|
|
79
|
+
],
|
|
80
|
+
batch_size=(4, 2),
|
|
81
|
+
device=device,
|
|
82
|
+
),
|
|
83
|
+
batch_size=(4,),
|
|
84
|
+
device=device,
|
|
85
|
+
)
|
|
86
|
+
},
|
|
87
|
+
batch_size=(4,),
|
|
88
|
+
device=device,
|
|
89
|
+
).to_lazystack()
|
|
90
|
+
|
|
91
|
+
def setup():
|
|
92
|
+
if torch.cuda.is_available():
|
|
93
|
+
torch.cuda.empty_cache()
|
|
94
|
+
|
|
95
|
+
benchmark.pedantic(
|
|
96
|
+
transformers_wrapper,
|
|
97
|
+
(data,),
|
|
98
|
+
rounds=10,
|
|
99
|
+
warmup_rounds=3,
|
|
100
|
+
setup=setup,
|
|
101
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import gc
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
from tensordict import set_capture_non_tensor_stack
|
|
12
|
+
from torchrl.envs import ParallelEnv, SerialEnv
|
|
13
|
+
from torchrl.testing.mocking_classes import EnvWithMetadata
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _rollout(env, n_steps: int, break_when_any_done: bool) -> None:
|
|
17
|
+
env.rollout(n_steps, break_when_any_done=break_when_any_done)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.mark.parametrize("break_when_any_done", [True, False])
|
|
21
|
+
@pytest.mark.parametrize(
|
|
22
|
+
"kind,use_buffers",
|
|
23
|
+
[
|
|
24
|
+
pytest.param("single", None, id="single"),
|
|
25
|
+
pytest.param("serial", False, id="serial-no-buffers"),
|
|
26
|
+
pytest.param("serial", True, id="serial-buffers"),
|
|
27
|
+
pytest.param("parallel", False, id="parallel-no-buffers"),
|
|
28
|
+
pytest.param("parallel", True, id="parallel-buffers"),
|
|
29
|
+
],
|
|
30
|
+
)
|
|
31
|
+
@pytest.mark.parametrize("n_steps", [1000])
|
|
32
|
+
def test_non_tensor_env_rollout_speed(
|
|
33
|
+
benchmark,
|
|
34
|
+
break_when_any_done: bool,
|
|
35
|
+
kind: str,
|
|
36
|
+
use_buffers: bool | None,
|
|
37
|
+
n_steps: int,
|
|
38
|
+
):
|
|
39
|
+
"""Benchmarks a single rollout, after a warmup rollout, for non-tensor stacking envs.
|
|
40
|
+
|
|
41
|
+
Mirrors `test/test_envs.py::TestNonTensorEnv`'s option matrix (single/serial/parallel,
|
|
42
|
+
break_when_any_done, use_buffers).
|
|
43
|
+
"""
|
|
44
|
+
with set_capture_non_tensor_stack(False):
|
|
45
|
+
if kind == "single":
|
|
46
|
+
env = EnvWithMetadata()
|
|
47
|
+
elif kind == "serial":
|
|
48
|
+
env = SerialEnv(2, EnvWithMetadata, use_buffers=use_buffers)
|
|
49
|
+
elif kind == "parallel":
|
|
50
|
+
env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
|
|
51
|
+
else:
|
|
52
|
+
raise RuntimeError(f"Unknown kind={kind}")
|
|
53
|
+
|
|
54
|
+
env.set_seed(0)
|
|
55
|
+
env.reset()
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
# Warmup run (not timed)
|
|
59
|
+
_rollout(env, n_steps=n_steps, break_when_any_done=break_when_any_done)
|
|
60
|
+
|
|
61
|
+
# Timed run(s)
|
|
62
|
+
benchmark(
|
|
63
|
+
_rollout, env, n_steps=n_steps, break_when_any_done=break_when_any_done
|
|
64
|
+
)
|
|
65
|
+
finally:
|
|
66
|
+
env.close(raise_if_closed=False)
|
|
67
|
+
del env
|
|
68
|
+
# Give multiprocessing envs a brief chance to terminate cleanly.
|
|
69
|
+
time.sleep(0.05)
|
|
70
|
+
gc.collect()
|