torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .common import ModelBasedEnvBase
7
+ from .dreamer import DreamerDecoder, DreamerEnv
8
+
9
+ __all__ = ["ModelBasedEnvBase", "DreamerDecoder", "DreamerEnv"]
@@ -0,0 +1,180 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ import warnings
9
+
10
+ import torch
11
+ from tensordict import TensorDict
12
+ from tensordict.nn import TensorDictModule
13
+ from torchrl.data.utils import DEVICE_TYPING
14
+ from torchrl.envs.common import EnvBase
15
+
16
+
17
+ class ModelBasedEnvBase(EnvBase):
18
+ """Basic environment for Model Based RL sota-implementations.
19
+
20
+ Wrapper around the model of the MBRL algorithm.
21
+ It is meant to give an env framework to a world model (including but not limited to observations, reward, done state and safety constraints models).
22
+ and to behave as a classical environment.
23
+
24
+ This is a base class for other environments and it should not be used directly.
25
+
26
+ Example:
27
+ >>> import torch
28
+ >>> from tensordict import TensorDict
29
+ >>> from torchrl.data import Composite, Unbounded
30
+ >>> class MyMBEnv(ModelBasedEnvBase):
31
+ ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
32
+ ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
33
+ ... self.observation_spec = Composite(
34
+ ... hidden_observation=Unbounded((4,))
35
+ ... )
36
+ ... self.state_spec = Composite(
37
+ ... hidden_observation=Unbounded((4,)),
38
+ ... )
39
+ ... self.action_spec = Unbounded((1,))
40
+ ... self.reward_spec = Unbounded((1,))
41
+ ...
42
+ ... def _reset(self, tensordict: TensorDict) -> TensorDict:
43
+ ... tensordict = TensorDict(
44
+ ... batch_size=self.batch_size,
45
+ ... device=self.device,
46
+ ... )
47
+ ... tensordict = tensordict.update(self.state_spec.rand())
48
+ ... tensordict = tensordict.update(self.observation_spec.rand())
49
+ ... return tensordict
50
+ >>> # This environment is used as follows:
51
+ >>> import torch.nn as nn
52
+ >>> from torchrl.modules import MLP, WorldModelWrapper
53
+ >>> world_model = WorldModelWrapper(
54
+ ... TensorDictModule(
55
+ ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
56
+ ... in_keys=["hidden_observation", "action"],
57
+ ... out_keys=["hidden_observation"],
58
+ ... ),
59
+ ... TensorDictModule(
60
+ ... nn.Linear(4, 1),
61
+ ... in_keys=["hidden_observation"],
62
+ ... out_keys=["reward"],
63
+ ... ),
64
+ ... )
65
+ >>> env = MyMBEnv(world_model)
66
+ >>> tensordict = env.rollout(max_steps=10)
67
+ >>> print(tensordict)
68
+ TensorDict(
69
+ fields={
70
+ action: Tensor(torch.Size([10, 1]), dtype=torch.float32),
71
+ done: Tensor(torch.Size([10, 1]), dtype=torch.bool),
72
+ hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32),
73
+ next: LazyStackedTensorDict(
74
+ fields={
75
+ hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
76
+ batch_size=torch.Size([10]),
77
+ device=cpu,
78
+ is_shared=False),
79
+ reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)},
80
+ batch_size=torch.Size([10]),
81
+ device=cpu,
82
+ is_shared=False)
83
+
84
+
85
+ Properties:
86
+ observation_spec (Composite): sampling spec of the observations;
87
+ action_spec (TensorSpec): sampling spec of the actions;
88
+ reward_spec (TensorSpec): sampling spec of the rewards;
89
+ input_spec (Composite): sampling spec of the inputs;
90
+ batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes.
91
+ device (torch.device): device where the env input and output are expected to live
92
+
93
+ Args:
94
+ world_model (nn.Module): model that generates world states and its corresponding rewards;
95
+ params (List[torch.Tensor], optional): list of parameters of the world model;
96
+ buffers (List[torch.Tensor], optional): list of buffers of the world model;
97
+ device (torch.device, optional): device where the env input and output are expected to live
98
+ dtype (torch.dtype, optional): dtype of the env input and output
99
+ batch_size (torch.Size, optional): number of environments contained in the instance
100
+ run_type_check (bool, optional): whether to run type checks on the step of the env
101
+
102
+ Methods:
103
+ step (TensorDict -> TensorDict): step in the environment
104
+ reset (TensorDict, optional -> TensorDict): reset the environment
105
+ set_seed (int -> int): sets the seed of the environment
106
+ rand_step (TensorDict, optional -> TensorDict): random step given the action spec
107
+ rollout (Callable, ... -> TensorDict): executes a rollout in the environment with the given policy (or random
108
+ steps if no policy is provided)
109
+
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ world_model: TensorDictModule,
115
+ params: list[torch.Tensor] | None = None,
116
+ buffers: list[torch.Tensor] | None = None,
117
+ device: DEVICE_TYPING = "cpu",
118
+ batch_size: torch.Size | None = None,
119
+ run_type_checks: bool = False,
120
+ allow_done_after_reset: bool = False,
121
+ ):
122
+ super().__init__(
123
+ device=device,
124
+ batch_size=batch_size,
125
+ run_type_checks=run_type_checks,
126
+ allow_done_after_reset=allow_done_after_reset,
127
+ )
128
+ self.world_model = world_model.to(self.device)
129
+ self.world_model_params = params
130
+ self.world_model_buffers = buffers
131
+
132
+ @classmethod
133
+ def __new__(cls, *args, **kwargs):
134
+ return super().__new__(
135
+ cls, *args, _inplace_update=False, _batch_locked=False, **kwargs
136
+ )
137
+
138
+ def set_specs_from_env(self, env: EnvBase):
139
+ """Sets the specs of the environment from the specs of the given environment."""
140
+ device = self.device
141
+ output_spec = env.output_spec.clone()
142
+ input_spec = env.input_spec.clone()
143
+ if device is not None:
144
+ output_spec = output_spec.to(device)
145
+ input_spec = input_spec.to(device)
146
+ self.__dict__["_output_spec"] = output_spec
147
+ self.__dict__["_input_spec"] = input_spec
148
+ self.empty_cache()
149
+
150
+ def _step(
151
+ self,
152
+ tensordict: TensorDict,
153
+ ) -> TensorDict:
154
+ # step method requires to be immutable
155
+ tensordict_out = tensordict.clone(recurse=False)
156
+ # Compute world state
157
+ if self.world_model_params is not None:
158
+ tensordict_out = self.world_model(
159
+ tensordict_out,
160
+ params=self.world_model_params,
161
+ buffers=self.world_model_buffers,
162
+ )
163
+ else:
164
+ tensordict_out = self.world_model(tensordict_out)
165
+ # done can be missing, it will be filled by `step`
166
+ # Convert to list for torch.compile compatibility (dynamo can't unpack _CompositeSpecKeysView)
167
+ keys_to_select = (
168
+ list(self.observation_spec.keys())
169
+ + list(self.full_done_spec.keys())
170
+ + list(self.full_reward_spec.keys())
171
+ )
172
+ tensordict_out = tensordict_out.select(*keys_to_select, strict=False)
173
+ return tensordict_out
174
+
175
+ @abc.abstractmethod
176
+ def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
177
+ raise NotImplementedError
178
+
179
+ def _set_seed(self, seed: int | None) -> None:
180
+ warnings.warn("Set seed isn't needed for model based environments")
@@ -0,0 +1,112 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ from tensordict import TensorDict
9
+ from tensordict.nn import TensorDictModule
10
+ from torchrl.data.tensor_specs import Composite
11
+ from torchrl.data.utils import DEVICE_TYPING
12
+ from torchrl.envs.common import EnvBase
13
+ from torchrl.envs.model_based import ModelBasedEnvBase
14
+ from torchrl.envs.transforms.transforms import Transform
15
+
16
+
17
+ class DreamerEnv(ModelBasedEnvBase):
18
+ """Dreamer simulation environment.
19
+
20
+ This environment is used for imagination rollouts in Dreamer training.
21
+ It never terminates (done is always False) since imagination runs for a
22
+ fixed horizon. The done-checking methods are overridden to avoid CUDA
23
+ synchronization overhead from Python control flow on CUDA tensors.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ world_model: TensorDictModule,
29
+ prior_shape: tuple[int, ...],
30
+ belief_shape: tuple[int, ...],
31
+ obs_decoder: TensorDictModule = None,
32
+ device: DEVICE_TYPING = "cpu",
33
+ batch_size: torch.Size | None = None,
34
+ ):
35
+ super().__init__(
36
+ world_model,
37
+ device=device,
38
+ batch_size=batch_size,
39
+ # Skip done validation in reset() — imagination never terminates.
40
+ allow_done_after_reset=True,
41
+ )
42
+ self.obs_decoder = obs_decoder
43
+ self.prior_shape = prior_shape
44
+ self.belief_shape = belief_shape
45
+
46
+ def any_done(self, tensordict) -> bool:
47
+ """Returns False — imagination rollouts never terminate.
48
+
49
+ Overridden to avoid CUDA sync from `done.any()` in parent class.
50
+ """
51
+ return False
52
+
53
+ def maybe_reset(self, tensordict):
54
+ """No-op — imagination rollouts don't need partial resets.
55
+
56
+ Overridden to avoid CUDA sync from done checks in parent class.
57
+ """
58
+ return tensordict
59
+
60
+ def set_specs_from_env(self, env: EnvBase):
61
+ """Sets the specs of the environment from the specs of the given environment."""
62
+ super().set_specs_from_env(env)
63
+ self.action_spec = self.action_spec.to(self.device)
64
+ self.state_spec = Composite(
65
+ state=self.observation_spec["state"],
66
+ belief=self.observation_spec["belief"],
67
+ shape=env.batch_size,
68
+ )
69
+
70
+ def _reset(self, tensordict=None, **kwargs) -> TensorDict:
71
+ batch_size = tensordict.batch_size if tensordict is not None else []
72
+ device = tensordict.device if tensordict is not None else self.device
73
+ if tensordict is None:
74
+ td = self.state_spec.rand(shape=batch_size)
75
+ # why don't we reuse actions taken at those steps?
76
+ td.set("action", self.action_spec.rand(shape=batch_size))
77
+ td[("next", "reward")] = self.reward_spec.rand(shape=batch_size)
78
+ td.update(self.observation_spec.rand(shape=batch_size))
79
+ if device is not None:
80
+ td = td.to(device, non_blocking=True)
81
+ if torch.cuda.is_available() and device.type == "cpu":
82
+ torch.cuda.synchronize()
83
+ elif torch.backends.mps.is_available():
84
+ torch.mps.synchronize()
85
+ else:
86
+ td = tensordict.clone()
87
+ return td
88
+
89
+ def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDict:
90
+ if self.obs_decoder is None:
91
+ raise ValueError("No observation decoder provided")
92
+ if compute_latents:
93
+ tensordict = self.world_model(tensordict)
94
+ return self.obs_decoder(tensordict)
95
+
96
+
97
+ class DreamerDecoder(Transform):
98
+ """A transform to record the decoded observations in Dreamer.
99
+
100
+ Examples:
101
+ >>> model_based_env = DreamerEnv(...)
102
+ >>> model_based_env_eval = model_based_env.append_transform(DreamerDecoder())
103
+ """
104
+
105
+ def _call(self, next_tensordict):
106
+ return self.parent.base_env.obs_decoder(next_tensordict)
107
+
108
+ def _reset(self, tensordict, tensordict_reset):
109
+ return self._call(tensordict_reset)
110
+
111
+ def transform_observation_spec(self, observation_spec):
112
+ return observation_spec
@@ -0,0 +1,147 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .gym_transforms import EndOfLifeTransform
7
+ from .module import ModuleTransform
8
+ from .r3m import R3MTransform
9
+ from .ray_service import RayTransform
10
+ from .rb_transforms import MultiStepTransform
11
+ from .transforms import (
12
+ ActionDiscretizer,
13
+ ActionMask,
14
+ AutoResetEnv,
15
+ AutoResetTransform,
16
+ BatchSizeTransform,
17
+ BinarizeReward,
18
+ BurnInTransform,
19
+ CatFrames,
20
+ CatTensors,
21
+ CenterCrop,
22
+ ClipTransform,
23
+ Compose,
24
+ ConditionalPolicySwitch,
25
+ ConditionalSkip,
26
+ Crop,
27
+ DeviceCastTransform,
28
+ DiscreteActionProjection,
29
+ DoubleToFloat,
30
+ DTypeCastTransform,
31
+ ExcludeTransform,
32
+ FiniteTensorDictCheck,
33
+ FlattenObservation,
34
+ FrameSkipTransform,
35
+ GrayScale,
36
+ gSDENoise,
37
+ Hash,
38
+ InitTracker,
39
+ LineariseRewards,
40
+ MultiAction,
41
+ NoopResetEnv,
42
+ ObservationNorm,
43
+ ObservationTransform,
44
+ PermuteTransform,
45
+ PinMemoryTransform,
46
+ RandomCropTensorDict,
47
+ RemoveEmptySpecs,
48
+ RenameTransform,
49
+ Resize,
50
+ Reward2GoTransform,
51
+ RewardClipping,
52
+ RewardScaling,
53
+ RewardSum,
54
+ SelectTransform,
55
+ SignTransform,
56
+ SqueezeTransform,
57
+ Stack,
58
+ StepCounter,
59
+ TargetReturn,
60
+ TensorDictPrimer,
61
+ TimeMaxPool,
62
+ Timer,
63
+ Tokenizer,
64
+ ToTensorImage,
65
+ TrajCounter,
66
+ Transform,
67
+ TransformedEnv,
68
+ UnaryTransform,
69
+ UnsqueezeTransform,
70
+ VecGymEnvTransform,
71
+ VecNorm,
72
+ )
73
+ from .vc1 import VC1Transform
74
+ from .vecnorm import VecNormV2
75
+ from .vip import VIPRewardTransform, VIPTransform
76
+
77
+ __all__ = [
78
+ "ActionDiscretizer",
79
+ "ActionMask",
80
+ "AutoResetEnv",
81
+ "AutoResetTransform",
82
+ "BatchSizeTransform",
83
+ "BinarizeReward",
84
+ "BurnInTransform",
85
+ "CatFrames",
86
+ "CatTensors",
87
+ "CenterCrop",
88
+ "ClipTransform",
89
+ "Compose",
90
+ "ConditionalPolicySwitch",
91
+ "ConditionalSkip",
92
+ "Crop",
93
+ "DTypeCastTransform",
94
+ "DeviceCastTransform",
95
+ "DiscreteActionProjection",
96
+ "DoubleToFloat",
97
+ "EndOfLifeTransform",
98
+ "ExcludeTransform",
99
+ "FiniteTensorDictCheck",
100
+ "FlattenObservation",
101
+ "FrameSkipTransform",
102
+ "GrayScale",
103
+ "Hash",
104
+ "InitTracker",
105
+ "LineariseRewards",
106
+ "ModuleTransform",
107
+ "MultiAction",
108
+ "MultiStepTransform",
109
+ "NoopResetEnv",
110
+ "ObservationNorm",
111
+ "ObservationTransform",
112
+ "PermuteTransform",
113
+ "PinMemoryTransform",
114
+ "R3MTransform",
115
+ "RandomCropTensorDict",
116
+ "RayTransform",
117
+ "RemoveEmptySpecs",
118
+ "RenameTransform",
119
+ "Resize",
120
+ "Reward2GoTransform",
121
+ "RewardClipping",
122
+ "RewardScaling",
123
+ "RewardSum",
124
+ "SelectTransform",
125
+ "SignTransform",
126
+ "SqueezeTransform",
127
+ "Stack",
128
+ "StepCounter",
129
+ "TargetReturn",
130
+ "TensorDictPrimer",
131
+ "TimeMaxPool",
132
+ "Timer",
133
+ "ToTensorImage",
134
+ "Tokenizer",
135
+ "TrajCounter",
136
+ "Transform",
137
+ "TransformedEnv",
138
+ "UnaryTransform",
139
+ "UnsqueezeTransform",
140
+ "VC1Transform",
141
+ "VIPRewardTransform",
142
+ "VIPTransform",
143
+ "VecGymEnvTransform",
144
+ "VecNorm",
145
+ "VecNormV2",
146
+ "gSDENoise",
147
+ ]
@@ -0,0 +1,48 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from torch import Tensor
8
+
9
+
10
+ # copied from torchvision
11
+ def _get_image_num_channels(img: Tensor) -> int:
12
+ if img.ndim == 2:
13
+ return 1
14
+ elif img.ndim > 2:
15
+ return img.shape[-3]
16
+
17
+ raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
18
+
19
+
20
+ def _assert_channels(img: Tensor, permitted: list[int]) -> None:
21
+ c = _get_image_num_channels(img)
22
+ if c not in permitted:
23
+ raise TypeError(
24
+ f"Input image tensor permitted channel values are {permitted}, but found "
25
+ f"{c} (full shape: {img.shape})"
26
+ )
27
+
28
+
29
+ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
30
+ """Turns an RGB image into grayscale."""
31
+ if img.ndim < 3:
32
+ raise TypeError(
33
+ "Input image tensor should have at least 3 dimensions, but found"
34
+ "{}".format(img.ndim)
35
+ )
36
+ _assert_channels(img, [3])
37
+
38
+ if num_output_channels not in (1, 3):
39
+ raise ValueError("num_output_channels should be either 1 or 3")
40
+
41
+ r, g, b = img.unbind(dim=-3)
42
+ l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
43
+ l_img = l_img.unsqueeze(dim=-3)
44
+
45
+ if num_output_channels == 3:
46
+ return l_img.expand(img.shape)
47
+
48
+ return l_img