torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1955 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import json
10
+ import os
11
+ import queue
12
+ import re
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+ import threading
17
+ import time
18
+ from collections.abc import Callable, Sequence
19
+ from dataclasses import dataclass
20
+ from typing import Any, Protocol, TextIO
21
+
22
+ import torch
23
+
24
+ from tensordict import lazy_stack, TensorDictBase
25
+ from torchrl._utils import logger as torchrl_logger
26
+ from torchrl.data.llm import History
27
+
28
+ from torchrl.envs import Transform
29
+ from typing_extensions import TypedDict
30
+
31
+
32
+ # --- Base Class for Tool Transforms ---
33
+
34
+
35
+ class ToolTransformBase(Transform):
36
+ """Base class for tool transforms that parse and execute tools from LLM output.
37
+
38
+ This class handles all the common boilerplate for tool transforms:
39
+ - History extraction and validation
40
+ - Batch dimension flattening
41
+ - Result collection and padding
42
+ - History extension with tool results
43
+
44
+ Subclasses only need to implement:
45
+ - :meth:`_process_batch_item`: Extract and execute tools from one response
46
+ - :meth:`_format_result`: Format one tool result as string (optional)
47
+
48
+ Attributes:
49
+ use_step (bool): Whether to use _step() vs _call(). Defaults to True.
50
+ tool_role (str): Role name for results in history. Defaults to "tool".
51
+
52
+ Examples:
53
+ >>> class SimpleCalculator(ToolTransformBase):
54
+ ... tool_role = "calculator"
55
+ ...
56
+ ... def _process_batch_item(self, content: str, index: int):
57
+ ... # Extract math expressions and evaluate
58
+ ... if "2+2" in content:
59
+ ... return ["2+2=4"]
60
+ ... return None
61
+ """
62
+
63
+ use_step: bool = True # Use _step() vs _call()
64
+ tool_role: str = "tool" # Role name for results in history
65
+
66
+ def _validate_and_extract_history(
67
+ self, next_tensordict: TensorDictBase
68
+ ) -> tuple[History, History]:
69
+ """Validate environment and extract history.
70
+
71
+ Args:
72
+ next_tensordict: The tensordict containing history.
73
+
74
+ Returns:
75
+ tuple: (full_history, local_history) where local_history is the last message.
76
+
77
+ Raises:
78
+ RuntimeError: If parent env doesn't exist or isn't in history mode.
79
+ """
80
+ # Check that base_env is in history mode
81
+ parent = self.parent
82
+ if parent is None:
83
+ raise RuntimeError(f"{self.__class__.__name__} must be used with a ChatEnv")
84
+ base_env = parent.base_env
85
+ if base_env.input_mode != "history":
86
+ raise RuntimeError(
87
+ f"{self.__class__.__name__} must be used with a ChatEnv in history mode"
88
+ )
89
+
90
+ # Get history and isolate last element (the LLM's response)
91
+ history = next_tensordict["history"].prompt
92
+ local_history = history[..., -1]
93
+
94
+ return history, local_history
95
+
96
+ def _process_batch_item(self, content: str, index: int) -> list[str] | None:
97
+ """Process one item in the batch to extract and execute tools.
98
+
99
+ This is the main method subclasses must implement.
100
+
101
+ Args:
102
+ content: The text content from the LLM response.
103
+ index: The index of this item in the batch.
104
+
105
+ Returns:
106
+ list[str] or None: List of result strings for each tool executed,
107
+ or None if no tools were found/executed.
108
+ """
109
+ raise NotImplementedError(
110
+ f"{self.__class__.__name__} must implement _process_batch_item()"
111
+ )
112
+
113
+ def _format_result(self, result: str) -> str:
114
+ """Format a single result string.
115
+
116
+ Override this to customize result formatting. Default is identity.
117
+
118
+ Args:
119
+ result: Raw result string from tool execution.
120
+
121
+ Returns:
122
+ str: Formatted result string.
123
+ """
124
+ return result
125
+
126
+ def _inject_results_to_history(
127
+ self,
128
+ history: History,
129
+ results: list[list[str] | None],
130
+ next_tensordict: TensorDictBase,
131
+ ) -> TensorDictBase:
132
+ """Inject tool results back into history with proper batching.
133
+
134
+ Args:
135
+ history: The full conversation history.
136
+ results: List of results per batch item (can contain None).
137
+ next_tensordict: The tensordict to update.
138
+
139
+ Returns:
140
+ TensorDictBase: Updated tensordict with results in history.
141
+ """
142
+ # Convert string results to History objects
143
+ procs = []
144
+ for batch_results in results:
145
+ if batch_results is None or len(batch_results) == 0:
146
+ procs.append(None)
147
+ else:
148
+ formatted_results = [self._format_result(r) for r in batch_results]
149
+ procs.append(
150
+ [
151
+ History(role=self.tool_role, content=result)
152
+ for result in formatted_results
153
+ ]
154
+ )
155
+
156
+ # If there are no tool responses, skip
157
+ if all(p is None for p in procs):
158
+ return next_tensordict
159
+
160
+ # Fill None entries with empty lists for consistent batching
161
+ if any(p is None for p in procs):
162
+ procs = [p if p is not None else [] for p in procs]
163
+
164
+ # Pad all results to same length (required for batching)
165
+ if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs):
166
+
167
+ def fill_procs(proc: list[History], max_len: int) -> list[History]:
168
+ if len(proc) == max_len:
169
+ return proc
170
+ return proc + [History(role="<none>", content="")] * (
171
+ max_len - len(proc)
172
+ )
173
+
174
+ max_len = max(len(p) for p in procs)
175
+ procs = [fill_procs(p, max_len) for p in procs]
176
+
177
+ # Stack and extend history
178
+ procs = lazy_stack([lazy_stack(p) for p in procs])
179
+ history.extend(procs, dim=-1)
180
+ next_tensordict["history"].prompt = history
181
+
182
+ return next_tensordict
183
+
184
+ def _process_tensordict(self, next_tensordict: TensorDictBase) -> TensorDictBase:
185
+ """Main processing logic for tool transforms.
186
+
187
+ Handles batch flattening, history extraction, tool processing, and result injection.
188
+
189
+ Args:
190
+ next_tensordict: The tensordict to process.
191
+
192
+ Returns:
193
+ TensorDictBase: Updated tensordict with tool results.
194
+ """
195
+ # Flatten batch dimensions if needed
196
+ if next_tensordict.batch_dims > 1:
197
+ with next_tensordict.view(-1) as next_tensordict_flat:
198
+ next_tensordict_flat = self._process_tensordict(next_tensordict_flat)
199
+ return next_tensordict
200
+
201
+ # Extract and validate history
202
+ history, local_history = self._validate_and_extract_history(next_tensordict)
203
+
204
+ # Handle content as string or list
205
+ content = local_history.content
206
+ if isinstance(content, str):
207
+ content = [content]
208
+
209
+ # Process each batch item
210
+ results = []
211
+ for i, text in enumerate(content):
212
+ batch_results = self._process_batch_item(text, i)
213
+ results.append(batch_results)
214
+
215
+ # Inject results back into history
216
+ return self._inject_results_to_history(history, results, next_tensordict)
217
+
218
+ def _step(
219
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
220
+ ) -> TensorDictBase:
221
+ """Handle step with tool processing.
222
+
223
+ Args:
224
+ tensordict: Input tensordict.
225
+ next_tensordict: Output tensordict.
226
+
227
+ Returns:
228
+ TensorDictBase: Updated next_tensordict.
229
+ """
230
+ if not self.use_step:
231
+ raise RuntimeError(
232
+ f"{self.__class__.__name__} uses _call(), not _step(). Set use_step=False."
233
+ )
234
+ return self._process_tensordict(next_tensordict)
235
+
236
+ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
237
+ """Handle call with tool processing.
238
+
239
+ Args:
240
+ next_tensordict: The tensordict to process.
241
+
242
+ Returns:
243
+ TensorDictBase: Updated tensordict.
244
+ """
245
+ if self.use_step:
246
+ raise RuntimeError(
247
+ f"{self.__class__.__name__} uses _step(), not _call(). Set use_step=True."
248
+ )
249
+ return self._process_tensordict(next_tensordict)
250
+
251
+ def _reset(
252
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
253
+ ) -> TensorDictBase:
254
+ """Handle reset (no-op for base class).
255
+
256
+ Args:
257
+ tensordict (TensorDictBase): Input tensordict.
258
+ tensordict_reset (TensorDictBase): Reset tensordict.
259
+
260
+ Returns:
261
+ TensorDictBase: Unchanged reset tensordict.
262
+ """
263
+ return tensordict_reset
264
+
265
+
266
+ # --- Tool Service Library: Pluggable Services & Parsers ---
267
+
268
+
269
+ class ToolService(Protocol):
270
+ """Protocol for side-effecting service callable with structured IO.
271
+
272
+ A tool service is a callable that can be invoked with keyword arguments
273
+ and returns a dictionary of results. It has a name and input/output schemas.
274
+
275
+ Attributes:
276
+ name (str): The name of the tool service.
277
+ schema_in (dict[str, Any]): Input schema describing expected parameters.
278
+ schema_out (dict[str, Any]): Output schema describing returned data.
279
+ """
280
+
281
+ name: str
282
+ schema_in: dict[str, Any]
283
+ schema_out: dict[str, Any]
284
+
285
+ def __call__(self, **kwargs) -> dict[str, Any]:
286
+ """Execute the tool service.
287
+
288
+ Args:
289
+ **kwargs: Keyword arguments matching the input schema.
290
+
291
+ Returns:
292
+ dict[str, Any]: Results matching the output schema.
293
+ """
294
+ ...
295
+
296
+
297
+ class ToolRegistry:
298
+ """Registry for managing available tool services.
299
+
300
+ This class maintains a collection of tool services that can be looked up
301
+ by name for execution.
302
+
303
+ Args:
304
+ services (Sequence[ToolService], optional): Initial services to register.
305
+ Defaults to an empty sequence.
306
+
307
+ Examples:
308
+ >>> class AddService:
309
+ ... name = "add"
310
+ ... schema_in = {"a": int, "b": int}
311
+ ... schema_out = {"result": int}
312
+ ... def __call__(self, a, b, **kwargs):
313
+ ... return {"result": a + b}
314
+ >>> registry = ToolRegistry([AddService()])
315
+ >>> service = registry.get("add")
316
+ >>> result = service(a=1, b=2)
317
+ >>> print(result)
318
+ {"result": 3}
319
+ """
320
+
321
+ def __init__(self, services: Sequence[ToolService] = ()):
322
+ self._svc: dict[str, ToolService] = {s.name: s for s in services}
323
+
324
+ def register(self, service: ToolService) -> None:
325
+ """Register a new service.
326
+
327
+ Args:
328
+ service (ToolService): The service to register.
329
+ """
330
+ self._svc[service.name] = service
331
+
332
+ def get(self, name: str) -> ToolService:
333
+ """Retrieve a service by name.
334
+
335
+ Args:
336
+ name (str): The name of the service to retrieve.
337
+
338
+ Returns:
339
+ ToolService: The requested service.
340
+
341
+ Raises:
342
+ KeyError: If the service is not found.
343
+ """
344
+ if name not in self._svc:
345
+ raise KeyError(f"Unknown tool: {name}")
346
+ return self._svc[name]
347
+
348
+ def __contains__(self, name: str) -> bool:
349
+ """Check if a service is registered.
350
+
351
+ Args:
352
+ name (str): The name to check.
353
+
354
+ Returns:
355
+ bool: True if the service exists, False otherwise.
356
+ """
357
+ return name in self._svc
358
+
359
+
360
+ @dataclass
361
+ class ToolCall:
362
+ """Representation of a parsed tool call from LLM output.
363
+
364
+ Attributes:
365
+ tool (str): The name of the tool to call.
366
+ args (dict[str, Any]): Arguments to pass to the tool.
367
+ tag (str | None): Optional user-visible label or correlation ID.
368
+ """
369
+
370
+ tool: str
371
+ args: dict[str, Any]
372
+ tag: str | None = None
373
+
374
+
375
+ class ParseResult(TypedDict):
376
+ """Result of parsing an LLM response for tool calls.
377
+
378
+ This is a TypedDict-style class that contains:
379
+ text (str): The final message to user (post tool blocks removal).
380
+ calls (list[ToolCall]): Ordered tool calls as they appear.
381
+ meta (dict[str, Any]): Optional parser metadata.
382
+ """
383
+
384
+ text: str
385
+ calls: list[ToolCall]
386
+ meta: dict[str, Any]
387
+
388
+
389
+ class LLMToolParser(Protocol):
390
+ """Protocol for parsing LLM responses into ordered tool calls.
391
+
392
+ A tool parser takes the LLM's response (as string or structured data)
393
+ and extracts ordered tool calls, along with the cleaned user-facing text.
394
+ """
395
+
396
+ def __call__(self, response: str | dict[str, Any]) -> ParseResult:
397
+ """Parse an LLM response.
398
+
399
+ Args:
400
+ response (str | dict[str, Any]): The LLM's response to parse.
401
+
402
+ Returns:
403
+ ParseResult: Parsed result with text, calls, and metadata.
404
+ """
405
+ ...
406
+
407
+
408
+ class XMLBlockParser:
409
+ r"""Parser for XML-style tool blocks in LLM responses.
410
+
411
+ Parses tool calls in the format:
412
+ <tool name="tool_name" tag="optional_tag">{"arg": "value"}</tool>
413
+
414
+ Examples:
415
+ >>> parser = XMLBlockParser()
416
+ >>> response = '<tool name="search" tag="A">{"query": "torchrl"}</tool>\\nSome text.'
417
+ >>> result = parser(response)
418
+ >>> print(result["text"])
419
+ Some text.
420
+ >>> print(result["calls"][0].tool)
421
+ search
422
+ >>> print(result["calls"][0].args)
423
+ {"query": "torchrl"}
424
+ """
425
+
426
+ _re = re.compile(
427
+ r'<tool\s+name="(?P<name>[^"]+)"(?:\s+tag="(?P<tag>[^"]+)")?\s*>\s*(?P<body>.*?)\s*</tool>',
428
+ re.DOTALL,
429
+ )
430
+
431
+ def __call__(self, response: str | dict[str, Any]) -> ParseResult:
432
+ """Parse XML-style tool blocks from response.
433
+
434
+ Args:
435
+ response (str | dict[str, Any]): The response to parse.
436
+
437
+ Returns:
438
+ ParseResult: Parsed result with cleaned text and tool calls.
439
+ """
440
+ text = response if isinstance(response, str) else response.get("text", "")
441
+ calls: list[ToolCall] = []
442
+
443
+ def repl(m: re.Match) -> str:
444
+ name = m.group("name")
445
+ tag = m.group("tag")
446
+ body = m.group("body")
447
+ try:
448
+ args = json.loads(body) if body.strip() else {}
449
+ except json.JSONDecodeError:
450
+ # If JSON parsing fails, pass the raw body as a "raw" argument
451
+ args = {"raw": body}
452
+ calls.append(ToolCall(tool=name, args=args, tag=tag))
453
+ return "" # Remove block from final user-visible message
454
+
455
+ cleaned = self._re.sub(repl, text).strip()
456
+ result = ParseResult()
457
+ result["text"] = cleaned
458
+ result["calls"] = calls
459
+ result["meta"] = {"count": len(calls)}
460
+ return result
461
+
462
+
463
+ class JSONCallParser:
464
+ """Parser for JSON-style function-calling responses.
465
+
466
+ Expects responses in the format::
467
+
468
+ {
469
+ "message": "...",
470
+ "tools": [
471
+ {"tool": "search", "args": {"query": "..."}, "tag": "A"},
472
+ {"tool": "summarize", "args": {"text": "..."}}
473
+ ]
474
+ }
475
+
476
+ Examples:
477
+ >>> parser = JSONCallParser()
478
+ >>> response = {
479
+ ... "message": "Let me search for that.",
480
+ ... "tools": [{"tool": "search", "args": {"query": "torchrl"}}]
481
+ ... }
482
+ >>> result = parser(response)
483
+ >>> print(result["text"])
484
+ Let me search for that.
485
+ >>> print(result["calls"][0].tool)
486
+ search
487
+ """
488
+
489
+ def __call__(self, response: str | dict[str, Any]) -> ParseResult:
490
+ """Parse JSON-style function calls from response.
491
+
492
+ Args:
493
+ response (str | dict[str, Any]): The response to parse.
494
+
495
+ Returns:
496
+ ParseResult: Parsed result with message and tool calls.
497
+ """
498
+ if isinstance(response, str):
499
+ try:
500
+ response = json.loads(response)
501
+ except json.JSONDecodeError:
502
+ # If it's not valid JSON, treat as plain text with no tools
503
+ result = ParseResult()
504
+ result["text"] = response
505
+ result["calls"] = []
506
+ result["meta"] = {"count": 0}
507
+ return result
508
+
509
+ tools_data = response.get("tools", [])
510
+ calls = [ToolCall(**c) for c in tools_data]
511
+
512
+ result = ParseResult()
513
+ result["text"] = response.get("message", "")
514
+ result["calls"] = calls
515
+ result["meta"] = {"count": len(calls)}
516
+ return result
517
+
518
+
519
+ class ExecuteToolsInOrder(ToolTransformBase):
520
+ """A Transform that executes tools in the order they appear in LLM output.
521
+
522
+ This transform reads the LLM response, parses ordered tool blocks using a
523
+ pluggable parser, and executes tools via a ToolRegistry strictly in the
524
+ order they appear in the response (independent of transform stacking order).
525
+
526
+ The transform integrates naturally with TorchRL's LLM environments and can
527
+ read/write conversation history alongside other transforms.
528
+
529
+ Args:
530
+ registry (ToolRegistry): Registry containing available tool services.
531
+ parser (LLMToolParser): Parser for extracting tool calls from LLM output.
532
+ stop_on_error (bool, optional): Whether to stop execution on first error.
533
+ Defaults to ``False``.
534
+ pass_state_to_tools (bool, optional): Whether to pass TD state to tools.
535
+ Defaults to ``True``.
536
+
537
+ Examples:
538
+ >>> from torchrl.envs.llm import ChatEnv
539
+ >>> from torchrl.envs.transforms import TransformedEnv, Compose
540
+ >>> from torchrl.envs.llm.transforms import ExecuteToolsInOrder, ToolRegistry, XMLBlockParser
541
+ >>>
542
+ >>> # Define a simple service
543
+ >>> class WebSearch:
544
+ ... name = "search"
545
+ ... schema_in = {"query": str}
546
+ ... schema_out = {"results": list}
547
+ ... def __call__(self, query: str, **kwargs):
548
+ ... return {"results": [{"title": "TorchRL docs", "url": "https://..."}]}
549
+ >>>
550
+ >>> # Create registry and parser
551
+ >>> registry = ToolRegistry([WebSearch()])
552
+ >>> parser = XMLBlockParser()
553
+ >>>
554
+ >>> # Create environment with transform
555
+ >>> env = ChatEnv(batch_size=(1,))
556
+ >>> env = TransformedEnv(
557
+ ... env,
558
+ ... ExecuteToolsInOrder(registry=registry, parser=parser)
559
+ ... )
560
+
561
+ .. note::
562
+ This transform operates in the forward direction only; inverse is a no-op.
563
+ Tool execution order is determined by appearance in the LLM output,
564
+ not by the order of transforms in the Compose stack.
565
+ """
566
+
567
+ use_step = True # Use _step() method
568
+
569
+ def __init__(
570
+ self,
571
+ registry: ToolRegistry,
572
+ parser: LLMToolParser,
573
+ stop_on_error: bool = False,
574
+ pass_state_to_tools: bool = True,
575
+ ):
576
+ super().__init__()
577
+ self.registry = registry
578
+ self.parser = parser
579
+ self.stop_on_error = stop_on_error
580
+ self.pass_state_to_tools = pass_state_to_tools
581
+ self.tool_role = "tool"
582
+
583
+ def _process_batch_item(self, content: str, index: int) -> list[str] | None:
584
+ """Process one batch item to extract and execute tools.
585
+
586
+ This is the main method required by ToolTransformBase.
587
+
588
+ Args:
589
+ content: The text content from the LLM response.
590
+ index: The index of this item in the batch.
591
+
592
+ Returns:
593
+ list[str] or None: List of result strings for each tool executed,
594
+ or None if no tools were found.
595
+ """
596
+ # Parse the response for tool calls
597
+ parse: ParseResult = self.parser(content)
598
+ ordered_calls = parse["calls"]
599
+
600
+ if not ordered_calls:
601
+ return None
602
+
603
+ tool_outputs: list[dict[str, Any]] = []
604
+
605
+ # Execute tools IN ORDER OF APPEARANCE
606
+ for j, call in enumerate(ordered_calls):
607
+ try:
608
+ service = self.registry.get(call.tool)
609
+ kwargs = dict(call.args)
610
+ if self.pass_state_to_tools:
611
+ # Get tensordict from parent context if available
612
+ # For now, pass empty state - can be enhanced later
613
+ kwargs["_state"] = {}
614
+
615
+ out = service(**kwargs)
616
+ out["_tool"] = call.tool
617
+ out["_index"] = j
618
+ if call.tag:
619
+ out["_tag"] = call.tag
620
+ tool_outputs.append(out)
621
+ except Exception as e:
622
+ err = {"_tool": call.tool, "_index": j, "error": str(e)}
623
+ tool_outputs.append(err)
624
+ if self.stop_on_error:
625
+ break
626
+
627
+ # Format tool results as a single string
628
+ # Format tool results as a single string
629
+ if tool_outputs:
630
+ results_text = self._format_tool_results(tool_outputs)
631
+ return [results_text] if results_text else None
632
+
633
+ def _format_tool_results(self, tool_outputs: list[dict[str, Any]]) -> str:
634
+ """Format tool execution results as text.
635
+
636
+ Args:
637
+ tool_outputs (list[dict[str, Any]]): List of tool execution results.
638
+
639
+ Returns:
640
+ str: Formatted text representation of results.
641
+ """
642
+ if not tool_outputs:
643
+ return ""
644
+
645
+ lines = ["<tool_results>"]
646
+ for output in tool_outputs:
647
+ tool_name = output.pop("_tool", "unknown")
648
+ index = output.pop("_index", 0)
649
+ tag = output.pop("_tag", None)
650
+
651
+ if "error" in output:
652
+ lines.append(f"Tool {tool_name} (call {index + 1}) failed:")
653
+ lines.append(f" Error: {output['error']}")
654
+ else:
655
+ header = f"Tool {tool_name} (call {index + 1})"
656
+ if tag:
657
+ header += f" [tag: {tag}]"
658
+ header += " succeeded:"
659
+ lines.append(header)
660
+ lines.append(f" Result: {json.dumps(output, indent=2)}")
661
+
662
+ lines.append("</tool_results>")
663
+ return "\n".join(lines)
664
+
665
+
666
+ class PersistentPythonProcess:
667
+ """A persistent Python process that can execute code blocks."""
668
+
669
+ def __init__(self, timeout: float = 10.0):
670
+ self.timeout = timeout
671
+ self._output_queue = queue.Queue()
672
+ self._error_queue = queue.Queue()
673
+ self._accumulated_errors = []
674
+ self._init_script = None
675
+ self.process = None # Initialize to None to avoid AttributeError in __del__
676
+
677
+ # Start the process
678
+ self._start_process()
679
+
680
+ def _start_process(self):
681
+ """Start the Python process with the initialization script."""
682
+ # Create a temporary file for initialization
683
+ init_file = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False)
684
+ self._init_script = init_file.name
685
+
686
+ # Write a script that creates a continuous execution environment
687
+ init_file.write(
688
+ """
689
+ import sys
690
+ import traceback
691
+
692
+ def run_code(code_str):
693
+ # Create a dictionary to store the local variables
694
+ locals_dict = {}
695
+ try:
696
+ # First try to compile the code to catch syntax errors
697
+ compiled = compile(code_str, '<string>', 'exec')
698
+ # Execute the code with the locals dictionary
699
+ exec(compiled, globals(), locals_dict)
700
+ # Ensure output is flushed
701
+ sys.stdout.flush()
702
+ sys.stderr.flush()
703
+ return locals_dict
704
+ except Exception as e:
705
+ print(f"Error: {str(e)}", file=sys.stderr)
706
+ print("Traceback:", file=sys.stderr)
707
+ traceback.print_exc(file=sys.stderr)
708
+ # Ensure error output is flushed immediately
709
+ sys.stdout.flush()
710
+ sys.stderr.flush()
711
+ return locals_dict
712
+
713
+ # Signal that we're ready to accept commands
714
+ print('---READY---')
715
+ sys.stdout.flush()
716
+
717
+ # Main loop to handle commands
718
+ while True:
719
+ try:
720
+ # Read a line that signals the start of a command
721
+ line = input()
722
+ if line.strip() == '---EXEC---':
723
+ # Read the code until we see the end marker
724
+ code_lines = []
725
+ while True:
726
+ line = input()
727
+ if line.strip() == '---END---':
728
+ break
729
+ code_lines.append(line)
730
+
731
+ # Execute the code
732
+ code_str = '\\n'.join(code_lines)
733
+ print('---START---') # Signal start of execution
734
+ sys.stdout.flush()
735
+ locals_dict = run_code(code_str)
736
+ # Update globals with new locals for persistence
737
+ globals().update(locals_dict)
738
+ print('---END---') # Signal end of execution
739
+ # Ensure all output is flushed
740
+ sys.stdout.flush()
741
+ sys.stderr.flush()
742
+ except (EOFError, KeyboardInterrupt):
743
+ break
744
+ except Exception as e:
745
+ print(f"Fatal error: {str(e)}", file=sys.stderr)
746
+ sys.stderr.flush()
747
+ break
748
+ """
749
+ )
750
+ init_file.close()
751
+
752
+ # Start the process
753
+ try:
754
+ self.process = subprocess.Popen(
755
+ [sys.executable, "-u", self._init_script], # -u for unbuffered output
756
+ stdin=subprocess.PIPE,
757
+ stdout=subprocess.PIPE,
758
+ stderr=subprocess.PIPE,
759
+ text=True,
760
+ bufsize=1,
761
+ )
762
+
763
+ # Start output reading threads
764
+ self._stdout_thread = threading.Thread(
765
+ target=self._read_output,
766
+ args=(self.process.stdout, self._output_queue, "stdout"),
767
+ daemon=True,
768
+ )
769
+ self._stderr_thread = threading.Thread(
770
+ target=self._read_output,
771
+ args=(self.process.stderr, self._error_queue, "stderr"),
772
+ daemon=True,
773
+ )
774
+ self._stdout_thread.start()
775
+ self._stderr_thread.start()
776
+
777
+ # Wait for the process to be ready
778
+ ready = False
779
+ timeout = self.timeout
780
+ while timeout > 0 and not ready:
781
+ if self.process.poll() is not None:
782
+ raise RuntimeError(
783
+ f"Process failed to start: {self.process.returncode}"
784
+ )
785
+
786
+ try:
787
+ line = self._output_queue.get_nowait()
788
+ torchrl_logger.info(f"Output: {line}")
789
+ if "---READY---" in line:
790
+ ready = True
791
+ break
792
+ except queue.Empty:
793
+ timeout -= 0.1
794
+ time.sleep(0.1)
795
+
796
+ if not ready:
797
+ raise RuntimeError("Process failed to initialize within timeout")
798
+
799
+ except Exception:
800
+ # Clean up if process creation failed
801
+ if self._init_script:
802
+ try:
803
+ os.unlink(self._init_script)
804
+ self._init_script = None
805
+ except Exception:
806
+ pass
807
+ raise
808
+
809
+ def _read_output(self, pipe: TextIO, q: queue.Queue, pipe_name: str) -> None:
810
+ """Read output from a pipe and put it in a queue."""
811
+ try:
812
+ for line in iter(pipe.readline, ""):
813
+ if pipe_name == "stderr":
814
+ self._accumulated_errors.append(line)
815
+ q.put(line)
816
+ except (ValueError, OSError) as e:
817
+ # Pipe has been closed
818
+ torchrl_logger.info(f"{pipe_name} pipe closed: {str(e)}")
819
+ finally:
820
+ try:
821
+ pipe.close()
822
+ except Exception:
823
+ pass
824
+
825
+ def execute(self, prompt: str) -> dict[str, Any]:
826
+ """Execute code in the persistent process."""
827
+ if not self.process or self.process.poll() is not None:
828
+ # Get any accumulated errors
829
+ errors = "".join(self._accumulated_errors)
830
+ torchrl_logger.info(
831
+ f"Process state: poll={self.process.poll() if self.process else 'No process'}, accumulated errors: {errors}"
832
+ )
833
+ return {
834
+ "success": False,
835
+ "stdout": "",
836
+ "stderr": f"Process not initialized or terminated. Accumulated errors: {errors}",
837
+ "returncode": self.process.returncode if self.process else -1,
838
+ }
839
+
840
+ if not self.process.stdin:
841
+ return {
842
+ "success": False,
843
+ "stdout": "",
844
+ "stderr": "Process stdin not available",
845
+ "returncode": -1,
846
+ }
847
+
848
+ try:
849
+ # Clear accumulated errors before new execution
850
+ self._accumulated_errors.clear()
851
+
852
+ # Send the execution markers and code
853
+ try:
854
+ self.process.stdin.write("---EXEC---\n")
855
+ torchrl_logger.info(f"Writing to stdin: {prompt}")
856
+ self.process.stdin.write(f"{prompt}\n")
857
+ self.process.stdin.write("---END---\n")
858
+ self.process.stdin.flush()
859
+ except OSError as e:
860
+ torchrl_logger.info(f"Failed to write to stdin: {str(e)}")
861
+ return {
862
+ "success": False,
863
+ "stdout": "",
864
+ "stderr": f"Failed to write to process: {str(e)}",
865
+ "returncode": -1,
866
+ }
867
+
868
+ # Collect output until we see the end marker
869
+ output = []
870
+ error = []
871
+ start_found = False
872
+ timeout_val = self.timeout
873
+
874
+ while timeout_val > 0:
875
+ if self.process.poll() is not None:
876
+ # Process has terminated - get accumulated errors
877
+ errors = "".join(self._accumulated_errors)
878
+ torchrl_logger.info(
879
+ f"Process terminated with return code {self.process.returncode} - accumulated errors: {errors}"
880
+ )
881
+ error.append(
882
+ f"Process terminated with return code {self.process.returncode} - {errors}"
883
+ )
884
+ break
885
+
886
+ try:
887
+ # Check for errors first
888
+ try:
889
+ while True: # Drain all available error output
890
+ line = self._error_queue.get_nowait()
891
+ torchrl_logger.info(f"Error: {line}")
892
+ error.append(line)
893
+ except queue.Empty:
894
+ pass
895
+
896
+ # Then check for output
897
+ try:
898
+ line = self._output_queue.get_nowait()
899
+ torchrl_logger.info(f"Output: {line}")
900
+ if "---START---" in line:
901
+ start_found = True
902
+ continue
903
+ if "---END---" in line:
904
+ break
905
+ if start_found:
906
+ output.append(line)
907
+ except queue.Empty:
908
+ pass
909
+
910
+ # Always sleep a bit to avoid busy-waiting and give subprocess time
911
+ timeout_val -= 0.01
912
+ time.sleep(0.01)
913
+
914
+ except Exception as e:
915
+ return {
916
+ "success": False,
917
+ "stdout": "",
918
+ "stderr": f"Execution error: {str(e)}",
919
+ "returncode": -1,
920
+ }
921
+
922
+ if timeout_val <= 0:
923
+ # Kill the process and create a new one
924
+ self.cleanup()
925
+ self.__init__(self.timeout)
926
+ return {
927
+ "success": False,
928
+ "stdout": "",
929
+ "stderr": "Code execution timed out - process restarted",
930
+ "returncode": -1,
931
+ }
932
+
933
+ return {
934
+ "success": len(error) == 0,
935
+ "stdout": "".join(output),
936
+ "stderr": "".join(error),
937
+ "returncode": 0 if len(error) == 0 else 1,
938
+ }
939
+
940
+ except Exception as e:
941
+ # If we encounter any error, restart the process
942
+ self.cleanup()
943
+ self.__init__(self.timeout)
944
+ return {
945
+ "success": False,
946
+ "stdout": "",
947
+ "stderr": f"Execution error: {str(e)} - process restarted",
948
+ "returncode": -1,
949
+ }
950
+
951
+ def cleanup(self):
952
+ """Clean up the persistent process."""
953
+ import signal
954
+
955
+ if self.process:
956
+ try:
957
+ self.process.send_signal(signal.SIGTERM)
958
+ self.process.wait(timeout=1.0)
959
+ except (subprocess.TimeoutExpired, OSError):
960
+ self.process.kill()
961
+ self.process = None
962
+
963
+ # Clean up the init script
964
+ if self._init_script:
965
+ try:
966
+ os.unlink(self._init_script)
967
+ self._init_script = None
968
+ except Exception:
969
+ pass
970
+
971
+ def __del__(self):
972
+ """Ensure cleanup on deletion."""
973
+ self.cleanup()
974
+
975
+
976
+ class PythonExecutorService:
977
+ """Ray actor that manages a pool of persistent Python interpreters.
978
+
979
+ This service allows multiple environments to share a pool of Python
980
+ interpreters, reducing resource usage and improving efficiency.
981
+
982
+ Args:
983
+ pool_size (int): Number of Python interpreter processes to maintain.
984
+ timeout (float): Timeout for code execution in seconds.
985
+
986
+ Examples:
987
+ >>> # Register the service
988
+ >>> from torchrl.services import get_services
989
+ >>> services = get_services(backend="ray")
990
+ >>> services.register(
991
+ ... "python_executor",
992
+ ... PythonExecutorService,
993
+ ... pool_size=32,
994
+ ... timeout=10.0,
995
+ ... num_cpus=32,
996
+ ... max_concurrency=32
997
+ ... )
998
+ >>>
999
+ >>> # Use in transform
1000
+ >>> env = env.append_transform(
1001
+ ... PythonInterpreter(services="ray")
1002
+ ... )
1003
+ """
1004
+
1005
+ def __init__(self, pool_size: int = 32, timeout: float = 10.0):
1006
+ self.pool_size = pool_size
1007
+ self.timeout = timeout
1008
+ self.processes = [
1009
+ PersistentPythonProcess(timeout=timeout) for _ in range(pool_size)
1010
+ ]
1011
+ # Create a lock for each process to prevent concurrent access
1012
+ self.process_locks = [threading.Lock() for _ in range(pool_size)]
1013
+ self.next_idx = 0
1014
+ self._selection_lock = threading.Lock()
1015
+
1016
+ def execute(self, code: str) -> dict:
1017
+ """Execute Python code using next available process (round-robin).
1018
+
1019
+ Args:
1020
+ code: Python code to execute.
1021
+
1022
+ Returns:
1023
+ dict: Execution result with keys 'success', 'stdout', 'stderr', 'returncode'.
1024
+ """
1025
+ # Select a process using round-robin
1026
+ with self._selection_lock:
1027
+ process_idx = self.next_idx
1028
+ self.next_idx = (self.next_idx + 1) % self.pool_size
1029
+
1030
+ # Lock the selected process for the duration of execution
1031
+ with self.process_locks[process_idx]:
1032
+ return self.processes[process_idx].execute(code)
1033
+
1034
+ def cleanup(self):
1035
+ """Cleanup all processes in the pool."""
1036
+ if hasattr(self, "processes"):
1037
+ for process in self.processes:
1038
+ if process:
1039
+ process.cleanup()
1040
+ self.processes = []
1041
+
1042
+ def __del__(self):
1043
+ """Ensure cleanup on deletion."""
1044
+ try:
1045
+ self.cleanup()
1046
+ except Exception:
1047
+ # Ignore errors during cleanup - we might be in Ray actor context
1048
+ pass
1049
+
1050
+
1051
+ class PythonInterpreter(ToolTransformBase):
1052
+ r"""A transform that executes Python code in the LLM response.
1053
+
1054
+ This transform inherits from :class:`ToolTransformBase` and handles all the
1055
+ boilerplate for history extraction, batch processing, and result injection.
1056
+
1057
+ Args:
1058
+ tokenizer: The tokenizer to use. Defaults to `None` (no tokenizer).
1059
+ tool_name: The name of the tool in the chat history. Defaults to `"tool"`.
1060
+ persistent: Whether to use persistent processes. Defaults to `False`.
1061
+ timeout: The timeout for the persistent processes. Defaults to `10.0`.
1062
+ services: Backend for shared Python executor service. If `"ray"`, uses
1063
+ a shared Ray actor service for execution. If `None`, uses local
1064
+ processes. Defaults to `None`.
1065
+ service_name: Name of the service in the registry. Only used if
1066
+ `services="ray"`. Defaults to `"python_executor"`.
1067
+ namespace: Ray namespace for the service. Only used if `services="ray"`.
1068
+ If `None`, uses the default namespace. Defaults to `None`.
1069
+
1070
+ Examples:
1071
+ >>> from torchrl.envs.llm.transforms import PythonInterpreter
1072
+ >>> from transformers import AutoTokenizer
1073
+ >>> from tensordict import TensorDict, set_list_to_stack
1074
+ >>> from torchrl.envs.llm import ChatEnv
1075
+ >>> set_list_to_stack(True).set()
1076
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
1077
+ >>> env = ChatEnv(
1078
+ ... batch_size=(1,),
1079
+ ... system_prompt="I'm the system, do as I say",
1080
+ ... apply_template=True,
1081
+ ... tokenizer=tokenizer,
1082
+ ... )
1083
+ >>> env = env.append_transform(PythonInterpreter())
1084
+ >>> r = env.reset(TensorDict(text=["This is the user prompt"], batch_size=(1,)))
1085
+ >>> r["text_response"] = ["Here is a python code to execute:\n```python\na=1\nprint(f'{a=}')\n```<|im_end|>\n"]
1086
+ >>> s = env.step(r)
1087
+ >>> print(s['next', 'history'].apply_chat_template(tokenizer=tokenizer))
1088
+ ['<|im_start|>system\n'
1089
+ "I'm the system, do as I say<|im_end|>\n"
1090
+ '<|im_start|>user\n'
1091
+ 'This is the user prompt<|im_end|>\n'
1092
+ '<|im_start|>assistant\n'
1093
+ 'Here is a python code to execute:\n'
1094
+ '```python\n'
1095
+ 'a=1\n'
1096
+ "print(f'{a=}')\n"
1097
+ '```<|im_end|>\n'
1098
+ '<|im_start|>user\n'
1099
+ '<tool_response>\n'
1100
+ 'Code block 1 executed successfully:\n'
1101
+ 'a=1\n'
1102
+ '\n'
1103
+ '</tool_response><|im_end|>\n'
1104
+ '<|im_start|>assistant\n']
1105
+
1106
+ Using shared Ray service:
1107
+ >>> from torchrl.services import get_services
1108
+ >>>
1109
+ >>> # Register service once (e.g., in main process)
1110
+ >>> services = get_services(backend="ray")
1111
+ >>> if "python_executor" not in services:
1112
+ ... services.register(
1113
+ ... "python_executor",
1114
+ ... PythonExecutorService,
1115
+ ... pool_size=32,
1116
+ ... timeout=10.0,
1117
+ ... num_cpus=32,
1118
+ ... max_concurrency=32
1119
+ ... )
1120
+ >>>
1121
+ >>> # Use in transform (all 128 envs share the 32 interpreters)
1122
+ >>> env = env.append_transform(PythonInterpreter(services="ray"))
1123
+ """
1124
+
1125
+ use_step = True # Use _step() method
1126
+
1127
+ def __init__(
1128
+ self,
1129
+ tokenizer=None, # type: ignore
1130
+ tool_name: str = "tool",
1131
+ persistent: bool = False,
1132
+ timeout: float = 10.0,
1133
+ services: str | None = None,
1134
+ service_name: str = "python_executor",
1135
+ namespace: str | None = None,
1136
+ ):
1137
+ super().__init__()
1138
+ self.tokenizer = tokenizer
1139
+ self.tool_role = tool_name # Set the role for history entries
1140
+ self.persistent = persistent
1141
+ self.timeout = timeout
1142
+ self.services = services
1143
+ self.service_name = service_name
1144
+ self.namespace = namespace
1145
+
1146
+ # Initialize attributes to avoid AttributeError in __del__
1147
+ self.python_service = None
1148
+ self.processes = None
1149
+
1150
+ # Initialize based on service mode
1151
+ if services == "ray":
1152
+ # Use shared Ray service
1153
+ try:
1154
+ from torchrl.services import get_services
1155
+
1156
+ service_registry = get_services(backend="ray", namespace=namespace)
1157
+ self.python_service = service_registry[service_name]
1158
+ self.processes = None
1159
+ torchrl_logger.info(
1160
+ f"PythonInterpreter using Ray service '{service_name}'"
1161
+ )
1162
+ except Exception as e:
1163
+ raise RuntimeError(
1164
+ f"Failed to get Ray service '{service_name}'. "
1165
+ f"Make sure the service is registered. Error: {e}"
1166
+ ) from e
1167
+ elif services is None:
1168
+ # Use local processes
1169
+ self.python_service = None
1170
+ self.processes = [] if persistent else []
1171
+ else:
1172
+ raise ValueError(
1173
+ f"Invalid services backend: {services}. Must be 'ray' or None."
1174
+ )
1175
+
1176
+ def close(self):
1177
+ """Close the transform."""
1178
+ if self.python_service is None and self.processes:
1179
+ for process in self.processes:
1180
+ if process:
1181
+ process.cleanup()
1182
+ self.processes = []
1183
+
1184
+ def clone(self):
1185
+ """Clone the transform."""
1186
+ return self.__class__(
1187
+ tokenizer=self.tokenizer,
1188
+ tool_name=self.tool_role, # tool_role is the instance attribute
1189
+ persistent=self.persistent,
1190
+ timeout=self.timeout,
1191
+ services=self.services,
1192
+ service_name=self.service_name,
1193
+ namespace=self.namespace,
1194
+ )
1195
+
1196
+ def _ensure_processes(self, batch_size: int):
1197
+ """Ensure we have the right number of persistent processes."""
1198
+ if not self.persistent:
1199
+ return
1200
+
1201
+ # Create new processes if needed
1202
+ while len(self.processes) < batch_size:
1203
+ self.processes.append(PersistentPythonProcess(timeout=self.timeout))
1204
+
1205
+ if any(p is None for p in self.processes):
1206
+ self.processes = [
1207
+ p if p is not None else PersistentPythonProcess(timeout=self.timeout)
1208
+ for p in self.processes
1209
+ ]
1210
+
1211
+ # Remove extra processes if batch size decreased
1212
+ if len(self.processes) > batch_size:
1213
+ raise RuntimeError(
1214
+ f"Too many processes: {len(self.processes)} > {batch_size}"
1215
+ )
1216
+
1217
+ def _execute_python_code(self, code: str, i: int) -> dict:
1218
+ """Safely execute Python code and return results."""
1219
+ if self.python_service is not None:
1220
+ # Use shared Ray service
1221
+ try:
1222
+ import ray
1223
+
1224
+ result = ray.get(self.python_service.execute.remote(code))
1225
+ return result
1226
+ except Exception as e:
1227
+ return {
1228
+ "success": False,
1229
+ "stdout": "",
1230
+ "stderr": f"Ray service execution failed: {str(e)}",
1231
+ "returncode": -1,
1232
+ }
1233
+ elif self.persistent:
1234
+ # Use local persistent process
1235
+ # Ensure we have enough processes
1236
+ if i >= len(self.processes):
1237
+ self._ensure_processes(i + 1)
1238
+ # Use persistent process
1239
+ process = self.processes[i]
1240
+ if process is None:
1241
+ return {
1242
+ "success": False,
1243
+ "stdout": "",
1244
+ "stderr": "Process not initialized",
1245
+ "returncode": -1,
1246
+ }
1247
+ return process.execute(code)
1248
+ else:
1249
+ # Use temporary file approach
1250
+ try:
1251
+ with tempfile.NamedTemporaryFile(
1252
+ mode="w", suffix=".py", delete=False
1253
+ ) as f:
1254
+ f.write(code)
1255
+ temp_file = f.name
1256
+
1257
+ result = subprocess.run(
1258
+ [sys.executable, temp_file],
1259
+ capture_output=True,
1260
+ text=True,
1261
+ timeout=self.timeout,
1262
+ )
1263
+
1264
+ os.unlink(temp_file)
1265
+
1266
+ return {
1267
+ "success": result.returncode == 0,
1268
+ "stdout": result.stdout,
1269
+ "stderr": result.stderr,
1270
+ "returncode": result.returncode,
1271
+ }
1272
+
1273
+ except subprocess.TimeoutExpired:
1274
+ return {
1275
+ "success": False,
1276
+ "stdout": "",
1277
+ "stderr": "Code execution timed out",
1278
+ "returncode": -1,
1279
+ }
1280
+ except Exception as e:
1281
+ return {
1282
+ "success": False,
1283
+ "stdout": "",
1284
+ "stderr": str(e),
1285
+ "returncode": -1,
1286
+ }
1287
+
1288
+ def _extract_python_code(self, text: str) -> list[str]:
1289
+ """Extract Python code blocks from markdown-style formatting."""
1290
+ # Pattern to match ```python ... ``` blocks
1291
+ pattern = r"```python\n(.*?)\n```"
1292
+ matches = re.findall(pattern, text, re.DOTALL)
1293
+ return matches
1294
+
1295
+ def _process_batch_item(self, content: str, index: int) -> list[str] | None:
1296
+ """Process one batch item to extract and execute Python code.
1297
+
1298
+ This is the main method required by ToolTransformBase.
1299
+
1300
+ Args:
1301
+ content: The text content from the LLM response.
1302
+ index: The index of this item in the batch.
1303
+
1304
+ Returns:
1305
+ list[str] or None: List of result strings for each code block executed,
1306
+ or None if no code blocks were found.
1307
+ """
1308
+ # Ensure we have enough processes for persistent mode
1309
+ if self.persistent:
1310
+ if index >= len(self.processes):
1311
+ self._ensure_processes(index + 1)
1312
+
1313
+ # Extract code blocks
1314
+ code_blocks = self._extract_python_code(content)
1315
+ if not code_blocks:
1316
+ return None
1317
+
1318
+ # Execute each code block
1319
+ results = []
1320
+ for block_idx, code in enumerate(code_blocks):
1321
+ result = self._execute_python_code(code, index)
1322
+
1323
+ if result["success"]:
1324
+ results.append(
1325
+ f"Code block {block_idx + 1} executed successfully:\n{result['stdout']}"
1326
+ )
1327
+ else:
1328
+ results.append(
1329
+ f"Code block {block_idx + 1} failed:\n{result['stderr']}"
1330
+ )
1331
+
1332
+ return results if results else None
1333
+
1334
+ def _step(
1335
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
1336
+ ) -> TensorDictBase:
1337
+ """Override to handle batch size management for persistent processes."""
1338
+ # Ensure we have enough processes for the entire batch (only for local persistent mode)
1339
+ if (
1340
+ self.python_service is None
1341
+ and self.persistent
1342
+ and next_tensordict.batch_dims == 1
1343
+ ):
1344
+ self._ensure_processes(len(next_tensordict))
1345
+
1346
+ # Delegate to base class for all the heavy lifting
1347
+ return super()._step(tensordict, next_tensordict)
1348
+
1349
+ def __del__(self):
1350
+ """Ensure cleanup on deletion."""
1351
+ try:
1352
+ if hasattr(self, "python_service") and self.python_service is None:
1353
+ if hasattr(self, "processes") and self.processes:
1354
+ for process in self.processes:
1355
+ if process:
1356
+ process.cleanup()
1357
+ except Exception:
1358
+ # Ignore errors during cleanup
1359
+ pass
1360
+
1361
+ def _reset(
1362
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
1363
+ ) -> TensorDictBase:
1364
+ # Get the '_reset' key from the tensordict_reset
1365
+ reset = tensordict.get("_reset")
1366
+ if reset is not None:
1367
+ reset = reset.view(tensordict.shape)
1368
+ else:
1369
+ reset = torch.ones(
1370
+ tensordict.shape, device=tensordict.device, dtype=torch.bool
1371
+ )
1372
+
1373
+ # Only reset local persistent processes, not the shared service
1374
+ if self.python_service is None and self.persistent:
1375
+ for i, process in enumerate(self.processes):
1376
+ if reset[i] and process is not None:
1377
+ process.cleanup()
1378
+ self.processes = [
1379
+ process
1380
+ if not reset[i]
1381
+ else PersistentPythonProcess(timeout=self.timeout)
1382
+ for i, process in enumerate(self.processes)
1383
+ ]
1384
+ return tensordict_reset
1385
+
1386
+
1387
+ class SimpleToolTransform(ToolTransformBase):
1388
+ r"""A simple transform that executes tools from a dictionary of callables.
1389
+
1390
+ This is a lightweight alternative to MCPToolTransform for simple use cases
1391
+ where you don't need the full Model Context Protocol infrastructure.
1392
+
1393
+ Args:
1394
+ tools (dict[str, Callable]): Dictionary mapping tool names to their implementation functions.
1395
+ Each function should accept kwargs matching its expected parameters.
1396
+ tool_schemas (dict[str, dict], optional): Dictionary mapping tool names to their schemas.
1397
+ Used for documentation purposes only.
1398
+ parser (LLMToolParser | None, optional): Parser for extracting tool calls. If None,
1399
+ uses a simple XML-style parser. Defaults to None.
1400
+ tool_call_pattern (str | None, optional): Regex pattern for extracting tool calls.
1401
+ Only used if parser is None. Format should capture (tool_name, args_json).
1402
+ Defaults to ``r"<tool>(.*?)\\n(.*?)</tool>"``.
1403
+ tool_name (str, optional): Role name for tool results in history. Defaults to "tool".
1404
+ timeout (float, optional): Timeout for tool execution in seconds. Defaults to 10.0.
1405
+
1406
+ Examples:
1407
+ >>> from torchrl.envs.llm.transforms import SimpleToolTransform, XMLBlockParser
1408
+ >>> from torchrl.envs.llm import ChatEnv
1409
+ >>> from tensordict import TensorDict, set_list_to_stack
1410
+ >>> set_list_to_stack(True).set()
1411
+ >>>
1412
+ >>> # Define a simple tool
1413
+ >>> def calculator(operation: str, a: float, b: float):
1414
+ ... if operation == "add":
1415
+ ... return {"result": a + b}
1416
+ ... return {"error": "unknown operation"}
1417
+ >>>
1418
+ >>> tools = {"calculator": calculator}
1419
+ >>> env = ChatEnv(batch_size=(1,))
1420
+ >>>
1421
+ >>> # Option 1: Use default XML-style pattern
1422
+ >>> env = env.append_transform(SimpleToolTransform(tools=tools))
1423
+ >>>
1424
+ >>> # Option 2: Use XMLBlockParser for more features
1425
+ >>> parser = XMLBlockParser()
1426
+ >>> env = env.append_transform(SimpleToolTransform(tools=tools, parser=parser))
1427
+ >>>
1428
+ >>> # Option 3: Custom pattern
1429
+ >>> env = env.append_transform(
1430
+ ... SimpleToolTransform(
1431
+ ... tools=tools,
1432
+ ... tool_call_pattern=r"CALL\[(.*?)\]\((.*?)\)"
1433
+ ... )
1434
+ ... )
1435
+ """
1436
+
1437
+ use_step = True
1438
+
1439
+ def __init__(
1440
+ self,
1441
+ tools: dict[str, Callable],
1442
+ tool_schemas: dict[str, dict] | None = None,
1443
+ parser: LLMToolParser | None = None,
1444
+ tool_call_pattern: str | None = None,
1445
+ tool_name: str = "tool",
1446
+ timeout: float = 10.0,
1447
+ ):
1448
+ super().__init__()
1449
+ self.tools = tools
1450
+ self.tool_schemas = tool_schemas or {}
1451
+ self.parser = parser
1452
+ self.tool_call_pattern = tool_call_pattern or r"<tool>(.*?)\n(.*?)</tool>"
1453
+ self.tool_role = tool_name
1454
+ self.timeout = timeout
1455
+
1456
+ def _extract_tool_calls(self, text: str) -> list[tuple[str, str]]:
1457
+ """Extract tool calls from text.
1458
+
1459
+ Uses parser if provided, otherwise falls back to regex pattern.
1460
+ """
1461
+ if self.parser is not None:
1462
+ # Use the parser (e.g., XMLBlockParser)
1463
+ result: ParseResult = self.parser(text)
1464
+ calls = result.get("calls", [])
1465
+ return [(call.tool, json.dumps(call.args)) for call in calls]
1466
+ else:
1467
+ # Use regex pattern
1468
+ matches = re.findall(self.tool_call_pattern, text, re.DOTALL)
1469
+ return matches
1470
+
1471
+ def _execute_tool(self, tool_name: str, args_json: str) -> dict:
1472
+ """Execute a tool with the given arguments."""
1473
+ try:
1474
+ if tool_name not in self.tools:
1475
+ return {
1476
+ "success": False,
1477
+ "error": f"Tool {tool_name} not found",
1478
+ }
1479
+
1480
+ # Parse arguments
1481
+ try:
1482
+ args = json.loads(args_json) if args_json.strip() else {}
1483
+ except json.JSONDecodeError as e:
1484
+ return {
1485
+ "success": False,
1486
+ "error": f"Failed to parse tool arguments: {str(e)}",
1487
+ }
1488
+
1489
+ # Execute tool
1490
+ result = self.tools[tool_name](**args)
1491
+ return {
1492
+ "success": True,
1493
+ "result": result,
1494
+ }
1495
+ except Exception as e:
1496
+ return {
1497
+ "success": False,
1498
+ "error": f"Tool execution failed: {str(e)}",
1499
+ }
1500
+
1501
+ def _process_batch_item(self, content: str, index: int) -> list[str] | None:
1502
+ """Process one batch item to extract and execute simple tools."""
1503
+ tool_calls = self._extract_tool_calls(content)
1504
+ if not tool_calls:
1505
+ return None
1506
+
1507
+ results = []
1508
+ for tool_name, args_json in tool_calls:
1509
+ result = self._execute_tool(tool_name, args_json)
1510
+
1511
+ if result["success"]:
1512
+ results.append(
1513
+ f"Tool {tool_name} executed successfully:\n{result['result']}"
1514
+ )
1515
+ else:
1516
+ results.append(f"Tool {tool_name} failed:\n{result['error']}")
1517
+
1518
+ return results if results else None
1519
+
1520
+
1521
+ class MCPToolTransform(ToolTransformBase):
1522
+ r"""A transform that executes tools via the Model Context Protocol (MCP).
1523
+
1524
+ This transform connects to MCP servers and executes tools through the official
1525
+ MCP library. It runs async operations in a background thread to work with
1526
+ TorchRL's synchronous transform API.
1527
+
1528
+ Args:
1529
+ servers (dict[str, dict]): Dictionary mapping server names to their configurations.
1530
+ Each config should have:
1531
+ - "command" (str): Command to launch the server (e.g., "npx", "uvx")
1532
+ - "args" (list[str]): Arguments for the command
1533
+ Example: {"browser": {"command": "npx", "args": ["@browsermcp/mcp@latest"]}}
1534
+ tool_call_pattern (str, optional): Regex pattern for extracting tool calls.
1535
+ Should capture (tool_name_with_server, args_json).
1536
+ Defaults to ``r"<tool>([\\w.]+)\\n(.*?)</tool>"``.
1537
+ tool_name (str, optional): Role name for tool results in history. Defaults to "tool".
1538
+ timeout (float, optional): Timeout for tool execution in seconds. Defaults to 10.0.
1539
+
1540
+ Examples:
1541
+ >>> import os
1542
+ >>> import json
1543
+ >>> from torchrl.envs.llm import ChatEnv
1544
+ >>> from torchrl.envs.llm.transforms import MCPToolTransform
1545
+ >>> from torchrl.data.llm import History
1546
+ >>> from tensordict import TensorDict, set_list_to_stack
1547
+ >>> set_list_to_stack(True).set()
1548
+ >>>
1549
+ >>> # Add Deno to PATH (required for mcp-run-python)
1550
+ >>> environ = os.environ.copy()
1551
+ >>> deno_path = os.path.expanduser("~/.deno/bin")
1552
+ >>> if deno_path not in os.environ.get('PATH', ''):
1553
+ ... environ['PATH'] = f"{deno_path}:{os.environ['PATH']}"
1554
+ >>>
1555
+ >>> # Define MCP servers
1556
+ >>> servers = {
1557
+ ... "browser": {
1558
+ ... "command": "npx",
1559
+ ... "args": ["@browsermcp/mcp@latest"]
1560
+ ... },
1561
+ ... "python": {
1562
+ ... "command": "uvx",
1563
+ ... "args": ["mcp-run-python", "stdio"],
1564
+ ... "env": environ
1565
+ ... }
1566
+ ... }
1567
+ >>>
1568
+ >>> # Create environment with MCP transform
1569
+ >>> env = ChatEnv(batch_size=(1,))
1570
+ >>> env = env.append_transform(MCPToolTransform(servers=servers)) # doctest: +SKIP
1571
+ [torchrl][INFO] Connecting to MCP server 'browser' (npx @browsermcp/mcp@latest)
1572
+ [torchrl][INFO] Connected to MCP server 'browser' with 12 tools
1573
+ [torchrl][INFO] Connecting to MCP server 'python' (uvx mcp-run-python stdio)
1574
+ [torchrl][INFO] Connected to MCP server 'python' with 1 tools
1575
+ >>>
1576
+ >>> # Execute Python code via MCP
1577
+ >>> reset_data = TensorDict(query="You are a useful assistant", batch_size=(1,))
1578
+ >>> td = env.reset(reset_data)
1579
+ >>> history = td.get("history")
1580
+ >>> code = '''
1581
+ ... import math
1582
+ ... result = math.sqrt(144) + math.pi
1583
+ ... print(f"Result: {result}")
1584
+ ... result
1585
+ ... '''
1586
+ >>> response = History(
1587
+ ... role="assistant",
1588
+ ... content=f'Let me calculate that.\n<tool>python.run_python_code\n{json.dumps({"python_code": code})}</tool>',
1589
+ ... ).unsqueeze(0).unsqueeze(0)
1590
+ >>> history.full = history.prompt.extend(response, inplace=True, dim=-1)
1591
+ >>> history.response = response
1592
+ >>> result = env.step(td.set("history", history)) # doctest: +SKIP
1593
+ >>> print(result["next", "history", "prompt"][..., -1].content) # doctest: +SKIP
1594
+ LinkedList(LinkedList(["Tool python.run_python_code executed successfully:\n[TextContent(type='text', text='<status>success</status>\\n<output>\\nResult: 15.141592653589793\\n</output>\\n<return_value>\\n15.141592653589793\\n</return_value>', annotations=None, meta=None)]"]))
1595
+
1596
+ .. note::
1597
+ This requires the `mcp` package to be installed: `pip install mcp`
1598
+ The transform manages async MCP connections in a background thread.
1599
+
1600
+ .. note::
1601
+ Some MCP servers have additional requirements:
1602
+ - `mcp-run-python` requires Deno: `curl -fsSL https://deno.land/install.sh | sh`
1603
+ - Server-specific dependencies should be installed before use
1604
+ """
1605
+
1606
+ use_step = True # Use _step() method
1607
+
1608
+ def __init__(
1609
+ self,
1610
+ servers: dict[str, dict],
1611
+ tool_call_pattern: str | None = None,
1612
+ tool_name: str = "tool",
1613
+ timeout: float = 10.0,
1614
+ ):
1615
+ super().__init__()
1616
+ self.server_configs = servers
1617
+ self.tool_call_pattern = tool_call_pattern or r"<tool>([\w.]+)\n(.*?)</tool>"
1618
+ self.tool_role = tool_name
1619
+ self.timeout = timeout
1620
+
1621
+ # MCP session management
1622
+ self._loop = None
1623
+ self._thread = None
1624
+ self._sessions = {}
1625
+ self._tools_cache = {}
1626
+ self._shutdown_event = threading.Event()
1627
+ self._ready_event = threading.Event()
1628
+ self._connection_error = None
1629
+
1630
+ # Start the async event loop in a background thread
1631
+ self._start_mcp_thread()
1632
+
1633
+ def _start_mcp_thread(self):
1634
+ """Start a background thread running an async event loop for MCP, since it's made of coroutines."""
1635
+
1636
+ def run_loop():
1637
+ try:
1638
+ import asyncio
1639
+ except ImportError:
1640
+ self._connection_error = "asyncio not available for MCPToolTransform"
1641
+ torchrl_logger.error(self._connection_error)
1642
+ self._ready_event.set()
1643
+ return
1644
+
1645
+ try:
1646
+ self._loop = asyncio.new_event_loop()
1647
+ asyncio.set_event_loop(self._loop)
1648
+
1649
+ # Connect to all MCP servers
1650
+ self._loop.run_until_complete(self._connect_servers())
1651
+
1652
+ # Signal that initialization is complete
1653
+ self._ready_event.set()
1654
+
1655
+ # Keep loop running until shutdown
1656
+ while not self._shutdown_event.is_set():
1657
+ self._loop.run_until_complete(asyncio.sleep(0.1))
1658
+
1659
+ # Cleanup
1660
+ self._loop.run_until_complete(self._disconnect_servers())
1661
+ self._loop.close()
1662
+ except Exception as e:
1663
+ self._connection_error = f"MCP thread failed: {str(e)}"
1664
+ torchrl_logger.error(self._connection_error)
1665
+ self._ready_event.set()
1666
+
1667
+ self._thread = threading.Thread(target=run_loop, daemon=True)
1668
+ self._thread.start()
1669
+
1670
+ # Wait for initialization to complete (with timeout)
1671
+ if not self._ready_event.wait(timeout=10.0):
1672
+ torchrl_logger.warning("MCP initialization timed out after 10 seconds")
1673
+
1674
+ if self._connection_error:
1675
+ torchrl_logger.warning(
1676
+ f"MCP initialization had errors: {self._connection_error}"
1677
+ )
1678
+
1679
+ async def _connect_servers(self):
1680
+ """Connect to all configured MCP servers."""
1681
+ try:
1682
+ from mcp import ClientSession, StdioServerParameters
1683
+ from mcp.client.stdio import stdio_client
1684
+ except ImportError as e:
1685
+ torchrl_logger.error(
1686
+ f"MCP library not installed. Install with: pip install mcp\nError: {e}"
1687
+ )
1688
+ return
1689
+
1690
+ for server_name, config in self.server_configs.items():
1691
+ try:
1692
+ # Create stdio transport
1693
+ server_params = StdioServerParameters(
1694
+ command=config["command"],
1695
+ args=config.get("args", []),
1696
+ env=config.get("env", None),
1697
+ )
1698
+
1699
+ torchrl_logger.info(
1700
+ f"Connecting to MCP server '{server_name}' ({config['command']} {' '.join(config.get('args', []))})"
1701
+ )
1702
+
1703
+ # Connect and initialize session
1704
+ stdio = stdio_client(server_params)
1705
+ try:
1706
+ read, write = await stdio.__aenter__()
1707
+ except Exception as e:
1708
+ error_msg = str(e).lower()
1709
+ if (
1710
+ "deno" in error_msg
1711
+ or "no such file or directory: 'deno'" in error_msg
1712
+ ):
1713
+ torchrl_logger.error(
1714
+ f"Failed to start stdio for '{server_name}': Deno is not installed.\n"
1715
+ f" Install Deno: curl -fsSL https://deno.land/install.sh | sh\n"
1716
+ f" After installing, restart your terminal/shell."
1717
+ )
1718
+ else:
1719
+ torchrl_logger.error(
1720
+ f"Failed to start stdio for '{server_name}': {type(e).__name__}: {e}"
1721
+ )
1722
+ raise
1723
+
1724
+ session = ClientSession(read, write)
1725
+ try:
1726
+ await session.__aenter__()
1727
+ except Exception as e:
1728
+ error_msg = str(e).lower()
1729
+ if "connection closed" in error_msg:
1730
+ # Subprocess likely crashed - check for common issues
1731
+ torchrl_logger.error(
1732
+ f"Failed to initialize session for '{server_name}': Subprocess terminated.\n"
1733
+ f" The MCP server '{config['command']}' started but immediately crashed.\n"
1734
+ f" Common causes:\n"
1735
+ f" - Missing dependencies (e.g., Deno for mcp-run-python)\n"
1736
+ f" - Invalid server configuration\n"
1737
+ f" Try running manually: {config['command']} {' '.join(config.get('args', []))}\n"
1738
+ f" Error: {e}"
1739
+ )
1740
+ else:
1741
+ torchrl_logger.error(
1742
+ f"Failed to initialize session for '{server_name}': {type(e).__name__}: {e}"
1743
+ )
1744
+ # Try to close stdio
1745
+ try:
1746
+ await stdio.__aexit__(None, None, None)
1747
+ except Exception:
1748
+ pass
1749
+ raise
1750
+
1751
+ self._sessions[server_name] = {
1752
+ "session": session,
1753
+ "stdio": stdio,
1754
+ }
1755
+
1756
+ # Discover tools
1757
+ try:
1758
+ tools_response = await session.list_tools()
1759
+ tools = {tool.name: tool for tool in tools_response.tools}
1760
+ self._tools_cache[server_name] = tools
1761
+ torchrl_logger.info(
1762
+ f"Connected to MCP server '{server_name}' with {len(tools)} tools"
1763
+ )
1764
+ except Exception as e:
1765
+ error_msg = str(e).lower()
1766
+ if "connection closed" in error_msg:
1767
+ torchrl_logger.error(
1768
+ f"Could not list tools for server '{server_name}': Connection closed.\n"
1769
+ f" The MCP server started but crashed immediately.\n"
1770
+ f" This often means missing dependencies (e.g., Deno for mcp-run-python).\n"
1771
+ f" Test manually: {config['command']} {' '.join(config.get('args', []))}\n"
1772
+ f" For mcp-run-python, install Deno: curl -fsSL https://deno.land/install.sh | sh"
1773
+ )
1774
+ else:
1775
+ torchrl_logger.warning(
1776
+ f"Could not list tools for server '{server_name}': {e}"
1777
+ )
1778
+ self._tools_cache[server_name] = {}
1779
+ # Don't keep a session we can't list tools from
1780
+ try:
1781
+ await session.__aexit__(None, None, None)
1782
+ await stdio.__aexit__(None, None, None)
1783
+ except Exception:
1784
+ pass
1785
+ if server_name in self._sessions:
1786
+ del self._sessions[server_name]
1787
+
1788
+ except FileNotFoundError as e:
1789
+ # Check if it's a Deno dependency issue
1790
+ if "deno" in str(e).lower():
1791
+ torchrl_logger.error(
1792
+ f"Failed to connect to MCP server '{server_name}': Deno is not installed.\n"
1793
+ f" Install Deno: curl -fsSL https://deno.land/install.sh | sh\n"
1794
+ f" Or use a different MCP server that doesn't require Deno.\n"
1795
+ f" Error: {e}"
1796
+ )
1797
+ else:
1798
+ torchrl_logger.error(
1799
+ f"Failed to connect to MCP server '{server_name}': Command not found.\n"
1800
+ f" Make sure '{config['command']}' is installed and in your PATH.\n"
1801
+ f" Error: {e}"
1802
+ )
1803
+ except Exception as e:
1804
+ torchrl_logger.error(
1805
+ f"Failed to connect to MCP server '{server_name}': {type(e).__name__}: {e}"
1806
+ )
1807
+
1808
+ async def _disconnect_servers(self):
1809
+ """Disconnect from all MCP servers."""
1810
+ for server_name, server_data in self._sessions.items():
1811
+ try:
1812
+ session = server_data["session"]
1813
+ stdio = server_data["stdio"]
1814
+ await session.__aexit__(None, None, None)
1815
+ await stdio.__aexit__(None, None, None)
1816
+ except Exception as e:
1817
+ torchrl_logger.warning(f"Error disconnecting from '{server_name}': {e}")
1818
+
1819
+ self._sessions.clear()
1820
+ self._tools_cache.clear()
1821
+
1822
+ def _extract_tool_calls(self, text: str) -> list[tuple[str, str, str]]:
1823
+ r"""Extract tool calls from text in format <tool>server.tool_name\nargs_json</tool>."""
1824
+ matches = re.findall(self.tool_call_pattern, text, re.DOTALL)
1825
+
1826
+ # Parse into (server_name, tool_name, args_json)
1827
+ parsed = []
1828
+ for full_name, args_json in matches:
1829
+ if "." in full_name:
1830
+ server_name, tool_name = full_name.split(".", 1)
1831
+ else:
1832
+ # Default to first server if no prefix
1833
+ server_name = next(iter(self.server_configs.keys()), None)
1834
+ tool_name = full_name
1835
+
1836
+ if server_name:
1837
+ parsed.append((server_name, tool_name, args_json))
1838
+
1839
+ return parsed
1840
+
1841
+ def _execute_tool_sync(
1842
+ self, server_name: str, tool_name: str, args_json: str
1843
+ ) -> dict:
1844
+ """Execute a tool via MCP (blocking call that schedules async work)."""
1845
+ if not self._loop or not self._thread or not self._thread.is_alive():
1846
+ return {
1847
+ "success": False,
1848
+ "error": "MCP thread not running",
1849
+ }
1850
+
1851
+ # Schedule the async call in the background thread
1852
+ future = asyncio.run_coroutine_threadsafe(
1853
+ self._execute_tool_async(server_name, tool_name, args_json), self._loop
1854
+ )
1855
+
1856
+ try:
1857
+ result = future.result(timeout=self.timeout)
1858
+ return result
1859
+ except TimeoutError:
1860
+ return {
1861
+ "success": False,
1862
+ "error": f"Tool execution timed out after {self.timeout}s",
1863
+ }
1864
+ except Exception as e:
1865
+ return {
1866
+ "success": False,
1867
+ "error": f"Tool execution failed: {str(e)}",
1868
+ }
1869
+
1870
+ async def _execute_tool_async(
1871
+ self, server_name: str, tool_name: str, args_json: str
1872
+ ) -> dict:
1873
+ """Execute a tool via MCP (async implementation)."""
1874
+ try:
1875
+ # Check if server exists
1876
+ if server_name not in self._sessions:
1877
+ return {
1878
+ "success": False,
1879
+ "error": f"MCP server '{server_name}' not connected",
1880
+ }
1881
+
1882
+ session = self._sessions[server_name]["session"]
1883
+
1884
+ # Parse arguments
1885
+ try:
1886
+ args = json.loads(args_json) if args_json.strip() else {}
1887
+ except json.JSONDecodeError as e:
1888
+ return {
1889
+ "success": False,
1890
+ "error": f"Failed to parse tool arguments: {str(e)}",
1891
+ }
1892
+
1893
+ # Call the tool via MCP
1894
+ result = await session.call_tool(tool_name, arguments=args)
1895
+
1896
+ return {
1897
+ "success": True,
1898
+ "result": result.content if hasattr(result, "content") else str(result),
1899
+ }
1900
+
1901
+ except Exception as e:
1902
+ return {
1903
+ "success": False,
1904
+ "error": f"MCP tool call failed: {str(e)}",
1905
+ }
1906
+
1907
+ def _process_batch_item(self, content: str, index: int) -> list[str] | None:
1908
+ """Process one batch item to extract and execute MCP tools.
1909
+
1910
+ This is the main method required by ToolTransformBase.
1911
+
1912
+ Args:
1913
+ content: The text content from the LLM response.
1914
+ index: The index of this item in the batch.
1915
+
1916
+ Returns:
1917
+ list[str] or None: List of result strings for each tool executed,
1918
+ or None if no tools were found.
1919
+ """
1920
+ # Extract tool calls
1921
+ tool_calls = self._extract_tool_calls(content)
1922
+ if not tool_calls:
1923
+ return None
1924
+
1925
+ # Execute each tool via MCP
1926
+ results = []
1927
+ for server_name, tool_name, args_json in tool_calls:
1928
+ result = self._execute_tool_sync(server_name, tool_name, args_json)
1929
+
1930
+ if result["success"]:
1931
+ results.append(
1932
+ f"Tool {server_name}.{tool_name} executed successfully:\n{result['result']}"
1933
+ )
1934
+ else:
1935
+ results.append(
1936
+ f"Tool {server_name}.{tool_name} failed:\n{result['error']}"
1937
+ )
1938
+
1939
+ return results if results else None
1940
+
1941
+ def close(self):
1942
+ """Shutdown the MCP connections and background thread."""
1943
+ if self._thread and self._thread.is_alive():
1944
+ self._shutdown_event.set()
1945
+ self._thread.join(timeout=2.0)
1946
+
1947
+ self._loop = None
1948
+ self._thread = None
1949
+
1950
+ def __del__(self):
1951
+ """Ensure cleanup on deletion."""
1952
+ try:
1953
+ self.close()
1954
+ except Exception:
1955
+ pass