rlinf 0.2.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (356) hide show
  1. rlinf/__init__.py +17 -0
  2. rlinf/agents/__init__.py +13 -0
  3. rlinf/agents/multiturn_demo/__init__.py +13 -0
  4. rlinf/agents/multiturn_demo/fake_tool_worker.py +64 -0
  5. rlinf/agents/multiturn_demo/mcp_agent_loop.py +221 -0
  6. rlinf/agents/multiturn_demo/mcp_filesystem_worker.py +577 -0
  7. rlinf/agents/multiturn_demo/tool_agent_loop.py +163 -0
  8. rlinf/agents/searchr1/__init__.py +13 -0
  9. rlinf/agents/searchr1/search_tool_worker.py +152 -0
  10. rlinf/agents/searchr1/searchr1_agent_loop.py +193 -0
  11. rlinf/algorithms/__init__.py +14 -0
  12. rlinf/algorithms/advantages.py +186 -0
  13. rlinf/algorithms/losses.py +277 -0
  14. rlinf/algorithms/registry.py +118 -0
  15. rlinf/algorithms/rewards/__init__.py +36 -0
  16. rlinf/algorithms/rewards/code/__init__.py +33 -0
  17. rlinf/algorithms/rewards/code/code_verifier/__init__.py +13 -0
  18. rlinf/algorithms/rewards/code/code_verifier/verify.py +230 -0
  19. rlinf/algorithms/rewards/math/__init__.py +42 -0
  20. rlinf/algorithms/rewards/math/math_verifier/__init__.py +13 -0
  21. rlinf/algorithms/rewards/math/math_verifier/parser.py +441 -0
  22. rlinf/algorithms/rewards/math/math_verifier/verify.py +441 -0
  23. rlinf/algorithms/rewards/searchr1/__init__.py +161 -0
  24. rlinf/algorithms/rewards/vqa/__init__.py +60 -0
  25. rlinf/algorithms/rewards/vqa/format_rewards.py +66 -0
  26. rlinf/algorithms/rewards/vqa/qa_rewards.py +109 -0
  27. rlinf/algorithms/utils.py +359 -0
  28. rlinf/config.py +1255 -0
  29. rlinf/data/__init__.py +13 -0
  30. rlinf/data/datasets/__init__.py +135 -0
  31. rlinf/data/datasets/item.py +45 -0
  32. rlinf/data/datasets/math.py +228 -0
  33. rlinf/data/datasets/vlm.py +468 -0
  34. rlinf/data/datasets/world_model.py +365 -0
  35. rlinf/data/embodied_io_struct.py +521 -0
  36. rlinf/data/io_struct.py +1139 -0
  37. rlinf/data/replay_buffer.py +1169 -0
  38. rlinf/data/tokenizers.py +71 -0
  39. rlinf/data/tool_call/__init__.py +13 -0
  40. rlinf/data/tool_call/tool_io_struct.py +122 -0
  41. rlinf/data/utils.py +53 -0
  42. rlinf/envs/__init__.py +107 -0
  43. rlinf/envs/action_utils.py +220 -0
  44. rlinf/envs/behavior/__init__.py +13 -0
  45. rlinf/envs/behavior/behavior_env.py +312 -0
  46. rlinf/envs/calvin/__init__.py +129 -0
  47. rlinf/envs/calvin/calvin_gym_env.py +486 -0
  48. rlinf/envs/calvin/utils.py +75 -0
  49. rlinf/envs/calvin/venv.py +264 -0
  50. rlinf/envs/frankasim/__init__.py +19 -0
  51. rlinf/envs/frankasim/frankasim_env.py +722 -0
  52. rlinf/envs/habitat/__init__.py +17 -0
  53. rlinf/envs/habitat/extensions/__init__.py +13 -0
  54. rlinf/envs/habitat/extensions/config/vlnce_r2r.yaml +68 -0
  55. rlinf/envs/habitat/extensions/maps.py +357 -0
  56. rlinf/envs/habitat/extensions/utils.py +711 -0
  57. rlinf/envs/habitat/habitat_env.py +348 -0
  58. rlinf/envs/habitat/venv.py +246 -0
  59. rlinf/envs/isaaclab/__init__.py +21 -0
  60. rlinf/envs/isaaclab/isaaclab_env.py +264 -0
  61. rlinf/envs/isaaclab/tasks/__init__.py +13 -0
  62. rlinf/envs/isaaclab/tasks/stack_cube.py +97 -0
  63. rlinf/envs/isaaclab/utils.py +62 -0
  64. rlinf/envs/isaaclab/venv.py +118 -0
  65. rlinf/envs/libero/__init__.py +13 -0
  66. rlinf/envs/libero/libero_env.py +459 -0
  67. rlinf/envs/libero/utils.py +125 -0
  68. rlinf/envs/libero/venv.py +173 -0
  69. rlinf/envs/maniskill/__init__.py +33 -0
  70. rlinf/envs/maniskill/maniskill_env.py +406 -0
  71. rlinf/envs/maniskill/tasks/__init__.py +13 -0
  72. rlinf/envs/maniskill/tasks/put_carrot_on_plate.py +151 -0
  73. rlinf/envs/maniskill/tasks/put_on_in_scene_multi.py +963 -0
  74. rlinf/envs/maniskill/tasks/variants/__init__.py +51 -0
  75. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_carrot.py +77 -0
  76. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_ee_pose.py +289 -0
  77. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_image.py +105 -0
  78. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_instruct.py +123 -0
  79. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_multi_carrot.py +352 -0
  80. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_multi_plate.py +399 -0
  81. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_plate.py +97 -0
  82. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_position.py +231 -0
  83. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_position_change.py +179 -0
  84. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_single.py +110 -0
  85. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_image.py +49 -0
  86. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_texture.py +325 -0
  87. rlinf/envs/maniskill/tasks/variants/put_on_plate_25_vision_whole.py +364 -0
  88. rlinf/envs/maniskill/tasks/variants/utils.py +31 -0
  89. rlinf/envs/maniskill/utils.py +41 -0
  90. rlinf/envs/metaworld/__init__.py +61 -0
  91. rlinf/envs/metaworld/metaworld_env.py +442 -0
  92. rlinf/envs/metaworld/utils.py +21 -0
  93. rlinf/envs/metaworld/venv.py +170 -0
  94. rlinf/envs/realworld/__init__.py +26 -0
  95. rlinf/envs/realworld/common/camera/__init__.py +17 -0
  96. rlinf/envs/realworld/common/camera/camera.py +143 -0
  97. rlinf/envs/realworld/common/keyboard/__init__.py +13 -0
  98. rlinf/envs/realworld/common/keyboard/keyboard_listener.py +42 -0
  99. rlinf/envs/realworld/common/ros/__init__.py +17 -0
  100. rlinf/envs/realworld/common/ros/ros_controller.py +129 -0
  101. rlinf/envs/realworld/common/spacemouse/__init__.py +13 -0
  102. rlinf/envs/realworld/common/spacemouse/spacemouse_expert.py +74 -0
  103. rlinf/envs/realworld/common/video_player/__init__.py +17 -0
  104. rlinf/envs/realworld/common/video_player/video_player.py +55 -0
  105. rlinf/envs/realworld/common/wrappers/__init__.py +31 -0
  106. rlinf/envs/realworld/common/wrappers/euler_obs.py +39 -0
  107. rlinf/envs/realworld/common/wrappers/gripper_close.py +41 -0
  108. rlinf/envs/realworld/common/wrappers/relative_frame.py +140 -0
  109. rlinf/envs/realworld/common/wrappers/reward_done_wrapper.py +105 -0
  110. rlinf/envs/realworld/common/wrappers/spacemouse_intervention.py +72 -0
  111. rlinf/envs/realworld/franka/__init__.py +18 -0
  112. rlinf/envs/realworld/franka/franka_controller.py +375 -0
  113. rlinf/envs/realworld/franka/franka_env.py +544 -0
  114. rlinf/envs/realworld/franka/franka_robot_state.py +48 -0
  115. rlinf/envs/realworld/franka/tasks/__init__.py +35 -0
  116. rlinf/envs/realworld/franka/tasks/bottle.py +136 -0
  117. rlinf/envs/realworld/franka/tasks/franka_bin_relocation.py +240 -0
  118. rlinf/envs/realworld/franka/tasks/peg_insertion_env.py +129 -0
  119. rlinf/envs/realworld/franka/utils.py +105 -0
  120. rlinf/envs/realworld/realworld_env.py +395 -0
  121. rlinf/envs/realworld/venv.py +319 -0
  122. rlinf/envs/robocasa/__init__.py +17 -0
  123. rlinf/envs/robocasa/robocasa_env.py +509 -0
  124. rlinf/envs/robocasa/utils.py +178 -0
  125. rlinf/envs/robocasa/venv.py +163 -0
  126. rlinf/envs/robotwin/__init__.py +13 -0
  127. rlinf/envs/robotwin/robotwin_env.py +493 -0
  128. rlinf/envs/utils.py +294 -0
  129. rlinf/envs/venv/__init__.py +33 -0
  130. rlinf/envs/venv/venv.py +985 -0
  131. rlinf/envs/world_model/__init__.py +13 -0
  132. rlinf/envs/world_model/base_world_env.py +155 -0
  133. rlinf/envs/world_model/world_model_opensora_env.py +801 -0
  134. rlinf/envs/wrappers/__init__.py +17 -0
  135. rlinf/envs/wrappers/record_video.py +377 -0
  136. rlinf/hybrid_engines/__init__.py +13 -0
  137. rlinf/hybrid_engines/fsdp/__init__.py +49 -0
  138. rlinf/hybrid_engines/fsdp/fsdp_model_manager.py +605 -0
  139. rlinf/hybrid_engines/fsdp/strategy/__init__.py +13 -0
  140. rlinf/hybrid_engines/fsdp/strategy/base.py +515 -0
  141. rlinf/hybrid_engines/fsdp/strategy/checkpoint.py +132 -0
  142. rlinf/hybrid_engines/fsdp/strategy/fsdp.py +350 -0
  143. rlinf/hybrid_engines/fsdp/strategy/fsdp2.py +203 -0
  144. rlinf/hybrid_engines/fsdp/utils.py +829 -0
  145. rlinf/hybrid_engines/megatron/__init__.py +13 -0
  146. rlinf/hybrid_engines/megatron/megatron_model_manager.py +667 -0
  147. rlinf/hybrid_engines/megatron/utils.py +232 -0
  148. rlinf/hybrid_engines/sglang/common/__init__.py +13 -0
  149. rlinf/hybrid_engines/sglang/common/detokenizer_manager.py +57 -0
  150. rlinf/hybrid_engines/sglang/common/io_struct.py +52 -0
  151. rlinf/hybrid_engines/sglang/common/sgl_engine.py +138 -0
  152. rlinf/hybrid_engines/sglang/common/sgl_scheduler.py +585 -0
  153. rlinf/hybrid_engines/sglang/common/tokenizer_manager.py +236 -0
  154. rlinf/hybrid_engines/vllm/vllm_0_8_5/__init__.py +13 -0
  155. rlinf/hybrid_engines/vllm/vllm_0_8_5/executor.py +313 -0
  156. rlinf/hybrid_engines/vllm/vllm_0_8_5/weight_loader.py +43 -0
  157. rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py +161 -0
  158. rlinf/models/__init__.py +96 -0
  159. rlinf/models/embodiment/__init__.py +13 -0
  160. rlinf/models/embodiment/base_policy.py +70 -0
  161. rlinf/models/embodiment/cnn_policy/__init__.py +26 -0
  162. rlinf/models/embodiment/cnn_policy/cnn_policy.py +458 -0
  163. rlinf/models/embodiment/dexbotic_pi/__init__.py +120 -0
  164. rlinf/models/embodiment/dexbotic_pi/dexbotic_pi_policy.py +772 -0
  165. rlinf/models/embodiment/flow_policy/__init__.py +51 -0
  166. rlinf/models/embodiment/flow_policy/flow_policy.py +635 -0
  167. rlinf/models/embodiment/gr00t/__init__.py +82 -0
  168. rlinf/models/embodiment/gr00t/embodiment_tags.py +59 -0
  169. rlinf/models/embodiment/gr00t/gr00t_action_model.py +732 -0
  170. rlinf/models/embodiment/gr00t/modality_config.py +177 -0
  171. rlinf/models/embodiment/gr00t/simulation_io.py +198 -0
  172. rlinf/models/embodiment/gr00t/utils.py +127 -0
  173. rlinf/models/embodiment/mlp_policy/__init__.py +31 -0
  174. rlinf/models/embodiment/mlp_policy/mlp_policy.py +288 -0
  175. rlinf/models/embodiment/modules/__init__.py +13 -0
  176. rlinf/models/embodiment/modules/batch_renorm.py +132 -0
  177. rlinf/models/embodiment/modules/entropy_tunning.py +69 -0
  178. rlinf/models/embodiment/modules/explore_noise_net.py +168 -0
  179. rlinf/models/embodiment/modules/flow_actor.py +469 -0
  180. rlinf/models/embodiment/modules/mlp.py +95 -0
  181. rlinf/models/embodiment/modules/q_head.py +328 -0
  182. rlinf/models/embodiment/modules/resnet_utils.py +158 -0
  183. rlinf/models/embodiment/modules/utils.py +65 -0
  184. rlinf/models/embodiment/modules/value_head.py +67 -0
  185. rlinf/models/embodiment/openpi/__init__.py +96 -0
  186. rlinf/models/embodiment/openpi/dataconfig/__init__.py +327 -0
  187. rlinf/models/embodiment/openpi/dataconfig/behavior_dataconfig.py +102 -0
  188. rlinf/models/embodiment/openpi/dataconfig/calvin_dataconfig.py +71 -0
  189. rlinf/models/embodiment/openpi/dataconfig/franka_dataconfig.py +101 -0
  190. rlinf/models/embodiment/openpi/dataconfig/gsenv_dataconfig.py +58 -0
  191. rlinf/models/embodiment/openpi/dataconfig/libero_dataconfig.py +102 -0
  192. rlinf/models/embodiment/openpi/dataconfig/maniskill_dataconfig.py +103 -0
  193. rlinf/models/embodiment/openpi/dataconfig/metaworld_dataconfig.py +68 -0
  194. rlinf/models/embodiment/openpi/dataconfig/robocasa_dataconfig.py +75 -0
  195. rlinf/models/embodiment/openpi/dataconfig/robotwin_aloha_dataconfig.py +106 -0
  196. rlinf/models/embodiment/openpi/openpi_action_model.py +789 -0
  197. rlinf/models/embodiment/openpi/policies/__init__.py +13 -0
  198. rlinf/models/embodiment/openpi/policies/aloha_policy.py +227 -0
  199. rlinf/models/embodiment/openpi/policies/behavior_policy.py +119 -0
  200. rlinf/models/embodiment/openpi/policies/calvin_policy.py +87 -0
  201. rlinf/models/embodiment/openpi/policies/franka_policy.py +137 -0
  202. rlinf/models/embodiment/openpi/policies/gsenv_policy.py +74 -0
  203. rlinf/models/embodiment/openpi/policies/libero_policy.py +118 -0
  204. rlinf/models/embodiment/openpi/policies/maniskill_policy.py +114 -0
  205. rlinf/models/embodiment/openpi/policies/metaworld_policy.py +75 -0
  206. rlinf/models/embodiment/openpi/policies/robocasa_policy.py +126 -0
  207. rlinf/models/embodiment/openvla/__init__.py +104 -0
  208. rlinf/models/embodiment/openvla/openvla_action_model.py +808 -0
  209. rlinf/models/embodiment/openvla_oft/__init__.py +32 -0
  210. rlinf/models/embodiment/openvla_oft/official/__init__.py +118 -0
  211. rlinf/models/embodiment/openvla_oft/official/openvla_oft_action_model.py +703 -0
  212. rlinf/models/embodiment/openvla_oft/openvla_utils.py +193 -0
  213. rlinf/models/embodiment/openvla_oft/rlinf/__init__.py +109 -0
  214. rlinf/models/embodiment/openvla_oft/rlinf/openvla_oft_action_model.py +574 -0
  215. rlinf/models/embodiment/prismatic/__init__.py +13 -0
  216. rlinf/models/embodiment/prismatic/processing_prismatic.py +243 -0
  217. rlinf/runners/__init__.py +13 -0
  218. rlinf/runners/agent_eval_runner.py +379 -0
  219. rlinf/runners/agent_runner.py +319 -0
  220. rlinf/runners/async_embodied_runner.py +181 -0
  221. rlinf/runners/coding_online_rl_runner.py +308 -0
  222. rlinf/runners/embodied_eval_runner.py +80 -0
  223. rlinf/runners/embodied_runner.py +291 -0
  224. rlinf/runners/reasoning_eval_runner.py +194 -0
  225. rlinf/runners/reasoning_runner.py +498 -0
  226. rlinf/runners/sft_runner.py +153 -0
  227. rlinf/scheduler/__init__.py +49 -0
  228. rlinf/scheduler/channel/__init__.py +18 -0
  229. rlinf/scheduler/channel/channel.py +645 -0
  230. rlinf/scheduler/channel/channel_worker.py +536 -0
  231. rlinf/scheduler/cluster/__init__.py +29 -0
  232. rlinf/scheduler/cluster/cluster.py +486 -0
  233. rlinf/scheduler/cluster/config.py +442 -0
  234. rlinf/scheduler/cluster/node.py +546 -0
  235. rlinf/scheduler/cluster/utils.py +221 -0
  236. rlinf/scheduler/collective/__init__.py +36 -0
  237. rlinf/scheduler/collective/async_work.py +386 -0
  238. rlinf/scheduler/collective/collective.py +96 -0
  239. rlinf/scheduler/collective/collective_group.py +1761 -0
  240. rlinf/scheduler/collective/multi_channel_pg.py +894 -0
  241. rlinf/scheduler/dynamic_scheduler/__init__.py +13 -0
  242. rlinf/scheduler/dynamic_scheduler/manager.py +1069 -0
  243. rlinf/scheduler/dynamic_scheduler/scheduler_worker.py +129 -0
  244. rlinf/scheduler/dynamic_scheduler/utils.py +162 -0
  245. rlinf/scheduler/hardware/__init__.py +36 -0
  246. rlinf/scheduler/hardware/accelerators/__init__.py +31 -0
  247. rlinf/scheduler/hardware/accelerators/accelerator.py +284 -0
  248. rlinf/scheduler/hardware/accelerators/amd_gpu.py +141 -0
  249. rlinf/scheduler/hardware/accelerators/ascend_npu.py +113 -0
  250. rlinf/scheduler/hardware/accelerators/intel_gpu.py +113 -0
  251. rlinf/scheduler/hardware/accelerators/musa_gpu.py +147 -0
  252. rlinf/scheduler/hardware/accelerators/nvidia_gpu.py +203 -0
  253. rlinf/scheduler/hardware/hardware.py +180 -0
  254. rlinf/scheduler/hardware/robots/__init__.py +17 -0
  255. rlinf/scheduler/hardware/robots/franka.py +176 -0
  256. rlinf/scheduler/manager/__init__.py +32 -0
  257. rlinf/scheduler/manager/coll_manager.py +140 -0
  258. rlinf/scheduler/manager/lock_manager.py +187 -0
  259. rlinf/scheduler/manager/manager.py +123 -0
  260. rlinf/scheduler/manager/node_manager.py +46 -0
  261. rlinf/scheduler/manager/worker_manager.py +315 -0
  262. rlinf/scheduler/placement/__init__.py +27 -0
  263. rlinf/scheduler/placement/flexible.py +277 -0
  264. rlinf/scheduler/placement/node.py +205 -0
  265. rlinf/scheduler/placement/packed.py +335 -0
  266. rlinf/scheduler/placement/placement.py +674 -0
  267. rlinf/scheduler/worker/__init__.py +24 -0
  268. rlinf/scheduler/worker/lock.py +103 -0
  269. rlinf/scheduler/worker/worker.py +1262 -0
  270. rlinf/scheduler/worker/worker_group.py +527 -0
  271. rlinf/utils/__init__.py +13 -0
  272. rlinf/utils/ckpt_convertor/__init__.py +13 -0
  273. rlinf/utils/ckpt_convertor/convert_openpi_jax_to_python.py +706 -0
  274. rlinf/utils/ckpt_convertor/fsdp_convertor/__init__.py +13 -0
  275. rlinf/utils/ckpt_convertor/fsdp_convertor/config/fsdp_model_convertor.yaml +27 -0
  276. rlinf/utils/ckpt_convertor/fsdp_convertor/convert_dcp_to_pt.py +58 -0
  277. rlinf/utils/ckpt_convertor/fsdp_convertor/convert_pt_to_hf.py +81 -0
  278. rlinf/utils/ckpt_convertor/fsdp_convertor/utils.py +197 -0
  279. rlinf/utils/ckpt_convertor/megatron_convertor/__init__.py +13 -0
  280. rlinf/utils/ckpt_convertor/megatron_convertor/config.py +208 -0
  281. rlinf/utils/ckpt_convertor/megatron_convertor/convert_hf_to_mg.py +410 -0
  282. rlinf/utils/ckpt_convertor/megatron_convertor/convert_hf_to_middle_file.py +503 -0
  283. rlinf/utils/ckpt_convertor/megatron_convertor/convert_mg_to_middle_file.py +863 -0
  284. rlinf/utils/ckpt_convertor/megatron_convertor/convert_middle_file_to_hf.py +726 -0
  285. rlinf/utils/ckpt_convertor/megatron_convertor/convert_middle_file_to_mg.py +626 -0
  286. rlinf/utils/ckpt_convertor/megatron_convertor/default_args.yaml +144 -0
  287. rlinf/utils/ckpt_convertor/megatron_convertor/utils/__init__.py +31 -0
  288. rlinf/utils/ckpt_convertor/megatron_convertor/utils/fp8_utils.py +135 -0
  289. rlinf/utils/ckpt_convertor/megatron_convertor/utils/mg_loader.py +171 -0
  290. rlinf/utils/ckpt_convertor/megatron_convertor/utils/mg_moe_groupgemm.py +198 -0
  291. rlinf/utils/ckpt_convertor/megatron_convertor/utils/mp_utils.py +61 -0
  292. rlinf/utils/ckpt_convertor/megatron_convertor/utils/safetensors_loader.py +116 -0
  293. rlinf/utils/ckpt_convertor/megatron_convertor/utils/tensor_operations.py +402 -0
  294. rlinf/utils/convertor/__init__.py +13 -0
  295. rlinf/utils/convertor/utils.py +637 -0
  296. rlinf/utils/data_iter_utils.py +692 -0
  297. rlinf/utils/data_process.py +90 -0
  298. rlinf/utils/distributed.py +1014 -0
  299. rlinf/utils/drq.py +109 -0
  300. rlinf/utils/flops.py +240 -0
  301. rlinf/utils/initialize.py +330 -0
  302. rlinf/utils/logging.py +20 -0
  303. rlinf/utils/metric_logger.py +121 -0
  304. rlinf/utils/metric_utils.py +342 -0
  305. rlinf/utils/nested_dict_process.py +110 -0
  306. rlinf/utils/omega_resolver.py +36 -0
  307. rlinf/utils/patcher.py +217 -0
  308. rlinf/utils/placement.py +449 -0
  309. rlinf/utils/profiler.py +242 -0
  310. rlinf/utils/resharding/__init__.py +13 -0
  311. rlinf/utils/resharding/mcore_weight_reshard.py +335 -0
  312. rlinf/utils/resharding/reshard_config.py +93 -0
  313. rlinf/utils/resharding/utils.py +332 -0
  314. rlinf/utils/runner_utils.py +82 -0
  315. rlinf/utils/timers.py +195 -0
  316. rlinf/utils/torch_functionals.py +32 -0
  317. rlinf/utils/train_utils.py +84 -0
  318. rlinf/utils/utils.py +507 -0
  319. rlinf/workers/__init__.py +13 -0
  320. rlinf/workers/actor/__init__.py +30 -0
  321. rlinf/workers/actor/async_fsdp_sac_policy_worker.py +137 -0
  322. rlinf/workers/actor/fsdp_actor_worker.py +1377 -0
  323. rlinf/workers/actor/fsdp_sac_policy_worker.py +748 -0
  324. rlinf/workers/actor/megatron_actor_worker.py +1543 -0
  325. rlinf/workers/agent/__init__.py +13 -0
  326. rlinf/workers/agent/agent_loop.py +246 -0
  327. rlinf/workers/agent/tool_worker.py +43 -0
  328. rlinf/workers/env/__init__.py +13 -0
  329. rlinf/workers/env/async_env_worker.py +131 -0
  330. rlinf/workers/env/env_worker.py +441 -0
  331. rlinf/workers/inference/__init__.py +13 -0
  332. rlinf/workers/inference/fsdp_inference_worker.py +146 -0
  333. rlinf/workers/inference/megatron_inference_worker.py +105 -0
  334. rlinf/workers/inference/utils.py +48 -0
  335. rlinf/workers/reward/__init__.py +13 -0
  336. rlinf/workers/reward/reward_worker.py +113 -0
  337. rlinf/workers/rollout/__init__.py +13 -0
  338. rlinf/workers/rollout/hf/__init__.py +13 -0
  339. rlinf/workers/rollout/hf/async_huggingface_worker.py +52 -0
  340. rlinf/workers/rollout/hf/huggingface_worker.py +404 -0
  341. rlinf/workers/rollout/hf/utils.py +30 -0
  342. rlinf/workers/rollout/server/__init__.py +13 -0
  343. rlinf/workers/rollout/server/online_router_worker.py +259 -0
  344. rlinf/workers/rollout/server/server_rollout_worker.py +379 -0
  345. rlinf/workers/rollout/sglang/__init__.py +43 -0
  346. rlinf/workers/rollout/sglang/sglang_worker.py +505 -0
  347. rlinf/workers/rollout/utils.py +562 -0
  348. rlinf/workers/rollout/vllm/__init__.py +41 -0
  349. rlinf/workers/rollout/vllm/vllm_worker.py +508 -0
  350. rlinf/workers/sft/__init__.py +13 -0
  351. rlinf/workers/sft/fsdp_sft_worker.py +175 -0
  352. rlinf-0.2.0.dev1.dist-info/METADATA +575 -0
  353. rlinf-0.2.0.dev1.dist-info/RECORD +356 -0
  354. rlinf-0.2.0.dev1.dist-info/WHEEL +5 -0
  355. rlinf-0.2.0.dev1.dist-info/licenses/LICENSE +201 -0
  356. rlinf-0.2.0.dev1.dist-info/top_level.txt +1 -0
