torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,238 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch.nn
8
+ import torch.optim
9
+ from tensordict.nn import TensorDictModule
10
+ from torchrl.data.tensor_specs import CategoricalBox
11
+ from torchrl.envs import (
12
+ CatFrames,
13
+ DoubleToFloat,
14
+ EndOfLifeTransform,
15
+ EnvCreator,
16
+ ExplorationType,
17
+ GrayScale,
18
+ GymEnv,
19
+ NoopResetEnv,
20
+ ParallelEnv,
21
+ RenameTransform,
22
+ Resize,
23
+ RewardSum,
24
+ set_gym_backend,
25
+ SignTransform,
26
+ StepCounter,
27
+ ToTensorImage,
28
+ TransformedEnv,
29
+ VecNorm,
30
+ )
31
+ from torchrl.modules import (
32
+ ActorValueOperator,
33
+ ConvNet,
34
+ MLP,
35
+ ProbabilisticActor,
36
+ TanhNormal,
37
+ ValueOperator,
38
+ )
39
+ from torchrl.record import VideoRecorder
40
+
41
+
42
+ # ====================================================================
43
+ # Environment utils
44
+ # --------------------------------------------------------------------
45
+
46
+
47
+ def make_base_env(
48
+ env_name="BreakoutNoFrameskip-v4",
49
+ frame_skip=4,
50
+ gym_backend="gymnasium",
51
+ is_test=False,
52
+ ):
53
+ with set_gym_backend(gym_backend):
54
+ env = GymEnv(
55
+ env_name,
56
+ frame_skip=frame_skip,
57
+ from_pixels=True,
58
+ pixels_only=False,
59
+ device="cpu",
60
+ categorical_action_encoding=True,
61
+ )
62
+ env = TransformedEnv(env)
63
+ env.append_transform(NoopResetEnv(noops=30, random=True))
64
+ if not is_test:
65
+ env.append_transform(EndOfLifeTransform())
66
+ return env
67
+
68
+
69
+ def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False):
70
+ env = ParallelEnv(
71
+ num_envs,
72
+ EnvCreator(lambda: make_base_env(env_name, gym_backend=gym_backend)),
73
+ serial_for_single=True,
74
+ device=device,
75
+ )
76
+ env = TransformedEnv(env)
77
+ env.append_transform(RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"]))
78
+ env.append_transform(ToTensorImage(in_keys=["pixels_int"], out_keys=["pixels"]))
79
+ env.append_transform(GrayScale())
80
+ env.append_transform(Resize(84, 84))
81
+ env.append_transform(CatFrames(N=4, dim=-3))
82
+ env.append_transform(RewardSum())
83
+ env.append_transform(StepCounter(max_steps=4500))
84
+ if not is_test:
85
+ env.append_transform(SignTransform(in_keys=["reward"]))
86
+ env.append_transform(DoubleToFloat())
87
+ env.append_transform(VecNorm(in_keys=["pixels"]))
88
+ return env
89
+
90
+
91
+ # ====================================================================
92
+ # Model utils
93
+ # --------------------------------------------------------------------
94
+
95
+
96
+ def make_ppo_modules_pixels(proof_environment, device):
97
+
98
+ # Define input shape
99
+ input_shape = proof_environment.observation_spec["pixels"].shape
100
+
101
+ # Define distribution class and kwargs
102
+ if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox):
103
+ num_outputs = proof_environment.action_spec_unbatched.space.n
104
+ distribution_class = torch.distributions.Categorical
105
+ distribution_kwargs = {}
106
+ else: # is ContinuousBox
107
+ num_outputs = proof_environment.action_spec_unbatched.shape
108
+ distribution_class = TanhNormal
109
+ distribution_kwargs = {
110
+ "low": proof_environment.action_spec_unbatched.space.low.to(device),
111
+ "high": proof_environment.action_spec_unbatched.space.high.to(device),
112
+ }
113
+
114
+ # Define input keys
115
+ in_keys = ["pixels"]
116
+
117
+ # Define a shared Module and TensorDictModule (CNN + MLP)
118
+ common_cnn = ConvNet(
119
+ activation_class=torch.nn.ReLU,
120
+ num_cells=[32, 64, 64],
121
+ kernel_sizes=[8, 4, 3],
122
+ strides=[4, 2, 1],
123
+ device=device,
124
+ )
125
+ common_cnn_output = common_cnn(torch.ones(input_shape, device=device))
126
+ common_mlp = MLP(
127
+ in_features=common_cnn_output.shape[-1],
128
+ activation_class=torch.nn.ReLU,
129
+ activate_last_layer=True,
130
+ out_features=512,
131
+ num_cells=[],
132
+ device=device,
133
+ )
134
+ common_mlp_output = common_mlp(common_cnn_output)
135
+
136
+ # Define shared net as TensorDictModule
137
+ common_module = TensorDictModule(
138
+ module=torch.nn.Sequential(common_cnn, common_mlp),
139
+ in_keys=in_keys,
140
+ out_keys=["common_features"],
141
+ )
142
+
143
+ # Define on head for the policy
144
+ policy_net = MLP(
145
+ in_features=common_mlp_output.shape[-1],
146
+ out_features=num_outputs,
147
+ activation_class=torch.nn.ReLU,
148
+ num_cells=[],
149
+ device=device,
150
+ )
151
+ policy_module = TensorDictModule(
152
+ module=policy_net,
153
+ in_keys=["common_features"],
154
+ out_keys=["logits"],
155
+ )
156
+
157
+ # Add probabilistic sampling of the actions
158
+ policy_module = ProbabilisticActor(
159
+ policy_module,
160
+ in_keys=["logits"],
161
+ spec=proof_environment.full_action_spec_unbatched.to(device),
162
+ distribution_class=distribution_class,
163
+ distribution_kwargs=distribution_kwargs,
164
+ return_log_prob=True,
165
+ default_interaction_type=ExplorationType.RANDOM,
166
+ )
167
+
168
+ # Define another head for the value
169
+ value_net = MLP(
170
+ activation_class=torch.nn.ReLU,
171
+ in_features=common_mlp_output.shape[-1],
172
+ out_features=1,
173
+ num_cells=[],
174
+ device=device,
175
+ )
176
+ value_module = ValueOperator(
177
+ value_net,
178
+ in_keys=["common_features"],
179
+ )
180
+
181
+ return common_module, policy_module, value_module
182
+
183
+
184
+ def make_ppo_models(env_name, device, gym_backend):
185
+
186
+ proof_environment = make_parallel_env(
187
+ env_name, 1, device=device, gym_backend=gym_backend
188
+ )
189
+ common_module, policy_module, value_module = make_ppo_modules_pixels(
190
+ proof_environment,
191
+ device=device,
192
+ )
193
+
194
+ # Wrap modules in a single ActorCritic operator
195
+ actor_critic = ActorValueOperator(
196
+ common_operator=common_module,
197
+ policy_operator=policy_module,
198
+ value_operator=value_module,
199
+ )
200
+
201
+ with torch.no_grad():
202
+ td = proof_environment.fake_tensordict().expand(10)
203
+ actor_critic(td)
204
+ del td
205
+
206
+ actor = actor_critic.get_policy_operator()
207
+ critic = actor_critic.get_value_operator()
208
+
209
+ del proof_environment
210
+
211
+ return actor, critic
212
+
213
+
214
+ # ====================================================================
215
+ # Evaluation utils
216
+ # --------------------------------------------------------------------
217
+
218
+
219
+ def dump_video(module):
220
+ if isinstance(module, VideoRecorder):
221
+ module.dump()
222
+
223
+
224
+ def eval_model(actor, test_env, num_episodes=3):
225
+ test_rewards = []
226
+ for _ in range(num_episodes):
227
+ td_test = test_env.rollout(
228
+ policy=actor,
229
+ auto_reset=True,
230
+ auto_cast_to_device=True,
231
+ break_when_any_done=True,
232
+ max_steps=10_000_000,
233
+ )
234
+ test_env.apply(dump_video)
235
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
236
+ test_rewards.append(reward.cpu())
237
+ del td_test
238
+ return torch.cat(test_rewards, 0).mean()
@@ -0,0 +1,152 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch.nn
8
+ import torch.optim
9
+
10
+ from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
11
+ from torchrl.envs import (
12
+ ClipTransform,
13
+ DoubleToFloat,
14
+ ExplorationType,
15
+ RewardSum,
16
+ StepCounter,
17
+ TransformedEnv,
18
+ VecNorm,
19
+ )
20
+ from torchrl.envs.libs.gym import GymEnv
21
+ from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
22
+ from torchrl.record import VideoRecorder
23
+
24
+
25
+ # ====================================================================
26
+ # Environment utils
27
+ # --------------------------------------------------------------------
28
+
29
+
30
+ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False):
31
+ env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
32
+ env = TransformedEnv(env)
33
+ env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
34
+ env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
35
+ env.append_transform(RewardSum())
36
+ env.append_transform(StepCounter())
37
+ env.append_transform(DoubleToFloat(in_keys=["observation"]))
38
+ return env
39
+
40
+
41
+ # ====================================================================
42
+ # Model utils
43
+ # --------------------------------------------------------------------
44
+
45
+
46
+ def make_ppo_models_state(proof_environment, device):
47
+
48
+ # Define input shape
49
+ input_shape = proof_environment.observation_spec["observation"].shape
50
+
51
+ # Define policy output distribution class
52
+ num_outputs = proof_environment.action_spec_unbatched.shape[-1]
53
+ distribution_class = TanhNormal
54
+ distribution_kwargs = {
55
+ "low": proof_environment.action_spec_unbatched.space.low.to(device),
56
+ "high": proof_environment.action_spec_unbatched.space.high.to(device),
57
+ "tanh_loc": False,
58
+ }
59
+
60
+ # Define policy architecture
61
+ policy_mlp = MLP(
62
+ in_features=input_shape[-1],
63
+ activation_class=torch.nn.Tanh,
64
+ out_features=num_outputs, # predict only loc
65
+ num_cells=[64, 64],
66
+ device=device,
67
+ )
68
+
69
+ # Initialize policy weights
70
+ for layer in policy_mlp.modules():
71
+ if isinstance(layer, torch.nn.Linear):
72
+ torch.nn.init.orthogonal_(layer.weight, 1.0)
73
+ layer.bias.data.zero_()
74
+
75
+ # Add state-independent normal scale
76
+ policy_mlp = torch.nn.Sequential(
77
+ policy_mlp,
78
+ AddStateIndependentNormalScale(
79
+ proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8
80
+ ).to(device),
81
+ )
82
+
83
+ # Add probabilistic sampling of the actions
84
+ policy_module = ProbabilisticActor(
85
+ TensorDictModule(
86
+ module=policy_mlp,
87
+ in_keys=["observation"],
88
+ out_keys=["loc", "scale"],
89
+ ),
90
+ in_keys=["loc", "scale"],
91
+ spec=proof_environment.full_action_spec_unbatched.to(device),
92
+ distribution_class=distribution_class,
93
+ distribution_kwargs=distribution_kwargs,
94
+ return_log_prob=True,
95
+ default_interaction_type=ExplorationType.RANDOM,
96
+ )
97
+
98
+ # Define value architecture
99
+ value_mlp = MLP(
100
+ in_features=input_shape[-1],
101
+ activation_class=torch.nn.Tanh,
102
+ out_features=1,
103
+ num_cells=[64, 64],
104
+ device=device,
105
+ )
106
+
107
+ # Initialize value weights
108
+ for layer in value_mlp.modules():
109
+ if isinstance(layer, torch.nn.Linear):
110
+ torch.nn.init.orthogonal_(layer.weight, 0.01)
111
+ layer.bias.data.zero_()
112
+
113
+ # Define value module
114
+ value_module = ValueOperator(
115
+ value_mlp,
116
+ in_keys=["observation"],
117
+ )
118
+
119
+ return policy_module, value_module
120
+
121
+
122
+ def make_ppo_models(env_name, device):
123
+ proof_environment = make_env(env_name, device=device)
124
+ actor, critic = make_ppo_models_state(proof_environment, device=device)
125
+ return actor, critic
126
+
127
+
128
+ # ====================================================================
129
+ # Evaluation utils
130
+ # --------------------------------------------------------------------
131
+
132
+
133
+ def dump_video(module):
134
+ if isinstance(module, VideoRecorder):
135
+ module.dump()
136
+
137
+
138
+ def eval_model(actor, test_env, num_episodes=3):
139
+ test_rewards = []
140
+ for _ in range(num_episodes):
141
+ td_test = test_env.rollout(
142
+ policy=actor,
143
+ auto_reset=True,
144
+ auto_cast_to_device=True,
145
+ break_when_any_done=True,
146
+ max_steps=10_000_000,
147
+ )
148
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
149
+ test_rewards.append(reward.cpu())
150
+ test_env.apply(dump_video)
151
+ del td_test
152
+ return torch.cat(test_rewards, 0).mean()
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import hydra
6
+ import torchrl
7
+ from torchrl.trainers.algorithms.configs import * # noqa: F401, F403
8
+
9
+
10
+ @hydra.main(config_path="config", config_name="config", version_base="1.1")
11
+ def main(cfg):
12
+ def print_reward(td):
13
+ torchrl.logger.info(f"reward: {td['next', 'reward'].mean(): 4.4f}")
14
+
15
+ trainer = hydra.utils.instantiate(cfg.trainer)
16
+ trainer.register_op(dest="batch_process", op=print_reward)
17
+ trainer.train()
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
@@ -0,0 +1,7 @@
1
+ # REDQ example
2
+
3
+ ## Note:
4
+ This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the
5
+ benchmarking of future releases, to ensure that it can be successfully run with the release code and that the
6
+ results are consistent. For now, be aware that this additional check has not been performed in the case of this
7
+ specific example.
@@ -0,0 +1,199 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import uuid
8
+ from datetime import datetime
9
+
10
+ import hydra
11
+ import torch.cuda
12
+ from tensordict.nn import TensorDictSequential
13
+ from torchrl.envs import EnvCreator, ParallelEnv
14
+ from torchrl.envs.transforms import RewardScaling, TransformedEnv
15
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
16
+ from torchrl.modules import OrnsteinUhlenbeckProcessModule
17
+ from torchrl.record import VideoRecorder
18
+ from torchrl.record.loggers import get_logger
19
+ from utils import (
20
+ correct_for_frame_skip,
21
+ get_norm_state_dict,
22
+ initialize_observation_norm_transforms,
23
+ make_collector_offpolicy,
24
+ make_redq_loss,
25
+ make_redq_model,
26
+ make_replay_buffer,
27
+ make_trainer,
28
+ parallel_env_constructor,
29
+ retrieve_observation_norms_state_dict,
30
+ transformed_env_constructor,
31
+ )
32
+
33
+ DEFAULT_REWARD_SCALING = {
34
+ "Hopper-v1": 5,
35
+ "Walker2d-v1": 5,
36
+ "HalfCheetah-v1": 5,
37
+ "cheetah": 5,
38
+ "Ant-v2": 5,
39
+ "Humanoid-v2": 20,
40
+ "humanoid": 100,
41
+ }
42
+
43
+
44
+ @hydra.main(version_base="1.1", config_path="", config_name="config")
45
+ def main(cfg: DictConfig): # noqa: F821
46
+
47
+ cfg = correct_for_frame_skip(cfg)
48
+
49
+ if not isinstance(cfg.env.reward_scaling, float):
50
+ cfg.env.reward_scaling = DEFAULT_REWARD_SCALING.get(cfg.env.name, 5.0)
51
+ cfg.env.reward_loc = 0.0
52
+
53
+ device = (
54
+ torch.device("cpu")
55
+ if torch.cuda.device_count() == 0
56
+ else torch.device("cuda:0")
57
+ )
58
+
59
+ exp_name = "_".join(
60
+ [
61
+ "REDQ",
62
+ cfg.logger.exp_name,
63
+ str(uuid.uuid4())[:8],
64
+ datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
65
+ ]
66
+ )
67
+
68
+ if cfg.logger.backend:
69
+ logger = get_logger(
70
+ logger_type=cfg.logger.backend,
71
+ logger_name="redq_logging",
72
+ experiment_name=exp_name,
73
+ wandb_kwargs={
74
+ "mode": cfg.logger.mode,
75
+ "config": dict(cfg),
76
+ "project": cfg.logger.project_name,
77
+ "group": cfg.logger.group_name,
78
+ },
79
+ )
80
+ else:
81
+ logger = None
82
+
83
+ key, init_env_steps, stats = None, None, None
84
+ if not cfg.env.vecnorm and cfg.env.norm_stats:
85
+ key = (
86
+ ("next", "pixels")
87
+ if cfg.env.from_pixels
88
+ else ("next", "observation_vector")
89
+ )
90
+ init_env_steps = cfg.env.init_env_steps
91
+ stats = {"loc": None, "scale": None}
92
+ elif cfg.env.from_pixels:
93
+ stats = {"loc": 0.5, "scale": 0.5}
94
+
95
+ proof_env = transformed_env_constructor(
96
+ cfg=cfg,
97
+ use_env_creator=False,
98
+ stats=stats,
99
+ )()
100
+ initialize_observation_norm_transforms(
101
+ proof_environment=proof_env, num_iter=init_env_steps, key=key
102
+ )
103
+ _, obs_norm_state_dict = retrieve_observation_norms_state_dict(proof_env)[0]
104
+
105
+ model = make_redq_model(
106
+ proof_env,
107
+ cfg=cfg,
108
+ device=device,
109
+ )
110
+ loss_module, target_net_updater = make_redq_loss(model, cfg)
111
+
112
+ actor_model_explore = model[0]
113
+ if cfg.exploration.ou_exploration:
114
+ if cfg.exploration.gSDE:
115
+ raise RuntimeError("gSDE and ou_exploration are incompatible")
116
+ actor_model_explore = TensorDictSequential(
117
+ actor_model_explore,
118
+ OrnsteinUhlenbeckProcessModule(
119
+ spec=actor_model_explore.spec,
120
+ annealing_num_steps=cfg.exploration.annealing_frames,
121
+ sigma=cfg.exploration.ou_sigma,
122
+ theta=cfg.exploration.ou_theta,
123
+ device=device,
124
+ ),
125
+ )
126
+ if device == torch.device("cpu"):
127
+ # mostly for debugging
128
+ actor_model_explore.share_memory()
129
+
130
+ if cfg.exploration.gSDE:
131
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
132
+ # get dimensions to build the parallel env
133
+ proof_td = actor_model_explore(proof_env.reset().to(device))
134
+ action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
135
+ del proof_td
136
+ else:
137
+ action_dim_gsde, state_dim_gsde = None, None
138
+
139
+ proof_env.close()
140
+ create_env_fn = parallel_env_constructor(
141
+ cfg=cfg,
142
+ obs_norm_state_dict=obs_norm_state_dict,
143
+ action_dim_gsde=action_dim_gsde,
144
+ state_dim_gsde=state_dim_gsde,
145
+ )
146
+
147
+ collector = make_collector_offpolicy(
148
+ make_env=create_env_fn,
149
+ actor_model_explore=actor_model_explore,
150
+ cfg=cfg,
151
+ )
152
+
153
+ replay_buffer = make_replay_buffer("cpu", cfg)
154
+
155
+ recorder = transformed_env_constructor(
156
+ cfg,
157
+ video_tag="rendering/test",
158
+ norm_obs_only=True,
159
+ obs_norm_state_dict=obs_norm_state_dict,
160
+ logger=logger,
161
+ use_env_creator=False,
162
+ )()
163
+ if isinstance(create_env_fn, ParallelEnv):
164
+ raise NotImplementedError("This behavior is deprecated")
165
+ elif isinstance(create_env_fn, EnvCreator):
166
+ recorder.transform[1:].load_state_dict(
167
+ get_norm_state_dict(create_env_fn()), strict=False
168
+ )
169
+ elif isinstance(create_env_fn, TransformedEnv):
170
+ recorder.transform = create_env_fn.transform.clone()
171
+ else:
172
+ raise NotImplementedError(f"Unsupported env type {type(create_env_fn)}")
173
+ if logger is not None and cfg.logger.video:
174
+ recorder.insert_transform(0, VideoRecorder(logger=logger, tag="rendering/test"))
175
+
176
+ # reset reward scaling
177
+ for t in recorder.transform:
178
+ if isinstance(t, RewardScaling):
179
+ t.scale.fill_(1.0)
180
+ t.loc.fill_(0.0)
181
+
182
+ trainer = make_trainer(
183
+ collector=collector,
184
+ loss_module=loss_module,
185
+ recorder=recorder,
186
+ target_net_updater=target_net_updater,
187
+ policy_exploration=actor_model_explore,
188
+ replay_buffer=replay_buffer,
189
+ logger=logger,
190
+ cfg=cfg,
191
+ )
192
+
193
+ trainer.train()
194
+ if logger is not None:
195
+ return (logger.log_dir, trainer._log_dict)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ main()