stable-baselines3 2.3.0a5__tar.gz → 2.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. {stable_baselines3-2.3.0a5/stable_baselines3.egg-info → stable_baselines3-2.3.2}/PKG-INFO +1 -1
  2. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/README.md +1 -1
  3. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/save_util.py +2 -1
  4. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/utils.py +3 -1
  5. stable_baselines3-2.3.2/stable_baselines3/version.txt +1 -0
  6. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2/stable_baselines3.egg-info}/PKG-INFO +1 -1
  7. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_save_load.py +14 -0
  8. stable_baselines3-2.3.0a5/stable_baselines3/version.txt +0 -1
  9. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/LICENSE +0 -0
  10. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/NOTICE +0 -0
  11. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/pyproject.toml +0 -0
  12. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/setup.cfg +0 -0
  13. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/setup.py +0 -0
  14. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/__init__.py +0 -0
  15. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/a2c/__init__.py +0 -0
  16. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/a2c/a2c.py +0 -0
  17. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/a2c/policies.py +0 -0
  18. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/__init__.py +0 -0
  19. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/atari_wrappers.py +0 -0
  20. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/base_class.py +0 -0
  21. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/buffers.py +0 -0
  22. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/callbacks.py +0 -0
  23. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/distributions.py +0 -0
  24. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/env_checker.py +0 -0
  25. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/env_util.py +0 -0
  26. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/envs/__init__.py +0 -0
  27. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/envs/bit_flipping_env.py +0 -0
  28. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/envs/identity_env.py +0 -0
  29. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/envs/multi_input_envs.py +0 -0
  30. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/evaluation.py +0 -0
  31. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/logger.py +0 -0
  32. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/monitor.py +0 -0
  33. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/noise.py +0 -0
  34. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/off_policy_algorithm.py +0 -0
  35. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/on_policy_algorithm.py +0 -0
  36. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/policies.py +0 -0
  37. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/preprocessing.py +0 -0
  38. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/results_plotter.py +0 -0
  39. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/running_mean_std.py +0 -0
  40. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/sb2_compat/__init__.py +0 -0
  41. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +0 -0
  42. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/torch_layers.py +0 -0
  43. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/type_aliases.py +0 -0
  44. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/__init__.py +0 -0
  45. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/base_vec_env.py +0 -0
  46. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/dummy_vec_env.py +0 -0
  47. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/patch_gym.py +0 -0
  48. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/stacked_observations.py +0 -0
  49. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/subproc_vec_env.py +0 -0
  50. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/util.py +0 -0
  51. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/vec_check_nan.py +0 -0
  52. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/vec_extract_dict_obs.py +0 -0
  53. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/vec_frame_stack.py +0 -0
  54. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/vec_monitor.py +0 -0
  55. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/vec_normalize.py +0 -0
  56. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/vec_transpose.py +0 -0
  57. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/common/vec_env/vec_video_recorder.py +0 -0
  58. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/ddpg/__init__.py +0 -0
  59. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/ddpg/ddpg.py +0 -0
  60. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/ddpg/policies.py +0 -0
  61. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/dqn/__init__.py +0 -0
  62. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/dqn/dqn.py +0 -0
  63. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/dqn/policies.py +0 -0
  64. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/her/__init__.py +0 -0
  65. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/her/goal_selection_strategy.py +0 -0
  66. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/her/her_replay_buffer.py +0 -0
  67. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/ppo/__init__.py +0 -0
  68. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/ppo/policies.py +0 -0
  69. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/ppo/ppo.py +0 -0
  70. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/py.typed +0 -0
  71. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/sac/__init__.py +0 -0
  72. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/sac/policies.py +0 -0
  73. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/sac/sac.py +0 -0
  74. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/td3/__init__.py +0 -0
  75. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/td3/policies.py +0 -0
  76. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3/td3/td3.py +0 -0
  77. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3.egg-info/SOURCES.txt +0 -0
  78. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3.egg-info/dependency_links.txt +0 -0
  79. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3.egg-info/requires.txt +0 -0
  80. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/stable_baselines3.egg-info/top_level.txt +0 -0
  81. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_buffers.py +0 -0
  82. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_callbacks.py +0 -0
  83. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_cnn.py +0 -0
  84. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_custom_policy.py +0 -0
  85. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_deterministic.py +0 -0
  86. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_dict_env.py +0 -0
  87. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_distributions.py +0 -0
  88. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_env_checker.py +0 -0
  89. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_envs.py +0 -0
  90. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_gae.py +0 -0
  91. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_her.py +0 -0
  92. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_identity.py +0 -0
  93. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_logger.py +0 -0
  94. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_monitor.py +0 -0
  95. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_predict.py +0 -0
  96. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_preprocessing.py +0 -0
  97. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_run.py +0 -0
  98. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_sde.py +0 -0
  99. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_spaces.py +0 -0
  100. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_tensorboard.py +0 -0
  101. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_train_eval_mode.py +0 -0
  102. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_utils.py +0 -0
  103. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_vec_check_nan.py +0 -0
  104. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_vec_envs.py +0 -0
  105. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_vec_extract_dict_obs.py +0 -0
  106. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_vec_monitor.py +0 -0
  107. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_vec_normalize.py +0 -0
  108. {stable_baselines3-2.3.0a5 → stable_baselines3-2.3.2}/tests/test_vec_stacked_obs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: stable_baselines3
