torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env
|
|
8
|
+
from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"make_gsm8k_env",
|
|
12
|
+
"GSM8KPrepareQuestion",
|
|
13
|
+
"GSM8KEnv",
|
|
14
|
+
"IFEvalEnv",
|
|
15
|
+
"IFEvalData",
|
|
16
|
+
"IfEvalScorer",
|
|
17
|
+
]
|
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from typing import Any, Literal, TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from tensordict import NestedKey, TensorDict, TensorDictBase
|
|
13
|
+
from tensordict.tensorclass import NonTensorData, NonTensorStack
|
|
14
|
+
from tensordict.utils import _zip_strict
|
|
15
|
+
from torch.utils.data import DataLoader
|
|
16
|
+
from torchrl.data import TensorSpec
|
|
17
|
+
from torchrl.envs import StepCounter, Transform
|
|
18
|
+
|
|
19
|
+
from torchrl.envs.llm.chat import DatasetChatEnv
|
|
20
|
+
|
|
21
|
+
from torchrl.envs.llm.envs import LLMEnv
|
|
22
|
+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
import transformers
|
|
26
|
+
|
|
27
|
+
BASE_PROMPT = (
|
|
28
|
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "
|
|
29
|
+
"The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
|
|
30
|
+
"The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, "
|
|
31
|
+
"i.e., <think>reasoning process here</think> <answer>answer here</answer>. User: %s. Assistant: <think>"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GSM8KPrepareQuestion(Transform):
|
|
36
|
+
"""A transform to prepare the prompt when using GSM8k within an LLMEnv."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
in_keys: list[NestedKey] | None = None,
|
|
41
|
+
out_keys: list[NestedKey] | None = None,
|
|
42
|
+
):
|
|
43
|
+
if in_keys is None:
|
|
44
|
+
in_keys = ["text"]
|
|
45
|
+
if out_keys is None:
|
|
46
|
+
out_keys = list(in_keys)
|
|
47
|
+
super().__init__(in_keys, out_keys)
|
|
48
|
+
|
|
49
|
+
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
50
|
+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
|
|
51
|
+
string = tensordict.get(in_key)
|
|
52
|
+
tensordict.set(out_key, self._modify_str(string))
|
|
53
|
+
return tensordict
|
|
54
|
+
|
|
55
|
+
def _modify_str(
|
|
56
|
+
self, obs: str | list[str] | NonTensorData | NonTensorStack
|
|
57
|
+
) -> NonTensorData | NonTensorStack:
|
|
58
|
+
if isinstance(obs, NonTensorData):
|
|
59
|
+
return self._modify_str(obs.data)
|
|
60
|
+
if isinstance(obs, NonTensorStack):
|
|
61
|
+
return self._modify_str(obs.tolist())
|
|
62
|
+
if isinstance(obs, list):
|
|
63
|
+
return NonTensorStack(*[BASE_PROMPT % obs for obs in obs])
|
|
64
|
+
return NonTensorData(BASE_PROMPT % obs)
|
|
65
|
+
|
|
66
|
+
def _apply_transform(self, obs: torch.Tensor) -> None:
|
|
67
|
+
return obs
|
|
68
|
+
|
|
69
|
+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
|
|
70
|
+
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
|
|
71
|
+
if out_key != in_key:
|
|
72
|
+
observation_spec[out_key] = observation_spec[in_key].clone()
|
|
73
|
+
return observation_spec
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _collate_fn(batch):
|
|
77
|
+
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
|
|
78
|
+
batch.rename_key_("question", "query")
|
|
79
|
+
return batch
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def make_gsm8k_env(
|
|
83
|
+
dataset: str = "openai/gsm8k",
|
|
84
|
+
num_envs: int = 1,
|
|
85
|
+
repeats: int | None = None,
|
|
86
|
+
batch_size_dl: int = 1,
|
|
87
|
+
seed: int | None = None,
|
|
88
|
+
group_repeats: bool = False,
|
|
89
|
+
tokenizer: transformers.PretrainedTokenizer | None = None, # noqa
|
|
90
|
+
):
|
|
91
|
+
"""A builder for an LLMEnv-based GSM8K environment.
|
|
92
|
+
|
|
93
|
+
.. note:: Prefer `torchrl.envs.llm.GSM8KEnv` to interact with this dataset.
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
warnings.warn("This constructor is to be deprecated. Use GSM8KEnv instead.")
|
|
97
|
+
from datasets import load_dataset
|
|
98
|
+
|
|
99
|
+
dataset = load_dataset(dataset, "main")
|
|
100
|
+
train_dataset = dataset["train"]
|
|
101
|
+
|
|
102
|
+
# Env
|
|
103
|
+
if seed is None:
|
|
104
|
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
|
105
|
+
generator = torch.Generator(device=torch.get_default_device())
|
|
106
|
+
generator.manual_seed(seed)
|
|
107
|
+
|
|
108
|
+
dataloader = DataLoader( # noqa: TOR401
|
|
109
|
+
train_dataset,
|
|
110
|
+
batch_size=batch_size_dl,
|
|
111
|
+
shuffle=True,
|
|
112
|
+
collate_fn=_collate_fn,
|
|
113
|
+
generator=generator,
|
|
114
|
+
)
|
|
115
|
+
env = LLMEnv.from_dataloader(
|
|
116
|
+
dataloader=dataloader,
|
|
117
|
+
# tokenizer=tokenizer,
|
|
118
|
+
from_text=True,
|
|
119
|
+
batch_size=(num_envs,),
|
|
120
|
+
repeats=repeats,
|
|
121
|
+
group_repeats=group_repeats,
|
|
122
|
+
# assign_reward=True,
|
|
123
|
+
)
|
|
124
|
+
env.insert_transform(0, GSM8KPrepareQuestion())
|
|
125
|
+
|
|
126
|
+
# Finally, we want the env to stop after the first step
|
|
127
|
+
env.append_transform(StepCounter(max_steps=1))
|
|
128
|
+
|
|
129
|
+
if tokenizer is not None:
|
|
130
|
+
env.append_transform(
|
|
131
|
+
GSM8KRewardParser(
|
|
132
|
+
tokenizer=tokenizer,
|
|
133
|
+
input_mode="text",
|
|
134
|
+
in_keys=["text_response", "answer"],
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
warnings.warn("No tokenizer specified - reward will not be assigned.")
|
|
139
|
+
|
|
140
|
+
return env
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class GSM8KEnv(DatasetChatEnv):
|
|
144
|
+
r"""GSM8K dataset environment.
|
|
145
|
+
|
|
146
|
+
Keyword Args:
|
|
147
|
+
dataset (str, optional): The name of the dataset. Defaults to `"gsm8k"`.
|
|
148
|
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
|
|
149
|
+
num_envs (int, optional): The number of environments to create. Defaults to `1`.
|
|
150
|
+
repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
|
|
151
|
+
based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
|
|
152
|
+
batch_size_dl (int, optional): The batch size for data loading. Defaults to `1`.
|
|
153
|
+
seed (int | None, optional): The random seed for reproducibility. If `None`, a random seed is used. Defaults to `None`.
|
|
154
|
+
group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
|
|
155
|
+
tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
|
|
156
|
+
|
|
157
|
+
.. note:: It is recommended to pass a tokenizer to the environment. This is an easy way to ensure that the
|
|
158
|
+
template applied to the chat history is consistent with the format required by the model.
|
|
159
|
+
|
|
160
|
+
device (torch.device | None, optional): The device to use for computations. Defaults to None.
|
|
161
|
+
template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
|
|
162
|
+
apply_template (bool | None, optional): Whether to apply the template to the text. Defaults to `False`.
|
|
163
|
+
compute_reward (bool, optional): Whether to compute rewards. Defaults to `True`.
|
|
164
|
+
collate_fn (Callable | None, optional): A custom collate function for data loading. If `None`, a default
|
|
165
|
+
collate function is used. Defaults to `None`.
|
|
166
|
+
max_steps (int, optional): The maximum number of steps allowed in an episode. Defaults to `1`.
|
|
167
|
+
input_mode (Literal["history", "text", "tokens"], optional): The mode of input to use. Defaults to `"history"`.
|
|
168
|
+
ray_backend (bool, optional): Whether to use the Ray backend for data loading. Defaults to `False`.
|
|
169
|
+
Using this backend allows for explicit resource control and avoids serialization issues, as well as
|
|
170
|
+
sharing the same dataloader across multiple environments and actors.
|
|
171
|
+
dataloader_actor_name (str, optional): Name of the Ray actor to use for data loading.
|
|
172
|
+
Defaults to `"gsm8k_dataloader"`.
|
|
173
|
+
|
|
174
|
+
Examples:
|
|
175
|
+
>>> import transformers
|
|
176
|
+
>>> from torchrl.envs.llm.datasets.gsm8k import GSM8KEnv
|
|
177
|
+
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
|
|
178
|
+
>>> env = GSM8KEnv(tokenizer=tokenizer, apply_template=True)
|
|
179
|
+
>>> r = env.reset()
|
|
180
|
+
>>> assert "history" in r
|
|
181
|
+
>>> # We have an instruction step (role="system") and a question (role="user")
|
|
182
|
+
>>> assert r["history"].shape == (1, 2)
|
|
183
|
+
>>> assert "text" in r
|
|
184
|
+
>>> r = r.clone()
|
|
185
|
+
>>> print(r)
|
|
186
|
+
LazyStackedTensorDict(
|
|
187
|
+
fields={
|
|
188
|
+
answer: NonTensorStack(
|
|
189
|
+
['Adam bought 3 sandwiches, so he paid 3 * 3 = $<<...,
|
|
190
|
+
batch_size=torch.Size([1]),
|
|
191
|
+
device=None),
|
|
192
|
+
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
193
|
+
history: History(
|
|
194
|
+
content=NonTensorStack(
|
|
195
|
+
[['A conversation between User and Assistant. The ...,
|
|
196
|
+
batch_size=torch.Size([1, 2]),
|
|
197
|
+
device=None),
|
|
198
|
+
role=NonTensorStack(
|
|
199
|
+
[['system', 'user']],
|
|
200
|
+
batch_size=torch.Size([1, 2]),
|
|
201
|
+
device=None),
|
|
202
|
+
batch_size=torch.Size([1, 2]),
|
|
203
|
+
device=None,
|
|
204
|
+
is_shared=False),
|
|
205
|
+
step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
206
|
+
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
207
|
+
text: NonTensorStack(
|
|
208
|
+
['<|im_start|>system\nA conversation between User ...,
|
|
209
|
+
batch_size=torch.Size([1]),
|
|
210
|
+
device=None),
|
|
211
|
+
truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
212
|
+
exclusive_fields={
|
|
213
|
+
},
|
|
214
|
+
batch_size=torch.Size([1]),
|
|
215
|
+
device=None,
|
|
216
|
+
is_shared=False,
|
|
217
|
+
stack_dim=0)
|
|
218
|
+
>>> response = "<think>First, calculate the total number of snakes in the breeding balls. There are 3 breeding balls with 8 snakes each, so 3 * 8 = 24 snakes. Next, calculate the number of snakes in the additional pairs. There are 6 pairs of snakes, and each pair has 2 snakes, so 6 * 2 = 12 snakes. Finally, add the number of snakes from the breeding balls and the additional pairs: 24 + 12 = 36 snakes.</think> <answer>Mary saw a total of 36 snakes.</answer><|im_end|>"
|
|
219
|
+
>>> r["text_response"] = [response]
|
|
220
|
+
>>> s = env.step(r)
|
|
221
|
+
>>> print(s)
|
|
222
|
+
LazyStackedTensorDict(
|
|
223
|
+
fields={
|
|
224
|
+
answer: NonTensorStack(
|
|
225
|
+
['Adam bought 3 sandwiches, so he paid 3 * 3 = $<<...,
|
|
226
|
+
batch_size=torch.Size([1]),
|
|
227
|
+
device=None),
|
|
228
|
+
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
229
|
+
history: History(
|
|
230
|
+
content=NonTensorStack(
|
|
231
|
+
[['A conversation between User and Assistant. The ...,
|
|
232
|
+
batch_size=torch.Size([1, 2]),
|
|
233
|
+
device=None),
|
|
234
|
+
role=NonTensorStack(
|
|
235
|
+
[['system', 'user']],
|
|
236
|
+
batch_size=torch.Size([1, 2]),
|
|
237
|
+
device=None),
|
|
238
|
+
batch_size=torch.Size([1, 2]),
|
|
239
|
+
device=None,
|
|
240
|
+
is_shared=False),
|
|
241
|
+
next: LazyStackedTensorDict(
|
|
242
|
+
fields={
|
|
243
|
+
answer: NonTensorStack(
|
|
244
|
+
['Adam bought 3 sandwiches, so he paid 3 * 3 = $<<...,
|
|
245
|
+
batch_size=torch.Size([1]),
|
|
246
|
+
device=None),
|
|
247
|
+
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
248
|
+
history: History(
|
|
249
|
+
content=NonTensorStack(
|
|
250
|
+
[['A conversation between User and Assistant. The ...,
|
|
251
|
+
batch_size=torch.Size([1, 3]),
|
|
252
|
+
device=None),
|
|
253
|
+
role=NonTensorStack(
|
|
254
|
+
[['system', 'user', 'assistant']],
|
|
255
|
+
batch_size=torch.Size([1, 3]),
|
|
256
|
+
device=None),
|
|
257
|
+
batch_size=torch.Size([1, 3]),
|
|
258
|
+
device=None,
|
|
259
|
+
is_shared=False),
|
|
260
|
+
reward: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
261
|
+
reward_answer: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
262
|
+
reward_contained: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
263
|
+
reward_right: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
264
|
+
reward_think: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
265
|
+
step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
266
|
+
success: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
267
|
+
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
268
|
+
text: NonTensorStack(
|
|
269
|
+
['<|im_start|>system\nA conversation between User ...,
|
|
270
|
+
batch_size=torch.Size([1]),
|
|
271
|
+
device=None),
|
|
272
|
+
truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
273
|
+
exclusive_fields={
|
|
274
|
+
},
|
|
275
|
+
batch_size=torch.Size([1]),
|
|
276
|
+
device=None,
|
|
277
|
+
is_shared=False,
|
|
278
|
+
stack_dim=0),
|
|
279
|
+
step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
280
|
+
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
281
|
+
text: NonTensorStack(
|
|
282
|
+
['<|im_start|>system\nA conversation between User ...,
|
|
283
|
+
batch_size=torch.Size([1]),
|
|
284
|
+
device=None),
|
|
285
|
+
text_response: NonTensorStack(
|
|
286
|
+
['<think>First, calculate the total number of snak...,
|
|
287
|
+
batch_size=torch.Size([1]),
|
|
288
|
+
device=None),
|
|
289
|
+
truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
290
|
+
exclusive_fields={
|
|
291
|
+
},
|
|
292
|
+
batch_size=torch.Size([1]),
|
|
293
|
+
device=None,
|
|
294
|
+
is_shared=False,
|
|
295
|
+
stack_dim=0)
|
|
296
|
+
>>> assert s["next", "reward"] >= 10
|
|
297
|
+
>>> assert s["next", "done"].all()
|
|
298
|
+
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
|
302
|
+
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
|
|
303
|
+
The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively,
|
|
304
|
+
i.e., <think>reasoning process here</think> <answer>answer here</answer>. The answer should be a number."""
|
|
305
|
+
|
|
306
|
+
def __init__(
|
|
307
|
+
self,
|
|
308
|
+
*,
|
|
309
|
+
dataset: str = "openai/gsm8k",
|
|
310
|
+
shuffle: bool = True,
|
|
311
|
+
num_envs: int = 1,
|
|
312
|
+
repeats: int | None = None,
|
|
313
|
+
batch_size_dl: int = 1,
|
|
314
|
+
seed: int | None = None,
|
|
315
|
+
group_repeats: bool = False,
|
|
316
|
+
tokenizer: transformers.AutoTokenizer | None = None, # noqa
|
|
317
|
+
device: torch.device | None = None,
|
|
318
|
+
template_kwargs: dict[str, Any] | None = None,
|
|
319
|
+
apply_template: bool | None = False,
|
|
320
|
+
compute_reward: bool = True,
|
|
321
|
+
collate_fn: Callable | None = None,
|
|
322
|
+
max_steps: int = 1,
|
|
323
|
+
input_mode: Literal["history", "text", "tokens"] = "history",
|
|
324
|
+
ray_backend: bool = False,
|
|
325
|
+
dataloader_actor_name: str | None = None,
|
|
326
|
+
):
|
|
327
|
+
if ray_backend and dataloader_actor_name is None:
|
|
328
|
+
dataloader_actor_name = "gsm8k_dataloader"
|
|
329
|
+
if collate_fn is None:
|
|
330
|
+
collate_fn = _collate_fn
|
|
331
|
+
super().__init__(
|
|
332
|
+
dataset=dataset,
|
|
333
|
+
shuffle=shuffle,
|
|
334
|
+
name="main",
|
|
335
|
+
num_envs=num_envs,
|
|
336
|
+
repeats=repeats,
|
|
337
|
+
batch_size_dl=batch_size_dl,
|
|
338
|
+
seed=seed,
|
|
339
|
+
group_repeats=group_repeats,
|
|
340
|
+
tokenizer=tokenizer,
|
|
341
|
+
device=device,
|
|
342
|
+
template_kwargs=template_kwargs,
|
|
343
|
+
apply_template=apply_template,
|
|
344
|
+
collate_fn=collate_fn,
|
|
345
|
+
input_mode=input_mode,
|
|
346
|
+
ray_backend=ray_backend,
|
|
347
|
+
dataloader_actor_name=dataloader_actor_name,
|
|
348
|
+
)
|
|
349
|
+
if max_steps:
|
|
350
|
+
self.append_transform(StepCounter(max_steps=max_steps))
|
|
351
|
+
if compute_reward:
|
|
352
|
+
t = GSM8KRewardParser(tokenizer=tokenizer)
|
|
353
|
+
self.append_transform(t)
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
|
|
9
|
+
from typing import Any, Literal, TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from tensordict import NonTensorData, NonTensorStack, TensorClass, TensorDict
|
|
13
|
+
from torchrl.data import Composite, NonTensor, Unbounded
|
|
14
|
+
from torchrl.envs import StepCounter
|
|
15
|
+
from torchrl.envs.llm.chat import DatasetChatEnv
|
|
16
|
+
from torchrl.envs.llm.reward.ifeval import IfEvalScorer
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
import transformers
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class IFEvalData(TensorClass["nocast"]):
|
|
23
|
+
"""A tensorclass for IFEval dta."""
|
|
24
|
+
|
|
25
|
+
key: torch.Tensor
|
|
26
|
+
instruction_id_list: list[str]
|
|
27
|
+
kwargs: list[dict]
|
|
28
|
+
query: str
|
|
29
|
+
|
|
30
|
+
# Reponses and additional fields
|
|
31
|
+
response: str | None = None
|
|
32
|
+
tokens: torch.Tensor | None = None
|
|
33
|
+
tokens_response: torch.Tensor | None = None
|
|
34
|
+
logits: torch.Tensor | None = None
|
|
35
|
+
reward: torch.Tensor | None = None
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def default_spec(
|
|
39
|
+
cls, shape: torch.Size, device: torch.device | None = None
|
|
40
|
+
) -> Composite:
|
|
41
|
+
return Composite(
|
|
42
|
+
key=Unbounded(shape=shape, dtype=torch.int64, device=device),
|
|
43
|
+
instruction_id_list=NonTensor(
|
|
44
|
+
shape=shape,
|
|
45
|
+
device=device,
|
|
46
|
+
feature_dims=0,
|
|
47
|
+
example_data=["punctuation:no_comma"],
|
|
48
|
+
),
|
|
49
|
+
kwargs=NonTensor(
|
|
50
|
+
shape=shape,
|
|
51
|
+
device=device,
|
|
52
|
+
feature_dims=0,
|
|
53
|
+
example_data={
|
|
54
|
+
"num_highlights": None,
|
|
55
|
+
"relation": None,
|
|
56
|
+
"num_placeholders": None,
|
|
57
|
+
},
|
|
58
|
+
),
|
|
59
|
+
query=NonTensor(
|
|
60
|
+
shape=shape,
|
|
61
|
+
device=device,
|
|
62
|
+
example_data="Plan a 2 week Europe trip and visit London, Paris, and Rome. Answer in all caps. The response must contain at least 8 placeholders (i.e., [restaurant]).",
|
|
63
|
+
),
|
|
64
|
+
shape=shape,
|
|
65
|
+
step_mdp_static=True,
|
|
66
|
+
data_cls=cls,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _collate_fn(batch):
|
|
71
|
+
batch = torch.stack([TensorDict.from_any(_batch) for _batch in batch])
|
|
72
|
+
batch.rename_key_("prompt", "query")
|
|
73
|
+
# we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
|
|
74
|
+
instruction_id_list = batch["instruction_id_list"]
|
|
75
|
+
# instruction_id_list should be a list of lists
|
|
76
|
+
instruction_id_list = NonTensorStack(
|
|
77
|
+
*[
|
|
78
|
+
NonTensorData([item] if not isinstance(item, list) else item)
|
|
79
|
+
for item in instruction_id_list
|
|
80
|
+
]
|
|
81
|
+
)
|
|
82
|
+
kwargs = batch["kwargs"]
|
|
83
|
+
kwargs = NonTensorStack(
|
|
84
|
+
*[
|
|
85
|
+
NonTensorData([item] if not isinstance(item, list) else item)
|
|
86
|
+
for item in kwargs
|
|
87
|
+
]
|
|
88
|
+
)
|
|
89
|
+
batch.set("instruction_id_list", instruction_id_list)
|
|
90
|
+
batch.set("kwargs", kwargs)
|
|
91
|
+
# we don't need a tensorclass here
|
|
92
|
+
return batch
|
|
93
|
+
# return IFEvalData.from_tensordict(batch)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class IFEvalEnv(DatasetChatEnv):
|
|
97
|
+
r"""A chat environment based on the IFEval dataset.
|
|
98
|
+
|
|
99
|
+
Keyword Args:
|
|
100
|
+
dataset (str, optional): The name of the dataset. Defaults to `"google/IFeval"`.
|
|
101
|
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
|
|
102
|
+
num_envs (int, optional): The number of environments to create. Defaults to `1`.
|
|
103
|
+
repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
|
|
104
|
+
based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
|
|
105
|
+
batch_size_dl (int, optional): The batch size for data loading. Defaults to `1`.
|
|
106
|
+
seed (int | None, optional): The random seed for reproducibility. If `None`, a random seed is used. Defaults to `None`.
|
|
107
|
+
group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
|
|
108
|
+
tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
|
|
109
|
+
|
|
110
|
+
.. note:: It is recommended to pass a tokenizer to the environment. This is an easy way to ensure that the
|
|
111
|
+
template applied to the chat history is consistent with the format required by the model.
|
|
112
|
+
|
|
113
|
+
device (torch.device | None, optional): The device to use for computations. Defaults to None.
|
|
114
|
+
template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
|
|
115
|
+
apply_template (bool | None, optional): Whether to apply the template to the text. Defaults to `False`.
|
|
116
|
+
compute_reward (bool, optional): Whether to compute rewards. Defaults to `True`.
|
|
117
|
+
collate_fn (Callable | None, optional): A custom collate function for data loading. If `None`, a default
|
|
118
|
+
collate function is used. Defaults to `None`.
|
|
119
|
+
max_steps (int, optional): The maximum number of steps allowed in an episode. Defaults to `1`.
|
|
120
|
+
input_mode (Literal["history", "text", "tokens"], optional): The mode of input to use. Defaults to `"history"`.
|
|
121
|
+
ray_backend (bool, optional): Whether to use the Ray backend for data loading. Defaults to `False`.
|
|
122
|
+
Using this backend allows for explicit resource control and avoids serialization issues, as well as
|
|
123
|
+
sharing the same dataloader across multiple environments and actors.
|
|
124
|
+
dataloader_actor_name (str, optional): Name of the Ray actor to use for data loading.
|
|
125
|
+
Defaults to `"ifeval_dataloader"`.
|
|
126
|
+
|
|
127
|
+
Examples:
|
|
128
|
+
>>> import transformers
|
|
129
|
+
>>> from pprint import pprint
|
|
130
|
+
>>> from torchrl.envs.llm.datasets import IFEvalEnv
|
|
131
|
+
>>> from tensordict import set_list_to_stack
|
|
132
|
+
>>> set_list_to_stack(True).set()
|
|
133
|
+
>>>
|
|
134
|
+
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
|
|
135
|
+
>>> env = IFEvalEnv(tokenizer=tokenizer, apply_template=True)
|
|
136
|
+
>>> r = env.reset()
|
|
137
|
+
>>> print(r)
|
|
138
|
+
LazyStackedTensorDict(
|
|
139
|
+
fields={
|
|
140
|
+
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
141
|
+
history: History(
|
|
142
|
+
content=NonTensorStack(
|
|
143
|
+
[['A conversation between User and Assistant.\nYou...,
|
|
144
|
+
batch_size=torch.Size([1, 2]),
|
|
145
|
+
device=None),
|
|
146
|
+
role=NonTensorStack(
|
|
147
|
+
[['system', 'user']],
|
|
148
|
+
batch_size=torch.Size([1, 2]),
|
|
149
|
+
device=None),
|
|
150
|
+
batch_size=torch.Size([1, 2]),
|
|
151
|
+
device=None,
|
|
152
|
+
is_shared=False),
|
|
153
|
+
instruction_id_list: NonTensorStack(
|
|
154
|
+
[['detectable_content:number_placeholders']],
|
|
155
|
+
batch_size=torch.Size([1, 1]),
|
|
156
|
+
device=None),
|
|
157
|
+
key: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
158
|
+
kwargs: NonTensorStack(
|
|
159
|
+
[[{'num_highlights': None, 'relation': None, 'num_...,
|
|
160
|
+
batch_size=torch.Size([1, 1]),
|
|
161
|
+
device=None),
|
|
162
|
+
step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
163
|
+
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
164
|
+
text: NonTensorStack(
|
|
165
|
+
['<|im_start|>system\nA conversation between User ...,
|
|
166
|
+
batch_size=torch.Size([1]),
|
|
167
|
+
device=None),
|
|
168
|
+
truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
169
|
+
exclusive_fields={
|
|
170
|
+
},
|
|
171
|
+
batch_size=torch.Size([1]),
|
|
172
|
+
device=None,
|
|
173
|
+
is_shared=False,
|
|
174
|
+
stack_dim=0)
|
|
175
|
+
>>> # Print content of conversation so far
|
|
176
|
+
>>> pprint(r["history", "content"])
|
|
177
|
+
[['A conversation between User and Assistant.\n'
|
|
178
|
+
'You are tasked with responding to user queries in a very specific format. \n'
|
|
179
|
+
'When given a task or question, first think through the problem and provide '
|
|
180
|
+
'your thought process between <think> and </think> tags.\n'
|
|
181
|
+
'Then, give your final answer or response between <answer> and </answer> '
|
|
182
|
+
'tags.\n'
|
|
183
|
+
'You will be assessed by the content of the answer block only, so make sure '
|
|
184
|
+
'it contains all the required information, and only that.',
|
|
185
|
+
'Plan a 2 week Europe trip and visit London, Paris, and Rome. Answer in all '
|
|
186
|
+
'caps. The response must contain at least 8 placeholders (i.e., '
|
|
187
|
+
'[restaurant]).']]
|
|
188
|
+
>>> # Actions space: the environment expects an action with key "text_response" containing a (list of) strings
|
|
189
|
+
>>> print(env.action_spec)
|
|
190
|
+
Composite(
|
|
191
|
+
text_response: NonTensor(
|
|
192
|
+
shape=torch.Size([1]),
|
|
193
|
+
space=None,
|
|
194
|
+
device=None,
|
|
195
|
+
dtype=None,
|
|
196
|
+
domain=None,
|
|
197
|
+
example_data=a string),
|
|
198
|
+
device=None,
|
|
199
|
+
shape=torch.Size([1]))
|
|
200
|
+
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
SYSTEM_PROMPT = """You are a helpful AI assistant that follows instructions extremely well.
|
|
204
|
+
|
|
205
|
+
IMPORTANT: You must respond in a specific format for every task:
|
|
206
|
+
|
|
207
|
+
1. First, think through the problem step by step and write your reasoning between <think> and </think> tags
|
|
208
|
+
2. Then, provide your final answer between <answer> and </answer> tags
|
|
209
|
+
|
|
210
|
+
CRITICAL RULES:
|
|
211
|
+
- ALWAYS use <think>...</think> and <answer>...</answer> tags exactly as shown
|
|
212
|
+
- Do NOT use <thought>, <reasoning>, or any other tag variations
|
|
213
|
+
- Your <answer> section will be evaluated, so make it complete and accurate
|
|
214
|
+
- Follow ALL specific requirements in the user's request (formatting, content, etc.)
|
|
215
|
+
- If the user asks for placeholders like [restaurant], include them exactly as requested
|
|
216
|
+
- Pay attention to capitalization, punctuation, and other formatting requirements
|
|
217
|
+
|
|
218
|
+
Example format:
|
|
219
|
+
<think>
|
|
220
|
+
I need to analyze what the user is asking for...
|
|
221
|
+
[Your reasoning here]
|
|
222
|
+
</think>
|
|
223
|
+
<answer>
|
|
224
|
+
[Your final answer here, following all user requirements]
|
|
225
|
+
</answer>"""
|
|
226
|
+
|
|
227
|
+
def __init__(
|
|
228
|
+
self,
|
|
229
|
+
*,
|
|
230
|
+
dataset: str = "google/IFeval",
|
|
231
|
+
shuffle: bool = True,
|
|
232
|
+
num_envs: int = 1,
|
|
233
|
+
repeats: int | None = None,
|
|
234
|
+
batch_size_dl: int = 1,
|
|
235
|
+
seed: int | None = None,
|
|
236
|
+
group_repeats: bool = False,
|
|
237
|
+
tokenizer: transformers.AutoTokenizer | None = None, # noqa
|
|
238
|
+
device: torch.device | None = None,
|
|
239
|
+
template_kwargs: dict[str, Any] | None = None,
|
|
240
|
+
apply_template: bool | None = False,
|
|
241
|
+
compute_reward: bool = True,
|
|
242
|
+
collate_fn: Callable | None = None,
|
|
243
|
+
max_steps: int = 1,
|
|
244
|
+
input_mode: Literal["history", "text", "tokens"] = "history",
|
|
245
|
+
ray_backend: bool = False,
|
|
246
|
+
dataloader_actor_name: str | None = None,
|
|
247
|
+
):
|
|
248
|
+
if ray_backend and dataloader_actor_name is None:
|
|
249
|
+
dataloader_actor_name = "ifeval_dataloader"
|
|
250
|
+
if collate_fn is None:
|
|
251
|
+
collate_fn = _collate_fn
|
|
252
|
+
super().__init__(
|
|
253
|
+
dataset=dataset,
|
|
254
|
+
shuffle=shuffle,
|
|
255
|
+
num_envs=num_envs,
|
|
256
|
+
repeats=repeats,
|
|
257
|
+
batch_size_dl=batch_size_dl,
|
|
258
|
+
seed=seed,
|
|
259
|
+
group_repeats=group_repeats,
|
|
260
|
+
tokenizer=tokenizer,
|
|
261
|
+
device=device,
|
|
262
|
+
template_kwargs=template_kwargs,
|
|
263
|
+
apply_template=apply_template,
|
|
264
|
+
collate_fn=collate_fn,
|
|
265
|
+
input_mode=input_mode,
|
|
266
|
+
data_key="query",
|
|
267
|
+
primers=IFEvalData.default_spec((num_envs,), device),
|
|
268
|
+
ray_backend=ray_backend,
|
|
269
|
+
dataloader_actor_name=dataloader_actor_name,
|
|
270
|
+
)
|
|
271
|
+
if max_steps:
|
|
272
|
+
self.append_transform(StepCounter(max_steps=max_steps))
|
|
273
|
+
if compute_reward:
|
|
274
|
+
self.append_transform(IfEvalScorer())
|