pyRDDLGym-jax 2.2__py3-none-any.whl → 2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,7 @@ import datetime
18
18
  import threading
19
19
  import multiprocessing
20
20
  import os
21
+ import termcolor
21
22
  import time
22
23
  import traceback
23
24
  from typing import Any, Callable, Dict, Iterable, Optional, Tuple
@@ -45,8 +46,7 @@ try:
45
46
  from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
46
47
  except Exception:
47
48
  raise_warning('Failed to load the dashboard visualization tool: '
48
- 'please make sure you have installed the required packages.',
49
- 'red')
49
+ 'please make sure you have installed the required packages.', 'red')
50
50
  traceback.print_exc()
51
51
  JaxPlannerDashboard = None
52
52
 
@@ -159,24 +159,24 @@ class JaxParameterTuning:
159
159
  kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
160
160
  return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
161
161
 
162
- def summarize_hyperparameters(self) -> None:
162
+ def summarize_hyperparameters(self) -> str:
163
163
  hyper_params_table = []
164
164
  for (_, param) in self.hyperparams_dict.items():
165
165
  hyper_params_table.append(f' {str(param)}')
166
166
  hyper_params_table = '\n'.join(hyper_params_table)
167
- print(f'hyperparameter optimizer parameters:\n'
168
- f' tuned_hyper_parameters =\n{hyper_params_table}\n'
169
- f' initialization_args ={self.gp_init_kwargs}\n'
170
- f' gp_params ={self.gp_params}\n'
171
- f' tuning_iterations ={self.gp_iters}\n'
172
- f' tuning_timeout ={self.timeout_tuning}\n'
173
- f' tuning_batch_size ={self.num_workers}\n'
174
- f' mp_pool_context_type ={self.pool_context}\n'
175
- f' mp_pool_poll_frequency ={self.poll_frequency}\n'
176
- f'meta-objective parameters:\n'
177
- f' planning_trials_per_iter ={self.eval_trials}\n'
178
- f' rollouts_per_trial ={self.rollouts_per_trial}\n'
179
- f' acquisition_fn ={self.acquisition}')
167
+ return (f'hyperparameter optimizer parameters:\n'
168
+ f' tuned_hyper_parameters =\n{hyper_params_table}\n'
169
+ f' initialization_args ={self.gp_init_kwargs}\n'
170
+ f' gp_params ={self.gp_params}\n'
171
+ f' tuning_iterations ={self.gp_iters}\n'
172
+ f' tuning_timeout ={self.timeout_tuning}\n'
173
+ f' tuning_batch_size ={self.num_workers}\n'
174
+ f' mp_pool_context_type ={self.pool_context}\n'
175
+ f' mp_pool_poll_frequency ={self.poll_frequency}\n'
176
+ f'meta-objective parameters:\n'
177
+ f' planning_trials_per_iter ={self.eval_trials}\n'
178
+ f' rollouts_per_trial ={self.rollouts_per_trial}\n'
179
+ f' acquisition_fn ={self.acquisition}')
180
180
 
181
181
  @staticmethod
182
182
  def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
@@ -346,6 +346,7 @@ class JaxParameterTuning:
346
346
 
347
347
  # remove keywords that should not be in the tuner
348
348
  train_args.pop('dashboard', None)
349
+ planner_args.pop('parallel_updates', None)
349
350
 
350
351
  # initialize env for evaluation (need fresh copy to avoid concurrency)
351
352
  env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
@@ -368,12 +369,12 @@ class JaxParameterTuning:
368
369
 
369
370
  def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
370
371
  '''Tunes the Bayesian optimization algorithm hyper-parameters.'''
371
- print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.')
372
+ print(f'Kernel: {repr(optimizer._gp.kernel_)}.')
372
373
 
373
374
  def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
374
375
  '''Tunes the hyper-parameters for Jax planner, returns the best found.'''
375
376
 
376
- self.summarize_hyperparameters()
377
+ print(self.summarize_hyperparameters())
377
378
 
378
379
  # clear and prepare output file
379
380
  with open(log_file, 'w', newline='') as file:
@@ -445,13 +446,15 @@ class JaxParameterTuning:
445
446
  # check if there is enough time left for another iteration
446
447
  elapsed = time.time() - start_time
447
448
  if elapsed >= self.timeout_tuning:
448
- print(f'global time limit reached at iteration {it}, aborting')
449
+ message = termcolor.colored(
450
+ f'[INFO] Global time limit reached at iteration {it}.', 'green')
451
+ print(message)
449
452
  break
450
453
 
451
454
  # continue with next iteration
452
455
  print('\n' + '*' * 80 +
453
456
  f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
454
- f'starting iteration {it + 1}' +
457
+ f'Starting iteration {it + 1}' +
455
458
  '\n' + '*' * 80)
456
459
  key, *subkeys = jax.random.split(key, num=num_workers + 1)
457
460
  rows = [None] * num_workers
@@ -507,7 +510,10 @@ class JaxParameterTuning:
507
510
 
508
511
  # print best parameter if found
509
512
  if best_target > old_best_target:
510
- print(f'* found new best average reward {best_target:.6f}')
513
+ message = termcolor.colored(
514
+ f'[INFO] Found new best average reward {best_target:.6f}.',
515
+ 'green')
516
+ print(message)
511
517
 
512
518
  # tune the optimizer here
513
519
  self.tune_optimizer(optimizer)
@@ -528,7 +534,7 @@ class JaxParameterTuning:
528
534
 
529
535
  # print summary of results
530
536
  elapsed = time.time() - start_time
