rlgym-learn-algos 0.2.2__cp39-cp39-musllinux_1_2_armv7l.whl → 0.2.4__cp39-cp39-musllinux_1_2_armv7l.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,8 @@ from rlgym.api import (
24
24
  )
25
25
  from rlgym_learn import EnvActionResponse, EnvActionResponseType, Timestep
26
26
  from rlgym_learn.api.agent_controller import AgentController
27
+ from torch import device as _device
28
+
27
29
  from rlgym_learn_algos.logging import (
28
30
  DerivedMetricsLoggerConfig,
29
31
  MetricsLogger,
@@ -34,7 +36,6 @@ from rlgym_learn_algos.logging import (
34
36
  )
35
37
  from rlgym_learn_algos.stateful_functions import ObsStandardizer
36
38
  from rlgym_learn_algos.util.torch_functions import get_device
37
- from torch import device as _device
38
39
 
39
40
  from .actor import Actor
40
41
  from .critic import Critic
@@ -57,8 +58,8 @@ EXPERIENCE_BUFFER_FOLDER = "experience_buffer"
57
58
  PPO_LEARNER_FOLDER = "ppo_learner"
58
59
  METRICS_LOGGER_FOLDER = "metrics_logger"
59
60
  PPO_AGENT_FILE = "ppo_agent.json"
61
+ ITERATION_TRAJECTORIES_FILE = "current_trajectories.pkl" # this should be renamed, but it would be a breaking change so I'm leaving it until I happen to make one of those and remember to update this at the same time
60
62
  ITERATION_SHARED_INFOS_FILE = "iteration_shared_infos.pkl"
61
- CURRENT_TRAJECTORIES_FILE = "current_trajectories.pkl"
62
63
 
63
64
 
64
65
  class PPOAgentControllerConfigModel(BaseModel, extra="forbid"):
@@ -68,6 +69,7 @@ class PPOAgentControllerConfigModel(BaseModel, extra="forbid"):
68
69
  checkpoint_load_folder: Optional[str] = None
69
70
  n_checkpoints_to_keep: int = 5
70
71
  random_seed: int = 123
72
+ save_mid_iteration_data_in_checkpoint: bool = True
71
73
  learner_config: PPOLearnerConfigModel = Field(default_factory=PPOLearnerConfigModel)
72
74
  experience_buffer_config: ExperienceBufferConfigModel = Field(
73
75
  default_factory=ExperienceBufferConfigModel
@@ -166,7 +168,7 @@ class PPOAgentController(
166
168
  str,
167
169
  EnvTrajectories[AgentID, ActionType, ObsType, RewardType],
168
170
  ] = {}
169
- self.current_trajectories: List[
171
+ self.iteration_trajectories: List[
170
172
  Trajectory[AgentID, ActionType, ObsType, RewardType]
171
173
  ] = []
172
174
  self.iteration_shared_infos: List[Dict[str, Any]] = []
@@ -309,19 +311,18 @@ class PPOAgentController(
309
311
  with open(
310
312
  os.path.join(
311
313
  self.config.agent_controller_config.checkpoint_load_folder,
312
- CURRENT_TRAJECTORIES_FILE,
314
+ ITERATION_TRAJECTORIES_FILE,
313
315
  ),
314
316
  "rb",
315
317
  ) as f:
316
- current_trajectories: Dict[
317
- int,
318
- EnvTrajectories[AgentID, ActionType, ObsType, RewardType],
318
+ iteration_trajectories: List[
319
+ Trajectory[AgentID, ObsType, ActionType, RewardType]
319
320
  ] = pickle.load(f)
320
321
  except FileNotFoundError:
321
322
  print(
322
- f"{self.config.agent_controller_name}: Tried to load current trajectories from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder, CURRENT_TRAJECTORIES_FILE))}, but there is no such file! Current trajectories will be initialized as an empty dict instead."
323
+ f"{self.config.agent_controller_name}: Tried to load current trajectories from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder, ITERATION_TRAJECTORIES_FILE))}, but there is no such file! Current trajectories will be initialized as an empty list instead."
323
324
  )
324
- current_trajectories = {}
325
+ iteration_trajectories = []
325
326
  try:
326
327
  with open(
327
328
  os.path.join(
@@ -335,7 +336,7 @@ class PPOAgentController(
335
336
  print(
336
337
  f"{self.config.agent_controller_name}: Tried to load iteration shared info data from checkpoint using the file at location {str(os.path.join(self.config.agent_controller_config.checkpoint_load_folder, ITERATION_SHARED_INFOS_FILE))}, but there is no such file! Iteration shared info data will be initialized as an empty list instead."
337
338
  )
338
- current_trajectories = {}
339
+ iteration_shared_infos = []
339
340
  try:
340
341
  with open(
341
342
  os.path.join(
@@ -357,7 +358,7 @@ class PPOAgentController(
357
358
  "timestep_collection_start_time": time.perf_counter(),
358
359
  }
359
360
 
360
- self.current_trajectories = current_trajectories
361
+ self.iteration_trajectories = iteration_trajectories
361
362
  self.iteration_shared_infos = iteration_shared_infos
362
363
  self.cur_iteration = state["cur_iteration"]
363
364
  self.iteration_timesteps = state["iteration_timesteps"]
@@ -379,20 +380,22 @@ class PPOAgentController(
379
380
  self.experience_buffer.save_checkpoint(
380
381
  os.path.join(checkpoint_save_folder, EXPERIENCE_BUFFER_FOLDER)
381
382
  )
382
- self.metrics_logger.save_checkpoint(
383
- os.path.join(checkpoint_save_folder, METRICS_LOGGER_FOLDER)
384
- )
383
+ if self.metrics_logger is not None:
384
+ self.metrics_logger.save_checkpoint(
385
+ os.path.join(checkpoint_save_folder, METRICS_LOGGER_FOLDER)
386
+ )
385
387
 
386
- with open(
387
- os.path.join(checkpoint_save_folder, CURRENT_TRAJECTORIES_FILE),
388
- "wb",
389
- ) as f:
390
- pickle.dump(self.current_trajectories, f)
391
- with open(
392
- os.path.join(checkpoint_save_folder, ITERATION_SHARED_INFOS_FILE),
393
- "wb",
394
- ) as f:
395
- pickle.dump(self.iteration_shared_infos, f)
388
+ if self.config.agent_controller_config.save_mid_iteration_data_in_checkpoint:
389
+ with open(
390
+ os.path.join(checkpoint_save_folder, ITERATION_TRAJECTORIES_FILE),
391
+ "wb",
392
+ ) as f:
393
+ pickle.dump(self.iteration_trajectories, f)
394
+ with open(
395
+ os.path.join(checkpoint_save_folder, ITERATION_SHARED_INFOS_FILE),
396
+ "wb",
397
+ ) as f:
398
+ pickle.dump(self.iteration_shared_infos, f)
396
399
  with open(os.path.join(checkpoint_save_folder, PPO_AGENT_FILE), "wt") as f:
397
400
  state = {
398
401
  "cur_iteration": self.cur_iteration,
@@ -403,6 +406,7 @@ class PPOAgentController(
403
406
  }
404
407
  json.dump(state, f, indent=4)
405
408
 
409
+ # TODO: does this actually work? I'm not sure the file structure I'm using actually works with this assumption
406
410
  # Prune old checkpoints
407
411
  existing_checkpoints = [
408
412
  int(arg) for arg in os.listdir(self.checkpoints_save_folder)
@@ -506,7 +510,7 @@ class PPOAgentController(
506
510
  elif enum_type == EnvActionResponseType.RESET:
507
511
  env_trajectories = self.current_env_trajectories.pop(env_id)
508
512
  env_trajectories.finalize()
509
- self.current_trajectories += env_trajectories.get_trajectories()
513
+ self.iteration_trajectories += env_trajectories.get_trajectories()
510
514
  elif enum_type == EnvActionResponseType.SET_STATE:
511
515
  # Can get the desired_state using env_action.desired_state and the prev_timestep_id_dict using env_action.prev_timestep_id_dict, but I'll leave that to you
512
516
  raise NotImplementedError
@@ -517,10 +521,10 @@ class PPOAgentController(
517
521
  env_trajectories_list = list(self.current_env_trajectories.values())
518
522
  for env_trajectories in env_trajectories_list:
519
523
  env_trajectories.finalize()
520
- self.current_trajectories += env_trajectories.get_trajectories()
524
+ self.iteration_trajectories += env_trajectories.get_trajectories()
521
525
  self._update_value_predictions()
522
526
  trajectory_processor_data = self.experience_buffer.submit_experience(
523
- self.current_trajectories
527
+ self.iteration_trajectories
524
528
  )
525
529
  ppo_data = self.learner.learn(self.experience_buffer)
526
530
 
@@ -540,9 +544,9 @@ class PPOAgentController(
540
544
  self.metrics_logger.collect_env_metrics(self.iteration_shared_infos)
541
545
  self.metrics_logger.report_metrics()
542
546
 
543
- self.iteration_shared_infos = []
547
+ self.iteration_trajectories.clear()
548
+ self.iteration_shared_infos.clear()
544
549
  self.current_env_trajectories.clear()
545
- self.current_trajectories.clear()
546
550
  self.ts_since_last_save += self.iteration_timesteps
547
551
  self.iteration_timesteps = 0
548
552
  self.iteration_start_time = cur_time
@@ -551,14 +555,14 @@ class PPOAgentController(
551
555
  @torch.no_grad()
552
556
  def _update_value_predictions(self):
553
557
  """
554
- Function to update the value predictions inside the Trajectory instances of self.current_trajectories
558
+ Function to update the value predictions inside the Trajectory instances of self.iteration_trajectories
555
559
  """
556
560
  traj_timestep_idx_ranges: List[Tuple[int, int]] = []
557
561
  start = 0
558
562
  stop = 0
559
563
  critic_agent_id_input: List[AgentID] = []
560
564
  critic_obs_input: List[ObsType] = []
561
- for trajectory in self.current_trajectories:
565
+ for trajectory in self.iteration_trajectories:
562
566
  obs_list = trajectory.obs_list + [trajectory.final_obs]
563
567
  traj_len = len(obs_list)
564
568
  agent_id_list = [trajectory.agent_id] * traj_len
@@ -575,7 +579,7 @@ class PPOAgentController(
575
579
  )
576
580
  torch.cuda.empty_cache()
577
581
  for idx, (start, stop) in enumerate(traj_timestep_idx_ranges):
578
- self.current_trajectories[idx].val_preds = val_preds[start : stop - 1]
579
- self.current_trajectories[idx].final_val_pred = val_preds[stop - 1]
582
+ self.iteration_trajectories[idx].val_preds = val_preds[start : stop - 1]
583
+ self.iteration_trajectories[idx].final_val_pred = val_preds[stop - 1]
580
584
  if self.config.agent_controller_config.learner_config.device.type != "cpu":
581
585
  torch.cuda.current_stream().synchronize()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rlgym-learn-algos
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Requires-Dist: pydantic>=2.8.2
@@ -1,6 +1,6 @@
1
- rlgym_learn_algos-0.2.2.dist-info/METADATA,sha256=9m49XNjxVc1H0Hj27PV8BszTk9DnOBNBNIjdxb0uwcQ,2403
2
- rlgym_learn_algos-0.2.2.dist-info/WHEEL,sha256=ZaFOerxeFPIN7Ome868VEyCSJvdCwZINrGvZscX7-b8,105
3
- rlgym_learn_algos-0.2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
1
+ rlgym_learn_algos-0.2.4.dist-info/METADATA,sha256=KldIto2nUjijheVI6OpfvsBKYxgWYCbOFclqfwp13ys,2403
2
+ rlgym_learn_algos-0.2.4.dist-info/WHEEL,sha256=ZaFOerxeFPIN7Ome868VEyCSJvdCwZINrGvZscX7-b8,105
3
+ rlgym_learn_algos-0.2.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
4
4
  rlgym_learn_algos.libs/libgcc_s-5b5488a6.so.1,sha256=HGKUsVmTeNAxEdSy7Ua5Vh_I9FN3RCbPWzvZ7H_TrwE,2749061
5
5
  rlgym_learn_algos/__init__.py,sha256=dZeTgNro6qG1Hu0l0UBhgHOYiyeCwPWndC84dJAp__U,203
6
6
  rlgym_learn_algos/conversion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -21,13 +21,13 @@ rlgym_learn_algos/ppo/experience_buffer_numpy.py,sha256=6AOGQjDn_dHLS9bmxJW_cGEj
21
21
  rlgym_learn_algos/ppo/gae_trajectory_processor.py,sha256=gv5kxvvPnK7SyQIAq6MbOFILIMdPlzoLZwM8TRmtNWw,5302
22
22
  rlgym_learn_algos/ppo/gae_trajectory_processor_pure_python.py,sha256=cq7qbK0mcLDXRzA6-pKW0OC50X52XhT5himcOTD6Ei4,6657
23
23
  rlgym_learn_algos/ppo/multi_discrete_actor.py,sha256=Mik0X79dUy2ZRIMol4RMTZE9qzsOk6f_6bDaOl5ghxs,3039
24
- rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=iVmCvN7H1IcKX7VrZnMParQ148EnPqs9yZ9CCgtdsq8,24524
24
+ rlgym_learn_algos/ppo/ppo_agent_controller.py,sha256=to807i7Nm7FMA0zT8m9VWTBZz7pxhL-W8JLBM4OFuc0,25051
25
25
  rlgym_learn_algos/ppo/ppo_learner.py,sha256=utEWkikXCpC6Xc1D3asohO0HsIaq3tLyoTlb7fXLOw4,15522
26
26
  rlgym_learn_algos/ppo/ppo_metrics_logger.py,sha256=iUyUc2GPwDIIjZeJPZWxoeRrzUWV_qLOac0vApQBkp0,2803
27
27
  rlgym_learn_algos/ppo/trajectory.py,sha256=_xyS9ueU6iVvqMUpFr-kb42wEHHZy4zCse7_r660n5E,690
28
28
  rlgym_learn_algos/ppo/trajectory_processor.py,sha256=3XRsXXexHWp6UV5nAeBLYvWqvQ9EbNHSN3Yooi4cezo,2031
29
29
  rlgym_learn_algos/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- rlgym_learn_algos/rlgym_learn_algos.cpython-39-arm-linux-gnueabihf.so,sha256=U0IsbCcmaF2TxVJxfjuICaeu9iHELIxOFn7rK-qhLs0,739981
30
+ rlgym_learn_algos/rlgym_learn_algos.cpython-39-arm-linux-gnueabihf.so,sha256=uvZb1wCJqEBZozUM2MxXf4OcgqDLlVZiwHOQ1Z6fXSs,739981
31
31
  rlgym_learn_algos/rlgym_learn_algos.pyi,sha256=B9Kt9uK8xCqASRxWvzLdV501TSCMO4vTNqvZ0MhOHyo,1164
32
32
  rlgym_learn_algos/stateful_functions/__init__.py,sha256=OAVy6cQIS85Utyp18jjHgdmascX_8nkwk3A0OpFJxT4,230
33
33
  rlgym_learn_algos/stateful_functions/batch_reward_type_numpy_converter.py,sha256=a3q2l5SIgDI36ImF_kYoa684pghnFnlV2vGYvV2zcV0,743
@@ -37,4 +37,4 @@ rlgym_learn_algos/util/__init__.py,sha256=hq7M00Q7zAfyQmIGmXOif0vI40aj_FQ5SqI5dn
37
37
  rlgym_learn_algos/util/running_stats.py,sha256=KtzdKKT75-5ZC58JRqaDXk6sBqa3ZSjQQZrRajAw3Yk,4339
38
38
  rlgym_learn_algos/util/torch_functions.py,sha256=ImgDw4I3ZixGDi17YRkW6UbaiaQTbvOCUCS7N0QVSsU,3320
39
39
  rlgym_learn_algos/util/torch_pydantic.py,sha256=khPGA6kWh4_WHoploDkl_SCIGX8SkKkFT40RE06PImc,3413
40
- rlgym_learn_algos-0.2.2.dist-info/RECORD,,
40
+ rlgym_learn_algos-0.2.4.dist-info/RECORD,,