torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,544 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import collections
8
+ import importlib
9
+ import os
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from torchrl._utils import logger as torchrl_logger, VERBOSE
16
+ from torchrl.data.tensor_specs import (
17
+ Bounded,
18
+ Categorical,
19
+ Composite,
20
+ OneHot,
21
+ TensorSpec,
22
+ Unbounded,
23
+ )
24
+ from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict
25
+ from torchrl.envs.common import _EnvPostInit
26
+ from torchrl.envs.gym_like import GymLikeEnv
27
+ from torchrl.envs.utils import _classproperty
28
+
29
+ if torch.cuda.device_count() > 1:
30
+ n = torch.cuda.device_count() - 1
31
+ os.environ["EGL_DEVICE_ID"] = str(1 + (os.getpid() % n))
32
+ if VERBOSE:
33
+ torchrl_logger.info(f"EGL_DEVICE_ID: {os.environ['EGL_DEVICE_ID']}")
34
+
35
+ _has_dmc = _has_dm_control = importlib.util.find_spec("dm_control") is not None
36
+
37
+ __all__ = ["DMControlEnv", "DMControlWrapper"]
38
+
39
+
40
+ def _dmcontrol_to_torchrl_spec_transform(
41
+ spec,
42
+ dtype: torch.dtype | None = None,
43
+ device: DEVICE_TYPING = None,
44
+ categorical_discrete_encoding: bool = False,
45
+ ) -> TensorSpec:
46
+ import dm_env
47
+
48
+ if isinstance(spec, collections.OrderedDict) or isinstance(spec, dict):
49
+ spec = {
50
+ k: _dmcontrol_to_torchrl_spec_transform(
51
+ item,
52
+ device=device,
53
+ categorical_discrete_encoding=categorical_discrete_encoding,
54
+ )
55
+ for k, item in spec.items()
56
+ }
57
+ return Composite(**spec)
58
+ elif isinstance(spec, dm_env.specs.DiscreteArray):
59
+ # DiscreteArray is a type of BoundedArray so this block needs to go first
60
+ action_space_cls = Categorical if categorical_discrete_encoding else OneHot
61
+ if dtype is None:
62
+ dtype = (
63
+ numpy_to_torch_dtype_dict[spec.dtype]
64
+ if categorical_discrete_encoding
65
+ else torch.long
66
+ )
67
+ return action_space_cls(spec.num_values, device=device, dtype=dtype)
68
+ elif isinstance(spec, dm_env.specs.BoundedArray):
69
+ if dtype is None:
70
+ dtype = numpy_to_torch_dtype_dict[spec.dtype]
71
+ shape = spec.shape
72
+ if not len(shape):
73
+ shape = torch.Size([1])
74
+ return Bounded(
75
+ shape=shape,
76
+ low=spec.minimum,
77
+ high=spec.maximum,
78
+ dtype=dtype,
79
+ device=device,
80
+ )
81
+ elif isinstance(spec, dm_env.specs.Array):
82
+ shape = spec.shape
83
+ if not len(shape):
84
+ shape = torch.Size([1])
85
+ if dtype is None:
86
+ dtype = numpy_to_torch_dtype_dict[spec.dtype]
87
+ if dtype in (torch.float, torch.double, torch.half):
88
+ return Unbounded(shape=shape, dtype=dtype, device=device)
89
+ else:
90
+ return Unbounded(shape=shape, dtype=dtype, device=device)
91
+ else:
92
+ raise NotImplementedError(type(spec))
93
+
94
+
95
+ def _get_envs(to_dict: bool = True) -> dict[str, Any]:
96
+ if not _has_dm_control:
97
+ raise ImportError("Cannot find dm_control in virtual environment.")
98
+ from dm_control import suite
99
+
100
+ if not to_dict:
101
+ return tuple(suite.BENCHMARKING) + tuple(suite.EXTRA)
102
+ d = {}
103
+ for tup in suite.BENCHMARKING:
104
+ env_name = tup[0]
105
+ d.setdefault(env_name, []).append(tup[1])
106
+ for tup in suite.EXTRA:
107
+ env_name = tup[0]
108
+ d.setdefault(env_name, []).append(tup[1])
109
+ return d.items()
110
+
111
+
112
+ def _robust_to_tensor(array: float | np.ndarray) -> torch.Tensor:
113
+ if isinstance(array, np.ndarray):
114
+ return torch.as_tensor(array.copy())
115
+ else:
116
+ return torch.as_tensor(array)
117
+
118
+
119
+ class _DMControlMeta(_EnvPostInit):
120
+ """Metaclass for DMControlEnv that returns a lazy ParallelEnv when num_workers > 1.
121
+
122
+ When ``DMControlEnv(..., num_workers=4)`` is called, this metaclass intercepts the
123
+ call and returns a :class:`~torchrl.envs.ParallelEnv` instead. The returned
124
+ ParallelEnv is lazy — workers are not started until the environment is actually used
125
+ (e.g., via :meth:`torchrl.envs.batched_envs.BatchedEnvBase.reset` / :meth:`torchrl.envs.batched_envs.BatchedEnvBase.step`
126
+ or when accessing specs).
127
+
128
+ Users can call :meth:`torchrl.envs.batched_envs.BatchedEnvBase.configure_parallel`
129
+ to set ParallelEnv parameters before the environment starts.
130
+ """
131
+
132
+ def __call__(cls, *args, num_workers: int | None = None, **kwargs):
133
+ # Extract num_workers from explicit kwarg or kwargs dict
134
+ if num_workers is None:
135
+ num_workers = kwargs.pop("num_workers", 1)
136
+ else:
137
+ kwargs.pop("num_workers", None)
138
+
139
+ num_workers = int(num_workers) if num_workers is not None else 1
140
+ if cls.__name__ == "DMControlEnv" and num_workers > 1:
141
+ from torchrl.envs import ParallelEnv
142
+
143
+ # Extract env_name and task_name from args
144
+ env_name = args[0] if len(args) >= 1 else kwargs.get("env_name")
145
+ task_name = args[1] if len(args) >= 2 else kwargs.get("task_name")
146
+
147
+ # Remove env_name and task_name from kwargs if they were there
148
+ # (they'll be passed positionally to the env creator)
149
+ env_kwargs = {
150
+ k: v for k, v in kwargs.items() if k not in ("env_name", "task_name")
151
+ }
152
+
153
+ # Create factory function that builds single DMControlEnv instances
154
+ def make_env(_env_name=env_name, _task_name=task_name, _kwargs=env_kwargs):
155
+ return cls(_env_name, _task_name, num_workers=1, **_kwargs)
156
+
157
+ # Return lazy ParallelEnv (workers not started yet)
158
+ return ParallelEnv(num_workers, make_env)
159
+
160
+ return super().__call__(*args, **kwargs)
161
+
162
+
163
+ class DMControlWrapper(GymLikeEnv):
164
+ """DeepMind Control lab environment wrapper.
165
+
166
+ The DeepMind control library can be found here: https://github.com/deepmind/dm_control.
167
+
168
+ Paper: https://arxiv.org/abs/2006.12983
169
+
170
+ Args:
171
+ env (dm_control.suite env): :class:`~dm_control.suite.base.Task`
172
+ environment instance.
173
+
174
+ Keyword Args:
175
+ from_pixels (bool, optional): if ``True``, an attempt to return the pixel
176
+ observations from the env will be performed.
177
+ By default, these observations
178
+ will be written under the ``"pixels"`` entry.
179
+ Defaults to ``False``.
180
+ pixels_only (bool, optional): if ``True``, only the pixel observations will
181
+ be returned (by default under the ``"pixels"`` entry in the output tensordict).
182
+ If ``False``, observations (eg, states) and pixels will be returned
183
+ whenever ``from_pixels=True``. Defaults to ``True``.
184
+ frame_skip (int, optional): if provided, indicates for how many steps the
185
+ same action is to be repeated. The observation returned will be the
186
+ last observation of the sequence, whereas the reward will be the sum
187
+ of rewards across steps.
188
+ device (torch.device, optional): if provided, the device on which the data
189
+ is to be cast. Defaults to ``torch.device("cpu")``.
190
+ batch_size (torch.Size, optional): the batch size of the environment.
191
+ Should match the leading dimensions of all observations, done states,
192
+ rewards, actions and infos.
193
+ Defaults to ``torch.Size([])``.
194
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
195
+ for envs to be ``done`` just after :meth:`reset` is called.
196
+ Defaults to ``False``.
197
+
198
+ Attributes:
199
+ available_envs (list): a list of ``Tuple[str, List[str]]`` representing the
200
+ environment / task pairs available.
201
+
202
+ Examples:
203
+ >>> from dm_control import suite
204
+ >>> from torchrl.envs import DMControlWrapper
205
+ >>> env = suite.load("cheetah", "run")
206
+ >>> env = DMControlWrapper(env,
207
+ ... from_pixels=True, frame_skip=4)
208
+ >>> td = env.rand_step()
209
+ >>> print(td)
210
+ TensorDict(
211
+ fields={
212
+ action: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float64, is_shared=False),
213
+ next: TensorDict(
214
+ fields={
215
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
216
+ pixels: Tensor(shape=torch.Size([240, 320, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
217
+ position: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float64, is_shared=False),
218
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float64, is_shared=False),
219
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
220
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
221
+ velocity: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.float64, is_shared=False)},
222
+ batch_size=torch.Size([]),
223
+ device=cpu,
224
+ is_shared=False)},
225
+ batch_size=torch.Size([]),
226
+ device=cpu,
227
+ is_shared=False)
228
+ >>> print(env.available_envs)
229
+ [('acrobot', ['swingup', 'swingup_sparse']), ('ball_in_cup', ['catch']), ('cartpole', ['balance', 'balance_sparse', 'swingup', 'swingup_sparse', 'three_poles', 'two_poles']), ('cheetah', ['run']), ('finger', ['spin', 'turn_easy', 'turn_hard']), ('fish', ['upright', 'swim']), ('hopper', ['stand', 'hop']), ('humanoid', ['stand', 'walk', 'run', 'run_pure_state']), ('manipulator', ['bring_ball', 'bring_peg', 'insert_ball', 'insert_peg']), ('pendulum', ['swingup']), ('point_mass', ['easy', 'hard']), ('reacher', ['easy', 'hard']), ('swimmer', ['swimmer6', 'swimmer15']), ('walker', ['stand', 'walk', 'run']), ('dog', ['fetch', 'run', 'stand', 'trot', 'walk']), ('humanoid_CMU', ['run', 'stand', 'walk']), ('lqr', ['lqr_2_1', 'lqr_6_2']), ('quadruped', ['escape', 'fetch', 'run', 'walk']), ('stacker', ['stack_2', 'stack_4'])]
230
+
231
+ """
232
+
233
+ git_url = "https://github.com/deepmind/dm_control"
234
+ libname = "dm_control"
235
+
236
+ @_classproperty
237
+ def available_envs(cls):
238
+ if not _has_dm_control:
239
+ return []
240
+ return list(_get_envs())
241
+
242
+ @property
243
+ def lib(self):
244
+ import dm_control
245
+
246
+ return dm_control
247
+
248
+ def __init__(self, env=None, **kwargs):
249
+ if env is not None:
250
+ kwargs["env"] = env
251
+ super().__init__(**kwargs)
252
+
253
+ def _build_env(
254
+ self,
255
+ env,
256
+ _seed: int | None = None,
257
+ from_pixels: bool = False,
258
+ render_kwargs: dict | None = None,
259
+ pixels_only: bool = False,
260
+ camera_id: int | str = 0,
261
+ **kwargs,
262
+ ):
263
+ self.from_pixels = from_pixels
264
+ self.pixels_only = pixels_only
265
+
266
+ if from_pixels:
267
+ from dm_control.suite.wrappers import pixels
268
+
269
+ self._set_egl_device(self.device)
270
+ self.render_kwargs = {"camera_id": camera_id}
271
+ if render_kwargs is not None:
272
+ self.render_kwargs.update(render_kwargs)
273
+ env = pixels.Wrapper(
274
+ env,
275
+ pixels_only=self.pixels_only,
276
+ render_kwargs=self.render_kwargs,
277
+ )
278
+ return env
279
+
280
+ def _make_specs(self, env: gym.Env) -> None: # noqa: F821
281
+ # specs are defined when first called
282
+ self.observation_spec = _dmcontrol_to_torchrl_spec_transform(
283
+ self._env.observation_spec(), device=self.device
284
+ )
285
+ reward_spec = _dmcontrol_to_torchrl_spec_transform(
286
+ self._env.reward_spec(), device=self.device
287
+ )
288
+ if len(reward_spec.shape) == 0:
289
+ reward_spec.shape = torch.Size([1])
290
+ self.reward_spec = reward_spec
291
+ # populate default done spec
292
+ done_spec = Categorical(
293
+ n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device
294
+ )
295
+ self.done_spec = Composite(
296
+ done=done_spec.clone(),
297
+ truncated=done_spec.clone(),
298
+ terminated=done_spec.clone(),
299
+ device=self.device,
300
+ )
301
+ self.action_spec = _dmcontrol_to_torchrl_spec_transform(
302
+ self._env.action_spec(), device=self.device
303
+ )
304
+
305
+ def _check_kwargs(self, kwargs: dict):
306
+ dm_control = self.lib
307
+ from dm_control.suite.wrappers import pixels
308
+
309
+ if "env" not in kwargs:
310
+ raise TypeError("Could not find environment key 'env' in kwargs.")
311
+ env = kwargs["env"]
312
+ if not isinstance(env, (dm_control.rl.control.Environment, pixels.Wrapper)):
313
+ raise TypeError(
314
+ "env is not of type 'dm_control.rl.control.Environment' or `dm_control.suite.wrappers.pixels.Wrapper`."
315
+ )
316
+
317
+ def _set_egl_device(self, device: DEVICE_TYPING):
318
+ # Deprecated as lead to unreliable rendering
319
+ # egl device needs to be set before importing mujoco bindings: in
320
+ # distributed settings, it'll be easy to tell which cuda device to use.
321
+ # In mp settings, we'll need to use mp.Pool with a specific init function
322
+ # that defines the EGL device before importing libraries. For now, we'll
323
+ # just use a common EGL_DEVICE_ID environment variable for all processes.
324
+ return
325
+
326
+ def to(self, device: DEVICE_TYPING) -> DMControlEnv:
327
+ super().to(device)
328
+ self._set_egl_device(self.device)
329
+ return self
330
+
331
+ def _init_env(self, seed: int | None = None) -> int | None:
332
+ seed = self.set_seed(seed)
333
+ return seed
334
+
335
+ def _set_seed(self, _seed: int | None) -> None:
336
+ from dm_control.suite.wrappers import pixels
337
+
338
+ if _seed is None:
339
+ return None
340
+ random_state = np.random.RandomState(_seed)
341
+ if isinstance(self._env, pixels.Wrapper):
342
+ if not hasattr(self._env._env.task, "_random"):
343
+ raise RuntimeError("self._env._env.task._random does not exist")
344
+ self._env._env.task._random = random_state
345
+ else:
346
+ if not hasattr(self._env.task, "_random"):
347
+ raise RuntimeError("self._env._env.task._random does not exist")
348
+ self._env.task._random = random_state
349
+ self.reset()
350
+
351
+ def _output_transform(
352
+ self, timestep_tuple: tuple[TimeStep] # noqa: F821
353
+ ) -> tuple[np.ndarray, float, bool, bool, dict]:
354
+ from dm_env import StepType
355
+
356
+ if type(timestep_tuple) is not tuple:
357
+ timestep_tuple = (timestep_tuple,)
358
+ reward = timestep_tuple[0].reward
359
+
360
+ truncated = terminated = False
361
+ if timestep_tuple[0].step_type == StepType.LAST:
362
+ if np.isclose(timestep_tuple[0].discount, 1):
363
+ truncated = True
364
+ else:
365
+ terminated = True
366
+ done = truncated or terminated
367
+
368
+ observation = timestep_tuple[0].observation
369
+ info = {}
370
+
371
+ return observation, reward, terminated, truncated, done, info
372
+
373
+ def _reset_output_transform(self, reset_data):
374
+ (
375
+ observation,
376
+ reward,
377
+ terminated,
378
+ truncated,
379
+ done,
380
+ info,
381
+ ) = self._output_transform(reset_data)
382
+ return observation, info
383
+
384
+ def __repr__(self) -> str:
385
+ return (
386
+ f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
387
+ )
388
+
389
+
390
+ class DMControlEnv(DMControlWrapper, metaclass=_DMControlMeta):
391
+ """DeepMind Control lab environment wrapper.
392
+
393
+ The DeepMind control library can be found here: https://github.com/deepmind/dm_control.
394
+
395
+ Paper: https://arxiv.org/abs/2006.12983
396
+
397
+ Args:
398
+ env_name (str): name of the environment.
399
+ task_name (str): name of the task.
400
+ num_workers (int, optional): number of parallel environments. Defaults to 1.
401
+ When ``num_workers > 1``, a lazy :class:`~torchrl.envs.ParallelEnv` is
402
+ returned instead of a single environment. The parallel environment
403
+ is not started until it is actually used (e.g., via reset/step or
404
+ accessing specs). Use :meth:`~torchrl.envs.BatchedEnvBase.configure_parallel`
405
+ to set parallel execution parameters before the environment starts.
406
+
407
+ Keyword Args:
408
+ from_pixels (bool, optional): if ``True``, an attempt to return the pixel
409
+ observations from the env will be performed.
410
+ By default, these observations
411
+ will be written under the ``"pixels"`` entry.
412
+ Defaults to ``False``.
413
+ pixels_only (bool, optional): if ``True``, only the pixel observations will
414
+ be returned (by default under the ``"pixels"`` entry in the output tensordict).
415
+ If ``False``, observations (eg, states) and pixels will be returned
416
+ whenever ``from_pixels=True``. Defaults to ``True``.
417
+ frame_skip (int, optional): if provided, indicates for how many steps the
418
+ same action is to be repeated. The observation returned will be the
419
+ last observation of the sequence, whereas the reward will be the sum
420
+ of rewards across steps.
421
+ device (torch.device, optional): if provided, the device on which the data
422
+ is to be cast. Defaults to ``torch.device("cpu")``.
423
+ batch_size (torch.Size, optional): the batch size of the environment.
424
+ Should match the leading dimensions of all observations, done states,
425
+ rewards, actions and infos.
426
+ Defaults to ``torch.Size([])``.
427
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
428
+ for envs to be ``done`` just after :meth:`reset` is called.
429
+ Defaults to ``False``.
430
+
431
+ Attributes:
432
+ available_envs (list): a list of ``Tuple[str, List[str]]`` representing the
433
+ environment / task pairs available.
434
+
435
+ Examples:
436
+ >>> from torchrl.envs import DMControlEnv
437
+ >>> env = DMControlEnv(env_name="cheetah", task_name="run",
438
+ ... from_pixels=True, frame_skip=4)
439
+ >>> td = env.rand_step()
440
+ >>> print(td)
441
+ TensorDict(
442
+ fields={
443
+ action: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float64, is_shared=False),
444
+ next: TensorDict(
445
+ fields={
446
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
447
+ pixels: Tensor(shape=torch.Size([240, 320, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
448
+ position: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float64, is_shared=False),
449
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float64, is_shared=False),
450
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
451
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
452
+ velocity: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.float64, is_shared=False)},
453
+ batch_size=torch.Size([]),
454
+ device=cpu,
455
+ is_shared=False)},
456
+ batch_size=torch.Size([]),
457
+ device=cpu,
458
+ is_shared=False)
459
+ >>> print(env.available_envs)
460
+ [('acrobot', ['swingup', 'swingup_sparse']), ...]
461
+ >>> # For running multiple envs in parallel (returns a lazy ParallelEnv)
462
+ >>> env = DMControlEnv("cheetah", "run", num_workers=4)
463
+ >>> # Configure parallel parameters before the env starts
464
+ >>> env.configure_parallel(use_buffers=True, num_threads=2)
465
+ >>> # Environment starts when first used
466
+ >>> env.reset()
467
+ """
468
+
469
+ def __init__(self, env_name, task_name, **kwargs):
470
+ if not _has_dmc:
471
+ raise ImportError(
472
+ "dm_control python package was not found. Please install this dependency."
473
+ )
474
+
475
+ kwargs["env_name"] = env_name
476
+ kwargs["task_name"] = task_name
477
+
478
+ super().__init__(**kwargs)
479
+
480
+ def _build_env(
481
+ self,
482
+ env_name: str,
483
+ task_name: str,
484
+ _seed: int | None = None,
485
+ **kwargs,
486
+ ):
487
+ from dm_control import suite
488
+
489
+ self.env_name = env_name
490
+ self.task_name = task_name
491
+
492
+ from_pixels = kwargs.get("from_pixels")
493
+ if "from_pixels" in kwargs:
494
+ del kwargs["from_pixels"]
495
+ pixels_only = kwargs.get("pixels_only")
496
+ if "pixels_only" in kwargs:
497
+ del kwargs["pixels_only"]
498
+
499
+ if not _has_dmc:
500
+ raise ImportError(
501
+ f"dm_control not found, unable to create {env_name}:"
502
+ f" {task_name}. Consider downloading and installing "
503
+ f"dm_control from {self.git_url}"
504
+ )
505
+
506
+ camera_id = kwargs.pop("camera_id", 0)
507
+ if _seed is not None:
508
+ random_state = np.random.RandomState(_seed)
509
+ kwargs["random"] = random_state
510
+ env = suite.load(env_name, task_name, task_kwargs=kwargs)
511
+ return super()._build_env(
512
+ env,
513
+ _seed=_seed,
514
+ from_pixels=from_pixels,
515
+ pixels_only=pixels_only,
516
+ camera_id=camera_id,
517
+ **kwargs,
518
+ )
519
+
520
+ def rebuild_with_kwargs(self, **new_kwargs):
521
+ self._constructor_kwargs.update(new_kwargs)
522
+ self._env = self._build_env()
523
+ self._make_specs(self._env)
524
+
525
+ def _check_kwargs(self, kwargs: dict):
526
+ if "env_name" in kwargs:
527
+ env_name = kwargs["env_name"]
528
+ if "task_name" in kwargs:
529
+ task_name = kwargs["task_name"]
530
+ available_envs = dict(self.available_envs)
531
+ if (
532
+ env_name not in available_envs
533
+ or task_name not in available_envs[env_name]
534
+ ):
535
+ raise RuntimeError(
536
+ f"{env_name} with task {task_name} is unknown in {self.libname}"
537
+ )
538
+ else:
539
+ raise TypeError("dm_control requires task_name to be specified")
540
+ else:
541
+ raise TypeError("dm_control requires env_name to be specified")
542
+
543
+ def __repr__(self) -> str:
544
+ return f"{self.__class__.__name__}(env={self.env_name}, task={self.task_name}, batch_size={self.batch_size})"