torchrl 0.10.0__cp313-cp313-win_amd64.whl → 0.11.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (232) hide show
  1. benchmarks/ecosystem/gym_env_throughput.py +1 -1
  2. benchmarks/storage/benchmark_sample_latency_over_rpc.py +1 -1
  3. benchmarks/test_collectors_benchmark.py +1 -1
  4. benchmarks/test_llm.py +5 -0
  5. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  6. benchmarks/test_objectives_benchmarks.py +22 -0
  7. sota-implementations/a2c/a2c_atari.py +6 -8
  8. sota-implementations/a2c/a2c_mujoco.py +6 -8
  9. sota-implementations/cql/cql_offline.py +4 -8
  10. sota-implementations/cql/cql_online.py +4 -8
  11. sota-implementations/cql/discrete_cql_offline.py +4 -5
  12. sota-implementations/cql/discrete_cql_online.py +4 -8
  13. sota-implementations/crossq/crossq.py +6 -8
  14. sota-implementations/ddpg/ddpg.py +9 -16
  15. sota-implementations/decision_transformer/dt.py +4 -8
  16. sota-implementations/decision_transformer/online_dt.py +4 -8
  17. sota-implementations/dqn/dqn_atari.py +2 -8
  18. sota-implementations/dqn/dqn_cartpole.py +2 -8
  19. sota-implementations/dreamer/README.md +128 -6
  20. sota-implementations/dreamer/dreamer.py +408 -159
  21. sota-implementations/dreamer/dreamer_utils.py +456 -76
  22. sota-implementations/expert-iteration/ei_utils.py +24 -28
  23. sota-implementations/expert-iteration/expert-iteration-async.py +29 -26
  24. sota-implementations/expert-iteration/expert-iteration-sync.py +25 -28
  25. sota-implementations/expert-iteration/requirements_gsm8k.txt +2 -2
  26. sota-implementations/expert-iteration/requirements_ifeval.txt +2 -2
  27. sota-implementations/gail/gail.py +6 -10
  28. sota-implementations/grpo/grpo-async.py +26 -22
  29. sota-implementations/grpo/grpo-sync.py +21 -20
  30. sota-implementations/grpo/grpo_utils.py +40 -10
  31. sota-implementations/grpo/requirements_gsm8k.txt +11 -13
  32. sota-implementations/grpo/requirements_ifeval.txt +16 -16
  33. sota-implementations/impala/impala_multi_node_ray.py +2 -2
  34. sota-implementations/impala/impala_multi_node_submitit.py +2 -2
  35. sota-implementations/impala/impala_single_node.py +2 -2
  36. sota-implementations/iql/discrete_iql.py +4 -8
  37. sota-implementations/iql/iql_offline.py +6 -10
  38. sota-implementations/iql/iql_online.py +4 -8
  39. sota-implementations/multiagent/mappo_ippo.py +1 -1
  40. sota-implementations/multiagent/sac.py +1 -1
  41. sota-implementations/ppo/ppo_atari.py +6 -10
  42. sota-implementations/ppo/ppo_mujoco.py +6 -10
  43. sota-implementations/sac/sac-async.py +11 -8
  44. sota-implementations/sac/sac.py +6 -8
  45. sota-implementations/sac_trainer/train.py +16 -0
  46. sota-implementations/td3/td3.py +6 -9
  47. sota-implementations/td3_bc/td3_bc.py +6 -8
  48. torchrl/__init__.py +54 -28
  49. torchrl/_extension.py +12 -2
  50. torchrl/_torchrl.cp313-win_amd64.pyd +0 -0
  51. torchrl/_utils.py +475 -6
  52. torchrl/collectors/__init__.py +27 -14
  53. torchrl/collectors/_base.py +1058 -0
  54. torchrl/collectors/_constants.py +88 -0
  55. torchrl/collectors/_multi_async.py +324 -0
  56. torchrl/collectors/_multi_base.py +1805 -0
  57. torchrl/collectors/_multi_sync.py +464 -0
  58. torchrl/collectors/_runner.py +581 -0
  59. torchrl/collectors/_single.py +2009 -0
  60. torchrl/collectors/_single_async.py +259 -0
  61. torchrl/collectors/collectors.py +54 -3876
  62. torchrl/collectors/distributed/__init__.py +10 -3
  63. torchrl/collectors/distributed/default_configs.py +99 -0
  64. torchrl/collectors/distributed/generic.py +396 -162
  65. torchrl/collectors/distributed/ray.py +377 -82
  66. torchrl/collectors/distributed/rpc.py +201 -106
  67. torchrl/collectors/distributed/sync.py +185 -96
  68. torchrl/collectors/distributed/utils.py +4 -5
  69. torchrl/collectors/llm/base.py +44 -29
  70. torchrl/collectors/llm/ray_collector.py +1 -1
  71. torchrl/collectors/llm/weight_update/vllm.py +25 -17
  72. torchrl/collectors/llm/weight_update/vllm_v2.py +20 -20
  73. torchrl/collectors/utils.py +177 -3
  74. torchrl/collectors/weight_update.py +25 -9
  75. torchrl/data/__init__.py +0 -24
  76. torchrl/data/datasets/d4rl.py +14 -6
  77. torchrl/data/datasets/minari_data.py +110 -8
  78. torchrl/data/datasets/openx.py +1 -1
  79. torchrl/data/map/tree.py +2 -2
  80. torchrl/data/postprocs/postprocs.py +2 -2
  81. torchrl/data/replay_buffers/checkpointers.py +54 -37
  82. torchrl/data/replay_buffers/ray_buffer.py +16 -6
  83. torchrl/data/replay_buffers/replay_buffers.py +248 -39
  84. torchrl/data/replay_buffers/samplers.py +46 -40
  85. torchrl/data/replay_buffers/storages.py +488 -18
  86. torchrl/data/replay_buffers/utils.py +4 -4
  87. torchrl/data/replay_buffers/writers.py +22 -0
  88. torchrl/data/tensor_specs.py +240 -178
  89. torchrl/data/utils.py +6 -9
  90. torchrl/envs/__init__.py +4 -4
  91. torchrl/envs/async_envs.py +83 -8
  92. torchrl/envs/batched_envs.py +348 -77
  93. torchrl/envs/common.py +165 -10
  94. torchrl/envs/custom/chess.py +5 -6
  95. torchrl/envs/custom/llm.py +3 -3
  96. torchrl/envs/env_creator.py +7 -2
  97. torchrl/envs/libs/__init__.py +3 -0
  98. torchrl/envs/libs/brax.py +76 -4
  99. torchrl/envs/libs/dm_control.py +64 -5
  100. torchrl/envs/libs/envpool.py +40 -0
  101. torchrl/envs/libs/gym.py +174 -29
  102. torchrl/envs/libs/procgen.py +351 -0
  103. torchrl/envs/libs/vmas.py +10 -40
  104. torchrl/envs/llm/datasets/gsm8k.py +2 -2
  105. torchrl/envs/llm/envs.py +6 -4
  106. torchrl/envs/llm/transforms/__init__.py +25 -6
  107. torchrl/envs/llm/transforms/browser.py +4 -4
  108. torchrl/envs/llm/transforms/dataloading.py +3 -7
  109. torchrl/envs/llm/transforms/kl.py +3 -2
  110. torchrl/envs/llm/transforms/tools.py +1382 -251
  111. torchrl/envs/model_based/common.py +8 -5
  112. torchrl/envs/model_based/dreamer.py +28 -2
  113. torchrl/envs/transforms/__init__.py +5 -4
  114. torchrl/envs/transforms/module.py +341 -0
  115. torchrl/envs/{llm/transforms → transforms}/ray_service.py +51 -32
  116. torchrl/envs/transforms/transforms.py +261 -38
  117. torchrl/envs/transforms/vecnorm.py +27 -0
  118. torchrl/envs/transforms/vip.py +6 -6
  119. torchrl/envs/utils.py +8 -32
  120. torchrl/modules/__init__.py +20 -0
  121. torchrl/modules/distributions/continuous.py +74 -11
  122. torchrl/modules/distributions/discrete.py +3 -3
  123. torchrl/modules/distributions/truncated_normal.py +6 -4
  124. torchrl/modules/distributions/utils.py +7 -3
  125. torchrl/modules/llm/__init__.py +29 -15
  126. torchrl/modules/llm/backends/__init__.py +42 -17
  127. torchrl/modules/llm/backends/vllm/__init__.py +48 -16
  128. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  129. torchrl/modules/llm/backends/vllm/vllm_async.py +236 -153
  130. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  131. torchrl/modules/llm/backends/vllm/vllm_sync.py +26 -6
  132. torchrl/modules/llm/backends/vllm/vllm_utils.py +23 -4
  133. torchrl/modules/llm/policies/__init__.py +6 -1
  134. torchrl/modules/llm/policies/common.py +447 -9
  135. torchrl/modules/llm/policies/transformers_wrapper.py +33 -21
  136. torchrl/modules/llm/policies/vllm_wrapper.py +117 -48
  137. torchrl/modules/mcts/__init__.py +21 -0
  138. torchrl/modules/mcts/scores.py +579 -0
  139. torchrl/modules/models/decision_transformer.py +1 -1
  140. torchrl/modules/models/exploration.py +5 -5
  141. torchrl/modules/models/model_based.py +256 -36
  142. torchrl/modules/models/multiagent.py +28 -12
  143. torchrl/modules/planners/cem.py +5 -2
  144. torchrl/modules/planners/common.py +4 -1
  145. torchrl/modules/planners/mppi.py +5 -1
  146. torchrl/modules/tensordict_module/__init__.py +2 -0
  147. torchrl/modules/tensordict_module/actors.py +10 -10
  148. torchrl/modules/tensordict_module/common.py +2 -2
  149. torchrl/modules/tensordict_module/exploration.py +80 -11
  150. torchrl/modules/tensordict_module/probabilistic.py +2 -2
  151. torchrl/modules/tensordict_module/rnn.py +21 -15
  152. torchrl/objectives/a2c.py +27 -26
  153. torchrl/objectives/common.py +70 -12
  154. torchrl/objectives/cql.py +27 -8
  155. torchrl/objectives/crossq.py +21 -10
  156. torchrl/objectives/ddpg.py +23 -4
  157. torchrl/objectives/decision_transformer.py +7 -8
  158. torchrl/objectives/deprecated.py +16 -4
  159. torchrl/objectives/dqn.py +46 -10
  160. torchrl/objectives/dreamer.py +52 -21
  161. torchrl/objectives/gail.py +1 -1
  162. torchrl/objectives/iql.py +23 -12
  163. torchrl/objectives/llm/__init__.py +22 -2
  164. torchrl/objectives/llm/grpo.py +381 -65
  165. torchrl/objectives/llm/sft.py +4 -4
  166. torchrl/objectives/multiagent/qmixer.py +9 -2
  167. torchrl/objectives/ppo.py +39 -31
  168. torchrl/objectives/redq.py +18 -5
  169. torchrl/objectives/reinforce.py +11 -3
  170. torchrl/objectives/sac.py +111 -160
  171. torchrl/objectives/td3.py +23 -6
  172. torchrl/objectives/td3_bc.py +23 -6
  173. torchrl/objectives/utils.py +44 -4
  174. torchrl/objectives/value/advantages.py +16 -2
  175. torchrl/objectives/value/functional.py +41 -12
  176. torchrl/record/loggers/trackio.py +163 -0
  177. torchrl/record/loggers/utils.py +9 -1
  178. torchrl/record/loggers/wandb.py +13 -3
  179. torchrl/record/recorder.py +5 -3
  180. torchrl/services/__init__.py +79 -0
  181. torchrl/services/base.py +109 -0
  182. torchrl/services/ray_service.py +453 -0
  183. torchrl/testing/__init__.py +107 -0
  184. torchrl/testing/assertions.py +179 -0
  185. torchrl/testing/dist_utils.py +122 -0
  186. torchrl/testing/env_creators.py +227 -0
  187. torchrl/testing/env_helper.py +35 -0
  188. torchrl/testing/gym_helpers.py +156 -0
  189. torchrl/testing/llm_mocks.py +119 -0
  190. torchrl/testing/mocking_classes.py +2720 -0
  191. torchrl/testing/modules.py +295 -0
  192. torchrl/testing/mp_helpers.py +15 -0
  193. torchrl/testing/ray_helpers.py +293 -0
  194. torchrl/testing/utils.py +190 -0
  195. torchrl/trainers/__init__.py +6 -4
  196. torchrl/trainers/algorithms/__init__.py +2 -1
  197. torchrl/trainers/algorithms/configs/__init__.py +138 -13
  198. torchrl/trainers/algorithms/configs/collectors.py +44 -13
  199. torchrl/trainers/algorithms/configs/data.py +3 -0
  200. torchrl/trainers/algorithms/configs/modules.py +228 -4
  201. torchrl/trainers/algorithms/configs/objectives.py +108 -6
  202. torchrl/trainers/algorithms/configs/trainers.py +211 -7
  203. torchrl/trainers/algorithms/configs/transforms.py +33 -2
  204. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  205. torchrl/trainers/algorithms/ppo.py +139 -26
  206. torchrl/trainers/algorithms/sac.py +308 -0
  207. torchrl/trainers/helpers/collectors.py +15 -15
  208. torchrl/trainers/helpers/trainers.py +5 -5
  209. torchrl/trainers/trainers.py +431 -100
  210. torchrl/weight_update/__init__.py +33 -0
  211. torchrl/weight_update/_distributed.py +749 -0
  212. torchrl/weight_update/_mp.py +624 -0
  213. torchrl/weight_update/_noupdate.py +102 -0
  214. torchrl/weight_update/_ray.py +1032 -0
  215. torchrl/weight_update/_rpc.py +284 -0
  216. torchrl/weight_update/_shared.py +891 -0
  217. torchrl/weight_update/llm/__init__.py +32 -0
  218. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  219. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  220. torchrl/weight_update/utils.py +73 -0
  221. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  222. {torchrl-0.10.0.dist-info → torchrl-0.11.0.dist-info}/METADATA +23 -6
  223. torchrl-0.11.0.dist-info/RECORD +394 -0
  224. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  225. {torchrl-0.10.0.dist-info → torchrl-0.11.0.dist-info}/top_level.txt +2 -0
  226. torchrl/data/rlhf.py +0 -41
  227. torchrl/envs/transforms/llm.py +0 -329
  228. torchrl/envs/transforms/rlhf.py +0 -27
  229. torchrl/modules/models/rlhf.py +0 -17
  230. torchrl-0.10.0.dist-info/RECORD +0 -351
  231. {torchrl-0.10.0.dist-info → torchrl-0.11.0.dist-info}/LICENSE +0 -0
  232. {torchrl-0.10.0.dist-info → torchrl-0.11.0.dist-info}/WHEEL +0 -0
