torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2720 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import random
8
+ import string
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from tensordict import tensorclass, TensorDict, TensorDictBase
14
+ from tensordict.nn import TensorDictModuleBase
15
+ from tensordict.utils import expand_right, NestedKey
16
+ from torchrl._utils import logger as torchrl_logger
17
+ from torchrl.data import (
18
+ Binary,
19
+ Bounded,
20
+ Categorical,
21
+ Composite,
22
+ MultiOneHot,
23
+ NonTensor,
24
+ OneHot,
25
+ TensorSpec,
26
+ Unbounded,
27
+ )
28
+ from torchrl.data.utils import consolidate_spec
29
+ from torchrl.envs import Transform
30
+ from torchrl.envs.common import EnvBase
31
+ from torchrl.envs.model_based.common import ModelBasedEnvBase
32
+ from torchrl.envs.utils import (
33
+ _terminated_or_truncated,
34
+ check_marl_grouping,
35
+ MarlGroupMapType,
36
+ )
37
+
38
+ spec_dict = {
39
+ "bounded": Bounded,
40
+ "one_hot": OneHot,
41
+ "categorical": Categorical,
42
+ "unbounded": Unbounded,
43
+ "binary": Binary,
44
+ "mult_one_hot": MultiOneHot,
45
+ "composite": Composite,
46
+ }
47
+
48
+ default_spec_kwargs = {
49
+ OneHot: {"n": 7},
50
+ Categorical: {"n": 7},
51
+ Bounded: {"minimum": -torch.ones(4), "maximum": torch.ones(4)},
52
+ Unbounded: {
53
+ "shape": [
54
+ 7,
55
+ ]
56
+ },
57
+ Binary: {"n": 7},
58
+ MultiOneHot: {"nvec": [7, 3, 5]},
59
+ Composite: {},
60
+ }
61
+
62
+
63
+ def make_spec(spec_str):
64
+ """Create a spec instance from a short spec name."""
65
+ target_class = spec_dict[spec_str]
66
+ return target_class(**default_spec_kwargs[target_class])
67
+
68
+
69
+ class _MockEnv(EnvBase):
70
+ @classmethod
71
+ def __new__(
72
+ cls,
73
+ *args,
74
+ **kwargs,
75
+ ):
76
+ for key, item in list(cls._output_spec["full_observation_spec"].items()):
77
+ cls._output_spec["full_observation_spec"][key] = item.to(
78
+ torch.get_default_dtype()
79
+ )
80
+ reward_spec = cls._output_spec["full_reward_spec"]
81
+ if isinstance(reward_spec, Composite):
82
+ reward_spec = Composite(
83
+ {
84
+ key: item.to(torch.get_default_dtype())
85
+ for key, item in reward_spec.items(True, True)
86
+ },
87
+ shape=reward_spec.shape,
88
+ device=reward_spec.device,
89
+ )
90
+ else:
91
+ reward_spec = reward_spec.to(torch.get_default_dtype())
92
+ cls._output_spec["full_reward_spec"] = reward_spec
93
+ if not isinstance(cls._output_spec["full_reward_spec"], Composite):
94
+ cls._output_spec["full_reward_spec"] = Composite(
95
+ reward=cls._output_spec["full_reward_spec"],
96
+ shape=cls._output_spec["full_reward_spec"].shape[:-1],
97
+ )
98
+ if not isinstance(cls._output_spec["full_done_spec"], Composite):
99
+ cls._output_spec["full_done_spec"] = Composite(
100
+ done=cls._output_spec["full_done_spec"].clone(),
101
+ terminated=cls._output_spec["full_done_spec"].clone(),
102
+ shape=cls._output_spec["full_done_spec"].shape[:-1],
103
+ )
104
+ if not isinstance(cls._input_spec["full_action_spec"], Composite):
105
+ cls._input_spec["full_action_spec"] = Composite(
106
+ action=cls._input_spec["full_action_spec"],
107
+ shape=cls._input_spec["full_action_spec"].shape[:-1],
108
+ )
109
+ dtype = kwargs.pop("dtype", torch.get_default_dtype())
110
+ for spec in (cls._output_spec, cls._input_spec):
111
+ if dtype != torch.get_default_dtype():
112
+ for key, val in list(spec.items(True, True)):
113
+ if val.dtype == torch.get_default_dtype():
114
+ val = val.to(dtype)
115
+ spec[key] = val
116
+ return super().__new__(cls, *args, **kwargs)
117
+
118
+ def __init__(
119
+ self,
120
+ *args,
121
+ seed: int = 100,
122
+ **kwargs,
123
+ ):
124
+ super().__init__(
125
+ device=kwargs.pop("device", "cpu"),
126
+ allow_done_after_reset=kwargs.pop("allow_done_after_reset", False),
127
+ )
128
+ self.set_seed(seed)
129
+ self.is_closed = False
130
+
131
+ @property
132
+ def maxstep(self) -> int:
133
+ return 100
134
+
135
+ def _set_seed(self, seed: int | None) -> None:
136
+ self.seed = seed
137
+ self.counter = seed % 17 # make counter a small number
138
+
139
+ def custom_fun(self) -> int:
140
+ return 0
141
+
142
+ custom_attr = 1
143
+
144
+ @property
145
+ def custom_prop(self) -> int:
146
+ return 2
147
+
148
+ @property
149
+ def custom_td(self) -> TensorDict:
150
+ return TensorDict({"a": torch.zeros(3)}, [])
151
+
152
+
153
+ class MockSerialEnv(EnvBase):
154
+ """A simple counting env that is reset after a predefined max number of steps."""
155
+
156
+ @classmethod
157
+ def __new__(
158
+ cls,
159
+ *args,
160
+ observation_spec=None,
161
+ action_spec=None,
162
+ state_spec=None,
163
+ reward_spec=None,
164
+ done_spec=None,
165
+ **kwargs,
166
+ ):
167
+ batch_size = kwargs.setdefault("batch_size", torch.Size([]))
168
+ if action_spec is None:
169
+ action_spec = Unbounded(
170
+ (
171
+ *batch_size,
172
+ 1,
173
+ )
174
+ )
175
+ if observation_spec is None:
176
+ observation_spec = Composite(
177
+ observation=Unbounded(
178
+ (
179
+ *batch_size,
180
+ 1,
181
+ )
182
+ ),
183
+ shape=batch_size,
184
+ )
185
+ if reward_spec is None:
186
+ reward_spec = Unbounded(
187
+ (
188
+ *batch_size,
189
+ 1,
190
+ )
191
+ )
192
+ if done_spec is None:
193
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
194
+ if state_spec is None:
195
+ state_spec = Composite(shape=batch_size)
196
+ input_spec = Composite(
197
+ full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size
198
+ )
199
+ cls._output_spec = Composite(shape=batch_size)
200
+ cls._output_spec["full_reward_spec"] = reward_spec
201
+ cls._output_spec["full_done_spec"] = done_spec
202
+ cls._output_spec["full_observation_spec"] = observation_spec
203
+ cls._input_spec = input_spec
204
+
205
+ if not isinstance(cls._output_spec["full_reward_spec"], Composite):
206
+ cls._output_spec["full_reward_spec"] = Composite(
207
+ reward=cls._output_spec["full_reward_spec"], shape=batch_size
208
+ )
209
+ if not isinstance(cls._output_spec["full_done_spec"], Composite):
210
+ cls._output_spec["full_done_spec"] = Composite(
211
+ done=cls._output_spec["full_done_spec"], shape=batch_size
212
+ )
213
+ if not isinstance(cls._input_spec["full_action_spec"], Composite):
214
+ cls._input_spec["full_action_spec"] = Composite(
215
+ action=cls._input_spec["full_action_spec"], shape=batch_size
216
+ )
217
+ return super().__new__(*args, **kwargs)
218
+
219
+ def __init__(self, device="cpu"):
220
+ super().__init__(device=device)
221
+ self.is_closed = False
222
+
223
+ def _set_seed(self, seed: int | None) -> None:
224
+ assert seed >= 1
225
+ self.seed = seed
226
+ self.counter = seed % 17 # make counter a small number
227
+ self.max_val = max(self.counter + 100, self.counter * 2)
228
+
229
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
230
+ self.counter += 1
231
+ n = torch.tensor(
232
+ [self.counter], device=self.device, dtype=torch.get_default_dtype()
233
+ )
234
+ done = self.counter >= self.max_val
235
+ done = torch.tensor([done], dtype=torch.bool, device=self.device)
236
+ return TensorDict(
237
+ {
238
+ "reward": n,
239
+ "done": done,
240
+ "terminated": done.clone(),
241
+ "observation": n.clone(),
242
+ },
243
+ batch_size=[],
244
+ device=self.device,
245
+ )
246
+
247
+ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
248
+ self.max_val = max(self.counter + 100, self.counter * 2)
249
+
250
+ n = torch.tensor(
251
+ [self.counter], device=self.device, dtype=torch.get_default_dtype()
252
+ )
253
+ done = self.counter >= self.max_val
254
+ done = torch.tensor([done], dtype=torch.bool, device=self.device)
255
+ return TensorDict(
256
+ {"done": done, "terminated": done.clone(), "observation": n},
257
+ [],
258
+ device=self.device,
259
+ )
260
+
261
+ def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
262
+ return self.step(tensordict)
263
+
264
+
265
+ class MockBatchedLockedEnv(EnvBase):
266
+ """Mocks an env whose batch_size defines the size of the output tensordict."""
267
+
268
+ @classmethod
269
+ def __new__(
270
+ cls,
271
+ *args,
272
+ observation_spec=None,
273
+ action_spec=None,
274
+ state_spec=None,
275
+ reward_spec=None,
276
+ done_spec=None,
277
+ **kwargs,
278
+ ):
279
+ batch_size = kwargs.setdefault("batch_size", torch.Size([]))
280
+ if action_spec is None:
281
+ action_spec = Unbounded(
282
+ (
283
+ *batch_size,
284
+ 1,
285
+ )
286
+ )
287
+ if state_spec is None:
288
+ state_spec = Composite(
289
+ observation=Unbounded(
290
+ (
291
+ *batch_size,
292
+ 1,
293
+ )
294
+ ),
295
+ shape=batch_size,
296
+ )
297
+ if observation_spec is None:
298
+ observation_spec = Composite(
299
+ observation=Unbounded(
300
+ (
301
+ *batch_size,
302
+ 1,
303
+ )
304
+ ),
305
+ shape=batch_size,
306
+ )
307
+ if reward_spec is None:
308
+ reward_spec = Unbounded(
309
+ (
310
+ *batch_size,
311
+ 1,
312
+ )
313
+ )
314
+ if done_spec is None:
315
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
316
+ cls._output_spec = Composite(shape=batch_size)
317
+ cls._output_spec["full_reward_spec"] = reward_spec
318
+ cls._output_spec["full_done_spec"] = done_spec
319
+ cls._output_spec["full_observation_spec"] = observation_spec
320
+ cls._input_spec = Composite(
321
+ full_action_spec=action_spec,
322
+ full_state_spec=state_spec,
323
+ shape=batch_size,
324
+ )
325
+ if not isinstance(cls._output_spec["full_reward_spec"], Composite):
326
+ cls._output_spec["full_reward_spec"] = Composite(
327
+ reward=cls._output_spec["full_reward_spec"], shape=batch_size
328
+ )
329
+ if not isinstance(cls._output_spec["full_done_spec"], Composite):
330
+ cls._output_spec["full_done_spec"] = Composite(
331
+ done=cls._output_spec["full_done_spec"], shape=batch_size
332
+ )
333
+ if not isinstance(cls._input_spec["full_action_spec"], Composite):
334
+ cls._input_spec["full_action_spec"] = Composite(
335
+ action=cls._input_spec["full_action_spec"], shape=batch_size
336
+ )
337
+ return super().__new__(cls, *args, **kwargs)
338
+
339
+ def __init__(self, device="cpu", batch_size=None):
340
+ super().__init__(device=device, batch_size=batch_size)
341
+ self.counter = 0
342
+
343
+ rand_step = MockSerialEnv.rand_step
344
+
345
+ def _set_seed(self, seed: int | None) -> None:
346
+ assert seed >= 1
347
+ self.seed = seed
348
+ self.counter = seed % 17 # make counter a small number
349
+ self.max_val = max(self.counter + 100, self.counter * 2)
350
+
351
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
352
+ if len(self.batch_size):
353
+ leading_batch_size = (
354
+ tensordict.shape[: -len(self.batch_size)]
355
+ if tensordict is not None
356
+ else []
357
+ )
358
+ else:
359
+ leading_batch_size = tensordict.shape if tensordict is not None else []
360
+ self.counter += 1
361
+ # We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv
362
+ n = torch.full(
363
+ [*leading_batch_size, *self.observation_spec["observation"].shape],
364
+ self.counter,
365
+ device=self.device,
366
+ dtype=torch.get_default_dtype(),
367
+ )
368
+ done = self.counter >= self.max_val
369
+ done = torch.full(
370
+ (*leading_batch_size, *self.batch_size, 1),
371
+ done,
372
+ dtype=torch.bool,
373
+ device=self.device,
374
+ )
375
+ return TensorDict(
376
+ {"reward": n, "done": done, "terminated": done.clone(), "observation": n},
377
+ batch_size=tensordict.batch_size,
378
+ device=self.device,
379
+ )
380
+
381
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
382
+ self.max_val = max(self.counter + 100, self.counter * 2)
383
+ batch_size = self.batch_size
384
+ if len(batch_size):
385
+ leading_batch_size = (
386
+ tensordict.shape[: -len(self.batch_size)]
387
+ if tensordict is not None
388
+ else []
389
+ )
390
+ else:
391
+ leading_batch_size = tensordict.shape if tensordict is not None else []
392
+
393
+ n = torch.full(
394
+ [*leading_batch_size, *self.observation_spec["observation"].shape],
395
+ self.counter,
396
+ device=self.device,
397
+ dtype=torch.get_default_dtype(),
398
+ )
399
+ done = self.counter >= self.max_val
400
+ done = torch.full(
401
+ (*leading_batch_size, *batch_size, 1),
402
+ done,
403
+ dtype=torch.bool,
404
+ device=self.device,
405
+ )
406
+ return TensorDict(
407
+ {"done": done, "terminated": done.clone(), "observation": n},
408
+ [
409
+ *leading_batch_size,
410
+ *batch_size,
411
+ ],
412
+ device=self.device,
413
+ )
414
+
415
+
416
+ class MockBatchedUnLockedEnv(MockBatchedLockedEnv):
417
+ """Mocks an env which batch_size does not define the size of the output tensordict.
418
+
419
+ The size of the output tensordict is defined by the input tensordict itself.
420
+
421
+ """
422
+
423
+ def __init__(self, device="cpu", batch_size=None):
424
+ super().__init__(batch_size=batch_size, device=device)
425
+
426
+ @classmethod
427
+ def __new__(cls, *args, **kwargs):
428
+ return super().__new__(cls, *args, _batch_locked=False, **kwargs)
429
+
430
+
431
+ class StateLessCountingEnv(EnvBase):
432
+ """A simple counting environment with no internal state beyond the input tensordict."""
433
+
434
+ def __init__(self):
435
+ self.observation_spec = Composite(
436
+ count=Unbounded((1,), dtype=torch.int32),
437
+ max_count=Unbounded((1,), dtype=torch.int32),
438
+ )
439
+ self.full_action_spec = Composite(
440
+ action=Unbounded((1,), dtype=torch.int32),
441
+ )
442
+ self.full_done_spec = Composite(
443
+ done=Unbounded((1,), dtype=torch.bool),
444
+ termindated=Unbounded((1,), dtype=torch.bool),
445
+ truncated=Unbounded((1,), dtype=torch.bool),
446
+ )
447
+ self.reward_spec = Composite(reward=Unbounded((1,), dtype=torch.float))
448
+ super().__init__()
449
+ self._batch_locked = False
450
+
451
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
452
+
453
+ max_count = None
454
+ count = None
455
+ if tensordict is not None:
456
+ max_count = tensordict.get("max_count")
457
+ count = tensordict.get("count")
458
+ tensordict = TensorDict(
459
+ batch_size=tensordict.batch_size, device=tensordict.device
460
+ )
461
+ shape = tensordict.batch_size
462
+ else:
463
+ shape = ()
464
+ tensordict = TensorDict(device=self.device)
465
+ tensordict.update(
466
+ TensorDict(
467
+ count=torch.zeros(
468
+ (
469
+ *shape,
470
+ 1,
471
+ ),
472
+ dtype=torch.int32,
473
+ )
474
+ if count is None
475
+ else count,
476
+ max_count=torch.randint(
477
+ 10,
478
+ 20,
479
+ (
480
+ *shape,
481
+ 1,
482
+ ),
483
+ dtype=torch.int32,
484
+ )
485
+ if max_count is None
486
+ else max_count,
487
+ **self.done_spec.zero(shape),
488
+ **self.full_reward_spec.zero(shape),
489
+ )
490
+ )
491
+ return tensordict
492
+
493
+ def _step(
494
+ self,
495
+ tensordict: TensorDictBase,
496
+ ) -> TensorDictBase:
497
+ action = tensordict["action"]
498
+ count = tensordict["count"] + action
499
+ terminated = done = count >= tensordict["max_count"]
500
+ truncated = torch.zeros_like(done)
501
+ return TensorDict(
502
+ count=count,
503
+ max_count=tensordict["max_count"],
504
+ done=done,
505
+ terminated=terminated,
506
+ truncated=truncated,
507
+ reward=self.reward_spec.zero(tensordict.shape),
508
+ batch_size=tensordict.batch_size,
509
+ device=tensordict.device,
510
+ )
511
+
512
+ def _set_seed(self, seed: int | None) -> None:
513
+ ...
514
+
515
+
516
+ class DiscreteActionVecMockEnv(_MockEnv):
517
+ """Mock env with vector observations and discrete (one-hot/categorical) actions."""
518
+
519
+ @classmethod
520
+ def __new__(
521
+ cls,
522
+ *args,
523
+ observation_spec=None,
524
+ action_spec=None,
525
+ state_spec=None,
526
+ reward_spec=None,
527
+ done_spec=None,
528
+ from_pixels=False,
529
+ categorical_action_encoding=False,
530
+ **kwargs,
531
+ ):
532
+ batch_size = kwargs.setdefault("batch_size", torch.Size([]))
533
+ size = cls.size = 7
534
+ if observation_spec is None:
535
+ cls.out_key = "observation"
536
+ observation_spec = Composite(
537
+ observation=Unbounded(shape=torch.Size([*batch_size, size])),
538
+ observation_orig=Unbounded(shape=torch.Size([*batch_size, size])),
539
+ shape=batch_size,
540
+ )
541
+ if action_spec is None:
542
+ if categorical_action_encoding:
543
+ action_spec_cls = Categorical
544
+ action_spec = action_spec_cls(n=7, shape=batch_size)
545
+ else:
546
+ action_spec_cls = OneHot
547
+ action_spec = action_spec_cls(n=7, shape=(*batch_size, 7))
548
+ if reward_spec is None:
549
+ reward_spec = Composite(reward=Unbounded(shape=(1,)))
550
+ if done_spec is None:
551
+ done_spec = Composite(
552
+ terminated=Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
553
+ )
554
+
555
+ if state_spec is None:
556
+ cls._out_key = "observation_orig"
557
+ state_spec = Composite(
558
+ {
559
+ cls._out_key: observation_spec["observation"],
560
+ },
561
+ shape=batch_size,
562
+ )
563
+ cls._output_spec = Composite(shape=batch_size)
564
+ cls._output_spec["full_reward_spec"] = reward_spec
565
+ cls._output_spec["full_done_spec"] = done_spec
566
+ cls._output_spec["full_observation_spec"] = observation_spec
567
+ cls._input_spec = Composite(
568
+ full_action_spec=action_spec,
569
+ full_state_spec=state_spec,
570
+ shape=batch_size,
571
+ )
572
+ cls.from_pixels = from_pixels
573
+ cls.categorical_action_encoding = categorical_action_encoding
574
+ return super().__new__(*args, **kwargs)
575
+
576
+ def _get_in_obs(self, obs):
577
+ return obs
578
+
579
+ def _get_out_obs(self, obs):
580
+ return obs
581
+
582
+ def _reset(self, tensordict: TensorDictBase = None) -> TensorDictBase:
583
+ self.counter += 1
584
+ state = torch.zeros(self.size) + self.counter
585
+ if tensordict is None:
586
+ tensordict = TensorDict(batch_size=self.batch_size, device=self.device)
587
+ tensordict = tensordict.empty().set(self.out_key, self._get_out_obs(state))
588
+ tensordict = tensordict.set(self._out_key, self._get_out_obs(state))
589
+ tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
590
+ tensordict.set(
591
+ "terminated", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)
592
+ )
593
+ return tensordict
594
+
595
+ def _step(
596
+ self,
597
+ tensordict: TensorDictBase,
598
+ ) -> TensorDictBase:
599
+ tensordict = tensordict.to(self.device)
600
+ a = tensordict.get("action")
601
+
602
+ if not self.categorical_action_encoding:
603
+ assert (a.sum(-1) == 1).all()
604
+
605
+ obs = self._get_in_obs(tensordict.get(self._out_key)) + a / self.maxstep
606
+ tensordict = tensordict.empty()
607
+
608
+ tensordict.set(self.out_key, self._get_out_obs(obs))
609
+ tensordict.set(self._out_key, self._get_out_obs(obs))
610
+
611
+ done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1))
612
+ reward = done.any(-1).unsqueeze(-1)
613
+
614
+ # set done to False
615
+ done = torch.zeros_like(done).all(-1).unsqueeze(-1)
616
+ tensordict.set("reward", reward.to(torch.get_default_dtype()))
617
+ tensordict.set("done", done)
618
+ tensordict.set("terminated", done.clone())
619
+ return tensordict
620
+
621
+
622
+ class ContinuousActionVecMockEnv(_MockEnv):
623
+ """Mock env with vector observations and continuous (bounded) actions."""
624
+
625
+ adapt_dtype: bool = True
626
+
627
+ @classmethod
628
+ def __new__(
629
+ cls,
630
+ *args,
631
+ observation_spec=None,
632
+ action_spec=None,
633
+ state_spec=None,
634
+ reward_spec=None,
635
+ done_spec=None,
636
+ from_pixels=False,
637
+ **kwargs,
638
+ ):
639
+ batch_size = kwargs.setdefault("batch_size", torch.Size([]))
640
+ size = cls.size = 7
641
+ if observation_spec is None:
642
+ cls.out_key = "observation"
643
+ observation_spec = Composite(
644
+ observation=Unbounded(shape=torch.Size([*batch_size, size])),
645
+ observation_orig=Unbounded(shape=torch.Size([*batch_size, size])),
646
+ shape=batch_size,
647
+ )
648
+ if action_spec is None:
649
+ action_spec = Bounded(
650
+ -1,
651
+ 1,
652
+ (
653
+ *batch_size,
654
+ 7,
655
+ ),
656
+ )
657
+ if reward_spec is None:
658
+ reward_spec = Unbounded(shape=(*batch_size, 1))
659
+ if done_spec is None:
660
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
661
+
662
+ if state_spec is None:
663
+ cls._out_key = "observation_orig"
664
+ state_spec = Composite(
665
+ {
666
+ cls._out_key: observation_spec["observation"],
667
+ },
668
+ shape=batch_size,
669
+ )
670
+ cls._output_spec = Composite(shape=batch_size)
671
+ cls._output_spec["full_reward_spec"] = reward_spec
672
+ cls._output_spec["full_done_spec"] = done_spec
673
+ cls._output_spec["full_observation_spec"] = observation_spec
674
+ cls._input_spec = Composite(
675
+ full_action_spec=action_spec,
676
+ full_state_spec=state_spec,
677
+ shape=batch_size,
678
+ )
679
+ cls.from_pixels = from_pixels
680
+ return super().__new__(cls, *args, **kwargs)
681
+
682
+ def _get_in_obs(self, obs):
683
+ return obs
684
+
685
+ def _get_out_obs(self, obs):
686
+ return obs
687
+
688
+ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
689
+ self.counter += 1
690
+ self.step_count = 0
691
+ # state = torch.zeros(self.size) + self.counter
692
+ if tensordict is None:
693
+ tensordict = TensorDict(batch_size=self.batch_size, device=self.device)
694
+
695
+ tensordict = tensordict.empty()
696
+ tensordict.update(self.observation_spec.rand())
697
+ # tensordict.set("next_" + self.out_key, self._get_out_obs(state))
698
+ # tensordict.set("next_" + self._out_key, self._get_out_obs(state))
699
+ tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool))
700
+ tensordict.set(
701
+ "terminated", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)
702
+ )
703
+ return tensordict
704
+
705
+ def _step(
706
+ self,
707
+ tensordict: TensorDictBase,
708
+ ) -> TensorDictBase:
709
+ self.step_count += 1
710
+ tensordict = tensordict.to(self.device)
711
+ a = tensordict.get("action")
712
+
713
+ obs = self._obs_step(self._get_in_obs(tensordict.get(self._out_key)), a)
714
+
715
+ tensordict = tensordict.empty() # empty tensordict
716
+
717
+ tensordict.set(self.out_key, self._get_out_obs(obs))
718
+ tensordict.set(self._out_key, self._get_out_obs(obs))
719
+
720
+ done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1))
721
+ while done.shape != tensordict.shape:
722
+ done = done.any(-1)
723
+ done = reward = done.unsqueeze(-1)
724
+ tensordict.set(
725
+ "reward",
726
+ reward.to(
727
+ self.reward_spec.dtype
728
+ if self.adapt_dtype
729
+ else torch.get_default_dtype()
730
+ ).expand(self.reward_spec.shape),
731
+ )
732
+ tensordict.set("done", done)
733
+ tensordict.set("terminated", done)
734
+ return tensordict
735
+
736
+ def _obs_step(self, obs, a):
737
+ return obs + a / self.maxstep
738
+
739
+
740
+ class DiscreteActionVecPolicy(TensorDictModuleBase):
741
+ """Deterministic policy for `DiscreteActionVecMockEnv`-like observations."""
742
+
743
+ in_keys = ["observation"]
744
+ out_keys = ["action"]
745
+
746
+ def _get_in_obs(self, tensordict):
747
+ obs = tensordict.get(*self.in_keys)
748
+ return obs
749
+
750
+ def __call__(self, tensordict: TensorDictBase) -> TensorDictBase:
751
+ obs = self._get_in_obs(tensordict)
752
+ max_obs = (obs == obs.max(dim=-1, keepdim=True)[0]).cumsum(-1).argmax(-1)
753
+ k = tensordict.get(*self.in_keys).shape[-1]
754
+ max_obs = (max_obs + 1) % k
755
+ action = torch.nn.functional.one_hot(max_obs, k)
756
+ tensordict.set(*self.out_keys, action)
757
+ return tensordict
758
+
759
+
760
+ class DiscreteActionConvMockEnv(DiscreteActionVecMockEnv):
761
+ """Mock env with image-like observations and discrete (one-hot) actions."""
762
+
763
+ @classmethod
764
+ def __new__(
765
+ cls,
766
+ *args,
767
+ observation_spec=None,
768
+ action_spec=None,
769
+ state_spec=None,
770
+ reward_spec=None,
771
+ done_spec=None,
772
+ from_pixels=True,
773
+ **kwargs,
774
+ ):
775
+ batch_size = kwargs.setdefault("batch_size", torch.Size([]))
776
+ if observation_spec is None:
777
+ cls.out_key = "pixels"
778
+ observation_spec = Composite(
779
+ pixels=Unbounded(shape=torch.Size([*batch_size, 1, 7, 7])),
780
+ pixels_orig=Unbounded(shape=torch.Size([*batch_size, 1, 7, 7])),
781
+ shape=batch_size,
782
+ )
783
+ if action_spec is None:
784
+ action_spec = OneHot(7, shape=(*batch_size, 7))
785
+ if reward_spec is None:
786
+ reward_spec = Unbounded(shape=(*batch_size, 1))
787
+ if done_spec is None:
788
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
789
+
790
+ if state_spec is None:
791
+ cls._out_key = "pixels_orig"
792
+ state_spec = Composite(
793
+ {
794
+ cls._out_key: observation_spec["pixels_orig"].clone(),
795
+ },
796
+ shape=batch_size,
797
+ )
798
+ return super().__new__(
799
+ *args,
800
+ observation_spec=observation_spec,
801
+ action_spec=action_spec,
802
+ reward_spec=reward_spec,
803
+ state_spec=state_spec,
804
+ from_pixels=from_pixels,
805
+ done_spec=done_spec,
806
+ **kwargs,
807
+ )
808
+
809
+ def _get_out_obs(self, obs):
810
+ obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(0)
811
+ return obs
812
+
813
+ def _get_in_obs(self, obs):
814
+ return obs.diagonal(0, -1, -2).squeeze()
815
+
816
+
817
+ class DiscreteActionConvMockEnvNumpy(DiscreteActionConvMockEnv):
818
+ """Numpy-style variant of `DiscreteActionConvMockEnv` (channels-last pixels)."""
819
+
820
+ @classmethod
821
+ def __new__(
822
+ cls,
823
+ *args,
824
+ observation_spec=None,
825
+ action_spec=None,
826
+ state_spec=None,
827
+ reward_spec=None,
828
+ done_spec=None,
829
+ from_pixels=True,
830
+ categorical_action_encoding=False,
831
+ **kwargs,
832
+ ):
833
+ batch_size = kwargs.setdefault("batch_size", torch.Size([]))
834
+ if observation_spec is None:
835
+ cls.out_key = "pixels"
836
+ observation_spec = Composite(
837
+ pixels=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])),
838
+ pixels_orig=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])),
839
+ shape=batch_size,
840
+ )
841
+ if action_spec is None:
842
+ action_spec_cls = Categorical if categorical_action_encoding else OneHot
843
+ action_spec = action_spec_cls(7, shape=(*batch_size, 7))
844
+ if state_spec is None:
845
+ cls._out_key = "pixels_orig"
846
+ state_spec = Composite(
847
+ {
848
+ cls._out_key: observation_spec["pixels_orig"],
849
+ },
850
+ shape=batch_size,
851
+ )
852
+
853
+ return super().__new__(
854
+ *args,
855
+ observation_spec=observation_spec,
856
+ action_spec=action_spec,
857
+ reward_spec=reward_spec,
858
+ state_spec=state_spec,
859
+ from_pixels=from_pixels,
860
+ categorical_action_encoding=categorical_action_encoding,
861
+ **kwargs,
862
+ )
863
+
864
+ def _get_out_obs(self, obs):
865
+ obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(-1)
866
+ obs = obs.expand(*obs.shape[:-1], 3)
867
+ return obs
868
+
869
+ def _get_in_obs(self, obs):
870
+ return obs.diagonal(0, -2, -3)[..., 0, :]
871
+
872
+ def _obs_step(self, obs, a):
873
+ return obs + a.unsqueeze(-1) / self.maxstep
874
+
875
+
876
+ class ContinuousActionConvMockEnv(ContinuousActionVecMockEnv):
877
+ """Mock env with image-like observations and continuous (bounded) actions."""
878
+
879
+ @classmethod
880
+ def __new__(
881
+ cls,
882
+ *args,
883
+ observation_spec=None,
884
+ action_spec=None,
885
+ state_spec=None,
886
+ reward_spec=None,
887
+ done_spec=None,
888
+ from_pixels=True,
889
+ pixel_shape=None,
890
+ **kwargs,
891
+ ):
892
+ batch_size = kwargs.setdefault("batch_size", torch.Size([]))
893
+ if pixel_shape is None:
894
+ pixel_shape = [1, 7, 7]
895
+ if observation_spec is None:
896
+ cls.out_key = "pixels"
897
+ observation_spec = Composite(
898
+ pixels=Unbounded(shape=torch.Size([*batch_size, *pixel_shape])),
899
+ pixels_orig=Unbounded(shape=torch.Size([*batch_size, *pixel_shape])),
900
+ shape=batch_size,
901
+ )
902
+
903
+ if action_spec is None:
904
+ action_spec = Bounded(-1, 1, [*batch_size, pixel_shape[-1]])
905
+ if reward_spec is None:
906
+ reward_spec = Unbounded(shape=(*batch_size, 1))
907
+ if done_spec is None:
908
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
909
+ if state_spec is None:
910
+ cls._out_key = "pixels_orig"
911
+ state_spec = Composite(
912
+ {cls._out_key: observation_spec["pixels"]}, shape=batch_size
913
+ )
914
+ return super().__new__(
915
+ *args,
916
+ observation_spec=observation_spec,
917
+ action_spec=action_spec,
918
+ reward_spec=reward_spec,
919
+ from_pixels=from_pixels,
920
+ state_spec=state_spec,
921
+ done_spec=done_spec,
922
+ **kwargs,
923
+ )
924
+
925
+ def _get_out_obs(self, obs):
926
+ obs = torch.diag_embed(obs, 0, -2, -1)
927
+ return obs
928
+
929
+ def _get_in_obs(self, obs):
930
+ obs = obs.diagonal(0, -1, -2)
931
+ return obs
932
+
933
+
934
+ class ContinuousActionConvMockEnvNumpy(ContinuousActionConvMockEnv):
935
+ """Numpy-style variant of `ContinuousActionConvMockEnv` (channels-last pixels)."""
936
+
937
+ @classmethod
938
+ def __new__(
939
+ cls,
940
+ *args,
941
+ observation_spec=None,
942
+ action_spec=None,
943
+ state_spec=None,
944
+ reward_spec=None,
945
+ done_spec=None,
946
+ from_pixels=True,
947
+ **kwargs,
948
+ ):
949
+ batch_size = kwargs.setdefault("batch_size", torch.Size([]))
950
+ if observation_spec is None:
951
+ cls.out_key = "pixels"
952
+ observation_spec = Composite(
953
+ pixels=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])),
954
+ pixels_orig=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])),
955
+ )
956
+ return super().__new__(
957
+ *args,
958
+ observation_spec=observation_spec,
959
+ action_spec=action_spec,
960
+ reward_spec=reward_spec,
961
+ state_spec=state_spec,
962
+ from_pixels=from_pixels,
963
+ **kwargs,
964
+ )
965
+
966
+ def _get_out_obs(self, obs):
967
+ obs = torch.diag_embed(obs, 0, -2, -1).unsqueeze(-1)
968
+ obs = obs.expand(*obs.shape[:-1], 3)
969
+ return obs
970
+
971
+ def _get_in_obs(self, obs):
972
+ return obs.diagonal(0, -2, -3)[..., 0, :]
973
+
974
+ def _obs_step(self, obs, a):
975
+ return obs + a / self.maxstep
976
+
977
+
978
+ class DiscreteActionConvPolicy(DiscreteActionVecPolicy):
979
+ """Policy for discrete-action convolutional mock environments."""
980
+
981
+ in_keys = ["pixels"]
982
+ out_keys = ["action"]
983
+
984
+ def _get_in_obs(self, tensordict):
985
+ obs = tensordict.get(*self.in_keys).diagonal(0, -1, -2).squeeze()
986
+ return obs
987
+
988
+
989
+ class DummyModelBasedEnvBase(ModelBasedEnvBase):
990
+ """Dummy environment for Model Based RL sota-implementations.
991
+
992
+ This class is meant to be used to test the model based environment.
993
+
994
+ Args:
995
+ world_model (WorldModel): the world model to use for the environment.
996
+ device (str or torch.device, optional): the device to use for the environment.
997
+ dtype (torch.dtype, optional): the dtype to use for the environment.
998
+ batch_size (sequence of int, optional): the batch size to use for the environment.
999
+ """
1000
+
1001
+ def __init__(
1002
+ self,
1003
+ world_model,
1004
+ device="cpu",
1005
+ dtype=None,
1006
+ batch_size=None,
1007
+ ):
1008
+ super().__init__(
1009
+ world_model,
1010
+ device=device,
1011
+ batch_size=batch_size,
1012
+ )
1013
+ self.observation_spec = Composite(
1014
+ hidden_observation=Unbounded(
1015
+ (
1016
+ *self.batch_size,
1017
+ 4,
1018
+ )
1019
+ ),
1020
+ shape=self.batch_size,
1021
+ )
1022
+ self.state_spec = Composite(
1023
+ hidden_observation=Unbounded(
1024
+ (
1025
+ *self.batch_size,
1026
+ 4,
1027
+ )
1028
+ ),
1029
+ shape=self.batch_size,
1030
+ )
1031
+ self.action_spec = Unbounded(
1032
+ (
1033
+ *self.batch_size,
1034
+ 1,
1035
+ )
1036
+ )
1037
+ self.reward_spec = Unbounded(
1038
+ (
1039
+ *self.batch_size,
1040
+ 1,
1041
+ )
1042
+ )
1043
+
1044
+ def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
1045
+ td = TensorDict(
1046
+ {
1047
+ "hidden_observation": self.state_spec["hidden_observation"].rand(),
1048
+ },
1049
+ batch_size=self.batch_size,
1050
+ device=self.device,
1051
+ )
1052
+ return td
1053
+
1054
+
1055
+ class ActionObsMergeLinear(nn.Module):
1056
+ """Linear layer that consumes concatenated observation and action tensors."""
1057
+
1058
+ def __init__(self, in_size, out_size):
1059
+ super().__init__()
1060
+ self.linear = nn.Linear(in_size, out_size)
1061
+
1062
+ def forward(self, observation, action):
1063
+ return self.linear(torch.cat([observation, action], dim=-1))
1064
+
1065
+
1066
+ class CountingEnvCountPolicy(TensorDictModuleBase):
1067
+ """Policy that always returns an increment action for counting environments."""
1068
+
1069
+ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"):
1070
+ super().__init__()
1071
+ assert not isinstance(action_spec, Composite)
1072
+ self.action_spec = action_spec
1073
+ self.action_key = action_key
1074
+ self.in_keys = []
1075
+ self.out_keys = [action_key]
1076
+
1077
+ def __call__(self, td: TensorDictBase) -> TensorDictBase:
1078
+ return td.set(self.action_key, self.action_spec.zero() + 1)
1079
+
1080
+
1081
+ class CountingEnvCountModule(nn.Module):
1082
+ """Module that returns a constant increment action given an action spec."""
1083
+
1084
+ def __init__(self, action_spec: TensorSpec):
1085
+ super().__init__()
1086
+ self.action_spec = action_spec
1087
+
1088
+ def forward(self, t):
1089
+ return self.action_spec.zero() + 1
1090
+
1091
+
1092
+ class CountingEnv(EnvBase):
1093
+ """An env that is done after a given number of steps.
1094
+
1095
+ The action is the count increment.
1096
+
1097
+ """
1098
+
1099
+ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
1100
+ super().__init__(**kwargs)
1101
+ self.max_steps = max_steps
1102
+ self.start_val = start_val
1103
+
1104
+ self.observation_spec = Composite(
1105
+ observation=Unbounded(
1106
+ (
1107
+ *self.batch_size,
1108
+ 1,
1109
+ ),
1110
+ dtype=torch.int32,
1111
+ device=self.device,
1112
+ ),
1113
+ shape=self.batch_size,
1114
+ device=self.device,
1115
+ )
1116
+ self.reward_spec = Unbounded(
1117
+ (
1118
+ *self.batch_size,
1119
+ 1,
1120
+ ),
1121
+ device=self.device,
1122
+ )
1123
+ self.done_spec = Categorical(
1124
+ 2,
1125
+ dtype=torch.bool,
1126
+ shape=(*self.batch_size, 1),
1127
+ device=self.device,
1128
+ )
1129
+ self.action_spec = Binary(n=1, shape=[*self.batch_size, 1], device=self.device)
1130
+ self.register_buffer(
1131
+ "count",
1132
+ torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
1133
+ )
1134
+
1135
+ def _set_seed(self, seed: int | None) -> None:
1136
+ torch.manual_seed(seed)
1137
+
1138
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1139
+ if tensordict is not None and "_reset" in tensordict.keys():
1140
+ _reset = tensordict.get("_reset")
1141
+ self.count[_reset] = self.start_val
1142
+ else:
1143
+ self.count[:] = self.start_val
1144
+ return TensorDict(
1145
+ source={
1146
+ "observation": self.count.clone(),
1147
+ "done": self.count > self.max_steps,
1148
+ "terminated": self.count > self.max_steps,
1149
+ },
1150
+ batch_size=self.batch_size,
1151
+ device=self.device,
1152
+ )
1153
+
1154
+ def _step(
1155
+ self,
1156
+ tensordict: TensorDictBase,
1157
+ ) -> TensorDictBase:
1158
+ action = tensordict.get(self.action_key)
1159
+ try:
1160
+ device = self.full_action_spec[self.action_key].device
1161
+ except KeyError:
1162
+ device = self.device
1163
+ self.count += action.to(
1164
+ dtype=torch.int,
1165
+ device=device if self.device is None else self.device,
1166
+ )
1167
+ if self.reward_keys:
1168
+ reward_spec = self.full_reward_spec[self.reward_keys[0]]
1169
+ reward_spec_dtype = reward_spec.dtype
1170
+ else:
1171
+ reward_spec_dtype = torch.get_default_dtype()
1172
+ tensordict = TensorDict(
1173
+ source={
1174
+ "observation": self.count.clone(),
1175
+ "done": self.count > self.max_steps,
1176
+ "terminated": self.count > self.max_steps,
1177
+ "reward": torch.zeros_like(self.count, dtype=reward_spec_dtype),
1178
+ },
1179
+ batch_size=self.batch_size,
1180
+ device=self.device,
1181
+ )
1182
+ return tensordict
1183
+
1184
+
1185
+ def get_random_string(min_size, max_size):
1186
+ """Return a random ASCII lowercase string with length in [min_size, max_size]."""
1187
+ size = random.randint(min_size, max_size)
1188
+ return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
1189
+
1190
+
1191
+ class CountingEnvWithString(CountingEnv):
1192
+ """`CountingEnv` variant that adds a non-tensor string observation."""
1193
+
1194
+ def __init__(self, *args, **kwargs):
1195
+ self.max_size = kwargs.pop("max_size", 30)
1196
+ self.min_size = kwargs.pop("min_size", 4)
1197
+ super().__init__(*args, **kwargs)
1198
+ self.observation_spec.set(
1199
+ "string",
1200
+ NonTensor(
1201
+ shape=self.batch_size,
1202
+ device=self.device,
1203
+ example_data=self.get_random_string(),
1204
+ ),
1205
+ )
1206
+
1207
+ def get_random_string(self):
1208
+ return get_random_string(self.min_size, self.max_size)
1209
+
1210
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1211
+ res = super()._reset(tensordict, **kwargs)
1212
+ random_string = self.get_random_string()
1213
+ res["string"] = random_string
1214
+ return res
1215
+
1216
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
1217
+ res = super()._step(tensordict)
1218
+ random_string = self.get_random_string()
1219
+ res["string"] = random_string
1220
+ return res
1221
+
1222
+
1223
+ class MultiAgentCountingEnv(EnvBase):
1224
+ """A multi-agent env that is done after a given number of steps.
1225
+
1226
+ All agents have identical specs.
1227
+
1228
+ The count is incremented by 1 on each step.
1229
+
1230
+ """
1231
+
1232
+ def __init__(
1233
+ self,
1234
+ n_agents: int,
1235
+ group_map: MarlGroupMapType
1236
+ | dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
1237
+ max_steps: int = 5,
1238
+ start_val: int = 0,
1239
+ **kwargs,
1240
+ ):
1241
+ super().__init__(**kwargs)
1242
+ self.max_steps = max_steps
1243
+ self.start_val = start_val
1244
+ self.n_agents = n_agents
1245
+ self.agent_names = [f"agent_{idx}" for idx in range(n_agents)]
1246
+
1247
+ if isinstance(group_map, MarlGroupMapType):
1248
+ group_map = group_map.get_group_map(self.agent_names)
1249
+ check_marl_grouping(group_map, self.agent_names)
1250
+
1251
+ self.group_map = group_map
1252
+
1253
+ observation_specs = {}
1254
+ reward_specs = {}
1255
+ done_specs = {}
1256
+ action_specs = {}
1257
+
1258
+ for group_name, agents in group_map.items():
1259
+ observation_specs[group_name] = {}
1260
+ reward_specs[group_name] = {}
1261
+ done_specs[group_name] = {}
1262
+ action_specs[group_name] = {}
1263
+
1264
+ for agent_name in agents:
1265
+ observation_specs[group_name][agent_name] = Composite(
1266
+ observation=Unbounded(
1267
+ (
1268
+ *self.batch_size,
1269
+ 3,
1270
+ 4,
1271
+ ),
1272
+ dtype=torch.float32,
1273
+ device=self.device,
1274
+ ),
1275
+ shape=self.batch_size,
1276
+ device=self.device,
1277
+ )
1278
+ reward_specs[group_name][agent_name] = Composite(
1279
+ reward=Unbounded(
1280
+ (
1281
+ *self.batch_size,
1282
+ 1,
1283
+ ),
1284
+ device=self.device,
1285
+ ),
1286
+ shape=self.batch_size,
1287
+ device=self.device,
1288
+ )
1289
+ done_specs[group_name][agent_name] = Composite(
1290
+ done=Categorical(
1291
+ 2,
1292
+ dtype=torch.bool,
1293
+ shape=(
1294
+ *self.batch_size,
1295
+ 1,
1296
+ ),
1297
+ device=self.device,
1298
+ ),
1299
+ shape=self.batch_size,
1300
+ device=self.device,
1301
+ )
1302
+ action_specs[group_name][agent_name] = Composite(
1303
+ action=Binary(n=1, shape=[*self.batch_size, 1], device=self.device),
1304
+ shape=self.batch_size,
1305
+ device=self.device,
1306
+ )
1307
+
1308
+ self.observation_spec = Composite(observation_specs)
1309
+ self.reward_spec = Composite(reward_specs)
1310
+ self.done_spec = Composite(done_specs)
1311
+ self.action_spec = Composite(action_specs)
1312
+ self.register_buffer(
1313
+ "count",
1314
+ torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
1315
+ )
1316
+
1317
+ def _set_seed(self, seed: int | None) -> None:
1318
+ torch.manual_seed(seed)
1319
+
1320
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1321
+ if tensordict is not None and "_reset" in tensordict.keys():
1322
+ _reset = tensordict.get("_reset")
1323
+ self.count[_reset] = self.start_val
1324
+ else:
1325
+ self.count[:] = self.start_val
1326
+
1327
+ source = {}
1328
+ for group_name, agents in self.group_map.items():
1329
+ source[group_name] = {}
1330
+ for agent_name in agents:
1331
+ source[group_name][agent_name] = TensorDict(
1332
+ source={
1333
+ "observation": torch.rand(
1334
+ (*self.batch_size, 3, 4),
1335
+ device=self.device,
1336
+ dtype=self.full_observation_spec[
1337
+ group_name, agent_name, "observation"
1338
+ ].dtype,
1339
+ ),
1340
+ "done": self.count > self.max_steps,
1341
+ "terminated": self.count > self.max_steps,
1342
+ },
1343
+ batch_size=self.batch_size,
1344
+ device=self.device,
1345
+ )
1346
+
1347
+ tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device)
1348
+ return tensordict
1349
+
1350
+ def _step(
1351
+ self,
1352
+ tensordict: TensorDictBase,
1353
+ ) -> TensorDictBase:
1354
+ self.count += 1
1355
+ source = {}
1356
+ for group_name, agents in self.group_map.items():
1357
+ source[group_name] = {}
1358
+ for agent_name in agents:
1359
+ source[group_name][agent_name] = TensorDict(
1360
+ source={
1361
+ "observation": torch.rand(
1362
+ (*self.batch_size, 3, 4),
1363
+ device=self.device,
1364
+ dtype=self.full_observation_spec[
1365
+ group_name, agent_name, "observation"
1366
+ ].dtype,
1367
+ ),
1368
+ "done": self.count > self.max_steps,
1369
+ "terminated": self.count > self.max_steps,
1370
+ "reward": torch.zeros_like(
1371
+ self.count,
1372
+ dtype=self.full_reward_spec[
1373
+ group_name, agent_name, "reward"
1374
+ ].dtype,
1375
+ ),
1376
+ },
1377
+ batch_size=self.batch_size,
1378
+ device=self.device,
1379
+ )
1380
+ tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device)
1381
+ return tensordict
1382
+
1383
+
1384
+ class IncrementingEnv(CountingEnv):
1385
+ """`CountingEnv` variant that always increments the count by 1 regardless of action."""
1386
+
1387
+ def _step(
1388
+ self,
1389
+ tensordict: TensorDictBase,
1390
+ ) -> TensorDictBase:
1391
+ self.count += 1 # The only difference with CountingEnv.
1392
+ tensordict = TensorDict(
1393
+ source={
1394
+ "observation": self.count.clone(),
1395
+ "done": self.count > self.max_steps,
1396
+ "terminated": self.count > self.max_steps,
1397
+ "reward": torch.zeros_like(self.count, dtype=torch.float),
1398
+ },
1399
+ batch_size=self.batch_size,
1400
+ device=self.device,
1401
+ )
1402
+ return tensordict
1403
+
1404
+
1405
+ class NestedCountingEnv(CountingEnv):
1406
+ """Counting environment with nested observation/action/reward/done structures."""
1407
+
1408
+ def __init__(
1409
+ self,
1410
+ max_steps: int = 5,
1411
+ start_val: int = 0,
1412
+ nest_obs_action: bool = True,
1413
+ nest_done: bool = True,
1414
+ nest_reward: bool = True,
1415
+ nested_dim: int = 3,
1416
+ has_root_done: bool = False,
1417
+ **kwargs,
1418
+ ):
1419
+ super().__init__(max_steps=max_steps, start_val=start_val, **kwargs)
1420
+
1421
+ self.nested_dim = nested_dim
1422
+ self.has_root_done = has_root_done
1423
+
1424
+ self.nested_obs_action = nest_obs_action
1425
+ self.nested_done = nest_done
1426
+ self.nested_reward = nest_reward
1427
+
1428
+ if self.nested_obs_action:
1429
+ self.observation_spec = Composite(
1430
+ {
1431
+ "data": Composite(
1432
+ {
1433
+ "states": self.observation_spec["observation"]
1434
+ .unsqueeze(-1)
1435
+ .expand(*self.batch_size, self.nested_dim, 1)
1436
+ },
1437
+ shape=(
1438
+ *self.batch_size,
1439
+ self.nested_dim,
1440
+ ),
1441
+ )
1442
+ },
1443
+ shape=self.batch_size,
1444
+ )
1445
+ action_spec = self.full_action_spec[self.action_key]
1446
+ assert not isinstance(action_spec, Composite)
1447
+ self.full_action_spec = Composite(
1448
+ {
1449
+ "data": Composite(
1450
+ {
1451
+ "action": action_spec.unsqueeze(-1).expand(
1452
+ *self.batch_size, self.nested_dim, 1
1453
+ )
1454
+ },
1455
+ shape=(
1456
+ *self.batch_size,
1457
+ self.nested_dim,
1458
+ ),
1459
+ )
1460
+ },
1461
+ shape=self.batch_size,
1462
+ )
1463
+
1464
+ if self.nested_reward:
1465
+ self.reward_spec = Composite(
1466
+ {
1467
+ "data": Composite(
1468
+ {
1469
+ "reward": self.reward_spec.unsqueeze(-1).expand(
1470
+ *self.batch_size, self.nested_dim, 1
1471
+ )
1472
+ },
1473
+ shape=(
1474
+ *self.batch_size,
1475
+ self.nested_dim,
1476
+ ),
1477
+ )
1478
+ },
1479
+ shape=self.batch_size,
1480
+ )
1481
+
1482
+ if self.nested_done:
1483
+ done_spec = self.full_done_spec.unsqueeze(-1).expand(
1484
+ *self.batch_size, self.nested_dim
1485
+ )
1486
+ done_spec = Composite(
1487
+ {"data": done_spec},
1488
+ shape=self.batch_size,
1489
+ )
1490
+ if self.has_root_done:
1491
+ done_spec["done"] = Categorical(
1492
+ 2,
1493
+ shape=(
1494
+ *self.batch_size,
1495
+ 1,
1496
+ ),
1497
+ dtype=torch.bool,
1498
+ )
1499
+ self.done_spec = done_spec
1500
+
1501
+ def _reset(self, tensordict):
1502
+
1503
+ # check that reset works as expected
1504
+ if tensordict is not None:
1505
+ if self.nested_done:
1506
+ if not self.has_root_done:
1507
+ assert "_reset" not in tensordict.keys()
1508
+ else:
1509
+ assert ("data", "_reset") not in tensordict.keys(True)
1510
+
1511
+ tensordict_reset = super()._reset(tensordict)
1512
+
1513
+ if self.nested_done:
1514
+ for done_key in self.done_keys:
1515
+ if isinstance(done_key, str):
1516
+ continue
1517
+ else:
1518
+ done = tensordict_reset.pop(done_key[-1], None)
1519
+ if done is None:
1520
+ continue
1521
+ tensordict_reset.set(
1522
+ done_key,
1523
+ (done.unsqueeze(-2).expand(*self.batch_size, self.nested_dim, 1)),
1524
+ )
1525
+ if self.nested_obs_action:
1526
+ obs = tensordict_reset.pop("observation")
1527
+ tensordict_reset.set(
1528
+ ("data", "states"),
1529
+ (obs.unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)),
1530
+ )
1531
+ if "data" in tensordict_reset.keys():
1532
+ tensordict_reset.get("data").batch_size = (
1533
+ *self.batch_size,
1534
+ self.nested_dim,
1535
+ )
1536
+ return tensordict_reset
1537
+
1538
+ def _step(self, tensordict):
1539
+ if self.nested_obs_action:
1540
+ tensordict = tensordict.clone()
1541
+ tensordict["data"].batch_size = self.batch_size
1542
+ tensordict[self.action_key] = tensordict[self.action_key].max(-2)[0]
1543
+ next_tensordict = super()._step(tensordict)
1544
+ if self.nested_obs_action:
1545
+ tensordict[self.action_key] = (
1546
+ tensordict[self.action_key]
1547
+ .unsqueeze(-1)
1548
+ .expand(*self.batch_size, self.nested_dim, 1)
1549
+ )
1550
+ if "data" in tensordict.keys():
1551
+ tensordict["data"].batch_size = (*self.batch_size, self.nested_dim)
1552
+ if self.nested_done:
1553
+ for done_key in self.done_keys:
1554
+ if isinstance(done_key, str):
1555
+ continue
1556
+ else:
1557
+ done = next_tensordict.pop(done_key[-1], None)
1558
+ if done is None:
1559
+ continue
1560
+ next_tensordict.set(
1561
+ done_key,
1562
+ (done.unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1)),
1563
+ )
1564
+ if self.nested_obs_action:
1565
+ next_tensordict.set(
1566
+ ("data", "states"),
1567
+ (
1568
+ next_tensordict.pop("observation")
1569
+ .unsqueeze(-1)
1570
+ .expand(*self.batch_size, self.nested_dim, 1)
1571
+ ),
1572
+ )
1573
+ if self.nested_reward:
1574
+ next_tensordict.set(
1575
+ self.reward_key,
1576
+ (
1577
+ next_tensordict.pop("reward")
1578
+ .unsqueeze(-1)
1579
+ .expand(*self.batch_size, self.nested_dim, 1)
1580
+ ),
1581
+ )
1582
+ if "data" in next_tensordict.keys():
1583
+ next_tensordict.get("data").batch_size = (*self.batch_size, self.nested_dim)
1584
+ return next_tensordict
1585
+
1586
+
1587
+ class CountingBatchedEnv(EnvBase):
1588
+ """An env that is done after a given number of steps.
1589
+
1590
+ The action is the count increment.
1591
+
1592
+ Unlike ``CountingEnv``, different envs of the batch can have different max_steps
1593
+ """
1594
+
1595
+ def __init__(
1596
+ self,
1597
+ max_steps: torch.Tensor = None,
1598
+ start_val: torch.Tensor = None,
1599
+ **kwargs,
1600
+ ):
1601
+ super().__init__(**kwargs)
1602
+ if max_steps is None:
1603
+ max_steps = torch.tensor(5)
1604
+ if start_val is None:
1605
+ start_val = torch.zeros((), dtype=torch.int32)
1606
+ if max_steps.shape != self.batch_size:
1607
+ raise RuntimeError(
1608
+ f"batch_size and max_steps shape must match. Got self.batch_size={self.batch_size} and max_steps.shape={max_steps.shape}."
1609
+ )
1610
+
1611
+ self.max_steps = max_steps
1612
+
1613
+ self.observation_spec = Composite(
1614
+ observation=Unbounded(
1615
+ (
1616
+ *self.batch_size,
1617
+ 1,
1618
+ ),
1619
+ dtype=torch.int32,
1620
+ ),
1621
+ shape=self.batch_size,
1622
+ )
1623
+ self.reward_spec = Unbounded(
1624
+ (
1625
+ *self.batch_size,
1626
+ 1,
1627
+ )
1628
+ )
1629
+ self.done_spec = Categorical(
1630
+ 2,
1631
+ dtype=torch.bool,
1632
+ shape=(
1633
+ *self.batch_size,
1634
+ 1,
1635
+ ),
1636
+ )
1637
+ self.action_spec = Binary(n=1, shape=[*self.batch_size, 1])
1638
+
1639
+ self.count = torch.zeros(
1640
+ (*self.batch_size, 1), device=self.device, dtype=torch.int
1641
+ )
1642
+ if start_val.numel() == self.batch_size.numel():
1643
+ self.start_val = start_val.view(*self.batch_size, 1)
1644
+ elif start_val.numel() <= 1:
1645
+ self.start_val = start_val.expand_as(self.count)
1646
+
1647
+ def _set_seed(self, seed: int | None) -> None:
1648
+ torch.manual_seed(seed)
1649
+
1650
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1651
+ if tensordict is not None and "_reset" in tensordict.keys():
1652
+ _reset = tensordict.get("_reset")
1653
+ self.count[_reset] = self.start_val[_reset].view_as(self.count[_reset])
1654
+ else:
1655
+ self.count[:] = self.start_val.view_as(self.count)
1656
+ return TensorDict(
1657
+ source={
1658
+ "observation": self.count.clone(),
1659
+ "done": self.count > self.max_steps.view_as(self.count),
1660
+ "terminated": self.count > self.max_steps.view_as(self.count),
1661
+ },
1662
+ batch_size=self.batch_size,
1663
+ device=self.device,
1664
+ )
1665
+
1666
+ def _step(
1667
+ self,
1668
+ tensordict: TensorDictBase,
1669
+ ) -> TensorDictBase:
1670
+ action = tensordict.get("action")
1671
+ self.count += action.to(torch.int).view_as(self.count)
1672
+ tensordict = TensorDict(
1673
+ source={
1674
+ "observation": self.count.clone(),
1675
+ "done": self.count > self.max_steps.unsqueeze(-1),
1676
+ "terminated": self.count > self.max_steps.unsqueeze(-1),
1677
+ "reward": torch.zeros_like(self.count, dtype=torch.float),
1678
+ },
1679
+ batch_size=self.batch_size,
1680
+ device=self.device,
1681
+ )
1682
+ return tensordict
1683
+
1684
+
1685
+ class HeterogeneousCountingEnvPolicy(TensorDictModuleBase):
1686
+ """Policy for `HeterogeneousCountingEnv` that outputs increment (or zero) actions."""
1687
+
1688
+ def __init__(self, full_action_spec: TensorSpec, count: bool = True):
1689
+ super().__init__()
1690
+ self.full_action_spec = full_action_spec
1691
+ self.count = count
1692
+
1693
+ def __call__(self, td: TensorDictBase) -> TensorDictBase:
1694
+ action_td = self.full_action_spec.zero()
1695
+ if self.count:
1696
+ action_td.apply_(lambda x: x + 1)
1697
+ return td.update(action_td)
1698
+
1699
+
1700
+ class HeterogeneousCountingEnv(EnvBase):
1701
+ """A heterogeneous, counting Env."""
1702
+
1703
+ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
1704
+ super().__init__(**kwargs)
1705
+ self.n_nested_dim = 3
1706
+ self.max_steps = max_steps
1707
+ self.start_val = start_val
1708
+
1709
+ count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int)
1710
+ count[:] = self.start_val
1711
+
1712
+ self.register_buffer("count", count)
1713
+ self._make_specs()
1714
+
1715
+ def _make_specs(self):
1716
+ obs_specs = []
1717
+ action_specs = []
1718
+ for index in range(self.n_nested_dim):
1719
+ obs_specs.append(self.get_agent_obs_spec(index))
1720
+ action_specs.append(self.get_agent_action_spec(index))
1721
+ obs_specs = torch.stack(obs_specs, dim=0)
1722
+ obs_spec_unlazy = consolidate_spec(obs_specs)
1723
+ action_specs = torch.stack(action_specs, dim=0)
1724
+
1725
+ self.observation_spec_unbatched = Composite(
1726
+ lazy=obs_spec_unlazy,
1727
+ state=Unbounded(shape=(64, 64, 3)),
1728
+ device=self.device,
1729
+ )
1730
+
1731
+ self.action_spec_unbatched = Composite(
1732
+ lazy=action_specs,
1733
+ device=self.device,
1734
+ )
1735
+ self.reward_spec_unbatched = Composite(
1736
+ {
1737
+ "lazy": Composite(
1738
+ {"reward": Unbounded(shape=(self.n_nested_dim, 1))},
1739
+ shape=(self.n_nested_dim,),
1740
+ )
1741
+ },
1742
+ device=self.device,
1743
+ )
1744
+ self.done_spec_unbatched = Composite(
1745
+ {
1746
+ "lazy": Composite(
1747
+ {
1748
+ "done": Categorical(
1749
+ n=2,
1750
+ shape=(self.n_nested_dim, 1),
1751
+ dtype=torch.bool,
1752
+ ),
1753
+ },
1754
+ shape=(self.n_nested_dim,),
1755
+ )
1756
+ },
1757
+ device=self.device,
1758
+ )
1759
+
1760
+ def get_agent_obs_spec(self, i):
1761
+ camera = Bounded(low=0, high=200, shape=(7, 7, 3))
1762
+ vector_3d = Unbounded(shape=(3,))
1763
+ vector_2d = Unbounded(shape=(2,))
1764
+ lidar = Bounded(low=0, high=5, shape=(8,))
1765
+
1766
+ tensor_0 = Unbounded(shape=(1,))
1767
+ tensor_1 = Bounded(low=0, high=3, shape=(1, 2))
1768
+ tensor_2 = Unbounded(shape=(1, 2, 3))
1769
+
1770
+ if i == 0:
1771
+ return Composite(
1772
+ {
1773
+ "camera": camera,
1774
+ "lidar": lidar,
1775
+ "vector": vector_3d,
1776
+ "tensor_0": tensor_0,
1777
+ },
1778
+ device=self.device,
1779
+ )
1780
+ elif i == 1:
1781
+ return Composite(
1782
+ {
1783
+ "camera": camera,
1784
+ "lidar": lidar,
1785
+ "vector": vector_2d,
1786
+ "tensor_1": tensor_1,
1787
+ },
1788
+ device=self.device,
1789
+ )
1790
+ elif i == 2:
1791
+ return Composite(
1792
+ {
1793
+ "camera": camera,
1794
+ "vector": vector_2d,
1795
+ "tensor_2": tensor_2,
1796
+ },
1797
+ device=self.device,
1798
+ )
1799
+ else:
1800
+ raise ValueError(f"Index {i} undefined for index 3")
1801
+
1802
+ def get_agent_action_spec(self, i):
1803
+ action_3d = Bounded(low=-1, high=1, shape=(3,))
1804
+ action_2d = Bounded(low=-1, high=1, shape=(2,))
1805
+
1806
+ # Some have 2d action and some 3d
1807
+ # TODO Introduce composite heterogeneous actions
1808
+ if i == 0:
1809
+ ret = action_3d
1810
+ elif i == 1:
1811
+ ret = action_2d
1812
+ elif i == 2:
1813
+ ret = action_2d
1814
+ else:
1815
+ raise ValueError(f"Index {i} undefined for index 3")
1816
+
1817
+ return Composite({"action": ret})
1818
+
1819
+ def _reset(
1820
+ self,
1821
+ tensordict: TensorDictBase = None,
1822
+ **kwargs,
1823
+ ) -> TensorDictBase:
1824
+ if tensordict is not None and self.reset_keys[0] in tensordict.keys(True):
1825
+ _reset = tensordict.get(self.reset_keys[0]).squeeze(-1).any(-1)
1826
+ self.count[_reset] = self.start_val
1827
+ else:
1828
+ self.count[:] = self.start_val
1829
+
1830
+ reset_td = self.observation_spec.zero()
1831
+ reset_td.apply_(lambda x: x + expand_right(self.count, x.shape))
1832
+ reset_td.update(self.output_spec["full_done_spec"].zero())
1833
+
1834
+ assert reset_td.batch_size == self.batch_size
1835
+ for key in reset_td.keys(True):
1836
+ assert "_reset" not in key
1837
+ return reset_td
1838
+
1839
+ def _step(
1840
+ self,
1841
+ tensordict: TensorDictBase,
1842
+ ) -> TensorDictBase:
1843
+ actions = torch.zeros_like(self.count.squeeze(-1), dtype=torch.bool)
1844
+ for i in range(self.n_nested_dim):
1845
+ action = tensordict["lazy"][..., i]["action"]
1846
+ action = action[..., 0].to(torch.bool)
1847
+ actions += action
1848
+
1849
+ self.count += actions.unsqueeze(-1).to(torch.int)
1850
+
1851
+ td = self.observation_spec.zero()
1852
+ td.apply_(lambda x: x + expand_right(self.count, x.shape))
1853
+ td.update(self.output_spec["full_done_spec"].zero())
1854
+ td.update(self.output_spec["full_reward_spec"].zero())
1855
+
1856
+ assert td.batch_size == self.batch_size
1857
+ for done_key in self.done_keys:
1858
+ td[done_key] = expand_right(
1859
+ self.count > self.max_steps,
1860
+ self.full_done_spec[done_key].shape,
1861
+ )
1862
+
1863
+ return td
1864
+
1865
+ def _set_seed(self, seed: int | None) -> None:
1866
+ torch.manual_seed(seed)
1867
+
1868
+
1869
+ class MultiKeyCountingEnvPolicy(TensorDictModuleBase):
1870
+ """Policy for `MultiKeyCountingEnv` that can count deterministically or stochastically."""
1871
+
1872
+ def __init__(
1873
+ self,
1874
+ full_action_spec: TensorSpec,
1875
+ count: bool = True,
1876
+ deterministic: bool = False,
1877
+ ):
1878
+ super().__init__()
1879
+ if not deterministic and not count:
1880
+ raise ValueError("Not counting policy is always deterministic")
1881
+
1882
+ self.full_action_spec = full_action_spec
1883
+ self.count = count
1884
+ self.deterministic = deterministic
1885
+
1886
+ def __call__(self, td: TensorDictBase) -> TensorDictBase:
1887
+ action_td = self.full_action_spec.zero()
1888
+ if self.count:
1889
+ if self.deterministic:
1890
+ action_td["nested_1", "action"] += 1
1891
+ action_td["nested_2", "azione"] += 1
1892
+ action_td["action"][..., 1] = 1
1893
+ else:
1894
+ # We choose an action at random
1895
+ choice = torch.randint(0, 3, ()).item()
1896
+ if choice == 0:
1897
+ action_td["nested_1", "action"] += 1
1898
+ elif choice == 1:
1899
+ action_td["nested_2", "azione"] += 1
1900
+ else:
1901
+ action_td["action"][..., 1] = 1
1902
+ return td.update(action_td)
1903
+
1904
+
1905
+ class MultiKeyCountingEnv(EnvBase):
1906
+ """Counting env with multiple action/observation keys and nested structures."""
1907
+
1908
+ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
1909
+ super().__init__(**kwargs)
1910
+
1911
+ self.max_steps = max_steps
1912
+ self.start_val = start_val
1913
+ self.nested_dim_1 = 3
1914
+ self.nested_dim_2 = 2
1915
+
1916
+ count = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int)
1917
+ count_nested_1 = torch.zeros(
1918
+ (*self.batch_size, self.nested_dim_1, 1),
1919
+ device=self.device,
1920
+ dtype=torch.int,
1921
+ )
1922
+ count_nested_2 = torch.zeros(
1923
+ (*self.batch_size, self.nested_dim_2, 1),
1924
+ device=self.device,
1925
+ dtype=torch.int,
1926
+ )
1927
+
1928
+ count[:] = self.start_val
1929
+ count_nested_1[:] = self.start_val
1930
+ count_nested_2[:] = self.start_val
1931
+
1932
+ self.register_buffer("count", count)
1933
+ self.register_buffer("count_nested_1", count_nested_1)
1934
+ self.register_buffer("count_nested_2", count_nested_2)
1935
+
1936
+ self.make_specs()
1937
+
1938
+ def make_specs(self):
1939
+ self.observation_spec_unbatched = Composite(
1940
+ nested_1=Composite(
1941
+ observation=Bounded(low=0, high=200, shape=(self.nested_dim_1, 3)),
1942
+ shape=(self.nested_dim_1,),
1943
+ ),
1944
+ nested_2=Composite(
1945
+ observation=Unbounded(shape=(self.nested_dim_2, 2)),
1946
+ shape=(self.nested_dim_2,),
1947
+ ),
1948
+ observation=Unbounded(
1949
+ shape=(
1950
+ 10,
1951
+ 10,
1952
+ 3,
1953
+ )
1954
+ ),
1955
+ )
1956
+
1957
+ self.action_spec_unbatched = Composite(
1958
+ nested_1=Composite(
1959
+ action=Categorical(n=2, shape=(self.nested_dim_1,)),
1960
+ shape=(self.nested_dim_1,),
1961
+ ),
1962
+ nested_2=Composite(
1963
+ azione=Bounded(low=0, high=100, shape=(self.nested_dim_2, 1)),
1964
+ shape=(self.nested_dim_2,),
1965
+ ),
1966
+ action=OneHot(n=2),
1967
+ )
1968
+
1969
+ self.reward_spec_unbatched = Composite(
1970
+ nested_1=Composite(
1971
+ gift=Unbounded(shape=(self.nested_dim_1, 1)),
1972
+ shape=(self.nested_dim_1,),
1973
+ ),
1974
+ nested_2=Composite(
1975
+ reward=Unbounded(shape=(self.nested_dim_2, 1)),
1976
+ shape=(self.nested_dim_2,),
1977
+ ),
1978
+ reward=Unbounded(shape=(1,)),
1979
+ )
1980
+
1981
+ self.done_spec_unbatched = Composite(
1982
+ nested_1=Composite(
1983
+ done=Categorical(
1984
+ n=2,
1985
+ shape=(self.nested_dim_1, 1),
1986
+ dtype=torch.bool,
1987
+ ),
1988
+ terminated=Categorical(
1989
+ n=2,
1990
+ shape=(self.nested_dim_1, 1),
1991
+ dtype=torch.bool,
1992
+ ),
1993
+ shape=(self.nested_dim_1,),
1994
+ ),
1995
+ nested_2=Composite(
1996
+ done=Categorical(
1997
+ n=2,
1998
+ shape=(self.nested_dim_2, 1),
1999
+ dtype=torch.bool,
2000
+ ),
2001
+ terminated=Categorical(
2002
+ n=2,
2003
+ shape=(self.nested_dim_2, 1),
2004
+ dtype=torch.bool,
2005
+ ),
2006
+ shape=(self.nested_dim_2,),
2007
+ ),
2008
+ # done at the root always prevail
2009
+ done=Categorical(
2010
+ n=2,
2011
+ shape=(1,),
2012
+ dtype=torch.bool,
2013
+ ),
2014
+ terminated=Categorical(
2015
+ n=2,
2016
+ shape=(1,),
2017
+ dtype=torch.bool,
2018
+ ),
2019
+ )
2020
+
2021
+ def _reset(
2022
+ self,
2023
+ tensordict: TensorDictBase = None,
2024
+ **kwargs,
2025
+ ) -> TensorDictBase:
2026
+ reset_all = False
2027
+ if tensordict is not None:
2028
+ _reset = tensordict.get("_reset", None)
2029
+ if _reset is not None:
2030
+ self.count[_reset.squeeze(-1)] = self.start_val
2031
+ self.count_nested_1[_reset.squeeze(-1)] = self.start_val
2032
+ self.count_nested_2[_reset.squeeze(-1)] = self.start_val
2033
+ else:
2034
+ reset_all = True
2035
+
2036
+ if tensordict is None or reset_all:
2037
+ self.count[:] = self.start_val
2038
+ self.count_nested_1[:] = self.start_val
2039
+ self.count_nested_2[:] = self.start_val
2040
+
2041
+ reset_td = self.observation_spec.zero()
2042
+ reset_td["observation"] += expand_right(
2043
+ self.count, reset_td["observation"].shape
2044
+ )
2045
+ reset_td["nested_1", "observation"] += expand_right(
2046
+ self.count_nested_1, reset_td["nested_1", "observation"].shape
2047
+ )
2048
+ reset_td["nested_2", "observation"] += expand_right(
2049
+ self.count_nested_2, reset_td["nested_2", "observation"].shape
2050
+ )
2051
+
2052
+ reset_td.update(self.output_spec["full_done_spec"].zero())
2053
+
2054
+ assert reset_td.batch_size == self.batch_size
2055
+
2056
+ return reset_td
2057
+
2058
+ def _step(
2059
+ self,
2060
+ tensordict: TensorDictBase,
2061
+ ) -> TensorDictBase:
2062
+
2063
+ # Each action has a corresponding reward, done, and observation
2064
+ reward = self.output_spec["full_reward_spec"].zero()
2065
+ done = self.output_spec["full_done_spec"].zero()
2066
+ td = self.observation_spec.zero()
2067
+
2068
+ one_hot_action = tensordict["action"]
2069
+ one_hot_action = one_hot_action.long().argmax(-1).unsqueeze(-1)
2070
+ reward["reward"] += one_hot_action.to(torch.float)
2071
+ self.count += one_hot_action.to(torch.int)
2072
+ td["observation"] += expand_right(self.count, td["observation"].shape)
2073
+ done["done"] = self.count > self.max_steps
2074
+ done["terminated"] = self.count > self.max_steps
2075
+
2076
+ discrete_action = tensordict["nested_1"]["action"].unsqueeze(-1)
2077
+ reward["nested_1"]["gift"] += discrete_action.to(torch.float)
2078
+ self.count_nested_1 += discrete_action.to(torch.int)
2079
+ td["nested_1", "observation"] += expand_right(
2080
+ self.count_nested_1, td["nested_1", "observation"].shape
2081
+ )
2082
+ done["nested_1", "done"] = self.count_nested_1 > self.max_steps
2083
+ done["nested_1", "terminated"] = self.count_nested_1 > self.max_steps
2084
+
2085
+ continuous_action = tensordict["nested_2"]["azione"]
2086
+ reward["nested_2"]["reward"] += continuous_action.to(torch.float)
2087
+ self.count_nested_2 += continuous_action.to(torch.bool)
2088
+ td["nested_2", "observation"] += expand_right(
2089
+ self.count_nested_2, td["nested_2", "observation"].shape
2090
+ )
2091
+ done["nested_2", "done"] = self.count_nested_2 > self.max_steps
2092
+ done["nested_2", "terminated"] = self.count_nested_2 > self.max_steps
2093
+
2094
+ td.update(done)
2095
+ td.update(reward)
2096
+
2097
+ assert td.batch_size == self.batch_size
2098
+ return td
2099
+
2100
+ def _set_seed(self, seed: int | None) -> None:
2101
+ torch.manual_seed(seed)
2102
+
2103
+
2104
+ class EnvWithMetadata(EnvBase):
2105
+ """Environment that emits both tensor and non-tensor observations (for metadata tests)."""
2106
+
2107
+ def __init__(self):
2108
+ super().__init__()
2109
+ self.observation_spec = Composite(
2110
+ tensor=Unbounded(3),
2111
+ non_tensor=NonTensor(shape=()),
2112
+ )
2113
+ self._saved_obs_spec = self.observation_spec.clone()
2114
+ self.state_spec = Composite(
2115
+ non_tensor=NonTensor(shape=()),
2116
+ )
2117
+ self._saved_state_spec = self.state_spec.clone()
2118
+ self.reward_spec = Unbounded(1)
2119
+ self._saved_full_reward_spec = self.full_reward_spec.clone()
2120
+ self.action_spec = Unbounded(1)
2121
+ self._saved_full_action_spec = self.full_action_spec.clone()
2122
+
2123
+ def _reset(self, tensordict):
2124
+ data = self._saved_obs_spec.zero()
2125
+ data.set_non_tensor("non_tensor", 0)
2126
+ data.update(self.full_done_spec.zero())
2127
+ return data
2128
+
2129
+ def _step(
2130
+ self,
2131
+ tensordict: TensorDictBase,
2132
+ ) -> TensorDictBase:
2133
+ data = self._saved_obs_spec.zero()
2134
+ data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1)
2135
+ data.update(self.full_done_spec.zero())
2136
+ data.update(self._saved_full_reward_spec.zero())
2137
+ return data
2138
+
2139
+ def _set_seed(self, seed: int | None) -> None:
2140
+ ...
2141
+
2142
+
2143
+ class AutoResettingCountingEnv(CountingEnv):
2144
+ """`CountingEnv` variant that auto-resets when done is reached."""
2145
+
2146
+ def _step(self, tensordict):
2147
+ tensordict = super()._step(tensordict)
2148
+ if tensordict["done"].any():
2149
+ td_reset = super().reset()
2150
+ tensordict.update(td_reset.exclude(*self.done_keys))
2151
+ return tensordict
2152
+
2153
+ def _reset(self, tensordict=None):
2154
+ if tensordict is not None and "_reset" in tensordict:
2155
+ raise RuntimeError
2156
+ return super()._reset(tensordict)
2157
+
2158
+
2159
+ class AutoResetHeteroCountingEnv(HeterogeneousCountingEnv):
2160
+ """`HeterogeneousCountingEnv` variant that partially resets done sub-episodes."""
2161
+
2162
+ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
2163
+ super().__init__(**kwargs)
2164
+ self.n_nested_dim = 3
2165
+ self.max_steps = max_steps
2166
+ self.start_val = start_val
2167
+
2168
+ count = torch.zeros(
2169
+ (*self.batch_size, self.n_nested_dim, 1),
2170
+ device=self.device,
2171
+ dtype=torch.int,
2172
+ )
2173
+ count[:] = self.start_val
2174
+
2175
+ self.register_buffer("count", count)
2176
+ self._make_specs()
2177
+
2178
+ def _step(self, tensordict):
2179
+ for i in range(self.n_nested_dim):
2180
+ action = tensordict["lazy"][..., i]["action"]
2181
+ action = action[..., 0].to(torch.bool)
2182
+ self.count[..., i, 0] += action
2183
+
2184
+ td = self.observation_spec.zero()
2185
+ for done_key in self.done_keys:
2186
+ td[done_key] = self.count > self.max_steps
2187
+
2188
+ any_done = _terminated_or_truncated(
2189
+ td,
2190
+ full_done_spec=self.output_spec["full_done_spec"],
2191
+ key=None,
2192
+ )
2193
+ if any_done:
2194
+ self.count[td["lazy", "done"]] = 0
2195
+
2196
+ for i in range(self.n_nested_dim):
2197
+ lazy = tensordict["lazy"][..., i]
2198
+ for obskey in self.observation_spec.keys(True, True):
2199
+ if isinstance(obskey, tuple) and obskey[0] == "lazy":
2200
+ lazy[obskey[1:]] += expand_right(
2201
+ self.count[..., i, 0], lazy[obskey[1:]].shape
2202
+ ).clone()
2203
+ td.update(self.full_done_spec.zero())
2204
+ td.update(self.full_reward_spec.zero())
2205
+
2206
+ assert td.batch_size == self.batch_size
2207
+ return td
2208
+
2209
+ def _reset(self, tensordict=None):
2210
+ if tensordict is not None and self.reset_keys[0] in tensordict.keys(True):
2211
+ raise RuntimeError
2212
+ self.count[:] = self.start_val
2213
+
2214
+ reset_td = self.observation_spec.zero()
2215
+ reset_td.update(self.full_done_spec.zero())
2216
+ assert reset_td.batch_size == self.batch_size
2217
+ return reset_td
2218
+
2219
+
2220
+ class EnvWithDynamicSpec(EnvBase):
2221
+ """Environment with dynamic (ragged) observation specs that grow over time."""
2222
+
2223
+ def __init__(self, max_count=5):
2224
+ super().__init__(batch_size=())
2225
+ self.observation_spec = Composite(
2226
+ observation=Unbounded(shape=(3, -1, 2)),
2227
+ )
2228
+ self.action_spec = Bounded(low=-1, high=1, shape=(2,))
2229
+ self.full_done_spec = Composite(
2230
+ done=Binary(1, shape=(1,), dtype=torch.bool),
2231
+ terminated=Binary(1, shape=(1,), dtype=torch.bool),
2232
+ truncated=Binary(1, shape=(1,), dtype=torch.bool),
2233
+ )
2234
+ self.reward_spec = Unbounded((1,), dtype=torch.float)
2235
+ self.count = 0
2236
+ self.max_count = max_count
2237
+
2238
+ def _reset(self, tensordict=None):
2239
+ self.count = 0
2240
+ data = TensorDict(
2241
+ {
2242
+ "observation": torch.full(
2243
+ (3, self.count + 1, 2),
2244
+ self.count,
2245
+ dtype=self.observation_spec["observation"].dtype,
2246
+ )
2247
+ }
2248
+ )
2249
+ data.update(self.done_spec.zero())
2250
+ return data
2251
+
2252
+ def _step(
2253
+ self,
2254
+ tensordict: TensorDictBase,
2255
+ ) -> TensorDictBase:
2256
+ self.count += 1
2257
+ done = self.count >= self.max_count
2258
+ observation = TensorDict(
2259
+ {
2260
+ "observation": torch.full(
2261
+ (3, self.count + 1, 2),
2262
+ self.count,
2263
+ dtype=self.observation_spec["observation"].dtype,
2264
+ )
2265
+ }
2266
+ )
2267
+ done = self.full_done_spec.zero() | done
2268
+ reward = self.full_reward_spec.zero()
2269
+ return observation.update(done).update(reward)
2270
+
2271
+ def _set_seed(self, seed: int | None) -> None:
2272
+ self.manual_seed = seed
2273
+
2274
+
2275
+ class EnvWithScalarAction(EnvBase):
2276
+ """Environment exposing a scalar (or singleton) action spec for edge-case testing."""
2277
+
2278
+ def __init__(self, singleton: bool = False, **kwargs):
2279
+ super().__init__(**kwargs)
2280
+ self.singleton = singleton
2281
+ self.action_spec = Bounded(
2282
+ -1,
2283
+ 1,
2284
+ shape=(
2285
+ *self.batch_size,
2286
+ 1,
2287
+ )
2288
+ if self.singleton
2289
+ else self.batch_size,
2290
+ )
2291
+ self.observation_spec = Composite(
2292
+ observation=Unbounded(
2293
+ shape=(
2294
+ *self.batch_size,
2295
+ 3,
2296
+ )
2297
+ ),
2298
+ shape=self.batch_size,
2299
+ )
2300
+ self.done_spec = Composite(
2301
+ done=Unbounded(self.batch_size + (1,), dtype=torch.bool),
2302
+ terminated=Unbounded(self.batch_size + (1,), dtype=torch.bool),
2303
+ truncated=Unbounded(self.batch_size + (1,), dtype=torch.bool),
2304
+ shape=self.batch_size,
2305
+ )
2306
+ self.reward_spec = Unbounded(
2307
+ shape=(
2308
+ *self.batch_size,
2309
+ 1,
2310
+ )
2311
+ )
2312
+
2313
+ def _reset(self, td: TensorDict):
2314
+ return TensorDict(
2315
+ observation=torch.randn(*self.batch_size, 3, device=self.device),
2316
+ done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device),
2317
+ truncated=torch.zeros(
2318
+ *self.batch_size, 1, dtype=torch.bool, device=self.device
2319
+ ),
2320
+ terminated=torch.zeros(
2321
+ *self.batch_size, 1, dtype=torch.bool, device=self.device
2322
+ ),
2323
+ device=self.device,
2324
+ )
2325
+
2326
+ def _step(
2327
+ self,
2328
+ tensordict: TensorDictBase,
2329
+ ) -> TensorDictBase:
2330
+ return TensorDict(
2331
+ observation=torch.randn(*self.batch_size, 3, device=self.device),
2332
+ reward=torch.zeros(1, device=self.device),
2333
+ done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device),
2334
+ truncated=torch.zeros(
2335
+ *self.batch_size, 1, dtype=torch.bool, device=self.device
2336
+ ),
2337
+ terminated=torch.zeros(
2338
+ *self.batch_size, 1, dtype=torch.bool, device=self.device
2339
+ ),
2340
+ )
2341
+
2342
+ def _set_seed(self, seed: int | None) -> None:
2343
+ ...
2344
+
2345
+
2346
+ class EnvThatDoesNothing(EnvBase):
2347
+ """Environment whose reset/step return empty tensordicts (for plumbing tests)."""
2348
+
2349
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2350
+ return TensorDict(batch_size=self.batch_size, device=self.device)
2351
+
2352
+ def _step(
2353
+ self,
2354
+ tensordict: TensorDictBase,
2355
+ ) -> TensorDictBase:
2356
+ return TensorDict(batch_size=self.batch_size, device=self.device)
2357
+
2358
+ def _set_seed(self, seed: int | None) -> None:
2359
+ ...
2360
+
2361
+
2362
+ class Str2StrEnv(EnvBase):
2363
+ """String-to-string environment with non-tensor observation/action fields."""
2364
+
2365
+ def __init__(self, min_size=4, max_size=10, **kwargs):
2366
+ self.observation_spec = Composite(
2367
+ observation=NonTensor(example_data="an observation!", shape=())
2368
+ )
2369
+ self.full_action_spec = Composite(
2370
+ action=NonTensor(example_data="an action!", shape=())
2371
+ )
2372
+ self.reward_spec = Unbounded(shape=(1,), dtype=torch.float)
2373
+ self.min_size = min_size
2374
+ self.max_size = max_size
2375
+ super().__init__(**kwargs)
2376
+
2377
+ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2378
+ assert isinstance(tensordict["action"], str)
2379
+ out = tensordict.empty()
2380
+ out.set("observation", self.get_random_string())
2381
+ out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01))
2382
+ out.set("reward", torch.zeros(1, dtype=torch.float).bernoulli_(0.01))
2383
+ return out
2384
+
2385
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2386
+ out = tensordict.empty() if tensordict is not None else TensorDict()
2387
+ out.set("observation", self.get_random_string())
2388
+ out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01))
2389
+ return out
2390
+
2391
+ def get_random_string(self):
2392
+ return get_random_string(self.min_size, self.max_size)
2393
+
2394
+ def _set_seed(self, seed: int | None) -> None:
2395
+ random.seed(seed)
2396
+ torch.manual_seed(0)
2397
+
2398
+
2399
+ class EnvThatErrorsAfter10Iters(EnvBase):
2400
+ """Environment that raises after 10 steps (used to validate error propagation)."""
2401
+
2402
+ def __init__(self):
2403
+ self.action_spec = Composite(action=Unbounded((1,)))
2404
+ self.reward_spec = Composite(reward=Unbounded((1,)))
2405
+ self.done_spec = Composite(done=Unbounded((1,)))
2406
+ self.observation_spec = Composite(observation=Unbounded((1,)))
2407
+ self.counter = 0
2408
+ super().__init__()
2409
+
2410
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
2411
+ return self.full_observation_spec.zero().update(self.full_done_spec.zero())
2412
+
2413
+ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
2414
+ if self.counter >= 10:
2415
+ raise RuntimeError("max steps!")
2416
+ self.counter += 1
2417
+ return (
2418
+ self.full_observation_spec.zero()
2419
+ .update(self.full_done_spec.zero())
2420
+ .update(self.full_reward_spec.zero())
2421
+ )
2422
+
2423
+ def _set_seed(self, seed: int | None) -> None:
2424
+ ...
2425
+
2426
+
2427
+ @tensorclass()
2428
+ class TC:
2429
+ """Simple tensorclass used by `EnvWithTensorClass`."""
2430
+
2431
+ field0: str
2432
+ field1: torch.Tensor
2433
+
2434
+
2435
+ class EnvWithTensorClass(CountingEnv):
2436
+ """`CountingEnv` variant that carries a tensorclass observation."""
2437
+
2438
+ tc_cls = TC
2439
+
2440
+ def __init__(self, **kwargs):
2441
+ super().__init__(**kwargs)
2442
+ self.observation_spec["tc"] = Composite(
2443
+ field0=NonTensor(example_data="an observation!", shape=self.batch_size),
2444
+ field1=Unbounded(shape=self.batch_size),
2445
+ shape=self.batch_size,
2446
+ data_cls=TC,
2447
+ )
2448
+
2449
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2450
+ td = super()._reset(tensordict, **kwargs)
2451
+ td["tc"] = TC("0", torch.zeros(self.batch_size))
2452
+ return td
2453
+
2454
+ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2455
+ td = super()._step(tensordict, **kwargs)
2456
+ default = TC("0", 0)
2457
+ f0 = tensordict.get("tc", default).field0
2458
+ if f0 is None:
2459
+ f0 = "0"
2460
+ f1 = tensordict.get("tc", default).field1
2461
+ if f1 is None:
2462
+ f1 = torch.zeros(self.batch_size)
2463
+ td["tc"] = TC(
2464
+ str(int(f0) + 1),
2465
+ f1 + 1,
2466
+ )
2467
+ return td
2468
+
2469
+
2470
+ @tensorclass
2471
+ class History:
2472
+ """Simple history record (role/content) used by `HistoryTransform`."""
2473
+
2474
+ role: str
2475
+ content: str
2476
+
2477
+
2478
+ class HistoryTransform(Transform):
2479
+ """A mocking class to record history."""
2480
+
2481
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
2482
+ defaults = {
2483
+ "role": NonTensor(
2484
+ example_data="a role!",
2485
+ shape=(-1,),
2486
+ ),
2487
+ "content": NonTensor(
2488
+ example_data="a content!",
2489
+ shape=(-1,),
2490
+ ),
2491
+ }
2492
+ observation_spec["history"] = Composite(
2493
+ defaults,
2494
+ shape=(-1,),
2495
+ data_cls=History,
2496
+ )
2497
+ assert observation_spec.device == self.parent.device
2498
+ assert observation_spec["history"].device == self.parent.device
2499
+ return observation_spec
2500
+
2501
+ def _reset(
2502
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
2503
+ ) -> TensorDictBase:
2504
+ assert tensordict_reset.device == self.parent.device
2505
+ tensordict_reset["history"] = torch.stack(
2506
+ [
2507
+ History(role="system", content="0"),
2508
+ History(role="user", content="1"),
2509
+ ]
2510
+ )
2511
+ assert tensordict_reset["history"].device == self.parent.device
2512
+ return tensordict_reset
2513
+
2514
+ def _step(
2515
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
2516
+ ) -> TensorDictBase:
2517
+ assert next_tensordict.device == self.parent.device
2518
+ history = tensordict["history"]
2519
+ local_history = History(
2520
+ role=np.random.choice(["user", "system", "assistant"]),
2521
+ content=str(int(history.content[-1]) + 1),
2522
+ device=history.device,
2523
+ )
2524
+ # history = tensordict["history"].append(local_history)
2525
+ try:
2526
+ history = torch.stack(list(history.unbind(0)) + [local_history])
2527
+ except Exception:
2528
+ raise
2529
+ assert isinstance(history, History)
2530
+ next_tensordict["history"] = history
2531
+ assert next_tensordict["history"].device == self.parent.device, (
2532
+ next_tensordict["history"],
2533
+ self.parent.device,
2534
+ )
2535
+ return next_tensordict
2536
+
2537
+
2538
+ class DummyStrDataLoader:
2539
+ """Minimal iterator that yields random strings (for LLM tests)."""
2540
+
2541
+ def __init__(self, batch_size=0):
2542
+ if isinstance(batch_size, tuple):
2543
+ batch_size = torch.Size(batch_size).numel()
2544
+ self.batch_size = batch_size
2545
+
2546
+ def generate_random_string(self, length=10):
2547
+ """Generate a random string of a given length."""
2548
+ return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
2549
+
2550
+ def __iter__(self):
2551
+ return self
2552
+
2553
+ def __next__(self):
2554
+ if self.batch_size == 0:
2555
+ return {"text": self.generate_random_string()}
2556
+ else:
2557
+ return {
2558
+ "text": [self.generate_random_string() for _ in range(self.batch_size)]
2559
+ }
2560
+
2561
+
2562
+ class DummyTensorDataLoader:
2563
+ """Minimal iterator that yields random token tensors (for LLM tests)."""
2564
+
2565
+ def __init__(self, batch_size=0, max_length=10, padding=False):
2566
+ if isinstance(batch_size, tuple):
2567
+ batch_size = torch.Size(batch_size).numel()
2568
+ self.batch_size = batch_size
2569
+ self.max_length = max_length
2570
+ self.padding = padding
2571
+
2572
+ def generate_random_tensor(self):
2573
+ """Generate a tensor of random int64 values."""
2574
+ length = random.randint(1, self.max_length)
2575
+ rt = torch.randint(1, 10000, (length,))
2576
+ return rt
2577
+
2578
+ def pad_tensor(self, tensor):
2579
+ """Pad a tensor to the maximum length."""
2580
+ padding_length = self.max_length - len(tensor)
2581
+ return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor))
2582
+
2583
+ def __iter__(self):
2584
+ return self
2585
+
2586
+ def __next__(self):
2587
+ if self.batch_size == 0:
2588
+ tensor = self.generate_random_tensor()
2589
+ tokens = self.pad_tensor(tensor) if self.padding else tensor
2590
+ else:
2591
+ tensors = [self.generate_random_tensor() for _ in range(self.batch_size)]
2592
+ if self.padding:
2593
+ tensors = [self.pad_tensor(tensor) for tensor in tensors]
2594
+ tokens = torch.stack(tensors)
2595
+ else:
2596
+ tokens = tensors
2597
+ return {"tokens": tokens, "attention_mask": tokens != 0}
2598
+
2599
+
2600
+ class MockNestedResetEnv(EnvBase):
2601
+ """To test behaviour of envs with nested done states - where the root done prevails over others."""
2602
+
2603
+ def __init__(self, num_steps: int, done_at_root: bool) -> None:
2604
+ super().__init__(device="cpu")
2605
+ self._num_steps = num_steps
2606
+ self._counter = 0
2607
+ self.done_at_root = done_at_root
2608
+ self.done_spec = Composite(
2609
+ {
2610
+ ("agent_1", "done"): Binary(1, dtype=torch.bool),
2611
+ ("agent_2", "done"): Binary(1, dtype=torch.bool),
2612
+ }
2613
+ )
2614
+ if done_at_root:
2615
+ self.full_done_spec["done"] = Binary(1, dtype=torch.bool)
2616
+
2617
+ def _reset(self, tensordict: TensorDict) -> TensorDict:
2618
+ torchrl_logger.info(f"Reset after {self._counter} steps!")
2619
+ if tensordict is not None:
2620
+ torchrl_logger.info(f"tensordict at reset {tensordict.to_dict()}")
2621
+ self._counter = 0
2622
+ result = TensorDict(
2623
+ {
2624
+ ("agent_1", "done"): torch.tensor([False], dtype=torch.bool),
2625
+ ("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
2626
+ },
2627
+ )
2628
+ if self.done_at_root:
2629
+ result["done"] = torch.tensor([False], dtype=torch.bool)
2630
+ return result
2631
+
2632
+ def _step(self, tensordict: TensorDict) -> TensorDict:
2633
+ self._counter += 1
2634
+ done = torch.tensor([self._counter >= self._num_steps], dtype=torch.bool)
2635
+ if self.done_at_root:
2636
+ return TensorDict(
2637
+ {
2638
+ "done": done,
2639
+ ("agent_1", "done"): torch.tensor([True], dtype=torch.bool),
2640
+ ("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
2641
+ },
2642
+ )
2643
+ else:
2644
+ return TensorDict(
2645
+ {
2646
+ ("agent_1", "done"): done,
2647
+ ("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
2648
+ },
2649
+ )
2650
+
2651
+ def _set_seed(self):
2652
+ pass
2653
+
2654
+
2655
+ class EnvThatErrorsBecauseOfStack(EnvBase):
2656
+ """Environment crafted to trigger stacking errors with certain batch shapes."""
2657
+
2658
+ def __init__(self, target: int = 5, batch_size: int | None = None):
2659
+ super().__init__(device="cpu", batch_size=batch_size)
2660
+ self.target = target
2661
+ self.observation_spec = Bounded(
2662
+ low=0, high=self.target, shape=(1,), dtype=torch.int64
2663
+ )
2664
+ self.action_spec = Categorical(n=2, shape=(1,), dtype=torch.int64)
2665
+ self.reward_spec = Unbounded(shape=(1,), dtype=torch.float32)
2666
+ self.done_spec = Categorical(n=2, shape=(1,), dtype=torch.bool)
2667
+
2668
+ def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict:
2669
+ if tensordict is None:
2670
+ tensordict = TensorDict(batch_size=self.batch_size, device=self.device)
2671
+
2672
+ observation = torch.zeros(
2673
+ self.batch_size, dtype=self.observation_spec.dtype, device=self.device
2674
+ )
2675
+ reward = torch.zeros(
2676
+ self.batch_size + torch.Size([1]),
2677
+ dtype=self.reward_spec.dtype,
2678
+ device=self.device,
2679
+ )
2680
+ done = torch.zeros(
2681
+ self.batch_size + torch.Size([1]), dtype=torch.bool, device=self.device
2682
+ )
2683
+ terminated = torch.zeros_like(done)
2684
+ action = torch.zeros(
2685
+ self.batch_size + torch.Size([1]), dtype=torch.int64, device=self.device
2686
+ )
2687
+
2688
+ tensordict.set(self.observation_keys[0], observation)
2689
+ tensordict.set(self.reward_key, reward)
2690
+ tensordict.set(self.done_keys[0], done)
2691
+ tensordict.set("terminated", terminated)
2692
+ tensordict.set(self.action_keys[0], action)
2693
+
2694
+ return tensordict
2695
+
2696
+ def _step(self, tensordict: TensorDict) -> TensorDict:
2697
+ obs = tensordict.get(
2698
+ self.observation_keys[0]
2699
+ ) # the counter value or the counters value if it is several batchs
2700
+ action = tensordict.get(self.action_keys[0]).squeeze(-1)
2701
+
2702
+ new_obs = obs + (action == 1).to(obs.dtype)
2703
+ new_obs = new_obs.clamp_max(self.target)
2704
+ reward = (new_obs == self.target).to(self.reward_spec.dtype).unsqueeze(-1)
2705
+ done = (new_obs == self.target).to(torch.bool).unsqueeze(-1)
2706
+ terminated = done.clone()
2707
+ return TensorDict(
2708
+ {
2709
+ self.observation_keys[0]: new_obs,
2710
+ self.reward_keys[0]: reward,
2711
+ self.done_keys[0]: done,
2712
+ "terminated": terminated,
2713
+ self.action_keys[0]: action.unsqueeze(-1),
2714
+ },
2715
+ batch_size=self.batch_size,
2716
+ device=self.device,
2717
+ )
2718
+
2719
+ def _set_seed(self, seed: int | None) -> None:
2720
+ return 0