stable-baselines3 2.2.1__tar.gz → 2.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/PKG-INFO +4 -4
  2. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/README.md +1 -1
  3. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/pyproject.toml +4 -2
  4. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/setup.py +3 -3
  5. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/base_class.py +4 -1
  6. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/distributions.py +1 -1
  7. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/env_checker.py +28 -16
  8. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/on_policy_algorithm.py +25 -11
  9. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/policies.py +3 -1
  10. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/save_util.py +2 -1
  11. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/type_aliases.py +1 -0
  12. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/util.py +1 -0
  13. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_frame_stack.py +6 -1
  14. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ddpg/ddpg.py +3 -3
  15. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/dqn/dqn.py +1 -1
  16. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/td3/td3.py +3 -3
  17. stable_baselines3-2.3.0/stable_baselines3/version.txt +1 -0
  18. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/PKG-INFO +5 -5
  19. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/requires.txt +2 -2
  20. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_envs.py +5 -1
  21. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_logger.py +91 -1
  22. stable_baselines3-2.2.1/stable_baselines3/version.txt +0 -1
  23. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/LICENSE +0 -0
  24. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/NOTICE +0 -0
  25. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/setup.cfg +0 -0
  26. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/__init__.py +0 -0
  27. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/a2c/__init__.py +0 -0
  28. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/a2c/a2c.py +0 -0
  29. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/a2c/policies.py +0 -0
  30. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/__init__.py +0 -0
  31. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/atari_wrappers.py +0 -0
  32. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/buffers.py +0 -0
  33. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/callbacks.py +0 -0
  34. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/env_util.py +0 -0
  35. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/__init__.py +0 -0
  36. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/bit_flipping_env.py +0 -0
  37. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/identity_env.py +0 -0
  38. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/envs/multi_input_envs.py +0 -0
  39. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/evaluation.py +0 -0
  40. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/logger.py +0 -0
  41. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/monitor.py +0 -0
  42. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/noise.py +0 -0
  43. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/off_policy_algorithm.py +0 -0
  44. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/preprocessing.py +0 -0
  45. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/results_plotter.py +0 -0
  46. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/running_mean_std.py +0 -0
  47. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/sb2_compat/__init__.py +0 -0
  48. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +0 -0
  49. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/torch_layers.py +0 -0
  50. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/utils.py +0 -0
  51. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/__init__.py +0 -0
  52. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/base_vec_env.py +0 -0
  53. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/dummy_vec_env.py +0 -0
  54. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/patch_gym.py +0 -0
  55. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/stacked_observations.py +0 -0
  56. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/subproc_vec_env.py +0 -0
  57. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_check_nan.py +0 -0
  58. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_extract_dict_obs.py +0 -0
  59. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_monitor.py +0 -0
  60. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_normalize.py +0 -0
  61. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_transpose.py +0 -0
  62. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/common/vec_env/vec_video_recorder.py +0 -0
  63. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ddpg/__init__.py +0 -0
  64. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ddpg/policies.py +0 -0
  65. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/dqn/__init__.py +0 -0
  66. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/dqn/policies.py +0 -0
  67. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/her/__init__.py +0 -0
  68. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/her/goal_selection_strategy.py +0 -0
  69. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/her/her_replay_buffer.py +0 -0
  70. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ppo/__init__.py +0 -0
  71. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ppo/policies.py +0 -0
  72. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/ppo/ppo.py +0 -0
  73. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/py.typed +0 -0
  74. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/sac/__init__.py +0 -0
  75. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/sac/policies.py +0 -0
  76. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/sac/sac.py +0 -0
  77. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/td3/__init__.py +0 -0
  78. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3/td3/policies.py +0 -0
  79. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/SOURCES.txt +0 -0
  80. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/dependency_links.txt +0 -0
  81. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/stable_baselines3.egg-info/top_level.txt +0 -0
  82. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_buffers.py +0 -0
  83. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_callbacks.py +0 -0
  84. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_cnn.py +0 -0
  85. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_custom_policy.py +0 -0
  86. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_deterministic.py +0 -0
  87. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_dict_env.py +0 -0
  88. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_distributions.py +0 -0
  89. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_env_checker.py +0 -0
  90. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_gae.py +0 -0
  91. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_her.py +0 -0
  92. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_identity.py +0 -0
  93. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_monitor.py +0 -0
  94. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_predict.py +0 -0
  95. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_preprocessing.py +0 -0
  96. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_run.py +0 -0
  97. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_save_load.py +0 -0
  98. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_sde.py +0 -0
  99. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_spaces.py +0 -0
  100. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_tensorboard.py +0 -0
  101. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_train_eval_mode.py +0 -0
  102. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_utils.py +0 -0
  103. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_check_nan.py +0 -0
  104. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_envs.py +0 -0
  105. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_extract_dict_obs.py +0 -0
  106. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_monitor.py +0 -0
  107. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_normalize.py +0 -0
  108. {stable_baselines3-2.2.1 → stable_baselines3-2.3.0}/tests/test_vec_stacked_obs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: stable_baselines3
