torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,789 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import warnings
8
+ from collections.abc import Callable
9
+
10
+ from typing import Any, Literal, TYPE_CHECKING
11
+
12
+ import torch
13
+
14
+ from tensordict import (
15
+ is_leaf_nontensor,
16
+ LazyStackedTensorDict,
17
+ NestedKey,
18
+ set_list_to_stack,
19
+ TensorDict,
20
+ TensorDictBase,
21
+ unravel_key,
22
+ )
23
+ from tensordict.tensorclass import NonTensorData, NonTensorStack
24
+ from tensordict.utils import _zip_strict
25
+ from torch.utils.data import DataLoader
26
+
27
+ from torchrl._utils import _replace_last
28
+ from torchrl.data.map.hash import SipHash
29
+ from torchrl.data.tensor_specs import (
30
+ Bounded,
31
+ Categorical as CategoricalSpec,
32
+ Composite,
33
+ NonTensor,
34
+ Unbounded,
35
+ )
36
+ from torchrl.envs import EnvBase
37
+ from torchrl.envs.utils import _StepMDP
38
+ from torchrl.modules.utils.utils import _unpad_tensors
39
+
40
+ if TYPE_CHECKING:
41
+ import transformers
42
+
43
+
44
+ class LLMEnv(EnvBase):
45
+ """A text generation environment for language models.
46
+
47
+ This environment is designed to work with language models, where the observation is a string or a tensor of
48
+ integers representing a sequence of tokens. The action is also a string or a tensor of integers, which is
49
+ concatenated to the previous observation to form the new observation.
50
+
51
+ By default, this environment is meant to track history for a prompt. Users can append transforms to tailor
52
+ this to their use case, such as Chain of Thought (CoT) reasoning or other custom processing.
53
+
54
+ Users must append a transform to set the "done" condition, which would trigger the loading of the next prompt.
55
+ Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via
56
+ :meth:`~from_dataloader`.
57
+
58
+ .. note:: The default arguments of the `LLMEnv` class are set to make it easy to run this environment with
59
+ the vllm backend (:class:`~torchrl.modules.vLLMWrapper`).
60
+
61
+ Keyword Args:
62
+ token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `from_text=False`).
63
+ Defaults to ``"tokens"``.
64
+ str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `from_text=True`).
65
+ Defaults to ``"text"``.
66
+ attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored.
67
+ Defaults to ``"attention_mask"``.
68
+ action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to
69
+ ``"tokens_response"`` or ``"text_response"``.
70
+ reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
71
+ Defaults to ``"reward"``.
72
+ from_text (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True``.
73
+ device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``.
74
+ vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an
75
+ unbounded vocabulary. Defaults to ``None``.
76
+ has_attention (bool, optional): If ``True``, an attention mask is to be used under the key indicated by
77
+ :attr:`attention_key`. Defaults to ``True``.
78
+ assign_reward (bool, optional): If ``True``, a zero-valued reward of shape equal to the action shape
79
+ is written during calls to `step()`. Defaults to ``False``.
80
+ assign_done (bool, optional): If ``True``, a zero-valued done and terminated state of shape equal to the
81
+ action shape is written during calls to `step()`. Defaults to ``False``.
82
+
83
+ .. note:: Regardless of the value assigned to `assign_done`, a done state will be written at the root
84
+ as it is a requirement for all TorchRL environments.
85
+
86
+ batch_size (int or torch.Size, optional): Batch size of the environment.
87
+ If left empty, an empty batch-size is assumed.
88
+ The batch size can be null (`torch.Size([])`) or one-dimensional. Batchless environments are not supported.
89
+
90
+ .. note:: When using a :class:`~torchrl.envs.DataLoadingPrimer` transform, the batch-size of the env
91
+ and the transform should match.
92
+
93
+ eos_token_id (int, optional): The token id of the end of the sequence. If passed, the `done` state
94
+ is set to `True` when detected. Defaults to `None`.
95
+
96
+ .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples.
97
+
98
+ Methods:
99
+ from_dataloader: Creates an LLMEnv instance from a dataloader.
100
+
101
+ """
102
+
103
+ _DEFAULT_TOKEN_KEY = "tokens"
104
+ _DEFAULT_STR_KEY = "text"
105
+ _DEFAULT_ATTENTION_KEY = "attention_mask"
106
+ _DEFAULT_ACTION_TOKENS_KEY = "tokens_response"
107
+ _DEFAULT_ACTION_STR_KEY = "text_response"
108
+
109
+ def __init__(
110
+ self,
111
+ *,
112
+ token_key: NestedKey | None = None,
113
+ str_key: NestedKey | None = None,
114
+ attention_key: NestedKey | None = None,
115
+ action_key: NestedKey | None = None,
116
+ reward_key: NestedKey = "reward",
117
+ from_text: bool = True,
118
+ device: torch.device | None = None,
119
+ vocab_size: int | None = None,
120
+ assign_reward: bool = False,
121
+ assign_done: bool = False,
122
+ batch_size: int | torch.Size | None = None,
123
+ has_attention: bool = True,
124
+ # Experimental
125
+ as_llm_data: bool = False,
126
+ eos_token_id: int | None = None,
127
+ ) -> None:
128
+ self._warn_deprecated()
129
+ self.as_llm_data = as_llm_data
130
+ if token_key is None:
131
+ token_key = self._DEFAULT_TOKEN_KEY
132
+ if str_key is None:
133
+ str_key = self._DEFAULT_STR_KEY
134
+ if attention_key is None:
135
+ attention_key = self._DEFAULT_ATTENTION_KEY
136
+ if action_key is None:
137
+ if from_text:
138
+ action_key = self._DEFAULT_ACTION_STR_KEY
139
+ else:
140
+ action_key = self._DEFAULT_ACTION_TOKENS_KEY
141
+ self._batch_locked = True
142
+ if batch_size is None:
143
+ batch_size = ()
144
+ else:
145
+ if not isinstance(batch_size, (tuple, list)):
146
+ batch_size = (batch_size,)
147
+ elif len(batch_size) > 1:
148
+ raise TypeError(
149
+ f"batch-size of LLMEnv must be 0 or 1d. Got batch_size={batch_size}."
150
+ )
151
+ super().__init__(
152
+ device=device,
153
+ batch_size=batch_size,
154
+ )
155
+ self.has_attention = has_attention
156
+ self.from_text = from_text
157
+ self.vocab_size = vocab_size
158
+ self.token_key = unravel_key(token_key)
159
+ self.str_key = unravel_key(str_key)
160
+ if attention_key is not None:
161
+ attention_key = unravel_key(attention_key)
162
+ self.attention_key = attention_key
163
+ self.assign_reward = assign_reward
164
+ self.assign_done = assign_done
165
+ self.eos_token_id = eos_token_id
166
+ if eos_token_id is None:
167
+ warnings.warn(
168
+ "eos_token_id is missing. This means that the environment will not be able to capture its "
169
+ "done state automatically. This may lead to undefined behaviors when the generated text reaches "
170
+ "an eos_token.",
171
+ category=UserWarning,
172
+ )
173
+
174
+ # self.action_key = unravel_key(action_key)
175
+ if from_text:
176
+ self.full_observation_spec_unbatched = Composite(
177
+ {
178
+ self.str_key: NonTensor(
179
+ example_data="a string",
180
+ batched=True,
181
+ shape=(),
182
+ device=device,
183
+ )
184
+ }
185
+ )
186
+ self.full_action_spec_unbatched = Composite(
187
+ {
188
+ action_key: NonTensor(
189
+ example_data="a string", batched=True, shape=(), device=device
190
+ )
191
+ }
192
+ )
193
+ else:
194
+ if vocab_size is None:
195
+ observation_spec = {
196
+ token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device)
197
+ }
198
+ if self.has_attention:
199
+ observation_spec[attention_key] = Unbounded(
200
+ shape=(-1,), dtype=torch.int64, device=device
201
+ )
202
+ self.full_observation_spec_unbatched = Composite(observation_spec)
203
+ self.full_action_spec_unbatched = Composite(
204
+ {
205
+ action_key: Unbounded(
206
+ shape=(-1,), dtype=torch.int64, device=device
207
+ )
208
+ }
209
+ )
210
+ else:
211
+ self.full_observation_spec_unbatched = Composite(
212
+ {
213
+ token_key: Bounded(
214
+ shape=(-1,),
215
+ dtype=torch.int64,
216
+ low=0,
217
+ high=vocab_size,
218
+ device=device,
219
+ )
220
+ }
221
+ )
222
+ self.full_action_spec_unbatched = Composite(
223
+ {
224
+ action_key: Bounded(
225
+ shape=(-1,),
226
+ dtype=torch.int64,
227
+ low=0,
228
+ high=vocab_size,
229
+ device=device,
230
+ )
231
+ }
232
+ )
233
+ STR2STR_ERR = ValueError(
234
+ "from_text cannot be True when either of assign_reward / assign_done are True. "
235
+ "Tokens are required to compute the reward shape."
236
+ )
237
+ if self.assign_reward:
238
+ if self.from_text:
239
+ raise STR2STR_ERR
240
+ self.full_reward_spec_unbatched = Composite(
241
+ {reward_key: Unbounded(shape=(-1,), device=device)}
242
+ )
243
+ else:
244
+ self.full_reward_spec_unbatched = Composite(device=device)
245
+
246
+ if not self.assign_done:
247
+ # Use single done
248
+ self.full_done_spec_unbatched = Composite(
249
+ done=Unbounded(shape=(1,), dtype=torch.bool, device=device),
250
+ terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device),
251
+ )
252
+ elif self.from_text:
253
+ raise STR2STR_ERR
254
+ else:
255
+ # Use single done
256
+ self.full_done_spec_unbatched = Composite(
257
+ tokens_data=Composite(
258
+ done=Unbounded(shape=(-1,), dtype=torch.bool, device=device),
259
+ terminated=Unbounded(shape=(-1,), dtype=torch.bool, device=device),
260
+ ),
261
+ done=Unbounded(shape=(1,), dtype=torch.bool, device=device),
262
+ terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device),
263
+ )
264
+
265
+ @classmethod
266
+ def _warn_deprecated(cls):
267
+ warnings.warn(
268
+ "LLMEnv is deprecated. Please use ChatEnv instead.",
269
+ category=DeprecationWarning,
270
+ )
271
+
272
+ @classmethod
273
+ def from_dataloader(
274
+ cls,
275
+ dataloader: DataLoader,
276
+ *,
277
+ tokenizer: transformers.PretrainedTokenizerBase | None = None, # noqa
278
+ token_key: NestedKey | None = None,
279
+ str_key: NestedKey | None = None,
280
+ attention_key: NestedKey | None = None,
281
+ action_key: NestedKey | None = None,
282
+ reward_key: NestedKey = "reward",
283
+ from_text: bool = True,
284
+ device: torch.device | None = None,
285
+ vocab_size: int | None = None,
286
+ batch_size: int | torch.Size | None = None,
287
+ has_attention: bool = True,
288
+ assign_reward: bool = False,
289
+ assign_done: bool = False,
290
+ primers: Composite | None = None,
291
+ example_data: Any = None,
292
+ stack_method: Callable[[Any], Any]
293
+ | Literal["as_nested_tensor", "as_padded_tensor"] = None,
294
+ repeats: int | None = None,
295
+ group_repeats: bool = True,
296
+ eos_token_id: int | None = None,
297
+ ) -> LLMEnv:
298
+ """Creates an LLMEnv instance from a dataloader.
299
+
300
+ This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which populates ``data_keys`` (by default ``observation_key``) with data from the provided dataloader when the environment is reset.
301
+
302
+ Args:
303
+ dataloader (DataLoader): The dataloader to load data from.
304
+
305
+ Keyword Args:
306
+ tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
307
+ "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
308
+ pre-trained tokenizer.
309
+
310
+ .. note:: Using the `tokenizer` will append a :class:`~torchrl.envs.Tokenizer` transform to the environment.
311
+ If `from_text` is set to `True`, the tokenizer will be called during every iteration and the rollout
312
+ will contain both tokens and text data.
313
+ If `from_text` is set to `False`, the tokenizer will be called during reset only, and the only
314
+ text data in the rollout will be the text sampled from the dataset.
315
+
316
+ token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `from_text=False`).
317
+ Defaults to ``("tokens_in", "input_ids")``.
318
+ str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `from_text=True`).
319
+ Defaults to ``"test"``.
320
+ attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored.
321
+ Defaults to ``("tokens_in", "input_ids")``
322
+ action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to
323
+ ``("tokens_out", "sequences")``.
324
+ reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
325
+ Defaults to ``"reward"``.
326
+ from_text (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True``.
327
+ device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``.
328
+ vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an
329
+ unbounded vocabulary. Defaults to ``None``.
330
+ has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by
331
+ :attr:`attention_key`. Defaults to ``True``.
332
+ assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape
333
+ is written during calls to `step()`. Defaults to ``False``.
334
+ assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the
335
+ action shape is written during calls to `step()`. Defaults to ``False``.
336
+
337
+ .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root
338
+ as it is a requirement for all TorchRL environments.
339
+
340
+ batch_size (int or torch.Size, optional): Batch size of the environment.
341
+ If left empty, the batch size is inferred from `dataloader.batch_size` if that attribute exists, otherwise
342
+ it is set to `()`.
343
+ The batch size can be null (`torch.Size([])`) or one-dimensional. Batchless environments are not supported.
344
+
345
+ .. note:: When using a :class:`~torchrl.envs.DataLoadingPrimer` transform, the batch-size of the env
346
+ and the transform should match.
347
+
348
+ primers (Composite | None, optional): The primers to use for each key in the dataloader.
349
+ Defaults to ``None`` (inferred automatically from the first batch of data).
350
+ stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The
351
+ method to use for stacking the data. Defaults to ``None``.
352
+ repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
353
+ situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
354
+ samples (rather than an advantage module).
355
+ group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
356
+ all repeats are grouped in a single batch collected from the buffer. Defaults to ``True``.
357
+ eos_token_id (int, optional): The token id of the end of the sequence. If passed, the `done` state
358
+ is set to `True` when detected. Defaults to `None`.
359
+
360
+ Returns:
361
+ LLMEnv: The created LLMEnv instance.
362
+ """
363
+ cls._warn_deprecated()
364
+
365
+ from torchrl.envs.llm import DataLoadingPrimer, Tokenizer
366
+
367
+ if str_key is None:
368
+ str_key = LLMEnv._DEFAULT_STR_KEY
369
+ if token_key is None:
370
+ token_key = LLMEnv._DEFAULT_TOKEN_KEY
371
+ if attention_key is None:
372
+ attention_key = LLMEnv._DEFAULT_ATTENTION_KEY
373
+ elif tokenizer is not None and attention_key != _replace_last(
374
+ token_key, "attention_mask"
375
+ ):
376
+ raise ValueError(
377
+ "When using the Tokenizer, attention key must match `(*token_key[:-1], 'attention_mask')` where "
378
+ f"`token_key` is a tuple-typed nested key. Got attention_key={attention_key} while expecting "
379
+ f"{_replace_last(token_key, 'attention_mask')}."
380
+ )
381
+
382
+ if tokenizer is not None:
383
+ if from_text:
384
+ # In this case, the tokenizer is appended to the env after each step
385
+ if action_key is None:
386
+ action_key = cls._DEFAULT_ACTION_STR_KEY
387
+ tokenizer_transform = Tokenizer(
388
+ tokenizer=tokenizer,
389
+ in_keys=[str_key],
390
+ out_keys=[token_key],
391
+ # Assume that the tokens are named according to _DEFAULT_ACTION_TOKENS_KEY
392
+ in_keys_inv=[action_key],
393
+ out_keys_inv=[cls._DEFAULT_ACTION_TOKENS_KEY],
394
+ call_before_reset=False,
395
+ # We should always see the required entries
396
+ missing_tolerance=False,
397
+ )
398
+ else:
399
+ # FIXME: This is broken - do we need it anyway?
400
+ raise RuntimeError(
401
+ "tokenizers can only be used whenever from_text is set to `True`."
402
+ )
403
+
404
+ primer = DataLoadingPrimer(
405
+ dataloader=dataloader,
406
+ primers=primers,
407
+ stack_method=stack_method,
408
+ repeats=repeats,
409
+ device=device,
410
+ group_repeats=group_repeats,
411
+ batch_size=batch_size,
412
+ )
413
+ env = LLMEnv(
414
+ from_text=from_text,
415
+ device=device,
416
+ token_key=token_key,
417
+ str_key=str_key,
418
+ attention_key=attention_key,
419
+ action_key=action_key,
420
+ reward_key=reward_key,
421
+ vocab_size=vocab_size,
422
+ assign_reward=assign_reward,
423
+ assign_done=assign_done,
424
+ batch_size=primer.batch_size,
425
+ has_attention=has_attention,
426
+ eos_token_id=eos_token_id,
427
+ )
428
+ if tokenizer is not None:
429
+ env = env.append_transform(tokenizer_transform)
430
+ return env.append_transform(primer)
431
+
432
+ @staticmethod
433
+ def _check_obs_act_and_cat(obs, action, *, device):
434
+ if not isinstance(obs, str):
435
+ raise TypeError(f"Observation must be a string, got {type(obs)}.")
436
+ if not isinstance(action, str):
437
+ raise TypeError(f"Action must be a string, got {type(action)}.")
438
+ return NonTensorData(obs + action, device=device)
439
+
440
+ def _step(
441
+ self,
442
+ tensordict: TensorDictBase,
443
+ ) -> TensorDictBase:
444
+ next_td = tensordict.empty()
445
+ self._make_next_obs(tensordict, next_td)
446
+ self._maybe_make_reward(tensordict, next_td)
447
+ self._maybe_make_done(tensordict, next_td)
448
+ if self.as_llm_data:
449
+ raise NotImplementedError()
450
+ return next_td
451
+
452
+ def _maybe_make_reward(
453
+ self, tensordict: TensorDictBase, next_td: TensorDictBase
454
+ ) -> TensorDictBase:
455
+ if self.assign_reward:
456
+ next_td.set(
457
+ self.reward_key,
458
+ torch.zeros_like(
459
+ tensordict.get(self.action_key), dtype=self.reward_spec.dtype
460
+ ),
461
+ )
462
+ return next_td
463
+
464
+ def _maybe_make_done(
465
+ self,
466
+ tensordict: TensorDictBase,
467
+ next_td: TensorDictBase,
468
+ resetting: bool = False,
469
+ ) -> TensorDictBase:
470
+ if self.assign_done:
471
+ action = tensordict.get(self.action_key)
472
+ if action is None:
473
+ done = torch.zeros(
474
+ tensordict.shape + (1,), dtype=torch.bool, device=self.device
475
+ )
476
+ else:
477
+ done = torch.zeros_like(action, dtype=torch.bool)
478
+ next_td.set(("tokens_data", "terminated"), done)
479
+ next_td.set(("tokens_data", "done"), done.clone())
480
+ next_td.set(
481
+ "done", next_td.get(("tokens_data", "done")).any(-1, keepdim=True)
482
+ )
483
+ next_td.set(
484
+ "terminated",
485
+ next_td.get(("tokens_data", "terminated")).any(-1, keepdim=True),
486
+ )
487
+ if not resetting and self.eos_token_id is not None:
488
+ if self.from_text:
489
+ token_action_key = self._DEFAULT_ACTION_TOKENS_KEY
490
+ else:
491
+ token_action_key = self.action_key
492
+ action = tensordict.get(
493
+ token_action_key, as_padded_tensor=True, padding_value=-1
494
+ )
495
+ mask = action == -1
496
+
497
+ if action is None:
498
+ raise RuntimeError(
499
+ f"Couldn't find the tokenized action with key {token_action_key} to set the done state in tensordict "
500
+ f"with keys {list(tensordict.keys(True))}."
501
+ )
502
+ full_done = action == self.eos_token_id
503
+ done = full_done.any(-1, keepdim=True)
504
+ next_td.set("done", done)
505
+ next_td.set("terminated", done)
506
+ if self.assign_done:
507
+ full_done = _unpad_tensors(full_done, mask)
508
+ next_td.set(("tokens_data", "terminated"), full_done)
509
+ next_td.set(("tokens_data", "done"), full_done)
510
+ return next_td
511
+
512
+ def _make_next_obs(
513
+ self, tensordict: TensorDictBase, nex_td: TensorDictBase
514
+ ) -> TensorDictBase:
515
+ # Cat action entry with prev obs
516
+ if self.from_text:
517
+ obs = tensordict[self.str_key]
518
+ action = tensordict[self.action_key]
519
+ if not tensordict.batch_size:
520
+ if not isinstance(obs, str) or not isinstance(action, str):
521
+ raise TypeError(
522
+ "The tensordict is batchless, yet the action and/or observations are not "
523
+ f"strings but {type(action)} and {type(obs)}, respectivly."
524
+ )
525
+ observation = self._check_obs_act_and_cat(
526
+ obs, action, device=self.device
527
+ )
528
+ else:
529
+ observation = NonTensorStack(
530
+ *[
531
+ self._check_obs_act_and_cat(_obs, _action, device=self.device)
532
+ for (_obs, _action) in _zip_strict(obs, action)
533
+ ]
534
+ )
535
+ return nex_td.set(self.str_key, observation)
536
+ else:
537
+ try:
538
+ obs: torch.Tensor = tensordict.get(self.token_key)
539
+ action = tensordict.get(self.action_key)
540
+ if getattr(obs, "is_nested", False):
541
+ observation = torch.nested.as_nested_tensor(
542
+ [
543
+ torch.cat([_obs, _action], -1)
544
+ for _obs, _action in _zip_strict(
545
+ obs.unbind(0), action.unbind(0)
546
+ )
547
+ ],
548
+ layout=obs.layout,
549
+ )
550
+ else:
551
+ observation = torch.cat([obs, action], -1)
552
+ if self.has_attention:
553
+ attention_mask = tensordict.get(self.attention_key)
554
+ attention_mask = torch.cat(
555
+ [attention_mask, attention_mask.new_ones(action.shape)], -1
556
+ )
557
+ nex_td.set(self.attention_key, attention_mask)
558
+ except TypeError:
559
+ raise TypeError(
560
+ "Failed to cat action and observation tensors. Check that from_text argument is correctly "
561
+ f"set in {type(self).__name__}."
562
+ )
563
+ return nex_td.set(self.token_key, observation)
564
+
565
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
566
+ # We should have an observation by this time, if not raise an exception
567
+ def check_token():
568
+ return not self.from_text and (
569
+ self.token_key not in tensordict.keys(isinstance(self.token_key, tuple))
570
+ )
571
+
572
+ def check_str():
573
+ return self.from_text and (
574
+ self.str_key not in tensordict.keys(isinstance(self.str_key, tuple))
575
+ )
576
+
577
+ if tensordict is None or check_token() or check_str():
578
+ raise KeyError(
579
+ f"Observation key {self.token_key}/{self.str_key} is not defined in tensordict with keys "
580
+ f"{list(tensordict.keys(True, True, is_leaf=is_leaf_nontensor))}. Make sure a TensorDictPrimer (eg, "
581
+ f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms."
582
+ )
583
+ if not isinstance(tensordict, LazyStackedTensorDict) and tensordict.ndim:
584
+ tensordict = LazyStackedTensorDict(*tensordict.unbind(0))
585
+ td_reset = tensordict.copy()
586
+ if td_reset.device != self.device:
587
+ if self.device is None:
588
+ td_reset.clear_device_()
589
+ else:
590
+ td_reset = td_reset.to(self.device)
591
+ tensordict = self._maybe_make_done(tensordict, td_reset, resetting=True)
592
+ if self.as_llm_data:
593
+ raise NotImplementedError()
594
+ return tensordict
595
+
596
+ def _set_seed(self, seed: int | None):
597
+ return seed
598
+
599
+
600
+ class LLMHashingEnv(EnvBase):
601
+ """A text generation environment that uses a hashing module to identify unique observations.
602
+
603
+ The primary goal of this environment is to identify token chains using a hashing function.
604
+ This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
605
+ identifiers, or easily prune repeated token chains in a data structure.
606
+
607
+ .. The following figure gives an overview of this workflow:
608
+ .. .. figure:: /_static/img/rollout-llm.png
609
+ .. :alt: Data collection loop with our LLM environment.
610
+
611
+ Args:
612
+ vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
613
+
614
+ Keyword Args:
615
+ hashing_module (Callable[[torch.Tensor], torch.Tensor], optional):
616
+ A hashing function that takes a tensor as input and returns a hashed tensor.
617
+ Defaults to :class:`~torchrl.data.SipHash` if not provided.
618
+ observation_key (NestedKey, optional): The key for the observation in the TensorDict.
619
+ Defaults to "observation".
620
+ text_output (bool, optional): Whether to include the text output in the observation.
621
+ Defaults to `True`.
622
+ tokenizer (transformers.Tokenizer | None, optional):
623
+ A tokenizer function that converts text to tensors.
624
+ Only used when `text_output` is `True`.
625
+ Must implement the following methods: `decode` and `batch_decode`.
626
+ Defaults to ``None``.
627
+ text_key (NestedKey | None, optional): The key for the text output in the TensorDict.
628
+ Defaults to "text".
629
+
630
+ Examples:
631
+ >>> from tensordict import TensorDict
632
+ >>> from torchrl.envs import LLMHashingEnv
633
+ >>> from transformers import GPT2Tokenizer
634
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
635
+ >>> x = tokenizer(["Check out TorchRL!"])["input_ids"]
636
+ >>> env = LLMHashingEnv(tokenizer=tokenizer)
637
+ >>> td = TensorDict(observation=x, batch_size=[1])
638
+ >>> td = env.reset(td)
639
+ >>> print(td)
640
+ TensorDict(
641
+ fields={
642
+ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
643
+ hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
644
+ observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False),
645
+ terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
646
+ text: NonTensorStack(
647
+ ['Check out TorchRL!'],
648
+ batch_size=torch.Size([1]),
649
+ device=None)},
650
+ batch_size=torch.Size([1]),
651
+ device=None,
652
+ is_shared=False)
653
+
654
+ """
655
+
656
+ def __init__(
657
+ self,
658
+ vocab_size: int | None = None,
659
+ *,
660
+ hashing_module: Callable[[torch.Tensor], torch.Tensor] = None,
661
+ observation_key: NestedKey = "observation",
662
+ text_output: bool = True,
663
+ tokenizer: Callable[[str | list[str]], torch.Tensor] | None = None,
664
+ text_key: NestedKey | None = "text",
665
+ ):
666
+ super().__init__()
667
+ if vocab_size is None:
668
+ if tokenizer is None:
669
+ raise TypeError(
670
+ "You must provide a vocab_size integer if tokenizer is `None`."
671
+ )
672
+ vocab_size = tokenizer.vocab_size
673
+ self._batch_locked = False
674
+ if hashing_module is None:
675
+ hashing_module = SipHash()
676
+
677
+ self._hashing_module = hashing_module
678
+ self._tokenizer = tokenizer
679
+ self.observation_key = observation_key
680
+ observation_spec = {
681
+ observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)),
682
+ "hashing": Unbounded(shape=(1,), dtype=torch.int64),
683
+ }
684
+ self.text_output = text_output
685
+ if not text_output:
686
+ text_key = None
687
+ elif text_key is None:
688
+ text_key = "text"
689
+ if text_key is not None:
690
+ observation_spec[text_key] = NonTensor(shape=())
691
+ self.text_key = text_key
692
+ self.observation_spec = Composite(observation_spec)
693
+ self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,)))
694
+ _StepMDP(self)
695
+
696
+ @set_list_to_stack(True)
697
+ def make_tensordict(self, input: str | list[str]) -> TensorDict:
698
+ """Converts a string or list of strings in a TensorDict with appropriate shape and device."""
699
+ list_len = len(input) if isinstance(input, list) else 0
700
+ tensordict = TensorDict(
701
+ {self.observation_key: self._tokenizer(input)}, device=self.device
702
+ )
703
+ if list_len:
704
+ tensordict.batch_size = [list_len]
705
+ return self.reset(tensordict)
706
+
707
+ def _reset(self, tensordict: TensorDictBase):
708
+ """Initializes the environment with a given observation.
709
+
710
+ Args:
711
+ tensordict (TensorDictBase): A TensorDict containing the initial observation.
712
+
713
+ Returns:
714
+ A TensorDict containing the initial observation, its hash, and other relevant information.
715
+
716
+ """
717
+ out = tensordict.empty()
718
+ obs = tensordict.get(self.observation_key, None)
719
+ if obs is None:
720
+ raise RuntimeError(
721
+ f"Resetting the {type(self).__name__} environment requires a prompt."
722
+ )
723
+ if self.text_output:
724
+ if obs.ndim > 1:
725
+ text = self._tokenizer.batch_decode(obs)
726
+ text = NonTensorStack.from_list(text)
727
+ else:
728
+ text = self._tokenizer.decode(obs)
729
+ text = NonTensorData(text)
730
+ out.set(self.text_key, text)
731
+
732
+ if obs.ndim > 1:
733
+ out.set("hashing", self._hashing_module(obs).unsqueeze(-1))
734
+ else:
735
+ out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1))
736
+
737
+ if not self.full_done_spec.is_empty():
738
+ out.update(self.full_done_spec.zero(tensordict.shape))
739
+ else:
740
+ out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool))
741
+ out.set(
742
+ "terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)
743
+ )
744
+ return out
745
+
746
+ def _step(self, tensordict):
747
+ """Takes an action (i.e., the next token to generate) and returns the next observation and reward.
748
+
749
+ Args:
750
+ tensordict: A TensorDict containing the current observation and action.
751
+
752
+ Returns:
753
+ A TensorDict containing the next observation, its hash, and other relevant information.
754
+ """
755
+ out = tensordict.empty()
756
+ action = tensordict.get("action")
757
+ obs = torch.cat([tensordict.get(self.observation_key), action], -1)
758
+ kwargs = {self.observation_key: obs}
759
+
760
+ catval = torch.cat([tensordict.get("hashing"), action], -1)
761
+ if obs.ndim > 1:
762
+ new_hash = self._hashing_module(catval).unsqueeze(-1)
763
+ else:
764
+ new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1)
765
+
766
+ if self.text_output:
767
+ if obs.ndim > 1:
768
+ text = self._tokenizer.batch_decode(obs)
769
+ text = NonTensorStack.from_list(text)
770
+ else:
771
+ text = self._tokenizer.decode(obs)
772
+ text = NonTensorData(text)
773
+ kwargs[self.text_key] = text
774
+ kwargs.update(
775
+ {
776
+ "hashing": new_hash,
777
+ "done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool),
778
+ "terminated": torch.zeros(
779
+ (*tensordict.batch_size, 1), dtype=torch.bool
780
+ ),
781
+ }
782
+ )
783
+ return out.update(kwargs)
784
+
785
+ def _set_seed(self, *args):
786
+ """Sets the seed for the environment's randomness.
787
+
788
+ .. note:: This environment has no randomness, so this method does nothing.
789
+ """