torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,573 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This makes omegaconf unhappy with typing.Any
7
+ # Therefore we need Optional and Union
8
+ # from __future__ import annotations
9
+
10
+ import importlib.util
11
+ from collections.abc import Callable, Sequence
12
+ from copy import copy
13
+ from dataclasses import dataclass, field as dataclass_field
14
+ from typing import Any, Optional, Union
15
+
16
+ import torch
17
+ from torchrl._utils import logger as torchrl_logger, VERBOSE
18
+ from torchrl.envs import ParallelEnv
19
+ from torchrl.envs.common import EnvBase
20
+ from torchrl.envs.env_creator import env_creator, EnvCreator
21
+ from torchrl.envs.libs.dm_control import DMControlEnv
22
+ from torchrl.envs.libs.gym import GymEnv
23
+ from torchrl.envs.transforms import (
24
+ CatFrames,
25
+ CatTensors,
26
+ CenterCrop,
27
+ Compose,
28
+ DoubleToFloat,
29
+ GrayScale,
30
+ NoopResetEnv,
31
+ ObservationNorm,
32
+ Resize,
33
+ RewardScaling,
34
+ ToTensorImage,
35
+ TransformedEnv,
36
+ VecNorm,
37
+ )
38
+ from torchrl.envs.transforms.transforms import (
39
+ FlattenObservation,
40
+ gSDENoise,
41
+ InitTracker,
42
+ StepCounter,
43
+ )
44
+ from torchrl.record.loggers import Logger
45
+ from torchrl.record.recorder import VideoRecorder
46
+
47
+ LIBS = {
48
+ "gym": GymEnv,
49
+ "dm_control": DMControlEnv,
50
+ }
51
+
52
+ _has_omegaconf = importlib.util.find_spec("omegaconf") is not None
53
+ if _has_omegaconf:
54
+ from omegaconf import DictConfig
55
+ else:
56
+
57
+ class DictConfig: # noqa
58
+ ...
59
+
60
+
61
+ def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821
62
+ """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.
63
+
64
+ This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targeting a total number of frames
65
+ of 1M but actually collecting frame_skip * 1M frames.
66
+
67
+ Args:
68
+ cfg (DictConfig): DictConfig containing some frame-counting argument, including:
69
+ "max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames",
70
+ "init_random_frames", "init_env_steps"
71
+
72
+ Returns:
73
+ the input DictConfig, modified in-place.
74
+
75
+ """
76
+ # Adapt all frame counts wrt frame_skip
77
+ if cfg.frame_skip != 1:
78
+ fields = [
79
+ "max_frames_per_traj",
80
+ "total_frames",
81
+ "frames_per_batch",
82
+ "record_frames",
83
+ "annealing_frames",
84
+ "init_random_frames",
85
+ "init_env_steps",
86
+ "noops",
87
+ ]
88
+ for field in fields:
89
+ if hasattr(cfg, field):
90
+ setattr(cfg, field, getattr(cfg, field) // cfg.frame_skip)
91
+ return cfg
92
+
93
+
94
+ def make_env_transforms(
95
+ env,
96
+ cfg,
97
+ video_tag,
98
+ logger,
99
+ env_name,
100
+ stats,
101
+ norm_obs_only,
102
+ env_library,
103
+ action_dim_gsde,
104
+ state_dim_gsde,
105
+ batch_dims=0,
106
+ obs_norm_state_dict=None,
107
+ ):
108
+ """Creates the typical transforms for and env."""
109
+ env = TransformedEnv(env)
110
+
111
+ from_pixels = cfg.from_pixels
112
+ vecnorm = cfg.vecnorm
113
+ norm_rewards = vecnorm and cfg.norm_rewards
114
+ _norm_obs_only = norm_obs_only or not norm_rewards
115
+ reward_scaling = cfg.reward_scaling
116
+ reward_loc = cfg.reward_loc
117
+
118
+ if len(video_tag):
119
+ center_crop = cfg.center_crop
120
+ if center_crop:
121
+ center_crop = center_crop[0]
122
+ env.append_transform(
123
+ VideoRecorder(
124
+ logger=logger,
125
+ tag=f"{video_tag}_{env_name}_video",
126
+ center_crop=center_crop,
127
+ ),
128
+ )
129
+
130
+ if from_pixels:
131
+ if not cfg.catframes:
132
+ raise RuntimeError(
133
+ "this env builder currently only accepts positive catframes values "
134
+ "when pixels are being used."
135
+ )
136
+ env.append_transform(ToTensorImage())
137
+ if cfg.center_crop:
138
+ env.append_transform(CenterCrop(*cfg.center_crop))
139
+ env.append_transform(Resize(cfg.image_size, cfg.image_size))
140
+ if cfg.grayscale:
141
+ env.append_transform(GrayScale())
142
+ env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True))
143
+ env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"], dim=-3))
144
+ if stats is None and obs_norm_state_dict is None:
145
+ obs_stats = {}
146
+ elif stats is None:
147
+ obs_stats = copy(obs_norm_state_dict)
148
+ else:
149
+ obs_stats = copy(stats)
150
+ obs_stats["standard_normal"] = True
151
+ obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"])
152
+ env.append_transform(obs_norm)
153
+ if norm_rewards:
154
+ reward_scaling = 1.0
155
+ reward_loc = 0.0
156
+ if norm_obs_only:
157
+ reward_scaling = 1.0
158
+ reward_loc = 0.0
159
+ if reward_scaling is not None:
160
+ env.append_transform(RewardScaling(reward_loc, reward_scaling))
161
+
162
+ if not from_pixels:
163
+ selected_keys = [
164
+ key
165
+ for key in env.observation_spec.keys(True, True)
166
+ if ("pixels" not in key) and (key not in env.state_spec.keys(True, True))
167
+ ]
168
+
169
+ # even if there is a single tensor, it'll be renamed in "observation_vector"
170
+ out_key = "observation_vector"
171
+ env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key))
172
+
173
+ if not vecnorm:
174
+ if stats is None and obs_norm_state_dict is None:
175
+ _stats = {}
176
+ elif stats is None:
177
+ _stats = copy(obs_norm_state_dict)
178
+ else:
179
+ _stats = copy(stats)
180
+ _stats.update({"standard_normal": True})
181
+ obs_norm = ObservationNorm(
182
+ **_stats,
183
+ in_keys=[out_key],
184
+ )
185
+ env.append_transform(obs_norm)
186
+ else:
187
+ env.append_transform(
188
+ VecNorm(
189
+ in_keys=[out_key, "reward"] if not _norm_obs_only else [out_key],
190
+ decay=0.9999,
191
+ )
192
+ )
193
+
194
+ env.append_transform(DoubleToFloat())
195
+
196
+ if hasattr(cfg, "catframes") and cfg.catframes:
197
+ env.append_transform(CatFrames(N=cfg.catframes, in_keys=[out_key], dim=-1))
198
+
199
+ else:
200
+ env.append_transform(DoubleToFloat())
201
+
202
+ if hasattr(cfg, "gSDE") and cfg.gSDE:
203
+ env.append_transform(
204
+ gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde)
205
+ )
206
+
207
+ env.append_transform(StepCounter())
208
+ env.append_transform(InitTracker())
209
+
210
+ return env
211
+
212
+
213
+ def get_norm_state_dict(env):
214
+ """Gets the normalization loc and scale from the env state_dict."""
215
+ sd = env.state_dict()
216
+ sd = {
217
+ key: val
218
+ for key, val in sd.items()
219
+ if key.endswith("loc") or key.endswith("scale")
220
+ }
221
+ return sd
222
+
223
+
224
+ def transformed_env_constructor(
225
+ cfg: DictConfig, # noqa: F821
226
+ video_tag: str = "",
227
+ logger: Optional[Logger] = None, # noqa
228
+ stats: Optional[dict] = None,
229
+ norm_obs_only: bool = False,
230
+ use_env_creator: bool = False,
231
+ custom_env_maker: Optional[Callable] = None,
232
+ custom_env: Optional[EnvBase] = None,
233
+ return_transformed_envs: bool = True,
234
+ action_dim_gsde: Optional[int] = None,
235
+ state_dim_gsde: Optional[int] = None,
236
+ batch_dims: Optional[int] = 0,
237
+ obs_norm_state_dict: Optional[dict] = None,
238
+ ) -> Union[Callable, EnvCreator]:
239
+ """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
240
+
241
+ Args:
242
+ cfg (DictConfig): a DictConfig containing the arguments of the script.
243
+ video_tag (str, optional): video tag to be passed to the Logger object
244
+ logger (Logger, optional): logger associated with the script
245
+ stats (dict, optional): a dictionary containing the :obj:`loc` and :obj:`scale` for the `ObservationNorm` transform
246
+ norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online.
247
+ Default is `False`.
248
+ use_env_creator (bool, optional): whether the `EnvCreator` class should be used. By using `EnvCreator`,
249
+ one can make sure that running statistics will be put in shared memory and accessible for all workers
250
+ when using a `VecNorm` transform. Default is `True`.
251
+ custom_env_maker (callable, optional): if your env maker is not part
252
+ of torchrl env wrappers, a custom callable
253
+ can be passed instead. In this case it will override the
254
+ constructor retrieved from `args`.
255
+ custom_env (EnvBase, optional): if an existing environment needs to be
256
+ transformed_in, it can be passed directly to this helper. `custom_env_maker`
257
+ and `custom_env` are exclusive features.
258
+ return_transformed_envs (bool, optional): if ``True``, a transformed_in environment
259
+ is returned.
260
+ action_dim_gsde (int, Optional): if gSDE is used, this can present the action dim to initialize the noise.
261
+ Make sure this is indicated in environment executed in parallel.
262
+ state_dim_gsde: if gSDE is used, this can present the state dim to initialize the noise.
263
+ Make sure this is indicated in environment executed in parallel.
264
+ batch_dims (int, optional): number of dimensions of a batch of data. If a single env is
265
+ used, it should be 0 (default). If multiple envs are being transformed in parallel,
266
+ it should be set to 1 (or the number of dims of the batch).
267
+ obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the
268
+ environment
269
+ """
270
+
271
+ def make_transformed_env(**kwargs) -> TransformedEnv:
272
+ env_name = cfg.env_name
273
+ env_task = cfg.env_task
274
+ env_library = LIBS[cfg.env_library]
275
+ frame_skip = cfg.frame_skip
276
+ from_pixels = cfg.from_pixels
277
+ categorical_action_encoding = cfg.categorical_action_encoding
278
+
279
+ if custom_env is None and custom_env_maker is None:
280
+ if isinstance(cfg.collector_device, str):
281
+ device = cfg.collector_device
282
+ elif isinstance(cfg.collector_device, Sequence):
283
+ device = cfg.collector_device[0]
284
+ else:
285
+ raise ValueError(
286
+ "collector_device must be either a string or a sequence of strings"
287
+ )
288
+ env_kwargs = {
289
+ "env_name": env_name,
290
+ "device": device,
291
+ "frame_skip": frame_skip,
292
+ "from_pixels": from_pixels or len(video_tag),
293
+ "pixels_only": from_pixels,
294
+ }
295
+ if env_library is GymEnv:
296
+ env_kwargs.update(
297
+ {"categorical_action_encoding": categorical_action_encoding}
298
+ )
299
+ elif categorical_action_encoding:
300
+ raise NotImplementedError(
301
+ "categorical_action_encoding=True is currently only compatible with GymEnvs."
302
+ )
303
+ if env_library is DMControlEnv:
304
+ env_kwargs.update({"task_name": env_task})
305
+ env_kwargs.update(kwargs)
306
+ env = env_library(**env_kwargs)
307
+ elif custom_env is None and custom_env_maker is not None:
308
+ env = custom_env_maker(**kwargs)
309
+ elif custom_env_maker is None and custom_env is not None:
310
+ env = custom_env
311
+ else:
312
+ raise RuntimeError("cannot provide both custom_env and custom_env_maker")
313
+
314
+ if cfg.noops and custom_env is None:
315
+ # this is a bit hacky: if custom_env is not None, it is probably a ParallelEnv
316
+ # that already has its NoopResetEnv set for the contained envs.
317
+ # There is a risk however that we're just skipping the NoopsReset instantiation
318
+ env = TransformedEnv(env, NoopResetEnv(cfg.noops))
319
+ if not return_transformed_envs:
320
+ return env
321
+
322
+ return make_env_transforms(
323
+ env,
324
+ cfg,
325
+ video_tag,
326
+ logger,
327
+ env_name,
328
+ stats,
329
+ norm_obs_only,
330
+ env_library,
331
+ action_dim_gsde,
332
+ state_dim_gsde,
333
+ batch_dims=batch_dims,
334
+ obs_norm_state_dict=obs_norm_state_dict,
335
+ )
336
+
337
+ if use_env_creator:
338
+ return env_creator(make_transformed_env)
339
+ return make_transformed_env
340
+
341
+
342
+ def parallel_env_constructor(
343
+ cfg: DictConfig, **kwargs # noqa: F821
344
+ ) -> Union[ParallelEnv, EnvCreator]:
345
+ """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
346
+
347
+ Args:
348
+ cfg (DictConfig): config containing user-defined arguments
349
+ kwargs: keyword arguments for the `transformed_env_constructor` method.
350
+ """
351
+ batch_transform = cfg.batch_transform
352
+ if not batch_transform:
353
+ raise NotImplementedError(
354
+ "batch_transform must be set to True for the recorder to be synced "
355
+ "with the collection envs."
356
+ )
357
+ if cfg.env_per_collector == 1:
358
+ kwargs.update({"cfg": cfg, "use_env_creator": True})
359
+ make_transformed_env = transformed_env_constructor(**kwargs)
360
+ return make_transformed_env
361
+ kwargs.update({"cfg": cfg, "use_env_creator": True})
362
+ make_transformed_env = transformed_env_constructor(
363
+ return_transformed_envs=not batch_transform, **kwargs
364
+ )
365
+ parallel_env = ParallelEnv(
366
+ num_workers=cfg.env_per_collector,
367
+ create_env_fn=make_transformed_env,
368
+ create_env_kwargs=None,
369
+ pin_memory=cfg.pin_memory,
370
+ )
371
+ if batch_transform:
372
+ kwargs.update(
373
+ {
374
+ "cfg": cfg,
375
+ "use_env_creator": False,
376
+ "custom_env": parallel_env,
377
+ "batch_dims": 1,
378
+ }
379
+ )
380
+ env = transformed_env_constructor(**kwargs)()
381
+ return env
382
+ return parallel_env
383
+
384
+
385
+ @torch.no_grad()
386
+ def get_stats_random_rollout(
387
+ cfg: DictConfig, # noqa: F821
388
+ proof_environment: EnvBase = None,
389
+ key: Optional[str] = None,
390
+ ):
391
+ """Gathers stas (loc and scale) from an environment using random rollouts.
392
+
393
+ Args:
394
+ cfg (DictConfig): a config object with `init_env_steps` field, indicating
395
+ the total number of frames to be collected to compute the stats.
396
+ proof_environment (EnvBase instance, optional): if provided, this env will
397
+ be used ot execute the rollouts. If not, it will be created using
398
+ the cfg object.
399
+ key (str, optional): if provided, the stats of this key will be gathered.
400
+ If not, it is expected that only one key exists in `env.observation_spec`.
401
+
402
+ """
403
+ proof_env_is_none = proof_environment is None
404
+ if proof_env_is_none:
405
+ proof_environment = transformed_env_constructor(
406
+ cfg=cfg, use_env_creator=False, stats={"loc": 0.0, "scale": 1.0}
407
+ )()
408
+
409
+ if VERBOSE:
410
+ torchrl_logger.info("computing state stats")
411
+ if not hasattr(cfg, "init_env_steps"):
412
+ raise AttributeError("init_env_steps missing from arguments.")
413
+
414
+ n = 0
415
+ val_stats = []
416
+ while n < cfg.init_env_steps:
417
+ _td_stats = proof_environment.rollout(max_steps=cfg.init_env_steps)
418
+ n += _td_stats.numel()
419
+ val = _td_stats.get(key).cpu()
420
+ val_stats.append(val)
421
+ del _td_stats, val
422
+ val_stats = torch.cat(val_stats, 0)
423
+
424
+ if key is None:
425
+ keys = list(proof_environment.observation_spec.keys(True, True))
426
+ key = keys.pop()
427
+ if len(keys):
428
+ raise RuntimeError(
429
+ f"More than one key exists in the observation_specs: {[key] + keys} were found, "
430
+ "thus get_stats_random_rollout cannot infer which to compute the stats of."
431
+ )
432
+
433
+ if key == "pixels":
434
+ m = val_stats.mean()
435
+ s = val_stats.std()
436
+ else:
437
+ m = val_stats.mean(dim=0)
438
+ s = val_stats.std(dim=0)
439
+ m[s == 0] = 0.0
440
+ s[s == 0] = 1.0
441
+
442
+ if VERBOSE:
443
+ torchrl_logger.info(
444
+ f"stats computed for {val_stats.numel()} steps. Got: \n"
445
+ f"loc = {m}, \n"
446
+ f"scale = {s}"
447
+ )
448
+ if not torch.isfinite(m).all():
449
+ raise RuntimeError("non-finite values found in mean")
450
+ if not torch.isfinite(s).all():
451
+ raise RuntimeError("non-finite values found in sd")
452
+ stats = {"loc": m, "scale": s}
453
+ if proof_env_is_none:
454
+ proof_environment.close()
455
+ if (
456
+ proof_environment.device != torch.device("cpu")
457
+ and torch.cuda.device_count() > 0
458
+ ):
459
+ torch.cuda.empty_cache()
460
+ del proof_environment
461
+ return stats
462
+
463
+
464
+ def initialize_observation_norm_transforms(
465
+ proof_environment: EnvBase,
466
+ num_iter: int = 1000,
467
+ key: Optional[Union[str, tuple[str, ...]]] = None,
468
+ ):
469
+ """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
470
+
471
+ If an :obj:`ObservationNorm` already has non-null :obj:`loc` or :obj:`scale`, a call to :obj:`initialize_observation_norm_transforms` will be a no-op.
472
+ Similarly, if the transformed environment does not contain any :obj:`ObservationNorm`, a call to this function will have no effect.
473
+ If no key is provided but the observations of the :obj:`EnvBase` contains more than one key, an exception will
474
+ be raised.
475
+
476
+ Args:
477
+ proof_environment (EnvBase instance, optional): if provided, this env will
478
+ be used to execute the rollouts. If not, it will be created using
479
+ the cfg object.
480
+ num_iter (int): Number of iterations used for initializing the :obj:`ObservationNorms`
481
+ key (str, optional): if provided, the stats of this key will be gathered.
482
+ If not, it is expected that only one key exists in `env.observation_spec`.
483
+
484
+ """
485
+ if not isinstance(proof_environment.transform, Compose) and not isinstance(
486
+ proof_environment.transform, ObservationNorm
487
+ ):
488
+ return
489
+
490
+ if key is None:
491
+ keys = list(proof_environment.base_env.observation_spec.keys(True, True))
492
+ key = keys.pop()
493
+ if len(keys):
494
+ raise RuntimeError(
495
+ f"More than one key exists in the observation_specs: {[key] + keys} were found, "
496
+ "thus initialize_observation_norm_transforms cannot infer which to compute the stats of."
497
+ )
498
+
499
+ if isinstance(proof_environment.transform, Compose):
500
+ for transform in proof_environment.transform:
501
+ if isinstance(transform, ObservationNorm) and not transform.initialized:
502
+ transform.init_stats(num_iter=num_iter, key=key)
503
+ elif not proof_environment.transform.initialized:
504
+ proof_environment.transform.init_stats(num_iter=num_iter, key=key)
505
+
506
+
507
+ def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv):
508
+ """Traverses the transforms of the environment and retrieves the :obj:`ObservationNorm` state dicts.
509
+
510
+ Returns a list of tuple (idx, state_dict) for each :obj:`ObservationNorm` transform in proof_environment
511
+ If the environment transforms do not contain any :obj:`ObservationNorm`, returns an empty list
512
+
513
+ Args:
514
+ proof_environment (EnvBase instance, optional): the :obj:``TransformedEnv` to retrieve the :obj:`ObservationNorm`
515
+ state dict from
516
+ """
517
+ obs_norm_state_dicts = []
518
+
519
+ if isinstance(proof_environment.transform, Compose):
520
+ for idx, transform in enumerate(proof_environment.transform):
521
+ if isinstance(transform, ObservationNorm):
522
+ obs_norm_state_dicts.append((idx, transform.state_dict()))
523
+
524
+ if isinstance(proof_environment.transform, ObservationNorm):
525
+ obs_norm_state_dicts.append((0, proof_environment.transform.state_dict()))
526
+
527
+ return obs_norm_state_dicts
528
+
529
+
530
+ @dataclass
531
+ class EnvConfig:
532
+ """Environment config struct."""
533
+
534
+ env_library: str = "gym"
535
+ # env_library used for the simulated environment. Default=gym
536
+ env_name: str = "Humanoid-v2"
537
+ # name of the environment to be created. Default=Humanoid-v2
538
+ env_task: str = ""
539
+ # task (if any) for the environment. Default=run
540
+ from_pixels: bool = False
541
+ # whether the environment output should be state vector(s) (default) or the pixels.
542
+ frame_skip: int = 1
543
+ # frame_skip for the environment. Note that this value does NOT impact the buffer size,
544
+ # maximum steps per trajectory, frames per batch or any other factor in the algorithm,
545
+ # e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4
546
+ # the actual number of frames retrieved will be 200e6. Default=1.
547
+ reward_scaling: Any = None # noqa
548
+ # scale of the reward.
549
+ reward_loc: float = 0.0
550
+ # location of the reward.
551
+ init_env_steps: int = 1000
552
+ # number of random steps to compute normalizing constants
553
+ vecnorm: bool = False
554
+ # Normalizes the environment observation and reward outputs with the running statistics obtained across processes.
555
+ norm_rewards: bool = False
556
+ # If True, rewards will be normalized on the fly. This may interfere with SAC update rule and should be used cautiously.
557
+ norm_stats: bool = True
558
+ # Deactivates the normalization based on random collection of data.
559
+ noops: int = 0
560
+ # number of random steps to do after reset. Default is 0
561
+ catframes: int = 0
562
+ # Number of frames to concatenate through time. Default is 0 (do not use CatFrames).
563
+ center_crop: Any = dataclass_field(default_factory=lambda: [])
564
+ # center crop size.
565
+ grayscale: bool = True
566
+ # Disables grayscale transform.
567
+ max_frames_per_traj: int = 1000
568
+ # Number of steps before a reset of the environment is called (if it has not been flagged as done before).
569
+ batch_transform: bool = False
570
+ # if ``True``, the transforms will be applied to the parallel env, and not to each individual env.\
571
+ image_size: int = 84
572
+ # if True and environment has discrete action space, then it is encoded as categorical values rather than one-hot.
573
+ categorical_action_encoding: bool = False
@@ -0,0 +1,33 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Any
9
+
10
+
11
+ @dataclass
12
+ class LoggerConfig:
13
+ """Logger config data-class."""
14
+
15
+ logger: str = "csv"
16
+ # recorder type to be used. One of 'tensorboard', 'wandb' or 'csv'
17
+ record_video: bool = False
18
+ # whether a video of the task should be rendered during logging.
19
+ no_video: bool = True
20
+ # whether a video of the task should be rendered during logging.
21
+ exp_name: str = ""
22
+ # experiment name. Used for logging directory.
23
+ # A date and uuid will be joined to account for multiple experiments with the same name.
24
+ record_interval: int = 1000
25
+ # number of batch collections in between two collections of validation rollouts. Default=1000.
26
+ record_frames: int = 1000
27
+ # number of steps in validation rollouts. " "Default=1000.
28
+ recorder_log_keys: Any = field(default_factory=lambda: None)
29
+ # Keys to log in the recorder
30
+ offline_logging: bool = True
31
+ # If True, Wandb will do the logging offline
32
+ project_name: str = ""
33
+ # The name of the project for WandB