3
- Version: 2.3.0a5
3
+ Version: 2.3.2
4
4
  Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
5
5
  Home-page: https://github.com/DLR-RM/stable-baselines3
6
6
  Author: Antonin Raffin
@@ -85,7 +85,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/
85
85
 
86
86
  ## Stable-Baselines Jax (SBX)
87
87
 
88
- [Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax.
88
+ [Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ.
89
89
 
90
90
  It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698
91
91
 
@@ -447,7 +447,8 @@ def load_from_zip_file(
447
447
  file_content.seek(0)
448
448
  # Load the parameters with the right ``map_location``.
449
449
  # Remove ".pth" ending with splitext
450
- th_object = th.load(file_content, map_location=device, weights_only=True)
450
+ # Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911
451
+ th_object = th.load(file_content, map_location=device, weights_only=False)
451
452
  # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
452
453
  if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
453
454
  # PyTorch variables (not state_dicts)
@@ -92,7 +92,9 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
92
92
  value_schedule = constant_fn(float(value_schedule))
93
93
  else:
94
94
  assert callable(value_schedule)
95
- return value_schedule
95
+ # Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
96
+ # Some types are have odd behaviors when part of a Schedule, like numpy floats
97
+ return lambda progress_remaining: float(value_schedule(progress_remaining))
96
98
 
97
99
 
98
100
  def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: stable_baselines3
3
- Version: 2.3.0a5
3
+ Version: 2.3.2
4
4
  Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
5
5
  Home-page: https://github.com/DLR-RM/stable-baselines3
6
6
  Author: Antonin Raffin
@@ -783,3 +783,17 @@ def test_no_resource_warning(tmp_path):
783
783
  fp.seek(0)
784
784
  model.load_replay_buffer(fp)
785
785
  assert not fp.closed
786
+
787
+
788
+ def test_cast_lr_schedule(tmp_path):
789
+ # See GH#1900
790
+ model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda t: t * np.sin(1.0))
791
+ # Note: for recent version of numpy, np.float64 is a subclass of float
792
+ # so we need to use type here
793
+ # assert isinstance(model.lr_schedule(1.0), float)
794
+ assert type(model.lr_schedule(1.0)) is float # noqa: E721
795
+ assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))
796
+ model.save(tmp_path / "ppo.zip")
797
+ model = PPO.load(tmp_path / "ppo.zip")
798
+ assert type(model.lr_schedule(1.0)) is float # noqa: E721
799
+ assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))
@@ -1 +0,0 @@
1
- 2.3.0a5