torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,68 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .brax import BraxEnv, BraxWrapper
7
+ from .dm_control import DMControlEnv, DMControlWrapper
8
+ from .envpool import MultiThreadedEnv, MultiThreadedEnvWrapper
9
+ from .gym import (
10
+ gym_backend,
11
+ GymEnv,
12
+ GymWrapper,
13
+ MOGymEnv,
14
+ MOGymWrapper,
15
+ register_gym_spec_conversion,
16
+ set_gym_backend,
17
+ )
18
+ from .habitat import HabitatEnv
19
+ from .isaac_lab import IsaacLabWrapper
20
+ from .isaacgym import IsaacGymEnv, IsaacGymWrapper
21
+ from .jumanji import JumanjiEnv, JumanjiWrapper
22
+ from .meltingpot import MeltingpotEnv, MeltingpotWrapper
23
+ from .openml import OpenMLEnv
24
+ from .openspiel import OpenSpielEnv, OpenSpielWrapper
25
+ from .pettingzoo import PettingZooEnv, PettingZooWrapper
26
+ from .procgen import ProcgenEnv, ProcgenWrapper
27
+ from .robohive import RoboHiveEnv
28
+ from .smacv2 import SMACv2Env, SMACv2Wrapper
29
+ from .unity_mlagents import UnityMLAgentsEnv, UnityMLAgentsWrapper
30
+ from .vmas import VmasEnv, VmasWrapper
31
+
32
+ __all__ = [
33
+ "BraxEnv",
34
+ "BraxWrapper",
35
+ "DMControlEnv",
36
+ "DMControlWrapper",
37
+ "GymEnv",
38
+ "GymWrapper",
39
+ "HabitatEnv",
40
+ "IsaacGymEnv",
41
+ "IsaacGymWrapper",
42
+ "IsaacLabWrapper",
43
+ "JumanjiEnv",
44
+ "JumanjiWrapper",
45
+ "MOGymEnv",
46
+ "MOGymWrapper",
47
+ "MeltingpotEnv",
48
+ "MeltingpotWrapper",
49
+ "MultiThreadedEnv",
50
+ "MultiThreadedEnvWrapper",
51
+ "OpenMLEnv",
52
+ "OpenSpielEnv",
53
+ "OpenSpielWrapper",
54
+ "PettingZooEnv",
55
+ "PettingZooWrapper",
56
+ "ProcgenEnv",
57
+ "ProcgenWrapper",
58
+ "RoboHiveEnv",
59
+ "SMACv2Env",
60
+ "SMACv2Wrapper",
61
+ "UnityMLAgentsEnv",
62
+ "UnityMLAgentsWrapper",
63
+ "VmasEnv",
64
+ "VmasWrapper",
65
+ "gym_backend",
66
+ "register_gym_spec_conversion",
67
+ "set_gym_backend",
68
+ ]
@@ -0,0 +1,326 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+
9
+ import torch
10
+ from tensordict.utils import unravel_key
11
+
12
+ from torch.utils._pytree import tree_map
13
+
14
+ from torchrl._utils import implement_for
15
+ from torchrl.data.tensor_specs import Composite
16
+ from torchrl.envs import step_mdp, TransformedEnv
17
+ from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform, GYMNASIUM_1_ERROR
18
+
19
+ _has_gym = importlib.util.find_spec("gym", None) is not None
20
+ _has_gymnasium = importlib.util.find_spec("gymnasium", None) is not None
21
+
22
+
23
+ class _BaseGymWrapper:
24
+ def __init__(
25
+ self, *, entry_point, to_numpy=False, transform=None, info_keys=None, **kwargs
26
+ ):
27
+ super().__init__()
28
+ torchrl_env = entry_point(**kwargs)
29
+ if transform is not None:
30
+ torchrl_env = TransformedEnv(torchrl_env, transform)
31
+ self.torchrl_env = torchrl_env
32
+ self.info_keys = info_keys
33
+ self.action_space = _torchrl_to_gym_spec_transform(
34
+ self.torchrl_env.action_spec,
35
+ categorical_action_encoding=self.torchrl_env.__dict__.get(
36
+ "categorical_action_encoding", True
37
+ ),
38
+ )
39
+ self.observation_space = _torchrl_to_gym_spec_transform(
40
+ Composite(
41
+ {
42
+ key: self.torchrl_env.full_observation_spec[key].clone()
43
+ for key in self._observation_keys
44
+ }
45
+ ),
46
+ categorical_action_encoding=self.torchrl_env.__dict__.get(
47
+ "categorical_action_encoding", True
48
+ ),
49
+ )
50
+ self.to_numpy = to_numpy
51
+
52
+ def seed(self, seed: int):
53
+ return self.torchrl_env.set_seed(seed)
54
+
55
+ @property
56
+ def info_keys(self):
57
+ return self._info_keys
58
+
59
+ @info_keys.setter
60
+ def info_keys(self, value):
61
+ if value is None:
62
+ value = []
63
+ self._info_keys = [unravel_key(v) for v in value]
64
+
65
+ @property
66
+ def _observation_keys(self):
67
+ obs_keys = self.__dict__.get("_observation_keys", None)
68
+ if obs_keys is None:
69
+ keys = []
70
+ if self.info_keys:
71
+
72
+ def check_tuple_keys(key, info_key):
73
+ if isinstance(info_key, tuple):
74
+ return key[: len(info_key)] == info_key
75
+ else:
76
+ return key[0] == info_key
77
+
78
+ for key in self.torchrl_env.observation_spec.keys(True):
79
+ if isinstance(key, tuple):
80
+ # check if an info key has the same start
81
+ if any(
82
+ check_tuple_keys(key, info_key)
83
+ for info_key in self.info_keys
84
+ ):
85
+ continue
86
+ keys.append(key)
87
+ else:
88
+ if any(
89
+ key == info_key
90
+ for info_key in self.info_keys
91
+ if isinstance(info_key, str)
92
+ ):
93
+ continue
94
+ keys.append(key)
95
+ else:
96
+ keys = self.torchrl_env.observation_spec.keys(True)
97
+ obs_keys = self.__dict__["_observation_keys"] = sorted(
98
+ keys,
99
+ key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
100
+ )
101
+ return obs_keys
102
+
103
+ @property
104
+ def _input_keys(self):
105
+ input_keys = self.__dict__.get("_inp_keys", None)
106
+ if input_keys is None:
107
+ input_keys = self.__dict__["_inp_keys"] = sorted(
108
+ set(self.torchrl_env.state_spec.keys(True)),
109
+ key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
110
+ )
111
+ return input_keys
112
+
113
+ @property
114
+ def _action_keys(self):
115
+ action_keys = self.__dict__.get("_act_keys", None)
116
+ if action_keys is None:
117
+ action_keys = self.__dict__["_act_keys"] = sorted(
118
+ set(self.torchrl_env.full_action_spec.keys(True)),
119
+ key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
120
+ )
121
+ return action_keys
122
+
123
+
124
+ if _has_gymnasium:
125
+ import gymnasium
126
+
127
+ class _TorchRLGymnasiumWrapper(gymnasium.Env, _BaseGymWrapper):
128
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
129
+ def step(self, action): # noqa: F811
130
+ raise ImportError(GYMNASIUM_1_ERROR)
131
+
132
+ @implement_for("gymnasium", None, "1.0.0")
133
+ def step(self, action): # noqa: F811
134
+ action_keys = self._action_keys
135
+ if len(action_keys) == 1:
136
+ self._tensordict.set(action_keys[0], action)
137
+ else:
138
+ raise RuntimeError(
139
+ "Wrapping environments with more than one action key is not supported yet."
140
+ )
141
+ self.torchrl_env.step(self._tensordict)
142
+ _tensordict = step_mdp(self._tensordict)
143
+ observation = self._tensordict.get("next")
144
+ if self.info_keys:
145
+ info = observation.select(*self.info_keys).to_dict()
146
+ else:
147
+ info = {}
148
+ observation = observation.select(*self._observation_keys).to_dict()
149
+ reward = self._tensordict.get(("next", "reward"))
150
+ terminated = self._tensordict.get(("next", "terminated"))
151
+ truncated = self._tensordict.get(
152
+ ("next", "truncated"), torch.zeros_like(terminated)
153
+ )
154
+ self._tensordict = _tensordict.select(*self._input_keys)
155
+ out = (observation, reward, terminated, truncated, info)
156
+ if self.to_numpy:
157
+ out = tree_map(lambda x: x.detach().cpu().numpy(), out)
158
+ return out
159
+
160
+ @implement_for("gymnasium", "1.1.0")
161
+ def step(self, action): # noqa: F811
162
+ action_keys = self._action_keys
163
+ if len(action_keys) == 1:
164
+ self._tensordict.set(action_keys[0], action)
165
+ else:
166
+ raise RuntimeError(
167
+ "Wrapping environments with more than one action key is not supported yet."
168
+ )
169
+ self.torchrl_env.step(self._tensordict)
170
+ _tensordict = step_mdp(self._tensordict)
171
+ observation = self._tensordict.get("next")
172
+ if self.info_keys:
173
+ info = observation.select(*self.info_keys).to_dict()
174
+ else:
175
+ info = {}
176
+ observation = observation.select(*self._observation_keys).to_dict()
177
+ reward = self._tensordict.get(("next", "reward"))
178
+ terminated = self._tensordict.get(("next", "terminated"))
179
+ truncated = self._tensordict.get(
180
+ ("next", "truncated"), torch.zeros_like(terminated)
181
+ )
182
+ self._tensordict = _tensordict.select(*self._input_keys)
183
+ out = (observation, reward, terminated, truncated, info)
184
+ if self.to_numpy:
185
+ out = tree_map(lambda x: x.detach().cpu().numpy(), out)
186
+ return out
187
+
188
+ @implement_for("gymnasium", None, "1.0.0")
189
+ def reset(
190
+ self, seed: int | None = None, options: dict | None = None
191
+ ): # noqa: F811
192
+ if seed is not None:
193
+ self.torchrl_env.set_seed(seed)
194
+ if options is None:
195
+ options = {}
196
+ self._tensordict = self.torchrl_env.reset(**options)
197
+ observation = self._tensordict
198
+ if self.info_keys:
199
+ info = observation.select(*self.info_keys).to_dict()
200
+ else:
201
+ info = {}
202
+ observation = observation.select(*self._observation_keys).to_dict()
203
+ out = observation, info
204
+ if self.to_numpy:
205
+ out = tree_map(lambda x: x.detach().cpu().numpy(), out)
206
+ return out
207
+
208
+ @implement_for("gymnasium", "1.0.0", "1.1.0")
209
+ def reset(self): # noqa: F811
210
+ raise ImportError(GYMNASIUM_1_ERROR)
211
+
212
+ @implement_for("gymnasium", "1.1.0")
213
+ def reset( # noqa: F811
214
+ self, seed: int | None = None, options: dict | None = None
215
+ ):
216
+ if seed is not None:
217
+ self.torchrl_env.set_seed(seed)
218
+ if options is None:
219
+ options = {}
220
+ self._tensordict = self.torchrl_env.reset(**options)
221
+ observation = self._tensordict
222
+ if self.info_keys:
223
+ info = observation.select(*self.info_keys).to_dict()
224
+ else:
225
+ info = {}
226
+ observation = observation.select(*self._observation_keys).to_dict()
227
+ out = observation, info
228
+ if self.to_numpy:
229
+ out = tree_map(lambda x: x.detach().cpu().numpy(), out)
230
+ return out
231
+
232
+ else:
233
+
234
+ class _TorchRLGymnasiumWrapper:
235
+ # placeholder
236
+ def __init__(self, *args, **kwargs):
237
+ raise ImportError("Gymnasium could not be found.")
238
+
239
+
240
+ if _has_gym:
241
+ import gym
242
+
243
+ class _TorchRLGymWrapper(gym.Env, _BaseGymWrapper):
244
+ @implement_for("gym", "0.26", None)
245
+ def step(self, action): # noqa: F811
246
+ action_keys = self._action_keys
247
+ if len(action_keys) == 1:
248
+ self._tensordict.set(action_keys[0], action)
249
+ else:
250
+ raise RuntimeError(
251
+ "Wrapping environments with more than one action key is not supported yet."
252
+ )
253
+ self.torchrl_env.step(self._tensordict)
254
+ _tensordict = step_mdp(self._tensordict)
255
+ observation = self._tensordict.get("next")
256
+ if self.info_keys:
257
+ info = observation.select(*self.info_keys).to_dict()
258
+ else:
259
+ info = {}
260
+ observation = observation.select(*self._observation_keys).to_dict()
261
+ reward = self._tensordict.get(("next", "reward"))
262
+ terminated = self._tensordict.get(("next", "terminated"))
263
+ truncated = self._tensordict.get(
264
+ ("next", "truncated"), torch.zeros_like(terminated)
265
+ )
266
+ self._tensordict = _tensordict.select(*self._input_keys)
267
+ out = (observation, reward, terminated, truncated, info)
268
+ if self.to_numpy:
269
+ out = tree_map(lambda x: x.detach().cpu().numpy(), out)
270
+ return out
271
+
272
+ @implement_for("gym", None, "0.26")
273
+ def step(self, action): # noqa: F811
274
+ action_keys = self._action_keys
275
+ if len(action_keys) == 1:
276
+ self._tensordict.set(action_keys[0], action)
277
+ else:
278
+ raise RuntimeError(
279
+ "Wrapping environments with more than one action key is not supported yet."
280
+ )
281
+ self.torchrl_env.step(self._tensordict)
282
+ _tensordict = step_mdp(self._tensordict)
283
+ observation = self._tensordict.get("next")
284
+ if self.info_keys:
285
+ info = observation.select(*self.info_keys).to_dict()
286
+ else:
287
+ info = {}
288
+ observation = observation.select(*self._observation_keys).to_dict()
289
+ reward = self._tensordict.get(("next", "reward"))
290
+ done = self._tensordict.get(("next", "done"))
291
+ self._tensordict = _tensordict.select(*self._input_keys)
292
+ out = (observation, reward, done, info)
293
+ if self.to_numpy:
294
+ out = tree_map(lambda x: x.detach().cpu().numpy(), out)
295
+ return out
296
+
297
+ @implement_for("gym", None, "0.26")
298
+ def reset(self): # noqa: F811
299
+ self._tensordict = self.torchrl_env.reset()
300
+ observation = self._tensordict
301
+ observation = observation.select(*self._observation_keys).to_dict()
302
+ out = observation
303
+ if self.to_numpy:
304
+ out = tree_map(lambda x: x.detach().cpu().numpy(), out)
305
+ return out
306
+
307
+ @implement_for("gym", "0.26", None)
308
+ def reset(self): # noqa: F811
309
+ self._tensordict = self.torchrl_env.reset()
310
+ observation = self._tensordict
311
+ if self.info_keys:
312
+ info = observation.select(*self.info_keys).to_dict()
313
+ else:
314
+ info = {}
315
+ observation = observation.select(*self._observation_keys).to_dict()
316
+ out = observation, info
317
+ if self.to_numpy:
318
+ out = tree_map(lambda x: x.detach().cpu().numpy(), out)
319
+ return out
320
+
321
+ else:
322
+
323
+ class _TorchRLGymWrapper:
324
+ # placeholder
325
+ def __init__(self, *args, **kwargs):
326
+ raise ImportError("Gym could not be found.")