torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,319 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import tempfile
9
+ from contextlib import nullcontext
10
+
11
+ import torch
12
+ from tensordict.nn import TensorDictModule, TensorDictSequential
13
+
14
+ from torch import nn, optim
15
+ from torchrl.collectors import SyncDataCollector
16
+ from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
17
+ from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage
18
+ from torchrl.envs import (
19
+ CatTensors,
20
+ Compose,
21
+ DMControlEnv,
22
+ DoubleToFloat,
23
+ EnvCreator,
24
+ InitTracker,
25
+ ParallelEnv,
26
+ RewardSum,
27
+ StepCounter,
28
+ TransformedEnv,
29
+ )
30
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
31
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
32
+ from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator
33
+
34
+ from torchrl.objectives import SoftUpdate
35
+ from torchrl.objectives.td3 import TD3Loss
36
+ from torchrl.record import VideoRecorder
37
+
38
+
39
+ # ====================================================================
40
+ # Environment utils
41
+ # -----------------
42
+
43
+
44
+ def env_maker(cfg, device="cpu", from_pixels=False):
45
+ lib = cfg.env.library
46
+ if lib in ("gym", "gymnasium"):
47
+ with set_gym_backend(lib):
48
+ return GymEnv(
49
+ cfg.env.name,
50
+ device=device,
51
+ from_pixels=from_pixels,
52
+ pixels_only=False,
53
+ )
54
+ elif lib == "dm_control":
55
+ env = DMControlEnv(
56
+ cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
57
+ )
58
+ return TransformedEnv(
59
+ env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
60
+ )
61
+ else:
62
+ raise NotImplementedError(f"Unknown lib {lib}.")
63
+
64
+
65
+ def apply_env_transforms(env, max_episode_steps):
66
+ transformed_env = TransformedEnv(
67
+ env,
68
+ Compose(
69
+ StepCounter(max_steps=max_episode_steps),
70
+ InitTracker(),
71
+ DoubleToFloat(),
72
+ RewardSum(),
73
+ ),
74
+ )
75
+ return transformed_env
76
+
77
+
78
+ def make_environment(cfg, logger, device):
79
+ """Make environments for training and evaluation."""
80
+ partial = functools.partial(env_maker, cfg=cfg)
81
+ parallel_env = ParallelEnv(
82
+ cfg.collector.env_per_collector,
83
+ EnvCreator(partial),
84
+ serial_for_single=True,
85
+ device=device,
86
+ )
87
+ parallel_env.set_seed(cfg.env.seed)
88
+
89
+ train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)
90
+
91
+ partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video)
92
+ trsf_clone = train_env.transform.clone()
93
+ if cfg.logger.video:
94
+ trsf_clone.insert(
95
+ 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
96
+ )
97
+ eval_env = TransformedEnv(
98
+ ParallelEnv(
99
+ 1,
100
+ EnvCreator(partial),
101
+ serial_for_single=True,
102
+ device=device,
103
+ ),
104
+ trsf_clone,
105
+ )
106
+ return train_env, eval_env
107
+
108
+
109
+ # ====================================================================
110
+ # Collector and replay buffer
111
+ # ---------------------------
112
+
113
+
114
+ def make_collector(cfg, train_env, actor_model_explore, compile_mode, device):
115
+ """Make collector."""
116
+ collector_device = cfg.collector.device
117
+ if collector_device in ("", None):
118
+ collector_device = device
119
+ collector = SyncDataCollector(
120
+ train_env,
121
+ actor_model_explore,
122
+ init_random_frames=cfg.collector.init_random_frames,
123
+ frames_per_batch=cfg.collector.frames_per_batch,
124
+ total_frames=cfg.collector.total_frames,
125
+ reset_at_each_iter=cfg.collector.reset_at_each_iter,
126
+ device=collector_device,
127
+ compile_policy={"mode": compile_mode} if compile_mode else False,
128
+ cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
129
+ )
130
+ collector.set_seed(cfg.env.seed)
131
+ return collector
132
+
133
+
134
+ def make_replay_buffer(
135
+ batch_size: int,
136
+ prb: bool = False,
137
+ buffer_size: int = 1000000,
138
+ scratch_dir: str | None = None,
139
+ device: torch.device = "cpu",
140
+ prefetch: int = 3,
141
+ compile: bool = False,
142
+ ):
143
+ if compile:
144
+ prefetch = 0
145
+ if scratch_dir in ("", None):
146
+ ctx = nullcontext(None)
147
+ elif scratch_dir == "temp":
148
+ ctx = tempfile.TemporaryDirectory()
149
+ else:
150
+ ctx = nullcontext(scratch_dir)
151
+ with ctx as scratch_dir:
152
+ storage_cls = (
153
+ functools.partial(LazyTensorStorage, device=device, compilable=compile)
154
+ if not scratch_dir
155
+ else functools.partial(
156
+ LazyMemmapStorage, device="cpu", scratch_dir=scratch_dir
157
+ )
158
+ )
159
+
160
+ if prb:
161
+ replay_buffer = TensorDictPrioritizedReplayBuffer(
162
+ alpha=0.7,
163
+ beta=0.5,
164
+ pin_memory=False,
165
+ prefetch=prefetch,
166
+ storage=storage_cls(buffer_size),
167
+ batch_size=batch_size,
168
+ compilable=compile,
169
+ )
170
+ else:
171
+ replay_buffer = TensorDictReplayBuffer(
172
+ pin_memory=False,
173
+ prefetch=prefetch,
174
+ storage=storage_cls(buffer_size),
175
+ batch_size=batch_size,
176
+ compilable=compile,
177
+ )
178
+ if scratch_dir:
179
+ replay_buffer.append_transform(lambda td: td.to(device))
180
+ return replay_buffer
181
+
182
+
183
+ # ====================================================================
184
+ # Model
185
+ # -----
186
+
187
+
188
+ def make_td3_agent(cfg, train_env, eval_env, device):
189
+ """Make TD3 agent."""
190
+ # Define Actor Network
191
+ in_keys = ["observation"]
192
+ action_spec = train_env.action_spec_unbatched.to(device)
193
+ actor_net = MLP(
194
+ num_cells=cfg.network.hidden_sizes,
195
+ out_features=action_spec.shape[-1],
196
+ activation_class=get_activation(cfg),
197
+ device=device,
198
+ )
199
+
200
+ in_keys_actor = in_keys
201
+ actor_module = TensorDictModule(
202
+ actor_net,
203
+ in_keys=in_keys_actor,
204
+ out_keys=["param"],
205
+ )
206
+ actor = TensorDictSequential(
207
+ actor_module,
208
+ TanhModule(
209
+ in_keys=["param"],
210
+ out_keys=["action"],
211
+ spec=action_spec,
212
+ ),
213
+ )
214
+
215
+ # Define Critic Network
216
+ qvalue_net = MLP(
217
+ num_cells=cfg.network.hidden_sizes,
218
+ out_features=1,
219
+ activation_class=get_activation(cfg),
220
+ device=device,
221
+ )
222
+
223
+ qvalue = ValueOperator(
224
+ in_keys=["action"] + in_keys,
225
+ module=qvalue_net,
226
+ )
227
+
228
+ model = nn.ModuleList([actor, qvalue])
229
+
230
+ # init nets
231
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
232
+ td = eval_env.fake_tensordict()
233
+ td = td.to(device)
234
+ for net in model:
235
+ net(td)
236
+ # Exploration wrappers:
237
+ actor_model_explore = TensorDictSequential(
238
+ actor,
239
+ AdditiveGaussianModule(
240
+ sigma_init=1,
241
+ sigma_end=1,
242
+ mean=0,
243
+ std=0.1,
244
+ spec=action_spec,
245
+ device=device,
246
+ ),
247
+ )
248
+ return model, actor_model_explore
249
+
250
+
251
+ # ====================================================================
252
+ # TD3 Loss
253
+ # ---------
254
+
255
+
256
+ def make_loss_module(cfg, model):
257
+ """Make loss module and target network updater."""
258
+ # Create TD3 loss
259
+ loss_module = TD3Loss(
260
+ actor_network=model[0],
261
+ qvalue_network=model[1],
262
+ num_qvalue_nets=2,
263
+ loss_function=cfg.optim.loss_function,
264
+ delay_actor=True,
265
+ delay_qvalue=True,
266
+ action_spec=model[0][1].spec,
267
+ policy_noise=cfg.optim.policy_noise,
268
+ noise_clip=cfg.optim.noise_clip,
269
+ )
270
+ loss_module.make_value_estimator(gamma=cfg.optim.gamma)
271
+
272
+ # Define Target Network Updater
273
+ target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
274
+ return loss_module, target_net_updater
275
+
276
+
277
+ def make_optimizer(cfg, loss_module):
278
+ critic_params = list(loss_module.qvalue_network_params.flatten_keys().values())
279
+ actor_params = list(loss_module.actor_network_params.flatten_keys().values())
280
+
281
+ optimizer_actor = optim.Adam(
282
+ actor_params,
283
+ lr=cfg.optim.lr,
284
+ weight_decay=cfg.optim.weight_decay,
285
+ eps=cfg.optim.adam_eps,
286
+ )
287
+ optimizer_critic = optim.Adam(
288
+ critic_params,
289
+ lr=cfg.optim.lr,
290
+ weight_decay=cfg.optim.weight_decay,
291
+ eps=cfg.optim.adam_eps,
292
+ )
293
+ return optimizer_actor, optimizer_critic
294
+
295
+
296
+ # ====================================================================
297
+ # General utils
298
+ # ---------
299
+
300
+
301
+ def log_metrics(logger, metrics, step):
302
+ for metric_name, metric_value in metrics.items():
303
+ logger.log_scalar(metric_name, metric_value, step)
304
+
305
+
306
+ def get_activation(cfg):
307
+ if cfg.network.activation == "relu":
308
+ return nn.ReLU
309
+ elif cfg.network.activation == "tanh":
310
+ return nn.Tanh
311
+ elif cfg.network.activation == "leaky_relu":
312
+ return nn.LeakyReLU
313
+ else:
314
+ raise NotImplementedError
315
+
316
+
317
+ def dump_video(module):
318
+ if isinstance(module, VideoRecorder):
319
+ module.dump()
@@ -0,0 +1,177 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """TD3+BC Example.
6
+
7
+ This is a self-contained example of an offline RL TD3+BC training script.
8
+
9
+ The helper functions are coded in the utils.py associated with this script.
10
+
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import warnings
15
+
16
+ import hydra
17
+ import numpy as np
18
+ import torch
19
+ import tqdm
20
+ from tensordict import TensorDict
21
+ from tensordict.nn import CudaGraphModule
22
+ from torchrl._utils import compile_with_warmup, get_available_device, timeit
23
+ from torchrl.envs import set_gym_backend
24
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
25
+ from torchrl.record.loggers import generate_exp_name, get_logger
26
+ from utils import (
27
+ dump_video,
28
+ log_metrics,
29
+ make_environment,
30
+ make_loss_module,
31
+ make_offline_replay_buffer,
32
+ make_optimizer,
33
+ make_td3_agent,
34
+ )
35
+
36
+
37
+ @hydra.main(config_path="", config_name="config")
38
+ def main(cfg: DictConfig): # noqa: F821
39
+ set_gym_backend(cfg.env.library).set()
40
+
41
+ # Create logger
42
+ exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name)
43
+ logger = None
44
+ if cfg.logger.backend:
45
+ logger = get_logger(
46
+ logger_type=cfg.logger.backend,
47
+ logger_name="td3bc_logging",
48
+ experiment_name=exp_name,
49
+ wandb_kwargs={
50
+ "mode": cfg.logger.mode,
51
+ "config": dict(cfg),
52
+ "project": cfg.logger.project_name,
53
+ "group": cfg.logger.group_name,
54
+ },
55
+ )
56
+
57
+ # Set seeds
58
+ torch.manual_seed(cfg.env.seed)
59
+ np.random.seed(cfg.env.seed)
60
+ device = (
61
+ torch.device(cfg.network.device)
62
+ if cfg.network.device
63
+ else get_available_device()
64
+ )
65
+
66
+ # Creante env
67
+ eval_env = make_environment(
68
+ cfg,
69
+ logger=logger,
70
+ )
71
+
72
+ # Create replay buffer
73
+ replay_buffer = make_offline_replay_buffer(cfg.replay_buffer, device=device)
74
+
75
+ compile_mode = None
76
+ if cfg.compile.compile:
77
+ compile_mode = cfg.compile.compile_mode
78
+ if compile_mode in ("", None):
79
+ if cfg.compile.cudagraphs:
80
+ compile_mode = "default"
81
+ else:
82
+ compile_mode = "reduce-overhead"
83
+
84
+ # Create agent
85
+ model, _ = make_td3_agent(cfg, eval_env, device)
86
+
87
+ # Create loss
88
+ loss_module, target_net_updater = make_loss_module(cfg.optim, model)
89
+
90
+ # Create optimizer
91
+ optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module)
92
+
93
+ def update(sampled_tensordict, update_actor):
94
+ # Compute loss
95
+ q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)
96
+
97
+ # Update critic
98
+ q_loss.backward()
99
+ optimizer_critic.step()
100
+ optimizer_critic.zero_grad(set_to_none=True)
101
+
102
+ # Update actor
103
+ if update_actor:
104
+ actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict)
105
+ actor_loss.backward()
106
+ optimizer_actor.step()
107
+ optimizer_actor.zero_grad(set_to_none=True)
108
+
109
+ # Update target params
110
+ target_net_updater.step()
111
+ else:
112
+ actorloss_metadata = {}
113
+ actor_loss = q_loss.new_zeros(())
114
+ metadata = TensorDict(actorloss_metadata)
115
+ metadata.set("q_loss", q_loss.detach())
116
+ metadata.set("actor_loss", actor_loss.detach())
117
+ return metadata
118
+
119
+ if cfg.compile.compile:
120
+ update = compile_with_warmup(update, mode=compile_mode, warmup=1)
121
+
122
+ if cfg.compile.cudagraphs:
123
+ warnings.warn(
124
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
125
+ category=UserWarning,
126
+ )
127
+ update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
128
+
129
+ gradient_steps = cfg.optim.gradient_steps
130
+ evaluation_interval = cfg.logger.eval_iter
131
+ eval_steps = cfg.logger.eval_steps
132
+ delayed_updates = cfg.optim.policy_update_delay
133
+ pbar = tqdm.tqdm(range(gradient_steps))
134
+ # Training loop
135
+ for update_counter in pbar:
136
+ timeit.printevery(num_prints=1000, total_count=gradient_steps, erase=True)
137
+
138
+ # Update actor every delayed_updates
139
+ update_actor = update_counter % delayed_updates == 0
140
+
141
+ with timeit("rb - sample"):
142
+ # Sample from replay buffer
143
+ sampled_tensordict = replay_buffer.sample()
144
+
145
+ with timeit("update"):
146
+ torch.compiler.cudagraph_mark_step_begin()
147
+ metadata = update(sampled_tensordict, update_actor).clone()
148
+
149
+ metrics_to_log = {}
150
+ if update_actor:
151
+ metrics_to_log.update(metadata.to_dict())
152
+ else:
153
+ metrics_to_log.update(metadata.exclude("actor_loss").to_dict())
154
+
155
+ # evaluation
156
+ if update_counter % evaluation_interval == 0:
157
+ with set_exploration_type(
158
+ ExplorationType.DETERMINISTIC
159
+ ), torch.no_grad(), timeit("eval"):
160
+ eval_td = eval_env.rollout(
161
+ max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
162
+ )
163
+ eval_env.apply(dump_video)
164
+ eval_reward = eval_td["next", "reward"].sum(1).mean().item()
165
+ metrics_to_log["evaluation_reward"] = eval_reward
166
+ if logger is not None:
167
+ metrics_to_log.update(timeit.todict(prefix="time"))
168
+ metrics_to_log["time/speed"] = pbar.format_dict["rate"]
169
+ log_metrics(logger, metrics_to_log, update_counter)
170
+
171
+ if not eval_env.is_closed:
172
+ eval_env.close()
173
+ pbar.close()
174
+
175
+
176
+ if __name__ == "__main__":
177
+ main()