torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,454 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Modifications from original script.
7
+
8
+ Modifications include:
9
+
10
+ - TensorDict embedding
11
+ - Modification of key names
12
+ - make IfEvalScorer a TorchRL transform
13
+
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import importlib.util
19
+ import re
20
+ from collections.abc import Callable
21
+
22
+ import torch
23
+ from tensordict import (
24
+ lazy_stack,
25
+ NestedKey,
26
+ NonTensorData,
27
+ TensorClass,
28
+ TensorDict,
29
+ TensorDictBase,
30
+ )
31
+ from tensordict.tensorclass import is_non_tensor
32
+ from torchrl._utils import logger as torchrl_logger
33
+
34
+ from torchrl.data.tensor_specs import Composite, Unbounded
35
+ from torchrl.envs import Transform
36
+
37
+ _has_langdetect = importlib.util.find_spec("langdetect") is not None
38
+ _has_nltk = importlib.util.find_spec("nltk") is not None
39
+ _has_immutabledict = importlib.util.find_spec("immutabledict") is not None
40
+
41
+
42
+ class IFEvalScoreData(TensorClass):
43
+ """IFEval score container."""
44
+
45
+ prompt_level_strict_acc: torch.Tensor | None
46
+ inst_level_strict_acc: torch.Tensor | None
47
+ prompt_level_loose_acc: torch.Tensor | None
48
+ inst_level_loose_acc: torch.Tensor | None
49
+
50
+ @classmethod
51
+ def default_spec(
52
+ cls, shape: torch.Size, device: torch.device | None = None
53
+ ) -> Composite:
54
+ return Composite(
55
+ prompt_level_strict_acc=Unbounded(
56
+ shape=shape, dtype=torch.bool, device=device
57
+ ),
58
+ inst_level_strict_acc=Unbounded(
59
+ shape=shape, dtype=torch.bool, device=device
60
+ ),
61
+ prompt_level_loose_acc=Unbounded(
62
+ shape=shape, dtype=torch.bool, device=device
63
+ ),
64
+ inst_level_loose_acc=Unbounded(
65
+ shape=shape, dtype=torch.bool, device=device
66
+ ),
67
+ data_cls=cls,
68
+ step_mdp_static=True,
69
+ )
70
+
71
+ def __post_init__(self):
72
+ prompt_level_loose_acc = self.get(
73
+ "prompt_level_loose_acc", as_padded_tensor=True
74
+ )
75
+ inst_level_loose_acc = self.get("inst_level_loose_acc", as_padded_tensor=True)
76
+ prompt_level_strict_acc = self.get(
77
+ "prompt_level_strict_acc", as_padded_tensor=True
78
+ )
79
+ inst_level_strict_acc = self.get("inst_level_strict_acc", as_padded_tensor=True)
80
+
81
+ if prompt_level_loose_acc is None:
82
+ self.prompt_level_loose_acc = torch.zeros(self.batch_size + (1,))
83
+ elif prompt_level_loose_acc.ndim == self.ndim:
84
+ self.prompt_level_loose_acc = prompt_level_loose_acc.unsqueeze(-1)
85
+
86
+ if inst_level_loose_acc is None:
87
+ self.inst_level_loose_acc = torch.zeros(self.batch_size + (1,))
88
+ elif inst_level_loose_acc.ndim == self.ndim:
89
+ self.inst_level_loose_acc = inst_level_loose_acc.unsqueeze(-1)
90
+
91
+ if prompt_level_strict_acc is None:
92
+ self.prompt_level_strict_acc = torch.zeros(self.batch_size + (1,))
93
+ elif prompt_level_strict_acc.ndim == self.ndim:
94
+ self.prompt_level_strict_acc = prompt_level_strict_acc.unsqueeze(-1)
95
+
96
+ if inst_level_strict_acc is None:
97
+ self.inst_level_strict_acc = torch.zeros(self.batch_size + (1,))
98
+ elif inst_level_strict_acc.ndim == self.ndim:
99
+ self.inst_level_strict_acc = inst_level_strict_acc.unsqueeze(-1)
100
+
101
+
102
+ def _process_results(
103
+ data: TensorDict,
104
+ response: str | NonTensorData,
105
+ verbose: bool = False,
106
+ prompt: str | None = None,
107
+ ) -> IFEvalScoreData:
108
+ if not _has_langdetect:
109
+ raise ImportError("langdetect must be installed to user IFEvalScorer.")
110
+ if not _has_immutabledict:
111
+ raise ImportError("immutabledict must be installed to user IFEvalScorer.")
112
+
113
+ from ._instructions_main import (
114
+ _InputExample,
115
+ _test_instruction_following_loose,
116
+ _test_instruction_following_strict,
117
+ )
118
+
119
+ if prompt is None:
120
+ prompt = data["text"]
121
+
122
+ inp = _InputExample(
123
+ key=data["key"],
124
+ instruction_id_list=data["instruction_id_list"],
125
+ prompt=prompt if prompt is not None else "",
126
+ kwargs=data["kwargs"],
127
+ )
128
+
129
+ if verbose:
130
+ torchrl_logger.info(f"Processing {inp=} {response=}")
131
+ out_strict = _test_instruction_following_strict(inp, response)
132
+ out_loose = _test_instruction_following_loose(inp, response)
133
+
134
+ result = IFEvalScoreData(
135
+ prompt_level_strict_acc=out_strict.follow_all_instructions,
136
+ inst_level_strict_acc=out_strict.follow_instruction_list,
137
+ prompt_level_loose_acc=out_loose.follow_all_instructions,
138
+ inst_level_loose_acc=out_loose.follow_instruction_list,
139
+ batch_size=data.batch_size,
140
+ device=data.device,
141
+ )
142
+
143
+ if verbose:
144
+ torchrl_logger.info(f"Result: {result.to_dict()=}")
145
+ return result
146
+
147
+
148
+ class IfEvalScorer(Transform):
149
+ """Scorer for the IF-Eval task.
150
+
151
+ For the IFEval dataset format, see https://huggingface.co/datasets/google/IFEval
152
+
153
+ The score data is written under the `score_key` using the :class:`~torchrl.envs.llm.IFEvalScoreData` data structure.
154
+ Scores can be aggregated on a single reward by using the `aggregate_reward` keyword argument in the constructor, which
155
+ can be a bool or a function.
156
+
157
+ Keyword Args:
158
+ instruction_ids_key (NestedKey, optional): The column name for the list of instruction ids. Defaults to "instruction_id_list".
159
+ prompt_key (NestedKey, optional): The column name for the prompt. Defaults to "text".
160
+ keyword_args_key (NestedKey, optional): The column name for the keyword arguments to the instruction builder. Defaults to "kwargs".
161
+ id_key (NestedKey, optional): The column name for the unique identifier. Defaults to "key".
162
+ response_column (NestedKey, optional): The column name for the response. Defaults to "text_response".
163
+ score_key (NestedKey, optional): The key to store the score. Defaults to "ifeval_score".
164
+ aggregate_reward (bool, callable, optional): Whether to aggregate the reward or not. If a Callable is passed,
165
+ it must take as input an :class:`~torchrl.envs.llm.IFEvalScoreData` instance, and optionally `think_blocks`, `answer_blocks` and `complete` keyword arguments
166
+ containing the list of think and answer blocks, respectively.
167
+ It must return a tensor with shape identical to the env batch-size with an additional trailing singleton dimension.
168
+ Defaults to `True`. The default aggregator is a simple sum over the fields of :class:`~torchrl.envs.llm.IFEvalScoreData`.
169
+ format_weights (list[float], optional): The weights for the format fields (`prompt_level_strict_acc`, `inst_level_strict_acc`,
170
+ `prompt_level_loose_acc`, `inst_level_loose_acc`, in that order). Defaults to `[0.4, 0.3, 0.2, 0.1]`.
171
+ This is only used if `aggregate_reward` is `True` and the default aggregator is used.
172
+ verbose (bool, optional): Whether to print verbose information. Defaults to `False`.
173
+ set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
174
+
175
+ .. note:: `IFEvalScorer` requires the following libraries to be installed: `langdetect`, `nltk` and `immutabledict`.
176
+
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ *,
182
+ instruction_ids_key: NestedKey = "instruction_id_list",
183
+ prompt_key: NestedKey = "text",
184
+ keyword_args_key: NestedKey = "kwargs",
185
+ id_key: NestedKey = "key",
186
+ response_column: NestedKey = "text_response",
187
+ score_key: NestedKey = "ifeval_score",
188
+ aggregate_reward: bool
189
+ | Callable[
190
+ [IFEvalScoreData, list[str] | None, list[str] | None], torch.Tensor
191
+ ] = True,
192
+ format_weights: list[float] | None = None,
193
+ verbose: bool = False,
194
+ set_done_if_answer: bool = True,
195
+ ):
196
+ self.aggregate_reward = aggregate_reward
197
+ self.score_key = score_key
198
+ self.set_done_if_answer = set_done_if_answer
199
+ out_keys = [self.score_key]
200
+ if aggregate_reward:
201
+ out_keys.append("reward")
202
+ super().__init__(
203
+ in_keys=[
204
+ instruction_ids_key,
205
+ prompt_key,
206
+ keyword_args_key,
207
+ id_key,
208
+ response_column,
209
+ ],
210
+ out_keys=out_keys,
211
+ )
212
+ if not _has_langdetect:
213
+ raise ImportError("langdetect must be installed to user IFEvalScorer.")
214
+ if not _has_nltk:
215
+ raise ImportError("nltk must be installed to user IFEvalScorer.")
216
+ self.instruction_ids_key = instruction_ids_key
217
+ self.response_key = response_column
218
+ self.keyword_args_key = keyword_args_key
219
+ self.prompt_key = prompt_key
220
+ self.id_key = id_key
221
+ self.format_weights = (
222
+ format_weights if format_weights is not None else [0.4, 0.3, 0.2, 0.1]
223
+ )
224
+ self.verbose = verbose
225
+
226
+ def default_reward_aggregator(
227
+ self,
228
+ score: IFEvalScoreData,
229
+ think_blocks: list[str] | None = None,
230
+ answer_blocks: list[str] | None = None,
231
+ complete: bool | torch.Tensor | None = None,
232
+ ) -> torch.Tensor:
233
+ r"""Improved reward aggregation function with tiered multiplicative scoring.
234
+
235
+ Args:
236
+ score (IFEvalScoreData): The score data.
237
+ think_blocks (list[str], optional): The list of think blocks.
238
+ answer_blocks (list[str], optional): The list of answer blocks.
239
+ complete (bool, optional): Whether the response is complete (ends with a eos token).
240
+
241
+ The reward uses a tiered multiplicative system:
242
+
243
+ 1. Critical failure check: No answer blocks = 0 reward
244
+ 2. Base format score (0-1): Weighted average of format metrics
245
+ 3. Structure multiplier (0.1-1.0): Penalties for missing/multiple blocks
246
+ 4. Quality bonus (0-0.5): Rewards for high quality and completion
247
+ 5. Task complexity scaling: More requirements = higher potential rewards
248
+
249
+ The final formula is:
250
+ reward = (format_score + quality_bonus) * structure_multiplier * complexity_scale
251
+
252
+ This provides better learning signals by:
253
+ - Requiring critical elements (answer tags) for meaningful rewards
254
+ - Using multiplicative scaling to reward doing everything well
255
+ - Scaling rewards based on task complexity
256
+ - Providing clear failure modes and success incentives
257
+
258
+ Reward range: 0.0 to ~1.5-2.7 depending on task complexity (more instructions = higher max reward).
259
+ """
260
+ default_dtype = torch.get_default_dtype()
261
+ score = score.to(default_dtype)
262
+
263
+ # Critical failure check - no answer = no reward
264
+ if not answer_blocks:
265
+ return torch.zeros(
266
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
267
+ )
268
+
269
+ # Base format score calculation (0-1)
270
+ format_components = torch.stack(
271
+ [
272
+ score.prompt_level_strict_acc.sum(-1, keepdim=True)
273
+ if score.prompt_level_strict_acc is not None
274
+ else torch.zeros(
275
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
276
+ ), # Single value
277
+ score.inst_level_strict_acc.mean(-1, keepdim=True)
278
+ if score.inst_level_strict_acc is not None
279
+ else torch.zeros(
280
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
281
+ ), # Average across instructions
282
+ score.prompt_level_loose_acc.sum(-1, keepdim=True)
283
+ if score.prompt_level_loose_acc is not None
284
+ else torch.zeros(
285
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
286
+ ), # Single value
287
+ score.inst_level_loose_acc.mean(-1, keepdim=True)
288
+ if score.inst_level_loose_acc is not None
289
+ else torch.zeros(
290
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
291
+ ), # Average across instructions
292
+ ],
293
+ -1,
294
+ )
295
+ weights = torch.tensor(
296
+ self.format_weights,
297
+ device=format_components.device,
298
+ dtype=default_dtype,
299
+ )
300
+ format_score = (format_components * weights).sum(dim=-1, keepdim=True)
301
+
302
+ # Structure multiplier (0.1-1.0)
303
+ structure_multiplier = 1.0
304
+
305
+ # Heavy penalty for missing think blocks (but not zero)
306
+ if not think_blocks:
307
+ structure_multiplier *= 0.3
308
+ elif len(think_blocks) > 1:
309
+ structure_multiplier *= 0.7 # Penalty for multiple think blocks
310
+
311
+ # Penalty for multiple answer blocks
312
+ if len(answer_blocks) > 1:
313
+ structure_multiplier *= 0.7
314
+
315
+ # Quality bonus (0-0.5)
316
+ quality_bonus = torch.zeros_like(format_score)
317
+
318
+ # Bonus for high quality responses
319
+ if format_score > 0.8:
320
+ quality_bonus += 0.3
321
+
322
+ # Completion bonus
323
+ if complete is not None:
324
+ if isinstance(complete, torch.Tensor):
325
+ completion_bonus = complete.to(default_dtype) * 0.2
326
+ else:
327
+ completion_bonus = float(complete) * 0.2
328
+ quality_bonus += completion_bonus
329
+
330
+ # Task complexity scaling based on number of instructions
331
+ # More instructions = higher potential rewards
332
+ if (
333
+ score.inst_level_strict_acc is not None
334
+ and score.inst_level_strict_acc.numel() > 0
335
+ ):
336
+ num_instructions = score.inst_level_strict_acc.shape[-1]
337
+ else:
338
+ num_instructions = 1
339
+ complexity_scale = (
340
+ 1.0 + (num_instructions - 1) * 0.2
341
+ ) # 1.0 for 1 instruction, 1.2 for 2, etc.
342
+
343
+ # Final reward: (format + quality) * structure_multiplier * complexity_scale
344
+ final_reward = (
345
+ (format_score + quality_bonus) * structure_multiplier * complexity_scale
346
+ )
347
+ final_reward = final_reward.to(default_dtype)
348
+
349
+ return final_reward
350
+
351
+ def _step(
352
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
353
+ ) -> TensorDictBase:
354
+ if not getattr(self.parent.base_env, "input_mode", "history") == "history":
355
+ raise ValueError("IFEvalScorer only supports history input mode")
356
+
357
+ if tensordict.ndim:
358
+ return lazy_stack(
359
+ [
360
+ self._step(td, next_td)
361
+ for td, next_td in zip(
362
+ tensordict.unbind(0), next_tensordict.unbind(0)
363
+ )
364
+ ]
365
+ )
366
+ h = tensordict["history", "full"][..., -1]
367
+ prompt = tensordict["history", "prompt"][..., -1].content
368
+ response = h.content
369
+ complete = h.is_complete
370
+ # response = tensordict.get(self.response_key)
371
+ if is_non_tensor(response):
372
+ response = response.data
373
+
374
+ # TODO: This should be a distinct module
375
+ # Regular expression patterns to match think and answer blocks
376
+ think_pattern = r"<think>(.*?)</think>"
377
+ answer_pattern = r"<answer>(.*?)</answer>"
378
+ # Extract think block
379
+ think_blocks = re.findall(think_pattern, response, re.DOTALL)
380
+
381
+ # Extract answer block
382
+ answer_blocks = re.findall(answer_pattern, response, re.DOTALL)
383
+
384
+ score = _process_results(
385
+ tensordict.copy().auto_device_(),
386
+ answer_blocks[0] if answer_blocks else "",
387
+ verbose=self.verbose,
388
+ prompt=prompt,
389
+ )
390
+ next_tensordict.set(
391
+ self.score_key,
392
+ score,
393
+ )
394
+ if self.aggregate_reward:
395
+ if callable(self.aggregate_reward):
396
+ reward_func = self.aggregate_reward
397
+ else:
398
+ reward_func = self.default_reward_aggregator
399
+ reward = reward_func(
400
+ score,
401
+ think_blocks=think_blocks,
402
+ answer_blocks=answer_blocks,
403
+ complete=complete,
404
+ )
405
+ reward = reward.view(
406
+ next_tensordict.batch_size
407
+ + (
408
+ 1,
409
+ 1,
410
+ )
411
+ )
412
+ next_tensordict.set("reward", reward)
413
+ if self.set_done_if_answer and bool(answer_blocks):
414
+ next_tensordict.set(
415
+ "done",
416
+ torch.ones(
417
+ next_tensordict.batch_size + (1,),
418
+ device=next_tensordict.device,
419
+ dtype=torch.bool,
420
+ ),
421
+ )
422
+ next_tensordict.set(
423
+ "terminated",
424
+ torch.ones(
425
+ next_tensordict.batch_size + (1,),
426
+ device=next_tensordict.device,
427
+ dtype=torch.bool,
428
+ ),
429
+ )
430
+ return next_tensordict
431
+
432
+ @property
433
+ def expected_keys(self) -> list[str]:
434
+ return [
435
+ self.instruction_ids_key,
436
+ self.prompt_key,
437
+ self.keyword_args_key,
438
+ self.id_key,
439
+ self.response_key,
440
+ ]
441
+
442
+ def transform_reward_spec(self, reward_spec: Composite) -> Composite:
443
+ reward_spec["reward"] = Unbounded(
444
+ reward_spec.shape + (1, 1),
445
+ dtype=torch.get_default_dtype(),
446
+ device=reward_spec.device,
447
+ )
448
+ return reward_spec
449
+
450
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
451
+ observation_spec[self.score_key] = IFEvalScoreData.default_spec(
452
+ observation_spec.shape, device=observation_spec.device
453
+ )
454
+ return observation_spec
@@ -0,0 +1,55 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .browser import BrowserTransform
7
+ from .dataloading import (
8
+ as_nested_tensor,
9
+ as_padded_tensor,
10
+ DataLoadingPrimer,
11
+ RayDataLoadingPrimer,
12
+ )
13
+ from .format import TemplateTransform
14
+ from .kl import KLComputation, KLRewardTransform, RetrieveKL, RetrieveLogProb
15
+ from .policy_version import PolicyVersion
16
+ from .reason import AddThinkingPrompt
17
+ from .tokenizer import Tokenizer
18
+ from .tools import (
19
+ ExecuteToolsInOrder,
20
+ JSONCallParser,
21
+ MCPToolTransform,
22
+ PythonExecutorService,
23
+ PythonInterpreter,
24
+ SimpleToolTransform,
25
+ ToolCall,
26
+ ToolRegistry,
27
+ ToolService,
28
+ XMLBlockParser,
29
+ )
30
+
31
+ __all__ = [
32
+ "AddThinkingPrompt",
33
+ "BrowserTransform",
34
+ "DataLoadingPrimer",
35
+ "ExecuteToolsInOrder",
36
+ "JSONCallParser",
37
+ "KLComputation",
38
+ "KLRewardTransform",
39
+ "MCPToolTransform",
40
+ "PolicyVersion",
41
+ "PythonExecutorService",
42
+ "PythonInterpreter",
43
+ "RayDataLoadingPrimer",
44
+ "RetrieveKL",
45
+ "RetrieveLogProb",
46
+ "SimpleToolTransform",
47
+ "TemplateTransform",
48
+ "Tokenizer",
49
+ "ToolCall",
50
+ "ToolRegistry",
51
+ "ToolService",
52
+ "XMLBlockParser",
53
+ "as_nested_tensor",
54
+ "as_padded_tensor",
55
+ ]