torchrl 0.10.1__cp313-cp313-win_amd64.whl → 0.11.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/ecosystem/gym_env_throughput.py +1 -1
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +1 -1
- benchmarks/test_collectors_benchmark.py +1 -1
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +22 -0
- sota-implementations/a2c/a2c_atari.py +6 -8
- sota-implementations/a2c/a2c_mujoco.py +6 -8
- sota-implementations/cql/cql_offline.py +4 -8
- sota-implementations/cql/cql_online.py +4 -8
- sota-implementations/cql/discrete_cql_offline.py +4 -5
- sota-implementations/cql/discrete_cql_online.py +4 -8
- sota-implementations/crossq/crossq.py +6 -8
- sota-implementations/ddpg/ddpg.py +9 -16
- sota-implementations/decision_transformer/dt.py +4 -8
- sota-implementations/decision_transformer/online_dt.py +4 -8
- sota-implementations/dqn/dqn_atari.py +2 -8
- sota-implementations/dqn/dqn_cartpole.py +2 -8
- sota-implementations/dreamer/README.md +128 -6
- sota-implementations/dreamer/dreamer.py +408 -159
- sota-implementations/dreamer/dreamer_utils.py +456 -76
- sota-implementations/expert-iteration/ei_utils.py +24 -28
- sota-implementations/expert-iteration/expert-iteration-async.py +29 -26
- sota-implementations/expert-iteration/expert-iteration-sync.py +25 -28
- sota-implementations/expert-iteration/requirements_gsm8k.txt +2 -2
- sota-implementations/expert-iteration/requirements_ifeval.txt +2 -2
- sota-implementations/gail/gail.py +6 -10
- sota-implementations/grpo/grpo-async.py +26 -22
- sota-implementations/grpo/grpo-sync.py +21 -20
- sota-implementations/grpo/grpo_utils.py +40 -10
- sota-implementations/grpo/requirements_gsm8k.txt +11 -13
- sota-implementations/grpo/requirements_ifeval.txt +16 -16
- sota-implementations/impala/impala_multi_node_ray.py +2 -2
- sota-implementations/impala/impala_multi_node_submitit.py +2 -2
- sota-implementations/impala/impala_single_node.py +2 -2
- sota-implementations/iql/discrete_iql.py +4 -8
- sota-implementations/iql/iql_offline.py +6 -10
- sota-implementations/iql/iql_online.py +4 -8
- sota-implementations/multiagent/mappo_ippo.py +1 -1
- sota-implementations/multiagent/sac.py +1 -1
- sota-implementations/ppo/ppo_atari.py +6 -10
- sota-implementations/ppo/ppo_mujoco.py +6 -10
- sota-implementations/sac/sac-async.py +11 -8
- sota-implementations/sac/sac.py +6 -8
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +6 -9
- sota-implementations/td3_bc/td3_bc.py +6 -8
- torchrl/__init__.py +52 -29
- torchrl/_extension.py +12 -2
- torchrl/_torchrl.cp313-win_amd64.pyd +0 -0
- torchrl/_utils.py +472 -5
- torchrl/collectors/__init__.py +27 -14
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +54 -3877
- torchrl/collectors/distributed/__init__.py +10 -3
- torchrl/collectors/distributed/default_configs.py +99 -0
- torchrl/collectors/distributed/generic.py +393 -159
- torchrl/collectors/distributed/ray.py +367 -78
- torchrl/collectors/distributed/rpc.py +198 -103
- torchrl/collectors/distributed/sync.py +182 -93
- torchrl/collectors/distributed/utils.py +4 -5
- torchrl/collectors/llm/base.py +44 -29
- torchrl/collectors/llm/ray_collector.py +1 -1
- torchrl/collectors/llm/weight_update/vllm.py +25 -17
- torchrl/collectors/llm/weight_update/vllm_v2.py +20 -20
- torchrl/collectors/utils.py +177 -3
- torchrl/collectors/weight_update.py +25 -9
- torchrl/data/__init__.py +0 -24
- torchrl/data/datasets/d4rl.py +14 -6
- torchrl/data/datasets/minari_data.py +110 -8
- torchrl/data/datasets/openx.py +1 -1
- torchrl/data/postprocs/postprocs.py +2 -2
- torchrl/data/replay_buffers/checkpointers.py +54 -37
- torchrl/data/replay_buffers/ray_buffer.py +16 -6
- torchrl/data/replay_buffers/replay_buffers.py +248 -39
- torchrl/data/replay_buffers/samplers.py +46 -40
- torchrl/data/replay_buffers/storages.py +488 -18
- torchrl/data/replay_buffers/utils.py +4 -4
- torchrl/data/replay_buffers/writers.py +22 -0
- torchrl/data/tensor_specs.py +49 -170
- torchrl/data/utils.py +6 -9
- torchrl/envs/__init__.py +4 -4
- torchrl/envs/async_envs.py +83 -8
- torchrl/envs/batched_envs.py +347 -76
- torchrl/envs/common.py +164 -9
- torchrl/envs/custom/chess.py +5 -6
- torchrl/envs/env_creator.py +7 -2
- torchrl/envs/libs/__init__.py +3 -0
- torchrl/envs/libs/brax.py +76 -4
- torchrl/envs/libs/dm_control.py +64 -5
- torchrl/envs/libs/envpool.py +40 -0
- torchrl/envs/libs/gym.py +174 -29
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/vmas.py +10 -40
- torchrl/envs/llm/datasets/gsm8k.py +2 -2
- torchrl/envs/llm/envs.py +3 -1
- torchrl/envs/llm/transforms/__init__.py +25 -6
- torchrl/envs/llm/transforms/browser.py +4 -4
- torchrl/envs/llm/transforms/dataloading.py +3 -7
- torchrl/envs/llm/transforms/kl.py +3 -2
- torchrl/envs/llm/transforms/tools.py +1382 -251
- torchrl/envs/model_based/common.py +8 -5
- torchrl/envs/model_based/dreamer.py +28 -2
- torchrl/envs/transforms/__init__.py +5 -4
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/{llm/transforms → transforms}/ray_service.py +51 -32
- torchrl/envs/transforms/transforms.py +259 -36
- torchrl/envs/transforms/vecnorm.py +27 -0
- torchrl/envs/transforms/vip.py +6 -6
- torchrl/envs/utils.py +8 -32
- torchrl/modules/__init__.py +20 -0
- torchrl/modules/distributions/continuous.py +74 -11
- torchrl/modules/distributions/discrete.py +2 -2
- torchrl/modules/distributions/truncated_normal.py +6 -4
- torchrl/modules/distributions/utils.py +7 -3
- torchrl/modules/llm/__init__.py +29 -15
- torchrl/modules/llm/backends/__init__.py +42 -17
- torchrl/modules/llm/backends/vllm/__init__.py +48 -16
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +235 -152
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +26 -6
- torchrl/modules/llm/backends/vllm/vllm_utils.py +23 -4
- torchrl/modules/llm/policies/__init__.py +6 -1
- torchrl/modules/llm/policies/common.py +447 -9
- torchrl/modules/llm/policies/transformers_wrapper.py +31 -20
- torchrl/modules/llm/policies/vllm_wrapper.py +116 -48
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/decision_transformer.py +1 -1
- torchrl/modules/models/exploration.py +5 -5
- torchrl/modules/models/model_based.py +256 -36
- torchrl/modules/models/multiagent.py +28 -12
- torchrl/modules/planners/cem.py +5 -2
- torchrl/modules/planners/common.py +4 -1
- torchrl/modules/planners/mppi.py +5 -1
- torchrl/modules/tensordict_module/__init__.py +2 -0
- torchrl/modules/tensordict_module/actors.py +10 -10
- torchrl/modules/tensordict_module/common.py +2 -2
- torchrl/modules/tensordict_module/exploration.py +80 -11
- torchrl/modules/tensordict_module/probabilistic.py +2 -2
- torchrl/modules/tensordict_module/rnn.py +21 -15
- torchrl/objectives/a2c.py +24 -23
- torchrl/objectives/common.py +70 -12
- torchrl/objectives/cql.py +25 -6
- torchrl/objectives/crossq.py +19 -8
- torchrl/objectives/ddpg.py +22 -3
- torchrl/objectives/decision_transformer.py +5 -6
- torchrl/objectives/deprecated.py +14 -2
- torchrl/objectives/dqn.py +42 -6
- torchrl/objectives/dreamer.py +52 -21
- torchrl/objectives/iql.py +19 -8
- torchrl/objectives/llm/__init__.py +22 -2
- torchrl/objectives/llm/grpo.py +381 -65
- torchrl/objectives/llm/sft.py +4 -4
- torchrl/objectives/multiagent/qmixer.py +8 -1
- torchrl/objectives/ppo.py +39 -31
- torchrl/objectives/redq.py +16 -3
- torchrl/objectives/reinforce.py +8 -0
- torchrl/objectives/sac.py +106 -155
- torchrl/objectives/td3.py +21 -4
- torchrl/objectives/td3_bc.py +21 -4
- torchrl/objectives/utils.py +44 -4
- torchrl/objectives/value/advantages.py +16 -2
- torchrl/objectives/value/functional.py +41 -12
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +9 -1
- torchrl/record/loggers/wandb.py +10 -0
- torchrl/record/recorder.py +5 -3
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +6 -4
- torchrl/trainers/algorithms/__init__.py +2 -1
- torchrl/trainers/algorithms/configs/__init__.py +138 -13
- torchrl/trainers/algorithms/configs/collectors.py +44 -13
- torchrl/trainers/algorithms/configs/data.py +3 -0
- torchrl/trainers/algorithms/configs/modules.py +228 -4
- torchrl/trainers/algorithms/configs/objectives.py +108 -6
- torchrl/trainers/algorithms/configs/trainers.py +211 -7
- torchrl/trainers/algorithms/configs/transforms.py +33 -2
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/ppo.py +139 -26
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/collectors.py +15 -15
- torchrl/trainers/helpers/trainers.py +5 -5
- torchrl/trainers/trainers.py +431 -100
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- {torchrl-0.10.1.dist-info → torchrl-0.11.0.dist-info}/METADATA +23 -6
- torchrl-0.11.0.dist-info/RECORD +394 -0
- {torchrl-0.10.1.dist-info → torchrl-0.11.0.dist-info}/top_level.txt +2 -0
- torchrl/data/rlhf.py +0 -41
- torchrl/envs/transforms/llm.py +0 -329
- torchrl/envs/transforms/rlhf.py +0 -27
- torchrl/modules/models/rlhf.py +0 -17
- torchrl-0.10.1.dist-info/RECORD +0 -352
- {torchrl-0.10.1.dist-info → torchrl-0.11.0.dist-info}/LICENSE +0 -0
- {torchrl-0.10.1.dist-info → torchrl-0.11.0.dist-info}/WHEEL +0 -0
- {torchrl-0.10.1.dist-info → torchrl-0.11.0.dist-info}/entry_points.txt +0 -0
|
@@ -27,7 +27,7 @@ from torchrl.collectors import (
|
|
|
27
27
|
)
|
|
28
28
|
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
|
|
29
29
|
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
|
|
30
|
-
from torchrl.
|
|
30
|
+
from torchrl.modules import RandomPolicy
|
|
31
31
|
|
|
32
32
|
if __name__ == "__main__":
|
|
33
33
|
avail_devices = ("cpu",)
|
|
@@ -144,7 +144,7 @@ if __name__ == "__main__":
|
|
|
144
144
|
rank = args.rank
|
|
145
145
|
storage_type = args.storage
|
|
146
146
|
|
|
147
|
-
torchrl_logger.
|
|
147
|
+
torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}")
|
|
148
148
|
|
|
149
149
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
150
150
|
os.environ["MASTER_PORT"] = "29500"
|
|
@@ -18,7 +18,7 @@ from torchrl.data import LazyTensorStorage, ReplayBuffer
|
|
|
18
18
|
from torchrl.data.utils import CloudpickleWrapper
|
|
19
19
|
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
|
|
20
20
|
from torchrl.envs.libs.dm_control import DMControlEnv
|
|
21
|
-
from torchrl.
|
|
21
|
+
from torchrl.modules import RandomPolicy
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def single_collector_setup():
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import gc
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
from tensordict import set_capture_non_tensor_stack
|
|
12
|
+
from torchrl.envs import ParallelEnv, SerialEnv
|
|
13
|
+
from torchrl.testing.mocking_classes import EnvWithMetadata
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _rollout(env, n_steps: int, break_when_any_done: bool) -> None:
|
|
17
|
+
env.rollout(n_steps, break_when_any_done=break_when_any_done)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.mark.parametrize("break_when_any_done", [True, False])
|
|
21
|
+
@pytest.mark.parametrize(
|
|
22
|
+
"kind,use_buffers",
|
|
23
|
+
[
|
|
24
|
+
pytest.param("single", None, id="single"),
|
|
25
|
+
pytest.param("serial", False, id="serial-no-buffers"),
|
|
26
|
+
pytest.param("serial", True, id="serial-buffers"),
|
|
27
|
+
pytest.param("parallel", False, id="parallel-no-buffers"),
|
|
28
|
+
pytest.param("parallel", True, id="parallel-buffers"),
|
|
29
|
+
],
|
|
30
|
+
)
|
|
31
|
+
@pytest.mark.parametrize("n_steps", [1000])
|
|
32
|
+
def test_non_tensor_env_rollout_speed(
|
|
33
|
+
benchmark,
|
|
34
|
+
break_when_any_done: bool,
|
|
35
|
+
kind: str,
|
|
36
|
+
use_buffers: bool | None,
|
|
37
|
+
n_steps: int,
|
|
38
|
+
):
|
|
39
|
+
"""Benchmarks a single rollout, after a warmup rollout, for non-tensor stacking envs.
|
|
40
|
+
|
|
41
|
+
Mirrors `test/test_envs.py::TestNonTensorEnv`'s option matrix (single/serial/parallel,
|
|
42
|
+
break_when_any_done, use_buffers).
|
|
43
|
+
"""
|
|
44
|
+
with set_capture_non_tensor_stack(False):
|
|
45
|
+
if kind == "single":
|
|
46
|
+
env = EnvWithMetadata()
|
|
47
|
+
elif kind == "serial":
|
|
48
|
+
env = SerialEnv(2, EnvWithMetadata, use_buffers=use_buffers)
|
|
49
|
+
elif kind == "parallel":
|
|
50
|
+
env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
|
|
51
|
+
else:
|
|
52
|
+
raise RuntimeError(f"Unknown kind={kind}")
|
|
53
|
+
|
|
54
|
+
env.set_seed(0)
|
|
55
|
+
env.reset()
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
# Warmup run (not timed)
|
|
59
|
+
_rollout(env, n_steps=n_steps, break_when_any_done=break_when_any_done)
|
|
60
|
+
|
|
61
|
+
# Timed run(s)
|
|
62
|
+
benchmark(
|
|
63
|
+
_rollout, env, n_steps=n_steps, break_when_any_done=break_when_any_done
|
|
64
|
+
)
|
|
65
|
+
finally:
|
|
66
|
+
env.close(raise_if_closed=False)
|
|
67
|
+
del env
|
|
68
|
+
# Give multiprocessing envs a brief chance to terminate cleanly.
|
|
69
|
+
time.sleep(0.05)
|
|
70
|
+
gc.collect()
|
|
@@ -172,6 +172,8 @@ def _maybe_compile(fn, compile, td, fullgraph=FULLGRAPH, warmup=3):
|
|
|
172
172
|
def test_dqn_speed(
|
|
173
173
|
benchmark, backward, compile, n_obs=8, n_act=4, depth=3, ncells=128, batch=128
|
|
174
174
|
):
|
|
175
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
176
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
175
177
|
if compile:
|
|
176
178
|
torch._dynamo.reset_code_caches()
|
|
177
179
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -227,6 +229,8 @@ def test_dqn_speed(
|
|
|
227
229
|
def test_ddpg_speed(
|
|
228
230
|
benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
|
|
229
231
|
):
|
|
232
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
233
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
230
234
|
if compile:
|
|
231
235
|
torch._dynamo.reset_code_caches()
|
|
232
236
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -303,6 +307,8 @@ def test_ddpg_speed(
|
|
|
303
307
|
def test_sac_speed(
|
|
304
308
|
benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
|
|
305
309
|
):
|
|
310
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
311
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
306
312
|
if compile:
|
|
307
313
|
torch._dynamo.reset_code_caches()
|
|
308
314
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -393,6 +399,8 @@ def test_sac_speed(
|
|
|
393
399
|
def test_redq_speed(
|
|
394
400
|
benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
|
|
395
401
|
):
|
|
402
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
403
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
396
404
|
if compile:
|
|
397
405
|
torch._dynamo.reset_code_caches()
|
|
398
406
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -484,6 +492,8 @@ def test_redq_speed(
|
|
|
484
492
|
def test_redq_deprec_speed(
|
|
485
493
|
benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
|
|
486
494
|
):
|
|
495
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
496
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
487
497
|
if compile:
|
|
488
498
|
torch._dynamo.reset_code_caches()
|
|
489
499
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -573,6 +583,8 @@ def test_redq_deprec_speed(
|
|
|
573
583
|
def test_td3_speed(
|
|
574
584
|
benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
|
|
575
585
|
):
|
|
586
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
587
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
576
588
|
if compile:
|
|
577
589
|
torch._dynamo.reset_code_caches()
|
|
578
590
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -667,6 +679,8 @@ def test_td3_speed(
|
|
|
667
679
|
def test_cql_speed(
|
|
668
680
|
benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
|
|
669
681
|
):
|
|
682
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
683
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
670
684
|
if compile:
|
|
671
685
|
torch._dynamo.reset_code_caches()
|
|
672
686
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -763,6 +777,8 @@ def test_a2c_speed(
|
|
|
763
777
|
batch=128,
|
|
764
778
|
T=10,
|
|
765
779
|
):
|
|
780
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
781
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
766
782
|
if compile:
|
|
767
783
|
torch._dynamo.reset_code_caches()
|
|
768
784
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -866,6 +882,8 @@ def test_ppo_speed(
|
|
|
866
882
|
batch=128,
|
|
867
883
|
T=10,
|
|
868
884
|
):
|
|
885
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
886
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
869
887
|
if compile:
|
|
870
888
|
torch._dynamo.reset_code_caches()
|
|
871
889
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -969,6 +987,8 @@ def test_reinforce_speed(
|
|
|
969
987
|
batch=128,
|
|
970
988
|
T=10,
|
|
971
989
|
):
|
|
990
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
991
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
972
992
|
if compile:
|
|
973
993
|
torch._dynamo.reset_code_caches()
|
|
974
994
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -1072,6 +1092,8 @@ def test_iql_speed(
|
|
|
1072
1092
|
batch=128,
|
|
1073
1093
|
T=10,
|
|
1074
1094
|
):
|
|
1095
|
+
if compile == "reduce-overhead" and backward is not None:
|
|
1096
|
+
pytest.skip("reduce-overhead with backward causes segfaults in CI")
|
|
1075
1097
|
if compile:
|
|
1076
1098
|
torch._dynamo.reset_code_caches()
|
|
1077
1099
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -22,7 +22,7 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
22
22
|
from tensordict import from_module
|
|
23
23
|
from tensordict.nn import CudaGraphModule
|
|
24
24
|
|
|
25
|
-
from torchrl._utils import timeit
|
|
25
|
+
from torchrl._utils import get_available_device, timeit
|
|
26
26
|
from torchrl.collectors import SyncDataCollector
|
|
27
27
|
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
|
|
28
28
|
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
|
@@ -33,11 +33,9 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
33
33
|
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
34
34
|
from utils_atari import eval_model, make_parallel_env, make_ppo_models
|
|
35
35
|
|
|
36
|
-
device =
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
else:
|
|
40
|
-
device = torch.device(device)
|
|
36
|
+
device = (
|
|
37
|
+
torch.device(cfg.loss.device) if cfg.loss.device else get_available_device()
|
|
38
|
+
)
|
|
41
39
|
|
|
42
40
|
# Correct for frame_skip
|
|
43
41
|
frame_skip = 4
|
|
@@ -76,8 +74,8 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
76
74
|
actor_network=actor,
|
|
77
75
|
critic_network=critic,
|
|
78
76
|
loss_critic_type=cfg.loss.loss_critic_type,
|
|
79
|
-
|
|
80
|
-
|
|
77
|
+
entropy_coeff=cfg.loss.entropy_coeff,
|
|
78
|
+
critic_coeff=cfg.loss.critic_coeff,
|
|
81
79
|
)
|
|
82
80
|
|
|
83
81
|
# use end-of-life as done key
|
|
@@ -23,7 +23,7 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
23
23
|
from tensordict import from_module
|
|
24
24
|
from tensordict.nn import CudaGraphModule
|
|
25
25
|
|
|
26
|
-
from torchrl._utils import timeit
|
|
26
|
+
from torchrl._utils import get_available_device, timeit
|
|
27
27
|
from torchrl.collectors import SyncDataCollector
|
|
28
28
|
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
|
|
29
29
|
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
|
@@ -36,11 +36,9 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
36
36
|
|
|
37
37
|
# Define paper hyperparameters
|
|
38
38
|
|
|
39
|
-
device =
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
else:
|
|
43
|
-
device = torch.device(device)
|
|
39
|
+
device = (
|
|
40
|
+
torch.device(cfg.loss.device) if cfg.loss.device else get_available_device()
|
|
41
|
+
)
|
|
44
42
|
|
|
45
43
|
num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
|
|
46
44
|
total_network_updates = (
|
|
@@ -77,8 +75,8 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
77
75
|
actor_network=actor,
|
|
78
76
|
critic_network=critic,
|
|
79
77
|
loss_critic_type=cfg.loss.loss_critic_type,
|
|
80
|
-
|
|
81
|
-
|
|
78
|
+
entropy_coeff=cfg.loss.entropy_coeff,
|
|
79
|
+
critic_coeff=cfg.loss.critic_coeff,
|
|
82
80
|
)
|
|
83
81
|
|
|
84
82
|
# Create optimizers
|
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
|
18
18
|
import torch
|
|
19
19
|
import tqdm
|
|
20
20
|
from tensordict.nn import CudaGraphModule
|
|
21
|
-
from torchrl._utils import timeit
|
|
21
|
+
from torchrl._utils import get_available_device, timeit
|
|
22
22
|
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
23
23
|
from torchrl.objectives import group_optimizers
|
|
24
24
|
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
@@ -55,13 +55,9 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
55
55
|
# Set seeds
|
|
56
56
|
torch.manual_seed(cfg.env.seed)
|
|
57
57
|
np.random.seed(cfg.env.seed)
|
|
58
|
-
device =
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
device = "cuda:0"
|
|
62
|
-
else:
|
|
63
|
-
device = "cpu"
|
|
64
|
-
device = torch.device(device)
|
|
58
|
+
device = (
|
|
59
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
60
|
+
)
|
|
65
61
|
|
|
66
62
|
# Create replay buffer
|
|
67
63
|
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
|
|
@@ -21,7 +21,7 @@ import torch
|
|
|
21
21
|
import tqdm
|
|
22
22
|
from tensordict import TensorDict
|
|
23
23
|
from tensordict.nn import CudaGraphModule
|
|
24
|
-
from torchrl._utils import timeit
|
|
24
|
+
from torchrl._utils import get_available_device, timeit
|
|
25
25
|
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
26
26
|
from torchrl.objectives import group_optimizers
|
|
27
27
|
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
@@ -60,13 +60,9 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
60
60
|
# Set seeds
|
|
61
61
|
torch.manual_seed(cfg.env.seed)
|
|
62
62
|
np.random.seed(cfg.env.seed)
|
|
63
|
-
device =
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
device = "cuda:0"
|
|
67
|
-
else:
|
|
68
|
-
device = "cpu"
|
|
69
|
-
device = torch.device(device)
|
|
63
|
+
device = (
|
|
64
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
65
|
+
)
|
|
70
66
|
|
|
71
67
|
# Create env
|
|
72
68
|
train_env, eval_env = make_environment(
|
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
|
18
18
|
import torch
|
|
19
19
|
import tqdm
|
|
20
20
|
from tensordict.nn import CudaGraphModule
|
|
21
|
-
from torchrl._utils import timeit
|
|
21
|
+
from torchrl._utils import get_available_device, timeit
|
|
22
22
|
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
23
23
|
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
24
24
|
from utils import (
|
|
@@ -36,10 +36,9 @@ torch.set_float32_matmul_precision("high")
|
|
|
36
36
|
|
|
37
37
|
@hydra.main(version_base="1.1", config_path="", config_name="discrete_offline_config")
|
|
38
38
|
def main(cfg): # noqa: F821
|
|
39
|
-
device =
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
device = torch.device(device)
|
|
39
|
+
device = (
|
|
40
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
41
|
+
)
|
|
43
42
|
|
|
44
43
|
# Create logger
|
|
45
44
|
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
|
|
@@ -20,7 +20,7 @@ import torch
|
|
|
20
20
|
import torch.cuda
|
|
21
21
|
import tqdm
|
|
22
22
|
from tensordict.nn import CudaGraphModule
|
|
23
|
-
from torchrl._utils import timeit
|
|
23
|
+
from torchrl._utils import get_available_device, timeit
|
|
24
24
|
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
25
25
|
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
26
26
|
from utils import (
|
|
@@ -38,13 +38,9 @@ torch.set_float32_matmul_precision("high")
|
|
|
38
38
|
|
|
39
39
|
@hydra.main(version_base="1.1", config_path="", config_name="discrete_online_config")
|
|
40
40
|
def main(cfg: DictConfig): # noqa: F821
|
|
41
|
-
device =
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
device = "cuda:0"
|
|
45
|
-
else:
|
|
46
|
-
device = "cpu"
|
|
47
|
-
device = torch.device(device)
|
|
41
|
+
device = (
|
|
42
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
43
|
+
)
|
|
48
44
|
|
|
49
45
|
# Create logger
|
|
50
46
|
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
|
|
@@ -21,7 +21,7 @@ import torch.cuda
|
|
|
21
21
|
import tqdm
|
|
22
22
|
from tensordict import TensorDict
|
|
23
23
|
from tensordict.nn import CudaGraphModule
|
|
24
|
-
from torchrl._utils import timeit
|
|
24
|
+
from torchrl._utils import get_available_device, timeit
|
|
25
25
|
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
26
26
|
from torchrl.objectives import group_optimizers
|
|
27
27
|
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
@@ -40,13 +40,11 @@ torch.set_float32_matmul_precision("high")
|
|
|
40
40
|
|
|
41
41
|
@hydra.main(version_base="1.1", config_path=".", config_name="config")
|
|
42
42
|
def main(cfg: DictConfig): # noqa: F821
|
|
43
|
-
device =
|
|
44
|
-
|
|
45
|
-
if
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
device = torch.device("cpu")
|
|
49
|
-
device = torch.device(device)
|
|
43
|
+
device = (
|
|
44
|
+
torch.device(cfg.network.device)
|
|
45
|
+
if cfg.network.device
|
|
46
|
+
else get_available_device()
|
|
47
|
+
)
|
|
50
48
|
|
|
51
49
|
# Create logger
|
|
52
50
|
exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name)
|
|
@@ -21,7 +21,7 @@ import torch.cuda
|
|
|
21
21
|
import tqdm
|
|
22
22
|
from tensordict import TensorDict
|
|
23
23
|
from tensordict.nn import CudaGraphModule
|
|
24
|
-
from torchrl._utils import timeit
|
|
24
|
+
from torchrl._utils import get_available_device, timeit
|
|
25
25
|
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
26
26
|
from torchrl.objectives import group_optimizers
|
|
27
27
|
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
@@ -39,21 +39,14 @@ from utils import (
|
|
|
39
39
|
|
|
40
40
|
@hydra.main(version_base="1.1", config_path="", config_name="config")
|
|
41
41
|
def main(cfg: DictConfig): # noqa: F821
|
|
42
|
-
device =
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
collector_device = cfg.collector.device
|
|
51
|
-
if collector_device in ("", None):
|
|
52
|
-
if torch.cuda.is_available():
|
|
53
|
-
collector_device = "cuda:0"
|
|
54
|
-
else:
|
|
55
|
-
collector_device = "cpu"
|
|
56
|
-
collector_device = torch.device(collector_device)
|
|
42
|
+
device = (
|
|
43
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
44
|
+
)
|
|
45
|
+
collector_device = (
|
|
46
|
+
torch.device(cfg.collector.device)
|
|
47
|
+
if cfg.collector.device
|
|
48
|
+
else get_available_device()
|
|
49
|
+
)
|
|
57
50
|
|
|
58
51
|
# Create logger
|
|
59
52
|
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
|
|
@@ -17,7 +17,7 @@ import torch
|
|
|
17
17
|
import tqdm
|
|
18
18
|
from tensordict import TensorDict
|
|
19
19
|
from tensordict.nn import CudaGraphModule
|
|
20
|
-
from torchrl._utils import logger as torchrl_logger, timeit
|
|
20
|
+
from torchrl._utils import get_available_device, logger as torchrl_logger, timeit
|
|
21
21
|
from torchrl.envs.libs.gym import set_gym_backend
|
|
22
22
|
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
23
23
|
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
|
|
@@ -38,13 +38,9 @@ from utils import (
|
|
|
38
38
|
def main(cfg: DictConfig): # noqa: F821
|
|
39
39
|
set_gym_backend(cfg.env.backend).set()
|
|
40
40
|
|
|
41
|
-
model_device =
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
model_device = "cuda:0"
|
|
45
|
-
else:
|
|
46
|
-
model_device = "cpu"
|
|
47
|
-
model_device = torch.device(model_device)
|
|
41
|
+
model_device = (
|
|
42
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
43
|
+
)
|
|
48
44
|
|
|
49
45
|
# Set seeds
|
|
50
46
|
torch.manual_seed(cfg.env.seed)
|
|
@@ -15,7 +15,7 @@ import numpy as np
|
|
|
15
15
|
import torch
|
|
16
16
|
import tqdm
|
|
17
17
|
from tensordict.nn import CudaGraphModule
|
|
18
|
-
from torchrl._utils import logger as torchrl_logger, timeit
|
|
18
|
+
from torchrl._utils import get_available_device, logger as torchrl_logger, timeit
|
|
19
19
|
from torchrl.envs.libs.gym import set_gym_backend
|
|
20
20
|
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
21
21
|
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
|
|
@@ -36,13 +36,9 @@ from utils import (
|
|
|
36
36
|
def main(cfg: DictConfig): # noqa: F821
|
|
37
37
|
set_gym_backend(cfg.env.backend).set()
|
|
38
38
|
|
|
39
|
-
model_device =
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
model_device = "cuda:0"
|
|
43
|
-
else:
|
|
44
|
-
model_device = "cpu"
|
|
45
|
-
model_device = torch.device(model_device)
|
|
39
|
+
model_device = (
|
|
40
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
41
|
+
)
|
|
46
42
|
|
|
47
43
|
# Set seeds
|
|
48
44
|
torch.manual_seed(cfg.env.seed)
|
|
@@ -17,7 +17,7 @@ import torch.nn
|
|
|
17
17
|
import torch.optim
|
|
18
18
|
import tqdm
|
|
19
19
|
from tensordict.nn import CudaGraphModule, TensorDictSequential
|
|
20
|
-
from torchrl._utils import timeit
|
|
20
|
+
from torchrl._utils import get_available_device, timeit
|
|
21
21
|
from torchrl.collectors import SyncDataCollector
|
|
22
22
|
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
|
23
23
|
from torchrl.envs import ExplorationType, set_exploration_type
|
|
@@ -33,13 +33,7 @@ torch.set_float32_matmul_precision("high")
|
|
|
33
33
|
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
|
|
34
34
|
def main(cfg: DictConfig): # noqa: F821
|
|
35
35
|
|
|
36
|
-
device = cfg.device
|
|
37
|
-
if device in ("", None):
|
|
38
|
-
if torch.cuda.is_available():
|
|
39
|
-
device = "cuda:0"
|
|
40
|
-
else:
|
|
41
|
-
device = "cpu"
|
|
42
|
-
device = torch.device(device)
|
|
36
|
+
device = torch.device(cfg.device) if cfg.device else get_available_device()
|
|
43
37
|
|
|
44
38
|
# Correct for frame_skip
|
|
45
39
|
frame_skip = 4
|
|
@@ -12,7 +12,7 @@ import torch.nn
|
|
|
12
12
|
import torch.optim
|
|
13
13
|
import tqdm
|
|
14
14
|
from tensordict.nn import CudaGraphModule, TensorDictSequential
|
|
15
|
-
from torchrl._utils import timeit
|
|
15
|
+
from torchrl._utils import get_available_device, timeit
|
|
16
16
|
from torchrl.collectors import SyncDataCollector
|
|
17
17
|
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
|
|
18
18
|
from torchrl.envs import ExplorationType, set_exploration_type
|
|
@@ -28,13 +28,7 @@ torch.set_float32_matmul_precision("high")
|
|
|
28
28
|
@hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
|
|
29
29
|
def main(cfg: DictConfig): # noqa: F821
|
|
30
30
|
|
|
31
|
-
device = cfg.device
|
|
32
|
-
if device in ("", None):
|
|
33
|
-
if torch.cuda.is_available():
|
|
34
|
-
device = "cuda:0"
|
|
35
|
-
else:
|
|
36
|
-
device = "cpu"
|
|
37
|
-
device = torch.device(device)
|
|
31
|
+
device = torch.device(cfg.device) if cfg.device else get_available_device()
|
|
38
32
|
|
|
39
33
|
# Make the components
|
|
40
34
|
model = make_dqn_model(cfg.env.env_name, device=device)
|