stable-baselines3 2.2.1__tar.gz → 2.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/PKG-INFO +4 -4
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/README.md +1 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/pyproject.toml +4 -2
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/setup.py +3 -3
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/base_class.py +4 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/distributions.py +1 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/env_checker.py +28 -16
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/on_policy_algorithm.py +25 -11
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/policies.py +3 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/save_util.py +2 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/type_aliases.py +1 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/util.py +1 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_frame_stack.py +6 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ddpg/ddpg.py +3 -3
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/dqn/dqn.py +1 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/td3/td3.py +3 -3
- stable_baselines3-2.3.0/stable_baselines3/version.txt +1 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/PKG-INFO +5 -5
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/requires.txt +2 -2
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_envs.py +5 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_logger.py +91 -1
- stable_baselines3-2.2.1/stable_baselines3/version.txt +0 -1
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/LICENSE +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/NOTICE +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/setup.cfg +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/a2c/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/a2c/a2c.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/a2c/policies.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/atari_wrappers.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/buffers.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/callbacks.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/env_util.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/bit_flipping_env.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/identity_env.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/multi_input_envs.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/evaluation.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/logger.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/monitor.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/noise.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/off_policy_algorithm.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/preprocessing.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/results_plotter.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/running_mean_std.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/sb2_compat/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/torch_layers.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/utils.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/base_vec_env.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/dummy_vec_env.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/patch_gym.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/stacked_observations.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/subproc_vec_env.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_check_nan.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_extract_dict_obs.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_monitor.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_normalize.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_transpose.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_video_recorder.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ddpg/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ddpg/policies.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/dqn/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/dqn/policies.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/her/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/her/goal_selection_strategy.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/her/her_replay_buffer.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ppo/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ppo/policies.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ppo/ppo.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/py.typed +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/sac/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/sac/policies.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/sac/sac.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/td3/__init__.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/td3/policies.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/SOURCES.txt +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/dependency_links.txt +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/top_level.txt +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_buffers.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_callbacks.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_cnn.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_custom_policy.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_deterministic.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_dict_env.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_distributions.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_env_checker.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_gae.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_her.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_identity.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_monitor.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_predict.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_preprocessing.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_run.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_save_load.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_sde.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_spaces.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_tensorboard.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_train_eval_mode.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_utils.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_check_nan.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_envs.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_extract_dict_obs.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_monitor.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_normalize.py +0 -0
- {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_stacked_obs.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: stable_baselines3
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3.0
|
|
4
4
|
Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
|
|
5
5
|
Home-page: https://github.com/DLR-RM/stable-baselines3
|
|
6
6
|
Author: Antonin Raffin
|
|
@@ -34,8 +34,8 @@ Requires-Dist: pytest-cov; extra == "tests"
|
|
|
34
34
|
Requires-Dist: pytest-env; extra == "tests"
|
|
35
35
|
Requires-Dist: pytest-xdist; extra == "tests"
|
|
36
36
|
Requires-Dist: mypy; extra == "tests"
|
|
37
|
-
Requires-Dist: ruff>=0.
|
|
38
|
-
Requires-Dist: black<24
|
|
37
|
+
Requires-Dist: ruff>=0.3.1; extra == "tests"
|
|
38
|
+
Requires-Dist: black<25,>=24.2.0; extra == "tests"
|
|
39
39
|
Provides-Extra: docs
|
|
40
40
|
Requires-Dist: sphinx<8,>=5; extra == "docs"
|
|
41
41
|
Requires-Dist: sphinx-autobuild; extra == "docs"
|
|
@@ -99,7 +99,7 @@ import gymnasium
|
|
|
99
99
|
|
|
100
100
|
from stable_baselines3 import PPO
|
|
101
101
|
|
|
102
|
-
env = gymnasium.make("CartPole-v1")
|
|
102
|
+
env = gymnasium.make("CartPole-v1", render_mode="human")
|
|
103
103
|
|
|
104
104
|
model = PPO("MlpPolicy", env, verbose=1)
|
|
105
105
|
model.learn(total_timesteps=10_000)
|
|
@@ -3,13 +3,15 @@
|
|
|
3
3
|
line-length = 127
|
|
4
4
|
# Assume Python 3.8
|
|
5
5
|
target-version = "py38"
|
|
6
|
+
|
|
7
|
+
[tool.ruff.lint]
|
|
6
8
|
# See https://beta.ruff.rs/docs/rules/
|
|
7
9
|
select = ["E", "F", "B", "UP", "C90", "RUF"]
|
|
8
10
|
# B028: Ignore explicit stacklevel`
|
|
9
11
|
# RUF013: Too many false positives (implicit optional)
|
|
10
12
|
ignore = ["B028", "RUF013"]
|
|
11
13
|
|
|
12
|
-
[tool.ruff.per-file-ignores]
|
|
14
|
+
[tool.ruff.lint.per-file-ignores]
|
|
13
15
|
# Default implementation in abstract methods
|
|
14
16
|
"./stable_baselines3/common/callbacks.py"= ["B027"]
|
|
15
17
|
"./stable_baselines3/common/noise.py"= ["B027"]
|
|
@@ -17,7 +19,7 @@ ignore = ["B028", "RUF013"]
|
|
|
17
19
|
"./tests/*.py"= ["RUF012", "RUF013"]
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
[tool.ruff.mccabe]
|
|
22
|
+
[tool.ruff.lint.mccabe]
|
|
21
23
|
# Unlike Flake8, default to a complexity level of 10.
|
|
22
24
|
max-complexity = 15
|
|
23
25
|
|
|
@@ -43,7 +43,7 @@ import gymnasium
|
|
|
43
43
|
|
|
44
44
|
from stable_baselines3 import PPO
|
|
45
45
|
|
|
46
|
-
env = gymnasium.make("CartPole-v1")
|
|
46
|
+
env = gymnasium.make("CartPole-v1", render_mode="human")
|
|
47
47
|
|
|
48
48
|
model = PPO("MlpPolicy", env, verbose=1)
|
|
49
49
|
model.learn(total_timesteps=10_000)
|
|
@@ -120,9 +120,9 @@ setup(
|
|
|
120
120
|
# Type check
|
|
121
121
|
"mypy",
|
|
122
122
|
# Lint code and sort imports (flake8 and isort replacement)
|
|
123
|
-
"ruff>=0.
|
|
123
|
+
"ruff>=0.3.1",
|
|
124
124
|
# Reformat
|
|
125
|
-
"black>=
|
|
125
|
+
"black>=24.2.0,<25",
|
|
126
126
|
],
|
|
127
127
|
"docs": [
|
|
128
128
|
"sphinx>=5,<8",
|
|
@@ -523,7 +523,10 @@ class BaseAlgorithm(ABC):
|
|
|
523
523
|
|
|
524
524
|
:param total_timesteps: The total number of samples (env steps) to train on
|
|
525
525
|
:param callback: callback(s) called at every step with state of the algorithm.
|
|
526
|
-
:param log_interval:
|
|
526
|
+
:param log_interval: for on-policy algos (e.g., PPO, A2C, ...) this is the number of
|
|
527
|
+
training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging;
|
|
528
|
+
for off-policy algos (e.g., TD3, SAC, ...) this is the number of episodes before
|
|
529
|
+
logging.
|
|
527
530
|
:param tb_log_name: the name of the run for TensorBoard logging
|
|
528
531
|
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
|
529
532
|
:param progress_bar: Display a progress bar using tqdm and rich.
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/distributions.py
RENAMED
|
@@ -113,7 +113,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
|
|
|
113
113
|
so we can sum components of the ``log_prob`` or the entropy.
|
|
114
114
|
|
|
115
115
|
:param tensor: shape: (n_batch, n_actions) or (n_batch,)
|
|
116
|
-
:return: shape: (n_batch,)
|
|
116
|
+
:return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input
|
|
117
117
|
"""
|
|
118
118
|
if len(tensor.shape) > 1:
|
|
119
119
|
tensor = tensor.sum(dim=1)
|
|
@@ -17,13 +17,37 @@ def _is_numpy_array_space(space: spaces.Space) -> bool:
|
|
|
17
17
|
return not isinstance(space, (spaces.Dict, spaces.Tuple))
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool:
|
|
21
|
+
"""
|
|
22
|
+
Return False if a (Multi)Discrete space has a non-zero start.
|
|
23
|
+
"""
|
|
24
|
+
return np.allclose(space.start, np.zeros_like(space.start))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None:
|
|
28
|
+
"""
|
|
29
|
+
:param space: Observation or action space
|
|
30
|
+
:param space_type: information about whether it is an observation or action space
|
|
31
|
+
(for the warning message)
|
|
32
|
+
:param key: When the observation space comes from a Dict space, we pass the
|
|
33
|
+
corresponding key to have more precise warning messages. Defaults to "".
|
|
34
|
+
"""
|
|
35
|
+
if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space):
|
|
36
|
+
maybe_key = f"(key='{key}')" if key else ""
|
|
37
|
+
warnings.warn(
|
|
38
|
+
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
|
|
39
|
+
"is not supported by Stable-Baselines3. "
|
|
40
|
+
f"You can use a wrapper or update your {space_type} space."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
20
44
|
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
|
|
21
45
|
"""
|
|
22
46
|
Check that the input will be compatible with Stable-Baselines
|
|
23
47
|
when the observation is apparently an image.
|
|
24
48
|
|
|
25
49
|
:param observation_space: Observation space
|
|
26
|
-
:key: When the observation space comes from a Dict space, we pass the
|
|
50
|
+
:param key: When the observation space comes from a Dict space, we pass the
|
|
27
51
|
corresponding key to have more precise warning messages. Defaults to "".
|
|
28
52
|
"""
|
|
29
53
|
if observation_space.dtype != np.uint8:
|
|
@@ -63,11 +87,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
|
|
|
63
87
|
for key, space in observation_space.spaces.items():
|
|
64
88
|
if isinstance(space, spaces.Dict):
|
|
65
89
|
nested_dict = True
|
|
66
|
-
|
|
67
|
-
warnings.warn(
|
|
68
|
-
f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. "
|
|
69
|
-
"You can use a wrapper or update your observation space."
|
|
70
|
-
)
|
|
90
|
+
_check_non_zero_start(space, "observation", key)
|
|
71
91
|
|
|
72
92
|
if nested_dict:
|
|
73
93
|
warnings.warn(
|
|
@@ -87,11 +107,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
|
|
|
87
107
|
"which is supported by SB3."
|
|
88
108
|
)
|
|
89
109
|
|
|
90
|
-
|
|
91
|
-
warnings.warn(
|
|
92
|
-
"Discrete observation space with a non-zero start is not supported by Stable-Baselines3. "
|
|
93
|
-
"You can use a wrapper or update your observation space."
|
|
94
|
-
)
|
|
110
|
+
_check_non_zero_start(observation_space, "observation")
|
|
95
111
|
|
|
96
112
|
if isinstance(observation_space, spaces.Sequence):
|
|
97
113
|
warnings.warn(
|
|
@@ -100,11 +116,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
|
|
|
100
116
|
"Note: The checks for returned values are skipped."
|
|
101
117
|
)
|
|
102
118
|
|
|
103
|
-
|
|
104
|
-
warnings.warn(
|
|
105
|
-
"Discrete action space with a non-zero start is not supported by Stable-Baselines3. "
|
|
106
|
-
"You can use a wrapper or update your action space."
|
|
107
|
-
)
|
|
119
|
+
_check_non_zero_start(action_space, "action")
|
|
108
120
|
|
|
109
121
|
if not _is_numpy_array_space(action_space):
|
|
110
122
|
warnings.warn(
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/on_policy_algorithm.py
RENAMED
|
@@ -92,6 +92,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|
|
92
92
|
use_sde=use_sde,
|
|
93
93
|
sde_sample_freq=sde_sample_freq,
|
|
94
94
|
support_multi_env=True,
|
|
95
|
+
monitor_wrapper=monitor_wrapper,
|
|
95
96
|
seed=seed,
|
|
96
97
|
stats_window_size=stats_window_size,
|
|
97
98
|
tensorboard_log=tensorboard_log,
|
|
@@ -200,7 +201,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|
|
200
201
|
if not callback.on_step():
|
|
201
202
|
return False
|
|
202
203
|
|
|
203
|
-
self._update_info_buffer(infos)
|
|
204
|
+
self._update_info_buffer(infos, dones)
|
|
204
205
|
n_steps += 1
|
|
205
206
|
|
|
206
207
|
if isinstance(self.action_space, spaces.Discrete):
|
|
@@ -250,6 +251,28 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|
|
250
251
|
"""
|
|
251
252
|
raise NotImplementedError
|
|
252
253
|
|
|
254
|
+
def _dump_logs(self, iteration: int) -> None:
|
|
255
|
+
"""
|
|
256
|
+
Write log.
|
|
257
|
+
|
|
258
|
+
:param iteration: Current logging iteration
|
|
259
|
+
"""
|
|
260
|
+
assert self.ep_info_buffer is not None
|
|
261
|
+
assert self.ep_success_buffer is not None
|
|
262
|
+
|
|
263
|
+
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
|
264
|
+
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
|
265
|
+
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
|
266
|
+
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
|
267
|
+
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
|
268
|
+
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
|
269
|
+
self.logger.record("time/fps", fps)
|
|
270
|
+
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
|
271
|
+
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
|
272
|
+
if len(self.ep_success_buffer) > 0:
|
|
273
|
+
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
|
|
274
|
+
self.logger.dump(step=self.num_timesteps)
|
|
275
|
+
|
|
253
276
|
def learn(
|
|
254
277
|
self: SelfOnPolicyAlgorithm,
|
|
255
278
|
total_timesteps: int,
|
|
@@ -285,16 +308,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|
|
285
308
|
# Display training infos
|
|
286
309
|
if log_interval is not None and iteration % log_interval == 0:
|
|
287
310
|
assert self.ep_info_buffer is not None
|
|
288
|
-
|
|
289
|
-
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
|
290
|
-
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
|
291
|
-
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
|
292
|
-
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
|
293
|
-
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
|
294
|
-
self.logger.record("time/fps", fps)
|
|
295
|
-
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
|
296
|
-
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
|
297
|
-
self.logger.dump(step=self.num_timesteps)
|
|
311
|
+
self._dump_logs(iteration)
|
|
298
312
|
|
|
299
313
|
self.train()
|
|
300
314
|
|
|
@@ -173,7 +173,9 @@ class BaseModel(nn.Module):
|
|
|
173
173
|
:return:
|
|
174
174
|
"""
|
|
175
175
|
device = get_device(device)
|
|
176
|
-
|
|
176
|
+
# Note(antonin): we cannot use `weights_only=True` here because we need to allow
|
|
177
|
+
# gymnasium imports for the policy to be loaded successfully
|
|
178
|
+
saved_variables = th.load(path, map_location=device, weights_only=False)
|
|
177
179
|
|
|
178
180
|
# Create policy object
|
|
179
181
|
model = cls(**saved_variables["data"])
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Save util taken from stable_baselines
|
|
3
3
|
used to serialize data (class parameters) of model classes
|
|
4
4
|
"""
|
|
5
|
+
|
|
5
6
|
import base64
|
|
6
7
|
import functools
|
|
7
8
|
import io
|
|
@@ -446,7 +447,7 @@ def load_from_zip_file(
|
|
|
446
447
|
file_content.seek(0)
|
|
447
448
|
# Load the parameters with the right ``map_location``.
|
|
448
449
|
# Remove ".pth" ending with splitext
|
|
449
|
-
th_object = th.load(file_content, map_location=device)
|
|
450
|
+
th_object = th.load(file_content, map_location=device, weights_only=True)
|
|
450
451
|
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
|
|
451
452
|
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
|
|
452
453
|
# PyTorch variables (not state_dicts)
|
|
@@ -29,7 +29,12 @@ class VecFrameStack(VecEnvWrapper):
|
|
|
29
29
|
|
|
30
30
|
def step_wait(
|
|
31
31
|
self,
|
|
32
|
-
) -> Tuple[
|
|
32
|
+
) -> Tuple[
|
|
33
|
+
Union[np.ndarray, Dict[str, np.ndarray]],
|
|
34
|
+
np.ndarray,
|
|
35
|
+
np.ndarray,
|
|
36
|
+
List[Dict[str, Any]],
|
|
37
|
+
]:
|
|
33
38
|
observations, rewards, dones, infos = self.venv.step_wait()
|
|
34
39
|
observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
|
|
35
40
|
return observations, rewards, dones, infos
|
|
@@ -60,11 +60,11 @@ class DDPG(TD3):
|
|
|
60
60
|
learning_rate: Union[float, Schedule] = 1e-3,
|
|
61
61
|
buffer_size: int = 1_000_000, # 1e6
|
|
62
62
|
learning_starts: int = 100,
|
|
63
|
-
batch_size: int =
|
|
63
|
+
batch_size: int = 256,
|
|
64
64
|
tau: float = 0.005,
|
|
65
65
|
gamma: float = 0.99,
|
|
66
|
-
train_freq: Union[int, Tuple[int, str]] =
|
|
67
|
-
gradient_steps: int =
|
|
66
|
+
train_freq: Union[int, Tuple[int, str]] = 1,
|
|
67
|
+
gradient_steps: int = 1,
|
|
68
68
|
action_noise: Optional[ActionNoise] = None,
|
|
69
69
|
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
|
70
70
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
|
@@ -79,7 +79,7 @@ class DQN(OffPolicyAlgorithm):
|
|
|
79
79
|
env: Union[GymEnv, str],
|
|
80
80
|
learning_rate: Union[float, Schedule] = 1e-4,
|
|
81
81
|
buffer_size: int = 1_000_000, # 1e6
|
|
82
|
-
learning_starts: int =
|
|
82
|
+
learning_starts: int = 100,
|
|
83
83
|
batch_size: int = 32,
|
|
84
84
|
tau: float = 1.0,
|
|
85
85
|
gamma: float = 0.99,
|
|
@@ -83,11 +83,11 @@ class TD3(OffPolicyAlgorithm):
|
|
|
83
83
|
learning_rate: Union[float, Schedule] = 1e-3,
|
|
84
84
|
buffer_size: int = 1_000_000, # 1e6
|
|
85
85
|
learning_starts: int = 100,
|
|
86
|
-
batch_size: int =
|
|
86
|
+
batch_size: int = 256,
|
|
87
87
|
tau: float = 0.005,
|
|
88
88
|
gamma: float = 0.99,
|
|
89
|
-
train_freq: Union[int, Tuple[int, str]] =
|
|
90
|
-
gradient_steps: int =
|
|
89
|
+
train_freq: Union[int, Tuple[int, str]] = 1,
|
|
90
|
+
gradient_steps: int = 1,
|
|
91
91
|
action_noise: Optional[ActionNoise] = None,
|
|
92
92
|
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
|
93
93
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
2.3.0
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
|
-
Name:
|
|
3
|
-
Version: 2.
|
|
2
|
+
Name: stable_baselines3
|
|
3
|
+
Version: 2.3.0
|
|
4
4
|
Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
|
|
5
5
|
Home-page: https://github.com/DLR-RM/stable-baselines3
|
|
6
6
|
Author: Antonin Raffin
|
|
@@ -34,8 +34,8 @@ Requires-Dist: pytest-cov; extra == "tests"
|
|
|
34
34
|
Requires-Dist: pytest-env; extra == "tests"
|
|
35
35
|
Requires-Dist: pytest-xdist; extra == "tests"
|
|
36
36
|
Requires-Dist: mypy; extra == "tests"
|
|
37
|
-
Requires-Dist: ruff>=0.
|
|
38
|
-
Requires-Dist: black<24
|
|
37
|
+
Requires-Dist: ruff>=0.3.1; extra == "tests"
|
|
38
|
+
Requires-Dist: black<25,>=24.2.0; extra == "tests"
|
|
39
39
|
Provides-Extra: docs
|
|
40
40
|
Requires-Dist: sphinx<8,>=5; extra == "docs"
|
|
41
41
|
Requires-Dist: sphinx-autobuild; extra == "docs"
|
|
@@ -99,7 +99,7 @@ import gymnasium
|
|
|
99
99
|
|
|
100
100
|
from stable_baselines3 import PPO
|
|
101
101
|
|
|
102
|
-
env = gymnasium.make("CartPole-v1")
|
|
102
|
+
env = gymnasium.make("CartPole-v1", render_mode="human")
|
|
103
103
|
|
|
104
104
|
model = PPO("MlpPolicy", env, verbose=1)
|
|
105
105
|
model.learn(total_timesteps=10_000)
|
|
@@ -123,6 +123,8 @@ def test_high_dimension_action_space():
|
|
|
123
123
|
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
|
|
124
124
|
# Non zero start index
|
|
125
125
|
spaces.Discrete(3, start=-1),
|
|
126
|
+
# Non zero start index (MultiDiscrete)
|
|
127
|
+
spaces.MultiDiscrete([4, 4], start=[1, 0]),
|
|
126
128
|
# Non zero start index inside a Dict
|
|
127
129
|
spaces.Dict({"obs": spaces.Discrete(3, start=1)}),
|
|
128
130
|
],
|
|
@@ -164,6 +166,8 @@ def test_non_default_spaces(new_obs_space):
|
|
|
164
166
|
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
|
|
165
167
|
# Non zero start index
|
|
166
168
|
spaces.Discrete(3, start=-1),
|
|
169
|
+
# Non zero start index (MultiDiscrete)
|
|
170
|
+
spaces.MultiDiscrete([4, 4], start=[1, 0]),
|
|
167
171
|
],
|
|
168
172
|
)
|
|
169
173
|
def test_non_default_action_spaces(new_action_space):
|
|
@@ -179,7 +183,7 @@ def test_non_default_action_spaces(new_action_space):
|
|
|
179
183
|
env.action_space = new_action_space
|
|
180
184
|
|
|
181
185
|
# Discrete action space
|
|
182
|
-
if isinstance(new_action_space, spaces.Discrete):
|
|
186
|
+
if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)):
|
|
183
187
|
with pytest.warns(UserWarning):
|
|
184
188
|
check_env(env)
|
|
185
189
|
return
|
|
@@ -14,7 +14,7 @@ from gymnasium import spaces
|
|
|
14
14
|
from matplotlib import pyplot as plt
|
|
15
15
|
from pandas.errors import EmptyDataError
|
|
16
16
|
|
|
17
|
-
from stable_baselines3 import A2C, DQN
|
|
17
|
+
from stable_baselines3 import A2C, DQN, PPO
|
|
18
18
|
from stable_baselines3.common.env_checker import check_env
|
|
19
19
|
from stable_baselines3.common.logger import (
|
|
20
20
|
DEBUG,
|
|
@@ -33,6 +33,7 @@ from stable_baselines3.common.logger import (
|
|
|
33
33
|
read_csv,
|
|
34
34
|
read_json,
|
|
35
35
|
)
|
|
36
|
+
from stable_baselines3.common.monitor import Monitor
|
|
36
37
|
|
|
37
38
|
KEY_VALUES = {
|
|
38
39
|
"test": 1,
|
|
@@ -474,3 +475,92 @@ def test_human_output_format_custom_test_io(base_class):
|
|
|
474
475
|
"""
|
|
475
476
|
|
|
476
477
|
assert printed == desired_printed
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
class DummySuccessEnv(gym.Env):
|
|
481
|
+
"""
|
|
482
|
+
Create a dummy success environment that returns wether True or False for info['is_success']
|
|
483
|
+
at the end of an episode according to its dummy successes list
|
|
484
|
+
"""
|
|
485
|
+
|
|
486
|
+
def __init__(self, dummy_successes, ep_steps):
|
|
487
|
+
"""Init the dummy success env
|
|
488
|
+
|
|
489
|
+
:param dummy_successes: list of size (n_logs_iterations, n_episodes_per_log) that specifies
|
|
490
|
+
the success value of log iteration i at episode j
|
|
491
|
+
:param ep_steps: number of steps per episode (to activate truncated)
|
|
492
|
+
"""
|
|
493
|
+
self.n_steps = 0
|
|
494
|
+
self.log_id = 0
|
|
495
|
+
self.ep_id = 0
|
|
496
|
+
|
|
497
|
+
self.ep_steps = ep_steps
|
|
498
|
+
|
|
499
|
+
self.dummy_success = dummy_successes
|
|
500
|
+
self.num_logs = len(dummy_successes)
|
|
501
|
+
self.ep_per_log = len(dummy_successes[0])
|
|
502
|
+
self.steps_per_log = self.ep_per_log * self.ep_steps
|
|
503
|
+
|
|
504
|
+
self.action_space = spaces.Discrete(2)
|
|
505
|
+
self.observation_space = spaces.Discrete(2)
|
|
506
|
+
|
|
507
|
+
def reset(self, seed=None, options=None):
|
|
508
|
+
"""
|
|
509
|
+
Reset the env and advance to the next episode_id to get the next dummy success
|
|
510
|
+
"""
|
|
511
|
+
self.n_steps = 0
|
|
512
|
+
|
|
513
|
+
if self.ep_id == self.ep_per_log:
|
|
514
|
+
self.ep_id = 0
|
|
515
|
+
self.log_id = (self.log_id + 1) % self.num_logs
|
|
516
|
+
|
|
517
|
+
return self.observation_space.sample(), {}
|
|
518
|
+
|
|
519
|
+
def step(self, action):
|
|
520
|
+
"""
|
|
521
|
+
Step and return a dummy success when an episode is truncated
|
|
522
|
+
"""
|
|
523
|
+
self.n_steps += 1
|
|
524
|
+
truncated = self.n_steps >= self.ep_steps
|
|
525
|
+
|
|
526
|
+
info = {}
|
|
527
|
+
if truncated:
|
|
528
|
+
maybe_success = self.dummy_success[self.log_id][self.ep_id]
|
|
529
|
+
info["is_success"] = maybe_success
|
|
530
|
+
self.ep_id += 1
|
|
531
|
+
return self.observation_space.sample(), 0.0, False, truncated, info
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def test_rollout_success_rate_on_policy_algorithm(tmp_path):
|
|
535
|
+
"""
|
|
536
|
+
Test if the rollout/success_rate information is correctly logged with on policy algorithms
|
|
537
|
+
|
|
538
|
+
To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode)
|
|
539
|
+
is going to be successfull or not.
|
|
540
|
+
"""
|
|
541
|
+
|
|
542
|
+
STATS_WINDOW_SIZE = 10
|
|
543
|
+
# Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE
|
|
544
|
+
dummy_successes = [
|
|
545
|
+
[True] * 3 + [False] * 7,
|
|
546
|
+
[True] * 5 + [False] * 5,
|
|
547
|
+
[True] * 8 + [False] * 2,
|
|
548
|
+
]
|
|
549
|
+
ep_steps = 64
|
|
550
|
+
|
|
551
|
+
# Monitor the env to track the success info
|
|
552
|
+
monitor_file = str(tmp_path / "monitor.csv")
|
|
553
|
+
env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",))
|
|
554
|
+
|
|
555
|
+
# Equip the model of a custom logger to check the success_rate info
|
|
556
|
+
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1)
|
|
557
|
+
logger = InMemoryLogger()
|
|
558
|
+
model.set_logger(logger)
|
|
559
|
+
|
|
560
|
+
# Make the model learn and check that the success rate corresponds to the ratio of dummy successes
|
|
561
|
+
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
|
562
|
+
assert logger.name_to_value["rollout/success_rate"] == 0.3
|
|
563
|
+
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
|
564
|
+
assert logger.name_to_value["rollout/success_rate"] == 0.5
|
|
565
|
+
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
|
|
566
|
+
assert logger.name_to_value["rollout/success_rate"] == 0.8
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
2.2.1
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/atari_wrappers.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/identity_env.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/off_policy_algorithm.py
RENAMED
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/preprocessing.py
RENAMED
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/results_plotter.py
RENAMED
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/running_mean_std.py
RENAMED
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/sb2_compat/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/torch_layers.py
RENAMED
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/__init__.py
RENAMED
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/base_vec_env.py
RENAMED
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/patch_gym.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_monitor.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/her/goal_selection_strategy.py
RENAMED
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/her/her_replay_buffer.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|