torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,179 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Assertions and validation utilities for TorchRL tests."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+ from tensordict import TensorDict
12
+
13
+ __all__ = [
14
+ "check_rollout_consistency_multikey_env",
15
+ "rand_reset",
16
+ "rollout_consistency_assertion",
17
+ ]
18
+
19
+
20
+ def rollout_consistency_assertion(
21
+ rollout, *, done_key="done", observation_key="observation", done_strict=False
22
+ ):
23
+ """Test that observations in 'next' match observations in the next root tensordict.
24
+
25
+ Verifies consistency: when done is False the next observation should match,
26
+ and when done is True they should differ (indicating a reset occurred).
27
+
28
+ Args:
29
+ rollout: The rollout tensordict to validate.
30
+ done_key: The key for the done signal.
31
+ observation_key: The key for observations.
32
+ done_strict: If True, raise an error if no done is detected.
33
+ """
34
+ done = rollout[..., :-1]["next", done_key].squeeze(-1)
35
+ # data resulting from step, when it's not done
36
+ r_not_done = rollout[..., :-1]["next"][~done]
37
+ # data resulting from step, when it's not done, after step_mdp
38
+ r_not_done_tp1 = rollout[:, 1:][~done]
39
+ torch.testing.assert_close(
40
+ r_not_done[observation_key],
41
+ r_not_done_tp1[observation_key],
42
+ msg=f"Key {observation_key} did not match",
43
+ )
44
+
45
+ if done_strict and not done.any():
46
+ raise RuntimeError("No done detected, test could not complete.")
47
+ if done.any():
48
+ # data resulting from step, when it's done
49
+ r_done = rollout[..., :-1]["next"][done]
50
+ # data resulting from step, when it's done, after step_mdp and reset
51
+ r_done_tp1 = rollout[..., 1:][done]
52
+ # check that at least one obs after reset does not match the version before reset
53
+ assert not torch.isclose(
54
+ r_done[observation_key], r_done_tp1[observation_key]
55
+ ).all()
56
+
57
+
58
+ def rand_reset(env):
59
+ """Generate a tensordict with reset keys that mimic the done spec.
60
+
61
+ Values are drawn at random until at least one reset is present.
62
+
63
+ Args:
64
+ env: The environment to generate reset keys for.
65
+
66
+ Returns:
67
+ A TensorDict containing the reset signals.
68
+ """
69
+ full_done_spec = env.full_done_spec
70
+ result = {}
71
+ for reset_key, list_of_done in zip(env.reset_keys, env.done_keys_groups):
72
+ val = full_done_spec[list_of_done[0]].rand()
73
+ while not val.any():
74
+ val = full_done_spec[list_of_done[0]].rand()
75
+ result[reset_key] = val
76
+ # create a data structure that keeps the batch size of the nested specs
77
+ result = (
78
+ full_done_spec.zero().update(result).exclude(*full_done_spec.keys(True, True))
79
+ )
80
+ return result
81
+
82
+
83
+ def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int):
84
+ """Check rollout consistency for environments with multiple observation/action keys.
85
+
86
+ Validates that:
87
+ - Done and reset behavior is correct for root, nested_1, and nested_2
88
+ - Observations update correctly based on actions
89
+ - Rewards are computed correctly
90
+
91
+ Args:
92
+ td: The rollout tensordict to validate.
93
+ max_steps: The maximum steps before done in the environment.
94
+ """
95
+ index_batch_size = (0,) * (len(td.batch_size) - 1)
96
+
97
+ # Check done and reset for root
98
+ observation_is_max = td["next", "observation"][..., 0, 0, 0] == max_steps + 1
99
+ next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1)
100
+ assert (td["next", "done"][observation_is_max]).all()
101
+ assert (~td["next", "done"][~observation_is_max]).all()
102
+ # Obs after done is 0
103
+ assert (td["observation"][index_batch_size][1:][next_is_done] == 0).all()
104
+ # Obs after not done is previous obs
105
+ assert (
106
+ td["observation"][index_batch_size][1:][~next_is_done]
107
+ == td["next", "observation"][index_batch_size][:-1][~next_is_done]
108
+ ).all()
109
+ # Check observation and reward update with count action for root
110
+ action_is_count = td["action"].long().argmax(-1).to(torch.bool)
111
+ assert (
112
+ td["next", "observation"][action_is_count]
113
+ == td["observation"][action_is_count] + 1
114
+ ).all()
115
+ assert (td["next", "reward"][action_is_count] == 1).all()
116
+ # Check observation and reward do not update with no-count action for root
117
+ assert (
118
+ td["next", "observation"][~action_is_count]
119
+ == td["observation"][~action_is_count]
120
+ ).all()
121
+ assert (td["next", "reward"][~action_is_count] == 0).all()
122
+
123
+ # Check done and reset for nested_1
124
+ observation_is_max = td["next", "nested_1", "observation"][..., 0] == max_steps + 1
125
+ # done at the root always prevail
126
+ next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1)
127
+ assert (td["next", "nested_1", "done"][observation_is_max]).all()
128
+ assert (~td["next", "nested_1", "done"][~observation_is_max]).all()
129
+ # Obs after done is 0
130
+ assert (
131
+ td["nested_1", "observation"][index_batch_size][1:][next_is_done] == 0
132
+ ).all()
133
+ # Obs after not done is previous obs
134
+ assert (
135
+ td["nested_1", "observation"][index_batch_size][1:][~next_is_done]
136
+ == td["next", "nested_1", "observation"][index_batch_size][:-1][~next_is_done]
137
+ ).all()
138
+ # Check observation and reward update with count action for nested_1
139
+ action_is_count = td["nested_1"]["action"].to(torch.bool)
140
+ assert (
141
+ td["next", "nested_1", "observation"][action_is_count]
142
+ == td["nested_1", "observation"][action_is_count] + 1
143
+ ).all()
144
+ assert (td["next", "nested_1", "gift"][action_is_count] == 1).all()
145
+ # Check observation and reward do not update with no-count action for nested_1
146
+ assert (
147
+ td["next", "nested_1", "observation"][~action_is_count]
148
+ == td["nested_1", "observation"][~action_is_count]
149
+ ).all()
150
+ assert (td["next", "nested_1", "gift"][~action_is_count] == 0).all()
151
+
152
+ # Check done and reset for nested_2
153
+ observation_is_max = td["next", "nested_2", "observation"][..., 0] == max_steps + 1
154
+ # done at the root always prevail
155
+ next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1)
156
+ assert (td["next", "nested_2", "done"][observation_is_max]).all()
157
+ assert (~td["next", "nested_2", "done"][~observation_is_max]).all()
158
+ # Obs after done is 0
159
+ assert (
160
+ td["nested_2", "observation"][index_batch_size][1:][next_is_done] == 0
161
+ ).all()
162
+ # Obs after not done is previous obs
163
+ assert (
164
+ td["nested_2", "observation"][index_batch_size][1:][~next_is_done]
165
+ == td["next", "nested_2", "observation"][index_batch_size][:-1][~next_is_done]
166
+ ).all()
167
+ # Check observation and reward update with count action for nested_2
168
+ action_is_count = td["nested_2"]["azione"].squeeze(-1).to(torch.bool)
169
+ assert (
170
+ td["next", "nested_2", "observation"][action_is_count]
171
+ == td["nested_2", "observation"][action_is_count] + 1
172
+ ).all()
173
+ assert (td["next", "nested_2", "reward"][action_is_count] == 1).all()
174
+ # Check observation and reward do not update with no-count action for nested_2
175
+ assert (
176
+ td["next", "nested_2", "observation"][~action_is_count]
177
+ == td["nested_2", "observation"][~action_is_count]
178
+ ).all()
179
+ assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()
@@ -0,0 +1,122 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ import psutil
9
+
10
+ __all__ = [
11
+ "assert_no_new_python_processes",
12
+ "is_python_process",
13
+ "snapshot_python_processes",
14
+ ]
15
+
16
+
17
+ def is_python_process(comm: str | None, args: str | None) -> bool:
18
+ """Check if a process is a python process."""
19
+ if comm is None:
20
+ comm = ""
21
+ comm = comm.lower()
22
+ if comm.startswith(("python", "pypy")):
23
+ return True
24
+ if not args:
25
+ return False
26
+ return "python" in args.lower()
27
+
28
+
29
+ def snapshot_python_processes(
30
+ root: psutil.Process | None = None,
31
+ ) -> dict[tuple[int, float], dict[str, Any]]:
32
+ """Snapshot python processes belonging to the given process tree.
33
+
34
+ Returns a dict keyed by (pid, start_time) -> info.
35
+ """
36
+ if root is None:
37
+ root = psutil.Process(os.getpid())
38
+
39
+ uid = os.getuid()
40
+
41
+ # Snapshot descendant PIDs first, then query process info via process_iter.
42
+ # This avoids race conditions where a child exits between `children()` and
43
+ # attribute access on a stale Process handle (common with Ray helpers).
44
+ descendant_pids = {root.pid}
45
+ descendant_pids.update(p.pid for p in root.children(recursive=True))
46
+
47
+ out: dict[tuple[int, float], dict[str, Any]] = {}
48
+ for proc in psutil.process_iter(
49
+ attrs=["pid", "name", "cmdline", "create_time", "uids"], ad_value=None
50
+ ):
51
+ info = proc.info
52
+ pid = info.get("pid")
53
+ if pid is None or pid not in descendant_pids:
54
+ continue
55
+ uids = info.get("uids")
56
+ if uids is None or uids.real != uid:
57
+ continue
58
+
59
+ name = info.get("name") or ""
60
+ cmdline = info.get("cmdline") or []
61
+ args = " ".join(cmdline) if isinstance(cmdline, (list, tuple)) else str(cmdline)
62
+ if not is_python_process(name, args):
63
+ continue
64
+
65
+ start_time = float(info.get("create_time") or 0.0)
66
+ key = (int(pid), start_time)
67
+ out[key] = {
68
+ "pid": int(pid),
69
+ "start_time": start_time,
70
+ "comm": name,
71
+ "args": args,
72
+ }
73
+ return out
74
+
75
+
76
+ def assert_no_new_python_processes(
77
+ *,
78
+ baseline: dict[tuple[int, float], dict[str, Any]],
79
+ baseline_time: float,
80
+ timeout: float = 20.0,
81
+ ignore_info_fn: Callable[[dict[str, Any]], bool] | None = None,
82
+ ) -> None:
83
+ """Assert that no python process started after baseline_time remains alive.
84
+
85
+ The check is limited to the current process tree (pytest process + descendants).
86
+ """
87
+ if ignore_info_fn is None:
88
+
89
+ def ignore_info_fn(_info: dict[str, Any]) -> bool:
90
+ return False
91
+
92
+ deadline = time.time() + timeout
93
+ last_new: dict[tuple[int, float], dict[str, Any]] | None = None
94
+ while time.time() < deadline:
95
+ current = snapshot_python_processes()
96
+ new: dict[tuple[int, float], dict[str, Any]] = {}
97
+ for (pid, start_time), info in current.items():
98
+ if pid == os.getpid():
99
+ continue
100
+ if ignore_info_fn(info):
101
+ continue
102
+ # Guard against pid reuse: only consider processes started after the baseline.
103
+ if start_time and start_time < baseline_time - 1.0:
104
+ continue
105
+ if (pid, start_time) in baseline:
106
+ continue
107
+ new[(pid, start_time)] = info
108
+ if not new:
109
+ return
110
+ last_new = new
111
+ time.sleep(0.25)
112
+
113
+ if last_new is None:
114
+ return
115
+ details = "\n".join(
116
+ f"- pid={v['pid']} comm={v.get('comm')} args={v.get('args')}"
117
+ for v in last_new.values()
118
+ )
119
+ raise AssertionError(
120
+ "Leaked python processes detected after collector.shutdown().\n"
121
+ f"Processes still alive:\n{details}"
122
+ )
@@ -0,0 +1,227 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Environment creation utilities for TorchRL tests."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+
12
+ from torchrl.envs import MultiThreadedEnv, ObservationNorm
13
+ from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
14
+ from torchrl.envs.libs.envpool import _has_envpool
15
+ from torchrl.envs.libs.gym import GymEnv
16
+ from torchrl.envs.transforms import (
17
+ Compose,
18
+ RewardClipping,
19
+ ToTensorImage,
20
+ TransformedEnv,
21
+ )
22
+ from torchrl.testing.gym_helpers import HALFCHEETAH_VERSIONED, PONG_VERSIONED
23
+ from torchrl.testing.utils import mp_ctx
24
+
25
+ __all__ = [
26
+ "get_transform_out",
27
+ "make_envs",
28
+ "make_multithreaded_env",
29
+ ]
30
+
31
+
32
+ def make_envs(
33
+ env_name,
34
+ frame_skip,
35
+ transformed_in,
36
+ transformed_out,
37
+ N,
38
+ device="cpu",
39
+ kwargs=None,
40
+ local_mp_ctx=mp_ctx,
41
+ ):
42
+ """Create parallel, serial, multithreaded, and single environment instances.
43
+
44
+ This helper creates environments suitable for testing batched environment behavior.
45
+
46
+ Args:
47
+ env_name: The gym environment name.
48
+ frame_skip: Number of frames to skip.
49
+ transformed_in: Whether to apply transforms inside the base env.
50
+ transformed_out: Whether to apply transforms outside the batched env.
51
+ N: Number of environments in the batch.
52
+ device: Device for the environments.
53
+ kwargs: Additional keyword arguments for environment creation.
54
+ local_mp_ctx: Multiprocessing context ('fork' or 'spawn').
55
+
56
+ Returns:
57
+ Tuple of (env_parallel, env_serial, env_multithread, env0).
58
+ """
59
+ torch.manual_seed(0)
60
+ if not transformed_in:
61
+
62
+ def create_env_fn():
63
+ return GymEnv(env_name, frame_skip=frame_skip, device=device)
64
+
65
+ else:
66
+ if env_name == PONG_VERSIONED():
67
+
68
+ def create_env_fn():
69
+ base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
70
+ in_keys = list(base_env.observation_spec.keys(True, True))[:1]
71
+ return TransformedEnv(
72
+ base_env,
73
+ Compose(*[ToTensorImage(in_keys=in_keys), RewardClipping(0, 0.1)]),
74
+ )
75
+
76
+ else:
77
+
78
+ def create_env_fn():
79
+
80
+ base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
81
+ in_keys = list(base_env.observation_spec.keys(True, True))[:1]
82
+
83
+ return TransformedEnv(
84
+ base_env,
85
+ Compose(
86
+ ObservationNorm(in_keys=in_keys, loc=0.5, scale=1.1),
87
+ RewardClipping(0, 0.1),
88
+ ),
89
+ )
90
+
91
+ env0 = create_env_fn()
92
+ env_parallel = ParallelEnv(
93
+ N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx
94
+ )
95
+ env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs)
96
+
97
+ for key in env0.observation_spec.keys(True, True):
98
+ obs_key = key
99
+ break
100
+ else:
101
+ obs_key = None
102
+
103
+ if transformed_out:
104
+ t_out = get_transform_out(env_name, transformed_in, obs_key=obs_key)
105
+
106
+ env0 = TransformedEnv(
107
+ env0,
108
+ t_out(),
109
+ )
110
+ env_parallel = TransformedEnv(
111
+ env_parallel,
112
+ t_out(),
113
+ )
114
+ env_serial = TransformedEnv(
115
+ env_serial,
116
+ t_out(),
117
+ )
118
+ else:
119
+ t_out = None
120
+
121
+ if _has_envpool:
122
+ env_multithread = make_multithreaded_env(
123
+ env_name,
124
+ frame_skip,
125
+ t_out,
126
+ N,
127
+ device="cpu",
128
+ kwargs=None,
129
+ )
130
+ else:
131
+ env_multithread = None
132
+
133
+ return env_parallel, env_serial, env_multithread, env0
134
+
135
+
136
+ def make_multithreaded_env(
137
+ env_name,
138
+ frame_skip,
139
+ transformed_out,
140
+ N,
141
+ device="cpu",
142
+ kwargs=None,
143
+ ):
144
+ """Create a multithreaded environment using envpool.
145
+
146
+ Args:
147
+ env_name: The gym environment name.
148
+ frame_skip: Number of frames to skip.
149
+ transformed_out: Transform factory to apply, or None.
150
+ N: Number of environments in the batch.
151
+ device: Device for the environment.
152
+ kwargs: Additional keyword arguments (unused, for API compatibility).
153
+
154
+ Returns:
155
+ A MultiThreadedEnv instance, optionally wrapped with transforms.
156
+ """
157
+ torch.manual_seed(0)
158
+ multithreaded_kwargs = (
159
+ {"frame_skip": frame_skip} if env_name == PONG_VERSIONED() else {}
160
+ )
161
+ env_multithread = MultiThreadedEnv(
162
+ N,
163
+ env_name,
164
+ create_env_kwargs=multithreaded_kwargs,
165
+ device=device,
166
+ )
167
+
168
+ if transformed_out:
169
+ for key in env_multithread.observation_spec.keys(True, True):
170
+ obs_key = key
171
+ break
172
+ else:
173
+ obs_key = None
174
+ env_multithread = TransformedEnv(
175
+ env_multithread,
176
+ get_transform_out(env_name, transformed_in=False, obs_key=obs_key)(),
177
+ )
178
+ return env_multithread
179
+
180
+
181
+ def get_transform_out(env_name, transformed_in, obs_key=None):
182
+ """Create a transform factory for output transforms based on environment type.
183
+
184
+ Args:
185
+ env_name: The gym environment name.
186
+ transformed_in: Whether transforms were already applied inside.
187
+ obs_key: The observation key to transform.
188
+
189
+ Returns:
190
+ A callable that returns a Compose transform.
191
+ """
192
+ if env_name == PONG_VERSIONED():
193
+ if obs_key is None:
194
+ obs_key = "pixels"
195
+
196
+ def t_out():
197
+ return (
198
+ Compose(*[ToTensorImage(in_keys=[obs_key]), RewardClipping(0, 0.1)])
199
+ if not transformed_in
200
+ else Compose(*[ObservationNorm(in_keys=[obs_key], loc=0, scale=1)])
201
+ )
202
+
203
+ elif env_name == HALFCHEETAH_VERSIONED:
204
+ if obs_key is None:
205
+ obs_key = ("observation", "velocity")
206
+
207
+ def t_out():
208
+ return Compose(
209
+ ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1),
210
+ RewardClipping(0, 0.1),
211
+ )
212
+
213
+ else:
214
+ if obs_key is None:
215
+ obs_key = "observation"
216
+
217
+ def t_out():
218
+ return (
219
+ Compose(
220
+ ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1),
221
+ RewardClipping(0, 0.1),
222
+ )
223
+ if not transformed_in
224
+ else Compose(ObservationNorm(in_keys=[obs_key], loc=1.0, scale=1.0))
225
+ )
226
+
227
+ return t_out
@@ -0,0 +1,35 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+
9
+ def make_isaac_env(env_name: str = "Isaac-Ant-v0"):
10
+ """Helper function to create an IsaacLab env."""
11
+ import torch
12
+
13
+ torch.manual_seed(0)
14
+ import argparse
15
+
16
+ # This code block ensures that the Isaac app is started in headless mode
17
+ from isaaclab.app import AppLauncher
18
+ from torchrl import logger as torchrl_logger
19
+
20
+ parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.")
21
+ AppLauncher.add_app_launcher_args(parser)
22
+ args_cli, hydra_args = parser.parse_known_args(["--headless"])
23
+ AppLauncher(args_cli)
24
+
25
+ # Imports and env
26
+ import gymnasium as gym
27
+ import isaaclab_tasks # noqa: F401
28
+ from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
29
+ from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
30
+
31
+ torchrl_logger.info("Making IsaacLab env...")
32
+ env = gym.make(env_name, cfg=AntEnvCfg())
33
+ torchrl_logger.info("Wrapping IsaacLab env...")
34
+ env = IsaacLabWrapper(env)
35
+ return env