torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,240 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import numpy as np
8
+ import torch.nn
9
+ import torch.optim
10
+ from tensordict.nn import TensorDictModule
11
+ from torchrl.data.tensor_specs import CategoricalBox
12
+ from torchrl.envs import (
13
+ CatFrames,
14
+ DoubleToFloat,
15
+ EndOfLifeTransform,
16
+ EnvCreator,
17
+ ExplorationType,
18
+ GrayScale,
19
+ GymEnv,
20
+ NoopResetEnv,
21
+ ParallelEnv,
22
+ Resize,
23
+ RewardSum,
24
+ set_gym_backend,
25
+ SignTransform,
26
+ StepCounter,
27
+ ToTensorImage,
28
+ TransformedEnv,
29
+ VecNorm,
30
+ )
31
+ from torchrl.modules import (
32
+ ActorValueOperator,
33
+ ConvNet,
34
+ MLP,
35
+ OneHotCategorical,
36
+ ProbabilisticActor,
37
+ TanhNormal,
38
+ ValueOperator,
39
+ )
40
+ from torchrl.record import VideoRecorder
41
+
42
+
43
+ # ====================================================================
44
+ # Environment utils
45
+ # --------------------------------------------------------------------
46
+
47
+
48
+ def make_base_env(
49
+ env_name="BreakoutNoFrameskip-v4",
50
+ gym_backend="gymnasium",
51
+ frame_skip=4,
52
+ device="cpu",
53
+ is_test=False,
54
+ ):
55
+ with set_gym_backend(gym_backend):
56
+ env = GymEnv(
57
+ env_name,
58
+ frame_skip=frame_skip,
59
+ from_pixels=True,
60
+ pixels_only=False,
61
+ device=device,
62
+ )
63
+ env = TransformedEnv(env)
64
+ env.append_transform(NoopResetEnv(noops=30, random=True))
65
+ if not is_test:
66
+ env.append_transform(EndOfLifeTransform())
67
+ return env
68
+
69
+
70
+ def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False):
71
+ env = ParallelEnv(
72
+ num_envs,
73
+ EnvCreator(
74
+ lambda: make_base_env(env_name, gym_backend=gym_backend, is_test=is_test),
75
+ ),
76
+ serial_for_single=True,
77
+ device=device,
78
+ )
79
+ env = TransformedEnv(env)
80
+ env.append_transform(DoubleToFloat())
81
+ env.append_transform(ToTensorImage())
82
+ env.append_transform(GrayScale())
83
+ env.append_transform(Resize(84, 84))
84
+ env.append_transform(CatFrames(N=4, dim=-3))
85
+ env.append_transform(RewardSum())
86
+ env.append_transform(StepCounter(max_steps=4500))
87
+ if not is_test:
88
+ env.append_transform(SignTransform(in_keys=["reward"]))
89
+ env.append_transform(VecNorm(in_keys=["pixels"]))
90
+ return env
91
+
92
+
93
+ # ====================================================================
94
+ # Model utils
95
+ # --------------------------------------------------------------------
96
+
97
+
98
+ def make_ppo_modules_pixels(proof_environment, device):
99
+
100
+ # Define input shape
101
+ input_shape = proof_environment.observation_spec["pixels"].shape
102
+
103
+ # Define distribution class and kwargs
104
+ if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox):
105
+ num_outputs = proof_environment.action_spec_unbatched.space.n
106
+ distribution_class = OneHotCategorical
107
+ distribution_kwargs = {}
108
+ else: # is ContinuousBox
109
+ num_outputs = proof_environment.action_spec_unbatched.shape
110
+ distribution_class = TanhNormal
111
+ distribution_kwargs = {
112
+ "low": proof_environment.action_spec_unbatched.space.low.to(device),
113
+ "high": proof_environment.action_spec_unbatched.space.high.to(device),
114
+ }
115
+
116
+ # Define input keys
117
+ in_keys = ["pixels"]
118
+
119
+ # Define a shared Module and TensorDictModule (CNN + MLP)
120
+ common_cnn = ConvNet(
121
+ activation_class=torch.nn.ReLU,
122
+ num_cells=[32, 64, 64],
123
+ kernel_sizes=[8, 4, 3],
124
+ strides=[4, 2, 1],
125
+ device=device,
126
+ )
127
+ common_cnn_output = common_cnn(torch.ones(input_shape, device=device))
128
+ common_mlp = MLP(
129
+ in_features=common_cnn_output.shape[-1],
130
+ activation_class=torch.nn.ReLU,
131
+ activate_last_layer=True,
132
+ out_features=512,
133
+ num_cells=[],
134
+ device=device,
135
+ )
136
+ common_mlp_output = common_mlp(common_cnn_output)
137
+
138
+ # Define shared net as TensorDictModule
139
+ common_module = TensorDictModule(
140
+ module=torch.nn.Sequential(common_cnn, common_mlp),
141
+ in_keys=in_keys,
142
+ out_keys=["common_features"],
143
+ )
144
+
145
+ # Define on head for the policy
146
+ policy_net = MLP(
147
+ in_features=common_mlp_output.shape[-1],
148
+ out_features=num_outputs,
149
+ activation_class=torch.nn.ReLU,
150
+ num_cells=[],
151
+ device=device,
152
+ )
153
+ policy_module = TensorDictModule(
154
+ module=policy_net,
155
+ in_keys=["common_features"],
156
+ out_keys=["logits"],
157
+ )
158
+
159
+ # Add probabilistic sampling of the actions
160
+ policy_module = ProbabilisticActor(
161
+ policy_module,
162
+ in_keys=["logits"],
163
+ spec=proof_environment.full_action_spec_unbatched.to(device),
164
+ distribution_class=distribution_class,
165
+ distribution_kwargs=distribution_kwargs,
166
+ return_log_prob=True,
167
+ default_interaction_type=ExplorationType.RANDOM,
168
+ )
169
+
170
+ # Define another head for the value
171
+ value_net = MLP(
172
+ activation_class=torch.nn.ReLU,
173
+ in_features=common_mlp_output.shape[-1],
174
+ out_features=1,
175
+ num_cells=[],
176
+ device=device,
177
+ )
178
+ value_module = ValueOperator(
179
+ value_net,
180
+ in_keys=["common_features"],
181
+ )
182
+
183
+ return common_module, policy_module, value_module
184
+
185
+
186
+ def make_ppo_models(env_name, device, gym_backend):
187
+
188
+ proof_environment = make_parallel_env(
189
+ env_name, num_envs=1, device="cpu", gym_backend=gym_backend
190
+ )
191
+ common_module, policy_module, value_module = make_ppo_modules_pixels(
192
+ proof_environment, device=device
193
+ )
194
+
195
+ # Wrap modules in a single ActorCritic operator
196
+ actor_critic = ActorValueOperator(
197
+ common_operator=common_module,
198
+ policy_operator=policy_module,
199
+ value_operator=value_module,
200
+ )
201
+
202
+ with torch.no_grad():
203
+ td = proof_environment.fake_tensordict().expand(1)
204
+ td = actor_critic(td.to(device))
205
+ del td
206
+
207
+ actor = actor_critic.get_policy_operator()
208
+ critic = actor_critic.get_value_operator()
209
+ critic_head = actor_critic.get_value_head()
210
+
211
+ del proof_environment
212
+
213
+ return actor, critic, critic_head
214
+
215
+
216
+ # ====================================================================
217
+ # Evaluation utils
218
+ # --------------------------------------------------------------------
219
+
220
+
221
+ def dump_video(module):
222
+ if isinstance(module, VideoRecorder):
223
+ module.dump()
224
+
225
+
226
+ def eval_model(actor, test_env, num_episodes=3):
227
+ test_rewards = []
228
+ for _ in range(num_episodes):
229
+ td_test = test_env.rollout(
230
+ policy=actor,
231
+ auto_reset=True,
232
+ auto_cast_to_device=True,
233
+ break_when_any_done=True,
234
+ max_steps=10_000_000,
235
+ )
236
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
237
+ test_rewards = np.append(test_rewards, reward.cpu().numpy())
238
+ test_env.apply(dump_video)
239
+ del td_test
240
+ return test_rewards.mean()
@@ -0,0 +1,160 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import numpy as np
8
+ import torch.nn
9
+ import torch.optim
10
+
11
+ from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
12
+ from torchrl.envs import (
13
+ ClipTransform,
14
+ DoubleToFloat,
15
+ ExplorationType,
16
+ RewardSum,
17
+ StepCounter,
18
+ TransformedEnv,
19
+ VecNorm,
20
+ )
21
+ from torchrl.envs.libs.gym import GymEnv
22
+ from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
23
+ from torchrl.record import VideoRecorder
24
+
25
+
26
+ # ====================================================================
27
+ # Environment utils
28
+ # --------------------------------------------------------------------
29
+
30
+
31
+ def make_env(
32
+ env_name="HalfCheetah-v4", device="cpu", from_pixels=False, pixels_only=False
33
+ ):
34
+ env = GymEnv(
35
+ env_name, device=device, from_pixels=from_pixels, pixels_only=pixels_only
36
+ )
37
+ env = TransformedEnv(env)
38
+ env.append_transform(RewardSum())
39
+ env.append_transform(StepCounter())
40
+ env.append_transform(VecNorm(in_keys=["observation"]))
41
+ env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
42
+ env.append_transform(DoubleToFloat(in_keys=["observation"]))
43
+ return env
44
+
45
+
46
+ # ====================================================================
47
+ # Model utils
48
+ # --------------------------------------------------------------------
49
+
50
+
51
+ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
52
+
53
+ # Define input shape
54
+ input_shape = proof_environment.observation_spec["observation"].shape
55
+
56
+ # Define policy output distribution class
57
+ num_outputs = proof_environment.action_spec_unbatched.shape[-1]
58
+ distribution_class = TanhNormal
59
+ distribution_kwargs = {
60
+ "low": proof_environment.action_spec_unbatched.space.low.to(device),
61
+ "high": proof_environment.action_spec_unbatched.space.high.to(device),
62
+ "tanh_loc": False,
63
+ "safe_tanh": True,
64
+ }
65
+
66
+ # Define policy architecture
67
+ policy_mlp = MLP(
68
+ in_features=input_shape[-1],
69
+ activation_class=torch.nn.Tanh,
70
+ out_features=num_outputs, # predict only loc
71
+ num_cells=[64, 64],
72
+ device=device,
73
+ )
74
+
75
+ # Initialize policy weights
76
+ for layer in policy_mlp.modules():
77
+ if isinstance(layer, torch.nn.Linear):
78
+ torch.nn.init.orthogonal_(layer.weight, 1.0)
79
+ layer.bias.data.zero_()
80
+
81
+ # Add state-independent normal scale
82
+ policy_mlp = torch.nn.Sequential(
83
+ policy_mlp,
84
+ AddStateIndependentNormalScale(
85
+ proof_environment.action_spec_unbatched.shape[-1], device=device
86
+ ),
87
+ )
88
+
89
+ # Add probabilistic sampling of the actions
90
+ policy_module = ProbabilisticActor(
91
+ TensorDictModule(
92
+ module=policy_mlp,
93
+ in_keys=["observation"],
94
+ out_keys=["loc", "scale"],
95
+ ),
96
+ in_keys=["loc", "scale"],
97
+ spec=proof_environment.full_action_spec_unbatched.to(device),
98
+ distribution_class=distribution_class,
99
+ distribution_kwargs=distribution_kwargs,
100
+ return_log_prob=True,
101
+ default_interaction_type=ExplorationType.RANDOM,
102
+ )
103
+
104
+ # Define value architecture
105
+ value_mlp = MLP(
106
+ in_features=input_shape[-1],
107
+ activation_class=torch.nn.Tanh,
108
+ out_features=1,
109
+ num_cells=[64, 64],
110
+ device=device,
111
+ )
112
+
113
+ # Initialize value weights
114
+ for layer in value_mlp.modules():
115
+ if isinstance(layer, torch.nn.Linear):
116
+ torch.nn.init.orthogonal_(layer.weight, 0.01)
117
+ layer.bias.data.zero_()
118
+
119
+ # Define value module
120
+ value_module = ValueOperator(
121
+ value_mlp,
122
+ in_keys=["observation"],
123
+ )
124
+
125
+ return policy_module, value_module
126
+
127
+
128
+ def make_ppo_models(env_name, device, *, compile: bool = False):
129
+ proof_environment = make_env(env_name, device="cpu")
130
+ actor, critic = make_ppo_models_state(
131
+ proof_environment, device=device, compile=compile
132
+ )
133
+ return actor, critic
134
+
135
+
136
+ # ====================================================================
137
+ # Evaluation utils
138
+ # --------------------------------------------------------------------
139
+
140
+
141
+ def dump_video(module):
142
+ if isinstance(module, VideoRecorder):
143
+ module.dump()
144
+
145
+
146
+ def eval_model(actor, test_env, num_episodes=3):
147
+ test_rewards = []
148
+ for _ in range(num_episodes):
149
+ td_test = test_env.rollout(
150
+ policy=actor,
151
+ auto_reset=True,
152
+ auto_cast_to_device=True,
153
+ break_when_any_done=True,
154
+ max_steps=10_000_000,
155
+ )
156
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
157
+ test_rewards = np.append(test_rewards, reward.cpu().numpy())
158
+ test_env.apply(dump_video)
159
+ del td_test
160
+ return test_rewards.mean()
@@ -0,0 +1,7 @@
1
+ # Bandits example
2
+
3
+ ## Note:
4
+ This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the
5
+ benchmarking of future releases, to ensure that it can be successfully run with the release code and that the
6
+ results are consistent. For now, be aware that this additional check has not been performed in the case of this
7
+ specific example.
@@ -0,0 +1,126 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import argparse
8
+
9
+ import torch
10
+ import tqdm
11
+
12
+ from tensordict.nn import TensorDictSequential
13
+ from torch import nn
14
+ from torchrl.envs.libs.openml import OpenMLEnv
15
+ from torchrl.envs.utils import ExplorationType, set_exploration_type
16
+ from torchrl.modules import DistributionalQValueActor, EGreedyModule, MLP, QValueActor
17
+ from torchrl.objectives import DistributionalDQNLoss, DQNLoss
18
+
19
+ parser = argparse.ArgumentParser()
20
+
21
+ # Add arguments
22
+ parser.add_argument("--batch_size", type=int, default=256, help="batch size")
23
+ parser.add_argument("--n_steps", type=int, default=10000, help="number of steps")
24
+ parser.add_argument(
25
+ "--eps_greedy", type=float, default=0.1, help="epsilon-greedy parameter"
26
+ )
27
+ parser.add_argument("--lr", type=float, default=2e-4, help="learning rate")
28
+ parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
29
+ parser.add_argument("--n_cells", type=int, default=128, help="number of cells")
30
+ parser.add_argument(
31
+ "--distributional", action="store_true", help="enable distributional Q-learning"
32
+ )
33
+ parser.add_argument(
34
+ "--dataset",
35
+ default="adult_onehot",
36
+ choices=[
37
+ "adult_num",
38
+ "adult_onehot",
39
+ "mushroom_num",
40
+ "mushroom_onehot",
41
+ "covertype",
42
+ "shuttle",
43
+ "magic",
44
+ ],
45
+ help="OpenML dataset",
46
+ )
47
+
48
+ if __name__ == "__main__":
49
+ # Parse arguments
50
+ args = parser.parse_args()
51
+
52
+ # Access arguments
53
+ batch_size = args.batch_size
54
+ n_steps = args.n_steps
55
+ eps_greedy = args.eps_greedy
56
+ lr = args.lr
57
+ wd = args.wd
58
+ n_cells = args.n_cells
59
+ distributional = args.distributional
60
+ dataset = args.dataset
61
+
62
+ env = OpenMLEnv(dataset, batch_size=[batch_size])
63
+ n_actions = env.action_spec.space.n
64
+ if distributional:
65
+ # does not really make sense since the value is either 0 or 1 and hopefully we
66
+ # should always predict 1
67
+ nbins = 2
68
+ model = MLP(
69
+ out_features=(nbins, n_actions),
70
+ depth=3,
71
+ num_cells=n_cells,
72
+ activation_class=nn.Tanh,
73
+ )
74
+ actor = DistributionalQValueActor(
75
+ model, support=torch.arange(2), action_space="categorical"
76
+ )
77
+ actor(env.reset())
78
+ loss = DistributionalDQNLoss(
79
+ actor,
80
+ )
81
+ loss.make_value_estimator(gamma=0.9)
82
+ else:
83
+ model = MLP(
84
+ out_features=n_actions, depth=3, num_cells=n_cells, activation_class=nn.Tanh
85
+ )
86
+ actor = QValueActor(model, action_space="categorical")
87
+ actor(env.reset())
88
+ loss = DQNLoss(actor, loss_function="smooth_l1", action_space=env.action_spec)
89
+ loss.make_value_estimator(gamma=0.0)
90
+ policy = TensorDictSequential(
91
+ actor,
92
+ EGreedyModule(
93
+ eps_init=eps_greedy,
94
+ eps_end=0.0,
95
+ annealing_num_steps=n_steps,
96
+ spec=env.action_spec,
97
+ ),
98
+ )
99
+ optim = torch.optim.Adam(loss.parameters(), lr, weight_decay=wd)
100
+
101
+ pbar = tqdm.tqdm(range(n_steps))
102
+
103
+ init_r = None
104
+ init_loss = None
105
+ for i in pbar:
106
+ with set_exploration_type(ExplorationType.RANDOM):
107
+ data = env.step(policy(env.reset()))
108
+ loss_vals = loss(data)
109
+ loss_val = sum(
110
+ value for key, value in loss_vals.items() if key.startswith("loss")
111
+ )
112
+ loss_val.backward()
113
+ optim.step()
114
+ optim.zero_grad()
115
+ if i % 10 == 0:
116
+ test_data = env.step(policy(env.reset()))
117
+ if init_r is None:
118
+ init_r = test_data["next", "reward"].sum() / env.numel()
119
+ if init_loss is None:
120
+ init_loss = loss_val.detach().item()
121
+ pbar.set_description(
122
+ f"reward: {test_data['next', 'reward'].sum() / env.numel(): 4.4f} (init={init_r: 4.4f}), "
123
+ f"training reward {data['next', 'reward'].sum() / env.numel() : 4.4f}, "
124
+ f"loss {loss_val: 4.4f} (init: {init_loss: 4.4f})"
125
+ )
126
+ policy[1].step()