rlinf 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rlinf/__init__.py +17 -0
- rlinf/agents/__init__.py +13 -0
- rlinf/agents/mas_search/__init__.py +18 -0
- rlinf/agents/mas_search/mas_search_agent_loop.py +222 -0
- rlinf/agents/rstar2/__init__.py +13 -0
- rlinf/agents/rstar2/http_code_judge_tool.py +331 -0
- rlinf/agents/rstar2/http_tool_worker.py +176 -0
- rlinf/agents/rstar2/rstar2_agent_loop.py +324 -0
- rlinf/agents/searchr1/__init__.py +13 -0
- rlinf/agents/searchr1/eval_runner.py +260 -0
- rlinf/agents/searchr1/search_tool_worker.py +157 -0
- rlinf/agents/searchr1/searchr1_agent_loop.py +189 -0
- rlinf/agents/wideseek_r1/__init__.py +13 -0
- rlinf/agents/wideseek_r1/eval_runner.py +609 -0
- rlinf/agents/wideseek_r1/tools.py +785 -0
- rlinf/agents/wideseek_r1/utils/__init__.py +13 -0
- rlinf/agents/wideseek_r1/utils/metrics.py +188 -0
- rlinf/agents/wideseek_r1/utils/prompt.py +529 -0
- rlinf/agents/wideseek_r1/utils/prompt_utils.py +261 -0
- rlinf/agents/wideseek_r1/utils/reward.py +671 -0
- rlinf/agents/wideseek_r1/utils/sglang_client.py +144 -0
- rlinf/agents/wideseek_r1/utils/tool_description.py +285 -0
- rlinf/agents/wideseek_r1/utils/webpage.py +192 -0
- rlinf/agents/wideseek_r1/wideseek_r1.py +843 -0
- rlinf/algorithms/__init__.py +14 -0
- rlinf/algorithms/advantages.py +350 -0
- rlinf/algorithms/loss_scales.py +182 -0
- rlinf/algorithms/losses.py +461 -0
- rlinf/algorithms/registry.py +156 -0
- rlinf/algorithms/rewards/__init__.py +38 -0
- rlinf/algorithms/rewards/code/__init__.py +33 -0
- rlinf/algorithms/rewards/code/code_verifier/__init__.py +13 -0
- rlinf/algorithms/rewards/code/code_verifier/verify.py +230 -0
- rlinf/algorithms/rewards/math/__init__.py +42 -0
- rlinf/algorithms/rewards/math/math_verifier/__init__.py +13 -0
- rlinf/algorithms/rewards/math/math_verifier/parser.py +441 -0
- rlinf/algorithms/rewards/math/math_verifier/verify.py +441 -0
- rlinf/algorithms/rewards/rstar2/__init__.py +182 -0
- rlinf/algorithms/rewards/rstar2/fused_compute_score/__init__.py +13 -0
- rlinf/algorithms/rewards/rstar2/fused_compute_score/compute_score.py +37 -0
- rlinf/algorithms/rewards/rstar2/fused_compute_score/math_verify.py +43 -0
- rlinf/algorithms/rewards/rstar2/fused_compute_score/prime_math/__init__.py +440 -0
- rlinf/algorithms/rewards/rstar2/fused_compute_score/prime_math/grader.py +545 -0
- rlinf/algorithms/rewards/rstar2/fused_compute_score/prime_math/math_normalize.py +192 -0
- rlinf/algorithms/rewards/searchr1/__init__.py +181 -0
- rlinf/algorithms/rewards/vqa/__init__.py +60 -0
- rlinf/algorithms/rewards/vqa/format_rewards.py +66 -0
- rlinf/algorithms/rewards/vqa/qa_rewards.py +109 -0
- rlinf/algorithms/toolcall_parsers.py +297 -0
- rlinf/algorithms/utils.py +398 -0
- rlinf/config.py +1358 -0
- rlinf/data/__init__.py +13 -0
- rlinf/data/datasets/__init__.py +267 -0
- rlinf/data/datasets/item.py +74 -0
- rlinf/data/datasets/reasoning.py +250 -0
- rlinf/data/datasets/rstar2.py +100 -0
- rlinf/data/datasets/vlm.py +684 -0
- rlinf/data/datasets/wideseek_r1.py +139 -0
- rlinf/data/datasets/world_model.py +375 -0
- rlinf/data/embodied_buffer_dataset.py +287 -0
- rlinf/data/embodied_io_struct.py +660 -0
- rlinf/data/io_struct.py +1827 -0
- rlinf/data/lerobot_writer.py +1065 -0
- rlinf/data/replay_buffer.py +1169 -0
- rlinf/data/tokenizers.py +71 -0
- rlinf/data/tool_call/__init__.py +13 -0
- rlinf/data/tool_call/tool_io_struct.py +122 -0
- rlinf/data/utils.py +53 -0
- rlinf/envs/__init__.py +117 -0
- rlinf/envs/action_utils.py +246 -0
- rlinf/envs/behavior/__init__.py +13 -0
- rlinf/envs/behavior/behavior_env.py +312 -0
- rlinf/envs/calvin/__init__.py +129 -0
- rlinf/envs/calvin/calvin_gym_env.py +486 -0
- rlinf/envs/calvin/utils.py +75 -0
- rlinf/envs/calvin/venv.py +264 -0
- rlinf/envs/frankasim/__init__.py +19 -0
- rlinf/envs/frankasim/frankasim_env.py +722 -0
- rlinf/envs/habitat/__init__.py +17 -0
- rlinf/envs/habitat/extensions/__init__.py +13 -0
- rlinf/envs/habitat/extensions/config/vlnce_r2r.yaml +68 -0
- rlinf/envs/habitat/extensions/maps.py +357 -0
- rlinf/envs/habitat/extensions/utils.py +711 -0
- rlinf/envs/habitat/habitat_env.py +348 -0
- rlinf/envs/habitat/venv.py +246 -0
- rlinf/envs/isaaclab/__init__.py +21 -0
- rlinf/envs/isaaclab/isaaclab_env.py +264 -0
- rlinf/envs/isaaclab/tasks/__init__.py +13 -0
- rlinf/envs/isaaclab/tasks/stack_cube.py +97 -0
- rlinf/envs/isaaclab/utils.py +62 -0
- rlinf/envs/isaaclab/venv.py +118 -0
- rlinf/envs/libero/__init__.py +13 -0
- rlinf/envs/libero/libero_env.py +462 -0
- rlinf/envs/libero/utils.py +125 -0
- rlinf/envs/libero/venv.py +173 -0
- rlinf/envs/maniskill/__init__.py +33 -0
- rlinf/envs/maniskill/maniskill_env.py +422 -0
- rlinf/envs/maniskill/maniskill_offload_env.py +449 -0
- rlinf/envs/maniskill/tasks/__init__.py +13 -0
- rlinf/envs/maniskill/tasks/panda_put_on_in_scene_multi.py +1380 -0
- rlinf/envs/maniskill/tasks/panda_table_agent.py +136 -0
- rlinf/envs/maniskill/tasks/pose_utils.py +311 -0
- rlinf/envs/maniskill/tasks/put_carrot_on_plate.py +151 -0
- rlinf/envs/maniskill/tasks/put_on_in_scene_multi.py +963 -0
- rlinf/envs/maniskill/tasks/variants/__init__.py +51 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_carrot.py +77 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_ee_pose.py +289 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_image.py +105 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_instruct.py +123 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_multi_carrot.py +352 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_multi_plate.py +399 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_plate.py +97 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_position.py +231 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_position_change.py +179 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_single.py +110 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_image.py +49 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_texture.py +325 -0
- rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_whole.py +364 -0
- rlinf/envs/maniskill/tasks/variants/utils.py +31 -0
- rlinf/envs/maniskill/utils.py +66 -0
- rlinf/envs/metaworld/__init__.py +61 -0
- rlinf/envs/metaworld/metaworld_env.py +442 -0
- rlinf/envs/metaworld/utils.py +21 -0
- rlinf/envs/metaworld/venv.py +170 -0
- rlinf/envs/realworld/__init__.py +33 -0
- rlinf/envs/realworld/common/camera/__init__.py +17 -0
- rlinf/envs/realworld/common/camera/camera.py +143 -0
- rlinf/envs/realworld/common/keyboard/__init__.py +13 -0
- rlinf/envs/realworld/common/keyboard/keyboard_listener.py +42 -0
- rlinf/envs/realworld/common/ros/__init__.py +17 -0
- rlinf/envs/realworld/common/ros/ros_controller.py +129 -0
- rlinf/envs/realworld/common/spacemouse/__init__.py +13 -0
- rlinf/envs/realworld/common/spacemouse/spacemouse_expert.py +74 -0
- rlinf/envs/realworld/common/video_player/__init__.py +17 -0
- rlinf/envs/realworld/common/video_player/video_player.py +55 -0
- rlinf/envs/realworld/common/wrappers/__init__.py +31 -0
- rlinf/envs/realworld/common/wrappers/euler_obs.py +39 -0
- rlinf/envs/realworld/common/wrappers/gripper_close.py +41 -0
- rlinf/envs/realworld/common/wrappers/relative_frame.py +141 -0
- rlinf/envs/realworld/common/wrappers/reward_done_wrapper.py +105 -0
- rlinf/envs/realworld/common/wrappers/spacemouse_intervention.py +72 -0
- rlinf/envs/realworld/franka/__init__.py +18 -0
- rlinf/envs/realworld/franka/franka_controller.py +375 -0
- rlinf/envs/realworld/franka/franka_env.py +573 -0
- rlinf/envs/realworld/franka/franka_robot_state.py +48 -0
- rlinf/envs/realworld/franka/tasks/__init__.py +35 -0
- rlinf/envs/realworld/franka/tasks/bottle.py +136 -0
- rlinf/envs/realworld/franka/tasks/franka_bin_relocation.py +240 -0
- rlinf/envs/realworld/franka/tasks/peg_insertion_env.py +129 -0
- rlinf/envs/realworld/franka/utils.py +105 -0
- rlinf/envs/realworld/realworld_env.py +395 -0
- rlinf/envs/realworld/venv.py +319 -0
- rlinf/envs/realworld/xsquare/__init__.py +18 -0
- rlinf/envs/realworld/xsquare/tasks/__init__.py +24 -0
- rlinf/envs/realworld/xsquare/tasks/button_env.py +79 -0
- rlinf/envs/realworld/xsquare/turtle2_env.py +567 -0
- rlinf/envs/realworld/xsquare/turtle2_robot_state.py +49 -0
- rlinf/envs/realworld/xsquare/turtle2_smooth_controller.py +264 -0
- rlinf/envs/robocasa/__init__.py +17 -0
- rlinf/envs/robocasa/robocasa_env.py +509 -0
- rlinf/envs/robocasa/utils.py +178 -0
- rlinf/envs/robocasa/venv.py +163 -0
- rlinf/envs/robotwin/__init__.py +13 -0
- rlinf/envs/robotwin/robotwin_env.py +506 -0
- rlinf/envs/utils.py +317 -0
- rlinf/envs/venv/__init__.py +33 -0
- rlinf/envs/venv/venv.py +985 -0
- rlinf/envs/world_model/__init__.py +13 -0
- rlinf/envs/world_model/base_world_env.py +158 -0
- rlinf/envs/world_model/world_model_opensora_env.py +917 -0
- rlinf/envs/world_model/world_model_wan_env.py +833 -0
- rlinf/envs/wrappers/__init__.py +18 -0
- rlinf/envs/wrappers/collect_episode.py +642 -0
- rlinf/envs/wrappers/record_video.py +457 -0
- rlinf/hybrid_engines/__init__.py +13 -0
- rlinf/hybrid_engines/fsdp/__init__.py +49 -0
- rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +625 -0
- rlinf/hybrid_engines/fsdp/strategy/__init__.py +13 -0
- rlinf/hybrid_engines/fsdp/strategy/base.py +547 -0
- rlinf/hybrid_engines/fsdp/strategy/checkpoint.py +132 -0
- rlinf/hybrid_engines/fsdp/strategy/fsdp.py +351 -0
- rlinf/hybrid_engines/fsdp/strategy/fsdp2.py +203 -0
- rlinf/hybrid_engines/fsdp/utils.py +1014 -0
- rlinf/hybrid_engines/megatron/__init__.py +13 -0
- rlinf/hybrid_engines/megatron/megatron_model_manager.py +842 -0
- rlinf/hybrid_engines/megatron/token_dispatcher.py +600 -0
- rlinf/hybrid_engines/megatron/utils.py +240 -0
- rlinf/hybrid_engines/sglang/common/__init__.py +13 -0
- rlinf/hybrid_engines/sglang/common/detokenizer_manager.py +61 -0
- rlinf/hybrid_engines/sglang/common/io_struct.py +52 -0
- rlinf/hybrid_engines/sglang/common/sgl_engine.py +138 -0
- rlinf/hybrid_engines/sglang/common/sgl_scheduler.py +592 -0
- rlinf/hybrid_engines/sglang/common/tokenizer_manager.py +240 -0
- rlinf/hybrid_engines/vllm/vllm_0_8_5/__init__.py +13 -0
- rlinf/hybrid_engines/vllm/vllm_0_8_5/executor.py +313 -0
- rlinf/hybrid_engines/vllm/vllm_0_8_5/weight_loader.py +43 -0
- rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py +161 -0
- rlinf/models/__init__.py +96 -0
- rlinf/models/embodiment/__init__.py +13 -0
- rlinf/models/embodiment/base_policy.py +93 -0
- rlinf/models/embodiment/cnn_policy/__init__.py +26 -0
- rlinf/models/embodiment/cnn_policy/cnn_policy.py +621 -0
- rlinf/models/embodiment/dexbotic_pi/__init__.py +120 -0
- rlinf/models/embodiment/dexbotic_pi/dexbotic_pi_policy.py +773 -0
- rlinf/models/embodiment/flow_policy/__init__.py +51 -0
- rlinf/models/embodiment/flow_policy/flow_policy.py +635 -0
- rlinf/models/embodiment/gr00t/__init__.py +82 -0
- rlinf/models/embodiment/gr00t/embodiment_tags.py +59 -0
- rlinf/models/embodiment/gr00t/gr00t_action_model.py +732 -0
- rlinf/models/embodiment/gr00t/modality_config.py +177 -0
- rlinf/models/embodiment/gr00t/simulation_io.py +198 -0
- rlinf/models/embodiment/gr00t/utils.py +127 -0
- rlinf/models/embodiment/mlp_policy/__init__.py +31 -0
- rlinf/models/embodiment/mlp_policy/mlp_policy.py +403 -0
- rlinf/models/embodiment/modules/__init__.py +13 -0
- rlinf/models/embodiment/modules/batch_renorm.py +132 -0
- rlinf/models/embodiment/modules/compact_encoders.py +356 -0
- rlinf/models/embodiment/modules/entropy_tunning.py +69 -0
- rlinf/models/embodiment/modules/explore_noise_net.py +168 -0
- rlinf/models/embodiment/modules/flow_actor.py +469 -0
- rlinf/models/embodiment/modules/gaussian_policy.py +317 -0
- rlinf/models/embodiment/modules/mlp.py +95 -0
- rlinf/models/embodiment/modules/q_head.py +328 -0
- rlinf/models/embodiment/modules/resnet_utils.py +158 -0
- rlinf/models/embodiment/modules/utils.py +65 -0
- rlinf/models/embodiment/modules/value_head.py +67 -0
- rlinf/models/embodiment/openpi/__init__.py +121 -0
- rlinf/models/embodiment/openpi/dataconfig/__init__.py +373 -0
- rlinf/models/embodiment/openpi/dataconfig/behavior_dataconfig.py +102 -0
- rlinf/models/embodiment/openpi/dataconfig/calvin_dataconfig.py +71 -0
- rlinf/models/embodiment/openpi/dataconfig/franka_co_training_dataconfig.py +102 -0
- rlinf/models/embodiment/openpi/dataconfig/franka_dataconfig.py +101 -0
- rlinf/models/embodiment/openpi/dataconfig/gsenv_dataconfig.py +58 -0
- rlinf/models/embodiment/openpi/dataconfig/libero_dataconfig.py +102 -0
- rlinf/models/embodiment/openpi/dataconfig/maniskill_dataconfig.py +103 -0
- rlinf/models/embodiment/openpi/dataconfig/metaworld_dataconfig.py +68 -0
- rlinf/models/embodiment/openpi/dataconfig/robocasa_dataconfig.py +75 -0
- rlinf/models/embodiment/openpi/dataconfig/robotwin_aloha_dataconfig.py +106 -0
- rlinf/models/embodiment/openpi/openpi_action_model.py +1185 -0
- rlinf/models/embodiment/openpi/policies/__init__.py +13 -0
- rlinf/models/embodiment/openpi/policies/aloha_policy.py +241 -0
- rlinf/models/embodiment/openpi/policies/behavior_policy.py +119 -0
- rlinf/models/embodiment/openpi/policies/calvin_policy.py +87 -0
- rlinf/models/embodiment/openpi/policies/franka_policy.py +137 -0
- rlinf/models/embodiment/openpi/policies/gsenv_policy.py +74 -0
- rlinf/models/embodiment/openpi/policies/libero_policy.py +118 -0
- rlinf/models/embodiment/openpi/policies/maniskill_policy.py +114 -0
- rlinf/models/embodiment/openpi/policies/metaworld_policy.py +75 -0
- rlinf/models/embodiment/openpi/policies/robocasa_policy.py +126 -0
- rlinf/models/embodiment/openvla/__init__.py +104 -0
- rlinf/models/embodiment/openvla/openvla_action_model.py +808 -0
- rlinf/models/embodiment/openvla_oft/__init__.py +32 -0
- rlinf/models/embodiment/openvla_oft/official/__init__.py +118 -0
- rlinf/models/embodiment/openvla_oft/official/openvla_oft_action_model.py +703 -0
- rlinf/models/embodiment/openvla_oft/openvla_utils.py +193 -0
- rlinf/models/embodiment/openvla_oft/rlinf/__init__.py +109 -0
- rlinf/models/embodiment/openvla_oft/rlinf/openvla_oft_action_model.py +574 -0
- rlinf/models/embodiment/prismatic/__init__.py +13 -0
- rlinf/models/embodiment/prismatic/processing_prismatic.py +243 -0
- rlinf/runners/__init__.py +13 -0
- rlinf/runners/agent_eval_runner.py +248 -0
- rlinf/runners/agent_runner.py +326 -0
- rlinf/runners/async_embodied_runner.py +274 -0
- rlinf/runners/async_ppo_embodied_runner.py +260 -0
- rlinf/runners/coding_online_rl_runner.py +308 -0
- rlinf/runners/embodied_eval_runner.py +80 -0
- rlinf/runners/embodied_runner.py +439 -0
- rlinf/runners/reasoning_eval_runner.py +179 -0
- rlinf/runners/reasoning_runner.py +645 -0
- rlinf/runners/sft_runner.py +168 -0
- rlinf/scheduler/__init__.py +56 -0
- rlinf/scheduler/channel/__init__.py +18 -0
- rlinf/scheduler/channel/channel.py +648 -0
- rlinf/scheduler/channel/channel_worker.py +536 -0
- rlinf/scheduler/cluster/__init__.py +30 -0
- rlinf/scheduler/cluster/cluster.py +525 -0
- rlinf/scheduler/cluster/config.py +442 -0
- rlinf/scheduler/cluster/node.py +554 -0
- rlinf/scheduler/cluster/utils.py +604 -0
- rlinf/scheduler/collective/__init__.py +36 -0
- rlinf/scheduler/collective/async_work.py +386 -0
- rlinf/scheduler/collective/collective.py +96 -0
- rlinf/scheduler/collective/collective_group.py +1827 -0
- rlinf/scheduler/collective/multi_channel_pg.py +927 -0
- rlinf/scheduler/dynamic_scheduler/__init__.py +13 -0
- rlinf/scheduler/dynamic_scheduler/manager.py +1069 -0
- rlinf/scheduler/dynamic_scheduler/scheduler_worker.py +129 -0
- rlinf/scheduler/dynamic_scheduler/utils.py +162 -0
- rlinf/scheduler/hardware/__init__.py +38 -0
- rlinf/scheduler/hardware/accelerators/__init__.py +31 -0
- rlinf/scheduler/hardware/accelerators/accelerator.py +300 -0
- rlinf/scheduler/hardware/accelerators/amd_gpu.py +143 -0
- rlinf/scheduler/hardware/accelerators/ascend_npu.py +113 -0
- rlinf/scheduler/hardware/accelerators/intel_gpu.py +113 -0
- rlinf/scheduler/hardware/accelerators/musa_gpu.py +147 -0
- rlinf/scheduler/hardware/accelerators/nvidia_gpu.py +203 -0
- rlinf/scheduler/hardware/hardware.py +180 -0
- rlinf/scheduler/hardware/robots/__init__.py +18 -0
- rlinf/scheduler/hardware/robots/franka.py +176 -0
- rlinf/scheduler/hardware/robots/xsquare.py +86 -0
- rlinf/scheduler/manager/__init__.py +32 -0
- rlinf/scheduler/manager/coll_manager.py +140 -0
- rlinf/scheduler/manager/lock_manager.py +187 -0
- rlinf/scheduler/manager/manager.py +123 -0
- rlinf/scheduler/manager/node_manager.py +46 -0
- rlinf/scheduler/manager/worker_manager.py +318 -0
- rlinf/scheduler/placement/__init__.py +27 -0
- rlinf/scheduler/placement/flexible.py +277 -0
- rlinf/scheduler/placement/node.py +205 -0
- rlinf/scheduler/placement/packed.py +335 -0
- rlinf/scheduler/placement/placement.py +674 -0
- rlinf/scheduler/worker/__init__.py +24 -0
- rlinf/scheduler/worker/lock.py +103 -0
- rlinf/scheduler/worker/worker.py +1250 -0
- rlinf/scheduler/worker/worker_group.py +556 -0
- rlinf/utils/__init__.py +13 -0
- rlinf/utils/ckpt_convertor/__init__.py +13 -0
- rlinf/utils/ckpt_convertor/convert_openpi_jax_to_python.py +706 -0
- rlinf/utils/ckpt_convertor/fsdp_convertor/__init__.py +13 -0
- rlinf/utils/ckpt_convertor/fsdp_convertor/config/fsdp_model_convertor.yaml +27 -0
- rlinf/utils/ckpt_convertor/fsdp_convertor/convert_dcp_to_pt.py +58 -0
- rlinf/utils/ckpt_convertor/fsdp_convertor/convert_pt_to_hf.py +81 -0
- rlinf/utils/ckpt_convertor/fsdp_convertor/utils.py +197 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/__init__.py +13 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/config.py +208 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/convert_hf_to_mg.py +410 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/convert_hf_to_middle_file.py +503 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/convert_mg_to_middle_file.py +863 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/convert_middle_file_to_hf.py +726 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/convert_middle_file_to_mg.py +626 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/default_args.yaml +151 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/utils/__init__.py +31 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/utils/fp8_utils.py +135 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/utils/mg_loader.py +171 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/utils/mg_moe_groupgemm.py +198 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/utils/mp_utils.py +61 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/utils/safetensors_loader.py +116 -0
- rlinf/utils/ckpt_convertor/megatron_convertor/utils/tensor_operations.py +402 -0
- rlinf/utils/comm_mapping.py +91 -0
- rlinf/utils/convertor/__init__.py +13 -0
- rlinf/utils/convertor/utils.py +637 -0
- rlinf/utils/cuda_graph.py +274 -0
- rlinf/utils/data_iter_utils.py +718 -0
- rlinf/utils/data_process.py +90 -0
- rlinf/utils/distributed.py +1316 -0
- rlinf/utils/drq.py +109 -0
- rlinf/utils/flops.py +240 -0
- rlinf/utils/initialize.py +333 -0
- rlinf/utils/logging.py +20 -0
- rlinf/utils/metric_logger.py +175 -0
- rlinf/utils/metric_utils.py +348 -0
- rlinf/utils/nested_dict_process.py +110 -0
- rlinf/utils/omega_resolver.py +36 -0
- rlinf/utils/patcher.py +217 -0
- rlinf/utils/placement.py +599 -0
- rlinf/utils/profiler.py +244 -0
- rlinf/utils/pytree.py +60 -0
- rlinf/utils/resharding/__init__.py +13 -0
- rlinf/utils/resharding/mcore_weight_reshard.py +335 -0
- rlinf/utils/resharding/reshard_config.py +93 -0
- rlinf/utils/resharding/utils.py +332 -0
- rlinf/utils/runner_utils.py +82 -0
- rlinf/utils/timers.py +196 -0
- rlinf/utils/torch_functionals.py +32 -0
- rlinf/utils/train_utils.py +84 -0
- rlinf/utils/utils.py +509 -0
- rlinf/workers/__init__.py +13 -0
- rlinf/workers/actor/__init__.py +30 -0
- rlinf/workers/actor/async_fsdp_sac_policy_worker.py +138 -0
- rlinf/workers/actor/async_ppo_fsdp_worker.py +366 -0
- rlinf/workers/actor/fsdp_actor_worker.py +1494 -0
- rlinf/workers/actor/fsdp_sac_policy_worker.py +839 -0
- rlinf/workers/actor/ma_megatron_actor_worker.py +710 -0
- rlinf/workers/actor/megatron_actor_worker.py +417 -0
- rlinf/workers/agent/__init__.py +13 -0
- rlinf/workers/agent/agent_loop.py +683 -0
- rlinf/workers/agent/tool_worker.py +43 -0
- rlinf/workers/critic/__init__.py +26 -0
- rlinf/workers/critic/megatron_critic_worker.py +112 -0
- rlinf/workers/env/__init__.py +13 -0
- rlinf/workers/env/async_env_worker.py +94 -0
- rlinf/workers/env/env_worker.py +612 -0
- rlinf/workers/inference/__init__.py +13 -0
- rlinf/workers/inference/fsdp_inference_worker.py +146 -0
- rlinf/workers/inference/megatron_inference_worker.py +129 -0
- rlinf/workers/inference/utils.py +73 -0
- rlinf/workers/megatron_worker.py +1340 -0
- rlinf/workers/reward/__init__.py +13 -0
- rlinf/workers/reward/reward_worker.py +339 -0
- rlinf/workers/rollout/__init__.py +13 -0
- rlinf/workers/rollout/hf/__init__.py +13 -0
- rlinf/workers/rollout/hf/async_huggingface_worker.py +176 -0
- rlinf/workers/rollout/hf/huggingface_worker.py +575 -0
- rlinf/workers/rollout/hf/utils.py +30 -0
- rlinf/workers/rollout/server/__init__.py +13 -0
- rlinf/workers/rollout/server/online_router_worker.py +259 -0
- rlinf/workers/rollout/server/server_rollout_worker.py +378 -0
- rlinf/workers/rollout/sglang/__init__.py +43 -0
- rlinf/workers/rollout/sglang/sglang_worker.py +510 -0
- rlinf/workers/rollout/utils.py +562 -0
- rlinf/workers/rollout/vllm/__init__.py +41 -0
- rlinf/workers/rollout/vllm/vllm_worker.py +508 -0
- rlinf/workers/sft/__init__.py +13 -0
- rlinf/workers/sft/fsdp_sft_worker.py +218 -0
- rlinf/workers/sft/fsdp_vla_sft_worker.py +74 -0
- rlinf/workers/sft/fsdp_vlm_sft_worker.py +317 -0
- rlinf-0.2.dist-info/METADATA +640 -0
- rlinf-0.2.dist-info/RECORD +411 -0
- rlinf-0.2.dist-info/WHEEL +5 -0
- rlinf-0.2.dist-info/licenses/LICENSE +201 -0
- rlinf-0.2.dist-info/top_level.txt +1 -0
rlinf/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 The RLinf Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .utils.omega_resolver import omegaconf_register
|
|
16
|
+
|
|
17
|
+
omegaconf_register()
|
rlinf/agents/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 The RLinf Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2026 The RLinf Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from rlinf.agents.mas_search.mas_search_agent_loop import MasSearchAgentLoopWorker
|
|
16
|
+
from rlinf.agents.searchr1.search_tool_worker import SearchToolWorker
|
|
17
|
+
|
|
18
|
+
__all__ = ["MasSearchAgentLoopWorker", "SearchToolWorker"]
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
# Copyright 2026 The RLinf Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import copy
|
|
17
|
+
import json
|
|
18
|
+
import re
|
|
19
|
+
from uuid import uuid4
|
|
20
|
+
|
|
21
|
+
from omegaconf import DictConfig
|
|
22
|
+
|
|
23
|
+
from rlinf.data.tool_call.tool_io_struct import (
|
|
24
|
+
ToolChannelRequest,
|
|
25
|
+
ToolChannelResponse,
|
|
26
|
+
ToolRequest,
|
|
27
|
+
ToolResponse,
|
|
28
|
+
)
|
|
29
|
+
from rlinf.scheduler import Channel
|
|
30
|
+
from rlinf.utils.placement import ModelParallelComponentPlacement
|
|
31
|
+
from rlinf.workers.agent.agent_loop import (
|
|
32
|
+
AgentLoopOutput,
|
|
33
|
+
MultiAgentLoopOutput,
|
|
34
|
+
MultiAgentLoopWorker,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MasSearchAgentLoopWorker(MultiAgentLoopWorker):
|
|
39
|
+
"""
|
|
40
|
+
Agent loop worker that combines search-r1's <search>keyword</search> extraction
|
|
41
|
+
logic with multi agent system's component structure.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
cfg: DictConfig,
|
|
47
|
+
placement: ModelParallelComponentPlacement,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(cfg, placement)
|
|
50
|
+
self.max_prompt_len = int(self.cfg.data.max_prompt_length)
|
|
51
|
+
max_total_len = int(self.cfg.actor.model.encoder_seq_length)
|
|
52
|
+
self.max_resp_len = max(1, max_total_len - self.max_prompt_len)
|
|
53
|
+
|
|
54
|
+
# Search tool call tokens
|
|
55
|
+
self.tool_call_start_token: str = "<search>"
|
|
56
|
+
self.tool_call_end_token: str = "</search>"
|
|
57
|
+
self.tool_call_regex = re.compile(r"<search>(.*?)</search>", re.DOTALL)
|
|
58
|
+
|
|
59
|
+
# Max turns for multi-turn interaction
|
|
60
|
+
self.max_turns = self.cfg.agentloop.get("max_turns", 5)
|
|
61
|
+
self.return_logprobs = not cfg.algorithm.recompute_logprobs
|
|
62
|
+
|
|
63
|
+
async def state_less_tool_call_with_channel(
|
|
64
|
+
self,
|
|
65
|
+
input_channel: Channel,
|
|
66
|
+
output_channel: Channel,
|
|
67
|
+
tool_name: str,
|
|
68
|
+
tool_args: dict,
|
|
69
|
+
) -> ToolChannelResponse:
|
|
70
|
+
"""state-less tool call with channel, used for demo"""
|
|
71
|
+
session_id = uuid4().hex
|
|
72
|
+
await input_channel.put(
|
|
73
|
+
ToolChannelRequest(
|
|
74
|
+
session_id=session_id,
|
|
75
|
+
request_type="execute",
|
|
76
|
+
tool_name=tool_name,
|
|
77
|
+
tool_args=tool_args,
|
|
78
|
+
),
|
|
79
|
+
async_op=True,
|
|
80
|
+
).async_wait()
|
|
81
|
+
return await output_channel.get(session_id, async_op=True).async_wait()
|
|
82
|
+
|
|
83
|
+
async def tool_call(self, tool_request: ToolRequest) -> ToolResponse:
|
|
84
|
+
tool_name, tool_args = tool_request.name, tool_request.arguments
|
|
85
|
+
tool_channel_info = self.tool_channel_info_map[self.tool_name_map[tool_name]]
|
|
86
|
+
channel_response = await self.state_less_tool_call_with_channel(
|
|
87
|
+
tool_channel_info.input_channel,
|
|
88
|
+
self.tool_worker_output_channel,
|
|
89
|
+
tool_name,
|
|
90
|
+
tool_args,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# no failure in this demo
|
|
94
|
+
assert channel_response.success
|
|
95
|
+
if isinstance(channel_response.result, (list, dict)):
|
|
96
|
+
result_text = json.dumps(channel_response.result)
|
|
97
|
+
else:
|
|
98
|
+
result_text = str(channel_response.result)
|
|
99
|
+
return ToolResponse(
|
|
100
|
+
text=result_text,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
async def extract_tool_calls(self, response_text) -> tuple[str, list[ToolRequest]]:
|
|
104
|
+
"""
|
|
105
|
+
Extract tool calls from response text using <search>keyword</search> format.
|
|
106
|
+
"""
|
|
107
|
+
if (
|
|
108
|
+
self.tool_call_start_token not in response_text
|
|
109
|
+
or self.tool_call_end_token not in response_text
|
|
110
|
+
):
|
|
111
|
+
return response_text, []
|
|
112
|
+
matches = self.tool_call_regex.findall(response_text)
|
|
113
|
+
function_calls = []
|
|
114
|
+
if matches:
|
|
115
|
+
match = matches[-1].strip()
|
|
116
|
+
function_calls.append(
|
|
117
|
+
ToolRequest(name="search", arguments={"keyword": match})
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# remaining text exclude tool call tokens
|
|
121
|
+
content = self.tool_call_regex.sub("", response_text)
|
|
122
|
+
|
|
123
|
+
return content, function_calls
|
|
124
|
+
|
|
125
|
+
async def run_one_query(self, prompt_ids: list[int], *, answer) -> AgentLoopOutput:
|
|
126
|
+
output_buffer = []
|
|
127
|
+
orig_prompt_ids = copy.deepcopy(prompt_ids)
|
|
128
|
+
trace_prints = []
|
|
129
|
+
response_mask = []
|
|
130
|
+
all_response_ids = []
|
|
131
|
+
for _ in range(self.max_turns):
|
|
132
|
+
# Generate response from LLM
|
|
133
|
+
max_resp_len = self.max_resp_len - (len(prompt_ids) - len(orig_prompt_ids))
|
|
134
|
+
|
|
135
|
+
generate_result = await self.generate(
|
|
136
|
+
prompt_ids, sampling_params={"max_new_tokens": max_resp_len}
|
|
137
|
+
)
|
|
138
|
+
generate_prompt_ids = copy.deepcopy(prompt_ids)
|
|
139
|
+
response_ids = generate_result["output_ids"]
|
|
140
|
+
if len(response_ids) > max_resp_len:
|
|
141
|
+
response_ids = response_ids[:max_resp_len]
|
|
142
|
+
response_text = self.tokenizer.decode(response_ids)
|
|
143
|
+
|
|
144
|
+
# # split </search> manually
|
|
145
|
+
# if "</search>" in response_text:
|
|
146
|
+
# response_text = response_text.split("</search>")[0] + "</search>"
|
|
147
|
+
# response_ids = self.tokenizer.encode(response_text)
|
|
148
|
+
|
|
149
|
+
output_buffer.append(
|
|
150
|
+
AgentLoopOutput(
|
|
151
|
+
prompt_ids=copy.deepcopy(prompt_ids),
|
|
152
|
+
response_ids=copy.deepcopy(response_ids),
|
|
153
|
+
prompt_text=self.tokenizer.decode(prompt_ids),
|
|
154
|
+
response_text=response_text,
|
|
155
|
+
response_mask=response_mask,
|
|
156
|
+
response_logprobs=generate_result["logprobs"]
|
|
157
|
+
if self.return_logprobs
|
|
158
|
+
else None,
|
|
159
|
+
)
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
prompt_ids += response_ids
|
|
163
|
+
all_response_ids.extend(response_ids)
|
|
164
|
+
response_mask += [1] * len(response_ids)
|
|
165
|
+
|
|
166
|
+
if len(response_ids) == max_resp_len:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
# Extract tool calls from response text
|
|
170
|
+
content, function_calls = await self.extract_tool_calls(response_text)
|
|
171
|
+
|
|
172
|
+
if function_calls == []:
|
|
173
|
+
break
|
|
174
|
+
|
|
175
|
+
# Execute tools in parallel with history propagation
|
|
176
|
+
tasks = []
|
|
177
|
+
for tool_request in function_calls:
|
|
178
|
+
tasks.append(self.tool_call(tool_request))
|
|
179
|
+
tool_responses: list[ToolResponse] = await asyncio.gather(*tasks)
|
|
180
|
+
|
|
181
|
+
# Convert tool responses to messages and tokenize
|
|
182
|
+
tool_messages = []
|
|
183
|
+
for tool_response in tool_responses:
|
|
184
|
+
message = {"role": "tool", "content": tool_response.text}
|
|
185
|
+
tool_messages.append(message)
|
|
186
|
+
tool_response_ids = self.tokenizer.encode(
|
|
187
|
+
tool_messages[0]["content"], add_special_tokens=False
|
|
188
|
+
)
|
|
189
|
+
max_tool_resp_len = self.max_resp_len - (
|
|
190
|
+
len(prompt_ids) - len(orig_prompt_ids)
|
|
191
|
+
)
|
|
192
|
+
if len(tool_response_ids) > max_tool_resp_len:
|
|
193
|
+
# task_failed = True
|
|
194
|
+
break
|
|
195
|
+
prompt_ids += tool_response_ids
|
|
196
|
+
all_response_ids.extend(tool_response_ids)
|
|
197
|
+
response_mask += [0] * len(tool_response_ids)
|
|
198
|
+
|
|
199
|
+
if self.print_outputs:
|
|
200
|
+
# add anything you want to print
|
|
201
|
+
trace_prints.append(
|
|
202
|
+
{
|
|
203
|
+
"decode_prompt": self.tokenizer.decode(generate_prompt_ids),
|
|
204
|
+
"generate": response_text,
|
|
205
|
+
"tool_resp": tool_messages,
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
# Build complete response text from all turns
|
|
209
|
+
complete_response_text = self.tokenizer.decode(all_response_ids)
|
|
210
|
+
|
|
211
|
+
# Extract final answer from complete response using searchr1's extract_solution
|
|
212
|
+
from rlinf.algorithms.rewards.searchr1 import compute_score
|
|
213
|
+
|
|
214
|
+
reward_score = compute_score(complete_response_text, answer)
|
|
215
|
+
|
|
216
|
+
for single_turn_output in output_buffer:
|
|
217
|
+
single_turn_output.reward_score = reward_score
|
|
218
|
+
|
|
219
|
+
return MultiAgentLoopOutput(
|
|
220
|
+
single_turn_outputs=output_buffer,
|
|
221
|
+
trace_prints=[],
|
|
222
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2026 The RLinf Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
# Copyright 2026 The RLinf Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import time
|
|
17
|
+
from typing import Callable, Optional
|
|
18
|
+
|
|
19
|
+
from rlinf.data.tool_call.tool_io_struct import ToolChannelRequest, ToolChannelResponse
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ToolBase:
|
|
23
|
+
name = None
|
|
24
|
+
|
|
25
|
+
def __init__(self, cfg):
|
|
26
|
+
self.cfg = cfg
|
|
27
|
+
|
|
28
|
+
async def execute(
|
|
29
|
+
self, request: ToolChannelRequest, **kwargs
|
|
30
|
+
) -> ToolChannelResponse:
|
|
31
|
+
"""Execute the tool call."""
|
|
32
|
+
raise NotImplementedError()
|
|
33
|
+
|
|
34
|
+
def tool_schema(self) -> dict:
|
|
35
|
+
"""
|
|
36
|
+
A JSON Schema, giving the name, description and argument types for the tool.
|
|
37
|
+
Ref: https://huggingface.co/docs/transformers/en/chat_extras#json-schemas
|
|
38
|
+
"""
|
|
39
|
+
raise NotImplementedError()
|
|
40
|
+
|
|
41
|
+
def validate(self, request: ToolChannelRequest) -> Optional[ToolChannelResponse]:
|
|
42
|
+
"""Validate the request call the right tool and the schema is right."""
|
|
43
|
+
raise NotImplementedError()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CodeJudgeToolBase(ToolBase):
|
|
47
|
+
name = None
|
|
48
|
+
|
|
49
|
+
def __init__(self, cfg):
|
|
50
|
+
super().__init__(cfg=cfg)
|
|
51
|
+
self.url = f"http://{self.cfg.tools.codejudge.host_addr}:{self.cfg.tools.codejudge.host_port}/run/long-batch"
|
|
52
|
+
|
|
53
|
+
def _postprocess(self, result):
|
|
54
|
+
if result["run_success"] and result["success"]:
|
|
55
|
+
output_parts = []
|
|
56
|
+
output_parts.append("Tool call success")
|
|
57
|
+
if result["stdout"]:
|
|
58
|
+
output_parts.append(f"stdout: {result['stdout']}")
|
|
59
|
+
if result["stderr"]:
|
|
60
|
+
output_parts.append(f"stderr: {result['stderr']}")
|
|
61
|
+
output_parts.append(f"execution time: {result['cost']:.2f}s")
|
|
62
|
+
result = "\n".join(output_parts)
|
|
63
|
+
return ToolChannelResponse(success=True, result=result)
|
|
64
|
+
else:
|
|
65
|
+
output_parts = []
|
|
66
|
+
output_parts.append("Tool call failure")
|
|
67
|
+
output_parts.append(f"reason: {result['reason']}")
|
|
68
|
+
if result["stdout"]:
|
|
69
|
+
output_parts.append(f"stdout: {result['stdout']}")
|
|
70
|
+
if result["stderr"]:
|
|
71
|
+
output_parts.append(f"stderr: {result['stderr']}")
|
|
72
|
+
output_parts.append(f"execution time: {result['cost']:.2f}s")
|
|
73
|
+
result = "\n".join(output_parts)
|
|
74
|
+
return ToolChannelResponse(success=False, result=result)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
code_template_setup = '''
|
|
78
|
+
import os
|
|
79
|
+
import base64
|
|
80
|
+
import sys
|
|
81
|
+
import ast
|
|
82
|
+
import traceback
|
|
83
|
+
from typing import Optional, Any
|
|
84
|
+
import linecache
|
|
85
|
+
from types import CodeType
|
|
86
|
+
from contextlib import redirect_stdout, redirect_stderr
|
|
87
|
+
from io import StringIO
|
|
88
|
+
|
|
89
|
+
class CodeExecutionError(Exception):
|
|
90
|
+
"""Custom exception for code execution errors with line information"""
|
|
91
|
+
def __init__(self, original_error: Exception, code: str, line_offset: int = 0):
|
|
92
|
+
self.original_error = original_error
|
|
93
|
+
self.code = code
|
|
94
|
+
self.line_offset = line_offset
|
|
95
|
+
|
|
96
|
+
# Get error line number
|
|
97
|
+
if hasattr(original_error, 'lineno'):
|
|
98
|
+
self.lineno = original_error.lineno
|
|
99
|
+
else:
|
|
100
|
+
tb = getattr(original_error, '__traceback__', None)
|
|
101
|
+
if tb:
|
|
102
|
+
while tb.tb_next:
|
|
103
|
+
tb = tb.tb_next
|
|
104
|
+
self.lineno = tb.tb_lineno
|
|
105
|
+
else:
|
|
106
|
+
self.lineno = -1
|
|
107
|
+
|
|
108
|
+
# Adjust line number for code segment
|
|
109
|
+
if self.lineno != -1:
|
|
110
|
+
self.lineno += line_offset
|
|
111
|
+
|
|
112
|
+
# Format error message
|
|
113
|
+
error_type = type(original_error).__name__
|
|
114
|
+
error_msg = str(original_error)
|
|
115
|
+
|
|
116
|
+
if self.lineno != -1:
|
|
117
|
+
# Get the problematic line
|
|
118
|
+
lines = code.splitlines()
|
|
119
|
+
if 0 <= self.lineno - 1 < len(lines):
|
|
120
|
+
error_line = lines[self.lineno - 1]
|
|
121
|
+
# Create error message with line information
|
|
122
|
+
super().__init__(f"{error_type} at line {self.lineno}: {error_msg}\\n {error_line}")
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
super().__init__(f"{error_type}: {error_msg}")
|
|
126
|
+
|
|
127
|
+
class PersistentExecutor:
|
|
128
|
+
def __init__(self):
|
|
129
|
+
self.exec_globals = {
|
|
130
|
+
'__name__': '__main__',
|
|
131
|
+
'__file__': '<string>',
|
|
132
|
+
'__builtins__': __builtins__
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
def split_code(self, code: str) -> tuple[str, Optional[str]]:
|
|
136
|
+
"""
|
|
137
|
+
Intelligently split code into main body and last expression
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
code: The source code string
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
tuple[str, Optional[str]]: (main code body, last expression if exists)
|
|
144
|
+
"""
|
|
145
|
+
try:
|
|
146
|
+
# Parse code into AST
|
|
147
|
+
tree = ast.parse(code)
|
|
148
|
+
if not tree.body:
|
|
149
|
+
return code, None
|
|
150
|
+
|
|
151
|
+
# Check if the last node is a pure expression (not a call)
|
|
152
|
+
last_node = tree.body[-1]
|
|
153
|
+
if isinstance(last_node, ast.Expr):
|
|
154
|
+
# Get the line range of the last expression
|
|
155
|
+
last_expr_start = last_node.lineno
|
|
156
|
+
last_expr_end = last_node.end_lineno if hasattr(last_node, 'end_lineno') else last_node.lineno
|
|
157
|
+
|
|
158
|
+
# Split the code
|
|
159
|
+
lines = code.splitlines()
|
|
160
|
+
main_code = '\\n'.join(lines[:last_expr_start-1])
|
|
161
|
+
last_expr = '\\n'.join(lines[last_expr_start-1:last_expr_end])
|
|
162
|
+
return main_code, last_expr
|
|
163
|
+
except SyntaxError as e:
|
|
164
|
+
raise CodeExecutionError(e, code)
|
|
165
|
+
return code, None
|
|
166
|
+
|
|
167
|
+
def execute_code(self, code: str, replay_history_code: bool) -> None:
|
|
168
|
+
"""
|
|
169
|
+
Execute code while maintaining persistent environment state.
|
|
170
|
+
If the last line is an expression, its value will be printed to stdout.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
code: The source code string to execute
|
|
174
|
+
replay_history_code: If True, suppress stdout and stderr output
|
|
175
|
+
"""
|
|
176
|
+
try:
|
|
177
|
+
# Split code intelligently
|
|
178
|
+
main_code, last_expr = self.split_code(code)
|
|
179
|
+
|
|
180
|
+
# Set up output redirection if replay_history_code is True
|
|
181
|
+
if replay_history_code:
|
|
182
|
+
stdout_capture = StringIO()
|
|
183
|
+
stderr_capture = StringIO()
|
|
184
|
+
stdout_context = redirect_stdout(stdout_capture)
|
|
185
|
+
stderr_context = redirect_stderr(stderr_capture)
|
|
186
|
+
else:
|
|
187
|
+
stdout_context = redirect_stdout(sys.stdout)
|
|
188
|
+
stderr_context = redirect_stderr(sys.stderr)
|
|
189
|
+
|
|
190
|
+
# Execute main code body
|
|
191
|
+
if main_code:
|
|
192
|
+
try:
|
|
193
|
+
# Compile code to get better error line numbers
|
|
194
|
+
compiled_code = compile(main_code, '<string>', 'exec')
|
|
195
|
+
with stdout_context, stderr_context:
|
|
196
|
+
exec(compiled_code, self.exec_globals)
|
|
197
|
+
except Exception as e:
|
|
198
|
+
raise CodeExecutionError(e, main_code)
|
|
199
|
+
|
|
200
|
+
# If there's a last expression, try to evaluate and print it
|
|
201
|
+
if last_expr:
|
|
202
|
+
try:
|
|
203
|
+
# Compile expression to get better error line numbers
|
|
204
|
+
compiled_expr = compile(last_expr, '<string>', 'eval')
|
|
205
|
+
with stdout_context, stderr_context:
|
|
206
|
+
last_value = eval(compiled_expr, self.exec_globals)
|
|
207
|
+
|
|
208
|
+
# Only print the result if not in replay mode
|
|
209
|
+
if last_value is not None and not replay_history_code:
|
|
210
|
+
print(repr(last_value), file=sys.stdout)
|
|
211
|
+
except Exception as e:
|
|
212
|
+
# Try executing as statement if evaluation fails
|
|
213
|
+
try:
|
|
214
|
+
compiled_stmt = compile(last_expr, '<string>', 'exec')
|
|
215
|
+
with stdout_context, stderr_context:
|
|
216
|
+
exec(compiled_stmt, self.exec_globals)
|
|
217
|
+
except Exception as e:
|
|
218
|
+
# Calculate line offset for the last expression
|
|
219
|
+
line_offset = len(main_code.splitlines()) if main_code else 0
|
|
220
|
+
raise CodeExecutionError(e, last_expr, line_offset)
|
|
221
|
+
|
|
222
|
+
except Exception as e:
|
|
223
|
+
if replay_history_code:
|
|
224
|
+
return
|
|
225
|
+
if isinstance(e, CodeExecutionError):
|
|
226
|
+
print(str(e), file=sys.stderr)
|
|
227
|
+
else:
|
|
228
|
+
traceback.print_exc(file=sys.stderr)
|
|
229
|
+
os._exit(1)
|
|
230
|
+
return
|
|
231
|
+
|
|
232
|
+
persistent_executor = PersistentExecutor()
|
|
233
|
+
'''
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
code_template_exec = """
|
|
237
|
+
code_to_execute = base64.b64decode("{}".encode()).decode()
|
|
238
|
+
persistent_executor.execute_code(code_to_execute, replay_history_code={})
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class PythonTool(CodeJudgeToolBase):
|
|
243
|
+
name = "python_code_with_standard_io"
|
|
244
|
+
|
|
245
|
+
def __init__(self, cfg):
|
|
246
|
+
super().__init__(cfg=cfg)
|
|
247
|
+
|
|
248
|
+
async def execute(
|
|
249
|
+
self,
|
|
250
|
+
request: ToolChannelRequest,
|
|
251
|
+
send_request_func: Callable[[str, dict], ToolChannelResponse],
|
|
252
|
+
):
|
|
253
|
+
err_msg = self.validate(request)
|
|
254
|
+
if err_msg:
|
|
255
|
+
return err_msg
|
|
256
|
+
|
|
257
|
+
# convert the code to the code exec on code-judge
|
|
258
|
+
code_to_execute = base64.b64encode(
|
|
259
|
+
request.tool_args.get("code", "").encode()
|
|
260
|
+
).decode()
|
|
261
|
+
final_code = code_template_setup
|
|
262
|
+
# TODO: add history code here
|
|
263
|
+
final_code += code_template_exec.format(code_to_execute, "False")
|
|
264
|
+
|
|
265
|
+
submission = {
|
|
266
|
+
"type": "python",
|
|
267
|
+
"solution": final_code,
|
|
268
|
+
"input": request.tool_args.get("input", ""),
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
data = {"type": "batch", "submissions": [submission]}
|
|
272
|
+
|
|
273
|
+
for retry_time in range(4):
|
|
274
|
+
try:
|
|
275
|
+
results = (await send_request_func(self.url, data))["results"]
|
|
276
|
+
break
|
|
277
|
+
except Exception as e:
|
|
278
|
+
print(f"Tool retry time {retry_time}, exception: {e}")
|
|
279
|
+
time.sleep(1)
|
|
280
|
+
else:
|
|
281
|
+
raise RuntimeError("Tool call failed after retries")
|
|
282
|
+
assert len(results) == 1, f"{results}"
|
|
283
|
+
return self._postprocess(results[0])
|
|
284
|
+
|
|
285
|
+
def tool_schema(self) -> dict:
|
|
286
|
+
return {
|
|
287
|
+
"type": "function",
|
|
288
|
+
"function": {
|
|
289
|
+
"name": "python_code_with_standard_io",
|
|
290
|
+
"description": "Execute Python code with standard input and capture standard output. This function takes a Python code string and an input string, provides the input string through standard input (stdin) to the code, and captures and returns any output produced through standard output (stdout). If the executed code raises an exception, the error message will be captured and returned instead.",
|
|
291
|
+
"parameters": {
|
|
292
|
+
"type": "object",
|
|
293
|
+
"properties": {
|
|
294
|
+
"code": {
|
|
295
|
+
"type": "string",
|
|
296
|
+
"description": "A string containing Python code to be executed. The code can read from standard input using the input() function.",
|
|
297
|
+
},
|
|
298
|
+
"input": {
|
|
299
|
+
"type": "string",
|
|
300
|
+
"description": "A string that will be provided as standard input to the code when it calls input().",
|
|
301
|
+
},
|
|
302
|
+
},
|
|
303
|
+
"required": ["code", "input"],
|
|
304
|
+
},
|
|
305
|
+
},
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
def validate(self, request) -> Optional[ToolChannelResponse]:
|
|
309
|
+
tool_args = request.tool_args
|
|
310
|
+
|
|
311
|
+
assert request.tool_name == self.name, (
|
|
312
|
+
f"Name mismatch, {self.name} != {request.tool_name}"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
if not isinstance(tool_args, dict):
|
|
316
|
+
return ToolChannelResponse(
|
|
317
|
+
success=False,
|
|
318
|
+
result="Error when executing tool: run_tool_calls_on_server_async failed for1 tool calls after 4 attempts.",
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if "code" in tool_args and not isinstance(tool_args["code"], str):
|
|
322
|
+
return ToolChannelResponse(
|
|
323
|
+
success=False,
|
|
324
|
+
result="Error when executing tool: run_tool_calls_on_server_async failed for1 tool calls after 4 attempts.",
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
if "input" in tool_args and not isinstance(tool_args["input"], str):
|
|
328
|
+
return ToolChannelResponse(
|
|
329
|
+
success=False,
|
|
330
|
+
result="Error when executing tool: run_tool_calls_on_server_async failed for1 tool calls after 4 attempts.",
|
|
331
|
+
)
|