rlinf/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .utils.omega_resolver import omegaconf_register
16
+
17
+ omegaconf_register()
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import asyncio
17
+
18
+ from omegaconf import DictConfig
19
+
20
+ from rlinf.data.tool_call.tool_io_struct import ToolChannelRequest, ToolChannelResponse
21
+ from rlinf.scheduler import Channel
22
+ from rlinf.workers.agent.tool_worker import ToolWorker
23
+
24
+
25
+ class FakeToolWorker(ToolWorker):
26
+ def __init__(self, cfg: DictConfig):
27
+ super().__init__()
28
+ self.cfg = cfg
29
+ self.request_processor_task = None
30
+
31
+ def init_worker(self, input_channel: Channel, output_channel: Channel):
32
+ self.input_channel = input_channel
33
+ self.output_channel = output_channel
34
+
35
+ def start_server(self):
36
+ loop = asyncio.get_running_loop()
37
+ self.request_processor_task = loop.create_task(self._process_requests())
38
+
39
+ def stop_server(self):
40
+ # Cancel request processor task
41
+ if self.request_processor_task and not self.request_processor_task.done():
42
+ self.request_processor_task.cancel()
43
+
44
+ async def _process_requests(self):
45
+ async def generate_and_send(session_id: str, tool_args: dict):
46
+ response = ToolChannelResponse(
47
+ success=True,
48
+ result="fake_tool_response",
49
+ )
50
+ await self.output_channel.put(
51
+ response, key=session_id, async_op=True
52
+ ).async_wait()
53
+ self.logger.info("FakeToolWorker._process_requests: sent response")
54
+
55
+ while True:
56
+ request: ToolChannelRequest = await self.input_channel.get(
57
+ async_op=True
58
+ ).async_wait()
59
+ self.logger.info("FakeToolWorker._process_requests: got request")
60
+ assert request.request_type == "execute"
61
+ assert request.tool_name == "fake_tool"
62
+ asyncio.create_task(
63
+ generate_and_send(request.session_id, request.tool_args)
64
+ )
@@ -0,0 +1,221 @@
1
+ # Copyright 2025 The RLinf Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import copy
17
+ import json
18
+ import random
19
+ import time
20
+ from dataclasses import dataclass, field
21
+ from typing import Any
22
+ from uuid import uuid4
23
+
24
+ from omegaconf import DictConfig
25
+
26
+ from rlinf.data.tool_call.tool_io_struct import (
27
+ ToolChannelRequest,
28
+ ToolChannelResponse,
29
+ ToolRequest,
30
+ ToolResponse,
31
+ )
32
+ from rlinf.utils.placement import ModelParallelComponentPlacement
33
+ from rlinf.workers.agent.agent_loop import AgentLoopOutput, AgentLoopWorker
34
+
35
+
36
+ @dataclass
37
+ class GenerateContext:
38
+ tool_session_ids: dict[str, str] = field(default_factory=dict)
39
+
40
+
41
+ class MCPAgentLoopWorker(AgentLoopWorker):
42
+ """
43
+ An agent loop worker that can interact with mcp tools with session.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ cfg: DictConfig,
49
+ placement: ModelParallelComponentPlacement,
50
+ ):
51
+ super().__init__(cfg, placement)
52
+ self.max_prompt_len = int(self.cfg.data.max_prompt_length)
53
+ max_total_len = int(self.cfg.actor.model.encoder_seq_length)
54
+ self.max_resp_len = max(1, max_total_len - self.max_prompt_len)
55
+
56
+ # 5 is a magic number in this demo.
57
+ self.max_turns = self.cfg.agentloop.get("max_turns", 5)
58
+
59
+ def generate_context_create(self) -> dict[str, Any]:
60
+ return GenerateContext()
61
+
62
+ async def generate_context_release(
63
+ self, generate_context: GenerateContext
64
+ ) -> dict[str, Any]:
65
+ for tool_worker_name, session_id in generate_context.tool_session_ids.items():
66
+ if self.tool_channel_info_map[tool_worker_name].has_session:
67
+ # tool need session
68
+ await self.tool_session_release(tool_worker_name, session_id)
69
+
70
+ async def tool_session_get(
71
+ self, generate_context: GenerateContext, tool_name: str
72
+ ) -> Any:
73
+ tool_worker_name = self.tool_name_map[tool_name]
74
+ tool_channel_info = self.tool_channel_info_map[tool_worker_name]
75
+ if tool_worker_name in generate_context.tool_session_ids:
76
+ return generate_context.tool_session_ids[tool_worker_name]
77
+ session_id = uuid4().hex
78
+ generate_context.tool_session_ids[tool_worker_name] = session_id
79
+ if tool_channel_info.has_session:
80
+ # tool need session
81
+ await tool_channel_info.input_channel.put(
82
+ ToolChannelRequest(session_id=session_id, request_type="session_start"),
83
+ async_op=True,
84
+ ).async_wait()
85
+ response: ToolChannelResponse = await self.tool_worker_output_channel.get(
86
+ session_id, async_op=True
87
+ ).async_wait()
88
+ assert response.success
89
+ return session_id
90
+
91
+ async def tool_session_release(self, tool_worker_name, session_id) -> str | dict:
92
+ tool_channel_info = self.tool_channel_info_map[tool_worker_name]
93
+ await tool_channel_info.input_channel.put(
94
+ ToolChannelRequest(session_id=session_id, request_type="session_end"),
95
+ async_op=True,
96
+ ).async_wait()
97
+ response: ToolChannelResponse = await self.tool_worker_output_channel.get(
98
+ session_id, async_op=True
99
+ ).async_wait()
100
+ assert response.success
101
+
102
+ async def tool_call(
103
+ self, generate_context: GenerateContext, tool_request: ToolRequest
104
+ ) -> ToolResponse:
105
+ tool_name, tool_args = tool_request.name, tool_request.arguments
106
+ tool_channel_info = self.tool_channel_info_map[self.tool_name_map[tool_name]]
107
+ tool_input_channel = tool_channel_info.input_channel
108
+ session_id = await self.tool_session_get(generate_context, tool_name)
109
+ await tool_input_channel.put(
110
+ ToolChannelRequest(
111
+ session_id=session_id,
112
+ request_type="execute",
113
+ tool_name=tool_name,
114
+ tool_args=tool_args,
115
+ ),
116
+ async_op=True,
117
+ ).async_wait()
118
+ response: ToolChannelResponse = await self.tool_worker_output_channel.get(
119
+ session_id, async_op=True
120
+ ).async_wait()
121
+ assert response.success
122
+ if isinstance(response.result, (list, dict)):
123
+ result_text = json.dumps(response.result)
124
+ else:
125
+ result_text = str(response.result)
126
+ return ToolResponse(
127
+ text=result_text,
128
+ )
129
+
130
+ async def extract_tool_calls(self, response_text) -> tuple[str, list[ToolRequest]]:
131
+ # random tool call
132
+ return_function_calls = random.choice(
133
+ [
134
+ [
135
+ ToolRequest(
136
+ name="write_file",
137
+ arguments={
138
+ "path": "/projects/test/mcp_written.txt",
139
+ "content": f"Written by mcp at {time.strftime('%Y-%m-%d %H:%M:%S')}",
140
+ },
141
+ )
142
+ ],
143
+ [
144
+ ToolRequest(
145
+ name="list_directory", arguments={"path": "/projects/test"}
146
+ )
147
+ ],
148
+ ]
149
+ )
150
+
151
+ return response_text, return_function_calls
152
+
153
+ async def run_one_query(self, prompt_ids: list[int]) -> AgentLoopOutput:
154
+ generate_context: GenerateContext = self.generate_context_create()
155
+ prompt_ids = prompt_ids[: self.max_prompt_len]
156
+ orig_prompt_ids = copy.deepcopy(prompt_ids)
157
+ trace_prints = []
158
+ response_mask = []
159
+ try:
160
+ for _ in range(self.max_turns):
161
+ # Generate response from LLM
162
+ generate_result = await self.generate(prompt_ids)
163
+ response_ids = generate_result["output_ids"]
164
+ max_resp_len = self.max_resp_len - (
165
+ len(prompt_ids) - len(orig_prompt_ids)
166
+ )
167
+ if len(response_ids) > max_resp_len:
168
+ response_ids = response_ids[:max_resp_len]
169
+ response_text = self.tokenizer.decode(response_ids)
170
+ prompt_ids += response_ids
171
+ response_mask += [1] * len(response_ids) # 1 for LLM generated tokens
172
+ if self.print_outputs:
173
+ # add anything you want to print
174
+ trace_prints.append({"generate": response_text})
175
+ if len(response_ids) == max_resp_len:
176
+ break
177
+
178
+ # Extract tool calls from response
179
+ _, tool_requests = await self.extract_tool_calls(response_text)
180
+
181
+ # Execute tools in parallel with history propagation
182
+ tasks = []
183
+ for tool_request in tool_requests:
184
+ tasks.append(self.tool_call(generate_context, tool_request))
185
+ tool_responses: list[ToolResponse] = await asyncio.gather(*tasks)
186
+
187
+ # Convert tool responses to messages and tokenize
188
+ tool_messages = []
189
+ for tool_response in tool_responses:
190
+ message = {"role": "tool", "content": tool_response.text}
191
+ tool_messages.append(message)
192
+
193
+ # Tokenize tool responses
194
+ tool_response_ids = self.get_tool_response_ids(tool_messages)
195
+ max_tool_resp_len = self.max_resp_len - (
196
+ len(prompt_ids) - len(orig_prompt_ids)
197
+ )
198
+ if len(tool_response_ids) > max_tool_resp_len:
199
+ break
200
+
201
+ prompt_ids += tool_response_ids
202
+ response_mask += [0] * len(
203
+ tool_response_ids
204
+ ) # 0 for tool response tokens
205
+ if self.print_outputs:
206
+ # add anything you want to print
207
+ trace_prints[-1]["tool_resp"] = tool_messages
208
+
209
+ # Separate prompt and response
210
+ response_ids = prompt_ids[len(orig_prompt_ids) :]
211
+
212
+ return AgentLoopOutput(
213
+ prompt_ids=orig_prompt_ids,
214
+ prompt_text=self.tokenizer.decode(orig_prompt_ids),
215
+ response_ids=response_ids,
216
+ response_text=self.tokenizer.decode(response_ids),
217
+ response_mask=response_mask,
218
+ trace_prints=trace_prints,
219
+ )
220
+ finally:
221
+ await self.generate_context_release(generate_context)