torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,351 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import warnings
5
+
6
+ import torch
7
+ from tensordict import TensorDict
8
+
9
+ from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
10
+ from torchrl.envs.common import _EnvWrapper
11
+ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend
12
+ from torchrl.envs.utils import _classproperty
13
+
14
+ __all__ = ["ProcgenWrapper", "ProcgenEnv"]
15
+
16
+ _has_procgen = importlib.util.find_spec("procgen") is not None
17
+
18
+ if _has_procgen:
19
+ import procgen # type: ignore
20
+ else:
21
+ procgen = None # type: ignore
22
+
23
+
24
+ def _get_procgen_envs() -> list[str]:
25
+ if not _has_procgen:
26
+ raise ImportError("procgen is not installed.")
27
+ env_names = getattr(procgen, "ENV_NAMES", None)
28
+ if env_names:
29
+ return list(env_names)
30
+ try:
31
+ env_mod = importlib.import_module("procgen.env")
32
+ return list(getattr(env_mod, "ENV_NAMES", []))
33
+ except Exception:
34
+ return list(getattr(procgen, "ENV_NAMES", []))
35
+
36
+
37
+ def _get_num_envs(env) -> int | None:
38
+ """Get the number of parallel environments from a procgen env."""
39
+ # procgen.ProcgenEnv returns a ToGymEnv wrapper; the num attribute
40
+ # may be on the wrapper, the inner env (.env), or as num_envs
41
+ return (
42
+ getattr(env, "num", None)
43
+ or getattr(env, "nenvs", None)
44
+ or getattr(env, "num_envs", None)
45
+ or getattr(getattr(env, "env", None), "num", None)
46
+ )
47
+
48
+
49
+ class ProcgenWrapper(_EnvWrapper):
50
+ """OpenAI Procgen environment wrapper.
51
+
52
+ Wraps an existing :class:`procgen.ProcgenEnv` instance and exposes it
53
+ under the TorchRL environment API.
54
+
55
+ This wrapper is responsible for:
56
+ - Converting Procgen observations (``{"rgb": np.ndarray}``) to Torch tensors
57
+ - Handling vectorized Procgen semantics
58
+ - Producing TorchRL-compliant ``TensorDict`` outputs
59
+
60
+ Args:
61
+ env (procgen.ProcgenEnv): an already constructed Procgen environment.
62
+
63
+ Keyword Args:
64
+ device (torch.device | str, optional): device on which tensors are placed.
65
+ batch_size (torch.Size, optional): expected batch size.
66
+ allow_done_after_reset (bool, optional): tolerate done right after reset.
67
+
68
+ Attributes:
69
+ available_envs (List[str]): list of Procgen environment ids.
70
+
71
+ Examples:
72
+ >>> import procgen
73
+ >>> from torchrl.envs.libs.procgen import ProcgenWrapper
74
+ >>> env = procgen.ProcgenEnv(4, "coinrun")
75
+ >>> env = ProcgenWrapper(env=env)
76
+ >>> td = env.reset()
77
+ >>> print(td)
78
+ TensorDict(
79
+ fields={
80
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
81
+ observation: Tensor(shape=torch.Size([4, 3, 64, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
82
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
83
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
84
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
85
+ batch_size=torch.Size([]),
86
+ device=None,
87
+ is_shared=False
88
+ )
89
+ >>> print(td["observation"].shape)
90
+ torch.Size([4, 3, 64, 64])
91
+ """
92
+
93
+ git_url = "https://github.com/openai/procgen"
94
+ lib = procgen
95
+
96
+ @_classproperty
97
+ def available_envs(cls) -> list[str]:
98
+ if not _has_procgen:
99
+ return []
100
+ return _get_procgen_envs()
101
+
102
+ def __init__(self, env, **kwargs):
103
+ # Detect num_envs before calling parent __init__ so batch_size is set
104
+ # before _make_specs() is called
105
+ n = _get_num_envs(env)
106
+ if n is not None and "batch_size" not in kwargs:
107
+ kwargs["batch_size"] = torch.Size([n])
108
+ super().__init__(env=env, **kwargs)
109
+
110
+ def _check_kwargs(self, kwargs: dict) -> None:
111
+ if "env" not in kwargs:
112
+ raise TypeError("ProcgenWrapper requires an 'env' argument.")
113
+
114
+ def _build_env(self, env, **_) -> procgen.ProcgenEnv:
115
+ return env
116
+
117
+ @property
118
+ def observation_space(self):
119
+ # gym3 uses ob_space instead of observation_space
120
+ return getattr(self._env, "observation_space", None) or self._env.ob_space
121
+
122
+ @property
123
+ def action_space(self):
124
+ # gym3 uses ac_space instead of action_space
125
+ return getattr(self._env, "action_space", None) or self._env.ac_space
126
+
127
+ def _make_specs(self, env) -> None:
128
+ from torchrl.data.tensor_specs import Bounded
129
+
130
+ batch_size = self.batch_size
131
+
132
+ # Procgen observation is rgb with shape (64, 64, 3) per env
133
+ # After permuting in _reset/_step it becomes (3, 64, 64) per env
134
+ # With batch_size, full shape is (*batch_size, 3, 64, 64)
135
+ self.observation_spec = Composite(
136
+ observation=Bounded(
137
+ low=0,
138
+ high=255,
139
+ shape=(*batch_size, 3, 64, 64),
140
+ dtype=torch.uint8,
141
+ device=self.device,
142
+ ),
143
+ shape=batch_size,
144
+ )
145
+
146
+ # Procgen has Discrete(15) action space
147
+ with set_gym_backend("gym"):
148
+ action_spec = _gym_to_torchrl_spec_transform(
149
+ self.action_space,
150
+ categorical_action_encoding=True,
151
+ device=self.device,
152
+ )
153
+ # Expand action spec to include batch dimension
154
+ if len(batch_size) > 0 and action_spec.shape[: len(batch_size)] != batch_size:
155
+ action_spec = action_spec.expand(*batch_size, *action_spec.shape)
156
+ self.action_spec = action_spec
157
+
158
+ self.reward_spec = Composite(
159
+ reward=Unbounded(
160
+ shape=(*batch_size, 1), dtype=torch.float32, device=self.device
161
+ ),
162
+ shape=batch_size,
163
+ )
164
+
165
+ done_leaf = Categorical(
166
+ n=2, shape=(*batch_size, 1), dtype=torch.bool, device=self.device
167
+ )
168
+ self.done_spec = Composite(
169
+ done=done_leaf.clone(),
170
+ terminated=done_leaf.clone(),
171
+ truncated=done_leaf.clone(),
172
+ shape=batch_size,
173
+ )
174
+
175
+ def _init_env(self) -> None:
176
+ # batch_size is set in __init__ before _make_specs() is called
177
+ try:
178
+ self._env.reset()
179
+ except Exception:
180
+ pass
181
+
182
+ def _set_seed(self, seed: int | None) -> None:
183
+ if seed is None:
184
+ return
185
+ try:
186
+ if hasattr(self._env, "seed"):
187
+ self._env.seed(seed)
188
+ elif hasattr(self._env, "set_seed"):
189
+ self._env.set_seed(seed)
190
+ elif hasattr(self._env, "rand_seed"):
191
+ self._env.rand_seed = seed
192
+ except Exception:
193
+ warnings.warn("ProcgenWrapper: seeding failed (best-effort).")
194
+
195
+ def _reset(self, tensordict=None, **kwargs) -> TensorDict:
196
+ obs = self._env.reset()
197
+ if isinstance(obs, (tuple, list)):
198
+ obs = obs[0]
199
+
200
+ rgb = torch.from_numpy(obs["rgb"]).to(self.device).permute(0, 3, 1, 2)
201
+
202
+ td = TensorDict(
203
+ {"observation": rgb},
204
+ batch_size=self.batch_size,
205
+ device=self.device,
206
+ )
207
+
208
+ # Set done flags (required by TorchRL)
209
+ zeros = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.bool)
210
+ td.set("done", zeros)
211
+ td.set("terminated", zeros.clone())
212
+ td.set("truncated", zeros.clone())
213
+
214
+ return td
215
+
216
+ def _step(self, tensordict: TensorDict, **kwargs) -> TensorDict:
217
+ action = tensordict.get("action")
218
+ # Procgen expects numpy arrays with shape (num_envs,)
219
+ action_np = action.cpu().numpy().flatten()
220
+ obs, reward, done, info = self._env.step(action_np)
221
+
222
+ rgb = torch.from_numpy(obs["rgb"]).to(self.device).permute(0, 3, 1, 2)
223
+ reward = torch.as_tensor(reward, device=self.device).view(-1, 1)
224
+ done = torch.as_tensor(done, device=self.device).view(-1, 1).bool()
225
+
226
+ td = TensorDict(
227
+ {
228
+ "observation": rgb,
229
+ "reward": reward,
230
+ "done": done,
231
+ "terminated": done.clone(),
232
+ "truncated": torch.zeros_like(done),
233
+ },
234
+ batch_size=self.batch_size,
235
+ device=self.device,
236
+ )
237
+
238
+ # Expose info dict fields (e.g., level_seed, prev_level_complete)
239
+ # Note: procgen may return info as a list of dicts or a single dict
240
+ if info and isinstance(info, dict):
241
+ for key, val in info.items():
242
+ td.set(key, torch.as_tensor(val, device=self.device))
243
+
244
+ return td
245
+
246
+
247
+ class ProcgenEnv(ProcgenWrapper):
248
+ """OpenAI Procgen environment.
249
+
250
+ Convenience class that constructs a Procgen environment by name.
251
+
252
+ See https://github.com/openai/procgen for more details on Procgen.
253
+
254
+ Args:
255
+ env_name (str): name of the Procgen game (e.g. ``"coinrun"``).
256
+ Available games: bigfish, bossfight, caveflyer, chaser, climber,
257
+ coinrun, dodgeball, fruitbot, heist, jumper, leaper, maze, miner,
258
+ ninja, plunder, starpilot.
259
+
260
+ Keyword Args:
261
+ num_envs (int, optional): number of parallel environments. Defaults to 1.
262
+ distribution_mode (str, optional): Procgen distribution mode. One of
263
+ ``"easy"``, ``"hard"``, ``"extreme"``, ``"memory"``, ``"exploration"``.
264
+ Defaults to ``"hard"``.
265
+ start_level (int, optional): the level id to start from. Defaults to 0.
266
+ num_levels (int, optional): the number of unique levels that can be
267
+ generated. Set to 0 for unlimited levels. Defaults to 0.
268
+ use_sequential_levels (bool, optional): if ``True``, levels are played
269
+ sequentially rather than randomly. Defaults to ``False``.
270
+ center_agent (bool, optional): if ``True``, observations are centered
271
+ on the agent. Defaults to ``True``.
272
+ use_backgrounds (bool, optional): if ``True``, include background
273
+ assets. Defaults to ``True``.
274
+ use_monochrome_assets (bool, optional): if ``True``, use monochrome
275
+ assets for simpler visuals. Defaults to ``False``.
276
+ restrict_themes (bool, optional): if ``True``, restrict visual themes.
277
+ Defaults to ``False``.
278
+ use_generated_assets (bool, optional): if ``True``, use procedurally
279
+ generated assets. Defaults to ``False``.
280
+ paint_vel_info (bool, optional): if ``True``, paint velocity info on
281
+ observations. Defaults to ``False``.
282
+ seed (int, optional): random seed for the environment. Note that procgen
283
+ environments must be seeded at construction time; calling ``set_seed()``
284
+ after construction will not work reliably.
285
+ render_mode (str, optional): render mode for the environment.
286
+ device (torch.device | str, optional): device for tensors.
287
+ allow_done_after_reset (bool, optional): tolerate done after reset.
288
+
289
+ Examples:
290
+ >>> from torchrl.envs.libs.procgen import ProcgenEnv
291
+ >>> env = ProcgenEnv("coinrun", num_envs=8)
292
+ >>> td = env.reset()
293
+ >>> print(td)
294
+ TensorDict(
295
+ fields={
296
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
297
+ observation: Tensor(shape=torch.Size([8, 3, 64, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
298
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
299
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
300
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
301
+ batch_size=torch.Size([]),
302
+ device=None,
303
+ is_shared=False
304
+ )
305
+ >>> print(td["observation"].shape)
306
+ torch.Size([8, 3, 64, 64])
307
+ >>> print(env.available_envs)
308
+ ['bigfish', 'bossfight', 'caveflyer', 'chaser', 'climber', 'coinrun', 'dodgeball', 'fruitbot', 'heist', 'jumper', 'leaper', 'maze', 'miner', 'ninja', 'plunder', 'starpilot']
309
+ """
310
+
311
+ def __init__(self, env_name: str, **kwargs):
312
+ if not _has_procgen:
313
+ raise ImportError(
314
+ "procgen python package was not found. "
315
+ "Install it from https://github.com/openai/procgen."
316
+ )
317
+
318
+ if env_name not in self.available_envs:
319
+ raise ValueError(
320
+ f"Unknown Procgen environment '{env_name}'. "
321
+ f"Available envs: {self.available_envs}"
322
+ )
323
+
324
+ num_envs = kwargs.pop("num_envs", 1)
325
+ # Procgen uses rand_seed for seeding at construction time
326
+ seed = kwargs.pop("seed", None)
327
+ if seed is not None:
328
+ kwargs["rand_seed"] = seed
329
+ # Extract procgen-specific kwargs before passing to parent
330
+ procgen_kwargs = {}
331
+ for key in list(kwargs.keys()):
332
+ if key in (
333
+ "distribution_mode",
334
+ "start_level",
335
+ "num_levels",
336
+ "use_sequential_levels",
337
+ "center_agent",
338
+ "use_backgrounds",
339
+ "use_monochrome_assets",
340
+ "restrict_themes",
341
+ "use_generated_assets",
342
+ "paint_vel_info",
343
+ "render_mode",
344
+ "rand_seed",
345
+ ):
346
+ procgen_kwargs[key] = kwargs.pop(key)
347
+ env = procgen.ProcgenEnv(num_envs, env_name, **procgen_kwargs)
348
+ # Pass batch_size to parent; it will be set before _make_specs()
349
+ if "batch_size" not in kwargs:
350
+ kwargs["batch_size"] = torch.Size([num_envs])
351
+ super().__init__(env=env, **kwargs)