torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Browser automation transform for LLM agents."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
from typing import Any
|
|
12
|
+
from urllib.parse import urlparse
|
|
13
|
+
|
|
14
|
+
from tensordict import TensorDictBase
|
|
15
|
+
|
|
16
|
+
from torchrl.envs.llm.transforms.tools import SimpleToolTransform
|
|
17
|
+
|
|
18
|
+
# Schema for the browser tool
|
|
19
|
+
BROWSER_SCHEMA = {
|
|
20
|
+
"name": "browser",
|
|
21
|
+
"description": "Browse and interact with web pages",
|
|
22
|
+
"parameters": {
|
|
23
|
+
"type": "object",
|
|
24
|
+
"properties": {
|
|
25
|
+
"action": {
|
|
26
|
+
"type": "string",
|
|
27
|
+
"enum": [
|
|
28
|
+
"navigate",
|
|
29
|
+
"click",
|
|
30
|
+
"type",
|
|
31
|
+
"screenshot",
|
|
32
|
+
"extract",
|
|
33
|
+
"scroll",
|
|
34
|
+
],
|
|
35
|
+
"description": "The action to perform",
|
|
36
|
+
},
|
|
37
|
+
"url": {
|
|
38
|
+
"type": "string",
|
|
39
|
+
"description": "URL to navigate to (for navigate action)",
|
|
40
|
+
},
|
|
41
|
+
"selector": {
|
|
42
|
+
"type": "string",
|
|
43
|
+
"description": "CSS selector to target element (for click/type/extract actions)",
|
|
44
|
+
},
|
|
45
|
+
"text": {
|
|
46
|
+
"type": "string",
|
|
47
|
+
"description": "Text to type (for type action)",
|
|
48
|
+
},
|
|
49
|
+
"scroll_amount": {
|
|
50
|
+
"type": "integer",
|
|
51
|
+
"description": "Amount to scroll in pixels (for scroll action)",
|
|
52
|
+
},
|
|
53
|
+
"extract_type": {
|
|
54
|
+
"type": "string",
|
|
55
|
+
"enum": ["text", "html", "attribute"],
|
|
56
|
+
"description": "What to extract from the element (for extract action)",
|
|
57
|
+
},
|
|
58
|
+
"attribute": {
|
|
59
|
+
"type": "string",
|
|
60
|
+
"description": "Attribute name to extract (for extract action with extract_type=attribute)",
|
|
61
|
+
},
|
|
62
|
+
},
|
|
63
|
+
"required": ["action"],
|
|
64
|
+
"allOf": [
|
|
65
|
+
{
|
|
66
|
+
"if": {"properties": {"action": {"const": "navigate"}}},
|
|
67
|
+
"then": {"required": ["url"]},
|
|
68
|
+
},
|
|
69
|
+
{
|
|
70
|
+
"if": {
|
|
71
|
+
"properties": {
|
|
72
|
+
"action": {"enum": ["click", "type", "extract"]},
|
|
73
|
+
}
|
|
74
|
+
},
|
|
75
|
+
"then": {"required": ["selector"]},
|
|
76
|
+
},
|
|
77
|
+
{
|
|
78
|
+
"if": {"properties": {"action": {"const": "type"}}},
|
|
79
|
+
"then": {"required": ["text"]},
|
|
80
|
+
},
|
|
81
|
+
{
|
|
82
|
+
"if": {"properties": {"action": {"const": "scroll"}}},
|
|
83
|
+
"then": {"required": ["scroll_amount"]},
|
|
84
|
+
},
|
|
85
|
+
{
|
|
86
|
+
"if": {"properties": {"action": {"const": "extract"}}},
|
|
87
|
+
"then": {"required": ["extract_type"]},
|
|
88
|
+
},
|
|
89
|
+
{
|
|
90
|
+
"if": {
|
|
91
|
+
"properties": {
|
|
92
|
+
"action": {"const": "extract"},
|
|
93
|
+
"extract_type": {"const": "attribute"},
|
|
94
|
+
}
|
|
95
|
+
},
|
|
96
|
+
"then": {"required": ["attribute"]},
|
|
97
|
+
},
|
|
98
|
+
],
|
|
99
|
+
},
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class BrowserTransform(SimpleToolTransform):
|
|
104
|
+
"""A transform that enables web browsing capabilities.
|
|
105
|
+
|
|
106
|
+
This transform allows LLM agents to interact with web pages through a browser,
|
|
107
|
+
supporting actions like navigation, clicking, typing, and extracting content.
|
|
108
|
+
|
|
109
|
+
For a complete example of how to use this transform, see the LLM Tools tutorial in the documentation.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
allowed_domains (list[str], optional): List of allowed domains. If None, all domains are allowed.
|
|
113
|
+
headless (bool): Whether to run browser in headless mode. Defaults to True.
|
|
114
|
+
timeout (float): Timeout for browser operations in seconds. Defaults to 30.0.
|
|
115
|
+
tokenizer: The tokenizer to use. Defaults to None.
|
|
116
|
+
tool_name (str): The name of the tool in chat history. Defaults to "tool".
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
allowed_domains: list[str] | None = None,
|
|
122
|
+
headless: bool = True,
|
|
123
|
+
timeout: float = 30.0,
|
|
124
|
+
tokenizer=None, # type: ignore
|
|
125
|
+
tool_name: str = "tool",
|
|
126
|
+
):
|
|
127
|
+
self.allowed_domains = allowed_domains
|
|
128
|
+
self.headless = headless
|
|
129
|
+
self.browser = None
|
|
130
|
+
self.context = None
|
|
131
|
+
self.page = None
|
|
132
|
+
self.loop = asyncio.get_event_loop()
|
|
133
|
+
|
|
134
|
+
super().__init__(
|
|
135
|
+
tools={"browser": self._execute_browser_action},
|
|
136
|
+
tool_schemas={"browser": BROWSER_SCHEMA},
|
|
137
|
+
tool_name=tool_name,
|
|
138
|
+
timeout=timeout,
|
|
139
|
+
)
|
|
140
|
+
self.tokenizer = tokenizer
|
|
141
|
+
|
|
142
|
+
async def _init_browser(self):
|
|
143
|
+
"""Initialize the browser if not already initialized."""
|
|
144
|
+
from playwright.async_api import async_playwright
|
|
145
|
+
|
|
146
|
+
if self.browser is None:
|
|
147
|
+
playwright = await async_playwright().start()
|
|
148
|
+
self.browser = await playwright.chromium.launch(headless=self.headless)
|
|
149
|
+
self.context = await self.browser.new_context()
|
|
150
|
+
self.page = await self.context.new_page()
|
|
151
|
+
|
|
152
|
+
def _validate_url(self, url: str) -> bool:
|
|
153
|
+
"""Validate if the URL is allowed based on domain restrictions."""
|
|
154
|
+
if not self.allowed_domains:
|
|
155
|
+
return True
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
domain = urlparse(url).netloc
|
|
159
|
+
return any(domain.endswith(d) for d in self.allowed_domains)
|
|
160
|
+
except Exception:
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
async def _navigate(self, url: str) -> dict[str, Any]:
|
|
164
|
+
"""Navigate to a URL."""
|
|
165
|
+
if not self._validate_url(url):
|
|
166
|
+
return {
|
|
167
|
+
"success": False,
|
|
168
|
+
"error": f"Domain not allowed. Must be one of: {self.allowed_domains}",
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
await self._init_browser()
|
|
173
|
+
response = await self.page.goto(url, wait_until="networkidle")
|
|
174
|
+
return {
|
|
175
|
+
"success": True,
|
|
176
|
+
"result": {
|
|
177
|
+
"url": self.page.url,
|
|
178
|
+
"status": response.status if response else None,
|
|
179
|
+
},
|
|
180
|
+
}
|
|
181
|
+
except Exception as e:
|
|
182
|
+
return {"success": False, "error": str(e)}
|
|
183
|
+
|
|
184
|
+
async def _click(self, selector: str) -> dict[str, Any]:
|
|
185
|
+
"""Click an element on the page."""
|
|
186
|
+
try:
|
|
187
|
+
await self._init_browser()
|
|
188
|
+
await self.page.click(selector)
|
|
189
|
+
return {"success": True, "result": {"clicked": selector}}
|
|
190
|
+
except Exception as e:
|
|
191
|
+
return {"success": False, "error": str(e)}
|
|
192
|
+
|
|
193
|
+
async def _type(self, selector: str, text: str) -> dict[str, Any]:
|
|
194
|
+
"""Type text into an element."""
|
|
195
|
+
try:
|
|
196
|
+
await self._init_browser()
|
|
197
|
+
await self.page.fill(selector, text)
|
|
198
|
+
return {"success": True, "result": {"typed": text, "into": selector}}
|
|
199
|
+
except Exception as e:
|
|
200
|
+
return {"success": False, "error": str(e)}
|
|
201
|
+
|
|
202
|
+
async def _scroll(self, amount: int) -> dict[str, Any]:
|
|
203
|
+
"""Scroll the page."""
|
|
204
|
+
try:
|
|
205
|
+
await self._init_browser()
|
|
206
|
+
await self.page.evaluate(f"window.scrollBy(0, {amount})")
|
|
207
|
+
return {"success": True, "result": {"scrolled": amount}}
|
|
208
|
+
except Exception as e:
|
|
209
|
+
return {"success": False, "error": str(e)}
|
|
210
|
+
|
|
211
|
+
async def _extract(
|
|
212
|
+
self, selector: str, extract_type: str, attribute: str | None = None
|
|
213
|
+
) -> dict[str, Any]:
|
|
214
|
+
"""Extract content from the page."""
|
|
215
|
+
try:
|
|
216
|
+
await self._init_browser()
|
|
217
|
+
element = await self.page.wait_for_selector(selector)
|
|
218
|
+
if not element:
|
|
219
|
+
return {"success": False, "error": f"Element not found: {selector}"}
|
|
220
|
+
|
|
221
|
+
if extract_type == "text":
|
|
222
|
+
content = await element.text_content()
|
|
223
|
+
elif extract_type == "html":
|
|
224
|
+
content = await element.inner_html()
|
|
225
|
+
elif extract_type == "attribute" and attribute:
|
|
226
|
+
content = await element.get_attribute(attribute)
|
|
227
|
+
else:
|
|
228
|
+
return {"success": False, "error": "Invalid extraction type"}
|
|
229
|
+
|
|
230
|
+
return {
|
|
231
|
+
"success": True,
|
|
232
|
+
"result": {"content": content, "type": extract_type},
|
|
233
|
+
}
|
|
234
|
+
except Exception as e:
|
|
235
|
+
return {"success": False, "error": str(e)}
|
|
236
|
+
|
|
237
|
+
async def _execute_browser_action_async(self, **kwargs) -> dict[str, Any]:
|
|
238
|
+
"""Execute a browser action asynchronously."""
|
|
239
|
+
action = kwargs.pop("action")
|
|
240
|
+
|
|
241
|
+
if action == "navigate":
|
|
242
|
+
return await self._navigate(kwargs["url"])
|
|
243
|
+
elif action == "click":
|
|
244
|
+
return await self._click(kwargs["selector"])
|
|
245
|
+
elif action == "type":
|
|
246
|
+
return await self._type(kwargs["selector"], kwargs["text"])
|
|
247
|
+
elif action == "scroll":
|
|
248
|
+
return await self._scroll(kwargs["scroll_amount"])
|
|
249
|
+
elif action == "extract":
|
|
250
|
+
return await self._extract(
|
|
251
|
+
kwargs["selector"],
|
|
252
|
+
kwargs["extract_type"],
|
|
253
|
+
kwargs.get("attribute"),
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
return {"success": False, "error": f"Unknown action: {action}"}
|
|
257
|
+
|
|
258
|
+
def _execute_browser_action(self, **kwargs) -> dict[str, Any]:
|
|
259
|
+
"""Execute a browser action."""
|
|
260
|
+
return self.loop.run_until_complete(
|
|
261
|
+
self._execute_browser_action_async(**kwargs)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def close(self):
|
|
265
|
+
"""Close the browser and clean up resources."""
|
|
266
|
+
if self.browser:
|
|
267
|
+
self.loop.run_until_complete(self.browser.close())
|
|
268
|
+
self.browser = None
|
|
269
|
+
self.context = None
|
|
270
|
+
self.page = None
|
|
271
|
+
|
|
272
|
+
def __del__(self):
|
|
273
|
+
"""Ensure browser is closed on deletion."""
|
|
274
|
+
self.close()
|
|
275
|
+
|
|
276
|
+
def _reset(
|
|
277
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
278
|
+
) -> TensorDictBase:
|
|
279
|
+
"""Reset the browser state."""
|
|
280
|
+
# Close and reinitialize browser on reset
|
|
281
|
+
self.close()
|
|
282
|
+
return tensordict_reset
|
|
283
|
+
|
|
284
|
+
def clone(self):
|
|
285
|
+
"""Clone the browser transform."""
|
|
286
|
+
return self.__class__(
|
|
287
|
+
allowed_domains=self.allowed_domains,
|
|
288
|
+
headless=self.headless,
|
|
289
|
+
timeout=self.timeout,
|
|
290
|
+
tokenizer=self.tokenizer,
|
|
291
|
+
tool_name=self.tool_name,
|
|
292
|
+
)
|