stable-baselines3 2.3.2__tar.gz → 2.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {stable_baselines3-2.3.2/stable_baselines3.egg-info → stable_baselines3-2.4.0}/PKG-INFO +5 -15
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/README.md +24 -14
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/pyproject.toml +18 -15
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/setup.py +18 -29
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/base_class.py +30 -8
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/buffers.py +2 -2
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/callbacks.py +5 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/env_checker.py +9 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/logger.py +3 -2
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/monitor.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/on_policy_algorithm.py +24 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/policies.py +2 -2
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/results_plotter.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/running_mean_std.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/torch_layers.py +45 -10
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/type_aliases.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/utils.py +2 -2
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/dummy_vec_env.py +4 -4
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/patch_gym.py +2 -2
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/subproc_vec_env.py +17 -17
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/util.py +4 -16
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_normalize.py +1 -1
- stable_baselines3-2.4.0/stable_baselines3/common/vec_env/vec_video_recorder.py +154 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/dqn/policies.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/her/her_replay_buffer.py +4 -4
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ppo/ppo.py +0 -4
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/sac/policies.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/td3/policies.py +1 -1
- stable_baselines3-2.4.0/stable_baselines3/version.txt +1 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0/stable_baselines3.egg-info}/PKG-INFO +5 -15
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3.egg-info/requires.txt +4 -15
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_buffers.py +17 -10
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_callbacks.py +26 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_cnn.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_custom_policy.py +56 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_dict_env.py +2 -3
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_envs.py +2 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_gae.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_her.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_logger.py +69 -15
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_run.py +12 -2
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_save_load.py +27 -3
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_utils.py +4 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_envs.py +1 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_normalize.py +1 -1
- stable_baselines3-2.3.2/stable_baselines3/common/vec_env/vec_video_recorder.py +0 -113
- stable_baselines3-2.3.2/stable_baselines3/version.txt +0 -1
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/LICENSE +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/NOTICE +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/setup.cfg +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/a2c/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/a2c/a2c.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/a2c/policies.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/atari_wrappers.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/distributions.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/env_util.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/envs/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/envs/bit_flipping_env.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/envs/identity_env.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/envs/multi_input_envs.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/evaluation.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/noise.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/off_policy_algorithm.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/preprocessing.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/save_util.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/sb2_compat/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/base_vec_env.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/stacked_observations.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_check_nan.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_extract_dict_obs.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_frame_stack.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_monitor.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/vec_env/vec_transpose.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ddpg/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ddpg/ddpg.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ddpg/policies.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/dqn/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/dqn/dqn.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/her/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/her/goal_selection_strategy.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ppo/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/ppo/policies.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/py.typed +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/sac/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/sac/sac.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/td3/__init__.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/td3/td3.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3.egg-info/SOURCES.txt +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3.egg-info/dependency_links.txt +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3.egg-info/top_level.txt +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_deterministic.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_distributions.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_env_checker.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_identity.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_monitor.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_predict.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_preprocessing.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_sde.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_spaces.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_tensorboard.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_train_eval_mode.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_check_nan.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_extract_dict_obs.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_monitor.py +0 -0
- {stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/tests/test_vec_stacked_obs.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: stable_baselines3
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.4.0
|
|
4
4
|
Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
|
|
5
5
|
Home-page: https://github.com/DLR-RM/stable-baselines3
|
|
6
6
|
Author: Antonin Raffin
|
|
@@ -22,8 +22,8 @@ Requires-Python: >=3.8
|
|
|
22
22
|
Description-Content-Type: text/markdown
|
|
23
23
|
License-File: LICENSE
|
|
24
24
|
License-File: NOTICE
|
|
25
|
-
Requires-Dist: gymnasium<0
|
|
26
|
-
Requires-Dist: numpy
|
|
25
|
+
Requires-Dist: gymnasium<1.1.0,>=0.29.1
|
|
26
|
+
Requires-Dist: numpy<2.0,>=1.20
|
|
27
27
|
Requires-Dist: torch>=1.13
|
|
28
28
|
Requires-Dist: cloudpickle
|
|
29
29
|
Requires-Dist: pandas
|
|
@@ -37,7 +37,7 @@ Requires-Dist: mypy; extra == "tests"
|
|
|
37
37
|
Requires-Dist: ruff>=0.3.1; extra == "tests"
|
|
38
38
|
Requires-Dist: black<25,>=24.2.0; extra == "tests"
|
|
39
39
|
Provides-Extra: docs
|
|
40
|
-
Requires-Dist: sphinx<
|
|
40
|
+
Requires-Dist: sphinx<9,>=5; extra == "docs"
|
|
41
41
|
Requires-Dist: sphinx-autobuild; extra == "docs"
|
|
42
42
|
Requires-Dist: sphinx-rtd-theme>=1.3.0; extra == "docs"
|
|
43
43
|
Requires-Dist: sphinxcontrib.spelling; extra == "docs"
|
|
@@ -49,18 +49,8 @@ Requires-Dist: tensorboard>=2.9.1; extra == "extra"
|
|
|
49
49
|
Requires-Dist: psutil; extra == "extra"
|
|
50
50
|
Requires-Dist: tqdm; extra == "extra"
|
|
51
51
|
Requires-Dist: rich; extra == "extra"
|
|
52
|
-
Requires-Dist:
|
|
52
|
+
Requires-Dist: ale-py>=0.9.0; extra == "extra"
|
|
53
53
|
Requires-Dist: pillow; extra == "extra"
|
|
54
|
-
Requires-Dist: autorom[accept-rom-license]~=0.6.1; extra == "extra"
|
|
55
|
-
Provides-Extra: extra-no-roms
|
|
56
|
-
Requires-Dist: opencv-python; extra == "extra-no-roms"
|
|
57
|
-
Requires-Dist: pygame; extra == "extra-no-roms"
|
|
58
|
-
Requires-Dist: tensorboard>=2.9.1; extra == "extra-no-roms"
|
|
59
|
-
Requires-Dist: psutil; extra == "extra-no-roms"
|
|
60
|
-
Requires-Dist: tqdm; extra == "extra-no-roms"
|
|
61
|
-
Requires-Dist: rich; extra == "extra-no-roms"
|
|
62
|
-
Requires-Dist: shimmy[atari]~=1.3.0; extra == "extra-no-roms"
|
|
63
|
-
Requires-Dist: pillow; extra == "extra-no-roms"
|
|
64
54
|
|
|
65
55
|
|
|
66
56
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
|
|
2
|
-
|
|
3
1
|
<!-- [](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
|
|
4
|
-

|
|
5
|
-
[](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [](https://
|
|
2
|
+
[](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
|
|
3
|
+
[](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
|
|
6
4
|
[](https://github.com/psf/black)
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
# Stable Baselines3
|
|
10
8
|
|
|
9
|
+
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
|
|
10
|
+
|
|
11
11
|
Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
|
|
12
12
|
|
|
13
13
|
You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf).
|
|
@@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to
|
|
|
22
22
|
**The performance of each algorithm was tested** (see *Results* section in their respective page),
|
|
23
23
|
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.
|
|
24
24
|
|
|
25
|
+
We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.
|
|
26
|
+
|
|
25
27
|
|
|
26
28
|
| **Features** | **Stable-Baselines3** |
|
|
27
29
|
| --------------------------- | ----------------------|
|
|
@@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin
|
|
|
41
43
|
|
|
42
44
|
### Planned features
|
|
43
45
|
|
|
44
|
-
|
|
46
|
+
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
|
|
47
|
+
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).
|
|
48
|
+
|
|
49
|
+
While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
|
|
50
|
+
- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
|
|
51
|
+
- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
|
|
52
|
+
- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)
|
|
45
53
|
|
|
46
54
|
## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)
|
|
47
55
|
|
|
@@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/
|
|
|
79
87
|
|
|
80
88
|
We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
|
|
81
89
|
|
|
82
|
-
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
|
|
90
|
+
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
|
|
83
91
|
|
|
84
92
|
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
|
|
85
93
|
|
|
@@ -97,17 +105,16 @@ It provides a minimal number of features compared to SB3 but can be much faster
|
|
|
97
105
|
### Prerequisites
|
|
98
106
|
Stable Baselines3 requires Python 3.8+.
|
|
99
107
|
|
|
100
|
-
#### Windows
|
|
108
|
+
#### Windows
|
|
101
109
|
|
|
102
110
|
To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).
|
|
103
111
|
|
|
104
112
|
|
|
105
113
|
### Install using pip
|
|
106
114
|
Install the Stable Baselines3 package:
|
|
115
|
+
```sh
|
|
116
|
+
pip install 'stable-baselines3[extra]'
|
|
107
117
|
```
|
|
108
|
-
pip install stable-baselines3[extra]
|
|
109
|
-
```
|
|
110
|
-
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).
|
|
111
118
|
|
|
112
119
|
This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
|
|
113
120
|
```sh
|
|
@@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks:
|
|
|
177
184
|
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
|
|
178
185
|
| ARS<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
|
|
179
186
|
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
|
187
|
+
| CrossQ<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
|
|
180
188
|
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
|
|
181
189
|
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
|
|
182
190
|
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
|
|
@@ -191,8 +199,8 @@ All the following examples can be executed online using Google Colab notebooks:
|
|
|
191
199
|
|
|
192
200
|
<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.
|
|
193
201
|
|
|
194
|
-
Actions `
|
|
195
|
-
* `Box`: A N-dimensional box that
|
|
202
|
+
Actions `gymnasium.spaces`:
|
|
203
|
+
* `Box`: A N-dimensional box that contains every point in the action space.
|
|
196
204
|
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
|
|
197
205
|
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
|
|
198
206
|
* `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
|
|
@@ -218,9 +226,9 @@ To run a single test:
|
|
|
218
226
|
python3 -m pytest -v -k 'test_check_env_dict_action'
|
|
219
227
|
```
|
|
220
228
|
|
|
221
|
-
You can also do a static type check using `
|
|
229
|
+
You can also do a static type check using `mypy`:
|
|
222
230
|
```sh
|
|
223
|
-
pip install
|
|
231
|
+
pip install mypy
|
|
224
232
|
make type
|
|
225
233
|
```
|
|
226
234
|
|
|
@@ -252,6 +260,8 @@ To cite this repository in publications:
|
|
|
252
260
|
}
|
|
253
261
|
```
|
|
254
262
|
|
|
263
|
+
Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).
|
|
264
|
+
|
|
255
265
|
## Maintainers
|
|
256
266
|
|
|
257
267
|
Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
|
|
@@ -13,11 +13,10 @@ ignore = ["B028", "RUF013"]
|
|
|
13
13
|
|
|
14
14
|
[tool.ruff.lint.per-file-ignores]
|
|
15
15
|
# Default implementation in abstract methods
|
|
16
|
-
"./stable_baselines3/common/callbacks.py"= ["B027"]
|
|
17
|
-
"./stable_baselines3/common/noise.py"= ["B027"]
|
|
16
|
+
"./stable_baselines3/common/callbacks.py" = ["B027"]
|
|
17
|
+
"./stable_baselines3/common/noise.py" = ["B027"]
|
|
18
18
|
# ClassVar, implicit optional check not needed for tests
|
|
19
|
-
"./tests/*.py"= ["RUF012", "RUF013"]
|
|
20
|
-
|
|
19
|
+
"./tests/*.py" = ["RUF012", "RUF013"]
|
|
21
20
|
|
|
22
21
|
[tool.ruff.lint.mccabe]
|
|
23
22
|
# Unlike Flake8, default to a complexity level of 10.
|
|
@@ -37,31 +36,35 @@ exclude = """(?x)(
|
|
|
37
36
|
|
|
38
37
|
[tool.pytest.ini_options]
|
|
39
38
|
# Deterministic ordering for tests; useful for pytest-xdist.
|
|
40
|
-
env = [
|
|
41
|
-
"PYTHONHASHSEED=0"
|
|
42
|
-
]
|
|
39
|
+
env = ["PYTHONHASHSEED=0"]
|
|
43
40
|
|
|
44
41
|
filterwarnings = [
|
|
45
42
|
# Tensorboard warnings
|
|
46
43
|
"ignore::DeprecationWarning:tensorboard",
|
|
47
44
|
# Gymnasium warnings
|
|
48
45
|
"ignore::UserWarning:gymnasium",
|
|
46
|
+
# tqdm warning about rich being experimental
|
|
47
|
+
"ignore:rich is experimental",
|
|
49
48
|
]
|
|
50
49
|
markers = [
|
|
51
|
-
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
|
|
50
|
+
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
|
|
52
51
|
]
|
|
53
52
|
|
|
54
53
|
[tool.coverage.run]
|
|
55
54
|
disable_warnings = ["couldnt-parse"]
|
|
56
55
|
branch = false
|
|
57
56
|
omit = [
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
57
|
+
"tests/*",
|
|
58
|
+
"setup.py",
|
|
59
|
+
# Require graphical interface
|
|
60
|
+
"stable_baselines3/common/results_plotter.py",
|
|
61
|
+
# Require ffmpeg
|
|
62
|
+
"stable_baselines3/common/vec_env/vec_video_recorder.py",
|
|
64
63
|
]
|
|
65
64
|
|
|
66
65
|
[tool.coverage.report]
|
|
67
|
-
exclude_lines = [
|
|
66
|
+
exclude_lines = [
|
|
67
|
+
"pragma: no cover",
|
|
68
|
+
"raise NotImplementedError()",
|
|
69
|
+
"if typing.TYPE_CHECKING:",
|
|
70
|
+
]
|
|
@@ -70,38 +70,14 @@ model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
|
|
|
70
70
|
|
|
71
71
|
""" # noqa:E501
|
|
72
72
|
|
|
73
|
-
# Atari Games download is sometimes problematic:
|
|
74
|
-
# https://github.com/Farama-Foundation/AutoROM/issues/39
|
|
75
|
-
# That's why we define extra packages without it.
|
|
76
|
-
extra_no_roms = [
|
|
77
|
-
# For render
|
|
78
|
-
"opencv-python",
|
|
79
|
-
"pygame",
|
|
80
|
-
# Tensorboard support
|
|
81
|
-
"tensorboard>=2.9.1",
|
|
82
|
-
# Checking memory taken by replay buffer
|
|
83
|
-
"psutil",
|
|
84
|
-
# For progress bar callback
|
|
85
|
-
"tqdm",
|
|
86
|
-
"rich",
|
|
87
|
-
# For atari games,
|
|
88
|
-
"shimmy[atari]~=1.3.0",
|
|
89
|
-
"pillow",
|
|
90
|
-
]
|
|
91
|
-
|
|
92
|
-
extra_packages = extra_no_roms + [ # noqa: RUF005
|
|
93
|
-
# For atari roms,
|
|
94
|
-
"autorom[accept-rom-license]~=0.6.1",
|
|
95
|
-
]
|
|
96
|
-
|
|
97
73
|
|
|
98
74
|
setup(
|
|
99
75
|
name="stable_baselines3",
|
|
100
76
|
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
|
|
101
77
|
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
|
|
102
78
|
install_requires=[
|
|
103
|
-
"gymnasium>=0.
|
|
104
|
-
"numpy>=1.20",
|
|
79
|
+
"gymnasium>=0.29.1,<1.1.0",
|
|
80
|
+
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
|
|
105
81
|
"torch>=1.13",
|
|
106
82
|
# For saving models
|
|
107
83
|
"cloudpickle",
|
|
@@ -125,7 +101,7 @@ setup(
|
|
|
125
101
|
"black>=24.2.0,<25",
|
|
126
102
|
],
|
|
127
103
|
"docs": [
|
|
128
|
-
"sphinx>=5,<
|
|
104
|
+
"sphinx>=5,<9",
|
|
129
105
|
"sphinx-autobuild",
|
|
130
106
|
"sphinx-rtd-theme>=1.3.0",
|
|
131
107
|
# For spelling
|
|
@@ -133,8 +109,21 @@ setup(
|
|
|
133
109
|
# Copy button for code snippets
|
|
134
110
|
"sphinx_copybutton",
|
|
135
111
|
],
|
|
136
|
-
"extra":
|
|
137
|
-
|
|
112
|
+
"extra": [
|
|
113
|
+
# For render
|
|
114
|
+
"opencv-python",
|
|
115
|
+
"pygame",
|
|
116
|
+
# Tensorboard support
|
|
117
|
+
"tensorboard>=2.9.1",
|
|
118
|
+
# Checking memory taken by replay buffer
|
|
119
|
+
"psutil",
|
|
120
|
+
# For progress bar callback
|
|
121
|
+
"tqdm",
|
|
122
|
+
"rich",
|
|
123
|
+
# For atari games,
|
|
124
|
+
"ale-py>=0.9.0",
|
|
125
|
+
"pillow",
|
|
126
|
+
],
|
|
138
127
|
},
|
|
139
128
|
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
|
|
140
129
|
author="Antonin Raffin",
|
|
@@ -48,7 +48,7 @@ def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv:
|
|
|
48
48
|
"""If env is a string, make the environment; otherwise, return env.
|
|
49
49
|
|
|
50
50
|
:param env: The environment to learn from.
|
|
51
|
-
:param verbose: Verbosity level: 0 for no output, 1 for indicating if
|
|
51
|
+
:param verbose: Verbosity level: 0 for no output, 1 for indicating if environment is created
|
|
52
52
|
:return A Gym (vector) environment.
|
|
53
53
|
"""
|
|
54
54
|
if isinstance(env, str):
|
|
@@ -592,7 +592,7 @@ class BaseAlgorithm(ABC):
|
|
|
592
592
|
if isinstance(load_path_or_dict, dict):
|
|
593
593
|
params = load_path_or_dict
|
|
594
594
|
else:
|
|
595
|
-
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)
|
|
595
|
+
_, params, _ = load_from_zip_file(load_path_or_dict, device=device, load_data=False)
|
|
596
596
|
|
|
597
597
|
# Keep track which objects were updated.
|
|
598
598
|
# `_get_torch_save_params` returns [params, other_pytorch_variables].
|
|
@@ -692,10 +692,9 @@ class BaseAlgorithm(ABC):
|
|
|
692
692
|
if "device" in data["policy_kwargs"]:
|
|
693
693
|
del data["policy_kwargs"]["device"]
|
|
694
694
|
# backward compatibility, convert to new format
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
|
|
695
|
+
saved_net_arch = data["policy_kwargs"].get("net_arch")
|
|
696
|
+
if saved_net_arch and isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
|
|
697
|
+
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
|
|
699
698
|
|
|
700
699
|
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
|
|
701
700
|
raise ValueError(
|
|
@@ -743,13 +742,13 @@ class BaseAlgorithm(ABC):
|
|
|
743
742
|
# put state_dicts back in place
|
|
744
743
|
model.set_parameters(params, exact_match=True, device=device)
|
|
745
744
|
except RuntimeError as e:
|
|
746
|
-
# Patch to load
|
|
745
|
+
# Patch to load policies saved using SB3 < 1.7.0
|
|
747
746
|
# the error is probably due to old policy being loaded
|
|
748
747
|
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
|
|
749
748
|
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
|
|
750
749
|
model.set_parameters(params, exact_match=False, device=device)
|
|
751
750
|
warnings.warn(
|
|
752
|
-
"You are probably loading a model saved with SB3 < 1.7.0, "
|
|
751
|
+
"You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, "
|
|
753
752
|
"we deactivated exact_match so you can save the model "
|
|
754
753
|
"again to avoid issues in the future "
|
|
755
754
|
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
|
|
@@ -758,6 +757,29 @@ class BaseAlgorithm(ABC):
|
|
|
758
757
|
)
|
|
759
758
|
else:
|
|
760
759
|
raise e
|
|
760
|
+
except ValueError as e:
|
|
761
|
+
# Patch to load DQN policies saved using SB3 < 2.4.0
|
|
762
|
+
# The target network params are no longer in the optimizer
|
|
763
|
+
# See https://github.com/DLR-RM/stable-baselines3/pull/1963
|
|
764
|
+
saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index]
|
|
765
|
+
n_params_saved = len(saved_optim_params)
|
|
766
|
+
n_params = len(model.policy.optimizer.param_groups[0]["params"])
|
|
767
|
+
if n_params_saved == 2 * n_params:
|
|
768
|
+
# Truncate to include only online network params
|
|
769
|
+
params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index]
|
|
770
|
+
|
|
771
|
+
model.set_parameters(params, exact_match=True, device=device)
|
|
772
|
+
warnings.warn(
|
|
773
|
+
"You are probably loading a DQN model saved with SB3 < 2.4.0, "
|
|
774
|
+
"we truncated the optimizer state so you can save the model "
|
|
775
|
+
"again to avoid issues in the future "
|
|
776
|
+
"(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). "
|
|
777
|
+
f"Original error: {e} \n"
|
|
778
|
+
"Note: the model should still work fine, this only a warning."
|
|
779
|
+
)
|
|
780
|
+
else:
|
|
781
|
+
raise e
|
|
782
|
+
|
|
761
783
|
# put other pytorch variables back in place
|
|
762
784
|
if pytorch_variables is not None:
|
|
763
785
|
for name in pytorch_variables:
|
|
@@ -419,12 +419,12 @@ class RolloutBuffer(BaseBuffer):
|
|
|
419
419
|
:param dones: if the last step was a terminal step (one bool for each env).
|
|
420
420
|
"""
|
|
421
421
|
# Convert to numpy
|
|
422
|
-
last_values = last_values.clone().cpu().numpy().flatten()
|
|
422
|
+
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
|
|
423
423
|
|
|
424
424
|
last_gae_lam = 0
|
|
425
425
|
for step in reversed(range(self.buffer_size)):
|
|
426
426
|
if step == self.buffer_size - 1:
|
|
427
|
-
next_non_terminal = 1.0 - dones
|
|
427
|
+
next_non_terminal = 1.0 - dones.astype(np.float32)
|
|
428
428
|
next_values = last_values
|
|
429
429
|
else:
|
|
430
430
|
next_non_terminal = 1.0 - self.episode_starts[step + 1]
|
|
@@ -204,6 +204,10 @@ class CallbackList(BaseCallback):
|
|
|
204
204
|
for callback in self.callbacks:
|
|
205
205
|
callback.init_callback(self.model)
|
|
206
206
|
|
|
207
|
+
# Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791
|
|
208
|
+
# pass through the parent callback to all children
|
|
209
|
+
callback.parent = self.parent
|
|
210
|
+
|
|
207
211
|
def _on_training_start(self) -> None:
|
|
208
212
|
for callback in self.callbacks:
|
|
209
213
|
callback.on_training_start(self.locals, self.globals)
|
|
@@ -606,7 +610,7 @@ class StopTrainingOnMaxEpisodes(BaseCallback):
|
|
|
606
610
|
self.n_episodes = 0
|
|
607
611
|
|
|
608
612
|
def _init_callback(self) -> None:
|
|
609
|
-
# At start set total max according to number of
|
|
613
|
+
# At start set total max according to number of environments
|
|
610
614
|
self._total_max_episodes = self.max_episodes * self.training_env.num_envs
|
|
611
615
|
|
|
612
616
|
def _on_step(self) -> bool:
|
|
@@ -98,6 +98,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
|
|
|
98
98
|
"is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
|
|
99
99
|
)
|
|
100
100
|
|
|
101
|
+
if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1:
|
|
102
|
+
warnings.warn(
|
|
103
|
+
f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} "
|
|
104
|
+
"which is currently not supported by Stable-Baselines3. "
|
|
105
|
+
"Please convert it to a 1D array using a wrapper: "
|
|
106
|
+
"https://github.com/DLR-RM/stable-baselines3/issues/1836."
|
|
107
|
+
)
|
|
108
|
+
|
|
101
109
|
if isinstance(observation_space, spaces.Tuple):
|
|
102
110
|
warnings.warn(
|
|
103
111
|
"The observation space is a Tuple, "
|
|
@@ -397,7 +405,7 @@ def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover
|
|
|
397
405
|
"you may have trouble when calling `.render()`"
|
|
398
406
|
)
|
|
399
407
|
|
|
400
|
-
# Only check
|
|
408
|
+
# Only check current render mode
|
|
401
409
|
if env.render_mode:
|
|
402
410
|
env.render()
|
|
403
411
|
env.close()
|
|
@@ -412,8 +412,9 @@ class TensorBoardOutputFormat(KVWriter):
|
|
|
412
412
|
else:
|
|
413
413
|
self.writer.add_scalar(key, value, step)
|
|
414
414
|
|
|
415
|
-
if isinstance(value, th.Tensor):
|
|
416
|
-
|
|
415
|
+
if isinstance(value, (th.Tensor, np.ndarray)):
|
|
416
|
+
# Convert to Torch so it works with numpy<1.24 and torch<2.0
|
|
417
|
+
self.writer.add_histogram(key, th.as_tensor(value), step)
|
|
417
418
|
|
|
418
419
|
if isinstance(value, Video):
|
|
419
420
|
self.writer.add_video(key, value.frames, step, value.fps)
|
|
@@ -189,7 +189,7 @@ class ResultsWriter:
|
|
|
189
189
|
filename = os.path.realpath(filename)
|
|
190
190
|
# Create (if any) missing filename directories
|
|
191
191
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
192
|
-
# Append mode when not
|
|
192
|
+
# Append mode when not overriding existing file
|
|
193
193
|
mode = "w" if override_existing else "a"
|
|
194
194
|
# Prevent newline issue on Windows, see GH issue #692
|
|
195
195
|
self.file_handler = open(filename, f"{mode}t", newline="\n")
|
{stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/on_policy_algorithm.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import time
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -135,6 +136,28 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|
|
135
136
|
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
|
|
136
137
|
)
|
|
137
138
|
self.policy = self.policy.to(self.device)
|
|
139
|
+
# Warn when not using CPU with MlpPolicy
|
|
140
|
+
self._maybe_recommend_cpu()
|
|
141
|
+
|
|
142
|
+
def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
|
|
143
|
+
"""
|
|
144
|
+
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
|
|
145
|
+
|
|
146
|
+
:param: The name of the class for the default MlpPolicy.
|
|
147
|
+
"""
|
|
148
|
+
policy_class_name = self.policy_class.__name__
|
|
149
|
+
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
|
|
150
|
+
warnings.warn(
|
|
151
|
+
f"You are trying to run {self.__class__.__name__} on the GPU, "
|
|
152
|
+
"but it is primarily intended to run on the CPU when not using a CNN policy "
|
|
153
|
+
f"(you are using {policy_class_name} which should be a MlpPolicy). "
|
|
154
|
+
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
|
|
155
|
+
"for more info. "
|
|
156
|
+
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
|
|
157
|
+
"Note: The model will train, but the GPU utilization will be poor and "
|
|
158
|
+
"the training might take longer than on CPU.",
|
|
159
|
+
UserWarning,
|
|
160
|
+
)
|
|
138
161
|
|
|
139
162
|
def collect_rollouts(
|
|
140
163
|
self,
|
|
@@ -208,7 +231,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|
|
208
231
|
# Reshape in case of discrete action
|
|
209
232
|
actions = actions.reshape(-1, 1)
|
|
210
233
|
|
|
211
|
-
# Handle timeout by
|
|
234
|
+
# Handle timeout by bootstrapping with value function
|
|
212
235
|
# see GitHub issue #633
|
|
213
236
|
for idx, done in enumerate(dones):
|
|
214
237
|
if (
|
|
@@ -367,7 +367,7 @@ class BasePolicy(BaseModel, ABC):
|
|
|
367
367
|
with th.no_grad():
|
|
368
368
|
actions = self._predict(obs_tensor, deterministic=deterministic)
|
|
369
369
|
# Convert to numpy, and reshape to the original action shape
|
|
370
|
-
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
|
|
370
|
+
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment]
|
|
371
371
|
|
|
372
372
|
if isinstance(self.action_space, spaces.Box):
|
|
373
373
|
if self.squash_output:
|
|
@@ -922,7 +922,7 @@ class ContinuousCritic(BaseModel):
|
|
|
922
922
|
By default, it creates two critic networks used to reduce overestimation
|
|
923
923
|
thanks to clipped Q-learning (cf TD3 paper).
|
|
924
924
|
|
|
925
|
-
:param observation_space:
|
|
925
|
+
:param observation_space: Observation space
|
|
926
926
|
:param action_space: Action space
|
|
927
927
|
:param net_arch: Network architecture
|
|
928
928
|
:param features_extractor: Network to extract features
|
{stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/results_plotter.py
RENAMED
|
@@ -46,7 +46,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl
|
|
|
46
46
|
|
|
47
47
|
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
|
|
48
48
|
"""
|
|
49
|
-
Decompose a data frame variable to x
|
|
49
|
+
Decompose a data frame variable to x and ys
|
|
50
50
|
|
|
51
51
|
:param data_frame: the input data
|
|
52
52
|
:param x_axis: the axis for the x and y output
|
{stable_baselines3-2.3.2 → stable_baselines3-2.4.0}/stable_baselines3/common/running_mean_std.py
RENAMED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
class RunningMeanStd:
|
|
7
7
|
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
|
|
8
8
|
"""
|
|
9
|
-
|
|
9
|
+
Calculates the running mean and std of a data stream
|
|
10
10
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
|
11
11
|
|
|
12
12
|
:param epsilon: helps with arithmetic issues
|