3
- Version: 2.2.1
3
+ Version: 2.3.0
4
4
  Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
5
5
  Home-page: https://github.com/DLR-RM/stable-baselines3
6
6
  Author: Antonin Raffin
@@ -34,8 +34,8 @@ Requires-Dist: pytest-cov; extra == "tests"
34
34
  Requires-Dist: pytest-env; extra == "tests"
35
35
  Requires-Dist: pytest-xdist; extra == "tests"
36
36
  Requires-Dist: mypy; extra == "tests"
37
- Requires-Dist: ruff>=0.0.288; extra == "tests"
38
- Requires-Dist: black<24,>=23.9.1; extra == "tests"
37
+ Requires-Dist: ruff>=0.3.1; extra == "tests"
38
+ Requires-Dist: black<25,>=24.2.0; extra == "tests"
39
39
  Provides-Extra: docs
40
40
  Requires-Dist: sphinx<8,>=5; extra == "docs"
41
41
  Requires-Dist: sphinx-autobuild; extra == "docs"
@@ -99,7 +99,7 @@ import gymnasium
99
99
 
100
100
  from stable_baselines3 import PPO
101
101
 
102
- env = gymnasium.make("CartPole-v1")
102
+ env = gymnasium.make("CartPole-v1", render_mode="human")
103
103
 
104
104
  model = PPO("MlpPolicy", env, verbose=1)
105
105
  model.learn(total_timesteps=10_000)
@@ -127,7 +127,7 @@ import gymnasium as gym
127
127
 
128
128
  from stable_baselines3 import PPO
129
129
 
130
- env = gym.make("CartPole-v1")
130
+ env = gym.make("CartPole-v1", render_mode="human")
131
131
 
132
132
  model = PPO("MlpPolicy", env, verbose=1)
133
133
  model.learn(total_timesteps=10_000)
@@ -3,13 +3,15 @@
3
3
  line-length = 127
4
4
  # Assume Python 3.8
5
5
  target-version = "py38"
6
+
7
+ [tool.ruff.lint]
6
8
  # See https://beta.ruff.rs/docs/rules/
7
9
  select = ["E", "F", "B", "UP", "C90", "RUF"]
8
10
  # B028: Ignore explicit stacklevel`
9
11
  # RUF013: Too many false positives (implicit optional)
10
12
  ignore = ["B028", "RUF013"]
11
13
 
12
- [tool.ruff.per-file-ignores]
14
+ [tool.ruff.lint.per-file-ignores]
13
15
  # Default implementation in abstract methods
14
16
  "./stable_baselines3/common/callbacks.py"= ["B027"]
15
17
  "./stable_baselines3/common/noise.py"= ["B027"]
@@ -17,7 +19,7 @@ ignore = ["B028", "RUF013"]
17
19
  "./tests/*.py"= ["RUF012", "RUF013"]
18
20
 
19
21
 
20
- [tool.ruff.mccabe]
22
+ [tool.ruff.lint.mccabe]
21
23
  # Unlike Flake8, default to a complexity level of 10.
22
24
  max-complexity = 15
23
25
 
@@ -43,7 +43,7 @@ import gymnasium
43
43
 
44
44
  from stable_baselines3 import PPO
45
45
 
46
- env = gymnasium.make("CartPole-v1")
46
+ env = gymnasium.make("CartPole-v1", render_mode="human")
47
47
 
48
48
  model = PPO("MlpPolicy", env, verbose=1)
49
49
  model.learn(total_timesteps=10_000)
@@ -120,9 +120,9 @@ setup(
120
120
  # Type check
121
121
  "mypy",
122
122
  # Lint code and sort imports (flake8 and isort replacement)
123
- "ruff>=0.0.288",
123
+ "ruff>=0.3.1",
124
124
  # Reformat
125
- "black>=23.9.1,<24",
125
+ "black>=24.2.0,<25",
126
126
  ],
