pyRDDLGym-jax 2.2__py3-none-any.whl → 2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyRDDLGym_jax/__init__.py +1 -1
- pyRDDLGym_jax/core/compiler.py +16 -11
- pyRDDLGym_jax/core/logic.py +233 -119
- pyRDDLGym_jax/core/planner.py +489 -218
- pyRDDLGym_jax/core/tuning.py +28 -22
- pyRDDLGym_jax/examples/run_plan.py +2 -2
- pyRDDLGym_jax/examples/run_scipy.py +2 -2
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/METADATA +1 -1
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/RECORD +13 -13
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/WHEEL +1 -1
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/LICENSE +0 -0
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/entry_points.txt +0 -0
- {pyrddlgym_jax-2.2.dist-info → pyrddlgym_jax-2.4.dist-info}/top_level.txt +0 -0
pyRDDLGym_jax/core/tuning.py
CHANGED
|
@@ -18,6 +18,7 @@ import datetime
|
|
|
18
18
|
import threading
|
|
19
19
|
import multiprocessing
|
|
20
20
|
import os
|
|
21
|
+
import termcolor
|
|
21
22
|
import time
|
|
22
23
|
import traceback
|
|
23
24
|
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
|
|
@@ -45,8 +46,7 @@ try:
|
|
|
45
46
|
from pyRDDLGym_jax.core.visualization import JaxPlannerDashboard
|
|
46
47
|
except Exception:
|
|
47
48
|
raise_warning('Failed to load the dashboard visualization tool: '
|
|
48
|
-
'please make sure you have installed the required packages.',
|
|
49
|
-
'red')
|
|
49
|
+
'please make sure you have installed the required packages.', 'red')
|
|
50
50
|
traceback.print_exc()
|
|
51
51
|
JaxPlannerDashboard = None
|
|
52
52
|
|
|
@@ -159,24 +159,24 @@ class JaxParameterTuning:
|
|
|
159
159
|
kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
|
|
160
160
|
return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3
|
|
161
161
|
|
|
162
|
-
def summarize_hyperparameters(self) ->
|
|
162
|
+
def summarize_hyperparameters(self) -> str:
|
|
163
163
|
hyper_params_table = []
|
|
164
164
|
for (_, param) in self.hyperparams_dict.items():
|
|
165
165
|
hyper_params_table.append(f' {str(param)}')
|
|
166
166
|
hyper_params_table = '\n'.join(hyper_params_table)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
167
|
+
return (f'hyperparameter optimizer parameters:\n'
|
|
168
|
+
f' tuned_hyper_parameters =\n{hyper_params_table}\n'
|
|
169
|
+
f' initialization_args ={self.gp_init_kwargs}\n'
|
|
170
|
+
f' gp_params ={self.gp_params}\n'
|
|
171
|
+
f' tuning_iterations ={self.gp_iters}\n'
|
|
172
|
+
f' tuning_timeout ={self.timeout_tuning}\n'
|
|
173
|
+
f' tuning_batch_size ={self.num_workers}\n'
|
|
174
|
+
f' mp_pool_context_type ={self.pool_context}\n'
|
|
175
|
+
f' mp_pool_poll_frequency ={self.poll_frequency}\n'
|
|
176
|
+
f'meta-objective parameters:\n'
|
|
177
|
+
f' planning_trials_per_iter ={self.eval_trials}\n'
|
|
178
|
+
f' rollouts_per_trial ={self.rollouts_per_trial}\n'
|
|
179
|
+
f' acquisition_fn ={self.acquisition}')
|
|
180
180
|
|
|
181
181
|
@staticmethod
|
|
182
182
|
def annealing_acquisition(n_samples: int, n_delay_samples: int=0,
|
|
@@ -346,6 +346,7 @@ class JaxParameterTuning:
|
|
|
346
346
|
|
|
347
347
|
# remove keywords that should not be in the tuner
|
|
348
348
|
train_args.pop('dashboard', None)
|
|
349
|
+
planner_args.pop('parallel_updates', None)
|
|
349
350
|
|
|
350
351
|
# initialize env for evaluation (need fresh copy to avoid concurrency)
|
|
351
352
|
env = RDDLEnv(domain, instance, vectorized=True, enforce_action_constraints=False)
|
|
@@ -368,12 +369,12 @@ class JaxParameterTuning:
|
|
|
368
369
|
|
|
369
370
|
def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
|
|
370
371
|
'''Tunes the Bayesian optimization algorithm hyper-parameters.'''
|
|
371
|
-
print(
|
|
372
|
+
print(f'Kernel: {repr(optimizer._gp.kernel_)}.')
|
|
372
373
|
|
|
373
374
|
def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
|
|
374
375
|
'''Tunes the hyper-parameters for Jax planner, returns the best found.'''
|
|
375
376
|
|
|
376
|
-
self.summarize_hyperparameters()
|
|
377
|
+
print(self.summarize_hyperparameters())
|
|
377
378
|
|
|
378
379
|
# clear and prepare output file
|
|
379
380
|
with open(log_file, 'w', newline='') as file:
|
|
@@ -445,13 +446,15 @@ class JaxParameterTuning:
|
|
|
445
446
|
# check if there is enough time left for another iteration
|
|
446
447
|
elapsed = time.time() - start_time
|
|
447
448
|
if elapsed >= self.timeout_tuning:
|
|
448
|
-
|
|
449
|
+
message = termcolor.colored(
|
|
450
|
+
f'[INFO] Global time limit reached at iteration {it}.', 'green')
|
|
451
|
+
print(message)
|
|
449
452
|
break
|
|
450
453
|
|
|
451
454
|
# continue with next iteration
|
|
452
455
|
print('\n' + '*' * 80 +
|
|
453
456
|
f'\n[{datetime.timedelta(seconds=elapsed)}] ' +
|
|
454
|
-
f'
|
|
457
|
+
f'Starting iteration {it + 1}' +
|
|
455
458
|
'\n' + '*' * 80)
|
|
456
459
|
key, *subkeys = jax.random.split(key, num=num_workers + 1)
|
|
457
460
|
rows = [None] * num_workers
|
|
@@ -507,7 +510,10 @@ class JaxParameterTuning:
|
|
|
507
510
|
|
|
508
511
|
# print best parameter if found
|
|
509
512
|
if best_target > old_best_target:
|
|
510
|
-
|
|
513
|
+
message = termcolor.colored(
|
|
514
|
+
f'[INFO] Found new best average reward {best_target:.6f}.',
|
|
515
|
+
'green')
|
|
516
|
+
print(message)
|
|
511
517
|
|
|
512
518
|
# tune the optimizer here
|
|
513
519
|
self.tune_optimizer(optimizer)
|
|
@@ -528,7 +534,7 @@ class JaxParameterTuning:
|
|
|
528
534
|
|
|
529
535
|
# print summary of results
|
|
530
536
|
elapsed = time.time() - start_time
|
|
531
|
-
print(f'
|
|
537
|
+
print(f'Summary of hyper-parameter optimization:\n'
|
|
532
538
|
f' time_elapsed ={datetime.timedelta(seconds=elapsed)}\n'
|
|
533
539
|
f' iterations ={it + 1}\n'
|
|
534
540
|
f' best_hyper_parameters={best_params}\n'
|
|
@@ -36,8 +36,8 @@ def main(domain, instance, method, episodes=1):
|
|
|
36
36
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
37
37
|
config_path = os.path.join(abs_path, 'configs', f'{domain}_{method}.cfg')
|
|
38
38
|
if not os.path.isfile(config_path):
|
|
39
|
-
raise_warning(f'Config file {config_path} was not found, '
|
|
40
|
-
f'using default_{method}.cfg.', '
|
|
39
|
+
raise_warning(f'[WARN] Config file {config_path} was not found, '
|
|
40
|
+
f'using default_{method}.cfg.', 'yellow')
|
|
41
41
|
config_path = os.path.join(abs_path, 'configs', f'default_{method}.cfg')
|
|
42
42
|
elif os.path.isfile(method):
|
|
43
43
|
config_path = method
|
|
@@ -31,8 +31,8 @@ def main(domain, instance, method, episodes=1):
|
|
|
31
31
|
abs_path = os.path.dirname(os.path.abspath(__file__))
|
|
32
32
|
config_path = os.path.join(abs_path, 'configs', f'{domain}_slp.cfg')
|
|
33
33
|
if not os.path.isfile(config_path):
|
|
34
|
-
raise_warning(f'Config file {config_path} was not found, '
|
|
35
|
-
f'using default_slp.cfg.', '
|
|
34
|
+
raise_warning(f'[WARN] Config file {config_path} was not found, '
|
|
35
|
+
f'using default_slp.cfg.', 'yellow')
|
|
36
36
|
config_path = os.path.join(abs_path, 'configs', 'default_slp.cfg')
|
|
37
37
|
planner_args, _, train_args = load_config(config_path)
|
|
38
38
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pyRDDLGym-jax
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.4
|
|
4
4
|
Summary: pyRDDLGym-jax: automatic differentiation for solving sequential planning problems in JAX.
|
|
5
5
|
Home-page: https://github.com/pyrddlgym-project/pyRDDLGym-jax
|
|
6
6
|
Author: Michael Gimelfarb, Ayal Taitler, Scott Sanner
|
|
@@ -1,19 +1,19 @@
|
|
|
1
|
-
pyRDDLGym_jax/__init__.py,sha256=
|
|
1
|
+
pyRDDLGym_jax/__init__.py,sha256=6Bd43-94X_2dH_ErGLQ0_DvlhX5cLWkVPvn31JBzFkY,19
|
|
2
2
|
pyRDDLGym_jax/entry_point.py,sha256=dxDlO_5gneEEViwkLCg30Z-KVzUgdRXaKuFjoZklkA0,974
|
|
3
3
|
pyRDDLGym_jax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
pyRDDLGym_jax/core/compiler.py,sha256=
|
|
5
|
-
pyRDDLGym_jax/core/logic.py,sha256=
|
|
6
|
-
pyRDDLGym_jax/core/planner.py,sha256=
|
|
4
|
+
pyRDDLGym_jax/core/compiler.py,sha256=NFWfTHtGf7F-t7Qhn6X-VpSAJkTVHm-oRjujFw4O1HA,82605
|
|
5
|
+
pyRDDLGym_jax/core/logic.py,sha256=lfc2ak_ap_ajMEFlB5EHCRNgJym31dNyA-5d-7N4CZA,56271
|
|
6
|
+
pyRDDLGym_jax/core/planner.py,sha256=wZJiZHV0Qxi9DS3AQ9Rx1doBvsKQXc1HYziY6GXTu_A,136965
|
|
7
7
|
pyRDDLGym_jax/core/simulator.py,sha256=DnPL93WVCMZqtqMUoiJdfWcH9pEvNgGfDfO4NV0wIS0,9271
|
|
8
|
-
pyRDDLGym_jax/core/tuning.py,sha256=
|
|
8
|
+
pyRDDLGym_jax/core/tuning.py,sha256=Gm3YJF84_2vDIIJpOj0tK0-4rlJoEjYwxRt_JpUKAOA,24482
|
|
9
9
|
pyRDDLGym_jax/core/visualization.py,sha256=4BghMp8N7qtF0tdyDSqtxAxNfP9HPrQWTiXzAMJmx7o,70365
|
|
10
10
|
pyRDDLGym_jax/core/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
pyRDDLGym_jax/core/assets/favicon.ico,sha256=RMMrI9YvmF81TgYG7FO7UAre6WmYFkV3B2GmbA1l0kM,175085
|
|
12
12
|
pyRDDLGym_jax/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
pyRDDLGym_jax/examples/run_gradient.py,sha256=KhXvijRDZ4V7N8NOI2WV8ePGpPna5_vnET61YwS7Tco,2919
|
|
14
14
|
pyRDDLGym_jax/examples/run_gym.py,sha256=rXvNWkxe4jHllvbvU_EOMji_2-2k5d4tbBKhpMm_Gaw,1526
|
|
15
|
-
pyRDDLGym_jax/examples/run_plan.py,sha256=
|
|
16
|
-
pyRDDLGym_jax/examples/run_scipy.py,sha256=
|
|
15
|
+
pyRDDLGym_jax/examples/run_plan.py,sha256=TVfziHHaEC56wxwRw9llZ5iqSHe3m6yy8HxiR2TyvXE,2778
|
|
16
|
+
pyRDDLGym_jax/examples/run_scipy.py,sha256=7uVnDXb7D3NTJqA2L8nrcYDJP-k0ba9dl9YqA2CD9ac,2301
|
|
17
17
|
pyRDDLGym_jax/examples/run_tune.py,sha256=WbGO8RudIK-cPMAMKvI8NbFQAqkG-Blbnta3Efsep6c,3828
|
|
18
18
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_drp.cfg,sha256=mE8MqhOlkHeXIGEVrnR3QY6I-_iy4uxFYRA71P1bmtk,347
|
|
19
19
|
pyRDDLGym_jax/examples/configs/Cartpole_Continuous_gym_replan.cfg,sha256=nFFYHCKQUMn8x-OpJwu2pwe1tycNSJ8iAIwSkCBn33E,370
|
|
@@ -41,9 +41,9 @@ pyRDDLGym_jax/examples/configs/default_slp.cfg,sha256=mJo0woDevhQCSQfJg30ULVy9qG
|
|
|
41
41
|
pyRDDLGym_jax/examples/configs/tuning_drp.cfg,sha256=CQMpSCKTkGioO7U82mHMsYWFRsutULx0V6Wrl3YzV2U,504
|
|
42
42
|
pyRDDLGym_jax/examples/configs/tuning_replan.cfg,sha256=m_0nozFg_GVld0tGv92Xao_KONFJDq_vtiJKt5isqI8,501
|
|
43
43
|
pyRDDLGym_jax/examples/configs/tuning_slp.cfg,sha256=KHu8II6CA-h_HblwvWHylNRjSvvGS3VHxN7JQNR4p_Q,464
|
|
44
|
-
pyrddlgym_jax-2.
|
|
45
|
-
pyrddlgym_jax-2.
|
|
46
|
-
pyrddlgym_jax-2.
|
|
47
|
-
pyrddlgym_jax-2.
|
|
48
|
-
pyrddlgym_jax-2.
|
|
49
|
-
pyrddlgym_jax-2.
|
|
44
|
+
pyrddlgym_jax-2.4.dist-info/LICENSE,sha256=Y0Gi6H6mLOKN-oIKGZulQkoTJyPZeAaeuZu7FXH-meg,1095
|
|
45
|
+
pyrddlgym_jax-2.4.dist-info/METADATA,sha256=98Nl3EnEk-fRLeoy9orDScaikCT9M8X4zOfYtiS-WXI,17021
|
|
46
|
+
pyrddlgym_jax-2.4.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
|
47
|
+
pyrddlgym_jax-2.4.dist-info/entry_points.txt,sha256=Q--z9QzqDBz1xjswPZ87PU-pib-WPXx44hUWAFoBGBA,59
|
|
48
|
+
pyrddlgym_jax-2.4.dist-info/top_level.txt,sha256=n_oWkP_BoZK0VofvPKKmBZ3NPk86WFNvLhi1BktCbVQ,14
|
|
49
|
+
pyrddlgym_jax-2.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|