torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Modifications from original script.
|
|
7
|
+
|
|
8
|
+
Modifications include:
|
|
9
|
+
|
|
10
|
+
- TensorDict embedding
|
|
11
|
+
- Modification of key names
|
|
12
|
+
- make IfEvalScorer a TorchRL transform
|
|
13
|
+
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import importlib.util
|
|
19
|
+
import re
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from tensordict import (
|
|
24
|
+
lazy_stack,
|
|
25
|
+
NestedKey,
|
|
26
|
+
NonTensorData,
|
|
27
|
+
TensorClass,
|
|
28
|
+
TensorDict,
|
|
29
|
+
TensorDictBase,
|
|
30
|
+
)
|
|
31
|
+
from tensordict.tensorclass import is_non_tensor
|
|
32
|
+
from torchrl._utils import logger as torchrl_logger
|
|
33
|
+
|
|
34
|
+
from torchrl.data.tensor_specs import Composite, Unbounded
|
|
35
|
+
from torchrl.envs import Transform
|
|
36
|
+
|
|
37
|
+
_has_langdetect = importlib.util.find_spec("langdetect") is not None
|
|
38
|
+
_has_nltk = importlib.util.find_spec("nltk") is not None
|
|
39
|
+
_has_immutabledict = importlib.util.find_spec("immutabledict") is not None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class IFEvalScoreData(TensorClass):
|
|
43
|
+
"""IFEval score container."""
|
|
44
|
+
|
|
45
|
+
prompt_level_strict_acc: torch.Tensor | None
|
|
46
|
+
inst_level_strict_acc: torch.Tensor | None
|
|
47
|
+
prompt_level_loose_acc: torch.Tensor | None
|
|
48
|
+
inst_level_loose_acc: torch.Tensor | None
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def default_spec(
|
|
52
|
+
cls, shape: torch.Size, device: torch.device | None = None
|
|
53
|
+
) -> Composite:
|
|
54
|
+
return Composite(
|
|
55
|
+
prompt_level_strict_acc=Unbounded(
|
|
56
|
+
shape=shape, dtype=torch.bool, device=device
|
|
57
|
+
),
|
|
58
|
+
inst_level_strict_acc=Unbounded(
|
|
59
|
+
shape=shape, dtype=torch.bool, device=device
|
|
60
|
+
),
|
|
61
|
+
prompt_level_loose_acc=Unbounded(
|
|
62
|
+
shape=shape, dtype=torch.bool, device=device
|
|
63
|
+
),
|
|
64
|
+
inst_level_loose_acc=Unbounded(
|
|
65
|
+
shape=shape, dtype=torch.bool, device=device
|
|
66
|
+
),
|
|
67
|
+
data_cls=cls,
|
|
68
|
+
step_mdp_static=True,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def __post_init__(self):
|
|
72
|
+
prompt_level_loose_acc = self.get(
|
|
73
|
+
"prompt_level_loose_acc", as_padded_tensor=True
|
|
74
|
+
)
|
|
75
|
+
inst_level_loose_acc = self.get("inst_level_loose_acc", as_padded_tensor=True)
|
|
76
|
+
prompt_level_strict_acc = self.get(
|
|
77
|
+
"prompt_level_strict_acc", as_padded_tensor=True
|
|
78
|
+
)
|
|
79
|
+
inst_level_strict_acc = self.get("inst_level_strict_acc", as_padded_tensor=True)
|
|
80
|
+
|
|
81
|
+
if prompt_level_loose_acc is None:
|
|
82
|
+
self.prompt_level_loose_acc = torch.zeros(self.batch_size + (1,))
|
|
83
|
+
elif prompt_level_loose_acc.ndim == self.ndim:
|
|
84
|
+
self.prompt_level_loose_acc = prompt_level_loose_acc.unsqueeze(-1)
|
|
85
|
+
|
|
86
|
+
if inst_level_loose_acc is None:
|
|
87
|
+
self.inst_level_loose_acc = torch.zeros(self.batch_size + (1,))
|
|
88
|
+
elif inst_level_loose_acc.ndim == self.ndim:
|
|
89
|
+
self.inst_level_loose_acc = inst_level_loose_acc.unsqueeze(-1)
|
|
90
|
+
|
|
91
|
+
if prompt_level_strict_acc is None:
|
|
92
|
+
self.prompt_level_strict_acc = torch.zeros(self.batch_size + (1,))
|
|
93
|
+
elif prompt_level_strict_acc.ndim == self.ndim:
|
|
94
|
+
self.prompt_level_strict_acc = prompt_level_strict_acc.unsqueeze(-1)
|
|
95
|
+
|
|
96
|
+
if inst_level_strict_acc is None:
|
|
97
|
+
self.inst_level_strict_acc = torch.zeros(self.batch_size + (1,))
|
|
98
|
+
elif inst_level_strict_acc.ndim == self.ndim:
|
|
99
|
+
self.inst_level_strict_acc = inst_level_strict_acc.unsqueeze(-1)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _process_results(
|
|
103
|
+
data: TensorDict,
|
|
104
|
+
response: str | NonTensorData,
|
|
105
|
+
verbose: bool = False,
|
|
106
|
+
prompt: str | None = None,
|
|
107
|
+
) -> IFEvalScoreData:
|
|
108
|
+
if not _has_langdetect:
|
|
109
|
+
raise ImportError("langdetect must be installed to user IFEvalScorer.")
|
|
110
|
+
if not _has_immutabledict:
|
|
111
|
+
raise ImportError("immutabledict must be installed to user IFEvalScorer.")
|
|
112
|
+
|
|
113
|
+
from ._instructions_main import (
|
|
114
|
+
_InputExample,
|
|
115
|
+
_test_instruction_following_loose,
|
|
116
|
+
_test_instruction_following_strict,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if prompt is None:
|
|
120
|
+
prompt = data["text"]
|
|
121
|
+
|
|
122
|
+
inp = _InputExample(
|
|
123
|
+
key=data["key"],
|
|
124
|
+
instruction_id_list=data["instruction_id_list"],
|
|
125
|
+
prompt=prompt if prompt is not None else "",
|
|
126
|
+
kwargs=data["kwargs"],
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if verbose:
|
|
130
|
+
torchrl_logger.info(f"Processing {inp=} {response=}")
|
|
131
|
+
out_strict = _test_instruction_following_strict(inp, response)
|
|
132
|
+
out_loose = _test_instruction_following_loose(inp, response)
|
|
133
|
+
|
|
134
|
+
result = IFEvalScoreData(
|
|
135
|
+
prompt_level_strict_acc=out_strict.follow_all_instructions,
|
|
136
|
+
inst_level_strict_acc=out_strict.follow_instruction_list,
|
|
137
|
+
prompt_level_loose_acc=out_loose.follow_all_instructions,
|
|
138
|
+
inst_level_loose_acc=out_loose.follow_instruction_list,
|
|
139
|
+
batch_size=data.batch_size,
|
|
140
|
+
device=data.device,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if verbose:
|
|
144
|
+
torchrl_logger.info(f"Result: {result.to_dict()=}")
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class IfEvalScorer(Transform):
|
|
149
|
+
"""Scorer for the IF-Eval task.
|
|
150
|
+
|
|
151
|
+
For the IFEval dataset format, see https://huggingface.co/datasets/google/IFEval
|
|
152
|
+
|
|
153
|
+
The score data is written under the `score_key` using the :class:`~torchrl.envs.llm.IFEvalScoreData` data structure.
|
|
154
|
+
Scores can be aggregated on a single reward by using the `aggregate_reward` keyword argument in the constructor, which
|
|
155
|
+
can be a bool or a function.
|
|
156
|
+
|
|
157
|
+
Keyword Args:
|
|
158
|
+
instruction_ids_key (NestedKey, optional): The column name for the list of instruction ids. Defaults to "instruction_id_list".
|
|
159
|
+
prompt_key (NestedKey, optional): The column name for the prompt. Defaults to "text".
|
|
160
|
+
keyword_args_key (NestedKey, optional): The column name for the keyword arguments to the instruction builder. Defaults to "kwargs".
|
|
161
|
+
id_key (NestedKey, optional): The column name for the unique identifier. Defaults to "key".
|
|
162
|
+
response_column (NestedKey, optional): The column name for the response. Defaults to "text_response".
|
|
163
|
+
score_key (NestedKey, optional): The key to store the score. Defaults to "ifeval_score".
|
|
164
|
+
aggregate_reward (bool, callable, optional): Whether to aggregate the reward or not. If a Callable is passed,
|
|
165
|
+
it must take as input an :class:`~torchrl.envs.llm.IFEvalScoreData` instance, and optionally `think_blocks`, `answer_blocks` and `complete` keyword arguments
|
|
166
|
+
containing the list of think and answer blocks, respectively.
|
|
167
|
+
It must return a tensor with shape identical to the env batch-size with an additional trailing singleton dimension.
|
|
168
|
+
Defaults to `True`. The default aggregator is a simple sum over the fields of :class:`~torchrl.envs.llm.IFEvalScoreData`.
|
|
169
|
+
format_weights (list[float], optional): The weights for the format fields (`prompt_level_strict_acc`, `inst_level_strict_acc`,
|
|
170
|
+
`prompt_level_loose_acc`, `inst_level_loose_acc`, in that order). Defaults to `[0.4, 0.3, 0.2, 0.1]`.
|
|
171
|
+
This is only used if `aggregate_reward` is `True` and the default aggregator is used.
|
|
172
|
+
verbose (bool, optional): Whether to print verbose information. Defaults to `False`.
|
|
173
|
+
set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
|
|
174
|
+
|
|
175
|
+
.. note:: `IFEvalScorer` requires the following libraries to be installed: `langdetect`, `nltk` and `immutabledict`.
|
|
176
|
+
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
*,
|
|
182
|
+
instruction_ids_key: NestedKey = "instruction_id_list",
|
|
183
|
+
prompt_key: NestedKey = "text",
|
|
184
|
+
keyword_args_key: NestedKey = "kwargs",
|
|
185
|
+
id_key: NestedKey = "key",
|
|
186
|
+
response_column: NestedKey = "text_response",
|
|
187
|
+
score_key: NestedKey = "ifeval_score",
|
|
188
|
+
aggregate_reward: bool
|
|
189
|
+
| Callable[
|
|
190
|
+
[IFEvalScoreData, list[str] | None, list[str] | None], torch.Tensor
|
|
191
|
+
] = True,
|
|
192
|
+
format_weights: list[float] | None = None,
|
|
193
|
+
verbose: bool = False,
|
|
194
|
+
set_done_if_answer: bool = True,
|
|
195
|
+
):
|
|
196
|
+
self.aggregate_reward = aggregate_reward
|
|
197
|
+
self.score_key = score_key
|
|
198
|
+
self.set_done_if_answer = set_done_if_answer
|
|
199
|
+
out_keys = [self.score_key]
|
|
200
|
+
if aggregate_reward:
|
|
201
|
+
out_keys.append("reward")
|
|
202
|
+
super().__init__(
|
|
203
|
+
in_keys=[
|
|
204
|
+
instruction_ids_key,
|
|
205
|
+
prompt_key,
|
|
206
|
+
keyword_args_key,
|
|
207
|
+
id_key,
|
|
208
|
+
response_column,
|
|
209
|
+
],
|
|
210
|
+
out_keys=out_keys,
|
|
211
|
+
)
|
|
212
|
+
if not _has_langdetect:
|
|
213
|
+
raise ImportError("langdetect must be installed to user IFEvalScorer.")
|
|
214
|
+
if not _has_nltk:
|
|
215
|
+
raise ImportError("nltk must be installed to user IFEvalScorer.")
|
|
216
|
+
self.instruction_ids_key = instruction_ids_key
|
|
217
|
+
self.response_key = response_column
|
|
218
|
+
self.keyword_args_key = keyword_args_key
|
|
219
|
+
self.prompt_key = prompt_key
|
|
220
|
+
self.id_key = id_key
|
|
221
|
+
self.format_weights = (
|
|
222
|
+
format_weights if format_weights is not None else [0.4, 0.3, 0.2, 0.1]
|
|
223
|
+
)
|
|
224
|
+
self.verbose = verbose
|
|
225
|
+
|
|
226
|
+
def default_reward_aggregator(
|
|
227
|
+
self,
|
|
228
|
+
score: IFEvalScoreData,
|
|
229
|
+
think_blocks: list[str] | None = None,
|
|
230
|
+
answer_blocks: list[str] | None = None,
|
|
231
|
+
complete: bool | torch.Tensor | None = None,
|
|
232
|
+
) -> torch.Tensor:
|
|
233
|
+
r"""Improved reward aggregation function with tiered multiplicative scoring.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
score (IFEvalScoreData): The score data.
|
|
237
|
+
think_blocks (list[str], optional): The list of think blocks.
|
|
238
|
+
answer_blocks (list[str], optional): The list of answer blocks.
|
|
239
|
+
complete (bool, optional): Whether the response is complete (ends with a eos token).
|
|
240
|
+
|
|
241
|
+
The reward uses a tiered multiplicative system:
|
|
242
|
+
|
|
243
|
+
1. Critical failure check: No answer blocks = 0 reward
|
|
244
|
+
2. Base format score (0-1): Weighted average of format metrics
|
|
245
|
+
3. Structure multiplier (0.1-1.0): Penalties for missing/multiple blocks
|
|
246
|
+
4. Quality bonus (0-0.5): Rewards for high quality and completion
|
|
247
|
+
5. Task complexity scaling: More requirements = higher potential rewards
|
|
248
|
+
|
|
249
|
+
The final formula is:
|
|
250
|
+
reward = (format_score + quality_bonus) * structure_multiplier * complexity_scale
|
|
251
|
+
|
|
252
|
+
This provides better learning signals by:
|
|
253
|
+
- Requiring critical elements (answer tags) for meaningful rewards
|
|
254
|
+
- Using multiplicative scaling to reward doing everything well
|
|
255
|
+
- Scaling rewards based on task complexity
|
|
256
|
+
- Providing clear failure modes and success incentives
|
|
257
|
+
|
|
258
|
+
Reward range: 0.0 to ~1.5-2.7 depending on task complexity (more instructions = higher max reward).
|
|
259
|
+
"""
|
|
260
|
+
default_dtype = torch.get_default_dtype()
|
|
261
|
+
score = score.to(default_dtype)
|
|
262
|
+
|
|
263
|
+
# Critical failure check - no answer = no reward
|
|
264
|
+
if not answer_blocks:
|
|
265
|
+
return torch.zeros(
|
|
266
|
+
score.batch_size + (1,), device=score.device, dtype=default_dtype
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Base format score calculation (0-1)
|
|
270
|
+
format_components = torch.stack(
|
|
271
|
+
[
|
|
272
|
+
score.prompt_level_strict_acc.sum(-1, keepdim=True)
|
|
273
|
+
if score.prompt_level_strict_acc is not None
|
|
274
|
+
else torch.zeros(
|
|
275
|
+
score.batch_size + (1,), device=score.device, dtype=default_dtype
|
|
276
|
+
), # Single value
|
|
277
|
+
score.inst_level_strict_acc.mean(-1, keepdim=True)
|
|
278
|
+
if score.inst_level_strict_acc is not None
|
|
279
|
+
else torch.zeros(
|
|
280
|
+
score.batch_size + (1,), device=score.device, dtype=default_dtype
|
|
281
|
+
), # Average across instructions
|
|
282
|
+
score.prompt_level_loose_acc.sum(-1, keepdim=True)
|
|
283
|
+
if score.prompt_level_loose_acc is not None
|
|
284
|
+
else torch.zeros(
|
|
285
|
+
score.batch_size + (1,), device=score.device, dtype=default_dtype
|
|
286
|
+
), # Single value
|
|
287
|
+
score.inst_level_loose_acc.mean(-1, keepdim=True)
|
|
288
|
+
if score.inst_level_loose_acc is not None
|
|
289
|
+
else torch.zeros(
|
|
290
|
+
score.batch_size + (1,), device=score.device, dtype=default_dtype
|
|
291
|
+
), # Average across instructions
|
|
292
|
+
],
|
|
293
|
+
-1,
|
|
294
|
+
)
|
|
295
|
+
weights = torch.tensor(
|
|
296
|
+
self.format_weights,
|
|
297
|
+
device=format_components.device,
|
|
298
|
+
dtype=default_dtype,
|
|
299
|
+
)
|
|
300
|
+
format_score = (format_components * weights).sum(dim=-1, keepdim=True)
|
|
301
|
+
|
|
302
|
+
# Structure multiplier (0.1-1.0)
|
|
303
|
+
structure_multiplier = 1.0
|
|
304
|
+
|
|
305
|
+
# Heavy penalty for missing think blocks (but not zero)
|
|
306
|
+
if not think_blocks:
|
|
307
|
+
structure_multiplier *= 0.3
|
|
308
|
+
elif len(think_blocks) > 1:
|
|
309
|
+
structure_multiplier *= 0.7 # Penalty for multiple think blocks
|
|
310
|
+
|
|
311
|
+
# Penalty for multiple answer blocks
|
|
312
|
+
if len(answer_blocks) > 1:
|
|
313
|
+
structure_multiplier *= 0.7
|
|
314
|
+
|
|
315
|
+
# Quality bonus (0-0.5)
|
|
316
|
+
quality_bonus = torch.zeros_like(format_score)
|
|
317
|
+
|
|
318
|
+
# Bonus for high quality responses
|
|
319
|
+
if format_score > 0.8:
|
|
320
|
+
quality_bonus += 0.3
|
|
321
|
+
|
|
322
|
+
# Completion bonus
|
|
323
|
+
if complete is not None:
|
|
324
|
+
if isinstance(complete, torch.Tensor):
|
|
325
|
+
completion_bonus = complete.to(default_dtype) * 0.2
|
|
326
|
+
else:
|
|
327
|
+
completion_bonus = float(complete) * 0.2
|
|
328
|
+
quality_bonus += completion_bonus
|
|
329
|
+
|
|
330
|
+
# Task complexity scaling based on number of instructions
|
|
331
|
+
# More instructions = higher potential rewards
|
|
332
|
+
if (
|
|
333
|
+
score.inst_level_strict_acc is not None
|
|
334
|
+
and score.inst_level_strict_acc.numel() > 0
|
|
335
|
+
):
|
|
336
|
+
num_instructions = score.inst_level_strict_acc.shape[-1]
|
|
337
|
+
else:
|
|
338
|
+
num_instructions = 1
|
|
339
|
+
complexity_scale = (
|
|
340
|
+
1.0 + (num_instructions - 1) * 0.2
|
|
341
|
+
) # 1.0 for 1 instruction, 1.2 for 2, etc.
|
|
342
|
+
|
|
343
|
+
# Final reward: (format + quality) * structure_multiplier * complexity_scale
|
|
344
|
+
final_reward = (
|
|
345
|
+
(format_score + quality_bonus) * structure_multiplier * complexity_scale
|
|
346
|
+
)
|
|
347
|
+
final_reward = final_reward.to(default_dtype)
|
|
348
|
+
|
|
349
|
+
return final_reward
|
|
350
|
+
|
|
351
|
+
def _step(
|
|
352
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
353
|
+
) -> TensorDictBase:
|
|
354
|
+
if not getattr(self.parent.base_env, "input_mode", "history") == "history":
|
|
355
|
+
raise ValueError("IFEvalScorer only supports history input mode")
|
|
356
|
+
|
|
357
|
+
if tensordict.ndim:
|
|
358
|
+
return lazy_stack(
|
|
359
|
+
[
|
|
360
|
+
self._step(td, next_td)
|
|
361
|
+
for td, next_td in zip(
|
|
362
|
+
tensordict.unbind(0), next_tensordict.unbind(0)
|
|
363
|
+
)
|
|
364
|
+
]
|
|
365
|
+
)
|
|
366
|
+
h = tensordict["history", "full"][..., -1]
|
|
367
|
+
prompt = tensordict["history", "prompt"][..., -1].content
|
|
368
|
+
response = h.content
|
|
369
|
+
complete = h.is_complete
|
|
370
|
+
# response = tensordict.get(self.response_key)
|
|
371
|
+
if is_non_tensor(response):
|
|
372
|
+
response = response.data
|
|
373
|
+
|
|
374
|
+
# TODO: This should be a distinct module
|
|
375
|
+
# Regular expression patterns to match think and answer blocks
|
|
376
|
+
think_pattern = r"<think>(.*?)</think>"
|
|
377
|
+
answer_pattern = r"<answer>(.*?)</answer>"
|
|
378
|
+
# Extract think block
|
|
379
|
+
think_blocks = re.findall(think_pattern, response, re.DOTALL)
|
|
380
|
+
|
|
381
|
+
# Extract answer block
|
|
382
|
+
answer_blocks = re.findall(answer_pattern, response, re.DOTALL)
|
|
383
|
+
|
|
384
|
+
score = _process_results(
|
|
385
|
+
tensordict.copy().auto_device_(),
|
|
386
|
+
answer_blocks[0] if answer_blocks else "",
|
|
387
|
+
verbose=self.verbose,
|
|
388
|
+
prompt=prompt,
|
|
389
|
+
)
|
|
390
|
+
next_tensordict.set(
|
|
391
|
+
self.score_key,
|
|
392
|
+
score,
|
|
393
|
+
)
|
|
394
|
+
if self.aggregate_reward:
|
|
395
|
+
if callable(self.aggregate_reward):
|
|
396
|
+
reward_func = self.aggregate_reward
|
|
397
|
+
else:
|
|
398
|
+
reward_func = self.default_reward_aggregator
|
|
399
|
+
reward = reward_func(
|
|
400
|
+
score,
|
|
401
|
+
think_blocks=think_blocks,
|
|
402
|
+
answer_blocks=answer_blocks,
|
|
403
|
+
complete=complete,
|
|
404
|
+
)
|
|
405
|
+
reward = reward.view(
|
|
406
|
+
next_tensordict.batch_size
|
|
407
|
+
+ (
|
|
408
|
+
1,
|
|
409
|
+
1,
|
|
410
|
+
)
|
|
411
|
+
)
|
|
412
|
+
next_tensordict.set("reward", reward)
|
|
413
|
+
if self.set_done_if_answer and bool(answer_blocks):
|
|
414
|
+
next_tensordict.set(
|
|
415
|
+
"done",
|
|
416
|
+
torch.ones(
|
|
417
|
+
next_tensordict.batch_size + (1,),
|
|
418
|
+
device=next_tensordict.device,
|
|
419
|
+
dtype=torch.bool,
|
|
420
|
+
),
|
|
421
|
+
)
|
|
422
|
+
next_tensordict.set(
|
|
423
|
+
"terminated",
|
|
424
|
+
torch.ones(
|
|
425
|
+
next_tensordict.batch_size + (1,),
|
|
426
|
+
device=next_tensordict.device,
|
|
427
|
+
dtype=torch.bool,
|
|
428
|
+
),
|
|
429
|
+
)
|
|
430
|
+
return next_tensordict
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def expected_keys(self) -> list[str]:
|
|
434
|
+
return [
|
|
435
|
+
self.instruction_ids_key,
|
|
436
|
+
self.prompt_key,
|
|
437
|
+
self.keyword_args_key,
|
|
438
|
+
self.id_key,
|
|
439
|
+
self.response_key,
|
|
440
|
+
]
|
|
441
|
+
|
|
442
|
+
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
|
|
443
|
+
reward_spec["reward"] = Unbounded(
|
|
444
|
+
reward_spec.shape + (1, 1),
|
|
445
|
+
dtype=torch.get_default_dtype(),
|
|
446
|
+
device=reward_spec.device,
|
|
447
|
+
)
|
|
448
|
+
return reward_spec
|
|
449
|
+
|
|
450
|
+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
|
|
451
|
+
observation_spec[self.score_key] = IFEvalScoreData.default_spec(
|
|
452
|
+
observation_spec.shape, device=observation_spec.device
|
|
453
|
+
)
|
|
454
|
+
return observation_spec
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .browser import BrowserTransform
|
|
7
|
+
from .dataloading import (
|
|
8
|
+
as_nested_tensor,
|
|
9
|
+
as_padded_tensor,
|
|
10
|
+
DataLoadingPrimer,
|
|
11
|
+
RayDataLoadingPrimer,
|
|
12
|
+
)
|
|
13
|
+
from .format import TemplateTransform
|
|
14
|
+
from .kl import KLComputation, KLRewardTransform, RetrieveKL, RetrieveLogProb
|
|
15
|
+
from .policy_version import PolicyVersion
|
|
16
|
+
from .reason import AddThinkingPrompt
|
|
17
|
+
from .tokenizer import Tokenizer
|
|
18
|
+
from .tools import (
|
|
19
|
+
ExecuteToolsInOrder,
|
|
20
|
+
JSONCallParser,
|
|
21
|
+
MCPToolTransform,
|
|
22
|
+
PythonExecutorService,
|
|
23
|
+
PythonInterpreter,
|
|
24
|
+
SimpleToolTransform,
|
|
25
|
+
ToolCall,
|
|
26
|
+
ToolRegistry,
|
|
27
|
+
ToolService,
|
|
28
|
+
XMLBlockParser,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
"AddThinkingPrompt",
|
|
33
|
+
"BrowserTransform",
|
|
34
|
+
"DataLoadingPrimer",
|
|
35
|
+
"ExecuteToolsInOrder",
|
|
36
|
+
"JSONCallParser",
|
|
37
|
+
"KLComputation",
|
|
38
|
+
"KLRewardTransform",
|
|
39
|
+
"MCPToolTransform",
|
|
40
|
+
"PolicyVersion",
|
|
41
|
+
"PythonExecutorService",
|
|
42
|
+
"PythonInterpreter",
|
|
43
|
+
"RayDataLoadingPrimer",
|
|
44
|
+
"RetrieveKL",
|
|
45
|
+
"RetrieveLogProb",
|
|
46
|
+
"SimpleToolTransform",
|
|
47
|
+
"TemplateTransform",
|
|
48
|
+
"Tokenizer",
|
|
49
|
+
"ToolCall",
|
|
50
|
+
"ToolRegistry",
|
|
51
|
+
"ToolService",
|
|
52
|
+
"XMLBlockParser",
|
|
53
|
+
"as_nested_tensor",
|
|
54
|
+
"as_padded_tensor",
|
|
55
|
+
]
|