torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,251 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+
9
+ import torch
10
+ from tensordict.nn import TensorDictModule, TensorDictSequential
11
+
12
+ from torch import nn, optim
13
+ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
14
+ from torchrl.data.replay_buffers import SamplerWithoutReplacement
15
+ from torchrl.envs import (
16
+ CatTensors,
17
+ Compose,
18
+ DMControlEnv,
19
+ DoubleToFloat,
20
+ EnvCreator,
21
+ InitTracker,
22
+ ParallelEnv,
23
+ RewardSum,
24
+ StepCounter,
25
+ TransformedEnv,
26
+ )
27
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
28
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
29
+ from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator
30
+
31
+ from torchrl.objectives import SoftUpdate
32
+ from torchrl.objectives.td3_bc import TD3BCLoss
33
+ from torchrl.record import VideoRecorder
34
+
35
+
36
+ # ====================================================================
37
+ # Environment utils
38
+ # -----------------
39
+
40
+
41
+ def env_maker(cfg, device="cpu", from_pixels=False):
42
+ lib = cfg.env.library
43
+ if lib in ("gym", "gymnasium"):
44
+ with set_gym_backend(lib):
45
+ return GymEnv(
46
+ cfg.env.name,
47
+ device=device,
48
+ from_pixels=from_pixels,
49
+ pixels_only=False,
50
+ )
51
+ elif lib == "dm_control":
52
+ env = DMControlEnv(
53
+ cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
54
+ )
55
+ return TransformedEnv(
56
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
57
+ )
58
+ else:
59
+ raise NotImplementedError(f"Unknown lib {lib}.")
60
+
61
+
62
+ def apply_env_transforms(env, max_episode_steps):
63
+ transformed_env = TransformedEnv(
64
+ env,
65
+ Compose(
66
+ StepCounter(max_steps=max_episode_steps),
67
+ InitTracker(),
68
+ DoubleToFloat(),
69
+ RewardSum(),
70
+ ),
71
+ )
72
+ return transformed_env
73
+
74
+
75
+ def make_environment(cfg, logger=None):
76
+ """Make environments for training and evaluation."""
77
+ partial = functools.partial(env_maker, cfg=cfg)
78
+ parallel_env = ParallelEnv(
79
+ cfg.logger.eval_envs,
80
+ EnvCreator(partial),
81
+ serial_for_single=True,
82
+ )
83
+ parallel_env.set_seed(cfg.env.seed)
84
+
85
+ train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)
86
+ return train_env
87
+
88
+
89
+ # ====================================================================
90
+ # Replay buffer
91
+ # ---------------------------
92
+
93
+
94
+ def make_offline_replay_buffer(rb_cfg, device):
95
+ data = D4RLExperienceReplay(
96
+ dataset_id=rb_cfg.dataset,
97
+ split_trajs=False,
98
+ batch_size=rb_cfg.batch_size,
99
+ # drop_last for compile
100
+ sampler=SamplerWithoutReplacement(drop_last=True),
101
+ prefetch=4,
102
+ direct_download=True,
103
+ )
104
+
105
+ data.append_transform(DoubleToFloat())
106
+ data.append_transform(lambda td: td.to(device))
107
+
108
+ return data
109
+
110
+
111
+ # ====================================================================
112
+ # Model
113
+ # -----
114
+
115
+
116
+ def make_td3_agent(cfg, train_env, device):
117
+ """Make TD3 agent."""
118
+ # Define Actor Network
119
+ in_keys = ["observation"]
120
+ action_spec = train_env.action_spec_unbatched.to(device)
121
+
122
+ actor_net = MLP(
123
+ num_cells=cfg.network.hidden_sizes,
124
+ out_features=action_spec.shape[-1],
125
+ activation_class=get_activation(cfg),
126
+ device=device,
127
+ )
128
+
129
+ in_keys_actor = in_keys
130
+ actor_module = TensorDictModule(
131
+ actor_net,
132
+ in_keys=in_keys_actor,
133
+ out_keys=["param"],
134
+ )
135
+ actor = TensorDictSequential(
136
+ actor_module,
137
+ TanhModule(
138
+ in_keys=["param"],
139
+ out_keys=["action"],
140
+ spec=action_spec,
141
+ ),
142
+ )
143
+
144
+ # Define Critic Network
145
+ qvalue_net = MLP(
146
+ num_cells=cfg.network.hidden_sizes,
147
+ out_features=1,
148
+ activation_class=get_activation(cfg),
149
+ device=device,
150
+ )
151
+
152
+ qvalue = ValueOperator(
153
+ in_keys=["action"] + in_keys,
154
+ module=qvalue_net,
155
+ )
156
+
157
+ model = nn.ModuleList([actor, qvalue])
158
+
159
+ # init nets
160
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
161
+ td = train_env.fake_tensordict()
162
+ td = td.to(device)
163
+ for net in model:
164
+ net(td)
165
+ del td
166
+
167
+ # Exploration wrappers:
168
+ actor_model_explore = TensorDictSequential(
169
+ model[0],
170
+ AdditiveGaussianModule(
171
+ sigma_init=1,
172
+ sigma_end=1,
173
+ mean=0,
174
+ std=0.1,
175
+ spec=action_spec,
176
+ device=device,
177
+ ),
178
+ )
179
+ return model, actor_model_explore
180
+
181
+
182
+ # ====================================================================
183
+ # TD3 Loss
184
+ # ---------
185
+
186
+
187
+ def make_loss_module(cfg, model):
188
+ """Make loss module and target network updater."""
189
+ # Create TD3 loss
190
+ loss_module = TD3BCLoss(
191
+ actor_network=model[0],
192
+ qvalue_network=model[1],
193
+ num_qvalue_nets=2,
194
+ loss_function=cfg.loss_function,
195
+ delay_actor=True,
196
+ delay_qvalue=True,
197
+ action_spec=model[0][1].spec,
198
+ policy_noise=cfg.policy_noise,
199
+ noise_clip=cfg.noise_clip,
200
+ alpha=cfg.alpha,
201
+ )
202
+ loss_module.make_value_estimator(gamma=cfg.gamma)
203
+
204
+ # Define Target Network Updater
205
+ target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak)
206
+ return loss_module, target_net_updater
207
+
208
+
209
+ def make_optimizer(cfg, loss_module):
210
+ critic_params = list(loss_module.qvalue_network_params.values(True, True))
211
+ actor_params = list(loss_module.actor_network_params.values(True, True))
212
+
213
+ optimizer_actor = optim.Adam(
214
+ actor_params,
215
+ lr=cfg.lr,
216
+ weight_decay=cfg.weight_decay,
217
+ eps=cfg.adam_eps,
218
+ )
219
+ optimizer_critic = optim.Adam(
220
+ critic_params,
221
+ lr=cfg.lr,
222
+ weight_decay=cfg.weight_decay,
223
+ eps=cfg.adam_eps,
224
+ )
225
+ return optimizer_actor, optimizer_critic
226
+
227
+
228
+ # ====================================================================
229
+ # General utils
230
+ # ---------
231
+
232
+
233
+ def log_metrics(logger, metrics, step):
234
+ for metric_name, metric_value in metrics.items():
235
+ logger.log_scalar(metric_name, metric_value, step)
236
+
237
+
238
+ def get_activation(cfg):
239
+ if cfg.network.activation == "relu":
240
+ return nn.ReLU
241
+ elif cfg.network.activation == "tanh":
242
+ return nn.Tanh
243
+ elif cfg.network.activation == "leaky_relu":
244
+ return nn.LeakyReLU
245
+ else:
246
+ raise NotImplementedError
247
+
248
+
249
+ def dump_video(module):
250
+ if isinstance(module, VideoRecorder):
251
+ module.dump()
Binary file
torchrl/__init__.py ADDED
@@ -0,0 +1,144 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import warnings
6
+ import weakref
7
+ from warnings import warn
8
+
9
+ import torch
10
+
11
+ # Silence noisy dependency warning triggered at import time on older torch stacks.
12
+ # (Emitted by tensordict when registering pytree nodes.)
13
+ warnings.filterwarnings(
14
+ "ignore",
15
+ category=UserWarning,
16
+ message=r"torch\.utils\._pytree\._register_pytree_node is deprecated\.",
17
+ )
18
+
19
+ from tensordict import set_lazy_legacy # noqa: E402
20
+
21
+ from torch import multiprocessing as mp # noqa: E402
22
+ from torch.distributions.transforms import ( # noqa: E402
23
+ _InverseTransform,
24
+ ComposeTransform,
25
+ )
26
+
27
+ torch._C._log_api_usage_once("torchrl")
28
+
29
+ set_lazy_legacy(False).set()
30
+
31
+ from ._extension import _init_extension # noqa: E402
32
+
33
+ __version__ = None # type: ignore
34
+ try:
35
+ try:
36
+ from importlib.metadata import version as _dist_version
37
+ except ImportError: # pragma: no cover
38
+ from importlib_metadata import version as _dist_version # type: ignore
39
+
40
+ __version__ = _dist_version("torchrl")
41
+ except Exception:
42
+ try:
43
+ from ._version import __version__
44
+ except Exception:
45
+ try:
46
+ from .version import __version__
47
+ except Exception:
48
+ __version__ = None # type: ignore
49
+
50
+ try:
51
+ from torch.compiler import is_dynamo_compiling
52
+ except ImportError:
53
+ from torch._dynamo import is_compiling as is_dynamo_compiling
54
+
55
+ _init_extension()
56
+
57
+ from torchrl._utils import ( # noqa: E402
58
+ _get_default_mp_start_method,
59
+ auto_unwrap_transformed_env,
60
+ compile_with_warmup,
61
+ get_ray_default_runtime_env,
62
+ implement_for,
63
+ logger,
64
+ merge_ray_runtime_env,
65
+ set_auto_unwrap_transformed_env,
66
+ set_profiling_enabled,
67
+ timeit,
68
+ )
69
+
70
+ logger = logger
71
+
72
+ # TorchRL's multiprocessing default.
73
+ _preferred_start_method = _get_default_mp_start_method()
74
+ if _preferred_start_method == "spawn":
75
+ try:
76
+ mp.set_start_method("spawn")
77
+ except RuntimeError as err:
78
+ if str(err).startswith("context has already been set"):
79
+ mp_start_method = mp.get_start_method()
80
+ if mp_start_method != "spawn":
81
+ warn(
82
+ f"failed to set start method to spawn, "
83
+ f"and current start method for mp is {mp_start_method}."
84
+ )
85
+
86
+ # Filter warnings in subprocesses: True by default given the multiple optional
87
+ # deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
88
+ filter_warnings_subprocess = True
89
+
90
+ _THREAD_POOL_INIT = torch.get_num_threads()
91
+
92
+
93
+ # monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home
94
+ @property
95
+ def _inv(self):
96
+ """Patched version of Transform.inv.
97
+
98
+ Returns the inverse :class:`Transform` of this transform.
99
+
100
+ This should satisfy ``t.inv.inv is t``.
101
+ """
102
+ inv = None
103
+ if self._inv is not None:
104
+ inv = self._inv()
105
+ if inv is None:
106
+ inv = _InverseTransform(self)
107
+ if not is_dynamo_compiling():
108
+ self._inv = weakref.ref(inv)
109
+ return inv
110
+
111
+
112
+ torch.distributions.transforms.Transform.inv = _inv
113
+
114
+
115
+ @property
116
+ def _inv(self):
117
+ inv = None
118
+ if self._inv is not None:
119
+ inv = self._inv()
120
+ if inv is None:
121
+ inv = ComposeTransform([p.inv for p in reversed(self.parts)])
122
+ if not is_dynamo_compiling():
123
+ self._inv = weakref.ref(inv)
124
+ inv._inv = weakref.ref(self)
125
+ else:
126
+ # We need inv.inv to be equal to self, but weakref can cause a graph break
127
+ inv._inv = lambda out=self: out
128
+
129
+ return inv
130
+
131
+
132
+ ComposeTransform.inv = _inv
133
+
134
+ __all__ = [
135
+ "auto_unwrap_transformed_env",
136
+ "compile_with_warmup",
137
+ "get_ray_default_runtime_env",
138
+ "implement_for",
139
+ "merge_ray_runtime_env",
140
+ "set_auto_unwrap_transformed_env",
141
+ "timeit",
142
+ "logger",
143
+ "logger",
144
+ ]
torchrl/_extension.py ADDED
@@ -0,0 +1,74 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import warnings
9
+
10
+ from packaging.version import parse
11
+
12
+ __version__ = None # type: ignore
13
+ try:
14
+ try:
15
+ from importlib.metadata import version as _dist_version
16
+ except ImportError: # pragma: no cover
17
+ from importlib_metadata import version as _dist_version # type: ignore
18
+
19
+ __version__ = _dist_version("torchrl")
20
+ except Exception:
21
+ __version__ = None # type: ignore
22
+
23
+ try:
24
+ from .version import pytorch_version
25
+ except ImportError:
26
+ pytorch_version = "unknown"
27
+
28
+
29
+ def is_module_available(*modules: str) -> bool:
30
+ """Returns if a top-level module with :attr:`name` exists *without** importing it.
31
+
32
+ This is generally safer than try-catch block around a
33
+ `import X`. It avoids third party libraries breaking assumptions of some of
34
+ our tests, e.g., setting multiprocessing start method when imported
35
+ (see librosa/#747, torchvision/#544).
36
+ """
37
+ return all(importlib.util.find_spec(m) is not None for m in modules)
38
+
39
+
40
+ def _init_extension():
41
+ if not is_module_available("torchrl._torchrl"):
42
+ warnings.warn("torchrl C++ extension is not available.")
43
+ return
44
+
45
+
46
+ def _is_nightly(version):
47
+ if version is None:
48
+ return True
49
+ parsed_version = parse(version)
50
+ return parsed_version.local is not None
51
+
52
+
53
+ if _is_nightly(__version__):
54
+ EXTENSION_WARNING = (
55
+ "Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. "
56
+ "You seem to be using the nightly version of TorchRL. If this is a local install, there might be an issue with "
57
+ "the local installation. Here are some tips to debug this:\n"
58
+ " - make sure ninja and cmake were installed\n"
59
+ " - make sure you ran `python setup.py clean && python setup.py develop` and that no error was raised\n"
60
+ " - make sure the version of PyTorch you are using matches the one that was present in your virtual env during "
61
+ f"setup. This package was built with PyTorch {pytorch_version}. You can deactivate this warning by setting the environment variable `RL_WARNINGS=0`."
62
+ )
63
+
64
+ else:
65
+ EXTENSION_WARNING = (
66
+ "Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. "
67
+ "This is likely due to a discrepancy between your package version and the PyTorch version. "
68
+ "TorchRL does not tightly pin PyTorch versions to give users freedom, but the trade-off is that C++ extensions like "
69
+ "prioritized replay buffers can only be used with the PyTorch version they were built against. "
70
+ f"This package was built with PyTorch {pytorch_version}. "
71
+ "Workarounds include: (1) upgrading/downgrading PyTorch or TorchRL to compatible versions, "
72
+ "or (2) making a local install using `pip install git+https://github.com/pytorch/rl.git@<version>`. "
73
+ "You can deactivate this warning by setting the environment variable `RL_WARNINGS=0`."
74
+ )
Binary file