torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Versioned gym environment name helpers for TorchRL tests."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import sys
|
|
11
|
+
|
|
12
|
+
from torchrl._utils import implement_for
|
|
13
|
+
from torchrl.envs.libs.gym import _has_gym, gym_backend
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"BREAKOUT_VERSIONED",
|
|
17
|
+
"CARTPOLE_VERSIONED",
|
|
18
|
+
"CLIFFWALKING_VERSIONED",
|
|
19
|
+
"HALFCHEETAH_VERSIONED",
|
|
20
|
+
"PENDULUM_VERSIONED",
|
|
21
|
+
"PONG_VERSIONED",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
PYTHON_3_9 = sys.version_info.major == 3 and sys.version_info.minor <= 9
|
|
25
|
+
|
|
26
|
+
# Module-level variables that will be set by _set_gym_environments
|
|
27
|
+
_CARTPOLE_VERSIONED = None
|
|
28
|
+
_HALFCHEETAH_VERSIONED = None
|
|
29
|
+
_PENDULUM_VERSIONED = None
|
|
30
|
+
_PONG_VERSIONED = None
|
|
31
|
+
_BREAKOUT_VERSIONED = None
|
|
32
|
+
_CLIFFWALKING_VERSIONED = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def CARTPOLE_VERSIONED():
|
|
36
|
+
"""Return the versioned CartPole environment name for the current gym backend."""
|
|
37
|
+
if gym_backend() is not None:
|
|
38
|
+
_set_gym_environments()
|
|
39
|
+
return _CARTPOLE_VERSIONED
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def HALFCHEETAH_VERSIONED():
|
|
43
|
+
"""Return the versioned HalfCheetah environment name for the current gym backend."""
|
|
44
|
+
if gym_backend() is not None:
|
|
45
|
+
_set_gym_environments()
|
|
46
|
+
return _HALFCHEETAH_VERSIONED
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def PONG_VERSIONED():
|
|
50
|
+
"""Return the versioned Pong environment name for the current gym backend."""
|
|
51
|
+
# Gymnasium says that the ale_py behavior changes from 1.0
|
|
52
|
+
# but with python 3.12 it is already the case with 0.29.1
|
|
53
|
+
try:
|
|
54
|
+
import ale_py # noqa: F401
|
|
55
|
+
except ImportError:
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
if gym_backend() is not None:
|
|
59
|
+
_set_gym_environments()
|
|
60
|
+
return _PONG_VERSIONED
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def CLIFFWALKING_VERSIONED():
|
|
64
|
+
"""Return the versioned CliffWalking environment name for the current gym backend."""
|
|
65
|
+
if gym_backend() is not None:
|
|
66
|
+
_set_gym_environments()
|
|
67
|
+
return _CLIFFWALKING_VERSIONED
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def BREAKOUT_VERSIONED():
|
|
71
|
+
"""Return the versioned Breakout environment name for the current gym backend."""
|
|
72
|
+
# Gymnasium says that the ale_py behavior changes from 1.0
|
|
73
|
+
# but with python 3.12 it is already the case with 0.29.1
|
|
74
|
+
try:
|
|
75
|
+
import ale_py # noqa: F401
|
|
76
|
+
except ImportError:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
if gym_backend() is not None:
|
|
80
|
+
_set_gym_environments()
|
|
81
|
+
return _BREAKOUT_VERSIONED
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def PENDULUM_VERSIONED():
|
|
85
|
+
"""Return the versioned Pendulum environment name for the current gym backend."""
|
|
86
|
+
if gym_backend() is not None:
|
|
87
|
+
_set_gym_environments()
|
|
88
|
+
return _PENDULUM_VERSIONED
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _set_gym_environments():
|
|
92
|
+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
|
|
93
|
+
|
|
94
|
+
_CARTPOLE_VERSIONED = None
|
|
95
|
+
_HALFCHEETAH_VERSIONED = None
|
|
96
|
+
_PENDULUM_VERSIONED = None
|
|
97
|
+
_PONG_VERSIONED = None
|
|
98
|
+
_BREAKOUT_VERSIONED = None
|
|
99
|
+
_CLIFFWALKING_VERSIONED = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@implement_for("gym", None, "0.21.0")
|
|
103
|
+
def _set_gym_environments(): # noqa: F811
|
|
104
|
+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
|
|
105
|
+
|
|
106
|
+
_CARTPOLE_VERSIONED = "CartPole-v0"
|
|
107
|
+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v2"
|
|
108
|
+
_PENDULUM_VERSIONED = "Pendulum-v0"
|
|
109
|
+
_PONG_VERSIONED = "Pong-v4"
|
|
110
|
+
_BREAKOUT_VERSIONED = "Breakout-v4"
|
|
111
|
+
_CLIFFWALKING_VERSIONED = "CliffWalking-v0"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@implement_for("gym", "0.21.0", None)
|
|
115
|
+
def _set_gym_environments(): # noqa: F811
|
|
116
|
+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
|
|
117
|
+
|
|
118
|
+
_CARTPOLE_VERSIONED = "CartPole-v1"
|
|
119
|
+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
|
|
120
|
+
_PENDULUM_VERSIONED = "Pendulum-v1"
|
|
121
|
+
_PONG_VERSIONED = "ALE/Pong-v5"
|
|
122
|
+
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
|
|
123
|
+
_CLIFFWALKING_VERSIONED = "CliffWalking-v0"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@implement_for("gymnasium", None, "1.0.0")
|
|
127
|
+
def _set_gym_environments(): # noqa: F811
|
|
128
|
+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
|
|
129
|
+
|
|
130
|
+
_CARTPOLE_VERSIONED = "CartPole-v1"
|
|
131
|
+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v4"
|
|
132
|
+
_PENDULUM_VERSIONED = "Pendulum-v1"
|
|
133
|
+
_PONG_VERSIONED = "ALE/Pong-v5"
|
|
134
|
+
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
|
|
135
|
+
_CLIFFWALKING_VERSIONED = "CliffWalking-v0"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
139
|
+
def _set_gym_environments(): # noqa: F811
|
|
140
|
+
raise ImportError
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@implement_for("gymnasium", "1.1.0")
|
|
144
|
+
def _set_gym_environments(): # noqa: F811
|
|
145
|
+
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED
|
|
146
|
+
|
|
147
|
+
_CARTPOLE_VERSIONED = "CartPole-v1"
|
|
148
|
+
_HALFCHEETAH_VERSIONED = "HalfCheetah-v5"
|
|
149
|
+
_PENDULUM_VERSIONED = "Pendulum-v1"
|
|
150
|
+
_PONG_VERSIONED = "ALE/Pong-v5"
|
|
151
|
+
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
|
|
152
|
+
_CLIFFWALKING_VERSIONED = "CliffWalking-v1" if not PYTHON_3_9 else "CliffWalking-v0"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
if _has_gym:
|
|
156
|
+
_set_gym_environments()
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""Shared test fixtures and mock infrastructure for LLM tests."""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import random
|
|
9
|
+
import string
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MockTransformerConfig:
|
|
15
|
+
"""Mock config to mimic transformers model config."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, vocab_size: int, max_position_embeddings: int = 2048):
|
|
18
|
+
self.vocab_size = vocab_size
|
|
19
|
+
self.max_position_embeddings = max_position_embeddings
|
|
20
|
+
self.hidden_size = vocab_size # For simplicity
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MockTransformerOutput:
|
|
24
|
+
"""Mock output that mimics transformers model output with dict-like access."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, logits):
|
|
27
|
+
self.logits = logits
|
|
28
|
+
|
|
29
|
+
def __getitem__(self, key):
|
|
30
|
+
"""Allow dict-like access for compatibility."""
|
|
31
|
+
if key == "logits":
|
|
32
|
+
return self.logits
|
|
33
|
+
raise KeyError(f"Key {key} not found in model output")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MockTransformerModel(torch.nn.Module):
|
|
37
|
+
"""Mock transformer model that mimics the structure of HuggingFace models."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, vocab_size: int, device: torch.device | str | int = "cpu"):
|
|
40
|
+
super().__init__()
|
|
41
|
+
device = torch.device(device)
|
|
42
|
+
self.config = MockTransformerConfig(vocab_size)
|
|
43
|
+
# Simple embedding layer that maps tokens to logits
|
|
44
|
+
self.embedding = torch.nn.Embedding(vocab_size, vocab_size, device=device)
|
|
45
|
+
self.device = device
|
|
46
|
+
|
|
47
|
+
def forward(self, input_ids, attention_mask=None, **kwargs):
|
|
48
|
+
"""Forward pass that returns logits in the expected format."""
|
|
49
|
+
# Get embeddings (which we'll use as logits for simplicity)
|
|
50
|
+
logits = self.embedding(input_ids.to(self.device) % self.config.vocab_size)
|
|
51
|
+
# Return output object similar to transformers models
|
|
52
|
+
return MockTransformerOutput(logits)
|
|
53
|
+
|
|
54
|
+
def get_tokenizer(self):
|
|
55
|
+
from transformers import AutoTokenizer
|
|
56
|
+
|
|
57
|
+
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DummyStrDataLoader:
|
|
61
|
+
"""A dummy dataloader that yields random strings for LLM testing."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, batch_size=0):
|
|
64
|
+
if isinstance(batch_size, tuple):
|
|
65
|
+
batch_size = torch.Size(batch_size).numel()
|
|
66
|
+
self.batch_size = batch_size
|
|
67
|
+
|
|
68
|
+
def generate_random_string(self, length=10):
|
|
69
|
+
"""Generate a random string of a given length."""
|
|
70
|
+
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
|
|
71
|
+
|
|
72
|
+
def __iter__(self):
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def __next__(self):
|
|
76
|
+
if self.batch_size == 0:
|
|
77
|
+
return {"text": self.generate_random_string()}
|
|
78
|
+
else:
|
|
79
|
+
return {
|
|
80
|
+
"query": [self.generate_random_string() for _ in range(self.batch_size)]
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class DummyTensorDataLoader:
|
|
85
|
+
"""A dummy dataloader that yields random token tensors for LLM testing."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, batch_size=0, max_length=10, padding=False):
|
|
88
|
+
if isinstance(batch_size, tuple):
|
|
89
|
+
batch_size = torch.Size(batch_size).numel()
|
|
90
|
+
self.batch_size = batch_size
|
|
91
|
+
self.max_length = max_length
|
|
92
|
+
self.padding = padding
|
|
93
|
+
|
|
94
|
+
def generate_random_tensor(self):
|
|
95
|
+
"""Generate a tensor of random int64 values."""
|
|
96
|
+
length = random.randint(1, self.max_length)
|
|
97
|
+
rt = torch.randint(1, 10000, (length,))
|
|
98
|
+
return rt
|
|
99
|
+
|
|
100
|
+
def pad_tensor(self, tensor):
|
|
101
|
+
"""Pad a tensor to the maximum length."""
|
|
102
|
+
padding_length = self.max_length - len(tensor)
|
|
103
|
+
return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor))
|
|
104
|
+
|
|
105
|
+
def __iter__(self):
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
def __next__(self):
|
|
109
|
+
if self.batch_size == 0:
|
|
110
|
+
tensor = self.generate_random_tensor()
|
|
111
|
+
tokens = self.pad_tensor(tensor) if self.padding else tensor
|
|
112
|
+
else:
|
|
113
|
+
tensors = [self.generate_random_tensor() for _ in range(self.batch_size)]
|
|
114
|
+
if self.padding:
|
|
115
|
+
tensors = [self.pad_tensor(tensor) for tensor in tensors]
|
|
116
|
+
tokens = torch.stack(tensors)
|
|
117
|
+
else:
|
|
118
|
+
tokens = tensors
|
|
119
|
+
return {"tokens": tokens, "attention_mask": tokens != 0}
|