torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,591 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import abc
8
+ import weakref
9
+ from collections.abc import Callable
10
+ from typing import Any, TypeVar
11
+
12
+ import torch
13
+ from tensordict import TensorDict, TensorDictBase
14
+ from tensordict.nn import TensorDictModuleBase
15
+ from torchrl._utils import logger as torchrl_logger
16
+
17
+ Policy = TypeVar("Policy", bound=TensorDictModuleBase)
18
+
19
+
20
+ class WeightUpdaterBase(metaclass=abc.ABCMeta):
21
+ """A base class for updating remote policy weights on inference workers.
22
+
23
+ .. deprecated::
24
+ WeightUpdaterBase is deprecated and will be removed in a future version.
25
+ Please use WeightSyncScheme from torchrl.weight_update.weight_sync_schemes instead.
26
+
27
+ The weight updater is the central piece of the weight update scheme:
28
+
29
+ - In leaf collector nodes, it is responsible for sending the weights to the policy, which can be as simple as
30
+ updating a state-dict, or more complex if an inference server is being used.
31
+ - In server collector nodes, it is responsible for sending the weights to the leaf collectors.
32
+
33
+ In a collector, the updater is called within :meth:`~torchrl.collector.BaseCollector.update_policy_weights_`.`
34
+
35
+ The main method of this class is the :meth:`~._push_weights` method, which updates the policy weights in the worker /
36
+ policy. This method is called by :meth:`~.push_weights`, which also calls the post-hooks: only `_push_weights` should
37
+ be implemented by child classes.
38
+
39
+ To extend this class, implement the following abstract methods:
40
+
41
+ - `_get_server_weights` (optional): Define how to retrieve the weights from the server if they are not passed to
42
+ the updater directly. This method is only called if the weights (handle) is not passed directly.
43
+ - `_sync_weights_with_worker`: Define how to synchronize weights with a specific worker.
44
+ This method must be implemented by child classes.
45
+ - `_maybe_map_weights`: Optionally transform the server weights before distribution.
46
+ By default, this method returns the weights unchanged.
47
+ - `all_worker_ids`: Provide a list of all worker identifiers.
48
+ Returns `None` by default (no worker id).
49
+ - `from_policy` (optional classmethod): Define how to create an instance of the weight updater from a policy.
50
+ If implemented, this method will be called before falling back to the default constructor when initializing
51
+ a weight updater in a collector.
52
+
53
+ Attributes:
54
+ collector: The collector (or any container) of the weight receiver. The collector is registered via
55
+ :meth:`~torchrl.collectors.WeightUpdaterBase.register_collector`.
56
+
57
+ Methods:
58
+ push_weights: Updates the weights on specified or all remote workers.
59
+ The `__call__` method is a proxy to `push_weights`.
60
+ register_collector: Registers the collector (or any container) in the receiver through a weakref.
61
+ This will be called automatically by the collector upon registration of the updater.
62
+ from_policy: Optional classmethod to create an instance from a policy.
63
+
64
+ Post-hooks:
65
+ - `register_post_hook`: Registers a post-hook to be called after the weights are updated.
66
+ The post-hook must be a callable that takes no arguments.
67
+ The post-hook will be called after the weights are updated.
68
+ The post-hook will be called in the same process as the weight updater.
69
+ The post-hook will be called in the same order as the post-hooks were registered.
70
+
71
+ .. seealso:: :meth:`~torchrl.collectors.BaseCollector.update_policy_weights_`.
72
+
73
+ """
74
+
75
+ _collector_wrs: list[Any] = None
76
+ _post_hooks: list[Callable[[], Any]] | None = None
77
+
78
+ def __init_subclass__(cls, **kwargs):
79
+ super().__init_subclass__(**kwargs)
80
+ import warnings
81
+
82
+ warnings.warn(
83
+ f"Creating {cls.__name__} which inherits from WeightUpdaterBase is deprecated. "
84
+ "Please use WeightSyncScheme from torchrl.weight_update.weight_sync_schemes instead. "
85
+ "This will be removed in a future version.",
86
+ DeprecationWarning,
87
+ stacklevel=2,
88
+ )
89
+
90
+ @property
91
+ def post_hooks(self) -> list[Callable[[], None]]:
92
+ """The list of post-hooks registered to the weight updater."""
93
+ if self._post_hooks is None:
94
+ self._post_hooks = []
95
+ return self._post_hooks
96
+
97
+ @classmethod
98
+ def from_policy(cls, policy: TensorDictModuleBase) -> WeightUpdaterBase | None:
99
+ """Optional classmethod to create a weight updater instance from a policy.
100
+
101
+ This method can be implemented by subclasses to provide custom initialization logic
102
+ based on the policy. If implemented, this method will be called before falling back
103
+ to the default constructor when initializing a weight updater in a collector.
104
+
105
+ Args:
106
+ policy (TensorDictModuleBase): The policy to create the weight updater from.
107
+
108
+ Returns:
109
+ WeightUpdaterBase | None: An instance of the weight updater, or None if the policy
110
+ cannot be used to create an instance.
111
+ """
112
+ return None
113
+
114
+ def register_collector(self, collector): # noqa
115
+ """Register a collector in the updater.
116
+
117
+ Once registered, the updater will not accept another collector.
118
+
119
+ Args:
120
+ collector (BaseCollector): The collector to register.
121
+
122
+ """
123
+ if self._collector_wrs is None:
124
+ self._collector_wrs = []
125
+ self._collector_wrs.append(weakref.ref(collector))
126
+
127
+ @property
128
+ def collector(self) -> Any | None: # noqa
129
+ """The collector or container of the receiver.
130
+
131
+ Returns `None` if the container is out-of-scope or not set.
132
+ """
133
+ if self._collector_wrs is None:
134
+ return None
135
+ if len(self._collector_wrs) > 1:
136
+ raise ValueError("Cannot access `collector` with multiple collectors.")
137
+ if self._collector_wrs:
138
+ collector = self._collector_wrs[0]()
139
+ else:
140
+ collector = None
141
+ return collector
142
+
143
+ @property
144
+ def collectors(self) -> list[Any] | None:
145
+ """The collectors or container of the receiver."""
146
+ if self._collector_wrs is None:
147
+ return None
148
+ if self._collector_wrs:
149
+ collectors = [
150
+ wr() if wr is not None else None for wr in self._collector_wrs
151
+ ]
152
+ else:
153
+ collectors = None
154
+ return collectors
155
+
156
+ def _push_weights(
157
+ self,
158
+ *,
159
+ policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None,
160
+ worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
161
+ ):
162
+ """Updates the weights of the policy, or on specified / all remote workers.
163
+
164
+ Args:
165
+ policy_or_weights: The source to get weights from. Can be:
166
+ - TensorDictModuleBase: A policy module whose weights will be extracted
167
+ - TensorDictBase: A TensorDict containing weights
168
+ - dict: A regular dict containing weights
169
+ - None: Will try to get weights from server using _get_server_weights()
170
+ worker_ids: An optional list of workers to update.
171
+
172
+ Returns: nothing.
173
+ """
174
+ if policy_or_weights is None:
175
+ # Get the weights on server (local)
176
+ server_weights = self._get_server_weights()
177
+ else:
178
+ server_weights = policy_or_weights
179
+
180
+ server_weights = self._maybe_map_weights(server_weights)
181
+
182
+ # Get the remote weights (inference workers)
183
+ if isinstance(worker_ids, (int, torch.device)):
184
+ worker_ids = [worker_ids]
185
+ elif worker_ids is None:
186
+ worker_ids = self.all_worker_ids()
187
+ if worker_ids is None:
188
+ self._sync_weights_with_worker(server_weights=server_weights)
189
+ return
190
+ for worker in worker_ids:
191
+ if self._skip_update(worker):
192
+ continue
193
+ self._sync_weights_with_worker(
194
+ worker_id=worker, server_weights=server_weights
195
+ )
196
+
197
+ def push_weights(
198
+ self,
199
+ policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None,
200
+ worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
201
+ ):
202
+ """Updates the weights of the policy, or on specified / all remote workers.
203
+
204
+ Args:
205
+ policy_or_weights: The source to get weights from. Can be:
206
+ - TensorDictModuleBase: A policy module whose weights will be extracted
207
+ - TensorDictBase: A TensorDict containing weights
208
+ - dict: A regular dict containing weights
209
+ - None: Will try to get weights from server using _get_server_weights()
210
+ worker_ids: An optional list of workers to update.
211
+
212
+ Returns: nothing.
213
+ """
214
+ self._push_weights(policy_or_weights=policy_or_weights, worker_ids=worker_ids)
215
+ self._call_post_hooks()
216
+
217
+ def init(self, *args, **kwargs):
218
+ """Initialize the weight updater with custom arguments.
219
+
220
+ This method can be overridden by subclasses to handle custom initialization.
221
+ By default, this is a no-op.
222
+
223
+ Args:
224
+ *args: Positional arguments for initialization
225
+ **kwargs: Keyword arguments for initialization
226
+ """
227
+ return
228
+
229
+ def register_post_hook(self, hook: Callable[[], None]):
230
+ """Registers a post-hook to be called after weights are updated.
231
+
232
+ Args:
233
+ hook (Callable[[], None]): The post-hook to register.
234
+ """
235
+ self.post_hooks.append(hook)
236
+
237
+ def _call_post_hooks(self):
238
+ """Calls all registered post-hooks in order."""
239
+ for hook in self.post_hooks:
240
+ hook()
241
+
242
+ def _skip_update(self, worker_id: int | torch.device) -> bool:
243
+ """Whether to skip updating weights for a worker.
244
+
245
+ By default, never skips updates. Subclasses can override this to implement
246
+ custom update frequency logic.
247
+
248
+ Args:
249
+ worker_id (int | torch.device): The worker ID to check.
250
+
251
+ Returns:
252
+ bool: Whether to skip the update.
253
+ """
254
+ return False
255
+
256
+ @abc.abstractmethod
257
+ def _sync_weights_with_worker(
258
+ self,
259
+ *,
260
+ worker_id: int | torch.device | None = None,
261
+ server_weights: TensorDictBase,
262
+ ) -> None:
263
+ """Synchronizes weights with a specific worker.
264
+
265
+ This method must be implemented by child classes to define how weights are
266
+ synchronized with workers.
267
+
268
+ Args:
269
+ worker_id (int | torch.device | None): The worker to sync with, if applicable.
270
+ server_weights (TensorDictBase): The weights from the server to sync.
271
+ """
272
+ raise NotImplementedError
273
+
274
+ def _get_server_weights(self) -> TensorDictBase | None:
275
+ """Gets the weights from the server.
276
+
277
+ This method is called when no weights are passed to push_weights().
278
+ By default returns None. Subclasses can override to implement custom
279
+ weight retrieval logic.
280
+
281
+ Returns:
282
+ TensorDictBase | None: The server weights, or None.
283
+ """
284
+ return None
285
+
286
+ def _maybe_map_weights(self, policy_or_weights: TensorDictBase) -> TensorDictBase:
287
+ """Optionally transforms server weights before distribution.
288
+
289
+ By default returns weights unchanged. Subclasses can override to implement
290
+ custom weight mapping logic.
291
+
292
+ Args:
293
+ policy_or_weights (Any): The weights - or any container or handler - to potentially transform, query or extract.
294
+
295
+ Returns:
296
+ TensorDictBase: The transformed weights.
297
+ """
298
+ if isinstance(policy_or_weights, TensorDictModuleBase):
299
+ # Extract weights from policy module
300
+ server_weights = TensorDict.from_module(policy_or_weights).data
301
+ elif isinstance(policy_or_weights, (TensorDictBase, dict)):
302
+ # Use weights directly
303
+ server_weights = policy_or_weights
304
+ else:
305
+ raise TypeError(
306
+ f"policy_or_weights must be None, TensorDictModuleBase, TensorDictBase or dict, got {type(policy_or_weights)}"
307
+ )
308
+ return server_weights
309
+
310
+ def all_worker_ids(self) -> list[int] | list[torch.device] | None:
311
+ """Gets list of all worker IDs.
312
+
313
+ Returns None by default. Subclasses should override to return actual worker IDs.
314
+
315
+ Returns:
316
+ list[int] | list[torch.device] | None: List of worker IDs or None.
317
+ """
318
+ return None
319
+
320
+ def __call__(
321
+ self,
322
+ policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None,
323
+ worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
324
+ ):
325
+ """Updates the weights of the policy, or on specified / all remote workers.
326
+
327
+ Args:
328
+ policy_or_weights: The source to get weights from. Can be:
329
+ - TensorDictModuleBase: A policy module whose weights will be extracted
330
+ - TensorDictBase: A TensorDict containing weights
331
+ - dict: A regular dict containing weights
332
+ - None: Will try to get weights from server using _get_server_weights()
333
+ worker_ids: An optional list of workers to update.
334
+
335
+ Returns: nothing.
336
+ """
337
+ return self.push_weights(
338
+ policy_or_weights=policy_or_weights, worker_ids=worker_ids
339
+ )
340
+
341
+ def increment_version(self):
342
+ """Increment the policy version."""
343
+ for collector in self.collectors:
344
+ collector.increment_version()
345
+
346
+
347
+ # Specialized classes
348
+ class VanillaWeightUpdater(WeightUpdaterBase):
349
+ """A simple implementation of :class:`~torchrl.collectors.WeightUpdaterBase` for updating local policy weights.
350
+
351
+ The `VanillaWeightSender` class provides a basic mechanism for updating the weights
352
+ of a local policy by directly fetching them from a specified source. It is typically used
353
+ in scenarios where the weight update logic is straightforward and does not require any
354
+ complex mapping or transformation.
355
+
356
+ This class is used by default in the `Collector` when no custom weight sender
357
+ is provided.
358
+
359
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase` and :class:`~torchrl.collectors.Collector`.
360
+
361
+ Keyword Args:
362
+ weight_getter (Callable[[], TensorDictBase], optional): a callable that returns the weights from the server.
363
+ If not provided, the weights must be passed to ``push_weights`` directly.
364
+ policy_weights (TensorDictBase): a TensorDictBase containing the policy weights to be updated
365
+ in-place. Use ``push_weights`` to update the weights.
366
+ """
367
+
368
+ @classmethod
369
+ def from_policy(cls, policy: TensorDictModuleBase) -> WeightUpdaterBase | None:
370
+ """Creates a VanillaWeightUpdater instance from a policy.
371
+
372
+ This method creates a weight updater that will update the policy's weights directly
373
+ using its state dict.
374
+
375
+ Args:
376
+ policy (TensorDictModuleBase): The policy to create the weight updater from.
377
+
378
+ Returns:
379
+ VanillaWeightUpdater: An instance of the weight updater configured to update
380
+ the policy's weights.
381
+ """
382
+ policy_weights = TensorDict.from_module(policy)
383
+ return cls(policy_weights=policy_weights.lock_())
384
+
385
+ def __init__(
386
+ self,
387
+ *,
388
+ weight_getter: Callable[[], TensorDictBase] | None = None,
389
+ policy_weights: TensorDictBase,
390
+ ):
391
+ self.weight_getter = weight_getter
392
+ self.policy_weights = policy_weights
393
+
394
+ def _get_server_weights(self) -> TensorDictBase | None:
395
+ return self.weight_getter() if self.weight_getter is not None else None
396
+
397
+ def _get_local_weights(self) -> TensorDictBase:
398
+ return self.policy_weights
399
+
400
+ def _map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
401
+ return server_weights
402
+
403
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
404
+ return server_weights
405
+
406
+ def _sync_weights_with_worker(
407
+ self, *, worker_id: None = None, server_weights: TensorDictBase
408
+ ) -> None:
409
+ if server_weights is None:
410
+ return
411
+ self.policy_weights.update_(server_weights)
412
+
413
+
414
+ class MultiProcessedWeightUpdater(WeightUpdaterBase):
415
+ """A remote weight updater for synchronizing policy weights across multiple processes or devices.
416
+
417
+ The `MultiProcessedWeightUpdater` class provides a mechanism for updating the weights
418
+ of a policy across multiple inference workers in a multiprocessed environment. It is designed
419
+ to handle the distribution of weights from a central server to various devices or processes
420
+ that are running the policy.
421
+ This class is typically used in multiprocessed data collectors where each process or device
422
+ requires an up-to-date copy of the policy weights.
423
+
424
+ Keyword Args:
425
+ get_server_weights (Callable[[], TensorDictBase] | None): A callable that retrieves the
426
+ latest policy weights from the server or another centralized source.
427
+ policy_weights (Dict[torch.device, TensorDictBase]): A dictionary mapping each device or
428
+ process to its current policy weights, which will be updated.
429
+
430
+ .. note::
431
+ This class assumes that the server weights can be directly applied to the workers without
432
+ any additional processing. If your use case requires more complex weight mapping or synchronization
433
+ logic, consider extending `WeightUpdaterBase` with a custom implementation.
434
+
435
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase` and
436
+ :class:`~torchrl.collectors.BaseCollector`.
437
+
438
+ """
439
+
440
+ def __init__(
441
+ self,
442
+ *,
443
+ get_server_weights: Callable[[], TensorDictBase] | None,
444
+ policy_weights: dict[torch.device, TensorDictBase],
445
+ ):
446
+ self.weights_getter = get_server_weights
447
+ self._policy_weights = policy_weights
448
+
449
+ def all_worker_ids(self) -> list[int] | list[torch.device]:
450
+ return list(self._policy_weights)
451
+
452
+ def _sync_weights_with_worker(
453
+ self, worker_id: int | torch.device, server_weights: TensorDictBase | None
454
+ ) -> None:
455
+ if server_weights is None:
456
+ return
457
+ self._policy_weights[worker_id].data.update_(server_weights)
458
+
459
+ def _get_server_weights(self) -> TensorDictBase | None:
460
+ # The weights getter can be none if no mapping is required
461
+ if self.weights_getter is None:
462
+ return
463
+ weights = self.weights_getter()
464
+ if weights is None:
465
+ return
466
+ return weights.data
467
+
468
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
469
+ return server_weights
470
+
471
+
472
+ class RemoteModuleWeightUpdater(WeightUpdaterBase):
473
+ """A weight updater for remote nn.Modules that requires explicit weight passing.
474
+
475
+ This weight updater is designed for scenarios where the master collector doesn't have
476
+ direct access to worker weights (e.g., when using policy_factory). It enforces that
477
+ weights must be passed explicitly when calling update_policy_weights_().
478
+
479
+ This updater does not try to retrieve weights from the server or workers automatically.
480
+ Instead, it raises an exception if weights are not provided, ensuring that the weight
481
+ synchronization is handled explicitly by the user.
482
+
483
+ Raises:
484
+ RuntimeError: If update_policy_weights_() is called without providing weights.
485
+
486
+ .. note::
487
+ This weight updater is primarily used to suppress warnings in tests and scenarios
488
+ where the weight synchronization is handled externally or the workers manage
489
+ their own weight updates.
490
+
491
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase`
492
+ """
493
+
494
+ def __init__(self):
495
+ pass
496
+
497
+ def _get_server_weights(self) -> TensorDictBase | None:
498
+ """Returns None since this updater doesn't manage server weights."""
499
+ return None
500
+
501
+ def _sync_weights_with_worker(
502
+ self,
503
+ *,
504
+ worker_id: int | torch.device | None = None,
505
+ server_weights: TensorDictBase,
506
+ ) -> None:
507
+ """Raises an error if weights are not provided explicitly.
508
+
509
+ Since this updater is for remote modules where the master doesn't have access
510
+ to worker weights, it enforces that weights must be passed explicitly.
511
+ """
512
+ if server_weights is None:
513
+ raise RuntimeError(
514
+ "RemoteModuleWeightUpdater requires weights to be passed explicitly. "
515
+ "Call update_policy_weights_(weights) with the weights to be synchronized."
516
+ )
517
+ # If weights are provided, we assume the synchronization is handled elsewhere
518
+ # This is a no-op updater that just validates the weight passing pattern
519
+
520
+ def all_worker_ids(self) -> None:
521
+ """Returns None since this updater doesn't manage specific workers."""
522
+ return None
523
+
524
+
525
+ class RayWeightUpdater(WeightUpdaterBase):
526
+ """A remote weight updater for synchronizing policy weights across remote workers using Ray.
527
+
528
+ The `RayWeightUpdater` class provides a mechanism for updating the weights of a policy
529
+ across remote inference workers managed by Ray. It leverages Ray's distributed computing
530
+ capabilities to efficiently distribute policy weights to remote collectors.
531
+ This class is typically used in distributed data collectors where each remote worker requires
532
+ an up-to-date copy of the policy weights.
533
+
534
+ Args:
535
+ policy_weights (TensorDictBase): The current weights of the policy that need to be distributed
536
+ to remote workers.
537
+ remote_collectors (List): A list of remote collectors that will receive the updated policy weights.
538
+ max_interval (int, optional): The maximum number of batches between weight updates for each worker.
539
+ Defaults to 0, meaning weights are updated every batch.
540
+
541
+ Methods:
542
+ all_worker_ids: Returns a list of all worker identifiers (indices of remote collectors).
543
+ _get_server_weights: Retrieves the latest weights from the server and stores them in Ray's object store.
544
+ _maybe_map_weights: Optionally maps server weights before distribution (no-op in this implementation).
545
+ _sync_weights_with_worker: Synchronizes the server weights with a specific remote worker using Ray.
546
+ _skip_update: Determines whether to skip the weight update for a specific worker based on the interval.
547
+
548
+ .. note::
549
+ This class assumes that the server weights can be directly applied to the remote workers without
550
+ any additional processing. If your use case requires more complex weight mapping or synchronization
551
+ logic, consider extending `WeightUpdaterBase` with a custom implementation.
552
+
553
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase` and
554
+ :class:`~torchrl.collectors.distributed.RayCollector`.
555
+
556
+ """
557
+
558
+ def __init__(
559
+ self,
560
+ policy_weights: TensorDictBase,
561
+ remote_collectors: list,
562
+ max_interval: int = 0,
563
+ ):
564
+ self.policy_weights = policy_weights
565
+ self.remote_collectors = remote_collectors
566
+ self.max_interval = max(0, max_interval)
567
+ self._batches_since_weight_update = [0] * len(self.remote_collectors)
568
+
569
+ def all_worker_ids(self) -> list[int] | list[torch.device]:
570
+ return list(range(len(self.remote_collectors)))
571
+
572
+ def _get_server_weights(self) -> Any:
573
+ import ray
574
+
575
+ return ray.put(self.policy_weights.data)
576
+
577
+ def _maybe_map_weights(self, server_weights: Any) -> Any:
578
+ return server_weights
579
+
580
+ def _sync_weights_with_worker(self, worker_id: int, server_weights: Any) -> Any:
581
+ torchrl_logger.debug(f"syncing weights with worker {worker_id}")
582
+ c = self.remote_collectors[worker_id]
583
+ c.update_policy_weights_.remote(policy_weights=server_weights)
584
+ self._batches_since_weight_update[worker_id] = 0
585
+
586
+ def _skip_update(self, worker_id: int) -> bool:
587
+ self._batches_since_weight_update[worker_id] += 1
588
+ # Use gt because we just incremented it
589
+ if self._batches_since_weight_update[worker_id] > self.max_interval:
590
+ return False
591
+ return True
@@ -0,0 +1,38 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ //
3
+ // This source code is licensed under the MIT license found in the
4
+ // LICENSE file in the root directory of this source tree.
5
+
6
+ #pragma once
7
+
8
+ #include <pybind11/numpy.h>
9
+
10
+ #include <algorithm>
11
+ #include <cstdint>
12
+ #include <vector>
13
+
14
+ namespace py = pybind11;
15
+
16
+ namespace torchrl {
17
+ namespace utils {
18
+
19
+ template <typename T>
20
+ std::vector<int64_t> NumpyArrayShape(const py::array_t<T>& arr) {
21
+ const int64_t ndim = arr.ndim();
22
+ std::vector<int64_t> shape(ndim);
23
+ for (int64_t i = 0; i < ndim; ++i) {
24
+ shape[i] = static_cast<int64_t>(arr.shape(i));
25
+ }
26
+ return shape;
27
+ }
28
+
29
+ template <typename T_SRC, typename T_DST = T_SRC>
30
+ py::array_t<T_DST> NumpyEmptyLike(const py::array_t<T_SRC>& src) {
31
+ py::array_t<T_DST> dst(src.size());
32
+ const std::vector<int64_t> shape = NumpyArrayShape(src);
33
+ dst.resize(shape);
34
+ return dst;
35
+ }
36
+
37
+ } // namespace utils
38
+ } // namespace torchrl
@@ -0,0 +1,27 @@
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ //
3
+ // This source code is licensed under the MIT license found in the
4
+ // LICENSE file in the root directory of this source tree.
5
+
6
+ #include <pybind11/numpy.h>
7
+ #include <pybind11/pybind11.h>
8
+ #include <torch/extension.h>
9
+ #include <torch/torch.h>
10
+
11
+ #include <memory>
12
+
13
+ #include "segment_tree.h"
14
+ #include "utils.h"
15
+
16
+ namespace py = pybind11;
17
+
18
+ PYBIND11_MODULE(_torchrl, m) {
19
+ torchrl::DefineSumSegmentTree<float>("Fp32", m);
20
+ torchrl::DefineSumSegmentTree<double>("Fp64", m);
21
+
22
+ torchrl::DefineMinSegmentTree<float>("Fp32", m);
23
+ torchrl::DefineMinSegmentTree<double>("Fp64", m);
24
+
25
+ m.def("safetanh", &safetanh, "Safe Tanh");
26
+ m.def("safeatanh", &safeatanh, "Safe Inverse Tanh");
27
+ }