torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,869 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import functools
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import re
|
|
13
|
+
from contextlib import contextmanager
|
|
14
|
+
from dataclasses import asdict
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any, Literal, TYPE_CHECKING
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from tensordict import NestedKey, NonTensorData, TensorDict, TensorDictBase
|
|
22
|
+
from tensordict.tensorclass import is_non_tensor
|
|
23
|
+
|
|
24
|
+
from torchrl._utils import logger as torchrl_logger
|
|
25
|
+
from torchrl.data import Choice, Composite, NonTensor
|
|
26
|
+
from torchrl.data.llm import History
|
|
27
|
+
from torchrl.envs import ConditionalSkip, GymWrapper, Transform, TransformedEnv
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
import mlgym
|
|
31
|
+
import transformers
|
|
32
|
+
|
|
33
|
+
# Inv transforms:
|
|
34
|
+
# Transforms to apply prior to pass the model output to the env
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@contextmanager
|
|
38
|
+
def _temp_cwd_mlgym():
|
|
39
|
+
"""Temporarily change the current working directory to mlgym."""
|
|
40
|
+
import mlgym
|
|
41
|
+
|
|
42
|
+
path = Path(mlgym.__spec__.submodule_search_locations[0]).parent
|
|
43
|
+
old_pwd = os.getcwd()
|
|
44
|
+
os.chdir(str(path))
|
|
45
|
+
# sys.path.insert(-1, "mlgym")
|
|
46
|
+
try:
|
|
47
|
+
yield
|
|
48
|
+
finally:
|
|
49
|
+
# sys.path.pop()
|
|
50
|
+
os.chdir(old_pwd)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MLGymBaseTransform(Transform):
|
|
54
|
+
"""Base class for all MLGym transforms."""
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def config(self):
|
|
58
|
+
return self.parent.base_env.config
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def system_args(self):
|
|
62
|
+
return {
|
|
63
|
+
"command_docs": self.config.tools_handler.command_docs,
|
|
64
|
+
**self.config.tools_handler.env_variables,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def task_args(self):
|
|
69
|
+
# Placeholder
|
|
70
|
+
task_args = getattr(self, "_task_args", None)
|
|
71
|
+
if task_args is None:
|
|
72
|
+
return self.parent.base_env.task.args
|
|
73
|
+
return task_args
|
|
74
|
+
|
|
75
|
+
@task_args.setter
|
|
76
|
+
def task_args(self, task_args):
|
|
77
|
+
self._task_args = task_args
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def name(self):
|
|
81
|
+
return "torchrl"
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def state_command(self):
|
|
85
|
+
return self.config.state_command.name
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def agent_args(self):
|
|
89
|
+
return self.parent.base_env.agent_args
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def model_name(self) -> Literal["human", "human_thought"]:
|
|
93
|
+
return self.agent_args.model.model_name
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
#######################################################
|
|
97
|
+
# Forward transforms: Format the env output
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# Transform #0: Resets the env
|
|
101
|
+
class ResetModule(MLGymBaseTransform):
|
|
102
|
+
"""Runs setup pipeline and enables multi-resets.
|
|
103
|
+
|
|
104
|
+
The reset method reads the 'system' initial input from the config and parses it to a History
|
|
105
|
+
object.
|
|
106
|
+
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
response_key: NestedKey = "text_response"
|
|
110
|
+
|
|
111
|
+
def __init__(self):
|
|
112
|
+
super().__init__(in_keys=[], out_keys=["history"])
|
|
113
|
+
|
|
114
|
+
@_temp_cwd_mlgym()
|
|
115
|
+
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
116
|
+
base_env = self.parent.base_env._env
|
|
117
|
+
if tensordict is not None and "task" in tensordict:
|
|
118
|
+
import gymnasium as gym
|
|
119
|
+
|
|
120
|
+
task = tensordict["task"]
|
|
121
|
+
torchrl_logger.info(f"Resetting with {task=}")
|
|
122
|
+
if is_non_tensor(task):
|
|
123
|
+
task = task.data
|
|
124
|
+
task_id, agent_args = _TASK_IDS[task]
|
|
125
|
+
try:
|
|
126
|
+
base_env.close()
|
|
127
|
+
except Exception:
|
|
128
|
+
torchrl_logger.info(f"Failed to close {base_env=}")
|
|
129
|
+
base_env = gym.make(
|
|
130
|
+
f"mlgym/{task}",
|
|
131
|
+
devices=["cpu_0"],
|
|
132
|
+
).unwrapped
|
|
133
|
+
base_env.config = agent_args.config
|
|
134
|
+
self.parent.base_env.set_env(base_env)
|
|
135
|
+
base_env.reset_container()
|
|
136
|
+
base_env.communicate(f"cd {Path(base_env.task_workspace).parent}")
|
|
137
|
+
return tensordict
|
|
138
|
+
|
|
139
|
+
@_temp_cwd_mlgym()
|
|
140
|
+
def _reset(
|
|
141
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
142
|
+
) -> TensorDictBase:
|
|
143
|
+
# TODO: what to do with this?
|
|
144
|
+
# reset model stats
|
|
145
|
+
# self.model.reset_stats(init_model_stats)
|
|
146
|
+
# env = self.parent.base_env._env
|
|
147
|
+
|
|
148
|
+
env = self.parent.base_env._env
|
|
149
|
+
self.set_environment_vars(env, self.config.env_variables)
|
|
150
|
+
|
|
151
|
+
system_msg = self.config.system_template.format(
|
|
152
|
+
**self.system_args, **asdict(self.task_args)
|
|
153
|
+
)
|
|
154
|
+
# self.logger.log(self._default_logging_level, f"SYSTEM ({self.name})\n{system_msg}")
|
|
155
|
+
history = History(
|
|
156
|
+
role="system",
|
|
157
|
+
content=system_msg, # agent=self.name,
|
|
158
|
+
batch_size=(1,),
|
|
159
|
+
device=self.parent.device,
|
|
160
|
+
)
|
|
161
|
+
tensordict_reset["history"] = history
|
|
162
|
+
|
|
163
|
+
return tensordict_reset
|
|
164
|
+
|
|
165
|
+
def _step(
|
|
166
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
167
|
+
) -> TensorDictBase:
|
|
168
|
+
# Placeholder
|
|
169
|
+
if "history" not in next_tensordict:
|
|
170
|
+
if "local_history" in tensordict:
|
|
171
|
+
local_history = tensordict["local_history"]
|
|
172
|
+
else:
|
|
173
|
+
local_history = None
|
|
174
|
+
history = tensordict["history"]
|
|
175
|
+
if local_history is not None:
|
|
176
|
+
history = history.append(local_history, inplace=False)
|
|
177
|
+
tensordict["history"] = history
|
|
178
|
+
next_tensordict["history"] = history
|
|
179
|
+
return next_tensordict
|
|
180
|
+
|
|
181
|
+
def set_environment_vars(
|
|
182
|
+
self, env: MLGymWrapper, env_variables: dict[str, Any]
|
|
183
|
+
) -> None:
|
|
184
|
+
commands_to_execute = (
|
|
185
|
+
[self.config.state_command.code]
|
|
186
|
+
+ # [code for code in self.config.util_functions] +
|
|
187
|
+
# [command.code for command in self.config._commands] +
|
|
188
|
+
[f"{k}={v}" for k, v in env_variables.items()]
|
|
189
|
+
)
|
|
190
|
+
commands = "\n".join(commands_to_execute)
|
|
191
|
+
try:
|
|
192
|
+
output = env.communicate(commands)
|
|
193
|
+
if env.returncode != 0:
|
|
194
|
+
msg = f"Nonzero return code: {env.returncode}\nOutput: {output}"
|
|
195
|
+
raise RuntimeError(msg)
|
|
196
|
+
except KeyboardInterrupt:
|
|
197
|
+
raise
|
|
198
|
+
except Exception as e:
|
|
199
|
+
raise e
|
|
200
|
+
command_files = []
|
|
201
|
+
for file in self.config.command_files:
|
|
202
|
+
datum = {}
|
|
203
|
+
with open(file) as f:
|
|
204
|
+
contents = f.read()
|
|
205
|
+
datum["contents"] = contents
|
|
206
|
+
filename = Path(file).name
|
|
207
|
+
if not contents.strip().startswith("#!"):
|
|
208
|
+
if filename.endswith(".sh"):
|
|
209
|
+
# files are sourced, so they are not executable
|
|
210
|
+
datum["name"] = Path(file).name
|
|
211
|
+
datum["type"] = "source_file"
|
|
212
|
+
elif filename.startswith("_"):
|
|
213
|
+
# files are sourced, so they are not executable
|
|
214
|
+
datum["name"] = Path(file).name
|
|
215
|
+
datum["type"] = "utility"
|
|
216
|
+
else:
|
|
217
|
+
msg = (
|
|
218
|
+
f"Non-shell script file {file} does not start with shebang.\n"
|
|
219
|
+
"Either add a shebang (#!) or change the file extension to .sh if you want to source it.\n"
|
|
220
|
+
"You can override this behavior by adding an underscore to the file name (e.g. _utils.py)."
|
|
221
|
+
)
|
|
222
|
+
raise ValueError(msg)
|
|
223
|
+
else:
|
|
224
|
+
# scripts are made executable
|
|
225
|
+
datum["name"] = Path(file).name.rsplit(".", 1)[0]
|
|
226
|
+
datum["type"] = "script"
|
|
227
|
+
command_files.append(datum)
|
|
228
|
+
# TODO: implement add commands method in environment
|
|
229
|
+
env.add_commands(command_files)
|
|
230
|
+
|
|
231
|
+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
|
|
232
|
+
observation_spec["history"] = History.default_spec()
|
|
233
|
+
return observation_spec
|
|
234
|
+
|
|
235
|
+
def transform_action_spec(self, action_spec: Composite) -> Composite:
|
|
236
|
+
if isinstance(action_spec, Composite):
|
|
237
|
+
action_spec[self.response_key] = self.transform_action_spec(
|
|
238
|
+
action_spec[self.response_key]
|
|
239
|
+
)
|
|
240
|
+
return action_spec
|
|
241
|
+
# make the "random" action just a choice between innocuous bash commands
|
|
242
|
+
return Choice(
|
|
243
|
+
[
|
|
244
|
+
NonTensor(example_data="ls -rtlh", shape=action_spec.shape),
|
|
245
|
+
NonTensor(example_data="pwd", shape=action_spec.shape),
|
|
246
|
+
]
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def transform_state_spec(self, state_spec: Composite) -> Composite:
|
|
250
|
+
state_spec["history"] = History.default_spec()
|
|
251
|
+
return state_spec
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class TaskSampler(Transform):
|
|
255
|
+
"""A sampler for tasks in a certain task set."""
|
|
256
|
+
|
|
257
|
+
def __init__(self, tasks: list[str]):
|
|
258
|
+
super().__init__()
|
|
259
|
+
self.tasks = tasks
|
|
260
|
+
|
|
261
|
+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
|
|
262
|
+
observation_spec["task"] = NonTensor(example_data="<a task>", shape=())
|
|
263
|
+
return observation_spec
|
|
264
|
+
|
|
265
|
+
@_temp_cwd_mlgym()
|
|
266
|
+
def _reset_env_preprocess(
|
|
267
|
+
self, tensordict: TensorDictBase | None
|
|
268
|
+
) -> TensorDictBase:
|
|
269
|
+
if tensordict is None:
|
|
270
|
+
tensordict = TensorDict(batch_size=self.parent.batch_size)
|
|
271
|
+
# Sample a task
|
|
272
|
+
task = np.random.choice(self.tasks)
|
|
273
|
+
tensordict["task"] = NonTensorData(task)
|
|
274
|
+
self._current_task = task
|
|
275
|
+
return tensordict
|
|
276
|
+
|
|
277
|
+
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
|
|
278
|
+
next_tensordict["task"] = self._current_task
|
|
279
|
+
return next_tensordict
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# Transform #1: env -> state
|
|
283
|
+
class ReadState(MLGymBaseTransform):
|
|
284
|
+
"""Reads current state and writes it as a parsable str in the tensordict."""
|
|
285
|
+
|
|
286
|
+
# from mlgym/agent/base.py:BaseAgent:forward_model
|
|
287
|
+
def _step(
|
|
288
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
289
|
+
) -> TensorDictBase:
|
|
290
|
+
base_mlgym_env = self.parent.base_env # getattr is forwarded
|
|
291
|
+
|
|
292
|
+
command = self.state_command
|
|
293
|
+
state = base_mlgym_env.communicate(command) if self.state_command else None
|
|
294
|
+
|
|
295
|
+
next_tensordict["state"] = state
|
|
296
|
+
return next_tensordict
|
|
297
|
+
|
|
298
|
+
def _reset(
|
|
299
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
300
|
+
) -> TensorDictBase:
|
|
301
|
+
# tensordict_reset.setdefault("message", NonTensorData(""))
|
|
302
|
+
# tensordict_reset.setdefault("state", NonTensorData(""))
|
|
303
|
+
return self._step(tensordict_reset, tensordict_reset)
|
|
304
|
+
|
|
305
|
+
def transform_observation_spec(self, observation_spec):
|
|
306
|
+
observation_spec.set(
|
|
307
|
+
"state",
|
|
308
|
+
NonTensor(
|
|
309
|
+
example_data="a string",
|
|
310
|
+
device=observation_spec.device,
|
|
311
|
+
shape=observation_spec.shape,
|
|
312
|
+
),
|
|
313
|
+
)
|
|
314
|
+
return observation_spec
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
# Transform #2: state -> message
|
|
318
|
+
class StateToMessage(MLGymBaseTransform):
|
|
319
|
+
"""Parses the string using json to a given template.
|
|
320
|
+
|
|
321
|
+
Requires:
|
|
322
|
+
- a 'state' key from the ReadState transform
|
|
323
|
+
- an 'observation' key from the base environment
|
|
324
|
+
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
def _step(
|
|
328
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
329
|
+
) -> TensorDictBase:
|
|
330
|
+
base_mlgym_env = self.parent.base_env # getattr is forwarded
|
|
331
|
+
observation = tensordict["observation"]
|
|
332
|
+
state = tensordict["state"]
|
|
333
|
+
config = self.config
|
|
334
|
+
|
|
335
|
+
current_step = base_mlgym_env.current_step
|
|
336
|
+
max_steps = base_mlgym_env.max_steps
|
|
337
|
+
try:
|
|
338
|
+
state_vars = json.loads(state)
|
|
339
|
+
except json.JSONDecodeError as e:
|
|
340
|
+
msg = f"State {state!r} is not valid json. This is an internal error, please report it."
|
|
341
|
+
raise ValueError(msg) from e
|
|
342
|
+
# add step information to state_vars
|
|
343
|
+
state_vars["current_step"] = current_step
|
|
344
|
+
state_vars["remaining_steps"] = max_steps - current_step
|
|
345
|
+
|
|
346
|
+
# FIXME: we don't need to do this, we have our own observation space
|
|
347
|
+
# Determine observation template based on what prior observation was
|
|
348
|
+
|
|
349
|
+
history: History = tensordict["history"]
|
|
350
|
+
if history[..., -1].role == "system":
|
|
351
|
+
# Show task template if prev. obs. was initial system message
|
|
352
|
+
templates = [config.task_template]
|
|
353
|
+
if config.strategy_template is not None:
|
|
354
|
+
templates.append(config.strategy_template)
|
|
355
|
+
elif observation is None or observation.strip() == "":
|
|
356
|
+
# Show no output template if observation content was empty
|
|
357
|
+
assert config.next_step_no_output_template is not None # linting
|
|
358
|
+
templates = [config.next_step_no_output_template]
|
|
359
|
+
else:
|
|
360
|
+
# Show standard output template if there is observation content
|
|
361
|
+
assert config.next_step_template is not None # linting
|
|
362
|
+
templates = [config.next_step_template]
|
|
363
|
+
|
|
364
|
+
# Format selected template(s) with information
|
|
365
|
+
messages = []
|
|
366
|
+
assert self.task_args is not None
|
|
367
|
+
for template in templates:
|
|
368
|
+
messages.append(
|
|
369
|
+
template.format(
|
|
370
|
+
**asdict(self.task_args),
|
|
371
|
+
**self.system_args,
|
|
372
|
+
**state_vars,
|
|
373
|
+
observation=(observation if observation is not None else ""),
|
|
374
|
+
# missing forwarded_vars because no attempts
|
|
375
|
+
),
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
message = "\n".join(messages)
|
|
379
|
+
next_tensordict["message"] = message
|
|
380
|
+
# model query hooks here
|
|
381
|
+
return next_tensordict
|
|
382
|
+
|
|
383
|
+
def _reset(
|
|
384
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
385
|
+
) -> TensorDictBase:
|
|
386
|
+
# tensordict_reset.setdefault("message", NonTensorData(""))
|
|
387
|
+
# tensordict_reset.setdefault("state", NonTensorData(""))
|
|
388
|
+
return self._step(tensordict_reset, tensordict_reset)
|
|
389
|
+
|
|
390
|
+
def transform_observation_spec(self, observation_spec):
|
|
391
|
+
observation_spec.set(
|
|
392
|
+
"message",
|
|
393
|
+
NonTensor(
|
|
394
|
+
example_data="a string",
|
|
395
|
+
device=observation_spec.device,
|
|
396
|
+
shape=observation_spec.shape,
|
|
397
|
+
),
|
|
398
|
+
)
|
|
399
|
+
return observation_spec
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
# Transform #3: Append message to history
|
|
403
|
+
class MessageToHistory(MLGymBaseTransform):
|
|
404
|
+
"""Parses the message string to a History object, then reparses the history to a complete message.
|
|
405
|
+
|
|
406
|
+
.. seealso:: HistoryToMessage
|
|
407
|
+
|
|
408
|
+
"""
|
|
409
|
+
|
|
410
|
+
def __init__(self):
|
|
411
|
+
super().__init__(in_keys=["message", "history"], out_keys=["history", "chat"])
|
|
412
|
+
|
|
413
|
+
# from mlgym/agent/base.py:BaseAgent:local_history
|
|
414
|
+
# from mlgym/agent/base.py:BaseAgent:_append_history
|
|
415
|
+
def _step(
|
|
416
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
|
417
|
+
) -> TensorDictBase:
|
|
418
|
+
# From PrepareDataForModel
|
|
419
|
+
message: str = next_tensordict["message"]
|
|
420
|
+
# from mlgym/agent/base.py:BaseAgent:forward_model
|
|
421
|
+
history = tensordict["history"]
|
|
422
|
+
cur_history = History(
|
|
423
|
+
role="user", content=message, batch_size=(), device=self.parent.device
|
|
424
|
+
)
|
|
425
|
+
# This is the basic thing our transform does: append the history to the existing one.
|
|
426
|
+
# (We should be able to extend the lazy stack directly)
|
|
427
|
+
history = history.append(cur_history, inplace=False)
|
|
428
|
+
|
|
429
|
+
next_tensordict["history"] = history
|
|
430
|
+
return next_tensordict
|
|
431
|
+
|
|
432
|
+
def _reset(
|
|
433
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
434
|
+
) -> TensorDictBase:
|
|
435
|
+
return self._step(tensordict_reset, tensordict_reset)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
# Inverse transforms:
|
|
439
|
+
# Format the action from the model for the env
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class TemplateTransform(MLGymBaseTransform):
|
|
443
|
+
"""A transform to apply the chat template to the History."""
|
|
444
|
+
|
|
445
|
+
response_key: NestedKey = "text_response"
|
|
446
|
+
prompt_key: NestedKey = "text"
|
|
447
|
+
|
|
448
|
+
# alternative to DummyFormat, wip
|
|
449
|
+
def __init__(
|
|
450
|
+
self,
|
|
451
|
+
in_keys=None,
|
|
452
|
+
out_keys=None,
|
|
453
|
+
in_keys_inv=None,
|
|
454
|
+
out_keys_inv=None,
|
|
455
|
+
tokenizer=None,
|
|
456
|
+
chat_template_name: Literal["chatml_format"] | None = None,
|
|
457
|
+
continue_final_message: bool = False,
|
|
458
|
+
tokenize: bool = False,
|
|
459
|
+
return_tensors: str = "pt",
|
|
460
|
+
return_dict: bool = False,
|
|
461
|
+
padding: bool | str = False,
|
|
462
|
+
truncation: bool | str = False,
|
|
463
|
+
):
|
|
464
|
+
super().__init__(
|
|
465
|
+
in_keys=["history"] if in_keys is None else in_keys,
|
|
466
|
+
out_keys=[self.prompt_key] if out_keys is None else out_keys,
|
|
467
|
+
in_keys_inv=[self.prompt_key, self.response_key]
|
|
468
|
+
if in_keys_inv is None
|
|
469
|
+
else in_keys_inv,
|
|
470
|
+
# TODO: we should not use the response key here but another dedicated entry, like "action_parsed"
|
|
471
|
+
out_keys_inv=[self.response_key] if out_keys_inv is None else out_keys_inv,
|
|
472
|
+
)
|
|
473
|
+
self.chat_template_name = chat_template_name
|
|
474
|
+
self.tokenizer = tokenizer
|
|
475
|
+
self.tokenize = tokenize
|
|
476
|
+
self.continue_final_message = continue_final_message
|
|
477
|
+
self.return_tensors = return_tensors
|
|
478
|
+
self.return_dict = return_dict
|
|
479
|
+
self.padding = padding
|
|
480
|
+
self.truncation = truncation
|
|
481
|
+
|
|
482
|
+
def transform_observation_spec(self, observation_spec: Composite):
|
|
483
|
+
observation_spec[self.prompt_key] = NonTensor(
|
|
484
|
+
example_data="<some chat string>",
|
|
485
|
+
shape=observation_spec.shape,
|
|
486
|
+
device=observation_spec.device,
|
|
487
|
+
)
|
|
488
|
+
return observation_spec
|
|
489
|
+
|
|
490
|
+
@property
|
|
491
|
+
def _chat_template(self):
|
|
492
|
+
chat_template = None
|
|
493
|
+
if self.chat_template_name:
|
|
494
|
+
from torchrl.data.llm.datatypes.chat import _CHAT_TEMPLATES
|
|
495
|
+
|
|
496
|
+
chat_template = _CHAT_TEMPLATES[self.chat_template_name]
|
|
497
|
+
elif self.tokenizer.chat_template is not None:
|
|
498
|
+
chat_template = self.tokenizer.chat_template
|
|
499
|
+
elif chat_template is None:
|
|
500
|
+
raise ValueError("Failed to determine chat template.")
|
|
501
|
+
return chat_template
|
|
502
|
+
|
|
503
|
+
def _apply_transform(self, history: History) -> NonTensorData:
|
|
504
|
+
if self.tokenizer is None:
|
|
505
|
+
raise RuntimeError("Cannot apply chat template without a tokenizer.")
|
|
506
|
+
result = history.apply_chat_template(
|
|
507
|
+
tokenizer=self.tokenizer,
|
|
508
|
+
add_generation_prompt=True,
|
|
509
|
+
chat_template=self._chat_template,
|
|
510
|
+
continue_final_message=self.continue_final_message,
|
|
511
|
+
tokenize=self.tokenize,
|
|
512
|
+
padding=self.padding,
|
|
513
|
+
truncation=self.truncation,
|
|
514
|
+
return_tensors=self.return_tensors,
|
|
515
|
+
)
|
|
516
|
+
return result
|
|
517
|
+
|
|
518
|
+
def _reset(self, tensordict, tensordict_reset):
|
|
519
|
+
return self._call(tensordict_reset)
|
|
520
|
+
|
|
521
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
522
|
+
if self.in_keys_inv:
|
|
523
|
+
prompt = tensordict[self.prompt_key]
|
|
524
|
+
response = tensordict[self.response_key]
|
|
525
|
+
if isinstance(prompt, list):
|
|
526
|
+
action = [
|
|
527
|
+
prompt + response for prompt, response in zip(prompt, response)
|
|
528
|
+
]
|
|
529
|
+
else:
|
|
530
|
+
action = prompt + response
|
|
531
|
+
try:
|
|
532
|
+
history, action = self._inv_apply_transform(action)
|
|
533
|
+
tensordict["local_history"] = history
|
|
534
|
+
tensordict[self.response_key] = action
|
|
535
|
+
except RuntimeError as e:
|
|
536
|
+
if "Expected assistant role" in str(e):
|
|
537
|
+
tensordict["local_history"] = History(role="assistant", content="")
|
|
538
|
+
tensordict[self.response_key] = ""
|
|
539
|
+
return tensordict
|
|
540
|
+
|
|
541
|
+
def _inv_apply_transform(self, action):
|
|
542
|
+
if self.tokenize:
|
|
543
|
+
action = self.tokenizer.decode(action)
|
|
544
|
+
|
|
545
|
+
if not isinstance(action, (str, list)):
|
|
546
|
+
action = action.data
|
|
547
|
+
history, action = self._inv_apply_transform(action)
|
|
548
|
+
action = NonTensorData(
|
|
549
|
+
action, batch_size=action.batch_size, device=action.device
|
|
550
|
+
)
|
|
551
|
+
return history, action
|
|
552
|
+
|
|
553
|
+
history = History.from_text(
|
|
554
|
+
action,
|
|
555
|
+
# chat_template=self._chat_template,
|
|
556
|
+
)[..., -1]
|
|
557
|
+
if history.role != "assistant":
|
|
558
|
+
raise RuntimeError(f"Expected assistant role, got {history.role=}")
|
|
559
|
+
action = history.get("content")
|
|
560
|
+
return history, action
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
class IsolateCodeBlock(MLGymBaseTransform):
|
|
564
|
+
"""A transform that isolates the code block in the action generated by the LLM.
|
|
565
|
+
|
|
566
|
+
Optionally, wrongly formatted actions are assigned a negative reward.
|
|
567
|
+
"""
|
|
568
|
+
|
|
569
|
+
response_key: NestedKey = "text_response"
|
|
570
|
+
|
|
571
|
+
def __init__(self, reward_wrong_format: float | None = None):
|
|
572
|
+
super().__init__(
|
|
573
|
+
in_keys_inv=[self.response_key], out_keys_inv=[self.response_key]
|
|
574
|
+
)
|
|
575
|
+
from mlgym.agent.parsing import ThoughtActionParser
|
|
576
|
+
|
|
577
|
+
self.parser = ThoughtActionParser()
|
|
578
|
+
self.reward_wrong_format = reward_wrong_format
|
|
579
|
+
self._assign_reward = False
|
|
580
|
+
|
|
581
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
582
|
+
torchrl_logger.info("inv call with IsolateCodeBlock")
|
|
583
|
+
action = tensordict[self.response_key]
|
|
584
|
+
# if we didn't find an action, the action is empty
|
|
585
|
+
if not action:
|
|
586
|
+
torchrl_logger.info(
|
|
587
|
+
"Did not find a suitable action, skipping the call to step."
|
|
588
|
+
)
|
|
589
|
+
tensordict["retry"] = torch.ones(tensordict.shape, dtype=torch.bool)
|
|
590
|
+
self._assign_reward = True
|
|
591
|
+
else:
|
|
592
|
+
from mlgym.exceptions import FormatError
|
|
593
|
+
|
|
594
|
+
try:
|
|
595
|
+
action = self._inv_apply_transform(action)
|
|
596
|
+
tensordict[self.response_key] = action
|
|
597
|
+
torchrl_logger.info(f"Code block: {action}")
|
|
598
|
+
tensordict["retry"] = torch.zeros(tensordict.shape, dtype=torch.bool)
|
|
599
|
+
self._assign_reward = False
|
|
600
|
+
except FormatError:
|
|
601
|
+
tensordict["retry"] = torch.ones(tensordict.shape, dtype=torch.bool)
|
|
602
|
+
self._assign_reward = True
|
|
603
|
+
return tensordict
|
|
604
|
+
|
|
605
|
+
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
606
|
+
if self._assign_reward:
|
|
607
|
+
torchrl_logger.info(
|
|
608
|
+
f"Assigning penalty for unsuitable action: {self.reward_wrong_format}"
|
|
609
|
+
)
|
|
610
|
+
if self.reward_wrong_format is not None:
|
|
611
|
+
tensordict[self.parent.reward_key] += self.reward_wrong_format
|
|
612
|
+
return tensordict
|
|
613
|
+
|
|
614
|
+
def _inv_apply_transform(self, action):
|
|
615
|
+
if not isinstance(action, (str, list)):
|
|
616
|
+
return NonTensorData(
|
|
617
|
+
self._inv_apply_transform(action.tolist()),
|
|
618
|
+
batch_size=action.batch_size,
|
|
619
|
+
device=action.device,
|
|
620
|
+
)
|
|
621
|
+
if isinstance(action, list):
|
|
622
|
+
return [self._inv_apply_transform(action) for action in action]
|
|
623
|
+
thought, action = self.parser(action, None)
|
|
624
|
+
return action
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
class EvaluationOutputParser:
|
|
628
|
+
"""Parser for the reward transform in MLGym.
|
|
629
|
+
|
|
630
|
+
.. seealso:: :class:`~torchrl.envs.llm.libs.mlgym.MLGymRewardAssignment`
|
|
631
|
+
|
|
632
|
+
"""
|
|
633
|
+
|
|
634
|
+
def __init__(self):
|
|
635
|
+
# Regular expressions to match the required fields
|
|
636
|
+
self.patterns = {
|
|
637
|
+
"submission_artefact_path": r"valid submission artefact at (.*)\.",
|
|
638
|
+
"baseline_score": r"Baseline Score: \{'Score': (.*)\}",
|
|
639
|
+
"evaluation_score": r"Evaluation Score: \{'Score': (.*)\}",
|
|
640
|
+
"current_step": r"\(Current Step: (\d+),",
|
|
641
|
+
"remaining_steps": r"Remaining Steps: (\d+)\)",
|
|
642
|
+
"open_file": r"\(Open file: (.*)\)",
|
|
643
|
+
"current_directory": r"\(Current directory: (.*)\)",
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
def __call__(self, output_string):
|
|
647
|
+
|
|
648
|
+
parsed_data = {}
|
|
649
|
+
|
|
650
|
+
for key, pattern in self.patterns.items():
|
|
651
|
+
match = re.search(pattern, output_string)
|
|
652
|
+
if match:
|
|
653
|
+
parsed_data[key] = match.group(1).strip()
|
|
654
|
+
if "baseline_score" in parsed_data:
|
|
655
|
+
parsed_data["baseline_score"] = float(parsed_data["baseline_score"])
|
|
656
|
+
|
|
657
|
+
if "evaluation_score" in parsed_data:
|
|
658
|
+
parsed_data["evaluation_score"] = float(parsed_data["evaluation_score"])
|
|
659
|
+
if "current_step" in parsed_data:
|
|
660
|
+
parsed_data["current_step"] = int(parsed_data["current_step"])
|
|
661
|
+
if "remaining_steps" in parsed_data:
|
|
662
|
+
parsed_data["remaining_steps"] = int(parsed_data["remaining_steps"])
|
|
663
|
+
|
|
664
|
+
return parsed_data
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class MLGymRewardAssignment(MLGymBaseTransform):
|
|
668
|
+
"""Reward assignment through parsing of the last item in history.
|
|
669
|
+
|
|
670
|
+
By default, the :class:`~torchrl.envs.llm.libs.mlgym.EvaluationOutputParser` class is used as parser.
|
|
671
|
+
|
|
672
|
+
"""
|
|
673
|
+
|
|
674
|
+
def __init__(self):
|
|
675
|
+
super().__init__(in_keys=["reward", "history"], out_keys=["reward"])
|
|
676
|
+
self.parser = EvaluationOutputParser()
|
|
677
|
+
|
|
678
|
+
def _call(self, tensordict):
|
|
679
|
+
history = tensordict.get("history")
|
|
680
|
+
if history is None:
|
|
681
|
+
raise KeyError(f"History is missing in tensordict {tensordict}")
|
|
682
|
+
if history.ndim != 1:
|
|
683
|
+
raise ValueError(f"History shape must be 1D, got {history.shape}")
|
|
684
|
+
content = history[-1].content
|
|
685
|
+
torchrl_logger.info(f"Parsing reward from: {content}")
|
|
686
|
+
parsed = self.parser(content)
|
|
687
|
+
reward = parsed.get("evaluation_score", 0.0) - parsed.get("baseline_score", 0.0)
|
|
688
|
+
torchrl_logger.info(f"Parsed reward: {reward}")
|
|
689
|
+
tensordict["reward"] = tensordict["reward"] + reward
|
|
690
|
+
return tensordict
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
class _add_info_to_reset:
|
|
694
|
+
def __init__(self, func):
|
|
695
|
+
functools.update_wrapper(self, func)
|
|
696
|
+
self.func = func
|
|
697
|
+
|
|
698
|
+
def __call__(self, *args, **kwargs):
|
|
699
|
+
return self.func(*args, **kwargs), {}
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
class _add_truncated_to_step:
|
|
703
|
+
def __init__(self, func):
|
|
704
|
+
functools.update_wrapper(self, func)
|
|
705
|
+
self.func = func
|
|
706
|
+
|
|
707
|
+
@_temp_cwd_mlgym()
|
|
708
|
+
def __call__(self, *args, **kwargs):
|
|
709
|
+
obs, r, done, info = self.func(*args, **kwargs)
|
|
710
|
+
return obs, r, done, False, info
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
class MLGymWrapper(GymWrapper):
|
|
714
|
+
"""A thin wrapper for MLGym environments.
|
|
715
|
+
|
|
716
|
+
This specialized :class:`~torchrl.envs.GymWrapper` subclass defines the observation space with `observation=NonTensor()`
|
|
717
|
+
and the action space with `text_response=NonTensor()`, according to the :class:`~torchrl.envs.llm.ChatEnv` API.
|
|
718
|
+
|
|
719
|
+
"""
|
|
720
|
+
|
|
721
|
+
def __init__(self, *args, **kwargs):
|
|
722
|
+
super().__init__(*args, **kwargs)
|
|
723
|
+
self.full_action_spec = Composite(
|
|
724
|
+
text_response=NonTensor(example_data="<a string>", shape=())
|
|
725
|
+
)
|
|
726
|
+
self.full_observation_spec = Composite(
|
|
727
|
+
observation=NonTensor(example_data="<a string>", shape=())
|
|
728
|
+
)
|
|
729
|
+
self.set_env()
|
|
730
|
+
|
|
731
|
+
def set_env(self, env: Any = None):
|
|
732
|
+
if env is not None:
|
|
733
|
+
self._env = env
|
|
734
|
+
self._patch_reset()
|
|
735
|
+
self._patch_step()
|
|
736
|
+
|
|
737
|
+
def _patch_reset(self):
|
|
738
|
+
if not isinstance(self._env.reset, _add_info_to_reset):
|
|
739
|
+
self._env.reset = _add_info_to_reset(self._env.reset)
|
|
740
|
+
|
|
741
|
+
def _patch_step(self):
|
|
742
|
+
if not isinstance(self._env.reset, _add_truncated_to_step):
|
|
743
|
+
self._env.step = _add_truncated_to_step(self._env.step)
|
|
744
|
+
|
|
745
|
+
@_temp_cwd_mlgym()
|
|
746
|
+
def _reset(
|
|
747
|
+
self, tensordict: TensorDictBase | None = None, **kwargs
|
|
748
|
+
) -> TensorDictBase:
|
|
749
|
+
return super()._reset(tensordict=tensordict, **kwargs)
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
_TASK_IDS = {}
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
def get_args(
|
|
756
|
+
task: Literal["prisonersDilemma"] = "prisonersDilemma",
|
|
757
|
+
) -> tuple[
|
|
758
|
+
mlgym.environment.env.EnvironmentArguments, # noqa
|
|
759
|
+
mlgym.agent.base.AgentArguments, # noqa
|
|
760
|
+
]: # noqa
|
|
761
|
+
"""Parse command line arguments and return a ScriptArguments object.
|
|
762
|
+
|
|
763
|
+
Args:
|
|
764
|
+
args: Optional list of arguments to parse. If not provided, uses sys.argv.
|
|
765
|
+
"""
|
|
766
|
+
import mlgym.environment.registration # noqa
|
|
767
|
+
from mlgym import CONFIG_DIR
|
|
768
|
+
from mlgym.agent.base import AgentArguments
|
|
769
|
+
from mlgym.backend.base import ModelArguments
|
|
770
|
+
from mlgym.environment.env import EnvironmentArguments
|
|
771
|
+
from mlgym.environment.registration import register_task
|
|
772
|
+
|
|
773
|
+
environment_args = EnvironmentArguments(
|
|
774
|
+
task_config_path=f"tasks/{task}.yaml",
|
|
775
|
+
max_steps=10,
|
|
776
|
+
seed=42,
|
|
777
|
+
container_type="podman",
|
|
778
|
+
verbose=False,
|
|
779
|
+
aliases_file="docker/aliases.sh",
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
agent_args = AgentArguments(
|
|
783
|
+
# placeholder
|
|
784
|
+
model=ModelArguments(""),
|
|
785
|
+
# Despite using torchrl as an agent, we still need the agent config - see StateToMessage parser
|
|
786
|
+
agent_config_path=CONFIG_DIR / "agents" / "default.yaml",
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
register_task(environment_args)
|
|
790
|
+
|
|
791
|
+
_TASK_IDS[task] = (environment_args.task.id, agent_args)
|
|
792
|
+
|
|
793
|
+
return environment_args, agent_args
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def make_mlgym(
|
|
797
|
+
*,
|
|
798
|
+
task: Literal["prisonersDilemma"] | None = None,
|
|
799
|
+
tasks: list[Literal["prisonersDilemma"]] | None = None,
|
|
800
|
+
tokenizer: transformers.AutoTokenizer | str | None = None, # noqa
|
|
801
|
+
device="cpu",
|
|
802
|
+
reward_wrong_format: float | None = None,
|
|
803
|
+
) -> TransformedEnv:
|
|
804
|
+
"""Wraps an MLGymEnv in a TorchRL Environment.
|
|
805
|
+
|
|
806
|
+
The appended transforms will make sure that the data is formatted for the LLM during (for the outputs of `env.step`)
|
|
807
|
+
and for the MLGym API (for inputs to `env.step`).
|
|
808
|
+
|
|
809
|
+
Keyword Args:
|
|
810
|
+
task (str): The task to wrap. Exclusive with `tasks` argument.
|
|
811
|
+
|
|
812
|
+
.. note:: The correct format is simply the task name, e.g., `"prisonersDilemma"`.
|
|
813
|
+
|
|
814
|
+
tasks (List[str]): The tasks available for the env. Exclusive with `task` argument.
|
|
815
|
+
|
|
816
|
+
.. note:: The correct format is simply the task name, e.g., `"prisonersDilemma"`.
|
|
817
|
+
|
|
818
|
+
tokenizer (transformers.AutoTokenizer or str, optional): A transformer that tokenizes the data.
|
|
819
|
+
If a string is passed, it will be converted to a `transformers.AutoTokenizer`.
|
|
820
|
+
device (str, optional): The device to set to the env. Defaults to "cpu".
|
|
821
|
+
reward_wrong_format (float, optional): The reward (negative penalty) for wrongly formatted actions.
|
|
822
|
+
Defaults to `None` (no penalty).
|
|
823
|
+
|
|
824
|
+
"""
|
|
825
|
+
import gymnasium as gym
|
|
826
|
+
|
|
827
|
+
if isinstance(tokenizer, str):
|
|
828
|
+
import transformers
|
|
829
|
+
|
|
830
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
|
|
831
|
+
|
|
832
|
+
with _temp_cwd_mlgym():
|
|
833
|
+
|
|
834
|
+
if task and not tasks:
|
|
835
|
+
environment_args, agent_args = get_args(task=task)
|
|
836
|
+
elif tasks and not task:
|
|
837
|
+
for task in tasks:
|
|
838
|
+
environment_args, agent_args = get_args(task=task)
|
|
839
|
+
else:
|
|
840
|
+
raise ValueError(
|
|
841
|
+
f"Either task or tasks should be provided, not both and not none. Got {task=} and {tasks=}."
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
base_env = gym.make(
|
|
845
|
+
f"mlgym/{_TASK_IDS[task][0]}",
|
|
846
|
+
devices=["cpu_0"],
|
|
847
|
+
).unwrapped
|
|
848
|
+
# we need the env to have access to the config
|
|
849
|
+
base_env.config = agent_args.config
|
|
850
|
+
env = TransformedEnv(
|
|
851
|
+
MLGymWrapper(base_env, auto_reset=False, device=device), auto_unwrap=False
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
env.append_transform(ConditionalSkip(lambda td: td["retry"]))
|
|
855
|
+
env.append_transform(IsolateCodeBlock(reward_wrong_format=reward_wrong_format))
|
|
856
|
+
|
|
857
|
+
env.append_transform(ResetModule())
|
|
858
|
+
if tasks:
|
|
859
|
+
# Add a task sampler
|
|
860
|
+
env.append_transform(TaskSampler(tasks))
|
|
861
|
+
env.append_transform(ReadState())
|
|
862
|
+
env.append_transform(StateToMessage())
|
|
863
|
+
env.append_transform(MessageToHistory())
|
|
864
|
+
env.append_transform(TemplateTransform(tokenizer=tokenizer))
|
|
865
|
+
env.append_transform(MLGymRewardAssignment())
|
|
866
|
+
# # We want the env to have a batch-size of (1,) because it will be easier to interact with
|
|
867
|
+
# # LLMs
|
|
868
|
+
# env.append_transform(BatchSizeTransform(batch_size=(1,)))
|
|
869
|
+
return env
|