@@ -27,7 +27,7 @@ from torchrl.collectors import (
27
27
  )
28
28
  from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
29
29
  from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
30
- from torchrl.envs.utils import RandomPolicy
30
+ from torchrl.modules import RandomPolicy
31
31
 
32
32
  if __name__ == "__main__":
33
33
  avail_devices = ("cpu",)
@@ -144,7 +144,7 @@ if __name__ == "__main__":
144
144
  rank = args.rank
145
145
  storage_type = args.storage
146
146
 
147
- torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}")
147
+ torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}")
148
148
 
149
149
  os.environ["MASTER_ADDR"] = "localhost"
150
150
  os.environ["MASTER_PORT"] = "29500"
@@ -18,7 +18,7 @@ from torchrl.data import LazyTensorStorage, ReplayBuffer
18
18
  from torchrl.data.utils import CloudpickleWrapper
19
19
  from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
20
20
  from torchrl.envs.libs.dm_control import DMControlEnv
21
- from torchrl.envs.utils import RandomPolicy
21
+ from torchrl.modules import RandomPolicy
22
22
 
23
23
 
24
24
  def single_collector_setup():
benchmarks/test_llm.py CHANGED
@@ -16,6 +16,11 @@ from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrappe
16
16
 
17
17
  _has_transformers = importlib.import_module("transformers") is not None
