torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,371 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from tensordict import TensorDict, TensorDictBase, TensorDictParams
12
+ from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
13
+ from tensordict.utils import NestedKey
14
+ from torch import distributions as d
15
+
16
+ from torchrl.objectives.common import LossModule
17
+ from torchrl.objectives.utils import _reduce, distance_loss
18
+
19
+
20
+ class OnlineDTLoss(LossModule):
21
+ r"""TorchRL implementation of the Online Decision Transformer loss.
22
+
23
+ Presented in `"Online Decision Transformer" <https://arxiv.org/abs/2202.05607>`
24
+
25
+ Args:
26
+ actor_network (ProbabilisticTensorDictSequential): stochastic actor
27
+
28
+ Keyword Args:
29
+ alpha_init (:obj:`float`, optional): initial entropy multiplier.
30
+ Default is 1.0.
31
+ min_alpha (:obj:`float`, optional): min value of alpha.
32
+ Default is None (no minimum value).
33
+ max_alpha (:obj:`float`, optional): max value of alpha.
34
+ Default is None (no maximum value).
35
+ fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its
36
+ initial value. Otherwise, alpha will be optimized to
37
+ match the 'target_entropy' value.
38
+ Default is ``False``.
39
+ target_entropy (:obj:`float` or str, optional): Target entropy for the
40
+ stochastic policy. Default is "auto", where target entropy is
41
+ computed as :obj:`-prod(n_actions)`.
42
+ samples_mc_entropy (int): number of samples to estimate the entropy
43
+ reduction (str, optional): Specifies the reduction to apply to the output:
44
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
45
+ ``"mean"``: the sum of the output will be divided by the number of
46
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
47
+ """
48
+
49
+ @dataclass
50
+ class _AcceptedKeys:
51
+ """Maintains default values for all configurable tensordict keys.
52
+
53
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
54
+ default values.
55
+
56
+ Attributes:
57
+ action_target (NestedKey): The input tensordict key where the action is expected.
58
+ Defaults to ``"action"``.
59
+ action_pred (NestedKey): The tensordict key where the output action (from the model) is expected.
60
+ Used to compute the target entropy.
61
+ Defaults to ``"action"``.
62
+
63
+ """
64
+
65
+ # the "action" contained in the dataset
66
+ action_target: NestedKey = "action"
67
+ # the "action" output from the model
68
+ action_pred: NestedKey = "action"
69
+
70
+ tensor_keys: _AcceptedKeys
71
+ default_keys = _AcceptedKeys
72
+
73
+ actor_network: TensorDictModule
74
+ actor_network_params: TensorDictParams
75
+ target_actor_network_params: TensorDictParams
76
+
77
+ def __init__(
78
+ self,
79
+ actor_network: ProbabilisticTensorDictSequential,
80
+ *,
81
+ alpha_init: float = 1.0,
82
+ min_alpha: float | None = None,
83
+ max_alpha: float | None = None,
84
+ fixed_alpha: bool = False,
85
+ target_entropy: str | float = "auto",
86
+ samples_mc_entropy: int = 1,
87
+ reduction: str | None = None,
88
+ ) -> None:
89
+ self._in_keys = None
90
+ self._out_keys = None
91
+ if reduction is None:
92
+ reduction = "mean"
93
+ super().__init__()
94
+
95
+ # Actor Network
96
+ self.convert_to_functional(
97
+ actor_network,
98
+ "actor_network",
99
+ create_target_params=False,
100
+ )
101
+ try:
102
+ device = next(self.parameters()).device
103
+ except AttributeError:
104
+ device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
105
+
106
+ self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
107
+ if bool(min_alpha) ^ bool(max_alpha):
108
+ min_alpha = min_alpha if min_alpha else 0.0
109
+ if max_alpha == 0:
110
+ raise ValueError("max_alpha must be either None or greater than 0.")
111
+ max_alpha = max_alpha if max_alpha else 1e9
112
+ if min_alpha:
113
+ self.register_buffer(
114
+ "min_log_alpha", torch.tensor(min_alpha, device=device).log()
115
+ )
116
+ else:
117
+ self.min_log_alpha = None
118
+ if max_alpha:
119
+ self.register_buffer(
120
+ "max_log_alpha", torch.tensor(max_alpha, device=device).log()
121
+ )
122
+ else:
123
+ self.max_log_alpha = None
124
+ self.fixed_alpha = fixed_alpha
125
+ if fixed_alpha:
126
+ self.register_buffer(
127
+ "log_alpha", torch.tensor(math.log(alpha_init), device=device)
128
+ )
129
+ else:
130
+ self.register_parameter(
131
+ "log_alpha",
132
+ torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
133
+ )
134
+
135
+ if target_entropy == "auto":
136
+ if actor_network.spec is None:
137
+ raise RuntimeError(
138
+ "Cannot infer the dimensionality of the action. Consider providing "
139
+ "the target entropy explicitly or provide the spec of the "
140
+ "action tensor in the actor network."
141
+ )
142
+ if isinstance(self.tensor_keys.action_pred, tuple):
143
+ action_container_shape = actor_network.spec[
144
+ self.tensor_keys.action_pred[:-1]
145
+ ].shape
146
+ else:
147
+ action_container_shape = actor_network.spec.shape
148
+ target_entropy = -float(
149
+ actor_network.spec[self.tensor_keys.action_pred]
150
+ .shape[len(action_container_shape) :]
151
+ .numel()
152
+ )
153
+ self.register_buffer(
154
+ "target_entropy", torch.tensor(target_entropy, device=device)
155
+ )
156
+
157
+ self.samples_mc_entropy = samples_mc_entropy
158
+ self._set_in_keys()
159
+ self.reduction = reduction
160
+
161
+ def _set_in_keys(self):
162
+ keys = self.actor_network.in_keys
163
+ keys = set(keys)
164
+ keys.add(self.tensor_keys.action_target)
165
+ self._in_keys = sorted(keys, key=str)
166
+
167
+ def _forward_value_estimator_keys(self, **kwargs):
168
+ pass
169
+
170
+ @property
171
+ def alpha(self):
172
+ if self.min_log_alpha is not None or self.max_log_alpha is not None:
173
+ self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
174
+ with torch.no_grad():
175
+ alpha = self.log_alpha.exp()
176
+ return alpha
177
+
178
+ @property
179
+ def in_keys(self):
180
+ if self._in_keys is None:
181
+ self._set_in_keys()
182
+ return self._in_keys
183
+
184
+ @in_keys.setter
185
+ def in_keys(self, values):
186
+ self._in_keys = values
187
+
188
+ @property
189
+ def out_keys(self):
190
+ if self._out_keys is None:
191
+ keys = [
192
+ "loss_log_likelihood",
193
+ "loss_entropy",
194
+ "loss_alpha",
195
+ "alpha",
196
+ "entropy",
197
+ ]
198
+ self._out_keys = keys
199
+ return self._out_keys
200
+
201
+ @out_keys.setter
202
+ def out_keys(self, values):
203
+ self._out_keys = values
204
+
205
+ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
206
+ x = dist.rsample((self.samples_mc_entropy,))
207
+ log_p = dist.log_prob(x)
208
+ # log_p: (batch_size, context_len)
209
+ return -log_p.mean(axis=0)
210
+
211
+ @dispatch
212
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
213
+ """Compute the loss for the Online Decision Transformer."""
214
+ # extract action targets
215
+ tensordict = tensordict.copy()
216
+ target_actions = tensordict.get(self.tensor_keys.action_target)
217
+ if target_actions.requires_grad:
218
+ raise RuntimeError("target action cannot be part of a graph.")
219
+
220
+ with self.actor_network_params.to_module(self.actor_network):
221
+ action_dist = self.actor_network.get_dist(tensordict)
222
+
223
+ log_likelihood = action_dist.log_prob(target_actions)
224
+ entropy = self.get_entropy_bonus(action_dist)
225
+ entropy_bonus = self.alpha.detach() * entropy
226
+
227
+ loss_alpha = self.log_alpha.exp() * (entropy - self.target_entropy).detach()
228
+
229
+ out = {
230
+ "loss_log_likelihood": -log_likelihood,
231
+ "loss_entropy": -entropy_bonus,
232
+ "loss_alpha": loss_alpha,
233
+ "entropy": entropy.detach().mean(),
234
+ "alpha": self.alpha.detach(),
235
+ }
236
+ td_out = TensorDict(out, [])
237
+ td_out = td_out.named_apply(
238
+ lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
239
+ if name.startswith("loss_")
240
+ else value,
241
+ )
242
+ self._clear_weakrefs(
243
+ tensordict,
244
+ td_out,
245
+ "actor_network_params",
246
+ "target_actor_network_params",
247
+ )
248
+ return td_out
249
+
250
+
251
+ class DTLoss(LossModule):
252
+ r"""TorchRL implementation of the Online Decision Transformer loss.
253
+
254
+ Presented in `"Decision Transformer: Reinforcement Learning via Sequence Modeling" <https://arxiv.org/abs/2106.01345>`
255
+
256
+ Args:
257
+ actor_network (ProbabilisticTensorDictSequential): stochastic actor
258
+
259
+ Keyword Args:
260
+ loss_function (str): loss function to use. Defaults to ``"l2"``.
261
+ reduction (str, optional): Specifies the reduction to apply to the output:
262
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
263
+ ``"mean"``: the sum of the output will be divided by the number of
264
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
265
+ """
266
+
267
+ @dataclass
268
+ class _AcceptedKeys:
269
+ """Maintains default values for all configurable tensordict keys.
270
+
271
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
272
+ default values.
273
+
274
+ Attributes:
275
+ action_target (NestedKey): The input tensordict key where the action is expected.
276
+ Defaults to ``"action"``.
277
+ action_pred (NestedKey): The tensordict key where the output action (from the model) is expected.
278
+ Defaults to ``"action"``.
279
+ """
280
+
281
+ # the "action" contained in the dataset
282
+ action_target: NestedKey = "action"
283
+ # the "action" output from the model
284
+ action_pred: NestedKey = "action"
285
+
286
+ tensor_keys: _AcceptedKeys
287
+ default_keys = _AcceptedKeys
288
+
289
+ actor_network: TensorDictModule
290
+ actor_network_params: TensorDictParams
291
+ target_actor_network_params: TensorDictParams
292
+
293
+ def __init__(
294
+ self,
295
+ actor_network: ProbabilisticTensorDictSequential,
296
+ *,
297
+ loss_function: str = "l2",
298
+ reduction: str | None = None,
299
+ device: torch.device | None = None,
300
+ ) -> None:
301
+ self._in_keys = None
302
+ self._out_keys = None
303
+ if reduction is None:
304
+ reduction = "mean"
305
+ super().__init__()
306
+
307
+ # Actor Network
308
+ self.convert_to_functional(
309
+ actor_network,
310
+ "actor_network",
311
+ create_target_params=False,
312
+ )
313
+ self.loss_function = loss_function
314
+ self.reduction = reduction
315
+
316
+ def _set_in_keys(self):
317
+ keys = self.actor_network.in_keys
318
+ keys = set(keys)
319
+ keys.add(self.tensor_keys.action_pred)
320
+ keys.add(self.tensor_keys.action_target)
321
+ self._in_keys = sorted(keys, key=str)
322
+
323
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
324
+ pass
325
+
326
+ @property
327
+ def in_keys(self):
328
+ if self._in_keys is None:
329
+ self._set_in_keys()
330
+ return self._in_keys
331
+
332
+ @in_keys.setter
333
+ def in_keys(self, values):
334
+ self._in_keys = values
335
+
336
+ @property
337
+ def out_keys(self):
338
+ if self._out_keys is None:
339
+ keys = ["loss"]
340
+ self._out_keys = keys
341
+ return self._out_keys
342
+
343
+ @out_keys.setter
344
+ def out_keys(self, values):
345
+ self._out_keys = values
346
+
347
+ @dispatch
348
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
349
+ """Compute the loss for the Online Decision Transformer."""
350
+ # extract action targets
351
+ tensordict = tensordict.copy()
352
+ target_actions = tensordict.get(self.tensor_keys.action_target).detach()
353
+
354
+ with self.actor_network_params.to_module(self.actor_network):
355
+ pred_actions = self.actor_network(tensordict).get(
356
+ self.tensor_keys.action_pred
357
+ )
358
+ loss = distance_loss(
359
+ pred_actions,
360
+ target_actions,
361
+ loss_function=self.loss_function,
362
+ )
363
+ loss = _reduce(loss, reduction=self.reduction)
364
+ td_out = TensorDict(loss=loss)
365
+ self._clear_weakrefs(
366
+ tensordict,
367
+ td_out,
368
+ "actor_network_params",
369
+ "target_actor_network_params",
370
+ )
371
+ return td_out