rlgym-learn-algos 0.2.1__cp38-cp38-musllinux_1_2_i686.whl → 0.2.2__cp38-cp38-musllinux_1_2_i686.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,7 @@ import numpy as np
7
7
  import torch
8
8
  from pydantic import BaseModel, Field, model_validator
9
9
  from rlgym.api import ActionType, AgentID, ObsType, RewardType
10
+
10
11
  from rlgym_learn_algos.util.torch_functions import get_device
11
12
  from rlgym_learn_algos.util.torch_pydantic import PydanticTorchDevice
12
13
 
@@ -24,6 +25,7 @@ EXPERIENCE_BUFFER_FILE = "experience_buffer.pkl"
24
25
  class ExperienceBufferConfigModel(BaseModel, extra="forbid"):
25
26
  max_size: int = 100000
26
27
  device: PydanticTorchDevice = "auto"
28
+ save_experience_buffer_in_checkpoint: bool = True
27
29
  trajectory_processor_config: Dict[str, Any] = Field(default_factory=dict)
28
30
 
29
31
  @model_validator(mode="before")
@@ -40,8 +42,9 @@ class ExperienceBufferConfigModel(BaseModel, extra="forbid"):
40
42
  data["trajectory_processor_config"] = data[
41
43
  "trajectory_processor_config"
42
44
  ].model_dump()
43
- if "device" not in data or data["device"] == "auto":
44
- data["device"] = get_device("auto")
45
+ if "device" not in data:
46
+ data["device"] = "auto"
47
+ data["device"] = get_device(data["device"])
45
48
  return data
46
49
 
47
50
 
