torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,361 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from omegaconf import MISSING
12
+ from torchrl.envs.libs.gym import set_gym_backend
13
+ from torchrl.envs.transforms.transforms import DoubleToFloat
14
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
15
+
16
+
17
+ @dataclass
18
+ class EnvLibsConfig(ConfigBase):
19
+ """Base configuration class for environment libs."""
20
+
21
+ _partial_: bool = False
22
+
23
+ def __post_init__(self) -> None:
24
+ """Post-initialization hook for environment libs configurations."""
25
+
26
+
27
+ @dataclass
28
+ class GymEnvConfig(EnvLibsConfig):
29
+ """Configuration for GymEnv environment."""
30
+
31
+ env_name: str = MISSING
32
+ categorical_action_encoding: bool = False
33
+ from_pixels: bool = False
34
+ pixels_only: bool = True
35
+ frame_skip: int = 1
36
+ device: str = "cpu"
37
+ batch_size: list[int] | None = None
38
+ allow_done_after_reset: bool = False
39
+ convert_actions_to_numpy: bool = True
40
+ missing_obs_value: Any = None
41
+ disable_env_checker: bool | None = None
42
+ render_mode: str | None = None
43
+ num_envs: int = 0
44
+ backend: str = "gymnasium"
45
+ _target_: str = "torchrl.trainers.algorithms.configs.envs_libs.make_gym_env"
46
+
47
+ def __post_init__(self) -> None:
48
+ """Post-initialization hook for GymEnv configuration."""
49
+ super().__post_init__()
50
+
51
+
52
+ def make_gym_env(
53
+ env_name: str,
54
+ backend: str = "gymnasium",
55
+ from_pixels: bool = False,
56
+ double_to_float: bool = False,
57
+ **kwargs,
58
+ ):
59
+ """Create a Gym/Gymnasium environment.
60
+
61
+ Args:
62
+ env_name: Name of the environment to create.
63
+ backend: Backend to use (gym or gymnasium).
64
+ from_pixels: Whether to use pixel observations.
65
+ double_to_float: Whether to convert double to float.
66
+
67
+ Returns:
68
+ The created environment instance.
69
+ """
70
+ from torchrl.envs.libs.gym import GymEnv
71
+
72
+ if backend is not None:
73
+ with set_gym_backend(backend):
74
+ env = GymEnv(env_name, from_pixels=from_pixels, **kwargs)
75
+ else:
76
+ env = GymEnv(env_name, from_pixels=from_pixels, **kwargs)
77
+
78
+ if double_to_float:
79
+ env = env.append_transform(DoubleToFloat(in_keys=["observation"]))
80
+
81
+ return env
82
+
83
+
84
+ @dataclass
85
+ class MOGymEnvConfig(EnvLibsConfig):
86
+ """Configuration for MOGymEnv environment."""
87
+
88
+ env_name: str = MISSING
89
+ categorical_action_encoding: bool = False
90
+ from_pixels: bool = False
91
+ pixels_only: bool = True
92
+ frame_skip: int | None = None
93
+ device: str = "cpu"
94
+ batch_size: list[int] | None = None
95
+ allow_done_after_reset: bool = False
96
+ convert_actions_to_numpy: bool = True
97
+ missing_obs_value: Any = None
98
+ backend: str | None = None
99
+ disable_env_checker: bool | None = None
100
+ render_mode: str | None = None
101
+ num_envs: int = 0
102
+ _target_: str = "torchrl.envs.libs.gym.MOGymEnv"
103
+
104
+ def __post_init__(self) -> None:
105
+ """Post-initialization hook for MOGymEnv configuration."""
106
+ super().__post_init__()
107
+
108
+
109
+ @dataclass
110
+ class BraxEnvConfig(EnvLibsConfig):
111
+ """Configuration for BraxEnv environment."""
112
+
113
+ env_name: str = MISSING
114
+ categorical_action_encoding: bool = False
115
+ cache_clear_frequency: int | None = None
116
+ from_pixels: bool = False
117
+ frame_skip: int | None = None
118
+ device: str = "cpu"
119
+ batch_size: list[int] | None = None
120
+ allow_done_after_reset: bool = False
121
+ requires_grad: bool = False
122
+ _target_: str = "torchrl.envs.libs.brax.BraxEnv"
123
+
124
+ def __post_init__(self) -> None:
125
+ """Post-initialization hook for BraxEnv configuration."""
126
+ super().__post_init__()
127
+
128
+
129
+ @dataclass
130
+ class DMControlEnvConfig(EnvLibsConfig):
131
+ """Configuration for DMControlEnv environment."""
132
+
133
+ env_name: str = MISSING
134
+ task_name: str = MISSING
135
+ from_pixels: bool = False
136
+ pixels_only: bool = True
137
+ frame_skip: int | None = None
138
+ device: str = "cpu"
139
+ batch_size: list[int] | None = None
140
+ allow_done_after_reset: bool = False
141
+ _target_: str = "torchrl.envs.libs.dm_control.DMControlEnv"
142
+
143
+ def __post_init__(self) -> None:
144
+ """Post-initialization hook for DMControlEnv configuration."""
145
+ super().__post_init__()
146
+
147
+
148
+ @dataclass
149
+ class HabitatEnvConfig(EnvLibsConfig):
150
+ """Configuration for HabitatEnv environment."""
151
+
152
+ env_name: str = MISSING
153
+ from_pixels: bool = False
154
+ pixels_only: bool = True
155
+ frame_skip: int | None = None
156
+ device: str = "cpu"
157
+ batch_size: list[int] | None = None
158
+ allow_done_after_reset: bool = False
159
+ _target_: str = "torchrl.envs.libs.habitat.HabitatEnv"
160
+
161
+ def __post_init__(self) -> None:
162
+ """Post-initialization hook for HabitatEnv configuration."""
163
+ super().__post_init__()
164
+
165
+
166
+ @dataclass
167
+ class IsaacGymEnvConfig(EnvLibsConfig):
168
+ """Configuration for IsaacGymEnv environment."""
169
+
170
+ env_name: str = MISSING
171
+ from_pixels: bool = False
172
+ pixels_only: bool = True
173
+ frame_skip: int | None = None
174
+ device: str = "cpu"
175
+ batch_size: list[int] | None = None
176
+ allow_done_after_reset: bool = False
177
+ _target_: str = "torchrl.envs.libs.isaacgym.IsaacGymEnv"
178
+
179
+ def __post_init__(self) -> None:
180
+ """Post-initialization hook for IsaacGymEnv configuration."""
181
+ super().__post_init__()
182
+
183
+
184
+ @dataclass
185
+ class JumanjiEnvConfig(EnvLibsConfig):
186
+ """Configuration for JumanjiEnv environment."""
187
+
188
+ env_name: str = MISSING
189
+ from_pixels: bool = False
190
+ pixels_only: bool = True
191
+ frame_skip: int | None = None
192
+ device: str = "cpu"
193
+ batch_size: list[int] | None = None
194
+ allow_done_after_reset: bool = False
195
+ _target_: str = "torchrl.envs.libs.jumanji.JumanjiEnv"
196
+
197
+ def __post_init__(self) -> None:
198
+ """Post-initialization hook for JumanjiEnv configuration."""
199
+ super().__post_init__()
200
+
201
+
202
+ @dataclass
203
+ class MeltingpotEnvConfig(EnvLibsConfig):
204
+ """Configuration for MeltingpotEnv environment."""
205
+
206
+ env_name: str = MISSING
207
+ from_pixels: bool = False
208
+ pixels_only: bool = True
209
+ frame_skip: int | None = None
210
+ device: str = "cpu"
211
+ batch_size: list[int] | None = None
212
+ allow_done_after_reset: bool = False
213
+ _target_: str = "torchrl.envs.libs.meltingpot.MeltingpotEnv"
214
+
215
+ def __post_init__(self) -> None:
216
+ """Post-initialization hook for MeltingpotEnv configuration."""
217
+ super().__post_init__()
218
+
219
+
220
+ @dataclass
221
+ class OpenMLEnvConfig(EnvLibsConfig):
222
+ """Configuration for OpenMLEnv environment."""
223
+
224
+ env_name: str = MISSING
225
+ from_pixels: bool = False
226
+ pixels_only: bool = True
227
+ frame_skip: int | None = None
228
+ device: str = "cpu"
229
+ batch_size: list[int] | None = None
230
+ allow_done_after_reset: bool = False
231
+ _target_: str = "torchrl.envs.libs.openml.OpenMLEnv"
232
+
233
+ def __post_init__(self) -> None:
234
+ """Post-initialization hook for OpenMLEnv configuration."""
235
+ super().__post_init__()
236
+
237
+
238
+ @dataclass
239
+ class OpenSpielEnvConfig(EnvLibsConfig):
240
+ """Configuration for OpenSpielEnv environment."""
241
+
242
+ env_name: str = MISSING
243
+ from_pixels: bool = False
244
+ pixels_only: bool = True
245
+ frame_skip: int | None = None
246
+ device: str = "cpu"
247
+ batch_size: list[int] | None = None
248
+ allow_done_after_reset: bool = False
249
+ _target_: str = "torchrl.envs.libs.openspiel.OpenSpielEnv"
250
+
251
+ def __post_init__(self) -> None:
252
+ """Post-initialization hook for OpenSpielEnv configuration."""
253
+ super().__post_init__()
254
+
255
+
256
+ @dataclass
257
+ class PettingZooEnvConfig(EnvLibsConfig):
258
+ """Configuration for PettingZooEnv environment."""
259
+
260
+ env_name: str = MISSING
261
+ from_pixels: bool = False
262
+ pixels_only: bool = True
263
+ frame_skip: int | None = None
264
+ device: str = "cpu"
265
+ batch_size: list[int] | None = None
266
+ allow_done_after_reset: bool = False
267
+ _target_: str = "torchrl.envs.libs.pettingzoo.PettingZooEnv"
268
+
269
+ def __post_init__(self) -> None:
270
+ """Post-initialization hook for PettingZooEnv configuration."""
271
+ super().__post_init__()
272
+
273
+
274
+ @dataclass
275
+ class RoboHiveEnvConfig(EnvLibsConfig):
276
+ """Configuration for RoboHiveEnv environment."""
277
+
278
+ env_name: str = MISSING
279
+ from_pixels: bool = False
280
+ pixels_only: bool = True
281
+ frame_skip: int | None = None
282
+ device: str = "cpu"
283
+ batch_size: list[int] | None = None
284
+ allow_done_after_reset: bool = False
285
+ _target_: str = "torchrl.envs.libs.robohive.RoboHiveEnv"
286
+
287
+ def __post_init__(self) -> None:
288
+ """Post-initialization hook for RoboHiveEnv configuration."""
289
+ super().__post_init__()
290
+
291
+
292
+ @dataclass
293
+ class SMACv2EnvConfig(EnvLibsConfig):
294
+ """Configuration for SMACv2Env environment."""
295
+
296
+ env_name: str = MISSING
297
+ from_pixels: bool = False
298
+ pixels_only: bool = True
299
+ frame_skip: int | None = None
300
+ device: str = "cpu"
301
+ batch_size: list[int] | None = None
302
+ allow_done_after_reset: bool = False
303
+ _target_: str = "torchrl.envs.libs.smacv2.SMACv2Env"
304
+
305
+ def __post_init__(self) -> None:
306
+ """Post-initialization hook for SMACv2Env configuration."""
307
+ super().__post_init__()
308
+
309
+
310
+ @dataclass
311
+ class UnityMLAgentsEnvConfig(EnvLibsConfig):
312
+ """Configuration for UnityMLAgentsEnv environment."""
313
+
314
+ env_name: str = MISSING
315
+ from_pixels: bool = False
316
+ pixels_only: bool = True
317
+ frame_skip: int | None = None
318
+ device: str = "cpu"
319
+ batch_size: list[int] | None = None
320
+ allow_done_after_reset: bool = False
321
+ _target_: str = "torchrl.envs.libs.unity_mlagents.UnityMLAgentsEnv"
322
+
323
+ def __post_init__(self) -> None:
324
+ """Post-initialization hook for UnityMLAgentsEnv configuration."""
325
+ super().__post_init__()
326
+
327
+
328
+ @dataclass
329
+ class VmasEnvConfig(EnvLibsConfig):
330
+ """Configuration for VmasEnv environment."""
331
+
332
+ env_name: str = MISSING
333
+ from_pixels: bool = False
334
+ pixels_only: bool = True
335
+ frame_skip: int | None = None
336
+ device: str = "cpu"
337
+ batch_size: list[int] | None = None
338
+ allow_done_after_reset: bool = False
339
+ _target_: str = "torchrl.envs.libs.vmas.VmasEnv"
340
+
341
+ def __post_init__(self) -> None:
342
+ """Post-initialization hook for VmasEnv configuration."""
343
+ super().__post_init__()
344
+
345
+
346
+ @dataclass
347
+ class MultiThreadedEnvConfig(EnvLibsConfig):
348
+ """Configuration for MultiThreadedEnv environment."""
349
+
350
+ env_name: str = MISSING
351
+ from_pixels: bool = False
352
+ pixels_only: bool = True
353
+ frame_skip: int | None = None
354
+ device: str = "cpu"
355
+ batch_size: list[int] | None = None
356
+ allow_done_after_reset: bool = False
357
+ _target_: str = "torchrl.envs.libs.envpool.MultiThreadedEnv"
358
+
359
+ def __post_init__(self) -> None:
360
+ """Post-initialization hook for MultiThreadedEnv configuration."""
361
+ super().__post_init__()
@@ -0,0 +1,80 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+
10
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
11
+
12
+
13
+ @dataclass
14
+ class LoggerConfig(ConfigBase):
15
+ """A class to configure a logger.
16
+
17
+ Args:
18
+ logger: The logger to use.
19
+ """
20
+
21
+ def __post_init__(self) -> None:
22
+ pass
23
+
24
+
25
+ @dataclass
26
+ class WandbLoggerConfig(LoggerConfig):
27
+ """A class to configure a Wandb logger.
28
+
29
+ .. seealso::
30
+ :class:`~torchrl.record.loggers.wandb.WandbLogger`
31
+ """
32
+
33
+ exp_name: str
34
+ offline: bool = False
35
+ save_dir: str | None = None
36
+ id: str | None = None
37
+ project: str | None = None
38
+ video_fps: int = 32
39
+ log_dir: str | None = None
40
+
41
+ _target_: str = "torchrl.record.loggers.wandb.WandbLogger"
42
+
43
+ def __post_init__(self) -> None:
44
+ pass
45
+
46
+
47
+ @dataclass
48
+ class TensorboardLoggerConfig(LoggerConfig):
49
+ """A class to configure a Tensorboard logger.
50
+
51
+ .. seealso::
52
+ :class:`~torchrl.record.loggers.tensorboard.TensorboardLogger`
53
+ """
54
+
55
+ exp_name: str
56
+ log_dir: str = "tb_logs"
57
+
58
+ _target_: str = "torchrl.record.loggers.tensorboard.TensorboardLogger"
59
+
60
+ def __post_init__(self) -> None:
61
+ pass
62
+
63
+
64
+ @dataclass
65
+ class CSVLoggerConfig(LoggerConfig):
66
+ """A class to configure a CSV logger.
67
+
68
+ .. seealso::
69
+ :class:`~torchrl.record.loggers.csv.CSVLogger`
70
+ """
71
+
72
+ exp_name: str
73
+ log_dir: str | None = None
74
+ video_format: str = "pt"
75
+ video_fps: int = 30
76
+
77
+ _target_: str = "torchrl.record.loggers.csv.CSVLogger"
78
+
79
+ def __post_init__(self) -> None:
80
+ pass