rlgym-learn-algos 0.2.1__cp311-cp311-win32.whl → 0.2.2__cp311-cp311-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rlgym_learn_algos/ppo/experience_buffer.py +21 -17
- rlgym_learn_algos/ppo/ppo_learner.py +22 -4
- rlgym_learn_algos/rlgym_learn_algos.cp311-win32.pyd +0 -0
- rlgym_learn_algos/util/torch_functions.py +0 -1
- rlgym_learn_algos/util/torch_pydantic.py +1 -1
- {rlgym_learn_algos-0.2.1.dist-info → rlgym_learn_algos-0.2.2.dist-info}/METADATA +1 -1
- {rlgym_learn_algos-0.2.1.dist-info → rlgym_learn_algos-0.2.2.dist-info}/RECORD +9 -9
- {rlgym_learn_algos-0.2.1.dist-info → rlgym_learn_algos-0.2.2.dist-info}/WHEEL +1 -1
- {rlgym_learn_algos-0.2.1.dist-info → rlgym_learn_algos-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,7 @@ import numpy as np
|
|
7
7
|
import torch
|
8
8
|
from pydantic import BaseModel, Field, model_validator
|
9
9
|
from rlgym.api import ActionType, AgentID, ObsType, RewardType
|
10
|
+
|
10
11
|
from rlgym_learn_algos.util.torch_functions import get_device
|
11
12
|
from rlgym_learn_algos.util.torch_pydantic import PydanticTorchDevice
|
12
13
|
|
@@ -24,6 +25,7 @@ EXPERIENCE_BUFFER_FILE = "experience_buffer.pkl"
|
|
24
25
|
class ExperienceBufferConfigModel(BaseModel, extra="forbid"):
|
25
26
|
max_size: int = 100000
|
26
27
|
device: PydanticTorchDevice = "auto"
|
28
|
+
save_experience_buffer_in_checkpoint: bool = True
|
27
29
|
trajectory_processor_config: Dict[str, Any] = Field(default_factory=dict)
|
28
30
|
|
29
31
|
@model_validator(mode="before")
|
@@ -40,8 +42,9 @@ class ExperienceBufferConfigModel(BaseModel, extra="forbid"):
|
|
40
42
|
data["trajectory_processor_config"] = data[
|
41
43
|
"trajectory_processor_config"
|
42
44
|
].model_dump()
|
43
|
-
if "device" not in data
|
44
|
-
data["device"] =
|
45
|
+
if "device" not in data:
|
46
|
+
data["device"] = "auto"
|
47
|
+
data["device"] = get_device(data["device"])
|
45
48
|
return data
|
46
49
|
|
47
50
|
|
@@ -165,21 +168,22 @@ class ExperienceBuffer(
|
|
165
168
|
|
166
169
|
def save_checkpoint(self, folder_path):
|
167
170
|
os.makedirs(folder_path, exist_ok=True)
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
171
|
+
if self.config.experience_buffer_config.save_experience_buffer_in_checkpoint:
|
172
|
+
with open(
|
173
|
+
os.path.join(folder_path, EXPERIENCE_BUFFER_FILE),
|
174
|
+
"wb",
|
175
|
+
) as f:
|
176
|
+
pickle.dump(
|
177
|
+
{
|
178
|
+
"agent_ids": self.agent_ids,
|
179
|
+
"observations": self.observations,
|
180
|
+
"actions": self.actions,
|
181
|
+
"log_probs": self.log_probs,
|
182
|
+
"values": self.values,
|
183
|
+
"advantages": self.advantages,
|
184
|
+
},
|
185
|
+
f,
|
186
|
+
)
|
183
187
|
self.trajectory_processor.save_checkpoint(folder_path)
|
184
188
|
|
185
189
|
# TODO: update docs
|
@@ -39,17 +39,25 @@ class PPOLearnerConfigModel(BaseModel, extra="forbid"):
|
|
39
39
|
clip_range: float = 0.2
|
40
40
|
actor_lr: float = 3e-4
|
41
41
|
critic_lr: float = 3e-4
|
42
|
+
advantage_normalization: bool = True
|
42
43
|
device: PydanticTorchDevice = "auto"
|
44
|
+
cudnn_benchmark_mode: bool = True
|
43
45
|
|
44
46
|
@model_validator(mode="before")
|
45
47
|
@classmethod
|
46
48
|
def set_device(cls, data):
|
47
|
-
if isinstance(data, dict)
|
48
|
-
"device" not in data
|
49
|
-
|
50
|
-
data["device"] = get_device("
|
49
|
+
if isinstance(data, dict):
|
50
|
+
if "device" not in data:
|
51
|
+
data["device"] = "auto"
|
52
|
+
data["device"] = get_device(data["device"])
|
51
53
|
return data
|
52
54
|
|
55
|
+
@model_validator(mode="after")
|
56
|
+
def validate_cudnn_benchmark(self):
|
57
|
+
if self.device.type != "cuda":
|
58
|
+
self.cudnn_benchmark_mode = False
|
59
|
+
return self
|
60
|
+
|
53
61
|
|
54
62
|
@dataclass
|
55
63
|
class DerivedPPOLearnerConfig:
|
@@ -107,6 +115,12 @@ class PPOLearner(
|
|
107
115
|
def load(self, config: DerivedPPOLearnerConfig):
|
108
116
|
self.config = config
|
109
117
|
|
118
|
+
if (
|
119
|
+
config.learner_config.cudnn_benchmark_mode
|
120
|
+
and config.learner_config.device.type == "cuda"
|
121
|
+
):
|
122
|
+
torch.backends.cudnn.benchmark = True
|
123
|
+
|
110
124
|
self.actor = self.actor_factory(
|
111
125
|
config.obs_space, config.action_space, config.learner_config.device
|
112
126
|
)
|
@@ -292,6 +306,10 @@ class PPOLearner(
|
|
292
306
|
advantages = batch_advantages[start:stop].to(
|
293
307
|
self.config.learner_config.device
|
294
308
|
)
|
309
|
+
if self.config.learner_config.advantage_normalization:
|
310
|
+
advantages = (advantages - torch.mean(advantages)) / (
|
311
|
+
torch.std(advantages) + 1e-8
|
312
|
+
)
|
295
313
|
old_probs = batch_old_probs[start:stop].to(
|
296
314
|
self.config.learner_config.device
|
297
315
|
)
|
Binary file
|
@@ -1,6 +1,6 @@
|
|
1
|
-
rlgym_learn_algos-0.2.
|
2
|
-
rlgym_learn_algos-0.2.
|
3
|
-
rlgym_learn_algos-0.2.
|
1
|
+
rlgym_learn_algos-0.2.2.dist-info/METADATA,sha256=4wwr9xqqVWvZ7HYM4cumHiRdz79gkixfpe11b4MyvSU,2431
|
2
|
+
rlgym_learn_algos-0.2.2.dist-info/WHEEL,sha256=zfc_r7GoDDc6Hz8pkKR77dEQzJMJDgktYOCKUorzovQ,92
|
3
|
+
rlgym_learn_algos-0.2.2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
4
4
|
rlgym_learn_algos/__init__.py,sha256=C7cRdL4lZrpk3ge_4_lGAbGodqWJXM56FfgO0keRPAY,207
|
5
5
|
rlgym_learn_algos/conversion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
rlgym_learn_algos/conversion/convert_rlgym_ppo_checkpoint.py,sha256=A9nvzjp3DQNRNL5TAt-u3xE80JDIpYEDqAGNReHvFG0,908
|
@@ -15,18 +15,18 @@ rlgym_learn_algos/ppo/continuous_actor.py,sha256=1vdBUw2mQNFNu6A6ZrAztBjd4DmwjGk
|
|
15
15
|
rlgym_learn_algos/ppo/critic.py,sha256=RB89WtiN52BEq5QCpGAPrASUnasac-Bpg7B0lM3UXHw,689
|
16
16
|
rlgym_learn_algos/ppo/discrete_actor.py,sha256=Nuc3EndIQud3NGrkBIQgy-Z-okhXVrj6p6okSGD1KNY,2620
|
17
17
|
rlgym_learn_algos/ppo/env_trajectories.py,sha256=gzQBRkzwZhlZeSvWL50cc8AOgBfsg5zUys0aTJj6aZU,3775
|
18
|
-
rlgym_learn_algos/ppo/experience_buffer.py,sha256=
|
18
|
+
rlgym_learn_algos/ppo/experience_buffer.py,sha256=xDm8NIMdErpv3GyWUBcTvzkLBQa8tW1TXb7OrKRDIu4,11059
|
19
19
|
rlgym_learn_algos/ppo/experience_buffer_numpy.py,sha256=Apk4x-pfRnitKJPW6LBZyOPIhgeJs_5EG7BbTCqMwjk,4761
|
20
20
|
rlgym_learn_algos/ppo/gae_trajectory_processor.py,sha256=JK958vasIIiuf3ALcFNlvBgGNhFshK8MhQJjwvxhrAM,5453
|
21
21
|
rlgym_learn_algos/ppo/gae_trajectory_processor_pure_python.py,sha256=RpyDR6GQ1JXvwtoKkx5V3z3WvU9ElJdzfNtpPiZDaTc,6831
|
22
22
|
rlgym_learn_algos/ppo/multi_discrete_actor.py,sha256=zSYeBBirjguSv_wO-peo06hioHiVhZQjnd-NYwJxmag,3127
|
23
23
|
rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=h0UR-o2k-_LyeFTzvII3HQHHWyeMJewqLlca8ThtyfA,25105
|
24
|
-
rlgym_learn_algos/ppo/ppo_learner.py,sha256=
|
24
|
+
rlgym_learn_algos/ppo/ppo_learner.py,sha256=Cbbuz0AMwPCmkQ1YPDdZLkbgZOdyrOLEx89Camn-nGE,15942
|
25
25
|
rlgym_learn_algos/ppo/ppo_metrics_logger.py,sha256=niW8xgQLEBCGgTaVyiE_JqsU6RTjV6h-JzM-7c3JT38,2868
|
26
26
|
rlgym_learn_algos/ppo/trajectory.py,sha256=IIH_IG8B_HkyxRPf-YsCyF1jQqNjDx752hgzAehG25I,719
|
27
27
|
rlgym_learn_algos/ppo/trajectory_processor.py,sha256=5eY_mNGjqIkhqnbKeaqDvqIWPdg6wD6Ai3fXH2WoXbw,2091
|
28
28
|
rlgym_learn_algos/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
-
rlgym_learn_algos/rlgym_learn_algos.cp311-win32.pyd,sha256=
|
29
|
+
rlgym_learn_algos/rlgym_learn_algos.cp311-win32.pyd,sha256=wSNbyQLW2s8XROJJ3CwpPBEPaYWqpAsZJ8syylePLNE,337920
|
30
30
|
rlgym_learn_algos/rlgym_learn_algos.pyi,sha256=NwY-sDZWM06TUiKPzxpfH1Td6G6E8TdxtRPgBSh-PPE,1203
|
31
31
|
rlgym_learn_algos/stateful_functions/__init__.py,sha256=QS0KYjuzagNkYiYllXQmjoJn14-G7KZawq1Zvwh8alY,236
|
32
32
|
rlgym_learn_algos/stateful_functions/batch_reward_type_numpy_converter.py,sha256=1yte5qYyl9LWdClHZ_YsF7R9dJqQeYfINMdgNF_59Gs,767
|
@@ -34,6 +34,6 @@ rlgym_learn_algos/stateful_functions/numpy_obs_standardizer.py,sha256=OgtwCaxBGT
|
|
34
34
|
rlgym_learn_algos/stateful_functions/obs_standardizer.py,sha256=qPPc3--J_3mpJJ-QHJjta6dbWWBobL7SYdK5MUP-XMw,606
|
35
35
|
rlgym_learn_algos/util/__init__.py,sha256=VPM6SN4T_625H9t30s9EiLeXiEEWgcyRVHa-LLVNrn4,47
|
36
36
|
rlgym_learn_algos/util/running_stats.py,sha256=0tiGFpKtHWzMa1CxM_ueBzd_ryX4bJBriC8MXcSLg8w,4479
|
37
|
-
rlgym_learn_algos/util/torch_functions.py,sha256=
|
38
|
-
rlgym_learn_algos/util/torch_pydantic.py,sha256=
|
39
|
-
rlgym_learn_algos-0.2.
|
37
|
+
rlgym_learn_algos/util/torch_functions.py,sha256=_uAXhq1YYPneWI3_XXRYsSA3Hn1a8wGjUnI3m9UojdU,3411
|
38
|
+
rlgym_learn_algos/util/torch_pydantic.py,sha256=5AbXQcfQtVgLRBSgCj0Hvi_H42WHLu4Oty4l_i22nAo,3531
|
39
|
+
rlgym_learn_algos-0.2.2.dist-info/RECORD,,
|
File without changes
|