rlgym-learn-algos 0.1.5__cp39-cp39-win_amd64.whl → 0.2.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rlgym_learn_algos/conversion/__init__.py +0 -0
- rlgym_learn_algos/conversion/convert_rlgym_ppo_checkpoint.py +27 -0
- rlgym_learn_algos/logging/metrics_logger.py +1 -1
- rlgym_learn_algos/logging/wandb_metrics_logger.py +27 -22
- rlgym_learn_algos/ppo/experience_buffer.py +60 -42
- rlgym_learn_algos/ppo/experience_buffer_numpy.py +14 -12
- rlgym_learn_algos/ppo/gae_trajectory_processor.py +14 -17
- rlgym_learn_algos/ppo/gae_trajectory_processor_pure_python.py +0 -8
- rlgym_learn_algos/ppo/ppo_agent_controller.py +68 -54
- rlgym_learn_algos/ppo/ppo_learner.py +101 -45
- rlgym_learn_algos/ppo/trajectory_processor.py +4 -3
- rlgym_learn_algos/rlgym_learn_algos.cp39-win_amd64.pyd +0 -0
- rlgym_learn_algos/util/torch_pydantic.py +118 -0
- {rlgym_learn_algos-0.1.5.dist-info → rlgym_learn_algos-0.2.1.dist-info}/METADATA +1 -1
- {rlgym_learn_algos-0.1.5.dist-info → rlgym_learn_algos-0.2.1.dist-info}/RECORD +22 -19
- {rlgym_learn_algos-0.1.5.dist-info → rlgym_learn_algos-0.2.1.dist-info}/WHEEL +1 -1
- {rlgym_learn_algos-0.1.5.dist-info → rlgym_learn_algos-0.2.1.dist-info}/licenses/LICENSE +0 -0
File without changes
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
|
7
|
+
def convert_rlgym_ppo_checkpoint(
|
8
|
+
rlgym_ppo_checkpoint_folder: str, out_folder: Optional[str]
|
9
|
+
):
|
10
|
+
|
11
|
+
if out_folder is None:
|
12
|
+
out_folder = f"rlgym_ppo_converted_checkpoint_{time.time_ns()}"
|
13
|
+
print(f"Saving converted checkpoint to folder {out_folder}")
|
14
|
+
|
15
|
+
os.makedirs(out_folder, exist_ok=True)
|
16
|
+
|
17
|
+
PPO_FILES = [
|
18
|
+
("PPO_POLICY_OPTIMIZER.pt", "actor_optimizer.pt"),
|
19
|
+
("PPO_POLICY.pt", "actor.pt"),
|
20
|
+
("PPO_VALUE_NET_OPTIMIZER.pt", "critic_optimizer.pt"),
|
21
|
+
("PPO_VALUE_NET.pt", "critic.pt"),
|
22
|
+
]
|
23
|
+
os.makedirs(f"{out_folder}/ppo_learner", exist_ok=True)
|
24
|
+
for file in PPO_FILES:
|
25
|
+
with open(f"{rlgym_ppo_checkpoint_folder}/{file[0]}", "rb") as fin:
|
26
|
+
with open(f"{out_folder}/ppo_learner/{file[1]}", "wb") as fout:
|
27
|
+
fout.write(fin.read())
|
@@ -12,9 +12,9 @@ MetricsLoggerAdditionalDerivedConfig = TypeVar("MetricsLoggerAdditionalDerivedCo
|
|
12
12
|
class DerivedMetricsLoggerConfig(
|
13
13
|
Generic[MetricsLoggerConfig, MetricsLoggerAdditionalDerivedConfig]
|
14
14
|
):
|
15
|
+
metrics_logger_config: MetricsLoggerConfig = None
|
15
16
|
checkpoint_load_folder: Optional[str] = None
|
16
17
|
agent_controller_name: str = ""
|
17
|
-
metrics_logger_config: MetricsLoggerConfig = None
|
18
18
|
additional_derived_config: MetricsLoggerAdditionalDerivedConfig = None
|
19
19
|
|
20
20
|
|
@@ -29,7 +29,7 @@ def convert_nested_dict(d):
|
|
29
29
|
return new
|
30
30
|
|
31
31
|
|
32
|
-
class WandbMetricsLoggerConfigModel(BaseModel):
|
32
|
+
class WandbMetricsLoggerConfigModel(BaseModel, extra="forbid"):
|
33
33
|
enable: bool = True
|
34
34
|
project: str = "rlgym-learn"
|
35
35
|
group: str = "unnamed-runs"
|
@@ -37,6 +37,7 @@ class WandbMetricsLoggerConfigModel(BaseModel):
|
|
37
37
|
id: Optional[str] = None
|
38
38
|
new_run_with_timestamp_suffix: bool = False
|
39
39
|
additional_wandb_run_config: Dict[str, Any] = Field(default_factory=dict)
|
40
|
+
settings_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
40
41
|
|
41
42
|
|
42
43
|
@dataclass
|
@@ -76,6 +77,7 @@ class WandbMetricsLogger(
|
|
76
77
|
):
|
77
78
|
self.inner_metrics_logger = inner_metrics_logger
|
78
79
|
self.checkpoint_file_name = checkpoint_file_name
|
80
|
+
self.run_id = None
|
79
81
|
|
80
82
|
def collect_env_metrics(self, data: List[Dict[str, Any]]):
|
81
83
|
self.inner_metrics_logger.collect_env_metrics(data)
|
@@ -107,17 +109,11 @@ class WandbMetricsLogger(
|
|
107
109
|
self.run_id = None
|
108
110
|
return
|
109
111
|
|
110
|
-
if
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
if self.run_id is not None:
|
115
|
-
print(
|
116
|
-
f"{self.config.agent_controller_name}: Wandb run id from checkpoint ({self.run_id}) is being overridden by wandb run id from config: {self.config.metrics_logger_config.id}"
|
117
|
-
)
|
112
|
+
if self.run_id is not None and self.config.metrics_logger_config.id is not None:
|
113
|
+
print(
|
114
|
+
f"{self.config.agent_controller_name}: Wandb run id from checkpoint ({self.run_id}) is being overridden by wandb run id from config: {self.config.metrics_logger_config.id}"
|
115
|
+
)
|
118
116
|
self.run_id = self.config.metrics_logger_config.id
|
119
|
-
else:
|
120
|
-
self.run_id = None
|
121
117
|
|
122
118
|
wandb_config = {
|
123
119
|
**self.config.additional_derived_config.derived_wandb_run_config,
|
@@ -145,22 +141,31 @@ class WandbMetricsLogger(
|
|
145
141
|
id=self.run_id,
|
146
142
|
resume="allow",
|
147
143
|
reinit=True,
|
144
|
+
settings=wandb.Settings(
|
145
|
+
**self.config.metrics_logger_config.settings_kwargs
|
146
|
+
),
|
148
147
|
)
|
149
148
|
self.run_id = self.wandb_run.id
|
150
149
|
print(f"{self.config.agent_controller_name}: Created wandb run! {self.run_id}")
|
151
150
|
|
152
151
|
def _load_from_checkpoint(self):
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
152
|
+
try:
|
153
|
+
with open(
|
154
|
+
os.path.join(
|
155
|
+
self.config.checkpoint_load_folder,
|
156
|
+
self.checkpoint_file_name,
|
157
|
+
),
|
158
|
+
"rt",
|
159
|
+
) as f:
|
160
|
+
state = json.load(f)
|
161
|
+
if "run_id" in state:
|
162
|
+
self.run_id = state["run_id"]
|
163
|
+
else:
|
164
|
+
self.run_id = None
|
165
|
+
except FileNotFoundError:
|
166
|
+
print(
|
167
|
+
f"{self.config.agent_controller_name}: Tried to load wandb run from checkpoint using the file at location {str(os.path.join(self.config.checkpoint_load_folder, self.checkpoint_file_name))}, but there is no such file! A new run will be created based on the config values instead."
|
168
|
+
)
|
164
169
|
self.run_id = None
|
165
170
|
|
166
171
|
def save_checkpoint(self, folder_path):
|
@@ -7,6 +7,8 @@ import numpy as np
|
|
7
7
|
import torch
|
8
8
|
from pydantic import BaseModel, Field, model_validator
|
9
9
|
from rlgym.api import ActionType, AgentID, ObsType, RewardType
|
10
|
+
from rlgym_learn_algos.util.torch_functions import get_device
|
11
|
+
from rlgym_learn_algos.util.torch_pydantic import PydanticTorchDevice
|
10
12
|
|
11
13
|
from .trajectory import Trajectory
|
12
14
|
from .trajectory_processor import (
|
@@ -19,8 +21,9 @@ from .trajectory_processor import (
|
|
19
21
|
EXPERIENCE_BUFFER_FILE = "experience_buffer.pkl"
|
20
22
|
|
21
23
|
|
22
|
-
class ExperienceBufferConfigModel(BaseModel):
|
24
|
+
class ExperienceBufferConfigModel(BaseModel, extra="forbid"):
|
23
25
|
max_size: int = 100000
|
26
|
+
device: PydanticTorchDevice = "auto"
|
24
27
|
trajectory_processor_config: Dict[str, Any] = Field(default_factory=dict)
|
25
28
|
|
26
29
|
@model_validator(mode="before")
|
@@ -31,21 +34,24 @@ class ExperienceBufferConfigModel(BaseModel):
|
|
31
34
|
data.trajectory_processor_config = (
|
32
35
|
data.trajectory_processor_config.model_dump()
|
33
36
|
)
|
34
|
-
elif isinstance(data, dict)
|
35
|
-
if
|
36
|
-
data["trajectory_processor_config"]
|
37
|
-
"trajectory_processor_config"
|
38
|
-
|
37
|
+
elif isinstance(data, dict):
|
38
|
+
if "trajectory_processor_config" in data:
|
39
|
+
if isinstance(data["trajectory_processor_config"], BaseModel):
|
40
|
+
data["trajectory_processor_config"] = data[
|
41
|
+
"trajectory_processor_config"
|
42
|
+
].model_dump()
|
43
|
+
if "device" not in data or data["device"] == "auto":
|
44
|
+
data["device"] = get_device("auto")
|
39
45
|
return data
|
40
46
|
|
41
47
|
|
42
48
|
@dataclass
|
43
49
|
class DerivedExperienceBufferConfig:
|
44
|
-
|
50
|
+
experience_buffer_config: ExperienceBufferConfigModel
|
51
|
+
agent_controller_name: str
|
45
52
|
seed: int
|
46
|
-
dtype:
|
47
|
-
|
48
|
-
trajectory_processor_config: Dict[str, Any]
|
53
|
+
dtype: torch.dtype
|
54
|
+
learner_device: torch.device
|
49
55
|
checkpoint_load_folder: Optional[str] = None
|
50
56
|
|
51
57
|
|
@@ -111,42 +117,51 @@ class ExperienceBuffer(
|
|
111
117
|
self.agent_ids: List[AgentID] = []
|
112
118
|
self.observations: List[ObsType] = []
|
113
119
|
self.actions: List[ActionType] = []
|
114
|
-
self.log_probs = torch.FloatTensor()
|
115
|
-
self.values = torch.FloatTensor()
|
116
|
-
self.advantages = torch.FloatTensor()
|
117
120
|
|
118
121
|
def load(self, config: DerivedExperienceBufferConfig):
|
119
122
|
self.config = config
|
120
123
|
self.rng = np.random.RandomState(config.seed)
|
121
124
|
trajectory_processor_config = self.trajectory_processor.validate_config(
|
122
|
-
config.trajectory_processor_config
|
125
|
+
config.experience_buffer_config.trajectory_processor_config
|
123
126
|
)
|
124
127
|
self.trajectory_processor.load(
|
125
128
|
DerivedTrajectoryProcessorConfig(
|
126
129
|
trajectory_processor_config=trajectory_processor_config,
|
130
|
+
agent_controller_name=config.agent_controller_name,
|
127
131
|
dtype=config.dtype,
|
128
|
-
device=config.
|
132
|
+
device=config.learner_device,
|
129
133
|
)
|
130
134
|
)
|
135
|
+
self.log_probs = torch.tensor([], dtype=config.dtype)
|
136
|
+
self.values = torch.tensor([], dtype=config.dtype)
|
137
|
+
self.advantages = torch.tensor([], dtype=config.dtype)
|
131
138
|
if self.config.checkpoint_load_folder is not None:
|
132
139
|
self._load_from_checkpoint()
|
133
|
-
self.log_probs = self.log_probs.to(config.
|
134
|
-
self.values = self.values.to(config.
|
135
|
-
self.advantages = self.advantages.to(config.
|
140
|
+
self.log_probs = self.log_probs.to(config.learner_device)
|
141
|
+
self.values = self.values.to(config.learner_device)
|
142
|
+
self.advantages = self.advantages.to(config.learner_device)
|
136
143
|
|
137
144
|
def _load_from_checkpoint(self):
|
138
145
|
# lazy way
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
146
|
+
# TODO: don't use pickle for torch things, use torch.load because of map_location. Or maybe define a custom unpickler for this? Or maybe one already exists?
|
147
|
+
try:
|
148
|
+
with open(
|
149
|
+
os.path.join(
|
150
|
+
self.config.checkpoint_load_folder, EXPERIENCE_BUFFER_FILE
|
151
|
+
),
|
152
|
+
"rb",
|
153
|
+
) as f:
|
154
|
+
state_dict = pickle.load(f)
|
155
|
+
self.agent_ids = state_dict["agent_ids"]
|
156
|
+
self.observations = state_dict["observations"]
|
157
|
+
self.actions = state_dict["actions"]
|
158
|
+
self.log_probs = state_dict["log_probs"]
|
159
|
+
self.values = state_dict["values"]
|
160
|
+
self.advantages = state_dict["advantages"]
|
161
|
+
except FileNotFoundError:
|
162
|
+
print(
|
163
|
+
f"{self.config.agent_controller_name}: Tried to load experience buffer from checkpoint using the file at location {str(os.path.join(self.config.checkpoint_load_folder, EXPERIENCE_BUFFER_FILE))}, but there is no such file! A blank experience buffer will be used instead."
|
164
|
+
)
|
150
165
|
|
151
166
|
def save_checkpoint(self, folder_path):
|
152
167
|
os.makedirs(folder_path, exist_ok=True)
|
@@ -195,29 +210,36 @@ class ExperienceBuffer(
|
|
195
210
|
exp_buffer_data
|
196
211
|
)
|
197
212
|
|
198
|
-
self.agent_ids = _cat_list(
|
213
|
+
self.agent_ids = _cat_list(
|
214
|
+
self.agent_ids, agent_ids, self.config.experience_buffer_config.max_size
|
215
|
+
)
|
199
216
|
self.observations = _cat_list(
|
200
|
-
self.observations,
|
217
|
+
self.observations,
|
218
|
+
observations,
|
219
|
+
self.config.experience_buffer_config.max_size,
|
220
|
+
)
|
221
|
+
self.actions = _cat_list(
|
222
|
+
self.actions, actions, self.config.experience_buffer_config.max_size
|
201
223
|
)
|
202
|
-
self.actions = _cat_list(self.actions, actions, self.config.max_size)
|
203
224
|
self.log_probs = _cat(
|
204
225
|
self.log_probs,
|
205
226
|
log_probs,
|
206
|
-
self.config.max_size,
|
227
|
+
self.config.experience_buffer_config.max_size,
|
207
228
|
)
|
208
229
|
self.values = _cat(
|
209
230
|
self.values,
|
210
231
|
values,
|
211
|
-
self.config.max_size,
|
232
|
+
self.config.experience_buffer_config.max_size,
|
212
233
|
)
|
213
234
|
self.advantages = _cat(
|
214
235
|
self.advantages,
|
215
236
|
advantages,
|
216
|
-
self.config.max_size,
|
237
|
+
self.config.experience_buffer_config.max_size,
|
217
238
|
)
|
218
239
|
|
219
240
|
return trajectory_processor_data
|
220
241
|
|
242
|
+
# TODO: tensordict?
|
221
243
|
def _get_samples(self, indices) -> Tuple[
|
222
244
|
Iterable[AgentID],
|
223
245
|
Iterable[ObsType],
|
@@ -242,18 +264,14 @@ class ExperienceBuffer(
|
|
242
264
|
:param batch_size: size of each batch yielded by the generator.
|
243
265
|
:return:
|
244
266
|
"""
|
245
|
-
if self.config.
|
267
|
+
if self.config.learner_device.type != "cpu":
|
246
268
|
torch.cuda.current_stream().synchronize()
|
247
269
|
total_samples = self.values.shape[0]
|
248
270
|
indices = self.rng.permutation(total_samples)
|
249
271
|
start_idx = 0
|
250
|
-
batches = []
|
251
272
|
while start_idx + batch_size <= total_samples:
|
252
|
-
|
253
|
-
self._get_samples(indices[start_idx : start_idx + batch_size])
|
254
|
-
)
|
273
|
+
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
255
274
|
start_idx += batch_size
|
256
|
-
return batches
|
257
275
|
|
258
276
|
def clear(self):
|
259
277
|
"""
|
@@ -265,4 +283,4 @@ class ExperienceBuffer(
|
|
265
283
|
del self.log_probs
|
266
284
|
del self.values
|
267
285
|
del self.advantages
|
268
|
-
self.__init__(self.
|
286
|
+
self.__init__(self.trajectory_processor)
|
@@ -76,25 +76,31 @@ class NumpyExperienceBuffer(
|
|
76
76
|
exp_buffer_data
|
77
77
|
)
|
78
78
|
|
79
|
-
self.agent_ids = _cat_list(
|
79
|
+
self.agent_ids = _cat_list(
|
80
|
+
self.agent_ids, agent_ids, self.config.experience_buffer_config.max_size
|
81
|
+
)
|
80
82
|
self.observations = _cat_numpy(
|
81
|
-
self.observations,
|
83
|
+
self.observations,
|
84
|
+
observations,
|
85
|
+
self.config.experience_buffer_config.max_size,
|
86
|
+
)
|
87
|
+
self.actions = _cat_numpy(
|
88
|
+
self.actions, actions, self.config.experience_buffer_config.max_size
|
82
89
|
)
|
83
|
-
self.actions = _cat_numpy(self.actions, actions, self.config.max_size)
|
84
90
|
self.log_probs = _cat(
|
85
91
|
self.log_probs,
|
86
92
|
log_probs,
|
87
|
-
self.config.max_size,
|
93
|
+
self.config.experience_buffer_config.max_size,
|
88
94
|
)
|
89
95
|
self.values = _cat(
|
90
96
|
self.values,
|
91
97
|
values,
|
92
|
-
self.config.max_size,
|
98
|
+
self.config.experience_buffer_config.max_size,
|
93
99
|
)
|
94
100
|
self.advantages = _cat(
|
95
101
|
self.advantages,
|
96
102
|
advantages,
|
97
|
-
self.config.max_size,
|
103
|
+
self.config.experience_buffer_config.max_size,
|
98
104
|
)
|
99
105
|
|
100
106
|
return trajectory_processor_data
|
@@ -116,18 +122,14 @@ class NumpyExperienceBuffer(
|
|
116
122
|
:param batch_size: size of each batch yielded by the generator.
|
117
123
|
:return:
|
118
124
|
"""
|
119
|
-
if self.config.device != "cpu":
|
125
|
+
if self.config.experience_buffer_config.device.type != "cpu":
|
120
126
|
torch.cuda.current_stream().synchronize()
|
121
127
|
total_samples = self.values.shape[0]
|
122
128
|
indices = self.rng.permutation(total_samples)
|
123
129
|
start_idx = 0
|
124
|
-
batches = []
|
125
130
|
while start_idx + batch_size <= total_samples:
|
126
|
-
|
127
|
-
self._get_samples(indices[start_idx : start_idx + batch_size])
|
128
|
-
)
|
131
|
+
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
129
132
|
start_idx += batch_size
|
130
|
-
return batches
|
131
133
|
|
132
134
|
def clear(self):
|
133
135
|
"""
|
@@ -20,7 +20,7 @@ from ..ppo import RustDerivedGAETrajectoryProcessorConfig, RustGAETrajectoryProc
|
|
20
20
|
from .trajectory_processor import TRAJECTORY_PROCESSOR_FILE, TrajectoryProcessor
|
21
21
|
|
22
22
|
|
23
|
-
class GAETrajectoryProcessorConfigModel(BaseModel):
|
23
|
+
class GAETrajectoryProcessorConfigModel(BaseModel, extra="forbid"):
|
24
24
|
gamma: float = 0.99
|
25
25
|
lmbda: float = 0.95
|
26
26
|
standardize_returns: bool = True
|
@@ -115,6 +115,7 @@ class GAETrajectoryProcessor(
|
|
115
115
|
self.max_returns_per_stats_increment = (
|
116
116
|
config.trajectory_processor_config.max_returns_per_stats_increment
|
117
117
|
)
|
118
|
+
self.agent_controller_name = config.agent_controller_name
|
118
119
|
self.dtype = config.dtype
|
119
120
|
self.device = config.device
|
120
121
|
self.checkpoint_load_folder = config.checkpoint_load_folder
|
@@ -122,29 +123,25 @@ class GAETrajectoryProcessor(
|
|
122
123
|
self._load_from_checkpoint()
|
123
124
|
self.rust_gae_trajectory_processor.load(
|
124
125
|
RustDerivedGAETrajectoryProcessorConfig(
|
125
|
-
self.gamma, self.lmbda, np.dtype(self.dtype)
|
126
|
+
self.gamma, self.lmbda, np.dtype(str(self.dtype)[6:])
|
126
127
|
)
|
127
128
|
)
|
128
129
|
|
129
130
|
def _load_from_checkpoint(self):
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
131
|
+
try:
|
132
|
+
with open(
|
133
|
+
os.path.join(self.checkpoint_load_folder, TRAJECTORY_PROCESSOR_FILE),
|
134
|
+
"rt",
|
135
|
+
) as f:
|
136
|
+
state = json.load(f)
|
137
|
+
self.return_stats.load_state_dict(state["return_running_stats"])
|
138
|
+
except FileNotFoundError:
|
139
|
+
print(
|
140
|
+
f"{self.agent_controller_name}: Tried to load trajectory processor from checkpoint using the trajectory processor file at location {str(os.path.join(self.checkpoint_load_folder, TRAJECTORY_PROCESSOR_FILE))}, but there is no such file! Running stats will be initialized as if this were a new run instead."
|
141
|
+
)
|
141
142
|
|
142
143
|
def save_checkpoint(self, folder_path):
|
143
144
|
state = {
|
144
|
-
"gamma": self.gamma,
|
145
|
-
"lambda": self.lmbda,
|
146
|
-
"standardize_returns": self.standardize_returns,
|
147
|
-
"max_returns_per_stats_increment": self.max_returns_per_stats_increment,
|
148
145
|
"return_running_stats": self.return_stats.state_dict(),
|
149
146
|
}
|
150
147
|
with open(
|
@@ -161,18 +161,10 @@ class GAETrajectoryProcessorPurePython(
|
|
161
161
|
"rt",
|
162
162
|
) as f:
|
163
163
|
state = json.load(f)
|
164
|
-
self.gamma = state["gamma"]
|
165
|
-
self.lmbda = state["lambda"]
|
166
|
-
self.standardize_returns = state["standardize_returns"]
|
167
|
-
self.max_returns_per_stats_increment = state["max_returns_per_stats_increment"]
|
168
164
|
self.return_stats.load_state_dict(state["return_running_stats"])
|
169
165
|
|
170
166
|
def save_checkpoint(self, folder_path):
|
171
167
|
state = {
|
172
|
-
"gamma": self.gamma,
|
173
|
-
"lambda": self.lmbda,
|
174
|
-
"standardize_returns": self.standardize_returns,
|
175
|
-
"max_returns_per_stats_increment": self.max_returns_per_stats_increment,
|
176
168
|
"return_running_stats": self.return_stats.state_dict(),
|
177
169
|
}
|
178
170
|
with open(
|
@@ -24,8 +24,6 @@ from rlgym.api import (
|
|
24
24
|
)
|
25
25
|
from rlgym_learn import EnvActionResponse, EnvActionResponseType, Timestep
|
26
26
|
from rlgym_learn.api.agent_controller import AgentController
|
27
|
-
from torch import device as _device
|
28
|
-
|
29
27
|
from rlgym_learn_algos.logging import (
|
30
28
|
DerivedMetricsLoggerConfig,
|
31
29
|
MetricsLogger,
|
@@ -36,6 +34,7 @@ from rlgym_learn_algos.logging import (
|
|
36
34
|
)
|
37
35
|
from rlgym_learn_algos.stateful_functions import ObsStandardizer
|
38
36
|
from rlgym_learn_algos.util.torch_functions import get_device
|
37
|
+
from torch import device as _device
|
39
38
|
|
40
39
|
from .actor import Actor
|
41
40
|
from .critic import Critic
|
@@ -62,15 +61,13 @@ ITERATION_SHARED_INFOS_FILE = "iteration_shared_infos.pkl"
|
|
62
61
|
CURRENT_TRAJECTORIES_FILE = "current_trajectories.pkl"
|
63
62
|
|
64
63
|
|
65
|
-
class PPOAgentControllerConfigModel(BaseModel):
|
64
|
+
class PPOAgentControllerConfigModel(BaseModel, extra="forbid"):
|
66
65
|
timesteps_per_iteration: int = 50000
|
67
66
|
save_every_ts: int = 1_000_000
|
68
67
|
add_unix_timestamp: bool = True
|
69
68
|
checkpoint_load_folder: Optional[str] = None
|
70
69
|
n_checkpoints_to_keep: int = 5
|
71
70
|
random_seed: int = 123
|
72
|
-
dtype: str = "float32"
|
73
|
-
device: Optional[str] = None
|
74
71
|
learner_config: PPOLearnerConfigModel = Field(default_factory=PPOLearnerConfigModel)
|
75
72
|
experience_buffer_config: ExperienceBufferConfigModel = Field(
|
76
73
|
default_factory=ExperienceBufferConfigModel
|
@@ -190,11 +187,9 @@ class PPOAgentController(
|
|
190
187
|
|
191
188
|
def load(self, config):
|
192
189
|
self.config = config
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
self.device = get_device(device)
|
197
|
-
print(f"{self.config.agent_controller_name}: Using device {self.device}")
|
190
|
+
print(
|
191
|
+
f"{self.config.agent_controller_name}: Using device {config.agent_controller_config.learner_config.device}"
|
192
|
+
)
|
198
193
|
agent_controller_config = config.agent_controller_config
|
199
194
|
learner_config = config.agent_controller_config.learner_config
|
200
195
|
experience_buffer_config = (
|
@@ -234,14 +229,14 @@ class PPOAgentController(
|
|
234
229
|
# TODO: this doesn't seem to be working
|
235
230
|
if abs_save_folder == loaded_checkpoint_runs_folder:
|
236
231
|
print(
|
237
|
-
"Using the loaded checkpoint's run folder as the checkpoints save folder."
|
232
|
+
f"{config.agent_controller_name}: Using the loaded checkpoint's run folder as the checkpoints save folder."
|
238
233
|
)
|
239
234
|
checkpoints_save_folder = os.path.abspath(
|
240
235
|
os.path.join(agent_controller_config.checkpoint_load_folder, "..")
|
241
236
|
)
|
242
237
|
else:
|
243
238
|
print(
|
244
|
-
"Runs folder in config does not align with loaded checkpoint's runs folder. Creating new run in the config-based runs folder."
|
239
|
+
f"{config.agent_controller_name}: Runs folder in config does not align with loaded checkpoint's runs folder. Creating new run in the config-based runs folder."
|
245
240
|
)
|
246
241
|
checkpoints_save_folder = os.path.join(
|
247
242
|
config.save_folder, agent_controller_config.run_name + run_suffix
|
@@ -257,26 +252,20 @@ class PPOAgentController(
|
|
257
252
|
|
258
253
|
self.learner.load(
|
259
254
|
DerivedPPOLearnerConfig(
|
255
|
+
learner_config=learner_config,
|
256
|
+
agent_controller_name=config.agent_controller_name,
|
260
257
|
obs_space=self.obs_space,
|
261
258
|
action_space=self.action_space,
|
262
|
-
n_epochs=learner_config.n_epochs,
|
263
|
-
batch_size=learner_config.batch_size,
|
264
|
-
n_minibatches=learner_config.n_minibatches,
|
265
|
-
ent_coef=learner_config.ent_coef,
|
266
|
-
clip_range=learner_config.clip_range,
|
267
|
-
actor_lr=learner_config.actor_lr,
|
268
|
-
critic_lr=learner_config.critic_lr,
|
269
|
-
device=self.device,
|
270
259
|
checkpoint_load_folder=learner_checkpoint_load_folder,
|
271
260
|
)
|
272
261
|
)
|
273
262
|
self.experience_buffer.load(
|
274
263
|
DerivedExperienceBufferConfig(
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
264
|
+
experience_buffer_config=experience_buffer_config,
|
265
|
+
agent_controller_name=config.agent_controller_name,
|
266
|
+
seed=config.base_config.random_seed,
|
267
|
+
dtype=agent_controller_config.learner_config.dtype,
|
268
|
+
learner_device=agent_controller_config.learner_config.device,
|
280
269
|
checkpoint_load_folder=experience_buffer_checkpoint_load_folder,
|
281
270
|
)
|
282
271
|
)
|
@@ -301,9 +290,9 @@ class PPOAgentController(
|
|
301
290
|
additional_derived_config = None
|
302
291
|
self.metrics_logger.load(
|
303
292
|
DerivedMetricsLoggerConfig(
|
293
|
+
metrics_logger_config=metrics_logger_config,
|
304
294
|
checkpoint_load_folder=metrics_logger_checkpoint_load_folder,
|
305
295
|
agent_controller_name=config.agent_controller_name,
|
306
|
-
metrics_logger_config=metrics_logger_config,
|
307
296
|
additional_derived_config=additional_derived_config,
|
308
297
|
)
|
309
298
|
)
|
@@ -316,33 +305,57 @@ class PPOAgentController(
|
|
316
305
|
random.seed(self.config.base_config.random_seed)
|
317
306
|
|
318
307
|
def _load_from_checkpoint(self):
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
)
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
308
|
+
try:
|
309
|
+
with open(
|
310
|
+
os.path.join(
|
311
|
+
self.config.agent_controller_config.checkpoint_load_folder,
|
312
|
+
CURRENT_TRAJECTORIES_FILE,
|
313
|
+
),
|
314
|
+
"rb",
|
315
|
+
) as f:
|
316
|
+
current_trajectories: Dict[
|
317
|
+
int,
|
318
|
+
EnvTrajectories[AgentID, ActionType, ObsType, RewardType],
|
319
|
+
] = pickle.load(f)
|
320
|
+
except FileNotFoundError:
|
321
|
+
print(
|
322
|
+
f"{self.config.agent_controller_name}: Tried to load current trajectories from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder, CURRENT_TRAJECTORIES_FILE))}, but there is no such file! Current trajectories will be initialized as an empty dict instead."
|
323
|
+
)
|
324
|
+
current_trajectories = {}
|
325
|
+
try:
|
326
|
+
with open(
|
327
|
+
os.path.join(
|
328
|
+
self.config.agent_controller_config.checkpoint_load_folder,
|
329
|
+
ITERATION_SHARED_INFOS_FILE,
|
330
|
+
),
|
331
|
+
"rb",
|
332
|
+
) as f:
|
333
|
+
iteration_shared_infos: List[Dict[str, Any]] = pickle.load(f)
|
334
|
+
except FileNotFoundError:
|
335
|
+
print(
|
336
|
+
f"{self.config.agent_controller_name}: Tried to load iteration shared info data from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder, ITERATION_SHARED_INFOS_FILE))}, but there is no such file! Iteration shared info data will be initialized as an empty list instead."
|
337
|
+
)
|
338
|
+
current_trajectories = {}
|
339
|
+
try:
|
340
|
+
with open(
|
341
|
+
os.path.join(
|
342
|
+
self.config.agent_controller_config.checkpoint_load_folder,
|
343
|
+
PPO_AGENT_FILE,
|
344
|
+
),
|
345
|
+
"rt",
|
346
|
+
) as f:
|
347
|
+
state = json.load(f)
|
348
|
+
except FileNotFoundError:
|
349
|
+
print(
|
350
|
+
f"{self.config.agent_controller_name}: Tried to load PPO agent miscellaneous state data from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder, PPO_AGENT_FILE))}, but there is no such file! This state data will be initialized as if this were a new run instead."
|
351
|
+
)
|
352
|
+
state = {
|
353
|
+
"cur_iteration": 0,
|
354
|
+
"iteration_timesteps": 0,
|
355
|
+
"cumulative_timesteps": 0,
|
356
|
+
"iteration_start_time": time.perf_counter(),
|
357
|
+
"timestep_collection_start_time": time.perf_counter(),
|
358
|
+
}
|
346
359
|
|
347
360
|
self.current_trajectories = current_trajectories
|
348
361
|
self.iteration_shared_infos = iteration_shared_infos
|
@@ -465,6 +478,7 @@ class PPOAgentController(
|
|
465
478
|
):
|
466
479
|
self.timestep_collection_end_time = time.perf_counter()
|
467
480
|
self._learn()
|
481
|
+
self.cur_iteration += 1
|
468
482
|
if self.ts_since_last_save >= self.config.agent_controller_config.save_every_ts:
|
469
483
|
self.save_checkpoint()
|
470
484
|
self.ts_since_last_save = 0
|
@@ -563,5 +577,5 @@ class PPOAgentController(
|
|
563
577
|
for idx, (start, stop) in enumerate(traj_timestep_idx_ranges):
|
564
578
|
self.current_trajectories[idx].val_preds = val_preds[start : stop - 1]
|
565
579
|
self.current_trajectories[idx].final_val_pred = val_preds[stop - 1]
|
566
|
-
if self.device != "cpu":
|
580
|
+
if self.config.agent_controller_config.learner_config.device.type != "cpu":
|
567
581
|
torch.cuda.current_stream().synchronize()
|
@@ -7,7 +7,7 @@ from typing import Generic, Optional
|
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
10
|
-
from pydantic import BaseModel
|
10
|
+
from pydantic import BaseModel, field_serializer, model_validator
|
11
11
|
from rlgym.api import (
|
12
12
|
ActionSpaceType,
|
13
13
|
ActionType,
|
@@ -18,13 +18,20 @@ from rlgym.api import (
|
|
18
18
|
)
|
19
19
|
from torch import nn as nn
|
20
20
|
|
21
|
+
from rlgym_learn_algos.util.torch_functions import get_device
|
22
|
+
from rlgym_learn_algos.util.torch_pydantic import (
|
23
|
+
PydanticTorchDevice,
|
24
|
+
PydanticTorchDtype,
|
25
|
+
)
|
26
|
+
|
21
27
|
from .actor import Actor
|
22
28
|
from .critic import Critic
|
23
29
|
from .experience_buffer import ExperienceBuffer
|
24
30
|
from .trajectory_processor import TrajectoryProcessorConfig, TrajectoryProcessorData
|
25
31
|
|
26
32
|
|
27
|
-
class PPOLearnerConfigModel(BaseModel):
|
33
|
+
class PPOLearnerConfigModel(BaseModel, extra="forbid"):
|
34
|
+
dtype: PydanticTorchDtype = torch.float32
|
28
35
|
n_epochs: int = 1
|
29
36
|
batch_size: int = 50000
|
30
37
|
n_minibatches: int = 1
|
@@ -32,20 +39,24 @@ class PPOLearnerConfigModel(BaseModel):
|
|
32
39
|
clip_range: float = 0.2
|
33
40
|
actor_lr: float = 3e-4
|
34
41
|
critic_lr: float = 3e-4
|
42
|
+
device: PydanticTorchDevice = "auto"
|
43
|
+
|
44
|
+
@model_validator(mode="before")
|
45
|
+
@classmethod
|
46
|
+
def set_device(cls, data):
|
47
|
+
if isinstance(data, dict) and (
|
48
|
+
"device" not in data or data["device"] == "auto"
|
49
|
+
):
|
50
|
+
data["device"] = get_device("auto")
|
51
|
+
return data
|
35
52
|
|
36
53
|
|
37
54
|
@dataclass
|
38
55
|
class DerivedPPOLearnerConfig:
|
56
|
+
learner_config: PPOLearnerConfigModel
|
57
|
+
agent_controller_name: str
|
39
58
|
obs_space: ObsSpaceType
|
40
59
|
action_space: ActionSpaceType
|
41
|
-
n_epochs: int = 10
|
42
|
-
batch_size: int = 50000
|
43
|
-
n_minibatches: int = 1
|
44
|
-
ent_coef: float = 0.005
|
45
|
-
clip_range: float = 0.2
|
46
|
-
actor_lr: float = 3e-4
|
47
|
-
critic_lr: float = 3e-4
|
48
|
-
device: str = "auto"
|
49
60
|
checkpoint_load_folder: Optional[str] = None
|
50
61
|
|
51
62
|
|
@@ -97,15 +108,17 @@ class PPOLearner(
|
|
97
108
|
self.config = config
|
98
109
|
|
99
110
|
self.actor = self.actor_factory(
|
100
|
-
config.obs_space, config.action_space, config.device
|
111
|
+
config.obs_space, config.action_space, config.learner_config.device
|
112
|
+
)
|
113
|
+
self.critic = self.critic_factory(
|
114
|
+
config.obs_space, config.learner_config.device
|
101
115
|
)
|
102
|
-
self.critic = self.critic_factory(config.obs_space, config.device)
|
103
116
|
|
104
117
|
self.actor_optimizer = torch.optim.Adam(
|
105
|
-
self.actor.parameters(), lr=self.config.actor_lr
|
118
|
+
self.actor.parameters(), lr=self.config.learner_config.actor_lr
|
106
119
|
)
|
107
120
|
self.critic_optimizer = torch.optim.Adam(
|
108
|
-
self.critic.parameters(), lr=self.config.critic_lr
|
121
|
+
self.critic.parameters(), lr=self.config.learner_config.critic_lr
|
109
122
|
)
|
110
123
|
self.critic_loss_fn = torch.nn.MSELoss()
|
111
124
|
|
@@ -122,51 +135,78 @@ class PPOLearner(
|
|
122
135
|
total_parameters = actor_params_count + critic_params_count
|
123
136
|
|
124
137
|
# Display in a structured manner
|
125
|
-
print("Trainable Parameters:")
|
126
|
-
print(f"{'Component':<10} {'Count':<10}")
|
138
|
+
print(f"{self.config.agent_controller_name}: Trainable Parameters:")
|
139
|
+
print(f"{self.config.agent_controller_name}: {'Component':<10} {'Count':<10}")
|
127
140
|
print("-" * 20)
|
128
|
-
print(
|
129
|
-
|
141
|
+
print(
|
142
|
+
f"{self.config.agent_controller_name}: {'Policy':<10} {actor_params_count:<10}"
|
143
|
+
)
|
144
|
+
print(
|
145
|
+
f"{self.config.agent_controller_name}: {'Critic':<10} {critic_params_count:<10}"
|
146
|
+
)
|
130
147
|
print("-" * 20)
|
131
|
-
print(
|
148
|
+
print(
|
149
|
+
f"{self.config.agent_controller_name}: {'Total':<10} {total_parameters:<10}"
|
150
|
+
)
|
132
151
|
|
133
|
-
print(
|
134
|
-
|
152
|
+
print(
|
153
|
+
f"{self.config.agent_controller_name}: Current Policy Learning Rate: {self.config.learner_config.actor_lr}"
|
154
|
+
)
|
155
|
+
print(
|
156
|
+
f"{self.config.agent_controller_name}: Current Critic Learning Rate: {self.config.learner_config.critic_lr}"
|
157
|
+
)
|
135
158
|
self.cumulative_model_updates = 0
|
136
159
|
|
137
160
|
if self.config.checkpoint_load_folder is not None:
|
138
161
|
self._load_from_checkpoint()
|
139
162
|
self.minibatch_size = int(
|
140
|
-
np.ceil(
|
163
|
+
np.ceil(
|
164
|
+
self.config.learner_config.batch_size
|
165
|
+
/ self.config.learner_config.n_minibatches
|
166
|
+
)
|
141
167
|
)
|
142
168
|
|
143
169
|
def _load_from_checkpoint(self):
|
144
170
|
|
145
171
|
assert os.path.exists(
|
146
172
|
self.config.checkpoint_load_folder
|
147
|
-
), f"PPO Learner cannot find folder: {self.config.checkpoint_load_folder}"
|
173
|
+
), f"{self.config.agent_controller_name}: PPO Learner cannot find folder: {self.config.checkpoint_load_folder}"
|
148
174
|
|
149
175
|
self.actor.load_state_dict(
|
150
|
-
torch.load(
|
176
|
+
torch.load(
|
177
|
+
os.path.join(self.config.checkpoint_load_folder, ACTOR_FILE),
|
178
|
+
map_location=self.config.learner_config.device,
|
179
|
+
)
|
151
180
|
)
|
152
181
|
self.critic.load_state_dict(
|
153
|
-
torch.load(
|
182
|
+
torch.load(
|
183
|
+
os.path.join(self.config.checkpoint_load_folder, CRITIC_FILE),
|
184
|
+
map_location=self.config.learner_config.device,
|
185
|
+
)
|
154
186
|
)
|
155
187
|
self.actor_optimizer.load_state_dict(
|
156
188
|
torch.load(
|
157
|
-
os.path.join(self.config.checkpoint_load_folder, ACTOR_OPTIMIZER_FILE)
|
189
|
+
os.path.join(self.config.checkpoint_load_folder, ACTOR_OPTIMIZER_FILE),
|
190
|
+
map_location=self.config.learner_config.device,
|
158
191
|
)
|
159
192
|
)
|
160
193
|
self.critic_optimizer.load_state_dict(
|
161
194
|
torch.load(
|
162
|
-
os.path.join(self.config.checkpoint_load_folder, CRITIC_OPTIMIZER_FILE)
|
195
|
+
os.path.join(self.config.checkpoint_load_folder, CRITIC_OPTIMIZER_FILE),
|
196
|
+
map_location=self.config.learner_config.device,
|
163
197
|
)
|
164
198
|
)
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
199
|
+
try:
|
200
|
+
with open(
|
201
|
+
os.path.join(self.config.checkpoint_load_folder, MISC_STATE), "rt"
|
202
|
+
) as f:
|
203
|
+
misc_state = json.load(f)
|
204
|
+
self.cumulative_model_updates = misc_state["cumulative_model_updates"]
|
205
|
+
except FileNotFoundError:
|
206
|
+
print(
|
207
|
+
f"{self.config.agent_controller_name}: Tried to load the PPO learner's misc state from the file at location {str(os.path.join(self.config.checkpoint_load_folder, MISC_STATE))}, but there is no such file! Miscellaneous stats will be initialized as if this were a new run instead."
|
208
|
+
)
|
209
|
+
self.cumulative_model_updates = 0
|
170
210
|
|
171
211
|
def save_checkpoint(self, folder_path):
|
172
212
|
os.makedirs(folder_path, exist_ok=True)
|
@@ -215,9 +255,11 @@ class PPOLearner(
|
|
215
255
|
critic_before = torch.nn.utils.parameters_to_vector(self.critic.parameters())
|
216
256
|
|
217
257
|
t1 = time.time()
|
218
|
-
for epoch in range(self.config.n_epochs):
|
258
|
+
for epoch in range(self.config.learner_config.n_epochs):
|
219
259
|
# Get all shuffled batches from the experience buffer.
|
220
|
-
batches = exp.get_all_batches_shuffled(
|
260
|
+
batches = exp.get_all_batches_shuffled(
|
261
|
+
self.config.learner_config.batch_size
|
262
|
+
)
|
221
263
|
for batch in batches:
|
222
264
|
(
|
223
265
|
batch_agent_ids,
|
@@ -232,20 +274,29 @@ class PPOLearner(
|
|
232
274
|
self.critic_optimizer.zero_grad()
|
233
275
|
|
234
276
|
for minibatch_slice in range(
|
235
|
-
0, self.config.batch_size, self.minibatch_size
|
277
|
+
0, self.config.learner_config.batch_size, self.minibatch_size
|
236
278
|
):
|
237
279
|
# Send everything to the device and enforce correct shapes.
|
238
280
|
start = minibatch_slice
|
239
|
-
stop = min(
|
240
|
-
|
281
|
+
stop = min(
|
282
|
+
start + self.minibatch_size,
|
283
|
+
self.config.learner_config.batch_size,
|
284
|
+
)
|
285
|
+
minibatch_ratio = (
|
286
|
+
stop - start
|
287
|
+
) / self.config.learner_config.batch_size
|
241
288
|
|
242
289
|
agent_ids = batch_agent_ids[start:stop]
|
243
290
|
obs = batch_obs[start:stop]
|
244
291
|
acts = batch_acts[start:stop]
|
245
|
-
advantages = batch_advantages[start:stop].to(
|
246
|
-
|
292
|
+
advantages = batch_advantages[start:stop].to(
|
293
|
+
self.config.learner_config.device
|
294
|
+
)
|
295
|
+
old_probs = batch_old_probs[start:stop].to(
|
296
|
+
self.config.learner_config.device
|
297
|
+
)
|
247
298
|
target_values = batch_target_values[start:stop].to(
|
248
|
-
self.config.device
|
299
|
+
self.config.learner_config.device
|
249
300
|
)
|
250
301
|
|
251
302
|
# Compute value estimates.
|
@@ -262,8 +313,8 @@ class PPOLearner(
|
|
262
313
|
ratio = torch.exp(log_probs - old_probs)
|
263
314
|
clipped = torch.clamp(
|
264
315
|
ratio,
|
265
|
-
1.0 - self.config.clip_range,
|
266
|
-
1.0 + self.config.clip_range,
|
316
|
+
1.0 - self.config.learner_config.clip_range,
|
317
|
+
1.0 + self.config.learner_config.clip_range,
|
267
318
|
)
|
268
319
|
|
269
320
|
# Compute KL divergence & clip fraction using SB3 method for reporting.
|
@@ -274,7 +325,10 @@ class PPOLearner(
|
|
274
325
|
|
275
326
|
# From the stable-baselines3 implementation of PPO.
|
276
327
|
clip_fraction = torch.mean(
|
277
|
-
(
|
328
|
+
(
|
329
|
+
torch.abs(ratio - 1)
|
330
|
+
> self.config.learner_config.clip_range
|
331
|
+
).float()
|
278
332
|
).to(device="cpu", non_blocking=True)
|
279
333
|
clip_fractions.append((clip_fraction, minibatch_ratio))
|
280
334
|
|
@@ -285,7 +339,9 @@ class PPOLearner(
|
|
285
339
|
value_loss = (
|
286
340
|
self.critic_loss_fn(vals, target_values) * minibatch_ratio
|
287
341
|
)
|
288
|
-
ppo_loss =
|
342
|
+
ppo_loss = (
|
343
|
+
actor_loss - entropy * self.config.learner_config.ent_coef
|
344
|
+
)
|
289
345
|
|
290
346
|
ppo_loss.backward()
|
291
347
|
value_loss.backward()
|
@@ -312,7 +368,7 @@ class PPOLearner(
|
|
312
368
|
actor_update_magnitude = (actor_before - actor_after).norm().cpu().item()
|
313
369
|
critic_update_magnitude = (critic_before - critic_after).norm().cpu().item()
|
314
370
|
|
315
|
-
if self.config.device != "cpu":
|
371
|
+
if self.config.learner_config.device.type != "cpu":
|
316
372
|
torch.cuda.current_stream().synchronize()
|
317
373
|
|
318
374
|
tot_clip = sum(
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3
3
|
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
|
4
4
|
|
5
5
|
from rlgym.api import ActionType, AgentID, ObsType, RewardType
|
6
|
-
from torch import Tensor
|
6
|
+
from torch import Tensor, device, dtype
|
7
7
|
|
8
8
|
from .trajectory import Trajectory
|
9
9
|
|
@@ -16,8 +16,9 @@ TRAJECTORY_PROCESSOR_FILE = "trajectory_processor.json"
|
|
16
16
|
@dataclass
|
17
17
|
class DerivedTrajectoryProcessorConfig(Generic[TrajectoryProcessorConfig]):
|
18
18
|
trajectory_processor_config: TrajectoryProcessorConfig
|
19
|
-
|
20
|
-
|
19
|
+
agent_controller_name: str
|
20
|
+
dtype: dtype
|
21
|
+
device: device
|
21
22
|
checkpoint_load_folder: Optional[str] = None
|
22
23
|
|
23
24
|
|
Binary file
|
@@ -0,0 +1,118 @@
|
|
1
|
+
from typing import Annotated, Any
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from pydantic import (
|
5
|
+
BaseModel,
|
6
|
+
GetCoreSchemaHandler,
|
7
|
+
GetJsonSchemaHandler,
|
8
|
+
ValidationError,
|
9
|
+
)
|
10
|
+
from pydantic.json_schema import JsonSchemaValue
|
11
|
+
from pydantic_core import core_schema
|
12
|
+
|
13
|
+
dtype_str_regex = "|".join(
|
14
|
+
set(
|
15
|
+
f"({str(v)[6:]})" for v in torch.__dict__.values() if isinstance(v, torch.dtype)
|
16
|
+
)
|
17
|
+
)
|
18
|
+
device_str_regex = (
|
19
|
+
"("
|
20
|
+
+ "|".join(
|
21
|
+
f"({v})"
|
22
|
+
for v in [
|
23
|
+
"cpu",
|
24
|
+
"cuda",
|
25
|
+
"ipu",
|
26
|
+
"xpu",
|
27
|
+
"mkldnn",
|
28
|
+
"opengl",
|
29
|
+
"opencl",
|
30
|
+
"ideep",
|
31
|
+
"hip",
|
32
|
+
"ve",
|
33
|
+
"fpga",
|
34
|
+
"maia",
|
35
|
+
"xla",
|
36
|
+
"lazy",
|
37
|
+
"vulkan",
|
38
|
+
"mps",
|
39
|
+
"meta",
|
40
|
+
"hpu",
|
41
|
+
"mtia",
|
42
|
+
"privateuseone",
|
43
|
+
]
|
44
|
+
)
|
45
|
+
+ ")(:\d+)"
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
# Created using the example here: https://docs.pydantic.dev/latest/concepts/types/#handling-third-party-types
|
50
|
+
class _TorchDtypePydanticAnnotation:
|
51
|
+
@classmethod
|
52
|
+
def __get_pydantic_core_schema__(
|
53
|
+
cls,
|
54
|
+
_source_type: Any,
|
55
|
+
_handler: GetCoreSchemaHandler,
|
56
|
+
) -> core_schema.CoreSchema:
|
57
|
+
from_str_schema = core_schema.chain_schema(
|
58
|
+
[
|
59
|
+
core_schema.str_schema(pattern=dtype_str_regex),
|
60
|
+
core_schema.no_info_plain_validator_function(
|
61
|
+
lambda v: getattr(torch, v)
|
62
|
+
),
|
63
|
+
]
|
64
|
+
)
|
65
|
+
|
66
|
+
return core_schema.json_or_python_schema(
|
67
|
+
json_schema=from_str_schema,
|
68
|
+
python_schema=core_schema.union_schema(
|
69
|
+
[
|
70
|
+
# check if it's an instance first before doing any further work
|
71
|
+
core_schema.is_instance_schema(torch.dtype),
|
72
|
+
from_str_schema,
|
73
|
+
]
|
74
|
+
),
|
75
|
+
serialization=core_schema.plain_serializer_function_ser_schema(
|
76
|
+
lambda v: str(v)[6:]
|
77
|
+
),
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
class _TorchDevicePydanticAnnotation:
|
82
|
+
@classmethod
|
83
|
+
def __get_pydantic_core_schema__(
|
84
|
+
cls,
|
85
|
+
_source_type: Any,
|
86
|
+
_handler: GetCoreSchemaHandler,
|
87
|
+
) -> core_schema.CoreSchema:
|
88
|
+
from_str_schema = core_schema.chain_schema(
|
89
|
+
[
|
90
|
+
core_schema.str_schema(pattern=device_str_regex),
|
91
|
+
core_schema.no_info_plain_validator_function(lambda v: torch.device(v)),
|
92
|
+
]
|
93
|
+
)
|
94
|
+
from_int_schema = core_schema.chain_schema(
|
95
|
+
[
|
96
|
+
core_schema.int_schema(ge=0),
|
97
|
+
core_schema.no_info_plain_validator_function(lambda v: torch.device(v)),
|
98
|
+
]
|
99
|
+
)
|
100
|
+
|
101
|
+
return core_schema.json_or_python_schema(
|
102
|
+
json_schema=from_str_schema,
|
103
|
+
python_schema=core_schema.union_schema(
|
104
|
+
[
|
105
|
+
# check if it's an instance first before doing any further work
|
106
|
+
core_schema.is_instance_schema(torch.dtype),
|
107
|
+
from_str_schema,
|
108
|
+
from_int_schema,
|
109
|
+
]
|
110
|
+
),
|
111
|
+
serialization=core_schema.plain_serializer_function_ser_schema(
|
112
|
+
lambda v: str(v)
|
113
|
+
),
|
114
|
+
)
|
115
|
+
|
116
|
+
|
117
|
+
PydanticTorchDtype = Annotated[torch.dtype, _TorchDtypePydanticAnnotation]
|
118
|
+
PydanticTorchDevice = Annotated[torch.device, _TorchDevicePydanticAnnotation]
|
@@ -1,36 +1,39 @@
|
|
1
|
-
rlgym_learn_algos-0.1.
|
2
|
-
rlgym_learn_algos-0.1.
|
3
|
-
rlgym_learn_algos-0.1.
|
4
|
-
rlgym_learn_algos/
|
5
|
-
rlgym_learn_algos/
|
6
|
-
rlgym_learn_algos/
|
1
|
+
rlgym_learn_algos-0.2.1.dist-info/METADATA,sha256=pRplMtq88vWNms7sVhWJfwf-W7GATsQH3617hrkNl3s,2431
|
2
|
+
rlgym_learn_algos-0.2.1.dist-info/WHEEL,sha256=0kWP4A00z5y7dQaPOt93HDVSEu9o_Xh01gZkicHYRtk,94
|
3
|
+
rlgym_learn_algos-0.2.1.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
4
|
+
rlgym_learn_algos/__init__.py,sha256=C7cRdL4lZrpk3ge_4_lGAbGodqWJXM56FfgO0keRPAY,207
|
5
|
+
rlgym_learn_algos/conversion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
rlgym_learn_algos/conversion/convert_rlgym_ppo_checkpoint.py,sha256=A9nvzjp3DQNRNL5TAt-u3xE80JDIpYEDqAGNReHvFG0,908
|
7
7
|
rlgym_learn_algos/logging/__init__.py,sha256=ouItskWI4ItuoFdL--rt9YXCt7MasA473lYPhmJnrFA,423
|
8
|
+
rlgym_learn_algos/logging/dict_metrics_logger.py,sha256=qmqr0HSiHpm5rjyxfAdmXOeBSbgP_t36-e-enpOccnE,1991
|
9
|
+
rlgym_learn_algos/logging/metrics_logger.py,sha256=0l69GSSrxRcPm0xAjvF7yEIis7jGNu70unXu3hnK0XE,4122
|
10
|
+
rlgym_learn_algos/logging/wandb_metrics_logger.py,sha256=OXyOJzGP4zz0mgy3-FAvR6LW7aZet3Ii8CsI5csw4c4,7051
|
11
|
+
rlgym_learn_algos/ppo/__init__.py,sha256=o6B8wCRfeyipSNEGJFyB3SHYmxUytaQelX2zsted5cg,1184
|
8
12
|
rlgym_learn_algos/ppo/actor.py,sha256=LZevg0kqRrb4PwF05ePK9b1JIBX04YkWjsPs7swZ9JY,1767
|
9
13
|
rlgym_learn_algos/ppo/basic_critic.py,sha256=oyyo8x9K6mi2BsbA6_tRy2Av8Pimb35WspJkPpe8XdQ,1022
|
10
14
|
rlgym_learn_algos/ppo/continuous_actor.py,sha256=1vdBUw2mQNFNu6A6ZrAztBjd4DmwjGkIIFLboMZ02lc,4417
|
11
15
|
rlgym_learn_algos/ppo/critic.py,sha256=RB89WtiN52BEq5QCpGAPrASUnasac-Bpg7B0lM3UXHw,689
|
12
16
|
rlgym_learn_algos/ppo/discrete_actor.py,sha256=Nuc3EndIQud3NGrkBIQgy-Z-okhXVrj6p6okSGD1KNY,2620
|
13
17
|
rlgym_learn_algos/ppo/env_trajectories.py,sha256=gzQBRkzwZhlZeSvWL50cc8AOgBfsg5zUys0aTJj6aZU,3775
|
14
|
-
rlgym_learn_algos/ppo/experience_buffer.py,sha256=
|
15
|
-
rlgym_learn_algos/ppo/experience_buffer_numpy.py,sha256=
|
16
|
-
rlgym_learn_algos/ppo/gae_trajectory_processor.py,sha256=
|
17
|
-
rlgym_learn_algos/ppo/gae_trajectory_processor_pure_python.py,sha256=
|
18
|
+
rlgym_learn_algos/ppo/experience_buffer.py,sha256=QdyFMMM8YpEYrmtFaeaHXvFlNT2pCZwQKBEqsrv4v2I,10838
|
19
|
+
rlgym_learn_algos/ppo/experience_buffer_numpy.py,sha256=Apk4x-pfRnitKJPW6LBZyOPIhgeJs_5EG7BbTCqMwjk,4761
|
20
|
+
rlgym_learn_algos/ppo/gae_trajectory_processor.py,sha256=JK958vasIIiuf3ALcFNlvBgGNhFshK8MhQJjwvxhrAM,5453
|
21
|
+
rlgym_learn_algos/ppo/gae_trajectory_processor_pure_python.py,sha256=RpyDR6GQ1JXvwtoKkx5V3z3WvU9ElJdzfNtpPiZDaTc,6831
|
18
22
|
rlgym_learn_algos/ppo/multi_discrete_actor.py,sha256=zSYeBBirjguSv_wO-peo06hioHiVhZQjnd-NYwJxmag,3127
|
19
|
-
rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=
|
20
|
-
rlgym_learn_algos/ppo/ppo_learner.py,sha256=
|
23
|
+
rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=h0UR-o2k-_LyeFTzvII3HQHHWyeMJewqLlca8ThtyfA,25105
|
24
|
+
rlgym_learn_algos/ppo/ppo_learner.py,sha256=3YTfs7LhjiJ0u3-k84rYWcmAQxKIf2yp1i1UVY4v8Oc,15229
|
21
25
|
rlgym_learn_algos/ppo/ppo_metrics_logger.py,sha256=niW8xgQLEBCGgTaVyiE_JqsU6RTjV6h-JzM-7c3JT38,2868
|
22
26
|
rlgym_learn_algos/ppo/trajectory.py,sha256=IIH_IG8B_HkyxRPf-YsCyF1jQqNjDx752hgzAehG25I,719
|
23
|
-
rlgym_learn_algos/ppo/trajectory_processor.py,sha256=
|
24
|
-
rlgym_learn_algos/ppo/__init__.py,sha256=o6B8wCRfeyipSNEGJFyB3SHYmxUytaQelX2zsted5cg,1184
|
27
|
+
rlgym_learn_algos/ppo/trajectory_processor.py,sha256=5eY_mNGjqIkhqnbKeaqDvqIWPdg6wD6Ai3fXH2WoXbw,2091
|
25
28
|
rlgym_learn_algos/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
+
rlgym_learn_algos/rlgym_learn_algos.cp39-win_amd64.pyd,sha256=ErFmDw0lBLgYM4TabNWN-hgm_WHqN_cgFtH7XEI1FtU,416256
|
26
30
|
rlgym_learn_algos/rlgym_learn_algos.pyi,sha256=NwY-sDZWM06TUiKPzxpfH1Td6G6E8TdxtRPgBSh-PPE,1203
|
31
|
+
rlgym_learn_algos/stateful_functions/__init__.py,sha256=QS0KYjuzagNkYiYllXQmjoJn14-G7KZawq1Zvwh8alY,236
|
27
32
|
rlgym_learn_algos/stateful_functions/batch_reward_type_numpy_converter.py,sha256=1yte5qYyl9LWdClHZ_YsF7R9dJqQeYfINMdgNF_59Gs,767
|
28
33
|
rlgym_learn_algos/stateful_functions/numpy_obs_standardizer.py,sha256=OgtwCaxBGTySPMnE5D5VDKpJ0dsTEz9oHc08A96xRao,1604
|
29
34
|
rlgym_learn_algos/stateful_functions/obs_standardizer.py,sha256=qPPc3--J_3mpJJ-QHJjta6dbWWBobL7SYdK5MUP-XMw,606
|
30
|
-
rlgym_learn_algos/
|
35
|
+
rlgym_learn_algos/util/__init__.py,sha256=VPM6SN4T_625H9t30s9EiLeXiEEWgcyRVHa-LLVNrn4,47
|
31
36
|
rlgym_learn_algos/util/running_stats.py,sha256=0tiGFpKtHWzMa1CxM_ueBzd_ryX4bJBriC8MXcSLg8w,4479
|
32
37
|
rlgym_learn_algos/util/torch_functions.py,sha256=CTTHzTIi7u1O9HyX0cVJOrnYVbAtnlVs0g1fO9s3ano,3458
|
33
|
-
rlgym_learn_algos/util/
|
34
|
-
rlgym_learn_algos/
|
35
|
-
rlgym_learn_algos/rlgym_learn_algos.cp39-win_amd64.pyd,sha256=lkUelOB7xpKkM0_kK7AeMahzsksYOxSNiYvpWF3N688,416256
|
36
|
-
rlgym_learn_algos-0.1.5.dist-info/RECORD,,
|
38
|
+
rlgym_learn_algos/util/torch_pydantic.py,sha256=pgj3I-3q8iW9qtOCv1fgjNkZgA00G_Rdkb4qJPk5gxo,3530
|
39
|
+
rlgym_learn_algos-0.2.1.dist-info/RECORD,,
|
File without changes
|