18
18
 
19
+ # Skip all these tests if gpu is not available
20
+ pytestmark = pytest.mark.skipif(
21
+ not torch.cuda.is_available(), reason="GPU not available"
22
+ )
23
+
19
24
 
20
25
  @pytest.fixture(scope="module")
21
26
  def transformers_wrapper():
@@ -0,0 +1,70 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import gc
8
+ import time
9
+
10
+ import pytest
11
+ from tensordict import set_capture_non_tensor_stack
12
+ from torchrl.envs import ParallelEnv, SerialEnv
13
+ from torchrl.testing.mocking_classes import EnvWithMetadata
14
+
15
+
16
+ def _rollout(env, n_steps: int, break_when_any_done: bool) -> None:
17
+ env.rollout(n_steps, break_when_any_done=break_when_any_done)
18
+
19
+
20
+ @pytest.mark.parametrize("break_when_any_done", [True, False])
21
+ @pytest.mark.parametrize(
22
+ "kind,use_buffers",
23
+ [
24
+ pytest.param("single", None, id="single"),
25
+ pytest.param("serial", False, id="serial-no-buffers"),
26
+ pytest.param("serial", True, id="serial-buffers"),
27
+ pytest.param("parallel", False, id="parallel-no-buffers"),
28
+ pytest.param("parallel", True, id="parallel-buffers"),
29
+ ],
30
+ )
31
+ @pytest.mark.parametrize("n_steps", [1000])
32
+ def test_non_tensor_env_rollout_speed(
33
+ benchmark,
34
+ break_when_any_done: bool,
35
+ kind: str,
36
+ use_buffers: bool | None,
37
+ n_steps: int,
38
+ ):
39
+ """Benchmarks a single rollout, after a warmup rollout, for non-tensor stacking envs.
40
+
41
+ Mirrors `test/test_envs.py::TestNonTensorEnv`'s option matrix (single/serial/parallel,
42
+ break_when_any_done, use_buffers).
43
+ """
44
+ with set_capture_non_tensor_stack(False):
45
+ if kind == "single":
46
+ env = EnvWithMetadata()
47
+ elif kind == "serial":
48
+ env = SerialEnv(2, EnvWithMetadata, use_buffers=use_buffers)
49
+ elif kind == "parallel":
50
+ env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
51
+ else:
52
+ raise RuntimeError(f"Unknown kind={kind}")
53
+
54
+ env.set_seed(0)
55
+ env.reset()
56
+
57
+ try:
58
+ # Warmup run (not timed)
59
+ _rollout(env, n_steps=n_steps, break_when_any_done=break_when_any_done)
60
+
61
+ # Timed run(s)
62
+ benchmark(
63
+ _rollout, env, n_steps=n_steps, break_when_any_done=break_when_any_done
64
+ )
65
+ finally:
66
+ env.close(raise_if_closed=False)
67
+ del env
68
+ # Give multiprocessing envs a brief chance to terminate cleanly.
69
+ time.sleep(0.05)
70
+ gc.collect()
@@ -172,6 +172,8 @@ def _maybe_compile(fn, compile, td, fullgraph=FULLGRAPH, warmup=3):
172
172
  def test_dqn_speed(
173
173
  benchmark, backward, compile, n_obs=8, n_act=4, depth=3, ncells=128, batch=128
174
174
  ):
