torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,488 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ from tensordict import TensorDict
11
+ from tensordict.nn import TensorDictModule
12
+ from tensordict.utils import NestedKey
13
+
14
+ from torchrl._utils import _maybe_record_function_decorator, _maybe_timeit
15
+ from torchrl.envs.model_based.dreamer import DreamerEnv
16
+ from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
17
+ from torchrl.objectives.common import LossModule
18
+ from torchrl.objectives.utils import (
19
+ _GAMMA_LMBDA_DEPREC_ERROR,
20
+ default_value_kwargs,
21
+ distance_loss,
22
+ hold_out_net,
23
+ ValueEstimators,
24
+ ) # distance_loss,
25
+ from torchrl.objectives.value import (
26
+ TD0Estimator,
27
+ TD1Estimator,
28
+ TDLambdaEstimator,
29
+ ValueEstimatorBase,
30
+ )
31
+
32
+
33
+ class DreamerModelLoss(LossModule):
34
+ """Dreamer Model Loss.
35
+
36
+ Computes the loss of the dreamer world model. The loss is composed of the
37
+ kl divergence between the prior and posterior of the RSSM,
38
+ the reconstruction loss over the reconstructed observation and the reward
39
+ loss over the predicted reward.
40
+
41
+ Reference: https://arxiv.org/abs/1912.01603.
42
+
43
+ Args:
44
+ world_model (TensorDictModule): the world model.
45
+ lambda_kl (:obj:`float`, optional): the weight of the kl divergence loss. Default: 1.0.
46
+ lambda_reco (:obj:`float`, optional): the weight of the reconstruction loss. Default: 1.0.
47
+ lambda_reward (:obj:`float`, optional): the weight of the reward loss. Default: 1.0.
48
+ reco_loss (str, optional): the reconstruction loss. Default: "l2".
49
+ reward_loss (str, optional): the reward loss. Default: "l2".
50
+ free_nats (int, optional): the free nats. Default: 3.
51
+ delayed_clamp (bool, optional): if ``True``, the KL clamping occurs after
52
+ averaging. If False (default), the kl divergence is clamped to the
53
+ free nats value first and then averaged.
54
+ global_average (bool, optional): if ``True``, the losses will be averaged
55
+ over all dimensions. Otherwise, a sum will be performed over all
56
+ non-batch/time dimensions and an average over batch and time.
57
+ Default: False.
58
+ """
59
+
60
+ @dataclass
61
+ class _AcceptedKeys:
62
+ """Maintains default values for all configurable tensordict keys.
63
+
64
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
65
+ default values
66
+
67
+ Attributes:
68
+ reward (NestedKey): The reward is expected to be in the tensordict
69
+ key ("next", reward). Defaults to ``"reward"``.
70
+ true_reward (NestedKey): The `true_reward` will be stored in the
71
+ tensordict key ("next", true_reward). Defaults to ``"true_reward"``.
72
+ prior_mean (NestedKey): The prior mean is expected to be in the
73
+ tensordict key ("next", prior_mean). Defaults to ``"prior_mean"``.
74
+ prior_std (NestedKey): The prior mean is expected to be in the
75
+ tensordict key ("next", prior_mean). Defaults to ``"prior_mean"``.
76
+ posterior_mean (NestedKey): The posterior mean is expected to be in
77
+ the tensordict key ("next", prior_mean). Defaults to ``"posterior_mean"``.
78
+ posterior_std (NestedKey): The posterior std is expected to be in
79
+ the tensordict key ("next", prior_mean). Defaults to ``"posterior_std"``.
80
+ pixels (NestedKey): The pixels is expected to be in the tensordict key ("next", pixels).
81
+ Defaults to ``"pixels"``.
82
+ reco_pixels (NestedKey): The reconstruction pixels is expected to be
83
+ in the tensordict key ("next", reco_pixels). Defaults to ``"reco_pixels"``.
84
+ """
85
+
86
+ reward: NestedKey = "reward"
87
+ true_reward: NestedKey = "true_reward"
88
+ prior_mean: NestedKey = "prior_mean"
89
+ prior_std: NestedKey = "prior_std"
90
+ posterior_mean: NestedKey = "posterior_mean"
91
+ posterior_std: NestedKey = "posterior_std"
92
+ pixels: NestedKey = "pixels"
93
+ reco_pixels: NestedKey = "reco_pixels"
94
+
95
+ tensor_keys: _AcceptedKeys
96
+ default_keys = _AcceptedKeys
97
+
98
+ decoder: TensorDictModule
99
+ reward_model: TensorDictModule
100
+ world_mdel: TensorDictModule
101
+
102
+ def __init__(
103
+ self,
104
+ world_model: TensorDictModule,
105
+ *,
106
+ lambda_kl: float = 1.0,
107
+ lambda_reco: float = 1.0,
108
+ lambda_reward: float = 1.0,
109
+ reco_loss: str | None = None,
110
+ reward_loss: str | None = None,
111
+ free_nats: int = 3,
112
+ delayed_clamp: bool = False,
113
+ global_average: bool = False,
114
+ ):
115
+ super().__init__()
116
+ self.world_model = world_model
117
+ self.reco_loss = reco_loss if reco_loss is not None else "l2"
118
+ self.reward_loss = reward_loss if reward_loss is not None else "l2"
119
+ self.lambda_kl = lambda_kl
120
+ self.lambda_reco = lambda_reco
121
+ self.lambda_reward = lambda_reward
122
+ self.free_nats = free_nats
123
+ self.delayed_clamp = delayed_clamp
124
+ self.global_average = global_average
125
+ self.__dict__["decoder"] = self.world_model[0][-1]
126
+ self.__dict__["reward_model"] = self.world_model[1]
127
+
128
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
129
+ pass
130
+
131
+ @_maybe_record_function_decorator("world_model_loss/forward")
132
+ def forward(self, tensordict: TensorDict) -> torch.Tensor:
133
+ tensordict = tensordict.copy()
134
+ tensordict.rename_key_(
135
+ ("next", self.tensor_keys.reward),
136
+ ("next", self.tensor_keys.true_reward),
137
+ )
138
+
139
+ tensordict = self.world_model(tensordict)
140
+
141
+ prior_mean = tensordict.get(("next", self.tensor_keys.prior_mean))
142
+ prior_std = tensordict.get(("next", self.tensor_keys.prior_std))
143
+ posterior_mean = tensordict.get(("next", self.tensor_keys.posterior_mean))
144
+ posterior_std = tensordict.get(("next", self.tensor_keys.posterior_std))
145
+
146
+ kl_loss = self.kl_loss(
147
+ prior_mean,
148
+ prior_std,
149
+ posterior_mean,
150
+ posterior_std,
151
+ ).unsqueeze(-1)
152
+
153
+ # Ensure contiguous layout for torch.compile compatibility
154
+ # The gradient from distance_loss flows back through decoder convolutions
155
+ pixels = tensordict.get(("next", self.tensor_keys.pixels)).contiguous()
156
+ reco_pixels = tensordict.get(
157
+ ("next", self.tensor_keys.reco_pixels)
158
+ ).contiguous()
159
+ reco_loss = distance_loss(
160
+ pixels,
161
+ reco_pixels,
162
+ self.reco_loss,
163
+ )
164
+ if not self.global_average:
165
+ reco_loss = reco_loss.sum((-3, -2, -1))
166
+ reco_loss = reco_loss.mean().unsqueeze(-1)
167
+
168
+ true_reward = tensordict.get(("next", self.tensor_keys.true_reward))
169
+ pred_reward = tensordict.get(("next", self.tensor_keys.reward))
170
+ reward_loss = distance_loss(
171
+ true_reward,
172
+ pred_reward,
173
+ self.reward_loss,
174
+ )
175
+ if not self.global_average:
176
+ reward_loss = reward_loss.squeeze(-1)
177
+ reward_loss = reward_loss.mean().unsqueeze(-1)
178
+
179
+ td_out = TensorDict(
180
+ loss_model_kl=self.lambda_kl * kl_loss,
181
+ loss_model_reco=self.lambda_reco * reco_loss,
182
+ loss_model_reward=self.lambda_reward * reward_loss,
183
+ )
184
+ self._clear_weakrefs(tensordict, td_out)
185
+
186
+ return (td_out, tensordict.data)
187
+
188
+ @staticmethod
189
+ def normal_log_probability(x, mean, std):
190
+ return (
191
+ -0.5 * ((x.to(mean.dtype) - mean) / std).pow(2) - std.log()
192
+ ) # - 0.5 * math.log(2 * math.pi)
193
+
194
+ def kl_loss(
195
+ self,
196
+ prior_mean: torch.Tensor,
197
+ prior_std: torch.Tensor,
198
+ posterior_mean: torch.Tensor,
199
+ posterior_std: torch.Tensor,
200
+ ) -> torch.Tensor:
201
+ kl = (
202
+ torch.log(prior_std / posterior_std)
203
+ + (posterior_std**2 + (prior_mean - posterior_mean) ** 2)
204
+ / (2 * prior_std**2)
205
+ - 0.5
206
+ )
207
+ if not self.global_average:
208
+ kl = kl.sum(-1)
209
+ if self.delayed_clamp:
210
+ kl = kl.mean().clamp_min(self.free_nats)
211
+ else:
212
+ kl = kl.clamp_min(self.free_nats).mean()
213
+ return kl
214
+
215
+
216
+ class DreamerActorLoss(LossModule):
217
+ """Dreamer Actor Loss.
218
+
219
+ Computes the loss of the dreamer actor. The actor loss is computed as the
220
+ negative average lambda return.
221
+
222
+ Reference: https://arxiv.org/abs/1912.01603.
223
+
224
+ Args:
225
+ actor_model (TensorDictModule): the actor model.
226
+ value_model (TensorDictModule): the value model.
227
+ model_based_env (DreamerEnv): the model based environment.
228
+ imagination_horizon (int, optional): The number of steps to unroll the
229
+ model. Defaults to ``15``.
230
+ discount_loss (bool, optional): if ``True``, the loss is discounted with a
231
+ gamma discount factor. Default to ``False``.
232
+
233
+ """
234
+
235
+ @dataclass
236
+ class _AcceptedKeys:
237
+ """Maintains default values for all configurable tensordict keys.
238
+
239
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
240
+ default values.
241
+
242
+ Attributes:
243
+ belief (NestedKey): The input tensordict key where the belief is expected.
244
+ Defaults to ``"belief"``.
245
+ reward (NestedKey): The reward is expected to be in the tensordict key ("next", reward).
246
+ Defaults to ``"reward"``.
247
+ value (NestedKey): The reward is expected to be in the tensordict key ("next", value).
248
+ Will be used for the underlying value estimator. Defaults to ``"state_value"``.
249
+ done (NestedKey): The input tensordict key where the flag if a
250
+ trajectory is done is expected ("next", done). Defaults to ``"done"``.
251
+ terminated (NestedKey): The input tensordict key where the flag if a
252
+ trajectory is terminated is expected ("next", terminated). Defaults to ``"terminated"``.
253
+ """
254
+
255
+ belief: NestedKey = "belief"
256
+ reward: NestedKey = "reward"
257
+ value: NestedKey = "state_value"
258
+ done: NestedKey = "done"
259
+ terminated: NestedKey = "terminated"
260
+
261
+ tensor_keys: _AcceptedKeys
262
+ default_keys = _AcceptedKeys
263
+ default_value_estimator = ValueEstimators.TDLambda
264
+
265
+ value_model: TensorDictModule
266
+ actor_model: TensorDictModule
267
+
268
+ def __init__(
269
+ self,
270
+ actor_model: TensorDictModule,
271
+ value_model: TensorDictModule,
272
+ model_based_env: DreamerEnv,
273
+ *,
274
+ imagination_horizon: int = 15,
275
+ discount_loss: bool = True, # for consistency with paper
276
+ gamma: int | None = None,
277
+ lmbda: int | None = None,
278
+ ):
279
+ super().__init__()
280
+ self.actor_model = actor_model
281
+ self.__dict__["value_model"] = value_model
282
+ self.model_based_env = model_based_env
283
+ self.imagination_horizon = imagination_horizon
284
+ self.discount_loss = discount_loss
285
+ if gamma is not None:
286
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
287
+ if lmbda is not None:
288
+ raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
289
+
290
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
291
+ if self._value_estimator is not None:
292
+ self._value_estimator.set_keys(
293
+ value=self._tensor_keys.value,
294
+ )
295
+
296
+ @_maybe_record_function_decorator("actor_loss/forward")
297
+ def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]:
298
+ tensordict = tensordict.select("state", self.tensor_keys.belief).data
299
+
300
+ with _maybe_timeit("actor_loss/time-rollout"), hold_out_net(
301
+ self.model_based_env
302
+ ), set_exploration_type(ExplorationType.RANDOM):
303
+ tensordict = self.model_based_env.reset(tensordict.copy())
304
+ fake_data = self.model_based_env.rollout(
305
+ max_steps=self.imagination_horizon,
306
+ policy=self.actor_model,
307
+ auto_reset=False,
308
+ tensordict=tensordict,
309
+ )
310
+ next_tensordict = step_mdp(fake_data, keep_other=True)
311
+ with hold_out_net(self.value_model):
312
+ next_tensordict = self.value_model(next_tensordict)
313
+
314
+ reward = fake_data.get(("next", self.tensor_keys.reward))
315
+ next_value = next_tensordict.get(self.tensor_keys.value)
316
+ lambda_target = self.lambda_target(reward, next_value)
317
+ fake_data.set("lambda_target", lambda_target)
318
+
319
+ if self.discount_loss:
320
+ gamma = self.value_estimator.gamma.to(tensordict.device)
321
+ discount = gamma.expand(lambda_target.shape).clone()
322
+ discount[..., 0, :] = 1
323
+ discount = discount.cumprod(dim=-2)
324
+ actor_loss = -(lambda_target * discount).sum((-2, -1)).mean()
325
+ else:
326
+ actor_loss = -lambda_target.sum((-2, -1)).mean()
327
+ loss_tensordict = TensorDict({"loss_actor": actor_loss}, [])
328
+ self._clear_weakrefs(tensordict, loss_tensordict)
329
+
330
+ return loss_tensordict, fake_data.data
331
+
332
+ def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
333
+ done = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device)
334
+ terminated = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device)
335
+ input_tensordict = TensorDict(
336
+ {
337
+ ("next", self.tensor_keys.reward): reward,
338
+ ("next", self.tensor_keys.value): value,
339
+ ("next", self.tensor_keys.done): done,
340
+ ("next", self.tensor_keys.terminated): terminated,
341
+ },
342
+ [],
343
+ )
344
+ return self.value_estimator.value_estimate(input_tensordict)
345
+
346
+ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
347
+ if value_type is None:
348
+ value_type = self.default_value_estimator
349
+
350
+ # Handle ValueEstimatorBase instance or class
351
+ if isinstance(value_type, ValueEstimatorBase) or (
352
+ isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase)
353
+ ):
354
+ return LossModule.make_value_estimator(self, value_type, **hyperparams)
355
+
356
+ self.value_type = value_type
357
+ value_net = None
358
+ hp = dict(default_value_kwargs(value_type))
359
+ if hasattr(self, "gamma"):
360
+ hp["gamma"] = self.gamma
361
+ hp.update(hyperparams)
362
+ if value_type is ValueEstimators.TD1:
363
+ self._value_estimator = TD1Estimator(
364
+ **hp,
365
+ value_network=value_net,
366
+ )
367
+ elif value_type is ValueEstimators.TD0:
368
+ self._value_estimator = TD0Estimator(
369
+ **hp,
370
+ value_network=value_net,
371
+ )
372
+ elif value_type is ValueEstimators.GAE:
373
+ if hasattr(self, "lmbda"):
374
+ hp["lmbda"] = self.lmbda
375
+ raise NotImplementedError(
376
+ f"Value type {value_type} it not implemented for loss {type(self)}."
377
+ )
378
+ elif value_type is ValueEstimators.TDLambda:
379
+ if hasattr(self, "lmbda"):
380
+ hp["lmbda"] = self.lmbda
381
+ self._value_estimator = TDLambdaEstimator(
382
+ **hp,
383
+ value_network=value_net,
384
+ vectorized=True, # TODO: vectorized version seems not to be similar to the non vectorised
385
+ )
386
+ else:
387
+ raise NotImplementedError(f"Unknown value type {value_type}")
388
+
389
+ tensor_keys = {
390
+ "value": self.tensor_keys.value,
391
+ "value_target": "value_target",
392
+ }
393
+ self._value_estimator.set_keys(**tensor_keys)
394
+
395
+
396
+ class DreamerValueLoss(LossModule):
397
+ """Dreamer Value Loss.
398
+
399
+ Computes the loss of the dreamer value model. The value loss is computed
400
+ between the predicted value and the lambda target.
401
+
402
+ Reference: https://arxiv.org/abs/1912.01603.
403
+
404
+ Args:
405
+ value_model (TensorDictModule): the value model.
406
+ value_loss (str, optional): the loss to use for the value loss.
407
+ Default: ``"l2"``.
408
+ discount_loss (bool, optional): if ``True``, the loss is discounted with a
409
+ gamma discount factor. Default: False.
410
+ gamma (:obj:`float`, optional): the gamma discount factor. Default: ``0.99``.
411
+
412
+ """
413
+
414
+ @dataclass
415
+ class _AcceptedKeys:
416
+ """Maintains default values for all configurable tensordict keys.
417
+
418
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
419
+ default values
420
+
421
+ Attributes:
422
+ value (NestedKey): The input tensordict key where the state value is expected.
423
+ Defaults to ``"state_value"``.
424
+ """
425
+
426
+ value: NestedKey = "state_value"
427
+
428
+ tensor_keys: _AcceptedKeys
429
+ default_keys = _AcceptedKeys
430
+
431
+ value_model: TensorDictModule
432
+
433
+ def __init__(
434
+ self,
435
+ value_model: TensorDictModule,
436
+ value_loss: str | None = None,
437
+ discount_loss: bool = True, # for consistency with paper
438
+ gamma: int = 0.99,
439
+ ):
440
+ super().__init__()
441
+ self.value_model = value_model
442
+ self.value_loss = value_loss if value_loss is not None else "l2"
443
+ self.gamma = gamma
444
+ self.discount_loss = discount_loss
445
+
446
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
447
+ pass
448
+
449
+ @_maybe_record_function_decorator("value_loss/forward")
450
+ def forward(self, fake_data) -> torch.Tensor:
451
+ lambda_target = fake_data.get("lambda_target")
452
+
453
+ tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False)
454
+ self.value_model(tensordict_select)
455
+
456
+ if self.discount_loss:
457
+ discount = self.gamma * torch.ones_like(
458
+ lambda_target, device=lambda_target.device
459
+ )
460
+ discount[..., 0, :] = 1
461
+ discount = discount.cumprod(dim=-2)
462
+ value_loss = (
463
+ (
464
+ discount
465
+ * distance_loss(
466
+ tensordict_select.get(self.tensor_keys.value),
467
+ lambda_target,
468
+ self.value_loss,
469
+ )
470
+ )
471
+ .sum((-1, -2))
472
+ .mean()
473
+ )
474
+ else:
475
+ value_loss = (
476
+ distance_loss(
477
+ tensordict_select.get(self.tensor_keys.value),
478
+ lambda_target,
479
+ self.value_loss,
480
+ )
481
+ .sum((-1, -2))
482
+ .mean()
483
+ )
484
+
485
+ loss_tensordict = TensorDict({"loss_value": value_loss})
486
+ self._clear_weakrefs(fake_data, loss_tensordict)
487
+
488
+ return loss_tensordict, fake_data
@@ -0,0 +1,48 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+
9
+
10
+ def cross_entropy_loss(
11
+ log_policy: torch.Tensor, action: torch.Tensor, inplace: bool = False
12
+ ) -> torch.Tensor:
13
+ """Returns the cross entropy loss defined as the log-softmax value indexed by the action index.
14
+
15
+ Supports discrete (integer) actions or one-hot encodings.
16
+
17
+ Args:
18
+ log_policy: Tensor of the log_softmax values of the policy.
19
+ action: Integer or one-hot representation of the actions undertaken. Must have a shape log_policy.shape[:-1]
20
+ (integer representation) or log_policy.shape (one-hot).
21
+ inplace: fills log_policy in-place with 0.0 at non-selected actions before summing along the last dimensions.
22
+ This is usually faster but it will change the value of log-policy in place, which may lead to unwanted
23
+ behaviors.
24
+
25
+ """
26
+ if action.shape == log_policy.shape:
27
+ if action.dtype not in (torch.bool, torch.long, torch.uint8):
28
+ raise TypeError(
29
+ f"Cross-entropy loss with {action.dtype} dtype is not permitted"
30
+ )
31
+ if not ((action == 1).sum(-1) == 1).all():
32
+ raise RuntimeError(
33
+ "Expected the action tensor to be a one hot encoding of the actions taken, "
34
+ "but got more/less than one non-null boolean index on the last dimension"
35
+ )
36
+ if inplace:
37
+ cross_entropy = log_policy.masked_fill_(action, 0.0).sum(-1)
38
+ else:
39
+ cross_entropy = (log_policy * action).sum(-1)
40
+ elif action.shape == log_policy.shape[:-1]:
41
+ cross_entropy = torch.gather(log_policy, dim=-1, index=action[..., None])
42
+ cross_entropy.squeeze_(-1)
43
+ else:
44
+ raise RuntimeError(
45
+ f"unexpected action shape in cross_entropy_loss with log_policy.shape={log_policy.shape} and"
46
+ f"action.shape={action.shape}"
47
+ )
48
+ return cross_entropy