torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/data/utils.py ADDED
@@ -0,0 +1,334 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import typing
9
+ from collections.abc import Callable
10
+ from typing import Any, Union
11
+
12
+ import cloudpickle
13
+ import numpy as np
14
+ import torch
15
+ from torch import Tensor
16
+ from torchrl.data.tensor_specs import (
17
+ Binary,
18
+ Categorical,
19
+ Composite,
20
+ MultiCategorical,
21
+ MultiOneHot,
22
+ OneHot,
23
+ Stacked,
24
+ StackedComposite,
25
+ TensorSpec,
26
+ )
27
+
28
+ numpy_to_torch_dtype_dict = {
29
+ np.dtype("bool"): torch.bool,
30
+ np.dtype("uint8"): torch.uint8,
31
+ np.dtype("int8"): torch.int8,
32
+ np.dtype("int16"): torch.int16,
33
+ np.dtype("int32"): torch.int32,
34
+ np.dtype("int64"): torch.int64,
35
+ np.dtype("float16"): torch.float16,
36
+ np.dtype("float32"): torch.float32,
37
+ np.dtype("float64"): torch.float64,
38
+ np.dtype("complex64"): torch.complex64,
39
+ np.dtype("complex128"): torch.complex128,
40
+ }
41
+ torch_to_numpy_dtype_dict = {
42
+ value: key for key, value in numpy_to_torch_dtype_dict.items()
43
+ }
44
+ DEVICE_TYPING = Union[torch.device, str, int]
45
+ if hasattr(typing, "get_args"):
46
+ DEVICE_TYPING_ARGS = typing.get_args(DEVICE_TYPING)
47
+ else:
48
+ DEVICE_TYPING_ARGS = (torch.device, str, int)
49
+
50
+ INDEX_TYPING = Union[None, int, slice, str, Tensor, list[Any], tuple[Any, ...]]
51
+
52
+
53
+ ACTION_SPACE_MAP = {
54
+ OneHot: "one_hot",
55
+ MultiOneHot: "mult_one_hot",
56
+ Binary: "binary",
57
+ Categorical: "categorical",
58
+ "one_hot": "one_hot",
59
+ "one-hot": "one_hot",
60
+ "mult_one_hot": "mult_one_hot",
61
+ "mult-one-hot": "mult_one_hot",
62
+ "multi_one_hot": "mult_one_hot",
63
+ "multi-one-hot": "mult_one_hot",
64
+ "binary": "binary",
65
+ "categorical": "categorical",
66
+ MultiCategorical: "multi_categorical",
67
+ "multi_categorical": "multi_categorical",
68
+ "multi-categorical": "multi_categorical",
69
+ "multi_discrete": "multi_categorical",
70
+ "multi-discrete": "multi_categorical",
71
+ }
72
+
73
+
74
+ def consolidate_spec(
75
+ spec: Composite,
76
+ recurse_through_entries: bool = True,
77
+ recurse_through_stack: bool = True,
78
+ ):
79
+ """Given a TensorSpec, removes exclusive keys by adding 0 shaped specs.
80
+
81
+ Args:
82
+ spec (Composite): the spec to be consolidated.
83
+ recurse_through_entries (bool): if True, call the function recursively on all entries of the spec.
84
+ Default is True.
85
+ recurse_through_stack (bool): if True, if the provided spec is lazy, the function recursively
86
+ on all specs in its list. Default is True.
87
+
88
+ """
89
+ spec = spec.clone()
90
+
91
+ if not isinstance(spec, (Composite, StackedComposite)):
92
+ return spec
93
+
94
+ if isinstance(spec, StackedComposite):
95
+ keys = set(spec.keys()) # shared keys
96
+ exclusive_keys_per_spec = [
97
+ set() for _ in range(len(spec._specs))
98
+ ] # list of exclusive keys per td
99
+ exclusive_keys_examples = (
100
+ {}
101
+ ) # map of all exclusive keys to a list of their values
102
+ for spec_index in range(len(spec._specs)): # gather all exclusive keys
103
+ sub_spec = spec._specs[spec_index]
104
+ if recurse_through_stack:
105
+ sub_spec = consolidate_spec(
106
+ sub_spec, recurse_through_entries, recurse_through_stack
107
+ )
108
+ spec._specs[spec_index] = sub_spec
109
+ for sub_spec_key in sub_spec.keys():
110
+ if sub_spec_key not in keys: # exclusive key
111
+ exclusive_keys_per_spec[spec_index].add(sub_spec_key)
112
+ value = sub_spec[sub_spec_key]
113
+ if sub_spec_key in exclusive_keys_examples:
114
+ exclusive_keys_examples[sub_spec_key].append(value)
115
+ else:
116
+ exclusive_keys_examples.update({sub_spec_key: [value]})
117
+
118
+ for sub_spec, exclusive_keys in zip(
119
+ spec._specs, exclusive_keys_per_spec
120
+ ): # add missing exclusive entries
121
+ for exclusive_key in set(exclusive_keys_examples.keys()).difference(
122
+ exclusive_keys
123
+ ):
124
+ exclusive_keys_example_list = exclusive_keys_examples[exclusive_key]
125
+ sub_spec.set(
126
+ exclusive_key,
127
+ _empty_like_spec(exclusive_keys_example_list, sub_spec.shape),
128
+ )
129
+
130
+ if recurse_through_entries:
131
+ for key, value in spec.items():
132
+ if isinstance(value, (Composite, StackedComposite)):
133
+ spec.set(
134
+ key,
135
+ consolidate_spec(
136
+ value, recurse_through_entries, recurse_through_stack
137
+ ),
138
+ )
139
+ return spec
140
+
141
+
142
+ def _empty_like_spec(specs: list[TensorSpec], shape):
143
+ for spec in specs[1:]:
144
+ if spec.__class__ != specs[0].__class__:
145
+ raise ValueError(
146
+ "Found same key in lazy specs corresponding to entries with different classes"
147
+ )
148
+ spec = specs[0]
149
+ if isinstance(spec, (Composite, StackedComposite)):
150
+ # the exclusive key has values which are Composite specs ->
151
+ # we create an empty composite spec with same batch size
152
+ return spec.empty()
153
+ elif isinstance(spec, Stacked):
154
+ # the exclusive key has values which are Stacked specs ->
155
+ # we create a Stacked spec with the same shape (aka same -1s) as the first in the list.
156
+ # this will not add any new -1s when they are stacked
157
+ shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :])
158
+ return Stacked(
159
+ *[_empty_like_spec(spec._specs, shape) for _ in spec._specs],
160
+ dim=spec.stack_dim,
161
+ )
162
+ else:
163
+ # the exclusive key has values which are TensorSpecs ->
164
+ # if the shapes of the values are all the same, we create a TensorSpec with leading shape `shape` and following dims 0 (having the same ndims as the values)
165
+ # if the shapes of the values differ, we create a TensorSpec with 0 size in the differing dims
166
+ spec_shape = list(spec.shape)
167
+
168
+ for dim_index in range(len(spec_shape)):
169
+ hetero_dim = False
170
+ for sub_spec in specs:
171
+ if sub_spec.shape[dim_index] != spec.shape[dim_index]:
172
+ hetero_dim = True
173
+ break
174
+ if hetero_dim:
175
+ spec_shape[dim_index] = 0
176
+
177
+ if 0 not in spec_shape: # the values have all same shape
178
+ spec_shape = [
179
+ dim if i < len(shape) else 0 for i, dim in enumerate(spec_shape)
180
+ ]
181
+
182
+ spec = spec[(0,) * len(spec.shape)]
183
+ spec = spec.expand(spec_shape)
184
+
185
+ return spec
186
+
187
+
188
+ def check_no_exclusive_keys(spec: TensorSpec, recurse: bool = True):
189
+ """Given a TensorSpec, returns true if there are no exclusive keys.
190
+
191
+ Args:
192
+ spec (TensorSpec): the spec to check
193
+ recurse (bool): if True, check recursively in nested specs. Default is True.
194
+ """
195
+ if isinstance(spec, StackedComposite):
196
+ keys = set(spec.keys())
197
+ for inner_td in spec._specs:
198
+ if recurse and not check_no_exclusive_keys(inner_td):
199
+ return False
200
+ if set(inner_td.keys()) != keys:
201
+ return False
202
+ elif isinstance(spec, Composite) and recurse:
203
+ for value in spec.values():
204
+ if not check_no_exclusive_keys(value):
205
+ return False
206
+ else:
207
+ return True
208
+ return True
209
+
210
+
211
+ def contains_lazy_spec(spec: TensorSpec) -> bool:
212
+ """Returns true if a spec contains lazy stacked specs.
213
+
214
+ Args:
215
+ spec (TensorSpec): the spec to check
216
+
217
+ """
218
+ if isinstance(spec, (Stacked, StackedComposite)):
219
+ return True
220
+ elif isinstance(spec, Composite):
221
+ for inner_spec in spec.values():
222
+ if contains_lazy_spec(inner_spec):
223
+ return True
224
+ return False
225
+
226
+
227
+ class _CloudpickleWrapperMeta(type):
228
+ def __call__(cls, obj):
229
+ if isinstance(obj, cls):
230
+ return obj
231
+ else:
232
+ return super().__call__(obj)
233
+
234
+
235
+ class CloudpickleWrapper(metaclass=_CloudpickleWrapperMeta):
236
+ """A wrapper for functions that allow for serialization in multiprocessed settings."""
237
+
238
+ def __init__(self, fn: Callable, **kwargs):
239
+ if fn.__class__.__name__ == "EnvCreator":
240
+ raise RuntimeError(
241
+ "CloudpickleWrapper usage with EnvCreator class is "
242
+ "prohibited as it breaks the transmission of shared tensors."
243
+ )
244
+ self.fn = fn
245
+ self.kwargs = kwargs
246
+
247
+ functools.update_wrapper(self, getattr(fn, "forward", fn))
248
+
249
+ def __getstate__(self):
250
+ return cloudpickle.dumps((self.fn, self.kwargs))
251
+
252
+ def __setstate__(self, ob: bytes):
253
+ self.fn, self.kwargs = cloudpickle.loads(ob)
254
+ functools.update_wrapper(self, getattr(self.fn, "forward", self.fn))
255
+
256
+ def __call__(self, *args, **kwargs) -> Any:
257
+ kwargs.update(self.kwargs)
258
+ return self.fn(*args, **kwargs)
259
+
260
+
261
+ def _process_action_space_spec(action_space, spec):
262
+ original_spec = spec
263
+ composite_spec = False
264
+ if isinstance(spec, Composite):
265
+ # this will break whenever our action is more complex than a single tensor
266
+ try:
267
+ if "action" in spec.keys():
268
+ _key = "action"
269
+ else:
270
+ # the first key is the action
271
+ for _key in spec.keys(True, True):
272
+ if isinstance(_key, tuple) and _key[-1] == "action":
273
+ break
274
+ else:
275
+ raise KeyError
276
+ spec = spec[_key]
277
+ composite_spec = True
278
+ except KeyError:
279
+ raise KeyError(
280
+ "action could not be found in the spec. Make sure "
281
+ "you pass a spec that is either a native action spec or a composite action spec "
282
+ "with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only."
283
+ )
284
+ if action_space is not None:
285
+ if isinstance(action_space, Composite):
286
+ raise ValueError("action_space cannot be of type Composite.")
287
+ if (
288
+ spec is not None
289
+ and isinstance(action_space, TensorSpec)
290
+ and action_space is not spec
291
+ ):
292
+ raise ValueError(
293
+ "Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match."
294
+ )
295
+ if isinstance(action_space, TensorSpec):
296
+ spec = action_space
297
+ action_space = _find_action_space(action_space)
298
+ # check that the spec and action_space match
299
+ if spec is not None and _find_action_space(spec) != action_space:
300
+ raise ValueError(
301
+ f"The action spec and the action space do not match: got action_space={action_space} and spec={spec}."
302
+ )
303
+ elif spec is not None:
304
+ action_space = _find_action_space(spec)
305
+ else:
306
+ raise ValueError(
307
+ "Neither action_space nor spec was defined. The action space cannot be inferred."
308
+ )
309
+ if composite_spec:
310
+ spec = original_spec
311
+ return action_space, spec
312
+
313
+
314
+ def _find_action_space(action_space) -> str:
315
+ if isinstance(action_space, TensorSpec):
316
+ if isinstance(action_space, Composite):
317
+ if "action" in action_space.keys():
318
+ _key = "action"
319
+ else:
320
+ # the first key is the action
321
+ for _key in action_space.keys(True, True):
322
+ if isinstance(_key, tuple) and _key[-1] == "action":
323
+ break
324
+ else:
325
+ raise KeyError
326
+ action_space = action_space[_key]
327
+ action_space = type(action_space)
328
+ try:
329
+ action_space = ACTION_SPACE_MAP[action_space]
330
+ except KeyError:
331
+ raise ValueError(
332
+ f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}."
333
+ )
334
+ return action_space
@@ -0,0 +1,265 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .async_envs import AsyncEnvPool, ProcessorAsyncEnvPool, ThreadingAsyncEnvPool
7
+ from .batched_envs import ParallelEnv, SerialEnv
8
+ from .common import EnvBase, EnvMetaData, make_tensordict
9
+ from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
10
+ from .env_creator import env_creator, EnvCreator, get_env_metadata
11
+ from .gym_like import default_info_dict_reader, GymLikeEnv
12
+ from .libs import (
13
+ BraxEnv,
14
+ BraxWrapper,
15
+ DMControlEnv,
16
+ DMControlWrapper,
17
+ gym_backend,
18
+ GymEnv,
19
+ GymWrapper,
20
+ HabitatEnv,
21
+ IsaacGymEnv,
22
+ IsaacGymWrapper,
23
+ IsaacLabWrapper,
24
+ JumanjiEnv,
25
+ JumanjiWrapper,
26
+ MeltingpotEnv,
27
+ MeltingpotWrapper,
28
+ MOGymEnv,
29
+ MOGymWrapper,
30
+ MultiThreadedEnv,
31
+ MultiThreadedEnvWrapper,
32
+ OpenMLEnv,
33
+ OpenSpielEnv,
34
+ OpenSpielWrapper,
35
+ PettingZooEnv,
36
+ PettingZooWrapper,
37
+ ProcgenEnv,
38
+ ProcgenWrapper,
39
+ register_gym_spec_conversion,
40
+ RoboHiveEnv,
41
+ set_gym_backend,
42
+ SMACv2Env,
43
+ SMACv2Wrapper,
44
+ UnityMLAgentsEnv,
45
+ UnityMLAgentsWrapper,
46
+ VmasEnv,
47
+ VmasWrapper,
48
+ )
49
+ from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase
50
+ from .transforms import (
51
+ ActionDiscretizer,
52
+ ActionMask,
53
+ AutoResetEnv,
54
+ AutoResetTransform,
55
+ BatchSizeTransform,
56
+ BinarizeReward,
57
+ BurnInTransform,
58
+ CatFrames,
59
+ CatTensors,
60
+ CenterCrop,
61
+ ClipTransform,
62
+ Compose,
63
+ ConditionalPolicySwitch,
64
+ ConditionalSkip,
65
+ Crop,
66
+ DeviceCastTransform,
67
+ DiscreteActionProjection,
68
+ DoubleToFloat,
69
+ DTypeCastTransform,
70
+ EndOfLifeTransform,
71
+ ExcludeTransform,
72
+ FiniteTensorDictCheck,
73
+ FlattenObservation,
74
+ FrameSkipTransform,
75
+ GrayScale,
76
+ gSDENoise,
77
+ Hash,
78
+ InitTracker,
79
+ LineariseRewards,
80
+ MultiAction,
81
+ MultiStepTransform,
82
+ NoopResetEnv,
83
+ ObservationNorm,
84
+ ObservationTransform,
85
+ PermuteTransform,
86
+ PinMemoryTransform,
87
+ R3MTransform,
88
+ RandomCropTensorDict,
89
+ RemoveEmptySpecs,
90
+ RenameTransform,
91
+ Resize,
92
+ Reward2GoTransform,
93
+ RewardClipping,
94
+ RewardScaling,
95
+ RewardSum,
96
+ SelectTransform,
97
+ SignTransform,
98
+ SqueezeTransform,
99
+ Stack,
100
+ StepCounter,
101
+ TargetReturn,
102
+ TensorDictPrimer,
103
+ TimeMaxPool,
104
+ Timer,
105
+ Tokenizer,
106
+ ToTensorImage,
107
+ TrajCounter,
108
+ Transform,
109
+ TransformedEnv,
110
+ UnaryTransform,
111
+ UnsqueezeTransform,
112
+ VC1Transform,
113
+ VecGymEnvTransform,
114
+ VecNorm,
115
+ VecNormV2,
116
+ VIPRewardTransform,
117
+ VIPTransform,
118
+ )
119
+ from .utils import (
120
+ check_env_specs,
121
+ check_marl_grouping,
122
+ exploration_type,
123
+ ExplorationType,
124
+ get_available_libraries,
125
+ make_composite_from_td,
126
+ MarlGroupMapType,
127
+ set_exploration_type,
128
+ step_mdp,
129
+ terminated_or_truncated,
130
+ )
131
+
132
+ __all__ = [
133
+ "ActionDiscretizer",
134
+ "ActionMask",
135
+ "VecNormV2",
136
+ "IsaacLabWrapper",
137
+ "AutoResetEnv",
138
+ "AutoResetTransform",
139
+ "AsyncEnvPool",
140
+ "ProcessorAsyncEnvPool",
141
+ "ConditionalPolicySwitch",
142
+ "ThreadingAsyncEnvPool",
143
+ "BatchSizeTransform",
144
+ "BinarizeReward",
145
+ "BraxEnv",
146
+ "BraxWrapper",
147
+ "BurnInTransform",
148
+ "CatFrames",
149
+ "CatTensors",
150
+ "CenterCrop",
151
+ "ChessEnv",
152
+ "ClipTransform",
153
+ "Compose",
154
+ "ConditionalSkip",
155
+ "Crop",
156
+ "DMControlEnv",
157
+ "DMControlWrapper",
158
+ "DTypeCastTransform",
159
+ "DeviceCastTransform",
160
+ "DiscreteActionProjection",
161
+ "DoubleToFloat",
162
+ "DreamerDecoder",
163
+ "DreamerEnv",
164
+ "EndOfLifeTransform",
165
+ "EnvBase",
166
+ "EnvCreator",
167
+ "EnvMetaData",
168
+ "ExcludeTransform",
169
+ "ExplorationType",
170
+ "FiniteTensorDictCheck",
171
+ "FlattenObservation",
172
+ "FrameSkipTransform",
173
+ "GrayScale",
174
+ "GymEnv",
175
+ "GymLikeEnv",
176
+ "GymWrapper",
177
+ "HabitatEnv",
178
+ "Hash",
179
+ "InitTracker",
180
+ "IsaacGymEnv",
181
+ "IsaacGymWrapper",
182
+ "JumanjiEnv",
183
+ "JumanjiWrapper",
184
+ "LLMHashingEnv",
185
+ "LineariseRewards",
186
+ "MOGymEnv",
187
+ "MOGymWrapper",
188
+ "MarlGroupMapType",
189
+ "MeltingpotEnv",
190
+ "MeltingpotWrapper",
191
+ "ModelBasedEnvBase",
192
+ "MultiAction",
193
+ "MultiStepTransform",
194
+ "MultiThreadedEnv",
195
+ "MultiThreadedEnvWrapper",
196
+ "NoopResetEnv",
197
+ "ObservationNorm",
198
+ "ObservationTransform",
199
+ "OpenMLEnv",
200
+ "OpenSpielEnv",
201
+ "OpenSpielWrapper",
202
+ "ParallelEnv",
203
+ "PendulumEnv",
204
+ "PermuteTransform",
205
+ "PettingZooEnv",
206
+ "PettingZooWrapper",
207
+ "PinMemoryTransform",
208
+ "ProcgenEnv",
209
+ "ProcgenWrapper",
210
+ "R3MTransform",
211
+ "RandomCropTensorDict",
212
+ "RemoveEmptySpecs",
213
+ "RenameTransform",
214
+ "Resize",
215
+ "Reward2GoTransform",
216
+ "RewardClipping",
217
+ "RewardScaling",
218
+ "RewardSum",
219
+ "RoboHiveEnv",
220
+ "SMACv2Env",
221
+ "SMACv2Wrapper",
222
+ "SelectTransform",
223
+ "SerialEnv",
224
+ "SignTransform",
225
+ "SqueezeTransform",
226
+ "Stack",
227
+ "StepCounter",
228
+ "TargetReturn",
229
+ "TensorDictPrimer",
230
+ "TicTacToeEnv",
231
+ "TimeMaxPool",
232
+ "Timer",
233
+ "ToTensorImage",
234
+ "Tokenizer",
235
+ "TrajCounter",
236
+ "Transform",
237
+ "TransformedEnv",
238
+ "UnaryTransform",
239
+ "UnityMLAgentsEnv",
240
+ "UnityMLAgentsWrapper",
241
+ "UnsqueezeTransform",
242
+ "VC1Transform",
243
+ "VIPRewardTransform",
244
+ "VIPTransform",
245
+ "VecGymEnvTransform",
246
+ "VecNorm",
247
+ "VmasEnv",
248
+ "VmasWrapper",
249
+ "check_env_specs",
250
+ "check_marl_grouping",
251
+ "default_info_dict_reader",
252
+ "env_creator",
253
+ "exploration_type",
254
+ "gSDENoise",
255
+ "get_available_libraries",
256
+ "get_env_metadata",
257
+ "gym_backend",
258
+ "make_composite_from_td",
259
+ "make_tensordict",
260
+ "register_gym_spec_conversion",
261
+ "set_exploration_type",
262
+ "set_gym_backend",
263
+ "step_mdp",
264
+ "terminated_or_truncated",
265
+ ]