rlinf 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (411) hide show
  1. rlinf/__init__.py +17 -0
  2. rlinf/agents/__init__.py +13 -0
  3. rlinf/agents/mas_search/__init__.py +18 -0
  4. rlinf/agents/mas_search/mas_search_agent_loop.py +222 -0
  5. rlinf/agents/rstar2/__init__.py +13 -0
  6. rlinf/agents/rstar2/http_code_judge_tool.py +331 -0
  7. rlinf/agents/rstar2/http_tool_worker.py +176 -0
  8. rlinf/agents/rstar2/rstar2_agent_loop.py +324 -0
  9. rlinf/agents/searchr1/__init__.py +13 -0
  10. rlinf/agents/searchr1/eval_runner.py +260 -0
  11. rlinf/agents/searchr1/search_tool_worker.py +157 -0
  12. rlinf/agents/searchr1/searchr1_agent_loop.py +189 -0
  13. rlinf/agents/wideseek_r1/__init__.py +13 -0
  14. rlinf/agents/wideseek_r1/eval_runner.py +609 -0
  15. rlinf/agents/wideseek_r1/tools.py +785 -0
  16. rlinf/agents/wideseek_r1/utils/__init__.py +13 -0
  17. rlinf/agents/wideseek_r1/utils/metrics.py +188 -0
  18. rlinf/agents/wideseek_r1/utils/prompt.py +529 -0
  19. rlinf/agents/wideseek_r1/utils/prompt_utils.py +261 -0
  20. rlinf/agents/wideseek_r1/utils/reward.py +671 -0
  21. rlinf/agents/wideseek_r1/utils/sglang_client.py +144 -0
  22. rlinf/agents/wideseek_r1/utils/tool_description.py +285 -0
  23. rlinf/agents/wideseek_r1/utils/webpage.py +192 -0
  24. rlinf/agents/wideseek_r1/wideseek_r1.py +843 -0
  25. rlinf/algorithms/__init__.py +14 -0
  26. rlinf/algorithms/advantages.py +350 -0
  27. rlinf/algorithms/loss_scales.py +182 -0
  28. rlinf/algorithms/losses.py +461 -0
  29. rlinf/algorithms/registry.py +156 -0
  30. rlinf/algorithms/rewards/__init__.py +38 -0
  31. rlinf/algorithms/rewards/code/__init__.py +33 -0
  32. rlinf/algorithms/rewards/code/code_verifier/__init__.py +13 -0
  33. rlinf/algorithms/rewards/code/code_verifier/verify.py +230 -0
  34. rlinf/algorithms/rewards/math/__init__.py +42 -0
  35. rlinf/algorithms/rewards/math/math_verifier/__init__.py +13 -0
  36. rlinf/algorithms/rewards/math/math_verifier/parser.py +441 -0
  37. rlinf/algorithms/rewards/math/math_verifier/verify.py +441 -0
  38. rlinf/algorithms/rewards/rstar2/__init__.py +182 -0
  39. rlinf/algorithms/rewards/rstar2/fused_compute_score/__init__.py +13 -0
  40. rlinf/algorithms/rewards/rstar2/fused_compute_score/compute_score.py +37 -0
  41. rlinf/algorithms/rewards/rstar2/fused_compute_score/math_verify.py +43 -0
  42. rlinf/algorithms/rewards/rstar2/fused_compute_score/prime_math/__init__.py +440 -0
  43. rlinf/algorithms/rewards/rstar2/fused_compute_score/prime_math/grader.py +545 -0
  44. rlinf/algorithms/rewards/rstar2/fused_compute_score/prime_math/math_normalize.py +192 -0
  45. rlinf/algorithms/rewards/searchr1/__init__.py +181 -0
  46. rlinf/algorithms/rewards/vqa/__init__.py +60 -0
  47. rlinf/algorithms/rewards/vqa/format_rewards.py +66 -0
  48. rlinf/algorithms/rewards/vqa/qa_rewards.py +109 -0
  49. rlinf/algorithms/toolcall_parsers.py +297 -0
  50. rlinf/algorithms/utils.py +398 -0
  51. rlinf/config.py +1358 -0
  52. rlinf/data/__init__.py +13 -0
  53. rlinf/data/datasets/__init__.py +267 -0
  54. rlinf/data/datasets/item.py +74 -0
  55. rlinf/data/datasets/reasoning.py +250 -0
  56. rlinf/data/datasets/rstar2.py +100 -0
  57. rlinf/data/datasets/vlm.py +684 -0
  58. rlinf/data/datasets/wideseek_r1.py +139 -0
  59. rlinf/data/datasets/world_model.py +375 -0
  60. rlinf/data/embodied_buffer_dataset.py +287 -0
  61. rlinf/data/embodied_io_struct.py +660 -0
  62. rlinf/data/io_struct.py +1827 -0
  63. rlinf/data/lerobot_writer.py +1065 -0
  64. rlinf/data/replay_buffer.py +1169 -0
  65. rlinf/data/tokenizers.py +71 -0
  66. rlinf/data/tool_call/__init__.py +13 -0
  67. rlinf/data/tool_call/tool_io_struct.py +122 -0
  68. rlinf/data/utils.py +53 -0
  69. rlinf/envs/__init__.py +117 -0
  70. rlinf/envs/action_utils.py +246 -0
  71. rlinf/envs/behavior/__init__.py +13 -0
  72. rlinf/envs/behavior/behavior_env.py +312 -0
  73. rlinf/envs/calvin/__init__.py +129 -0
  74. rlinf/envs/calvin/calvin_gym_env.py +486 -0
  75. rlinf/envs/calvin/utils.py +75 -0
  76. rlinf/envs/calvin/venv.py +264 -0
  77. rlinf/envs/frankasim/__init__.py +19 -0
  78. rlinf/envs/frankasim/frankasim_env.py +722 -0
  79. rlinf/envs/habitat/__init__.py +17 -0
  80. rlinf/envs/habitat/extensions/__init__.py +13 -0
  81. rlinf/envs/habitat/extensions/config/vlnce_r2r.yaml +68 -0
  82. rlinf/envs/habitat/extensions/maps.py +357 -0
  83. rlinf/envs/habitat/extensions/utils.py +711 -0
  84. rlinf/envs/habitat/habitat_env.py +348 -0
  85. rlinf/envs/habitat/venv.py +246 -0
  86. rlinf/envs/isaaclab/__init__.py +21 -0
  87. rlinf/envs/isaaclab/isaaclab_env.py +264 -0
  88. rlinf/envs/isaaclab/tasks/__init__.py +13 -0
  89. rlinf/envs/isaaclab/tasks/stack_cube.py +97 -0
  90. rlinf/envs/isaaclab/utils.py +62 -0
  91. rlinf/envs/isaaclab/venv.py +118 -0
  92. rlinf/envs/libero/__init__.py +13 -0
  93. rlinf/envs/libero/libero_env.py +462 -0
  94. rlinf/envs/libero/utils.py +125 -0
  95. rlinf/envs/libero/venv.py +173 -0
  96. rlinf/envs/maniskill/__init__.py +33 -0
  97. rlinf/envs/maniskill/maniskill_env.py +422 -0
  98. rlinf/envs/maniskill/maniskill_offload_env.py +449 -0
  99. rlinf/envs/maniskill/tasks/__init__.py +13 -0
  100. rlinf/envs/maniskill/tasks/panda_put_on_in_scene_multi.py +1380 -0
  101. rlinf/envs/maniskill/tasks/panda_table_agent.py +136 -0
  102. rlinf/envs/maniskill/tasks/pose_utils.py +311 -0
  103. rlinf/envs/maniskill/tasks/put_carrot_on_plate.py +151 -0
  104. rlinf/envs/maniskill/tasks/put_on_in_scene_multi.py +963 -0
  105. rlinf/envs/maniskill/tasks/variants/__init__.py +51 -0
  106. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_carrot.py +77 -0
  107. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_ee_pose.py +289 -0
  108. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_image.py +105 -0
  109. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_instruct.py +123 -0
  110. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_multi_carrot.py +352 -0
  111. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_multi_plate.py +399 -0
  112. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_plate.py +97 -0
  113. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_position.py +231 -0
  114. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_position_change.py +179 -0
  115. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_single.py +110 -0
  116. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_image.py +49 -0
  117. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_texture.py +325 -0
  118. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_whole.py +364 -0
  119. rlinf/envs/maniskill/tasks/variants/utils.py +31 -0
  120. rlinf/envs/maniskill/utils.py +66 -0
  121. rlinf/envs/metaworld/__init__.py +61 -0
  122. rlinf/envs/metaworld/metaworld_env.py +442 -0
  123. rlinf/envs/metaworld/utils.py +21 -0
  124. rlinf/envs/metaworld/venv.py +170 -0
  125. rlinf/envs/realworld/__init__.py +33 -0
  126. rlinf/envs/realworld/common/camera/__init__.py +17 -0
  127. rlinf/envs/realworld/common/camera/camera.py +143 -0
  128. rlinf/envs/realworld/common/keyboard/__init__.py +13 -0
  129. rlinf/envs/realworld/common/keyboard/keyboard_listener.py +42 -0
  130. rlinf/envs/realworld/common/ros/__init__.py +17 -0
  131. rlinf/envs/realworld/common/ros/ros_controller.py +129 -0
  132. rlinf/envs/realworld/common/spacemouse/__init__.py +13 -0
  133. rlinf/envs/realworld/common/spacemouse/spacemouse_expert.py +74 -0
  134. rlinf/envs/realworld/common/video_player/__init__.py +17 -0
  135. rlinf/envs/realworld/common/video_player/video_player.py +55 -0
  136. rlinf/envs/realworld/common/wrappers/__init__.py +31 -0
  137. rlinf/envs/realworld/common/wrappers/euler_obs.py +39 -0
  138. rlinf/envs/realworld/common/wrappers/gripper_close.py +41 -0
  139. rlinf/envs/realworld/common/wrappers/relative_frame.py +141 -0
  140. rlinf/envs/realworld/common/wrappers/reward_done_wrapper.py +105 -0
  141. rlinf/envs/realworld/common/wrappers/spacemouse_intervention.py +72 -0
  142. rlinf/envs/realworld/franka/__init__.py +18 -0
  143. rlinf/envs/realworld/franka/franka_controller.py +375 -0
  144. rlinf/envs/realworld/franka/franka_env.py +573 -0
  145. rlinf/envs/realworld/franka/franka_robot_state.py +48 -0
  146. rlinf/envs/realworld/franka/tasks/__init__.py +35 -0
  147. rlinf/envs/realworld/franka/tasks/bottle.py +136 -0
  148. rlinf/envs/realworld/franka/tasks/franka_bin_relocation.py +240 -0
  149. rlinf/envs/realworld/franka/tasks/peg_insertion_env.py +129 -0
  150. rlinf/envs/realworld/franka/utils.py +105 -0
  151. rlinf/envs/realworld/realworld_env.py +395 -0
  152. rlinf/envs/realworld/venv.py +319 -0
  153. rlinf/envs/realworld/xsquare/__init__.py +18 -0
  154. rlinf/envs/realworld/xsquare/tasks/__init__.py +24 -0
  155. rlinf/envs/realworld/xsquare/tasks/button_env.py +79 -0
  156. rlinf/envs/realworld/xsquare/turtle2_env.py +567 -0
  157. rlinf/envs/realworld/xsquare/turtle2_robot_state.py +49 -0
  158. rlinf/envs/realworld/xsquare/turtle2_smooth_controller.py +264 -0
  159. rlinf/envs/robocasa/__init__.py +17 -0
  160. rlinf/envs/robocasa/robocasa_env.py +509 -0
  161. rlinf/envs/robocasa/utils.py +178 -0
  162. rlinf/envs/robocasa/venv.py +163 -0
  163. rlinf/envs/robotwin/__init__.py +13 -0
  164. rlinf/envs/robotwin/robotwin_env.py +506 -0
  165. rlinf/envs/utils.py +317 -0
  166. rlinf/envs/venv/__init__.py +33 -0
  167. rlinf/envs/venv/venv.py +985 -0
  168. rlinf/envs/world_model/__init__.py +13 -0
  169. rlinf/envs/world_model/base_world_env.py +158 -0
  170. rlinf/envs/world_model/world_model_opensora_env.py +917 -0
  171. rlinf/envs/world_model/world_model_wan_env.py +833 -0
  172. rlinf/envs/wrappers/__init__.py +18 -0
  173. rlinf/envs/wrappers/collect_episode.py +642 -0
  174. rlinf/envs/wrappers/record_video.py +457 -0
  175. rlinf/hybrid_engines/__init__.py +13 -0
  176. rlinf/hybrid_engines/fsdp/__init__.py +49 -0
  177. rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +625 -0
  178. rlinf/hybrid_engines/fsdp/strategy/__init__.py +13 -0
  179. rlinf/hybrid_engines/fsdp/strategy/base.py +547 -0
  180. rlinf/hybrid_engines/fsdp/strategy/checkpoint.py +132 -0
  181. rlinf/hybrid_engines/fsdp/strategy/fsdp.py +351 -0
  182. rlinf/hybrid_engines/fsdp/strategy/fsdp2.py +203 -0
  183. rlinf/hybrid_engines/fsdp/utils.py +1014 -0
  184. rlinf/hybrid_engines/megatron/__init__.py +13 -0
  185. rlinf/hybrid_engines/megatron/megatron_model_manager.py +842 -0
  186. rlinf/hybrid_engines/megatron/token_dispatcher.py +600 -0
  187. rlinf/hybrid_engines/megatron/utils.py +240 -0
  188. rlinf/hybrid_engines/sglang/common/__init__.py +13 -0
  189. rlinf/hybrid_engines/sglang/common/detokenizer_manager.py +61 -0
  190. rlinf/hybrid_engines/sglang/common/io_struct.py +52 -0
  191. rlinf/hybrid_engines/sglang/common/sgl_engine.py +138 -0
  192. rlinf/hybrid_engines/sglang/common/sgl_scheduler.py +592 -0
  193. rlinf/hybrid_engines/sglang/common/tokenizer_manager.py +240 -0
  194. rlinf/hybrid_engines/vllm/vllm_0_8_5/__init__.py +13 -0
  195. rlinf/hybrid_engines/vllm/vllm_0_8_5/executor.py +313 -0
  196. rlinf/hybrid_engines/vllm/vllm_0_8_5/weight_loader.py +43 -0
  197. rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py +161 -0
  198. rlinf/models/__init__.py +96 -0
  199. rlinf/models/embodiment/__init__.py +13 -0
  200. rlinf/models/embodiment/base_policy.py +93 -0
  201. rlinf/models/embodiment/cnn_policy/__init__.py +26 -0
  202. rlinf/models/embodiment/cnn_policy/cnn_policy.py +621 -0
  203. rlinf/models/embodiment/dexbotic_pi/__init__.py +120 -0
  204. rlinf/models/embodiment/dexbotic_pi/dexbotic_pi_policy.py +773 -0
  205. rlinf/models/embodiment/flow_policy/__init__.py +51 -0
  206. rlinf/models/embodiment/flow_policy/flow_policy.py +635 -0
  207. rlinf/models/embodiment/gr00t/__init__.py +82 -0
  208. rlinf/models/embodiment/gr00t/embodiment_tags.py +59 -0
  209. rlinf/models/embodiment/gr00t/gr00t_action_model.py +732 -0
  210. rlinf/models/embodiment/gr00t/modality_config.py +177 -0
  211. rlinf/models/embodiment/gr00t/simulation_io.py +198 -0
  212. rlinf/models/embodiment/gr00t/utils.py +127 -0
  213. rlinf/models/embodiment/mlp_policy/__init__.py +31 -0
  214. rlinf/models/embodiment/mlp_policy/mlp_policy.py +403 -0
  215. rlinf/models/embodiment/modules/__init__.py +13 -0
  216. rlinf/models/embodiment/modules/batch_renorm.py +132 -0
  217. rlinf/models/embodiment/modules/compact_encoders.py +356 -0
  218. rlinf/models/embodiment/modules/entropy_tunning.py +69 -0
  219. rlinf/models/embodiment/modules/explore_noise_net.py +168 -0
  220. rlinf/models/embodiment/modules/flow_actor.py +469 -0
  221. rlinf/models/embodiment/modules/gaussian_policy.py +317 -0
  222. rlinf/models/embodiment/modules/mlp.py +95 -0
  223. rlinf/models/embodiment/modules/q_head.py +328 -0
  224. rlinf/models/embodiment/modules/resnet_utils.py +158 -0
  225. rlinf/models/embodiment/modules/utils.py +65 -0
  226. rlinf/models/embodiment/modules/value_head.py +67 -0
  227. rlinf/models/embodiment/openpi/__init__.py +121 -0
  228. rlinf/models/embodiment/openpi/dataconfig/__init__.py +373 -0
  229. rlinf/models/embodiment/openpi/dataconfig/behavior_dataconfig.py +102 -0
  230. rlinf/models/embodiment/openpi/dataconfig/calvin_dataconfig.py +71 -0
  231. rlinf/models/embodiment/openpi/dataconfig/franka_co_training_dataconfig.py +102 -0
  232. rlinf/models/embodiment/openpi/dataconfig/franka_dataconfig.py +101 -0
  233. rlinf/models/embodiment/openpi/dataconfig/gsenv_dataconfig.py +58 -0
  234. rlinf/models/embodiment/openpi/dataconfig/libero_dataconfig.py +102 -0
  235. rlinf/models/embodiment/openpi/dataconfig/maniskill_dataconfig.py +103 -0
  236. rlinf/models/embodiment/openpi/dataconfig/metaworld_dataconfig.py +68 -0
  237. rlinf/models/embodiment/openpi/dataconfig/robocasa_dataconfig.py +75 -0
  238. rlinf/models/embodiment/openpi/dataconfig/robotwin_aloha_dataconfig.py +106 -0
  239. rlinf/models/embodiment/openpi/openpi_action_model.py +1185 -0
  240. rlinf/models/embodiment/openpi/policies/__init__.py +13 -0
  241. rlinf/models/embodiment/openpi/policies/aloha_policy.py +241 -0
  242. rlinf/models/embodiment/openpi/policies/behavior_policy.py +119 -0
  243. rlinf/models/embodiment/openpi/policies/calvin_policy.py +87 -0
  244. rlinf/models/embodiment/openpi/policies/franka_policy.py +137 -0
  245. rlinf/models/embodiment/openpi/policies/gsenv_policy.py +74 -0
  246. rlinf/models/embodiment/openpi/policies/libero_policy.py +118 -0
  247. rlinf/models/embodiment/openpi/policies/maniskill_policy.py +114 -0
  248. rlinf/models/embodiment/openpi/policies/metaworld_policy.py +75 -0
  249. rlinf/models/embodiment/openpi/policies/robocasa_policy.py +126 -0
  250. rlinf/models/embodiment/openvla/__init__.py +104 -0
  251. rlinf/models/embodiment/openvla/openvla_action_model.py +808 -0
  252. rlinf/models/embodiment/openvla_oft/__init__.py +32 -0
  253. rlinf/models/embodiment/openvla_oft/official/__init__.py +118 -0
  254. rlinf/models/embodiment/openvla_oft/official/openvla_oft_action_model.py +703 -0
  255. rlinf/models/embodiment/openvla_oft/openvla_utils.py +193 -0
  256. rlinf/models/embodiment/openvla_oft/rlinf/__init__.py +109 -0
  257. rlinf/models/embodiment/openvla_oft/rlinf/openvla_oft_action_model.py +574 -0
  258. rlinf/models/embodiment/prismatic/__init__.py +13 -0
  259. rlinf/models/embodiment/prismatic/processing_prismatic.py +243 -0
  260. rlinf/runners/__init__.py +13 -0
  261. rlinf/runners/agent_eval_runner.py +248 -0
  262. rlinf/runners/agent_runner.py +326 -0
  263. rlinf/runners/async_embodied_runner.py +274 -0
  264. rlinf/runners/async_ppo_embodied_runner.py +260 -0
  265. rlinf/runners/coding_online_rl_runner.py +308 -0
  266. rlinf/runners/embodied_eval_runner.py +80 -0
  267. rlinf/runners/embodied_runner.py +439 -0
  268. rlinf/runners/reasoning_eval_runner.py +179 -0
  269. rlinf/runners/reasoning_runner.py +645 -0
  270. rlinf/runners/sft_runner.py +168 -0
  271. rlinf/scheduler/__init__.py +56 -0
  272. rlinf/scheduler/channel/__init__.py +18 -0
  273. rlinf/scheduler/channel/channel.py +648 -0
  274. rlinf/scheduler/channel/channel_worker.py +536 -0
  275. rlinf/scheduler/cluster/__init__.py +30 -0
  276. rlinf/scheduler/cluster/cluster.py +525 -0
  277. rlinf/scheduler/cluster/config.py +442 -0
  278. rlinf/scheduler/cluster/node.py +554 -0
  279. rlinf/scheduler/cluster/utils.py +604 -0
  280. rlinf/scheduler/collective/__init__.py +36 -0
  281. rlinf/scheduler/collective/async_work.py +386 -0
  282. rlinf/scheduler/collective/collective.py +96 -0
  283. rlinf/scheduler/collective/collective_group.py +1827 -0
  284. rlinf/scheduler/collective/multi_channel_pg.py +927 -0
  285. rlinf/scheduler/dynamic_scheduler/__init__.py +13 -0
  286. rlinf/scheduler/dynamic_scheduler/manager.py +1069 -0
  287. rlinf/scheduler/dynamic_scheduler/scheduler_worker.py +129 -0
  288. rlinf/scheduler/dynamic_scheduler/utils.py +162 -0
  289. rlinf/scheduler/hardware/__init__.py +38 -0
  290. rlinf/scheduler/hardware/accelerators/__init__.py +31 -0
  291. rlinf/scheduler/hardware/accelerators/accelerator.py +300 -0
  292. rlinf/scheduler/hardware/accelerators/amd_gpu.py +143 -0
  293. rlinf/scheduler/hardware/accelerators/ascend_npu.py +113 -0
  294. rlinf/scheduler/hardware/accelerators/intel_gpu.py +113 -0
  295. rlinf/scheduler/hardware/accelerators/musa_gpu.py +147 -0
  296. rlinf/scheduler/hardware/accelerators/nvidia_gpu.py +203 -0
  297. rlinf/scheduler/hardware/hardware.py +180 -0
  298. rlinf/scheduler/hardware/robots/__init__.py +18 -0
  299. rlinf/scheduler/hardware/robots/franka.py +176 -0
  300. rlinf/scheduler/hardware/robots/xsquare.py +86 -0
  301. rlinf/scheduler/manager/__init__.py +32 -0
  302. rlinf/scheduler/manager/coll_manager.py +140 -0
  303. rlinf/scheduler/manager/lock_manager.py +187 -0
  304. rlinf/scheduler/manager/manager.py +123 -0
  305. rlinf/scheduler/manager/node_manager.py +46 -0
  306. rlinf/scheduler/manager/worker_manager.py +318 -0
  307. rlinf/scheduler/placement/__init__.py +27 -0
  308. rlinf/scheduler/placement/flexible.py +277 -0
  309. rlinf/scheduler/placement/node.py +205 -0
  310. rlinf/scheduler/placement/packed.py +335 -0
  311. rlinf/scheduler/placement/placement.py +674 -0
  312. rlinf/scheduler/worker/__init__.py +24 -0
  313. rlinf/scheduler/worker/lock.py +103 -0
  314. rlinf/scheduler/worker/worker.py +1250 -0
  315. rlinf/scheduler/worker/worker_group.py +556 -0
  316. rlinf/utils/__init__.py +13 -0
  317. rlinf/utils/ckpt_convertor/__init__.py +13 -0
  318. rlinf/utils/ckpt_convertor/convert_openpi_jax_to_python.py +706 -0
  319. rlinf/utils/ckpt_convertor/fsdp_convertor/__init__.py +13 -0
  320. rlinf/utils/ckpt_convertor/fsdp_convertor/config/fsdp_model_convertor.yaml +27 -0
  321. rlinf/utils/ckpt_convertor/fsdp_convertor/convert_dcp_to_pt.py +58 -0
  322. rlinf/utils/ckpt_convertor/fsdp_convertor/convert_pt_to_hf.py +81 -0
  323. rlinf/utils/ckpt_convertor/fsdp_convertor/utils.py +197 -0
  324. rlinf/utils/ckpt_convertor/megatron_convertor/__init__.py +13 -0
  325. rlinf/utils/ckpt_convertor/megatron_convertor/config.py +208 -0
  326. rlinf/utils/ckpt_convertor/megatron_convertor/convert_hf_to_mg.py +410 -0
  327. rlinf/utils/ckpt_convertor/megatron_convertor/convert_hf_to_middle_file.py +503 -0
  328. rlinf/utils/ckpt_convertor/megatron_convertor/convert_mg_to_middle_file.py +863 -0
  329. rlinf/utils/ckpt_convertor/megatron_convertor/convert_middle_file_to_hf.py +726 -0
  330. rlinf/utils/ckpt_convertor/megatron_convertor/convert_middle_file_to_mg.py +626 -0
  331. rlinf/utils/ckpt_convertor/megatron_convertor/default_args.yaml +151 -0
  332. rlinf/utils/ckpt_convertor/megatron_convertor/utils/__init__.py +31 -0
  333. rlinf/utils/ckpt_convertor/megatron_convertor/utils/fp8_utils.py +135 -0
  334. rlinf/utils/ckpt_convertor/megatron_convertor/utils/mg_loader.py +171 -0
  335. rlinf/utils/ckpt_convertor/megatron_convertor/utils/mg_moe_groupgemm.py +198 -0
  336. rlinf/utils/ckpt_convertor/megatron_convertor/utils/mp_utils.py +61 -0
  337. rlinf/utils/ckpt_convertor/megatron_convertor/utils/safetensors_loader.py +116 -0
  338. rlinf/utils/ckpt_convertor/megatron_convertor/utils/tensor_operations.py +402 -0
  339. rlinf/utils/comm_mapping.py +91 -0
  340. rlinf/utils/convertor/__init__.py +13 -0
  341. rlinf/utils/convertor/utils.py +637 -0
  342. rlinf/utils/cuda_graph.py +274 -0
  343. rlinf/utils/data_iter_utils.py +718 -0
  344. rlinf/utils/data_process.py +90 -0
  345. rlinf/utils/distributed.py +1316 -0
  346. rlinf/utils/drq.py +109 -0
  347. rlinf/utils/flops.py +240 -0
  348. rlinf/utils/initialize.py +333 -0
  349. rlinf/utils/logging.py +20 -0
  350. rlinf/utils/metric_logger.py +175 -0
  351. rlinf/utils/metric_utils.py +348 -0
  352. rlinf/utils/nested_dict_process.py +110 -0
  353. rlinf/utils/omega_resolver.py +36 -0
  354. rlinf/utils/patcher.py +217 -0
  355. rlinf/utils/placement.py +599 -0
  356. rlinf/utils/profiler.py +244 -0
  357. rlinf/utils/pytree.py +60 -0
  358. rlinf/utils/resharding/__init__.py +13 -0
  359. rlinf/utils/resharding/mcore_weight_reshard.py +335 -0
  360. rlinf/utils/resharding/reshard_config.py +93 -0
  361. rlinf/utils/resharding/utils.py +332 -0
  362. rlinf/utils/runner_utils.py +82 -0
  363. rlinf/utils/timers.py +196 -0
  364. rlinf/utils/torch_functionals.py +32 -0
  365. rlinf/utils/train_utils.py +84 -0
  366. rlinf/utils/utils.py +509 -0
  367. rlinf/workers/__init__.py +13 -0
  368. rlinf/workers/actor/__init__.py +30 -0
  369. rlinf/workers/actor/async_fsdp_sac_policy_worker.py +138 -0
  370. rlinf/workers/actor/async_ppo_fsdp_worker.py +366 -0
  371. rlinf/workers/actor/fsdp_actor_worker.py +1494 -0
  372. rlinf/workers/actor/fsdp_sac_policy_worker.py +839 -0
  373. rlinf/workers/actor/ma_megatron_actor_worker.py +710 -0
  374. rlinf/workers/actor/megatron_actor_worker.py +417 -0
  375. rlinf/workers/agent/__init__.py +13 -0
  376. rlinf/workers/agent/agent_loop.py +683 -0
  377. rlinf/workers/agent/tool_worker.py +43 -0
  378. rlinf/workers/critic/__init__.py +26 -0
  379. rlinf/workers/critic/megatron_critic_worker.py +112 -0
  380. rlinf/workers/env/__init__.py +13 -0
  381. rlinf/workers/env/async_env_worker.py +94 -0
  382. rlinf/workers/env/env_worker.py +612 -0
  383. rlinf/workers/inference/__init__.py +13 -0
  384. rlinf/workers/inference/fsdp_inference_worker.py +146 -0
  385. rlinf/workers/inference/megatron_inference_worker.py +129 -0
  386. rlinf/workers/inference/utils.py +73 -0
  387. rlinf/workers/megatron_worker.py +1340 -0
  388. rlinf/workers/reward/__init__.py +13 -0
  389. rlinf/workers/reward/reward_worker.py +339 -0
  390. rlinf/workers/rollout/__init__.py +13 -0
  391. rlinf/workers/rollout/hf/__init__.py +13 -0
  392. rlinf/workers/rollout/hf/async_huggingface_worker.py +176 -0
  393. rlinf/workers/rollout/hf/huggingface_worker.py +575 -0
  394. rlinf/workers/rollout/hf/utils.py +30 -0
  395. rlinf/workers/rollout/server/__init__.py +13 -0
  396. rlinf/workers/rollout/server/online_router_worker.py +259 -0
  397. rlinf/workers/rollout/server/server_rollout_worker.py +378 -0
  398. rlinf/workers/rollout/sglang/__init__.py +43 -0
  399. rlinf/workers/rollout/sglang/sglang_worker.py +510 -0
  400. rlinf/workers/rollout/utils.py +562 -0
  401. rlinf/workers/rollout/vllm/__init__.py +41 -0
  402. rlinf/workers/rollout/vllm/vllm_worker.py +508 -0
  403. rlinf/workers/sft/__init__.py +13 -0
  404. rlinf/workers/sft/fsdp_sft_worker.py +218 -0
  405. rlinf/workers/sft/fsdp_vla_sft_worker.py +74 -0
  406. rlinf/workers/sft/fsdp_vlm_sft_worker.py +317 -0
  407. rlinf-0.2.dist-info/METADATA +640 -0
  408. rlinf-0.2.dist-info/RECORD +411 -0
  409. rlinf-0.2.dist-info/WHEEL +5 -0
  410. rlinf-0.2.dist-info/licenses/LICENSE +201 -0
  411. rlinf-0.2.dist-info/top_level.txt +1 -0
