stable-baselines3 2.3.2__tar.gz → 2.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {stable_baselines3-2.3.2/stable_baselines3.egg-info → stable_baselines3-2.4.0}/PKG-INFO +5 -15
  2. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/README.md +24 -14
  3. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/pyproject.toml +18 -15
  4. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/setup.py +18 -29
  5. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/base_class.py +30 -8
  6. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/buffers.py +2 -2
  7. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/callbacks.py +5 -1
  8. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/env_checker.py +9 -1
  9. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/logger.py +3 -2
  10. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/monitor.py +1 -1
  11. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/on_policy_algorithm.py +24 -1
  12. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/policies.py +2 -2
  13. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/results_plotter.py +1 -1
  14. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/running_mean_std.py +1 -1
  15. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/torch_layers.py +45 -10
  16. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/type_aliases.py +1 -1
  17. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/utils.py +2 -2
  18. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/dummy_vec_env.py +4 -4
  19. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/patch_gym.py +2 -2
  20. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/subproc_vec_env.py +17 -17
  21. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/util.py +4 -16
  22. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_normalize.py +1 -1
  23. stable_baselines3-2.4.0/stable_baselines3/common/vec_env/vec_video_recorder.py +154 -0
  24. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/dqn/policies.py +1 -1
  25. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/her/her_replay_buffer.py +4 -4
  26. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ppo/ppo.py +0 -4
  27. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/sac/policies.py +1 -1
  28. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/td3/policies.py +1 -1
  29. stable_baselines3-2.4.0/stable_baselines3/version.txt +1 -0
  30. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0/stable_baselines3.egg-info}/PKG-INFO +5 -15
  31. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3.egg-info/requires.txt +4 -15
  32. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_buffers.py +17 -10
  33. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_callbacks.py +26 -0
  34. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_cnn.py +1 -1
  35. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_custom_policy.py +56 -0
  36. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_dict_env.py +2 -3
  37. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_envs.py +2 -0
  38. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_gae.py +1 -1
  39. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_her.py +1 -1
  40. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_logger.py +69 -15
  41. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_run.py +12 -2
  42. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_save_load.py +27 -3
  43. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_utils.py +4 -1
  44. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_envs.py +1 -1
  45. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_normalize.py +1 -1
  46. stable_baselines3-2.3.2/stable_baselines3/common/vec_env/vec_video_recorder.py +0 -113
  47. stable_baselines3-2.3.2/stable_baselines3/version.txt +0 -1
  48. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/LICENSE +0 -0
  49. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/NOTICE +0 -0
  50. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/setup.cfg +0 -0
  51. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/__init__.py +0 -0
  52. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/a2c/__init__.py +0 -0
  53. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/a2c/a2c.py +0 -0
  54. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/a2c/policies.py +0 -0
  55. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/__init__.py +0 -0
  56. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/atari_wrappers.py +0 -0
  57. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/distributions.py +0 -0
  58. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/env_util.py +0 -0
  59. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/envs/__init__.py +0 -0
  60. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/envs/bit_flipping_env.py +0 -0
  61. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/envs/identity_env.py +0 -0
  62. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/envs/multi_input_envs.py +0 -0
  63. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/evaluation.py +0 -0
  64. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/noise.py +0 -0
  65. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/off_policy_algorithm.py +0 -0
  66. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/preprocessing.py +0 -0
  67. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/save_util.py +0 -0
  68. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/sb2_compat/__init__.py +0 -0
  69. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +0 -0
  70. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/__init__.py +0 -0
  71. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/base_vec_env.py +0 -0
  72. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/stacked_observations.py +0 -0
  73. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_check_nan.py +0 -0
  74. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_extract_dict_obs.py +0 -0
  75. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_frame_stack.py +0 -0
  76. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_monitor.py +0 -0
  77. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_transpose.py +0 -0
  78. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ddpg/__init__.py +0 -0
  79. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ddpg/ddpg.py +0 -0
  80. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ddpg/policies.py +0 -0
  81. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/dqn/__init__.py +0 -0
  82. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/dqn/dqn.py +0 -0
  83. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/her/__init__.py +0 -0
  84. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/her/goal_selection_strategy.py +0 -0
  85. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ppo/__init__.py +0 -0
  86. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ppo/policies.py +0 -0
  87. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/py.typed +0 -0
  88. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/sac/__init__.py +0 -0
  89. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/sac/sac.py +0 -0
  90. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/td3/__init__.py +0 -0
  91. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/td3/td3.py +0 -0
  92. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3.egg-info/SOURCES.txt +0 -0
  93. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3.egg-info/dependency_links.txt +0 -0
  94. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3.egg-info/top_level.txt +0 -0
  95. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_deterministic.py +0 -0
  96. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_distributions.py +0 -0
  97. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_env_checker.py +0 -0
  98. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_identity.py +0 -0
  99. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_monitor.py +0 -0
  100. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_predict.py +0 -0
  101. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_preprocessing.py +0 -0
  102. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_sde.py +0 -0
  103. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_spaces.py +0 -0
  104. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_tensorboard.py +0 -0
  105. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_train_eval_mode.py +0 -0
  106. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_check_nan.py +0 -0
  107. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_extract_dict_obs.py +0 -0
  108. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_monitor.py +0 -0
  109. {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_stacked_obs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: stable_baselines3
3
- Version: 2.3.2
3
+ Version: 2.4.0
4
4
  Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
5
5
  Home-page: https://github.com/DLR-RM/stable-baselines3
6
6
  Author: Antonin Raffin
@@ -22,8 +22,8 @@ Requires-Python: >=3.8
22
22
  Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
24
  License-File: NOTICE
25
- Requires-Dist: gymnasium<0.30,>=0.28.1
26
- Requires-Dist: numpy>=1.20
25
+ Requires-Dist: gymnasium<1.1.0,>=0.29.1
26
+ Requires-Dist: numpy<2.0,>=1.20
27
27
  Requires-Dist: torch>=1.13
28
28
  Requires-Dist: cloudpickle
29
29
  Requires-Dist: pandas
@@ -37,7 +37,7 @@ Requires-Dist: mypy; extra == "tests"
37
37
  Requires-Dist: ruff>=0.3.1; extra == "tests"
38
38
  Requires-Dist: black<25,>=24.2.0; extra == "tests"
39
39
  Provides-Extra: docs
40
- Requires-Dist: sphinx<8,>=5; extra == "docs"
40
+ Requires-Dist: sphinx<9,>=5; extra == "docs"
41
41
  Requires-Dist: sphinx-autobuild; extra == "docs"
42
42
  Requires-Dist: sphinx-rtd-theme>=1.3.0; extra == "docs"
43
43
  Requires-Dist: sphinxcontrib.spelling; extra == "docs"
@@ -49,18 +49,8 @@ Requires-Dist: tensorboard>=2.9.1; extra == "extra"
49
49
  Requires-Dist: psutil; extra == "extra"
50
50
  Requires-Dist: tqdm; extra == "extra"
51
51
  Requires-Dist: rich; extra == "extra"
52
- Requires-Dist: shimmy[atari]~=1.3.0; extra == "extra"
52
+ Requires-Dist: ale-py>=0.9.0; extra == "extra"
53
53
  Requires-Dist: pillow; extra == "extra"
54
- Requires-Dist: autorom[accept-rom-license]~=0.6.1; extra == "extra"
55
- Provides-Extra: extra-no-roms
56
- Requires-Dist: opencv-python; extra == "extra-no-roms"
57
- Requires-Dist: pygame; extra == "extra-no-roms"
58
- Requires-Dist: tensorboard>=2.9.1; extra == "extra-no-roms"
59
- Requires-Dist: psutil; extra == "extra-no-roms"
60
- Requires-Dist: tqdm; extra == "extra-no-roms"
61
- Requires-Dist: rich; extra == "extra-no-roms"
62
- Requires-Dist: shimmy[atari]~=1.3.0; extra == "extra-no-roms"
63
- Requires-Dist: pillow; extra == "extra-no-roms"
64
54
 
65
55
 
66
56
 
@@ -1,13 +1,13 @@
1
- <img src="docs/\_static/img/logo.png" align="right" width="40%"/>
2
-
3
1
  <!-- [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
4
- ![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
5
- [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
2
+ [![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
3
+ [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
6
4
  [![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
7
5
 
8
6
 
9
7
  # Stable Baselines3
10
8
 
9
+ <img src="docs/\_static/img/logo.png" align="right" width="40%"/>
10
+
11
11
  Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
12
12
 
13
13
  You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf).
@@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to
22
22
  **The performance of each algorithm was tested** (see *Results* section in their respective page),
23
23
  you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.
24
24
 
25
+ We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.
26
+
25
27
 
26
28
  | **Features** | **Stable-Baselines3** |
27
29
  | --------------------------- | ----------------------|
@@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin
41
43
 
42
44
  ### Planned features
43
45
 
44
- Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).
46
+ Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
47
+ If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).
48
+
49
+ While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
50
+ - newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
51
+ - faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
52
+ - the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)
45
53
 
46
54
  ## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)
47
55
 
@@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/
79
87
 
80
88
  We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
81
89
 
82
- This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
90
+ This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
83
91
 
84
92
  Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
85
93
 
@@ -97,17 +105,16 @@ It provides a minimal number of features compared to SB3 but can be much faster
97
105
  ### Prerequisites
98
106
  Stable Baselines3 requires Python 3.8+.
99
107
 
100
- #### Windows 10
108
+ #### Windows
101
109
 
102
110
  To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).
103
111
 
104
112
 
105
113
  ### Install using pip
106
114
  Install the Stable Baselines3 package:
115
+ ```sh
116
+ pip install 'stable-baselines3[extra]'
107
117
  ```
108
- pip install stable-baselines3[extra]
109
- ```
110
- **Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).
111
118
 
112
119
  This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
113
120
  ```sh
@@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks:
177
184
  | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
178
185
  | ARS<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
179
186
  | A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
187
+ | CrossQ<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
180
188
  | DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
181
189
  | DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
182
190
  | HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
@@ -191,8 +199,8 @@ All the following examples can be executed online using Google Colab notebooks:
191
199
 
192
200
  <b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.
193
201
 
194
- Actions `gym.spaces`:
195
- * `Box`: A N-dimensional box that containes every point in the action space.
202
+ Actions `gymnasium.spaces`:
203
+ * `Box`: A N-dimensional box that contains every point in the action space.
196
204
  * `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
197
205
  * `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
198
206
  * `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
@@ -218,9 +226,9 @@ To run a single test:
218
226
  python3 -m pytest -v -k 'test_check_env_dict_action'
219
227
  ```
220
228
 
221
- You can also do a static type check using `pytype` and `mypy`:
229
+ You can also do a static type check using `mypy`:
222
230
  ```sh
223
- pip install pytype mypy
231
+ pip install mypy
224
232
  make type
225
233
  ```
226
234
 
@@ -252,6 +260,8 @@ To cite this repository in publications:
252
260
  }
253
261
  ```
254
262
 
263
+ Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).
264
+
255
265
  ## Maintainers
256
266
 
257
267
  Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
@@ -13,11 +13,10 @@ ignore = ["B028", "RUF013"]
13
13
 
14
14
  [tool.ruff.lint.per-file-ignores]
15
15
  # Default implementation in abstract methods
16
- "./stable_baselines3/common/callbacks.py"= ["B027"]
17
- "./stable_baselines3/common/noise.py"= ["B027"]
16
+ "./stable_baselines3/common/callbacks.py" = ["B027"]
17
+ "./stable_baselines3/common/noise.py" = ["B027"]
18
18
  # ClassVar, implicit optional check not needed for tests
19
- "./tests/*.py"= ["RUF012", "RUF013"]
20
-
19
+ "./tests/*.py" = ["RUF012", "RUF013"]
21
20
 
22
21
  [tool.ruff.lint.mccabe]
23
22
  # Unlike Flake8, default to a complexity level of 10.
@@ -37,31 +36,35 @@ exclude = """(?x)(
37
36
 
38
37
  [tool.pytest.ini_options]
39
38
  # Deterministic ordering for tests; useful for pytest-xdist.
40
- env = [
41
- "PYTHONHASHSEED=0"
42
- ]
39
+ env = ["PYTHONHASHSEED=0"]
43
40
 
44
41
  filterwarnings = [
45
42
  # Tensorboard warnings
46
43
  "ignore::DeprecationWarning:tensorboard",
47
44
  # Gymnasium warnings
48
45
  "ignore::UserWarning:gymnasium",
46
+ # tqdm warning about rich being experimental
47
+ "ignore:rich is experimental",
49
48
  ]
50
49
  markers = [
51
- "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
50
+ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
52
51
  ]
53
52
 
54
53
  [tool.coverage.run]
55
54
  disable_warnings = ["couldnt-parse"]
56
55
  branch = false
57
56
  omit = [
58
- "tests/*",
59
- "setup.py",
60
- # Require graphical interface
61
- "stable_baselines3/common/results_plotter.py",
62
- # Require ffmpeg
63
- "stable_baselines3/common/vec_env/vec_video_recorder.py",
57
+ "tests/*",
58
+ "setup.py",
59
+ # Require graphical interface
60
+ "stable_baselines3/common/results_plotter.py",
61
+ # Require ffmpeg
62
+ "stable_baselines3/common/vec_env/vec_video_recorder.py",
64
63
  ]
65
64
 
66
65
  [tool.coverage.report]
67
- exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
66
+ exclude_lines = [
67
+ "pragma: no cover",
68
+ "raise NotImplementedError()",
69
+ "if typing.TYPE_CHECKING:",
70
+ ]
@@ -70,38 +70,14 @@ model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
70
70
 
71
71
  """ # noqa:E501
72
72
 
73
- # Atari Games download is sometimes problematic:
74
- # https://github.com/Farama-Foundation/AutoROM/issues/39
75
- # That's why we define extra packages without it.
76
- extra_no_roms = [
77
- # For render
78
- "opencv-python",
79
- "pygame",
80
- # Tensorboard support
81
- "tensorboard>=2.9.1",
82
- # Checking memory taken by replay buffer
83
- "psutil",
84
- # For progress bar callback
85
- "tqdm",
86
- "rich",
87
- # For atari games,
88
- "shimmy[atari]~=1.3.0",
89
- "pillow",
90
- ]
91
-
92
- extra_packages = extra_no_roms + [ # noqa: RUF005
93
- # For atari roms,
94
- "autorom[accept-rom-license]~=0.6.1",
95
- ]
96
-
97
73
 
98
74
  setup(
99
75
  name="stable_baselines3",
100
76
  packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
101
77
  package_data={"stable_baselines3": ["py.typed", "version.txt"]},
102
78
  install_requires=[
103
- "gymnasium>=0.28.1,<0.30",
104
- "numpy>=1.20",
79
+ "gymnasium>=0.29.1,<1.1.0",
80
+ "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
105
81
  "torch>=1.13",
106
82
  # For saving models
107
83
  "cloudpickle",
@@ -125,7 +101,7 @@ setup(
125
101
  "black>=24.2.0,<25",
126
102
  ],
127
103
  "docs": [
128
- "sphinx>=5,<8",
104
+ "sphinx>=5,<9",
129
105
  "sphinx-autobuild",
130
106
  "sphinx-rtd-theme>=1.3.0",
131
107
  # For spelling
@@ -133,8 +109,21 @@ setup(
133
109
  # Copy button for code snippets
134
110
  "sphinx_copybutton",
135
111
  ],
136
- "extra": extra_packages,
137
- "extra_no_roms": extra_no_roms,
112
+ "extra": [
113
+ # For render
114
+ "opencv-python",
115
+ "pygame",
116
+ # Tensorboard support
117
+ "tensorboard>=2.9.1",
118
+ # Checking memory taken by replay buffer
119
+ "psutil",
120
+ # For progress bar callback
121
+ "tqdm",
122
+ "rich",
123
+ # For atari games,
124
+ "ale-py>=0.9.0",
125
+ "pillow",
126
+ ],
138
127
  },
139
128
  description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
140
129
  author="Antonin Raffin",
@@ -48,7 +48,7 @@ def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv:
48
48
  """If env is a string, make the environment; otherwise, return env.
49
49
 
50
50
  :param env: The environment to learn from.
51
- :param verbose: Verbosity level: 0 for no output, 1 for indicating if envrironment is created
51
+ :param verbose: Verbosity level: 0 for no output, 1 for indicating if environment is created
52
52
  :return A Gym (vector) environment.
53
53
  """
54
54
  if isinstance(env, str):
@@ -592,7 +592,7 @@ class BaseAlgorithm(ABC):
592
592
  if isinstance(load_path_or_dict, dict):
593
593
  params = load_path_or_dict
594
594
  else:
595
- _, params, _ = load_from_zip_file(load_path_or_dict, device=device)
595
+ _, params, _ = load_from_zip_file(load_path_or_dict, device=device, load_data=False)
596
596
 
597
597
  # Keep track which objects were updated.
598
598
  # `_get_torch_save_params` returns [params, other_pytorch_variables].
@@ -692,10 +692,9 @@ class BaseAlgorithm(ABC):
692
692
  if "device" in data["policy_kwargs"]:
693
693
  del data["policy_kwargs"]["device"]
694
694
  # backward compatibility, convert to new format
695
- if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0:
696
- saved_net_arch = data["policy_kwargs"]["net_arch"]
697
- if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
698
- data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
695
+ saved_net_arch = data["policy_kwargs"].get("net_arch")
696
+ if saved_net_arch and isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
697
+ data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
699
698
 
700
699
  if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
701
700
  raise ValueError(
@@ -743,13 +742,13 @@ class BaseAlgorithm(ABC):
743
742
  # put state_dicts back in place
744
743
  model.set_parameters(params, exact_match=True, device=device)
745
744
  except RuntimeError as e:
746
- # Patch to load Policy saved using SB3 < 1.7.0
745
+ # Patch to load policies saved using SB3 < 1.7.0
747
746
  # the error is probably due to old policy being loaded
748
747
  # See https://github.com/DLR-RM/stable-baselines3/issues/1233
749
748
  if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
750
749
  model.set_parameters(params, exact_match=False, device=device)
751
750
  warnings.warn(
752
- "You are probably loading a model saved with SB3 < 1.7.0, "
751
+ "You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, "
753
752
  "we deactivated exact_match so you can save the model "
754
753
  "again to avoid issues in the future "
755
754
  "(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
@@ -758,6 +757,29 @@ class BaseAlgorithm(ABC):
758
757
  )
759
758
  else:
760
759
  raise e
760
+ except ValueError as e:
761
+ # Patch to load DQN policies saved using SB3 < 2.4.0
762
+ # The target network params are no longer in the optimizer
763
+ # See https://github.com/DLR-RM/stable-baselines3/pull/1963
764
+ saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index]
765
+ n_params_saved = len(saved_optim_params)
766
+ n_params = len(model.policy.optimizer.param_groups[0]["params"])
767
+ if n_params_saved == 2 * n_params:
768
+ # Truncate to include only online network params
769
+ params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index]
770
+
771
+ model.set_parameters(params, exact_match=True, device=device)
772
+ warnings.warn(
773
+ "You are probably loading a DQN model saved with SB3 < 2.4.0, "
774
+ "we truncated the optimizer state so you can save the model "
775
+ "again to avoid issues in the future "
776
+ "(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). "
777
+ f"Original error: {e} \n"
778
+ "Note: the model should still work fine, this only a warning."
779
+ )
780
+ else:
781
+ raise e
782
+
761
783
  # put other pytorch variables back in place
762
784
  if pytorch_variables is not None:
763
785
  for name in pytorch_variables:
@@ -419,12 +419,12 @@ class RolloutBuffer(BaseBuffer):
419
419
  :param dones: if the last step was a terminal step (one bool for each env).
420
420
  """
421
421
  # Convert to numpy
422
- last_values = last_values.clone().cpu().numpy().flatten()
422
+ last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
423
423
 
424
424
  last_gae_lam = 0
425
425
  for step in reversed(range(self.buffer_size)):
426
426
  if step == self.buffer_size - 1:
427
- next_non_terminal = 1.0 - dones
427
+ next_non_terminal = 1.0 - dones.astype(np.float32)
428
428
  next_values = last_values
429
429
  else:
430
430
  next_non_terminal = 1.0 - self.episode_starts[step + 1]
@@ -204,6 +204,10 @@ class CallbackList(BaseCallback):
204
204
  for callback in self.callbacks:
205
205
  callback.init_callback(self.model)
206
206
 
207
+ # Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791
208
+ # pass through the parent callback to all children
209
+ callback.parent = self.parent
210
+
207
211
  def _on_training_start(self) -> None:
208
212
  for callback in self.callbacks:
209
213
  callback.on_training_start(self.locals, self.globals)
@@ -606,7 +610,7 @@ class StopTrainingOnMaxEpisodes(BaseCallback):
606
610
  self.n_episodes = 0
607
611
 
608
612
  def _init_callback(self) -> None:
609
- # At start set total max according to number of envirnments
613
+ # At start set total max according to number of environments
610
614
  self._total_max_episodes = self.max_episodes * self.training_env.num_envs
611
615
 
612
616
  def _on_step(self) -> bool:
@@ -98,6 +98,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
98
98
  "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
99
99
  )
100
100
 
101
+ if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1:
102
+ warnings.warn(
103
+ f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} "
104
+ "which is currently not supported by Stable-Baselines3. "
105
+ "Please convert it to a 1D array using a wrapper: "
106
+ "https://github.com/DLR-RM/stable-baselines3/issues/1836."
107
+ )
108
+
101
109
  if isinstance(observation_space, spaces.Tuple):
102
110
  warnings.warn(
103
111
  "The observation space is a Tuple, "
@@ -397,7 +405,7 @@ def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover
397
405
  "you may have trouble when calling `.render()`"
398
406
  )
399
407
 
400
- # Only check currrent render mode
408
+ # Only check current render mode
401
409
  if env.render_mode:
402
410
  env.render()
403
411
  env.close()
@@ -412,8 +412,9 @@ class TensorBoardOutputFormat(KVWriter):
412
412
  else:
413
413
  self.writer.add_scalar(key, value, step)
414
414
 
415
- if isinstance(value, th.Tensor):
416
- self.writer.add_histogram(key, value, step)
415
+ if isinstance(value, (th.Tensor, np.ndarray)):
416
+ # Convert to Torch so it works with numpy<1.24 and torch<2.0
417
+ self.writer.add_histogram(key, th.as_tensor(value), step)
417
418
 
418
419
  if isinstance(value, Video):
419
420
  self.writer.add_video(key, value.frames, step, value.fps)
@@ -189,7 +189,7 @@ class ResultsWriter:
189
189
  filename = os.path.realpath(filename)
190
190
  # Create (if any) missing filename directories
191
191
  os.makedirs(os.path.dirname(filename), exist_ok=True)
192
- # Append mode when not overridding existing file
192
+ # Append mode when not overriding existing file
193
193
  mode = "w" if override_existing else "a"
194
194
  # Prevent newline issue on Windows, see GH issue #692
195
195
  self.file_handler = open(filename, f"{mode}t", newline="\n")
@@ -1,5 +1,6 @@
1
1
  import sys
2
2
  import time
3
+ import warnings
3
4
  from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
4
5
 
5
6
  import numpy as np
@@ -135,6 +136,28 @@ class OnPolicyAlgorithm(BaseAlgorithm):
135
136
  self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
136
137
  )
137
138
  self.policy = self.policy.to(self.device)
139
+ # Warn when not using CPU with MlpPolicy
140
+ self._maybe_recommend_cpu()
141
+
142
+ def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
143
+ """
144
+ Recommend to use CPU only when using A2C/PPO with MlpPolicy.
145
+
146
+ :param: The name of the class for the default MlpPolicy.
147
+ """
148
+ policy_class_name = self.policy_class.__name__
149
+ if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
150
+ warnings.warn(
151
+ f"You are trying to run {self.__class__.__name__} on the GPU, "
152
+ "but it is primarily intended to run on the CPU when not using a CNN policy "
153
+ f"(you are using {policy_class_name} which should be a MlpPolicy). "
154
+ "See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
155
+ "for more info. "
156
+ "You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
157
+ "Note: The model will train, but the GPU utilization will be poor and "
158
+ "the training might take longer than on CPU.",
159
+ UserWarning,
160
+ )
138
161
 
139
162
  def collect_rollouts(
140
163
  self,
@@ -208,7 +231,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
208
231
  # Reshape in case of discrete action
209
232
  actions = actions.reshape(-1, 1)
210
233
 
211
- # Handle timeout by bootstraping with value function
234
+ # Handle timeout by bootstrapping with value function
212
235
  # see GitHub issue #633
213
236
  for idx, done in enumerate(dones):
214
237
  if (
@@ -367,7 +367,7 @@ class BasePolicy(BaseModel, ABC):
367
367
  with th.no_grad():
368
368
  actions = self._predict(obs_tensor, deterministic=deterministic)
369
369
  # Convert to numpy, and reshape to the original action shape
370
- actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
370
+ actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment]
371
371
 
372
372
  if isinstance(self.action_space, spaces.Box):
373
373
  if self.squash_output:
@@ -922,7 +922,7 @@ class ContinuousCritic(BaseModel):
922
922
  By default, it creates two critic networks used to reduce overestimation
923
923
  thanks to clipped Q-learning (cf TD3 paper).
924
924
 
925
- :param observation_space: Obervation space
925
+ :param observation_space: Observation space
926
926
  :param action_space: Action space
927
927
  :param net_arch: Network architecture
928
928
  :param features_extractor: Network to extract features
@@ -46,7 +46,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl
46
46
 
47
47
  def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
48
48
  """
49
- Decompose a data frame variable to x ans ys
49
+ Decompose a data frame variable to x and ys
50
50
 
51
51
  :param data_frame: the input data
52
52
  :param x_axis: the axis for the x and y output
@@ -6,7 +6,7 @@ import numpy as np
6
6
  class RunningMeanStd:
7
7
  def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
8
8
  """
9
- Calulates the running mean and std of a data stream
9
+ Calculates the running mean and std of a data stream
10
10
  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
11
11
 
12
12
  :param epsilon: helps with arithmetic issues