omniopt2 8366__py3-none-any.whl → 9061__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omniopt2 might be problematic. Click here for more details.
- .gitignore +2 -0
- .helpers.py +0 -9
- .omniopt.py +1369 -1011
- .pareto.py +134 -0
- .shellscript_functions +2 -2
- .tests/pylint.rc +0 -4
- .tpe.py +4 -3
- omniopt +56 -27
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.helpers.py +0 -9
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt.py +1369 -1011
- omniopt2-9061.data/data/bin/.pareto.py +134 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.shellscript_functions +2 -2
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.tpe.py +4 -3
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt +56 -27
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/pylint.rc +0 -4
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/requirements.txt +1 -2
- {omniopt2-8366.dist-info → omniopt2-9061.dist-info}/METADATA +2 -3
- omniopt2-9061.dist-info/RECORD +73 -0
- omniopt2.egg-info/PKG-INFO +2 -3
- omniopt2.egg-info/SOURCES.txt +1 -0
- omniopt2.egg-info/requires.txt +1 -2
- pyproject.toml +1 -1
- requirements.txt +1 -2
- omniopt2-8366.dist-info/RECORD +0 -71
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.colorfunctions.sh +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.general.sh +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_cpu_ram_usage.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_general.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_gpu_usage.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_kde.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_scatter.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_scatter_generation_method.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_scatter_hex.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_time_and_exit_code.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_trial_index_result.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_worker.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.random_generator.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/LICENSE +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/apt-dependencies.txt +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt_docker +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt_evaluate +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt_plot +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt_share +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/setup.py +0 -0
- {omniopt2-8366.data → omniopt2-9061.data}/data/bin/test_requirements.txt +0 -0
- {omniopt2-8366.dist-info → omniopt2-9061.dist-info}/WHEEL +0 -0
- {omniopt2-8366.dist-info → omniopt2-9061.dist-info}/licenses/LICENSE +0 -0
- {omniopt2-8366.dist-info → omniopt2-9061.dist-info}/top_level.txt +0 -0
|
@@ -27,11 +27,15 @@ import inspect
|
|
|
27
27
|
import tracemalloc
|
|
28
28
|
import resource
|
|
29
29
|
from urllib.parse import urlencode
|
|
30
|
-
|
|
31
30
|
import psutil
|
|
32
31
|
|
|
33
32
|
FORCE_EXIT: bool = False
|
|
34
33
|
|
|
34
|
+
LAST_LOG_TIME: int = 0
|
|
35
|
+
last_msg_progressbar = ""
|
|
36
|
+
last_msg_raw = None
|
|
37
|
+
last_lock_print_debug = threading.Lock()
|
|
38
|
+
|
|
35
39
|
def force_exit(signal_number: Any, frame: Any) -> Any:
|
|
36
40
|
global FORCE_EXIT
|
|
37
41
|
|
|
@@ -61,10 +65,10 @@ _last_count_time = 0
|
|
|
61
65
|
_last_count_result: tuple[int, str] = (0, "")
|
|
62
66
|
|
|
63
67
|
_total_time = 0.0
|
|
64
|
-
_func_times = defaultdict(float)
|
|
65
|
-
_func_mem = defaultdict(float)
|
|
66
|
-
_func_call_paths = defaultdict(Counter)
|
|
67
|
-
_last_mem = defaultdict(float)
|
|
68
|
+
_func_times: defaultdict = defaultdict(float)
|
|
69
|
+
_func_mem: defaultdict = defaultdict(float)
|
|
70
|
+
_func_call_paths: defaultdict = defaultdict(Counter)
|
|
71
|
+
_last_mem: defaultdict = defaultdict(float)
|
|
68
72
|
_leak_threshold_mb = 10.0
|
|
69
73
|
generation_strategy_names: list = []
|
|
70
74
|
default_max_range_difference: int = 1000000
|
|
@@ -84,6 +88,7 @@ log_nr_gen_jobs: list[int] = []
|
|
|
84
88
|
generation_strategy_human_readable: str = ""
|
|
85
89
|
oo_call: str = "./omniopt"
|
|
86
90
|
progress_bar_length: int = 0
|
|
91
|
+
worker_usage_file = 'worker_usage.csv'
|
|
87
92
|
|
|
88
93
|
if os.environ.get("CUSTOM_VIRTUAL_ENV") == "1":
|
|
89
94
|
oo_call = "omniopt"
|
|
@@ -99,7 +104,7 @@ joined_valid_occ_types: str = ", ".join(valid_occ_types)
|
|
|
99
104
|
SUPPORTED_MODELS: list = ["SOBOL", "FACTORIAL", "SAASBO", "BOTORCH_MODULAR", "UNIFORM", "BO_MIXED", "RANDOMFOREST", "EXTERNAL_GENERATOR", "PSEUDORANDOM", "TPE"]
|
|
100
105
|
joined_supported_models: str = ", ".join(SUPPORTED_MODELS)
|
|
101
106
|
|
|
102
|
-
special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid"]
|
|
107
|
+
special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid", "runtime", "status"]
|
|
103
108
|
|
|
104
109
|
IGNORABLE_COLUMNS: list = ["start_time", "end_time", "hostname", "signal", "exit_code", "run_time", "program_string"] + special_col_names
|
|
105
110
|
|
|
@@ -153,12 +158,6 @@ try:
|
|
|
153
158
|
message="Ax currently requires a sqlalchemy version below 2.0.*",
|
|
154
159
|
)
|
|
155
160
|
|
|
156
|
-
warnings.filterwarnings(
|
|
157
|
-
"ignore",
|
|
158
|
-
category=RuntimeWarning,
|
|
159
|
-
message="coroutine 'start_logging_daemon' was never awaited"
|
|
160
|
-
)
|
|
161
|
-
|
|
162
161
|
with spinner("Importing argparse..."):
|
|
163
162
|
import argparse
|
|
164
163
|
|
|
@@ -204,6 +203,9 @@ try:
|
|
|
204
203
|
with spinner("Importing rich.pretty..."):
|
|
205
204
|
from rich.pretty import pprint
|
|
206
205
|
|
|
206
|
+
with spinner("Importing pformat..."):
|
|
207
|
+
from pprint import pformat
|
|
208
|
+
|
|
207
209
|
with spinner("Importing rich.prompt..."):
|
|
208
210
|
from rich.prompt import Prompt, FloatPrompt, IntPrompt
|
|
209
211
|
|
|
@@ -237,9 +239,6 @@ try:
|
|
|
237
239
|
with spinner("Importing uuid..."):
|
|
238
240
|
import uuid
|
|
239
241
|
|
|
240
|
-
#with spinner("Importing qrcode..."):
|
|
241
|
-
# import qrcode
|
|
242
|
-
|
|
243
242
|
with spinner("Importing cowsay..."):
|
|
244
243
|
import cowsay
|
|
245
244
|
|
|
@@ -333,7 +332,7 @@ def show_func_name_wrapper(func: F) -> F:
|
|
|
333
332
|
|
|
334
333
|
return result
|
|
335
334
|
|
|
336
|
-
return wrapper
|
|
335
|
+
return wrapper # type: ignore
|
|
337
336
|
|
|
338
337
|
def log_time_and_memory_wrapper(func: F) -> F:
|
|
339
338
|
@functools.wraps(func)
|
|
@@ -360,7 +359,7 @@ def log_time_and_memory_wrapper(func: F) -> F:
|
|
|
360
359
|
|
|
361
360
|
return result
|
|
362
361
|
|
|
363
|
-
return wrapper
|
|
362
|
+
return wrapper # type: ignore
|
|
364
363
|
|
|
365
364
|
def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: float, mem_peak: float) -> None:
|
|
366
365
|
global _total_time
|
|
@@ -381,16 +380,9 @@ def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: fl
|
|
|
381
380
|
call_path_str = " -> ".join(short_stack)
|
|
382
381
|
_func_call_paths[func_name][call_path_str] += 1
|
|
383
382
|
|
|
384
|
-
print(
|
|
385
|
-
|
|
386
|
-
f"(total {percent_if_added:.1f}% of tracked time)"
|
|
387
|
-
)
|
|
388
|
-
print(
|
|
389
|
-
f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, "
|
|
390
|
-
f"diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB"
|
|
391
|
-
)
|
|
383
|
+
print(f"Function '{func_name}' took {elapsed:.4f}s (total {percent_if_added:.1f}% of tracked time)")
|
|
384
|
+
print(f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB")
|
|
392
385
|
|
|
393
|
-
# NEU: Runtime Stats
|
|
394
386
|
runtime_stats = collect_runtime_stats()
|
|
395
387
|
print("=== Runtime Stats ===")
|
|
396
388
|
print(f"RSS: {runtime_stats['rss_MB']:.2f} MB, VMS: {runtime_stats['vms_MB']:.2f} MB")
|
|
@@ -449,7 +441,7 @@ RESET: str = "\033[0m"
|
|
|
449
441
|
|
|
450
442
|
uuid_regex: Pattern = re.compile(r"^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[89aAbB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$")
|
|
451
443
|
|
|
452
|
-
worker_generator_uuid: str = uuid.uuid4()
|
|
444
|
+
worker_generator_uuid: str = str(uuid.uuid4())
|
|
453
445
|
|
|
454
446
|
new_uuid: str = str(uuid.uuid4())
|
|
455
447
|
run_uuid: str = os.getenv("RUN_UUID", new_uuid)
|
|
@@ -478,7 +470,7 @@ def get_current_run_folder(name: Optional[str] = None) -> str:
|
|
|
478
470
|
|
|
479
471
|
return CURRENT_RUN_FOLDER
|
|
480
472
|
|
|
481
|
-
def get_state_file_name(name) -> str:
|
|
473
|
+
def get_state_file_name(name: str) -> str:
|
|
482
474
|
state_files_folder = f"{get_current_run_folder()}/state_files/"
|
|
483
475
|
makedirs(state_files_folder)
|
|
484
476
|
|
|
@@ -502,6 +494,24 @@ try:
|
|
|
502
494
|
dier: FunctionType = helpers.dier
|
|
503
495
|
is_equal: FunctionType = helpers.is_equal
|
|
504
496
|
is_not_equal: FunctionType = helpers.is_not_equal
|
|
497
|
+
with spinner("Importing pareto..."):
|
|
498
|
+
pareto_file: str = f"{script_dir}/.pareto.py"
|
|
499
|
+
spec = importlib.util.spec_from_file_location(
|
|
500
|
+
name="pareto",
|
|
501
|
+
location=pareto_file,
|
|
502
|
+
)
|
|
503
|
+
if spec is not None and spec.loader is not None:
|
|
504
|
+
pareto = importlib.util.module_from_spec(spec)
|
|
505
|
+
spec.loader.exec_module(pareto)
|
|
506
|
+
else:
|
|
507
|
+
raise ImportError(f"Could not load module from {pareto_file}")
|
|
508
|
+
|
|
509
|
+
pareto_front_table_filter_rows: FunctionType = pareto.pareto_front_table_filter_rows
|
|
510
|
+
pareto_front_table_add_headers: FunctionType = pareto.pareto_front_table_add_headers
|
|
511
|
+
pareto_front_table_add_rows: FunctionType = pareto.pareto_front_table_add_rows
|
|
512
|
+
pareto_front_filter_complete_points: FunctionType = pareto.pareto_front_filter_complete_points
|
|
513
|
+
pareto_front_select_pareto_points: FunctionType = pareto.pareto_front_select_pareto_points
|
|
514
|
+
|
|
505
515
|
except KeyboardInterrupt:
|
|
506
516
|
print("You pressed CTRL-c while importing the helpers file")
|
|
507
517
|
sys.exit(0)
|
|
@@ -525,15 +535,13 @@ def is_slurm_job() -> bool:
|
|
|
525
535
|
return True
|
|
526
536
|
return False
|
|
527
537
|
|
|
528
|
-
def _sleep(t: int) ->
|
|
538
|
+
def _sleep(t: Union[float, int]) -> None:
|
|
529
539
|
if args is not None and not args.no_sleep:
|
|
530
540
|
try:
|
|
531
541
|
time.sleep(t)
|
|
532
542
|
except KeyboardInterrupt:
|
|
533
543
|
pass
|
|
534
544
|
|
|
535
|
-
return t
|
|
536
|
-
|
|
537
545
|
LOG_DIR: str = "logs"
|
|
538
546
|
makedirs(LOG_DIR)
|
|
539
547
|
|
|
@@ -546,6 +554,17 @@ logfile_worker_creation_logs: str = f'{log_uuid_dir}_worker_creation_logs'
|
|
|
546
554
|
logfile_trial_index_to_param_logs: str = f'{log_uuid_dir}_trial_index_to_param_logs'
|
|
547
555
|
LOGFILE_DEBUG_GET_NEXT_TRIALS: Union[str, None] = None
|
|
548
556
|
|
|
557
|
+
def error_without_print(text: str) -> None:
|
|
558
|
+
print_debug(text)
|
|
559
|
+
|
|
560
|
+
if get_current_run_folder():
|
|
561
|
+
try:
|
|
562
|
+
with open(get_current_run_folder("oo_errors.txt"), mode="a", encoding="utf-8") as myfile:
|
|
563
|
+
myfile.write(text + "\n\n")
|
|
564
|
+
except (OSError, FileNotFoundError) as e:
|
|
565
|
+
helpers.print_color("red", f"Error: {e}. This may mean that the {get_current_run_folder()} was deleted during the run. Could not write '{text} to {get_current_run_folder()}/oo_errors.txt'")
|
|
566
|
+
sys.exit(99)
|
|
567
|
+
|
|
549
568
|
def print_red(text: str) -> None:
|
|
550
569
|
helpers.print_color("red", text)
|
|
551
570
|
|
|
@@ -579,20 +598,22 @@ def _debug(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) ->
|
|
|
579
598
|
def _get_debug_json(time_str: str, msg: str) -> str:
|
|
580
599
|
function_stack = []
|
|
581
600
|
try:
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
func_name = frame.f_code
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
601
|
+
cf = inspect.currentframe()
|
|
602
|
+
if cf:
|
|
603
|
+
frame = cf.f_back # skip _get_debug_json
|
|
604
|
+
while frame:
|
|
605
|
+
func_name = _function_name_cache.get(frame.f_code)
|
|
606
|
+
if func_name is None:
|
|
607
|
+
func_name = frame.f_code.co_name
|
|
608
|
+
_function_name_cache[frame.f_code] = func_name
|
|
609
|
+
|
|
610
|
+
if func_name not in ("<module>", "print_debug", "wrapper"):
|
|
611
|
+
function_stack.append({
|
|
612
|
+
"function": func_name,
|
|
613
|
+
"line_number": frame.f_lineno
|
|
614
|
+
})
|
|
615
|
+
|
|
616
|
+
frame = frame.f_back
|
|
596
617
|
except (SignalUSR, SignalINT, SignalCONT):
|
|
597
618
|
print_red("\n⚠ You pressed CTRL-C. This is ignored in _get_debug_json.")
|
|
598
619
|
|
|
@@ -694,11 +715,14 @@ def my_exit(_code: int = 0) -> None:
|
|
|
694
715
|
if is_skip_search() and os.getenv("SKIP_SEARCH_EXIT_CODE"):
|
|
695
716
|
skip_search_exit_code = os.getenv("SKIP_SEARCH_EXIT_CODE")
|
|
696
717
|
|
|
718
|
+
skip_search_exit_code_found = None
|
|
719
|
+
|
|
697
720
|
try:
|
|
698
|
-
|
|
721
|
+
if skip_search_exit_code_found is not None:
|
|
722
|
+
skip_search_exit_code_found = int(skip_search_exit_code)
|
|
723
|
+
sys.exit(skip_search_exit_code_found)
|
|
699
724
|
except ValueError:
|
|
700
|
-
|
|
701
|
-
sys.exit(_code)
|
|
725
|
+
print_debug(f"Trying to look for SKIP_SEARCH_EXIT_CODE failed. Exiting with original exit code {_code}")
|
|
702
726
|
|
|
703
727
|
sys.exit(_code)
|
|
704
728
|
|
|
@@ -799,6 +823,7 @@ class ConfigLoader:
|
|
|
799
823
|
parameter: Optional[List[str]]
|
|
800
824
|
experiment_constraints: Optional[List[str]]
|
|
801
825
|
main_process_gb: int
|
|
826
|
+
beartype: bool
|
|
802
827
|
worker_timeout: int
|
|
803
828
|
slurm_signal_delay_s: int
|
|
804
829
|
gridsearch: bool
|
|
@@ -819,6 +844,7 @@ class ConfigLoader:
|
|
|
819
844
|
run_program_once: str
|
|
820
845
|
mem_gb: int
|
|
821
846
|
flame_graph: bool
|
|
847
|
+
memray: bool
|
|
822
848
|
continue_previous_job: Optional[str]
|
|
823
849
|
calculate_pareto_front_of_job: Optional[List[str]]
|
|
824
850
|
revert_to_random_when_seemingly_exhausted: bool
|
|
@@ -980,6 +1006,7 @@ class ConfigLoader:
|
|
|
980
1006
|
debug.add_argument('--verbose_break_run_search_table', help='Verbose logging for break_run_search', action='store_true', default=False)
|
|
981
1007
|
debug.add_argument('--debug', help='Enable debugging', action='store_true', default=False)
|
|
982
1008
|
debug.add_argument('--flame_graph', help='Enable flame-graphing. Makes everything slower, but creates a flame graph', action='store_true', default=False)
|
|
1009
|
+
debug.add_argument('--memray', help='Use memray to show memory usage', action='store_true', default=False)
|
|
983
1010
|
debug.add_argument('--no_sleep', help='Disables sleeping for fast job generation (not to be used on HPC)', action='store_true', default=False)
|
|
984
1011
|
debug.add_argument('--tests', help='Run simple internal tests', action='store_true', default=False)
|
|
985
1012
|
debug.add_argument('--show_worker_percentage_table_at_end', help='Show a table of percentage of usage of max worker over time', action='store_true', default=False)
|
|
@@ -993,8 +1020,8 @@ class ConfigLoader:
|
|
|
993
1020
|
debug.add_argument('--runtime_debug', help='Logs which functions use most of the time', action='store_true', default=False)
|
|
994
1021
|
debug.add_argument('--debug_stack_regex', help='Only print debug messages if call stack matches any regex', type=str, default='')
|
|
995
1022
|
debug.add_argument('--debug_stack_trace_regex', help='Show compact call stack with arrows if any function in stack matches regex', type=str, default=None)
|
|
996
|
-
|
|
997
1023
|
debug.add_argument('--show_func_name', help='Show func name before each execution and when it is done', action='store_true', default=False)
|
|
1024
|
+
debug.add_argument('--beartype', help='Use beartype', action='store_true', default=False)
|
|
998
1025
|
|
|
999
1026
|
def load_config(self: Any, config_path: str, file_format: str) -> dict:
|
|
1000
1027
|
if not os.path.isfile(config_path):
|
|
@@ -1232,11 +1259,15 @@ for _rn in args.result_names:
|
|
|
1232
1259
|
_key = _rn
|
|
1233
1260
|
_min_or_max = __default_min_max
|
|
1234
1261
|
|
|
1262
|
+
_min_or_max = re.sub(r"'", "", _min_or_max)
|
|
1263
|
+
|
|
1235
1264
|
if _min_or_max not in ["min", "max"]:
|
|
1236
1265
|
if _min_or_max:
|
|
1237
1266
|
print_yellow(f"Value for determining whether to minimize or maximize was neither 'min' nor 'max' for key '{_key}', but '{_min_or_max}'. It will be set to the default, which is '{__default_min_max}' instead.")
|
|
1238
1267
|
_min_or_max = __default_min_max
|
|
1239
1268
|
|
|
1269
|
+
_key = re.sub(r"'", "", _key)
|
|
1270
|
+
|
|
1240
1271
|
if _key in arg_result_names:
|
|
1241
1272
|
console.print(f"[red]The --result_names option '{_key}' was specified multiple times![/]")
|
|
1242
1273
|
sys.exit(50)
|
|
@@ -1343,19 +1374,6 @@ try:
|
|
|
1343
1374
|
with spinner("Importing GeneratorSpec..."):
|
|
1344
1375
|
from ax.generation_strategy.generator_spec import GeneratorSpec
|
|
1345
1376
|
|
|
1346
|
-
#except Exception:
|
|
1347
|
-
# with spinner("Fallback: Importing ax.generation_strategy.generation_node..."):
|
|
1348
|
-
# import ax.generation_strategy.generation_node
|
|
1349
|
-
|
|
1350
|
-
# with spinner("Fallback: Importing GenerationStep, GenerationStrategy from ax.generation_strategy..."):
|
|
1351
|
-
# from ax.generation_strategy.generation_node import GenerationNode, GenerationStep
|
|
1352
|
-
|
|
1353
|
-
# with spinner("Fallback: Importing ExternalGenerationNode..."):
|
|
1354
|
-
# from ax.generation_strategy.external_generation_node import ExternalGenerationNode
|
|
1355
|
-
|
|
1356
|
-
# with spinner("Fallback: Importing MaxTrials..."):
|
|
1357
|
-
# from ax.generation_strategy.transition_criterion import MaxTrials
|
|
1358
|
-
|
|
1359
1377
|
with spinner("Importing Models from ax.generation_strategy.registry..."):
|
|
1360
1378
|
from ax.adapter.registry import Models
|
|
1361
1379
|
|
|
@@ -1463,6 +1481,9 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1463
1481
|
def update_generator_state(self: Any, experiment: Experiment, data: Data) -> None:
|
|
1464
1482
|
search_space = experiment.search_space
|
|
1465
1483
|
parameter_names = list(search_space.parameters.keys())
|
|
1484
|
+
if experiment.optimization_config is None:
|
|
1485
|
+
print_red("Error: update_generator_state is None")
|
|
1486
|
+
return
|
|
1466
1487
|
metric_names = list(experiment.optimization_config.metrics.keys())
|
|
1467
1488
|
|
|
1468
1489
|
completed_trials = [
|
|
@@ -1474,7 +1495,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1474
1495
|
y = np.zeros([num_completed_trials, 1])
|
|
1475
1496
|
|
|
1476
1497
|
for t_idx, trial in enumerate(completed_trials):
|
|
1477
|
-
trial_parameters = trial.
|
|
1498
|
+
trial_parameters = trial.arms[t_idx].parameters
|
|
1478
1499
|
x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
|
|
1479
1500
|
trial_df = data.df[data.df["trial_index"] == trial.index]
|
|
1480
1501
|
y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]]["mean"].item()
|
|
@@ -1614,10 +1635,18 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1614
1635
|
def _format_best_sample(self: Any, best_sample: TParameterization, reverse_choice_map: dict) -> None:
|
|
1615
1636
|
for name in best_sample.keys():
|
|
1616
1637
|
param = self.parameters.get(name)
|
|
1638
|
+
best_sample_by_name = best_sample[name]
|
|
1639
|
+
|
|
1617
1640
|
if isinstance(param, RangeParameter) and param.parameter_type == ParameterType.INT:
|
|
1618
|
-
|
|
1641
|
+
if best_sample_by_name is not None:
|
|
1642
|
+
best_sample[name] = int(round(float(best_sample_by_name)))
|
|
1643
|
+
else:
|
|
1644
|
+
print_debug("best_sample_by_name was empty")
|
|
1619
1645
|
elif isinstance(param, ChoiceParameter):
|
|
1620
|
-
|
|
1646
|
+
if best_sample_by_name is not None:
|
|
1647
|
+
best_sample[name] = str(reverse_choice_map.get(int(best_sample_by_name)))
|
|
1648
|
+
else:
|
|
1649
|
+
print_debug("best_sample_by_name was empty")
|
|
1621
1650
|
|
|
1622
1651
|
decoder_registry["RandomForestGenerationNode"] = RandomForestGenerationNode
|
|
1623
1652
|
|
|
@@ -2048,9 +2077,9 @@ def run_live_share_command(force: bool = False) -> Tuple[str, str]:
|
|
|
2048
2077
|
return str(result.stdout), str(result.stderr)
|
|
2049
2078
|
except subprocess.CalledProcessError as e:
|
|
2050
2079
|
if e.stderr:
|
|
2051
|
-
|
|
2080
|
+
print_debug(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
|
|
2052
2081
|
else:
|
|
2053
|
-
|
|
2082
|
+
print_debug(f"run_live_share_command: command failed with error: {e}")
|
|
2054
2083
|
return "", str(e.stderr)
|
|
2055
2084
|
except Exception as e:
|
|
2056
2085
|
print(f"run_live_share_command: An error occurred: {e}")
|
|
@@ -2064,6 +2093,8 @@ def force_live_share() -> bool:
|
|
|
2064
2093
|
return False
|
|
2065
2094
|
|
|
2066
2095
|
def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
|
|
2096
|
+
log_data()
|
|
2097
|
+
|
|
2067
2098
|
if not get_current_run_folder():
|
|
2068
2099
|
print(f"live_share: get_current_run_folder was empty or false: {get_current_run_folder()}")
|
|
2069
2100
|
return False
|
|
@@ -2077,23 +2108,16 @@ def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
|
|
|
2077
2108
|
if stderr:
|
|
2078
2109
|
print_green(stderr)
|
|
2079
2110
|
else:
|
|
2080
|
-
print_red("This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
|
|
2111
|
+
print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
|
|
2081
2112
|
if stdout:
|
|
2082
2113
|
print_debug(f"live_share stdout: {stdout}")
|
|
2083
2114
|
|
|
2084
2115
|
return True
|
|
2085
2116
|
|
|
2086
2117
|
def init_live_share() -> bool:
|
|
2087
|
-
|
|
2088
|
-
ret = live_share(True, True)
|
|
2089
|
-
|
|
2090
|
-
return ret
|
|
2118
|
+
ret = live_share(True, True)
|
|
2091
2119
|
|
|
2092
|
-
|
|
2093
|
-
if args.live_share and not os.environ.get("CI"):
|
|
2094
|
-
while True:
|
|
2095
|
-
live_share(force=False)
|
|
2096
|
-
time.sleep(30)
|
|
2120
|
+
return ret
|
|
2097
2121
|
|
|
2098
2122
|
def init_storage(db_url: str) -> None:
|
|
2099
2123
|
init_engine_and_session_factory(url=db_url, force_init=True)
|
|
@@ -2119,7 +2143,11 @@ def try_saving_to_db() -> None:
|
|
|
2119
2143
|
else:
|
|
2120
2144
|
print_red("ax_client was not defined in try_saving_to_db")
|
|
2121
2145
|
my_exit(101)
|
|
2122
|
-
|
|
2146
|
+
|
|
2147
|
+
if global_gs is not None:
|
|
2148
|
+
save_generation_strategy(global_gs)
|
|
2149
|
+
else:
|
|
2150
|
+
print_red("Not saving generation strategy: global_gs was empty")
|
|
2123
2151
|
except Exception as e:
|
|
2124
2152
|
print_debug(f"Failed trying to save sqlite3-DB: {e}")
|
|
2125
2153
|
|
|
@@ -2150,6 +2178,8 @@ def merge_with_job_infos(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
2150
2178
|
return merged
|
|
2151
2179
|
|
|
2152
2180
|
def save_results_csv() -> Optional[str]:
|
|
2181
|
+
log_data()
|
|
2182
|
+
|
|
2153
2183
|
if args.dryrun:
|
|
2154
2184
|
return None
|
|
2155
2185
|
|
|
@@ -2164,6 +2194,9 @@ def save_results_csv() -> Optional[str]:
|
|
|
2164
2194
|
|
|
2165
2195
|
try:
|
|
2166
2196
|
df = fetch_and_prepare_trials()
|
|
2197
|
+
if df is None:
|
|
2198
|
+
print_red(f"save_results_csv: fetch_and_prepare_trials returned an empty element: {df}")
|
|
2199
|
+
return None
|
|
2167
2200
|
write_csv(df, pd_csv)
|
|
2168
2201
|
write_json_snapshot(pd_json)
|
|
2169
2202
|
save_experiment_to_file()
|
|
@@ -2176,43 +2209,67 @@ def save_results_csv() -> Optional[str]:
|
|
|
2176
2209
|
except (SignalUSR, SignalCONT, SignalINT) as e:
|
|
2177
2210
|
raise type(e)(str(e)) from e
|
|
2178
2211
|
except Exception as e:
|
|
2179
|
-
print_red(f"
|
|
2212
|
+
print_red(f"\nWhile saving all trials as a pandas-dataframe-csv, an error occurred: {e}")
|
|
2180
2213
|
|
|
2181
2214
|
return pd_csv
|
|
2182
2215
|
|
|
2183
2216
|
def get_results_paths() -> tuple[str, str]:
|
|
2184
2217
|
return (get_current_run_folder(RESULTS_CSV_FILENAME), get_state_file_name('pd.json'))
|
|
2185
2218
|
|
|
2186
|
-
def
|
|
2187
|
-
ax_client
|
|
2188
|
-
|
|
2219
|
+
def ax_client_get_trials_data_frame() -> Optional[pd.DataFrame]:
|
|
2220
|
+
if not ax_client:
|
|
2221
|
+
my_exit(101)
|
|
2222
|
+
|
|
2223
|
+
return None
|
|
2224
|
+
|
|
2225
|
+
return ax_client.get_trials_data_frame()
|
|
2226
|
+
|
|
2227
|
+
def fetch_and_prepare_trials() -> Optional[pd.DataFrame]:
|
|
2228
|
+
if not ax_client:
|
|
2229
|
+
return None
|
|
2189
2230
|
|
|
2190
|
-
|
|
2191
|
-
|
|
2192
|
-
#print(df["generation_node"])
|
|
2231
|
+
ax_client.experiment.fetch_data()
|
|
2232
|
+
df = ax_client_get_trials_data_frame()
|
|
2193
2233
|
df = merge_with_job_infos(df)
|
|
2194
|
-
#print("AFTER merge_with_job_infos:")
|
|
2195
|
-
#print(df["generation_node"])
|
|
2196
2234
|
|
|
2197
2235
|
return df
|
|
2198
2236
|
|
|
2199
|
-
def write_csv(df, path: str) -> None:
|
|
2237
|
+
def write_csv(df: pd.DataFrame, path: str) -> None:
|
|
2200
2238
|
try:
|
|
2201
2239
|
df = df.sort_values(by=["trial_index"], kind="stable").reset_index(drop=True)
|
|
2202
2240
|
except KeyError:
|
|
2203
2241
|
pass
|
|
2204
2242
|
df.to_csv(path, index=False, float_format="%.30f")
|
|
2205
2243
|
|
|
2206
|
-
def
|
|
2244
|
+
def ax_client_to_json_snapshot() -> Optional[dict]:
|
|
2245
|
+
if not ax_client:
|
|
2246
|
+
my_exit(101)
|
|
2247
|
+
|
|
2248
|
+
return None
|
|
2249
|
+
|
|
2207
2250
|
json_snapshot = ax_client.to_json_snapshot()
|
|
2208
|
-
|
|
2209
|
-
|
|
2251
|
+
|
|
2252
|
+
return json_snapshot
|
|
2253
|
+
|
|
2254
|
+
def write_json_snapshot(path: str) -> None:
|
|
2255
|
+
if ax_client is not None:
|
|
2256
|
+
json_snapshot = ax_client_to_json_snapshot()
|
|
2257
|
+
if json_snapshot is not None:
|
|
2258
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
2259
|
+
json.dump(json_snapshot, f, indent=4)
|
|
2260
|
+
else:
|
|
2261
|
+
print_debug('json_snapshot from ax_client_to_json_snapshot was None')
|
|
2262
|
+
else:
|
|
2263
|
+
print_red("write_json_snapshot: ax_client was None")
|
|
2210
2264
|
|
|
2211
2265
|
def save_experiment_to_file() -> None:
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2215
|
-
|
|
2266
|
+
if ax_client is not None:
|
|
2267
|
+
save_experiment(
|
|
2268
|
+
ax_client.experiment,
|
|
2269
|
+
get_state_file_name("ax_client.experiment.json")
|
|
2270
|
+
)
|
|
2271
|
+
else:
|
|
2272
|
+
print_red("save_experiment: ax_client is None")
|
|
2216
2273
|
|
|
2217
2274
|
def should_save_to_database() -> bool:
|
|
2218
2275
|
return args.model not in uncontinuable_models and args.save_to_database
|
|
@@ -2401,7 +2458,7 @@ def set_nr_inserted_jobs(new_nr_inserted_jobs: int) -> None:
|
|
|
2401
2458
|
|
|
2402
2459
|
def write_worker_usage() -> None:
|
|
2403
2460
|
if len(WORKER_PERCENTAGE_USAGE):
|
|
2404
|
-
csv_filename = get_current_run_folder(
|
|
2461
|
+
csv_filename = get_current_run_folder(worker_usage_file)
|
|
2405
2462
|
|
|
2406
2463
|
csv_columns = ['time', 'num_parallel_jobs', 'nr_current_workers', 'percentage']
|
|
2407
2464
|
|
|
@@ -2411,35 +2468,39 @@ def write_worker_usage() -> None:
|
|
|
2411
2468
|
csv_writer.writerow(row)
|
|
2412
2469
|
else:
|
|
2413
2470
|
if is_slurm_job():
|
|
2414
|
-
print_debug("WORKER_PERCENTAGE_USAGE seems to be empty. Not writing
|
|
2471
|
+
print_debug(f"WORKER_PERCENTAGE_USAGE seems to be empty. Not writing {worker_usage_file}")
|
|
2415
2472
|
|
|
2416
2473
|
def log_system_usage() -> None:
|
|
2474
|
+
global LAST_LOG_TIME
|
|
2475
|
+
|
|
2476
|
+
now = time.time()
|
|
2477
|
+
if now - LAST_LOG_TIME < 30:
|
|
2478
|
+
return
|
|
2479
|
+
|
|
2480
|
+
LAST_LOG_TIME = int(now)
|
|
2481
|
+
|
|
2417
2482
|
if not get_current_run_folder():
|
|
2418
2483
|
return
|
|
2419
2484
|
|
|
2420
2485
|
ram_cpu_csv_file_path = os.path.join(get_current_run_folder(), "cpu_ram_usage.csv")
|
|
2421
|
-
|
|
2422
2486
|
makedirs(os.path.dirname(ram_cpu_csv_file_path))
|
|
2423
2487
|
|
|
2424
2488
|
file_exists = os.path.isfile(ram_cpu_csv_file_path)
|
|
2425
2489
|
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
current_time = int(time.time())
|
|
2430
|
-
|
|
2431
|
-
if process is not None:
|
|
2432
|
-
mem_proc = process.memory_info()
|
|
2433
|
-
|
|
2434
|
-
if mem_proc is not None:
|
|
2435
|
-
ram_usage_mb = mem_proc.rss / (1024 * 1024)
|
|
2436
|
-
cpu_usage_percent = psutil.cpu_percent(percpu=False)
|
|
2490
|
+
mem_proc = process.memory_info() if process else None
|
|
2491
|
+
if not mem_proc:
|
|
2492
|
+
return
|
|
2437
2493
|
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
|
|
2494
|
+
ram_usage_mb = mem_proc.rss / (1024 * 1024)
|
|
2495
|
+
cpu_usage_percent = psutil.cpu_percent(percpu=False)
|
|
2496
|
+
if ram_usage_mb <= 0 or cpu_usage_percent <= 0:
|
|
2497
|
+
return
|
|
2441
2498
|
|
|
2442
|
-
|
|
2499
|
+
with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
|
|
2500
|
+
writer = csv.writer(file)
|
|
2501
|
+
if not file_exists:
|
|
2502
|
+
writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
|
|
2503
|
+
writer.writerow([int(now), ram_usage_mb, cpu_usage_percent])
|
|
2443
2504
|
|
|
2444
2505
|
def write_process_info() -> None:
|
|
2445
2506
|
try:
|
|
@@ -2789,10 +2850,20 @@ def print_debug_get_next_trials(got: int, requested: int, _line: int) -> None:
|
|
|
2789
2850
|
log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, 0, "")
|
|
2790
2851
|
|
|
2791
2852
|
def print_debug_progressbar(msg: str) -> None:
|
|
2792
|
-
|
|
2793
|
-
|
|
2853
|
+
global last_msg_progressbar, last_msg_raw
|
|
2854
|
+
|
|
2855
|
+
try:
|
|
2856
|
+
with last_lock_print_debug:
|
|
2857
|
+
if msg != last_msg_raw:
|
|
2858
|
+
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
2859
|
+
full_msg = f"{time_str} ({worker_generator_uuid}): {msg}"
|
|
2794
2860
|
|
|
2795
|
-
|
|
2861
|
+
_debug_progressbar(full_msg)
|
|
2862
|
+
|
|
2863
|
+
last_msg_raw = msg
|
|
2864
|
+
last_msg_progressbar = full_msg
|
|
2865
|
+
except Exception as e:
|
|
2866
|
+
print(f"Error in print_debug_progressbar: {e}", flush=True)
|
|
2796
2867
|
|
|
2797
2868
|
def get_process_info(pid: Any) -> str:
|
|
2798
2869
|
try:
|
|
@@ -3297,135 +3368,457 @@ def parse_experiment_parameters() -> None:
|
|
|
3297
3368
|
# Remove duplicates by 'name' key preserving order
|
|
3298
3369
|
params = list({p['name']: p for p in params}.values())
|
|
3299
3370
|
|
|
3300
|
-
experiment_parameters = params
|
|
3371
|
+
experiment_parameters = params # type: ignore[assignment]
|
|
3301
3372
|
|
|
3302
|
-
def
|
|
3303
|
-
|
|
3304
|
-
_fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
|
|
3373
|
+
def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
|
|
3374
|
+
pf_start_time = time.time()
|
|
3305
3375
|
|
|
3306
|
-
|
|
3307
|
-
|
|
3308
|
-
valid_value_types_string = ", ".join(valid_value_types)
|
|
3309
|
-
_fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
|
|
3376
|
+
if not path_to_calculate:
|
|
3377
|
+
return False
|
|
3310
3378
|
|
|
3311
|
-
|
|
3312
|
-
|
|
3313
|
-
|
|
3379
|
+
global CURRENT_RUN_FOLDER
|
|
3380
|
+
global RESULT_CSV_FILE
|
|
3381
|
+
global arg_result_names
|
|
3314
3382
|
|
|
3315
|
-
|
|
3316
|
-
|
|
3317
|
-
|
|
3318
|
-
if upper_bound == lower_bound:
|
|
3319
|
-
if lower_bound == 0:
|
|
3320
|
-
_fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
|
|
3321
|
-
print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
|
|
3322
|
-
if upper_bound is not None:
|
|
3323
|
-
lower_bound = -upper_bound
|
|
3383
|
+
if not path_to_calculate:
|
|
3384
|
+
print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
|
|
3385
|
+
return False
|
|
3324
3386
|
|
|
3325
|
-
|
|
3326
|
-
|
|
3327
|
-
|
|
3328
|
-
s = format(value, float_format)
|
|
3329
|
-
s = s.rstrip('0').rstrip('.') if '.' in s else s
|
|
3330
|
-
return s
|
|
3331
|
-
return str(value)
|
|
3332
|
-
except Exception as e:
|
|
3333
|
-
print_red(f"⚠ Error formatting the number {value}: {e}")
|
|
3334
|
-
return str(value)
|
|
3387
|
+
if not os.path.exists(path_to_calculate):
|
|
3388
|
+
print_red(f"Path '{path_to_calculate}' does not exist")
|
|
3389
|
+
return False
|
|
3335
3390
|
|
|
3336
|
-
|
|
3337
|
-
parameters: dict,
|
|
3338
|
-
input_string: str,
|
|
3339
|
-
float_format: str = '.20f',
|
|
3340
|
-
additional_prefixes: list[str] = [],
|
|
3341
|
-
additional_patterns: list[str] = [],
|
|
3342
|
-
) -> str:
|
|
3343
|
-
try:
|
|
3344
|
-
prefixes = ['$', '%'] + additional_prefixes
|
|
3345
|
-
patterns = ['{key}', '({key})'] + additional_patterns
|
|
3391
|
+
ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
|
|
3346
3392
|
|
|
3347
|
-
|
|
3348
|
-
|
|
3349
|
-
|
|
3350
|
-
for pattern in patterns:
|
|
3351
|
-
token = prefix + pattern.format(key=key)
|
|
3352
|
-
input_string = input_string.replace(token, replacement)
|
|
3393
|
+
if not os.path.exists(ax_client_json):
|
|
3394
|
+
print_red(f"Path '{ax_client_json}' not found")
|
|
3395
|
+
return False
|
|
3353
3396
|
|
|
3354
|
-
|
|
3355
|
-
|
|
3397
|
+
checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
|
|
3398
|
+
if not os.path.exists(checkpoint_file):
|
|
3399
|
+
print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
|
|
3400
|
+
return False
|
|
3356
3401
|
|
|
3357
|
-
|
|
3358
|
-
|
|
3359
|
-
|
|
3402
|
+
RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3403
|
+
if not os.path.exists(RESULT_CSV_FILE):
|
|
3404
|
+
print_red(f"{RESULT_CSV_FILE} not found")
|
|
3405
|
+
return False
|
|
3360
3406
|
|
|
3361
|
-
|
|
3362
|
-
user_uid = os.getuid()
|
|
3407
|
+
res_names = []
|
|
3363
3408
|
|
|
3364
|
-
|
|
3365
|
-
|
|
3366
|
-
|
|
3367
|
-
|
|
3409
|
+
res_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3410
|
+
if not os.path.exists(res_names_file):
|
|
3411
|
+
print_red(f"File '{res_names_file}' does not exist")
|
|
3412
|
+
return False
|
|
3368
3413
|
|
|
3369
|
-
|
|
3414
|
+
try:
|
|
3415
|
+
with open(res_names_file, "r", encoding="utf-8") as file:
|
|
3416
|
+
lines = file.readlines()
|
|
3417
|
+
except Exception as e:
|
|
3418
|
+
print_red(f"Error reading file '{res_names_file}': {e}")
|
|
3419
|
+
return False
|
|
3370
3420
|
|
|
3371
|
-
|
|
3372
|
-
|
|
3373
|
-
|
|
3374
|
-
|
|
3375
|
-
self.running = True
|
|
3376
|
-
self.thread = threading.Thread(target=self._monitor)
|
|
3377
|
-
self.thread.daemon = True
|
|
3421
|
+
for line in lines:
|
|
3422
|
+
entry = line.strip()
|
|
3423
|
+
if entry != "":
|
|
3424
|
+
res_names.append(entry)
|
|
3378
3425
|
|
|
3379
|
-
|
|
3426
|
+
if len(res_names) < 2:
|
|
3427
|
+
print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
|
|
3428
|
+
return False
|
|
3380
3429
|
|
|
3381
|
-
|
|
3382
|
-
try:
|
|
3383
|
-
_internal_process = psutil.Process(self.pid)
|
|
3384
|
-
while self.running and _internal_process.is_running():
|
|
3385
|
-
crf = get_current_run_folder()
|
|
3430
|
+
load_username_to_args(path_to_calculate)
|
|
3386
3431
|
|
|
3387
|
-
|
|
3388
|
-
log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
|
|
3432
|
+
CURRENT_RUN_FOLDER = path_to_calculate
|
|
3389
3433
|
|
|
3390
|
-
|
|
3434
|
+
arg_result_names = res_names
|
|
3391
3435
|
|
|
3392
|
-
|
|
3393
|
-
hostname = socket.gethostname()
|
|
3436
|
+
load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
|
|
3394
3437
|
|
|
3395
|
-
|
|
3438
|
+
if experiment_parameters is None:
|
|
3439
|
+
return False
|
|
3396
3440
|
|
|
3397
|
-
|
|
3398
|
-
hostname += f"-SLURM-ID-{slurm_job_id}"
|
|
3441
|
+
show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
|
|
3399
3442
|
|
|
3400
|
-
|
|
3401
|
-
cpu_usage = psutil.cpu_percent(interval=5)
|
|
3443
|
+
pf_end_time = time.time()
|
|
3402
3444
|
|
|
3403
|
-
|
|
3445
|
+
print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
|
|
3404
3446
|
|
|
3405
|
-
|
|
3447
|
+
return True
|
|
3406
3448
|
|
|
3407
|
-
|
|
3408
|
-
|
|
3409
|
-
|
|
3410
|
-
|
|
3449
|
+
def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
|
|
3450
|
+
if args.dryrun:
|
|
3451
|
+
print_debug("Not showing Pareto-frontier data with --dryrun")
|
|
3452
|
+
return None
|
|
3411
3453
|
|
|
3412
|
-
|
|
3413
|
-
|
|
3414
|
-
|
|
3454
|
+
if len(res_names) > 1:
|
|
3455
|
+
try:
|
|
3456
|
+
show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
|
|
3457
|
+
except Exception as e:
|
|
3458
|
+
inner_tb = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
3459
|
+
print_red(f"show_pareto_frontier_data() failed with exception '{e}':\n{inner_tb}")
|
|
3460
|
+
else:
|
|
3461
|
+
print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
|
|
3462
|
+
return None
|
|
3415
3463
|
|
|
3416
|
-
|
|
3417
|
-
|
|
3418
|
-
self.thread.join()
|
|
3464
|
+
def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
|
|
3465
|
+
pareto_front_data: dict = {}
|
|
3419
3466
|
|
|
3420
|
-
|
|
3421
|
-
process_item = subprocess.Popen(code, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
3467
|
+
all_combinations = list(combinations(range(len(arg_result_names)), 2))
|
|
3422
3468
|
|
|
3423
|
-
|
|
3424
|
-
|
|
3425
|
-
|
|
3426
|
-
|
|
3427
|
-
|
|
3428
|
-
|
|
3469
|
+
skip = False
|
|
3470
|
+
|
|
3471
|
+
for i, j in all_combinations:
|
|
3472
|
+
if not skip:
|
|
3473
|
+
metric_x = arg_result_names[i]
|
|
3474
|
+
metric_y = arg_result_names[j]
|
|
3475
|
+
|
|
3476
|
+
x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
|
|
3477
|
+
y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
|
|
3478
|
+
|
|
3479
|
+
try:
|
|
3480
|
+
if metric_x not in pareto_front_data:
|
|
3481
|
+
pareto_front_data[metric_x] = {}
|
|
3482
|
+
|
|
3483
|
+
pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
|
|
3484
|
+
except ax.exceptions.core.DataRequiredError as e:
|
|
3485
|
+
print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
|
|
3486
|
+
except SignalINT:
|
|
3487
|
+
print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
|
|
3488
|
+
skip = True
|
|
3489
|
+
|
|
3490
|
+
return pareto_front_data
|
|
3491
|
+
|
|
3492
|
+
def pareto_front_transform_objectives(
|
|
3493
|
+
points: List[Tuple[Any, float, float]],
|
|
3494
|
+
primary_name: str,
|
|
3495
|
+
secondary_name: str
|
|
3496
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
3497
|
+
primary_idx = arg_result_names.index(primary_name)
|
|
3498
|
+
secondary_idx = arg_result_names.index(secondary_name)
|
|
3499
|
+
|
|
3500
|
+
x = np.array([p[1] for p in points])
|
|
3501
|
+
y = np.array([p[2] for p in points])
|
|
3502
|
+
|
|
3503
|
+
if arg_result_min_or_max[primary_idx] == "max":
|
|
3504
|
+
x = -x
|
|
3505
|
+
elif arg_result_min_or_max[primary_idx] != "min":
|
|
3506
|
+
raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
|
|
3507
|
+
|
|
3508
|
+
if arg_result_min_or_max[secondary_idx] == "max":
|
|
3509
|
+
y = -y
|
|
3510
|
+
elif arg_result_min_or_max[secondary_idx] != "min":
|
|
3511
|
+
raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
|
|
3512
|
+
|
|
3513
|
+
return x, y
|
|
3514
|
+
|
|
3515
|
+
def get_pareto_frontier_points(
|
|
3516
|
+
path_to_calculate: str,
|
|
3517
|
+
primary_objective: str,
|
|
3518
|
+
secondary_objective: str,
|
|
3519
|
+
x_minimize: bool,
|
|
3520
|
+
y_minimize: bool,
|
|
3521
|
+
absolute_metrics: List[str],
|
|
3522
|
+
num_points: int
|
|
3523
|
+
) -> Optional[dict]:
|
|
3524
|
+
records = pareto_front_aggregate_data(path_to_calculate)
|
|
3525
|
+
|
|
3526
|
+
if records is None:
|
|
3527
|
+
return None
|
|
3528
|
+
|
|
3529
|
+
points = pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
|
|
3530
|
+
x, y = pareto_front_transform_objectives(points, primary_objective, secondary_objective)
|
|
3531
|
+
selected_points = pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
|
|
3532
|
+
result = pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
|
|
3533
|
+
|
|
3534
|
+
return result
|
|
3535
|
+
|
|
3536
|
+
def pareto_front_table_read_csv() -> List[Dict[str, str]]:
|
|
3537
|
+
with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
|
|
3538
|
+
return list(csv.DictReader(f))
|
|
3539
|
+
|
|
3540
|
+
def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
|
|
3541
|
+
table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
|
|
3542
|
+
|
|
3543
|
+
rows = pareto_front_table_read_csv()
|
|
3544
|
+
if not rows:
|
|
3545
|
+
table.add_column("No data found")
|
|
3546
|
+
return table
|
|
3547
|
+
|
|
3548
|
+
filtered_rows = pareto_front_table_filter_rows(rows, idxs)
|
|
3549
|
+
if not filtered_rows:
|
|
3550
|
+
table.add_column("No matching entries")
|
|
3551
|
+
return table
|
|
3552
|
+
|
|
3553
|
+
param_cols, result_cols = pareto_front_table_get_columns(filtered_rows[0])
|
|
3554
|
+
|
|
3555
|
+
pareto_front_table_add_headers(table, param_cols, result_cols)
|
|
3556
|
+
pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
|
|
3557
|
+
|
|
3558
|
+
return table
|
|
3559
|
+
|
|
3560
|
+
def pareto_front_build_return_structure(
|
|
3561
|
+
path_to_calculate: str,
|
|
3562
|
+
selected_points: List[Tuple[Any, float, float]],
|
|
3563
|
+
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
3564
|
+
absolute_metrics: List[str],
|
|
3565
|
+
primary_name: str,
|
|
3566
|
+
secondary_name: str
|
|
3567
|
+
) -> dict:
|
|
3568
|
+
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3569
|
+
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3570
|
+
|
|
3571
|
+
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
3572
|
+
result_names = [line.strip() for line in f if line.strip()]
|
|
3573
|
+
|
|
3574
|
+
csv_rows = {}
|
|
3575
|
+
with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
|
|
3576
|
+
reader = csv.DictReader(csvfile)
|
|
3577
|
+
for row in reader:
|
|
3578
|
+
trial_index = int(row['trial_index'])
|
|
3579
|
+
csv_rows[trial_index] = row
|
|
3580
|
+
|
|
3581
|
+
ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
|
|
3582
|
+
ignored_columns.update(result_names)
|
|
3583
|
+
|
|
3584
|
+
param_dicts = []
|
|
3585
|
+
idxs = []
|
|
3586
|
+
means_dict = defaultdict(list)
|
|
3587
|
+
|
|
3588
|
+
for (trial_index, arm_name), _, _ in selected_points:
|
|
3589
|
+
row = csv_rows.get(trial_index, {})
|
|
3590
|
+
if row == {} or row is None or row['arm_name'] != arm_name:
|
|
3591
|
+
continue
|
|
3592
|
+
|
|
3593
|
+
idxs.append(int(row["trial_index"]))
|
|
3594
|
+
|
|
3595
|
+
param_dict: dict[str, int | float | str] = {}
|
|
3596
|
+
for key, value in row.items():
|
|
3597
|
+
if key not in ignored_columns:
|
|
3598
|
+
try:
|
|
3599
|
+
param_dict[key] = int(value)
|
|
3600
|
+
except ValueError:
|
|
3601
|
+
try:
|
|
3602
|
+
param_dict[key] = float(value)
|
|
3603
|
+
except ValueError:
|
|
3604
|
+
param_dict[key] = value
|
|
3605
|
+
|
|
3606
|
+
param_dicts.append(param_dict)
|
|
3607
|
+
|
|
3608
|
+
for metric in absolute_metrics:
|
|
3609
|
+
means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
|
|
3610
|
+
|
|
3611
|
+
ret = {
|
|
3612
|
+
primary_name: {
|
|
3613
|
+
secondary_name: {
|
|
3614
|
+
"absolute_metrics": absolute_metrics,
|
|
3615
|
+
"param_dicts": param_dicts,
|
|
3616
|
+
"means": dict(means_dict),
|
|
3617
|
+
"idxs": idxs
|
|
3618
|
+
},
|
|
3619
|
+
"absolute_metrics": absolute_metrics
|
|
3620
|
+
}
|
|
3621
|
+
}
|
|
3622
|
+
|
|
3623
|
+
return ret
|
|
3624
|
+
|
|
3625
|
+
def pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
|
|
3626
|
+
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3627
|
+
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3628
|
+
|
|
3629
|
+
if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
|
|
3630
|
+
return None
|
|
3631
|
+
|
|
3632
|
+
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
3633
|
+
result_names = [line.strip() for line in f if line.strip()]
|
|
3634
|
+
|
|
3635
|
+
records: dict = defaultdict(lambda: {'means': {}})
|
|
3636
|
+
|
|
3637
|
+
with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
|
|
3638
|
+
reader = csv.DictReader(csvfile)
|
|
3639
|
+
for row in reader:
|
|
3640
|
+
trial_index = int(row['trial_index'])
|
|
3641
|
+
arm_name = row['arm_name']
|
|
3642
|
+
key = (trial_index, arm_name)
|
|
3643
|
+
|
|
3644
|
+
for metric in result_names:
|
|
3645
|
+
if metric in row:
|
|
3646
|
+
try:
|
|
3647
|
+
records[key]['means'][metric] = float(row[metric])
|
|
3648
|
+
except ValueError:
|
|
3649
|
+
continue
|
|
3650
|
+
|
|
3651
|
+
return records
|
|
3652
|
+
|
|
3653
|
+
def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
|
|
3654
|
+
if data is None:
|
|
3655
|
+
print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
|
|
3656
|
+
return
|
|
3657
|
+
|
|
3658
|
+
if not supports_sixel():
|
|
3659
|
+
print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
|
|
3660
|
+
return
|
|
3661
|
+
|
|
3662
|
+
import matplotlib.pyplot as plt
|
|
3663
|
+
|
|
3664
|
+
means = data[x_metric][y_metric]["means"]
|
|
3665
|
+
|
|
3666
|
+
x_values = means[x_metric]
|
|
3667
|
+
y_values = means[y_metric]
|
|
3668
|
+
|
|
3669
|
+
fig, _ax = plt.subplots()
|
|
3670
|
+
|
|
3671
|
+
_ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
|
|
3672
|
+
|
|
3673
|
+
_ax.set_xlabel(x_metric)
|
|
3674
|
+
_ax.set_ylabel(y_metric)
|
|
3675
|
+
|
|
3676
|
+
_ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
|
|
3677
|
+
|
|
3678
|
+
_ax.ticklabel_format(style='plain', axis='both', useOffset=False)
|
|
3679
|
+
|
|
3680
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
|
|
3681
|
+
plt.savefig(tmp_file.name, dpi=300)
|
|
3682
|
+
|
|
3683
|
+
print_image_to_cli(tmp_file.name, 1000)
|
|
3684
|
+
|
|
3685
|
+
plt.close(fig)
|
|
3686
|
+
|
|
3687
|
+
def pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
|
|
3688
|
+
all_columns = list(first_row.keys())
|
|
3689
|
+
ignored_cols = set(special_col_names) - {"trial_index"}
|
|
3690
|
+
|
|
3691
|
+
param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
|
|
3692
|
+
result_cols = [col for col in arg_result_names if col in all_columns]
|
|
3693
|
+
return param_cols, result_cols
|
|
3694
|
+
|
|
3695
|
+
def check_factorial_range() -> None:
|
|
3696
|
+
if args.model and args.model == "FACTORIAL":
|
|
3697
|
+
_fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
|
|
3698
|
+
|
|
3699
|
+
def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
|
|
3700
|
+
if value_type not in valid_value_types:
|
|
3701
|
+
valid_value_types_string = ", ".join(valid_value_types)
|
|
3702
|
+
_fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
|
|
3703
|
+
|
|
3704
|
+
def check_range_params_length(this_args: Union[str, list]) -> None:
|
|
3705
|
+
if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
|
|
3706
|
+
_fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
|
|
3707
|
+
|
|
3708
|
+
def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
|
|
3709
|
+
if upper_bound is None or lower_bound is None:
|
|
3710
|
+
_fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
|
|
3711
|
+
if upper_bound == lower_bound:
|
|
3712
|
+
if lower_bound == 0:
|
|
3713
|
+
_fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
|
|
3714
|
+
print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
|
|
3715
|
+
if upper_bound is not None:
|
|
3716
|
+
lower_bound = -upper_bound
|
|
3717
|
+
|
|
3718
|
+
def format_value(value: Any, float_format: str = '.80f') -> str:
|
|
3719
|
+
try:
|
|
3720
|
+
if isinstance(value, float):
|
|
3721
|
+
s = format(value, float_format)
|
|
3722
|
+
s = s.rstrip('0').rstrip('.') if '.' in s else s
|
|
3723
|
+
return s
|
|
3724
|
+
return str(value)
|
|
3725
|
+
except Exception as e:
|
|
3726
|
+
print_red(f"⚠ Error formatting the number {value}: {e}")
|
|
3727
|
+
return str(value)
|
|
3728
|
+
|
|
3729
|
+
def replace_parameters_in_string(
|
|
3730
|
+
parameters: dict,
|
|
3731
|
+
input_string: str,
|
|
3732
|
+
float_format: str = '.20f',
|
|
3733
|
+
additional_prefixes: list[str] = [],
|
|
3734
|
+
additional_patterns: list[str] = [],
|
|
3735
|
+
) -> str:
|
|
3736
|
+
try:
|
|
3737
|
+
prefixes = ['$', '%'] + additional_prefixes
|
|
3738
|
+
patterns = ['{' + 'key' + '}', '(' + '{' + 'key' + '}' + ')'] + additional_patterns
|
|
3739
|
+
|
|
3740
|
+
for key, value in parameters.items():
|
|
3741
|
+
replacement = format_value(value, float_format=float_format)
|
|
3742
|
+
for prefix in prefixes:
|
|
3743
|
+
for pattern in patterns:
|
|
3744
|
+
token = prefix + pattern.format(key=key)
|
|
3745
|
+
input_string = input_string.replace(token, replacement)
|
|
3746
|
+
|
|
3747
|
+
input_string = input_string.replace('\r', ' ').replace('\n', ' ')
|
|
3748
|
+
return input_string
|
|
3749
|
+
|
|
3750
|
+
except Exception as e:
|
|
3751
|
+
print_red(f"\n⚠ Error: {e}")
|
|
3752
|
+
return ""
|
|
3753
|
+
|
|
3754
|
+
def get_memory_usage() -> float:
|
|
3755
|
+
user_uid = os.getuid()
|
|
3756
|
+
|
|
3757
|
+
memory_usage = float(sum(
|
|
3758
|
+
p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
|
|
3759
|
+
if p.info['uids'].real == user_uid
|
|
3760
|
+
) / (1024 * 1024))
|
|
3761
|
+
|
|
3762
|
+
return memory_usage
|
|
3763
|
+
|
|
3764
|
+
class MonitorProcess:
|
|
3765
|
+
def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
|
|
3766
|
+
self.pid = pid
|
|
3767
|
+
self.interval = interval
|
|
3768
|
+
self.running = True
|
|
3769
|
+
self.thread = threading.Thread(target=self._monitor)
|
|
3770
|
+
self.thread.daemon = True
|
|
3771
|
+
|
|
3772
|
+
fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
|
|
3773
|
+
|
|
3774
|
+
def _monitor(self: Any) -> None:
|
|
3775
|
+
try:
|
|
3776
|
+
_internal_process = psutil.Process(self.pid)
|
|
3777
|
+
while self.running and _internal_process.is_running():
|
|
3778
|
+
crf = get_current_run_folder()
|
|
3779
|
+
|
|
3780
|
+
if crf and crf != "":
|
|
3781
|
+
log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
|
|
3782
|
+
|
|
3783
|
+
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
|
3784
|
+
|
|
3785
|
+
with open(log_file_path, mode="a", encoding="utf-8") as log_file:
|
|
3786
|
+
hostname = socket.gethostname()
|
|
3787
|
+
|
|
3788
|
+
slurm_job_id = os.getenv("SLURM_JOB_ID")
|
|
3789
|
+
|
|
3790
|
+
if slurm_job_id:
|
|
3791
|
+
hostname += f"-SLURM-ID-{slurm_job_id}"
|
|
3792
|
+
|
|
3793
|
+
total_memory = psutil.virtual_memory().total / (1024 * 1024)
|
|
3794
|
+
cpu_usage = psutil.cpu_percent(interval=5)
|
|
3795
|
+
|
|
3796
|
+
memory_usage = get_memory_usage()
|
|
3797
|
+
|
|
3798
|
+
unix_timestamp = int(time.time())
|
|
3799
|
+
|
|
3800
|
+
log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
|
|
3801
|
+
time.sleep(self.interval)
|
|
3802
|
+
except psutil.NoSuchProcess:
|
|
3803
|
+
pass
|
|
3804
|
+
|
|
3805
|
+
def __enter__(self: Any) -> None:
|
|
3806
|
+
self.thread.start()
|
|
3807
|
+
return self
|
|
3808
|
+
|
|
3809
|
+
def __exit__(self: Any, exc_type: Any, exc_value: Any, _traceback: Any) -> None:
|
|
3810
|
+
self.running = False
|
|
3811
|
+
self.thread.join()
|
|
3812
|
+
|
|
3813
|
+
def execute_bash_code_log_time(code: str) -> list:
|
|
3814
|
+
process_item = subprocess.Popen(code, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
3815
|
+
|
|
3816
|
+
with MonitorProcess(process_item.pid):
|
|
3817
|
+
try:
|
|
3818
|
+
stdout, stderr = process_item.communicate()
|
|
3819
|
+
result = subprocess.CompletedProcess(
|
|
3820
|
+
args=code, returncode=process_item.returncode, stdout=stdout, stderr=stderr
|
|
3821
|
+
)
|
|
3429
3822
|
return [result.stdout, result.stderr, result.returncode, None]
|
|
3430
3823
|
except subprocess.CalledProcessError as e:
|
|
3431
3824
|
real_exit_code = e.returncode
|
|
@@ -3550,7 +3943,7 @@ def _add_to_csv_acquire_lock(lockfile: str, dir_path: str) -> bool:
|
|
|
3550
3943
|
time.sleep(wait_time)
|
|
3551
3944
|
max_wait -= wait_time
|
|
3552
3945
|
except Exception as e:
|
|
3553
|
-
|
|
3946
|
+
print_red(f"Lock error: {e}")
|
|
3554
3947
|
return False
|
|
3555
3948
|
return False
|
|
3556
3949
|
|
|
@@ -3588,11 +3981,11 @@ def _add_to_csv_rewrite_file(file_path: str, rows: List[list], existing_heading:
|
|
|
3588
3981
|
writer = csv.writer(tmp_file)
|
|
3589
3982
|
writer.writerow(all_headings)
|
|
3590
3983
|
for row in rows[1:]:
|
|
3591
|
-
tmp_file.writerow([
|
|
3984
|
+
tmp_file.writerow([ # type: ignore[attr-defined]
|
|
3592
3985
|
row[existing_heading.index(h)] if h in existing_heading else ""
|
|
3593
3986
|
for h in all_headings
|
|
3594
3987
|
])
|
|
3595
|
-
tmp_file.writerow([
|
|
3988
|
+
tmp_file.writerow([ # type: ignore[attr-defined]
|
|
3596
3989
|
formatted_data[new_heading.index(h)] if h in new_heading else ""
|
|
3597
3990
|
for h in all_headings
|
|
3598
3991
|
])
|
|
@@ -3623,12 +4016,12 @@ def find_file_paths(_text: str) -> List[str]:
|
|
|
3623
4016
|
def check_file_info(file_path: str) -> str:
|
|
3624
4017
|
if not os.path.exists(file_path):
|
|
3625
4018
|
if not args.tests:
|
|
3626
|
-
|
|
4019
|
+
print_red(f"check_file_info: The file {file_path} does not exist.")
|
|
3627
4020
|
return ""
|
|
3628
4021
|
|
|
3629
4022
|
if not os.access(file_path, os.R_OK):
|
|
3630
4023
|
if not args.tests:
|
|
3631
|
-
|
|
4024
|
+
print_red(f"check_file_info: The file {file_path} is not readable.")
|
|
3632
4025
|
return ""
|
|
3633
4026
|
|
|
3634
4027
|
file_stat = os.stat(file_path)
|
|
@@ -3742,7 +4135,7 @@ def count_defective_nodes(file_path: Union[str, None] = None, entry: Any = None)
|
|
|
3742
4135
|
return sorted(set(entries))
|
|
3743
4136
|
|
|
3744
4137
|
except Exception as e:
|
|
3745
|
-
|
|
4138
|
+
print_red(f"An error has occurred: {e}")
|
|
3746
4139
|
return []
|
|
3747
4140
|
|
|
3748
4141
|
def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]:
|
|
@@ -3753,7 +4146,7 @@ def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]
|
|
|
3753
4146
|
|
|
3754
4147
|
fool_linter(tmp)
|
|
3755
4148
|
except RuntimeError:
|
|
3756
|
-
|
|
4149
|
+
print_red(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
|
|
3757
4150
|
count_defective_nodes(None, socket.gethostname())
|
|
3758
4151
|
return return_in_case_of_error
|
|
3759
4152
|
except Exception:
|
|
@@ -4224,6 +4617,8 @@ def evaluate(parameters_with_trial_index: dict) -> Optional[Union[int, float, Di
|
|
|
4224
4617
|
trial_index = parameters_with_trial_index["trial_idx"]
|
|
4225
4618
|
submit_time = parameters_with_trial_index["submit_time"]
|
|
4226
4619
|
|
|
4620
|
+
print(f'Trial-Index: {trial_index}')
|
|
4621
|
+
|
|
4227
4622
|
queue_time = abs(int(time.time()) - int(submit_time))
|
|
4228
4623
|
|
|
4229
4624
|
start_nvidia_smi_thread()
|
|
@@ -4461,7 +4856,7 @@ def replace_string_with_params(input_string: str, params: list) -> str:
|
|
|
4461
4856
|
return replaced_string
|
|
4462
4857
|
except AssertionError as e:
|
|
4463
4858
|
error_text = f"Error in replace_string_with_params: {e}"
|
|
4464
|
-
|
|
4859
|
+
print_red(error_text)
|
|
4465
4860
|
raise
|
|
4466
4861
|
|
|
4467
4862
|
return ""
|
|
@@ -4738,9 +5133,7 @@ def get_sixel_graphics_data(_pd_csv: str, _force: bool = False) -> list:
|
|
|
4738
5133
|
_params = [_command, plot, _tmp, plot_type, tmp_file, _width]
|
|
4739
5134
|
data.append(_params)
|
|
4740
5135
|
except Exception as e:
|
|
4741
|
-
|
|
4742
|
-
print_red(f"Error trying to print {plot_type} to CLI: {e}, {tb}")
|
|
4743
|
-
print_debug(f"Error trying to print {plot_type} to CLI: {e}")
|
|
5136
|
+
print_red(f"Error trying to print {plot_type} to CLI: {e}")
|
|
4744
5137
|
|
|
4745
5138
|
return data
|
|
4746
5139
|
|
|
@@ -4931,15 +5324,17 @@ def abandon_job(job: Job, trial_index: int, reason: str) -> bool:
|
|
|
4931
5324
|
if job:
|
|
4932
5325
|
try:
|
|
4933
5326
|
if ax_client:
|
|
4934
|
-
_trial =
|
|
4935
|
-
_trial
|
|
5327
|
+
_trial = get_ax_client_trial(trial_index)
|
|
5328
|
+
if _trial is None:
|
|
5329
|
+
return False
|
|
5330
|
+
|
|
5331
|
+
mark_abandoned(_trial, reason, trial_index)
|
|
4936
5332
|
print_debug(f"abandon_job: removing job {job}, trial_index: {trial_index}")
|
|
4937
5333
|
global_vars["jobs"].remove((job, trial_index))
|
|
4938
5334
|
else:
|
|
4939
|
-
_fatal_error("ax_client could not be found",
|
|
5335
|
+
_fatal_error("ax_client could not be found", 101)
|
|
4940
5336
|
except Exception as e:
|
|
4941
|
-
|
|
4942
|
-
print_debug(f"ERROR in line {get_line_info()}: {e}")
|
|
5337
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
4943
5338
|
return False
|
|
4944
5339
|
job.cancel()
|
|
4945
5340
|
return True
|
|
@@ -4952,20 +5347,6 @@ def abandon_all_jobs() -> None:
|
|
|
4952
5347
|
if not abandoned:
|
|
4953
5348
|
print_debug(f"Job {job} could not be abandoned.")
|
|
4954
5349
|
|
|
4955
|
-
def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
|
|
4956
|
-
if args.dryrun:
|
|
4957
|
-
print_debug("Not showing Pareto-frontier data with --dryrun")
|
|
4958
|
-
return None
|
|
4959
|
-
|
|
4960
|
-
if len(res_names) > 1:
|
|
4961
|
-
try:
|
|
4962
|
-
show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
|
|
4963
|
-
except Exception as e:
|
|
4964
|
-
print_red(f"show_pareto_frontier_data() failed with exception '{e}'")
|
|
4965
|
-
else:
|
|
4966
|
-
print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
|
|
4967
|
-
return None
|
|
4968
|
-
|
|
4969
5350
|
def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None) -> None:
|
|
4970
5351
|
global END_PROGRAM_RAN
|
|
4971
5352
|
|
|
@@ -5002,7 +5383,7 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
|
|
|
5002
5383
|
_exit = new_exit
|
|
5003
5384
|
except (SignalUSR, SignalINT, SignalCONT, KeyboardInterrupt):
|
|
5004
5385
|
print_red("\n⚠ You pressed CTRL+C or a signal was sent. Program execution halted while ending program.")
|
|
5005
|
-
|
|
5386
|
+
print_red("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
|
|
5006
5387
|
new_exit = show_end_table_and_save_end_files()
|
|
5007
5388
|
if new_exit > 0:
|
|
5008
5389
|
_exit = new_exit
|
|
@@ -5023,21 +5404,31 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
|
|
|
5023
5404
|
|
|
5024
5405
|
my_exit(_exit)
|
|
5025
5406
|
|
|
5407
|
+
def save_ax_client_to_json_file(checkpoint_filepath: str) -> None:
|
|
5408
|
+
if not ax_client:
|
|
5409
|
+
my_exit(101)
|
|
5410
|
+
|
|
5411
|
+
return None
|
|
5412
|
+
|
|
5413
|
+
ax_client.save_to_json_file(checkpoint_filepath)
|
|
5414
|
+
|
|
5415
|
+
return None
|
|
5416
|
+
|
|
5026
5417
|
def save_checkpoint(trial_nr: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
5027
5418
|
if trial_nr > 3:
|
|
5028
5419
|
if eee:
|
|
5029
|
-
|
|
5420
|
+
print_red(f"Error during saving checkpoint: {eee}")
|
|
5030
5421
|
else:
|
|
5031
|
-
|
|
5422
|
+
print_red("Error during saving checkpoint")
|
|
5032
5423
|
return
|
|
5033
5424
|
|
|
5034
5425
|
try:
|
|
5035
5426
|
checkpoint_filepath = get_state_file_name('checkpoint.json')
|
|
5036
5427
|
|
|
5037
5428
|
if ax_client:
|
|
5038
|
-
|
|
5429
|
+
save_ax_client_to_json_file(checkpoint_filepath)
|
|
5039
5430
|
else:
|
|
5040
|
-
_fatal_error("Something went wrong using the ax_client",
|
|
5431
|
+
_fatal_error("Something went wrong using the ax_client", 101)
|
|
5041
5432
|
except Exception as e:
|
|
5042
5433
|
save_checkpoint(trial_nr + 1, e)
|
|
5043
5434
|
|
|
@@ -5198,7 +5589,7 @@ def parse_equation_item(comparer_found: bool, item: str, parsed: list, parsed_or
|
|
|
5198
5589
|
})
|
|
5199
5590
|
elif item in [">=", "<="]:
|
|
5200
5591
|
if comparer_found:
|
|
5201
|
-
|
|
5592
|
+
print_red("There is already one comparison operator! Cannot have more than one in an equation!")
|
|
5202
5593
|
return_totally = True
|
|
5203
5594
|
comparer_found = True
|
|
5204
5595
|
|
|
@@ -5409,20 +5800,9 @@ def check_equation(variables: list, equation: str) -> Union[str, bool]:
|
|
|
5409
5800
|
def set_objectives() -> dict:
|
|
5410
5801
|
objectives = {}
|
|
5411
5802
|
|
|
5412
|
-
|
|
5413
|
-
|
|
5414
|
-
|
|
5415
|
-
if "=" in rn:
|
|
5416
|
-
key, value = rn.split('=', 1)
|
|
5417
|
-
else:
|
|
5418
|
-
key = rn
|
|
5419
|
-
value = ""
|
|
5420
|
-
|
|
5421
|
-
if value not in ["min", "max"]:
|
|
5422
|
-
if value:
|
|
5423
|
-
print_yellow(f"Value '{value}' for --result_names {rn} is not a valid value. Must be min or max. Will be set to min.")
|
|
5424
|
-
|
|
5425
|
-
value = "min"
|
|
5803
|
+
k = 0
|
|
5804
|
+
for key in arg_result_names:
|
|
5805
|
+
value = arg_result_min_or_max[k]
|
|
5426
5806
|
|
|
5427
5807
|
_min = True
|
|
5428
5808
|
|
|
@@ -5430,12 +5810,18 @@ def set_objectives() -> dict:
|
|
|
5430
5810
|
_min = False
|
|
5431
5811
|
|
|
5432
5812
|
objectives[key] = ObjectiveProperties(minimize=_min)
|
|
5813
|
+
k = k + 1
|
|
5433
5814
|
|
|
5434
5815
|
return objectives
|
|
5435
5816
|
|
|
5436
|
-
def set_experiment_constraints(experiment_constraints: Optional[list], experiment_args: dict, _experiment_parameters: Union[dict, list]) -> dict:
|
|
5437
|
-
if
|
|
5817
|
+
def set_experiment_constraints(experiment_constraints: Optional[list], experiment_args: dict, _experiment_parameters: Optional[Union[dict, list]]) -> dict:
|
|
5818
|
+
if _experiment_parameters is None:
|
|
5819
|
+
print_red("set_experiment_constraints: _experiment_parameters was None")
|
|
5820
|
+
my_exit(95)
|
|
5821
|
+
|
|
5822
|
+
return {}
|
|
5438
5823
|
|
|
5824
|
+
if experiment_constraints and len(experiment_constraints):
|
|
5439
5825
|
experiment_args["parameter_constraints"] = []
|
|
5440
5826
|
|
|
5441
5827
|
if experiment_constraints:
|
|
@@ -5464,11 +5850,15 @@ def set_experiment_constraints(experiment_constraints: Optional[list], experimen
|
|
|
5464
5850
|
|
|
5465
5851
|
return experiment_args
|
|
5466
5852
|
|
|
5467
|
-
def replace_parameters_for_continued_jobs(parameter: Optional[list], cli_params_experiment_parameters: Optional[list]) -> None:
|
|
5853
|
+
def replace_parameters_for_continued_jobs(parameter: Optional[list], cli_params_experiment_parameters: Optional[dict | list]) -> None:
|
|
5854
|
+
if not experiment_parameters:
|
|
5855
|
+
print_red("replace_parameters_for_continued_jobs: experiment_parameters was False")
|
|
5856
|
+
return None
|
|
5857
|
+
|
|
5468
5858
|
if args.worker_generator_path:
|
|
5469
5859
|
return None
|
|
5470
5860
|
|
|
5471
|
-
def get_name(obj) -> Optional[str]:
|
|
5861
|
+
def get_name(obj: Any) -> Optional[str]:
|
|
5472
5862
|
"""Extract a parameter name from dict, list, or tuple safely."""
|
|
5473
5863
|
if isinstance(obj, dict):
|
|
5474
5864
|
return obj.get("name")
|
|
@@ -5550,13 +5940,13 @@ def copy_continue_uuid() -> None:
|
|
|
5550
5940
|
print_debug(f"copy_continue_uuid: Source file does not exist: {source_file}")
|
|
5551
5941
|
|
|
5552
5942
|
def load_ax_client_from_experiment_parameters() -> None:
|
|
5553
|
-
|
|
5554
|
-
|
|
5943
|
+
if experiment_parameters:
|
|
5944
|
+
global ax_client
|
|
5555
5945
|
|
|
5556
|
-
|
|
5557
|
-
|
|
5558
|
-
|
|
5559
|
-
|
|
5946
|
+
tmp_file_path = get_tmp_file_from_json(experiment_parameters)
|
|
5947
|
+
ax_client = AxClient.load_from_json_file(tmp_file_path)
|
|
5948
|
+
ax_client = cast(AxClient, ax_client)
|
|
5949
|
+
os.unlink(tmp_file_path)
|
|
5560
5950
|
|
|
5561
5951
|
def save_checkpoint_for_continued() -> None:
|
|
5562
5952
|
checkpoint_filepath = get_state_file_name('checkpoint.json')
|
|
@@ -5568,12 +5958,15 @@ def save_checkpoint_for_continued() -> None:
|
|
|
5568
5958
|
_fatal_error(f"{checkpoint_filepath} not found. Cannot continue_previous_job without.", 47)
|
|
5569
5959
|
|
|
5570
5960
|
def load_original_generation_strategy(original_ax_client_file: str) -> None:
|
|
5571
|
-
|
|
5572
|
-
|
|
5573
|
-
|
|
5961
|
+
if experiment_parameters:
|
|
5962
|
+
with open(original_ax_client_file, encoding="utf-8") as f:
|
|
5963
|
+
loaded_original_ax_client_json = json.load(f)
|
|
5964
|
+
original_generation_strategy = loaded_original_ax_client_json["generation_strategy"]
|
|
5574
5965
|
|
|
5575
|
-
|
|
5576
|
-
|
|
5966
|
+
if original_generation_strategy:
|
|
5967
|
+
experiment_parameters["generation_strategy"] = original_generation_strategy
|
|
5968
|
+
else:
|
|
5969
|
+
print_red("load_original_generation_strategy: experiment_parameters was empty!")
|
|
5577
5970
|
|
|
5578
5971
|
def wait_for_checkpoint_file(checkpoint_file: str) -> None:
|
|
5579
5972
|
start_time = time.time()
|
|
@@ -5586,10 +5979,6 @@ def wait_for_checkpoint_file(checkpoint_file: str) -> None:
|
|
|
5586
5979
|
elapsed = int(time.time() - start_time)
|
|
5587
5980
|
console.print(f"[green]Checkpoint file found after {elapsed} seconds[/green] ")
|
|
5588
5981
|
|
|
5589
|
-
def __get_experiment_parameters__check_ax_client() -> None:
|
|
5590
|
-
if not ax_client:
|
|
5591
|
-
_fatal_error("Something went wrong with the ax_client", 9)
|
|
5592
|
-
|
|
5593
5982
|
def validate_experiment_parameters() -> None:
|
|
5594
5983
|
if experiment_parameters is None:
|
|
5595
5984
|
print_red("Error: experiment_parameters is None.")
|
|
@@ -5599,6 +5988,8 @@ def validate_experiment_parameters() -> None:
|
|
|
5599
5988
|
print_red(f"Error: experiment_parameters is not a dict: {type(experiment_parameters).__name__}")
|
|
5600
5989
|
my_exit(95)
|
|
5601
5990
|
|
|
5991
|
+
sys.exit(95)
|
|
5992
|
+
|
|
5602
5993
|
path_checks = [
|
|
5603
5994
|
("experiment", experiment_parameters),
|
|
5604
5995
|
("search_space", experiment_parameters.get("experiment")),
|
|
@@ -5610,7 +6001,12 @@ def validate_experiment_parameters() -> None:
|
|
|
5610
6001
|
print_red(f"Error: Missing key '{key}' at level: {current_level}")
|
|
5611
6002
|
my_exit(95)
|
|
5612
6003
|
|
|
5613
|
-
def
|
|
6004
|
+
def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_parameters: Optional[dict | list]) -> Tuple[Any, str, str]:
|
|
6005
|
+
if not ax_client:
|
|
6006
|
+
print_red("load_from_checkpoint: ax_client was None")
|
|
6007
|
+
my_exit(101)
|
|
6008
|
+
return {}, "", ""
|
|
6009
|
+
|
|
5614
6010
|
print_debug(f"Load from checkpoint: {continue_previous_job}")
|
|
5615
6011
|
|
|
5616
6012
|
checkpoint_file = f"{continue_previous_job}/state_files/checkpoint.json"
|
|
@@ -5636,7 +6032,7 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
|
|
|
5636
6032
|
|
|
5637
6033
|
replace_parameters_for_continued_jobs(args.parameter, cli_params_experiment_parameters)
|
|
5638
6034
|
|
|
5639
|
-
|
|
6035
|
+
save_ax_client_to_json_file(original_ax_client_file)
|
|
5640
6036
|
|
|
5641
6037
|
load_original_generation_strategy(original_ax_client_file)
|
|
5642
6038
|
load_ax_client_from_experiment_parameters()
|
|
@@ -5652,6 +6048,12 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
|
|
|
5652
6048
|
|
|
5653
6049
|
experiment_constraints = get_constraints()
|
|
5654
6050
|
if experiment_constraints:
|
|
6051
|
+
|
|
6052
|
+
if not experiment_parameters:
|
|
6053
|
+
print_red("load_from_checkpoint: experiment_parameters was None")
|
|
6054
|
+
|
|
6055
|
+
return {}, "", ""
|
|
6056
|
+
|
|
5655
6057
|
experiment_args = set_experiment_constraints(
|
|
5656
6058
|
experiment_constraints,
|
|
5657
6059
|
experiment_args,
|
|
@@ -5660,7 +6062,116 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
|
|
|
5660
6062
|
|
|
5661
6063
|
return experiment_args, gpu_string, gpu_color
|
|
5662
6064
|
|
|
5663
|
-
def
|
|
6065
|
+
def get_experiment_args_import_python_script() -> str:
|
|
6066
|
+
|
|
6067
|
+
return """from ax.service.ax_client import AxClient, ObjectiveProperties
|
|
6068
|
+
from ax.adapter.registry import Generators
|
|
6069
|
+
import random
|
|
6070
|
+
|
|
6071
|
+
"""
|
|
6072
|
+
|
|
6073
|
+
def get_generate_and_test_random_function_str() -> str:
|
|
6074
|
+
raw_data_entries = ",\n ".join(
|
|
6075
|
+
f'"{name}": random.uniform(0, 1)' for name in arg_result_names
|
|
6076
|
+
)
|
|
6077
|
+
|
|
6078
|
+
return f"""
|
|
6079
|
+
def generate_and_test_random_parameters(n: int) -> None:
|
|
6080
|
+
for _ in range(n):
|
|
6081
|
+
print("======================================")
|
|
6082
|
+
parameters, trial_index = ax_client.get_next_trial()
|
|
6083
|
+
print("Trial Index:", trial_index)
|
|
6084
|
+
print("Suggested parameters:", parameters)
|
|
6085
|
+
|
|
6086
|
+
ax_client.complete_trial(
|
|
6087
|
+
trial_index=trial_index,
|
|
6088
|
+
raw_data={{
|
|
6089
|
+
{raw_data_entries}
|
|
6090
|
+
}}
|
|
6091
|
+
)
|
|
6092
|
+
|
|
6093
|
+
generate_and_test_random_parameters({args.num_random_steps + 1})
|
|
6094
|
+
"""
|
|
6095
|
+
|
|
6096
|
+
def get_global_gs_string() -> str:
|
|
6097
|
+
seed_str = ""
|
|
6098
|
+
if args.seed is not None:
|
|
6099
|
+
seed_str = f"model_kwargs={{'seed': {args.seed}}},"
|
|
6100
|
+
|
|
6101
|
+
return f"""from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy
|
|
6102
|
+
|
|
6103
|
+
global_gs = GenerationStrategy(
|
|
6104
|
+
steps=[
|
|
6105
|
+
GenerationStep(
|
|
6106
|
+
generator=Generators.SOBOL,
|
|
6107
|
+
num_trials={args.num_random_steps},
|
|
6108
|
+
max_parallelism=5,
|
|
6109
|
+
{seed_str}
|
|
6110
|
+
),
|
|
6111
|
+
GenerationStep(
|
|
6112
|
+
generator=Generators.{args.model},
|
|
6113
|
+
num_trials=-1,
|
|
6114
|
+
max_parallelism=5,
|
|
6115
|
+
),
|
|
6116
|
+
]
|
|
6117
|
+
)
|
|
6118
|
+
"""
|
|
6119
|
+
|
|
6120
|
+
def get_debug_ax_client_str() -> str:
|
|
6121
|
+
return """
|
|
6122
|
+
ax_client = AxClient(
|
|
6123
|
+
verbose_logging=True,
|
|
6124
|
+
enforce_sequential_optimization=False,
|
|
6125
|
+
generation_strategy=global_gs
|
|
6126
|
+
)
|
|
6127
|
+
"""
|
|
6128
|
+
|
|
6129
|
+
def write_ax_debug_python_code(experiment_args: dict) -> None:
|
|
6130
|
+
if args.generation_strategy:
|
|
6131
|
+
print_debug("Cannot write debug code for custom generation_strategy")
|
|
6132
|
+
return None
|
|
6133
|
+
|
|
6134
|
+
if args.model in uncontinuable_models:
|
|
6135
|
+
print_debug(f"Cannot write debug code for uncontinuable mode {args.model}")
|
|
6136
|
+
return None
|
|
6137
|
+
|
|
6138
|
+
python_code = python_code = get_experiment_args_import_python_script() + \
|
|
6139
|
+
get_global_gs_string() + \
|
|
6140
|
+
get_debug_ax_client_str() + \
|
|
6141
|
+
"experiment_args = " + pformat(experiment_args, width=120, compact=False) + \
|
|
6142
|
+
"\nax_client.create_experiment(**experiment_args)\n" + \
|
|
6143
|
+
get_generate_and_test_random_function_str()
|
|
6144
|
+
|
|
6145
|
+
file_path = f"{get_current_run_folder()}/debug.py"
|
|
6146
|
+
|
|
6147
|
+
try:
|
|
6148
|
+
print_debug(python_code)
|
|
6149
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
6150
|
+
f.write(python_code)
|
|
6151
|
+
except Exception as e:
|
|
6152
|
+
print_red(f"Error while writing {file_path}: {e}")
|
|
6153
|
+
|
|
6154
|
+
return None
|
|
6155
|
+
|
|
6156
|
+
def create_ax_client_experiment(experiment_args: dict) -> None:
|
|
6157
|
+
if not ax_client:
|
|
6158
|
+
my_exit(101)
|
|
6159
|
+
|
|
6160
|
+
return None
|
|
6161
|
+
|
|
6162
|
+
write_ax_debug_python_code(experiment_args)
|
|
6163
|
+
|
|
6164
|
+
ax_client.create_experiment(**experiment_args)
|
|
6165
|
+
|
|
6166
|
+
return None
|
|
6167
|
+
|
|
6168
|
+
def create_new_experiment() -> Tuple[dict, str, str]:
|
|
6169
|
+
if ax_client is None:
|
|
6170
|
+
print_red("create_new_experiment: ax_client is None")
|
|
6171
|
+
my_exit(101)
|
|
6172
|
+
|
|
6173
|
+
return {}, "", ""
|
|
6174
|
+
|
|
5664
6175
|
objectives = set_objectives()
|
|
5665
6176
|
|
|
5666
6177
|
experiment_args = {
|
|
@@ -5683,7 +6194,7 @@ def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str
|
|
|
5683
6194
|
experiment_args = set_experiment_constraints(get_constraints(), experiment_args, experiment_parameters)
|
|
5684
6195
|
|
|
5685
6196
|
try:
|
|
5686
|
-
|
|
6197
|
+
create_ax_client_experiment(experiment_args)
|
|
5687
6198
|
new_metrics = [Metric(k) for k in arg_result_names if k not in ax_client.metric_names]
|
|
5688
6199
|
ax_client.experiment.add_tracking_metrics(new_metrics)
|
|
5689
6200
|
except AssertionError as error:
|
|
@@ -5697,17 +6208,17 @@ def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str
|
|
|
5697
6208
|
|
|
5698
6209
|
return experiment_args, gpu_string, gpu_color
|
|
5699
6210
|
|
|
5700
|
-
def get_experiment_parameters(cli_params_experiment_parameters: Optional[list]) -> Optional[Tuple[
|
|
6211
|
+
def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[dict, str, str]]:
|
|
5701
6212
|
continue_previous_job = args.worker_generator_path or args.continue_previous_job
|
|
5702
6213
|
|
|
5703
|
-
|
|
6214
|
+
check_ax_client()
|
|
5704
6215
|
|
|
5705
6216
|
if continue_previous_job:
|
|
5706
|
-
experiment_args, gpu_string, gpu_color =
|
|
6217
|
+
experiment_args, gpu_string, gpu_color = load_from_checkpoint(continue_previous_job, cli_params_experiment_parameters)
|
|
5707
6218
|
else:
|
|
5708
|
-
experiment_args, gpu_string, gpu_color =
|
|
6219
|
+
experiment_args, gpu_string, gpu_color = create_new_experiment()
|
|
5709
6220
|
|
|
5710
|
-
return
|
|
6221
|
+
return experiment_args, gpu_string, gpu_color
|
|
5711
6222
|
|
|
5712
6223
|
def get_type_short(typename: str) -> str:
|
|
5713
6224
|
if typename == "RangeParameter":
|
|
@@ -5766,7 +6277,6 @@ def parse_single_experiment_parameter_table(classic_params: Optional[Union[list,
|
|
|
5766
6277
|
_upper = param["bounds"][1]
|
|
5767
6278
|
|
|
5768
6279
|
_possible_int_lower = str(helpers.to_int_when_possible(_lower))
|
|
5769
|
-
#print(f"name: {_name}, _possible_int_lower: {_possible_int_lower}, lower: {_lower}")
|
|
5770
6280
|
_possible_int_upper = str(helpers.to_int_when_possible(_upper))
|
|
5771
6281
|
|
|
5772
6282
|
rows.append([_name, _short_type, _possible_int_lower, _possible_int_upper, "", value_type, log_scale])
|
|
@@ -5843,19 +6353,62 @@ def print_ax_parameter_constraints_table(experiment_args: dict) -> None:
|
|
|
5843
6353
|
|
|
5844
6354
|
return None
|
|
5845
6355
|
|
|
6356
|
+
def check_base_for_print_overview() -> Optional[bool]:
|
|
6357
|
+
if args.continue_previous_job is not None and arg_result_names is not None and len(arg_result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
|
|
6358
|
+
print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
|
|
6359
|
+
|
|
6360
|
+
if ax_client is None:
|
|
6361
|
+
print_red("ax_client was None")
|
|
6362
|
+
return None
|
|
6363
|
+
|
|
6364
|
+
if ax_client.experiment is None:
|
|
6365
|
+
print_red("ax_client.experiment was None")
|
|
6366
|
+
return None
|
|
6367
|
+
|
|
6368
|
+
if ax_client.experiment.optimization_config is None:
|
|
6369
|
+
print_red("ax_client.experiment.optimization_config was None")
|
|
6370
|
+
return None
|
|
6371
|
+
|
|
6372
|
+
return True
|
|
6373
|
+
|
|
6374
|
+
def get_config_objectives() -> Any:
|
|
6375
|
+
if not ax_client:
|
|
6376
|
+
print_red("create_new_experiment: ax_client is None")
|
|
6377
|
+
my_exit(101)
|
|
6378
|
+
|
|
6379
|
+
return None
|
|
6380
|
+
|
|
6381
|
+
config_objectives = None
|
|
6382
|
+
|
|
6383
|
+
if ax_client.experiment and ax_client.experiment.optimization_config:
|
|
6384
|
+
opt_config = ax_client.experiment.optimization_config
|
|
6385
|
+
if opt_config.is_moo_problem:
|
|
6386
|
+
objective = getattr(opt_config, "objective", None)
|
|
6387
|
+
if objective and getattr(objective, "objectives", None) is not None:
|
|
6388
|
+
config_objectives = objective.objectives
|
|
6389
|
+
else:
|
|
6390
|
+
print_debug("ax_client.experiment.optimization_config.objective was None")
|
|
6391
|
+
else:
|
|
6392
|
+
config_objectives = [opt_config.objective]
|
|
6393
|
+
else:
|
|
6394
|
+
print_debug("ax_client.experiment or optimization_config was None")
|
|
6395
|
+
|
|
6396
|
+
return config_objectives
|
|
6397
|
+
|
|
5846
6398
|
def print_result_names_overview_table() -> None:
|
|
5847
6399
|
if not ax_client:
|
|
5848
6400
|
_fatal_error("Tried to access ax_client in print_result_names_overview_table, but it failed, because the ax_client was not defined.", 101)
|
|
5849
6401
|
|
|
5850
6402
|
return None
|
|
5851
6403
|
|
|
5852
|
-
if
|
|
5853
|
-
|
|
6404
|
+
if check_base_for_print_overview() is None:
|
|
6405
|
+
return None
|
|
5854
6406
|
|
|
5855
|
-
|
|
5856
|
-
|
|
5857
|
-
|
|
5858
|
-
config_objectives
|
|
6407
|
+
config_objectives = get_config_objectives()
|
|
6408
|
+
|
|
6409
|
+
if config_objectives is None:
|
|
6410
|
+
print_red("config_objectives not found")
|
|
6411
|
+
return None
|
|
5859
6412
|
|
|
5860
6413
|
res_names = []
|
|
5861
6414
|
res_min_max = []
|
|
@@ -5950,28 +6503,31 @@ def print_overview_tables(classic_params: Optional[Union[list, dict]], experimen
|
|
|
5950
6503
|
print_result_names_overview_table()
|
|
5951
6504
|
|
|
5952
6505
|
def update_progress_bar(nr: int) -> None:
|
|
5953
|
-
|
|
5954
|
-
|
|
5955
|
-
|
|
5956
|
-
|
|
6506
|
+
log_data()
|
|
6507
|
+
|
|
6508
|
+
if progress_bar is not None:
|
|
6509
|
+
try:
|
|
6510
|
+
progress_bar.update(nr)
|
|
6511
|
+
except Exception as e:
|
|
6512
|
+
print_red(f"Error updating progress bar: {e}")
|
|
6513
|
+
else:
|
|
6514
|
+
print_red("update_progress_bar: progress_bar was None")
|
|
5957
6515
|
|
|
5958
6516
|
def get_current_model_name() -> str:
|
|
5959
6517
|
if overwritten_to_random:
|
|
5960
6518
|
return "Random*"
|
|
5961
6519
|
|
|
6520
|
+
gs_model = "unknown model"
|
|
6521
|
+
|
|
5962
6522
|
if ax_client:
|
|
5963
6523
|
try:
|
|
5964
6524
|
if args.generation_strategy:
|
|
5965
|
-
idx = getattr(
|
|
6525
|
+
idx = getattr(global_gs, "current_step_index", None)
|
|
5966
6526
|
if isinstance(idx, int):
|
|
5967
6527
|
if 0 <= idx < len(generation_strategy_names):
|
|
5968
6528
|
gs_model = generation_strategy_names[int(idx)]
|
|
5969
|
-
else:
|
|
5970
|
-
gs_model = "unknown model"
|
|
5971
|
-
else:
|
|
5972
|
-
gs_model = "unknown model"
|
|
5973
6529
|
else:
|
|
5974
|
-
gs_model = getattr(
|
|
6530
|
+
gs_model = getattr(global_gs, "current_node_name", "unknown model")
|
|
5975
6531
|
|
|
5976
6532
|
if gs_model:
|
|
5977
6533
|
return str(gs_model)
|
|
@@ -6028,7 +6584,7 @@ def get_workers_string() -> str:
|
|
|
6028
6584
|
)
|
|
6029
6585
|
|
|
6030
6586
|
def _get_workers_string_collect_stats() -> dict:
|
|
6031
|
-
stats = {}
|
|
6587
|
+
stats: dict = {}
|
|
6032
6588
|
for job, _ in global_vars["jobs"][:]:
|
|
6033
6589
|
state = state_from_job(job)
|
|
6034
6590
|
stats[state] = stats.get(state, 0) + 1
|
|
@@ -6077,8 +6633,8 @@ def submitted_jobs(nr: int = 0) -> int:
|
|
|
6077
6633
|
def count_jobs_in_squeue() -> tuple[int, str]:
|
|
6078
6634
|
global _last_count_time, _last_count_result
|
|
6079
6635
|
|
|
6080
|
-
now = time.time()
|
|
6081
|
-
if _last_count_result != (0, "") and now - _last_count_time <
|
|
6636
|
+
now = int(time.time())
|
|
6637
|
+
if _last_count_result != (0, "") and now - _last_count_time < 5:
|
|
6082
6638
|
return _last_count_result
|
|
6083
6639
|
|
|
6084
6640
|
_len = len(global_vars["jobs"])
|
|
@@ -6100,6 +6656,7 @@ def count_jobs_in_squeue() -> tuple[int, str]:
|
|
|
6100
6656
|
check=True,
|
|
6101
6657
|
text=True
|
|
6102
6658
|
)
|
|
6659
|
+
|
|
6103
6660
|
if "slurm_load_jobs error" in result.stderr:
|
|
6104
6661
|
_last_count_result = (_len, "Detected slurm_load_jobs error in stderr.")
|
|
6105
6662
|
_last_count_time = now
|
|
@@ -6140,6 +6697,8 @@ def log_worker_numbers() -> None:
|
|
|
6140
6697
|
if len(WORKER_PERCENTAGE_USAGE) == 0 or WORKER_PERCENTAGE_USAGE[len(WORKER_PERCENTAGE_USAGE) - 1] != this_values:
|
|
6141
6698
|
WORKER_PERCENTAGE_USAGE.append(this_values)
|
|
6142
6699
|
|
|
6700
|
+
write_worker_usage()
|
|
6701
|
+
|
|
6143
6702
|
def get_slurm_in_brackets(in_brackets: list) -> list:
|
|
6144
6703
|
if is_slurm_job():
|
|
6145
6704
|
workers_strings = get_workers_string()
|
|
@@ -6224,6 +6783,8 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
|
|
|
6224
6783
|
global last_progress_bar_desc
|
|
6225
6784
|
global last_progress_bar_refresh_time
|
|
6226
6785
|
|
|
6786
|
+
log_data()
|
|
6787
|
+
|
|
6227
6788
|
if isinstance(new_msgs, str):
|
|
6228
6789
|
new_msgs = [new_msgs]
|
|
6229
6790
|
|
|
@@ -6241,7 +6802,7 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
|
|
|
6241
6802
|
print_red("Cannot update progress bar! It is None.")
|
|
6242
6803
|
|
|
6243
6804
|
def clean_completed_jobs() -> None:
|
|
6244
|
-
job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail"]
|
|
6805
|
+
job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail", "finished"]
|
|
6245
6806
|
job_states_to_be_ignored = ["ready", "completed", "unknown", "pending", "running", "completing", "out_of_memory", "requeued", "resv_del_hold"]
|
|
6246
6807
|
|
|
6247
6808
|
for job, trial_index in global_vars["jobs"][:]:
|
|
@@ -6299,7 +6860,7 @@ def load_existing_job_data_into_ax_client() -> None:
|
|
|
6299
6860
|
nr_of_imported_jobs = get_nr_of_imported_jobs()
|
|
6300
6861
|
set_nr_inserted_jobs(NR_INSERTED_JOBS + nr_of_imported_jobs)
|
|
6301
6862
|
|
|
6302
|
-
def parse_parameter_type_error(_error_message: Union[str, None]) -> Optional[dict]:
|
|
6863
|
+
def parse_parameter_type_error(_error_message: Union[Exception, str, None]) -> Optional[dict]:
|
|
6303
6864
|
if not _error_message:
|
|
6304
6865
|
return None
|
|
6305
6866
|
|
|
@@ -6366,7 +6927,7 @@ def get_generation_node_for_index(
|
|
|
6366
6927
|
results_list: List[Dict[str, Any]],
|
|
6367
6928
|
index: int,
|
|
6368
6929
|
__status: Any,
|
|
6369
|
-
base_str: str
|
|
6930
|
+
base_str: Optional[str]
|
|
6370
6931
|
) -> str:
|
|
6371
6932
|
__status.update(f"{base_str}: Getting generation node")
|
|
6372
6933
|
try:
|
|
@@ -6386,7 +6947,7 @@ def get_generation_node_for_index(
|
|
|
6386
6947
|
|
|
6387
6948
|
return generation_node
|
|
6388
6949
|
except Exception as e:
|
|
6389
|
-
|
|
6950
|
+
print_red(f"Error while get_generation_node_for_index: {e}")
|
|
6390
6951
|
return "MANUAL"
|
|
6391
6952
|
|
|
6392
6953
|
def _get_generation_node_for_index_index_valid(
|
|
@@ -6491,7 +7052,7 @@ def normalize_path(file_path: str) -> str:
|
|
|
6491
7052
|
|
|
6492
7053
|
def insert_jobs_from_lists(csv_path: str, arm_params_list: Any, results_list: Any, __status: Any) -> None:
|
|
6493
7054
|
cnt = 0
|
|
6494
|
-
err_msgs = []
|
|
7055
|
+
err_msgs: list = []
|
|
6495
7056
|
|
|
6496
7057
|
for i, (arm_params, result) in enumerate(zip(arm_params_list, results_list)):
|
|
6497
7058
|
base_str = f"[bold green]Loading job {i}/{len(results_list)} from {csv_path} into ax_client, result: {result}"
|
|
@@ -6525,9 +7086,16 @@ def try_insert_job(csv_path: str, arm_params: Dict, result: Any, i: int, arm_par
|
|
|
6525
7086
|
f"This can happen when the csv file has different parameters or results as the main job one's "
|
|
6526
7087
|
f"or other imported jobs. Error: {e}"
|
|
6527
7088
|
)
|
|
6528
|
-
|
|
6529
|
-
|
|
6530
|
-
err_msgs
|
|
7089
|
+
|
|
7090
|
+
if err_msgs is None:
|
|
7091
|
+
print_red("try_insert_job: err_msgs was None")
|
|
7092
|
+
else:
|
|
7093
|
+
if isinstance(err_msgs, list):
|
|
7094
|
+
if err_msg not in err_msgs:
|
|
7095
|
+
print_red(err_msg)
|
|
7096
|
+
err_msgs.append(err_msg)
|
|
7097
|
+
elif isinstance(err_msgs, str):
|
|
7098
|
+
err_msgs += f"\n{err_msg}"
|
|
6531
7099
|
|
|
6532
7100
|
return cnt
|
|
6533
7101
|
|
|
@@ -6544,33 +7112,52 @@ def update_global_job_counters(cnt: int) -> None:
|
|
|
6544
7112
|
set_max_eval(max_eval + cnt)
|
|
6545
7113
|
set_nr_inserted_jobs(NR_INSERTED_JOBS + cnt)
|
|
6546
7114
|
|
|
6547
|
-
def
|
|
7115
|
+
def update_status(__status: Optional[Any], base_str: Optional[str], new_text: str) -> None:
|
|
6548
7116
|
if __status and base_str:
|
|
6549
7117
|
__status.update(f"{base_str}: {new_text}")
|
|
6550
7118
|
|
|
6551
|
-
def
|
|
7119
|
+
def check_ax_client() -> None:
|
|
6552
7120
|
if ax_client is None or not ax_client:
|
|
6553
7121
|
_fatal_error("insert_job_into_ax_client: ax_client was not defined where it should have been", 101)
|
|
6554
7122
|
|
|
6555
|
-
def
|
|
7123
|
+
def attach_ax_client_data(arm_params: dict) -> Optional[Tuple[Any, int]]:
|
|
7124
|
+
if not ax_client:
|
|
7125
|
+
my_exit(101)
|
|
7126
|
+
|
|
7127
|
+
return None
|
|
7128
|
+
|
|
6556
7129
|
new_trial = ax_client.attach_trial(arm_params)
|
|
7130
|
+
|
|
7131
|
+
return new_trial
|
|
7132
|
+
|
|
7133
|
+
def attach_trial(arm_params: dict) -> Tuple[Any, int]:
|
|
7134
|
+
if ax_client is None:
|
|
7135
|
+
raise RuntimeError("attach_trial: ax_client was empty")
|
|
7136
|
+
|
|
7137
|
+
new_trial = attach_ax_client_data(arm_params)
|
|
6557
7138
|
if not isinstance(new_trial, tuple) or len(new_trial) < 2:
|
|
6558
7139
|
raise RuntimeError("attach_trial didn't return the expected tuple")
|
|
6559
7140
|
return new_trial
|
|
6560
7141
|
|
|
6561
|
-
def
|
|
7142
|
+
def get_trial_by_index(trial_idx: int) -> Any:
|
|
7143
|
+
if ax_client is None:
|
|
7144
|
+
raise RuntimeError("get_trial_by_index: ax_client was empty")
|
|
7145
|
+
|
|
6562
7146
|
trial = ax_client.experiment.trials.get(trial_idx)
|
|
6563
7147
|
if trial is None:
|
|
6564
7148
|
raise RuntimeError(f"Trial with index {trial_idx} not found")
|
|
6565
7149
|
return trial
|
|
6566
7150
|
|
|
6567
|
-
def
|
|
7151
|
+
def create_generator_run(arm_params: dict, trial_idx: int, new_job_type: str) -> GeneratorRun:
|
|
6568
7152
|
arm = Arm(parameters=arm_params, name=f'{trial_idx}_0')
|
|
6569
7153
|
return GeneratorRun(arms=[arm], generation_node_name=new_job_type)
|
|
6570
7154
|
|
|
6571
|
-
def
|
|
7155
|
+
def complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[Any], base_str: Optional[str]) -> None:
|
|
7156
|
+
if ax_client is None:
|
|
7157
|
+
raise RuntimeError("complete_trial_if_result: ax_client was empty")
|
|
7158
|
+
|
|
6572
7159
|
if f"{result}" != "":
|
|
6573
|
-
|
|
7160
|
+
update_status(__status, base_str, "Completing trial")
|
|
6574
7161
|
is_ok = True
|
|
6575
7162
|
|
|
6576
7163
|
for keyname in result.keys():
|
|
@@ -6578,20 +7165,20 @@ def __insert_job_into_ax_client__complete_trial_if_result(trial_idx: int, result
|
|
|
6578
7165
|
is_ok = False
|
|
6579
7166
|
|
|
6580
7167
|
if is_ok:
|
|
6581
|
-
|
|
6582
|
-
|
|
7168
|
+
complete_ax_client_trial(trial_idx, result)
|
|
7169
|
+
update_status(__status, base_str, "Completed trial")
|
|
6583
7170
|
else:
|
|
6584
7171
|
print_debug("Empty job encountered")
|
|
6585
7172
|
else:
|
|
6586
|
-
|
|
7173
|
+
update_status(__status, base_str, "Found trial without result. Not adding it.")
|
|
6587
7174
|
|
|
6588
|
-
def
|
|
7175
|
+
def save_results_if_needed(__status: Optional[Any], base_str: Optional[str]) -> None:
|
|
6589
7176
|
if not args.worker_generator_path:
|
|
6590
|
-
|
|
7177
|
+
update_status(__status, base_str, f"Saving {RESULTS_CSV_FILENAME}")
|
|
6591
7178
|
save_results_csv()
|
|
6592
|
-
|
|
7179
|
+
update_status(__status, base_str, f"Saved {RESULTS_CSV_FILENAME}")
|
|
6593
7180
|
|
|
6594
|
-
def
|
|
7181
|
+
def handle_insert_job_error(e: Exception, arm_params: dict) -> bool:
|
|
6595
7182
|
parsed_error = parse_parameter_type_error(e)
|
|
6596
7183
|
if parsed_error is not None:
|
|
6597
7184
|
param = parsed_error["parameter_name"]
|
|
@@ -6614,37 +7201,37 @@ def insert_job_into_ax_client(
|
|
|
6614
7201
|
result: dict,
|
|
6615
7202
|
new_job_type: str = "MANUAL",
|
|
6616
7203
|
__status: Optional[Any] = None,
|
|
6617
|
-
base_str: str = None
|
|
7204
|
+
base_str: Optional[str] = None
|
|
6618
7205
|
) -> bool:
|
|
6619
|
-
|
|
7206
|
+
check_ax_client()
|
|
6620
7207
|
|
|
6621
7208
|
done_converting = False
|
|
6622
7209
|
while not done_converting:
|
|
6623
7210
|
try:
|
|
6624
|
-
|
|
7211
|
+
update_status(__status, base_str, "Checking ax client")
|
|
6625
7212
|
if ax_client is None:
|
|
6626
7213
|
return False
|
|
6627
7214
|
|
|
6628
|
-
|
|
6629
|
-
_, new_trial_idx =
|
|
7215
|
+
update_status(__status, base_str, "Attaching new trial")
|
|
7216
|
+
_, new_trial_idx = attach_trial(arm_params)
|
|
6630
7217
|
|
|
6631
|
-
|
|
6632
|
-
trial =
|
|
6633
|
-
|
|
7218
|
+
update_status(__status, base_str, "Getting new trial")
|
|
7219
|
+
trial = get_trial_by_index(new_trial_idx)
|
|
7220
|
+
update_status(__status, base_str, "Got new trial")
|
|
6634
7221
|
|
|
6635
|
-
|
|
6636
|
-
manual_generator_run =
|
|
7222
|
+
update_status(__status, base_str, "Creating new arm")
|
|
7223
|
+
manual_generator_run = create_generator_run(arm_params, new_trial_idx, new_job_type)
|
|
6637
7224
|
trial._generator_run = manual_generator_run
|
|
6638
7225
|
fool_linter(trial._generator_run)
|
|
6639
7226
|
|
|
6640
|
-
|
|
7227
|
+
complete_trial_if_result(new_trial_idx, result, __status, base_str)
|
|
6641
7228
|
done_converting = True
|
|
6642
7229
|
|
|
6643
|
-
|
|
7230
|
+
save_results_if_needed(__status, base_str)
|
|
6644
7231
|
return True
|
|
6645
7232
|
|
|
6646
7233
|
except ax.exceptions.core.UnsupportedError as e:
|
|
6647
|
-
if not
|
|
7234
|
+
if not handle_insert_job_error(e, arm_params):
|
|
6648
7235
|
break
|
|
6649
7236
|
|
|
6650
7237
|
return False
|
|
@@ -6975,7 +7562,7 @@ def get_parameters_from_outfile(stdout_path: str) -> Union[None, dict, str]:
|
|
|
6975
7562
|
if not args.tests:
|
|
6976
7563
|
original_print(f"get_parameters_from_outfile: The file '{stdout_path}' was not found.")
|
|
6977
7564
|
except Exception as e:
|
|
6978
|
-
|
|
7565
|
+
print_red(f"get_parameters_from_outfile: There was an error: {e}")
|
|
6979
7566
|
|
|
6980
7567
|
return None
|
|
6981
7568
|
|
|
@@ -6993,7 +7580,7 @@ def get_hostname_from_outfile(stdout_path: Optional[str]) -> Optional[str]:
|
|
|
6993
7580
|
original_print(f"The file '{stdout_path}' was not found.")
|
|
6994
7581
|
return None
|
|
6995
7582
|
except Exception as e:
|
|
6996
|
-
|
|
7583
|
+
print_red(f"There was an error: {e}")
|
|
6997
7584
|
return None
|
|
6998
7585
|
|
|
6999
7586
|
def add_to_global_error_list(msg: str) -> None:
|
|
@@ -7029,22 +7616,68 @@ def mark_trial_as_failed(trial_index: int, _trial: Any) -> None:
|
|
|
7029
7616
|
|
|
7030
7617
|
return None
|
|
7031
7618
|
|
|
7032
|
-
|
|
7619
|
+
log_ax_client_trial_failure(trial_index)
|
|
7033
7620
|
_trial.mark_failed(unsafe=True)
|
|
7034
7621
|
except ValueError as e:
|
|
7035
7622
|
print_debug(f"mark_trial_as_failed error: {e}")
|
|
7036
7623
|
|
|
7037
7624
|
return None
|
|
7038
7625
|
|
|
7039
|
-
def
|
|
7626
|
+
def check_valid_result(result: Union[None, dict]) -> bool:
|
|
7040
7627
|
possible_val_not_found_values = [
|
|
7041
7628
|
VAL_IF_NOTHING_FOUND,
|
|
7042
7629
|
-VAL_IF_NOTHING_FOUND,
|
|
7043
7630
|
-99999999999999997168788049560464200849936328366177157906432,
|
|
7044
7631
|
99999999999999997168788049560464200849936328366177157906432
|
|
7045
7632
|
]
|
|
7046
|
-
|
|
7047
|
-
|
|
7633
|
+
|
|
7634
|
+
def flatten_values(obj: Any) -> Any:
|
|
7635
|
+
values = []
|
|
7636
|
+
try:
|
|
7637
|
+
if isinstance(obj, dict):
|
|
7638
|
+
for v in obj.values():
|
|
7639
|
+
values.extend(flatten_values(v))
|
|
7640
|
+
elif isinstance(obj, (list, tuple, set)):
|
|
7641
|
+
for v in obj:
|
|
7642
|
+
values.extend(flatten_values(v))
|
|
7643
|
+
else:
|
|
7644
|
+
values.append(obj)
|
|
7645
|
+
except Exception as e:
|
|
7646
|
+
print_red(f"Error while flattening values: {e}")
|
|
7647
|
+
return values
|
|
7648
|
+
|
|
7649
|
+
if result is None:
|
|
7650
|
+
return False
|
|
7651
|
+
|
|
7652
|
+
try:
|
|
7653
|
+
all_values = flatten_values(result)
|
|
7654
|
+
for val in all_values:
|
|
7655
|
+
if val in possible_val_not_found_values:
|
|
7656
|
+
return False
|
|
7657
|
+
return True
|
|
7658
|
+
except Exception as e:
|
|
7659
|
+
print_red(f"Error while checking result validity: {e}")
|
|
7660
|
+
return False
|
|
7661
|
+
|
|
7662
|
+
def update_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
|
|
7663
|
+
if not ax_client:
|
|
7664
|
+
my_exit(101)
|
|
7665
|
+
|
|
7666
|
+
return None
|
|
7667
|
+
|
|
7668
|
+
ax_client.update_trial_data(trial_index=trial_idx, raw_data=result)
|
|
7669
|
+
|
|
7670
|
+
return None
|
|
7671
|
+
|
|
7672
|
+
def complete_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
|
|
7673
|
+
if not ax_client:
|
|
7674
|
+
my_exit(101)
|
|
7675
|
+
|
|
7676
|
+
return None
|
|
7677
|
+
|
|
7678
|
+
ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
|
|
7679
|
+
|
|
7680
|
+
return None
|
|
7048
7681
|
|
|
7049
7682
|
def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -> None:
|
|
7050
7683
|
if ax_client is None:
|
|
@@ -7053,25 +7686,64 @@ def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -
|
|
|
7053
7686
|
|
|
7054
7687
|
try:
|
|
7055
7688
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result}...")
|
|
7056
|
-
|
|
7689
|
+
complete_ax_client_trial(trial_index, raw_result)
|
|
7057
7690
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result}... Done!")
|
|
7058
7691
|
except ax.exceptions.core.UnsupportedError as e:
|
|
7059
7692
|
if f"{e}":
|
|
7060
7693
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure. Trying to update trial...")
|
|
7061
|
-
|
|
7694
|
+
update_ax_client_trial(trial_index, raw_result)
|
|
7062
7695
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure... Done!")
|
|
7063
7696
|
else:
|
|
7064
7697
|
_fatal_error(f"Error completing trial: {e}", 234)
|
|
7065
7698
|
|
|
7066
7699
|
return None
|
|
7067
7700
|
|
|
7068
|
-
def
|
|
7701
|
+
def format_result_for_display(result: dict) -> str:
|
|
7702
|
+
def safe_float(v: Any) -> str:
|
|
7703
|
+
try:
|
|
7704
|
+
if v is None:
|
|
7705
|
+
return "None"
|
|
7706
|
+
if isinstance(v, (int, float)):
|
|
7707
|
+
if math.isnan(v):
|
|
7708
|
+
return "NaN"
|
|
7709
|
+
if math.isinf(v):
|
|
7710
|
+
return "∞" if v > 0 else "-∞"
|
|
7711
|
+
return f"{v:.6f}"
|
|
7712
|
+
return str(v)
|
|
7713
|
+
except Exception as e:
|
|
7714
|
+
return f"<error: {e}>"
|
|
7715
|
+
|
|
7716
|
+
try:
|
|
7717
|
+
if not isinstance(result, dict):
|
|
7718
|
+
return safe_float(result)
|
|
7719
|
+
|
|
7720
|
+
parts = []
|
|
7721
|
+
for key, val in result.items():
|
|
7722
|
+
try:
|
|
7723
|
+
if isinstance(val, (list, tuple)) and len(val) == 2:
|
|
7724
|
+
main, sem = val
|
|
7725
|
+
main_str = safe_float(main)
|
|
7726
|
+
if sem is not None:
|
|
7727
|
+
sem_str = safe_float(sem)
|
|
7728
|
+
parts.append(f"{key}: {main_str} (SEM: {sem_str})")
|
|
7729
|
+
else:
|
|
7730
|
+
parts.append(f"{key}: {main_str}")
|
|
7731
|
+
else:
|
|
7732
|
+
parts.append(f"{key}: {safe_float(val)}")
|
|
7733
|
+
except Exception as e:
|
|
7734
|
+
parts.append(f"{key}: <error: {e}>")
|
|
7735
|
+
|
|
7736
|
+
return ", ".join(parts)
|
|
7737
|
+
except Exception as e:
|
|
7738
|
+
return f"<error formatting result: {e}>"
|
|
7739
|
+
|
|
7740
|
+
def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: dict) -> None:
|
|
7069
7741
|
print_debug(f"Marking trial {_trial} as completed")
|
|
7070
7742
|
_trial.mark_completed(unsafe=True)
|
|
7071
7743
|
|
|
7072
7744
|
succeeded_jobs(1)
|
|
7073
7745
|
|
|
7074
|
-
progressbar_description(f"new result: {result}")
|
|
7746
|
+
progressbar_description(f"new result: {format_result_for_display(result)}")
|
|
7075
7747
|
update_progress_bar(1)
|
|
7076
7748
|
|
|
7077
7749
|
save_results_csv()
|
|
@@ -7085,8 +7757,8 @@ def _finish_job_core_helper_mark_failure(job: Any, trial_index: int, _trial: Any
|
|
|
7085
7757
|
if job:
|
|
7086
7758
|
try:
|
|
7087
7759
|
progressbar_description("job_failed")
|
|
7088
|
-
|
|
7089
|
-
_trial
|
|
7760
|
+
log_ax_client_trial_failure(trial_index)
|
|
7761
|
+
mark_trial_as_failed(trial_index, _trial)
|
|
7090
7762
|
except Exception as e:
|
|
7091
7763
|
print_red(f"\nERROR while trying to mark job as failure: {e}")
|
|
7092
7764
|
job.cancel()
|
|
@@ -7103,16 +7775,16 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
|
|
|
7103
7775
|
result = job.result()
|
|
7104
7776
|
print_debug(f"finish_job_core: trial-index: {trial_index}, job.result(): {result}, state: {state_from_job(job)}")
|
|
7105
7777
|
|
|
7106
|
-
raw_result = result
|
|
7107
|
-
result_keys = list(result.keys())
|
|
7108
|
-
result = result[result_keys[0]]
|
|
7109
7778
|
this_jobs_finished += 1
|
|
7110
7779
|
|
|
7111
7780
|
if ax_client:
|
|
7112
|
-
_trial =
|
|
7781
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7782
|
+
|
|
7783
|
+
if _trial is None:
|
|
7784
|
+
return 0
|
|
7113
7785
|
|
|
7114
|
-
if
|
|
7115
|
-
_finish_job_core_helper_complete_trial(trial_index,
|
|
7786
|
+
if check_valid_result(result):
|
|
7787
|
+
_finish_job_core_helper_complete_trial(trial_index, result)
|
|
7116
7788
|
|
|
7117
7789
|
try:
|
|
7118
7790
|
_finish_job_core_helper_mark_success(_trial, result)
|
|
@@ -7120,15 +7792,17 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
|
|
|
7120
7792
|
if len(arg_result_names) > 1 and count_done_jobs() > 1 and not job_calculate_pareto_front(get_current_run_folder(), True):
|
|
7121
7793
|
print_red("job_calculate_pareto_front post job failed")
|
|
7122
7794
|
except Exception as e:
|
|
7123
|
-
|
|
7795
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
7124
7796
|
else:
|
|
7125
7797
|
_finish_job_core_helper_mark_failure(job, trial_index, _trial)
|
|
7126
7798
|
else:
|
|
7127
|
-
_fatal_error("ax_client could not be found or used",
|
|
7799
|
+
_fatal_error("ax_client could not be found or used", 101)
|
|
7128
7800
|
|
|
7129
7801
|
print_debug(f"finish_job_core: removing job {job}, trial_index: {trial_index}")
|
|
7130
7802
|
global_vars["jobs"].remove((job, trial_index))
|
|
7131
7803
|
|
|
7804
|
+
log_data()
|
|
7805
|
+
|
|
7132
7806
|
force_live_share()
|
|
7133
7807
|
|
|
7134
7808
|
return this_jobs_finished
|
|
@@ -7141,11 +7815,14 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
|
|
|
7141
7815
|
if job:
|
|
7142
7816
|
try:
|
|
7143
7817
|
progressbar_description("job_failed")
|
|
7144
|
-
_trial =
|
|
7145
|
-
|
|
7818
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7819
|
+
if _trial is None:
|
|
7820
|
+
return None
|
|
7821
|
+
|
|
7822
|
+
log_ax_client_trial_failure(trial_index)
|
|
7146
7823
|
mark_trial_as_failed(trial_index, _trial)
|
|
7147
7824
|
except Exception as e:
|
|
7148
|
-
|
|
7825
|
+
print_debug(f"ERROR in line {get_line_info()}: {e}")
|
|
7149
7826
|
job.cancel()
|
|
7150
7827
|
orchestrate_job(job, trial_index)
|
|
7151
7828
|
|
|
@@ -7160,10 +7837,12 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
|
|
|
7160
7837
|
|
|
7161
7838
|
def _finish_previous_jobs_helper_handle_exception(job: Any, trial_index: int, error: Exception) -> int:
|
|
7162
7839
|
if "None for metric" in str(error):
|
|
7163
|
-
|
|
7164
|
-
|
|
7165
|
-
|
|
7166
|
-
|
|
7840
|
+
err_msg = f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string.\nError: {error}\nJob-result: {job.result()}"
|
|
7841
|
+
|
|
7842
|
+
if count_done_jobs() == 0:
|
|
7843
|
+
print_red(err_msg)
|
|
7844
|
+
else:
|
|
7845
|
+
print_debug(err_msg)
|
|
7167
7846
|
else:
|
|
7168
7847
|
print_red(f"\n⚠ {error}")
|
|
7169
7848
|
|
|
@@ -7182,7 +7861,10 @@ def _finish_previous_jobs_helper_process_job(job: Any, trial_index: int, this_jo
|
|
|
7182
7861
|
this_jobs_finished += _finish_previous_jobs_helper_handle_exception(job, trial_index, error)
|
|
7183
7862
|
return this_jobs_finished
|
|
7184
7863
|
|
|
7185
|
-
def _finish_previous_jobs_helper_check_and_process(
|
|
7864
|
+
def _finish_previous_jobs_helper_check_and_process(__args: Tuple[Any, int]) -> int:
|
|
7865
|
+
job, trial_index = __args
|
|
7866
|
+
|
|
7867
|
+
this_jobs_finished = 0
|
|
7186
7868
|
if job is None:
|
|
7187
7869
|
print_debug(f"finish_previous_jobs: job {job} is None")
|
|
7188
7870
|
return this_jobs_finished
|
|
@@ -7195,10 +7877,6 @@ def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, t
|
|
|
7195
7877
|
|
|
7196
7878
|
return this_jobs_finished
|
|
7197
7879
|
|
|
7198
|
-
def _finish_previous_jobs_helper_wrapper(__args: Tuple[Any, int]) -> int:
|
|
7199
|
-
job, trial_index = __args
|
|
7200
|
-
return _finish_previous_jobs_helper_check_and_process(job, trial_index, 0)
|
|
7201
|
-
|
|
7202
7880
|
def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
7203
7881
|
global JOBS_FINISHED
|
|
7204
7882
|
|
|
@@ -7211,13 +7889,10 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
|
7211
7889
|
|
|
7212
7890
|
jobs_copy = global_vars["jobs"][:]
|
|
7213
7891
|
|
|
7214
|
-
|
|
7215
|
-
print_debug(f"jobs in finish_previous_jobs: {jobs_copy}")
|
|
7216
|
-
|
|
7217
|
-
finishing_jobs_start_time = time.time()
|
|
7892
|
+
#finishing_jobs_start_time = time.time()
|
|
7218
7893
|
|
|
7219
7894
|
with ThreadPoolExecutor() as finish_job_executor:
|
|
7220
|
-
futures = [finish_job_executor.submit(
|
|
7895
|
+
futures = [finish_job_executor.submit(_finish_previous_jobs_helper_check_and_process, (job, trial_index)) for job, trial_index in jobs_copy]
|
|
7221
7896
|
|
|
7222
7897
|
for future in as_completed(futures):
|
|
7223
7898
|
try:
|
|
@@ -7225,11 +7900,11 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
|
7225
7900
|
except Exception as e:
|
|
7226
7901
|
print_red(f"⚠ Exception in parallel job handling: {e}")
|
|
7227
7902
|
|
|
7228
|
-
finishing_jobs_end_time = time.time()
|
|
7903
|
+
#finishing_jobs_end_time = time.time()
|
|
7229
7904
|
|
|
7230
|
-
finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
|
|
7905
|
+
#finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
|
|
7231
7906
|
|
|
7232
|
-
print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
|
|
7907
|
+
#print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
|
|
7233
7908
|
|
|
7234
7909
|
if this_jobs_finished > 0:
|
|
7235
7910
|
save_results_csv()
|
|
@@ -7376,11 +8051,15 @@ def is_already_in_defective_nodes(hostname: str) -> bool:
|
|
|
7376
8051
|
return True
|
|
7377
8052
|
except Exception as e:
|
|
7378
8053
|
print_red(f"is_already_in_defective_nodes: Error reading the file {file_path}: {e}")
|
|
7379
|
-
return False
|
|
7380
8054
|
|
|
7381
8055
|
return False
|
|
7382
8056
|
|
|
7383
8057
|
def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
|
|
8058
|
+
if submitit_executor is None:
|
|
8059
|
+
print_red("submit_new_job: submitit_executor was None")
|
|
8060
|
+
|
|
8061
|
+
return None
|
|
8062
|
+
|
|
7384
8063
|
print_debug(f"Submitting new job for trial_index {trial_index}, parameters {parameters}")
|
|
7385
8064
|
|
|
7386
8065
|
start = time.time()
|
|
@@ -7391,25 +8070,49 @@ def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
|
|
|
7391
8070
|
|
|
7392
8071
|
print_debug(f"Done submitting new job, took {elapsed} seconds")
|
|
7393
8072
|
|
|
8073
|
+
log_data()
|
|
8074
|
+
|
|
7394
8075
|
return new_job
|
|
7395
8076
|
|
|
8077
|
+
def get_ax_client_trial(trial_index: int) -> Optional[ax.core.trial.Trial]:
|
|
8078
|
+
if not ax_client:
|
|
8079
|
+
my_exit(101)
|
|
8080
|
+
|
|
8081
|
+
return None
|
|
8082
|
+
|
|
8083
|
+
try:
|
|
8084
|
+
log_data()
|
|
8085
|
+
|
|
8086
|
+
return ax_client.get_trial(trial_index)
|
|
8087
|
+
except KeyError:
|
|
8088
|
+
error_without_print(f"get_ax_client_trial: trial_index {trial_index} failed")
|
|
8089
|
+
return None
|
|
8090
|
+
|
|
7396
8091
|
def orchestrator_start_trial(parameters: Union[dict, str], trial_index: int) -> None:
|
|
7397
8092
|
if submitit_executor and ax_client:
|
|
7398
8093
|
new_job = submit_new_job(parameters, trial_index)
|
|
7399
|
-
|
|
8094
|
+
if new_job:
|
|
8095
|
+
submitted_jobs(1)
|
|
7400
8096
|
|
|
7401
|
-
|
|
8097
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7402
8098
|
|
|
7403
|
-
|
|
7404
|
-
|
|
7405
|
-
|
|
7406
|
-
|
|
7407
|
-
|
|
8099
|
+
if _trial is not None:
|
|
8100
|
+
try:
|
|
8101
|
+
_trial.mark_staged(unsafe=True)
|
|
8102
|
+
except Exception as e:
|
|
8103
|
+
print_debug(f"orchestrator_start_trial: error {e}")
|
|
8104
|
+
_trial.mark_running(unsafe=True, no_runner_required=True)
|
|
7408
8105
|
|
|
7409
|
-
|
|
7410
|
-
|
|
8106
|
+
print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
|
|
8107
|
+
global_vars["jobs"].append((new_job, trial_index))
|
|
8108
|
+
else:
|
|
8109
|
+
print_red("Trial was none in orchestrator_start_trial")
|
|
8110
|
+
else:
|
|
8111
|
+
print_red("orchestrator_start_trial: Failed to start new job")
|
|
8112
|
+
elif ax_client:
|
|
8113
|
+
_fatal_error("submitit_executor could not be found properly", 9)
|
|
7411
8114
|
else:
|
|
7412
|
-
_fatal_error("
|
|
8115
|
+
_fatal_error("ax_client could not be found properly", 101)
|
|
7413
8116
|
|
|
7414
8117
|
def handle_exclude_node(stdout_path: str, hostname_from_out_file: Union[None, str]) -> None:
|
|
7415
8118
|
stdout_path = check_alternate_path(stdout_path)
|
|
@@ -7428,7 +8131,7 @@ def handle_restart(stdout_path: str, trial_index: int) -> None:
|
|
|
7428
8131
|
if parameters:
|
|
7429
8132
|
orchestrator_start_trial(parameters, trial_index)
|
|
7430
8133
|
else:
|
|
7431
|
-
|
|
8134
|
+
print_red(f"Could not determine parameters from outfile {stdout_path} for restarting job")
|
|
7432
8135
|
|
|
7433
8136
|
def check_alternate_path(path: str) -> str:
|
|
7434
8137
|
if os.path.exists(path):
|
|
@@ -7515,7 +8218,7 @@ def execute_evaluation(_params: list) -> Optional[int]:
|
|
|
7515
8218
|
print_debug(f"execute_evaluation({_params})")
|
|
7516
8219
|
trial_index, parameters, trial_counter, phase = _params
|
|
7517
8220
|
if not ax_client:
|
|
7518
|
-
_fatal_error("Failed to get ax_client",
|
|
8221
|
+
_fatal_error("Failed to get ax_client", 101)
|
|
7519
8222
|
|
|
7520
8223
|
return None
|
|
7521
8224
|
|
|
@@ -7524,7 +8227,11 @@ def execute_evaluation(_params: list) -> Optional[int]:
|
|
|
7524
8227
|
|
|
7525
8228
|
return None
|
|
7526
8229
|
|
|
7527
|
-
_trial =
|
|
8230
|
+
_trial = get_ax_client_trial(trial_index)
|
|
8231
|
+
|
|
8232
|
+
if _trial is None:
|
|
8233
|
+
error_without_print(f"execute_evaluation: _trial was not in execute_evaluation for params {_params}")
|
|
8234
|
+
return None
|
|
7528
8235
|
|
|
7529
8236
|
def mark_trial_stage(stage: str, error_msg: str) -> None:
|
|
7530
8237
|
try:
|
|
@@ -7539,15 +8246,18 @@ def execute_evaluation(_params: list) -> Optional[int]:
|
|
|
7539
8246
|
try:
|
|
7540
8247
|
initialize_job_environment()
|
|
7541
8248
|
new_job = submit_new_job(parameters, trial_index)
|
|
7542
|
-
|
|
8249
|
+
if new_job:
|
|
8250
|
+
submitted_jobs(1)
|
|
7543
8251
|
|
|
7544
|
-
|
|
7545
|
-
|
|
8252
|
+
print_debug(f"execute_evaluation: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
|
|
8253
|
+
global_vars["jobs"].append((new_job, trial_index))
|
|
7546
8254
|
|
|
7547
|
-
|
|
7548
|
-
|
|
8255
|
+
mark_trial_stage("mark_running", "Marking the trial as running failed")
|
|
8256
|
+
trial_counter += 1
|
|
7549
8257
|
|
|
7550
|
-
|
|
8258
|
+
progressbar_description("started new job")
|
|
8259
|
+
else:
|
|
8260
|
+
progressbar_description("Failed to start new job")
|
|
7551
8261
|
except submitit.core.utils.FailedJobError as error:
|
|
7552
8262
|
handle_failed_job(error, trial_index, new_job)
|
|
7553
8263
|
trial_counter += 1
|
|
@@ -7589,7 +8299,7 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
|
|
|
7589
8299
|
my_exit(144)
|
|
7590
8300
|
|
|
7591
8301
|
if new_job is None:
|
|
7592
|
-
|
|
8302
|
+
print_debug("handle_failed_job: job is None")
|
|
7593
8303
|
|
|
7594
8304
|
return None
|
|
7595
8305
|
|
|
@@ -7600,16 +8310,24 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
|
|
|
7600
8310
|
|
|
7601
8311
|
return None
|
|
7602
8312
|
|
|
8313
|
+
def log_ax_client_trial_failure(trial_index: int) -> None:
|
|
8314
|
+
if not ax_client:
|
|
8315
|
+
my_exit(101)
|
|
8316
|
+
|
|
8317
|
+
return
|
|
8318
|
+
|
|
8319
|
+
ax_client.log_trial_failure(trial_index=trial_index)
|
|
8320
|
+
|
|
7603
8321
|
def cancel_failed_job(trial_index: int, new_job: Job) -> None:
|
|
7604
8322
|
print_debug("Trying to cancel job that failed")
|
|
7605
8323
|
if new_job:
|
|
7606
8324
|
try:
|
|
7607
8325
|
if ax_client:
|
|
7608
|
-
|
|
8326
|
+
log_ax_client_trial_failure(trial_index)
|
|
7609
8327
|
else:
|
|
7610
8328
|
_fatal_error("ax_client not defined", 101)
|
|
7611
8329
|
except Exception as e:
|
|
7612
|
-
|
|
8330
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
7613
8331
|
new_job.cancel()
|
|
7614
8332
|
|
|
7615
8333
|
print_debug(f"cancel_failed_job: removing job {new_job}, trial_index: {trial_index}")
|
|
@@ -7645,10 +8363,12 @@ def show_debug_table_for_break_run_search(_name: str, _max_eval: Optional[int])
|
|
|
7645
8363
|
("failed_jobs()", failed_jobs()),
|
|
7646
8364
|
("count_done_jobs()", count_done_jobs()),
|
|
7647
8365
|
("_max_eval", _max_eval),
|
|
7648
|
-
("progress_bar.total", progress_bar.total),
|
|
7649
8366
|
("NR_INSERTED_JOBS", NR_INSERTED_JOBS)
|
|
7650
8367
|
]
|
|
7651
8368
|
|
|
8369
|
+
if progress_bar is not None:
|
|
8370
|
+
rows.append(("progress_bar.total", progress_bar.total))
|
|
8371
|
+
|
|
7652
8372
|
for row in rows:
|
|
7653
8373
|
table.add_row(str(row[0]), str(row[1]))
|
|
7654
8374
|
|
|
@@ -7661,6 +8381,8 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
|
|
|
7661
8381
|
_submitted_jobs = submitted_jobs()
|
|
7662
8382
|
_failed_jobs = failed_jobs()
|
|
7663
8383
|
|
|
8384
|
+
log_data()
|
|
8385
|
+
|
|
7664
8386
|
max_failed_jobs = max_eval
|
|
7665
8387
|
|
|
7666
8388
|
if args.max_failed_jobs is not None and args.max_failed_jobs > 0:
|
|
@@ -7688,11 +8410,11 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
|
|
|
7688
8410
|
_ret = True
|
|
7689
8411
|
|
|
7690
8412
|
if args.verbose_break_run_search_table:
|
|
7691
|
-
show_debug_table_for_break_run_search(_name, _max_eval
|
|
8413
|
+
show_debug_table_for_break_run_search(_name, _max_eval)
|
|
7692
8414
|
|
|
7693
8415
|
return _ret
|
|
7694
8416
|
|
|
7695
|
-
def
|
|
8417
|
+
def calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
|
|
7696
8418
|
"""Calculates the number of jobs to retrieve."""
|
|
7697
8419
|
return min(
|
|
7698
8420
|
max_eval + simulated_jobs - count_done_jobs(),
|
|
@@ -7705,7 +8427,7 @@ def remove_extra_spaces(text: str) -> str:
|
|
|
7705
8427
|
raise ValueError("Input must be a string")
|
|
7706
8428
|
return re.sub(r'\s+', ' ', text).strip()
|
|
7707
8429
|
|
|
7708
|
-
def
|
|
8430
|
+
def get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
|
|
7709
8431
|
"""Generates the appropriate message for the number of trials being retrieved."""
|
|
7710
8432
|
ret = ""
|
|
7711
8433
|
if full_nr_of_jobs_to_get > 1:
|
|
@@ -7790,51 +8512,51 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
|
|
|
7790
8512
|
print_red("get_batched_arms: ax_client was None")
|
|
7791
8513
|
return []
|
|
7792
8514
|
|
|
7793
|
-
# Experiment-Status laden
|
|
7794
8515
|
load_experiment_state()
|
|
7795
8516
|
|
|
7796
|
-
while len(batched_arms)
|
|
8517
|
+
while len(batched_arms) < nr_of_jobs_to_get:
|
|
7797
8518
|
if attempts > args.max_attempts_for_generation:
|
|
7798
8519
|
print_debug(f"get_batched_arms: Stopped after {attempts} attempts: could not generate enough arms "
|
|
7799
8520
|
f"(got {len(batched_arms)} out of {nr_of_jobs_to_get}).")
|
|
7800
8521
|
break
|
|
7801
8522
|
|
|
7802
|
-
|
|
7803
|
-
print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting {remaining} more arm(s).")
|
|
7804
|
-
|
|
7805
|
-
print_debug("get pending observations")
|
|
7806
|
-
t0 = time.time()
|
|
7807
|
-
pending_observations = get_pending_observation_features(experiment=ax_client.experiment)
|
|
7808
|
-
dt = time.time() - t0
|
|
7809
|
-
print_debug(f"got pending observations: {pending_observations} (took {dt:.2f} seconds)")
|
|
8523
|
+
#print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting 1 more arm")
|
|
7810
8524
|
|
|
7811
|
-
|
|
7812
|
-
batched_generator_run = global_gs.gen(
|
|
8525
|
+
pending_observations = get_pending_observation_features(
|
|
7813
8526
|
experiment=ax_client.experiment,
|
|
7814
|
-
|
|
7815
|
-
pending_observations=pending_observations
|
|
8527
|
+
include_out_of_design_points=True
|
|
7816
8528
|
)
|
|
7817
|
-
|
|
8529
|
+
|
|
8530
|
+
try:
|
|
8531
|
+
#print_debug("getting global_gs.gen() with n=1")
|
|
8532
|
+
batched_generator_run: Any = global_gs.gen(
|
|
8533
|
+
experiment=ax_client.experiment,
|
|
8534
|
+
n=1,
|
|
8535
|
+
pending_observations=pending_observations,
|
|
8536
|
+
)
|
|
8537
|
+
print_debug(f"got global_gs.gen(): {batched_generator_run}")
|
|
8538
|
+
except Exception as e:
|
|
8539
|
+
print_debug(f"global_gs.gen failed: {e}")
|
|
8540
|
+
traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
|
|
8541
|
+
break
|
|
7818
8542
|
|
|
7819
8543
|
depth = 0
|
|
7820
8544
|
path = "batched_generator_run"
|
|
7821
|
-
while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run)
|
|
7822
|
-
print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
|
|
8545
|
+
while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) == 1:
|
|
8546
|
+
#print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
|
|
7823
8547
|
batched_generator_run = batched_generator_run[0]
|
|
7824
8548
|
path += "[0]"
|
|
7825
8549
|
depth += 1
|
|
7826
8550
|
|
|
7827
|
-
print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
|
|
8551
|
+
#print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
|
|
7828
8552
|
|
|
7829
|
-
|
|
7830
|
-
new_arms = batched_generator_run.arms
|
|
7831
|
-
print_debug(f"new_arms: {new_arms}")
|
|
8553
|
+
new_arms = getattr(batched_generator_run, "arms", [])
|
|
7832
8554
|
if not new_arms:
|
|
7833
8555
|
print_debug("get_batched_arms: No new arms were generated in this attempt.")
|
|
7834
8556
|
else:
|
|
7835
|
-
print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s),
|
|
8557
|
+
print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), now at {len(batched_arms) + len(new_arms)} of {nr_of_jobs_to_get}.")
|
|
8558
|
+
batched_arms.extend(new_arms)
|
|
7836
8559
|
|
|
7837
|
-
batched_arms.extend(new_arms)
|
|
7838
8560
|
attempts += 1
|
|
7839
8561
|
|
|
7840
8562
|
print_debug(f"get_batched_arms: Finished with {len(batched_arms)} arm(s) after {attempts} attempt(s).")
|
|
@@ -7845,7 +8567,7 @@ def fetch_next_trials(nr_of_jobs_to_get: int, recursion: bool = False) -> Tuple[
|
|
|
7845
8567
|
die_101_if_no_ax_client_or_experiment_or_gs()
|
|
7846
8568
|
|
|
7847
8569
|
if not ax_client:
|
|
7848
|
-
_fatal_error("ax_client was not defined",
|
|
8570
|
+
_fatal_error("ax_client was not defined", 101)
|
|
7849
8571
|
|
|
7850
8572
|
if global_gs is None:
|
|
7851
8573
|
_fatal_error("Global generation strategy is not set. This is a bug in OmniOpt2.", 107)
|
|
@@ -7873,16 +8595,14 @@ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7873
8595
|
retries += 1
|
|
7874
8596
|
continue
|
|
7875
8597
|
|
|
7876
|
-
|
|
7877
|
-
progressbar_description(_get_trials_message(cnt + 1, n, trial_durations))
|
|
8598
|
+
progressbar_description(get_trials_message(cnt + 1, n, trial_durations))
|
|
7878
8599
|
|
|
7879
8600
|
try:
|
|
7880
8601
|
result = create_and_handle_trial(arm)
|
|
7881
8602
|
if result is not None:
|
|
7882
8603
|
trial_index, trial_duration, trial_successful = result
|
|
7883
|
-
|
|
7884
8604
|
except TrialRejected as e:
|
|
7885
|
-
print_debug(f"Trial rejected: {e}")
|
|
8605
|
+
print_debug(f"generate_trials: Trial rejected, error: {e}")
|
|
7886
8606
|
retries += 1
|
|
7887
8607
|
continue
|
|
7888
8608
|
|
|
@@ -7892,14 +8612,23 @@ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7892
8612
|
cnt += 1
|
|
7893
8613
|
trials_dict[trial_index] = arm.parameters
|
|
7894
8614
|
|
|
7895
|
-
|
|
8615
|
+
finalized = finalize_generation(trials_dict, cnt, n, start_time)
|
|
8616
|
+
|
|
8617
|
+
return finalized
|
|
7896
8618
|
|
|
7897
8619
|
except Exception as e:
|
|
7898
|
-
return
|
|
8620
|
+
return handle_generation_failure(e, n, recursion)
|
|
7899
8621
|
|
|
7900
8622
|
class TrialRejected(Exception):
|
|
7901
8623
|
pass
|
|
7902
8624
|
|
|
8625
|
+
def mark_abandoned(trial: Any, reason: str, trial_index: int) -> None:
|
|
8626
|
+
try:
|
|
8627
|
+
print_debug(f"[INFO] Marking trial {trial.index} ({trial.arm.name}) as abandoned, trial-index: {trial_index}. Reason: {reason}")
|
|
8628
|
+
trial.mark_abandoned(reason)
|
|
8629
|
+
except Exception as e:
|
|
8630
|
+
print_red(f"[ERROR] Could not mark trial as abandoned: {e}")
|
|
8631
|
+
|
|
7903
8632
|
def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
7904
8633
|
if ax_client is None:
|
|
7905
8634
|
print_red("ax_client is None in create_and_handle_trial")
|
|
@@ -7925,7 +8654,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
7925
8654
|
arm = trial.arms[0]
|
|
7926
8655
|
if deduplicated_arm(arm):
|
|
7927
8656
|
print_debug(f"Duplicated arm: {arm}")
|
|
7928
|
-
|
|
8657
|
+
mark_abandoned(trial, "Duplication detected", trial_index)
|
|
7929
8658
|
raise TrialRejected("Duplicate arm.")
|
|
7930
8659
|
|
|
7931
8660
|
arms_by_name_for_deduplication[arm.name] = arm
|
|
@@ -7934,7 +8663,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
7934
8663
|
|
|
7935
8664
|
if not has_no_post_generation_constraints_or_matches_constraints(post_generation_constraints, params):
|
|
7936
8665
|
print_debug(f"Trial {trial_index} does not meet post-generation constraints. Marking abandoned. Params: {params}, constraints: {post_generation_constraints}")
|
|
7937
|
-
|
|
8666
|
+
mark_abandoned(trial, "Post-Generation-Constraint failed", trial_index)
|
|
7938
8667
|
abandoned_trial_indices.append(trial_index)
|
|
7939
8668
|
raise TrialRejected("Post-generation constraints not met.")
|
|
7940
8669
|
|
|
@@ -7942,7 +8671,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
7942
8671
|
end = time.time()
|
|
7943
8672
|
return trial_index, float(end - start), True
|
|
7944
8673
|
|
|
7945
|
-
def
|
|
8674
|
+
def finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
|
|
7946
8675
|
total_time = time.time() - start_time
|
|
7947
8676
|
|
|
7948
8677
|
log_gen_times.append(total_time)
|
|
@@ -7953,7 +8682,7 @@ def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int,
|
|
|
7953
8682
|
|
|
7954
8683
|
return trials_dict, False
|
|
7955
8684
|
|
|
7956
|
-
def
|
|
8685
|
+
def handle_generation_failure(
|
|
7957
8686
|
e: Exception,
|
|
7958
8687
|
requested: int,
|
|
7959
8688
|
recursion: bool
|
|
@@ -7969,19 +8698,19 @@ def _handle_generation_failure(
|
|
|
7969
8698
|
)):
|
|
7970
8699
|
msg = str(e)
|
|
7971
8700
|
if msg not in error_8_saved:
|
|
7972
|
-
|
|
8701
|
+
print_exhaustion_warning(e, recursion)
|
|
7973
8702
|
error_8_saved.append(msg)
|
|
7974
8703
|
|
|
7975
8704
|
if not recursion and args.revert_to_random_when_seemingly_exhausted:
|
|
7976
8705
|
print_debug("Switching to random search strategy.")
|
|
7977
|
-
|
|
8706
|
+
set_global_gs_to_sobol()
|
|
7978
8707
|
return fetch_next_trials(requested, True)
|
|
7979
8708
|
|
|
7980
|
-
print_red(f"
|
|
8709
|
+
print_red(f"handle_generation_failure: General Exception: {e}")
|
|
7981
8710
|
|
|
7982
8711
|
return {}, True
|
|
7983
8712
|
|
|
7984
|
-
def
|
|
8713
|
+
def print_exhaustion_warning(e: Exception, recursion: bool) -> None:
|
|
7985
8714
|
if not recursion and args.revert_to_random_when_seemingly_exhausted:
|
|
7986
8715
|
print_yellow(f"\n⚠Error 8: {e} From now (done jobs: {count_done_jobs()}) on, random points will be generated.")
|
|
7987
8716
|
else:
|
|
@@ -8026,21 +8755,24 @@ def get_model_gen_kwargs() -> dict:
|
|
|
8026
8755
|
"fit_out_of_design": args.fit_out_of_design
|
|
8027
8756
|
}
|
|
8028
8757
|
|
|
8029
|
-
def
|
|
8758
|
+
def set_global_gs_to_sobol() -> None:
|
|
8030
8759
|
global global_gs
|
|
8031
8760
|
global overwritten_to_random
|
|
8032
8761
|
|
|
8762
|
+
print("Reverting to SOBOL")
|
|
8763
|
+
|
|
8033
8764
|
global_gs = GenerationStrategy(
|
|
8034
8765
|
name="Random*",
|
|
8035
8766
|
nodes=[
|
|
8036
8767
|
GenerationNode(
|
|
8037
8768
|
node_name="Sobol",
|
|
8038
|
-
|
|
8039
|
-
|
|
8040
|
-
|
|
8041
|
-
|
|
8042
|
-
|
|
8043
|
-
|
|
8769
|
+
should_deduplicate=True,
|
|
8770
|
+
generator_specs=[ # type: ignore[arg-type]
|
|
8771
|
+
GeneratorSpec( # type: ignore[arg-type]
|
|
8772
|
+
Models.SOBOL, # type: ignore[arg-type]
|
|
8773
|
+
model_gen_kwargs=get_model_gen_kwargs() # type: ignore[arg-type]
|
|
8774
|
+
) # type: ignore[arg-type]
|
|
8775
|
+
] # type: ignore[arg-type]
|
|
8044
8776
|
)
|
|
8045
8777
|
]
|
|
8046
8778
|
)
|
|
@@ -8400,16 +9132,21 @@ def get_model_from_name(name: str) -> Any:
|
|
|
8400
9132
|
return gen
|
|
8401
9133
|
raise ValueError(f"Unknown or unsupported model: {name}")
|
|
8402
9134
|
|
|
8403
|
-
def get_name_from_model(model) ->
|
|
9135
|
+
def get_name_from_model(model: Any) -> str:
|
|
8404
9136
|
if not isinstance(SUPPORTED_MODELS, (list, set, tuple)):
|
|
8405
|
-
|
|
9137
|
+
raise RuntimeError("get_model_from_name: SUPPORTED_MODELS was not a list, set or tuple. Cannot continue")
|
|
8406
9138
|
|
|
8407
9139
|
model_str = model.value if hasattr(model, "value") else str(model)
|
|
8408
9140
|
|
|
8409
9141
|
model_str_lower = model_str.lower()
|
|
8410
9142
|
model_map = {m.lower(): m for m in SUPPORTED_MODELS}
|
|
8411
9143
|
|
|
8412
|
-
|
|
9144
|
+
ret = model_map.get(model_str_lower, None)
|
|
9145
|
+
|
|
9146
|
+
if ret is None:
|
|
9147
|
+
raise RuntimeError("get_name_from_model: failed to get Model")
|
|
9148
|
+
|
|
9149
|
+
return ret
|
|
8413
9150
|
|
|
8414
9151
|
def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str, int]], int]:
|
|
8415
9152
|
gen_strat_list = []
|
|
@@ -8420,10 +9157,10 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
|
|
|
8420
9157
|
|
|
8421
9158
|
for s in splitted_by_comma:
|
|
8422
9159
|
if "=" not in s:
|
|
8423
|
-
|
|
9160
|
+
print_red(f"'{s}' does not contain '='")
|
|
8424
9161
|
my_exit(123)
|
|
8425
9162
|
if s.count("=") != 1:
|
|
8426
|
-
|
|
9163
|
+
print_red(f"There can only be one '=' in the gen_strat_str's element '{s}'")
|
|
8427
9164
|
my_exit(123)
|
|
8428
9165
|
|
|
8429
9166
|
model_name, nr_str = s.split("=")
|
|
@@ -8433,13 +9170,13 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
|
|
|
8433
9170
|
_fatal_error(f"Model {matching_model} is not valid for custom generation strategy.", 56)
|
|
8434
9171
|
|
|
8435
9172
|
if not matching_model:
|
|
8436
|
-
|
|
9173
|
+
print_red(f"'{model_name}' not found in SUPPORTED_MODELS")
|
|
8437
9174
|
my_exit(123)
|
|
8438
9175
|
|
|
8439
9176
|
try:
|
|
8440
9177
|
nr = int(nr_str)
|
|
8441
9178
|
except ValueError:
|
|
8442
|
-
|
|
9179
|
+
print_red(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
|
|
8443
9180
|
my_exit(123)
|
|
8444
9181
|
|
|
8445
9182
|
gen_strat_list.append({matching_model: nr})
|
|
@@ -8596,11 +9333,12 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
|
|
|
8596
9333
|
if model_name.lower() != "sobol":
|
|
8597
9334
|
kwargs["model_kwargs"] = get_model_kwargs()
|
|
8598
9335
|
|
|
8599
|
-
model_spec = [GeneratorSpec(selected_model, **kwargs)]
|
|
9336
|
+
model_spec = [GeneratorSpec(selected_model, **kwargs)] # type: ignore[arg-type]
|
|
8600
9337
|
|
|
8601
9338
|
res = GenerationNode(
|
|
8602
9339
|
node_name=model_name,
|
|
8603
9340
|
generator_specs=model_spec,
|
|
9341
|
+
should_deduplicate=True,
|
|
8604
9342
|
transition_criteria=trans_crit
|
|
8605
9343
|
)
|
|
8606
9344
|
|
|
@@ -8611,7 +9349,7 @@ def get_optimizer_kwargs() -> dict:
|
|
|
8611
9349
|
"sequential": False
|
|
8612
9350
|
}
|
|
8613
9351
|
|
|
8614
|
-
def create_step(model_name: str, _num_trials: int
|
|
9352
|
+
def create_step(model_name: str, _num_trials: int, index: int) -> GenerationStep:
|
|
8615
9353
|
model_enum = get_model_from_name(model_name)
|
|
8616
9354
|
|
|
8617
9355
|
return GenerationStep(
|
|
@@ -8626,17 +9364,16 @@ def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = N
|
|
|
8626
9364
|
)
|
|
8627
9365
|
|
|
8628
9366
|
def set_global_generation_strategy() -> None:
|
|
8629
|
-
|
|
8630
|
-
continue_not_supported_on_custom_generation_strategy()
|
|
9367
|
+
continue_not_supported_on_custom_generation_strategy()
|
|
8631
9368
|
|
|
8632
|
-
|
|
8633
|
-
|
|
8634
|
-
|
|
8635
|
-
|
|
8636
|
-
|
|
8637
|
-
|
|
8638
|
-
|
|
8639
|
-
|
|
9369
|
+
try:
|
|
9370
|
+
if args.generation_strategy is None:
|
|
9371
|
+
setup_default_generation_strategy()
|
|
9372
|
+
else:
|
|
9373
|
+
setup_custom_generation_strategy()
|
|
9374
|
+
except Exception as e:
|
|
9375
|
+
print_red(f"Unexpected error in generation strategy setup: {e}")
|
|
9376
|
+
my_exit(111)
|
|
8640
9377
|
|
|
8641
9378
|
if global_gs is None:
|
|
8642
9379
|
print_red("global_gs is None after setup!")
|
|
@@ -8787,10 +9524,14 @@ def execute_trials(
|
|
|
8787
9524
|
result = future.result()
|
|
8788
9525
|
print_debug(f"result in execute_trials: {result}")
|
|
8789
9526
|
except Exception as exc:
|
|
8790
|
-
|
|
9527
|
+
failed_args = future_to_args[future]
|
|
9528
|
+
print_red(f"execute_trials: Error at executing a trial with args {failed_args}: {exc}")
|
|
9529
|
+
traceback.print_exc()
|
|
8791
9530
|
|
|
8792
9531
|
end_time = time.time()
|
|
8793
9532
|
|
|
9533
|
+
log_data()
|
|
9534
|
+
|
|
8794
9535
|
duration = float(end_time - start_time)
|
|
8795
9536
|
job_submit_durations.append(duration)
|
|
8796
9537
|
job_submit_nrs.append(cnt)
|
|
@@ -8824,20 +9565,20 @@ def create_and_execute_next_runs(next_nr_steps: int, phase: Optional[str], _max_
|
|
|
8824
9565
|
done_optimizing: bool = False
|
|
8825
9566
|
|
|
8826
9567
|
try:
|
|
8827
|
-
done_optimizing, trial_index_to_param =
|
|
8828
|
-
|
|
9568
|
+
done_optimizing, trial_index_to_param = create_and_execute_next_runs_run_loop(_max_eval, phase)
|
|
9569
|
+
create_and_execute_next_runs_finish(done_optimizing)
|
|
8829
9570
|
except Exception as e:
|
|
8830
9571
|
stacktrace = traceback.format_exc()
|
|
8831
9572
|
print_debug(f"Warning: create_and_execute_next_runs encountered an exception: {e}\n{stacktrace}")
|
|
8832
9573
|
return handle_exceptions_create_and_execute_next_runs(e)
|
|
8833
9574
|
|
|
8834
|
-
return
|
|
9575
|
+
return create_and_execute_next_runs_return_value(trial_index_to_param)
|
|
8835
9576
|
|
|
8836
|
-
def
|
|
9577
|
+
def create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
|
|
8837
9578
|
done_optimizing = False
|
|
8838
9579
|
trial_index_to_param: Optional[Dict] = None
|
|
8839
9580
|
|
|
8840
|
-
nr_of_jobs_to_get =
|
|
9581
|
+
nr_of_jobs_to_get = calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
|
|
8841
9582
|
|
|
8842
9583
|
__max_eval = _max_eval if _max_eval is not None else 0
|
|
8843
9584
|
new_nr_of_jobs_to_get = min(__max_eval - (submitted_jobs() - failed_jobs()), nr_of_jobs_to_get)
|
|
@@ -8851,6 +9592,7 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
|
|
|
8851
9592
|
|
|
8852
9593
|
for _ in range(range_nr):
|
|
8853
9594
|
trial_index_to_param, done_optimizing = get_next_trials(get_next_trials_nr)
|
|
9595
|
+
log_data()
|
|
8854
9596
|
if done_optimizing:
|
|
8855
9597
|
continue
|
|
8856
9598
|
|
|
@@ -8875,13 +9617,13 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
|
|
|
8875
9617
|
|
|
8876
9618
|
return done_optimizing, trial_index_to_param
|
|
8877
9619
|
|
|
8878
|
-
def
|
|
9620
|
+
def create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
|
|
8879
9621
|
finish_previous_jobs(["finishing jobs"])
|
|
8880
9622
|
|
|
8881
9623
|
if done_optimizing:
|
|
8882
9624
|
end_program(False, 0)
|
|
8883
9625
|
|
|
8884
|
-
def
|
|
9626
|
+
def create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
|
|
8885
9627
|
try:
|
|
8886
9628
|
if trial_index_to_param:
|
|
8887
9629
|
res = len(trial_index_to_param.keys())
|
|
@@ -8992,7 +9734,7 @@ def execute_nvidia_smi() -> None:
|
|
|
8992
9734
|
if not host:
|
|
8993
9735
|
print_debug("host not defined")
|
|
8994
9736
|
except Exception as e:
|
|
8995
|
-
|
|
9737
|
+
print_red(f"execute_nvidia_smi: An error occurred: {e}")
|
|
8996
9738
|
if is_slurm_job() and not args.force_local_execution:
|
|
8997
9739
|
_sleep(30)
|
|
8998
9740
|
|
|
@@ -9030,11 +9772,6 @@ def run_search() -> bool:
|
|
|
9030
9772
|
|
|
9031
9773
|
return False
|
|
9032
9774
|
|
|
9033
|
-
async def start_logging_daemon() -> None:
|
|
9034
|
-
while True:
|
|
9035
|
-
log_data()
|
|
9036
|
-
time.sleep(30)
|
|
9037
|
-
|
|
9038
9775
|
def should_break_search() -> bool:
|
|
9039
9776
|
ret = False
|
|
9040
9777
|
|
|
@@ -9083,10 +9820,8 @@ def check_search_space_exhaustion(nr_of_items: int) -> bool:
|
|
|
9083
9820
|
print_debug(_wrn)
|
|
9084
9821
|
progressbar_description(_wrn)
|
|
9085
9822
|
|
|
9086
|
-
live_share()
|
|
9087
9823
|
return True
|
|
9088
9824
|
|
|
9089
|
-
live_share()
|
|
9090
9825
|
return False
|
|
9091
9826
|
|
|
9092
9827
|
def finalize_jobs() -> None:
|
|
@@ -9100,7 +9835,7 @@ def finalize_jobs() -> None:
|
|
|
9100
9835
|
handle_slurm_execution()
|
|
9101
9836
|
|
|
9102
9837
|
def go_through_jobs_that_are_not_completed_yet() -> None:
|
|
9103
|
-
print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
|
|
9838
|
+
#print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
|
|
9104
9839
|
|
|
9105
9840
|
nr_jobs_left = len(global_vars['jobs'])
|
|
9106
9841
|
if nr_jobs_left == 1:
|
|
@@ -9170,7 +9905,7 @@ def parse_orchestrator_file(_f: str, _test: bool = False) -> Union[dict, None]:
|
|
|
9170
9905
|
|
|
9171
9906
|
return data
|
|
9172
9907
|
except Exception as e:
|
|
9173
|
-
|
|
9908
|
+
print_red(f"Error while parse_experiment_parameters({_f}): {e}")
|
|
9174
9909
|
else:
|
|
9175
9910
|
print_red(f"{_f} could not be found")
|
|
9176
9911
|
|
|
@@ -9178,364 +9913,52 @@ def parse_orchestrator_file(_f: str, _test: bool = False) -> Union[dict, None]:
|
|
|
9178
9913
|
|
|
9179
9914
|
def set_orchestrator() -> None:
|
|
9180
9915
|
with spinner("Setting orchestrator..."):
|
|
9181
|
-
global orchestrator
|
|
9182
|
-
|
|
9183
|
-
if args.orchestrator_file:
|
|
9184
|
-
if SYSTEM_HAS_SBATCH:
|
|
9185
|
-
orchestrator = parse_orchestrator_file(args.orchestrator_file, False)
|
|
9186
|
-
else:
|
|
9187
|
-
print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
|
|
9188
|
-
|
|
9189
|
-
def check_if_has_random_steps() -> None:
|
|
9190
|
-
if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
|
|
9191
|
-
_fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
|
|
9192
|
-
|
|
9193
|
-
def add_exclude_to_defective_nodes() -> None:
|
|
9194
|
-
with spinner("Adding excluded nodes..."):
|
|
9195
|
-
if args.exclude:
|
|
9196
|
-
entries = [entry.strip() for entry in args.exclude.split(',')]
|
|
9197
|
-
|
|
9198
|
-
for entry in entries:
|
|
9199
|
-
count_defective_nodes(None, entry)
|
|
9200
|
-
|
|
9201
|
-
def check_max_eval(_max_eval: int) -> None:
|
|
9202
|
-
with spinner("Checking max_eval..."):
|
|
9203
|
-
if not _max_eval:
|
|
9204
|
-
_fatal_error("--max_eval needs to be set!", 19)
|
|
9205
|
-
|
|
9206
|
-
def parse_parameters() -> Any:
|
|
9207
|
-
cli_params_experiment_parameters = None
|
|
9208
|
-
if args.parameter:
|
|
9209
|
-
parse_experiment_parameters()
|
|
9210
|
-
cli_params_experiment_parameters = experiment_parameters
|
|
9211
|
-
|
|
9212
|
-
return cli_params_experiment_parameters
|
|
9213
|
-
|
|
9214
|
-
def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
|
|
9215
|
-
table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
|
|
9216
|
-
|
|
9217
|
-
rows = _pareto_front_table_read_csv()
|
|
9218
|
-
if not rows:
|
|
9219
|
-
table.add_column("No data found")
|
|
9220
|
-
return table
|
|
9221
|
-
|
|
9222
|
-
filtered_rows = _pareto_front_table_filter_rows(rows, idxs)
|
|
9223
|
-
if not filtered_rows:
|
|
9224
|
-
table.add_column("No matching entries")
|
|
9225
|
-
return table
|
|
9226
|
-
|
|
9227
|
-
param_cols, result_cols = _pareto_front_table_get_columns(filtered_rows[0])
|
|
9228
|
-
|
|
9229
|
-
_pareto_front_table_add_headers(table, param_cols, result_cols)
|
|
9230
|
-
_pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
|
|
9231
|
-
|
|
9232
|
-
return table
|
|
9233
|
-
|
|
9234
|
-
def _pareto_front_table_read_csv() -> List[Dict[str, str]]:
|
|
9235
|
-
with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
|
|
9236
|
-
return list(csv.DictReader(f))
|
|
9237
|
-
|
|
9238
|
-
def _pareto_front_table_filter_rows(rows: List[Dict[str, str]], idxs: List[int]) -> List[Dict[str, str]]:
|
|
9239
|
-
result = []
|
|
9240
|
-
for row in rows:
|
|
9241
|
-
try:
|
|
9242
|
-
trial_index = int(row["trial_index"])
|
|
9243
|
-
except (KeyError, ValueError):
|
|
9244
|
-
continue
|
|
9245
|
-
|
|
9246
|
-
if row.get("trial_status", "").strip().upper() == "COMPLETED" and trial_index in idxs:
|
|
9247
|
-
result.append(row)
|
|
9248
|
-
return result
|
|
9249
|
-
|
|
9250
|
-
def _pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
|
|
9251
|
-
all_columns = list(first_row.keys())
|
|
9252
|
-
ignored_cols = set(special_col_names) - {"trial_index"}
|
|
9253
|
-
|
|
9254
|
-
param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
|
|
9255
|
-
result_cols = [col for col in arg_result_names if col in all_columns]
|
|
9256
|
-
return param_cols, result_cols
|
|
9257
|
-
|
|
9258
|
-
def _pareto_front_table_add_headers(table: Table, param_cols: List[str], result_cols: List[str]) -> None:
|
|
9259
|
-
for col in param_cols:
|
|
9260
|
-
table.add_column(col, justify="center")
|
|
9261
|
-
for col in result_cols:
|
|
9262
|
-
table.add_column(Text(f"{col}", style="cyan"), justify="center")
|
|
9263
|
-
|
|
9264
|
-
def _pareto_front_table_add_rows(table: Table, rows: List[Dict[str, str]], param_cols: List[str], result_cols: List[str]) -> None:
|
|
9265
|
-
for row in rows:
|
|
9266
|
-
values = [str(helpers.to_int_when_possible(row[col])) for col in param_cols]
|
|
9267
|
-
result_values = [Text(str(helpers.to_int_when_possible(row[col])), style="cyan") for col in result_cols]
|
|
9268
|
-
table.add_row(*values, *result_values, style="bold green")
|
|
9269
|
-
|
|
9270
|
-
def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
|
|
9271
|
-
if not os.path.exists(RESULT_CSV_FILE):
|
|
9272
|
-
print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
|
|
9273
|
-
return None
|
|
9274
|
-
|
|
9275
|
-
return create_pareto_front_table(idxs, metric_x, metric_y)
|
|
9276
|
-
|
|
9277
|
-
def supports_sixel() -> bool:
|
|
9278
|
-
term = os.environ.get("TERM", "").lower()
|
|
9279
|
-
if "xterm" in term or "mlterm" in term:
|
|
9280
|
-
return True
|
|
9281
|
-
|
|
9282
|
-
try:
|
|
9283
|
-
output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
|
|
9284
|
-
if output.returncode == 0 and "sixel" in output.stdout.lower():
|
|
9285
|
-
return True
|
|
9286
|
-
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
9287
|
-
pass
|
|
9288
|
-
|
|
9289
|
-
return False
|
|
9290
|
-
|
|
9291
|
-
def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
|
|
9292
|
-
if data is None:
|
|
9293
|
-
print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
|
|
9294
|
-
return
|
|
9295
|
-
|
|
9296
|
-
if not supports_sixel():
|
|
9297
|
-
print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
|
|
9298
|
-
return
|
|
9299
|
-
|
|
9300
|
-
import matplotlib.pyplot as plt
|
|
9301
|
-
|
|
9302
|
-
means = data[x_metric][y_metric]["means"]
|
|
9303
|
-
|
|
9304
|
-
x_values = means[x_metric]
|
|
9305
|
-
y_values = means[y_metric]
|
|
9306
|
-
|
|
9307
|
-
fig, _ax = plt.subplots()
|
|
9308
|
-
|
|
9309
|
-
_ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
|
|
9310
|
-
|
|
9311
|
-
_ax.set_xlabel(x_metric)
|
|
9312
|
-
_ax.set_ylabel(y_metric)
|
|
9313
|
-
|
|
9314
|
-
_ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
|
|
9315
|
-
|
|
9316
|
-
_ax.ticklabel_format(style='plain', axis='both', useOffset=False)
|
|
9317
|
-
|
|
9318
|
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
|
|
9319
|
-
plt.savefig(tmp_file.name, dpi=300)
|
|
9320
|
-
|
|
9321
|
-
print_image_to_cli(tmp_file.name, 1000)
|
|
9322
|
-
|
|
9323
|
-
plt.close(fig)
|
|
9324
|
-
|
|
9325
|
-
def _pareto_front_general_validate_shapes(x: np.ndarray, y: np.ndarray) -> None:
|
|
9326
|
-
if x.shape != y.shape:
|
|
9327
|
-
raise ValueError("Input arrays x and y must have the same shape.")
|
|
9328
|
-
|
|
9329
|
-
def _pareto_front_general_compare(
|
|
9330
|
-
xi: float, yi: float, xj: float, yj: float,
|
|
9331
|
-
x_minimize: bool, y_minimize: bool
|
|
9332
|
-
) -> bool:
|
|
9333
|
-
x_better_eq = xj <= xi if x_minimize else xj >= xi
|
|
9334
|
-
y_better_eq = yj <= yi if y_minimize else yj >= yi
|
|
9335
|
-
x_strictly_better = xj < xi if x_minimize else xj > xi
|
|
9336
|
-
y_strictly_better = yj < yi if y_minimize else yj > yi
|
|
9337
|
-
|
|
9338
|
-
return bool(x_better_eq and y_better_eq and (x_strictly_better or y_strictly_better))
|
|
9339
|
-
|
|
9340
|
-
def _pareto_front_general_find_dominated(
|
|
9341
|
-
x: np.ndarray, y: np.ndarray, x_minimize: bool, y_minimize: bool
|
|
9342
|
-
) -> np.ndarray:
|
|
9343
|
-
num_points = len(x)
|
|
9344
|
-
is_dominated = np.zeros(num_points, dtype=bool)
|
|
9345
|
-
|
|
9346
|
-
for i in range(num_points):
|
|
9347
|
-
for j in range(num_points):
|
|
9348
|
-
if i == j:
|
|
9349
|
-
continue
|
|
9350
|
-
|
|
9351
|
-
if _pareto_front_general_compare(x[i], y[i], x[j], y[j], x_minimize, y_minimize):
|
|
9352
|
-
is_dominated[i] = True
|
|
9353
|
-
break
|
|
9354
|
-
|
|
9355
|
-
return is_dominated
|
|
9356
|
-
|
|
9357
|
-
def pareto_front_general(
|
|
9358
|
-
x: np.ndarray,
|
|
9359
|
-
y: np.ndarray,
|
|
9360
|
-
x_minimize: bool = True,
|
|
9361
|
-
y_minimize: bool = True
|
|
9362
|
-
) -> np.ndarray:
|
|
9363
|
-
try:
|
|
9364
|
-
_pareto_front_general_validate_shapes(x, y)
|
|
9365
|
-
is_dominated = _pareto_front_general_find_dominated(x, y, x_minimize, y_minimize)
|
|
9366
|
-
return np.where(~is_dominated)[0]
|
|
9367
|
-
except Exception as e:
|
|
9368
|
-
print("Error in pareto_front_general:", str(e))
|
|
9369
|
-
return np.array([], dtype=int)
|
|
9370
|
-
|
|
9371
|
-
def _pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
|
|
9372
|
-
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
9373
|
-
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
9374
|
-
|
|
9375
|
-
if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
|
|
9376
|
-
return None
|
|
9377
|
-
|
|
9378
|
-
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
9379
|
-
result_names = [line.strip() for line in f if line.strip()]
|
|
9380
|
-
|
|
9381
|
-
records: dict = defaultdict(lambda: {'means': {}})
|
|
9382
|
-
|
|
9383
|
-
with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
|
|
9384
|
-
reader = csv.DictReader(csvfile)
|
|
9385
|
-
for row in reader:
|
|
9386
|
-
trial_index = int(row['trial_index'])
|
|
9387
|
-
arm_name = row['arm_name']
|
|
9388
|
-
key = (trial_index, arm_name)
|
|
9389
|
-
|
|
9390
|
-
for metric in result_names:
|
|
9391
|
-
if metric in row:
|
|
9392
|
-
try:
|
|
9393
|
-
records[key]['means'][metric] = float(row[metric])
|
|
9394
|
-
except ValueError:
|
|
9395
|
-
continue
|
|
9396
|
-
|
|
9397
|
-
return records
|
|
9398
|
-
|
|
9399
|
-
def _pareto_front_filter_complete_points(
|
|
9400
|
-
path_to_calculate: str,
|
|
9401
|
-
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
9402
|
-
primary_name: str,
|
|
9403
|
-
secondary_name: str
|
|
9404
|
-
) -> List[Tuple[Tuple[int, str], float, float]]:
|
|
9405
|
-
points = []
|
|
9406
|
-
for key, metrics in records.items():
|
|
9407
|
-
means = metrics['means']
|
|
9408
|
-
if primary_name in means and secondary_name in means:
|
|
9409
|
-
x_val = means[primary_name]
|
|
9410
|
-
y_val = means[secondary_name]
|
|
9411
|
-
points.append((key, x_val, y_val))
|
|
9412
|
-
if len(points) == 0:
|
|
9413
|
-
raise ValueError(f"No full data points with both objectives found in {path_to_calculate}.")
|
|
9414
|
-
return points
|
|
9415
|
-
|
|
9416
|
-
def _pareto_front_transform_objectives(
|
|
9417
|
-
points: List[Tuple[Any, float, float]],
|
|
9418
|
-
primary_name: str,
|
|
9419
|
-
secondary_name: str
|
|
9420
|
-
) -> Tuple[np.ndarray, np.ndarray]:
|
|
9421
|
-
primary_idx = arg_result_names.index(primary_name)
|
|
9422
|
-
secondary_idx = arg_result_names.index(secondary_name)
|
|
9423
|
-
|
|
9424
|
-
x = np.array([p[1] for p in points])
|
|
9425
|
-
y = np.array([p[2] for p in points])
|
|
9426
|
-
|
|
9427
|
-
if arg_result_min_or_max[primary_idx] == "max":
|
|
9428
|
-
x = -x
|
|
9429
|
-
elif arg_result_min_or_max[primary_idx] != "min":
|
|
9430
|
-
raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
|
|
9431
|
-
|
|
9432
|
-
if arg_result_min_or_max[secondary_idx] == "max":
|
|
9433
|
-
y = -y
|
|
9434
|
-
elif arg_result_min_or_max[secondary_idx] != "min":
|
|
9435
|
-
raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
|
|
9436
|
-
|
|
9437
|
-
return x, y
|
|
9438
|
-
|
|
9439
|
-
def _pareto_front_select_pareto_points(
|
|
9440
|
-
x: np.ndarray,
|
|
9441
|
-
y: np.ndarray,
|
|
9442
|
-
x_minimize: bool,
|
|
9443
|
-
y_minimize: bool,
|
|
9444
|
-
points: List[Tuple[Any, float, float]],
|
|
9445
|
-
num_points: int
|
|
9446
|
-
) -> List[Tuple[Any, float, float]]:
|
|
9447
|
-
indices = pareto_front_general(x, y, x_minimize, y_minimize)
|
|
9448
|
-
sorted_indices = indices[np.argsort(x[indices])]
|
|
9449
|
-
sorted_indices = sorted_indices[:num_points]
|
|
9450
|
-
selected_points = [points[i] for i in sorted_indices]
|
|
9451
|
-
return selected_points
|
|
9452
|
-
|
|
9453
|
-
def _pareto_front_build_return_structure(
|
|
9454
|
-
path_to_calculate: str,
|
|
9455
|
-
selected_points: List[Tuple[Any, float, float]],
|
|
9456
|
-
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
9457
|
-
absolute_metrics: List[str],
|
|
9458
|
-
primary_name: str,
|
|
9459
|
-
secondary_name: str
|
|
9460
|
-
) -> dict:
|
|
9461
|
-
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
9462
|
-
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
9463
|
-
|
|
9464
|
-
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
9465
|
-
result_names = [line.strip() for line in f if line.strip()]
|
|
9466
|
-
|
|
9467
|
-
csv_rows = {}
|
|
9468
|
-
with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
|
|
9469
|
-
reader = csv.DictReader(csvfile)
|
|
9470
|
-
for row in reader:
|
|
9471
|
-
trial_index = int(row['trial_index'])
|
|
9472
|
-
csv_rows[trial_index] = row
|
|
9473
|
-
|
|
9474
|
-
ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
|
|
9475
|
-
ignored_columns.update(result_names)
|
|
9476
|
-
|
|
9477
|
-
param_dicts = []
|
|
9478
|
-
idxs = []
|
|
9479
|
-
means_dict = defaultdict(list)
|
|
9480
|
-
|
|
9481
|
-
for (trial_index, arm_name), _, _ in selected_points:
|
|
9482
|
-
row = csv_rows.get(trial_index, {})
|
|
9483
|
-
if row == {} or row is None or row['arm_name'] != arm_name:
|
|
9484
|
-
print_debug(f"_pareto_front_build_return_structure: trial_index '{trial_index}' could not be found and row returned as None")
|
|
9485
|
-
continue
|
|
9916
|
+
global orchestrator
|
|
9486
9917
|
|
|
9487
|
-
|
|
9918
|
+
if args.orchestrator_file:
|
|
9919
|
+
if SYSTEM_HAS_SBATCH:
|
|
9920
|
+
orchestrator = parse_orchestrator_file(args.orchestrator_file, False)
|
|
9921
|
+
else:
|
|
9922
|
+
print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
|
|
9488
9923
|
|
|
9489
|
-
|
|
9490
|
-
|
|
9491
|
-
|
|
9492
|
-
try:
|
|
9493
|
-
param_dict[key] = int(value)
|
|
9494
|
-
except ValueError:
|
|
9495
|
-
try:
|
|
9496
|
-
param_dict[key] = float(value)
|
|
9497
|
-
except ValueError:
|
|
9498
|
-
param_dict[key] = value
|
|
9924
|
+
def check_if_has_random_steps() -> None:
|
|
9925
|
+
if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
|
|
9926
|
+
_fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
|
|
9499
9927
|
|
|
9500
|
-
|
|
9928
|
+
def add_exclude_to_defective_nodes() -> None:
|
|
9929
|
+
with spinner("Adding excluded nodes..."):
|
|
9930
|
+
if args.exclude:
|
|
9931
|
+
entries = [entry.strip() for entry in args.exclude.split(',')]
|
|
9501
9932
|
|
|
9502
|
-
|
|
9503
|
-
|
|
9933
|
+
for entry in entries:
|
|
9934
|
+
count_defective_nodes(None, entry)
|
|
9504
9935
|
|
|
9505
|
-
|
|
9506
|
-
|
|
9507
|
-
|
|
9508
|
-
|
|
9509
|
-
"param_dicts": param_dicts,
|
|
9510
|
-
"means": dict(means_dict),
|
|
9511
|
-
"idxs": idxs
|
|
9512
|
-
},
|
|
9513
|
-
"absolute_metrics": absolute_metrics
|
|
9514
|
-
}
|
|
9515
|
-
}
|
|
9936
|
+
def check_max_eval(_max_eval: int) -> None:
|
|
9937
|
+
with spinner("Checking max_eval..."):
|
|
9938
|
+
if not _max_eval:
|
|
9939
|
+
_fatal_error("--max_eval needs to be set!", 19)
|
|
9516
9940
|
|
|
9517
|
-
|
|
9941
|
+
def parse_parameters() -> Any:
|
|
9942
|
+
cli_params_experiment_parameters = None
|
|
9943
|
+
if args.parameter:
|
|
9944
|
+
parse_experiment_parameters()
|
|
9945
|
+
cli_params_experiment_parameters = experiment_parameters
|
|
9518
9946
|
|
|
9519
|
-
|
|
9520
|
-
path_to_calculate: str,
|
|
9521
|
-
primary_objective: str,
|
|
9522
|
-
secondary_objective: str,
|
|
9523
|
-
x_minimize: bool,
|
|
9524
|
-
y_minimize: bool,
|
|
9525
|
-
absolute_metrics: List[str],
|
|
9526
|
-
num_points: int
|
|
9527
|
-
) -> Optional[dict]:
|
|
9528
|
-
records = _pareto_front_aggregate_data(path_to_calculate)
|
|
9947
|
+
return cli_params_experiment_parameters
|
|
9529
9948
|
|
|
9530
|
-
|
|
9531
|
-
|
|
9949
|
+
def supports_sixel() -> bool:
|
|
9950
|
+
term = os.environ.get("TERM", "").lower()
|
|
9951
|
+
if "xterm" in term or "mlterm" in term:
|
|
9952
|
+
return True
|
|
9532
9953
|
|
|
9533
|
-
|
|
9534
|
-
|
|
9535
|
-
|
|
9536
|
-
|
|
9954
|
+
try:
|
|
9955
|
+
output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
|
|
9956
|
+
if output.returncode == 0 and "sixel" in output.stdout.lower():
|
|
9957
|
+
return True
|
|
9958
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
9959
|
+
pass
|
|
9537
9960
|
|
|
9538
|
-
return
|
|
9961
|
+
return False
|
|
9539
9962
|
|
|
9540
9963
|
def save_experiment_state() -> None:
|
|
9541
9964
|
try:
|
|
@@ -9543,14 +9966,14 @@ def save_experiment_state() -> None:
|
|
|
9543
9966
|
print_red("save_experiment_state: ax_client or ax_client.experiment is None, cannot save.")
|
|
9544
9967
|
return
|
|
9545
9968
|
state_path = get_current_run_folder("experiment_state.json")
|
|
9546
|
-
|
|
9969
|
+
save_ax_client_to_json_file(state_path)
|
|
9547
9970
|
except Exception as e:
|
|
9548
|
-
|
|
9971
|
+
print_debug(f"Error saving experiment state: {e}")
|
|
9549
9972
|
|
|
9550
9973
|
def wait_for_state_file(state_path: str, min_size: int = 5, max_wait_seconds: int = 60) -> bool:
|
|
9551
9974
|
try:
|
|
9552
9975
|
if not os.path.exists(state_path):
|
|
9553
|
-
|
|
9976
|
+
print_debug(f"[ERROR] File '{state_path}' does not exist.")
|
|
9554
9977
|
return False
|
|
9555
9978
|
|
|
9556
9979
|
i = 0
|
|
@@ -9618,7 +10041,7 @@ def load_experiment_state() -> None:
|
|
|
9618
10041
|
return
|
|
9619
10042
|
|
|
9620
10043
|
try:
|
|
9621
|
-
arms_seen = {}
|
|
10044
|
+
arms_seen: dict = {}
|
|
9622
10045
|
for arm in data.get("arms", []):
|
|
9623
10046
|
name = arm.get("name")
|
|
9624
10047
|
sig = arm.get("parameters")
|
|
@@ -9769,38 +10192,47 @@ def get_result_minimize_flag(path_to_calculate: str, resname: str) -> bool:
|
|
|
9769
10192
|
|
|
9770
10193
|
return minmax[index] == "min"
|
|
9771
10194
|
|
|
9772
|
-
def
|
|
9773
|
-
|
|
10195
|
+
def post_job_calculate_pareto_front() -> None:
|
|
10196
|
+
if not args.calculate_pareto_front_of_job:
|
|
10197
|
+
return
|
|
9774
10198
|
|
|
9775
|
-
|
|
10199
|
+
failure = False
|
|
9776
10200
|
|
|
9777
|
-
|
|
10201
|
+
_paths_to_calculate = []
|
|
9778
10202
|
|
|
9779
|
-
for
|
|
9780
|
-
|
|
9781
|
-
|
|
9782
|
-
metric_y = arg_result_names[j]
|
|
10203
|
+
for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
|
|
10204
|
+
try:
|
|
10205
|
+
found_paths = find_results_paths(_path_to_calculate)
|
|
9783
10206
|
|
|
9784
|
-
|
|
9785
|
-
|
|
10207
|
+
for _fp in found_paths:
|
|
10208
|
+
if _fp not in _paths_to_calculate:
|
|
10209
|
+
_paths_to_calculate.append(_fp)
|
|
10210
|
+
except (FileNotFoundError, NotADirectoryError) as e:
|
|
10211
|
+
print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
|
|
9786
10212
|
|
|
9787
|
-
|
|
9788
|
-
if metric_x not in pareto_front_data:
|
|
9789
|
-
pareto_front_data[metric_x] = {}
|
|
10213
|
+
failure = True
|
|
9790
10214
|
|
|
9791
|
-
|
|
9792
|
-
|
|
9793
|
-
|
|
9794
|
-
|
|
9795
|
-
print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
|
|
9796
|
-
skip = True
|
|
10215
|
+
for _path_to_calculate in _paths_to_calculate:
|
|
10216
|
+
for path_to_calculate in found_paths:
|
|
10217
|
+
if not job_calculate_pareto_front(path_to_calculate):
|
|
10218
|
+
failure = True
|
|
9797
10219
|
|
|
9798
|
-
|
|
10220
|
+
if failure:
|
|
10221
|
+
my_exit(24)
|
|
10222
|
+
|
|
10223
|
+
my_exit(0)
|
|
10224
|
+
|
|
10225
|
+
def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
|
|
10226
|
+
if not os.path.exists(RESULT_CSV_FILE):
|
|
10227
|
+
print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
|
|
10228
|
+
return None
|
|
10229
|
+
|
|
10230
|
+
return create_pareto_front_table(idxs, metric_x, metric_y)
|
|
9799
10231
|
|
|
9800
10232
|
def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_sixel_and_table: bool = False) -> None:
|
|
9801
10233
|
if len(res_names) <= 1:
|
|
9802
10234
|
print_debug(f"--result_names (has {len(res_names)} entries) must be at least 2.")
|
|
9803
|
-
return
|
|
10235
|
+
return None
|
|
9804
10236
|
|
|
9805
10237
|
pareto_front_data: dict = get_pareto_front_data(path_to_calculate, res_names)
|
|
9806
10238
|
|
|
@@ -9821,8 +10253,16 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
|
|
|
9821
10253
|
else:
|
|
9822
10254
|
print(f"Not showing Pareto-front-sixel for {path_to_calculate}")
|
|
9823
10255
|
|
|
9824
|
-
if
|
|
9825
|
-
|
|
10256
|
+
if calculated_frontier is None:
|
|
10257
|
+
print_debug("ERROR: calculated_frontier is None")
|
|
10258
|
+
return None
|
|
10259
|
+
|
|
10260
|
+
try:
|
|
10261
|
+
if len(calculated_frontier[metric_x][metric_y]["idxs"]):
|
|
10262
|
+
pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
|
|
10263
|
+
except AttributeError:
|
|
10264
|
+
print_debug(f"ERROR: calculated_frontier structure invalid for ({metric_x}, {metric_y})")
|
|
10265
|
+
return None
|
|
9826
10266
|
|
|
9827
10267
|
rich_table = pareto_front_as_rich_table(
|
|
9828
10268
|
calculated_frontier[metric_x][metric_y]["idxs"],
|
|
@@ -9847,6 +10287,8 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
|
|
|
9847
10287
|
|
|
9848
10288
|
live_share_after_pareto()
|
|
9849
10289
|
|
|
10290
|
+
return None
|
|
10291
|
+
|
|
9850
10292
|
def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_color: str) -> None:
|
|
9851
10293
|
cpu_count = os.cpu_count()
|
|
9852
10294
|
|
|
@@ -9858,9 +10300,11 @@ def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_
|
|
|
9858
10300
|
pass
|
|
9859
10301
|
|
|
9860
10302
|
if gpu_string:
|
|
9861
|
-
console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]
|
|
10303
|
+
console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]")
|
|
9862
10304
|
else:
|
|
9863
|
-
print_green(f"You have {cpu_count} CPUs available for the main process.
|
|
10305
|
+
print_green(f"You have {cpu_count} CPUs available for the main process.")
|
|
10306
|
+
|
|
10307
|
+
print_green(gs_string)
|
|
9864
10308
|
|
|
9865
10309
|
def write_args_overview_table() -> None:
|
|
9866
10310
|
table = Table(title="Arguments Overview")
|
|
@@ -10127,112 +10571,6 @@ def find_results_paths(base_path: str) -> list:
|
|
|
10127
10571
|
|
|
10128
10572
|
return list(set(found_paths))
|
|
10129
10573
|
|
|
10130
|
-
def post_job_calculate_pareto_front() -> None:
|
|
10131
|
-
if not args.calculate_pareto_front_of_job:
|
|
10132
|
-
return
|
|
10133
|
-
|
|
10134
|
-
failure = False
|
|
10135
|
-
|
|
10136
|
-
_paths_to_calculate = []
|
|
10137
|
-
|
|
10138
|
-
for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
|
|
10139
|
-
try:
|
|
10140
|
-
found_paths = find_results_paths(_path_to_calculate)
|
|
10141
|
-
|
|
10142
|
-
for _fp in found_paths:
|
|
10143
|
-
if _fp not in _paths_to_calculate:
|
|
10144
|
-
_paths_to_calculate.append(_fp)
|
|
10145
|
-
except (FileNotFoundError, NotADirectoryError) as e:
|
|
10146
|
-
print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
|
|
10147
|
-
|
|
10148
|
-
failure = True
|
|
10149
|
-
|
|
10150
|
-
for _path_to_calculate in _paths_to_calculate:
|
|
10151
|
-
for path_to_calculate in found_paths:
|
|
10152
|
-
if not job_calculate_pareto_front(path_to_calculate):
|
|
10153
|
-
failure = True
|
|
10154
|
-
|
|
10155
|
-
if failure:
|
|
10156
|
-
my_exit(24)
|
|
10157
|
-
|
|
10158
|
-
my_exit(0)
|
|
10159
|
-
|
|
10160
|
-
def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
|
|
10161
|
-
pf_start_time = time.time()
|
|
10162
|
-
|
|
10163
|
-
if not path_to_calculate:
|
|
10164
|
-
return False
|
|
10165
|
-
|
|
10166
|
-
global CURRENT_RUN_FOLDER
|
|
10167
|
-
global RESULT_CSV_FILE
|
|
10168
|
-
global arg_result_names
|
|
10169
|
-
|
|
10170
|
-
if not path_to_calculate:
|
|
10171
|
-
print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
|
|
10172
|
-
return False
|
|
10173
|
-
|
|
10174
|
-
if not os.path.exists(path_to_calculate):
|
|
10175
|
-
print_red(f"Path '{path_to_calculate}' does not exist")
|
|
10176
|
-
return False
|
|
10177
|
-
|
|
10178
|
-
ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
|
|
10179
|
-
|
|
10180
|
-
if not os.path.exists(ax_client_json):
|
|
10181
|
-
print_red(f"Path '{ax_client_json}' not found")
|
|
10182
|
-
return False
|
|
10183
|
-
|
|
10184
|
-
checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
|
|
10185
|
-
if not os.path.exists(checkpoint_file):
|
|
10186
|
-
print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
|
|
10187
|
-
return False
|
|
10188
|
-
|
|
10189
|
-
RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
10190
|
-
if not os.path.exists(RESULT_CSV_FILE):
|
|
10191
|
-
print_red(f"{RESULT_CSV_FILE} not found")
|
|
10192
|
-
return False
|
|
10193
|
-
|
|
10194
|
-
res_names = []
|
|
10195
|
-
|
|
10196
|
-
res_names_file = f"{path_to_calculate}/result_names.txt"
|
|
10197
|
-
if not os.path.exists(res_names_file):
|
|
10198
|
-
print_red(f"File '{res_names_file}' does not exist")
|
|
10199
|
-
return False
|
|
10200
|
-
|
|
10201
|
-
try:
|
|
10202
|
-
with open(res_names_file, "r", encoding="utf-8") as file:
|
|
10203
|
-
lines = file.readlines()
|
|
10204
|
-
except Exception as e:
|
|
10205
|
-
print_red(f"Error reading file '{res_names_file}': {e}")
|
|
10206
|
-
return False
|
|
10207
|
-
|
|
10208
|
-
for line in lines:
|
|
10209
|
-
entry = line.strip()
|
|
10210
|
-
if entry != "":
|
|
10211
|
-
res_names.append(entry)
|
|
10212
|
-
|
|
10213
|
-
if len(res_names) < 2:
|
|
10214
|
-
print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
|
|
10215
|
-
return False
|
|
10216
|
-
|
|
10217
|
-
load_username_to_args(path_to_calculate)
|
|
10218
|
-
|
|
10219
|
-
CURRENT_RUN_FOLDER = path_to_calculate
|
|
10220
|
-
|
|
10221
|
-
arg_result_names = res_names
|
|
10222
|
-
|
|
10223
|
-
load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
|
|
10224
|
-
|
|
10225
|
-
if experiment_parameters is None:
|
|
10226
|
-
return False
|
|
10227
|
-
|
|
10228
|
-
show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
|
|
10229
|
-
|
|
10230
|
-
pf_end_time = time.time()
|
|
10231
|
-
|
|
10232
|
-
print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
|
|
10233
|
-
|
|
10234
|
-
return True
|
|
10235
|
-
|
|
10236
10574
|
def set_arg_states_from_continue() -> None:
|
|
10237
10575
|
if args.continue_previous_job and not args.num_random_steps:
|
|
10238
10576
|
num_random_steps_file = f"{args.continue_previous_job}/state_files/num_random_steps"
|
|
@@ -10266,7 +10604,7 @@ def write_result_names_file() -> None:
|
|
|
10266
10604
|
except Exception as e:
|
|
10267
10605
|
print_red(f"Error trying to open file '{fn}': {e}")
|
|
10268
10606
|
|
|
10269
|
-
def run_program_once(params=None) -> None:
|
|
10607
|
+
def run_program_once(params: Optional[dict] = None) -> None:
|
|
10270
10608
|
if not args.run_program_once:
|
|
10271
10609
|
print_debug("[yellow]No setup script specified (run_program_once). Skipping setup.[/yellow]")
|
|
10272
10610
|
return
|
|
@@ -10275,27 +10613,27 @@ def run_program_once(params=None) -> None:
|
|
|
10275
10613
|
params = {}
|
|
10276
10614
|
|
|
10277
10615
|
if isinstance(args.run_program_once, str):
|
|
10278
|
-
command_str = args.run_program_once
|
|
10616
|
+
command_str = decode_if_base64(args.run_program_once)
|
|
10279
10617
|
for k, v in params.items():
|
|
10280
10618
|
placeholder = f"%({k})"
|
|
10281
10619
|
command_str = command_str.replace(placeholder, str(v))
|
|
10282
10620
|
|
|
10283
|
-
|
|
10284
|
-
|
|
10285
|
-
|
|
10286
|
-
|
|
10287
|
-
|
|
10288
|
-
|
|
10621
|
+
print(f"Executing command: [cyan]{command_str}[/cyan]")
|
|
10622
|
+
result = subprocess.run(command_str, shell=True, check=True)
|
|
10623
|
+
if result.returncode == 0:
|
|
10624
|
+
print("[bold green]Setup script completed successfully ✅[/bold green]")
|
|
10625
|
+
else:
|
|
10626
|
+
print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
|
|
10289
10627
|
|
|
10290
|
-
|
|
10628
|
+
my_exit(57)
|
|
10291
10629
|
|
|
10292
10630
|
elif isinstance(args.run_program_once, (list, tuple)):
|
|
10293
10631
|
with spinner("run_program_once: Executing command list: [cyan]{args.run_program_once}[/cyan]"):
|
|
10294
10632
|
result = subprocess.run(args.run_program_once, check=True)
|
|
10295
10633
|
if result.returncode == 0:
|
|
10296
|
-
|
|
10634
|
+
print("[bold green]Setup script completed successfully ✅[/bold green]")
|
|
10297
10635
|
else:
|
|
10298
|
-
|
|
10636
|
+
print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
|
|
10299
10637
|
|
|
10300
10638
|
my_exit(57)
|
|
10301
10639
|
|
|
@@ -10305,7 +10643,7 @@ def run_program_once(params=None) -> None:
|
|
|
10305
10643
|
my_exit(57)
|
|
10306
10644
|
|
|
10307
10645
|
def show_omniopt_call() -> None:
|
|
10308
|
-
def remove_ui_url(arg_str) -> str:
|
|
10646
|
+
def remove_ui_url(arg_str: str) -> str:
|
|
10309
10647
|
return re.sub(r'(?:--ui_url(?:=\S+)?(?:\s+\S+)?)', '', arg_str).strip()
|
|
10310
10648
|
|
|
10311
10649
|
original_argv = " ".join(sys.argv[1:])
|
|
@@ -10313,8 +10651,14 @@ def show_omniopt_call() -> None:
|
|
|
10313
10651
|
|
|
10314
10652
|
original_print(oo_call + " " + cleaned)
|
|
10315
10653
|
|
|
10654
|
+
if args.dependency is not None and args.dependency != "":
|
|
10655
|
+
print(f"Dependency: {args.dependency}")
|
|
10656
|
+
|
|
10657
|
+
if args.ui_url is not None and args.ui_url != "":
|
|
10658
|
+
print_yellow("--ui_url is deprecated. Do not use it anymore. It will be ignored and one day be removed.")
|
|
10659
|
+
|
|
10316
10660
|
def main() -> None:
|
|
10317
|
-
global RESULT_CSV_FILE,
|
|
10661
|
+
global RESULT_CSV_FILE, LOGFILE_DEBUG_GET_NEXT_TRIALS
|
|
10318
10662
|
|
|
10319
10663
|
check_if_has_random_steps()
|
|
10320
10664
|
|
|
@@ -10396,15 +10740,13 @@ def main() -> None:
|
|
|
10396
10740
|
exp_params = get_experiment_parameters(cli_params_experiment_parameters)
|
|
10397
10741
|
|
|
10398
10742
|
if exp_params is not None:
|
|
10399
|
-
|
|
10743
|
+
experiment_args, gpu_string, gpu_color = exp_params
|
|
10400
10744
|
print_debug(f"experiment_parameters: {experiment_parameters}")
|
|
10401
10745
|
|
|
10402
10746
|
set_orchestrator()
|
|
10403
10747
|
|
|
10404
10748
|
init_live_share()
|
|
10405
10749
|
|
|
10406
|
-
start_periodic_live_share()
|
|
10407
|
-
|
|
10408
10750
|
show_available_hardware_and_generation_strategy_string(gpu_string, gpu_color)
|
|
10409
10751
|
|
|
10410
10752
|
original_print(f"Run-Program: {global_vars['joined_run_program']}")
|
|
@@ -10419,6 +10761,8 @@ def main() -> None:
|
|
|
10419
10761
|
|
|
10420
10762
|
write_files_and_show_overviews()
|
|
10421
10763
|
|
|
10764
|
+
live_share()
|
|
10765
|
+
|
|
10422
10766
|
#if args.continue_previous_job:
|
|
10423
10767
|
# insert_jobs_from_csv(f"{args.continue_previous_job}/{RESULTS_CSV_FILENAME}")
|
|
10424
10768
|
|
|
@@ -10496,21 +10840,37 @@ def initialize_nvidia_logs() -> None:
|
|
|
10496
10840
|
def build_gui_url(config: argparse.Namespace) -> str:
|
|
10497
10841
|
base_url = get_base_url()
|
|
10498
10842
|
params = collect_params(config)
|
|
10499
|
-
|
|
10843
|
+
ret = f"{base_url}?{urlencode(params, doseq=True)}"
|
|
10844
|
+
|
|
10845
|
+
return ret
|
|
10846
|
+
|
|
10847
|
+
def get_result_names_for_url(value: List) -> str:
|
|
10848
|
+
d = dict(v.split("=", 1) if "=" in v else (v, "min") for v in value)
|
|
10849
|
+
s = " ".join(f"{k}={v}" for k, v in d.items())
|
|
10850
|
+
|
|
10851
|
+
return s
|
|
10500
10852
|
|
|
10501
10853
|
def collect_params(config: argparse.Namespace) -> dict:
|
|
10502
10854
|
params = {}
|
|
10855
|
+
user_home = os.path.expanduser("~")
|
|
10856
|
+
|
|
10503
10857
|
for attr, value in vars(config).items():
|
|
10504
10858
|
if attr == "run_program":
|
|
10505
10859
|
params[attr] = global_vars["joined_run_program"]
|
|
10860
|
+
elif attr == "result_names" and value:
|
|
10861
|
+
params[attr] = get_result_names_for_url(value)
|
|
10506
10862
|
elif attr == "parameter" and value is not None:
|
|
10507
10863
|
params.update(process_parameters(config.parameter))
|
|
10864
|
+
elif attr == "root_venv_dir":
|
|
10865
|
+
if value is not None and os.path.abspath(value) != os.path.abspath(user_home):
|
|
10866
|
+
params[attr] = value
|
|
10508
10867
|
elif isinstance(value, bool):
|
|
10509
10868
|
params[attr] = int(value)
|
|
10510
10869
|
elif isinstance(value, list):
|
|
10511
10870
|
params[attr] = value
|
|
10512
10871
|
elif value is not None:
|
|
10513
10872
|
params[attr] = value
|
|
10873
|
+
|
|
10514
10874
|
return params
|
|
10515
10875
|
|
|
10516
10876
|
def process_parameters(parameters: list) -> dict:
|
|
@@ -10563,7 +10923,7 @@ def get_base_url() -> str:
|
|
|
10563
10923
|
def write_ui_url() -> None:
|
|
10564
10924
|
url = build_gui_url(args)
|
|
10565
10925
|
with open(get_current_run_folder("ui_url.txt"), mode="a", encoding="utf-8") as myfile:
|
|
10566
|
-
myfile.write(
|
|
10926
|
+
myfile.write(url)
|
|
10567
10927
|
|
|
10568
10928
|
def handle_random_steps() -> None:
|
|
10569
10929
|
if args.parameter and args.continue_previous_job and random_steps <= 0:
|
|
@@ -10622,8 +10982,6 @@ def run_search_with_progress_bar() -> None:
|
|
|
10622
10982
|
wait_for_jobs_to_complete()
|
|
10623
10983
|
|
|
10624
10984
|
def complex_tests(_program_name: str, wanted_stderr: str, wanted_exit_code: int, wanted_signal: Union[int, None], res_is_none: bool = False) -> int:
|
|
10625
|
-
#print_yellow(f"Test suite: {_program_name}")
|
|
10626
|
-
|
|
10627
10985
|
nr_errors: int = 0
|
|
10628
10986
|
|
|
10629
10987
|
program_path: str = f"./.tests/test_wronggoing_stuff.bin/bin/{_program_name}"
|
|
@@ -10697,7 +11055,7 @@ def test_find_paths(program_code: str) -> int:
|
|
|
10697
11055
|
for i in files:
|
|
10698
11056
|
if i not in string:
|
|
10699
11057
|
if os.path.exists(i):
|
|
10700
|
-
print("Missing {i} in find_file_paths string!")
|
|
11058
|
+
print(f"Missing {i} in find_file_paths string!")
|
|
10701
11059
|
nr_errors += 1
|
|
10702
11060
|
|
|
10703
11061
|
return nr_errors
|
|
@@ -11078,17 +11436,16 @@ Exit-Code: 159
|
|
|
11078
11436
|
|
|
11079
11437
|
my_exit(nr_errors)
|
|
11080
11438
|
|
|
11081
|
-
def
|
|
11439
|
+
def main_wrapper() -> None:
|
|
11082
11440
|
print(f"Run-UUID: {run_uuid}")
|
|
11083
11441
|
|
|
11084
11442
|
auto_wrap_namespace(globals())
|
|
11085
11443
|
|
|
11086
11444
|
print_logo()
|
|
11087
11445
|
|
|
11088
|
-
start_logging_daemon()
|
|
11089
|
-
|
|
11090
11446
|
fool_linter(args.num_cpus_main_job)
|
|
11091
11447
|
fool_linter(args.flame_graph)
|
|
11448
|
+
fool_linter(args.memray)
|
|
11092
11449
|
|
|
11093
11450
|
with warnings.catch_warnings():
|
|
11094
11451
|
warnings.simplefilter("ignore")
|
|
@@ -11123,7 +11480,7 @@ def main_outside() -> None:
|
|
|
11123
11480
|
def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
|
|
11124
11481
|
pattern = re.compile(regex) if regex else None
|
|
11125
11482
|
|
|
11126
|
-
def wrapped(*args, **kwargs):
|
|
11483
|
+
def wrapped(*args: Any, **kwargs: Any) -> None:
|
|
11127
11484
|
if pattern and not pattern.search(func.__name__):
|
|
11128
11485
|
return func(*args, **kwargs)
|
|
11129
11486
|
|
|
@@ -11145,6 +11502,9 @@ def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
|
|
|
11145
11502
|
def auto_wrap_namespace(namespace: Any) -> Any:
|
|
11146
11503
|
enable_beartype = any(os.getenv(v) for v in ("ENABLE_BEARTYPE", "CI"))
|
|
11147
11504
|
|
|
11505
|
+
if args.beartype:
|
|
11506
|
+
enable_beartype = True
|
|
11507
|
+
|
|
11148
11508
|
excluded_functions = {
|
|
11149
11509
|
"log_time_and_memory_wrapper",
|
|
11150
11510
|
"collect_runtime_stats",
|
|
@@ -11153,8 +11513,6 @@ def auto_wrap_namespace(namespace: Any) -> Any:
|
|
|
11153
11513
|
"_record_stats",
|
|
11154
11514
|
"_open",
|
|
11155
11515
|
"_check_memory_leak",
|
|
11156
|
-
"start_periodic_live_share",
|
|
11157
|
-
"start_logging_daemon",
|
|
11158
11516
|
"get_current_run_folder",
|
|
11159
11517
|
"show_func_name_wrapper"
|
|
11160
11518
|
}
|
|
@@ -11180,7 +11538,7 @@ def auto_wrap_namespace(namespace: Any) -> Any:
|
|
|
11180
11538
|
|
|
11181
11539
|
if __name__ == "__main__":
|
|
11182
11540
|
try:
|
|
11183
|
-
|
|
11541
|
+
main_wrapper()
|
|
11184
11542
|
except (SignalUSR, SignalINT, SignalCONT) as e:
|
|
11185
|
-
print_red(f"
|
|
11543
|
+
print_red(f"main_wrapper failed with exception {e}")
|
|
11186
11544
|
end_program(True)
|