torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,198 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ from tensordict import tensorclass, TensorDict
9
+
10
+ from torchrl.data.llm.dataset import TensorDictTokenizer, TokenizedDatasetLoader
11
+
12
+ DEFAULT_DATASET = "CarperAI/openai_summarize_tldr"
13
+
14
+
15
+ @tensorclass
16
+ class PromptData:
17
+ """A prompt dataset."""
18
+
19
+ input_ids: torch.Tensor
20
+ attention_mask: torch.Tensor
21
+ prompt_rindex: torch.Tensor
22
+ labels: torch.Tensor | None = None
23
+ logits: torch.Tensor | None = None
24
+ loss: torch.Tensor | None = None
25
+
26
+ def mask_label(self, pad_token_id=50256):
27
+ _, block_size = self.input_ids.shape
28
+ attention_mask = (
29
+ torch.arange(block_size, device=self.prompt_rindex.device)
30
+ < self.prompt_rindex[:, None]
31
+ ).to(torch.int64)
32
+ input_ids = torch.where(attention_mask == 1, self.input_ids, pad_token_id)
33
+ return self.__class__(
34
+ input_ids=input_ids,
35
+ attention_mask=attention_mask,
36
+ prompt_rindex=self.prompt_rindex,
37
+ loss=self.loss,
38
+ batch_size=[],
39
+ )
40
+
41
+ @classmethod
42
+ def from_dataset(
43
+ cls,
44
+ split,
45
+ dataset_name=None,
46
+ max_length=550,
47
+ root_dir=None,
48
+ from_disk=False,
49
+ num_workers: int | None = None,
50
+ ):
51
+ """Returns a :class:`PromptData` from a dataset name.
52
+
53
+ Args:
54
+ split (str): ``"train"`` or ``"valid"`` depending on the data split needed.
55
+ dataset_name (str, optional): name of the dataset to be processed. Defaults to
56
+ ``"CarperAI/openai_summarize_comparisons"``.
57
+ max_length (int, optional): maximum length of the dataset sequences.
58
+ Defaults to 550.
59
+ root_dir (path, optional): the path where the datasets are stored.
60
+ Defaults to ``"$HOME/.cache/torchrl/data"``
61
+ from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk`
62
+ will be used. Otherwise, :func:`datasets.load_dataset` will be used.
63
+ Defaults to ``False``.
64
+ num_workers (int, optional): number of workers for :meth:`datasets.dataset.map`
65
+ which is called during tokenization.
66
+ Defaults to ``max(os.cpu_count() // 2, 1)``.
67
+
68
+ Returns: a :class:`PromptData` instance containing a memory-mapped
69
+ version of the required dataset.
70
+
71
+ Examples:
72
+ >>> data = PromptData.from_dataset("train")
73
+ >>> print(data)
74
+ PromptDataTLDR(
75
+ attention_mask=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
76
+ input_ids=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
77
+ prompt_rindex=MemoryMappedTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False),
78
+ labels=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
79
+ logits=None,
80
+ loss=None,
81
+ batch_size=torch.Size([116722]),
82
+ device=None,
83
+ is_shared=False)
84
+ >>> # data can be sampled from using regular indexing
85
+ >>> sub_data = data[:3]
86
+
87
+ """
88
+ dataset_name = dataset_name if dataset_name is not None else DEFAULT_DATASET
89
+ loader = TokenizedDatasetLoader(
90
+ split,
91
+ max_length,
92
+ dataset_name,
93
+ PromptTensorDictTokenizer,
94
+ root_dir=root_dir,
95
+ from_disk=from_disk,
96
+ num_workers=num_workers,
97
+ )
98
+ data = loader.load()
99
+ return cls(**data, labels=data["input_ids"], batch_size=data.shape)
100
+
101
+
102
+ class PromptTensorDictTokenizer(TensorDictTokenizer):
103
+ """Tokenization recipe for prompt datasets.
104
+
105
+ Returns a tokenizer function, which reads an example containing a prompt
106
+ and a label and tokenizes them.
107
+
108
+ Args:
109
+ tokenizer (tokenizer from transformers library): the tokenizer to use.
110
+ max_length (int): maximum length of the sequence.
111
+ key (str, optional): the key where to find the text. Defaults to ``"prompt"``.
112
+ padding (str, optional): type of padding. Defaults to ``"max_length"``.
113
+ truncation (bool, optional): whether the sequences should be truncated to max_length.
114
+ return_tensordict (bool, optional): if ``True``, a TensoDict is returned.
115
+ Otherwise, a the original data will be returned.
116
+ device (torch.device, optional): the device where to store the data.
117
+ This option is ignored if ``return_tensordict=False``.
118
+
119
+ The :meth:`__call__` method of this class will execute the following operations:
120
+
121
+ - Read the ``prompt`` string contacted with the ``label`` string and tokenize
122
+ them. The results will be stored in the ``"input_ids"`` TensorDict entry.
123
+ - Write a ``"prompt_rindex"`` entry with the index of the last valid
124
+ token from the prompt.
125
+ - Write a ``"valid_sample"`` which identifies which entry in the
126
+ tensordict has eough toknens to meet the ``max_length`` criterion.
127
+ - Return a :class:`tensordict.TensorDict` instance with tokenized inputs.
128
+
129
+ The tensordict batch-size will match the batch-size of the input.
130
+
131
+ Examples:
132
+ >>> from transformers import AutoTokenizer
133
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
134
+ >>> tokenizer.pad_token = tokenizer.eos_token
135
+ >>> example = {
136
+ ... "prompt": ["This prompt is long enough to be tokenized.", "this one too!"],
137
+ ... "label": ["Indeed it is.", 'It might as well be.'],
138
+ ... }
139
+ >>> fn = PromptTensorDictTokenizer(tokenizer, 50)
140
+ >>> print(fn(example))
141
+ TensorDict(
142
+ fields={
143
+ attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
144
+ input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
145
+ prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
146
+ valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)},
147
+ batch_size=torch.Size([2]),
148
+ device=None,
149
+ is_shared=False)
150
+
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ tokenizer,
156
+ max_length,
157
+ key="prompt",
158
+ padding="max_length",
159
+ truncation=True,
160
+ return_tensordict=True,
161
+ device=None,
162
+ ):
163
+ self.tokenizer = tokenizer
164
+ self.max_length = max_length
165
+ self.key = key
166
+ self.padding = padding
167
+ self.truncation = truncation
168
+ self.return_tensordict = return_tensordict
169
+ self.device = device
170
+
171
+ def __call__(self, sample):
172
+ tokenizer = self.tokenizer
173
+ max_length = self.max_length
174
+
175
+ tokenized_prompts = tokenizer(
176
+ sample[self.key], max_length=max_length, truncation=True
177
+ )
178
+ prompt_rindex = [len(prompt) - 1 for prompt in tokenized_prompts["input_ids"]]
179
+ tokenized_example = tokenizer(
180
+ [
181
+ prompt + label
182
+ for prompt, label in zip(sample[self.key], sample["label"])
183
+ ],
184
+ max_length=max_length,
185
+ padding=self.padding,
186
+ truncation=self.truncation,
187
+ )
188
+ tokenized_example["prompt_rindex"] = prompt_rindex
189
+ # drop any examples whose total length when tokenized exceeds block size
190
+ # with recommended block size of 550, this is only ~0.1% of available examples.
191
+ # NOTE: to mark as discarded we just save the mask as we cannot change the shape here
192
+ tokenized_example["valid_sample"] = [True] * len(tokenized_example["input_ids"])
193
+ for i, input_ids in enumerate(tokenized_example["input_ids"]):
194
+ if input_ids[-1] != tokenizer.eos_token_id:
195
+ tokenized_example["valid_sample"][i] = False
196
+ if self.return_tensordict:
197
+ return TensorDict.from_dict(dict(tokenized_example), device=self.device)
198
+ return tokenized_example
@@ -0,0 +1,225 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib
8
+
9
+ import torch
10
+ from tensordict import tensorclass
11
+ from torchrl.data.llm.dataset import TensorDictTokenizer, TokenizedDatasetLoader
12
+
13
+ DEFAULT_DATASET = "CarperAI/openai_summarize_comparisons"
14
+ _has_datasets = importlib.util.find_spec("datasets") is not None
15
+ _has_tqdm = importlib.util.find_spec("tqdm") is not None
16
+
17
+
18
+ @tensorclass
19
+ class RewardData:
20
+ """A dataclass for reward model training."""
21
+
22
+ input_ids: torch.Tensor
23
+ attention_mask: torch.Tensor
24
+ rewards: torch.Tensor | None = None
25
+ end_scores: torch.Tensor | None = None
26
+
27
+
28
+ @tensorclass
29
+ class PairwiseDataset:
30
+ """Represents a dataset in a pairwise manner (chosen vs rejected).
31
+
32
+ Attributes:
33
+ chosen_data: data to be chosen.
34
+ rejected_data: corresponding data to be rejected.
35
+
36
+ Examples:
37
+ >>> data = PairwiseDataset.from_dataset("train", max_length=550)
38
+ >>> print(data)
39
+ PairwiseDataset(
40
+ chosen_data=RewardData(
41
+ attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
42
+ input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
43
+ rewards=None,
44
+ end_scores=None,
45
+ batch_size=torch.Size([92534]),
46
+ device=None,
47
+ is_shared=False),
48
+ rejected_data=RewardData(
49
+ attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
50
+ input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
51
+ rewards=None,
52
+ end_scores=None,
53
+ batch_size=torch.Size([92534]),
54
+ device=None,
55
+ is_shared=False),
56
+ batch_size=torch.Size([92534]),
57
+ device=None,
58
+ is_shared=False)
59
+
60
+ """
61
+
62
+ chosen_data: RewardData
63
+ rejected_data: RewardData
64
+
65
+ @classmethod
66
+ def from_dataset(
67
+ cls,
68
+ split,
69
+ dataset_name: str | None = None,
70
+ max_length: int = 550,
71
+ root_dir: str | None = None,
72
+ from_disk: bool = False,
73
+ num_workers: int | None = None,
74
+ ):
75
+ """Returns a :class:`PairwiseDataset` from a dataset name.
76
+
77
+ Args:
78
+ split (str): ``"train"`` or ``"valid"`` depending on the data split needed.
79
+ dataset_name (str, optional): name of the dataset to be processed. Defaults to
80
+ ``"CarperAI/openai_summarize_comparisons"``.
81
+ max_length (int, optional): maximum length of the dataset sequences.
82
+ Defaults to 550.
83
+ root_dir (path, optional): the path where the datasets are stored.
84
+ Defaults to ``"$HOME/.cache/torchrl/data"``
85
+ from_disk (bool, optional): if ``True``, :func:`datasets.load_from_disk`
86
+ will be used. Otherwise, :func:`datasets.load_dataset` will be used.
87
+ Defaults to ``False``.
88
+
89
+ Returns: a :class:`PairwiseDataset` instance containing a memory-mapped
90
+ version of the required dataset.
91
+
92
+ Examples:
93
+ >>> data = PairwiseDataset.from_dataset("train")
94
+ >>> print(data)
95
+ PairwiseDataset(
96
+ chosen_data=RewardData(
97
+ attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
98
+ input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
99
+ rewards=None,
100
+ end_scores=None,
101
+ batch_size=torch.Size([92534]),
102
+ device=None,
103
+ is_shared=False),
104
+ rejected_data=RewardData(
105
+ attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
106
+ input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
107
+ rewards=None,
108
+ end_scores=None,
109
+ batch_size=torch.Size([92534]),
110
+ device=None,
111
+ is_shared=False),
112
+ batch_size=torch.Size([92534]),
113
+ device=None,
114
+ is_shared=False)
115
+ >>> # data can be sampled from using regular indexing
116
+ >>> sub_data = data[:3]
117
+
118
+ """
119
+ if dataset_name is None:
120
+ dataset_name = DEFAULT_DATASET
121
+ loader = TokenizedDatasetLoader(
122
+ split,
123
+ max_length,
124
+ dataset_name,
125
+ TensorDictTokenizer,
126
+ pre_tokenization_hook,
127
+ root_dir=root_dir,
128
+ from_disk=from_disk,
129
+ num_workers=num_workers,
130
+ )
131
+ data = loader.load()
132
+ maxidx = data.shape[0] // 2
133
+ batch_size = [maxidx]
134
+ # this is a zero-copy creation, as we index memmap-arrays without
135
+ # creating new storage.
136
+ chosen_data = data[:maxidx]
137
+ rejected_data = data[maxidx:]
138
+ return cls(
139
+ chosen_data=RewardData(
140
+ **chosen_data,
141
+ batch_size=batch_size,
142
+ ),
143
+ rejected_data=RewardData(
144
+ **rejected_data,
145
+ batch_size=batch_size,
146
+ ),
147
+ batch_size=batch_size,
148
+ )
149
+
150
+
151
+ def pre_tokenization_hook(dataset, min_length=5):
152
+ """Pre-tokenizer for the reward model (comparison) dataset.
153
+
154
+ This function selects all samples where the length of the prompt is
155
+ sufficient and where the ``"chosen"`` and ``"rejected"`` entries differ.
156
+
157
+ Args:
158
+ dataset (datasets.Dataset): the dataset to process. Should have entries
159
+ ``"prompt"``, ``"chosen"`` and ``"rejected"``.
160
+ min_length (int, optional): minimum length of a prompt (in word count).
161
+
162
+ Returns: a new ``datasets.Dataset`` with selected prompts under ``"text"``.
163
+ The first half are the chosen strings and the second the rejected ones,
164
+ always preceded by the original prompt.
165
+
166
+ Examples:
167
+ >>> from datasets import Dataset
168
+ >>> data = Dataset.from_dict({
169
+ ... "prompt": ["I'm the king"],
170
+ ... "chosen": ["It is true, you are the king"],
171
+ ... "rejected": ["No, I am the king, you are not"]})
172
+ >>> print(pre_tokenization_hook(data))
173
+ Dataset({
174
+ features: ['text'],
175
+ num_rows: 2
176
+ })
177
+ >>> data = Dataset.from_dict({
178
+ ... "prompt": ["I'm the king"],
179
+ ... "chosen": ["It is true, you are the king"],
180
+ ... "rejected": ["It is true, you are the king"]}) # chosen and rejected match
181
+ >>> print(pre_tokenization_hook(data))
182
+ Dataset({
183
+ features: [],
184
+ num_rows: 0
185
+ })
186
+ >>> data = Dataset.from_dict({
187
+ ... "prompt": ["I'm the king"],
188
+ ... "chosen": ["Yes"],
189
+ ... "rejected": ["No"]}) # chosen and rejected are too short
190
+ >>> print(pre_tokenization_hook(data))
191
+ Dataset({
192
+ features: [],
193
+ num_rows: 0
194
+ })
195
+
196
+ """
197
+ if not _has_datasets:
198
+ raise ImportError(
199
+ "datasets module couldn't be found. Make sure it is installed in your environment."
200
+ )
201
+ from datasets import Dataset as HFDataset
202
+
203
+ chosen = []
204
+ rejected = []
205
+ if _has_tqdm:
206
+ from tqdm import tqdm
207
+
208
+ pbar = tqdm(dataset)
209
+ else:
210
+ pbar = dataset
211
+ for sample in pbar:
212
+ prompt = sample["prompt"]
213
+ chosen_summary = sample["chosen"]
214
+ rejected_summary = sample["rejected"]
215
+ if chosen_summary == rejected_summary:
216
+ continue
217
+ if (
218
+ len(chosen_summary.split()) < min_length
219
+ or len(rejected_summary.split()) < min_length
220
+ ):
221
+ continue
222
+ chosen.append({"text": prompt + "\n" + chosen_summary})
223
+ rejected.append({"text": prompt + "\n" + rejected_summary})
224
+
225
+ return HFDataset.from_list(chosen + rejected)
@@ -0,0 +1,186 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections import defaultdict, deque
8
+ from typing import Any
9
+
10
+ import torch
11
+ from tensordict import NestedKey, TensorDictBase
12
+ from torchrl._utils import logger as torchrl_logger
13
+ from torchrl.envs.transforms import Transform
14
+
15
+
16
+ class TopKRewardSelector(Transform):
17
+ """A replay-buffer transform that selects the top-k rewards for each prompt.
18
+
19
+ Args:
20
+ total_dialog_turns (int): Number of dialog turns to keep in memory for the top-k selection.
21
+ topk_size (int): Number of top-k rewards to select. Must be smaller than or equal to total_dialog_turns.
22
+ prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to "text".
23
+ rewards_key (NestedKey): Key to the rewards in the tensordict. Defaults to ("next", "reward").
24
+ done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done").
25
+ verbose (bool): Whether to print verbose information. Defaults to `False`.
26
+
27
+ Example:
28
+ >>> from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement
29
+ >>> from tensordict import TensorDict, lazy_stack
30
+ >>> import torch
31
+ >>> from torchrl.data.llm.topk import TopKRewardSelector
32
+ >>> # Create a replay buffer with 50 items, a sampler that samples without replacement, and a batch size of 5
33
+ >>> rb = ReplayBuffer(
34
+ ... storage=LazyStackStorage(50),
35
+ ... sampler=SamplerWithoutReplacement,
36
+ ... batch_size=5,
37
+ ... )
38
+ >>> # Create a tensordict with 50 items, each with 10 dialog turns
39
+ >>> td = lazy_stack(
40
+ ... [
41
+ ... TensorDict(
42
+ ... {
43
+ ... ("next", "done"): torch.full((1, 1), True),
44
+ ... # Reward for i+5 tokens
45
+ ... ("next", "reward"): torch.full((i + 5, 1), i),
46
+ ... # total of 10 dialogs per prompt
47
+ ... "text": f"Prompt {i // 5}",
48
+ ... }
49
+ ... )
50
+ ... for i in range(50)
51
+ ... ]
52
+ ... )
53
+ >>> # Create a top-k reward selector with 5 dialog turns and a top-k size of 3
54
+ >>> topk = TopKRewardSelector(total_dialog_turns=5, topk_size=3)
55
+ >>> rb.append_transform(topk)
56
+ >>> for _td in td.chunk(25):
57
+ ... rb.extend(_td)
58
+ >>> # Only wrote top3 of 50 items in 10 groups of 5
59
+ >>> assert rb.write_count == 30
60
+ >>> assert len(rb) == 30
61
+ >>> r3 = rb[:3].get(("next", "reward"), as_padded_tensor=True).squeeze()
62
+ >>> # 0 and 1 are missing because they're not part of the top-k
63
+ >>> assert (
64
+ ... r3 == torch.tensor(
65
+ ... [
66
+ ... [4, 4, 4, 4, 4, 4, 4, 4, 4],
67
+ ... [3, 3, 3, 3, 3, 3, 3, 3, 0],
68
+ ... [2, 2, 2, 2, 2, 2, 2, 0, 0],
69
+ ... ]
70
+ ... )
71
+ ... ).all()
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ total_dialog_turns: int,
77
+ topk_size: int,
78
+ prompt_key: NestedKey = ("text", "prompt"),
79
+ rewards_key: NestedKey = ("next", "reward"),
80
+ done_key: NestedKey = ("next", "done"),
81
+ verbose: bool = True,
82
+ ):
83
+ super().__init__()
84
+ self.in_keys = [prompt_key, rewards_key, done_key]
85
+ self.prompt_key = prompt_key
86
+ self.rewards_key = rewards_key
87
+ self.done_key = done_key
88
+ self.queues = defaultdict(lambda: deque(maxlen=total_dialog_turns))
89
+ self.total_dialog_turns = total_dialog_turns
90
+ self.topk_size = topk_size
91
+ if topk_size > total_dialog_turns:
92
+ raise ValueError(
93
+ f"topk_size must be smaller than or equal to total_dialog_turns, got {topk_size=} and {total_dialog_turns=}"
94
+ )
95
+ self.verbose = verbose
96
+
97
+ def forward(self, tensordict: TensorDictBase) -> Any:
98
+ return tensordict
99
+
100
+ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
101
+ # Tensordict can be any number of dims, but it must contain entire trajectories
102
+ if tensordict.ndim == 1:
103
+ # Check how many done states we have
104
+ num_done = tensordict[self.done_key].sum()
105
+ if num_done > 1:
106
+ done_idx = tensordict[self.done_key].nonzero(as_tuple=True)[0] + 1
107
+ splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff()
108
+ tensordicts = tensordict.split(splits)
109
+ tensordicts = [self._inv_call(td) for td in tensordicts]
110
+ tensordicts = [td for td in tensordicts if td is not None]
111
+ return torch.cat(tensordicts) if tensordicts else None
112
+ # Then we have a single trajectory
113
+ if not tensordict[-1][self.done_key].all():
114
+ raise RuntimeError("Expected the trajectory to be done.")
115
+ prompt = tensordict[0][self.prompt_key]
116
+ if not isinstance(prompt, str):
117
+ raise TypeError(f"Expected a string as prompt, got {type(prompt)=}")
118
+ self.queues[prompt].append(tensordict)
119
+ if len(self.queues[prompt]) == self.total_dialog_turns:
120
+ if self.verbose:
121
+ torchrl_logger.info(f"Getting top-k rewards for {prompt=}")
122
+ # Cat is the most robust way to combine the trajs
123
+ tds = torch.cat(list(self.queues[prompt]), -1)
124
+ # Collect rewards
125
+ reward = tds.get(self.rewards_key, as_nested_tensor=True)
126
+ reward = self._aggregate_rewards(reward)
127
+ # Check if all rewards are equal
128
+ if (reward == reward[0]).all():
129
+ # If all rewards are equal, we can't select top-k
130
+ if self.verbose:
131
+ torchrl_logger.warning(
132
+ f"All rewards are equal ({reward.unique()=})"
133
+ )
134
+ return
135
+ # Filter out rewards below median
136
+ median_reward = reward.median(dim=-1, keepdim=True)[0]
137
+ mask = reward > median_reward
138
+ filtered_reward = reward[mask]
139
+ filtered_indices = mask.nonzero(as_tuple=True)[0]
140
+ # Get top-k from filtered rewards
141
+ topk_reward = filtered_reward.topk(
142
+ k=min(self.topk_size, len(filtered_indices)), dim=-1
143
+ )
144
+ if not topk_reward.indices.numel():
145
+ if self.verbose:
146
+ torchrl_logger.warning(
147
+ f"No top-{self.topk_size} rewards found ({reward=})"
148
+ )
149
+ return
150
+ # Map back to original indices
151
+ selected_indices = filtered_indices[topk_reward.indices]
152
+ tds = tds[selected_indices]
153
+ if self.verbose:
154
+ torchrl_logger.info(
155
+ f"Selected top-{self.topk_size} rewards, with reward {topk_reward.values=}"
156
+ )
157
+ return tds
158
+ return
159
+ elif tensordict.ndim > 2:
160
+ # keep the time dim at the end
161
+ tensordict = tensordict.flatten(0, -2)
162
+ trajs = tensordict.unbind(-1)
163
+ # Iterate over the trajectories
164
+ result = []
165
+ for traj in trajs:
166
+ td_out = self._inv_call(traj)
167
+ if td_out is None:
168
+ continue
169
+ result.append(td_out)
170
+ if result:
171
+ return torch.cat(result, -1)
172
+ return
173
+
174
+ def _aggregate_rewards(self, reward: torch.Tensor) -> torch.Tensor:
175
+ """Aggregate the rewards across the dialog turns.
176
+
177
+ `reward` is expected to be a nested tensor.
178
+
179
+ The default implementation is to take the mean of the rewards across the dialog turns.
180
+ """
181
+ # reward = reward.to_padded_tensor(padding=0.0)
182
+ if reward.ndim < 2 or reward.ndim > 3:
183
+ raise ValueError(
184
+ f"Expected reward to be a 2D or 3D tensor, got {reward.ndim}D tensor"
185
+ )
186
+ return reward.mean(dim=-2).squeeze(-1)