torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
|
|
14
|
+
from torchrl.data.replay_buffers.samplers import Sampler
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ParameterScheduler(ABC):
|
|
18
|
+
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler.
|
|
19
|
+
|
|
20
|
+
Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
obj (ReplayBuffer or Sampler): the replay buffer or sampler whose sampler to adjust
|
|
24
|
+
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter
|
|
25
|
+
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
|
|
26
|
+
Defaults to `None`.
|
|
27
|
+
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
|
|
28
|
+
Defaults to `None`.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
obj: ReplayBuffer | Sampler,
|
|
35
|
+
param_name: str,
|
|
36
|
+
min_value: int | float | None = None,
|
|
37
|
+
max_value: int | float | None = None,
|
|
38
|
+
):
|
|
39
|
+
if not isinstance(obj, (ReplayBuffer, Sampler)):
|
|
40
|
+
raise TypeError(
|
|
41
|
+
f"ParameterScheduler only supports Sampler class. Pass either `ReplayBuffer` or `Sampler` object. Got {type(obj)} instead."
|
|
42
|
+
)
|
|
43
|
+
self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj
|
|
44
|
+
self.param_name = param_name
|
|
45
|
+
self._min_val = min_value or float("-inf")
|
|
46
|
+
self._max_val = max_value or float("inf")
|
|
47
|
+
if not hasattr(self.sampler, self.param_name):
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"Provided class {type(obj).__name__} does not have an attribute {param_name}"
|
|
50
|
+
)
|
|
51
|
+
initial_val = getattr(self.sampler, self.param_name)
|
|
52
|
+
if isinstance(initial_val, torch.Tensor):
|
|
53
|
+
initial_val = initial_val.clone()
|
|
54
|
+
self.backend = torch
|
|
55
|
+
else:
|
|
56
|
+
self.backend = np
|
|
57
|
+
self.initial_val = initial_val
|
|
58
|
+
self._step_cnt = 0
|
|
59
|
+
|
|
60
|
+
def state_dict(self):
|
|
61
|
+
"""Returns the state of the scheduler as a :class:`dict`.
|
|
62
|
+
|
|
63
|
+
It contains an entry for every variable in ``self.__dict__`` which
|
|
64
|
+
is not the sampler.
|
|
65
|
+
"""
|
|
66
|
+
sd = dict(self.__dict__)
|
|
67
|
+
del sd["sampler"]
|
|
68
|
+
return sd
|
|
69
|
+
|
|
70
|
+
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
71
|
+
"""Load the scheduler's state.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
state_dict (dict): scheduler state. Should be an object returned
|
|
75
|
+
from a call to :meth:`state_dict`.
|
|
76
|
+
"""
|
|
77
|
+
self.__dict__.update(state_dict)
|
|
78
|
+
|
|
79
|
+
def step(self):
|
|
80
|
+
self._step_cnt += 1
|
|
81
|
+
# Apply the step function
|
|
82
|
+
new_value = self._step()
|
|
83
|
+
# clip value to specified range
|
|
84
|
+
new_value_clipped = self.backend.clip(new_value, self._min_val, self._max_val)
|
|
85
|
+
# Set the new value of the parameter dynamically
|
|
86
|
+
setattr(self.sampler, self.param_name, new_value_clipped)
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def _step(self):
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class LambdaScheduler(ParameterScheduler):
|
|
94
|
+
"""Sets a parameter to its initial value times a given function.
|
|
95
|
+
|
|
96
|
+
Similar to :class:`~torch.optim.LambdaLR`.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
|
|
100
|
+
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
|
|
101
|
+
beta parameter.
|
|
102
|
+
lambda_fn (Callable[[int], float]): A function which computes a multiplicative factor given an integer
|
|
103
|
+
parameter ``step_count``.
|
|
104
|
+
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
|
|
105
|
+
Defaults to `None`.
|
|
106
|
+
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
|
|
107
|
+
Defaults to `None`.
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
obj: ReplayBuffer | Sampler,
|
|
114
|
+
param_name: str,
|
|
115
|
+
lambda_fn: Callable[[int], float],
|
|
116
|
+
min_value: int | float | None = None,
|
|
117
|
+
max_value: int | float | None = None,
|
|
118
|
+
):
|
|
119
|
+
super().__init__(obj, param_name, min_value, max_value)
|
|
120
|
+
self.lambda_fn = lambda_fn
|
|
121
|
+
|
|
122
|
+
def _step(self):
|
|
123
|
+
return self.initial_val * self.lambda_fn(self._step_cnt)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class LinearScheduler(ParameterScheduler):
|
|
127
|
+
"""A linear scheduler for gradually altering a parameter in an object over a given number of steps.
|
|
128
|
+
|
|
129
|
+
This scheduler linearly interpolates between the initial value of the parameter and a final target value.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
|
|
133
|
+
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
|
|
134
|
+
beta parameter.
|
|
135
|
+
final_value (number): The final value that the parameter will reach after the
|
|
136
|
+
specified number of steps.
|
|
137
|
+
num_steps (number, optional): The total number of steps over which the parameter
|
|
138
|
+
will be linearly altered.
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
>>> # xdoctest: +SKIP
|
|
142
|
+
>>> # Assuming sampler uses initial beta = 0.6
|
|
143
|
+
>>> # beta = 0.7 if step == 1
|
|
144
|
+
>>> # beta = 0.8 if step == 2
|
|
145
|
+
>>> # beta = 0.9 if step == 3
|
|
146
|
+
>>> # beta = 1.0 if step >= 4
|
|
147
|
+
>>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4)
|
|
148
|
+
>>> for epoch in range(100):
|
|
149
|
+
>>> train(...)
|
|
150
|
+
>>> validate(...)
|
|
151
|
+
>>> scheduler.step()
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
obj: ReplayBuffer | Sampler,
|
|
157
|
+
param_name: str,
|
|
158
|
+
final_value: int | float,
|
|
159
|
+
num_steps: int,
|
|
160
|
+
):
|
|
161
|
+
super().__init__(obj, param_name)
|
|
162
|
+
if isinstance(self.initial_val, torch.Tensor):
|
|
163
|
+
# cast to same type as initial value
|
|
164
|
+
final_value = torch.tensor(final_value).to(self.initial_val)
|
|
165
|
+
self.final_val = final_value
|
|
166
|
+
self.num_steps = num_steps
|
|
167
|
+
self._delta = (self.final_val - self.initial_val) / self.num_steps
|
|
168
|
+
|
|
169
|
+
def _step(self):
|
|
170
|
+
# Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
|
|
171
|
+
# without graph breaks
|
|
172
|
+
if self._step_cnt < self.num_steps:
|
|
173
|
+
return self.initial_val + (self._delta * self._step_cnt)
|
|
174
|
+
else:
|
|
175
|
+
return self.final_val
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class StepScheduler(ParameterScheduler):
|
|
179
|
+
"""A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.
|
|
180
|
+
|
|
181
|
+
The scheduler can apply:
|
|
182
|
+
1. Multiplicative changes: `new_val = curr_val * gamma`
|
|
183
|
+
2. Additive changes: `new_val = curr_val + gamma`
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
|
|
187
|
+
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
|
|
188
|
+
beta parameter.
|
|
189
|
+
gamma (int or float, optional): The value by which to adjust the parameter,
|
|
190
|
+
either in a multiplicative or additive way.
|
|
191
|
+
n_steps (int, optional): The number of steps after which the parameter should be altered.
|
|
192
|
+
Defaults to 1.
|
|
193
|
+
mode (str, optional): The mode of scheduling. Can be either `'multiplicative'` or `'additive'`.
|
|
194
|
+
Defaults to `'multiplicative'`.
|
|
195
|
+
min_value (int or float, optional): a lower bound for the parameter to be adjusted.
|
|
196
|
+
Defaults to `None`.
|
|
197
|
+
max_value (int or float, optional): an upper bound for the parameter to be adjusted.
|
|
198
|
+
Defaults to `None`.
|
|
199
|
+
|
|
200
|
+
Example:
|
|
201
|
+
>>> # xdoctest: +SKIP
|
|
202
|
+
>>> # Assuming sampler uses initial beta = 0.6
|
|
203
|
+
>>> # beta = 0.6 if 0 <= step < 10
|
|
204
|
+
>>> # beta = 0.7 if 10 <= step < 20
|
|
205
|
+
>>> # beta = 0.8 if 20 <= step < 30
|
|
206
|
+
>>> # beta = 0.9 if 30 <= step < 40
|
|
207
|
+
>>> # beta = 1.0 if 40 <= step
|
|
208
|
+
>>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0)
|
|
209
|
+
>>> for epoch in range(100):
|
|
210
|
+
>>> train(...)
|
|
211
|
+
>>> validate(...)
|
|
212
|
+
>>> scheduler.step()
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
obj: ReplayBuffer | Sampler,
|
|
218
|
+
param_name: str,
|
|
219
|
+
gamma: int | float = 0.9,
|
|
220
|
+
n_steps: int = 1,
|
|
221
|
+
mode: str = "multiplicative",
|
|
222
|
+
min_value: int | float | None = None,
|
|
223
|
+
max_value: int | float | None = None,
|
|
224
|
+
):
|
|
225
|
+
|
|
226
|
+
super().__init__(obj, param_name, min_value, max_value)
|
|
227
|
+
self.gamma = gamma
|
|
228
|
+
self.n_steps = n_steps
|
|
229
|
+
self.mode = mode
|
|
230
|
+
if mode == "additive":
|
|
231
|
+
operator = self.backend.add
|
|
232
|
+
elif mode == "multiplicative":
|
|
233
|
+
operator = self.backend.multiply
|
|
234
|
+
else:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Invalid mode: {mode}. Choose 'multiplicative' or 'additive'."
|
|
237
|
+
)
|
|
238
|
+
self.operator = operator
|
|
239
|
+
|
|
240
|
+
def _step(self):
|
|
241
|
+
"""Applies the scheduling logic to alter the parameter value every `n_steps`."""
|
|
242
|
+
# Check if the current step count is a multiple of n_steps
|
|
243
|
+
current_val = getattr(self.sampler, self.param_name)
|
|
244
|
+
# Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
|
|
245
|
+
# without graph breaks
|
|
246
|
+
if self._step_cnt % self.n_steps == 0:
|
|
247
|
+
return self.operator(current_val, self.gamma)
|
|
248
|
+
else:
|
|
249
|
+
return current_val
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class SchedulerList:
|
|
253
|
+
"""Simple container abstracting a list of schedulers."""
|
|
254
|
+
|
|
255
|
+
def __init__(self, schedulers: list[ParameterScheduler]) -> None:
|
|
256
|
+
if isinstance(schedulers, ParameterScheduler):
|
|
257
|
+
schedulers = [schedulers]
|
|
258
|
+
self.schedulers = schedulers
|
|
259
|
+
|
|
260
|
+
def append(self, scheduler: ParameterScheduler):
|
|
261
|
+
self.schedulers.append(scheduler)
|
|
262
|
+
|
|
263
|
+
def step(self):
|
|
264
|
+
for scheduler in self.schedulers:
|
|
265
|
+
scheduler.step()
|