torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,4 @@
1
+ # Datasets
2
+
3
+ This folder contains utils for specific datasets, such as reward parsers or pre-build
4
+ environments.
@@ -0,0 +1,17 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env
8
+ from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer
9
+
10
+ __all__ = [
11
+ "make_gsm8k_env",
12
+ "GSM8KPrepareQuestion",
13
+ "GSM8KEnv",
14
+ "IFEvalEnv",
15
+ "IFEvalData",
16
+ "IfEvalScorer",
17
+ ]
@@ -0,0 +1,353 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+ from collections.abc import Callable
9
+ from typing import Any, Literal, TYPE_CHECKING
10
+
11
+ import torch
12
+ from tensordict import NestedKey, TensorDict, TensorDictBase
13
+ from tensordict.tensorclass import NonTensorData, NonTensorStack
14
+ from tensordict.utils import _zip_strict
15
+ from torch.utils.data import DataLoader
16
+ from torchrl.data import TensorSpec
17
+ from torchrl.envs import StepCounter, Transform
18
+
19
+ from torchrl.envs.llm.chat import DatasetChatEnv
20
+
21
+ from torchrl.envs.llm.envs import LLMEnv
22
+ from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
23
+
24
+ if TYPE_CHECKING:
25
+ import transformers
26
+
27
+ BASE_PROMPT = (
28
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "
29
+ "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
30
+ "The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively, "
31
+ "i.e., <think>reasoning process here</think> <answer>answer here</answer>. User: %s. Assistant: <think>"
32
+ )
33
+
34
+
35
+ class GSM8KPrepareQuestion(Transform):
36
+ """A transform to prepare the prompt when using GSM8k within an LLMEnv."""
37
+
38
+ def __init__(
39
+ self,
40
+ in_keys: list[NestedKey] | None = None,
41
+ out_keys: list[NestedKey] | None = None,
42
+ ):
43
+ if in_keys is None:
44
+ in_keys = ["text"]
45
+ if out_keys is None:
46
+ out_keys = list(in_keys)
47
+ super().__init__(in_keys, out_keys)
48
+
49
+ def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
50
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
51
+ string = tensordict.get(in_key)
52
+ tensordict.set(out_key, self._modify_str(string))
53
+ return tensordict
54
+
55
+ def _modify_str(
56
+ self, obs: str | list[str] | NonTensorData | NonTensorStack
57
+ ) -> NonTensorData | NonTensorStack:
58
+ if isinstance(obs, NonTensorData):
59
+ return self._modify_str(obs.data)
60
+ if isinstance(obs, NonTensorStack):
61
+ return self._modify_str(obs.tolist())
62
+ if isinstance(obs, list):
63
+ return NonTensorStack(*[BASE_PROMPT % obs for obs in obs])
64
+ return NonTensorData(BASE_PROMPT % obs)
65
+
66
+ def _apply_transform(self, obs: torch.Tensor) -> None:
67
+ return obs
68
+
69
+ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
70
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
71
+ if out_key != in_key:
72
+ observation_spec[out_key] = observation_spec[in_key].clone()
73
+ return observation_spec
74
+
75
+
76
+ def _collate_fn(batch):
77
+ batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
78
+ batch.rename_key_("question", "query")
79
+ return batch
80
+
81
+
82
+ def make_gsm8k_env(
83
+ dataset: str = "openai/gsm8k",
84
+ num_envs: int = 1,
85
+ repeats: int | None = None,
86
+ batch_size_dl: int = 1,
87
+ seed: int | None = None,
88
+ group_repeats: bool = False,
89
+ tokenizer: transformers.PretrainedTokenizer | None = None, # noqa
90
+ ):
91
+ """A builder for an LLMEnv-based GSM8K environment.
92
+
93
+ .. note:: Prefer `torchrl.envs.llm.GSM8KEnv` to interact with this dataset.
94
+
95
+ """
96
+ warnings.warn("This constructor is to be deprecated. Use GSM8KEnv instead.")
97
+ from datasets import load_dataset
98
+
99
+ dataset = load_dataset(dataset, "main")
100
+ train_dataset = dataset["train"]
101
+
102
+ # Env
103
+ if seed is None:
104
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
105
+ generator = torch.Generator(device=torch.get_default_device())
106
+ generator.manual_seed(seed)
107
+
108
+ dataloader = DataLoader( # noqa: TOR401
109
+ train_dataset,
110
+ batch_size=batch_size_dl,
111
+ shuffle=True,
112
+ collate_fn=_collate_fn,
113
+ generator=generator,
114
+ )
115
+ env = LLMEnv.from_dataloader(
116
+ dataloader=dataloader,
117
+ # tokenizer=tokenizer,
118
+ from_text=True,
119
+ batch_size=(num_envs,),
120
+ repeats=repeats,
121
+ group_repeats=group_repeats,
122
+ # assign_reward=True,
123
+ )
124
+ env.insert_transform(0, GSM8KPrepareQuestion())
125
+
126
+ # Finally, we want the env to stop after the first step
127
+ env.append_transform(StepCounter(max_steps=1))
128
+
129
+ if tokenizer is not None:
130
+ env.append_transform(
131
+ GSM8KRewardParser(
132
+ tokenizer=tokenizer,
133
+ input_mode="text",
134
+ in_keys=["text_response", "answer"],
135
+ )
136
+ )
137
+ else:
138
+ warnings.warn("No tokenizer specified - reward will not be assigned.")
139
+
140
+ return env
141
+
142
+
143
+ class GSM8KEnv(DatasetChatEnv):
144
+ r"""GSM8K dataset environment.
145
+
146
+ Keyword Args:
147
+ dataset (str, optional): The name of the dataset. Defaults to `"gsm8k"`.
148
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
149
+ num_envs (int, optional): The number of environments to create. Defaults to `1`.
150
+ repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
151
+ based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
152
+ batch_size_dl (int, optional): The batch size for data loading. Defaults to `1`.
153
+ seed (int | None, optional): The random seed for reproducibility. If `None`, a random seed is used. Defaults to `None`.
154
+ group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
155
+ tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
156
+
157
+ .. note:: It is recommended to pass a tokenizer to the environment. This is an easy way to ensure that the
158
+ template applied to the chat history is consistent with the format required by the model.
159
+
160
+ device (torch.device | None, optional): The device to use for computations. Defaults to None.
161
+ template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
162
+ apply_template (bool | None, optional): Whether to apply the template to the text. Defaults to `False`.
163
+ compute_reward (bool, optional): Whether to compute rewards. Defaults to `True`.
164
+ collate_fn (Callable | None, optional): A custom collate function for data loading. If `None`, a default
165
+ collate function is used. Defaults to `None`.
166
+ max_steps (int, optional): The maximum number of steps allowed in an episode. Defaults to `1`.
167
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to use. Defaults to `"history"`.
168
+ ray_backend (bool, optional): Whether to use the Ray backend for data loading. Defaults to `False`.
169
+ Using this backend allows for explicit resource control and avoids serialization issues, as well as
170
+ sharing the same dataloader across multiple environments and actors.
171
+ dataloader_actor_name (str, optional): Name of the Ray actor to use for data loading.
172
+ Defaults to `"gsm8k_dataloader"`.
173
+
174
+ Examples:
175
+ >>> import transformers
176
+ >>> from torchrl.envs.llm.datasets.gsm8k import GSM8KEnv
177
+ >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
178
+ >>> env = GSM8KEnv(tokenizer=tokenizer, apply_template=True)
179
+ >>> r = env.reset()
180
+ >>> assert "history" in r
181
+ >>> # We have an instruction step (role="system") and a question (role="user")
182
+ >>> assert r["history"].shape == (1, 2)
183
+ >>> assert "text" in r
184
+ >>> r = r.clone()
185
+ >>> print(r)
186
+ LazyStackedTensorDict(
187
+ fields={
188
+ answer: NonTensorStack(
189
+ ['Adam bought 3 sandwiches, so he paid 3 * 3 = $<<...,
190
+ batch_size=torch.Size([1]),
191
+ device=None),
192
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
193
+ history: History(
194
+ content=NonTensorStack(
195
+ [['A conversation between User and Assistant. The ...,
196
+ batch_size=torch.Size([1, 2]),
197
+ device=None),
198
+ role=NonTensorStack(
199
+ [['system', 'user']],
200
+ batch_size=torch.Size([1, 2]),
201
+ device=None),
202
+ batch_size=torch.Size([1, 2]),
203
+ device=None,
204
+ is_shared=False),
205
+ step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
206
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
207
+ text: NonTensorStack(
208
+ ['<|im_start|>system\nA conversation between User ...,
209
+ batch_size=torch.Size([1]),
210
+ device=None),
211
+ truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
212
+ exclusive_fields={
213
+ },
214
+ batch_size=torch.Size([1]),
215
+ device=None,
216
+ is_shared=False,
217
+ stack_dim=0)
218
+ >>> response = "<think>First, calculate the total number of snakes in the breeding balls. There are 3 breeding balls with 8 snakes each, so 3 * 8 = 24 snakes. Next, calculate the number of snakes in the additional pairs. There are 6 pairs of snakes, and each pair has 2 snakes, so 6 * 2 = 12 snakes. Finally, add the number of snakes from the breeding balls and the additional pairs: 24 + 12 = 36 snakes.</think> <answer>Mary saw a total of 36 snakes.</answer><|im_end|>"
219
+ >>> r["text_response"] = [response]
220
+ >>> s = env.step(r)
221
+ >>> print(s)
222
+ LazyStackedTensorDict(
223
+ fields={
224
+ answer: NonTensorStack(
225
+ ['Adam bought 3 sandwiches, so he paid 3 * 3 = $<<...,
226
+ batch_size=torch.Size([1]),
227
+ device=None),
228
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
229
+ history: History(
230
+ content=NonTensorStack(
231
+ [['A conversation between User and Assistant. The ...,
232
+ batch_size=torch.Size([1, 2]),
233
+ device=None),
234
+ role=NonTensorStack(
235
+ [['system', 'user']],
236
+ batch_size=torch.Size([1, 2]),
237
+ device=None),
238
+ batch_size=torch.Size([1, 2]),
239
+ device=None,
240
+ is_shared=False),
241
+ next: LazyStackedTensorDict(
242
+ fields={
243
+ answer: NonTensorStack(
244
+ ['Adam bought 3 sandwiches, so he paid 3 * 3 = $<<...,
245
+ batch_size=torch.Size([1]),
246
+ device=None),
247
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
248
+ history: History(
249
+ content=NonTensorStack(
250
+ [['A conversation between User and Assistant. The ...,
251
+ batch_size=torch.Size([1, 3]),
252
+ device=None),
253
+ role=NonTensorStack(
254
+ [['system', 'user', 'assistant']],
255
+ batch_size=torch.Size([1, 3]),
256
+ device=None),
257
+ batch_size=torch.Size([1, 3]),
258
+ device=None,
259
+ is_shared=False),
260
+ reward: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
261
+ reward_answer: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
262
+ reward_contained: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
263
+ reward_right: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
264
+ reward_think: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
265
+ step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
266
+ success: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
267
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
268
+ text: NonTensorStack(
269
+ ['<|im_start|>system\nA conversation between User ...,
270
+ batch_size=torch.Size([1]),
271
+ device=None),
272
+ truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
273
+ exclusive_fields={
274
+ },
275
+ batch_size=torch.Size([1]),
276
+ device=None,
277
+ is_shared=False,
278
+ stack_dim=0),
279
+ step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
280
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
281
+ text: NonTensorStack(
282
+ ['<|im_start|>system\nA conversation between User ...,
283
+ batch_size=torch.Size([1]),
284
+ device=None),
285
+ text_response: NonTensorStack(
286
+ ['<think>First, calculate the total number of snak...,
287
+ batch_size=torch.Size([1]),
288
+ device=None),
289
+ truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
290
+ exclusive_fields={
291
+ },
292
+ batch_size=torch.Size([1]),
293
+ device=None,
294
+ is_shared=False,
295
+ stack_dim=0)
296
+ >>> assert s["next", "reward"] >= 10
297
+ >>> assert s["next", "done"].all()
298
+
299
+ """
300
+
301
+ SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
302
+ The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
303
+ The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively,
304
+ i.e., <think>reasoning process here</think> <answer>answer here</answer>. The answer should be a number."""
305
+
306
+ def __init__(
307
+ self,
308
+ *,
309
+ dataset: str = "openai/gsm8k",
310
+ shuffle: bool = True,
311
+ num_envs: int = 1,
312
+ repeats: int | None = None,
313
+ batch_size_dl: int = 1,
314
+ seed: int | None = None,
315
+ group_repeats: bool = False,
316
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa
317
+ device: torch.device | None = None,
318
+ template_kwargs: dict[str, Any] | None = None,
319
+ apply_template: bool | None = False,
320
+ compute_reward: bool = True,
321
+ collate_fn: Callable | None = None,
322
+ max_steps: int = 1,
323
+ input_mode: Literal["history", "text", "tokens"] = "history",
324
+ ray_backend: bool = False,
325
+ dataloader_actor_name: str | None = None,
326
+ ):
327
+ if ray_backend and dataloader_actor_name is None:
328
+ dataloader_actor_name = "gsm8k_dataloader"
329
+ if collate_fn is None:
330
+ collate_fn = _collate_fn
331
+ super().__init__(
332
+ dataset=dataset,
333
+ shuffle=shuffle,
334
+ name="main",
335
+ num_envs=num_envs,
336
+ repeats=repeats,
337
+ batch_size_dl=batch_size_dl,
338
+ seed=seed,
339
+ group_repeats=group_repeats,
340
+ tokenizer=tokenizer,
341
+ device=device,
342
+ template_kwargs=template_kwargs,
343
+ apply_template=apply_template,
344
+ collate_fn=collate_fn,
345
+ input_mode=input_mode,
346
+ ray_backend=ray_backend,
347
+ dataloader_actor_name=dataloader_actor_name,
348
+ )
349
+ if max_steps:
350
+ self.append_transform(StepCounter(max_steps=max_steps))
351
+ if compute_reward:
352
+ t = GSM8KRewardParser(tokenizer=tokenizer)
353
+ self.append_transform(t)
@@ -0,0 +1,274 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Callable
8
+
9
+ from typing import Any, Literal, TYPE_CHECKING
10
+
11
+ import torch
12
+ from tensordict import NonTensorData, NonTensorStack, TensorClass, TensorDict
13
+ from torchrl.data import Composite, NonTensor, Unbounded
14
+ from torchrl.envs import StepCounter
15
+ from torchrl.envs.llm.chat import DatasetChatEnv
16
+ from torchrl.envs.llm.reward.ifeval import IfEvalScorer
17
+
18
+ if TYPE_CHECKING:
19
+ import transformers
20
+
21
+
22
+ class IFEvalData(TensorClass["nocast"]):
23
+ """A tensorclass for IFEval dta."""
24
+
25
+ key: torch.Tensor
26
+ instruction_id_list: list[str]
27
+ kwargs: list[dict]
28
+ query: str
29
+
30
+ # Reponses and additional fields
31
+ response: str | None = None
32
+ tokens: torch.Tensor | None = None
33
+ tokens_response: torch.Tensor | None = None
34
+ logits: torch.Tensor | None = None
35
+ reward: torch.Tensor | None = None
36
+
37
+ @classmethod
38
+ def default_spec(
39
+ cls, shape: torch.Size, device: torch.device | None = None
40
+ ) -> Composite:
41
+ return Composite(
42
+ key=Unbounded(shape=shape, dtype=torch.int64, device=device),
43
+ instruction_id_list=NonTensor(
44
+ shape=shape,
45
+ device=device,
46
+ feature_dims=0,
47
+ example_data=["punctuation:no_comma"],
48
+ ),
49
+ kwargs=NonTensor(
50
+ shape=shape,
51
+ device=device,
52
+ feature_dims=0,
53
+ example_data={
54
+ "num_highlights": None,
55
+ "relation": None,
56
+ "num_placeholders": None,
57
+ },
58
+ ),
59
+ query=NonTensor(
60
+ shape=shape,
61
+ device=device,
62
+ example_data="Plan a 2 week Europe trip and visit London, Paris, and Rome. Answer in all caps. The response must contain at least 8 placeholders (i.e., [restaurant]).",
63
+ ),
64
+ shape=shape,
65
+ step_mdp_static=True,
66
+ data_cls=cls,
67
+ )
68
+
69
+
70
+ def _collate_fn(batch):
71
+ batch = torch.stack([TensorDict.from_any(_batch) for _batch in batch])
72
+ batch.rename_key_("prompt", "query")
73
+ # we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
74
+ instruction_id_list = batch["instruction_id_list"]
75
+ # instruction_id_list should be a list of lists
76
+ instruction_id_list = NonTensorStack(
77
+ *[
78
+ NonTensorData([item] if not isinstance(item, list) else item)
79
+ for item in instruction_id_list
80
+ ]
81
+ )
82
+ kwargs = batch["kwargs"]
83
+ kwargs = NonTensorStack(
84
+ *[
85
+ NonTensorData([item] if not isinstance(item, list) else item)
86
+ for item in kwargs
87
+ ]
88
+ )
89
+ batch.set("instruction_id_list", instruction_id_list)
90
+ batch.set("kwargs", kwargs)
91
+ # we don't need a tensorclass here
92
+ return batch
93
+ # return IFEvalData.from_tensordict(batch)
94
+
95
+
96
+ class IFEvalEnv(DatasetChatEnv):
97
+ r"""A chat environment based on the IFEval dataset.
98
+
99
+ Keyword Args:
100
+ dataset (str, optional): The name of the dataset. Defaults to `"google/IFeval"`.
101
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
102
+ num_envs (int, optional): The number of environments to create. Defaults to `1`.
103
+ repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
104
+ based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
105
+ batch_size_dl (int, optional): The batch size for data loading. Defaults to `1`.
106
+ seed (int | None, optional): The random seed for reproducibility. If `None`, a random seed is used. Defaults to `None`.
107
+ group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
108
+ tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
109
+
110
+ .. note:: It is recommended to pass a tokenizer to the environment. This is an easy way to ensure that the
111
+ template applied to the chat history is consistent with the format required by the model.
112
+
113
+ device (torch.device | None, optional): The device to use for computations. Defaults to None.
114
+ template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
115
+ apply_template (bool | None, optional): Whether to apply the template to the text. Defaults to `False`.
116
+ compute_reward (bool, optional): Whether to compute rewards. Defaults to `True`.
117
+ collate_fn (Callable | None, optional): A custom collate function for data loading. If `None`, a default
118
+ collate function is used. Defaults to `None`.
119
+ max_steps (int, optional): The maximum number of steps allowed in an episode. Defaults to `1`.
120
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to use. Defaults to `"history"`.
121
+ ray_backend (bool, optional): Whether to use the Ray backend for data loading. Defaults to `False`.
122
+ Using this backend allows for explicit resource control and avoids serialization issues, as well as
123
+ sharing the same dataloader across multiple environments and actors.
124
+ dataloader_actor_name (str, optional): Name of the Ray actor to use for data loading.
125
+ Defaults to `"ifeval_dataloader"`.
126
+
127
+ Examples:
128
+ >>> import transformers
129
+ >>> from pprint import pprint
130
+ >>> from torchrl.envs.llm.datasets import IFEvalEnv
131
+ >>> from tensordict import set_list_to_stack
132
+ >>> set_list_to_stack(True).set()
133
+ >>>
134
+ >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
135
+ >>> env = IFEvalEnv(tokenizer=tokenizer, apply_template=True)
136
+ >>> r = env.reset()
137
+ >>> print(r)
138
+ LazyStackedTensorDict(
139
+ fields={
140
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
141
+ history: History(
142
+ content=NonTensorStack(
143
+ [['A conversation between User and Assistant.\nYou...,
144
+ batch_size=torch.Size([1, 2]),
145
+ device=None),
146
+ role=NonTensorStack(
147
+ [['system', 'user']],
148
+ batch_size=torch.Size([1, 2]),
149
+ device=None),
150
+ batch_size=torch.Size([1, 2]),
151
+ device=None,
152
+ is_shared=False),
153
+ instruction_id_list: NonTensorStack(
154
+ [['detectable_content:number_placeholders']],
155
+ batch_size=torch.Size([1, 1]),
156
+ device=None),
157
+ key: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
158
+ kwargs: NonTensorStack(
159
+ [[{'num_highlights': None, 'relation': None, 'num_...,
160
+ batch_size=torch.Size([1, 1]),
161
+ device=None),
162
+ step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
163
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
164
+ text: NonTensorStack(
165
+ ['<|im_start|>system\nA conversation between User ...,
166
+ batch_size=torch.Size([1]),
167
+ device=None),
168
+ truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
169
+ exclusive_fields={
170
+ },
171
+ batch_size=torch.Size([1]),
172
+ device=None,
173
+ is_shared=False,
174
+ stack_dim=0)
175
+ >>> # Print content of conversation so far
176
+ >>> pprint(r["history", "content"])
177
+ [['A conversation between User and Assistant.\n'
178
+ 'You are tasked with responding to user queries in a very specific format. \n'
179
+ 'When given a task or question, first think through the problem and provide '
180
+ 'your thought process between <think> and </think> tags.\n'
181
+ 'Then, give your final answer or response between <answer> and </answer> '
182
+ 'tags.\n'
183
+ 'You will be assessed by the content of the answer block only, so make sure '
184
+ 'it contains all the required information, and only that.',
185
+ 'Plan a 2 week Europe trip and visit London, Paris, and Rome. Answer in all '
186
+ 'caps. The response must contain at least 8 placeholders (i.e., '
187
+ '[restaurant]).']]
188
+ >>> # Actions space: the environment expects an action with key "text_response" containing a (list of) strings
189
+ >>> print(env.action_spec)
190
+ Composite(
191
+ text_response: NonTensor(
192
+ shape=torch.Size([1]),
193
+ space=None,
194
+ device=None,
195
+ dtype=None,
196
+ domain=None,
197
+ example_data=a string),
198
+ device=None,
199
+ shape=torch.Size([1]))
200
+
201
+ """
202
+
203
+ SYSTEM_PROMPT = """You are a helpful AI assistant that follows instructions extremely well.
204
+
205
+ IMPORTANT: You must respond in a specific format for every task:
206
+
207
+ 1. First, think through the problem step by step and write your reasoning between <think> and </think> tags
208
+ 2. Then, provide your final answer between <answer> and </answer> tags
209
+
210
+ CRITICAL RULES:
211
+ - ALWAYS use <think>...</think> and <answer>...</answer> tags exactly as shown
212
+ - Do NOT use <thought>, <reasoning>, or any other tag variations
213
+ - Your <answer> section will be evaluated, so make it complete and accurate
214
+ - Follow ALL specific requirements in the user's request (formatting, content, etc.)
215
+ - If the user asks for placeholders like [restaurant], include them exactly as requested
216
+ - Pay attention to capitalization, punctuation, and other formatting requirements
217
+
218
+ Example format:
219
+ <think>
220
+ I need to analyze what the user is asking for...
221
+ [Your reasoning here]
222
+ </think>
223
+ <answer>
224
+ [Your final answer here, following all user requirements]
225
+ </answer>"""
226
+
227
+ def __init__(
228
+ self,
229
+ *,
230
+ dataset: str = "google/IFeval",
231
+ shuffle: bool = True,
232
+ num_envs: int = 1,
233
+ repeats: int | None = None,
234
+ batch_size_dl: int = 1,
235
+ seed: int | None = None,
236
+ group_repeats: bool = False,
237
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa
238
+ device: torch.device | None = None,
239
+ template_kwargs: dict[str, Any] | None = None,
240
+ apply_template: bool | None = False,
241
+ compute_reward: bool = True,
242
+ collate_fn: Callable | None = None,
243
+ max_steps: int = 1,
244
+ input_mode: Literal["history", "text", "tokens"] = "history",
245
+ ray_backend: bool = False,
246
+ dataloader_actor_name: str | None = None,
247
+ ):
248
+ if ray_backend and dataloader_actor_name is None:
249
+ dataloader_actor_name = "ifeval_dataloader"
250
+ if collate_fn is None:
251
+ collate_fn = _collate_fn
252
+ super().__init__(
253
+ dataset=dataset,
254
+ shuffle=shuffle,
255
+ num_envs=num_envs,
256
+ repeats=repeats,
257
+ batch_size_dl=batch_size_dl,
258
+ seed=seed,
259
+ group_repeats=group_repeats,
260
+ tokenizer=tokenizer,
261
+ device=device,
262
+ template_kwargs=template_kwargs,
263
+ apply_template=apply_template,
264
+ collate_fn=collate_fn,
265
+ input_mode=input_mode,
266
+ data_key="query",
267
+ primers=IFEvalData.default_spec((num_envs,), device),
268
+ ray_backend=ray_backend,
269
+ dataloader_actor_name=dataloader_actor_name,
270
+ )
271
+ if max_steps:
272
+ self.append_transform(StepCounter(max_steps=max_steps))
273
+ if compute_reward:
274
+ self.append_transform(IfEvalScorer())