torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,189 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This source code is licensed under the MIT license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ from __future__ import annotations
6
+
7
+ import uuid
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ from typing import cast, Literal
11
+
12
+ import torch
13
+ from tensordict import NonTensorData, TensorDictBase
14
+ from torchrl.data.tensor_specs import Composite, NonTensor, Unbounded
15
+ from torchrl.envs.transforms.transforms import Transform
16
+
17
+
18
+ @dataclass
19
+ class VersionChange:
20
+ """Records a single version change event."""
21
+
22
+ timestamp: datetime
23
+ old_version: str | int | None
24
+ new_version: str | int
25
+
26
+
27
+ class PolicyVersion(Transform):
28
+ """A transform that keeps track of the version of the policy.
29
+
30
+ This transform is used to track policy versions during training, particularly in asynchronous
31
+ settings where policy weights are updated periodically. It is designed to work seamlessly with
32
+ :class:`~torchrl.collectors.llm.LLMCollector` to ensure data collection and training remain in sync.
33
+
34
+ The version can be either a UUID (string) or an integer counter. When used with :class:`~torchrl.collectors.llm.LLMCollector`,
35
+ the version is automatically incremented each time the policy weights are updated.
36
+
37
+ Example usage with :class:`~torchrl.collectors.llm.LLMCollector`:
38
+
39
+ .. code-block:: python
40
+
41
+ >>> # Create a policy version tracker
42
+ >>> policy_version = PolicyVersion(version_type="int") # or "uuid" for UUID-based versioning
43
+ >>> # Create collector with version tracking
44
+ >>> collector = LLMCollector(
45
+ ... env=env,
46
+ ... policy=policy,
47
+ ... track_policy_version=policy_version, # Pass the version tracker
48
+ ... # ... other arguments
49
+ ... )
50
+ >>> # The version will be automatically incremented when weights are updated
51
+ >>> collector.update_policy_weights_(new_weights)
52
+ >>> # The version is stored in the collected data
53
+ >>> for batch in collector:
54
+ ... current_version = batch["policy_version"]
55
+
56
+ Args:
57
+ version_type: The type of versioning to use. Can be either:
58
+ - str or "uuid": Uses UUID4 for versions (good for distributed systems)
59
+ - int or "int": Uses incrementing integers (good for debugging)
60
+ """
61
+
62
+ def __init__(self, version_type: type | Literal["uuid", "int"] = int):
63
+ super().__init__()
64
+ self.version_type = version_type
65
+ self.version_history: list[VersionChange] = [] # Track version changes
66
+ self._current_version: str | int | None = None
67
+ self._increment_version(init=True)
68
+ self.cal_on_reset = True
69
+
70
+ @property
71
+ def version(self) -> str | int:
72
+ """The current version of the policy."""
73
+ if self._current_version is None:
74
+ raise RuntimeError("Version not initialized")
75
+ return self._current_version
76
+
77
+ @version.setter
78
+ def version(self, value: str | int) -> None:
79
+ self._current_version = value
80
+
81
+ def increment_version(self) -> None:
82
+ """Increment the version number.
83
+
84
+ This is called automatically by LLMCollector when policy weights are updated.
85
+ Can also be called manually if needed.
86
+ """
87
+ self._increment_version()
88
+
89
+ def _increment_version(self, init: bool = False) -> str | int:
90
+ """Internal method to handle version incrementing with history tracking."""
91
+ old_version = self._current_version
92
+ if self.version_type in (str, "uuid"):
93
+ self._increment_version_uuid(init)
94
+ elif self.version_type in (int, "int"):
95
+ self._increment_version_int(init)
96
+ else:
97
+ raise ValueError(f"Invalid version type: {self.version_type}")
98
+
99
+ # Record the version change
100
+ self.version_history.append(
101
+ VersionChange(
102
+ timestamp=datetime.now(),
103
+ old_version=old_version,
104
+ new_version=self.version,
105
+ )
106
+ )
107
+ return self.version
108
+
109
+ def _increment_version_uuid(self, init: bool = False) -> None:
110
+ """Generate a new UUID version.
111
+
112
+ Args:
113
+ init: If True, this is the initial version creation.
114
+ """
115
+ self.version = str(uuid.uuid4())
116
+
117
+ def _increment_version_int(self, init: bool = False) -> None:
118
+ """Increment the integer version counter.
119
+
120
+ Args:
121
+ init: If True, initialize counter to 0, otherwise increment by 1.
122
+ """
123
+ if init:
124
+ self.version = 0
125
+ else:
126
+ # Cast to int to ensure type safety
127
+ current = cast(int, self.version)
128
+ self.version = current + 1
129
+
130
+ def _reset(
131
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
132
+ ) -> TensorDictBase:
133
+ """Reset the environment and update version in the new tensordict.
134
+
135
+ Args:
136
+ tensordict: The current tensordict
137
+ tensordict_reset: The tensordict to reset to
138
+
139
+ Returns:
140
+ The reset tensordict with updated version
141
+ """
142
+ tensordict_reset = self._step(None, tensordict_reset)
143
+ return tensordict_reset
144
+
145
+ def _step(
146
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
147
+ ) -> TensorDictBase:
148
+ """Add the current version to the tensordict.
149
+
150
+ This method is called on each environment step to ensure the collected
151
+ data is tagged with the correct policy version.
152
+
153
+ Args:
154
+ tensordict: The tensordict to update with version info
155
+
156
+ Returns:
157
+ The tensordict with added version information
158
+ """
159
+ if self.version_type in (str, "uuid"):
160
+ version = NonTensorData(self.version).expand(next_tensordict.shape)
161
+ elif self.version_type in (int, "int"):
162
+ # Cast to float for torch.full
163
+ version = torch.full(next_tensordict.shape, float(cast(int, self.version)))
164
+ else:
165
+ raise ValueError(f"Invalid version type: {self.version_type}")
166
+
167
+ next_tensordict.set("policy_version", version)
168
+ return next_tensordict
169
+
170
+ def transform_observation_spec(self, spec: Composite) -> Composite:
171
+ """Update the environment spec to include the version field.
172
+
173
+ Args:
174
+ spec: The environment spec to update
175
+
176
+ Returns:
177
+ Updated spec including the version field
178
+ """
179
+ if self.version_type in (str, "uuid"):
180
+ spec["policy_version"] = NonTensor(
181
+ example_data=uuid.uuid4(), shape=spec.shape, device=spec.device
182
+ )
183
+ elif self.version_type in (int, "int"):
184
+ spec["policy_version"] = Unbounded(
185
+ shape=spec.shape, dtype=torch.int64, device=spec.device
186
+ )
187
+ else:
188
+ raise ValueError(f"Invalid version type: {self.version_type}")
189
+ return spec
@@ -0,0 +1,323 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import re
9
+ from collections.abc import Callable
10
+ from typing import Literal
11
+
12
+ from tensordict import lazy_stack, TensorDictBase
13
+ from torchrl._utils import logger as torchrl_logger
14
+
15
+ from torchrl.data.llm.history import History
16
+ from torchrl.envs import Transform
17
+ from torchrl.envs.common import EnvBase
18
+
19
+
20
+ class AddThinkingPrompt(Transform):
21
+ """A transform that adds thinking prompts to encourage the LLM to reconsider its response.
22
+
23
+ This transform can either add a new thinking prompt as a separate message or edit the last
24
+ assistant response to include a thinking prompt before the final answer. This is useful for
25
+ training LLMs to self-correct and think more carefully when their initial responses are
26
+ incorrect or incomplete.
27
+
28
+ Args:
29
+ cond (Callable[[TensorDictBase], bool], optional): Condition function that determines
30
+ when to add the thinking prompt. Takes a tensordict and returns `True` if the prompt
31
+ should be added.
32
+ prompt (str, optional): The thinking prompt to add. If None, a default prompt is used.
33
+ Defaults to `"But wait, let me think about this more carefully..."`.
34
+ random_prompt (bool, optional): Whether to randomly select from predefined prompts.
35
+ Defaults to `False`.
36
+ role (Literal["user", "assistant"], optional): The role for the thinking prompt.
37
+ If `"assistant"`, the prompt is added to the assistant's response. If `"user"`, it's
38
+ added as a separate user message. Defaults to `"assistant"`.
39
+ edit_last_turn (bool, optional): Whether to edit the last assistant response instead
40
+ of adding a new message. Only works with `role="assistant"`. Defaults to `True`.
41
+ zero_reward (bool, optional): Whether to zero out the reward when the thinking prompt
42
+ is added. If `None`, defaults to the value of `edit_last_turn`. Defaults to the same value as `edit_last_turn`.
43
+ undo_done (bool, optional): Whether to undo the done flag when the thinking prompt
44
+ is added. Defaults to `True`.
45
+ egocentric (bool, optional): Whether the thinking prompt is written from the perspective of the assistant.
46
+ Defaults to `None`, which means that the prompt is written from the perspective of the user if `role="user"`
47
+ and from the perspective of the assistant if `role="assistant"`.
48
+
49
+ Examples:
50
+ >>> from torchrl.envs.llm.transforms import AddThinkingPrompt
51
+ >>> from torchrl.envs.llm import GSM8KEnv
52
+ >>> from transformers import AutoTokenizer
53
+ >>> import torch
54
+ >>>
55
+ >>> # Create environment with thinking prompt transform
56
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
57
+ >>> env = GSM8KEnv(tokenizer=tokenizer, max_steps=10)
58
+ >>> env = env.append_transform(
59
+ ... AddThinkingPrompt(
60
+ ... cond=lambda td: td["reward"] < 50,
61
+ ... role="assistant",
62
+ ... edit_last_turn=True,
63
+ ... zero_reward=True,
64
+ ... undo_done=True
65
+ ... )
66
+ ... )
67
+ >>>
68
+ >>> # Test with wrong answer (low reward)
69
+ >>> reset = env.reset()
70
+ >>> wrong_answer = (
71
+ ... "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
72
+ ... "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
73
+ ... "To find the total, I need to add April and May: 48 + 24 = 72. "
74
+ ... "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
75
+ ... "<answer>322 clips</answer><|im_end|>"
76
+ ... )
77
+ >>> reset["text_response"] = [wrong_answer]
78
+ >>> s = env.step(reset)
79
+ >>> assert (s["next", "reward"] == 0).all() # Reward zeroed
80
+ >>> assert (s["next", "done"] == 0).all() # Done undone
81
+ >>> assert s["next", "history"].shape == (1, 3) # History modified
82
+ >>>
83
+ >>> # Test with correct answer (high reward)
84
+ >>> reset = env.reset()
85
+ >>> correct_answer = (
86
+ ... "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
87
+ ... "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
88
+ ... "To find the total, I need to add April and May: 48 + 24 = 72. "
89
+ ... "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
90
+ ... "<answer>72</answer><|im_end|>"
91
+ ... )
92
+ >>> reset["text_response"] = [correct_answer]
93
+ >>> s = env.step(reset)
94
+ >>> assert (s["next", "reward"] != 0).all() # Reward not zeroed
95
+ >>> assert s["next", "done"].all() # Done remains True
96
+ >>> assert s["next", "history"].shape == (1, 3) # History unchanged
97
+ """
98
+
99
+ # Predefined thinking prompts
100
+ DEFAULT_PROMPTS_EG = [
101
+ "But wait, let me think about this more carefully...",
102
+ "Actually, let me reconsider this...",
103
+ "But we can do better. Let me think about it step by step...",
104
+ "Wait, I need to double-check my reasoning...",
105
+ "Actually, let me think about it more carefully...",
106
+ "It looks like I made a mistake. Let me think about it step by step...",
107
+ ]
108
+ DEFAULT_PROMPTS_COG = [
109
+ "But wait, think about this more carefully...",
110
+ "Actually, reconsider this...",
111
+ "But we can do better. Let's think about it step by step...",
112
+ "Wait, you need to double-check your reasoning...",
113
+ "Actually, think about it more carefully...",
114
+ "It looks like you made a mistake. Can you see what went wrong? Let's think about it step by step...",
115
+ ]
116
+
117
+ def __init__(
118
+ self,
119
+ cond: Callable[[TensorDictBase], bool],
120
+ prompt: str | None = None,
121
+ random_prompt: bool = False,
122
+ role: Literal["user", "assistant"] = "assistant",
123
+ edit_last_turn: bool = True,
124
+ zero_reward: bool | None = None,
125
+ undo_done: bool = True,
126
+ egocentric: bool | None = None,
127
+ ) -> None:
128
+ super().__init__()
129
+
130
+ # Set condition and role
131
+ self.cond = cond
132
+ self.role = role
133
+ if egocentric is None:
134
+ egocentric = role == "assistant"
135
+ self.egocentric = egocentric
136
+
137
+ # Set the prompt
138
+ if prompt is None:
139
+ prompt = (
140
+ self.DEFAULT_PROMPTS_EG[0]
141
+ if egocentric
142
+ else self.DEFAULT_PROMPTS_COG[0]
143
+ )
144
+ self._prompt = prompt
145
+ self.random_prompt = random_prompt
146
+
147
+ # Validate edit_last_turn constraint
148
+ if edit_last_turn and role != "assistant":
149
+ raise ValueError("edit_last_turn can only be used with role='assistant'")
150
+ self.edit_last_turn = edit_last_turn
151
+
152
+ # Set zero_reward behavior
153
+ if zero_reward is None:
154
+ zero_reward = edit_last_turn
155
+ self.zero_reward = zero_reward
156
+ self.undo_done = undo_done
157
+
158
+ @property
159
+ def prompt(self) -> str:
160
+ if self.random_prompt:
161
+ import random
162
+
163
+ return random.choice(
164
+ self.DEFAULT_PROMPTS_EG if self.egocentric else self.DEFAULT_PROMPTS_COG
165
+ )
166
+ return self._prompt
167
+
168
+ def _step(
169
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
170
+ ) -> TensorDictBase:
171
+ """Process the tensordict and add thinking prompts based on the condition.
172
+
173
+ Args:
174
+ tensordict: The current tensordict
175
+ next_tensordict: The next tensordict containing the most recent history and reward
176
+
177
+ Returns:
178
+ The modified next_tensordict
179
+ """
180
+ # Handle batch dimensions
181
+ if next_tensordict.batch_dims >= 1:
182
+ ntds = []
183
+ for td, next_td in zip(tensordict.unbind(0), next_tensordict.unbind(0)):
184
+ ntds.append(self._step(td, next_td))
185
+ next_tensordict.update(lazy_stack(ntds))
186
+ return next_tensordict
187
+
188
+ # Check that base_env is on history mode
189
+ parent = self.parent
190
+ if parent is None:
191
+ raise RuntimeError("AddThinkingPrompt must be used with a ChatEnv")
192
+ base_env = parent.base_env
193
+ if base_env.input_mode != "history":
194
+ raise RuntimeError(
195
+ "AddThinkingPrompt must be used with a ChatEnv in history mode"
196
+ )
197
+
198
+ # Check if we should add the thinking prompt
199
+ if self.cond(next_tensordict):
200
+ torchrl_logger.info("Adding thinking prompt.")
201
+ history: History = next_tensordict["history"].prompt
202
+ last_turn = history[..., -1]
203
+
204
+ if self.edit_last_turn:
205
+
206
+ # Edit the last assistant response
207
+ content = last_turn.content
208
+ modified_content = self._replace_answer_with_prompt(content)
209
+
210
+ # Create new history entry with modified content
211
+ new_turn = History(
212
+ role="assistant",
213
+ content=modified_content,
214
+ batch_size=last_turn.batch_size,
215
+ device=last_turn.device,
216
+ )
217
+
218
+ # Replace the last turn in history
219
+ history = history[..., :-1].append(new_turn)
220
+ next_tensordict["history"].prompt = history
221
+
222
+ else:
223
+ # Add a new message
224
+ prompt = self.prompt
225
+
226
+ history = history.append(History(role=self.role, content=prompt))
227
+ next_tensordict["history"].prompt = history
228
+
229
+ if self.undo_done:
230
+ parent: EnvBase = self.parent
231
+ if parent is not None:
232
+ done_keys = parent.done_keys
233
+ for key in done_keys:
234
+ done = next_tensordict.get(key)
235
+ if done is not None:
236
+ next_tensordict.set(key, done.zero_())
237
+
238
+ # Zero out reward if requested
239
+ if self.zero_reward:
240
+ parent: EnvBase = self.parent
241
+ if parent is not None:
242
+ reward_keys = parent.reward_keys
243
+ for key in reward_keys:
244
+ reward = next_tensordict.get(key)
245
+ if reward is not None:
246
+ next_tensordict.set(key, reward.zero_())
247
+ else:
248
+ torchrl_logger.info("Not adding thinking prompt.")
249
+ return next_tensordict
250
+
251
+ def _replace_answer_with_prompt(self, content: str) -> str:
252
+ """Replace the last answer section with a thinking prompt.
253
+
254
+ This method uses regex to find and replace the last <answer>...</answer> section
255
+ with the thinking prompt, preserving any content before the answer tag.
256
+ Only the last answer block is replaced to avoid interfering with earlier
257
+ examples or instructions that might contain answer tags.
258
+
259
+ Args:
260
+ content: The original content string
261
+
262
+ Returns:
263
+ The modified content with the last answer replaced by the thinking prompt
264
+ """
265
+ # Pattern to match <answer>...</answer> with optional EOS token
266
+ # Use non-greedy matching and be more specific about the end
267
+ answer_pattern = r"<answer>.*?</answer>(?:\s*<\|im_end\|>)?"
268
+
269
+ # Check if there's an answer tag
270
+ if "<answer>" in content:
271
+ # Find all matches to get the last one
272
+ matches = list(re.finditer(answer_pattern, content, flags=re.DOTALL))
273
+
274
+ if matches:
275
+ # Get the last match
276
+ last_match = matches[-1]
277
+ start, end = last_match.span()
278
+
279
+ # Replace only the last answer section with the thinking prompt
280
+ prompt = self.prompt
281
+ modified_content = content[:start] + prompt + content[end:]
282
+
283
+ # Clean up any trailing whitespace
284
+ modified_content = modified_content.rstrip()
285
+
286
+ # Ensure we end with the EOS token if the original content had it
287
+ if content.endswith("<|im_end|>"):
288
+ modified_content = modified_content.rstrip() + "<|im_end|>"
289
+
290
+ # Ensure proper spacing around the prompt
291
+ if not modified_content.endswith(prompt):
292
+ # If the prompt wasn't properly inserted, append it
293
+ modified_content = content.rstrip()
294
+ if modified_content.endswith("<|im_end|>"):
295
+ modified_content = modified_content[
296
+ : -len("<|im_end|>")
297
+ ].rstrip()
298
+ modified_content = modified_content + "\n\n" + prompt + "<|im_end|>"
299
+ else:
300
+ # No matches found, just append the prompt
301
+ prompt = self.prompt
302
+ modified_content = content.rstrip() + "\n\n" + prompt
303
+
304
+ else:
305
+ # No answer tag found, just append the prompt
306
+ prompt = self.prompt
307
+ modified_content = content.rstrip() + "\n\n" + prompt
308
+
309
+ return modified_content
310
+
311
+ def _reset(
312
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
313
+ ) -> TensorDictBase:
314
+ """Reset the transform state.
315
+
316
+ Args:
317
+ tensordict: The current tensordict
318
+ tensordict_reset: The reset tensordict
319
+
320
+ Returns:
321
+ The reset tensordict
322
+ """
323
+ return tensordict_reset