torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from .gsm8k import GSM8KRewardParser
8
+ from .ifeval import IFEvalScoreData, IfEvalScorer
9
+
10
+ __all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData"]
@@ -0,0 +1,324 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from typing import Literal
8
+
9
+ import torch
10
+ from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase
11
+ from tensordict.utils import _zip_strict, is_non_tensor
12
+ from torchrl.data import Composite, Unbounded
13
+ from torchrl.envs import Transform
14
+ from torchrl.envs.common import EnvBase
15
+
16
+
17
+ class GSM8KRewardParser(Transform):
18
+ """Reward parser for GSM8KEnv or make_gsm8k_env.
19
+
20
+ This parser automatically detects the input_mode from the parent environment and handles
21
+ responses accordingly:
22
+ - "history" mode: response is in ("history", "response") and is a History object
23
+ - "text" mode: response is in ("text", "response") and is text
24
+ - "tokens" mode: response is in ("tokens", "response") and is tokens
25
+
26
+ Args:
27
+ tokenizer (AutoTokenizer from transformers): the tokenizer associated with the model.
28
+ in_keys (list of NestedKey): the input keys. If None, will be automatically determined based on parent's input_mode.
29
+ out_keys (list of NestedKey): the output keys. Defaults to `[ "reward_answer", "reward_think", "reward_right", "reward_contained", "reward", "success"]`.
30
+ eos_token (str): the end of sentence token. Defaults to `tokenizer.eos_token` if not provided.
31
+ set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
32
+ input_mode (Literal["history", "text", "tokens"]): the input mode of the parent environment.
33
+ Defaults to `None` (will be automatically determined based on parent's input_mode).
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ tokenizer,
39
+ in_keys: list[NestedKey] | None = None,
40
+ out_keys: list[NestedKey] | None = None,
41
+ eos_token: str | None = None,
42
+ set_done_if_answer: bool = True,
43
+ input_mode: Literal["history", "text", "tokens"] | None = None,
44
+ ):
45
+ super().__init__()
46
+ self.tokenizer = tokenizer
47
+ self.eos_token = (
48
+ eos_token
49
+ if eos_token is not None
50
+ else tokenizer.eos_token
51
+ if tokenizer is not None
52
+ else None
53
+ )
54
+ self.set_done_if_answer = set_done_if_answer
55
+ self._input_mode = input_mode
56
+
57
+ if out_keys is None:
58
+ out_keys = [
59
+ "reward_answer",
60
+ "reward_think",
61
+ "reward_right",
62
+ "reward_contained",
63
+ "reward",
64
+ "success",
65
+ ]
66
+ super().__init__()
67
+ if in_keys is not None:
68
+ self.in_keys = in_keys
69
+ self.out_keys = out_keys
70
+
71
+ def _maybe_get_in_keys(self):
72
+ if not self.in_keys:
73
+ parent = getattr(self, "parent", None)
74
+ if parent is not None:
75
+ if getattr(parent, "base_env", None) is not None:
76
+ if getattr(parent.base_env, "input_mode", None) == "history":
77
+ self.in_keys = [("history", "full"), "answer"]
78
+ elif getattr(parent.base_env, "input_mode", None) == "text":
79
+ self.in_keys = [("text", "full"), "answer"]
80
+ elif getattr(parent.base_env, "input_mode", None) == "tokens":
81
+ self.in_keys = [("tokens", "full"), "answer"]
82
+ else:
83
+ raise ValueError(
84
+ f"No base env found for {self} with container {self.container}"
85
+ )
86
+
87
+ def set_container(self, container: Transform | EnvBase) -> None:
88
+ result = super().set_container(container)
89
+ self._maybe_get_in_keys()
90
+ return result
91
+
92
+ _input_mode = None
93
+
94
+ @property
95
+ def input_mode(self):
96
+ if self._input_mode is None:
97
+ input_mode = (
98
+ getattr(self.parent, "input_mode", "history")
99
+ if hasattr(self, "parent") and self.parent is not None
100
+ else "history"
101
+ )
102
+ self._input_mode = input_mode
103
+ return self._input_mode
104
+
105
+ def _step(
106
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
107
+ ) -> TensorDictBase:
108
+ from xml.etree import ElementTree as ET
109
+
110
+ if next_tensordict.batch_dims > 1:
111
+ with tensordict.view(-1) as td_view, next_tensordict.view(
112
+ -1
113
+ ) as next_td_view:
114
+ self._step(td_view, next_td_view)
115
+ # did update in place
116
+ return next_tensordict
117
+
118
+ # Get the completion based on input_mode
119
+ self._maybe_get_in_keys()
120
+ responses = tensordict[self.in_keys[0]] # batch_size, grpo_size, L
121
+
122
+ # Handle different response types based on input_mode
123
+ input_mode = self.input_mode
124
+ if input_mode == "history":
125
+ # responses is a History object, extract the text content
126
+ responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
127
+ if hasattr(responses, "content"):
128
+ # If it's a History object with content attribute
129
+ text_completion = responses.content
130
+ if is_non_tensor(text_completion):
131
+ text_completion = text_completion.tolist()
132
+ if not isinstance(text_completion, list):
133
+ text_completion = [text_completion]
134
+ elif hasattr(responses, "apply_chat_template"):
135
+ # If it's a History object, apply chat template to get text
136
+ text_completion = responses.apply_chat_template(
137
+ tokenizer=self.tokenizer, add_generation_prompt=False
138
+ )
139
+ if not isinstance(text_completion, list):
140
+ text_completion = [text_completion]
141
+ else:
142
+ # Fallback: try to convert to string
143
+ text_completion = [str(responses)]
144
+ elif input_mode == "text":
145
+ # responses is already text
146
+ if isinstance(responses, str):
147
+ text_completion = [
148
+ responses for _ in range(next_tensordict.batch_size[0])
149
+ ]
150
+ elif not isinstance(responses, list):
151
+ text_completion = [responses]
152
+ else:
153
+ text_completion = responses
154
+ elif input_mode == "tokens":
155
+ # responses is tokens, need to decode
156
+ if isinstance(responses, torch.Tensor):
157
+ if responses.ndim == 3:
158
+ batch_size, grpo_size, _ = responses.shape
159
+ # decode
160
+ text_completion = self.tokenizer.decode(
161
+ responses.flatten(0, 1).tolist()
162
+ )
163
+ if not isinstance(text_completion, list):
164
+ text_completion = [
165
+ text_completion for _ in range(next_tensordict.batch_size[0])
166
+ ]
167
+ else:
168
+ # Assume it's already a list of token sequences
169
+ text_completion = []
170
+ for token_seq in responses:
171
+ if isinstance(token_seq, torch.Tensor):
172
+ text_completion.append(
173
+ self.tokenizer.decode(token_seq.tolist())
174
+ )
175
+ else:
176
+ text_completion.append(str(token_seq))
177
+ else:
178
+ raise ValueError(f"Unknown input_mode: {input_mode}")
179
+
180
+ if self.eos_token is not None:
181
+ text_completion = [r.removesuffix(self.eos_token) for r in text_completion]
182
+ answers = next_tensordict[self.in_keys[1]] # batch_size, grpo_size
183
+
184
+ # Decomposed reward
185
+ tds = []
186
+ # torchrl_logger.info(f"{answers=}")
187
+ # torchrl_logger.info(f"{text_completion=}")
188
+ for answer, compl in _zip_strict(answers, text_completion):
189
+ try:
190
+ if not compl.startswith("<think>"):
191
+ compl = "<think>" + compl
192
+ if compl.endswith("<|im_end|>"):
193
+ compl = compl.removesuffix("<|im_end|>")
194
+ cot, potential_answer = self.extract_tags(compl)
195
+ except ET.ParseError:
196
+ cot, potential_answer = ("", "")
197
+ if potential_answer is None:
198
+ potential_answer = ""
199
+ if cot is None:
200
+ cot = ""
201
+ # TODO: in tune, the answer is parsed during dataloading
202
+ # we could create a similar dataclass for both proposed and real answer
203
+ # With tensorclass comparison should be easy
204
+ cot_orig, answer = answer.split("#### ")
205
+ tds.append(
206
+ self._single_shaped_correctness_reward(
207
+ answer, [potential_answer], [cot]
208
+ )
209
+ )
210
+ tds = torch.stack(tds)
211
+ if isinstance(responses, torch.Tensor) and responses.ndim == 3:
212
+ batch_size, grpo_size, _ = responses.shape
213
+ tds = tds.reshape(batch_size, grpo_size)
214
+ # Rewards need to have shape broadcastable to [batch x tokens x 1]
215
+ tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))
216
+ # Add the rewards, in case some have already been written
217
+ next_td_exist = next_tensordict.select(*tds.keys(True, True), strict=False)
218
+ if not next_td_exist.is_empty():
219
+ tds = tds.add(
220
+ next_td_exist, default=torch.zeros((), device=next_tensordict.device)
221
+ )
222
+ next_tensordict = next_tensordict.update(tds)
223
+ if (
224
+ self.set_done_if_answer
225
+ and (reward_answer := (next_tensordict["reward_answer"] > 0)).any()
226
+ ):
227
+ done = next_tensordict.get("done")
228
+ if done is not None:
229
+ next_tensordict.set("done", reward_answer.view_as(done) | done)
230
+ terminated = next_tensordict.get("terminated")
231
+ if terminated is not None:
232
+ next_tensordict.set(
233
+ "terminated", reward_answer.view_as(terminated) | terminated
234
+ )
235
+ return next_tensordict
236
+
237
+ def transform_reward_spec(self, reward_spec: Composite) -> Composite:
238
+ shape = reward_spec.shape + (1, 1)
239
+ reward_spec.update(
240
+ Composite(
241
+ reward_answer=Unbounded(shape),
242
+ reward_think=Unbounded(shape),
243
+ reward_right=Unbounded(shape),
244
+ reward_contained=Unbounded(shape),
245
+ reward=Unbounded(shape),
246
+ success=Unbounded(shape, dtype=torch.bool),
247
+ )
248
+ )
249
+ return reward_spec
250
+
251
+ @classmethod
252
+ def _single_shaped_correctness_reward(
253
+ cls, true_answer: str, potential_answer: list[str], cot: list[str]
254
+ ) -> TensorDict:
255
+ # TODO: In tune, these end up being lists
256
+ # torchrl_logger.info(f"{potential_answer=}")
257
+ # torchrl_logger.info(f"{true_answer=}")
258
+ if isinstance(potential_answer, str):
259
+ potential_answer = [potential_answer]
260
+ if isinstance(cot, str):
261
+ cot = [cot]
262
+
263
+ # Format quality rewards (always applied)
264
+ reward_answer = 5.0 * (len(potential_answer) == 1)
265
+ reward_think = 5.0 * (len(cot) == 1)
266
+
267
+ # Answer correctness rewards
268
+ reward_right = 20.0 * (
269
+ any(attempt == true_answer for attempt in potential_answer)
270
+ )
271
+ reward_contained = 10.0 * (
272
+ any((true_answer in attempt) for attempt in potential_answer)
273
+ )
274
+
275
+ success = len(potential_answer) > 0 and potential_answer[-1] == true_answer
276
+
277
+ # Base success reward (lower than before to make format quality more important)
278
+ base_success_reward = 60.0 if success else 0.0
279
+
280
+ # Compose the rewards - always include format quality, even when successful
281
+ reward = (
282
+ base_success_reward
283
+ + reward_answer
284
+ + reward_think
285
+ + reward_contained
286
+ + reward_right
287
+ )
288
+
289
+ rewards = TensorDict(
290
+ reward_answer=reward_answer,
291
+ reward_think=reward_think,
292
+ reward_right=reward_right,
293
+ reward_contained=reward_contained,
294
+ reward=reward,
295
+ success=success,
296
+ )
297
+ return rewards
298
+
299
+ @staticmethod
300
+ def extract_tags(text: str) -> tuple[str, str]:
301
+ """Parse XML-like tags from text.
302
+
303
+ Returns: a dictionary with keys 'think' and 'answer'.
304
+ The values are lists of strings, with each string being the content of a tag.
305
+
306
+ """
307
+ from xml.etree import ElementTree as ET
308
+
309
+ xml_string = f"<root>{text}</root>"
310
+ try:
311
+ root = ET.fromstring(xml_string)
312
+ except ET.ParseError:
313
+ return ("", "")
314
+
315
+ think_elem = root.find("think")
316
+ answer_elem = root.find("answer")
317
+ return (
318
+ think_elem.text
319
+ if think_elem is not None and think_elem.text is not None
320
+ else "",
321
+ answer_elem.text
322
+ if answer_elem is not None and answer_elem.text is not None
323
+ else "",
324
+ )
@@ -0,0 +1,13 @@
1
+ # Adapted Code from SkyThought
2
+
3
+ This project includes code adapted from [SkyThought](https://github.com/NovaSky-AI/SkyThought), specifically the file
4
+ [`ifeval_scorer.py`](https://github.com/NovaSky-AI/SkyThought/blob/2e5db2b26be63c5545d93be4ad08f5ca46449776/skythought/evals/scoring/ifeval/ifeval_scorer.py).
5
+
6
+ Parts of these files are themselves copied from other sources with a similar license.
7
+
8
+ The original code is distributed under the Apache 2.0 license, which can be found in the SkyThought repository: [Apache 2.0 License](https://github.com/NovaSky-AI/SkyThought/blob/main/LICENSE).
9
+
10
+ ### Modifications
11
+
12
+ Modifications were made to the original code according to the terms of the Apache 2.0 license. The changes include
13
+ TorchRL formatting of the data using TensorDict and TorchRL's transforms.
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from ._scorer import IFEvalScoreData, IfEvalScorer
9
+
10
+ __all__ = ["IfEvalScorer", "IFEvalScoreData"]