torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,869 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import functools
9
+
10
+ import json
11
+ import os
12
+ import re
13
+ from contextlib import contextmanager
14
+ from dataclasses import asdict
15
+ from pathlib import Path
16
+ from typing import Any, Literal, TYPE_CHECKING
17
+
18
+ import numpy as np
19
+
20
+ import torch
21
+ from tensordict import NestedKey, NonTensorData, TensorDict, TensorDictBase
22
+ from tensordict.tensorclass import is_non_tensor
23
+
24
+ from torchrl._utils import logger as torchrl_logger
25
+ from torchrl.data import Choice, Composite, NonTensor
26
+ from torchrl.data.llm import History
27
+ from torchrl.envs import ConditionalSkip, GymWrapper, Transform, TransformedEnv
28
+
29
+ if TYPE_CHECKING:
30
+ import mlgym
31
+ import transformers
32
+
33
+ # Inv transforms:
34
+ # Transforms to apply prior to pass the model output to the env
35
+
36
+
37
+ @contextmanager
38
+ def _temp_cwd_mlgym():
39
+ """Temporarily change the current working directory to mlgym."""
40
+ import mlgym
41
+
42
+ path = Path(mlgym.__spec__.submodule_search_locations[0]).parent
43
+ old_pwd = os.getcwd()
44
+ os.chdir(str(path))
45
+ # sys.path.insert(-1, "mlgym")
46
+ try:
47
+ yield
48
+ finally:
49
+ # sys.path.pop()
50
+ os.chdir(old_pwd)
51
+
52
+
53
+ class MLGymBaseTransform(Transform):
54
+ """Base class for all MLGym transforms."""
55
+
56
+ @property
57
+ def config(self):
58
+ return self.parent.base_env.config
59
+
60
+ @property
61
+ def system_args(self):
62
+ return {
63
+ "command_docs": self.config.tools_handler.command_docs,
64
+ **self.config.tools_handler.env_variables,
65
+ }
66
+
67
+ @property
68
+ def task_args(self):
69
+ # Placeholder
70
+ task_args = getattr(self, "_task_args", None)
71
+ if task_args is None:
72
+ return self.parent.base_env.task.args
73
+ return task_args
74
+
75
+ @task_args.setter
76
+ def task_args(self, task_args):
77
+ self._task_args = task_args
78
+
79
+ @property
80
+ def name(self):
81
+ return "torchrl"
82
+
83
+ @property
84
+ def state_command(self):
85
+ return self.config.state_command.name
86
+
87
+ @property
88
+ def agent_args(self):
89
+ return self.parent.base_env.agent_args
90
+
91
+ @property
92
+ def model_name(self) -> Literal["human", "human_thought"]:
93
+ return self.agent_args.model.model_name
94
+
95
+
96
+ #######################################################
97
+ # Forward transforms: Format the env output
98
+
99
+
100
+ # Transform #0: Resets the env
101
+ class ResetModule(MLGymBaseTransform):
102
+ """Runs setup pipeline and enables multi-resets.
103
+
104
+ The reset method reads the 'system' initial input from the config and parses it to a History
105
+ object.
106
+
107
+ """
108
+
109
+ response_key: NestedKey = "text_response"
110
+
111
+ def __init__(self):
112
+ super().__init__(in_keys=[], out_keys=["history"])
113
+
114
+ @_temp_cwd_mlgym()
115
+ def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
116
+ base_env = self.parent.base_env._env
117
+ if tensordict is not None and "task" in tensordict:
118
+ import gymnasium as gym
119
+
120
+ task = tensordict["task"]
121
+ torchrl_logger.info(f"Resetting with {task=}")
122
+ if is_non_tensor(task):
123
+ task = task.data
124
+ task_id, agent_args = _TASK_IDS[task]
125
+ try:
126
+ base_env.close()
127
+ except Exception:
128
+ torchrl_logger.info(f"Failed to close {base_env=}")
129
+ base_env = gym.make(
130
+ f"mlgym/{task}",
131
+ devices=["cpu_0"],
132
+ ).unwrapped
133
+ base_env.config = agent_args.config
134
+ self.parent.base_env.set_env(base_env)
135
+ base_env.reset_container()
136
+ base_env.communicate(f"cd {Path(base_env.task_workspace).parent}")
137
+ return tensordict
138
+
139
+ @_temp_cwd_mlgym()
140
+ def _reset(
141
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
142
+ ) -> TensorDictBase:
143
+ # TODO: what to do with this?
144
+ # reset model stats
145
+ # self.model.reset_stats(init_model_stats)
146
+ # env = self.parent.base_env._env
147
+
148
+ env = self.parent.base_env._env
149
+ self.set_environment_vars(env, self.config.env_variables)
150
+
151
+ system_msg = self.config.system_template.format(
152
+ **self.system_args, **asdict(self.task_args)
153
+ )
154
+ # self.logger.log(self._default_logging_level, f"SYSTEM ({self.name})\n{system_msg}")
155
+ history = History(
156
+ role="system",
157
+ content=system_msg, # agent=self.name,
158
+ batch_size=(1,),
159
+ device=self.parent.device,
160
+ )
161
+ tensordict_reset["history"] = history
162
+
163
+ return tensordict_reset
164
+
165
+ def _step(
166
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
167
+ ) -> TensorDictBase:
168
+ # Placeholder
169
+ if "history" not in next_tensordict:
170
+ if "local_history" in tensordict:
171
+ local_history = tensordict["local_history"]
172
+ else:
173
+ local_history = None
174
+ history = tensordict["history"]
175
+ if local_history is not None:
176
+ history = history.append(local_history, inplace=False)
177
+ tensordict["history"] = history
178
+ next_tensordict["history"] = history
179
+ return next_tensordict
180
+
181
+ def set_environment_vars(
182
+ self, env: MLGymWrapper, env_variables: dict[str, Any]
183
+ ) -> None:
184
+ commands_to_execute = (
185
+ [self.config.state_command.code]
186
+ + # [code for code in self.config.util_functions] +
187
+ # [command.code for command in self.config._commands] +
188
+ [f"{k}={v}" for k, v in env_variables.items()]
189
+ )
190
+ commands = "\n".join(commands_to_execute)
191
+ try:
192
+ output = env.communicate(commands)
193
+ if env.returncode != 0:
194
+ msg = f"Nonzero return code: {env.returncode}\nOutput: {output}"
195
+ raise RuntimeError(msg)
196
+ except KeyboardInterrupt:
197
+ raise
198
+ except Exception as e:
199
+ raise e
200
+ command_files = []
201
+ for file in self.config.command_files:
202
+ datum = {}
203
+ with open(file) as f:
204
+ contents = f.read()
205
+ datum["contents"] = contents
206
+ filename = Path(file).name
207
+ if not contents.strip().startswith("#!"):
208
+ if filename.endswith(".sh"):
209
+ # files are sourced, so they are not executable
210
+ datum["name"] = Path(file).name
211
+ datum["type"] = "source_file"
212
+ elif filename.startswith("_"):
213
+ # files are sourced, so they are not executable
214
+ datum["name"] = Path(file).name
215
+ datum["type"] = "utility"
216
+ else:
217
+ msg = (
218
+ f"Non-shell script file {file} does not start with shebang.\n"
219
+ "Either add a shebang (#!) or change the file extension to .sh if you want to source it.\n"
220
+ "You can override this behavior by adding an underscore to the file name (e.g. _utils.py)."
221
+ )
222
+ raise ValueError(msg)
223
+ else:
224
+ # scripts are made executable
225
+ datum["name"] = Path(file).name.rsplit(".", 1)[0]
226
+ datum["type"] = "script"
227
+ command_files.append(datum)
228
+ # TODO: implement add commands method in environment
229
+ env.add_commands(command_files)
230
+
231
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
232
+ observation_spec["history"] = History.default_spec()
233
+ return observation_spec
234
+
235
+ def transform_action_spec(self, action_spec: Composite) -> Composite:
236
+ if isinstance(action_spec, Composite):
237
+ action_spec[self.response_key] = self.transform_action_spec(
238
+ action_spec[self.response_key]
239
+ )
240
+ return action_spec
241
+ # make the "random" action just a choice between innocuous bash commands
242
+ return Choice(
243
+ [
244
+ NonTensor(example_data="ls -rtlh", shape=action_spec.shape),
245
+ NonTensor(example_data="pwd", shape=action_spec.shape),
246
+ ]
247
+ )
248
+
249
+ def transform_state_spec(self, state_spec: Composite) -> Composite:
250
+ state_spec["history"] = History.default_spec()
251
+ return state_spec
252
+
253
+
254
+ class TaskSampler(Transform):
255
+ """A sampler for tasks in a certain task set."""
256
+
257
+ def __init__(self, tasks: list[str]):
258
+ super().__init__()
259
+ self.tasks = tasks
260
+
261
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
262
+ observation_spec["task"] = NonTensor(example_data="<a task>", shape=())
263
+ return observation_spec
264
+
265
+ @_temp_cwd_mlgym()
266
+ def _reset_env_preprocess(
267
+ self, tensordict: TensorDictBase | None
268
+ ) -> TensorDictBase:
269
+ if tensordict is None:
270
+ tensordict = TensorDict(batch_size=self.parent.batch_size)
271
+ # Sample a task
272
+ task = np.random.choice(self.tasks)
273
+ tensordict["task"] = NonTensorData(task)
274
+ self._current_task = task
275
+ return tensordict
276
+
277
+ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
278
+ next_tensordict["task"] = self._current_task
279
+ return next_tensordict
280
+
281
+
282
+ # Transform #1: env -> state
283
+ class ReadState(MLGymBaseTransform):
284
+ """Reads current state and writes it as a parsable str in the tensordict."""
285
+
286
+ # from mlgym/agent/base.py:BaseAgent:forward_model
287
+ def _step(
288
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
289
+ ) -> TensorDictBase:
290
+ base_mlgym_env = self.parent.base_env # getattr is forwarded
291
+
292
+ command = self.state_command
293
+ state = base_mlgym_env.communicate(command) if self.state_command else None
294
+
295
+ next_tensordict["state"] = state
296
+ return next_tensordict
297
+
298
+ def _reset(
299
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
300
+ ) -> TensorDictBase:
301
+ # tensordict_reset.setdefault("message", NonTensorData(""))
302
+ # tensordict_reset.setdefault("state", NonTensorData(""))
303
+ return self._step(tensordict_reset, tensordict_reset)
304
+
305
+ def transform_observation_spec(self, observation_spec):
306
+ observation_spec.set(
307
+ "state",
308
+ NonTensor(
309
+ example_data="a string",
310
+ device=observation_spec.device,
311
+ shape=observation_spec.shape,
312
+ ),
313
+ )
314
+ return observation_spec
315
+
316
+
317
+ # Transform #2: state -> message
318
+ class StateToMessage(MLGymBaseTransform):
319
+ """Parses the string using json to a given template.
320
+
321
+ Requires:
322
+ - a 'state' key from the ReadState transform
323
+ - an 'observation' key from the base environment
324
+
325
+ """
326
+
327
+ def _step(
328
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
329
+ ) -> TensorDictBase:
330
+ base_mlgym_env = self.parent.base_env # getattr is forwarded
331
+ observation = tensordict["observation"]
332
+ state = tensordict["state"]
333
+ config = self.config
334
+
335
+ current_step = base_mlgym_env.current_step
336
+ max_steps = base_mlgym_env.max_steps
337
+ try:
338
+ state_vars = json.loads(state)
339
+ except json.JSONDecodeError as e:
340
+ msg = f"State {state!r} is not valid json. This is an internal error, please report it."
341
+ raise ValueError(msg) from e
342
+ # add step information to state_vars
343
+ state_vars["current_step"] = current_step
344
+ state_vars["remaining_steps"] = max_steps - current_step
345
+
346
+ # FIXME: we don't need to do this, we have our own observation space
347
+ # Determine observation template based on what prior observation was
348
+
349
+ history: History = tensordict["history"]
350
+ if history[..., -1].role == "system":
351
+ # Show task template if prev. obs. was initial system message
352
+ templates = [config.task_template]
353
+ if config.strategy_template is not None:
354
+ templates.append(config.strategy_template)
355
+ elif observation is None or observation.strip() == "":
356
+ # Show no output template if observation content was empty
357
+ assert config.next_step_no_output_template is not None # linting
358
+ templates = [config.next_step_no_output_template]
359
+ else:
360
+ # Show standard output template if there is observation content
361
+ assert config.next_step_template is not None # linting
362
+ templates = [config.next_step_template]
363
+
364
+ # Format selected template(s) with information
365
+ messages = []
366
+ assert self.task_args is not None
367
+ for template in templates:
368
+ messages.append(
369
+ template.format(
370
+ **asdict(self.task_args),
371
+ **self.system_args,
372
+ **state_vars,
373
+ observation=(observation if observation is not None else ""),
374
+ # missing forwarded_vars because no attempts
375
+ ),
376
+ )
377
+
378
+ message = "\n".join(messages)
379
+ next_tensordict["message"] = message
380
+ # model query hooks here
381
+ return next_tensordict
382
+
383
+ def _reset(
384
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
385
+ ) -> TensorDictBase:
386
+ # tensordict_reset.setdefault("message", NonTensorData(""))
387
+ # tensordict_reset.setdefault("state", NonTensorData(""))
388
+ return self._step(tensordict_reset, tensordict_reset)
389
+
390
+ def transform_observation_spec(self, observation_spec):
391
+ observation_spec.set(
392
+ "message",
393
+ NonTensor(
394
+ example_data="a string",
395
+ device=observation_spec.device,
396
+ shape=observation_spec.shape,
397
+ ),
398
+ )
399
+ return observation_spec
400
+
401
+
402
+ # Transform #3: Append message to history
403
+ class MessageToHistory(MLGymBaseTransform):
404
+ """Parses the message string to a History object, then reparses the history to a complete message.
405
+
406
+ .. seealso:: HistoryToMessage
407
+
408
+ """
409
+
410
+ def __init__(self):
411
+ super().__init__(in_keys=["message", "history"], out_keys=["history", "chat"])
412
+
413
+ # from mlgym/agent/base.py:BaseAgent:local_history
414
+ # from mlgym/agent/base.py:BaseAgent:_append_history
415
+ def _step(
416
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
417
+ ) -> TensorDictBase:
418
+ # From PrepareDataForModel
419
+ message: str = next_tensordict["message"]
420
+ # from mlgym/agent/base.py:BaseAgent:forward_model
421
+ history = tensordict["history"]
422
+ cur_history = History(
423
+ role="user", content=message, batch_size=(), device=self.parent.device
424
+ )
425
+ # This is the basic thing our transform does: append the history to the existing one.
426
+ # (We should be able to extend the lazy stack directly)
427
+ history = history.append(cur_history, inplace=False)
428
+
429
+ next_tensordict["history"] = history
430
+ return next_tensordict
431
+
432
+ def _reset(
433
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
434
+ ) -> TensorDictBase:
435
+ return self._step(tensordict_reset, tensordict_reset)
436
+
437
+
438
+ # Inverse transforms:
439
+ # Format the action from the model for the env
440
+
441
+
442
+ class TemplateTransform(MLGymBaseTransform):
443
+ """A transform to apply the chat template to the History."""
444
+
445
+ response_key: NestedKey = "text_response"
446
+ prompt_key: NestedKey = "text"
447
+
448
+ # alternative to DummyFormat, wip
449
+ def __init__(
450
+ self,
451
+ in_keys=None,
452
+ out_keys=None,
453
+ in_keys_inv=None,
454
+ out_keys_inv=None,
455
+ tokenizer=None,
456
+ chat_template_name: Literal["chatml_format"] | None = None,
457
+ continue_final_message: bool = False,
458
+ tokenize: bool = False,
459
+ return_tensors: str = "pt",
460
+ return_dict: bool = False,
461
+ padding: bool | str = False,
462
+ truncation: bool | str = False,
463
+ ):
464
+ super().__init__(
465
+ in_keys=["history"] if in_keys is None else in_keys,
466
+ out_keys=[self.prompt_key] if out_keys is None else out_keys,
467
+ in_keys_inv=[self.prompt_key, self.response_key]
468
+ if in_keys_inv is None
469
+ else in_keys_inv,
470
+ # TODO: we should not use the response key here but another dedicated entry, like "action_parsed"
471
+ out_keys_inv=[self.response_key] if out_keys_inv is None else out_keys_inv,
472
+ )
473
+ self.chat_template_name = chat_template_name
474
+ self.tokenizer = tokenizer
475
+ self.tokenize = tokenize
476
+ self.continue_final_message = continue_final_message
477
+ self.return_tensors = return_tensors
478
+ self.return_dict = return_dict
479
+ self.padding = padding
480
+ self.truncation = truncation
481
+
482
+ def transform_observation_spec(self, observation_spec: Composite):
483
+ observation_spec[self.prompt_key] = NonTensor(
484
+ example_data="<some chat string>",
485
+ shape=observation_spec.shape,
486
+ device=observation_spec.device,
487
+ )
488
+ return observation_spec
489
+
490
+ @property
491
+ def _chat_template(self):
492
+ chat_template = None
493
+ if self.chat_template_name:
494
+ from torchrl.data.llm.datatypes.chat import _CHAT_TEMPLATES
495
+
496
+ chat_template = _CHAT_TEMPLATES[self.chat_template_name]
497
+ elif self.tokenizer.chat_template is not None:
498
+ chat_template = self.tokenizer.chat_template
499
+ elif chat_template is None:
500
+ raise ValueError("Failed to determine chat template.")
501
+ return chat_template
502
+
503
+ def _apply_transform(self, history: History) -> NonTensorData:
504
+ if self.tokenizer is None:
505
+ raise RuntimeError("Cannot apply chat template without a tokenizer.")
506
+ result = history.apply_chat_template(
507
+ tokenizer=self.tokenizer,
508
+ add_generation_prompt=True,
509
+ chat_template=self._chat_template,
510
+ continue_final_message=self.continue_final_message,
511
+ tokenize=self.tokenize,
512
+ padding=self.padding,
513
+ truncation=self.truncation,
514
+ return_tensors=self.return_tensors,
515
+ )
516
+ return result
517
+
518
+ def _reset(self, tensordict, tensordict_reset):
519
+ return self._call(tensordict_reset)
520
+
521
+ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
522
+ if self.in_keys_inv:
523
+ prompt = tensordict[self.prompt_key]
524
+ response = tensordict[self.response_key]
525
+ if isinstance(prompt, list):
526
+ action = [
527
+ prompt + response for prompt, response in zip(prompt, response)
528
+ ]
529
+ else:
530
+ action = prompt + response
531
+ try:
532
+ history, action = self._inv_apply_transform(action)
533
+ tensordict["local_history"] = history
534
+ tensordict[self.response_key] = action
535
+ except RuntimeError as e:
536
+ if "Expected assistant role" in str(e):
537
+ tensordict["local_history"] = History(role="assistant", content="")
538
+ tensordict[self.response_key] = ""
539
+ return tensordict
540
+
541
+ def _inv_apply_transform(self, action):
542
+ if self.tokenize:
543
+ action = self.tokenizer.decode(action)
544
+
545
+ if not isinstance(action, (str, list)):
546
+ action = action.data
547
+ history, action = self._inv_apply_transform(action)
548
+ action = NonTensorData(
549
+ action, batch_size=action.batch_size, device=action.device
550
+ )
551
+ return history, action
552
+
553
+ history = History.from_text(
554
+ action,
555
+ # chat_template=self._chat_template,
556
+ )[..., -1]
557
+ if history.role != "assistant":
558
+ raise RuntimeError(f"Expected assistant role, got {history.role=}")
559
+ action = history.get("content")
560
+ return history, action
561
+
562
+
563
+ class IsolateCodeBlock(MLGymBaseTransform):
564
+ """A transform that isolates the code block in the action generated by the LLM.
565
+
566
+ Optionally, wrongly formatted actions are assigned a negative reward.
567
+ """
568
+
569
+ response_key: NestedKey = "text_response"
570
+
571
+ def __init__(self, reward_wrong_format: float | None = None):
572
+ super().__init__(
573
+ in_keys_inv=[self.response_key], out_keys_inv=[self.response_key]
574
+ )
575
+ from mlgym.agent.parsing import ThoughtActionParser
576
+
577
+ self.parser = ThoughtActionParser()
578
+ self.reward_wrong_format = reward_wrong_format
579
+ self._assign_reward = False
580
+
581
+ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
582
+ torchrl_logger.info("inv call with IsolateCodeBlock")
583
+ action = tensordict[self.response_key]
584
+ # if we didn't find an action, the action is empty
585
+ if not action:
586
+ torchrl_logger.info(
587
+ "Did not find a suitable action, skipping the call to step."
588
+ )
589
+ tensordict["retry"] = torch.ones(tensordict.shape, dtype=torch.bool)
590
+ self._assign_reward = True
591
+ else:
592
+ from mlgym.exceptions import FormatError
593
+
594
+ try:
595
+ action = self._inv_apply_transform(action)
596
+ tensordict[self.response_key] = action
597
+ torchrl_logger.info(f"Code block: {action}")
598
+ tensordict["retry"] = torch.zeros(tensordict.shape, dtype=torch.bool)
599
+ self._assign_reward = False
600
+ except FormatError:
601
+ tensordict["retry"] = torch.ones(tensordict.shape, dtype=torch.bool)
602
+ self._assign_reward = True
603
+ return tensordict
604
+
605
+ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
606
+ if self._assign_reward:
607
+ torchrl_logger.info(
608
+ f"Assigning penalty for unsuitable action: {self.reward_wrong_format}"
609
+ )
610
+ if self.reward_wrong_format is not None:
611
+ tensordict[self.parent.reward_key] += self.reward_wrong_format
612
+ return tensordict
613
+
614
+ def _inv_apply_transform(self, action):
615
+ if not isinstance(action, (str, list)):
616
+ return NonTensorData(
617
+ self._inv_apply_transform(action.tolist()),
618
+ batch_size=action.batch_size,
619
+ device=action.device,
620
+ )
621
+ if isinstance(action, list):
622
+ return [self._inv_apply_transform(action) for action in action]
623
+ thought, action = self.parser(action, None)
624
+ return action
625
+
626
+
627
+ class EvaluationOutputParser:
628
+ """Parser for the reward transform in MLGym.
629
+
630
+ .. seealso:: :class:`~torchrl.envs.llm.libs.mlgym.MLGymRewardAssignment`
631
+
632
+ """
633
+
634
+ def __init__(self):
635
+ # Regular expressions to match the required fields
636
+ self.patterns = {
637
+ "submission_artefact_path": r"valid submission artefact at (.*)\.",
638
+ "baseline_score": r"Baseline Score: \{'Score': (.*)\}",
639
+ "evaluation_score": r"Evaluation Score: \{'Score': (.*)\}",
640
+ "current_step": r"\(Current Step: (\d+),",
641
+ "remaining_steps": r"Remaining Steps: (\d+)\)",
642
+ "open_file": r"\(Open file: (.*)\)",
643
+ "current_directory": r"\(Current directory: (.*)\)",
644
+ }
645
+
646
+ def __call__(self, output_string):
647
+
648
+ parsed_data = {}
649
+
650
+ for key, pattern in self.patterns.items():
651
+ match = re.search(pattern, output_string)
652
+ if match:
653
+ parsed_data[key] = match.group(1).strip()
654
+ if "baseline_score" in parsed_data:
655
+ parsed_data["baseline_score"] = float(parsed_data["baseline_score"])
656
+
657
+ if "evaluation_score" in parsed_data:
658
+ parsed_data["evaluation_score"] = float(parsed_data["evaluation_score"])
659
+ if "current_step" in parsed_data:
660
+ parsed_data["current_step"] = int(parsed_data["current_step"])
661
+ if "remaining_steps" in parsed_data:
662
+ parsed_data["remaining_steps"] = int(parsed_data["remaining_steps"])
663
+
664
+ return parsed_data
665
+
666
+
667
+ class MLGymRewardAssignment(MLGymBaseTransform):
668
+ """Reward assignment through parsing of the last item in history.
669
+
670
+ By default, the :class:`~torchrl.envs.llm.libs.mlgym.EvaluationOutputParser` class is used as parser.
671
+
672
+ """
673
+
674
+ def __init__(self):
675
+ super().__init__(in_keys=["reward", "history"], out_keys=["reward"])
676
+ self.parser = EvaluationOutputParser()
677
+
678
+ def _call(self, tensordict):
679
+ history = tensordict.get("history")
680
+ if history is None:
681
+ raise KeyError(f"History is missing in tensordict {tensordict}")
682
+ if history.ndim != 1:
683
+ raise ValueError(f"History shape must be 1D, got {history.shape}")
684
+ content = history[-1].content
685
+ torchrl_logger.info(f"Parsing reward from: {content}")
686
+ parsed = self.parser(content)
687
+ reward = parsed.get("evaluation_score", 0.0) - parsed.get("baseline_score", 0.0)
688
+ torchrl_logger.info(f"Parsed reward: {reward}")
689
+ tensordict["reward"] = tensordict["reward"] + reward
690
+ return tensordict
691
+
692
+
693
+ class _add_info_to_reset:
694
+ def __init__(self, func):
695
+ functools.update_wrapper(self, func)
696
+ self.func = func
697
+
698
+ def __call__(self, *args, **kwargs):
699
+ return self.func(*args, **kwargs), {}
700
+
701
+
702
+ class _add_truncated_to_step:
703
+ def __init__(self, func):
704
+ functools.update_wrapper(self, func)
705
+ self.func = func
706
+
707
+ @_temp_cwd_mlgym()
708
+ def __call__(self, *args, **kwargs):
709
+ obs, r, done, info = self.func(*args, **kwargs)
710
+ return obs, r, done, False, info
711
+
712
+
713
+ class MLGymWrapper(GymWrapper):
714
+ """A thin wrapper for MLGym environments.
715
+
716
+ This specialized :class:`~torchrl.envs.GymWrapper` subclass defines the observation space with `observation=NonTensor()`
717
+ and the action space with `text_response=NonTensor()`, according to the :class:`~torchrl.envs.llm.ChatEnv` API.
718
+
719
+ """
720
+
721
+ def __init__(self, *args, **kwargs):
722
+ super().__init__(*args, **kwargs)
723
+ self.full_action_spec = Composite(
724
+ text_response=NonTensor(example_data="<a string>", shape=())
725
+ )
726
+ self.full_observation_spec = Composite(
727
+ observation=NonTensor(example_data="<a string>", shape=())
728
+ )
729
+ self.set_env()
730
+
731
+ def set_env(self, env: Any = None):
732
+ if env is not None:
733
+ self._env = env
734
+ self._patch_reset()
735
+ self._patch_step()
736
+
737
+ def _patch_reset(self):
738
+ if not isinstance(self._env.reset, _add_info_to_reset):
739
+ self._env.reset = _add_info_to_reset(self._env.reset)
740
+
741
+ def _patch_step(self):
742
+ if not isinstance(self._env.reset, _add_truncated_to_step):
743
+ self._env.step = _add_truncated_to_step(self._env.step)
744
+
745
+ @_temp_cwd_mlgym()
746
+ def _reset(
747
+ self, tensordict: TensorDictBase | None = None, **kwargs
748
+ ) -> TensorDictBase:
749
+ return super()._reset(tensordict=tensordict, **kwargs)
750
+
751
+
752
+ _TASK_IDS = {}
753
+
754
+
755
+ def get_args(
756
+ task: Literal["prisonersDilemma"] = "prisonersDilemma",
757
+ ) -> tuple[
758
+ mlgym.environment.env.EnvironmentArguments, # noqa
759
+ mlgym.agent.base.AgentArguments, # noqa
760
+ ]: # noqa
761
+ """Parse command line arguments and return a ScriptArguments object.
762
+
763
+ Args:
764
+ args: Optional list of arguments to parse. If not provided, uses sys.argv.
765
+ """
766
+ import mlgym.environment.registration # noqa
767
+ from mlgym import CONFIG_DIR
768
+ from mlgym.agent.base import AgentArguments
769
+ from mlgym.backend.base import ModelArguments
770
+ from mlgym.environment.env import EnvironmentArguments
771
+ from mlgym.environment.registration import register_task
772
+
773
+ environment_args = EnvironmentArguments(
774
+ task_config_path=f"tasks/{task}.yaml",
775
+ max_steps=10,
776
+ seed=42,
777
+ container_type="podman",
778
+ verbose=False,
779
+ aliases_file="docker/aliases.sh",
780
+ )
781
+
782
+ agent_args = AgentArguments(
783
+ # placeholder
784
+ model=ModelArguments(""),
785
+ # Despite using torchrl as an agent, we still need the agent config - see StateToMessage parser
786
+ agent_config_path=CONFIG_DIR / "agents" / "default.yaml",
787
+ )
788
+
789
+ register_task(environment_args)
790
+
791
+ _TASK_IDS[task] = (environment_args.task.id, agent_args)
792
+
793
+ return environment_args, agent_args
794
+
795
+
796
+ def make_mlgym(
797
+ *,
798
+ task: Literal["prisonersDilemma"] | None = None,
799
+ tasks: list[Literal["prisonersDilemma"]] | None = None,
800
+ tokenizer: transformers.AutoTokenizer | str | None = None, # noqa
801
+ device="cpu",
802
+ reward_wrong_format: float | None = None,
803
+ ) -> TransformedEnv:
804
+ """Wraps an MLGymEnv in a TorchRL Environment.
805
+
806
+ The appended transforms will make sure that the data is formatted for the LLM during (for the outputs of `env.step`)
807
+ and for the MLGym API (for inputs to `env.step`).
808
+
809
+ Keyword Args:
810
+ task (str): The task to wrap. Exclusive with `tasks` argument.
811
+
812
+ .. note:: The correct format is simply the task name, e.g., `"prisonersDilemma"`.
813
+
814
+ tasks (List[str]): The tasks available for the env. Exclusive with `task` argument.
815
+
816
+ .. note:: The correct format is simply the task name, e.g., `"prisonersDilemma"`.
817
+
818
+ tokenizer (transformers.AutoTokenizer or str, optional): A transformer that tokenizes the data.
819
+ If a string is passed, it will be converted to a `transformers.AutoTokenizer`.
820
+ device (str, optional): The device to set to the env. Defaults to "cpu".
821
+ reward_wrong_format (float, optional): The reward (negative penalty) for wrongly formatted actions.
822
+ Defaults to `None` (no penalty).
823
+
824
+ """
825
+ import gymnasium as gym
826
+
827
+ if isinstance(tokenizer, str):
828
+ import transformers
829
+
830
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
831
+
832
+ with _temp_cwd_mlgym():
833
+
834
+ if task and not tasks:
835
+ environment_args, agent_args = get_args(task=task)
836
+ elif tasks and not task:
837
+ for task in tasks:
838
+ environment_args, agent_args = get_args(task=task)
839
+ else:
840
+ raise ValueError(
841
+ f"Either task or tasks should be provided, not both and not none. Got {task=} and {tasks=}."
842
+ )
843
+
844
+ base_env = gym.make(
845
+ f"mlgym/{_TASK_IDS[task][0]}",
846
+ devices=["cpu_0"],
847
+ ).unwrapped
848
+ # we need the env to have access to the config
849
+ base_env.config = agent_args.config
850
+ env = TransformedEnv(
851
+ MLGymWrapper(base_env, auto_reset=False, device=device), auto_unwrap=False
852
+ )
853
+
854
+ env.append_transform(ConditionalSkip(lambda td: td["retry"]))
855
+ env.append_transform(IsolateCodeBlock(reward_wrong_format=reward_wrong_format))
856
+
857
+ env.append_transform(ResetModule())
858
+ if tasks:
859
+ # Add a task sampler
860
+ env.append_transform(TaskSampler(tasks))
861
+ env.append_transform(ReadState())
862
+ env.append_transform(StateToMessage())
863
+ env.append_transform(MessageToHistory())
864
+ env.append_transform(TemplateTransform(tokenizer=tokenizer))
865
+ env.append_transform(MLGymRewardAssignment())
866
+ # # We want the env to have a batch-size of (1,) because it will be easier to interact with
867
+ # # LLMs
868
+ # env.append_transform(BatchSizeTransform(batch_size=(1,)))
869
+ return env