torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,203 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Gym-specific transforms."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+
12
+ import torch
13
+ from tensordict import TensorDictBase
14
+ from tensordict.utils import expand_as_right, NestedKey
15
+ from torchrl.data.tensor_specs import Unbounded
16
+
17
+ from torchrl.envs.transforms.transforms import FORWARD_NOT_IMPLEMENTED, Transform
18
+
19
+
20
+ class EndOfLifeTransform(Transform):
21
+ """Registers the end-of-life signal from a Gym env with a `lives` method.
22
+
23
+ Proposed by DeepMind for the DQN and co. It helps value estimation.
24
+
25
+ Args:
26
+ eol_key (NestedKey, optional): the key where the end-of-life signal should
27
+ be written. Defaults to ``"end-of-life"``.
28
+ done_key (NestedKey, optional): a "done" key in the parent env done_spec,
29
+ where the done value can be retrieved. This key must be unique and its
30
+ shape must match the shape of the end-of-life entry. Defaults to ``"done"``.
31
+ eol_attribute (str, optional): the location of the "lives" in the gym env.
32
+ Defaults to ``"unwrapped.ale.lives"``. Supported attribute types are
33
+ integer/array-like objects or callables that return these values.
34
+
35
+ .. note::
36
+ This transform should be used with gym envs that have a ``env.unwrapped.ale.lives``.
37
+
38
+ Examples:
39
+ >>> from torchrl.envs.libs.gym import GymEnv
40
+ >>> from torchrl.envs.transforms.transforms import TransformedEnv
41
+ >>> env = GymEnv("ALE/Breakout-v5")
42
+ >>> env.rollout(100)
43
+ TensorDict(
44
+ fields={
45
+ action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.int64, is_shared=False),
46
+ done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
47
+ next: TensorDict(
48
+ fields={
49
+ done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
50
+ pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
51
+ reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
52
+ terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
53
+ truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
54
+ batch_size=torch.Size([100]),
55
+ device=cpu,
56
+ is_shared=False),
57
+ pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
58
+ terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
59
+ truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
60
+ batch_size=torch.Size([100]),
61
+ device=cpu,
62
+ is_shared=False)
63
+ >>> eol_transform = EndOfLifeTransform()
64
+ >>> env = TransformedEnv(env, eol_transform)
65
+ >>> env.rollout(100)
66
+ TensorDict(
67
+ fields={
68
+ action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.int64, is_shared=False),
69
+ done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
70
+ eol: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
71
+ lives: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
72
+ next: TensorDict(
73
+ fields={
74
+ done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
75
+ end-of-life: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
76
+ lives: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
77
+ pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
78
+ reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
79
+ terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
80
+ truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
81
+ batch_size=torch.Size([100]),
82
+ device=cpu,
83
+ is_shared=False),
84
+ pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
85
+ terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
86
+ truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
87
+ batch_size=torch.Size([100]),
88
+ device=cpu,
89
+ is_shared=False)
90
+
91
+ The typical usage of this transform is to replace the "done" state by "end-of-life"
92
+ within the loss module. The end-of-life signal isn't registered within the ``done_spec``
93
+ because it should not instruct the env to reset.
94
+
95
+ Examples:
96
+ >>> from torchrl.objectives import DQNLoss
97
+ >>> module = torch.nn.Identity() # used as a placeholder
98
+ >>> loss = DQNLoss(module, action_space="categorical")
99
+ >>> loss.set_keys(done="end-of-life", terminated="end-of-life")
100
+ >>> # equivalently
101
+ >>> eol_transform.register_keys(loss)
102
+ """
103
+
104
+ NO_PARENT_ERR = "The {} transform is being executed without a parent env. This is currently not supported."
105
+
106
+ def __init__(
107
+ self,
108
+ eol_key: NestedKey = "end-of-life",
109
+ lives_key: NestedKey = "lives",
110
+ done_key: NestedKey = "done",
111
+ eol_attribute="unwrapped.ale.lives",
112
+ ):
113
+ super().__init__(in_keys=[done_key], out_keys=[eol_key, lives_key])
114
+ self.eol_key = eol_key
115
+ self.lives_key = lives_key
116
+ self.done_key = done_key
117
+ self.eol_attribute = eol_attribute.split(".")
118
+
119
+ def _get_lives(self):
120
+ from torchrl.envs.libs.gym import GymWrapper
121
+
122
+ base_env = self.parent.base_env
123
+ if not isinstance(base_env, GymWrapper):
124
+ warnings.warn(
125
+ f"The base_env is not a gym env. Compatibility of {type(self)} is not guaranteed with "
126
+ f"environment types that do not inherit from GymWrapper.",
127
+ category=UserWarning,
128
+ )
129
+ # getattr falls back on _env by default
130
+ lives = getattr(base_env, self.eol_attribute[0])
131
+ for att in self.eol_attribute[1:]:
132
+ if isinstance(lives, list):
133
+ # For SerialEnv (and who knows Parallel one day)
134
+ lives = [getattr(_lives, att) for _lives in lives]
135
+ else:
136
+ lives = getattr(lives, att)
137
+ if callable(lives):
138
+ lives = lives()
139
+ elif isinstance(lives, list) and all(callable(_lives) for _lives in lives):
140
+ lives = torch.as_tensor([_lives() for _lives in lives])
141
+ return lives
142
+
143
+ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
144
+ return next_tensordict
145
+
146
+ def _step(self, tensordict, next_tensordict):
147
+ parent = self.parent
148
+ if parent is None:
149
+ raise RuntimeError(self.NO_PARENT_ERR.format(type(self)))
150
+
151
+ lives = self._get_lives()
152
+ end_of_life = torch.as_tensor(
153
+ tensordict.get(self.lives_key) > lives, device=self.parent.device
154
+ )
155
+ done = next_tensordict.get(self.done_key, None) # TODO: None soon to be removed
156
+ if done is None:
157
+ raise KeyError(
158
+ f"The done value pointed by {self.done_key} cannot be found in tensordict with keys {tensordict.keys(True, True)}. "
159
+ f"Make sure to pass the appropriate done_key to the {type(self)} transform."
160
+ )
161
+ end_of_life = expand_as_right(end_of_life, done) | done
162
+ next_tensordict.set(self.eol_key, end_of_life)
163
+ next_tensordict.set(self.lives_key, lives)
164
+ return next_tensordict
165
+
166
+ def _reset(self, tensordict, tensordict_reset):
167
+ parent = self.parent
168
+ if parent is None:
169
+ raise RuntimeError(self.NO_PARENT_ERR.format(type(self)))
170
+ lives = self._get_lives()
171
+ end_of_life = False
172
+ tensordict_reset.set(
173
+ self.eol_key,
174
+ torch.as_tensor(end_of_life).expand(
175
+ parent.full_done_spec[self.done_key].shape
176
+ ),
177
+ )
178
+ tensordict_reset.set(self.lives_key, lives)
179
+ return tensordict_reset
180
+
181
+ def transform_observation_spec(self, observation_spec):
182
+ full_done_spec = self.parent.output_spec["full_done_spec"]
183
+ observation_spec[self.eol_key] = full_done_spec[self.done_key].clone()
184
+ observation_spec[self.lives_key] = Unbounded(
185
+ self.parent.batch_size,
186
+ device=self.parent.device,
187
+ dtype=torch.int64,
188
+ )
189
+ return observation_spec
190
+
191
+ def register_keys(
192
+ self, loss_or_advantage: torchrl.objectives.common.LossModule # noqa
193
+ ):
194
+ """Registers the end-of-life key at appropriate places within the loss.
195
+
196
+ Args:
197
+ loss_or_advantage (torchrl.objectives.LossModule or torchrl.objectives.value.ValueEstimatorBase): a module to instruct what the end-of-life key is.
198
+
199
+ """
200
+ loss_or_advantage.set_keys(done=self.eol_key, terminated=self.eol_key)
201
+
202
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
203
+ raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self)))
@@ -0,0 +1,341 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Callable
8
+ from contextlib import nullcontext
9
+ from typing import overload, TYPE_CHECKING
10
+
11
+ import torch
12
+ from tensordict import TensorDictBase
13
+ from tensordict.nn import TensorDictModuleBase
14
+ from torchrl._utils import logger as torchrl_logger
15
+
16
+ from torchrl.data.tensor_specs import TensorSpec
17
+ from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform
18
+ from torchrl.envs.transforms.transforms import Transform
19
+
20
+ if TYPE_CHECKING:
21
+ from torchrl.weight_update import WeightSyncScheme
22
+
23
+ __all__ = ["ModuleTransform", "RayModuleTransform"]
24
+
25
+
26
+ class RayModuleTransform(RayTransform):
27
+ """Ray-based ModuleTransform for distributed processing.
28
+
29
+ This transform creates a Ray actor that wraps a ModuleTransform,
30
+ allowing module execution in a separate Ray worker process.
31
+
32
+ Args:
33
+ weight_sync_scheme: Optional weight synchronization scheme for updating
34
+ the module's weights from a parent collector. When provided, the scheme
35
+ is initialized on the receiver side (the Ray actor) and can receive
36
+ weight updates via torch.distributed.
37
+ **kwargs: Additional arguments passed to RayTransform and ModuleTransform.
38
+
39
+ Example:
40
+ >>> from torchrl.weight_update import RayModuleTransformScheme
41
+ >>> scheme = RayModuleTransformScheme()
42
+ >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme)
43
+ >>> # The scheme can then be registered with a collector for weight updates
44
+ """
45
+
46
+ def __init__(self, *, weight_sync_scheme=None, **kwargs):
47
+ self._weight_sync_scheme = weight_sync_scheme
48
+ super().__init__(**kwargs)
49
+
50
+ # After actor is created, initialize the scheme on the receiver side
51
+ if weight_sync_scheme is not None:
52
+ # Store transform reference in the scheme for sender initialization
53
+ weight_sync_scheme._set_transform(self)
54
+
55
+ weight_sync_scheme.init_on_sender()
56
+
57
+ # Initialize receiver in the actor
58
+ torchrl_logger.debug(
59
+ "Setting up weight sync scheme on sender -- sender will do the remote call"
60
+ )
61
+ weight_sync_scheme.connect()
62
+
63
+ @property
64
+ def in_keys(self):
65
+ return self._ray.get(self._actor._getattr.remote("in_keys"))
66
+
67
+ @property
68
+ def out_keys(self):
69
+ return self._ray.get(self._actor._getattr.remote("out_keys"))
70
+
71
+ def _create_actor(self, **kwargs):
72
+ import ray
73
+
74
+ remote = self._ray.remote(ModuleTransform)
75
+ ray_kwargs = {}
76
+ num_gpus = self._num_gpus
77
+ if num_gpus is not None:
78
+ ray_kwargs["num_gpus"] = num_gpus
79
+ num_cpus = self._num_cpus
80
+ if num_cpus is not None:
81
+ ray_kwargs["num_cpus"] = num_cpus
82
+ actor_name = self._actor_name
83
+ if actor_name is not None:
84
+ ray_kwargs["name"] = actor_name
85
+ if ray_kwargs:
86
+ remote = remote.options(**ray_kwargs)
87
+ actor = remote.remote(**kwargs)
88
+ # wait till the actor is ready
89
+ ray.get(actor._ready.remote())
90
+ return actor
91
+
92
+ @overload
93
+ def update_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
94
+ ...
95
+
96
+ @overload
97
+ def update_weights(self, params: TensorDictBase) -> None:
98
+ ...
99
+
100
+ def update_weights(self, *args, **kwargs) -> None:
101
+ import ray
102
+
103
+ if self._update_weights_method == "tensordict":
104
+ try:
105
+ td = kwargs.get("params", args[0])
106
+ except IndexError:
107
+ raise ValueError("params must be provided")
108
+ return ray.get(self._actor._update_weights_tensordict.remote(params=td))
109
+ elif self._update_weights_method == "state_dict":
110
+ try:
111
+ state_dict = kwargs.get("state_dict", args[0])
112
+ except IndexError:
113
+ raise ValueError("state_dict must be provided")
114
+ return ray.get(
115
+ self._actor._update_weights_state_dict.remote(state_dict=state_dict)
116
+ )
117
+ else:
118
+ raise ValueError(
119
+ f"Invalid update_weights_method: {self._update_weights_method}"
120
+ )
121
+
122
+
123
+ class ModuleTransform(Transform, metaclass=_RayServiceMetaClass):
124
+ """A transform that wraps a module.
125
+
126
+ Keyword Args:
127
+ module (TensorDictModuleBase): The module to wrap. Exclusive with `module_factory`. At least one of `module` or `module_factory` must be provided.
128
+ module_factory (Callable[[], TensorDictModuleBase]): The factory to create the module. Exclusive with `module`. At least one of `module` or `module_factory` must be provided.
129
+ no_grad (bool, optional): Whether to use gradient computation. Default is `False`.
130
+ inverse (bool, optional): Whether to use the inverse of the module. Default is `False`.
131
+ device (torch.device, optional): The device to use. Default is `None`.
132
+ use_ray_service (bool, optional): Whether to use Ray service. Default is `False`.
133
+ num_gpus (int, optional): The number of GPUs to use if using Ray. Default is `None`.
134
+ num_cpus (int, optional): The number of CPUs to use if using Ray. Default is `None`.
135
+ actor_name (str, optional): The name of the actor to use. Default is `None`. If an actor name is provided and
136
+ an actor with this name already exists, the existing actor will be used.
137
+ observation_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the observation
138
+ after it has been transformed by the module, or a function that modifies the existing spec.
139
+ Defaults to `None` (observation specs remain unchanged).
140
+ done_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the done
141
+ after it has been transformed by the module, or a function that modifies the existing spec.
142
+ Defaults to `None` (done specs remain unchanged).
143
+ reward_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the reward
144
+ after it has been transformed by the module, or a function that modifies the existing spec.
145
+ Defaults to `None` (reward specs remain unchanged).
146
+ state_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the state
147
+ after it has been transformed by the module, or a function that modifies the existing spec.
148
+ Defaults to `None` (state specs remain unchanged).
149
+ action_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the action
150
+ after it has been transformed by the module, or a function that modifies the existing spec.
151
+ Defaults to `None` (action specs remain unchanged).
152
+ """
153
+
154
+ _RayServiceClass = RayModuleTransform
155
+
156
+ def __init__(
157
+ self,
158
+ *,
159
+ module: TensorDictModuleBase | None = None,
160
+ module_factory: Callable[[], TensorDictModuleBase] | None = None,
161
+ no_grad: bool = False,
162
+ inverse: bool = False,
163
+ device: torch.device | None = None,
164
+ use_ray_service: bool = False, # noqa
165
+ actor_name: str | None = None, # noqa
166
+ num_gpus: int | None = None,
167
+ num_cpus: int | None = None,
168
+ observation_spec_transform: TensorSpec
169
+ | Callable[[TensorSpec], TensorSpec]
170
+ | None = None,
171
+ action_spec_transform: TensorSpec
172
+ | Callable[[TensorSpec], TensorSpec]
173
+ | None = None,
174
+ reward_spec_transform: TensorSpec
175
+ | Callable[[TensorSpec], TensorSpec]
176
+ | None = None,
177
+ done_spec_transform: TensorSpec
178
+ | Callable[[TensorSpec], TensorSpec]
179
+ | None = None,
180
+ state_spec_transform: TensorSpec
181
+ | Callable[[TensorSpec], TensorSpec]
182
+ | None = None,
183
+ ):
184
+ super().__init__()
185
+ if module is None and module_factory is None:
186
+ raise ValueError(
187
+ "At least one of `module` or `module_factory` must be provided."
188
+ )
189
+ if module is not None and module_factory is not None:
190
+ raise ValueError(
191
+ "Only one of `module` or `module_factory` must be provided."
192
+ )
193
+ self.module = module if module is not None else module_factory()
194
+ self.no_grad = no_grad
195
+ self.inverse = inverse
196
+ self.device = device
197
+ self.observation_spec_transform = observation_spec_transform
198
+ self.action_spec_transform = action_spec_transform
199
+ self.reward_spec_transform = reward_spec_transform
200
+ self.done_spec_transform = done_spec_transform
201
+ self.state_spec_transform = state_spec_transform
202
+
203
+ @property
204
+ def in_keys(self) -> list[str]:
205
+ return self._in_keys()
206
+
207
+ def _in_keys(self):
208
+ return self.module.in_keys if not self.inverse else []
209
+
210
+ @in_keys.setter
211
+ def in_keys(self, value: list[str] | None):
212
+ if value is not None:
213
+ raise RuntimeError(f"in_keys {value} cannot be set for ModuleTransform")
214
+
215
+ @property
216
+ def out_keys(self) -> list[str]:
217
+ return self._out_keys()
218
+
219
+ def _out_keys(self):
220
+ return self.module.out_keys if not self.inverse else []
221
+
222
+ @property
223
+ def in_keys_inv(self) -> list[str]:
224
+ return self._in_keys_inv()
225
+
226
+ def _in_keys_inv(self):
227
+ return self.module.out_keys if self.inverse else []
228
+
229
+ @in_keys_inv.setter
230
+ def in_keys_inv(self, value: list[str]):
231
+ if value is not None:
232
+ raise RuntimeError(f"in_keys_inv {value} cannot be set for ModuleTransform")
233
+
234
+ @property
235
+ def out_keys_inv(self) -> list[str]:
236
+ return self._out_keys_inv()
237
+
238
+ def _out_keys_inv(self):
239
+ return self.module.in_keys if self.inverse else []
240
+
241
+ @out_keys_inv.setter
242
+ def out_keys_inv(self, value: list[str] | None):
243
+ if value is not None:
244
+ raise RuntimeError(
245
+ f"out_keys_inv {value} cannot be set for ModuleTransform"
246
+ )
247
+
248
+ @out_keys.setter
249
+ def out_keys(self, value: list[str] | None):
250
+ if value is not None:
251
+ raise RuntimeError(f"out_keys {value} cannot be set for ModuleTransform")
252
+
253
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
254
+ return self._call(tensordict)
255
+
256
+ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
257
+ if self.inverse:
258
+ return tensordict
259
+ with torch.no_grad() if self.no_grad else nullcontext():
260
+ with (
261
+ tensordict.to(self.device)
262
+ if self.device is not None
263
+ else nullcontext(tensordict)
264
+ ) as td:
265
+ return self.module(td)
266
+
267
+ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
268
+ if not self.inverse:
269
+ return tensordict
270
+ with torch.no_grad() if self.no_grad else nullcontext():
271
+ with (
272
+ tensordict.to(self.device)
273
+ if self.device is not None
274
+ else nullcontext(tensordict)
275
+ ) as td:
276
+ return self.module(td)
277
+
278
+ def _update_weights_tensordict(self, params: TensorDictBase) -> None:
279
+ params.to_module(self.module)
280
+
281
+ def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
282
+ self.module.load_state_dict(state_dict)
283
+
284
+ def _init_weight_sync_scheme(self, scheme: WeightSyncScheme, model_id: str) -> None:
285
+ """Initialize weight sync scheme on the receiver side (called in Ray actor).
286
+
287
+ This method is called by RayModuleTransform after the actor is created
288
+ to set up the receiver side of the weight synchronization scheme.
289
+
290
+ Args:
291
+ scheme: The weight sync scheme instance (e.g., RayModuleTransformScheme).
292
+ model_id: Identifier for the model being synchronized.
293
+ """
294
+ torchrl_logger.debug(f"Initializing weight sync scheme for {model_id=}")
295
+ scheme.init_on_receiver(model_id=model_id, context=self)
296
+ torchrl_logger.debug(f"Setup weight sync scheme for {model_id=}")
297
+ scheme.connect()
298
+ self._weight_sync_scheme = scheme
299
+
300
+ def _receive_weights_scheme(self):
301
+ self._weight_sync_scheme.receive()
302
+
303
+ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
304
+ if self.observation_spec_transform is not None:
305
+ if isinstance(self.observation_spec_transform, TensorSpec):
306
+ return self.observation_spec_transform
307
+ else:
308
+ return self.observation_spec_transform(observation_spec)
309
+ return observation_spec
310
+
311
+ def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec:
312
+ if self.action_spec_transform is not None:
313
+ if isinstance(self.action_spec_transform, TensorSpec):
314
+ return self.action_spec_transform
315
+ else:
316
+ return self.action_spec_transform(action_spec)
317
+ return action_spec
318
+
319
+ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
320
+ if self.reward_spec_transform is not None:
321
+ if isinstance(self.reward_spec_transform, TensorSpec):
322
+ return self.reward_spec_transform
323
+ else:
324
+ return self.reward_spec_transform(reward_spec)
325
+ return reward_spec
326
+
327
+ def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec:
328
+ if self.done_spec_transform is not None:
329
+ if isinstance(self.done_spec_transform, TensorSpec):
330
+ return self.done_spec_transform
331
+ else:
332
+ return self.done_spec_transform(done_spec)
333
+ return done_spec
334
+
335
+ def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec:
336
+ if self.state_spec_transform is not None:
337
+ if isinstance(self.state_spec_transform, TensorSpec):
338
+ return self.state_spec_transform
339
+ else:
340
+ return self.state_spec_transform(state_spec)
341
+ return state_spec