rlgym-learn-algos 0.2.1__cp311-cp311-win32.whl → 0.2.2__cp311-cp311-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,7 @@ import numpy as np
7
7
  import torch
8
8
  from pydantic import BaseModel, Field, model_validator
9
9
  from rlgym.api import ActionType, AgentID, ObsType, RewardType
10
+
10
11
  from rlgym_learn_algos.util.torch_functions import get_device
11
12
  from rlgym_learn_algos.util.torch_pydantic import PydanticTorchDevice
12
13
 
@@ -24,6 +25,7 @@ EXPERIENCE_BUFFER_FILE = "experience_buffer.pkl"
24
25
  class ExperienceBufferConfigModel(BaseModel, extra="forbid"):
25
26
  max_size: int = 100000
26
27
  device: PydanticTorchDevice = "auto"
28
+ save_experience_buffer_in_checkpoint: bool = True
27
29
  trajectory_processor_config: Dict[str, Any] = Field(default_factory=dict)
28
30
 
29
31
  @model_validator(mode="before")
@@ -40,8 +42,9 @@ class ExperienceBufferConfigModel(BaseModel, extra="forbid"):
40
42
  data["trajectory_processor_config"] = data[
41
43
  "trajectory_processor_config"
42
44
  ].model_dump()
43
- if "device" not in data or data["device"] == "auto":
44
- data["device"] = get_device("auto")
45
+ if "device" not in data:
46
+ data["device"] = "auto"
47
+ data["device"] = get_device(data["device"])
45
48
  return data
46
49
 
47
50
 
