torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,955 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from torchrl.trainers.algorithms.configs.common import ConfigBase
12
+
13
+
14
+ @dataclass
15
+ class TransformConfig(ConfigBase):
16
+ """Base configuration class for transforms."""
17
+
18
+ def __post_init__(self) -> None:
19
+ """Post-initialization hook for transform configurations."""
20
+
21
+
22
+ @dataclass
23
+ class NoopResetEnvConfig(TransformConfig):
24
+ """Configuration for NoopResetEnv transform."""
25
+
26
+ noops: int = 30
27
+ random: bool = True
28
+ _target_: str = "torchrl.envs.transforms.transforms.NoopResetEnv"
29
+
30
+ def __post_init__(self) -> None:
31
+ """Post-initialization hook for NoopResetEnv configuration."""
32
+ super().__post_init__()
33
+
34
+
35
+ @dataclass
36
+ class StepCounterConfig(TransformConfig):
37
+ """Configuration for StepCounter transform."""
38
+
39
+ max_steps: int | None = None
40
+ truncated_key: str | None = "truncated"
41
+ step_count_key: str | None = "step_count"
42
+ update_done: bool = True
43
+ _target_: str = "torchrl.envs.transforms.transforms.StepCounter"
44
+
45
+ def __post_init__(self) -> None:
46
+ """Post-initialization hook for StepCounter configuration."""
47
+ super().__post_init__()
48
+
49
+
50
+ @dataclass
51
+ class ComposeConfig(TransformConfig):
52
+ """Configuration for Compose transform."""
53
+
54
+ transforms: list[Any] | None = None
55
+ _target_: str = "torchrl.envs.transforms.transforms.Compose"
56
+
57
+ def __post_init__(self) -> None:
58
+ """Post-initialization hook for Compose configuration."""
59
+ super().__post_init__()
60
+ if self.transforms is None:
61
+ self.transforms = []
62
+
63
+
64
+ @dataclass
65
+ class DoubleToFloatConfig(TransformConfig):
66
+ """Configuration for DoubleToFloat transform."""
67
+
68
+ in_keys: list[str] | None = None
69
+ out_keys: list[str] | None = None
70
+ in_keys_inv: list[str] | None = None
71
+ out_keys_inv: list[str] | None = None
72
+ _target_: str = "torchrl.envs.transforms.transforms.DoubleToFloat"
73
+
74
+ def __post_init__(self) -> None:
75
+ """Post-initialization hook for DoubleToFloat configuration."""
76
+ super().__post_init__()
77
+
78
+
79
+ @dataclass
80
+ class ToTensorImageConfig(TransformConfig):
81
+ """Configuration for ToTensorImage transform."""
82
+
83
+ from_int: bool | None = None
84
+ unsqueeze: bool = False
85
+ dtype: str | None = None
86
+ in_keys: list[str] | None = None
87
+ out_keys: list[str] | None = None
88
+ shape_tolerant: bool = False
89
+ _target_: str = "torchrl.envs.transforms.transforms.ToTensorImage"
90
+
91
+ def __post_init__(self) -> None:
92
+ """Post-initialization hook for ToTensorImage configuration."""
93
+ super().__post_init__()
94
+
95
+
96
+ @dataclass
97
+ class ClipTransformConfig(TransformConfig):
98
+ """Configuration for ClipTransform."""
99
+
100
+ in_keys: list[str] | None = None
101
+ out_keys: list[str] | None = None
102
+ in_keys_inv: list[str] | None = None
103
+ out_keys_inv: list[str] | None = None
104
+ low: float | None = None
105
+ high: float | None = None
106
+ _target_: str = "torchrl.envs.transforms.transforms.ClipTransform"
107
+
108
+ def __post_init__(self) -> None:
109
+ """Post-initialization hook for ClipTransform configuration."""
110
+ super().__post_init__()
111
+
112
+
113
+ @dataclass
114
+ class ResizeConfig(TransformConfig):
115
+ """Configuration for Resize transform."""
116
+
117
+ w: int = 84
118
+ h: int = 84
119
+ interpolation: str = "bilinear"
120
+ in_keys: list[str] | None = None
121
+ out_keys: list[str] | None = None
122
+ _target_: str = "torchrl.envs.transforms.transforms.Resize"
123
+
124
+ def __post_init__(self) -> None:
125
+ """Post-initialization hook for Resize configuration."""
126
+ super().__post_init__()
127
+
128
+
129
+ @dataclass
130
+ class CenterCropConfig(TransformConfig):
131
+ """Configuration for CenterCrop transform."""
132
+
133
+ height: int = 84
134
+ width: int = 84
135
+ in_keys: list[str] | None = None
136
+ out_keys: list[str] | None = None
137
+ _target_: str = "torchrl.envs.transforms.transforms.CenterCrop"
138
+
139
+ def __post_init__(self) -> None:
140
+ """Post-initialization hook for CenterCrop configuration."""
141
+ super().__post_init__()
142
+
143
+
144
+ @dataclass
145
+ class FlattenObservationConfig(TransformConfig):
146
+ """Configuration for FlattenObservation transform."""
147
+
148
+ in_keys: list[str] | None = None
149
+ out_keys: list[str] | None = None
150
+ _target_: str = "torchrl.envs.transforms.transforms.FlattenObservation"
151
+
152
+ def __post_init__(self) -> None:
153
+ """Post-initialization hook for FlattenObservation configuration."""
154
+ super().__post_init__()
155
+
156
+
157
+ @dataclass
158
+ class GrayScaleConfig(TransformConfig):
159
+ """Configuration for GrayScale transform."""
160
+
161
+ in_keys: list[str] | None = None
162
+ out_keys: list[str] | None = None
163
+ _target_: str = "torchrl.envs.transforms.transforms.GrayScale"
164
+
165
+ def __post_init__(self) -> None:
166
+ """Post-initialization hook for GrayScale configuration."""
167
+ super().__post_init__()
168
+
169
+
170
+ @dataclass
171
+ class ObservationNormConfig(TransformConfig):
172
+ """Configuration for ObservationNorm transform."""
173
+
174
+ loc: float = 0.0
175
+ scale: float = 1.0
176
+ in_keys: list[str] | None = None
177
+ out_keys: list[str] | None = None
178
+ standard_normal: bool = False
179
+ eps: float = 1e-8
180
+ _target_: str = "torchrl.envs.transforms.transforms.ObservationNorm"
181
+
182
+ def __post_init__(self) -> None:
183
+ """Post-initialization hook for ObservationNorm configuration."""
184
+ super().__post_init__()
185
+
186
+
187
+ @dataclass
188
+ class CatFramesConfig(TransformConfig):
189
+ """Configuration for CatFrames transform."""
190
+
191
+ N: int = 4
192
+ dim: int = -3
193
+ in_keys: list[str] | None = None
194
+ out_keys: list[str] | None = None
195
+ _target_: str = "torchrl.envs.transforms.transforms.CatFrames"
196
+
197
+ def __post_init__(self) -> None:
198
+ """Post-initialization hook for CatFrames configuration."""
199
+ super().__post_init__()
200
+
201
+
202
+ @dataclass
203
+ class RewardClippingConfig(TransformConfig):
204
+ """Configuration for RewardClipping transform."""
205
+
206
+ clamp_min: float | None = None
207
+ clamp_max: float | None = None
208
+ in_keys: list[str] | None = None
209
+ out_keys: list[str] | None = None
210
+ _target_: str = "torchrl.envs.transforms.transforms.RewardClipping"
211
+
212
+ def __post_init__(self) -> None:
213
+ """Post-initialization hook for RewardClipping configuration."""
214
+ super().__post_init__()
215
+
216
+
217
+ @dataclass
218
+ class RewardScalingConfig(TransformConfig):
219
+ """Configuration for RewardScaling transform."""
220
+
221
+ loc: float = 0.0
222
+ scale: float = 1.0
223
+ in_keys: list[str] | None = None
224
+ out_keys: list[str] | None = None
225
+ standard_normal: bool = False
226
+ eps: float = 1e-8
227
+ _target_: str = "torchrl.envs.transforms.transforms.RewardScaling"
228
+
229
+ def __post_init__(self) -> None:
230
+ """Post-initialization hook for RewardScaling configuration."""
231
+ super().__post_init__()
232
+
233
+
234
+ @dataclass
235
+ class VecNormConfig(TransformConfig):
236
+ """Configuration for VecNorm transform."""
237
+
238
+ in_keys: list[str] | None = None
239
+ out_keys: list[str] | None = None
240
+ decay: float = 0.99
241
+ eps: float = 1e-8
242
+ _target_: str = "torchrl.envs.transforms.transforms.VecNorm"
243
+
244
+ def __post_init__(self) -> None:
245
+ """Post-initialization hook for VecNorm configuration."""
246
+ super().__post_init__()
247
+
248
+
249
+ @dataclass
250
+ class FrameSkipTransformConfig(TransformConfig):
251
+ """Configuration for FrameSkipTransform."""
252
+
253
+ frame_skip: int = 4
254
+ in_keys: list[str] | None = None
255
+ out_keys: list[str] | None = None
256
+ _target_: str = "torchrl.envs.transforms.transforms.FrameSkipTransform"
257
+
258
+ def __post_init__(self) -> None:
259
+ """Post-initialization hook for FrameSkipTransform configuration."""
260
+ super().__post_init__()
261
+
262
+
263
+ @dataclass
264
+ class EndOfLifeTransformConfig(TransformConfig):
265
+ """Configuration for EndOfLifeTransform."""
266
+
267
+ eol_key: str = "end-of-life"
268
+ lives_key: str = "lives"
269
+ done_key: str = "done"
270
+ eol_attribute: str = "unwrapped.ale.lives"
271
+ _target_: str = "torchrl.envs.transforms.gym_transforms.EndOfLifeTransform"
272
+
273
+ def __post_init__(self) -> None:
274
+ """Post-initialization hook for EndOfLifeTransform configuration."""
275
+ super().__post_init__()
276
+
277
+
278
+ @dataclass
279
+ class MultiStepTransformConfig(TransformConfig):
280
+ """Configuration for MultiStepTransform."""
281
+
282
+ n_steps: int = 3
283
+ gamma: float = 0.99
284
+ in_keys: list[str] | None = None
285
+ out_keys: list[str] | None = None
286
+ _target_: str = "torchrl.envs.transforms.rb_transforms.MultiStepTransform"
287
+
288
+ def __post_init__(self) -> None:
289
+ """Post-initialization hook for MultiStepTransform configuration."""
290
+ super().__post_init__()
291
+
292
+
293
+ @dataclass
294
+ class TargetReturnConfig(TransformConfig):
295
+ """Configuration for TargetReturn transform."""
296
+
297
+ target_return: float = 10.0
298
+ mode: str = "reduce"
299
+ in_keys: list[str] | None = None
300
+ out_keys: list[str] | None = None
301
+ reset_key: str | None = None
302
+ _target_: str = "torchrl.envs.transforms.transforms.TargetReturn"
303
+
304
+ def __post_init__(self) -> None:
305
+ """Post-initialization hook for TargetReturn configuration."""
306
+ super().__post_init__()
307
+
308
+
309
+ @dataclass
310
+ class BinarizeRewardConfig(TransformConfig):
311
+ """Configuration for BinarizeReward transform."""
312
+
313
+ in_keys: list[str] | None = None
314
+ out_keys: list[str] | None = None
315
+ _target_: str = "torchrl.envs.transforms.transforms.BinarizeReward"
316
+
317
+ def __post_init__(self) -> None:
318
+ """Post-initialization hook for BinarizeReward configuration."""
319
+ super().__post_init__()
320
+
321
+
322
+ @dataclass
323
+ class ActionDiscretizerConfig(TransformConfig):
324
+ """Configuration for ActionDiscretizer transform."""
325
+
326
+ num_intervals: int = 10
327
+ action_key: str = "action"
328
+ out_action_key: str | None = None
329
+ sampling: str | None = None
330
+ categorical: bool = True
331
+ _target_: str = "torchrl.envs.transforms.transforms.ActionDiscretizer"
332
+
333
+ def __post_init__(self) -> None:
334
+ """Post-initialization hook for ActionDiscretizer configuration."""
335
+ super().__post_init__()
336
+
337
+
338
+ @dataclass
339
+ class AutoResetTransformConfig(TransformConfig):
340
+ """Configuration for AutoResetTransform."""
341
+
342
+ replace: bool | None = None
343
+ fill_float: str = "nan"
344
+ fill_int: int = -1
345
+ fill_bool: bool = False
346
+ _target_: str = "torchrl.envs.transforms.transforms.AutoResetTransform"
347
+
348
+ def __post_init__(self) -> None:
349
+ """Post-initialization hook for AutoResetTransform configuration."""
350
+ super().__post_init__()
351
+
352
+
353
+ @dataclass
354
+ class BatchSizeTransformConfig(TransformConfig):
355
+ """Configuration for BatchSizeTransform."""
356
+
357
+ batch_size: list[int] | None = None
358
+ reshape_fn: Any = None
359
+ reset_func: Any = None
360
+ env_kwarg: bool = False
361
+ _target_: str = "torchrl.envs.transforms.transforms.BatchSizeTransform"
362
+
363
+ def __post_init__(self) -> None:
364
+ """Post-initialization hook for BatchSizeTransform configuration."""
365
+ super().__post_init__()
366
+
367
+
368
+ @dataclass
369
+ class DeviceCastTransformConfig(TransformConfig):
370
+ """Configuration for DeviceCastTransform."""
371
+
372
+ device: str = "cpu"
373
+ in_keys: list[str] | None = None
374
+ out_keys: list[str] | None = None
375
+ in_keys_inv: list[str] | None = None
376
+ out_keys_inv: list[str] | None = None
377
+ _target_: str = "torchrl.envs.transforms.transforms.DeviceCastTransform"
378
+
379
+ def __post_init__(self) -> None:
380
+ """Post-initialization hook for DeviceCastTransform configuration."""
381
+ super().__post_init__()
382
+
383
+
384
+ @dataclass
385
+ class DTypeCastTransformConfig(TransformConfig):
386
+ """Configuration for DTypeCastTransform."""
387
+
388
+ dtype: str = "torch.float32"
389
+ in_keys: list[str] | None = None
390
+ out_keys: list[str] | None = None
391
+ in_keys_inv: list[str] | None = None
392
+ out_keys_inv: list[str] | None = None
393
+ _target_: str = "torchrl.envs.transforms.transforms.DTypeCastTransform"
394
+
395
+ def __post_init__(self) -> None:
396
+ """Post-initialization hook for DTypeCastTransform configuration."""
397
+ super().__post_init__()
398
+
399
+
400
+ @dataclass
401
+ class UnsqueezeTransformConfig(TransformConfig):
402
+ """Configuration for UnsqueezeTransform."""
403
+
404
+ dim: int = 0
405
+ in_keys: list[str] | None = None
406
+ out_keys: list[str] | None = None
407
+ _target_: str = "torchrl.envs.transforms.transforms.UnsqueezeTransform"
408
+
409
+ def __post_init__(self) -> None:
410
+ """Post-initialization hook for UnsqueezeTransform configuration."""
411
+ super().__post_init__()
412
+
413
+
414
+ @dataclass
415
+ class SqueezeTransformConfig(TransformConfig):
416
+ """Configuration for SqueezeTransform."""
417
+
418
+ dim: int = 0
419
+ in_keys: list[str] | None = None
420
+ out_keys: list[str] | None = None
421
+ _target_: str = "torchrl.envs.transforms.transforms.SqueezeTransform"
422
+
423
+ def __post_init__(self) -> None:
424
+ """Post-initialization hook for SqueezeTransform configuration."""
425
+ super().__post_init__()
426
+
427
+
428
+ @dataclass
429
+ class PermuteTransformConfig(TransformConfig):
430
+ """Configuration for PermuteTransform."""
431
+
432
+ dims: list[int] | None = None
433
+ in_keys: list[str] | None = None
434
+ out_keys: list[str] | None = None
435
+ _target_: str = "torchrl.envs.transforms.transforms.PermuteTransform"
436
+
437
+ def __post_init__(self) -> None:
438
+ """Post-initialization hook for PermuteTransform configuration."""
439
+ super().__post_init__()
440
+ if self.dims is None:
441
+ self.dims = [0, 2, 1]
442
+
443
+
444
+ @dataclass
445
+ class CatTensorsConfig(TransformConfig):
446
+ """Configuration for CatTensors transform."""
447
+
448
+ dim: int = -1
449
+ in_keys: list[str] | None = None
450
+ out_keys: list[str] | None = None
451
+ _target_: str = "torchrl.envs.transforms.transforms.CatTensors"
452
+
453
+ def __post_init__(self) -> None:
454
+ """Post-initialization hook for CatTensors configuration."""
455
+ super().__post_init__()
456
+
457
+
458
+ @dataclass
459
+ class StackConfig(TransformConfig):
460
+ """Configuration for Stack transform."""
461
+
462
+ dim: int = 0
463
+ in_keys: list[str] | None = None
464
+ out_keys: list[str] | None = None
465
+ _target_: str = "torchrl.envs.transforms.transforms.Stack"
466
+
467
+ def __post_init__(self) -> None:
468
+ """Post-initialization hook for Stack configuration."""
469
+ super().__post_init__()
470
+
471
+
472
+ @dataclass
473
+ class DiscreteActionProjectionConfig(TransformConfig):
474
+ """Configuration for DiscreteActionProjection transform."""
475
+
476
+ num_actions: int = 4
477
+ in_keys: list[str] | None = None
478
+ out_keys: list[str] | None = None
479
+ _target_: str = "torchrl.envs.transforms.transforms.DiscreteActionProjection"
480
+
481
+ def __post_init__(self) -> None:
482
+ """Post-initialization hook for DiscreteActionProjection configuration."""
483
+ super().__post_init__()
484
+
485
+
486
+ @dataclass
487
+ class TensorDictPrimerConfig(TransformConfig):
488
+ """Configuration for TensorDictPrimer transform."""
489
+
490
+ primer_spec: Any = None
491
+ in_keys: list[str] | None = None
492
+ out_keys: list[str] | None = None
493
+ _target_: str = "torchrl.envs.transforms.transforms.TensorDictPrimer"
494
+
495
+ def __post_init__(self) -> None:
496
+ """Post-initialization hook for TensorDictPrimer configuration."""
497
+ super().__post_init__()
498
+
499
+
500
+ @dataclass
501
+ class PinMemoryTransformConfig(TransformConfig):
502
+ """Configuration for PinMemoryTransform."""
503
+
504
+ in_keys: list[str] | None = None
505
+ out_keys: list[str] | None = None
506
+ _target_: str = "torchrl.envs.transforms.transforms.PinMemoryTransform"
507
+
508
+ def __post_init__(self) -> None:
509
+ """Post-initialization hook for PinMemoryTransform configuration."""
510
+ super().__post_init__()
511
+
512
+
513
+ @dataclass
514
+ class RewardSumConfig(TransformConfig):
515
+ """Configuration for RewardSum transform."""
516
+
517
+ in_keys: list[str] | None = None
518
+ out_keys: list[str] | None = None
519
+ _target_: str = "torchrl.envs.transforms.transforms.RewardSum"
520
+
521
+ def __post_init__(self) -> None:
522
+ """Post-initialization hook for RewardSum configuration."""
523
+ super().__post_init__()
524
+
525
+
526
+ @dataclass
527
+ class ExcludeTransformConfig(TransformConfig):
528
+ """Configuration for ExcludeTransform."""
529
+
530
+ exclude_keys: list[str] | None = None
531
+ _target_: str = "torchrl.envs.transforms.transforms.ExcludeTransform"
532
+
533
+ def __post_init__(self) -> None:
534
+ """Post-initialization hook for ExcludeTransform configuration."""
535
+ super().__post_init__()
536
+ if self.exclude_keys is None:
537
+ self.exclude_keys = []
538
+
539
+
540
+ @dataclass
541
+ class SelectTransformConfig(TransformConfig):
542
+ """Configuration for SelectTransform."""
543
+
544
+ include_keys: list[str] | None = None
545
+ _target_: str = "torchrl.envs.transforms.transforms.SelectTransform"
546
+
547
+ def __post_init__(self) -> None:
548
+ """Post-initialization hook for SelectTransform configuration."""
549
+ super().__post_init__()
550
+ if self.include_keys is None:
551
+ self.include_keys = []
552
+
553
+
554
+ @dataclass
555
+ class TimeMaxPoolConfig(TransformConfig):
556
+ """Configuration for TimeMaxPool transform."""
557
+
558
+ dim: int = -1
559
+ in_keys: list[str] | None = None
560
+ out_keys: list[str] | None = None
561
+ _target_: str = "torchrl.envs.transforms.transforms.TimeMaxPool"
562
+
563
+ def __post_init__(self) -> None:
564
+ """Post-initialization hook for TimeMaxPool configuration."""
565
+ super().__post_init__()
566
+
567
+
568
+ @dataclass
569
+ class RandomCropTensorDictConfig(TransformConfig):
570
+ """Configuration for RandomCropTensorDict transform."""
571
+
572
+ crop_size: list[int] | None = None
573
+ in_keys: list[str] | None = None
574
+ out_keys: list[str] | None = None
575
+ _target_: str = "torchrl.envs.transforms.transforms.RandomCropTensorDict"
576
+
577
+ def __post_init__(self) -> None:
578
+ """Post-initialization hook for RandomCropTensorDict configuration."""
579
+ super().__post_init__()
580
+ if self.crop_size is None:
581
+ self.crop_size = [84, 84]
582
+
583
+
584
+ @dataclass
585
+ class InitTrackerConfig(TransformConfig):
586
+ """Configuration for InitTracker transform."""
587
+
588
+ init_key: str | None = None
589
+ _target_: str = "torchrl.envs.transforms.transforms.InitTracker"
590
+
591
+ def __post_init__(self) -> None:
592
+ """Post-initialization hook for InitTracker configuration."""
593
+ super().__post_init__()
594
+
595
+
596
+ @dataclass
597
+ class RenameTransformConfig(TransformConfig):
598
+ """Configuration for RenameTransform."""
599
+
600
+ key_mapping: dict[str, str] | None = None
601
+ _target_: str = "torchrl.envs.transforms.transforms.RenameTransform"
602
+
603
+ def __post_init__(self) -> None:
604
+ """Post-initialization hook for RenameTransform configuration."""
605
+ super().__post_init__()
606
+ if self.key_mapping is None:
607
+ self.key_mapping = {}
608
+
609
+
610
+ @dataclass
611
+ class Reward2GoTransformConfig(TransformConfig):
612
+ """Configuration for Reward2GoTransform."""
613
+
614
+ gamma: float = 0.99
615
+ in_keys: list[str] | None = None
616
+ out_keys: list[str] | None = None
617
+ _target_: str = "torchrl.envs.transforms.transforms.Reward2GoTransform"
618
+
619
+ def __post_init__(self) -> None:
620
+ """Post-initialization hook for Reward2GoTransform configuration."""
621
+ super().__post_init__()
622
+
623
+
624
+ @dataclass
625
+ class ActionMaskConfig(TransformConfig):
626
+ """Configuration for ActionMask transform."""
627
+
628
+ mask_key: str = "action_mask"
629
+ in_keys: list[str] | None = None
630
+ out_keys: list[str] | None = None
631
+ _target_: str = "torchrl.envs.transforms.transforms.ActionMask"
632
+
633
+ def __post_init__(self) -> None:
634
+ """Post-initialization hook for ActionMask configuration."""
635
+ super().__post_init__()
636
+
637
+
638
+ @dataclass
639
+ class VecGymEnvTransformConfig(TransformConfig):
640
+ """Configuration for VecGymEnvTransform."""
641
+
642
+ in_keys: list[str] | None = None
643
+ out_keys: list[str] | None = None
644
+ _target_: str = "torchrl.envs.transforms.transforms.VecGymEnvTransform"
645
+
646
+ def __post_init__(self) -> None:
647
+ """Post-initialization hook for VecGymEnvTransform configuration."""
648
+ super().__post_init__()
649
+
650
+
651
+ @dataclass
652
+ class BurnInTransformConfig(TransformConfig):
653
+ """Configuration for BurnInTransform."""
654
+
655
+ burn_in: int = 10
656
+ in_keys: list[str] | None = None
657
+ out_keys: list[str] | None = None
658
+ _target_: str = "torchrl.envs.transforms.transforms.BurnInTransform"
659
+
660
+ def __post_init__(self) -> None:
661
+ """Post-initialization hook for BurnInTransform configuration."""
662
+ super().__post_init__()
663
+
664
+
665
+ @dataclass
666
+ class SignTransformConfig(TransformConfig):
667
+ """Configuration for SignTransform."""
668
+
669
+ in_keys: list[str] | None = None
670
+ out_keys: list[str] | None = None
671
+ _target_: str = "torchrl.envs.transforms.transforms.SignTransform"
672
+
673
+ def __post_init__(self) -> None:
674
+ """Post-initialization hook for SignTransform configuration."""
675
+ super().__post_init__()
676
+
677
+
678
+ @dataclass
679
+ class RemoveEmptySpecsConfig(TransformConfig):
680
+ """Configuration for RemoveEmptySpecs transform."""
681
+
682
+ _target_: str = "torchrl.envs.transforms.transforms.RemoveEmptySpecs"
683
+
684
+ def __post_init__(self) -> None:
685
+ """Post-initialization hook for RemoveEmptySpecs configuration."""
686
+ super().__post_init__()
687
+
688
+
689
+ @dataclass
690
+ class TrajCounterConfig(TransformConfig):
691
+ """Configuration for TrajCounter transform."""
692
+
693
+ out_key: str = "traj_count"
694
+ repeats: int | None = None
695
+ _target_: str = "torchrl.envs.transforms.transforms.TrajCounter"
696
+
697
+ def __post_init__(self) -> None:
698
+ """Post-initialization hook for TrajCounter configuration."""
699
+ super().__post_init__()
700
+
701
+
702
+ @dataclass
703
+ class LineariseRewardsConfig(TransformConfig):
704
+ """Configuration for LineariseRewards transform."""
705
+
706
+ in_keys: list[str] | None = None
707
+ out_keys: list[str] | None = None
708
+ weights: list[float] | None = None
709
+ _target_: str = "torchrl.envs.transforms.transforms.LineariseRewards"
710
+
711
+ def __post_init__(self) -> None:
712
+ """Post-initialization hook for LineariseRewards configuration."""
713
+ super().__post_init__()
714
+ if self.in_keys is None:
715
+ self.in_keys = []
716
+
717
+
718
+ @dataclass
719
+ class ConditionalSkipConfig(TransformConfig):
720
+ """Configuration for ConditionalSkip transform."""
721
+
722
+ cond: Any = None
723
+ _target_: str = "torchrl.envs.transforms.transforms.ConditionalSkip"
724
+
725
+ def __post_init__(self) -> None:
726
+ """Post-initialization hook for ConditionalSkip configuration."""
727
+ super().__post_init__()
728
+
729
+
730
+ @dataclass
731
+ class MultiActionConfig(TransformConfig):
732
+ """Configuration for MultiAction transform."""
733
+
734
+ dim: int = 1
735
+ stack_rewards: bool = True
736
+ stack_observations: bool = False
737
+ _target_: str = "torchrl.envs.transforms.transforms.MultiAction"
738
+
739
+ def __post_init__(self) -> None:
740
+ """Post-initialization hook for MultiAction configuration."""
741
+ super().__post_init__()
742
+
743
+
744
+ @dataclass
745
+ class TimerConfig(TransformConfig):
746
+ """Configuration for Timer transform."""
747
+
748
+ out_keys: list[str] | None = None
749
+ time_key: str = "time"
750
+ _target_: str = "torchrl.envs.transforms.transforms.Timer"
751
+
752
+ def __post_init__(self) -> None:
753
+ """Post-initialization hook for Timer configuration."""
754
+ super().__post_init__()
755
+
756
+
757
+ @dataclass
758
+ class ConditionalPolicySwitchConfig(TransformConfig):
759
+ """Configuration for ConditionalPolicySwitch transform."""
760
+
761
+ policy: Any = None
762
+ condition: Any = None
763
+ _target_: str = "torchrl.envs.transforms.transforms.ConditionalPolicySwitch"
764
+
765
+ def __post_init__(self) -> None:
766
+ """Post-initialization hook for ConditionalPolicySwitch configuration."""
767
+ super().__post_init__()
768
+
769
+
770
+ @dataclass
771
+ class KLRewardTransformConfig(TransformConfig):
772
+ """Configuration for KLRewardTransform."""
773
+
774
+ in_keys: list[str] | None = None
775
+ out_keys: list[str] | None = None
776
+ _target_: str = "torchrl.envs.transforms.llm.KLRewardTransform"
777
+
778
+ def __post_init__(self) -> None:
779
+ """Post-initialization hook for KLRewardTransform configuration."""
780
+ super().__post_init__()
781
+
782
+
783
+ @dataclass
784
+ class R3MTransformConfig(TransformConfig):
785
+ """Configuration for R3MTransform."""
786
+
787
+ in_keys: list[str] | None = None
788
+ out_keys: list[str] | None = None
789
+ model_name: str = "resnet18"
790
+ device: str = "cpu"
791
+ _target_: str = "torchrl.envs.transforms.r3m.R3MTransform"
792
+
793
+ def __post_init__(self) -> None:
794
+ """Post-initialization hook for R3MTransform configuration."""
795
+ super().__post_init__()
796
+
797
+
798
+ @dataclass
799
+ class VC1TransformConfig(TransformConfig):
800
+ """Configuration for VC1Transform."""
801
+
802
+ in_keys: list[str] | None = None
803
+ out_keys: list[str] | None = None
804
+ device: str = "cpu"
805
+ _target_: str = "torchrl.envs.transforms.vc1.VC1Transform"
806
+
807
+ def __post_init__(self) -> None:
808
+ """Post-initialization hook for VC1Transform configuration."""
809
+ super().__post_init__()
810
+
811
+
812
+ @dataclass
813
+ class VIPTransformConfig(TransformConfig):
814
+ """Configuration for VIPTransform."""
815
+
816
+ in_keys: list[str] | None = None
817
+ out_keys: list[str] | None = None
818
+ device: str = "cpu"
819
+ _target_: str = "torchrl.envs.transforms.vip.VIPTransform"
820
+
821
+ def __post_init__(self) -> None:
822
+ """Post-initialization hook for VIPTransform configuration."""
823
+ super().__post_init__()
824
+
825
+
826
+ @dataclass
827
+ class VIPRewardTransformConfig(TransformConfig):
828
+ """Configuration for VIPRewardTransform."""
829
+
830
+ in_keys: list[str] | None = None
831
+ out_keys: list[str] | None = None
832
+ device: str = "cpu"
833
+ _target_: str = "torchrl.envs.transforms.vip.VIPRewardTransform"
834
+
835
+ def __post_init__(self) -> None:
836
+ """Post-initialization hook for VIPRewardTransform configuration."""
837
+ super().__post_init__()
838
+
839
+
840
+ @dataclass
841
+ class VecNormV2Config(TransformConfig):
842
+ """Configuration for VecNormV2 transform."""
843
+
844
+ in_keys: list[str] | None = None
845
+ out_keys: list[str] | None = None
846
+ decay: float = 0.99
847
+ eps: float = 1e-8
848
+ _target_: str = "torchrl.envs.transforms.vecnorm.VecNormV2"
849
+
850
+ def __post_init__(self) -> None:
851
+ """Post-initialization hook for VecNormV2 configuration."""
852
+ super().__post_init__()
853
+
854
+
855
+ @dataclass
856
+ class FiniteTensorDictCheckConfig(TransformConfig):
857
+ """Configuration for FiniteTensorDictCheck transform."""
858
+
859
+ in_keys: list[str] | None = None
860
+ out_keys: list[str] | None = None
861
+ _target_: str = "torchrl.envs.transforms.transforms.FiniteTensorDictCheck"
862
+
863
+ def __post_init__(self) -> None:
864
+ """Post-initialization hook for FiniteTensorDictCheck configuration."""
865
+ super().__post_init__()
866
+
867
+
868
+ @dataclass
869
+ class UnaryTransformConfig(TransformConfig):
870
+ """Configuration for UnaryTransform."""
871
+
872
+ fn: Any = None
873
+ in_keys: list[str] | None = None
874
+ out_keys: list[str] | None = None
875
+ _target_: str = "torchrl.envs.transforms.transforms.UnaryTransform"
876
+
877
+ def __post_init__(self) -> None:
878
+ """Post-initialization hook for UnaryTransform configuration."""
879
+ super().__post_init__()
880
+
881
+
882
+ @dataclass
883
+ class HashConfig(TransformConfig):
884
+ """Configuration for Hash transform."""
885
+
886
+ in_keys: list[str] | None = None
887
+ out_keys: list[str] | None = None
888
+ _target_: str = "torchrl.envs.transforms.transforms.Hash"
889
+
890
+ def __post_init__(self) -> None:
891
+ """Post-initialization hook for Hash configuration."""
892
+ super().__post_init__()
893
+
894
+
895
+ @dataclass
896
+ class TokenizerConfig(TransformConfig):
897
+ """Configuration for Tokenizer transform."""
898
+
899
+ vocab_size: int = 1000
900
+ in_keys: list[str] | None = None
901
+ out_keys: list[str] | None = None
902
+ _target_: str = "torchrl.envs.transforms.transforms.Tokenizer"
903
+
904
+ def __post_init__(self) -> None:
905
+ """Post-initialization hook for Tokenizer configuration."""
906
+ super().__post_init__()
907
+
908
+
909
+ @dataclass
910
+ class CropConfig(TransformConfig):
911
+ """Configuration for Crop transform."""
912
+
913
+ top: int = 0
914
+ left: int = 0
915
+ height: int = 84
916
+ width: int = 84
917
+ in_keys: list[str] | None = None
918
+ out_keys: list[str] | None = None
919
+ _target_: str = "torchrl.envs.transforms.transforms.Crop"
920
+
921
+ def __post_init__(self) -> None:
922
+ """Post-initialization hook for Crop configuration."""
923
+ super().__post_init__()
924
+
925
+
926
+ @dataclass
927
+ class FlattenTensorDictConfig(TransformConfig):
928
+ """Configuration for flattening TensorDict during inverse pass.
929
+
930
+ This transform reshapes the tensordict to have a flat batch dimension
931
+ during the inverse pass, which is useful for replay buffers that need
932
+ to store data with a flat batch structure.
933
+ """
934
+
935
+ _target_: str = "torchrl.envs.transforms.transforms.FlattenTensorDict"
936
+
937
+ def __post_init__(self) -> None:
938
+ """Post-initialization hook for FlattenTensorDict configuration."""
939
+ super().__post_init__()
940
+
941
+
942
+ @dataclass
943
+ class ModuleTransformConfig(TransformConfig):
944
+ """Configuration for ModuleTransform."""
945
+
946
+ module: Any = None
947
+ device: Any = None
948
+ no_grad: bool = False
949
+ inverse: bool = False
950
+ _target_: str = "torchrl.envs.transforms.module.ModuleTransform"
951
+ _partial_: bool = False
952
+
953
+ def __post_init__(self) -> None:
954
+ """Post-initialization hook for ModuleTransform configuration."""
955
+ super().__post_init__()