torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from .gsm8k import GSM8KRewardParser
|
|
8
|
+
from .ifeval import IFEvalScoreData, IfEvalScorer
|
|
9
|
+
|
|
10
|
+
__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData"]
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase
|
|
11
|
+
from tensordict.utils import _zip_strict, is_non_tensor
|
|
12
|
+
from torchrl.data import Composite, Unbounded
|
|
13
|
+
from torchrl.envs import Transform
|
|
14
|
+
from torchrl.envs.common import EnvBase
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GSM8KRewardParser(Transform):
|
|
18
|
+
"""Reward parser for GSM8KEnv or make_gsm8k_env.
|
|
19
|
+
|
|
20
|
+
This parser automatically detects the input_mode from the parent environment and handles
|
|
21
|
+
responses accordingly:
|
|
22
|
+
- "history" mode: response is in ("history", "response") and is a History object
|
|
23
|
+
- "text" mode: response is in ("text", "response") and is text
|
|
24
|
+
- "tokens" mode: response is in ("tokens", "response") and is tokens
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
tokenizer (AutoTokenizer from transformers): the tokenizer associated with the model.
|
|
28
|
+
in_keys (list of NestedKey): the input keys. If None, will be automatically determined based on parent's input_mode.
|
|
29
|
+
out_keys (list of NestedKey): the output keys. Defaults to `[ "reward_answer", "reward_think", "reward_right", "reward_contained", "reward", "success"]`.
|
|
30
|
+
eos_token (str): the end of sentence token. Defaults to `tokenizer.eos_token` if not provided.
|
|
31
|
+
set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
|
|
32
|
+
input_mode (Literal["history", "text", "tokens"]): the input mode of the parent environment.
|
|
33
|
+
Defaults to `None` (will be automatically determined based on parent's input_mode).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
tokenizer,
|
|
39
|
+
in_keys: list[NestedKey] | None = None,
|
|
40
|
+
out_keys: list[NestedKey] | None = None,
|
|
41
|
+
eos_token: str | None = None,
|
|
42
|
+
set_done_if_answer: bool = True,
|
|
43
|
+
input_mode: Literal["history", "text", "tokens"] | None = None,
|
|
44
|
+
):
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.tokenizer = tokenizer
|
|
47
|
+
self.eos_token = (
|
|
48
|
+
eos_token
|
|
49
|
+
if eos_token is not None
|
|
50
|
+
else tokenizer.eos_token
|
|
51
|
+
if tokenizer is not None
|
|
52
|
+
else None
|
|
53
|
+
)
|
|
54
|
+
self.set_done_if_answer = set_done_if_answer
|
|
55
|
+
self._input_mode = input_mode
|
|
56
|
+
|
|
57
|
+
if out_keys is None:
|
|
58
|
+
out_keys = [
|
|
59
|
+
"reward_answer",
|
|
60
|
+
"reward_think",
|
|
61
|
+
"reward_right",
|
|
62
|
+
"reward_contained",
|
|
63
|
+
"reward",
|
|
64
|
+
"success",
|
|
65
|
+
]
|
|
66
|
+
super().__init__()
|
|
67
|
+
if in_keys is not None:
|
|
68
|
+
self.in_keys = in_keys
|
|
69
|
+
self.out_keys = out_keys
|
|
70
|
+
|
|
71
|
+
def _maybe_get_in_keys(self):
|
|
72
|
+
if not self.in_keys:
|
|
73
|
+
parent = getattr(self, "parent", None)
|
|
74
|
+
if parent is not None:
|
|
75
|
+
if getattr(parent, "base_env", None) is not None:
|
|
76
|
+
if getattr(parent.base_env, "input_mode", None) == "history":
|
|
77
|
+
self.in_keys = [("history", "full"), "answer"]
|
|
78
|
+
elif getattr(parent.base_env, "input_mode", None) == "text":
|
|
79
|
+
self.in_keys = [("text", "full"), "answer"]
|
|
80
|
+
elif getattr(parent.base_env, "input_mode", None) == "tokens":
|
|
81
|
+
self.in_keys = [("tokens", "full"), "answer"]
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"No base env found for {self} with container {self.container}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def set_container(self, container: Transform | EnvBase) -> None:
|
|
88
|
+
result = super().set_container(container)
|
|
89
|
+
self._maybe_get_in_keys()
|
|
90
|
+
return result
|
|
91
|
+
|
|
92
|
+
_input_mode = None
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def input_mode(self):
|
|
96
|
+
if self._input_mode is None:
|
|
97
|
+
input_mode = (
|
|
98
|
+
getattr(self.parent, "input_mode", "history")
|
|
99
|
+
if hasattr(self, "parent") and self.parent is not None
|
|
100
|
+
else "history"
|
|
101
|
+
)
|
|
102
|
+
self._input_mode = input_mode
|
|
103
|
+
return self._input_mode
|
|
104
|
+
|
|
105
|
+
def _step(
|
|
106
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
107
|
+
) -> TensorDictBase:
|
|
108
|
+
from xml.etree import ElementTree as ET
|
|
109
|
+
|
|
110
|
+
if next_tensordict.batch_dims > 1:
|
|
111
|
+
with tensordict.view(-1) as td_view, next_tensordict.view(
|
|
112
|
+
-1
|
|
113
|
+
) as next_td_view:
|
|
114
|
+
self._step(td_view, next_td_view)
|
|
115
|
+
# did update in place
|
|
116
|
+
return next_tensordict
|
|
117
|
+
|
|
118
|
+
# Get the completion based on input_mode
|
|
119
|
+
self._maybe_get_in_keys()
|
|
120
|
+
responses = tensordict[self.in_keys[0]] # batch_size, grpo_size, L
|
|
121
|
+
|
|
122
|
+
# Handle different response types based on input_mode
|
|
123
|
+
input_mode = self.input_mode
|
|
124
|
+
if input_mode == "history":
|
|
125
|
+
# responses is a History object, extract the text content
|
|
126
|
+
responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
|
|
127
|
+
if hasattr(responses, "content"):
|
|
128
|
+
# If it's a History object with content attribute
|
|
129
|
+
text_completion = responses.content
|
|
130
|
+
if is_non_tensor(text_completion):
|
|
131
|
+
text_completion = text_completion.tolist()
|
|
132
|
+
if not isinstance(text_completion, list):
|
|
133
|
+
text_completion = [text_completion]
|
|
134
|
+
elif hasattr(responses, "apply_chat_template"):
|
|
135
|
+
# If it's a History object, apply chat template to get text
|
|
136
|
+
text_completion = responses.apply_chat_template(
|
|
137
|
+
tokenizer=self.tokenizer, add_generation_prompt=False
|
|
138
|
+
)
|
|
139
|
+
if not isinstance(text_completion, list):
|
|
140
|
+
text_completion = [text_completion]
|
|
141
|
+
else:
|
|
142
|
+
# Fallback: try to convert to string
|
|
143
|
+
text_completion = [str(responses)]
|
|
144
|
+
elif input_mode == "text":
|
|
145
|
+
# responses is already text
|
|
146
|
+
if isinstance(responses, str):
|
|
147
|
+
text_completion = [
|
|
148
|
+
responses for _ in range(next_tensordict.batch_size[0])
|
|
149
|
+
]
|
|
150
|
+
elif not isinstance(responses, list):
|
|
151
|
+
text_completion = [responses]
|
|
152
|
+
else:
|
|
153
|
+
text_completion = responses
|
|
154
|
+
elif input_mode == "tokens":
|
|
155
|
+
# responses is tokens, need to decode
|
|
156
|
+
if isinstance(responses, torch.Tensor):
|
|
157
|
+
if responses.ndim == 3:
|
|
158
|
+
batch_size, grpo_size, _ = responses.shape
|
|
159
|
+
# decode
|
|
160
|
+
text_completion = self.tokenizer.decode(
|
|
161
|
+
responses.flatten(0, 1).tolist()
|
|
162
|
+
)
|
|
163
|
+
if not isinstance(text_completion, list):
|
|
164
|
+
text_completion = [
|
|
165
|
+
text_completion for _ in range(next_tensordict.batch_size[0])
|
|
166
|
+
]
|
|
167
|
+
else:
|
|
168
|
+
# Assume it's already a list of token sequences
|
|
169
|
+
text_completion = []
|
|
170
|
+
for token_seq in responses:
|
|
171
|
+
if isinstance(token_seq, torch.Tensor):
|
|
172
|
+
text_completion.append(
|
|
173
|
+
self.tokenizer.decode(token_seq.tolist())
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
text_completion.append(str(token_seq))
|
|
177
|
+
else:
|
|
178
|
+
raise ValueError(f"Unknown input_mode: {input_mode}")
|
|
179
|
+
|
|
180
|
+
if self.eos_token is not None:
|
|
181
|
+
text_completion = [r.removesuffix(self.eos_token) for r in text_completion]
|
|
182
|
+
answers = next_tensordict[self.in_keys[1]] # batch_size, grpo_size
|
|
183
|
+
|
|
184
|
+
# Decomposed reward
|
|
185
|
+
tds = []
|
|
186
|
+
# torchrl_logger.info(f"{answers=}")
|
|
187
|
+
# torchrl_logger.info(f"{text_completion=}")
|
|
188
|
+
for answer, compl in _zip_strict(answers, text_completion):
|
|
189
|
+
try:
|
|
190
|
+
if not compl.startswith("<think>"):
|
|
191
|
+
compl = "<think>" + compl
|
|
192
|
+
if compl.endswith("<|im_end|>"):
|
|
193
|
+
compl = compl.removesuffix("<|im_end|>")
|
|
194
|
+
cot, potential_answer = self.extract_tags(compl)
|
|
195
|
+
except ET.ParseError:
|
|
196
|
+
cot, potential_answer = ("", "")
|
|
197
|
+
if potential_answer is None:
|
|
198
|
+
potential_answer = ""
|
|
199
|
+
if cot is None:
|
|
200
|
+
cot = ""
|
|
201
|
+
# TODO: in tune, the answer is parsed during dataloading
|
|
202
|
+
# we could create a similar dataclass for both proposed and real answer
|
|
203
|
+
# With tensorclass comparison should be easy
|
|
204
|
+
cot_orig, answer = answer.split("#### ")
|
|
205
|
+
tds.append(
|
|
206
|
+
self._single_shaped_correctness_reward(
|
|
207
|
+
answer, [potential_answer], [cot]
|
|
208
|
+
)
|
|
209
|
+
)
|
|
210
|
+
tds = torch.stack(tds)
|
|
211
|
+
if isinstance(responses, torch.Tensor) and responses.ndim == 3:
|
|
212
|
+
batch_size, grpo_size, _ = responses.shape
|
|
213
|
+
tds = tds.reshape(batch_size, grpo_size)
|
|
214
|
+
# Rewards need to have shape broadcastable to [batch x tokens x 1]
|
|
215
|
+
tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))
|
|
216
|
+
# Add the rewards, in case some have already been written
|
|
217
|
+
next_td_exist = next_tensordict.select(*tds.keys(True, True), strict=False)
|
|
218
|
+
if not next_td_exist.is_empty():
|
|
219
|
+
tds = tds.add(
|
|
220
|
+
next_td_exist, default=torch.zeros((), device=next_tensordict.device)
|
|
221
|
+
)
|
|
222
|
+
next_tensordict = next_tensordict.update(tds)
|
|
223
|
+
if (
|
|
224
|
+
self.set_done_if_answer
|
|
225
|
+
and (reward_answer := (next_tensordict["reward_answer"] > 0)).any()
|
|
226
|
+
):
|
|
227
|
+
done = next_tensordict.get("done")
|
|
228
|
+
if done is not None:
|
|
229
|
+
next_tensordict.set("done", reward_answer.view_as(done) | done)
|
|
230
|
+
terminated = next_tensordict.get("terminated")
|
|
231
|
+
if terminated is not None:
|
|
232
|
+
next_tensordict.set(
|
|
233
|
+
"terminated", reward_answer.view_as(terminated) | terminated
|
|
234
|
+
)
|
|
235
|
+
return next_tensordict
|
|
236
|
+
|
|
237
|
+
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
|
|
238
|
+
shape = reward_spec.shape + (1, 1)
|
|
239
|
+
reward_spec.update(
|
|
240
|
+
Composite(
|
|
241
|
+
reward_answer=Unbounded(shape),
|
|
242
|
+
reward_think=Unbounded(shape),
|
|
243
|
+
reward_right=Unbounded(shape),
|
|
244
|
+
reward_contained=Unbounded(shape),
|
|
245
|
+
reward=Unbounded(shape),
|
|
246
|
+
success=Unbounded(shape, dtype=torch.bool),
|
|
247
|
+
)
|
|
248
|
+
)
|
|
249
|
+
return reward_spec
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def _single_shaped_correctness_reward(
|
|
253
|
+
cls, true_answer: str, potential_answer: list[str], cot: list[str]
|
|
254
|
+
) -> TensorDict:
|
|
255
|
+
# TODO: In tune, these end up being lists
|
|
256
|
+
# torchrl_logger.info(f"{potential_answer=}")
|
|
257
|
+
# torchrl_logger.info(f"{true_answer=}")
|
|
258
|
+
if isinstance(potential_answer, str):
|
|
259
|
+
potential_answer = [potential_answer]
|
|
260
|
+
if isinstance(cot, str):
|
|
261
|
+
cot = [cot]
|
|
262
|
+
|
|
263
|
+
# Format quality rewards (always applied)
|
|
264
|
+
reward_answer = 5.0 * (len(potential_answer) == 1)
|
|
265
|
+
reward_think = 5.0 * (len(cot) == 1)
|
|
266
|
+
|
|
267
|
+
# Answer correctness rewards
|
|
268
|
+
reward_right = 20.0 * (
|
|
269
|
+
any(attempt == true_answer for attempt in potential_answer)
|
|
270
|
+
)
|
|
271
|
+
reward_contained = 10.0 * (
|
|
272
|
+
any((true_answer in attempt) for attempt in potential_answer)
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
success = len(potential_answer) > 0 and potential_answer[-1] == true_answer
|
|
276
|
+
|
|
277
|
+
# Base success reward (lower than before to make format quality more important)
|
|
278
|
+
base_success_reward = 60.0 if success else 0.0
|
|
279
|
+
|
|
280
|
+
# Compose the rewards - always include format quality, even when successful
|
|
281
|
+
reward = (
|
|
282
|
+
base_success_reward
|
|
283
|
+
+ reward_answer
|
|
284
|
+
+ reward_think
|
|
285
|
+
+ reward_contained
|
|
286
|
+
+ reward_right
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
rewards = TensorDict(
|
|
290
|
+
reward_answer=reward_answer,
|
|
291
|
+
reward_think=reward_think,
|
|
292
|
+
reward_right=reward_right,
|
|
293
|
+
reward_contained=reward_contained,
|
|
294
|
+
reward=reward,
|
|
295
|
+
success=success,
|
|
296
|
+
)
|
|
297
|
+
return rewards
|
|
298
|
+
|
|
299
|
+
@staticmethod
|
|
300
|
+
def extract_tags(text: str) -> tuple[str, str]:
|
|
301
|
+
"""Parse XML-like tags from text.
|
|
302
|
+
|
|
303
|
+
Returns: a dictionary with keys 'think' and 'answer'.
|
|
304
|
+
The values are lists of strings, with each string being the content of a tag.
|
|
305
|
+
|
|
306
|
+
"""
|
|
307
|
+
from xml.etree import ElementTree as ET
|
|
308
|
+
|
|
309
|
+
xml_string = f"<root>{text}</root>"
|
|
310
|
+
try:
|
|
311
|
+
root = ET.fromstring(xml_string)
|
|
312
|
+
except ET.ParseError:
|
|
313
|
+
return ("", "")
|
|
314
|
+
|
|
315
|
+
think_elem = root.find("think")
|
|
316
|
+
answer_elem = root.find("answer")
|
|
317
|
+
return (
|
|
318
|
+
think_elem.text
|
|
319
|
+
if think_elem is not None and think_elem.text is not None
|
|
320
|
+
else "",
|
|
321
|
+
answer_elem.text
|
|
322
|
+
if answer_elem is not None and answer_elem.text is not None
|
|
323
|
+
else "",
|
|
324
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Adapted Code from SkyThought
|
|
2
|
+
|
|
3
|
+
This project includes code adapted from [SkyThought](https://github.com/NovaSky-AI/SkyThought), specifically the file
|
|
4
|
+
[`ifeval_scorer.py`](https://github.com/NovaSky-AI/SkyThought/blob/2e5db2b26be63c5545d93be4ad08f5ca46449776/skythought/evals/scoring/ifeval/ifeval_scorer.py).
|
|
5
|
+
|
|
6
|
+
Parts of these files are themselves copied from other sources with a similar license.
|
|
7
|
+
|
|
8
|
+
The original code is distributed under the Apache 2.0 license, which can be found in the SkyThought repository: [Apache 2.0 License](https://github.com/NovaSky-AI/SkyThought/blob/main/LICENSE).
|
|
9
|
+
|
|
10
|
+
### Modifications
|
|
11
|
+
|
|
12
|
+
Modifications were made to the original code according to the terms of the Apache 2.0 license. The changes include
|
|
13
|
+
TorchRL formatting of the data using TensorDict and TorchRL's transforms.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from ._scorer import IFEvalScoreData, IfEvalScorer
|
|
9
|
+
|
|
10
|
+
__all__ = ["IfEvalScorer", "IFEvalScoreData"]
|