531
- print(f'summary of hyper-parameter optimization:\n'
537
+ print(f'Summary of hyper-parameter optimization:\n'
532
538
  f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
533
539
  f' iterations ={it + 1}\n'
534
540
  f' best_hyper_parameters={best_params}\n'
@@ -36,8 +36,8 @@ def main(domain, instance, method, episodes=1):
36
36
  abs_path = os.path.dirname(os.path.abspath(__file__))
37
37
  config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
38
38
  if not os.path.isfile(config_path):
39
- raise_warning(f'Config file {config_path} was not found, '
40
- f'using default_{method}.cfg.', 'red')
39
+ raise_warning(f'[WARN] Config file {config_path} was not found, '
40
+ f'using default_{method}.cfg.', 'yellow')
41
41
  config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
42
42
  elif os.path.isfile(method):
43
43
  config_path = method
@@ -31,8 +31,8 @@ def main(domain, instance, method, episodes=1):
31
31
  abs_path = os.path.dirname(os.path.abspath(__file__))
32
32
  config_path = os.path.join(abs_path, 'configs', f'{domain}_slp.cfg')
33
33
  if not os.path.isfile(config_path):
34
- raise_warning(f'Config file {config_path} was not found, '
35
- f'using default_slp.cfg.', 'red')
34
+ raise_warning(f'[WARN] Config file {config_path} was not found, '
35
+ f'using default_slp.cfg.', 'yellow')
36
36
  config_path = os.path.join(abs_path, 'configs', 'default_slp.cfg')
37
37
  planner_args, _, train_args = load_config(config_path)
38
38
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pyRDDLGym-jax
3
- Version: 2.2
3
+ Version: 2.4
4
4
  Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
5
5
  Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
6
6
  Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
@@ -1,19 +1,19 @@
1
- pyRDDLGym_jax/__init__.py,sha256=lqo7WXKfZGHPIOxgE6EWI5fGZHP2h6XrwVNNVQAUN3Q,19
1
+ pyRDDLGym_jax/__init__.py,sha256=6Bd43-94X_2dH_ErGLQ0_DvlhX5cLWkVPvn31JBzFkY,19
2
2
  pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
3
3
  pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- pyRDDLGym_jax/core/compiler.py,sha256=_ERueJW7GQ7S8-IezreeuLs3fNCZbQZ8j7VMUVlEt1k,82306
5
- pyRDDLGym_jax/core/logic.py,sha256=ZeCwCLqC6BvXpRT06TvE2bfPNO6ALuMzPmUvXNzW6Uw,52278
6
- pyRDDLGym_jax/core/planner.py,sha256=0rluBXKGNHRPEPfegOWcx9__cJHr8KjZdDJtG7i1JjI,122793
4
+ pyRDDLGym_jax/core/compiler.py,sha256=NFWfTHtGf7F-t7Qhn6X-VpSAJkTVHm-oRjujFw4O1HA,82605
5
+ pyRDDLGym_jax/core/logic.py,sha256=lfc2ak_ap_ajMEFlB5EHCRNgJym31dNyA-5d-7N4CZA,56271
6
+ pyRDDLGym_jax/core/planner.py,sha256=wZJiZHV0Qxi9DS3AQ9Rx1doBvsKQXc1HYziY6GXTu_A,136965
7
7
  pyRDDLGym_jax/core/simulator.py,sha256=DnPL93WVCMZqtqMUoiJdfWcH9pEvNgGfDfO4NV0wIS0,9271
8
- pyRDDLGym_jax/core/tuning.py,sha256=RKKtDZp7unvfbhZEoaunZtcAn5xtzGYqXBB_Ij_Aapc,24205
8
+ pyRDDLGym_jax/core/tuning.py,sha256=Gm3YJF84_2vDIIJpOj0tK0-4rlJoEjYwxRt_JpUKAOA,24482
9
9
  pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
10
10
  pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
12
12
  pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
14
14
  pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
15
- pyRDDLGym_jax/examples/run_plan.py,sha256=v2AvwgIa4Ejr626vBOgWFJIQvay3IPKWno02ztIFCYc,2768
16
- pyRDDLGym_jax/examples/run_scipy.py,sha256=wvcpWCvdjvYHntO95a7JYfY2fuCMUTKnqjJikW0PnL4,2291
15
+ pyRDDLGym_jax/examples/run_plan.py,sha256=TVfziHHaEC56wxwRw9llZ5iqSHe3m6yy8HxiR2TyvXE,2778
16
+ pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
17
17
  pyRDDLGym_jax/examples/run_tune.py,sha256=WbGO8RudIK-cPMAMKvI8NbFQAqkG-Blbnta3Efsep6c,3828
18
18
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
19
19
  pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
@@ -41,9 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
41
41
  pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
42
42
  pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
43
43
  pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
44
- pyrddlgym_jax-2.2.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
- pyrddlgym_jax-2.2.dist-info/METADATA,sha256=aFNUX6uUZZHS7lPbYBTmMSH6TBiWmXbEgQNxPZNWiRI,17021
46
- pyrddlgym_jax-2.2.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
47
- pyrddlgym_jax-2.2.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
- pyrddlgym_jax-2.2.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
- pyrddlgym_jax-2.2.dist-info/RECORD,,
44
+ pyrddlgym_jax-2.4.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
45
+ pyrddlgym_jax-2.4.dist-info/METADATA,sha256=98Nl3EnEk-fRLeoy9orDScaikCT9M8X4zOfYtiS-WXI,17021
46
+ pyrddlgym_jax-2.4.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
47
+ pyrddlgym_jax-2.4.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
48
+ pyrddlgym_jax-2.4.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
49
+ pyrddlgym_jax-2.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5