torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,138 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import importlib.util
9
+
10
+ import torch
11
+ from torchrl._utils import _make_ordinal_device
12
+ from torchrl.data.utils import DEVICE_TYPING
13
+ from torchrl.envs.common import EnvBase
14
+ from torchrl.envs.libs.gym import GymEnv, set_gym_backend
15
+ from torchrl.envs.utils import _classproperty
16
+
17
+ _has_habitat = importlib.util.find_spec("habitat") is not None
18
+
19
+
20
+ def _wrap_import_error(fun):
21
+ @functools.wraps(fun)
22
+ def new_fun(*args, **kwargs):
23
+ if not _has_habitat:
24
+ raise ImportError(
25
+ "Habitat could not be loaded. Consider installing "
26
+ "it or solving the import bugs (see attached error message). "
27
+ "Refer to TorchRL's knowledge base in the documentation to "
28
+ "debug habitat installation."
29
+ )
30
+ return fun(*args, **kwargs)
31
+
32
+ return new_fun
33
+
34
+
35
+ @_wrap_import_error
36
+ def _get_available_envs():
37
+ for env in GymEnv.available_envs:
38
+ if env.startswith("Habitat"):
39
+ yield env
40
+
41
+
42
+ class HabitatEnv(GymEnv):
43
+ """A wrapper for habitat envs.
44
+
45
+ This class currently serves as placeholder and compatibility security.
46
+ It behaves exactly like the GymEnv wrapper.
47
+
48
+ Doc: https://aihabitat.org/docs/
49
+
50
+ GitHub: https://github.com/facebookresearch/habitat-lab
51
+
52
+ URL: https://aihabitat.org/habitat3/
53
+
54
+ Paper: https://ai.meta.com/static-resource/habitat3
55
+
56
+ Args:
57
+ env_name (str): The environment to execute.
58
+ categorical_action_encoding (bool, optional): if ``True``, categorical
59
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
60
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
61
+ Defaults to ``False``.
62
+
63
+ Keyword Args:
64
+ from_pixels (bool, optional): if ``True``, an attempt to return the pixel
65
+ observations from the env will be performed. By default, these observations
66
+ will be written under the ``"pixels"`` entry.
67
+ The method being used varies
68
+ depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``.
69
+ Defaults to ``False``.
70
+ pixels_only (bool, optional): if ``True``, only the pixel observations will
71
+ be returned (by default under the ``"pixels"`` entry in the output tensordict).
72
+ If ``False``, observations (eg, states) and pixels will be returned
73
+ whenever ``from_pixels=True``. Defaults to ``True``.
74
+ frame_skip (int, optional): if provided, indicates for how many steps the
75
+ same action is to be repeated. The observation returned will be the
76
+ last observation of the sequence, whereas the reward will be the sum
77
+ of rewards across steps.
78
+ device (torch.device, optional): if provided, the device on which the simulation
79
+ will occur. Defaults to ``torch.device("cuda:0")``.
80
+ batch_size (torch.Size, optional): the batch size of the environment.
81
+ Should match the leading dimensions of all observations, done states,
82
+ rewards, actions and infos.
83
+ Defaults to ``torch.Size([])``.
84
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
85
+ for envs to be ``done`` just after :meth:`reset` is called.
86
+ Defaults to ``False``.
87
+
88
+ Attributes:
89
+ available_envs (List[str]): a list of environments to build.
90
+
91
+ Examples:
92
+ >>> from torchrl.envs import HabitatEnv
93
+ >>> env = HabitatEnv("HabitatRenderPick-v0", from_pixels=True)
94
+ >>> env.rollout(3)
95
+
96
+ """
97
+
98
+ @_wrap_import_error
99
+ @set_gym_backend("gym")
100
+ def __init__(self, env_name, **kwargs):
101
+ import habitat # noqa
102
+ import habitat.gym # noqa
103
+
104
+ device_num = torch.device(kwargs.pop("device", 0)).index
105
+ kwargs["override_options"] = [
106
+ f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}",
107
+ "habitat.simulator.concur_render=False",
108
+ ]
109
+ super().__init__(env_name=env_name, **kwargs)
110
+
111
+ @_classproperty
112
+ def available_envs(cls):
113
+ if not _has_habitat:
114
+ return []
115
+ return list(_get_available_envs())
116
+
117
+ def _build_gym_env(self, env, pixels_only):
118
+ if self.from_pixels:
119
+ env.reset()
120
+ return super()._build_gym_env(env, pixels_only)
121
+
122
+ def to(self, device: DEVICE_TYPING) -> EnvBase:
123
+ device = _make_ordinal_device(torch.device(device))
124
+ if device.type != "cuda":
125
+ raise ValueError("The device must be of type cuda for Habitat.")
126
+ device_num = device.index
127
+ kwargs = {"override_options": []}
128
+ for arg in self._constructor_kwargs.get("override_options", []):
129
+ if arg.startswith("habitat.simulator.habitat_sim_v0.gpu_device_id"):
130
+ arg = f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}"
131
+ kwargs["override_options"].append(arg)
132
+ else:
133
+ kwargs["override_options"].append(arg)
134
+
135
+ self._env.close()
136
+ del self._env
137
+ self.rebuild_with_kwargs(**kwargs)
138
+ return super().to(device)
@@ -0,0 +1,87 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ from torchrl.envs.libs.gym import GymWrapper
9
+
10
+
11
+ class IsaacLabWrapper(GymWrapper):
12
+ """A wrapper for IsaacLab environments.
13
+
14
+ Args:
15
+ env (scripts_isaaclab.envs.ManagerBasedRLEnv or equivalent): the environment instance to wrap.
16
+ categorical_action_encoding (bool, optional): if ``True``, categorical
17
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
18
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
19
+ Defaults to ``False``.
20
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
21
+ for envs to be ``done`` just after :meth:`reset` is called.
22
+ Defaults to ``False``.
23
+
24
+ For other arguments, see the :class:`torchrl.envs.GymWrapper` documentation.
25
+
26
+ Refer to `the Isaac Lab doc for installation instructions <https://isaac-sim.github.io/IsaacLab/main/source/setup/installation/pip_installation.html>`_.
27
+
28
+ Example:
29
+ >>> # This code block ensures that the Isaac app is started in headless mode
30
+ >>> from scripts_isaaclab.app import AppLauncher
31
+ >>> import argparse
32
+
33
+ >>> parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.")
34
+ >>> AppLauncher.add_app_launcher_args(parser)
35
+ >>> args_cli, hydra_args = parser.parse_known_args(["--headless"])
36
+ >>> app_launcher = AppLauncher(args_cli)
37
+
38
+ >>> # Imports and env
39
+ >>> import gymnasium as gym
40
+ >>> import isaaclab_tasks # noqa: F401
41
+ >>> from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
42
+ >>> from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
43
+
44
+ >>> env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
45
+ >>> env = IsaacLabWrapper(env)
46
+
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ env: isaaclab.envs.ManagerBasedRLEnv, # noqa: F821
52
+ *,
53
+ categorical_action_encoding: bool = False,
54
+ allow_done_after_reset: bool = True,
55
+ convert_actions_to_numpy: bool = False,
56
+ device: torch.device | None = None,
57
+ **kwargs,
58
+ ):
59
+ if device is None:
60
+ device = torch.device("cuda:0")
61
+ super().__init__(
62
+ env,
63
+ device=device,
64
+ categorical_action_encoding=categorical_action_encoding,
65
+ allow_done_after_reset=allow_done_after_reset,
66
+ convert_actions_to_numpy=convert_actions_to_numpy,
67
+ **kwargs,
68
+ )
69
+
70
+ def seed(self, seed: int | None):
71
+ self._set_seed(seed)
72
+
73
+ def _output_transform(self, step_outputs_tuple): # noqa: F811
74
+ # IsaacLab will modify the `terminated` and `truncated` tensors
75
+ # in-place. We clone them here to make sure data doesn't inadvertently get modified.
76
+ # The variable naming follows torchrl's convention here.
77
+ observations, reward, terminated, truncated, info = step_outputs_tuple
78
+ done = terminated | truncated
79
+ reward = reward.unsqueeze(-1) # to get to (num_envs, 1)
80
+ return (
81
+ observations,
82
+ reward,
83
+ terminated.clone(),
84
+ truncated.clone(),
85
+ done.clone(),
86
+ info,
87
+ )
@@ -0,0 +1,203 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import importlib.util
8
+ import itertools
9
+ import warnings
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ import torch
14
+ from tensordict import TensorDictBase
15
+ from torchrl.data.tensor_specs import Composite
16
+ from torchrl.envs.libs.gym import GymWrapper
17
+ from torchrl.envs.utils import _classproperty, make_composite_from_td
18
+
19
+ _has_isaac = importlib.util.find_spec("isaacgym") is not None
20
+
21
+
22
+ class IsaacGymWrapper(GymWrapper):
23
+ """Wrapper for IsaacGymEnvs environments.
24
+
25
+ The original library can be found `here <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_
26
+ and is based on IsaacGym which can be downloaded `through NVIDIA's webpage <https://developer.nvidia.com/isaac-gym>_`.
27
+
28
+ .. note:: IsaacGym environments cannot be executed consecutively, ie. instantiating one
29
+ environment after another (even if it has been cleared) will cause
30
+ CUDA memory issues. We recommend creating one environment per process only.
31
+ If you need more than one environment, the best way to achieve that is
32
+ to spawn them across processes.
33
+
34
+ .. note:: IsaacGym works on CUDA devices by essence. Make sure your machine
35
+ has GPUs available and the required setup for IsaacGym (eg, Ubuntu 20.04).
36
+
37
+ """
38
+
39
+ @property
40
+ def lib(self):
41
+ import isaacgym
42
+
43
+ return isaacgym
44
+
45
+ def __init__(
46
+ self, env: isaacgymenvs.tasks.base.vec_task.Env, **kwargs # noqa: F821
47
+ ):
48
+ warnings.warn(
49
+ "IsaacGym environment support is an experimental feature that may change in the future."
50
+ )
51
+ super().__init__(
52
+ env, torch.device(env.device), batch_size=torch.Size([]), **kwargs
53
+ )
54
+ if not hasattr(self, "task"):
55
+ # by convention in IsaacGymEnvs
56
+ self.task = env.__name__
57
+
58
+ def _make_specs(self, env: gym.Env) -> None: # noqa: F821
59
+ super()._make_specs(env, batch_size=self.batch_size)
60
+ self.full_done_spec = Composite(
61
+ {
62
+ key: spec.squeeze(-1)
63
+ for key, spec in self.full_done_spec.items(True, True)
64
+ },
65
+ shape=self.batch_size,
66
+ )
67
+
68
+ self.observation_spec["obs"] = self.observation_spec["observation"]
69
+ del self.observation_spec["observation"]
70
+
71
+ data = self.rollout(3).get("next")[..., 0]
72
+ del data[self.reward_key]
73
+ for done_key in self.done_keys:
74
+ try:
75
+ del data[done_key]
76
+ except KeyError:
77
+ continue
78
+ specs = make_composite_from_td(data)
79
+
80
+ obs_spec = self.observation_spec
81
+ obs_spec.unlock_(recurse=True)
82
+ obs_spec.update(specs)
83
+ obs_spec.lock_(recurse=True)
84
+
85
+ def _output_transform(self, output):
86
+ obs, reward, done, info = output
87
+ if self.from_pixels:
88
+ obs["pixels"] = self._env.render(mode="rgb_array")
89
+ return obs, reward, done ^ done, done, done, info
90
+
91
+ def _reset_output_transform(self, reset_data):
92
+ reset_data.pop("reward", None)
93
+ if self.from_pixels:
94
+ reset_data["pixels"] = self._env.render(mode="rgb_array")
95
+ return reset_data, {}
96
+
97
+ @classmethod
98
+ def _make_envs(cls, *, task, num_envs, device, seed=None, headless=False, **kwargs):
99
+ import isaacgym # noqa
100
+ import isaacgymenvs # noqa
101
+
102
+ _ = kwargs.pop("from_pixels", None)
103
+ envs = isaacgymenvs.make(
104
+ seed=seed,
105
+ task=task,
106
+ num_envs=num_envs,
107
+ sim_device=str(device),
108
+ rl_device=str(device),
109
+ headless=headless,
110
+ **kwargs,
111
+ )
112
+ return envs
113
+
114
+ def _set_seed(self, seed: int | None) -> None:
115
+ # as of #665c32170d84b4be66722eea405a1e08b6e7f761 the seed points nowhere in gym.make for IsaacGymEnvs
116
+ ...
117
+
118
+ def read_action(self, action):
119
+ """Reads the action obtained from the input TensorDict and transforms it in the format expected by the contained environment.
120
+
121
+ Args:
122
+ action (Tensor or TensorDict): an action to be taken in the environment
123
+
124
+ Returns: an action in a format compatible with the contained environment.
125
+
126
+ """
127
+ return action
128
+
129
+ def read_done(
130
+ self,
131
+ terminated: bool | None = None,
132
+ truncated: bool | None = None,
133
+ done: bool | None = None,
134
+ ) -> tuple[bool, bool, bool]:
135
+ if terminated is not None:
136
+ terminated = terminated.bool()
137
+ if truncated is not None:
138
+ truncated = truncated.bool()
139
+ if done is not None:
140
+ done = done.bool()
141
+ return terminated, truncated, done, done.any()
142
+
143
+ def read_reward(self, total_reward):
144
+ return total_reward
145
+
146
+ def read_obs(
147
+ self, observations: dict[str, Any] | torch.Tensor | np.ndarray
148
+ ) -> dict[str, Any]:
149
+ """Reads an observation from the environment and returns an observation compatible with the output TensorDict.
150
+
151
+ Args:
152
+ observations (observation under a format dictated by the inner env): observation to be read.
153
+
154
+ """
155
+ if isinstance(observations, dict):
156
+ if "state" in observations and "observation" not in observations:
157
+ # we rename "state" in "observation" as "observation" is the conventional name
158
+ # for single observation in torchrl.
159
+ # naming it 'state' will result in envs that have a different name for the state vector
160
+ # when queried with and without pixels
161
+ observations["observation"] = observations.pop("state")
162
+ if not isinstance(observations, (TensorDictBase, dict)):
163
+ (key,) = itertools.islice(self.observation_spec.keys(True, True), 1)
164
+ observations = {key: observations}
165
+ return observations
166
+
167
+
168
+ class IsaacGymEnv(IsaacGymWrapper):
169
+ """A TorchRL Env interface for IsaacGym environments.
170
+
171
+ See :class:`~.IsaacGymWrapper` for more information.
172
+
173
+ Examples:
174
+ >>> env = IsaacGymEnv(task="Ant", num_envs=2000, device="cuda:0")
175
+ >>> rollout = env.rollout(3)
176
+ >>> assert env.batch_size == (2000,)
177
+
178
+ """
179
+
180
+ @_classproperty
181
+ def available_envs(cls):
182
+ if not _has_isaac:
183
+ return []
184
+
185
+ import isaacgymenvs # noqa
186
+
187
+ return list(isaacgymenvs.tasks.isaacgym_task_map.keys())
188
+
189
+ def __init__(self, task=None, *, env=None, num_envs, device, **kwargs):
190
+ if env is not None and task is not None:
191
+ raise RuntimeError("Cannot provide both `task` and `env` arguments.")
192
+ elif env is not None:
193
+ task = env
194
+ from_pixels = kwargs.pop("from_pixels", False)
195
+ envs = self._make_envs(
196
+ task=task,
197
+ num_envs=num_envs,
198
+ device=device,
199
+ virtual_screen_capture=False,
200
+ **kwargs,
201
+ )
202
+ self.task = task
203
+ super().__init__(envs, from_pixels=from_pixels, **kwargs)
@@ -0,0 +1,166 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+ import importlib.util
9
+
10
+ # import jax
11
+ import numpy as np
12
+ import torch
13
+
14
+ # from jax import dlpack as jax_dlpack, numpy as jnp
15
+ from tensordict import make_tensordict, TensorDictBase
16
+ from torch.utils import dlpack as torch_dlpack
17
+ from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded
18
+ from torchrl.data.utils import numpy_to_torch_dtype_dict
19
+
20
+ _has_jax = importlib.util.find_spec("jax") is not None
21
+
22
+
23
+ def _tree_reshape(x, batch_size: torch.Size):
24
+ import jax
25
+
26
+ shape, n = batch_size, 1
27
+ return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x)
28
+
29
+
30
+ def _tree_flatten(x, batch_size: torch.Size):
31
+ import jax
32
+
33
+ shape, n = (batch_size.numel(),), len(batch_size)
34
+ return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x)
35
+
36
+
37
+ _dtype_conversion = {
38
+ np.dtype("uint16"): np.int16,
39
+ np.dtype("uint32"): np.int32,
40
+ np.dtype("uint64"): np.int64,
41
+ }
42
+
43
+
44
+ def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa: F821
45
+ from jax import numpy as jnp
46
+
47
+ # JAX arrays generated by jax.vmap would have Numpy dtypes.
48
+ if value.dtype in _dtype_conversion:
49
+ value = value.view(_dtype_conversion[value.dtype])
50
+ if isinstance(value, jnp.ndarray):
51
+ dlpack_tensor = value.__dlpack__()
52
+ elif isinstance(value, np.ndarray):
53
+ dlpack_tensor = value.__dlpack__()
54
+ else:
55
+ raise NotImplementedError(f"unsupported data type {type(value)}")
56
+ out = torch_dlpack.from_dlpack(dlpack_tensor)
57
+ # dtype can be messed up by dlpack
58
+ return out.to(numpy_to_torch_dtype_dict[value.dtype])
59
+
60
+
61
+ def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: # noqa: F821
62
+ from jax import dlpack as jax_dlpack
63
+
64
+ # Detach the tensor to remove gradients before converting to DLPack
65
+ value = value.contiguous().detach()
66
+ return jax_dlpack.from_dlpack(value)
67
+
68
+
69
+ def _get_object_fields(obj) -> dict:
70
+ """Converts an object (named tuple or dataclass or dict) to a dict."""
71
+ if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
72
+ return dict(zip(obj._fields, obj))
73
+ elif dataclasses.is_dataclass(obj):
74
+ return {
75
+ field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)
76
+ }
77
+ elif isinstance(obj, dict):
78
+ return obj
79
+ elif obj is None:
80
+ return {}
81
+ else:
82
+ raise NotImplementedError(f"unsupported data type {type(obj)}")
83
+
84
+
85
+ def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:
86
+ """Converts a namedtuple or a dataclass to a TensorDict."""
87
+ from jax import numpy as jnp
88
+
89
+ t = {}
90
+ _fields = _get_object_fields(obj)
91
+ for name, value in _fields.items():
92
+ if isinstance(value, (np.number, int, float)):
93
+ t[name] = _ndarray_to_tensor(np.asarray([value])).to(device)
94
+ elif isinstance(value, (jnp.ndarray, np.ndarray)):
95
+ t[name] = _ndarray_to_tensor(value).to(device)
96
+ else:
97
+ nested = _object_to_tensordict(value, device, batch_size)
98
+ if nested is not None:
99
+ t[name] = nested
100
+ if len(t):
101
+ return make_tensordict(t, device=device, batch_size=batch_size)
102
+ # discard empty tensordicts
103
+ return None
104
+
105
+
106
+ def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size=None):
107
+ """Converts a TensorDict to a namedtuple or a dataclass."""
108
+ from jax import dlpack as jax_dlpack, numpy as jnp
109
+
110
+ if batch_size is None:
111
+ batch_size = []
112
+ t = {}
113
+ _fields = _get_object_fields(object_example)
114
+ for name, example in _fields.items():
115
+ value = tensordict.get(name, None)
116
+ if isinstance(value, TensorDictBase):
117
+ t[name] = _tensordict_to_object(value, example, batch_size=batch_size)
118
+ elif value is None:
119
+ if isinstance(example, dict):
120
+ t[name] = _tensordict_to_object({}, example, batch_size=batch_size)
121
+ else:
122
+ t[name] = None
123
+ else:
124
+ if value.dtype is torch.bool:
125
+ value = value.to(torch.uint8)
126
+ shape = value.shape
127
+ # We need to flatten to fix https://github.com/pytorch/rl/issues/2184
128
+ value = value.contiguous()
129
+ value = value.detach()
130
+ if value.ndim > 1:
131
+ value = value.flatten().clone()
132
+ else:
133
+ # Need this because otherwise an exception is raised
134
+ # ValueError: INTERNAL: Address of buffer 1 must be a multiple of 10, but was 0x7efccec00824
135
+ value = value.clone()
136
+ value = jax_dlpack.from_dlpack(value)
137
+ if shape.numel() == 1 and not value.shape:
138
+ while value.shape != shape:
139
+ value = jnp.expand_dims(value, 0)
140
+ if value.dtype != example.dtype:
141
+ t[name] = value.view(example.dtype)
142
+ else:
143
+ t[name] = value
144
+ else:
145
+ value = jnp.reshape(value, tuple(shape))
146
+ t[name] = value.view(example.dtype).reshape(
147
+ (*batch_size, *example.shape)
148
+ )
149
+ return type(object_example)(**t)
150
+
151
+
152
+ def _extract_spec(data: torch.Tensor | TensorDictBase, key=None) -> TensorSpec:
153
+ if isinstance(data, torch.Tensor):
154
+ shape = data.shape
155
+ if key in ("reward", "done"):
156
+ shape = (*shape, 1)
157
+ if data.dtype in (torch.float, torch.double, torch.half):
158
+ return Unbounded(shape=shape, dtype=data.dtype, device=data.device)
159
+ else:
160
+ return Unbounded(shape=shape, dtype=data.dtype, device=data.device)
161
+ elif isinstance(data, TensorDictBase):
162
+ return Composite(
163
+ {key: _extract_spec(value, key=key) for key, value in data.items()}
164
+ )
165
+ else:
166
+ raise TypeError(f"Unsupported data type {type(data)}")