torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,142 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+
9
+ import os
10
+ from collections.abc import Sequence
11
+ from tempfile import TemporaryDirectory
12
+ from typing import Any
13
+
14
+ from torch import Tensor
15
+
16
+ from torchrl.record.loggers.common import Logger
17
+
18
+ _has_tv = importlib.util.find_spec("torchvision") is not None
19
+
20
+ _has_mlflow = importlib.util.find_spec("mlflow") is not None
21
+ _has_omegaconf = importlib.util.find_spec("omegaconf") is not None
22
+
23
+
24
+ class MLFlowLogger(Logger):
25
+ """Wrapper for the mlflow logger.
26
+
27
+ Args:
28
+ exp_name (str): The name of the experiment.
29
+ tracking_uri (str): A tracking URI to a datastore that supports MLFlow or a local directory.
30
+
31
+ Keyword Args:
32
+ fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
33
+
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ exp_name: str,
39
+ tracking_uri: str,
40
+ tags: dict[str, Any] | None = None,
41
+ *,
42
+ video_fps: int = 30,
43
+ **kwargs,
44
+ ) -> None:
45
+ import mlflow
46
+
47
+ self._mlflow_kwargs = {
48
+ "name": exp_name,
49
+ "artifact_location": tracking_uri,
50
+ "tags": tags,
51
+ }
52
+ mlflow.set_tracking_uri(tracking_uri)
53
+ super().__init__(exp_name=exp_name, log_dir=tracking_uri)
54
+ self.video_log_counter = 0
55
+ self.video_fps = video_fps
56
+
57
+ def _create_experiment(self) -> mlflow.ActiveRun: # noqa
58
+ import mlflow
59
+
60
+ """Creates an mlflow experiment.
61
+
62
+ Returns:
63
+ mlflow.ActiveRun: The mlflow experiment object.
64
+ """
65
+ if not _has_mlflow:
66
+ raise ImportError("MLFlow is not installed")
67
+
68
+ # Only create experiment if it doesnt exist
69
+ experiment = mlflow.get_experiment_by_name(self._mlflow_kwargs["name"])
70
+ if experiment is None:
71
+ self.id = mlflow.create_experiment(**self._mlflow_kwargs)
72
+ else:
73
+ self.id = experiment.experiment_id
74
+ return mlflow.start_run(experiment_id=self.id)
75
+
76
+ def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
77
+ """Logs a scalar value to mlflow.
78
+
79
+ Args:
80
+ name (str): The name of the scalar.
81
+ value (float): The value of the scalar.
82
+ step (int, optional): The step at which the scalar is logged.
83
+ Defaults to None.
84
+ """
85
+ import mlflow
86
+
87
+ mlflow.set_experiment(experiment_id=self.id)
88
+ mlflow.log_metric(key=name, value=value, step=step)
89
+
90
+ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
91
+ """Log video inputs to mlflow.
92
+
93
+ Args:
94
+ name (str): The name of the video.
95
+ video (Tensor): The video to be logged, expected to be in (T, C, H, W) format
96
+ for consistency with other loggers.
97
+ **kwargs: Other keyword arguments. By construction, log_video
98
+ supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``).
99
+ """
100
+ import mlflow
101
+ import torchvision
102
+
103
+ if not _has_tv:
104
+ raise ImportError(
105
+ "Logging a video with MLFlow requires torchvision to be installed."
106
+ )
107
+ mlflow.set_experiment(experiment_id=self.id)
108
+ if video.ndim == 5:
109
+ video = video[-1] # N T C H W -> T C H W
110
+ video = video.permute(0, 2, 3, 1) # T C H W -> T H W C
111
+ if video.size(dim=-1) != 3:
112
+ raise ValueError(
113
+ "The MLFlow logger only supports videos with 3 color channels."
114
+ )
115
+ self.video_log_counter += 1
116
+ fps = kwargs.pop("fps", self.video_fps)
117
+ step = kwargs.pop("step", None)
118
+ with TemporaryDirectory() as temp_dir:
119
+ video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4"
120
+ with open(os.path.join(temp_dir, video_name), "wb") as f:
121
+ torchvision.io.write_video(filename=f.name, video_array=video, fps=fps)
122
+ mlflow.log_artifact(f.name, "videos")
123
+
124
+ def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
125
+ """Logs the hyperparameters of the experiment.
126
+
127
+ Args:
128
+ cfg (DictConfig or dict): The configuration of the experiment.
129
+ """
130
+ import mlflow
131
+ from omegaconf import OmegaConf
132
+
133
+ mlflow.set_experiment(experiment_id=self.id)
134
+ if type(cfg) is not dict and _has_omegaconf:
135
+ cfg = OmegaConf.to_container(cfg, resolve=True)
136
+ mlflow.log_params(cfg)
137
+
138
+ def __repr__(self) -> str:
139
+ return f"MLFlowLogger(experiment={self.experiment.__repr__()})"
140
+
141
+ def log_histogram(self, name: str, data: Sequence, **kwargs):
142
+ raise NotImplementedError("Logging histograms in cvs is not permitted.")
@@ -0,0 +1,139 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+
9
+ import os
10
+ from collections.abc import Sequence
11
+
12
+ from torch import Tensor
13
+
14
+ from .common import Logger
15
+
16
+ _has_tb = importlib.util.find_spec("tensorboard") is not None
17
+ _has_omgaconf = importlib.util.find_spec("omegaconf") is not None
18
+
19
+
20
+ class TensorboardLogger(Logger):
21
+ """Wrapper for the Tensoarboard logger.
22
+
23
+ Args:
24
+ exp_name (str): The name of the experiment.
25
+ log_dir (str): the tensorboard log_dir. Defaults to ``td_logs``.
26
+
27
+ """
28
+
29
+ def __init__(self, exp_name: str, log_dir: str = "tb_logs") -> None:
30
+ super().__init__(exp_name=exp_name, log_dir=log_dir)
31
+ # re-write log_dir
32
+ self.log_dir = self.experiment.log_dir
33
+
34
+ self._has_imported_moviepy = False
35
+
36
+ def _create_experiment(self) -> SummaryWriter: # noqa
37
+ """Creates a tensorboard experiment.
38
+
39
+ Args:
40
+ exp_name (str): The name of the experiment.
41
+
42
+ Returns:
43
+ SummaryWriter: The tensorboard experiment.
44
+
45
+ """
46
+ if not _has_tb:
47
+ raise ImportError("torch.utils.tensorboard could not be imported")
48
+
49
+ from torch.utils.tensorboard import SummaryWriter
50
+
51
+ log_dir = str(os.path.join(self.log_dir, self.exp_name))
52
+ return SummaryWriter(log_dir=log_dir)
53
+
54
+ def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
55
+ """Logs a scalar value to the tensorboard.
56
+
57
+ Args:
58
+ name (str): The name of the scalar.
59
+ value (float): The value of the scalar.
60
+ step (int, optional): The step at which the scalar is logged. Defaults to None.
61
+
62
+ """
63
+ self.experiment.add_scalar(name, value, global_step=step)
64
+
65
+ def log_video(
66
+ self, name: str, video: Tensor, step: int | None = None, **kwargs
67
+ ) -> None:
68
+ """Log videos inputs to the tensorboard.
69
+
70
+ Args:
71
+ name (str): The name of the video.
72
+ video (Tensor): The video to be logged.
73
+ step (int, optional): The step at which the video is logged. Defaults to None.
74
+
75
+ """
76
+ # check for correct format of the video tensor ((N), T, C, H, W)
77
+ # check that the color channel (C) is either 1 or 3
78
+ if video.dim() != 5 or video.size(dim=2) not in {1, 3}:
79
+ raise Exception(
80
+ "Wrong format of the video tensor. Should be ((N), T, C, H, W)"
81
+ )
82
+ if not self._has_imported_moviepy:
83
+ try:
84
+ import moviepy # noqa
85
+
86
+ self._has_imported_moviepy = True
87
+ except ImportError:
88
+ raise Exception(
89
+ "moviepy not found, videos cannot be logged with TensorboardLogger"
90
+ )
91
+ self.experiment.add_video(
92
+ tag=name,
93
+ vid_tensor=video,
94
+ global_step=step,
95
+ **kwargs,
96
+ )
97
+
98
+ def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
99
+ """Logs the hyperparameters of the experiment.
100
+
101
+ Args:
102
+ cfg (DictConfig or dict): The configuration of the experiment.
103
+
104
+ """
105
+ if type(cfg) is not dict and _has_omgaconf:
106
+ if not _has_omgaconf:
107
+ raise ImportError(
108
+ "OmegaConf could not be imported. "
109
+ "Cannot log hydra configs without OmegaConf."
110
+ )
111
+ from omegaconf import OmegaConf
112
+
113
+ cfg = OmegaConf.to_container(cfg, resolve=True)
114
+ self.experiment.add_hparams(cfg, metric_dict={})
115
+
116
+ def __repr__(self) -> str:
117
+ return f"TensorboardLogger(experiment={self.experiment.__repr__()})"
118
+
119
+ def log_histogram(self, name: str, data: Sequence, **kwargs):
120
+ """Add histogram to summary.
121
+
122
+ Args:
123
+ name (str): Data identifier
124
+ data (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram
125
+
126
+ Keyword Args:
127
+ step (int): Global step value to record
128
+ bins (str): One of {‘tensorflow’,’auto’, ‘fd’, …}. This determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
129
+ walltime (:obj:`float`): Optional override default walltime (time.time()) seconds after epoch of event
130
+
131
+ """
132
+ global_step = kwargs.pop("step", None)
133
+ bins = kwargs.pop("bins")
134
+ walltime = kwargs.pop("walltime", None)
135
+ if len(kwargs):
136
+ raise TypeError(f"Unrecognised arguments {kwargs}.")
137
+ self.experiment.add_histogram(
138
+ tag=name, values=data, global_step=global_step, bins=bins, walltime=walltime
139
+ )
@@ -0,0 +1,163 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+
9
+ from collections.abc import Sequence
10
+
11
+ import numpy as np
12
+
13
+ from torch import Tensor
14
+
15
+ from .common import Logger
16
+
17
+ _has_trackio = importlib.util.find_spec("trackio") is not None
18
+ _has_omegaconf = importlib.util.find_spec("omegaconf") is not None
19
+
20
+
21
+ class TrackioLogger(Logger):
22
+ """Wrapper for the trackio logger.
23
+
24
+ Args:
25
+ exp_name (str): The name of the experiment.
26
+ project (str): The name of the project.
27
+
28
+ Keyword Args:
29
+ fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
30
+ **kwargs: Extra keyword arguments for ``trackio.init``.
31
+
32
+ """
33
+
34
+ @classmethod
35
+ def __new__(cls, *args, **kwargs):
36
+ return super().__new__(cls)
37
+
38
+ def __init__(
39
+ self,
40
+ exp_name: str,
41
+ project: str,
42
+ *,
43
+ video_fps: int = 32,
44
+ **kwargs,
45
+ ) -> None:
46
+ if not _has_trackio:
47
+ raise ImportError("trackio could not be imported")
48
+
49
+ self.video_fps = video_fps
50
+ self._trackio_kwargs = {
51
+ "name": exp_name,
52
+ "project": project,
53
+ "resume": "allow",
54
+ **kwargs,
55
+ }
56
+
57
+ super().__init__(exp_name=exp_name, log_dir=project)
58
+
59
+ def _create_experiment(self):
60
+ """Creates a trackio experiment.
61
+
62
+ Args:
63
+ exp_name (str): The name of the experiment.
64
+
65
+ Returns:
66
+ A trackio.Experiment object.
67
+ """
68
+ if not _has_trackio:
69
+ raise ImportError("Trackio is not installed")
70
+ import trackio
71
+
72
+ return trackio.init(**self._trackio_kwargs)
73
+
74
+ def log_scalar(self, name: str, value: float, step: int | None = None) -> None:
75
+ """Logs a scalar value to trackio.
76
+
77
+ Args:
78
+ name (str): The name of the scalar.
79
+ value (float): The value of the scalar.
80
+ step (int, optional): The step at which the scalar is logged.
81
+ Defaults to None.
82
+ """
83
+ self.experiment.log({name: value}, step=step)
84
+
85
+ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
86
+ """Log videos inputs to trackio.
87
+
88
+ Args:
89
+ name (str): The name of the video.
90
+ video (Tensor): The video to be logged.
91
+ **kwargs: Other keyword arguments. By construction, log_video
92
+ supports 'step' (integer indicating the step index), 'format'
93
+ (default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are
94
+ passed as-is to the :obj:`experiment.log` method.
95
+ """
96
+ import trackio
97
+
98
+ fps = kwargs.pop("fps", self.video_fps)
99
+ format = kwargs.pop("format", "mp4")
100
+ self.experiment.log(
101
+ {
102
+ name: trackio.Video(
103
+ video.numpy().astype(np.uint8), fps=fps, format=format
104
+ )
105
+ },
106
+ **kwargs,
107
+ )
108
+
109
+ def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
110
+ """Logs the hyperparameters of the experiment.
111
+
112
+ Args:
113
+ cfg (DictConfig or dict): The configuration of the experiment.
114
+
115
+ """
116
+ if type(cfg) is not dict and _has_omegaconf:
117
+ if not _has_omegaconf:
118
+ raise ImportError(
119
+ "OmegaConf could not be imported. "
120
+ "Cannot log hydra configs without OmegaConf."
121
+ )
122
+ from omegaconf import OmegaConf
123
+
124
+ cfg = OmegaConf.to_container(cfg, resolve=True)
125
+ self.experiment.config.update(cfg)
126
+
127
+ def __repr__(self) -> str:
128
+ return f"TrackioLogger(experiment={self.experiment.__repr__()})"
129
+
130
+ def log_histogram(self, name: str, data: Sequence, **kwargs):
131
+ """Add histogram to log.
132
+
133
+ Args:
134
+ name (str): Data identifier
135
+ data (torch.Tensor, numpy.ndarray): Values to build histogram
136
+
137
+ Keyword Args:
138
+ step (int): Global step value to record
139
+ bins (int): Number of bins to use for the histogram
140
+
141
+ """
142
+ import trackio
143
+
144
+ num_bins = kwargs.pop("bins", None)
145
+ step = kwargs.pop("step", None)
146
+ self.experiment.log(
147
+ {name: trackio.Histogram(data, num_bins=num_bins)}, step=step
148
+ )
149
+
150
+ def log_str(self, name: str, value: str, step: int | None = None) -> None:
151
+ """Logs a string value to trackio using a table format for better visualization.
152
+
153
+ Args:
154
+ name (str): The name of the string data.
155
+ value (str): The string value to log.
156
+ step (int, optional): The step at which the string is logged.
157
+ Defaults to None.
158
+ """
159
+ import trackio
160
+
161
+ # Create a table with a single row
162
+ table = trackio.Table(columns=["text"], data=[[value]])
163
+ self.experiment.log({name: table}, step=step)
@@ -0,0 +1,78 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import pathlib
9
+ import uuid
10
+ from datetime import datetime
11
+
12
+ from torchrl.record.loggers.common import Logger
13
+
14
+
15
+ def generate_exp_name(model_name: str, experiment_name: str) -> str:
16
+ """Generates an ID (str) for the described experiment using UUID and current date."""
17
+ exp_name = "_".join(
18
+ (
19
+ model_name,
20
+ experiment_name,
21
+ str(uuid.uuid4())[:8],
22
+ datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
23
+ )
24
+ )
25
+ return exp_name
26
+
27
+
28
+ def get_logger(
29
+ logger_type: str, logger_name: str, experiment_name: str, **kwargs
30
+ ) -> Logger:
31
+ """Get a logger instance of the provided `logger_type`.
32
+
33
+ Args:
34
+ logger_type (str): One of tensorboard / csv / wandb / mlflow.
35
+ If empty, ``None`` is returned.
36
+ logger_name (str): Name to be used as a log_dir
37
+ experiment_name (str): Name of the experiment
38
+ kwargs (dict[str]): might contain either `wandb_kwargs`, `mlflow_kwargs` or `trackio_kwargs`
39
+ """
40
+ if logger_type == "tensorboard":
41
+ from torchrl.record.loggers.tensorboard import TensorboardLogger
42
+
43
+ logger = TensorboardLogger(log_dir=logger_name, exp_name=experiment_name)
44
+ elif logger_type == "csv":
45
+ from torchrl.record.loggers.csv import CSVLogger
46
+
47
+ logger = CSVLogger(
48
+ log_dir=logger_name, exp_name=experiment_name, video_format="mp4"
49
+ )
50
+ elif logger_type == "wandb":
51
+ from torchrl.record.loggers.wandb import WandbLogger
52
+
53
+ wandb_kwargs = kwargs.get("wandb_kwargs", {})
54
+ logger = WandbLogger(
55
+ log_dir=logger_name, exp_name=experiment_name, **wandb_kwargs
56
+ )
57
+ elif logger_type == "mlflow":
58
+ from torchrl.record.loggers.mlflow import MLFlowLogger
59
+
60
+ mlflow_kwargs = kwargs.get("mlflow_kwargs", {})
61
+ logger = MLFlowLogger(
62
+ tracking_uri=pathlib.Path(os.path.abspath(logger_name)).as_uri(),
63
+ exp_name=experiment_name,
64
+ **mlflow_kwargs,
65
+ )
66
+ elif logger_type == "trackio":
67
+ from torchrl.record.loggers.trackio import TrackioLogger
68
+
69
+ trackio_kwargs = kwargs.get("trackio_kwargs", {})
70
+ project = trackio_kwargs.pop("project", "torchrl")
71
+ logger = TrackioLogger(
72
+ project=project, exp_name=experiment_name, **trackio_kwargs
73
+ )
74
+ elif logger_type in ("", None):
75
+ return None
76
+ else:
77
+ raise NotImplementedError(f"Unsupported logger_type: '{logger_type}'")
78
+ return logger