torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,730 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+
9
+ from collections.abc import Callable
10
+
11
+ from typing import Any, Literal, TYPE_CHECKING
12
+
13
+ import torch
14
+ from tensordict import lazy_stack, TensorDictBase
15
+ from tensordict.utils import _zip_strict
16
+ from torch.utils.data import DataLoader
17
+ from torchrl.data import Composite, NonTensor
18
+ from torchrl.data.llm.history import History
19
+ from torchrl.envs import EnvBase, TransformedEnv
20
+
21
+ from torchrl.envs.llm.transforms.dataloading import (
22
+ DataLoadingPrimer,
23
+ RayDataLoadingPrimer,
24
+ )
25
+ from torchrl.modules.llm.policies.common import ChatHistory, Text, Tokens
26
+
27
+ if TYPE_CHECKING:
28
+ import transformers
29
+
30
+
31
+ def _default_collate_fn(batch):
32
+ # We want to rename the "text" key to "query"
33
+ # otherwise it will conflict with the "text" key in the tensordict returned by TorchRL components
34
+ if isinstance(batch, dict) and "text" in batch:
35
+ batch["query"] = batch.pop("text")
36
+ elif isinstance(batch, list):
37
+ for item in batch:
38
+ if "text" in item:
39
+ item["query"] = item.pop("text")
40
+ return batch
41
+
42
+
43
+ class ChatEnv(EnvBase):
44
+ r"""A chat-based environment for LLMs, designed as a blank canvas for conversation and RL.
45
+
46
+ This environment is designed to work seamlessly with both :class:`~torchrl.modules.llm.policies.TransformersWrapper` and
47
+ :class:`~torchrl.modules.llm.policies.vLLMWrapper`. It provides the fundamental structure for managing conversation state
48
+ using the :class:`~torchrl.data.llm.History` format (or, alternatively, tokens or text), but is intentionally minimal to allow
49
+ maximum flexibility through transforms.
50
+
51
+ Core Functionality
52
+ The environment operates in three main modes:
53
+
54
+ - **History mode**: Uses :class:`~torchrl.data.llm.History` objects for conversation management
55
+ - **Text mode**: Uses simple text strings for input/output
56
+ - **Tokens mode**: Uses tokenized data for input/output
57
+
58
+ Reset Operation
59
+ During reset, the environment:
60
+
61
+ 1. Takes input text from the `data_key` (default: `"query"`) in the tensordict
62
+ 2. Creates a :class:`~torchrl.data.llm.History` object with the user's message
63
+ 3. Optionally prepends a system prompt if provided
64
+ 4. Formats the conversation according to the selected input mode (history, text, or tokens)
65
+ 5. Returns the formatted prompt ready for the LLM
66
+
67
+ Step Operation
68
+ During step, the environment:
69
+
70
+ 1. Takes the LLM's response (containing both prompt and generated text)
71
+ 2. Extracts the full conversation history
72
+ 3. Prepares the next prompt by setting the full history as the new prompt
73
+ 4. Returns the updated conversation state
74
+
75
+ This design enables natural multi-turn conversations where each step extends the conversation
76
+ history, making it ideal for dialogue systems and reinforcement learning applications.
77
+
78
+ Integration with Transforms
79
+ ChatEnv is designed to be extended with transforms that add specific capabilities:
80
+
81
+ - **Reward computation**: :class:`~torchrl.envs.llm.transforms.KLRewardTransform` for KL divergence rewards
82
+ - **Tool execution**: :class:`~torchrl.envs.llm.transforms.PythonInterpreter` for Python code execution
83
+ - **Data loading**: :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` for loading prompts from datasets
84
+ - **Thinking prompts**: :class:`~torchrl.envs.llm.transforms.AddThinkingPrompt` for chain-of-thought reasoning
85
+
86
+ Keyword Args:
87
+ input_mode (Literal["history", "text", "tokens"]): The mode of input to the environment.
88
+ Defaults to `"history"`.
89
+ batch_size (torch.Size): Expected batch size of the input. Defaults to `(1,)` (null batch sizes such as `()`
90
+ are not recommended as they don't play well with generators).
91
+ system_prompt (str, optional): An optional `"system"` prompt string to use during reset calls.
92
+ Defaults to `None`.
93
+ tokenizer (transformers.PreTrainedTokenizer, optional): A tokenizer that will be used to tokenize the text.
94
+ Defaults to `None`.
95
+ template_kwargs (dict[str, any], optional): Keyword arguments passed to :meth:`~torchrl.data.llm.History.apply_chat_template`.
96
+ Defaults to `None`.
97
+ system_role (str, optional): The role of the system (at reset time). Defaults to `"system"`.
98
+ user_role (str, optional): The role of the user (at reset time). Defaults to `"user"`.
99
+ policy_role (str, optional): The role of the policy/assistant. Defaults to `"assistant"`.
100
+ data_key (str, optional): The key of the data input to the env at reset time (from dataloader). Defaults to `"query"`.
101
+ device (torch.device, optional): The device to use for computations. Defaults to `None`.
102
+
103
+ Methods:
104
+ reset (TensorDict): Resets the state of the environment. A tensordict or equivalent with a `"query"` entry
105
+ (originating from the dataloader) must be passed. This key name is defined as a class attribute `data_key`.
106
+ step (TensorDict): Makes a step in the environment. A tensordict or equivalent with the LLM's response must be passed.
107
+ The response key is defined as a class attribute `response_key`.
108
+
109
+ .. seealso:: To see examples of a `ChatEnv` in action, see :class:`~torchrl.envs.llm.chat.DatasetChatEnv`,
110
+ :class:`~torchrl.envs.llm.GSM8KEnv` and :class:`~torchrl.envs.llm.IFEvalEnv`.
111
+
112
+ Examples:
113
+ >>> from torchrl.envs.llm import ChatEnv
114
+ >>> from torchrl.data.llm import History
115
+ >>> from tensordict import TensorDict
116
+ >>>
117
+ >>> # Create a basic chat environment
118
+ >>> env = ChatEnv(
119
+ ... system_prompt="You are a helpful assistant.",
120
+ ... input_mode="history"
121
+ ... )
122
+ >>>
123
+ >>> # Reset with a user query
124
+ >>> reset_data = TensorDict({"query": "Hello, how are you?"}, batch_size=(1,))
125
+ >>> obs = env.reset(reset_data)
126
+ >>> print(obs["history"].prompt) # History with system prompt + user message
127
+ >>>
128
+ >>> # Simulate LLM response and step
129
+ >>> response_data = TensorDict({
130
+ ... "history": History.from_chats([[
131
+ ... {"role": "system", "content": "You are a helpful assistant."},
132
+ ... {"role": "user", "content": "Hello, how are you?"},
133
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"}
134
+ ... ]])
135
+ ... }, batch_size=(1,))
136
+ >>> next_obs = env.step(response_data)
137
+ >>> print(next_obs["history"].prompt) # Full conversation history
138
+
139
+ """
140
+
141
+ # Nested key corresponding to the text input to the LLM
142
+ text_key = ("text", "prompt")
143
+ # Nested key corresponding to the response from the LLM
144
+ response_key = ("text", "response")
145
+ # Nested key corresponding to the data input to the env at reset time (from dataloader)
146
+ data_key = "query"
147
+
148
+ def __init__(
149
+ self,
150
+ *,
151
+ input_mode: Literal["history", "text"] = "history",
152
+ batch_size: tuple | torch.Size | None = None,
153
+ system_prompt: str | None = None,
154
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
155
+ template_kwargs: dict[str, Any] | None = None,
156
+ system_role: str = "system",
157
+ user_role: str = "user",
158
+ policy_role: str | None = "assistant",
159
+ data_key: str | None = None,
160
+ device: torch.device | None = None,
161
+ ):
162
+ self.input_mode = input_mode
163
+ if batch_size is None:
164
+ batch_size = (1,)
165
+ if isinstance(batch_size, int):
166
+ batch_size = (batch_size,)
167
+ if isinstance(batch_size, list):
168
+ batch_size = torch.Size(batch_size)
169
+ if batch_size == ():
170
+ raise ValueError(f"{type(self).__name__} must have at least one dimension")
171
+ if data_key is not None:
172
+ self.data_key = data_key
173
+ super().__init__(batch_size=batch_size, device=device)
174
+ self.batch_size = batch_size
175
+
176
+ self.system_prompt = system_prompt
177
+
178
+ if template_kwargs is None:
179
+ template_kwargs = {}
180
+ self.template_kwargs = template_kwargs
181
+
182
+ self.system_role = system_role
183
+ self.user_role = user_role
184
+ self.policy_role = policy_role
185
+ self.tokenizer = tokenizer
186
+
187
+ self._make_specs()
188
+
189
+ def _make_specs(self):
190
+ if self.input_mode == "history":
191
+ self._make_specs_history()
192
+ elif self.input_mode == "text":
193
+ self._make_specs_text()
194
+ elif self.input_mode == "tokens":
195
+ self._make_specs_tokens()
196
+ else:
197
+ raise ValueError(f"Invalid input mode: {self.input_mode}")
198
+
199
+ def _make_specs_history(self):
200
+ # we output prompt
201
+ self.full_observation_spec = Composite(
202
+ history=ChatHistory.default_spec(shape=self.batch_size, keys=["prompt"]).to(
203
+ self.device
204
+ ),
205
+ shape=self.batch_size,
206
+ device=self.device,
207
+ )
208
+ # We receive prompt, response and full
209
+ self.full_action_spec = Composite(
210
+ history=ChatHistory.default_spec(shape=self.batch_size, keys=["full"]).to(
211
+ self.device
212
+ ),
213
+ shape=self.batch_size,
214
+ device=self.device,
215
+ )
216
+ self.full_state_spec = Composite(
217
+ {
218
+ self.data_key: NonTensor(
219
+ example_data="a string", shape=self.batch_size, device=self.device
220
+ )
221
+ },
222
+ shape=self.batch_size,
223
+ device=self.device,
224
+ )
225
+
226
+ def _make_specs_text(self):
227
+ # we output prompt
228
+ self.full_observation_spec = Composite(
229
+ text=Text.default_spec(shape=self.batch_size, keys=["prompt"]).to(
230
+ self.device
231
+ ),
232
+ shape=self.batch_size,
233
+ device=self.device,
234
+ )
235
+ # We receive prompt, response and full
236
+ self.full_action_spec = Composite(
237
+ text=Text.default_spec(shape=self.batch_size, keys=["full"]).to(
238
+ self.device
239
+ ),
240
+ shape=self.batch_size,
241
+ device=self.device,
242
+ )
243
+ self.full_state_spec = Composite(
244
+ {
245
+ self.data_key: NonTensor(
246
+ example_data="a string", shape=self.batch_size, device=self.device
247
+ )
248
+ },
249
+ shape=self.batch_size,
250
+ device=self.device,
251
+ )
252
+
253
+ def _make_specs_tokens(self):
254
+ # we output prompt
255
+ self.full_observation_spec = Composite(
256
+ tokens=Tokens.default_spec(shape=self.batch_size, keys=["prompt"]).to(
257
+ self.device
258
+ ),
259
+ shape=self.batch_size,
260
+ device=self.device,
261
+ )
262
+ # We receive prompt, response and full
263
+ self.full_action_spec = Composite(
264
+ tokens=Tokens.default_spec(shape=self.batch_size, keys=["full"]).to(
265
+ self.device
266
+ ),
267
+ shape=self.batch_size,
268
+ device=self.device,
269
+ )
270
+ self.full_state_spec = Composite(
271
+ {
272
+ self.data_key: NonTensor(
273
+ example_data="a string", shape=self.batch_size, device=self.device
274
+ )
275
+ },
276
+ shape=self.batch_size,
277
+ device=self.device,
278
+ )
279
+
280
+ @classmethod
281
+ def from_dataloader(
282
+ cls,
283
+ dataloader: DataLoader,
284
+ *,
285
+ repeats: int | None = None,
286
+ device: torch.device | None = None,
287
+ group_repeats: bool = False,
288
+ batch_size: tuple | torch.Size | None = None,
289
+ primers: Composite | None = None,
290
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
291
+ template_kwargs: dict[str, Any] | None = None,
292
+ input_mode: Literal["history", "text", "tokens"] = "history",
293
+ data_key: str | None = None,
294
+ system_prompt: str | None = None,
295
+ ):
296
+ """Create a chat environment from a dataloader.
297
+
298
+ Args:
299
+ dataloader (DataLoader): The dataloader to use.
300
+
301
+ Keyword Args:
302
+ repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
303
+ based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
304
+ device (torch.device | None, optional): The device to use for computations. Defaults to None.
305
+ group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
306
+ batch_size (tuple | torch.Size | None, optional): The batch size for data loading. Defaults to `1`.
307
+ primers (Composite | None, optional): The primers to use for data loading. Defaults to `None`.
308
+ tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
309
+ template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
310
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to the environment. Defaults to `"history"`.
311
+ data_key (str, optional): The spec of the data returned by the dataloader (or better, its collate_fn).
312
+ Defaults to `None` (automatically determined based on the input_mode).
313
+ system_prompt (str | None, optional): The system prompt to use for the environment. Defaults to `None`.
314
+
315
+ Returns:
316
+ DatasetChatEnv: The chat environment.
317
+ """
318
+ return DatasetChatEnv.from_dataloader(
319
+ dataloader=dataloader,
320
+ repeats=repeats,
321
+ device=device,
322
+ group_repeats=group_repeats,
323
+ batch_size=batch_size,
324
+ primers=primers,
325
+ tokenizer=tokenizer,
326
+ template_kwargs=template_kwargs,
327
+ input_mode=input_mode,
328
+ data_key=data_key,
329
+ system_prompt=system_prompt,
330
+ )
331
+
332
+ # def _post_step_mdp_hooks(self, tensordict: TensorDictBase) -> TensorDictBase:
333
+ # """Allows modification of the tensordict after the step_mdp."""
334
+ # if self.input_mode == "history":
335
+ # tensordict.exclude(
336
+ # ("history", "response"), ("history", "full"), inplace=True
337
+ # )
338
+ # if self.input_mode in ("text", "history"):
339
+ # tensordict.exclude(("text", "response"), ("text", "full"), inplace=True)
340
+ # if self.input_mode in ("tokens", "history", "text"):
341
+ # tensordict.exclude(("tokens", "response"), ("tokens", "full"), inplace=True)
342
+ # if "log_probs" in tensordict.keys():
343
+ # tensordict.exclude(
344
+ # ("log_probs", "response"), ("log_probs", "full"), inplace=True
345
+ # )
346
+ # return tensordict
347
+
348
+ def _step(self, tensordict):
349
+ if self.input_mode == "history":
350
+ return self._step_history(tensordict)
351
+ if self.input_mode in ("text", "history"):
352
+ return self._step_text(tensordict)
353
+ if self.input_mode in ("tokens", "history", "text"):
354
+ return self._step_tokens(tensordict)
355
+ else:
356
+ raise ValueError(f"Invalid input mode: {self.input_mode}")
357
+
358
+ def _step_history(self, tensordict):
359
+ """Step the environment in history mode."""
360
+ # get history from tensordict
361
+ chat_history: ChatHistory = tensordict["history"]
362
+ # prompt = chat_history.prompt
363
+ full = chat_history.full
364
+ # response = chat_history.response
365
+ empty_td = tensordict.empty(device=self.device)
366
+ # Old full will be new prompt - can be modified at will
367
+ new_history = ChatHistory(prompt=full)
368
+ empty_td.set("history", new_history)
369
+ return empty_td
370
+
371
+ def _step_text(self, tensordict):
372
+ """Step the environment in text mode."""
373
+ # get text from tensordict
374
+ text: Text = tensordict["text"]
375
+ full = text.full
376
+ empty_td = tensordict.empty(device=self.device)
377
+ new_history = Text(prompt=full)
378
+ empty_td.set("text", new_history)
379
+ return empty_td
380
+
381
+ def _step_tokens(self, tensordict):
382
+ """Step the environment in tokens mode."""
383
+ # get tokens from tensordict
384
+ tokens: Tokens = tensordict["tokens"]
385
+ full = tokens.full
386
+ empty_td = tensordict.empty(device=self.device)
387
+ new_history = Tokens(prompt=full)
388
+ empty_td.set("tokens", new_history)
389
+ return empty_td
390
+
391
+ def _reset(self, tensordict: TensorDictBase | None, **kwargs):
392
+ if tensordict is None:
393
+ raise RuntimeError(
394
+ f"{type(self).__name__} expects a tensordict as input. Got `None`."
395
+ )
396
+ # Find the total text
397
+ content = tensordict.get(self.data_key)
398
+ if content is None:
399
+ raise RuntimeError(
400
+ f"{type(self).__name__} expects a tensordict with a {self.data_key} key, got {tensordict.keys()}"
401
+ )
402
+ if content.batch_size != self.batch_size:
403
+ for s in reversed(self.batch_size):
404
+ content = [content for _ in range(s)]
405
+
406
+ # FIXME: Assume the text is not formatted and this is just content
407
+ role = self.user_role
408
+ for s in reversed(self.batch_size):
409
+ role = [role for _ in range(s)]
410
+ history = History(role=role, content=content, batch_size=self.batch_size)
411
+ if self.system_prompt is not None:
412
+ system_role = self.system_role
413
+ history_system = History(
414
+ role=system_role,
415
+ content=self.system_prompt,
416
+ )
417
+ for s in reversed(self.batch_size):
418
+ history_system = lazy_stack([history_system for _ in range(s)])
419
+ history = lazy_stack([history_system, history], -1)
420
+ else:
421
+ history = history.unsqueeze(-1)
422
+
423
+ # Now that we have the history, call the specific reset method
424
+ if self.input_mode == "history":
425
+ return (
426
+ self._reset_history(tensordict, history)
427
+ .update(tensordict)
428
+ .to_lazystack(0)
429
+ )
430
+ elif self.input_mode == "text":
431
+ return (
432
+ self._reset_text(tensordict, history).update(tensordict).to_lazystack(0)
433
+ )
434
+ elif self.input_mode == "tokens":
435
+ return (
436
+ self._reset_tokens(tensordict, history)
437
+ .update(tensordict)
438
+ .to_lazystack(0)
439
+ )
440
+ else:
441
+ raise ValueError(f"Invalid input mode: {self.input_mode}")
442
+
443
+ def _reset_history(self, tensordict: TensorDictBase, history: History):
444
+ # Simplest case: history is the prompt
445
+ chat_history = ChatHistory._from_tensordict(
446
+ tensordict.empty(device=self.device)
447
+ )
448
+ chat_history.prompt = history
449
+ return tensordict.empty(device=self.device).set("history", chat_history)
450
+
451
+ def _reset_text(self, tensordict: TensorDictBase, history: History):
452
+ # We need to parse the history to a text
453
+ text = history.apply_chat_template(
454
+ tokenizer=self.tokenizer, add_generation_prompt=True, **self.template_kwargs
455
+ )
456
+ txt = Text._from_tensordict(tensordict.empty())
457
+ txt.prompt = text
458
+ result = tensordict.empty(device=self.device).set("text", txt)
459
+ return result
460
+
461
+ def _reset_tokens(self, tensordict: TensorDictBase, history: History):
462
+ # We need to parse the history to a tokens
463
+ tokens = history.apply_chat_template(
464
+ tokenizer=self.tokenizer,
465
+ add_generation_prompt=True,
466
+ return_tensors="pt",
467
+ return_dict=True,
468
+ **self.template_kwargs,
469
+ )
470
+ tokens_obj = Tokens._from_tensordict(tensordict.empty().to_lazystack(0))
471
+ for to, tok in _zip_strict(tokens_obj.unbind(0), tokens["input_ids"]):
472
+ to.prompt = tok
473
+ result = tensordict.empty(device=self.device).set("tokens", tokens_obj)
474
+ return result
475
+
476
+ def _set_seed(self, seed):
477
+ return
478
+
479
+
480
+ class DatasetChatEnv(TransformedEnv):
481
+ """Base class for chat environment with queries pulled from a dataset.
482
+
483
+ Typical usage include RLHF (Reinforcement Learning from Human feedback) or RLVR (Reinforcement learning with Verifiable rewards).
484
+
485
+ Keyword Args:
486
+ dataset (str): The name of the dataset.
487
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
488
+ name (str, optional): name of the dataset configuration.
489
+ split (str, optional): the split to use (usually from `"train"`, `"val"` or `"test"`). Defaults to `None` (no split).
490
+ num_envs (int, optional): The number of environments to create. Defaults to `1`.
491
+ repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
492
+ based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
493
+ batch_size_dl (int, optional): The batch size for data loading. Defaults to `1`.
494
+ seed (int | None, optional): The random seed for reproducibility. If `None`, a random seed is used. Defaults to `None`.
495
+ group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
496
+ tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
497
+
498
+ .. note:: It is recommended to pass a tokenizer to the environment. This is an easy way to ensure that the
499
+ template applied to the chat history is consistent with the format required by the model.
500
+
501
+ device (torch.device | None, optional): The device to use for computations. Defaults to None.
502
+ template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
503
+ apply_template (bool | None, optional): Whether to apply the template to the text. Defaults to `False`.
504
+ collate_fn (Callable | None, optional): A custom collate function for data loading. If `None`, a default
505
+ collate function is used that renames the `"text"` key to `"query"` to avoid conflicts with the `"text"` key
506
+ in the tensordict returned by TorchRL components. Defaults to `None`.
507
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to the environment. Defaults to `"history"`.
508
+ data_key (str, optional): The spec of the data returned by the dataloader (or better, its collate_fn).
509
+ Defaults to `None` (automatically determined based on the input_mode).
510
+ system_prompt (str | None, optional): The system prompt to use for the environment. Defaults to `None`.
511
+ ray_backend (bool, optional): Whether to use the Ray backend for data loading. Defaults to `False`.
512
+ Using this backend allows for explicit resource control and avoids serialization issues, as well as
513
+ sharing the same dataloader across multiple environments and actors.
514
+ dataloader_actor_name (str | None, optional): Name of the Ray actor to use for data loading.
515
+ Ignored if `ray_backend` is `None`.
516
+
517
+ .. seealso:: `DatasetChatEnv` is a thin wrapper around :class:`~torchrl.envs.llm.ChatEnv` bucketed with a
518
+ :class:`~torchrl.envs.llm.DataLoadingPrimer` transform. See these two classes for more insight on data format
519
+ and functionality.
520
+
521
+ .. seealso:: Examples of `DatasetChatEnv` include :class:`~torchrl.envs.llm.GSM8KEnv` and :class:`~torchrl.envs.llm.IFEvalEnv`.
522
+
523
+ """
524
+
525
+ SYSTEM_PROMPT: str | None = None
526
+
527
+ def __init__(
528
+ self,
529
+ *,
530
+ dataset: str,
531
+ shuffle: bool = True,
532
+ name: str | None = None,
533
+ split: Literal["train", "val", "test"] | None = None,
534
+ num_envs: int = 1,
535
+ repeats: int | None = None,
536
+ batch_size_dl: int = 1,
537
+ seed: int | None = None,
538
+ group_repeats: bool = False,
539
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
540
+ device: torch.device | None = None,
541
+ template_kwargs: dict[str, Any] | None = None,
542
+ apply_template: bool | None = False,
543
+ collate_fn: Callable[[Any], Any] | None = None,
544
+ input_mode: Literal["history", "text", "tokens"] = "history",
545
+ data_key: str | None = None,
546
+ primers: Composite | None = None,
547
+ system_prompt: str | None = None,
548
+ ray_backend: bool = False,
549
+ dataloader_actor_name: str | None = None,
550
+ ):
551
+ from tensordict import list_to_stack
552
+
553
+ if not list_to_stack():
554
+ raise RuntimeError(
555
+ "list_to_stack() must return True. Use LIST_TO_STACK=1 or `tensordict.set_list_to_stack(True).set()` "
556
+ "at the beginning of the script."
557
+ )
558
+
559
+ batch_size = (num_envs,)
560
+
561
+ dataloader_factory = functools.partial(
562
+ self._dataloader_factory,
563
+ dataset=dataset,
564
+ name=name,
565
+ split=split,
566
+ seed=seed,
567
+ batch_size_dl=batch_size_dl,
568
+ shuffle=shuffle,
569
+ collate_fn=collate_fn,
570
+ )
571
+ self._from_dataloader(
572
+ self,
573
+ dataloader=None,
574
+ dataloader_factory=dataloader_factory,
575
+ ray_backend=ray_backend,
576
+ repeats=repeats,
577
+ device=device,
578
+ group_repeats=group_repeats,
579
+ batch_size=batch_size,
580
+ primers=primers,
581
+ tokenizer=tokenizer,
582
+ template_kwargs=template_kwargs,
583
+ input_mode=input_mode,
584
+ data_key=data_key,
585
+ system_prompt=system_prompt,
586
+ dataloader_actor_name=dataloader_actor_name,
587
+ )
588
+
589
+ @staticmethod
590
+ def _dataloader_factory(
591
+ dataset, name, split, seed, batch_size_dl, shuffle, collate_fn
592
+ ):
593
+ from datasets import load_dataset
594
+
595
+ dataset_obj = load_dataset(dataset, name)
596
+ if split is None and "train" in dataset_obj:
597
+ split = "train"
598
+ if split is not None:
599
+ dataset_obj = dataset_obj[split]
600
+ # Env
601
+ if seed is None:
602
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
603
+ generator = torch.Generator(device=torch.get_default_device())
604
+ generator.manual_seed(seed)
605
+
606
+ dataloader = DataLoader( # noqa: TOR401
607
+ dataset_obj,
608
+ batch_size=batch_size_dl,
609
+ shuffle=shuffle,
610
+ collate_fn=collate_fn if collate_fn is not None else _default_collate_fn,
611
+ generator=generator,
612
+ )
613
+ return dataloader
614
+
615
+ @classmethod
616
+ def from_dataloader(
617
+ cls,
618
+ dataloader: DataLoader,
619
+ *,
620
+ repeats: int | None = None,
621
+ device: torch.device | None = None,
622
+ group_repeats: bool = False,
623
+ batch_size: tuple | torch.Size | None = None,
624
+ primers: Composite | None = None,
625
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
626
+ template_kwargs: dict[str, Any] | None = None,
627
+ input_mode: Literal["history", "text", "tokens"] = "history",
628
+ data_key: str | None = None,
629
+ system_prompt: str | None = None,
630
+ ):
631
+ """Create a chat environment from a dataloader.
632
+
633
+ Args:
634
+ dataloader (DataLoader): The dataloader to use.
635
+
636
+ Keyword Args:
637
+ repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
638
+ based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
639
+ device (torch.device | None, optional): The device to use for computations. Defaults to None.
640
+ group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
641
+ batch_size (tuple | torch.Size | None, optional): The batch size for data loading. Defaults to `1`.
642
+ primers (Composite | None, optional): The primers to use for data loading. Defaults to `None`.
643
+ tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
644
+ template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
645
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to the environment. Defaults to `"history"`.
646
+ data_key (str, optional): The spec of the data returned by the dataloader (or better, its collate_fn).
647
+ Defaults to `None` (automatically determined based on the input_mode).
648
+ system_prompt (str | None, optional): The system prompt to use for the environment. Defaults to `None`.
649
+
650
+ Returns:
651
+ ChatEnv: The chat environment.
652
+ """
653
+ self = cls.__new__(cls)
654
+ return cls._from_dataloader(
655
+ self,
656
+ dataloader=dataloader,
657
+ repeats=repeats,
658
+ device=device,
659
+ group_repeats=group_repeats,
660
+ batch_size=batch_size,
661
+ primers=primers,
662
+ tokenizer=tokenizer,
663
+ template_kwargs=template_kwargs,
664
+ input_mode=input_mode,
665
+ data_key=data_key,
666
+ system_prompt=system_prompt,
667
+ )
668
+
669
+ @classmethod
670
+ def _from_dataloader(
671
+ cls,
672
+ self,
673
+ dataloader=None,
674
+ *,
675
+ dataloader_factory=None,
676
+ repeats: int | None = None,
677
+ device: torch.device | None = None,
678
+ group_repeats: bool = False,
679
+ batch_size: tuple | torch.Size | None = None,
680
+ primers: Composite | None = None,
681
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
682
+ template_kwargs: dict[str, Any] | None = None,
683
+ input_mode: Literal["history", "text", "tokens"] = "history",
684
+ data_key: str | None = None,
685
+ system_prompt: str | None = None,
686
+ ray_backend: bool = False,
687
+ dataloader_actor_name: str | None = None,
688
+ ):
689
+ if ray_backend:
690
+ dl_cls = functools.partial(
691
+ RayDataLoadingPrimer, actor_name=dataloader_actor_name
692
+ )
693
+ else:
694
+ if dataloader_actor_name is not None:
695
+ raise ValueError(
696
+ "dataloader_actor_name must be None if ray_backend is False"
697
+ )
698
+ dl_cls = DataLoadingPrimer
699
+ primer = dl_cls(
700
+ dataloader=dataloader,
701
+ dataloader_factory=dataloader_factory,
702
+ repeats=repeats,
703
+ device=device,
704
+ group_repeats=group_repeats,
705
+ batch_size=batch_size,
706
+ primers=primers,
707
+ )
708
+ env_base = ChatEnv(
709
+ batch_size=batch_size,
710
+ system_prompt=cls.SYSTEM_PROMPT if system_prompt is None else system_prompt,
711
+ tokenizer=tokenizer,
712
+ template_kwargs=template_kwargs,
713
+ input_mode=input_mode,
714
+ data_key=data_key,
715
+ device=device,
716
+ )
717
+ TransformedEnv.__init__(self, env_base, primer)
718
+ return self
719
+
720
+ def reset_dataloader(self):
721
+ """Reset the dataloader.
722
+
723
+ This is useful when the dataloader is not infinite and we want to reset it.
724
+
725
+ Returns:
726
+ self: The environment itself.
727
+ """
728
+ if hasattr(self.transform, "__getitem__"):
729
+ self.transform[0].reset_dataloader()
730
+ return self