175
+ if compile == "reduce-overhead" and backward is not None:
176
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
175
177
  if compile:
176
178
  torch._dynamo.reset_code_caches()
177
179
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -227,6 +229,8 @@ def test_dqn_speed(
227
229
  def test_ddpg_speed(
228
230
  benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
229
231
  ):
232
+ if compile == "reduce-overhead" and backward is not None:
233
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
230
234
  if compile:
231
235
  torch._dynamo.reset_code_caches()
232
236
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -303,6 +307,8 @@ def test_ddpg_speed(
303
307
  def test_sac_speed(
304
308
  benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
305
309
  ):
310
+ if compile == "reduce-overhead" and backward is not None:
311
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
306
312
  if compile:
307
313
  torch._dynamo.reset_code_caches()
308
314
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -393,6 +399,8 @@ def test_sac_speed(
393
399
  def test_redq_speed(
394
400
  benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
395
401
  ):
402
+ if compile == "reduce-overhead" and backward is not None:
403
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
396
404
  if compile:
397
405
  torch._dynamo.reset_code_caches()
398
406
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -484,6 +492,8 @@ def test_redq_speed(
484
492
  def test_redq_deprec_speed(
485
493
  benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
486
494
  ):
495
+ if compile == "reduce-overhead" and backward is not None:
496
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
487
497
  if compile:
488
498
  torch._dynamo.reset_code_caches()
489
499
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -573,6 +583,8 @@ def test_redq_deprec_speed(
573
583
  def test_td3_speed(
574
584
  benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
575
585
  ):
586
+ if compile == "reduce-overhead" and backward is not None:
587
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
576
588
  if compile:
577
589
  torch._dynamo.reset_code_caches()
578
590
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -667,6 +679,8 @@ def test_td3_speed(
667
679
  def test_cql_speed(
668
680
  benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
669
681
  ):
682
+ if compile == "reduce-overhead" and backward is not None:
683
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
670
684
  if compile:
671
685
  torch._dynamo.reset_code_caches()
672
686
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -763,6 +777,8 @@ def test_a2c_speed(
763
777
  batch=128,
764
778
  T=10,
765
779
  ):
780
+ if compile == "reduce-overhead" and backward is not None:
781
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
766
782
  if compile:
767
783
  torch._dynamo.reset_code_caches()
768
784
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -866,6 +882,8 @@ def test_ppo_speed(
866
882
  batch=128,
867
883
  T=10,
868
884
  ):
885
+ if compile == "reduce-overhead" and backward is not None:
886
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
869
887
  if compile:
870
888
  torch._dynamo.reset_code_caches()
871
889
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -969,6 +987,8 @@ def test_reinforce_speed(
969
987
  batch=128,
970
988
  T=10,
971
989
  ):
990
+ if compile == "reduce-overhead" and backward is not None:
991
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
972
992
  if compile:
973
993
  torch._dynamo.reset_code_caches()
974
994
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -1072,6 +1092,8 @@ def test_iql_speed(
1072
1092
  batch=128,
1073
1093
  T=10,
1074
1094
  ):
1095
+ if compile == "reduce-overhead" and backward is not None:
1096
+ pytest.skip("reduce-overhead with backward causes segfaults in CI")
1075
1097
  if compile:
1076
1098
  torch._dynamo.reset_code_caches()
1077
1099
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -22,7 +22,7 @@ def main(cfg: DictConfig): # noqa: F821
22
22
  from tensordict import from_module
23
23
  from tensordict.nn import CudaGraphModule
24
24
 
25
- from torchrl._utils import timeit
25
+ from torchrl._utils import get_available_device, timeit
26
26
  from torchrl.collectors import SyncDataCollector
27
27
  from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
28
28
  from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
@@ -33,11 +33,9 @@ def main(cfg: DictConfig): # noqa: F821
33
33
  from torchrl.record.loggers import generate_exp_name, get_logger
34
34
  from utils_atari import eval_model, make_parallel_env, make_ppo_models
35
35
 
36
- device = cfg.loss.device
37
- if not device:
38
- device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
39
- else:
40
- device = torch.device(device)
36
+ device = (
37
+ torch.device(cfg.loss.device) if cfg.loss.device else get_available_device()
38
+ )
41
39
 
42
40
  # Correct for frame_skip
43
41
  frame_skip = 4
@@ -76,8 +74,8 @@ def main(cfg: DictConfig): # noqa: F821
76
74
  actor_network=actor,
77
75
  critic_network=critic,
78
76
  loss_critic_type=cfg.loss.loss_critic_type,
79
- entropy_coef=cfg.loss.entropy_coef,
80
- critic_coef=cfg.loss.critic_coef,
77
+ entropy_coeff=cfg.loss.entropy_coeff,
78
+ critic_coeff=cfg.loss.critic_coeff,
81
79
  )
82
80
 
83
81
  # use end-of-life as done key
@@ -23,7 +23,7 @@ def main(cfg: DictConfig): # noqa: F821
23
23
  from tensordict import from_module
24
24
  from tensordict.nn import CudaGraphModule
25
25
 
26
- from torchrl._utils import timeit
26
+ from torchrl._utils import get_available_device, timeit
27
27
  from torchrl.collectors import SyncDataCollector
28
28
  from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
29
29
  from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
@@ -36,11 +36,9 @@ def main(cfg: DictConfig): # noqa: F821
36
36
 
37
37
  # Define paper hyperparameters
38
38
 
39
- device = cfg.loss.device
40
- if not device:
41
- device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
42
- else:
43
- device = torch.device(device)
39
+ device = (
40
+ torch.device(cfg.loss.device) if cfg.loss.device else get_available_device()
41
+ )
44
42
 
45
43
  num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
46
44
  total_network_updates = (
@@ -77,8 +75,8 @@ def main(cfg: DictConfig): # noqa: F821
77
75
  actor_network=actor,
78
76
  critic_network=critic,
79
77
  loss_critic_type=cfg.loss.loss_critic_type,
80
- entropy_coef=cfg.loss.entropy_coef,
81
- critic_coef=cfg.loss.critic_coef,
78
+ entropy_coeff=cfg.loss.entropy_coeff,
79
+ critic_coeff=cfg.loss.critic_coeff,
82
80
  )
83
81
 
84
82
  # Create optimizers
@@ -18,7 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  import tqdm
20
20
  from tensordict.nn import CudaGraphModule
21
- from torchrl._utils import timeit
21
+ from torchrl._utils import get_available_device, timeit
22
22
  from torchrl.envs.utils import ExplorationType, set_exploration_type
23
23
  from torchrl.objectives import group_optimizers
24
24
  from torchrl.record.loggers import generate_exp_name, get_logger
@@ -55,13 +55,9 @@ def main(cfg: DictConfig): # noqa: F821
55
55
  # Set seeds
56
56
  torch.manual_seed(cfg.env.seed)
57
57
  np.random.seed(cfg.env.seed)
58
- device = cfg.optim.device
59
- if device in ("", None):
60
- if torch.cuda.is_available():
61
- device = "cuda:0"
62
- else:
63
- device = "cpu"
64
- device = torch.device(device)
58
+ device = (
59
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
60
+ )
65
61
 
66
62
  # Create replay buffer
67
63
  replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
@@ -21,7 +21,7 @@ import torch
21
21
  import tqdm
22
22
  from tensordict import TensorDict
23
23
  from tensordict.nn import CudaGraphModule
24
- from torchrl._utils import timeit
24
+ from torchrl._utils import get_available_device, timeit
25
25
  from torchrl.envs.utils import ExplorationType, set_exploration_type
26
26
  from torchrl.objectives import group_optimizers
27
27
  from torchrl.record.loggers import generate_exp_name, get_logger
@@ -60,13 +60,9 @@ def main(cfg: DictConfig): # noqa: F821
60
60
  # Set seeds
61
61
  torch.manual_seed(cfg.env.seed)
62
62
  np.random.seed(cfg.env.seed)
63
- device = cfg.optim.device
64
- if device in ("", None):
65
- if torch.cuda.is_available():
66
- device = "cuda:0"
67
- else:
68
- device = "cpu"
69
- device = torch.device(device)
63
+ device = (
64
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
65
+ )
70
66
 
71
67
  # Create env
72
68
  train_env, eval_env = make_environment(
@@ -18,7 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  import tqdm
20
20
  from tensordict.nn import CudaGraphModule
21
- from torchrl._utils import timeit
21
+ from torchrl._utils import get_available_device, timeit
22
22
  from torchrl.envs.utils import ExplorationType, set_exploration_type
23
23
  from torchrl.record.loggers import generate_exp_name, get_logger
24
24
  from utils import (
@@ -36,10 +36,9 @@ torch.set_float32_matmul_precision("high")
36
36
 
37
37
  @hydra.main(version_base="1.1", config_path="", config_name="discrete_offline_config")
38
38
  def main(cfg): # noqa: F821
39
- device = cfg.optim.device
40
- if device in ("", None):
41
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
42
- device = torch.device(device)
39
+ device = (
40
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
41
+ )
43
42
 
44
43
  # Create logger
45
44
  exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
@@ -20,7 +20,7 @@ import torch
20
20
  import torch.cuda
21
21
  import tqdm
22
22
  from tensordict.nn import CudaGraphModule
23
- from torchrl._utils import timeit
23
+ from torchrl._utils import get_available_device, timeit
24
24
  from torchrl.envs.utils import ExplorationType, set_exploration_type
25
25
  from torchrl.record.loggers import generate_exp_name, get_logger
26
26
  from utils import (
@@ -38,13 +38,9 @@ torch.set_float32_matmul_precision("high")
38
38
 
39
39
  @hydra.main(version_base="1.1", config_path="", config_name="discrete_online_config")
40
40
  def main(cfg: DictConfig): # noqa: F821
41
- device = cfg.optim.device
42
- if device in ("", None):
43
- if torch.cuda.is_available():
44
- device = "cuda:0"
45
- else:
46
- device = "cpu"
47
- device = torch.device(device)
41
+ device = (
42
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
43
+ )
48
44
 
49
45
  # Create logger
50
46
  exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
@@ -21,7 +21,7 @@ import torch.cuda
21
21
  import tqdm
22
22
  from tensordict import TensorDict
23
23
  from tensordict.nn import CudaGraphModule
24
- from torchrl._utils import timeit
24
+ from torchrl._utils import get_available_device, timeit
25
25
  from torchrl.envs.utils import ExplorationType, set_exploration_type
26
26
  from torchrl.objectives import group_optimizers
27
27
  from torchrl.record.loggers import generate_exp_name, get_logger
@@ -40,13 +40,11 @@ torch.set_float32_matmul_precision("high")
40
40
 
41
41
  @hydra.main(version_base="1.1", config_path=".", config_name="config")
42
42
  def main(cfg: DictConfig): # noqa: F821
43
- device = cfg.network.device
44
- if device in ("", None):
45
- if torch.cuda.is_available():
46
- device = torch.device("cuda:0")
47
- else:
48
- device = torch.device("cpu")
49
- device = torch.device(device)
43
+ device = (
44
+ torch.device(cfg.network.device)
45
+ if cfg.network.device
46
+ else get_available_device()
47
+ )
50
48
 
51
49
  # Create logger
52
50
  exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name)
@@ -21,7 +21,7 @@ import torch.cuda
21
21
  import tqdm
22
22
  from tensordict import TensorDict
23
23
  from tensordict.nn import CudaGraphModule
24
- from torchrl._utils import timeit
24
+ from torchrl._utils import get_available_device, timeit
25
25
  from torchrl.envs.utils import ExplorationType, set_exploration_type
26
26
  from torchrl.objectives import group_optimizers
27
27
  from torchrl.record.loggers import generate_exp_name, get_logger
@@ -39,21 +39,14 @@ from utils import (
39
39
 
40
40
  @hydra.main(version_base="1.1", config_path="", config_name="config")
41
41
  def main(cfg: DictConfig): # noqa: F821
42
- device = cfg.optim.device
43
- if device in ("", None):
44
- if torch.cuda.is_available():
45
- device = "cuda:0"
46
- else:
47
- device = "cpu"
48
- device = torch.device(device)
49
-
50
- collector_device = cfg.collector.device
51
- if collector_device in ("", None):
52
- if torch.cuda.is_available():
53
- collector_device = "cuda:0"
54
- else:
55
- collector_device = "cpu"
56
- collector_device = torch.device(collector_device)
42
+ device = (
43
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
44
+ )
45
+ collector_device = (
46
+ torch.device(cfg.collector.device)
47
+ if cfg.collector.device
48
+ else get_available_device()
49
+ )
57
50
 
58
51
  # Create logger
59
52
  exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
@@ -17,7 +17,7 @@ import torch
17
17
  import tqdm
18
18
  from tensordict import TensorDict
19
19
  from tensordict.nn import CudaGraphModule
20
- from torchrl._utils import logger as torchrl_logger, timeit
20
+ from torchrl._utils import get_available_device, logger as torchrl_logger, timeit
21
21
  from torchrl.envs.libs.gym import set_gym_backend
22
22
  from torchrl.envs.utils import ExplorationType, set_exploration_type
23
23
  from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
@@ -38,13 +38,9 @@ from utils import (
38
38
  def main(cfg: DictConfig): # noqa: F821
39
39
  set_gym_backend(cfg.env.backend).set()
40
40
 
41
- model_device = cfg.optim.device
42
- if model_device in ("", None):
43
- if torch.cuda.is_available():
44
- model_device = "cuda:0"
45
- else:
46
- model_device = "cpu"
47
- model_device = torch.device(model_device)
41
+ model_device = (
42
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
43
+ )
48
44
 
49
45
  # Set seeds
50
46
  torch.manual_seed(cfg.env.seed)
@@ -15,7 +15,7 @@ import numpy as np
15
15
  import torch
16
16
  import tqdm
17
17
  from tensordict.nn import CudaGraphModule
18
- from torchrl._utils import logger as torchrl_logger, timeit
18
+ from torchrl._utils import get_available_device, logger as torchrl_logger, timeit
19
19
  from torchrl.envs.libs.gym import set_gym_backend
20
20
  from torchrl.envs.utils import ExplorationType, set_exploration_type
21
21
  from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
@@ -36,13 +36,9 @@ from utils import (
36
36
  def main(cfg: DictConfig): # noqa: F821
37
37
  set_gym_backend(cfg.env.backend).set()
38
38
 
39
- model_device = cfg.optim.device
40
- if model_device in ("", None):
41
- if torch.cuda.is_available():
42
- model_device = "cuda:0"
43
- else:
44
- model_device = "cpu"
45
- model_device = torch.device(model_device)
39
+ model_device = (
40
+ torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
41
+ )
46
42
 
47
43
  # Set seeds
48
44
  torch.manual_seed(cfg.env.seed)
@@ -17,7 +17,7 @@ import torch.nn
17
17
  import torch.optim
18
18
  import tqdm
19
19
  from tensordict.nn import CudaGraphModule, TensorDictSequential
20
- from torchrl._utils import timeit
20
+ from torchrl._utils import get_available_device, timeit
21
21
  from torchrl.collectors import SyncDataCollector
22
22
  from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
23
23
  from torchrl.envs import ExplorationType, set_exploration_type
@@ -33,13 +33,7 @@ torch.set_float32_matmul_precision("high")
33
33
  @hydra.main(config_path="", config_name="config_atari", version_base="1.1")
34
34
  def main(cfg: DictConfig): # noqa: F821
35
35
 
36
- device = cfg.device
37
- if device in ("", None):
38
- if torch.cuda.is_available():
39
- device = "cuda:0"
40
- else:
41
- device = "cpu"
42
- device = torch.device(device)
36
+ device = torch.device(cfg.device) if cfg.device else get_available_device()
43
37
 
44
38
  # Correct for frame_skip
45
39
  frame_skip = 4
@@ -12,7 +12,7 @@ import torch.nn
12
12
  import torch.optim
13
13
  import tqdm
14
14
  from tensordict.nn import CudaGraphModule, TensorDictSequential
15
- from torchrl._utils import timeit
15
+ from torchrl._utils import get_available_device, timeit
16
16
  from torchrl.collectors import SyncDataCollector
17
17
  from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
18
18
  from torchrl.envs import ExplorationType, set_exploration_type
@@ -28,13 +28,7 @@ torch.set_float32_matmul_precision("high")
28
28
  @hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
29
29
  def main(cfg: DictConfig): # noqa: F821
30
30
 
31
- device = cfg.device
32
- if device in ("", None):
33
- if torch.cuda.is_available():
34
- device = "cuda:0"
35
- else:
36
- device = "cpu"
37
- device = torch.device(device)
31
+ device = torch.device(cfg.device) if cfg.device else get_available_device()
38
32
 
39
33
  # Make the components
40
34
  model = make_dqn_model(cfg.env.env_name, device=device)