torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,187 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # from https://github.com/toshas/torch_truncnorm
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ from numbers import Number
11
+
12
+ import torch
13
+ from torch.distributions import constraints, Distribution
14
+ from torch.distributions.utils import broadcast_all
15
+
16
+ CONST_SQRT_2 = math.sqrt(2)
17
+ CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi)
18
+ CONST_INV_SQRT_2 = 1 / math.sqrt(2)
19
+ CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
20
+ CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)
21
+
22
+
23
+ class TruncatedStandardNormal(Distribution):
24
+ """Truncated Standard Normal distribution.
25
+
26
+ Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
27
+ """
28
+
29
+ arg_constraints = {
30
+ "a": constraints.real,
31
+ "b": constraints.real,
32
+ }
33
+ has_rsample = True
34
+ eps = 1e-6
35
+
36
+ def __init__(self, a, b, validate_args=None, device=None):
37
+ self.a, self.b = broadcast_all(a, b)
38
+ _non_blocking = device is not None and torch.device(device).type == "cuda"
39
+ self.a = self.a.to(device, non_blocking=_non_blocking)
40
+ self.b = self.b.to(device, non_blocking=_non_blocking)
41
+ if isinstance(a, Number) and isinstance(b, Number):
42
+ batch_shape = torch.Size()
43
+ else:
44
+ batch_shape = self.a.size()
45
+ super().__init__(batch_shape, validate_args=validate_args)
46
+ if self.a.dtype != self.b.dtype:
47
+ raise ValueError("Truncation bounds types are different")
48
+ if any(
49
+ (self.a >= self.b)
50
+ .view(
51
+ -1,
52
+ )
53
+ .tolist()
54
+ ):
55
+ raise ValueError("Incorrect truncation range")
56
+ eps = self.eps
57
+ self._dtype_min_gt_0 = eps
58
+ self._dtype_max_lt_1 = 1 - eps
59
+ self._little_phi_a = self._little_phi(self.a)
60
+ self._little_phi_b = self._little_phi(self.b)
61
+ self._big_phi_a = self._big_phi(self.a)
62
+ self._big_phi_b = self._big_phi(self.b)
63
+ self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps)
64
+ self._log_Z = self._Z.log()
65
+ little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan)
66
+ little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)
67
+ self._lpbb_m_lpaa_d_Z = (
68
+ self._little_phi_b * little_phi_coeff_b
69
+ - self._little_phi_a * little_phi_coeff_a
70
+ ) / self._Z
71
+ self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z
72
+ self._variance = (
73
+ 1
74
+ - self._lpbb_m_lpaa_d_Z
75
+ - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2
76
+ )
77
+ self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
78
+
79
+ @constraints.dependent_property
80
+ def support(self):
81
+ return constraints.interval(self.a, self.b)
82
+
83
+ @property
84
+ def mean(self):
85
+ return self._mean
86
+
87
+ @property
88
+ def deterministic_sample(self):
89
+ return self.mean
90
+
91
+ @property
92
+ def variance(self):
93
+ return self._variance
94
+
95
+ def entropy(self):
96
+ return self._entropy
97
+
98
+ @property
99
+ def auc(self):
100
+ return self._Z
101
+
102
+ @staticmethod
103
+ def _little_phi(x):
104
+ return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI
105
+
106
+ def _big_phi(self, x):
107
+ phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf())
108
+ return phi.clamp(self.eps, 1 - self.eps)
109
+
110
+ @staticmethod
111
+ def _inv_big_phi(x):
112
+ return CONST_SQRT_2 * (2 * x - 1).erfinv()
113
+
114
+ def cdf(self, value):
115
+ if self._validate_args:
116
+ self._validate_sample(value)
117
+ return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)
118
+
119
+ def icdf(self, value):
120
+ y = self._big_phi_a + value * self._Z
121
+ y = y.clamp(self.eps, 1 - self.eps)
122
+ return self._inv_big_phi(y)
123
+
124
+ def log_prob(self, value):
125
+ if self._validate_args:
126
+ self._validate_sample(value)
127
+ return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
128
+
129
+ def rsample(self, sample_shape=None):
130
+ if sample_shape is None:
131
+ sample_shape = torch.Size([])
132
+ shape = self._extended_shape(sample_shape)
133
+ p = torch.empty(shape, device=self.a.device).uniform_(
134
+ self._dtype_min_gt_0, self._dtype_max_lt_1
135
+ )
136
+ return self.icdf(p)
137
+
138
+
139
+ class TruncatedNormal(TruncatedStandardNormal):
140
+ """Truncated Normal distribution.
141
+
142
+ https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
143
+ """
144
+
145
+ has_rsample = True
146
+
147
+ def __init__(self, loc, scale, a, b, validate_args=None, device=None):
148
+ scale = scale.clamp_min(self.eps)
149
+ self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b)
150
+ _non_blocking = device is not None and torch.device(device).type == "cuda"
151
+ a = a.to(device, non_blocking=_non_blocking)
152
+ b = b.to(device, non_blocking=_non_blocking)
153
+ self._non_std_a = a
154
+ self._non_std_b = b
155
+ a = (a - self.loc) / self.scale
156
+ b = (b - self.loc) / self.scale
157
+ super().__init__(a, b, validate_args=validate_args)
158
+ self._log_scale = self.scale.log()
159
+ self._mean = self._mean * self.scale + self.loc
160
+ self._variance = self._variance * self.scale**2
161
+ self._entropy += self._log_scale
162
+
163
+ def _to_std_rv(self, value):
164
+ return (value - self.loc) / self.scale
165
+
166
+ def _from_std_rv(self, value):
167
+ return value * self.scale + self.loc
168
+
169
+ def cdf(self, value):
170
+ return super().cdf(self._to_std_rv(value))
171
+
172
+ def icdf(self, value):
173
+ sample = self._from_std_rv(super().icdf(value))
174
+
175
+ # clamp data but keep gradients
176
+ sample_clip = torch.stack(
177
+ [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0
178
+ ).max(0)[0]
179
+ sample_clip = torch.stack(
180
+ [sample_clip, self._non_std_b.detach().expand_as(sample)], 0
181
+ ).min(0)[0]
182
+ sample.data.copy_(sample_clip)
183
+ return sample
184
+
185
+ def log_prob(self, value):
186
+ value = self._to_std_rv(value)
187
+ return super().log_prob(value) - self._log_scale
@@ -0,0 +1,233 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ from torch import autograd, distributions as d
9
+ from torch.distributions import Independent, Transform, TransformedDistribution
10
+
11
+ try:
12
+ from torch.compiler import is_dynamo_compiling
13
+ except ImportError:
14
+ from torch._dynamo import is_compiling as is_dynamo_compiling
15
+
16
+
17
+ def _cast_device(elt: torch.Tensor | float, device) -> torch.Tensor | float:
18
+ if isinstance(elt, torch.Tensor):
19
+ _non_blocking = device is not None and torch.device(device).type == "cuda"
20
+ return elt.to(device, non_blocking=_non_blocking)
21
+ return elt
22
+
23
+
24
+ def _cast_transform_device(transform, device):
25
+ if transform is None:
26
+ return transform
27
+ _non_blocking = device is not None and torch.device(device).type == "cuda"
28
+ if isinstance(transform, d.ComposeTransform):
29
+ for i, t in enumerate(transform.parts):
30
+ transform.parts[i] = _cast_transform_device(t, device)
31
+ elif isinstance(transform, d.Transform):
32
+ for attribute in dir(transform):
33
+ value = getattr(transform, attribute)
34
+ if isinstance(value, torch.Tensor):
35
+ setattr(
36
+ transform, attribute, value.to(device, non_blocking=_non_blocking)
37
+ )
38
+ return transform
39
+ else:
40
+ raise TypeError(
41
+ f"Cannot perform device casting for transform of type {type(transform)}"
42
+ )
43
+
44
+
45
+ class FasterTransformedDistribution(TransformedDistribution):
46
+ """A faster implementation of TransformedDistribution."""
47
+
48
+ __doc__ = __doc__ + TransformedDistribution.__doc__
49
+
50
+ def __init__(self, base_distribution, transforms, validate_args=None):
51
+ if is_dynamo_compiling():
52
+ return super().__init__(
53
+ base_distribution, transforms, validate_args=validate_args
54
+ )
55
+ if isinstance(transforms, Transform):
56
+ self.transforms = [transforms]
57
+ elif isinstance(transforms, list):
58
+ raise ValueError("Make a ComposeTransform first.")
59
+ else:
60
+ raise ValueError(
61
+ f"transforms must be a Transform or list, but was {transforms}"
62
+ )
63
+ transform = self.transforms[0]
64
+ # Reshape base_distribution according to transforms.
65
+ base_shape = base_distribution.batch_shape + base_distribution.event_shape
66
+ base_event_dim = len(base_distribution.event_shape)
67
+ # transform = ComposeTransform(self.transforms)
68
+ # if len(base_shape) < transform.domain.event_dim:
69
+ # raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
70
+ # .format(transform.domain.event_dim, base_shape))
71
+ transform_codomain_event_dim = transform.codomain.event_dim
72
+ transform_domain_event_dim = transform.domain.event_dim
73
+
74
+ forward_shape = transform.forward_shape(base_shape)
75
+ expanded_base_shape = transform.inverse_shape(forward_shape)
76
+ if base_shape != expanded_base_shape:
77
+ base_batch_shape = expanded_base_shape[
78
+ : len(expanded_base_shape) - base_event_dim
79
+ ]
80
+ base_distribution = base_distribution.expand(base_batch_shape)
81
+ reinterpreted_batch_ndims = transform_domain_event_dim - base_event_dim
82
+ if reinterpreted_batch_ndims > 0:
83
+ base_distribution = Independent(
84
+ base_distribution, reinterpreted_batch_ndims
85
+ )
86
+ self.base_dist = base_distribution
87
+
88
+ # Compute shapes.
89
+ transform_change_in_event_dim = (
90
+ transform_codomain_event_dim - transform_domain_event_dim
91
+ )
92
+ event_dim = max(
93
+ transform_codomain_event_dim, # the transform is coupled
94
+ base_event_dim + transform_change_in_event_dim, # the base dist is coupled
95
+ )
96
+ cut = len(forward_shape) - event_dim
97
+ batch_shape = forward_shape[:cut]
98
+ event_shape = forward_shape[cut:]
99
+ super(TransformedDistribution, self).__init__(
100
+ batch_shape, event_shape, validate_args=validate_args
101
+ )
102
+
103
+
104
+ def _safetanh(x, eps): # noqa: D103
105
+ lim = 1.0 - eps
106
+ y = x.tanh()
107
+ return y.clamp(-lim, lim)
108
+
109
+
110
+ def _safeatanh(y, eps): # noqa: D103
111
+ lim = 1.0 - eps
112
+ return y.clamp(-lim, lim).atanh()
113
+
114
+
115
+ class _SafeTanh(autograd.Function):
116
+ generate_vmap_rule = True
117
+
118
+ @staticmethod
119
+ def forward(input, eps):
120
+ output = input.tanh()
121
+ lim = 1.0 - eps
122
+ output = output.clamp(-lim, lim)
123
+ # ctx.save_for_backward(output)
124
+ return output
125
+
126
+ @staticmethod
127
+ def setup_context(ctx, inputs, output):
128
+ # input, eps = inputs
129
+ # ctx.mark_non_differentiable(ind, ind_inv)
130
+ # # Tensors must be saved via ctx.save_for_backward. Please do not
131
+ # # assign them directly onto the ctx object.
132
+ ctx.save_for_backward(output)
133
+
134
+ @staticmethod
135
+ def backward(ctx, *grad):
136
+ grad = grad[0]
137
+ (output,) = ctx.saved_tensors
138
+ return (grad * (1 - output.pow(2)), None)
139
+
140
+
141
+ class _SafeTanhNoEps(autograd.Function):
142
+ generate_vmap_rule = True
143
+
144
+ @staticmethod
145
+ def forward(input):
146
+ output = input.tanh()
147
+ eps = torch.finfo(input.dtype).resolution
148
+ lim = 1.0 - eps
149
+ output = output.clamp(-lim, lim)
150
+ return output
151
+
152
+ @staticmethod
153
+ def setup_context(ctx, inputs, output):
154
+ ctx.save_for_backward(output)
155
+
156
+ @staticmethod
157
+ def backward(ctx, *grad):
158
+ grad = grad[0]
159
+ (output,) = ctx.saved_tensors
160
+ return (grad * (1 - output.pow(2)),)
161
+
162
+
163
+ class _SafeaTanh(autograd.Function):
164
+ generate_vmap_rule = True
165
+
166
+ @staticmethod
167
+ def forward(tanh_val, eps):
168
+ if eps is None:
169
+ eps = torch.finfo(tanh_val.dtype).resolution
170
+ lim = 1.0 - eps
171
+ output = tanh_val.clamp(-lim, lim)
172
+ # ctx.save_for_backward(output)
173
+ output = output.atanh()
174
+ return output
175
+
176
+ @staticmethod
177
+ def setup_context(ctx, inputs, output):
178
+ tanh_val, eps = inputs
179
+
180
+ # ctx.mark_non_differentiable(ind, ind_inv)
181
+ # # Tensors must be saved via ctx.save_for_backward. Please do not
182
+ # # assign them directly onto the ctx object.
183
+ ctx.save_for_backward(tanh_val)
184
+ ctx.eps = eps
185
+
186
+ @staticmethod
187
+ def backward(ctx, *grad):
188
+ grad = grad[0]
189
+ (tanh_val,) = ctx.saved_tensors
190
+ eps = ctx.eps
191
+ lim = 1.0 - eps
192
+ output = tanh_val.clamp(-lim, lim)
193
+ return (grad / (1 - output.pow(2)), None)
194
+
195
+
196
+ class _SafeaTanhNoEps(autograd.Function):
197
+ generate_vmap_rule = True
198
+
199
+ @staticmethod
200
+ def forward(tanh_val):
201
+ eps = torch.finfo(tanh_val.dtype).resolution
202
+ lim = 1.0 - eps
203
+ output = tanh_val.clamp(-lim, lim)
204
+ # ctx.save_for_backward(output)
205
+ output = output.atanh()
206
+ return output
207
+
208
+ @staticmethod
209
+ def setup_context(ctx, inputs, output):
210
+ tanh_val = inputs[0]
211
+ eps = torch.finfo(tanh_val.dtype).resolution
212
+
213
+ # ctx.mark_non_differentiable(ind, ind_inv)
214
+ # # Tensors must be saved via ctx.save_for_backward. Please do not
215
+ # # assign them directly onto the ctx object.
216
+ ctx.save_for_backward(tanh_val)
217
+ ctx.eps = eps
218
+
219
+ @staticmethod
220
+ def backward(ctx, *grad):
221
+ grad = grad[0]
222
+ (tanh_val,) = ctx.saved_tensors
223
+ eps = ctx.eps
224
+ lim = 1.0 - eps
225
+ output = tanh_val.clamp(-lim, lim)
226
+ return (grad / (1 - output.pow(2)),)
227
+
228
+
229
+ safetanh = _SafeTanh.apply
230
+ safeatanh = _SafeaTanh.apply
231
+
232
+ safetanh_noeps = _SafeTanhNoEps.apply
233
+ safeatanh_noeps = _SafeaTanhNoEps.apply
@@ -0,0 +1,62 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """LLM utilities for TorchRL.
6
+
7
+ Note:
8
+ This package contains optional integrations (e.g. vLLM) that may rely on native
9
+ extensions. To keep `import torchrl` / `import torchrl.envs` lightweight and
10
+ robust, we **avoid importing optional backends at module import time** and
11
+ instead only import those backends on demand.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any
17
+
18
+ from .policies.common import ChatHistory, LLMWrapperBase, LogProbs, Masks, Text, Tokens
19
+ from .policies.transformers_wrapper import (
20
+ RemoteTransformersWrapper,
21
+ TransformersWrapper,
22
+ )
23
+ from .policies.vllm_wrapper import vLLMWrapper
24
+
25
+ __all__ = [
26
+ # Data structures
27
+ "ChatHistory",
28
+ "LogProbs",
29
+ "Masks",
30
+ "Text",
31
+ "Tokens",
32
+ # Wrapper base class
33
+ "LLMWrapperBase",
34
+ # Local wrappers
35
+ "TransformersWrapper",
36
+ "vLLMWrapper",
37
+ # Remote wrappers
38
+ "RemoteTransformersWrapper",
39
+ # Async vLLM (recommended)
40
+ "AsyncVLLM",
41
+ "make_async_vllm_engine",
42
+ "stateless_init_process_group_async",
43
+ # Sync vLLM utilities
44
+ "make_vllm_worker",
45
+ "stateless_init_process_group",
46
+ ]
47
+
48
+
49
+ def __getattr__(name: str) -> Any: # noqa: ANN401
50
+ # Keep backends optional and on-demand to avoid importing vLLM native extensions
51
+ # as a side-effect of importing torchrl.
52
+ if name in {
53
+ "AsyncVLLM",
54
+ "make_async_vllm_engine",
55
+ "make_vllm_worker",
56
+ "stateless_init_process_group",
57
+ "stateless_init_process_group_async",
58
+ }:
59
+ from . import backends # local import is intentional / required
60
+
61
+ return getattr(backends, name)
62
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,65 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """LLM backends.
6
+
7
+ These backends can be optional and may rely on native extensions. We avoid
8
+ importing them at module import time and lazily load on attribute access.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+ __all__ = [
16
+ # Base classes
17
+ "RLvLLMEngine",
18
+ # Sync vLLM
19
+ "make_vllm_worker",
20
+ "RayLLMWorker",
21
+ "LocalLLMWrapper",
22
+ # Async vLLM
23
+ "_AsyncvLLMWorker",
24
+ "_AsyncLLMEngine",
25
+ "AsyncVLLM",
26
+ "make_async_vllm_engine",
27
+ # Utilities
28
+ "stateless_init_process_group",
29
+ "stateless_init_process_group_async",
30
+ ]
31
+
32
+ _LAZY_ATTRS: dict[str, tuple[str, str]] = {
33
+ # Base classes and interfaces
34
+ "RLvLLMEngine": ("torchrl.modules.llm.backends.vllm", "RLvLLMEngine"),
35
+ # Sync vLLM
36
+ "make_vllm_worker": ("torchrl.modules.llm.backends.vllm", "make_vllm_worker"),
37
+ "RayLLMWorker": ("torchrl.modules.llm.backends.vllm", "RayLLMWorker"),
38
+ "LocalLLMWrapper": ("torchrl.modules.llm.backends.vllm", "LocalLLMWrapper"),
39
+ # Async vLLM
40
+ "_AsyncvLLMWorker": ("torchrl.modules.llm.backends.vllm", "_AsyncvLLMWorker"),
41
+ "_AsyncLLMEngine": ("torchrl.modules.llm.backends.vllm", "_AsyncLLMEngine"),
42
+ "AsyncVLLM": ("torchrl.modules.llm.backends.vllm", "AsyncVLLM"),
43
+ "make_async_vllm_engine": (
44
+ "torchrl.modules.llm.backends.vllm",
45
+ "make_async_vllm_engine",
46
+ ),
47
+ # Utilities
48
+ "stateless_init_process_group": (
49
+ "torchrl.modules.llm.backends.vllm",
50
+ "stateless_init_process_group",
51
+ ),
52
+ "stateless_init_process_group_async": (
53
+ "torchrl.modules.llm.backends.vllm",
54
+ "stateless_init_process_group_async",
55
+ ),
56
+ }
57
+
58
+
59
+ def __getattr__(name: str) -> Any: # noqa: ANN401
60
+ target = _LAZY_ATTRS.get(name)
61
+ if target is None:
62
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
63
+ module_name, attr_name = target
64
+ module = __import__(module_name, fromlist=[attr_name])
65
+ return getattr(module, attr_name)
@@ -0,0 +1,94 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """vLLM backends for TorchRL.
7
+
8
+ This module provides comprehensive vLLM integration including:
9
+ - Base classes and interfaces
10
+ - Synchronous vLLM workers
11
+ - Asynchronous vLLM services
12
+ - Shared utilities
13
+
14
+ Examples:
15
+ >>> # Create an async vLLM service (recommended)
16
+ >>> from torchrl.modules.llm.backends.vllm import AsyncVLLM
17
+ >>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B")
18
+
19
+ >>> # Create a sync Ray worker
20
+ >>> from torchrl.modules.llm.backends.vllm import make_vllm_worker
21
+ >>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", make_ray_worker=True)
22
+
23
+ >>> # All engines implement the same interface
24
+ >>> from torchrl.modules.llm.backends.vllm import RLvLLMEngine
25
+ >>> updater = vLLMUpdaterV2(any_engine) # Works with any RLvLLMEngine
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ from typing import Any
31
+
32
+ __all__ = [
33
+ # Base classes and interfaces
34
+ "RLvLLMEngine",
35
+ # Synchronous vLLM
36
+ "make_vllm_worker",
37
+ "RayLLMWorker",
38
+ "LocalLLMWrapper",
39
+ # Asynchronous vLLM
40
+ "AsyncVLLM",
41
+ "make_async_vllm_engine",
42
+ "_AsyncLLMEngine",
43
+ "_AsyncvLLMWorker",
44
+ # Utilities
45
+ "stateless_init_process_group",
46
+ "stateless_init_process_group_async",
47
+ ]
48
+
49
+ _LAZY_ATTRS: dict[str, tuple[str, str]] = {
50
+ # Base
51
+ "RLvLLMEngine": ("torchrl.modules.llm.backends.vllm.base", "RLvLLMEngine"),
52
+ # Sync
53
+ "make_vllm_worker": (
54
+ "torchrl.modules.llm.backends.vllm.vllm_sync",
55
+ "make_vllm_worker",
56
+ ),
57
+ "RayLLMWorker": ("torchrl.modules.llm.backends.vllm.vllm_sync", "RayLLMWorker"),
58
+ "LocalLLMWrapper": (
59
+ "torchrl.modules.llm.backends.vllm.vllm_sync",
60
+ "LocalLLMWrapper",
61
+ ),
62
+ # Async
63
+ "_AsyncLLMEngine": (
64
+ "torchrl.modules.llm.backends.vllm.vllm_async",
65
+ "_AsyncLLMEngine",
66
+ ),
67
+ "_AsyncvLLMWorker": (
68
+ "torchrl.modules.llm.backends.vllm.vllm_async",
69
+ "_AsyncvLLMWorker",
70
+ ),
71
+ "AsyncVLLM": ("torchrl.modules.llm.backends.vllm.vllm_async", "AsyncVLLM"),
72
+ "make_async_vllm_engine": (
73
+ "torchrl.modules.llm.backends.vllm.vllm_async",
74
+ "make_async_vllm_engine",
75
+ ),
76
+ # Utils
77
+ "stateless_init_process_group": (
78
+ "torchrl.modules.llm.backends.vllm.vllm_utils",
79
+ "stateless_init_process_group",
80
+ ),
81
+ "stateless_init_process_group_async": (
82
+ "torchrl.modules.llm.backends.vllm.vllm_utils",
83
+ "stateless_init_process_group_async",
84
+ ),
85
+ }
86
+
87
+
88
+ def __getattr__(name: str) -> Any: # noqa: ANN401
89
+ target = _LAZY_ATTRS.get(name)
90
+ if target is None:
91
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
92
+ module_name, attr_name = target
93
+ module = __import__(module_name, fromlist=[attr_name])
94
+ return getattr(module, attr_name)
@@ -0,0 +1,46 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Override the last layers of your models here."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+
12
+ import torch
13
+
14
+ try:
15
+ from vllm.config import VllmConfig
16
+ from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
17
+ except ImportError:
18
+
19
+ class VllmConfig:
20
+ """Placeholder for VllmConfig class when vLLM is not installed."""
21
+
22
+ class Qwen3ForCausalLM:
23
+ """Placeholder for Qwen3ForCausalLM class when vLLM is not installed."""
24
+
25
+
26
+ def is_fp32_output_enabled() -> bool:
27
+ """Check if FP32 output is enabled."""
28
+ return os.getenv("VLLM_ENABLE_FP32_OUTPUT", "0") == "1"
29
+
30
+
31
+ class Qwen3ForCausalLMFP32(Qwen3ForCausalLM):
32
+ """Qwen3ForCausalLM with FP32 output."""
33
+
34
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
35
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
36
+ if is_fp32_output_enabled():
37
+ self.lm_head.float()
38
+
39
+ def compute_logits(
40
+ self,
41
+ hidden_states: torch.Tensor,
42
+ ) -> torch.Tensor | None:
43
+ if is_fp32_output_enabled():
44
+ hidden_states = hidden_states.float()
45
+ logits = self.logits_processor(self.lm_head, hidden_states)
46
+ return logits