torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,288 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ from tensordict import TensorDict, TensorDictBase
9
+ from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
10
+ from torchrl.envs.common import EnvBase
11
+
12
+
13
+ class TicTacToeEnv(EnvBase):
14
+ """A Tic-Tac-Toe implementation.
15
+
16
+ Keyword Args:
17
+ single_player (bool, optional): whether one or two players have to be
18
+ accounted for. ``single_player=True`` means that ``"player1"`` is
19
+ playing randomly. If ``False`` (default), at each turn,
20
+ one of the two players has to play.
21
+ device (torch.device, optional): the device where to put the tensors.
22
+ Defaults to ``None`` (default device).
23
+
24
+ The environment is stateless. To run it across multiple batches, call
25
+
26
+ >>> env.reset(TensorDict(batch_size=desired_batch_size))
27
+
28
+ If the ``"mask"`` entry is present, ``rand_action`` takes it into account to
29
+ generate the next action. Any policy executed on this env should take this
30
+ mask into account, as well as the turn of the player (stored in the ``"turn"``
31
+ output entry).
32
+
33
+ Specs:
34
+ >>> print(env.specs)
35
+ Composite(
36
+ output_spec: Composite(
37
+ full_observation_spec: Composite(
38
+ board: Categorical(
39
+ shape=torch.Size([3, 3]),
40
+ space=DiscreteBox(n=2),
41
+ dtype=torch.int32,
42
+ domain=discrete),
43
+ turn: Categorical(
44
+ shape=torch.Size([1]),
45
+ space=DiscreteBox(n=2),
46
+ dtype=torch.int32,
47
+ domain=discrete),
48
+ mask: Categorical(
49
+ shape=torch.Size([9]),
50
+ space=DiscreteBox(n=2),
51
+ dtype=torch.bool,
52
+ domain=discrete),
53
+ shape=torch.Size([])),
54
+ full_reward_spec: Composite(
55
+ player0: Composite(
56
+ reward: UnboundedContinuous(
57
+ shape=torch.Size([1]),
58
+ space=ContinuousBox(
59
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
60
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
61
+ dtype=torch.float32,
62
+ domain=continuous),
63
+ shape=torch.Size([])),
64
+ player1: Composite(
65
+ reward: UnboundedContinuous(
66
+ shape=torch.Size([1]),
67
+ space=ContinuousBox(
68
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
69
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
70
+ dtype=torch.float32,
71
+ domain=continuous),
72
+ shape=torch.Size([])),
73
+ shape=torch.Size([])),
74
+ full_done_spec: Composite(
75
+ done: Categorical(
76
+ shape=torch.Size([1]),
77
+ space=DiscreteBox(n=2),
78
+ dtype=torch.bool,
79
+ domain=discrete),
80
+ terminated: Categorical(
81
+ shape=torch.Size([1]),
82
+ space=DiscreteBox(n=2),
83
+ dtype=torch.bool,
84
+ domain=discrete),
85
+ truncated: Categorical(
86
+ shape=torch.Size([1]),
87
+ space=DiscreteBox(n=2),
88
+ dtype=torch.bool,
89
+ domain=discrete),
90
+ shape=torch.Size([])),
91
+ shape=torch.Size([])),
92
+ input_spec: Composite(
93
+ full_state_spec: Composite(
94
+ board: Categorical(
95
+ shape=torch.Size([3, 3]),
96
+ space=DiscreteBox(n=2),
97
+ dtype=torch.int32,
98
+ domain=discrete),
99
+ turn: Categorical(
100
+ shape=torch.Size([1]),
101
+ space=DiscreteBox(n=2),
102
+ dtype=torch.int32,
103
+ domain=discrete),
104
+ mask: Categorical(
105
+ shape=torch.Size([9]),
106
+ space=DiscreteBox(n=2),
107
+ dtype=torch.bool,
108
+ domain=discrete), shape=torch.Size([])),
109
+ full_action_spec: Composite(
110
+ action: Categorical(
111
+ shape=torch.Size([1]),
112
+ space=DiscreteBox(n=9),
113
+ dtype=torch.int64,
114
+ domain=discrete),
115
+ shape=torch.Size([])),
116
+ shape=torch.Size([])),
117
+ shape=torch.Size([]))
118
+
119
+ To run a dummy rollout, execute the following command:
120
+
121
+ Examples:
122
+ >>> env = TicTacToeEnv()
123
+ >>> env.rollout(10)
124
+ TensorDict(
125
+ fields={
126
+ action: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int64, is_shared=False),
127
+ board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False),
128
+ done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
129
+ mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False),
130
+ next: TensorDict(
131
+ fields={
132
+ board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False),
133
+ done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
134
+ mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False),
135
+ player0: TensorDict(
136
+ fields={
137
+ reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
138
+ batch_size=torch.Size([9]),
139
+ device=None,
140
+ is_shared=False),
141
+ player1: TensorDict(
142
+ fields={
143
+ reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
144
+ batch_size=torch.Size([9]),
145
+ device=None,
146
+ is_shared=False),
147
+ terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
148
+ truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
149
+ turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)},
150
+ batch_size=torch.Size([9]),
151
+ device=None,
152
+ is_shared=False),
153
+ terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
154
+ truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False),
155
+ turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)},
156
+ batch_size=torch.Size([9]),
157
+ device=None,
158
+ is_shared=False)
159
+
160
+ """
161
+
162
+ # batch_locked is set to False since various batch sizes can be provided to the env
163
+ batch_locked: bool = False
164
+
165
+ def __init__(self, *, single_player: bool = False, device=None):
166
+ super().__init__(device=device)
167
+ self.single_player = single_player
168
+ self.action_spec: Unbounded = Categorical(
169
+ n=9,
170
+ shape=(),
171
+ device=device,
172
+ )
173
+
174
+ self.full_observation_spec: Composite = Composite(
175
+ board=Unbounded(shape=(3, 3), dtype=torch.int, device=device),
176
+ turn=Categorical(
177
+ 2,
178
+ shape=(1,),
179
+ dtype=torch.int,
180
+ device=device,
181
+ ),
182
+ mask=Categorical(
183
+ 2,
184
+ shape=(9,),
185
+ dtype=torch.bool,
186
+ device=device,
187
+ ),
188
+ device=device,
189
+ )
190
+ self.state_spec: Composite = self.observation_spec.clone()
191
+
192
+ self.reward_spec: Unbounded = Composite(
193
+ {
194
+ ("player0", "reward"): Unbounded(shape=(1,), device=device),
195
+ ("player1", "reward"): Unbounded(shape=(1,), device=device),
196
+ },
197
+ device=device,
198
+ )
199
+
200
+ self.full_done_spec: Categorical = Composite(
201
+ done=Categorical(2, shape=(1,), dtype=torch.bool, device=device),
202
+ device=device,
203
+ )
204
+ self.full_done_spec["terminated"] = self.full_done_spec["done"].clone()
205
+ self.full_done_spec["truncated"] = self.full_done_spec["done"].clone()
206
+
207
+ def _reset(self, reset_td: TensorDict) -> TensorDict:
208
+ shape = reset_td.shape if reset_td is not None else ()
209
+ state = self.state_spec.zero(shape)
210
+ state["board"] -= 1
211
+ state["mask"].fill_(True)
212
+ return state.update(self.full_done_spec.zero(shape))
213
+
214
+ def _step(self, state: TensorDict) -> TensorDict:
215
+ board = state["board"].clone()
216
+ turn = state["turn"].clone()
217
+ action = state["action"]
218
+ board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1)
219
+ wins = self.win(board, action)
220
+
221
+ mask = board.flatten(-2, -1) == -1
222
+ done = wins | ~mask.any(-1, keepdim=True)
223
+ terminated = done.clone()
224
+
225
+ reward_0 = wins & (turn == 0)
226
+ reward_1 = wins & (turn == 1)
227
+
228
+ state = TensorDict(
229
+ {
230
+ "done": done,
231
+ "terminated": terminated,
232
+ ("player0", "reward"): reward_0.float(),
233
+ ("player1", "reward"): reward_1.float(),
234
+ "board": torch.where(board == -1, board, 1 - board),
235
+ "turn": 1 - turn,
236
+ "mask": mask,
237
+ },
238
+ batch_size=state.batch_size,
239
+ )
240
+ if self.single_player:
241
+ select = (~done & (turn == 0)).squeeze(-1)
242
+ if select.all():
243
+ state_select = state
244
+ elif select.any():
245
+ state_select = state[select]
246
+ else:
247
+ return state
248
+ state_select = self._step(self.rand_action(state_select))
249
+ if select.all():
250
+ return state_select
251
+ return torch.where(done, state, state_select)
252
+ return state
253
+
254
+ def _set_seed(self, seed: int | None) -> None:
255
+ ...
256
+
257
+ @staticmethod
258
+ def win(board: torch.Tensor, action: torch.Tensor):
259
+ row = action // 3 # type: ignore
260
+ col = action % 3 # type: ignore
261
+ if board[..., row, :].sum() == 3:
262
+ return True
263
+ if board[..., col].sum() == 3:
264
+ return True
265
+ if board.diagonal(0, -2, -1).sum() == 3:
266
+ return True
267
+ if board.flip(-1).diagonal(0, -2, -1).sum() == 3:
268
+ return True
269
+ return False
270
+
271
+ @staticmethod
272
+ def full(board: torch.Tensor) -> bool:
273
+ return torch.sym_int(board.abs().sum()) == 9
274
+
275
+ @staticmethod
276
+ def get_action_mask():
277
+ pass
278
+
279
+ def rand_action(self, tensordict: TensorDictBase | None = None):
280
+ mask = tensordict.get("mask")
281
+ action_spec = self.action_spec
282
+ if tensordict.ndim:
283
+ action_spec = action_spec.expand(tensordict.shape)
284
+ else:
285
+ action_spec = action_spec.clone()
286
+ action_spec.update_mask(mask)
287
+ tensordict.set(self.action_key, action_spec.rand())
288
+ return tensordict
@@ -0,0 +1,263 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from collections import OrderedDict
9
+ from collections.abc import Callable
10
+ from multiprocessing.sharedctypes import Synchronized
11
+ from multiprocessing.synchronize import Lock, RLock
12
+
13
+ import torch
14
+ from tensordict import TensorDictBase
15
+ from torchrl._utils import logger as torchrl_logger
16
+ from torchrl.data.utils import CloudpickleWrapper
17
+ from torchrl.envs.common import EnvBase, EnvMetaData
18
+
19
+
20
+ class EnvCreator:
21
+ """Environment creator class.
22
+
23
+ EnvCreator is a generic environment creator class that can substitute
24
+ lambda functions when creating environments in multiprocessing contexts.
25
+ If the environment created on a subprocess must share information with the
26
+ main process (e.g. for the VecNorm transform), EnvCreator will pass the
27
+ pointers to the tensordicts in shared memory to each process such that
28
+ all of them are synchronised.
29
+
30
+ Args:
31
+ create_env_fn (callable): a callable that returns an EnvBase
32
+ instance.
33
+ create_env_kwargs (dict, optional): the kwargs of the env creator.
34
+ share_memory (bool, optional): if False, the resulting tensordict
35
+ from the environment won't be placed in shared memory.
36
+ **kwargs: additional keyword arguments to be passed to the environment
37
+ during construction.
38
+
39
+ Examples:
40
+ >>> # We create the same environment on 2 processes using VecNorm
41
+ >>> # and check that the discounted count of observations matches on
42
+ >>> # both workers, even if one has not executed any step
43
+ >>> import time
44
+ >>> from torchrl.envs.libs.gym import GymEnv
45
+ >>> from torchrl.envs.transforms import VecNorm, TransformedEnv
46
+ >>> from torchrl.envs import EnvCreator
47
+ >>> from torch import multiprocessing as mp
48
+ >>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm())
49
+ >>> env_creator = EnvCreator(env_fn)
50
+ >>>
51
+ >>> def test_env1(env_creator):
52
+ ... env = env_creator()
53
+ ... tensordict = env.reset()
54
+ ... for _ in range(10):
55
+ ... env.rand_step(tensordict)
56
+ ... if tensordict.get(("next", "done")):
57
+ ... tensordict = env.reset(tensordict)
58
+ ... print("env 1: ", env.transform._td.get(("next", "observation_count")))
59
+ >>>
60
+ >>> def test_env2(env_creator):
61
+ ... env = env_creator()
62
+ ... time.sleep(5)
63
+ ... print("env 2: ", env.transform._td.get(("next", "observation_count")))
64
+ >>>
65
+ >>> if __name__ == "__main__":
66
+ ... ps = []
67
+ ... p1 = mp.Process(target=test_env1, args=(env_creator,))
68
+ ... p1.start()
69
+ ... ps.append(p1)
70
+ ... p2 = mp.Process(target=test_env2, args=(env_creator,))
71
+ ... p2.start()
72
+ ... ps.append(p1)
73
+ ... for p in ps:
74
+ ... p.join()
75
+ env 1: tensor([11.9934])
76
+ env 2: tensor([11.9934])
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ create_env_fn: Callable[..., EnvBase],
82
+ create_env_kwargs: dict | None = None,
83
+ share_memory: bool = True,
84
+ **kwargs,
85
+ ) -> None:
86
+ if not isinstance(create_env_fn, (EnvCreator, CloudpickleWrapper)):
87
+ self.create_env_fn = CloudpickleWrapper(create_env_fn)
88
+ else:
89
+ self.create_env_fn = create_env_fn
90
+
91
+ self.create_env_kwargs = kwargs
92
+ if isinstance(create_env_kwargs, dict):
93
+ self.create_env_kwargs.update(create_env_kwargs)
94
+ self.initialized = False
95
+ self._meta_data = None
96
+ self._share_memory = share_memory
97
+ self.init_()
98
+
99
+ def make_variant(self, **kwargs) -> EnvCreator:
100
+ """Creates a variant of the EnvCreator, pointing to the same underlying metadata but with different keyword arguments during construction.
101
+
102
+ This can be useful with transforms that share a state, like :class:`~torchrl.envs.TrajCounter`.
103
+
104
+ Examples:
105
+ >>> from torchrl.envs import GymEnv
106
+ >>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1")
107
+ >>> env_creator_cartpole = env_creator_pendulum.make_variant(env_name="CartPole-v1")
108
+
109
+ """
110
+ # Copy self
111
+ out = type(self).__new__(type(self))
112
+ out.__dict__.update(self.__dict__)
113
+ out.create_env_kwargs.update(kwargs)
114
+ return out
115
+
116
+ def share_memory(self, state_dict: OrderedDict) -> None:
117
+ for key, item in list(state_dict.items()):
118
+ if isinstance(item, (TensorDictBase,)):
119
+ if not item.is_shared():
120
+ item.share_memory_()
121
+ else:
122
+ torchrl_logger.info(
123
+ f"{self.env_type}: {item} is already shared"
124
+ ) # , deleting key'val)
125
+ del state_dict[key]
126
+ elif isinstance(item, OrderedDict):
127
+ self.share_memory(item)
128
+ elif isinstance(item, torch.Tensor):
129
+ del state_dict[key]
130
+
131
+ @property
132
+ def meta_data(self) -> EnvMetaData:
133
+ if self._meta_data is None:
134
+ raise RuntimeError(
135
+ "meta_data is None in EnvCreator. " "Make sure init_() has been called."
136
+ )
137
+ return self._meta_data
138
+
139
+ @meta_data.setter
140
+ def meta_data(self, value: EnvMetaData):
141
+ self._meta_data = value
142
+
143
+ @staticmethod
144
+ def _is_mp_value(val):
145
+ if isinstance(val, (Synchronized,)) and hasattr(val, "_obj"):
146
+ return True
147
+ # Also check for lock types which need to be shared across processes
148
+ if isinstance(val, (Lock, RLock)):
149
+ return True
150
+ return False
151
+
152
+ @classmethod
153
+ def _find_mp_values(cls, env_or_transform, values, prefix=()):
154
+ from torchrl.envs.transforms.transforms import Compose, TransformedEnv
155
+
156
+ if isinstance(env_or_transform, EnvBase) and isinstance(
157
+ env_or_transform, TransformedEnv
158
+ ):
159
+ cls._find_mp_values(
160
+ env_or_transform.transform,
161
+ values=values,
162
+ prefix=prefix + ("transform",),
163
+ )
164
+ cls._find_mp_values(
165
+ env_or_transform.base_env, values=values, prefix=prefix + ("base_env",)
166
+ )
167
+ elif isinstance(env_or_transform, Compose):
168
+ for i, t in enumerate(env_or_transform.transforms):
169
+ cls._find_mp_values(t, values=values, prefix=prefix + (i,))
170
+ for k, v in env_or_transform.__dict__.items():
171
+ if cls._is_mp_value(v):
172
+ values.append((prefix + (k,), v))
173
+ return values
174
+
175
+ def init_(self) -> EnvCreator:
176
+ shadow_env = self.create_env_fn(**self.create_env_kwargs)
177
+ tensordict = shadow_env.reset()
178
+ shadow_env.rand_step(tensordict)
179
+ self.env_type = type(shadow_env)
180
+ self._transform_state_dict = shadow_env.state_dict()
181
+ # Extract any mp.Value object from the env
182
+ self._mp_values = self._find_mp_values(shadow_env, values=[])
183
+
184
+ if self._share_memory:
185
+ self.share_memory(self._transform_state_dict)
186
+ self.initialized = True
187
+ self.meta_data = EnvMetaData.metadata_from_env(shadow_env)
188
+ shadow_env.close()
189
+ del shadow_env
190
+ return self
191
+
192
+ @classmethod
193
+ def _set_mp_value(cls, env, key, value):
194
+ if len(key) > 1:
195
+ if isinstance(key[0], int):
196
+ return cls._set_mp_value(env[key[0]], key[1:], value)
197
+ else:
198
+ return cls._set_mp_value(getattr(env, key[0]), key[1:], value)
199
+ else:
200
+ setattr(env, key[0], value)
201
+
202
+ def __call__(self, **kwargs) -> EnvBase:
203
+ if not self.initialized:
204
+ raise RuntimeError("EnvCreator must be initialized before being called.")
205
+ kwargs.update(self.create_env_kwargs) # create_env_kwargs precedes
206
+ env = self.create_env_fn(**kwargs)
207
+ if self._mp_values:
208
+ for k, v in self._mp_values:
209
+ self._set_mp_value(env, k, v)
210
+ env.load_state_dict(self._transform_state_dict, strict=False)
211
+ return env
212
+
213
+ def state_dict(self) -> OrderedDict:
214
+ if self._transform_state_dict is None:
215
+ return OrderedDict()
216
+ return self._transform_state_dict
217
+
218
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
219
+ if self._transform_state_dict is not None:
220
+ for key, item in state_dict.items():
221
+ item_to_update = self._transform_state_dict[key]
222
+ item_to_update.copy_(item)
223
+
224
+ def __repr__(self) -> str:
225
+ substr = ", ".join(
226
+ [f"{key}: {type(item)}" for key, item in self.create_env_kwargs]
227
+ )
228
+ return f"EnvCreator({self.create_env_fn}({substr}))"
229
+
230
+
231
+ def env_creator(fun: Callable) -> EnvCreator:
232
+ """Helper function to call `EnvCreator`."""
233
+ return EnvCreator(fun)
234
+
235
+
236
+ def get_env_metadata(env_or_creator: EnvBase | Callable, kwargs: dict | None = None):
237
+ """Retrieves a EnvMetaData object from an env."""
238
+ if isinstance(env_or_creator, (EnvBase,)):
239
+ return EnvMetaData.metadata_from_env(env_or_creator)
240
+ elif not isinstance(env_or_creator, EnvBase) and not isinstance(
241
+ env_or_creator, EnvCreator
242
+ ):
243
+ # then env is a creator
244
+ if kwargs is None:
245
+ kwargs = {}
246
+ env = env_or_creator(**kwargs)
247
+ return EnvMetaData.metadata_from_env(env)
248
+ elif isinstance(env_or_creator, EnvCreator):
249
+ if not (
250
+ kwargs == env_or_creator.create_env_kwargs
251
+ or kwargs is None
252
+ or len(kwargs) == 0
253
+ ):
254
+ raise RuntimeError(
255
+ "kwargs mismatch between EnvCreator and the kwargs provided to get_env_metadata:"
256
+ f"got EnvCreator.create_env_kwargs={env_or_creator.create_env_kwargs} and "
257
+ f"kwargs = {kwargs}"
258
+ )
259
+ return env_or_creator.meta_data.clone()
260
+ else:
261
+ raise NotImplementedError(
262
+ f"env of type {type(env_or_creator)} is not supported by get_env_metadata."
263
+ )