rlgym-learn-algos 0.2.2__cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl → 0.2.4__cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rlgym_learn_algos/ppo/ppo_agent_controller.py +37 -33
- rlgym_learn_algos/rlgym_learn_algos.cpython-312-arm-linux-gnueabihf.so +0 -0
- {rlgym_learn_algos-0.2.2.dist-info → rlgym_learn_algos-0.2.4.dist-info}/METADATA +1 -1
- {rlgym_learn_algos-0.2.2.dist-info → rlgym_learn_algos-0.2.4.dist-info}/RECORD +6 -6
- {rlgym_learn_algos-0.2.2.dist-info → rlgym_learn_algos-0.2.4.dist-info}/WHEEL +0 -0
- {rlgym_learn_algos-0.2.2.dist-info → rlgym_learn_algos-0.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -24,6 +24,8 @@ from rlgym.api import (
|
|
24
24
|
)
|
25
25
|
from rlgym_learn import EnvActionResponse, EnvActionResponseType, Timestep
|
26
26
|
from rlgym_learn.api.agent_controller import AgentController
|
27
|
+
from torch import device as _device
|
28
|
+
|
27
29
|
from rlgym_learn_algos.logging import (
|
28
30
|
DerivedMetricsLoggerConfig,
|
29
31
|
MetricsLogger,
|
@@ -34,7 +36,6 @@ from rlgym_learn_algos.logging import (
|
|
34
36
|
)
|
35
37
|
from rlgym_learn_algos.stateful_functions import ObsStandardizer
|
36
38
|
from rlgym_learn_algos.util.torch_functions import get_device
|
37
|
-
from torch import device as _device
|
38
39
|
|
39
40
|
from .actor import Actor
|
40
41
|
from .critic import Critic
|
@@ -57,8 +58,8 @@ EXPERIENCE_BUFFER_FOLDER = "experience_buffer"
|
|
57
58
|
PPO_LEARNER_FOLDER = "ppo_learner"
|
58
59
|
METRICS_LOGGER_FOLDER = "metrics_logger"
|
59
60
|
PPO_AGENT_FILE = "ppo_agent.json"
|
61
|
+
ITERATION_TRAJECTORIES_FILE = "current_trajectories.pkl" # this should be renamed, but it would be a breaking change so I'm leaving it until I happen to make one of those and remember to update this at the same time
|
60
62
|
ITERATION_SHARED_INFOS_FILE = "iteration_shared_infos.pkl"
|
61
|
-
CURRENT_TRAJECTORIES_FILE = "current_trajectories.pkl"
|
62
63
|
|
63
64
|
|
64
65
|
class PPOAgentControllerConfigModel(BaseModel, extra="forbid"):
|
@@ -68,6 +69,7 @@ class PPOAgentControllerConfigModel(BaseModel, extra="forbid"):
|
|
68
69
|
checkpoint_load_folder: Optional[str] = None
|
69
70
|
n_checkpoints_to_keep: int = 5
|
70
71
|
random_seed: int = 123
|
72
|
+
save_mid_iteration_data_in_checkpoint: bool = True
|
71
73
|
learner_config: PPOLearnerConfigModel = Field(default_factory=PPOLearnerConfigModel)
|
72
74
|
experience_buffer_config: ExperienceBufferConfigModel = Field(
|
73
75
|
default_factory=ExperienceBufferConfigModel
|
@@ -166,7 +168,7 @@ class PPOAgentController(
|
|
166
168
|
str,
|
167
169
|
EnvTrajectories[AgentID, ActionType, ObsType, RewardType],
|
168
170
|
] = {}
|
169
|
-
self.
|
171
|
+
self.iteration_trajectories: List[
|
170
172
|
Trajectory[AgentID, ActionType, ObsType, RewardType]
|
171
173
|
] = []
|
172
174
|
self.iteration_shared_infos: List[Dict[str, Any]] = []
|
@@ -309,19 +311,18 @@ class PPOAgentController(
|
|
309
311
|
with open(
|
310
312
|
os.path.join(
|
311
313
|
self.config.agent_controller_config.checkpoint_load_folder,
|
312
|
-
|
314
|
+
ITERATION_TRAJECTORIES_FILE,
|
313
315
|
),
|
314
316
|
"rb",
|
315
317
|
) as f:
|
316
|
-
|
317
|
-
|
318
|
-
EnvTrajectories[AgentID, ActionType, ObsType, RewardType],
|
318
|
+
iteration_trajectories: List[
|
319
|
+
Trajectory[AgentID, ObsType, ActionType, RewardType]
|
319
320
|
] = pickle.load(f)
|
320
321
|
except FileNotFoundError:
|
321
322
|
print(
|
322
|
-
f"{self.config.agent_controller_name}: Tried to load current trajectories from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder,
|
323
|
+
f"{self.config.agent_controller_name}: Tried to load current trajectories from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder, ITERATION_TRAJECTORIES_FILE))}, but there is no such file! Current trajectories will be initialized as an empty list instead."
|
323
324
|
)
|
324
|
-
|
325
|
+
iteration_trajectories = []
|
325
326
|
try:
|
326
327
|
with open(
|
327
328
|
os.path.join(
|
@@ -335,7 +336,7 @@ class PPOAgentController(
|
|
335
336
|
print(
|
336
337
|
f"{self.config.agent_controller_name}: Tried to load iteration shared info data from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder, ITERATION_SHARED_INFOS_FILE))}, but there is no such file! Iteration shared info data will be initialized as an empty list instead."
|
337
338
|
)
|
338
|
-
|
339
|
+
iteration_shared_infos = []
|
339
340
|
try:
|
340
341
|
with open(
|
341
342
|
os.path.join(
|
@@ -357,7 +358,7 @@ class PPOAgentController(
|
|
357
358
|
"timestep_collection_start_time": time.perf_counter(),
|
358
359
|
}
|
359
360
|
|
360
|
-
self.
|
361
|
+
self.iteration_trajectories = iteration_trajectories
|
361
362
|
self.iteration_shared_infos = iteration_shared_infos
|
362
363
|
self.cur_iteration = state["cur_iteration"]
|
363
364
|
self.iteration_timesteps = state["iteration_timesteps"]
|
@@ -379,20 +380,22 @@ class PPOAgentController(
|
|
379
380
|
self.experience_buffer.save_checkpoint(
|
380
381
|
os.path.join(checkpoint_save_folder, EXPERIENCE_BUFFER_FOLDER)
|
381
382
|
)
|
382
|
-
self.metrics_logger
|
383
|
-
|
384
|
-
|
383
|
+
if self.metrics_logger is not None:
|
384
|
+
self.metrics_logger.save_checkpoint(
|
385
|
+
os.path.join(checkpoint_save_folder, METRICS_LOGGER_FOLDER)
|
386
|
+
)
|
385
387
|
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
388
|
+
if self.config.agent_controller_config.save_mid_iteration_data_in_checkpoint:
|
389
|
+
with open(
|
390
|
+
os.path.join(checkpoint_save_folder, ITERATION_TRAJECTORIES_FILE),
|
391
|
+
"wb",
|
392
|
+
) as f:
|
393
|
+
pickle.dump(self.iteration_trajectories, f)
|
394
|
+
with open(
|
395
|
+
os.path.join(checkpoint_save_folder, ITERATION_SHARED_INFOS_FILE),
|
396
|
+
"wb",
|
397
|
+
) as f:
|
398
|
+
pickle.dump(self.iteration_shared_infos, f)
|
396
399
|
with open(os.path.join(checkpoint_save_folder, PPO_AGENT_FILE), "wt") as f:
|
397
400
|
state = {
|
398
401
|
"cur_iteration": self.cur_iteration,
|
@@ -403,6 +406,7 @@ class PPOAgentController(
|
|
403
406
|
}
|
404
407
|
json.dump(state, f, indent=4)
|
405
408
|
|
409
|
+
# TODO: does this actually work? I'm not sure the file structure I'm using actually works with this assumption
|
406
410
|
# Prune old checkpoints
|
407
411
|
existing_checkpoints = [
|
408
412
|
int(arg) for arg in os.listdir(self.checkpoints_save_folder)
|
@@ -506,7 +510,7 @@ class PPOAgentController(
|
|
506
510
|
elif enum_type == EnvActionResponseType.RESET:
|
507
511
|
env_trajectories = self.current_env_trajectories.pop(env_id)
|
508
512
|
env_trajectories.finalize()
|
509
|
-
self.
|
513
|
+
self.iteration_trajectories += env_trajectories.get_trajectories()
|
510
514
|
elif enum_type == EnvActionResponseType.SET_STATE:
|
511
515
|
# Can get the desired_state using env_action.desired_state and the prev_timestep_id_dict using env_action.prev_timestep_id_dict, but I'll leave that to you
|
512
516
|
raise NotImplementedError
|
@@ -517,10 +521,10 @@ class PPOAgentController(
|
|
517
521
|
env_trajectories_list = list(self.current_env_trajectories.values())
|
518
522
|
for env_trajectories in env_trajectories_list:
|
519
523
|
env_trajectories.finalize()
|
520
|
-
self.
|
524
|
+
self.iteration_trajectories += env_trajectories.get_trajectories()
|
521
525
|
self._update_value_predictions()
|
522
526
|
trajectory_processor_data = self.experience_buffer.submit_experience(
|
523
|
-
self.
|
527
|
+
self.iteration_trajectories
|
524
528
|
)
|
525
529
|
ppo_data = self.learner.learn(self.experience_buffer)
|
526
530
|
|
@@ -540,9 +544,9 @@ class PPOAgentController(
|
|
540
544
|
self.metrics_logger.collect_env_metrics(self.iteration_shared_infos)
|
541
545
|
self.metrics_logger.report_metrics()
|
542
546
|
|
543
|
-
self.
|
547
|
+
self.iteration_trajectories.clear()
|
548
|
+
self.iteration_shared_infos.clear()
|
544
549
|
self.current_env_trajectories.clear()
|
545
|
-
self.current_trajectories.clear()
|
546
550
|
self.ts_since_last_save += self.iteration_timesteps
|
547
551
|
self.iteration_timesteps = 0
|
548
552
|
self.iteration_start_time = cur_time
|
@@ -551,14 +555,14 @@ class PPOAgentController(
|
|
551
555
|
@torch.no_grad()
|
552
556
|
def _update_value_predictions(self):
|
553
557
|
"""
|
554
|
-
Function to update the value predictions inside the Trajectory instances of self.
|
558
|
+
Function to update the value predictions inside the Trajectory instances of self.iteration_trajectories
|
555
559
|
"""
|
556
560
|
traj_timestep_idx_ranges: List[Tuple[int, int]] = []
|
557
561
|
start = 0
|
558
562
|
stop = 0
|
559
563
|
critic_agent_id_input: List[AgentID] = []
|
560
564
|
critic_obs_input: List[ObsType] = []
|
561
|
-
for trajectory in self.
|
565
|
+
for trajectory in self.iteration_trajectories:
|
562
566
|
obs_list = trajectory.obs_list + [trajectory.final_obs]
|
563
567
|
traj_len = len(obs_list)
|
564
568
|
agent_id_list = [trajectory.agent_id] * traj_len
|
@@ -575,7 +579,7 @@ class PPOAgentController(
|
|
575
579
|
)
|
576
580
|
torch.cuda.empty_cache()
|
577
581
|
for idx, (start, stop) in enumerate(traj_timestep_idx_ranges):
|
578
|
-
self.
|
579
|
-
self.
|
582
|
+
self.iteration_trajectories[idx].val_preds = val_preds[start : stop - 1]
|
583
|
+
self.iteration_trajectories[idx].final_val_pred = val_preds[stop - 1]
|
580
584
|
if self.config.agent_controller_config.learner_config.device.type != "cpu":
|
581
585
|
torch.cuda.current_stream().synchronize()
|
Binary file
|
@@ -1,6 +1,6 @@
|
|
1
|
-
rlgym_learn_algos-0.2.
|
2
|
-
rlgym_learn_algos-0.2.
|
3
|
-
rlgym_learn_algos-0.2.
|
1
|
+
rlgym_learn_algos-0.2.4.dist-info/METADATA,sha256=KldIto2nUjijheVI6OpfvsBKYxgWYCbOFclqfwp13ys,2403
|
2
|
+
rlgym_learn_algos-0.2.4.dist-info/WHEEL,sha256=jlKmy-EzabauTasu_10kpmAYQKA8Q3THFJEGdKg_ucM,129
|
3
|
+
rlgym_learn_algos-0.2.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
4
4
|
rlgym_learn_algos/__init__.py,sha256=dZeTgNro6qG1Hu0l0UBhgHOYiyeCwPWndC84dJAp__U,203
|
5
5
|
rlgym_learn_algos/conversion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
rlgym_learn_algos/conversion/convert_rlgym_ppo_checkpoint.py,sha256=6Tj4KezPL1DFSjCwZPCbyaYFdp3-RHJ_ft9iDrFqg2I,881
|
@@ -20,13 +20,13 @@ rlgym_learn_algos/ppo/experience_buffer_numpy.py,sha256=6AOGQjDn_dHLS9bmxJW_cGEj
|
|
20
20
|
rlgym_learn_algos/ppo/gae_trajectory_processor.py,sha256=gv5kxvvPnK7SyQIAq6MbOFILIMdPlzoLZwM8TRmtNWw,5302
|
21
21
|
rlgym_learn_algos/ppo/gae_trajectory_processor_pure_python.py,sha256=cq7qbK0mcLDXRzA6-pKW0OC50X52XhT5himcOTD6Ei4,6657
|
22
22
|
rlgym_learn_algos/ppo/multi_discrete_actor.py,sha256=Mik0X79dUy2ZRIMol4RMTZE9qzsOk6f_6bDaOl5ghxs,3039
|
23
|
-
rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=
|
23
|
+
rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=to807i7Nm7FMA0zT8m9VWTBZz7pxhL-W8JLBM4OFuc0,25051
|
24
24
|
rlgym_learn_algos/ppo/ppo_learner.py,sha256=utEWkikXCpC6Xc1D3asohO0HsIaq3tLyoTlb7fXLOw4,15522
|
25
25
|
rlgym_learn_algos/ppo/ppo_metrics_logger.py,sha256=iUyUc2GPwDIIjZeJPZWxoeRrzUWV_qLOac0vApQBkp0,2803
|
26
26
|
rlgym_learn_algos/ppo/trajectory.py,sha256=_xyS9ueU6iVvqMUpFr-kb42wEHHZy4zCse7_r660n5E,690
|
27
27
|
rlgym_learn_algos/ppo/trajectory_processor.py,sha256=3XRsXXexHWp6UV5nAeBLYvWqvQ9EbNHSN3Yooi4cezo,2031
|
28
28
|
rlgym_learn_algos/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
29
|
-
rlgym_learn_algos/rlgym_learn_algos.cpython-312-arm-linux-gnueabihf.so,sha256=
|
29
|
+
rlgym_learn_algos/rlgym_learn_algos.cpython-312-arm-linux-gnueabihf.so,sha256=MKww29COmLuTgiD0gKZIXUhx7ML6F4WPgUf3ZhNqKys,732012
|
30
30
|
rlgym_learn_algos/rlgym_learn_algos.pyi,sha256=B9Kt9uK8xCqASRxWvzLdV501TSCMO4vTNqvZ0MhOHyo,1164
|
31
31
|
rlgym_learn_algos/stateful_functions/__init__.py,sha256=OAVy6cQIS85Utyp18jjHgdmascX_8nkwk3A0OpFJxT4,230
|
32
32
|
rlgym_learn_algos/stateful_functions/batch_reward_type_numpy_converter.py,sha256=a3q2l5SIgDI36ImF_kYoa684pghnFnlV2vGYvV2zcV0,743
|
@@ -36,4 +36,4 @@ rlgym_learn_algos/util/__init__.py,sha256=hq7M00Q7zAfyQmIGmXOif0vI40aj_FQ5SqI5dn
|
|
36
36
|
rlgym_learn_algos/util/running_stats.py,sha256=KtzdKKT75-5ZC58JRqaDXk6sBqa3ZSjQQZrRajAw3Yk,4339
|
37
37
|
rlgym_learn_algos/util/torch_functions.py,sha256=ImgDw4I3ZixGDi17YRkW6UbaiaQTbvOCUCS7N0QVSsU,3320
|
38
38
|
rlgym_learn_algos/util/torch_pydantic.py,sha256=khPGA6kWh4_WHoploDkl_SCIGX8SkKkFT40RE06PImc,3413
|
39
|
-
rlgym_learn_algos-0.2.
|
39
|
+
rlgym_learn_algos-0.2.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|