127
127
  "docs": [
128
128
  "sphinx>=5,<8",
@@ -523,7 +523,10 @@ class BaseAlgorithm(ABC):
523
523
 
524
524
  :param total_timesteps: The total number of samples (env steps) to train on
525
525
  :param callback: callback(s) called at every step with state of the algorithm.
526
- :param log_interval: The number of episodes before logging.
526
+ :param log_interval: for on-policy algos (e.g., PPO, A2C, ...) this is the number of
527
+ training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging;
528
+ for off-policy algos (e.g., TD3, SAC, ...) this is the number of episodes before
529
+ logging.
527
530
  :param tb_log_name: the name of the run for TensorBoard logging
528
531
  :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
529
532
  :param progress_bar: Display a progress bar using tqdm and rich.
@@ -113,7 +113,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
113
113
  so we can sum components of the ``log_prob`` or the entropy.
114
114
 
115
115
  :param tensor: shape: (n_batch, n_actions) or (n_batch,)
116
- :return: shape: (n_batch,)
116
+ :return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input
117
117
  """
118
118
  if len(tensor.shape) > 1:
119
119
  tensor = tensor.sum(dim=1)
@@ -17,13 +17,37 @@ def _is_numpy_array_space(space: spaces.Space) -> bool:
17
17
  return not isinstance(space, (spaces.Dict, spaces.Tuple))
18
18
 
19
19
 
20
+ def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool:
21
+ """
22
+ Return False if a (Multi)Discrete space has a non-zero start.
23
+ """
24
+ return np.allclose(space.start, np.zeros_like(space.start))
25
+
26
+
27
+ def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None:
28
+ """
29
+ :param space: Observation or action space
30
+ :param space_type: information about whether it is an observation or action space
31
+ (for the warning message)
32
+ :param key: When the observation space comes from a Dict space, we pass the
33
+ corresponding key to have more precise warning messages. Defaults to "".
34
+ """
35
+ if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space):
36
+ maybe_key = f"(key='{key}')" if key else ""
37
+ warnings.warn(
38
+ f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
39
+ "is not supported by Stable-Baselines3. "
40
+ f"You can use a wrapper or update your {space_type} space."
41
+ )
42
+
43
+
20
44
  def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
21
45
  """
22
46
  Check that the input will be compatible with Stable-Baselines
23
47
  when the observation is apparently an image.
24
48
 
25
49
  :param observation_space: Observation space
26
- :key: When the observation space comes from a Dict space, we pass the
50
+ :param key: When the observation space comes from a Dict space, we pass the
27
51
  corresponding key to have more precise warning messages. Defaults to "".
28
52
  """
29
53
  if observation_space.dtype != np.uint8:
@@ -63,11 +87,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
63
87
  for key, space in observation_space.spaces.items():
64
88
  if isinstance(space, spaces.Dict):
65
89
  nested_dict = True
66
- if isinstance(space, spaces.Discrete) and space.start != 0:
67
- warnings.warn(
68
- f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. "
69
- "You can use a wrapper or update your observation space."
70
- )
90
+ _check_non_zero_start(space, "observation", key)
71
91
 
72
92
  if nested_dict:
73
93
  warnings.warn(
@@ -87,11 +107,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
87
107
  "which is supported by SB3."
88
108
  )
89
109
 
90
- if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0:
91
- warnings.warn(
92
- "Discrete observation space with a non-zero start is not supported by Stable-Baselines3. "
93
- "You can use a wrapper or update your observation space."
94
- )
110
+ _check_non_zero_start(observation_space, "observation")
95
111
 
96
112
  if isinstance(observation_space, spaces.Sequence):
97
113
  warnings.warn(
@@ -100,11 +116,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
100
116
  "Note: The checks for returned values are skipped."
101
117
  )
102
118
 
103
- if isinstance(action_space, spaces.Discrete) and action_space.start != 0:
104
- warnings.warn(
105
- "Discrete action space with a non-zero start is not supported by Stable-Baselines3. "
106
- "You can use a wrapper or update your action space."
107
- )
119
+ _check_non_zero_start(action_space, "action")
108
120
 
109
121
  if not _is_numpy_array_space(action_space):
110
122
  warnings.warn(
@@ -92,6 +92,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
92
92
  use_sde=use_sde,
93
93
  sde_sample_freq=sde_sample_freq,
94
94
  support_multi_env=True,
95
+ monitor_wrapper=monitor_wrapper,
95
96
  seed=seed,
96
97
  stats_window_size=stats_window_size,
97
98
  tensorboard_log=tensorboard_log,
@@ -200,7 +201,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
200
201
  if not callback.on_step():
201
202
  return False
202
203
 
203
- self._update_info_buffer(infos)
204
+ self._update_info_buffer(infos, dones)
204
205
  n_steps += 1
205
206
 
206
207
  if isinstance(self.action_space, spaces.Discrete):
@@ -250,6 +251,28 @@ class OnPolicyAlgorithm(BaseAlgorithm):
250
251
  """
251
252
  raise NotImplementedError
252
253
 
254
+ def _dump_logs(self, iteration: int) -> None:
255
+ """
256
+ Write log.
257
+
258
+ :param iteration: Current logging iteration
259
+ """
260
+ assert self.ep_info_buffer is not None
261
+ assert self.ep_success_buffer is not None
262
+
263
+ time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
264
+ fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
265
+ self.logger.record("time/iterations", iteration, exclude="tensorboard")
266
+ if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
267
+ self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
268
+ self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
269
+ self.logger.record("time/fps", fps)
270
+ self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
271
+ self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
272
+ if len(self.ep_success_buffer) > 0:
273
+ self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
274
+ self.logger.dump(step=self.num_timesteps)
275
+
253
276
  def learn(
254
277
  self: SelfOnPolicyAlgorithm,
255
278
  total_timesteps: int,
@@ -285,16 +308,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
285
308
  # Display training infos
286
309
  if log_interval is not None and iteration % log_interval == 0:
287
310
  assert self.ep_info_buffer is not None
288
- time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
289
- fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
290
- self.logger.record("time/iterations", iteration, exclude="tensorboard")
291
- if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
292
- self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
293
- self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
294
- self.logger.record("time/fps", fps)
295
- self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
296
- self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
297
- self.logger.dump(step=self.num_timesteps)
311
+ self._dump_logs(iteration)
298
312
 
299
313
  self.train()
300
314
 
@@ -173,7 +173,9 @@ class BaseModel(nn.Module):
173
173
  :return:
174
174
  """
175
175
  device = get_device(device)
176
- saved_variables = th.load(path, map_location=device)
176
+ # Note(antonin): we cannot use `weights_only=True` here because we need to allow
177
+ # gymnasium imports for the policy to be loaded successfully
178
+ saved_variables = th.load(path, map_location=device, weights_only=False)
177
179
 
178
180
  # Create policy object
179
181
  model = cls(**saved_variables["data"])
@@ -2,6 +2,7 @@
2
2
  Save util taken from stable_baselines
3
3
  used to serialize data (class parameters) of model classes
4
4
  """
5
+
5
6
  import base64
6
7
  import functools
7
8
  import io
@@ -446,7 +447,7 @@ def load_from_zip_file(
446
447
  file_content.seek(0)
447
448
  # Load the parameters with the right ``map_location``.
448
449
  # Remove ".pth" ending with splitext
449
- th_object = th.load(file_content, map_location=device)
450
+ th_object = th.load(file_content, map_location=device, weights_only=True)
450
451
  # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
451
452
  if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
452
453
  # PyTorch variables (not state_dicts)
@@ -1,4 +1,5 @@
1
1
  """Common aliases for type hints"""
2
+
2
3
  from enum import Enum
3
4
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union
4
5
 
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Helpers for dealing with vectorized environments.
3
3
  """
4
+
4
5
  from collections import OrderedDict
5
6
  from typing import Any, Dict, List, Tuple
6
7
 
@@ -29,7 +29,12 @@ class VecFrameStack(VecEnvWrapper):
29
29
 
30
30
  def step_wait(
31
31
  self,
32
- ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
32
+ ) -> Tuple[
33
+ Union[np.ndarray, Dict[str, np.ndarray]],
34
+ np.ndarray,
35
+ np.ndarray,
36
+ List[Dict[str, Any]],
37
+ ]:
33
38
  observations, rewards, dones, infos = self.venv.step_wait()
34
39
  observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
35
40
  return observations, rewards, dones, infos
@@ -60,11 +60,11 @@ class DDPG(TD3):
60
60
  learning_rate: Union[float, Schedule] = 1e-3,
61
61
  buffer_size: int = 1_000_000, # 1e6
62
62
  learning_starts: int = 100,
63
- batch_size: int = 100,
63
+ batch_size: int = 256,
64
64
  tau: float = 0.005,
65
65
  gamma: float = 0.99,
66
- train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
67
- gradient_steps: int = -1,
66
+ train_freq: Union[int, Tuple[int, str]] = 1,
67
+ gradient_steps: int = 1,
68
68
  action_noise: Optional[ActionNoise] = None,
69
69
  replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
70
70
  replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
@@ -79,7 +79,7 @@ class DQN(OffPolicyAlgorithm):
79
79
  env: Union[GymEnv, str],
80
80
  learning_rate: Union[float, Schedule] = 1e-4,
81
81
  buffer_size: int = 1_000_000, # 1e6
82
- learning_starts: int = 50000,
82
+ learning_starts: int = 100,
83
83
  batch_size: int = 32,
84
84
  tau: float = 1.0,
85
85
  gamma: float = 0.99,
@@ -83,11 +83,11 @@ class TD3(OffPolicyAlgorithm):
83
83
  learning_rate: Union[float, Schedule] = 1e-3,
84
84
  buffer_size: int = 1_000_000, # 1e6
85
85
  learning_starts: int = 100,
86
- batch_size: int = 100,
86
+ batch_size: int = 256,
87
87
  tau: float = 0.005,
88
88
  gamma: float = 0.99,
89
- train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
90
- gradient_steps: int = -1,
89
+ train_freq: Union[int, Tuple[int, str]] = 1,
90
+ gradient_steps: int = 1,
91
91
  action_noise: Optional[ActionNoise] = None,
92
92
  replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
93
93
  replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
- Name: stable-baselines3
3
- Version: 2.2.1
2
+ Name: stable_baselines3
3
+ Version: 2.3.0
4
4
  Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
5
5
  Home-page: https://github.com/DLR-RM/stable-baselines3
6
6
  Author: Antonin Raffin
@@ -34,8 +34,8 @@ Requires-Dist: pytest-cov; extra == "tests"
34
34
  Requires-Dist: pytest-env; extra == "tests"
35
35
  Requires-Dist: pytest-xdist; extra == "tests"
36
36
  Requires-Dist: mypy; extra == "tests"
37
- Requires-Dist: ruff>=0.0.288; extra == "tests"
38
- Requires-Dist: black<24,>=23.9.1; extra == "tests"
37
+ Requires-Dist: ruff>=0.3.1; extra == "tests"
38
+ Requires-Dist: black<25,>=24.2.0; extra == "tests"
39
39
  Provides-Extra: docs
40
40
  Requires-Dist: sphinx<8,>=5; extra == "docs"
41
41
  Requires-Dist: sphinx-autobuild; extra == "docs"
@@ -99,7 +99,7 @@ import gymnasium
99
99
 
100
100
  from stable_baselines3 import PPO
101
101
 
102
- env = gymnasium.make("CartPole-v1")
102
+ env = gymnasium.make("CartPole-v1", render_mode="human")
103
103
 
104
104
  model = PPO("MlpPolicy", env, verbose=1)
105
105
  model.learn(total_timesteps=10_000)
@@ -39,5 +39,5 @@ pytest-cov
39
39
  pytest-env
40
40
  pytest-xdist
41
41
  mypy
42
- ruff>=0.0.288
43
- black<24,>=23.9.1
42
+ ruff>=0.3.1
43
+ black<25,>=24.2.0
@@ -123,6 +123,8 @@ def test_high_dimension_action_space():
123
123
  spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
124
124
  # Non zero start index
125
125
  spaces.Discrete(3, start=-1),
126
+ # Non zero start index (MultiDiscrete)
127
+ spaces.MultiDiscrete([4, 4], start=[1, 0]),
126
128
  # Non zero start index inside a Dict
127
129
  spaces.Dict({"obs": spaces.Discrete(3, start=1)}),
128
130
  ],
@@ -164,6 +166,8 @@ def test_non_default_spaces(new_obs_space):
164
166
  spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
165
167
  # Non zero start index
166
168
  spaces.Discrete(3, start=-1),
169
+ # Non zero start index (MultiDiscrete)
170
+ spaces.MultiDiscrete([4, 4], start=[1, 0]),
167
171
  ],
168
172
  )
169
173
  def test_non_default_action_spaces(new_action_space):
@@ -179,7 +183,7 @@ def test_non_default_action_spaces(new_action_space):
179
183
  env.action_space = new_action_space
180
184
 
181
185
  # Discrete action space
182
- if isinstance(new_action_space, spaces.Discrete):
186
+ if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)):
183
187
  with pytest.warns(UserWarning):
184
188
  check_env(env)
185
189
  return
@@ -14,7 +14,7 @@ from gymnasium import spaces
14
14
  from matplotlib import pyplot as plt
15
15
  from pandas.errors import EmptyDataError
16
16
 
17
- from stable_baselines3 import A2C, DQN
17
+ from stable_baselines3 import A2C, DQN, PPO
18
18
  from stable_baselines3.common.env_checker import check_env
19
19
  from stable_baselines3.common.logger import (
20
20
  DEBUG,
@@ -33,6 +33,7 @@ from stable_baselines3.common.logger import (
33
33
  read_csv,
34
34
  read_json,
35
35
  )
36
+ from stable_baselines3.common.monitor import Monitor
36
37
 
37
38
  KEY_VALUES = {
38
39
  "test": 1,
@@ -474,3 +475,92 @@ def test_human_output_format_custom_test_io(base_class):
474
475
  """
475
476
 
476
477
  assert printed == desired_printed
478
+
479
+
480
+ class DummySuccessEnv(gym.Env):
481
+ """
482
+ Create a dummy success environment that returns wether True or False for info['is_success']
483
+ at the end of an episode according to its dummy successes list
484
+ """
485
+
486
+ def __init__(self, dummy_successes, ep_steps):
487
+ """Init the dummy success env
488
+
489
+ :param dummy_successes: list of size (n_logs_iterations, n_episodes_per_log) that specifies
490
+ the success value of log iteration i at episode j
491
+ :param ep_steps: number of steps per episode (to activate truncated)
492
+ """
493
+ self.n_steps = 0
494
+ self.log_id = 0
495
+ self.ep_id = 0
496
+
497
+ self.ep_steps = ep_steps
498
+
499
+ self.dummy_success = dummy_successes
500
+ self.num_logs = len(dummy_successes)
501
+ self.ep_per_log = len(dummy_successes[0])
502
+ self.steps_per_log = self.ep_per_log * self.ep_steps
503
+
504
+ self.action_space = spaces.Discrete(2)
505
+ self.observation_space = spaces.Discrete(2)
506
+
507
+ def reset(self, seed=None, options=None):
508
+ """
509
+ Reset the env and advance to the next episode_id to get the next dummy success
510
+ """
511
+ self.n_steps = 0
512
+
513
+ if self.ep_id == self.ep_per_log:
514
+ self.ep_id = 0
515
+ self.log_id = (self.log_id + 1) % self.num_logs
516
+
517
+ return self.observation_space.sample(), {}
518
+
519
+ def step(self, action):
520
+ """
521
+ Step and return a dummy success when an episode is truncated
522
+ """
523
+ self.n_steps += 1
524
+ truncated = self.n_steps >= self.ep_steps
525
+
526
+ info = {}
527
+ if truncated:
528
+ maybe_success = self.dummy_success[self.log_id][self.ep_id]
529
+ info["is_success"] = maybe_success
530
+ self.ep_id += 1
531
+ return self.observation_space.sample(), 0.0, False, truncated, info
532
+
533
+
534
+ def test_rollout_success_rate_on_policy_algorithm(tmp_path):
535
+ """
536
+ Test if the rollout/success_rate information is correctly logged with on policy algorithms
537
+
538
+ To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode)
539
+ is going to be successfull or not.
540
+ """
541
+
542
+ STATS_WINDOW_SIZE = 10
543
+ # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE
544
+ dummy_successes = [
545
+ [True] * 3 + [False] * 7,
546
+ [True] * 5 + [False] * 5,
547
+ [True] * 8 + [False] * 2,
548
+ ]
549
+ ep_steps = 64
550
+
551
+ # Monitor the env to track the success info
552
+ monitor_file = str(tmp_path / "monitor.csv")
553
+ env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",))
554
+
555
+ # Equip the model of a custom logger to check the success_rate info
556
+ model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1)
557
+ logger = InMemoryLogger()
558
+ model.set_logger(logger)
559
+
560
+ # Make the model learn and check that the success rate corresponds to the ratio of dummy successes
561
+ model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
562
+ assert logger.name_to_value["rollout/success_rate"] == 0.3
563
+ model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
564
+ assert logger.name_to_value["rollout/success_rate"] == 0.5
565
+ model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
566
+ assert logger.name_to_value["rollout/success_rate"] == 0.8
@@ -1 +0,0 @@
1
- 2.2.1