@@ -165,21 +168,22 @@ class ExperienceBuffer(
165
168
 
166
169
  def save_checkpoint(self, folder_path):
167
170
  os.makedirs(folder_path, exist_ok=True)
168
- with open(
169
- os.path.join(folder_path, EXPERIENCE_BUFFER_FILE),
170
- "wb",
171
- ) as f:
172
- pickle.dump(
173
- {
174
- "agent_ids": self.agent_ids,
175
- "observations": self.observations,
176
- "actions": self.actions,
177
- "log_probs": self.log_probs,
178
- "values": self.values,
179
- "advantages": self.advantages,
180
- },
181
- f,
182
- )
171
+ if self.config.experience_buffer_config.save_experience_buffer_in_checkpoint:
172
+ with open(
173
+ os.path.join(folder_path, EXPERIENCE_BUFFER_FILE),
174
+ "wb",
175
+ ) as f:
176
+ pickle.dump(
177
+ {
178
+ "agent_ids": self.agent_ids,
179
+ "observations": self.observations,
180
+ "actions": self.actions,
181
+ "log_probs": self.log_probs,
182
+ "values": self.values,
183
+ "advantages": self.advantages,
184
+ },
185
+ f,
186
+ )
183
187
  self.trajectory_processor.save_checkpoint(folder_path)
184
188
 
185
189
  # TODO: update docs
@@ -39,17 +39,25 @@ class PPOLearnerConfigModel(BaseModel, extra="forbid"):
39
39
  clip_range: float = 0.2
40
40
  actor_lr: float = 3e-4
41
41
  critic_lr: float = 3e-4
42
+ advantage_normalization: bool = True
42
43
  device: PydanticTorchDevice = "auto"
44
+ cudnn_benchmark_mode: bool = True
43
45
 
44
46
  @model_validator(mode="before")
45
47
  @classmethod
46
48
  def set_device(cls, data):
47
- if isinstance(data, dict) and (
48
- "device" not in data or data["device"] == "auto"
49
- ):
50
- data["device"] = get_device("auto")
49
+ if isinstance(data, dict):
50
+ if "device" not in data:
51
+ data["device"] = "auto"
52
+ data["device"] = get_device(data["device"])
51
53
  return data
52
54
 
55
+ @model_validator(mode="after")
56
+ def validate_cudnn_benchmark(self):
57
+ if self.device.type != "cuda":
58
+ self.cudnn_benchmark_mode = False
59
+ return self
60
+
53
61
 
54
62
  @dataclass
55
63
  class DerivedPPOLearnerConfig:
@@ -107,6 +115,12 @@ class PPOLearner(
107
115
  def load(self, config: DerivedPPOLearnerConfig):
108
116
  self.config = config
109
117
 
118
+ if (
119
+ config.learner_config.cudnn_benchmark_mode
120
+ and config.learner_config.device.type == "cuda"
121
+ ):
122
+ torch.backends.cudnn.benchmark = True
123
+
110
124
  self.actor = self.actor_factory(
111
125
  config.obs_space, config.action_space, config.learner_config.device
112
126
  )
@@ -292,6 +306,10 @@ class PPOLearner(
292
306
  advantages = batch_advantages[start:stop].to(
293
307
  self.config.learner_config.device
294
308
  )
309
+ if self.config.learner_config.advantage_normalization:
310
+ advantages = (advantages - torch.mean(advantages)) / (
311
+ torch.std(advantages) + 1e-8
312
+ )
295
313
  old_probs = batch_old_probs[start:stop].to(
296
314
  self.config.learner_config.device
297
315
  )
@@ -13,7 +13,6 @@ import torch.nn as nn
13
13
 
14
14
  def get_device(device: str):
15
15
  if device in ["auto", "gpu"] and torch.cuda.is_available():
16
- torch.backends.cudnn.benchmark = True
17
16
  return "cuda:0"
18
17
  elif device == "auto" and not torch.cuda.is_available():
19
18
  return "cpu"
@@ -42,7 +42,7 @@ device_str_regex = (
42
42
  "privateuseone",
43
43
  ]
44
44
  )
45
- + ")(:\d+)"
45
+ + ")(:\d+)?"
46
46
  )
47
47
 
48
48
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rlgym-learn-algos
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Requires-Dist: pydantic>=2.8.2
@@ -1,6 +1,6 @@
1
- rlgym_learn_algos-0.2.1.dist-info/METADATA,sha256=u9tM8KxVpm3nQ86a01tlDy5csGAeTmCMDZdccXgrUrg,2403
2
- rlgym_learn_algos-0.2.1.dist-info/WHEEL,sha256=X1CAAG5PtBwuCUbvZLy-66J-Jfin3nupluhqSslkPtI,103
3
- rlgym_learn_algos-0.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1
+ rlgym_learn_algos-0.2.2.dist-info/METADATA,sha256=9m49XNjxVc1H0Hj27PV8BszTk9DnOBNBNIjdxb0uwcQ,2403
2
+ rlgym_learn_algos-0.2.2.dist-info/WHEEL,sha256=zrTr349f3UPOfsn1HRNexsaeA05ZzRsCL0d94NnDDZc,103
3
+ rlgym_learn_algos-0.2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
4
4
  rlgym_learn_algos.libs/libgcc_s-b5472b99.so.1,sha256=wh8CpjXz9IccAyeERcB7YDEx7NH2jF-PykwOyYNeRRI,453841
5
5
  rlgym_learn_algos/__init__.py,sha256=dZeTgNro6qG1Hu0l0UBhgHOYiyeCwPWndC84dJAp__U,203
6
6
  rlgym_learn_algos/conversion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,18 +16,18 @@ rlgym_learn_algos/ppo/continuous_actor.py,sha256=A4FQ0lKqlB47AeSrDdrPXMvNKXhl5to
16
16
  rlgym_learn_algos/ppo/critic.py,sha256=XPleWDO8uM25zlzptWDvZQpUpKvib5kRs9JpmWTVPuY,669
17
17
  rlgym_learn_algos/ppo/discrete_actor.py,sha256=TZC7b7ss16giobPC1oz-maOSDX-SrNBUzS1wIV2Rzgw,2547
18
18
  rlgym_learn_algos/ppo/env_trajectories.py,sha256=PaO6dmpNkQ3yDLaHIRc0ipn45t5zAjE5U1D_N-LQtgY,3684
19
- rlgym_learn_algos/ppo/experience_buffer.py,sha256=f_baTo18JKBQjTQoO0FJoCOWlZFV8vJgMXqevg-TSi8,10552
19
+ rlgym_learn_algos/ppo/experience_buffer.py,sha256=4wWSfq1tobXv7lmBbkM5sbTVuPJxrdAbxi5rNHc04g4,10769
20
20
  rlgym_learn_algos/ppo/experience_buffer_numpy.py,sha256=6AOGQjDn_dHLS9bmxJW_cGEjBUbe8u5VWS0LVlpIdmY,4617
21
21
  rlgym_learn_algos/ppo/gae_trajectory_processor.py,sha256=gv5kxvvPnK7SyQIAq6MbOFILIMdPlzoLZwM8TRmtNWw,5302
22
22
  rlgym_learn_algos/ppo/gae_trajectory_processor_pure_python.py,sha256=cq7qbK0mcLDXRzA6-pKW0OC50X52XhT5himcOTD6Ei4,6657
23
23
  rlgym_learn_algos/ppo/multi_discrete_actor.py,sha256=Mik0X79dUy2ZRIMol4RMTZE9qzsOk6f_6bDaOl5ghxs,3039
24
24
  rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=iVmCvN7H1IcKX7VrZnMParQ148EnPqs9yZ9CCgtdsq8,24524
25
- rlgym_learn_algos/ppo/ppo_learner.py,sha256=oRDBQszbgFvoVmGjRqpklQWTOI2NazOLzhDU_c9-SAU,14827
25
+ rlgym_learn_algos/ppo/ppo_learner.py,sha256=utEWkikXCpC6Xc1D3asohO0HsIaq3tLyoTlb7fXLOw4,15522
26
26
  rlgym_learn_algos/ppo/ppo_metrics_logger.py,sha256=iUyUc2GPwDIIjZeJPZWxoeRrzUWV_qLOac0vApQBkp0,2803
27
27
  rlgym_learn_algos/ppo/trajectory.py,sha256=_xyS9ueU6iVvqMUpFr-kb42wEHHZy4zCse7_r660n5E,690
28
28
  rlgym_learn_algos/ppo/trajectory_processor.py,sha256=3XRsXXexHWp6UV5nAeBLYvWqvQ9EbNHSN3Yooi4cezo,2031
29
29
  rlgym_learn_algos/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- rlgym_learn_algos/rlgym_learn_algos.cpython-38-i386-linux-gnu.so,sha256=VL0hdEQVoAI5zvHCFLRz6jYIARLjUJwoGNHQwhBgHyA,719349
30
+ rlgym_learn_algos/rlgym_learn_algos.cpython-38-i386-linux-gnu.so,sha256=DseEGF0EbwC1IS58CCtlI3cmsPJhtH-Cob_e1wQpHX8,707061
31
31
  rlgym_learn_algos/rlgym_learn_algos.pyi,sha256=B9Kt9uK8xCqASRxWvzLdV501TSCMO4vTNqvZ0MhOHyo,1164
32
32
  rlgym_learn_algos/stateful_functions/__init__.py,sha256=OAVy6cQIS85Utyp18jjHgdmascX_8nkwk3A0OpFJxT4,230
33
33
  rlgym_learn_algos/stateful_functions/batch_reward_type_numpy_converter.py,sha256=a3q2l5SIgDI36ImF_kYoa684pghnFnlV2vGYvV2zcV0,743
@@ -35,6 +35,6 @@ rlgym_learn_algos/stateful_functions/numpy_obs_standardizer.py,sha256=Xa_fuJCSGu
35
35
  rlgym_learn_algos/stateful_functions/obs_standardizer.py,sha256=m2nw1JUg2MKYthn6tWrv2HYIuQ-GfUm48RX9s99jXF4,589
36
36
  rlgym_learn_algos/util/__init__.py,sha256=hq7M00Q7zAfyQmIGmXOif0vI40aj_FQ5SqI5dnuGvb0,46
37
37
  rlgym_learn_algos/util/running_stats.py,sha256=KtzdKKT75-5ZC58JRqaDXk6sBqa3ZSjQQZrRajAw3Yk,4339
38
- rlgym_learn_algos/util/torch_functions.py,sha256=6esZL8FeVwWWQWvwLEhkh_B0WqfIWzIWkCDbuhmRlnk,3366
39
- rlgym_learn_algos/util/torch_pydantic.py,sha256=zXllJoV8HgqJxguPKJ4Y3DIWEwDeJlBW9CIps-yxM44,3412
40
- rlgym_learn_algos-0.2.1.dist-info/RECORD,,
38
+ rlgym_learn_algos/util/torch_functions.py,sha256=ImgDw4I3ZixGDi17YRkW6UbaiaQTbvOCUCS7N0QVSsU,3320
39
+ rlgym_learn_algos/util/torch_pydantic.py,sha256=khPGA6kWh4_WHoploDkl_SCIGX8SkKkFT40RE06PImc,3413
40
+ rlgym_learn_algos-0.2.2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: maturin (1.8.6)
2
+ Generator: maturin (1.9.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp38-cp38-musllinux_1_2_i686