rlinf/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .utils.omega_resolver import omegaconf_register
16
+
17
+ omegaconf_register()
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,18 @@
1
+ # Copyright 2026 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from rlinf.agents.mas_search.mas_search_agent_loop import MasSearchAgentLoopWorker
16
+ from rlinf.agents.searchr1.search_tool_worker import SearchToolWorker
17
+
18
+ __all__ = ["MasSearchAgentLoopWorker", "SearchToolWorker"]
@@ -0,0 +1,222 @@
1
+ # Copyright 2026 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import copy
17
+ import json
18
+ import re
19
+ from uuid import uuid4
20
+
21
+ from omegaconf import DictConfig
22
+
23
+ from rlinf.data.tool_call.tool_io_struct import (
24
+ ToolChannelRequest,
25
+ ToolChannelResponse,
26
+ ToolRequest,
27
+ ToolResponse,
28
+ )
29
+ from rlinf.scheduler import Channel
30
+ from rlinf.utils.placement import ModelParallelComponentPlacement
31
+ from rlinf.workers.agent.agent_loop import (
32
+ AgentLoopOutput,
33
+ MultiAgentLoopOutput,
34
+ MultiAgentLoopWorker,
35
+ )
36
+
37
+
38
+ class MasSearchAgentLoopWorker(MultiAgentLoopWorker):
39
+ """
40
+ Agent loop worker that combines search-r1's <search>keyword</search> extraction
41
+ logic with multi agent system's component structure.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ cfg: DictConfig,
47
+ placement: ModelParallelComponentPlacement,
48
+ ):
49
+ super().__init__(cfg, placement)
50
+ self.max_prompt_len = int(self.cfg.data.max_prompt_length)
51
+ max_total_len = int(self.cfg.actor.model.encoder_seq_length)
52
+ self.max_resp_len = max(1, max_total_len - self.max_prompt_len)
53
+
54
+ # Search tool call tokens
55
+ self.tool_call_start_token: str = "<search>"
56
+ self.tool_call_end_token: str = "</search>"
57
+ self.tool_call_regex = re.compile(r"<search>(.*?)</search>", re.DOTALL)
58
+
59
+ # Max turns for multi-turn interaction
60
+ self.max_turns = self.cfg.agentloop.get("max_turns", 5)
61
+ self.return_logprobs = not cfg.algorithm.recompute_logprobs
62
+
63
+ async def state_less_tool_call_with_channel(
64
+ self,
65
+ input_channel: Channel,
66
+ output_channel: Channel,
67
+ tool_name: str,
68
+ tool_args: dict,
69
+ ) -> ToolChannelResponse:
70
+ """state-less tool call with channel, used for demo"""
71
+ session_id = uuid4().hex
72
+ await input_channel.put(
73
+ ToolChannelRequest(
74
+ session_id=session_id,
75
+ request_type="execute",
76
+ tool_name=tool_name,
77
+ tool_args=tool_args,
78
+ ),
79
+ async_op=True,
80
+ ).async_wait()
81
+ return await output_channel.get(session_id, async_op=True).async_wait()
82
+
83
+ async def tool_call(self, tool_request: ToolRequest) -> ToolResponse:
84
+ tool_name, tool_args = tool_request.name, tool_request.arguments
85
+ tool_channel_info = self.tool_channel_info_map[self.tool_name_map[tool_name]]
86
+ channel_response = await self.state_less_tool_call_with_channel(
87
+ tool_channel_info.input_channel,
88
+ self.tool_worker_output_channel,
89
+ tool_name,
90
+ tool_args,
91
+ )
92
+
93
+ # no failure in this demo
94
+ assert channel_response.success
95
+ if isinstance(channel_response.result, (list, dict)):
96
+ result_text = json.dumps(channel_response.result)
97
+ else:
98
+ result_text = str(channel_response.result)
99
+ return ToolResponse(
100
+ text=result_text,
101
+ )
102
+
103
+ async def extract_tool_calls(self, response_text) -> tuple[str, list[ToolRequest]]:
104
+ """
105
+ Extract tool calls from response text using <search>keyword</search> format.
106
+ """
107
+ if (
108
+ self.tool_call_start_token not in response_text
109
+ or self.tool_call_end_token not in response_text
110
+ ):
111
+ return response_text, []
112
+ matches = self.tool_call_regex.findall(response_text)
113
+ function_calls = []
114
+ if matches:
115
+ match = matches[-1].strip()
116
+ function_calls.append(
117
+ ToolRequest(name="search", arguments={"keyword": match})
118
+ )
119
+
120
+ # remaining text exclude tool call tokens
121
+ content = self.tool_call_regex.sub("", response_text)
122
+
123
+ return content, function_calls
124
+
125
+ async def run_one_query(self, prompt_ids: list[int], *, answer) -> AgentLoopOutput:
126
+ output_buffer = []
127
+ orig_prompt_ids = copy.deepcopy(prompt_ids)
128
+ trace_prints = []
129
+ response_mask = []
130
+ all_response_ids = []
131
+ for _ in range(self.max_turns):
132
+ # Generate response from LLM
133
+ max_resp_len = self.max_resp_len - (len(prompt_ids) - len(orig_prompt_ids))
134
+
135
+ generate_result = await self.generate(
136
+ prompt_ids, sampling_params={"max_new_tokens": max_resp_len}
137
+ )
138
+ generate_prompt_ids = copy.deepcopy(prompt_ids)
139
+ response_ids = generate_result["output_ids"]
140
+ if len(response_ids) > max_resp_len:
141
+ response_ids = response_ids[:max_resp_len]
142
+ response_text = self.tokenizer.decode(response_ids)
143
+
144
+ # # split </search> manually
145
+ # if "</search>" in response_text:
146
+ # response_text = response_text.split("</search>")[0] + "</search>"
147
+ # response_ids = self.tokenizer.encode(response_text)
148
+
149
+ output_buffer.append(
150
+ AgentLoopOutput(
151
+ prompt_ids=copy.deepcopy(prompt_ids),
152
+ response_ids=copy.deepcopy(response_ids),
153
+ prompt_text=self.tokenizer.decode(prompt_ids),
154
+ response_text=response_text,
155
+ response_mask=response_mask,
156
+ response_logprobs=generate_result["logprobs"]
157
+ if self.return_logprobs
158
+ else None,
159
+ )
160
+ )
161
+
162
+ prompt_ids += response_ids
163
+ all_response_ids.extend(response_ids)
164
+ response_mask += [1] * len(response_ids)
165
+
166
+ if len(response_ids) == max_resp_len:
167
+ break
168
+
169
+ # Extract tool calls from response text
170
+ content, function_calls = await self.extract_tool_calls(response_text)
171
+
172
+ if function_calls == []:
173
+ break
174
+
175
+ # Execute tools in parallel with history propagation
176
+ tasks = []
177
+ for tool_request in function_calls:
178
+ tasks.append(self.tool_call(tool_request))
179
+ tool_responses: list[ToolResponse] = await asyncio.gather(*tasks)
180
+
181
+ # Convert tool responses to messages and tokenize
182
+ tool_messages = []
183
+ for tool_response in tool_responses:
184
+ message = {"role": "tool", "content": tool_response.text}
185
+ tool_messages.append(message)
186
+ tool_response_ids = self.tokenizer.encode(
187
+ tool_messages[0]["content"], add_special_tokens=False
188
+ )
189
+ max_tool_resp_len = self.max_resp_len - (
190
+ len(prompt_ids) - len(orig_prompt_ids)
191
+ )
192
+ if len(tool_response_ids) > max_tool_resp_len:
193
+ # task_failed = True
194
+ break
195
+ prompt_ids += tool_response_ids
196
+ all_response_ids.extend(tool_response_ids)
197
+ response_mask += [0] * len(tool_response_ids)
198
+
199
+ if self.print_outputs:
200
+ # add anything you want to print
201
+ trace_prints.append(
202
+ {
203
+ "decode_prompt": self.tokenizer.decode(generate_prompt_ids),
204
+ "generate": response_text,
205
+ "tool_resp": tool_messages,
206
+ }
207
+ )
208
+ # Build complete response text from all turns
209
+ complete_response_text = self.tokenizer.decode(all_response_ids)
210
+
211
+ # Extract final answer from complete response using searchr1's extract_solution
212
+ from rlinf.algorithms.rewards.searchr1 import compute_score
213
+
214
+ reward_score = compute_score(complete_response_text, answer)
215
+
216
+ for single_turn_output in output_buffer:
217
+ single_turn_output.reward_score = reward_score
218
+
219
+ return MultiAgentLoopOutput(
220
+ single_turn_outputs=output_buffer,
221
+ trace_prints=[],
222
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2026 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,331 @@
1
+ # Copyright 2026 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import time
17
+ from typing import Callable, Optional
18
+
19
+ from rlinf.data.tool_call.tool_io_struct import ToolChannelRequest, ToolChannelResponse
20
+
21
+
22
+ class ToolBase:
23
+ name = None
24
+
25
+ def __init__(self, cfg):
26
+ self.cfg = cfg
27
+
28
+ async def execute(
29
+ self, request: ToolChannelRequest, **kwargs
30
+ ) -> ToolChannelResponse:
31
+ """Execute the tool call."""
32
+ raise NotImplementedError()
33
+
34
+ def tool_schema(self) -> dict:
35
+ """
36
+ A JSON Schema, giving the name, description and argument types for the tool.
37
+ Ref: https://huggingface.co/docs/transformers/en/chat_extras#json-schemas
38
+ """
39
+ raise NotImplementedError()
40
+
41
+ def validate(self, request: ToolChannelRequest) -> Optional[ToolChannelResponse]:
42
+ """Validate the request call the right tool and the schema is right."""
43
+ raise NotImplementedError()
44
+
45
+
46
+ class CodeJudgeToolBase(ToolBase):
47
+ name = None
48
+
49
+ def __init__(self, cfg):
50
+ super().__init__(cfg=cfg)
51
+ self.url = f"http://{self.cfg.tools.codejudge.host_addr}:{self.cfg.tools.codejudge.host_port}/run/long-batch"
52
+
53
+ def _postprocess(self, result):
54
+ if result["run_success"] and result["success"]:
55
+ output_parts = []
56
+ output_parts.append("Tool call success")
57
+ if result["stdout"]:
58
+ output_parts.append(f"stdout: {result['stdout']}")
59
+ if result["stderr"]:
60
+ output_parts.append(f"stderr: {result['stderr']}")
61
+ output_parts.append(f"execution time: {result['cost']:.2f}s")
62
+ result = "\n".join(output_parts)
63
+ return ToolChannelResponse(success=True, result=result)
64
+ else:
65
+ output_parts = []
66
+ output_parts.append("Tool call failure")
67
+ output_parts.append(f"reason: {result['reason']}")
68
+ if result["stdout"]:
69
+ output_parts.append(f"stdout: {result['stdout']}")
70
+ if result["stderr"]:
71
+ output_parts.append(f"stderr: {result['stderr']}")
72
+ output_parts.append(f"execution time: {result['cost']:.2f}s")
73
+ result = "\n".join(output_parts)
74
+ return ToolChannelResponse(success=False, result=result)
75
+
76
+
77
+ code_template_setup = '''
78
+ import os
79
+ import base64
80
+ import sys
81
+ import ast
82
+ import traceback
83
+ from typing import Optional, Any
84
+ import linecache
85
+ from types import CodeType
86
+ from contextlib import redirect_stdout, redirect_stderr
87
+ from io import StringIO
88
+
89
+ class CodeExecutionError(Exception):
90
+ """Custom exception for code execution errors with line information"""
91
+ def __init__(self, original_error: Exception, code: str, line_offset: int = 0):
92
+ self.original_error = original_error
93
+ self.code = code
94
+ self.line_offset = line_offset
95
+
96
+ # Get error line number
97
+ if hasattr(original_error, 'lineno'):
98
+ self.lineno = original_error.lineno
99
+ else:
100
+ tb = getattr(original_error, '__traceback__', None)
101
+ if tb:
102
+ while tb.tb_next:
103
+ tb = tb.tb_next
104
+ self.lineno = tb.tb_lineno
105
+ else:
106
+ self.lineno = -1
107
+
108
+ # Adjust line number for code segment
109
+ if self.lineno != -1:
110
+ self.lineno += line_offset
111
+
112
+ # Format error message
113
+ error_type = type(original_error).__name__
114
+ error_msg = str(original_error)
115
+
116
+ if self.lineno != -1:
117
+ # Get the problematic line
118
+ lines = code.splitlines()
119
+ if 0 <= self.lineno - 1 < len(lines):
120
+ error_line = lines[self.lineno - 1]
121
+ # Create error message with line information
122
+ super().__init__(f"{error_type} at line {self.lineno}: {error_msg}\\n {error_line}")
123
+ return
124
+
125
+ super().__init__(f"{error_type}: {error_msg}")
126
+
127
+ class PersistentExecutor:
128
+ def __init__(self):
129
+ self.exec_globals = {
130
+ '__name__': '__main__',
131
+ '__file__': '<string>',
132
+ '__builtins__': __builtins__
133
+ }
134
+
135
+ def split_code(self, code: str) -> tuple[str, Optional[str]]:
136
+ """
137
+ Intelligently split code into main body and last expression
138
+
139
+ Args:
140
+ code: The source code string
141
+
142
+ Returns:
143
+ tuple[str, Optional[str]]: (main code body, last expression if exists)
144
+ """
145
+ try:
146
+ # Parse code into AST
147
+ tree = ast.parse(code)
148
+ if not tree.body:
149
+ return code, None
150
+
151
+ # Check if the last node is a pure expression (not a call)
152
+ last_node = tree.body[-1]
153
+ if isinstance(last_node, ast.Expr):
154
+ # Get the line range of the last expression
155
+ last_expr_start = last_node.lineno
156
+ last_expr_end = last_node.end_lineno if hasattr(last_node, 'end_lineno') else last_node.lineno
157
+
158
+ # Split the code
159
+ lines = code.splitlines()
160
+ main_code = '\\n'.join(lines[:last_expr_start-1])
161
+ last_expr = '\\n'.join(lines[last_expr_start-1:last_expr_end])
162
+ return main_code, last_expr
163
+ except SyntaxError as e:
164
+ raise CodeExecutionError(e, code)
165
+ return code, None
166
+
167
+ def execute_code(self, code: str, replay_history_code: bool) -> None:
168
+ """
169
+ Execute code while maintaining persistent environment state.
170
+ If the last line is an expression, its value will be printed to stdout.
171
+
172
+ Args:
173
+ code: The source code string to execute
174
+ replay_history_code: If True, suppress stdout and stderr output
175
+ """
176
+ try:
177
+ # Split code intelligently
178
+ main_code, last_expr = self.split_code(code)
179
+
180
+ # Set up output redirection if replay_history_code is True
181
+ if replay_history_code:
182
+ stdout_capture = StringIO()
183
+ stderr_capture = StringIO()
184
+ stdout_context = redirect_stdout(stdout_capture)
185
+ stderr_context = redirect_stderr(stderr_capture)
186
+ else:
187
+ stdout_context = redirect_stdout(sys.stdout)
188
+ stderr_context = redirect_stderr(sys.stderr)
189
+
190
+ # Execute main code body
191
+ if main_code:
192
+ try:
193
+ # Compile code to get better error line numbers
194
+ compiled_code = compile(main_code, '<string>', 'exec')
195
+ with stdout_context, stderr_context:
196
+ exec(compiled_code, self.exec_globals)
197
+ except Exception as e:
198
+ raise CodeExecutionError(e, main_code)
199
+
200
+ # If there's a last expression, try to evaluate and print it
201
+ if last_expr:
202
+ try:
203
+ # Compile expression to get better error line numbers
204
+ compiled_expr = compile(last_expr, '<string>', 'eval')
205
+ with stdout_context, stderr_context:
206
+ last_value = eval(compiled_expr, self.exec_globals)
207
+
208
+ # Only print the result if not in replay mode
209
+ if last_value is not None and not replay_history_code:
210
+ print(repr(last_value), file=sys.stdout)
211
+ except Exception as e:
212
+ # Try executing as statement if evaluation fails
213
+ try:
214
+ compiled_stmt = compile(last_expr, '<string>', 'exec')
215
+ with stdout_context, stderr_context:
216
+ exec(compiled_stmt, self.exec_globals)
217
+ except Exception as e:
218
+ # Calculate line offset for the last expression
219
+ line_offset = len(main_code.splitlines()) if main_code else 0
220
+ raise CodeExecutionError(e, last_expr, line_offset)
221
+
222
+ except Exception as e:
223
+ if replay_history_code:
224
+ return
225
+ if isinstance(e, CodeExecutionError):
226
+ print(str(e), file=sys.stderr)
227
+ else:
228
+ traceback.print_exc(file=sys.stderr)
229
+ os._exit(1)
230
+ return
231
+
232
+ persistent_executor = PersistentExecutor()
233
+ '''
234
+
235
+
236
+ code_template_exec = """
237
+ code_to_execute = base64.b64decode("{}".encode()).decode()
238
+ persistent_executor.execute_code(code_to_execute, replay_history_code={})
239
+ """
240
+
241
+
242
+ class PythonTool(CodeJudgeToolBase):
243
+ name = "python_code_with_standard_io"
244
+
245
+ def __init__(self, cfg):
246
+ super().__init__(cfg=cfg)
247
+
248
+ async def execute(
249
+ self,
250
+ request: ToolChannelRequest,
251
+ send_request_func: Callable[[str, dict], ToolChannelResponse],
252
+ ):
253
+ err_msg = self.validate(request)
254
+ if err_msg:
255
+ return err_msg
256
+
257
+ # convert the code to the code exec on code-judge
258
+ code_to_execute = base64.b64encode(
259
+ request.tool_args.get("code", "").encode()
260
+ ).decode()
261
+ final_code = code_template_setup
262
+ # TODO: add history code here
263
+ final_code += code_template_exec.format(code_to_execute, "False")
264
+
265
+ submission = {
266
+ "type": "python",
267
+ "solution": final_code,
268
+ "input": request.tool_args.get("input", ""),
269
+ }
270
+
271
+ data = {"type": "batch", "submissions": [submission]}
272
+
273
+ for retry_time in range(4):
274
+ try:
275
+ results = (await send_request_func(self.url, data))["results"]
276
+ break
277
+ except Exception as e:
278
+ print(f"Tool retry time {retry_time}, exception: {e}")
279
+ time.sleep(1)
280
+ else:
281
+ raise RuntimeError("Tool call failed after retries")
282
+ assert len(results) == 1, f"{results}"
283
+ return self._postprocess(results[0])
284
+
285
+ def tool_schema(self) -> dict:
286
+ return {
287
+ "type": "function",
288
+ "function": {
289
+ "name": "python_code_with_standard_io",
290
+ "description": "Execute Python code with standard input and capture standard output. This function takes a Python code string and an input string, provides the input string through standard input (stdin) to the code, and captures and returns any output produced through standard output (stdout). If the executed code raises an exception, the error message will be captured and returned instead.",
291
+ "parameters": {
292
+ "type": "object",
293
+ "properties": {
294
+ "code": {
295
+ "type": "string",
296
+ "description": "A string containing Python code to be executed. The code can read from standard input using the input() function.",
297
+ },
298
+ "input": {
299
+ "type": "string",
300
+ "description": "A string that will be provided as standard input to the code when it calls input().",
301
+ },
302
+ },
303
+ "required": ["code", "input"],
304
+ },
305
+ },
306
+ }
307
+
308
+ def validate(self, request) -> Optional[ToolChannelResponse]:
309
+ tool_args = request.tool_args
310
+
311
+ assert request.tool_name == self.name, (
312
+ f"Name mismatch, {self.name} != {request.tool_name}"
313
+ )
314
+
315
+ if not isinstance(tool_args, dict):
316
+ return ToolChannelResponse(
317
+ success=False,
318
+ result="Error when executing tool: run_tool_calls_on_server_async failed for1 tool calls after 4 attempts.",
319
+ )
320
+
321
+ if "code" in tool_args and not isinstance(tool_args["code"], str):
322
+ return ToolChannelResponse(
323
+ success=False,
324
+ result="Error when executing tool: run_tool_calls_on_server_async failed for1 tool calls after 4 attempts.",
325
+ )
326
+
327
+ if "input" in tool_args and not isinstance(tool_args["input"], str):
328
+ return ToolChannelResponse(
329
+ success=False,
330
+ result="Error when executing tool: run_tool_calls_on_server_async failed for1 tool calls after 4 attempts.",
331
+ )