@@ -165,21 +168,22 @@ class ExperienceBuffer(
165
168
 
166
169
  def save_checkpoint(self, folder_path):
167
170
  os.makedirs(folder_path, exist_ok=True)
168
- with open(
169
- os.path.join(folder_path, EXPERIENCE_BUFFER_FILE),
170
- "wb",
171
- ) as f:
172
- pickle.dump(
173
- {
174
- "agent_ids": self.agent_ids,
175
- "observations": self.observations,
176
- "actions": self.actions,
177
- "log_probs": self.log_probs,
178
- "values": self.values,
179
- "advantages": self.advantages,
180
- },
181
- f,
182
- )
171
+ if self.config.experience_buffer_config.save_experience_buffer_in_checkpoint:
172
+ with open(
173
+ os.path.join(folder_path, EXPERIENCE_BUFFER_FILE),
174
+ "wb",
175
+ ) as f:
176
+ pickle.dump(
177
+ {
178
+ "agent_ids": self.agent_ids,
179
+ "observations": self.observations,
180
+ "actions": self.actions,
181
+ "log_probs": self.log_probs,
182
+ "values": self.values,
183
+ "advantages": self.advantages,
184
+ },
185
+ f,
186
+ )
183
187
  self.trajectory_processor.save_checkpoint(folder_path)
184
188
 
185
189
  # TODO: update docs
@@ -39,17 +39,25 @@ class PPOLearnerConfigModel(BaseModel, extra="forbid"):
39
39
  clip_range: float = 0.2
40
40
  actor_lr: float = 3e-4
41
41
  critic_lr: float = 3e-4
42
+ advantage_normalization: bool = True
42
43
  device: PydanticTorchDevice = "auto"
44
+ cudnn_benchmark_mode: bool = True
43
45
 
44
46
  @model_validator(mode="before")
45
47
  @classmethod
46
48
  def set_device(cls, data):
47
- if isinstance(data, dict) and (
48
- "device" not in data or data["device"] == "auto"
49
- ):
50
- data["device"] = get_device("auto")
49
+ if isinstance(data, dict):
50
+ if "device" not in data:
51
+ data["device"] = "auto"
52
+ data["device"] = get_device(data["device"])
51
53
  return data
52
54
 
55
+ @model_validator(mode="after")
56
+ def validate_cudnn_benchmark(self):
57
+ if self.device.type != "cuda":
58
+ self.cudnn_benchmark_mode = False
59
+ return self
60
+
53
61
 
54
62
  @dataclass
55
63
  class DerivedPPOLearnerConfig:
@@ -107,6 +115,12 @@ class PPOLearner(
107
115
  def load(self, config: DerivedPPOLearnerConfig):
108
116
  self.config = config
109
117
 
118
+ if (
119
+ config.learner_config.cudnn_benchmark_mode
120
+ and config.learner_config.device.type == "cuda"
121
+ ):
122
+ torch.backends.cudnn.benchmark = True
123
+
110
124
  self.actor = self.actor_factory(
111
125
  config.obs_space, config.action_space, config.learner_config.device
112
126
  )
@@ -292,6 +306,10 @@ class PPOLearner(
292
306
  advantages = batch_advantages[start:stop].to(
293
307
  self.config.learner_config.device
294
308
  )
309
+ if self.config.learner_config.advantage_normalization:
310
+ advantages = (advantages - torch.mean(advantages)) / (
311
+ torch.std(advantages) + 1e-8
312
+ )
295
313
  old_probs = batch_old_probs[start:stop].to(
296
314
  self.config.learner_config.device
297
315
  )
@@ -13,7 +13,6 @@ import torch.nn as nn
13
13
 
14
14
  def get_device(device: str):
15
15
  if device in ["auto", "gpu"] and torch.cuda.is_available():
16
- torch.backends.cudnn.benchmark = True
17
16
  return "cuda:0"
18
17
  elif device == "auto" and not torch.cuda.is_available():
19
18
  return "cpu"
@@ -42,7 +42,7 @@ device_str_regex = (
42
42
  "privateuseone",
43
43
  ]
44
44
  )
45
- + ")(:\d+)"
45
+ + ")(:\d+)?"
46
46
  )
47
47
 
48
48
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rlgym-learn-algos
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Requires-Dist: pydantic>=2.8.2
@@ -1,6 +1,6 @@
1
- rlgym_learn_algos-0.2.1.dist-info/METADATA,sha256=pRplMtq88vWNms7sVhWJfwf-W7GATsQH3617hrkNl3s,2431
2
- rlgym_learn_algos-0.2.1.dist-info/WHEEL,sha256=1cEMGItI5ycdglW0xuhYFX4p-uaMeYRsFVmN9jJO6VY,92
3
- rlgym_learn_algos-0.2.1.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
1
+ rlgym_learn_algos-0.2.2.dist-info/METADATA,sha256=4wwr9xqqVWvZ7HYM4cumHiRdz79gkixfpe11b4MyvSU,2431
2
+ rlgym_learn_algos-0.2.2.dist-info/WHEEL,sha256=zfc_r7GoDDc6Hz8pkKR77dEQzJMJDgktYOCKUorzovQ,92
3
+ rlgym_learn_algos-0.2.2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
4
4
  rlgym_learn_algos/__init__.py,sha256=C7cRdL4lZrpk3ge_4_lGAbGodqWJXM56FfgO0keRPAY,207
5
5
  rlgym_learn_algos/conversion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  rlgym_learn_algos/conversion/convert_rlgym_ppo_checkpoint.py,sha256=A9nvzjp3DQNRNL5TAt-u3xE80JDIpYEDqAGNReHvFG0,908
@@ -15,18 +15,18 @@ rlgym_learn_algos/ppo/continuous_actor.py,sha256=1vdBUw2mQNFNu6A6ZrAztBjd4DmwjGk
15
15
  rlgym_learn_algos/ppo/critic.py,sha256=RB89WtiN52BEq5QCpGAPrASUnasac-Bpg7B0lM3UXHw,689
16
16
  rlgym_learn_algos/ppo/discrete_actor.py,sha256=Nuc3EndIQud3NGrkBIQgy-Z-okhXVrj6p6okSGD1KNY,2620
17
17
  rlgym_learn_algos/ppo/env_trajectories.py,sha256=gzQBRkzwZhlZeSvWL50cc8AOgBfsg5zUys0aTJj6aZU,3775
18
- rlgym_learn_algos/ppo/experience_buffer.py,sha256=QdyFMMM8YpEYrmtFaeaHXvFlNT2pCZwQKBEqsrv4v2I,10838
18
+ rlgym_learn_algos/ppo/experience_buffer.py,sha256=xDm8NIMdErpv3GyWUBcTvzkLBQa8tW1TXb7OrKRDIu4,11059
19
19
  rlgym_learn_algos/ppo/experience_buffer_numpy.py,sha256=Apk4x-pfRnitKJPW6LBZyOPIhgeJs_5EG7BbTCqMwjk,4761
20
20
  rlgym_learn_algos/ppo/gae_trajectory_processor.py,sha256=JK958vasIIiuf3ALcFNlvBgGNhFshK8MhQJjwvxhrAM,5453
21
21
  rlgym_learn_algos/ppo/gae_trajectory_processor_pure_python.py,sha256=RpyDR6GQ1JXvwtoKkx5V3z3WvU9ElJdzfNtpPiZDaTc,6831
22
22
  rlgym_learn_algos/ppo/multi_discrete_actor.py,sha256=zSYeBBirjguSv_wO-peo06hioHiVhZQjnd-NYwJxmag,3127
23
23
  rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=h0UR-o2k-_LyeFTzvII3HQHHWyeMJewqLlca8ThtyfA,25105
24
- rlgym_learn_algos/ppo/ppo_learner.py,sha256=3YTfs7LhjiJ0u3-k84rYWcmAQxKIf2yp1i1UVY4v8Oc,15229
24
+ rlgym_learn_algos/ppo/ppo_learner.py,sha256=Cbbuz0AMwPCmkQ1YPDdZLkbgZOdyrOLEx89Camn-nGE,15942
25
25
  rlgym_learn_algos/ppo/ppo_metrics_logger.py,sha256=niW8xgQLEBCGgTaVyiE_JqsU6RTjV6h-JzM-7c3JT38,2868
26
26
  rlgym_learn_algos/ppo/trajectory.py,sha256=IIH_IG8B_HkyxRPf-YsCyF1jQqNjDx752hgzAehG25I,719
27
27
  rlgym_learn_algos/ppo/trajectory_processor.py,sha256=5eY_mNGjqIkhqnbKeaqDvqIWPdg6wD6Ai3fXH2WoXbw,2091
28
28
  rlgym_learn_algos/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- rlgym_learn_algos/rlgym_learn_algos.cp311-win32.pyd,sha256=TQ1jNEan9yh_o1CkLLdtdke1q8DvOPsosdz3Fcdt1PE,337920
29
+ rlgym_learn_algos/rlgym_learn_algos.cp311-win32.pyd,sha256=wSNbyQLW2s8XROJJ3CwpPBEPaYWqpAsZJ8syylePLNE,337920
30
30
  rlgym_learn_algos/rlgym_learn_algos.pyi,sha256=NwY-sDZWM06TUiKPzxpfH1Td6G6E8TdxtRPgBSh-PPE,1203
31
31
  rlgym_learn_algos/stateful_functions/__init__.py,sha256=QS0KYjuzagNkYiYllXQmjoJn14-G7KZawq1Zvwh8alY,236
32
32
  rlgym_learn_algos/stateful_functions/batch_reward_type_numpy_converter.py,sha256=1yte5qYyl9LWdClHZ_YsF7R9dJqQeYfINMdgNF_59Gs,767
@@ -34,6 +34,6 @@ rlgym_learn_algos/stateful_functions/numpy_obs_standardizer.py,sha256=OgtwCaxBGT
34
34
  rlgym_learn_algos/stateful_functions/obs_standardizer.py,sha256=qPPc3--J_3mpJJ-QHJjta6dbWWBobL7SYdK5MUP-XMw,606
35
35
  rlgym_learn_algos/util/__init__.py,sha256=VPM6SN4T_625H9t30s9EiLeXiEEWgcyRVHa-LLVNrn4,47
36
36
  rlgym_learn_algos/util/running_stats.py,sha256=0tiGFpKtHWzMa1CxM_ueBzd_ryX4bJBriC8MXcSLg8w,4479
37
- rlgym_learn_algos/util/torch_functions.py,sha256=CTTHzTIi7u1O9HyX0cVJOrnYVbAtnlVs0g1fO9s3ano,3458
38
- rlgym_learn_algos/util/torch_pydantic.py,sha256=pgj3I-3q8iW9qtOCv1fgjNkZgA00G_Rdkb4qJPk5gxo,3530
39
- rlgym_learn_algos-0.2.1.dist-info/RECORD,,
37
+ rlgym_learn_algos/util/torch_functions.py,sha256=_uAXhq1YYPneWI3_XXRYsSA3Hn1a8wGjUnI3m9UojdU,3411
38
+ rlgym_learn_algos/util/torch_pydantic.py,sha256=5AbXQcfQtVgLRBSgCj0Hvi_H42WHLu4Oty4l_i22nAo,3531
39
+ rlgym_learn_algos-0.2.2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: maturin (1.8.6)
2
+ Generator: maturin (1.9.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp311-cp311-win32