torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,94 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ def check_finite(tensor: torch.Tensor):
12
+ """Raise an error if a tensor has non-finite elements."""
13
+ if not tensor.isfinite().all():
14
+ raise ValueError("Encountered a non-finite tensor.")
15
+
16
+
17
+ def _init_first(fun):
18
+ def new_fun(self, *args, **kwargs):
19
+ if not self.initialized:
20
+ self._init()
21
+ return fun(self, *args, **kwargs)
22
+
23
+ return new_fun
24
+
25
+
26
+ class _set_missing_tolerance:
27
+ """Context manager to change the transform tolerance to missing values.
28
+
29
+ If a transform has a missing_tolerance of True, it will not raise an error if a key is missing during reset.
30
+
31
+ This is implemented via :meth:`~torchrl.envs.transforms.Transform.set_missing_tolerance`.
32
+
33
+ The way this is handled is that, if `_reset` calls the default `_call` method, it will not raise an error if an input key is missing.
34
+
35
+ For custom `_reset` methods, you should implement this yourself:
36
+
37
+ Exmples:
38
+ >>> def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
39
+ ... with _set_missing_tolerance(self, True):
40
+ ... tensordict_reset = self.foo(tensordict, tensordict_reset)
41
+ ... return tensordict_reset
42
+ >>> def foo(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
43
+ ... if self.input_keys[0] not in tensordict_reset and self.missing_tolerance:
44
+ ... return tensordict_reset
45
+ ... else:
46
+ ... # your code here
47
+
48
+ Because `missing_tolerance` will be turned off during calls to `_step`, you can be sure that an appropriate KeyError will be raised
49
+ if the input key is missing at that time.
50
+
51
+ """
52
+
53
+ def __init__(self, transform, mode):
54
+ self.transform = transform
55
+ self.mode = mode
56
+
57
+ def __enter__(self):
58
+ self.exit_mode = self.transform.missing_tolerance
59
+ if self.mode != self.exit_mode:
60
+ self.transform.set_missing_tolerance(self.mode)
61
+
62
+ def __exit__(self, exc_type, exc_val, exc_tb):
63
+ if self.mode != self.exit_mode:
64
+ self.transform.set_missing_tolerance(self.exit_mode)
65
+
66
+
67
+ def _get_reset(reset_key, tensordict):
68
+ _reset = tensordict.get(reset_key, None)
69
+ # reset key must be unraveled already
70
+ parent_td = (
71
+ tensordict.get(reset_key[:-1], None)
72
+ if isinstance(reset_key, tuple)
73
+ else tensordict
74
+ )
75
+ if parent_td is None:
76
+ # we do this just in case the nested td wasn't found
77
+ parent_td = tensordict
78
+ if _reset is None:
79
+ _reset = torch.ones(
80
+ (),
81
+ dtype=torch.bool,
82
+ device=parent_td.device,
83
+ ).expand(parent_td.batch_size)
84
+ if _reset.ndim > parent_td.ndim:
85
+ _reset = _reset.flatten(parent_td.ndim, -1).any(-1)
86
+ return _reset
87
+
88
+
89
+ def _stateless_param(param):
90
+ is_param = isinstance(param, nn.Parameter)
91
+ param = param.data.to("meta")
92
+ if is_param:
93
+ return nn.Parameter(param, requires_grad=False)
94
+ return param
@@ -0,0 +1,307 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import importlib
9
+ import os
10
+ import subprocess
11
+ from functools import partial
12
+
13
+ import torch
14
+ from tensordict import TensorDictBase
15
+ from torch import nn
16
+ from torchrl._utils import logger as torchrl_logger
17
+
18
+ from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec, Unbounded
19
+ from torchrl.envs.transforms.transforms import (
20
+ CenterCrop,
21
+ Compose,
22
+ ObservationNorm,
23
+ Resize,
24
+ ToTensorImage,
25
+ Transform,
26
+ )
27
+ from torchrl.envs.transforms.utils import _set_missing_tolerance
28
+
29
+ _has_vc = importlib.util.find_spec("vc_models") is not None
30
+
31
+
32
+ class VC1Transform(Transform):
33
+ """VC1 Transform class.
34
+
35
+ VC1 provides pre-trained ResNet weights aimed at facilitating visual
36
+ embedding for robotic tasks. The models are trained using Ego4d.
37
+
38
+ See the paper:
39
+ VC1: A Universal Visual Representation for Robot Manipulation (Suraj Nair,
40
+ Aravind Rajeswaran, Vikash Kumar, Chelsea Finn, Abhinav Gupta)
41
+ https://arxiv.org/abs/2203.12601
42
+
43
+ The VC1Transform is created in a lazy manner: the object will be initialized
44
+ only when an attribute (a spec or the forward method) will be queried.
45
+ The reason for this is that the :obj:`_init()` method requires some attributes of
46
+ the parent environment (if any) to be accessed: by making the class lazy we
47
+ can ensure that the following code snippet works as expected:
48
+
49
+ Examples:
50
+ >>> transform = VC1Transform("default", in_keys=["pixels"])
51
+ >>> env.append_transform(transform)
52
+ >>> # the forward method will first call _init which will look at env.observation_spec
53
+ >>> env.reset()
54
+
55
+ Args:
56
+ in_keys (list of NestedKeys): list of input keys. If left empty, the
57
+ "pixels" key is assumed.
58
+ out_keys (list of NestedKeys, optional): list of output keys. If left empty,
59
+ "VC1_vec" is assumed.
60
+ model_name (str): One of ``"large"``, ``"base"`` or any other compatible
61
+ model name (see the `github repo <https://github.com/facebookresearch/eai-vc>`_ for more info). Defaults to ``"default"``
62
+ which provides a small, untrained model for testing.
63
+ del_keys (bool, optional): If ``True`` (default), the input key will be
64
+ discarded from the returned tensordict.
65
+ """
66
+
67
+ inplace = False
68
+ IMPORT_ERROR = (
69
+ "Could not load vc_models. You can install it via "
70
+ "VC1Transform.install_vc_models()."
71
+ )
72
+
73
+ def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True):
74
+ if model_name == "default":
75
+ self.make_noload_model()
76
+ model_name = "vc1_vitb_noload"
77
+ self.model_name = model_name
78
+ self.del_keys = del_keys
79
+
80
+ super().__init__(in_keys=in_keys, out_keys=out_keys)
81
+ self._init()
82
+
83
+ def _init(self):
84
+ try:
85
+ from vc_models.models.vit import model_utils
86
+ except ModuleNotFoundError as err:
87
+ raise ModuleNotFoundError(self.IMPORT_ERROR) from err
88
+
89
+ if self.model_name == "base":
90
+ model_name = model_utils.VC1_BASE_NAME
91
+ elif self.model_name == "large":
92
+ model_name = model_utils.VC1_LARGE_NAME
93
+ else:
94
+ model_name = self.model_name
95
+
96
+ model, embd_size, model_transforms, model_info = model_utils.load_model(
97
+ model_name
98
+ )
99
+ self.model = model
100
+ self.embd_size = embd_size
101
+ self.model_transforms = self._map_tv_to_torchrl(model_transforms)
102
+
103
+ def _map_tv_to_torchrl(
104
+ self,
105
+ model_transforms,
106
+ in_keys=None,
107
+ ):
108
+ if in_keys is None:
109
+ in_keys = self.in_keys
110
+ from torchvision import transforms
111
+
112
+ if isinstance(model_transforms, transforms.Resize):
113
+ size = model_transforms.size
114
+ if isinstance(size, int):
115
+ size = (size, size)
116
+ return Resize(
117
+ *size,
118
+ in_keys=in_keys,
119
+ )
120
+ elif isinstance(model_transforms, transforms.CenterCrop):
121
+ size = model_transforms.size
122
+ if isinstance(size, int):
123
+ size = (size,)
124
+ return CenterCrop(
125
+ *size,
126
+ in_keys=in_keys,
127
+ )
128
+ elif isinstance(model_transforms, transforms.Normalize):
129
+ return ObservationNorm(
130
+ in_keys=in_keys,
131
+ loc=torch.as_tensor(model_transforms.mean).reshape(3, 1, 1),
132
+ scale=torch.as_tensor(model_transforms.std).reshape(3, 1, 1),
133
+ standard_normal=True,
134
+ )
135
+ elif isinstance(model_transforms, transforms.ToTensor):
136
+ return ToTensorImage(
137
+ in_keys=in_keys,
138
+ )
139
+ elif isinstance(model_transforms, transforms.Compose):
140
+ transform_list = []
141
+ for t in model_transforms.transforms:
142
+
143
+ if isinstance(t, transforms.ToTensor):
144
+ transform_list.insert(0, t)
145
+ else:
146
+ transform_list.append(t)
147
+ if len(transform_list) == 0:
148
+ raise RuntimeError("Did not find any transform.")
149
+ for i, t in enumerate(transform_list):
150
+ if i == 0:
151
+ transform_list[i] = self._map_tv_to_torchrl(t)
152
+ else:
153
+ transform_list[i] = self._map_tv_to_torchrl(t)
154
+ return Compose(*transform_list)
155
+ else:
156
+ raise NotImplementedError(type(model_transforms))
157
+
158
+ def _call(self, next_tensordict):
159
+ if not self.del_keys:
160
+ in_keys = [
161
+ in_key
162
+ for in_key, out_key in zip(self.in_keys, self.out_keys)
163
+ if in_key != out_key
164
+ ]
165
+ saved_td = next_tensordict.select(*in_keys)
166
+ with next_tensordict.view(-1) as tensordict_view:
167
+ super()._call(self.model_transforms(tensordict_view))
168
+ if self.del_keys:
169
+ next_tensordict.exclude(*self.in_keys, inplace=True)
170
+ else:
171
+ # reset in_keys
172
+ next_tensordict.update(saved_td)
173
+ return next_tensordict
174
+
175
+ forward = _call
176
+
177
+ def _reset(
178
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
179
+ ) -> TensorDictBase:
180
+ # TODO: Check this makes sense
181
+ with _set_missing_tolerance(self, True):
182
+ tensordict_reset = self._call(tensordict_reset)
183
+ return tensordict_reset
184
+
185
+ @torch.no_grad()
186
+ def _apply_transform(self, obs: torch.Tensor) -> None:
187
+ shape = None
188
+ if obs.ndimension() > 4:
189
+ shape = obs.shape[:-3]
190
+ obs = obs.flatten(0, -4)
191
+ out = self.model(obs)
192
+ if shape is not None:
193
+ out = out.view(*shape, *out.shape[1:])
194
+ return out
195
+
196
+ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
197
+ if not isinstance(observation_spec, Composite):
198
+ raise ValueError("VC1Transform can only infer Composite")
199
+
200
+ keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
201
+ device = observation_spec[keys[0]].device
202
+ dim = observation_spec[keys[0]].shape[:-3]
203
+
204
+ observation_spec = observation_spec.clone()
205
+ if self.del_keys:
206
+ for in_key in keys:
207
+ del observation_spec[in_key]
208
+
209
+ for out_key in self.out_keys:
210
+ observation_spec[out_key] = Unbounded(
211
+ shape=torch.Size([*dim, self.embd_size]), device=device
212
+ )
213
+
214
+ return observation_spec
215
+
216
+ def to(self, dest: DEVICE_TYPING | torch.dtype):
217
+ if isinstance(dest, torch.dtype):
218
+ self._dtype = dest
219
+ else:
220
+ self._device = dest
221
+ return super().to(dest)
222
+
223
+ @property
224
+ def device(self):
225
+ return self._device
226
+
227
+ @property
228
+ def dtype(self):
229
+ return self._dtype
230
+
231
+ @classmethod
232
+ def install_vc_models(cls, auto_exit=False):
233
+ try:
234
+ from vc_models import models # noqa: F401
235
+
236
+ torchrl_logger.info("vc_models found, no need to install.")
237
+ except ModuleNotFoundError:
238
+ HOME = os.environ.get("HOME")
239
+ vcdir = HOME + "/.cache/torchrl/eai-vc"
240
+ parentdir = os.path.dirname(os.path.abspath(vcdir))
241
+ os.makedirs(parentdir, exist_ok=True)
242
+ try:
243
+ from git import Repo
244
+ except ModuleNotFoundError as err:
245
+ raise ModuleNotFoundError(
246
+ "Could not load git. Make sure that `git` has been installed "
247
+ "in your virtual environment."
248
+ ) from err
249
+ Repo.clone_from("https://github.com/facebookresearch/eai-vc.git", vcdir)
250
+ os.chdir(vcdir + "/vc_models")
251
+ subprocess.call(["python", "setup.py", "develop"])
252
+ if not auto_exit:
253
+ input(
254
+ "VC1 has been successfully installed. Exit this python run and "
255
+ "relaunch it again. Press Enter to exit..."
256
+ )
257
+ exit()
258
+
259
+ @classmethod
260
+ def make_noload_model(cls):
261
+ """Creates an naive model at a custom destination."""
262
+ import vc_models
263
+
264
+ models_filepath = os.path.dirname(os.path.abspath(vc_models.__file__))
265
+ cfg_path = os.path.join(
266
+ models_filepath, "conf", "model", "vc1_vitb_noload.yaml"
267
+ )
268
+ if os.path.exists(cfg_path):
269
+ return
270
+ config = """_target_: vc_models.models.load_model
271
+ model:
272
+ _target_: vc_models.models.vit.vit.load_mae_encoder
273
+ checkpoint_path:
274
+ model:
275
+ _target_: torchrl.envs.transforms.vc1._vit_base_patch16
276
+ img_size: 224
277
+ use_cls: True
278
+ drop_path_rate: 0.0
279
+ transform:
280
+ _target_: vc_models.transforms.vit_transforms
281
+ metadata:
282
+ algo: mae
283
+ model: vit_base_patch16
284
+ data:
285
+ - ego
286
+ - imagenet
287
+ - inav
288
+ comment: 182_epochs
289
+ """
290
+ with open(cfg_path, "w") as file:
291
+ file.write(config)
292
+
293
+
294
+ def _vit_base_patch16(**kwargs):
295
+ from vc_models.models.vit.vit import VisionTransformer
296
+
297
+ model = VisionTransformer(
298
+ patch_size=16,
299
+ embed_dim=16,
300
+ depth=4,
301
+ num_heads=4,
302
+ mlp_ratio=4,
303
+ qkv_bias=True,
304
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
305
+ **kwargs,
306
+ )
307
+ return model