omniopt2 8178__py3-none-any.whl → 9171__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- .gitignore +2 -0
- .helpers.py +0 -9
- .omniopt.py +1717 -1151
- .omniopt_plot_scatter.py +1 -1
- .omniopt_plot_scatter_hex.py +1 -1
- .omniopt_plot_trial_index_result.py +1 -0
- .pareto.py +134 -0
- .shellscript_functions +24 -15
- .tests/pylint.rc +0 -4
- .tpe.py +4 -3
- README.md +1 -1
- omniopt +92 -55
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.helpers.py +0 -9
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt.py +1717 -1151
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter.py +1 -1
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter_hex.py +1 -1
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_trial_index_result.py +1 -0
- omniopt2-9171.data/data/bin/.pareto.py +134 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.shellscript_functions +24 -15
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.tpe.py +4 -3
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt +92 -55
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt_docker +60 -60
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt_plot +1 -1
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/pylint.rc +0 -4
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/requirements.txt +3 -4
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/test_requirements.txt +1 -0
- {omniopt2-8178.dist-info → omniopt2-9171.dist-info}/METADATA +6 -6
- omniopt2-9171.dist-info/RECORD +73 -0
- omniopt2.egg-info/PKG-INFO +6 -6
- omniopt2.egg-info/SOURCES.txt +1 -0
- omniopt2.egg-info/requires.txt +4 -4
- omniopt_docker +60 -60
- omniopt_plot +1 -1
- pyproject.toml +1 -1
- requirements.txt +3 -4
- test_requirements.txt +1 -0
- omniopt2-8178.dist-info/RECORD +0 -71
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.colorfunctions.sh +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.general.sh +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_cpu_ram_usage.py +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_general.py +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_gpu_usage.py +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_kde.py +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter_generation_method.py +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_time_and_exit_code.py +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_worker.py +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.random_generator.py +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/LICENSE +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/apt-dependencies.txt +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt_evaluate +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt_share +0 -0
- {omniopt2-8178.data → omniopt2-9171.data}/data/bin/setup.py +0 -0
- {omniopt2-8178.dist-info → omniopt2-9171.dist-info}/WHEEL +0 -0
- {omniopt2-8178.dist-info → omniopt2-9171.dist-info}/licenses/LICENSE +0 -0
- {omniopt2-8178.dist-info → omniopt2-9171.dist-info}/top_level.txt +0 -0
.omniopt.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
#from mayhemmonkey import MayhemMonkey
|
|
4
4
|
#mayhemmonkey = MayhemMonkey()
|
|
5
|
-
#mayhemmonkey.set_function_fail_after_count("open",
|
|
5
|
+
#mayhemmonkey.set_function_fail_after_count("open", 10)
|
|
6
6
|
#mayhemmonkey.set_function_error_rate("open", 0.1)
|
|
7
7
|
#mayhemmonkey.set_function_group_error_rate(["io", "math"], 0.8)
|
|
8
8
|
#mayhemmonkey.install_faulty()
|
|
@@ -26,10 +26,16 @@ import traceback
|
|
|
26
26
|
import inspect
|
|
27
27
|
import tracemalloc
|
|
28
28
|
import resource
|
|
29
|
+
from urllib.parse import urlencode
|
|
29
30
|
import psutil
|
|
30
31
|
|
|
31
32
|
FORCE_EXIT: bool = False
|
|
32
33
|
|
|
34
|
+
LAST_LOG_TIME: int = 0
|
|
35
|
+
last_msg_progressbar = ""
|
|
36
|
+
last_msg_raw = None
|
|
37
|
+
last_lock_print_debug = threading.Lock()
|
|
38
|
+
|
|
33
39
|
def force_exit(signal_number: Any, frame: Any) -> Any:
|
|
34
40
|
global FORCE_EXIT
|
|
35
41
|
|
|
@@ -59,10 +65,10 @@ _last_count_time = 0
|
|
|
59
65
|
_last_count_result: tuple[int, str] = (0, "")
|
|
60
66
|
|
|
61
67
|
_total_time = 0.0
|
|
62
|
-
_func_times = defaultdict(float)
|
|
63
|
-
_func_mem = defaultdict(float)
|
|
64
|
-
_func_call_paths = defaultdict(Counter)
|
|
65
|
-
_last_mem = defaultdict(float)
|
|
68
|
+
_func_times: defaultdict = defaultdict(float)
|
|
69
|
+
_func_mem: defaultdict = defaultdict(float)
|
|
70
|
+
_func_call_paths: defaultdict = defaultdict(Counter)
|
|
71
|
+
_last_mem: defaultdict = defaultdict(float)
|
|
66
72
|
_leak_threshold_mb = 10.0
|
|
67
73
|
generation_strategy_names: list = []
|
|
68
74
|
default_max_range_difference: int = 1000000
|
|
@@ -82,6 +88,7 @@ log_nr_gen_jobs: list[int] = []
|
|
|
82
88
|
generation_strategy_human_readable: str = ""
|
|
83
89
|
oo_call: str = "./omniopt"
|
|
84
90
|
progress_bar_length: int = 0
|
|
91
|
+
worker_usage_file = 'worker_usage.csv'
|
|
85
92
|
|
|
86
93
|
if os.environ.get("CUSTOM_VIRTUAL_ENV") == "1":
|
|
87
94
|
oo_call = "omniopt"
|
|
@@ -97,7 +104,7 @@ joined_valid_occ_types: str = ", ".join(valid_occ_types)
|
|
|
97
104
|
SUPPORTED_MODELS: list = ["SOBOL", "FACTORIAL", "SAASBO", "BOTORCH_MODULAR", "UNIFORM", "BO_MIXED", "RANDOMFOREST", "EXTERNAL_GENERATOR", "PSEUDORANDOM", "TPE"]
|
|
98
105
|
joined_supported_models: str = ", ".join(SUPPORTED_MODELS)
|
|
99
106
|
|
|
100
|
-
special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid"]
|
|
107
|
+
special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid", "runtime", "status"]
|
|
101
108
|
|
|
102
109
|
IGNORABLE_COLUMNS: list = ["start_time", "end_time", "hostname", "signal", "exit_code", "run_time", "program_string"] + special_col_names
|
|
103
110
|
|
|
@@ -151,12 +158,6 @@ try:
|
|
|
151
158
|
message="Ax currently requires a sqlalchemy version below 2.0.*",
|
|
152
159
|
)
|
|
153
160
|
|
|
154
|
-
warnings.filterwarnings(
|
|
155
|
-
"ignore",
|
|
156
|
-
category=RuntimeWarning,
|
|
157
|
-
message="coroutine 'start_logging_daemon' was never awaited"
|
|
158
|
-
)
|
|
159
|
-
|
|
160
161
|
with spinner("Importing argparse..."):
|
|
161
162
|
import argparse
|
|
162
163
|
|
|
@@ -202,6 +203,9 @@ try:
|
|
|
202
203
|
with spinner("Importing rich.pretty..."):
|
|
203
204
|
from rich.pretty import pprint
|
|
204
205
|
|
|
206
|
+
with spinner("Importing pformat..."):
|
|
207
|
+
from pprint import pformat
|
|
208
|
+
|
|
205
209
|
with spinner("Importing rich.prompt..."):
|
|
206
210
|
from rich.prompt import Prompt, FloatPrompt, IntPrompt
|
|
207
211
|
|
|
@@ -235,9 +239,6 @@ try:
|
|
|
235
239
|
with spinner("Importing uuid..."):
|
|
236
240
|
import uuid
|
|
237
241
|
|
|
238
|
-
#with spinner("Importing qrcode..."):
|
|
239
|
-
# import qrcode
|
|
240
|
-
|
|
241
242
|
with spinner("Importing cowsay..."):
|
|
242
243
|
import cowsay
|
|
243
244
|
|
|
@@ -268,6 +269,9 @@ try:
|
|
|
268
269
|
with spinner("Importing beartype..."):
|
|
269
270
|
from beartype import beartype
|
|
270
271
|
|
|
272
|
+
with spinner("Importing rendering stuff..."):
|
|
273
|
+
from ax.plot.base import AxPlotConfig
|
|
274
|
+
|
|
271
275
|
with spinner("Importing statistics..."):
|
|
272
276
|
import statistics
|
|
273
277
|
|
|
@@ -331,7 +335,7 @@ def show_func_name_wrapper(func: F) -> F:
|
|
|
331
335
|
|
|
332
336
|
return result
|
|
333
337
|
|
|
334
|
-
return wrapper
|
|
338
|
+
return wrapper # type: ignore
|
|
335
339
|
|
|
336
340
|
def log_time_and_memory_wrapper(func: F) -> F:
|
|
337
341
|
@functools.wraps(func)
|
|
@@ -358,7 +362,7 @@ def log_time_and_memory_wrapper(func: F) -> F:
|
|
|
358
362
|
|
|
359
363
|
return result
|
|
360
364
|
|
|
361
|
-
return wrapper
|
|
365
|
+
return wrapper # type: ignore
|
|
362
366
|
|
|
363
367
|
def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: float, mem_peak: float) -> None:
|
|
364
368
|
global _total_time
|
|
@@ -379,16 +383,9 @@ def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: fl
|
|
|
379
383
|
call_path_str = " -> ".join(short_stack)
|
|
380
384
|
_func_call_paths[func_name][call_path_str] += 1
|
|
381
385
|
|
|
382
|
-
print(
|
|
383
|
-
|
|
384
|
-
f"(total {percent_if_added:.1f}% of tracked time)"
|
|
385
|
-
)
|
|
386
|
-
print(
|
|
387
|
-
f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, "
|
|
388
|
-
f"diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB"
|
|
389
|
-
)
|
|
386
|
+
print(f"Function '{func_name}' took {elapsed:.4f}s (total {percent_if_added:.1f}% of tracked time)")
|
|
387
|
+
print(f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB")
|
|
390
388
|
|
|
391
|
-
# NEU: Runtime Stats
|
|
392
389
|
runtime_stats = collect_runtime_stats()
|
|
393
390
|
print("=== Runtime Stats ===")
|
|
394
391
|
print(f"RSS: {runtime_stats['rss_MB']:.2f} MB, VMS: {runtime_stats['vms_MB']:.2f} MB")
|
|
@@ -447,7 +444,7 @@ RESET: str = "\033[0m"
|
|
|
447
444
|
|
|
448
445
|
uuid_regex: Pattern = re.compile(r"^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[89aAbB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$")
|
|
449
446
|
|
|
450
|
-
worker_generator_uuid: str = uuid.uuid4()
|
|
447
|
+
worker_generator_uuid: str = str(uuid.uuid4())
|
|
451
448
|
|
|
452
449
|
new_uuid: str = str(uuid.uuid4())
|
|
453
450
|
run_uuid: str = os.getenv("RUN_UUID", new_uuid)
|
|
@@ -476,7 +473,7 @@ def get_current_run_folder(name: Optional[str] = None) -> str:
|
|
|
476
473
|
|
|
477
474
|
return CURRENT_RUN_FOLDER
|
|
478
475
|
|
|
479
|
-
def get_state_file_name(name) -> str:
|
|
476
|
+
def get_state_file_name(name: str) -> str:
|
|
480
477
|
state_files_folder = f"{get_current_run_folder()}/state_files/"
|
|
481
478
|
makedirs(state_files_folder)
|
|
482
479
|
|
|
@@ -500,6 +497,24 @@ try:
|
|
|
500
497
|
dier: FunctionType = helpers.dier
|
|
501
498
|
is_equal: FunctionType = helpers.is_equal
|
|
502
499
|
is_not_equal: FunctionType = helpers.is_not_equal
|
|
500
|
+
with spinner("Importing pareto..."):
|
|
501
|
+
pareto_file: str = f"{script_dir}/.pareto.py"
|
|
502
|
+
spec = importlib.util.spec_from_file_location(
|
|
503
|
+
name="pareto",
|
|
504
|
+
location=pareto_file,
|
|
505
|
+
)
|
|
506
|
+
if spec is not None and spec.loader is not None:
|
|
507
|
+
pareto = importlib.util.module_from_spec(spec)
|
|
508
|
+
spec.loader.exec_module(pareto)
|
|
509
|
+
else:
|
|
510
|
+
raise ImportError(f"Could not load module from {pareto_file}")
|
|
511
|
+
|
|
512
|
+
pareto_front_table_filter_rows: FunctionType = pareto.pareto_front_table_filter_rows
|
|
513
|
+
pareto_front_table_add_headers: FunctionType = pareto.pareto_front_table_add_headers
|
|
514
|
+
pareto_front_table_add_rows: FunctionType = pareto.pareto_front_table_add_rows
|
|
515
|
+
pareto_front_filter_complete_points: FunctionType = pareto.pareto_front_filter_complete_points
|
|
516
|
+
pareto_front_select_pareto_points: FunctionType = pareto.pareto_front_select_pareto_points
|
|
517
|
+
|
|
503
518
|
except KeyboardInterrupt:
|
|
504
519
|
print("You pressed CTRL-c while importing the helpers file")
|
|
505
520
|
sys.exit(0)
|
|
@@ -523,15 +538,13 @@ def is_slurm_job() -> bool:
|
|
|
523
538
|
return True
|
|
524
539
|
return False
|
|
525
540
|
|
|
526
|
-
def _sleep(t: int) ->
|
|
541
|
+
def _sleep(t: Union[float, int]) -> None:
|
|
527
542
|
if args is not None and not args.no_sleep:
|
|
528
543
|
try:
|
|
529
544
|
time.sleep(t)
|
|
530
545
|
except KeyboardInterrupt:
|
|
531
546
|
pass
|
|
532
547
|
|
|
533
|
-
return t
|
|
534
|
-
|
|
535
548
|
LOG_DIR: str = "logs"
|
|
536
549
|
makedirs(LOG_DIR)
|
|
537
550
|
|
|
@@ -544,6 +557,17 @@ logfile_worker_creation_logs: str = f'{log_uuid_dir}_worker_creation_logs'
|
|
|
544
557
|
logfile_trial_index_to_param_logs: str = f'{log_uuid_dir}_trial_index_to_param_logs'
|
|
545
558
|
LOGFILE_DEBUG_GET_NEXT_TRIALS: Union[str, None] = None
|
|
546
559
|
|
|
560
|
+
def error_without_print(text: str) -> None:
|
|
561
|
+
print_debug(text)
|
|
562
|
+
|
|
563
|
+
if get_current_run_folder():
|
|
564
|
+
try:
|
|
565
|
+
with open(get_current_run_folder("oo_errors.txt"), mode="a", encoding="utf-8") as myfile:
|
|
566
|
+
myfile.write(text + "\n\n")
|
|
567
|
+
except (OSError, FileNotFoundError) as e:
|
|
568
|
+
helpers.print_color("red", f"Error: {e}. This may mean that the {get_current_run_folder()} was deleted during the run. Could not write '{text} to {get_current_run_folder()}/oo_errors.txt'")
|
|
569
|
+
sys.exit(99)
|
|
570
|
+
|
|
547
571
|
def print_red(text: str) -> None:
|
|
548
572
|
helpers.print_color("red", text)
|
|
549
573
|
|
|
@@ -577,20 +601,22 @@ def _debug(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) ->
|
|
|
577
601
|
def _get_debug_json(time_str: str, msg: str) -> str:
|
|
578
602
|
function_stack = []
|
|
579
603
|
try:
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
func_name = frame.f_code
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
604
|
+
cf = inspect.currentframe()
|
|
605
|
+
if cf:
|
|
606
|
+
frame = cf.f_back # skip _get_debug_json
|
|
607
|
+
while frame:
|
|
608
|
+
func_name = _function_name_cache.get(frame.f_code)
|
|
609
|
+
if func_name is None:
|
|
610
|
+
func_name = frame.f_code.co_name
|
|
611
|
+
_function_name_cache[frame.f_code] = func_name
|
|
612
|
+
|
|
613
|
+
if func_name not in ("<module>", "print_debug", "wrapper"):
|
|
614
|
+
function_stack.append({
|
|
615
|
+
"function": func_name,
|
|
616
|
+
"line_number": frame.f_lineno
|
|
617
|
+
})
|
|
618
|
+
|
|
619
|
+
frame = frame.f_back
|
|
594
620
|
except (SignalUSR, SignalINT, SignalCONT):
|
|
595
621
|
print_red("\n⚠ You pressed CTRL-C. This is ignored in _get_debug_json.")
|
|
596
622
|
|
|
@@ -599,20 +625,46 @@ def _get_debug_json(time_str: str, msg: str) -> str:
|
|
|
599
625
|
separators=(",", ":") # no pretty indent → smaller, faster
|
|
600
626
|
).replace('\r', '').replace('\n', '')
|
|
601
627
|
|
|
602
|
-
def
|
|
603
|
-
|
|
628
|
+
def print_stack_paths() -> None:
|
|
629
|
+
stack = inspect.stack()[1:] # skip current frame
|
|
630
|
+
stack.reverse()
|
|
604
631
|
|
|
605
|
-
|
|
632
|
+
last_filename = None
|
|
633
|
+
for depth, frame_info in enumerate(stack):
|
|
634
|
+
filename = frame_info.filename
|
|
635
|
+
lineno = frame_info.lineno
|
|
636
|
+
func_name = frame_info.function
|
|
606
637
|
|
|
607
|
-
|
|
638
|
+
if func_name in ["<module>", "print_debug"]:
|
|
639
|
+
continue
|
|
608
640
|
|
|
609
|
-
|
|
641
|
+
if filename != last_filename:
|
|
642
|
+
print(filename)
|
|
643
|
+
last_filename = filename
|
|
644
|
+
indent = ""
|
|
645
|
+
else:
|
|
646
|
+
indent = " " * 4 * depth
|
|
647
|
+
|
|
648
|
+
print(f"{indent}↳ {func_name}:{lineno}")
|
|
649
|
+
|
|
650
|
+
def print_debug(msg: str) -> None:
|
|
651
|
+
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
652
|
+
|
|
653
|
+
stack = traceback.extract_stack()[:-1]
|
|
654
|
+
stack_funcs = [frame.name for frame in stack]
|
|
610
655
|
|
|
611
|
-
|
|
656
|
+
if "args" in globals() and args and hasattr(args, "debug_stack_regex") and args.debug_stack_regex:
|
|
657
|
+
matched = any(any(re.match(regex, func) for regex in args.debug_stack_regex) for func in stack_funcs)
|
|
658
|
+
if matched:
|
|
659
|
+
print(f"DEBUG: {msg}")
|
|
660
|
+
print_stack_paths()
|
|
661
|
+
|
|
662
|
+
stack_trace_element = _get_debug_json(time_str, msg)
|
|
663
|
+
_debug(stack_trace_element)
|
|
612
664
|
|
|
613
665
|
try:
|
|
614
666
|
with open(logfile_bare, mode='a', encoding="utf-8") as f:
|
|
615
|
-
original_print(
|
|
667
|
+
original_print(msg, file=f)
|
|
616
668
|
except FileNotFoundError:
|
|
617
669
|
print_red("It seems like the run's folder was deleted during the run. Cannot continue.")
|
|
618
670
|
sys.exit(99)
|
|
@@ -666,11 +718,14 @@ def my_exit(_code: int = 0) -> None:
|
|
|
666
718
|
if is_skip_search() and os.getenv("SKIP_SEARCH_EXIT_CODE"):
|
|
667
719
|
skip_search_exit_code = os.getenv("SKIP_SEARCH_EXIT_CODE")
|
|
668
720
|
|
|
721
|
+
skip_search_exit_code_found = None
|
|
722
|
+
|
|
669
723
|
try:
|
|
670
|
-
|
|
724
|
+
if skip_search_exit_code_found is not None:
|
|
725
|
+
skip_search_exit_code_found = int(skip_search_exit_code)
|
|
726
|
+
sys.exit(skip_search_exit_code_found)
|
|
671
727
|
except ValueError:
|
|
672
|
-
|
|
673
|
-
sys.exit(_code)
|
|
728
|
+
print_debug(f"Trying to look for SKIP_SEARCH_EXIT_CODE failed. Exiting with original exit code {_code}")
|
|
674
729
|
|
|
675
730
|
sys.exit(_code)
|
|
676
731
|
|
|
@@ -730,6 +785,7 @@ _DEFAULT_SPECIALS: Dict[str, Any] = {
|
|
|
730
785
|
class ConfigLoader:
|
|
731
786
|
runtime_debug: bool
|
|
732
787
|
show_func_name: bool
|
|
788
|
+
debug_stack_regex: str
|
|
733
789
|
number_of_generators: int
|
|
734
790
|
disable_previous_job_constraint: bool
|
|
735
791
|
save_to_database: bool
|
|
@@ -770,11 +826,13 @@ class ConfigLoader:
|
|
|
770
826
|
parameter: Optional[List[str]]
|
|
771
827
|
experiment_constraints: Optional[List[str]]
|
|
772
828
|
main_process_gb: int
|
|
829
|
+
beartype: bool
|
|
773
830
|
worker_timeout: int
|
|
774
831
|
slurm_signal_delay_s: int
|
|
775
832
|
gridsearch: bool
|
|
776
833
|
auto_exclude_defective_hosts: bool
|
|
777
834
|
debug: bool
|
|
835
|
+
debug_stack_trace_regex: Optional[str]
|
|
778
836
|
num_restarts: int
|
|
779
837
|
raw_samples: int
|
|
780
838
|
show_generate_time_table: bool
|
|
@@ -789,6 +847,7 @@ class ConfigLoader:
|
|
|
789
847
|
run_program_once: str
|
|
790
848
|
mem_gb: int
|
|
791
849
|
flame_graph: bool
|
|
850
|
+
memray: bool
|
|
792
851
|
continue_previous_job: Optional[str]
|
|
793
852
|
calculate_pareto_front_of_job: Optional[List[str]]
|
|
794
853
|
revert_to_random_when_seemingly_exhausted: bool
|
|
@@ -950,6 +1009,7 @@ class ConfigLoader:
|
|
|
950
1009
|
debug.add_argument('--verbose_break_run_search_table', help='Verbose logging for break_run_search', action='store_true', default=False)
|
|
951
1010
|
debug.add_argument('--debug', help='Enable debugging', action='store_true', default=False)
|
|
952
1011
|
debug.add_argument('--flame_graph', help='Enable flame-graphing. Makes everything slower, but creates a flame graph', action='store_true', default=False)
|
|
1012
|
+
debug.add_argument('--memray', help='Use memray to show memory usage', action='store_true', default=False)
|
|
953
1013
|
debug.add_argument('--no_sleep', help='Disables sleeping for fast job generation (not to be used on HPC)', action='store_true', default=False)
|
|
954
1014
|
debug.add_argument('--tests', help='Run simple internal tests', action='store_true', default=False)
|
|
955
1015
|
debug.add_argument('--show_worker_percentage_table_at_end', help='Show a table of percentage of usage of max worker over time', action='store_true', default=False)
|
|
@@ -961,7 +1021,10 @@ class ConfigLoader:
|
|
|
961
1021
|
debug.add_argument('--just_return_defaults', help='Just return defaults in dryrun', action='store_true', default=False)
|
|
962
1022
|
debug.add_argument('--prettyprint', help='Shows stdout and stderr in a pretty printed format', action='store_true', default=False)
|
|
963
1023
|
debug.add_argument('--runtime_debug', help='Logs which functions use most of the time', action='store_true', default=False)
|
|
1024
|
+
debug.add_argument('--debug_stack_regex', help='Only print debug messages if call stack matches any regex', type=str, default='')
|
|
1025
|
+
debug.add_argument('--debug_stack_trace_regex', help='Show compact call stack with arrows if any function in stack matches regex', type=str, default=None)
|
|
964
1026
|
debug.add_argument('--show_func_name', help='Show func name before each execution and when it is done', action='store_true', default=False)
|
|
1027
|
+
debug.add_argument('--beartype', help='Use beartype', action='store_true', default=False)
|
|
965
1028
|
|
|
966
1029
|
def load_config(self: Any, config_path: str, file_format: str) -> dict:
|
|
967
1030
|
if not os.path.isfile(config_path):
|
|
@@ -1199,11 +1262,15 @@ for _rn in args.result_names:
|
|
|
1199
1262
|
_key = _rn
|
|
1200
1263
|
_min_or_max = __default_min_max
|
|
1201
1264
|
|
|
1265
|
+
_min_or_max = re.sub(r"'", "", _min_or_max)
|
|
1266
|
+
|
|
1202
1267
|
if _min_or_max not in ["min", "max"]:
|
|
1203
1268
|
if _min_or_max:
|
|
1204
1269
|
print_yellow(f"Value for determining whether to minimize or maximize was neither 'min' nor 'max' for key '{_key}', but '{_min_or_max}'. It will be set to the default, which is '{__default_min_max}' instead.")
|
|
1205
1270
|
_min_or_max = __default_min_max
|
|
1206
1271
|
|
|
1272
|
+
_key = re.sub(r"'", "", _key)
|
|
1273
|
+
|
|
1207
1274
|
if _key in arg_result_names:
|
|
1208
1275
|
console.print(f"[red]The --result_names option '{_key}' was specified multiple times![/]")
|
|
1209
1276
|
sys.exit(50)
|
|
@@ -1304,27 +1371,14 @@ try:
|
|
|
1304
1371
|
with spinner("Importing ExternalGenerationNode..."):
|
|
1305
1372
|
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
|
|
1306
1373
|
|
|
1307
|
-
with spinner("Importing
|
|
1308
|
-
from ax.generation_strategy.transition_criterion import
|
|
1374
|
+
with spinner("Importing MinTrials..."):
|
|
1375
|
+
from ax.generation_strategy.transition_criterion import MinTrials
|
|
1309
1376
|
|
|
1310
1377
|
with spinner("Importing GeneratorSpec..."):
|
|
1311
1378
|
from ax.generation_strategy.generator_spec import GeneratorSpec
|
|
1312
1379
|
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
# import ax.generation_strategy.generation_node
|
|
1316
|
-
|
|
1317
|
-
# with spinner("Fallback: Importing GenerationStep, GenerationStrategy from ax.generation_strategy..."):
|
|
1318
|
-
# from ax.generation_strategy.generation_node import GenerationNode, GenerationStep
|
|
1319
|
-
|
|
1320
|
-
# with spinner("Fallback: Importing ExternalGenerationNode..."):
|
|
1321
|
-
# from ax.generation_strategy.external_generation_node import ExternalGenerationNode
|
|
1322
|
-
|
|
1323
|
-
# with spinner("Fallback: Importing MaxTrials..."):
|
|
1324
|
-
# from ax.generation_strategy.transition_criterion import MaxTrials
|
|
1325
|
-
|
|
1326
|
-
with spinner("Importing Models from ax.generation_strategy.registry..."):
|
|
1327
|
-
from ax.adapter.registry import Models
|
|
1380
|
+
with spinner("Importing Generators from ax.generation_strategy.registry..."):
|
|
1381
|
+
from ax.adapter.registry import Generators
|
|
1328
1382
|
|
|
1329
1383
|
with spinner("Importing get_pending_observation_features..."):
|
|
1330
1384
|
from ax.core.utils import get_pending_observation_features
|
|
@@ -1410,7 +1464,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1410
1464
|
def __init__(self: Any, regressor_options: Dict[str, Any] = {}, seed: Optional[int] = None, num_samples: int = 1) -> None:
|
|
1411
1465
|
print_debug("Initializing RandomForestGenerationNode...")
|
|
1412
1466
|
t_init_start = time.monotonic()
|
|
1413
|
-
super().__init__(
|
|
1467
|
+
super().__init__(name="RANDOMFOREST")
|
|
1414
1468
|
self.num_samples: int = num_samples
|
|
1415
1469
|
self.seed: int = seed
|
|
1416
1470
|
|
|
@@ -1430,6 +1484,9 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1430
1484
|
def update_generator_state(self: Any, experiment: Experiment, data: Data) -> None:
|
|
1431
1485
|
search_space = experiment.search_space
|
|
1432
1486
|
parameter_names = list(search_space.parameters.keys())
|
|
1487
|
+
if experiment.optimization_config is None:
|
|
1488
|
+
print_red("Error: update_generator_state is None")
|
|
1489
|
+
return
|
|
1433
1490
|
metric_names = list(experiment.optimization_config.metrics.keys())
|
|
1434
1491
|
|
|
1435
1492
|
completed_trials = [
|
|
@@ -1441,7 +1498,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1441
1498
|
y = np.zeros([num_completed_trials, 1])
|
|
1442
1499
|
|
|
1443
1500
|
for t_idx, trial in enumerate(completed_trials):
|
|
1444
|
-
trial_parameters = trial.
|
|
1501
|
+
trial_parameters = trial.arms[t_idx].parameters
|
|
1445
1502
|
x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
|
|
1446
1503
|
trial_df = data.df[data.df["trial_index"] == trial.index]
|
|
1447
1504
|
y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]]["mean"].item()
|
|
@@ -1581,10 +1638,18 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1581
1638
|
def _format_best_sample(self: Any, best_sample: TParameterization, reverse_choice_map: dict) -> None:
|
|
1582
1639
|
for name in best_sample.keys():
|
|
1583
1640
|
param = self.parameters.get(name)
|
|
1641
|
+
best_sample_by_name = best_sample[name]
|
|
1642
|
+
|
|
1584
1643
|
if isinstance(param, RangeParameter) and param.parameter_type == ParameterType.INT:
|
|
1585
|
-
|
|
1644
|
+
if best_sample_by_name is not None:
|
|
1645
|
+
best_sample[name] = int(round(float(best_sample_by_name)))
|
|
1646
|
+
else:
|
|
1647
|
+
print_debug("best_sample_by_name was empty")
|
|
1586
1648
|
elif isinstance(param, ChoiceParameter):
|
|
1587
|
-
|
|
1649
|
+
if best_sample_by_name is not None:
|
|
1650
|
+
best_sample[name] = str(reverse_choice_map.get(int(best_sample_by_name)))
|
|
1651
|
+
else:
|
|
1652
|
+
print_debug("best_sample_by_name was empty")
|
|
1588
1653
|
|
|
1589
1654
|
decoder_registry["RandomForestGenerationNode"] = RandomForestGenerationNode
|
|
1590
1655
|
|
|
@@ -1632,10 +1697,10 @@ class InteractiveCLIGenerationNode(ExternalGenerationNode):
|
|
|
1632
1697
|
|
|
1633
1698
|
def __init__(
|
|
1634
1699
|
self: Any,
|
|
1635
|
-
|
|
1700
|
+
name: str = "INTERACTIVE_GENERATOR",
|
|
1636
1701
|
) -> None:
|
|
1637
1702
|
t0 = time.monotonic()
|
|
1638
|
-
super().__init__(
|
|
1703
|
+
super().__init__(name=name)
|
|
1639
1704
|
self.parameters = None
|
|
1640
1705
|
self.minimize = None
|
|
1641
1706
|
self.data = None
|
|
@@ -1791,10 +1856,10 @@ class InteractiveCLIGenerationNode(ExternalGenerationNode):
|
|
|
1791
1856
|
|
|
1792
1857
|
@dataclass(init=False)
|
|
1793
1858
|
class ExternalProgramGenerationNode(ExternalGenerationNode):
|
|
1794
|
-
def __init__(self: Any, external_generator: str = args.external_generator,
|
|
1859
|
+
def __init__(self: Any, external_generator: str = args.external_generator, name: str = "EXTERNAL_GENERATOR") -> None:
|
|
1795
1860
|
print_debug("Initializing ExternalProgramGenerationNode...")
|
|
1796
1861
|
t_init_start = time.monotonic()
|
|
1797
|
-
super().__init__(
|
|
1862
|
+
super().__init__(name=name)
|
|
1798
1863
|
self.seed: int = args.seed
|
|
1799
1864
|
self.external_generator: str = decode_if_base64(external_generator)
|
|
1800
1865
|
self.constraints = None
|
|
@@ -2015,24 +2080,15 @@ def run_live_share_command(force: bool = False) -> Tuple[str, str]:
|
|
|
2015
2080
|
return str(result.stdout), str(result.stderr)
|
|
2016
2081
|
except subprocess.CalledProcessError as e:
|
|
2017
2082
|
if e.stderr:
|
|
2018
|
-
|
|
2083
|
+
print_debug(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
|
|
2019
2084
|
else:
|
|
2020
|
-
|
|
2085
|
+
print_debug(f"run_live_share_command: command failed with error: {e}")
|
|
2021
2086
|
return "", str(e.stderr)
|
|
2022
2087
|
except Exception as e:
|
|
2023
2088
|
print(f"run_live_share_command: An error occurred: {e}")
|
|
2024
2089
|
|
|
2025
2090
|
return "", ""
|
|
2026
2091
|
|
|
2027
|
-
#def extract_and_print_qr(text: str) -> None:
|
|
2028
|
-
# match = re.search(r"(https?://\S+|\b[\w.-]+@[\w.-]+\.\w+\b|\b\d{10,}\b)", text)
|
|
2029
|
-
# if match:
|
|
2030
|
-
# data = match.group(0)
|
|
2031
|
-
# qr = qrcode.QRCode(box_size=1, error_correction=qrcode.constants.ERROR_CORRECT_L, border=0)
|
|
2032
|
-
# qr.add_data(data)
|
|
2033
|
-
# qr.make()
|
|
2034
|
-
# qr.print_ascii(out=sys.stdout)
|
|
2035
|
-
|
|
2036
2092
|
def force_live_share() -> bool:
|
|
2037
2093
|
if args.live_share:
|
|
2038
2094
|
return live_share(True)
|
|
@@ -2040,6 +2096,8 @@ def force_live_share() -> bool:
|
|
|
2040
2096
|
return False
|
|
2041
2097
|
|
|
2042
2098
|
def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
|
|
2099
|
+
log_data()
|
|
2100
|
+
|
|
2043
2101
|
if not get_current_run_folder():
|
|
2044
2102
|
print(f"live_share: get_current_run_folder was empty or false: {get_current_run_folder()}")
|
|
2045
2103
|
return False
|
|
@@ -2052,25 +2110,24 @@ def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
|
|
|
2052
2110
|
if text_and_qr:
|
|
2053
2111
|
if stderr:
|
|
2054
2112
|
print_green(stderr)
|
|
2055
|
-
#extract_and_print_qr(stderr)
|
|
2056
2113
|
else:
|
|
2057
|
-
|
|
2114
|
+
if stderr and stdout:
|
|
2115
|
+
print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
|
|
2116
|
+
elif stderr:
|
|
2117
|
+
print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}")
|
|
2118
|
+
elif stdout:
|
|
2119
|
+
print_red(f"This call should have shown the CURL, but didnt. Stdout: {stdout}")
|
|
2120
|
+
else:
|
|
2121
|
+
print_red("This call should have shown the CURL, but didnt.")
|
|
2058
2122
|
if stdout:
|
|
2059
2123
|
print_debug(f"live_share stdout: {stdout}")
|
|
2060
2124
|
|
|
2061
2125
|
return True
|
|
2062
2126
|
|
|
2063
2127
|
def init_live_share() -> bool:
|
|
2064
|
-
|
|
2065
|
-
ret = live_share(True, True)
|
|
2066
|
-
|
|
2067
|
-
return ret
|
|
2128
|
+
ret = live_share(True, True)
|
|
2068
2129
|
|
|
2069
|
-
|
|
2070
|
-
if args.live_share and not os.environ.get("CI"):
|
|
2071
|
-
while True:
|
|
2072
|
-
live_share(force=False)
|
|
2073
|
-
time.sleep(30)
|
|
2130
|
+
return ret
|
|
2074
2131
|
|
|
2075
2132
|
def init_storage(db_url: str) -> None:
|
|
2076
2133
|
init_engine_and_session_factory(url=db_url, force_init=True)
|
|
@@ -2096,29 +2153,33 @@ def try_saving_to_db() -> None:
|
|
|
2096
2153
|
else:
|
|
2097
2154
|
print_red("ax_client was not defined in try_saving_to_db")
|
|
2098
2155
|
my_exit(101)
|
|
2099
|
-
|
|
2156
|
+
|
|
2157
|
+
if global_gs is not None:
|
|
2158
|
+
save_generation_strategy(global_gs)
|
|
2159
|
+
else:
|
|
2160
|
+
print_red("Not saving generation strategy: global_gs was empty")
|
|
2100
2161
|
except Exception as e:
|
|
2101
2162
|
print_debug(f"Failed trying to save sqlite3-DB: {e}")
|
|
2102
2163
|
|
|
2103
|
-
def merge_with_job_infos(
|
|
2164
|
+
def merge_with_job_infos(df: pd.DataFrame) -> pd.DataFrame:
|
|
2104
2165
|
job_infos_path = os.path.join(get_current_run_folder(), "job_infos.csv")
|
|
2105
2166
|
if not os.path.exists(job_infos_path):
|
|
2106
|
-
return
|
|
2167
|
+
return df
|
|
2107
2168
|
|
|
2108
2169
|
job_df = pd.read_csv(job_infos_path)
|
|
2109
2170
|
|
|
2110
|
-
if 'trial_index' not in
|
|
2171
|
+
if 'trial_index' not in df.columns or 'trial_index' not in job_df.columns:
|
|
2111
2172
|
raise ValueError("Both DataFrames must contain a 'trial_index' column.")
|
|
2112
2173
|
|
|
2113
|
-
job_df_filtered = job_df[job_df['trial_index'].isin(
|
|
2174
|
+
job_df_filtered = job_df[job_df['trial_index'].isin(df['trial_index'])]
|
|
2114
2175
|
|
|
2115
|
-
new_cols = [col for col in job_df_filtered.columns if col != 'trial_index' and col not in
|
|
2176
|
+
new_cols = [col for col in job_df_filtered.columns if col != 'trial_index' and col not in df.columns]
|
|
2116
2177
|
|
|
2117
2178
|
job_df_reduced = job_df_filtered[['trial_index'] + new_cols]
|
|
2118
2179
|
|
|
2119
|
-
merged = pd.merge(
|
|
2180
|
+
merged = pd.merge(df, job_df_reduced, on='trial_index', how='left')
|
|
2120
2181
|
|
|
2121
|
-
old_cols = [col for col in
|
|
2182
|
+
old_cols = [col for col in df.columns if col != 'trial_index']
|
|
2122
2183
|
|
|
2123
2184
|
new_order = ['trial_index'] + new_cols + old_cols
|
|
2124
2185
|
|
|
@@ -2126,48 +2187,9 @@ def merge_with_job_infos(pd_frame: pd.DataFrame) -> pd.DataFrame:
|
|
|
2126
2187
|
|
|
2127
2188
|
return merged
|
|
2128
2189
|
|
|
2129
|
-
def reindex_trials(df: pd.DataFrame) -> pd.DataFrame:
|
|
2130
|
-
"""
|
|
2131
|
-
Ensure trial_index is sequential and arm_name unique.
|
|
2132
|
-
Keep arm_name unless all parameters except 'order', 'hostname', 'queue_time' match.
|
|
2133
|
-
"""
|
|
2134
|
-
if "trial_index" not in df.columns or "arm_name" not in df.columns:
|
|
2135
|
-
return df
|
|
2136
|
-
|
|
2137
|
-
# Sort by something stable (queue_time if available)
|
|
2138
|
-
sort_cols = ["queue_time"] if "queue_time" in df.columns else df.columns.tolist()
|
|
2139
|
-
df = df.sort_values(by=sort_cols, ignore_index=True)
|
|
2140
|
-
|
|
2141
|
-
# Mapping from "parameter signature" to assigned arm_name
|
|
2142
|
-
seen_signatures = {}
|
|
2143
|
-
new_arm_names = []
|
|
2144
|
-
|
|
2145
|
-
for new_idx, row in df.iterrows():
|
|
2146
|
-
# Create signature without 'order', 'hostname', 'queue_time', 'trial_index', 'arm_name'
|
|
2147
|
-
ignore_cols = {"order", "hostname", "queue_time", "trial_index", "arm_name"}
|
|
2148
|
-
signature = tuple((col, row[col]) for col in df.columns if col not in ignore_cols)
|
|
2149
|
-
|
|
2150
|
-
if signature in seen_signatures:
|
|
2151
|
-
# Collision → make a unique name
|
|
2152
|
-
base_name = seen_signatures[signature]
|
|
2153
|
-
suffix = 1
|
|
2154
|
-
new_name = f"{base_name}_{suffix}"
|
|
2155
|
-
while new_name in new_arm_names:
|
|
2156
|
-
suffix += 1
|
|
2157
|
-
new_name = f"{base_name}_{suffix}"
|
|
2158
|
-
new_arm_names.append(new_name)
|
|
2159
|
-
else:
|
|
2160
|
-
# First occurrence → use new_idx as trial index in name
|
|
2161
|
-
new_name = f"{new_idx}_0"
|
|
2162
|
-
seen_signatures[signature] = f"{new_idx}_0"
|
|
2163
|
-
new_arm_names.append(new_name)
|
|
2164
|
-
|
|
2165
|
-
df.at[new_idx, "trial_index"] = new_idx
|
|
2166
|
-
df.at[new_idx, "arm_name"] = new_name
|
|
2167
|
-
|
|
2168
|
-
return df
|
|
2169
|
-
|
|
2170
2190
|
def save_results_csv() -> Optional[str]:
|
|
2191
|
+
log_data()
|
|
2192
|
+
|
|
2171
2193
|
if args.dryrun:
|
|
2172
2194
|
return None
|
|
2173
2195
|
|
|
@@ -2181,8 +2203,11 @@ def save_results_csv() -> Optional[str]:
|
|
|
2181
2203
|
save_checkpoint()
|
|
2182
2204
|
|
|
2183
2205
|
try:
|
|
2184
|
-
|
|
2185
|
-
|
|
2206
|
+
df = fetch_and_prepare_trials()
|
|
2207
|
+
if df is None:
|
|
2208
|
+
print_red(f"save_results_csv: fetch_and_prepare_trials returned an empty element: {df}")
|
|
2209
|
+
return None
|
|
2210
|
+
write_csv(df, pd_csv)
|
|
2186
2211
|
write_json_snapshot(pd_json)
|
|
2187
2212
|
save_experiment_to_file()
|
|
2188
2213
|
|
|
@@ -2194,32 +2219,67 @@ def save_results_csv() -> Optional[str]:
|
|
|
2194
2219
|
except (SignalUSR, SignalCONT, SignalINT) as e:
|
|
2195
2220
|
raise type(e)(str(e)) from e
|
|
2196
2221
|
except Exception as e:
|
|
2197
|
-
print_red(f"
|
|
2222
|
+
print_red(f"\nWhile saving all trials as a pandas-dataframe-csv, an error occurred: {e}")
|
|
2198
2223
|
|
|
2199
2224
|
return pd_csv
|
|
2200
2225
|
|
|
2201
2226
|
def get_results_paths() -> tuple[str, str]:
|
|
2202
2227
|
return (get_current_run_folder(RESULTS_CSV_FILENAME), get_state_file_name('pd.json'))
|
|
2203
2228
|
|
|
2204
|
-
def
|
|
2229
|
+
def ax_client_get_trials_data_frame() -> Optional[pd.DataFrame]:
|
|
2230
|
+
if not ax_client:
|
|
2231
|
+
my_exit(101)
|
|
2232
|
+
|
|
2233
|
+
return None
|
|
2234
|
+
|
|
2235
|
+
return ax_client.get_trials_data_frame()
|
|
2236
|
+
|
|
2237
|
+
def fetch_and_prepare_trials() -> Optional[pd.DataFrame]:
|
|
2238
|
+
if not ax_client:
|
|
2239
|
+
return None
|
|
2240
|
+
|
|
2205
2241
|
ax_client.experiment.fetch_data()
|
|
2206
|
-
df =
|
|
2242
|
+
df = ax_client_get_trials_data_frame()
|
|
2207
2243
|
df = merge_with_job_infos(df)
|
|
2208
|
-
return reindex_trials(df)
|
|
2209
2244
|
|
|
2210
|
-
|
|
2245
|
+
return df
|
|
2246
|
+
|
|
2247
|
+
def write_csv(df: pd.DataFrame, path: str) -> None:
|
|
2248
|
+
try:
|
|
2249
|
+
df = df.sort_values(by=["trial_index"], kind="stable").reset_index(drop=True)
|
|
2250
|
+
except KeyError:
|
|
2251
|
+
pass
|
|
2211
2252
|
df.to_csv(path, index=False, float_format="%.30f")
|
|
2212
2253
|
|
|
2213
|
-
def
|
|
2254
|
+
def ax_client_to_json_snapshot() -> Optional[dict]:
|
|
2255
|
+
if not ax_client:
|
|
2256
|
+
my_exit(101)
|
|
2257
|
+
|
|
2258
|
+
return None
|
|
2259
|
+
|
|
2214
2260
|
json_snapshot = ax_client.to_json_snapshot()
|
|
2215
|
-
|
|
2216
|
-
|
|
2261
|
+
|
|
2262
|
+
return json_snapshot
|
|
2263
|
+
|
|
2264
|
+
def write_json_snapshot(path: str) -> None:
|
|
2265
|
+
if ax_client is not None:
|
|
2266
|
+
json_snapshot = ax_client_to_json_snapshot()
|
|
2267
|
+
if json_snapshot is not None:
|
|
2268
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
2269
|
+
json.dump(json_snapshot, f, indent=4)
|
|
2270
|
+
else:
|
|
2271
|
+
print_debug('json_snapshot from ax_client_to_json_snapshot was None')
|
|
2272
|
+
else:
|
|
2273
|
+
print_red("write_json_snapshot: ax_client was None")
|
|
2217
2274
|
|
|
2218
2275
|
def save_experiment_to_file() -> None:
|
|
2219
|
-
|
|
2220
|
-
|
|
2221
|
-
|
|
2222
|
-
|
|
2276
|
+
if ax_client is not None:
|
|
2277
|
+
save_experiment(
|
|
2278
|
+
ax_client.experiment,
|
|
2279
|
+
get_state_file_name("ax_client.experiment.json")
|
|
2280
|
+
)
|
|
2281
|
+
else:
|
|
2282
|
+
print_red("save_experiment: ax_client is None")
|
|
2223
2283
|
|
|
2224
2284
|
def should_save_to_database() -> bool:
|
|
2225
2285
|
return args.model not in uncontinuable_models and args.save_to_database
|
|
@@ -2408,7 +2468,7 @@ def set_nr_inserted_jobs(new_nr_inserted_jobs: int) -> None:
|
|
|
2408
2468
|
|
|
2409
2469
|
def write_worker_usage() -> None:
|
|
2410
2470
|
if len(WORKER_PERCENTAGE_USAGE):
|
|
2411
|
-
csv_filename = get_current_run_folder(
|
|
2471
|
+
csv_filename = get_current_run_folder(worker_usage_file)
|
|
2412
2472
|
|
|
2413
2473
|
csv_columns = ['time', 'num_parallel_jobs', 'nr_current_workers', 'percentage']
|
|
2414
2474
|
|
|
@@ -2418,35 +2478,39 @@ def write_worker_usage() -> None:
|
|
|
2418
2478
|
csv_writer.writerow(row)
|
|
2419
2479
|
else:
|
|
2420
2480
|
if is_slurm_job():
|
|
2421
|
-
print_debug("WORKER_PERCENTAGE_USAGE seems to be empty. Not writing
|
|
2481
|
+
print_debug(f"WORKER_PERCENTAGE_USAGE seems to be empty. Not writing {worker_usage_file}")
|
|
2422
2482
|
|
|
2423
2483
|
def log_system_usage() -> None:
|
|
2484
|
+
global LAST_LOG_TIME
|
|
2485
|
+
|
|
2486
|
+
now = time.time()
|
|
2487
|
+
if now - LAST_LOG_TIME < 30:
|
|
2488
|
+
return
|
|
2489
|
+
|
|
2490
|
+
LAST_LOG_TIME = int(now)
|
|
2491
|
+
|
|
2424
2492
|
if not get_current_run_folder():
|
|
2425
2493
|
return
|
|
2426
2494
|
|
|
2427
2495
|
ram_cpu_csv_file_path = os.path.join(get_current_run_folder(), "cpu_ram_usage.csv")
|
|
2428
|
-
|
|
2429
2496
|
makedirs(os.path.dirname(ram_cpu_csv_file_path))
|
|
2430
2497
|
|
|
2431
2498
|
file_exists = os.path.isfile(ram_cpu_csv_file_path)
|
|
2432
2499
|
|
|
2433
|
-
|
|
2434
|
-
|
|
2435
|
-
|
|
2436
|
-
current_time = int(time.time())
|
|
2437
|
-
|
|
2438
|
-
if process is not None:
|
|
2439
|
-
mem_proc = process.memory_info()
|
|
2440
|
-
|
|
2441
|
-
if mem_proc is not None:
|
|
2442
|
-
ram_usage_mb = mem_proc.rss / (1024 * 1024)
|
|
2443
|
-
cpu_usage_percent = psutil.cpu_percent(percpu=False)
|
|
2500
|
+
mem_proc = process.memory_info() if process else None
|
|
2501
|
+
if not mem_proc:
|
|
2502
|
+
return
|
|
2444
2503
|
|
|
2445
|
-
|
|
2446
|
-
|
|
2447
|
-
|
|
2504
|
+
ram_usage_mb = mem_proc.rss / (1024 * 1024)
|
|
2505
|
+
cpu_usage_percent = psutil.cpu_percent(percpu=False)
|
|
2506
|
+
if ram_usage_mb <= 0 or cpu_usage_percent <= 0:
|
|
2507
|
+
return
|
|
2448
2508
|
|
|
2449
|
-
|
|
2509
|
+
with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
|
|
2510
|
+
writer = csv.writer(file)
|
|
2511
|
+
if not file_exists:
|
|
2512
|
+
writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
|
|
2513
|
+
writer.writerow([int(now), ram_usage_mb, cpu_usage_percent])
|
|
2450
2514
|
|
|
2451
2515
|
def write_process_info() -> None:
|
|
2452
2516
|
try:
|
|
@@ -2598,9 +2662,6 @@ def _debug_worker_creation(msg: str, _lvl: int = 0, eee: Union[None, str, Except
|
|
|
2598
2662
|
def append_to_nvidia_smi_logs(_file: str, _host: str, result: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
2599
2663
|
log_message_to_file(_file, result, _lvl, str(eee))
|
|
2600
2664
|
|
|
2601
|
-
def _debug_get_next_trials(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
2602
|
-
log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, _lvl, str(eee))
|
|
2603
|
-
|
|
2604
2665
|
def _debug_progressbar(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
2605
2666
|
log_message_to_file(logfile_progressbar, msg, _lvl, str(eee))
|
|
2606
2667
|
|
|
@@ -2796,13 +2857,23 @@ def print_debug_get_next_trials(got: int, requested: int, _line: int) -> None:
|
|
|
2796
2857
|
time_str: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
2797
2858
|
msg: str = f"{time_str}, {got}, {requested}"
|
|
2798
2859
|
|
|
2799
|
-
|
|
2860
|
+
log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, 0, "")
|
|
2800
2861
|
|
|
2801
2862
|
def print_debug_progressbar(msg: str) -> None:
|
|
2802
|
-
|
|
2803
|
-
|
|
2863
|
+
global last_msg_progressbar, last_msg_raw
|
|
2864
|
+
|
|
2865
|
+
try:
|
|
2866
|
+
with last_lock_print_debug:
|
|
2867
|
+
if msg != last_msg_raw:
|
|
2868
|
+
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
2869
|
+
full_msg = f"{time_str} ({worker_generator_uuid}): {msg}"
|
|
2870
|
+
|
|
2871
|
+
_debug_progressbar(full_msg)
|
|
2804
2872
|
|
|
2805
|
-
|
|
2873
|
+
last_msg_raw = msg
|
|
2874
|
+
last_msg_progressbar = full_msg
|
|
2875
|
+
except Exception as e:
|
|
2876
|
+
print(f"Error in print_debug_progressbar: {e}", flush=True)
|
|
2806
2877
|
|
|
2807
2878
|
def get_process_info(pid: Any) -> str:
|
|
2808
2879
|
try:
|
|
@@ -3307,116 +3378,438 @@ def parse_experiment_parameters() -> None:
|
|
|
3307
3378
|
# Remove duplicates by 'name' key preserving order
|
|
3308
3379
|
params = list({p['name']: p for p in params}.values())
|
|
3309
3380
|
|
|
3310
|
-
experiment_parameters = params
|
|
3381
|
+
experiment_parameters = params # type: ignore[assignment]
|
|
3311
3382
|
|
|
3312
|
-
def
|
|
3313
|
-
|
|
3314
|
-
_fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
|
|
3383
|
+
def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
|
|
3384
|
+
pf_start_time = time.time()
|
|
3315
3385
|
|
|
3316
|
-
|
|
3317
|
-
|
|
3318
|
-
valid_value_types_string = ", ".join(valid_value_types)
|
|
3319
|
-
_fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
|
|
3386
|
+
if not path_to_calculate:
|
|
3387
|
+
return False
|
|
3320
3388
|
|
|
3321
|
-
|
|
3322
|
-
|
|
3323
|
-
|
|
3389
|
+
global CURRENT_RUN_FOLDER
|
|
3390
|
+
global RESULT_CSV_FILE
|
|
3391
|
+
global arg_result_names
|
|
3324
3392
|
|
|
3325
|
-
|
|
3326
|
-
|
|
3327
|
-
|
|
3328
|
-
if upper_bound == lower_bound:
|
|
3329
|
-
if lower_bound == 0:
|
|
3330
|
-
_fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
|
|
3331
|
-
print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
|
|
3332
|
-
if upper_bound is not None:
|
|
3333
|
-
lower_bound = -upper_bound
|
|
3393
|
+
if not path_to_calculate:
|
|
3394
|
+
print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
|
|
3395
|
+
return False
|
|
3334
3396
|
|
|
3335
|
-
|
|
3336
|
-
|
|
3337
|
-
|
|
3338
|
-
s = format(value, float_format)
|
|
3339
|
-
s = s.rstrip('0').rstrip('.') if '.' in s else s
|
|
3340
|
-
return s
|
|
3341
|
-
return str(value)
|
|
3342
|
-
except Exception as e:
|
|
3343
|
-
print_red(f"⚠ Error formatting the number {value}: {e}")
|
|
3344
|
-
return str(value)
|
|
3397
|
+
if not os.path.exists(path_to_calculate):
|
|
3398
|
+
print_red(f"Path '{path_to_calculate}' does not exist")
|
|
3399
|
+
return False
|
|
3345
3400
|
|
|
3346
|
-
|
|
3347
|
-
parameters: dict,
|
|
3348
|
-
input_string: str,
|
|
3349
|
-
float_format: str = '.20f',
|
|
3350
|
-
additional_prefixes: list[str] = [],
|
|
3351
|
-
additional_patterns: list[str] = [],
|
|
3352
|
-
) -> str:
|
|
3353
|
-
try:
|
|
3354
|
-
prefixes = ['$', '%'] + additional_prefixes
|
|
3355
|
-
patterns = ['{key}', '({key})'] + additional_patterns
|
|
3401
|
+
ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
|
|
3356
3402
|
|
|
3357
|
-
|
|
3358
|
-
|
|
3359
|
-
|
|
3360
|
-
for pattern in patterns:
|
|
3361
|
-
token = prefix + pattern.format(key=key)
|
|
3362
|
-
input_string = input_string.replace(token, replacement)
|
|
3403
|
+
if not os.path.exists(ax_client_json):
|
|
3404
|
+
print_red(f"Path '{ax_client_json}' not found")
|
|
3405
|
+
return False
|
|
3363
3406
|
|
|
3364
|
-
|
|
3365
|
-
|
|
3407
|
+
checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
|
|
3408
|
+
if not os.path.exists(checkpoint_file):
|
|
3409
|
+
print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
|
|
3410
|
+
return False
|
|
3366
3411
|
|
|
3367
|
-
|
|
3368
|
-
|
|
3369
|
-
|
|
3412
|
+
RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3413
|
+
if not os.path.exists(RESULT_CSV_FILE):
|
|
3414
|
+
print_red(f"{RESULT_CSV_FILE} not found")
|
|
3415
|
+
return False
|
|
3370
3416
|
|
|
3371
|
-
|
|
3372
|
-
user_uid = os.getuid()
|
|
3417
|
+
res_names = []
|
|
3373
3418
|
|
|
3374
|
-
|
|
3375
|
-
|
|
3376
|
-
|
|
3377
|
-
|
|
3419
|
+
res_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3420
|
+
if not os.path.exists(res_names_file):
|
|
3421
|
+
print_red(f"File '{res_names_file}' does not exist")
|
|
3422
|
+
return False
|
|
3378
3423
|
|
|
3379
|
-
|
|
3424
|
+
try:
|
|
3425
|
+
with open(res_names_file, "r", encoding="utf-8") as file:
|
|
3426
|
+
lines = file.readlines()
|
|
3427
|
+
except Exception as e:
|
|
3428
|
+
print_red(f"Error reading file '{res_names_file}': {e}")
|
|
3429
|
+
return False
|
|
3380
3430
|
|
|
3381
|
-
|
|
3382
|
-
|
|
3383
|
-
|
|
3384
|
-
|
|
3385
|
-
self.running = True
|
|
3386
|
-
self.thread = threading.Thread(target=self._monitor)
|
|
3387
|
-
self.thread.daemon = True
|
|
3431
|
+
for line in lines:
|
|
3432
|
+
entry = line.strip()
|
|
3433
|
+
if entry != "":
|
|
3434
|
+
res_names.append(entry)
|
|
3388
3435
|
|
|
3389
|
-
|
|
3436
|
+
if len(res_names) < 2:
|
|
3437
|
+
print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
|
|
3438
|
+
return False
|
|
3390
3439
|
|
|
3391
|
-
|
|
3392
|
-
try:
|
|
3393
|
-
_internal_process = psutil.Process(self.pid)
|
|
3394
|
-
while self.running and _internal_process.is_running():
|
|
3395
|
-
crf = get_current_run_folder()
|
|
3440
|
+
load_username_to_args(path_to_calculate)
|
|
3396
3441
|
|
|
3397
|
-
|
|
3398
|
-
log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
|
|
3442
|
+
CURRENT_RUN_FOLDER = path_to_calculate
|
|
3399
3443
|
|
|
3400
|
-
|
|
3444
|
+
arg_result_names = res_names
|
|
3401
3445
|
|
|
3402
|
-
|
|
3403
|
-
hostname = socket.gethostname()
|
|
3446
|
+
load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
|
|
3404
3447
|
|
|
3405
|
-
|
|
3448
|
+
if experiment_parameters is None:
|
|
3449
|
+
return False
|
|
3406
3450
|
|
|
3407
|
-
|
|
3408
|
-
hostname += f"-SLURM-ID-{slurm_job_id}"
|
|
3451
|
+
show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
|
|
3409
3452
|
|
|
3410
|
-
|
|
3411
|
-
cpu_usage = psutil.cpu_percent(interval=5)
|
|
3453
|
+
pf_end_time = time.time()
|
|
3412
3454
|
|
|
3413
|
-
|
|
3455
|
+
print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
|
|
3414
3456
|
|
|
3415
|
-
|
|
3457
|
+
return True
|
|
3416
3458
|
|
|
3417
|
-
|
|
3418
|
-
|
|
3419
|
-
|
|
3459
|
+
def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
|
|
3460
|
+
if args.dryrun:
|
|
3461
|
+
print_debug("Not showing Pareto-frontier data with --dryrun")
|
|
3462
|
+
return None
|
|
3463
|
+
|
|
3464
|
+
if len(res_names) > 1:
|
|
3465
|
+
try:
|
|
3466
|
+
show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
|
|
3467
|
+
except Exception as e:
|
|
3468
|
+
inner_tb = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
3469
|
+
print_red(f"show_pareto_frontier_data() failed with exception '{e}':\n{inner_tb}")
|
|
3470
|
+
else:
|
|
3471
|
+
print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
|
|
3472
|
+
return None
|
|
3473
|
+
|
|
3474
|
+
def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
|
|
3475
|
+
pareto_front_data: dict = {}
|
|
3476
|
+
|
|
3477
|
+
all_combinations = list(combinations(range(len(arg_result_names)), 2))
|
|
3478
|
+
|
|
3479
|
+
skip = False
|
|
3480
|
+
|
|
3481
|
+
for i, j in all_combinations:
|
|
3482
|
+
if not skip:
|
|
3483
|
+
metric_x = arg_result_names[i]
|
|
3484
|
+
metric_y = arg_result_names[j]
|
|
3485
|
+
|
|
3486
|
+
x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
|
|
3487
|
+
y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
|
|
3488
|
+
|
|
3489
|
+
try:
|
|
3490
|
+
if metric_x not in pareto_front_data:
|
|
3491
|
+
pareto_front_data[metric_x] = {}
|
|
3492
|
+
|
|
3493
|
+
pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
|
|
3494
|
+
except ax.exceptions.core.DataRequiredError as e:
|
|
3495
|
+
print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
|
|
3496
|
+
except SignalINT:
|
|
3497
|
+
print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
|
|
3498
|
+
skip = True
|
|
3499
|
+
|
|
3500
|
+
return pareto_front_data
|
|
3501
|
+
|
|
3502
|
+
def pareto_front_transform_objectives(
|
|
3503
|
+
points: List[Tuple[Any, float, float]],
|
|
3504
|
+
primary_name: str,
|
|
3505
|
+
secondary_name: str
|
|
3506
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
3507
|
+
primary_idx = arg_result_names.index(primary_name)
|
|
3508
|
+
secondary_idx = arg_result_names.index(secondary_name)
|
|
3509
|
+
|
|
3510
|
+
x = np.array([p[1] for p in points])
|
|
3511
|
+
y = np.array([p[2] for p in points])
|
|
3512
|
+
|
|
3513
|
+
if arg_result_min_or_max[primary_idx] == "max":
|
|
3514
|
+
x = -x
|
|
3515
|
+
elif arg_result_min_or_max[primary_idx] != "min":
|
|
3516
|
+
raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
|
|
3517
|
+
|
|
3518
|
+
if arg_result_min_or_max[secondary_idx] == "max":
|
|
3519
|
+
y = -y
|
|
3520
|
+
elif arg_result_min_or_max[secondary_idx] != "min":
|
|
3521
|
+
raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
|
|
3522
|
+
|
|
3523
|
+
return x, y
|
|
3524
|
+
|
|
3525
|
+
def get_pareto_frontier_points(
|
|
3526
|
+
path_to_calculate: str,
|
|
3527
|
+
primary_objective: str,
|
|
3528
|
+
secondary_objective: str,
|
|
3529
|
+
x_minimize: bool,
|
|
3530
|
+
y_minimize: bool,
|
|
3531
|
+
absolute_metrics: List[str],
|
|
3532
|
+
num_points: int
|
|
3533
|
+
) -> Optional[dict]:
|
|
3534
|
+
records = pareto_front_aggregate_data(path_to_calculate)
|
|
3535
|
+
|
|
3536
|
+
if records is None:
|
|
3537
|
+
return None
|
|
3538
|
+
|
|
3539
|
+
points = pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
|
|
3540
|
+
x, y = pareto_front_transform_objectives(points, primary_objective, secondary_objective)
|
|
3541
|
+
selected_points = pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
|
|
3542
|
+
result = pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
|
|
3543
|
+
|
|
3544
|
+
return result
|
|
3545
|
+
|
|
3546
|
+
def pareto_front_table_read_csv() -> List[Dict[str, str]]:
|
|
3547
|
+
with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
|
|
3548
|
+
return list(csv.DictReader(f))
|
|
3549
|
+
|
|
3550
|
+
def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
|
|
3551
|
+
table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
|
|
3552
|
+
|
|
3553
|
+
rows = pareto_front_table_read_csv()
|
|
3554
|
+
if not rows:
|
|
3555
|
+
table.add_column("No data found")
|
|
3556
|
+
return table
|
|
3557
|
+
|
|
3558
|
+
filtered_rows = pareto_front_table_filter_rows(rows, idxs)
|
|
3559
|
+
if not filtered_rows:
|
|
3560
|
+
table.add_column("No matching entries")
|
|
3561
|
+
return table
|
|
3562
|
+
|
|
3563
|
+
param_cols, result_cols = pareto_front_table_get_columns(filtered_rows[0])
|
|
3564
|
+
|
|
3565
|
+
pareto_front_table_add_headers(table, param_cols, result_cols)
|
|
3566
|
+
pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
|
|
3567
|
+
|
|
3568
|
+
return table
|
|
3569
|
+
|
|
3570
|
+
def pareto_front_build_return_structure(
|
|
3571
|
+
path_to_calculate: str,
|
|
3572
|
+
selected_points: List[Tuple[Any, float, float]],
|
|
3573
|
+
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
3574
|
+
absolute_metrics: List[str],
|
|
3575
|
+
primary_name: str,
|
|
3576
|
+
secondary_name: str
|
|
3577
|
+
) -> dict:
|
|
3578
|
+
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3579
|
+
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3580
|
+
|
|
3581
|
+
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
3582
|
+
result_names = [line.strip() for line in f if line.strip()]
|
|
3583
|
+
|
|
3584
|
+
csv_rows = {}
|
|
3585
|
+
with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
|
|
3586
|
+
reader = csv.DictReader(csvfile)
|
|
3587
|
+
for row in reader:
|
|
3588
|
+
trial_index = int(row['trial_index'])
|
|
3589
|
+
csv_rows[trial_index] = row
|
|
3590
|
+
|
|
3591
|
+
ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
|
|
3592
|
+
ignored_columns.update(result_names)
|
|
3593
|
+
|
|
3594
|
+
param_dicts = []
|
|
3595
|
+
idxs = []
|
|
3596
|
+
means_dict = defaultdict(list)
|
|
3597
|
+
|
|
3598
|
+
for (trial_index, arm_name), _, _ in selected_points:
|
|
3599
|
+
row = csv_rows.get(trial_index, {})
|
|
3600
|
+
if row == {} or row is None or row['arm_name'] != arm_name:
|
|
3601
|
+
continue
|
|
3602
|
+
|
|
3603
|
+
idxs.append(int(row["trial_index"]))
|
|
3604
|
+
|
|
3605
|
+
param_dict: dict[str, int | float | str] = {}
|
|
3606
|
+
for key, value in row.items():
|
|
3607
|
+
if key not in ignored_columns:
|
|
3608
|
+
try:
|
|
3609
|
+
param_dict[key] = int(value)
|
|
3610
|
+
except ValueError:
|
|
3611
|
+
try:
|
|
3612
|
+
param_dict[key] = float(value)
|
|
3613
|
+
except ValueError:
|
|
3614
|
+
param_dict[key] = value
|
|
3615
|
+
|
|
3616
|
+
param_dicts.append(param_dict)
|
|
3617
|
+
|
|
3618
|
+
for metric in absolute_metrics:
|
|
3619
|
+
means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
|
|
3620
|
+
|
|
3621
|
+
ret = {
|
|
3622
|
+
primary_name: {
|
|
3623
|
+
secondary_name: {
|
|
3624
|
+
"absolute_metrics": absolute_metrics,
|
|
3625
|
+
"param_dicts": param_dicts,
|
|
3626
|
+
"means": dict(means_dict),
|
|
3627
|
+
"idxs": idxs
|
|
3628
|
+
},
|
|
3629
|
+
"absolute_metrics": absolute_metrics
|
|
3630
|
+
}
|
|
3631
|
+
}
|
|
3632
|
+
|
|
3633
|
+
return ret
|
|
3634
|
+
|
|
3635
|
+
def pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
|
|
3636
|
+
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3637
|
+
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3638
|
+
|
|
3639
|
+
if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
|
|
3640
|
+
return None
|
|
3641
|
+
|
|
3642
|
+
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
3643
|
+
result_names = [line.strip() for line in f if line.strip()]
|
|
3644
|
+
|
|
3645
|
+
records: dict = defaultdict(lambda: {'means': {}})
|
|
3646
|
+
|
|
3647
|
+
with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
|
|
3648
|
+
reader = csv.DictReader(csvfile)
|
|
3649
|
+
for row in reader:
|
|
3650
|
+
trial_index = int(row['trial_index'])
|
|
3651
|
+
arm_name = row['arm_name']
|
|
3652
|
+
key = (trial_index, arm_name)
|
|
3653
|
+
|
|
3654
|
+
for metric in result_names:
|
|
3655
|
+
if metric in row:
|
|
3656
|
+
try:
|
|
3657
|
+
records[key]['means'][metric] = float(row[metric])
|
|
3658
|
+
except ValueError:
|
|
3659
|
+
continue
|
|
3660
|
+
|
|
3661
|
+
return records
|
|
3662
|
+
|
|
3663
|
+
def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
|
|
3664
|
+
if data is None:
|
|
3665
|
+
print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
|
|
3666
|
+
return
|
|
3667
|
+
|
|
3668
|
+
if not supports_sixel():
|
|
3669
|
+
print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
|
|
3670
|
+
return
|
|
3671
|
+
|
|
3672
|
+
import matplotlib.pyplot as plt
|
|
3673
|
+
|
|
3674
|
+
means = data[x_metric][y_metric]["means"]
|
|
3675
|
+
|
|
3676
|
+
x_values = means[x_metric]
|
|
3677
|
+
y_values = means[y_metric]
|
|
3678
|
+
|
|
3679
|
+
fig, _ax = plt.subplots()
|
|
3680
|
+
|
|
3681
|
+
_ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
|
|
3682
|
+
|
|
3683
|
+
_ax.set_xlabel(x_metric)
|
|
3684
|
+
_ax.set_ylabel(y_metric)
|
|
3685
|
+
|
|
3686
|
+
_ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
|
|
3687
|
+
|
|
3688
|
+
_ax.ticklabel_format(style='plain', axis='both', useOffset=False)
|
|
3689
|
+
|
|
3690
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
|
|
3691
|
+
plt.savefig(tmp_file.name, dpi=300)
|
|
3692
|
+
|
|
3693
|
+
print_image_to_cli(tmp_file.name, 1000)
|
|
3694
|
+
|
|
3695
|
+
plt.close(fig)
|
|
3696
|
+
|
|
3697
|
+
def pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
|
|
3698
|
+
all_columns = list(first_row.keys())
|
|
3699
|
+
ignored_cols = set(special_col_names) - {"trial_index"}
|
|
3700
|
+
|
|
3701
|
+
param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
|
|
3702
|
+
result_cols = [col for col in arg_result_names if col in all_columns]
|
|
3703
|
+
return param_cols, result_cols
|
|
3704
|
+
|
|
3705
|
+
def check_factorial_range() -> None:
|
|
3706
|
+
if args.model and args.model == "FACTORIAL":
|
|
3707
|
+
_fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
|
|
3708
|
+
|
|
3709
|
+
def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
|
|
3710
|
+
if value_type not in valid_value_types:
|
|
3711
|
+
valid_value_types_string = ", ".join(valid_value_types)
|
|
3712
|
+
_fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
|
|
3713
|
+
|
|
3714
|
+
def check_range_params_length(this_args: Union[str, list]) -> None:
|
|
3715
|
+
if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
|
|
3716
|
+
_fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
|
|
3717
|
+
|
|
3718
|
+
def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
|
|
3719
|
+
if upper_bound is None or lower_bound is None:
|
|
3720
|
+
_fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
|
|
3721
|
+
if upper_bound == lower_bound:
|
|
3722
|
+
if lower_bound == 0:
|
|
3723
|
+
_fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
|
|
3724
|
+
print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
|
|
3725
|
+
if upper_bound is not None:
|
|
3726
|
+
lower_bound = -upper_bound
|
|
3727
|
+
|
|
3728
|
+
def format_value(value: Any, float_format: str = '.80f') -> str:
|
|
3729
|
+
try:
|
|
3730
|
+
if isinstance(value, float):
|
|
3731
|
+
s = format(value, float_format)
|
|
3732
|
+
s = s.rstrip('0').rstrip('.') if '.' in s else s
|
|
3733
|
+
return s
|
|
3734
|
+
return str(value)
|
|
3735
|
+
except Exception as e:
|
|
3736
|
+
print_red(f"⚠ Error formatting the number {value}: {e}")
|
|
3737
|
+
return str(value)
|
|
3738
|
+
|
|
3739
|
+
def replace_parameters_in_string(
|
|
3740
|
+
parameters: dict,
|
|
3741
|
+
input_string: str,
|
|
3742
|
+
float_format: str = '.20f',
|
|
3743
|
+
additional_prefixes: list[str] = [],
|
|
3744
|
+
additional_patterns: list[str] = [],
|
|
3745
|
+
) -> str:
|
|
3746
|
+
try:
|
|
3747
|
+
prefixes = ['$', '%'] + additional_prefixes
|
|
3748
|
+
patterns = ['{' + 'key' + '}', '(' + '{' + 'key' + '}' + ')'] + additional_patterns
|
|
3749
|
+
|
|
3750
|
+
for key, value in parameters.items():
|
|
3751
|
+
replacement = format_value(value, float_format=float_format)
|
|
3752
|
+
for prefix in prefixes:
|
|
3753
|
+
for pattern in patterns:
|
|
3754
|
+
token = prefix + pattern.format(key=key)
|
|
3755
|
+
input_string = input_string.replace(token, replacement)
|
|
3756
|
+
|
|
3757
|
+
input_string = input_string.replace('\r', ' ').replace('\n', ' ')
|
|
3758
|
+
return input_string
|
|
3759
|
+
|
|
3760
|
+
except Exception as e:
|
|
3761
|
+
print_red(f"\n⚠ Error: {e}")
|
|
3762
|
+
return ""
|
|
3763
|
+
|
|
3764
|
+
def get_memory_usage() -> float:
|
|
3765
|
+
user_uid = os.getuid()
|
|
3766
|
+
|
|
3767
|
+
memory_usage = float(sum(
|
|
3768
|
+
p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
|
|
3769
|
+
if p.info['uids'].real == user_uid
|
|
3770
|
+
) / (1024 * 1024))
|
|
3771
|
+
|
|
3772
|
+
return memory_usage
|
|
3773
|
+
|
|
3774
|
+
class MonitorProcess:
|
|
3775
|
+
def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
|
|
3776
|
+
self.pid = pid
|
|
3777
|
+
self.interval = interval
|
|
3778
|
+
self.running = True
|
|
3779
|
+
self.thread = threading.Thread(target=self._monitor)
|
|
3780
|
+
self.thread.daemon = True
|
|
3781
|
+
|
|
3782
|
+
fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
|
|
3783
|
+
|
|
3784
|
+
def _monitor(self: Any) -> None:
|
|
3785
|
+
try:
|
|
3786
|
+
_internal_process = psutil.Process(self.pid)
|
|
3787
|
+
while self.running and _internal_process.is_running():
|
|
3788
|
+
crf = get_current_run_folder()
|
|
3789
|
+
|
|
3790
|
+
if crf and crf != "":
|
|
3791
|
+
log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
|
|
3792
|
+
|
|
3793
|
+
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
|
3794
|
+
|
|
3795
|
+
with open(log_file_path, mode="a", encoding="utf-8") as log_file:
|
|
3796
|
+
hostname = socket.gethostname()
|
|
3797
|
+
|
|
3798
|
+
slurm_job_id = os.getenv("SLURM_JOB_ID")
|
|
3799
|
+
|
|
3800
|
+
if slurm_job_id:
|
|
3801
|
+
hostname += f"-SLURM-ID-{slurm_job_id}"
|
|
3802
|
+
|
|
3803
|
+
total_memory = psutil.virtual_memory().total / (1024 * 1024)
|
|
3804
|
+
cpu_usage = psutil.cpu_percent(interval=5)
|
|
3805
|
+
|
|
3806
|
+
memory_usage = get_memory_usage()
|
|
3807
|
+
|
|
3808
|
+
unix_timestamp = int(time.time())
|
|
3809
|
+
|
|
3810
|
+
log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
|
|
3811
|
+
time.sleep(self.interval)
|
|
3812
|
+
except psutil.NoSuchProcess:
|
|
3420
3813
|
pass
|
|
3421
3814
|
|
|
3422
3815
|
def __enter__(self: Any) -> None:
|
|
@@ -3560,7 +3953,7 @@ def _add_to_csv_acquire_lock(lockfile: str, dir_path: str) -> bool:
|
|
|
3560
3953
|
time.sleep(wait_time)
|
|
3561
3954
|
max_wait -= wait_time
|
|
3562
3955
|
except Exception as e:
|
|
3563
|
-
|
|
3956
|
+
print_red(f"Lock error: {e}")
|
|
3564
3957
|
return False
|
|
3565
3958
|
return False
|
|
3566
3959
|
|
|
@@ -3598,11 +3991,11 @@ def _add_to_csv_rewrite_file(file_path: str, rows: List[list], existing_heading:
|
|
|
3598
3991
|
writer = csv.writer(tmp_file)
|
|
3599
3992
|
writer.writerow(all_headings)
|
|
3600
3993
|
for row in rows[1:]:
|
|
3601
|
-
tmp_file.writerow([
|
|
3994
|
+
tmp_file.writerow([ # type: ignore[attr-defined]
|
|
3602
3995
|
row[existing_heading.index(h)] if h in existing_heading else ""
|
|
3603
3996
|
for h in all_headings
|
|
3604
3997
|
])
|
|
3605
|
-
tmp_file.writerow([
|
|
3998
|
+
tmp_file.writerow([ # type: ignore[attr-defined]
|
|
3606
3999
|
formatted_data[new_heading.index(h)] if h in new_heading else ""
|
|
3607
4000
|
for h in all_headings
|
|
3608
4001
|
])
|
|
@@ -3633,12 +4026,12 @@ def find_file_paths(_text: str) -> List[str]:
|
|
|
3633
4026
|
def check_file_info(file_path: str) -> str:
|
|
3634
4027
|
if not os.path.exists(file_path):
|
|
3635
4028
|
if not args.tests:
|
|
3636
|
-
|
|
4029
|
+
print_red(f"check_file_info: The file {file_path} does not exist.")
|
|
3637
4030
|
return ""
|
|
3638
4031
|
|
|
3639
4032
|
if not os.access(file_path, os.R_OK):
|
|
3640
4033
|
if not args.tests:
|
|
3641
|
-
|
|
4034
|
+
print_red(f"check_file_info: The file {file_path} is not readable.")
|
|
3642
4035
|
return ""
|
|
3643
4036
|
|
|
3644
4037
|
file_stat = os.stat(file_path)
|
|
@@ -3698,7 +4091,7 @@ def write_failed_logs(data_dict: Optional[dict], error_description: str = "") ->
|
|
|
3698
4091
|
data = [list(data_dict.values())]
|
|
3699
4092
|
else:
|
|
3700
4093
|
print_debug("No data_dict provided, writing only error description.")
|
|
3701
|
-
data = [[]]
|
|
4094
|
+
data = [[]]
|
|
3702
4095
|
|
|
3703
4096
|
if error_description:
|
|
3704
4097
|
headers.append('error_description')
|
|
@@ -3752,7 +4145,7 @@ def count_defective_nodes(file_path: Union[str, None] = None, entry: Any = None)
|
|
|
3752
4145
|
return sorted(set(entries))
|
|
3753
4146
|
|
|
3754
4147
|
except Exception as e:
|
|
3755
|
-
|
|
4148
|
+
print_red(f"An error has occurred: {e}")
|
|
3756
4149
|
return []
|
|
3757
4150
|
|
|
3758
4151
|
def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]:
|
|
@@ -3763,7 +4156,7 @@ def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]
|
|
|
3763
4156
|
|
|
3764
4157
|
fool_linter(tmp)
|
|
3765
4158
|
except RuntimeError:
|
|
3766
|
-
|
|
4159
|
+
print_red(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
|
|
3767
4160
|
count_defective_nodes(None, socket.gethostname())
|
|
3768
4161
|
return return_in_case_of_error
|
|
3769
4162
|
except Exception:
|
|
@@ -3937,7 +4330,7 @@ def _write_job_infos_csv_build_headline(parameters_keys: List[str], extra_vars_n
|
|
|
3937
4330
|
"run_time",
|
|
3938
4331
|
"program_string",
|
|
3939
4332
|
*parameters_keys,
|
|
3940
|
-
*arg_result_names,
|
|
4333
|
+
*arg_result_names,
|
|
3941
4334
|
"exit_code",
|
|
3942
4335
|
"signal",
|
|
3943
4336
|
"hostname",
|
|
@@ -4234,6 +4627,8 @@ def evaluate(parameters_with_trial_index: dict) -> Optional[Union[int, float, Di
|
|
|
4234
4627
|
trial_index = parameters_with_trial_index["trial_idx"]
|
|
4235
4628
|
submit_time = parameters_with_trial_index["submit_time"]
|
|
4236
4629
|
|
|
4630
|
+
print(f'Trial-Index: {trial_index}')
|
|
4631
|
+
|
|
4237
4632
|
queue_time = abs(int(time.time()) - int(submit_time))
|
|
4238
4633
|
|
|
4239
4634
|
start_nvidia_smi_thread()
|
|
@@ -4471,7 +4866,7 @@ def replace_string_with_params(input_string: str, params: list) -> str:
|
|
|
4471
4866
|
return replaced_string
|
|
4472
4867
|
except AssertionError as e:
|
|
4473
4868
|
error_text = f"Error in replace_string_with_params: {e}"
|
|
4474
|
-
|
|
4869
|
+
print_red(error_text)
|
|
4475
4870
|
raise
|
|
4476
4871
|
|
|
4477
4872
|
return ""
|
|
@@ -4748,9 +5143,7 @@ def get_sixel_graphics_data(_pd_csv: str, _force: bool = False) -> list:
|
|
|
4748
5143
|
_params = [_command, plot, _tmp, plot_type, tmp_file, _width]
|
|
4749
5144
|
data.append(_params)
|
|
4750
5145
|
except Exception as e:
|
|
4751
|
-
|
|
4752
|
-
print_red(f"Error trying to print {plot_type} to CLI: {e}, {tb}")
|
|
4753
|
-
print_debug(f"Error trying to print {plot_type} to CLI: {e}")
|
|
5146
|
+
print_red(f"Error trying to print {plot_type} to CLI: {e}")
|
|
4754
5147
|
|
|
4755
5148
|
return data
|
|
4756
5149
|
|
|
@@ -4858,7 +5251,7 @@ def process_best_result(res_name: str, print_to_file: bool) -> int:
|
|
|
4858
5251
|
|
|
4859
5252
|
if str(best_result) in [NO_RESULT, None, "None"]:
|
|
4860
5253
|
print_red(f"Best {res_name} could not be determined")
|
|
4861
|
-
return 87
|
|
5254
|
+
return 87 # exit-code: 87
|
|
4862
5255
|
|
|
4863
5256
|
total_str = f"total: {_count_done_jobs(RESULT_CSV_FILE) - NR_INSERTED_JOBS}"
|
|
4864
5257
|
if NR_INSERTED_JOBS:
|
|
@@ -4941,15 +5334,17 @@ def abandon_job(job: Job, trial_index: int, reason: str) -> bool:
|
|
|
4941
5334
|
if job:
|
|
4942
5335
|
try:
|
|
4943
5336
|
if ax_client:
|
|
4944
|
-
_trial =
|
|
4945
|
-
_trial
|
|
5337
|
+
_trial = get_ax_client_trial(trial_index)
|
|
5338
|
+
if _trial is None:
|
|
5339
|
+
return False
|
|
5340
|
+
|
|
5341
|
+
mark_abandoned(_trial, reason, trial_index)
|
|
4946
5342
|
print_debug(f"abandon_job: removing job {job}, trial_index: {trial_index}")
|
|
4947
5343
|
global_vars["jobs"].remove((job, trial_index))
|
|
4948
5344
|
else:
|
|
4949
|
-
_fatal_error("ax_client could not be found",
|
|
5345
|
+
_fatal_error("ax_client could not be found", 101)
|
|
4950
5346
|
except Exception as e:
|
|
4951
|
-
|
|
4952
|
-
print_debug(f"ERROR in line {get_line_info()}: {e}")
|
|
5347
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
4953
5348
|
return False
|
|
4954
5349
|
job.cancel()
|
|
4955
5350
|
return True
|
|
@@ -4962,23 +5357,120 @@ def abandon_all_jobs() -> None:
|
|
|
4962
5357
|
if not abandoned:
|
|
4963
5358
|
print_debug(f"Job {job} could not be abandoned.")
|
|
4964
5359
|
|
|
4965
|
-
def
|
|
4966
|
-
if
|
|
4967
|
-
|
|
5360
|
+
def write_result_to_trace_file(res: str) -> bool:
|
|
5361
|
+
if res is None:
|
|
5362
|
+
sys.stderr.write("Provided result is None, nothing to write\n")
|
|
5363
|
+
return False
|
|
5364
|
+
|
|
5365
|
+
target_folder = get_current_run_folder()
|
|
5366
|
+
target_file = os.path.join(target_folder, "optimization_trace.html")
|
|
5367
|
+
|
|
5368
|
+
try:
|
|
5369
|
+
file_handle = open(target_file, "w", encoding="utf-8")
|
|
5370
|
+
except OSError as error:
|
|
5371
|
+
sys.stderr.write("Unable to open target file for writing\n")
|
|
5372
|
+
sys.stderr.write(str(error) + "\n")
|
|
5373
|
+
return False
|
|
5374
|
+
|
|
5375
|
+
try:
|
|
5376
|
+
written = file_handle.write(str(res))
|
|
5377
|
+
file_handle.flush()
|
|
5378
|
+
|
|
5379
|
+
if written == 0:
|
|
5380
|
+
sys.stderr.write("No data was written to the file\n")
|
|
5381
|
+
file_handle.close()
|
|
5382
|
+
return False
|
|
5383
|
+
except Exception as error:
|
|
5384
|
+
sys.stderr.write("Error occurred while writing to file\n")
|
|
5385
|
+
sys.stderr.write(str(error) + "\n")
|
|
5386
|
+
file_handle.close()
|
|
5387
|
+
return False
|
|
5388
|
+
|
|
5389
|
+
try:
|
|
5390
|
+
file_handle.close()
|
|
5391
|
+
except Exception as error:
|
|
5392
|
+
sys.stderr.write("Failed to properly close file\n")
|
|
5393
|
+
sys.stderr.write(str(error) + "\n")
|
|
5394
|
+
return False
|
|
5395
|
+
|
|
5396
|
+
return True
|
|
5397
|
+
|
|
5398
|
+
def render(plot_config: AxPlotConfig) -> None:
|
|
5399
|
+
if plot_config is None or "data" not in plot_config:
|
|
4968
5400
|
return None
|
|
4969
5401
|
|
|
4970
|
-
|
|
4971
|
-
|
|
4972
|
-
|
|
4973
|
-
|
|
4974
|
-
|
|
4975
|
-
|
|
4976
|
-
|
|
5402
|
+
res: str = plot_config.data # type: ignore
|
|
5403
|
+
|
|
5404
|
+
repair_funcs = """
|
|
5405
|
+
function decodeBData(obj) {
|
|
5406
|
+
if (!obj || typeof obj !== "object") {
|
|
5407
|
+
return obj;
|
|
5408
|
+
}
|
|
5409
|
+
|
|
5410
|
+
if (obj.bdata && obj.dtype) {
|
|
5411
|
+
var binary_string = atob(obj.bdata);
|
|
5412
|
+
var len = binary_string.length;
|
|
5413
|
+
var bytes = new Uint8Array(len);
|
|
5414
|
+
|
|
5415
|
+
for (var i = 0; i < len; i++) {
|
|
5416
|
+
bytes[i] = binary_string.charCodeAt(i);
|
|
5417
|
+
}
|
|
5418
|
+
|
|
5419
|
+
switch (obj.dtype) {
|
|
5420
|
+
case "i1": return Array.from(new Int8Array(bytes.buffer));
|
|
5421
|
+
case "i2": return Array.from(new Int16Array(bytes.buffer));
|
|
5422
|
+
case "i4": return Array.from(new Int32Array(bytes.buffer));
|
|
5423
|
+
case "f4": return Array.from(new Float32Array(bytes.buffer));
|
|
5424
|
+
case "f8": return Array.from(new Float64Array(bytes.buffer));
|
|
5425
|
+
default:
|
|
5426
|
+
console.error("Unknown dtype:", obj.dtype);
|
|
5427
|
+
return [];
|
|
5428
|
+
}
|
|
5429
|
+
}
|
|
5430
|
+
|
|
5431
|
+
return obj;
|
|
5432
|
+
}
|
|
5433
|
+
|
|
5434
|
+
function repairTraces(traces) {
|
|
5435
|
+
var fixed = [];
|
|
5436
|
+
|
|
5437
|
+
for (var i = 0; i < traces.length; i++) {
|
|
5438
|
+
var t = traces[i];
|
|
5439
|
+
|
|
5440
|
+
if (t.x) {
|
|
5441
|
+
t.x = decodeBData(t.x);
|
|
5442
|
+
}
|
|
5443
|
+
|
|
5444
|
+
if (t.y) {
|
|
5445
|
+
t.y = decodeBData(t.y);
|
|
5446
|
+
}
|
|
5447
|
+
|
|
5448
|
+
fixed.push(t);
|
|
5449
|
+
}
|
|
5450
|
+
|
|
5451
|
+
return fixed;
|
|
5452
|
+
}
|
|
5453
|
+
"""
|
|
5454
|
+
|
|
5455
|
+
res = str(res)
|
|
5456
|
+
|
|
5457
|
+
res = f"<div id='plot' style='width:100%;height:600px;'></div>\n<script type='text/javascript' src='https://cdn.plot.ly/plotly-latest.min.js'></script><script>{repair_funcs}\nconst True = true;\nconst False = false;\nconst data = {res};\ndata.data = repairTraces(data.data);\nPlotly.newPlot(document.getElementById('plot'), data.data, data.layout);</script>"
|
|
5458
|
+
|
|
5459
|
+
write_result_to_trace_file(res)
|
|
5460
|
+
|
|
4977
5461
|
return None
|
|
4978
5462
|
|
|
4979
5463
|
def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None) -> None:
|
|
4980
5464
|
global END_PROGRAM_RAN
|
|
4981
5465
|
|
|
5466
|
+
#dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.training_data[0].X)
|
|
5467
|
+
#dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.training_data[0].Y)
|
|
5468
|
+
#dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.outcomes)
|
|
5469
|
+
|
|
5470
|
+
if ax_client is not None:
|
|
5471
|
+
if len(arg_result_names) == 1:
|
|
5472
|
+
render(ax_client.get_optimization_trace())
|
|
5473
|
+
|
|
4982
5474
|
wait_for_jobs_to_complete()
|
|
4983
5475
|
|
|
4984
5476
|
show_pareto_or_error_msg(get_current_run_folder(), arg_result_names)
|
|
@@ -5012,7 +5504,7 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
|
|
|
5012
5504
|
_exit = new_exit
|
|
5013
5505
|
except (SignalUSR, SignalINT, SignalCONT, KeyboardInterrupt):
|
|
5014
5506
|
print_red("\n⚠ You pressed CTRL+C or a signal was sent. Program execution halted while ending program.")
|
|
5015
|
-
|
|
5507
|
+
print_red("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
|
|
5016
5508
|
new_exit = show_end_table_and_save_end_files()
|
|
5017
5509
|
if new_exit > 0:
|
|
5018
5510
|
_exit = new_exit
|
|
@@ -5021,8 +5513,6 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
|
|
|
5021
5513
|
|
|
5022
5514
|
abandon_all_jobs()
|
|
5023
5515
|
|
|
5024
|
-
save_results_csv()
|
|
5025
|
-
|
|
5026
5516
|
if exit_code:
|
|
5027
5517
|
_exit = exit_code
|
|
5028
5518
|
|
|
@@ -5035,21 +5525,31 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
|
|
|
5035
5525
|
|
|
5036
5526
|
my_exit(_exit)
|
|
5037
5527
|
|
|
5528
|
+
def save_ax_client_to_json_file(checkpoint_filepath: str) -> None:
|
|
5529
|
+
if not ax_client:
|
|
5530
|
+
my_exit(101)
|
|
5531
|
+
|
|
5532
|
+
return None
|
|
5533
|
+
|
|
5534
|
+
ax_client.save_to_json_file(checkpoint_filepath)
|
|
5535
|
+
|
|
5536
|
+
return None
|
|
5537
|
+
|
|
5038
5538
|
def save_checkpoint(trial_nr: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
5039
5539
|
if trial_nr > 3:
|
|
5040
5540
|
if eee:
|
|
5041
|
-
|
|
5541
|
+
print_red(f"Error during saving checkpoint: {eee}")
|
|
5042
5542
|
else:
|
|
5043
|
-
|
|
5543
|
+
print_red("Error during saving checkpoint")
|
|
5044
5544
|
return
|
|
5045
5545
|
|
|
5046
5546
|
try:
|
|
5047
5547
|
checkpoint_filepath = get_state_file_name('checkpoint.json')
|
|
5048
5548
|
|
|
5049
5549
|
if ax_client:
|
|
5050
|
-
|
|
5550
|
+
save_ax_client_to_json_file(checkpoint_filepath)
|
|
5051
5551
|
else:
|
|
5052
|
-
_fatal_error("Something went wrong using the ax_client",
|
|
5552
|
+
_fatal_error("Something went wrong using the ax_client", 101)
|
|
5053
5553
|
except Exception as e:
|
|
5054
5554
|
save_checkpoint(trial_nr + 1, e)
|
|
5055
5555
|
|
|
@@ -5210,7 +5710,7 @@ def parse_equation_item(comparer_found: bool, item: str, parsed: list, parsed_or
|
|
|
5210
5710
|
})
|
|
5211
5711
|
elif item in [">=", "<="]:
|
|
5212
5712
|
if comparer_found:
|
|
5213
|
-
|
|
5713
|
+
print_red("There is already one comparison operator! Cannot have more than one in an equation!")
|
|
5214
5714
|
return_totally = True
|
|
5215
5715
|
comparer_found = True
|
|
5216
5716
|
|
|
@@ -5421,20 +5921,9 @@ def check_equation(variables: list, equation: str) -> Union[str, bool]:
|
|
|
5421
5921
|
def set_objectives() -> dict:
|
|
5422
5922
|
objectives = {}
|
|
5423
5923
|
|
|
5424
|
-
|
|
5425
|
-
|
|
5426
|
-
|
|
5427
|
-
if "=" in rn:
|
|
5428
|
-
key, value = rn.split('=', 1)
|
|
5429
|
-
else:
|
|
5430
|
-
key = rn
|
|
5431
|
-
value = ""
|
|
5432
|
-
|
|
5433
|
-
if value not in ["min", "max"]:
|
|
5434
|
-
if value:
|
|
5435
|
-
print_yellow(f"Value '{value}' for --result_names {rn} is not a valid value. Must be min or max. Will be set to min.")
|
|
5436
|
-
|
|
5437
|
-
value = "min"
|
|
5924
|
+
k = 0
|
|
5925
|
+
for key in arg_result_names:
|
|
5926
|
+
value = arg_result_min_or_max[k]
|
|
5438
5927
|
|
|
5439
5928
|
_min = True
|
|
5440
5929
|
|
|
@@ -5442,12 +5931,18 @@ def set_objectives() -> dict:
|
|
|
5442
5931
|
_min = False
|
|
5443
5932
|
|
|
5444
5933
|
objectives[key] = ObjectiveProperties(minimize=_min)
|
|
5934
|
+
k = k + 1
|
|
5445
5935
|
|
|
5446
5936
|
return objectives
|
|
5447
5937
|
|
|
5448
|
-
def set_experiment_constraints(experiment_constraints: Optional[list], experiment_args: dict, _experiment_parameters: Union[dict, list]) -> dict:
|
|
5449
|
-
if
|
|
5938
|
+
def set_experiment_constraints(experiment_constraints: Optional[list], experiment_args: dict, _experiment_parameters: Optional[Union[dict, list]]) -> dict:
|
|
5939
|
+
if _experiment_parameters is None:
|
|
5940
|
+
print_red("set_experiment_constraints: _experiment_parameters was None")
|
|
5941
|
+
my_exit(95)
|
|
5942
|
+
|
|
5943
|
+
return {}
|
|
5450
5944
|
|
|
5945
|
+
if experiment_constraints and len(experiment_constraints):
|
|
5451
5946
|
experiment_args["parameter_constraints"] = []
|
|
5452
5947
|
|
|
5453
5948
|
if experiment_constraints:
|
|
@@ -5476,11 +5971,15 @@ def set_experiment_constraints(experiment_constraints: Optional[list], experimen
|
|
|
5476
5971
|
|
|
5477
5972
|
return experiment_args
|
|
5478
5973
|
|
|
5479
|
-
def replace_parameters_for_continued_jobs(parameter: Optional[list], cli_params_experiment_parameters: Optional[list]) -> None:
|
|
5974
|
+
def replace_parameters_for_continued_jobs(parameter: Optional[list], cli_params_experiment_parameters: Optional[dict | list]) -> None:
|
|
5975
|
+
if not experiment_parameters:
|
|
5976
|
+
print_red("replace_parameters_for_continued_jobs: experiment_parameters was False")
|
|
5977
|
+
return None
|
|
5978
|
+
|
|
5480
5979
|
if args.worker_generator_path:
|
|
5481
5980
|
return None
|
|
5482
5981
|
|
|
5483
|
-
def get_name(obj) -> Optional[str]:
|
|
5982
|
+
def get_name(obj: Any) -> Optional[str]:
|
|
5484
5983
|
"""Extract a parameter name from dict, list, or tuple safely."""
|
|
5485
5984
|
if isinstance(obj, dict):
|
|
5486
5985
|
return obj.get("name")
|
|
@@ -5562,13 +6061,13 @@ def copy_continue_uuid() -> None:
|
|
|
5562
6061
|
print_debug(f"copy_continue_uuid: Source file does not exist: {source_file}")
|
|
5563
6062
|
|
|
5564
6063
|
def load_ax_client_from_experiment_parameters() -> None:
|
|
5565
|
-
|
|
5566
|
-
|
|
6064
|
+
if experiment_parameters:
|
|
6065
|
+
global ax_client
|
|
5567
6066
|
|
|
5568
|
-
|
|
5569
|
-
|
|
5570
|
-
|
|
5571
|
-
|
|
6067
|
+
tmp_file_path = get_tmp_file_from_json(experiment_parameters)
|
|
6068
|
+
ax_client = AxClient.load_from_json_file(tmp_file_path)
|
|
6069
|
+
ax_client = cast(AxClient, ax_client)
|
|
6070
|
+
os.unlink(tmp_file_path)
|
|
5572
6071
|
|
|
5573
6072
|
def save_checkpoint_for_continued() -> None:
|
|
5574
6073
|
checkpoint_filepath = get_state_file_name('checkpoint.json')
|
|
@@ -5580,12 +6079,15 @@ def save_checkpoint_for_continued() -> None:
|
|
|
5580
6079
|
_fatal_error(f"{checkpoint_filepath} not found. Cannot continue_previous_job without.", 47)
|
|
5581
6080
|
|
|
5582
6081
|
def load_original_generation_strategy(original_ax_client_file: str) -> None:
|
|
5583
|
-
|
|
5584
|
-
|
|
5585
|
-
|
|
6082
|
+
if experiment_parameters:
|
|
6083
|
+
with open(original_ax_client_file, encoding="utf-8") as f:
|
|
6084
|
+
loaded_original_ax_client_json = json.load(f)
|
|
6085
|
+
original_generation_strategy = loaded_original_ax_client_json["generation_strategy"]
|
|
5586
6086
|
|
|
5587
|
-
|
|
5588
|
-
|
|
6087
|
+
if original_generation_strategy:
|
|
6088
|
+
experiment_parameters["generation_strategy"] = original_generation_strategy
|
|
6089
|
+
else:
|
|
6090
|
+
print_red("load_original_generation_strategy: experiment_parameters was empty!")
|
|
5589
6091
|
|
|
5590
6092
|
def wait_for_checkpoint_file(checkpoint_file: str) -> None:
|
|
5591
6093
|
start_time = time.time()
|
|
@@ -5598,10 +6100,6 @@ def wait_for_checkpoint_file(checkpoint_file: str) -> None:
|
|
|
5598
6100
|
elapsed = int(time.time() - start_time)
|
|
5599
6101
|
console.print(f"[green]Checkpoint file found after {elapsed} seconds[/green] ")
|
|
5600
6102
|
|
|
5601
|
-
def __get_experiment_parameters__check_ax_client() -> None:
|
|
5602
|
-
if not ax_client:
|
|
5603
|
-
_fatal_error("Something went wrong with the ax_client", 9)
|
|
5604
|
-
|
|
5605
6103
|
def validate_experiment_parameters() -> None:
|
|
5606
6104
|
if experiment_parameters is None:
|
|
5607
6105
|
print_red("Error: experiment_parameters is None.")
|
|
@@ -5611,6 +6109,8 @@ def validate_experiment_parameters() -> None:
|
|
|
5611
6109
|
print_red(f"Error: experiment_parameters is not a dict: {type(experiment_parameters).__name__}")
|
|
5612
6110
|
my_exit(95)
|
|
5613
6111
|
|
|
6112
|
+
sys.exit(95)
|
|
6113
|
+
|
|
5614
6114
|
path_checks = [
|
|
5615
6115
|
("experiment", experiment_parameters),
|
|
5616
6116
|
("search_space", experiment_parameters.get("experiment")),
|
|
@@ -5622,7 +6122,12 @@ def validate_experiment_parameters() -> None:
|
|
|
5622
6122
|
print_red(f"Error: Missing key '{key}' at level: {current_level}")
|
|
5623
6123
|
my_exit(95)
|
|
5624
6124
|
|
|
5625
|
-
def
|
|
6125
|
+
def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_parameters: Optional[dict | list]) -> Tuple[Any, str, str]:
|
|
6126
|
+
if not ax_client:
|
|
6127
|
+
print_red("load_from_checkpoint: ax_client was None")
|
|
6128
|
+
my_exit(101)
|
|
6129
|
+
return {}, "", ""
|
|
6130
|
+
|
|
5626
6131
|
print_debug(f"Load from checkpoint: {continue_previous_job}")
|
|
5627
6132
|
|
|
5628
6133
|
checkpoint_file = f"{continue_previous_job}/state_files/checkpoint.json"
|
|
@@ -5648,7 +6153,7 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
|
|
|
5648
6153
|
|
|
5649
6154
|
replace_parameters_for_continued_jobs(args.parameter, cli_params_experiment_parameters)
|
|
5650
6155
|
|
|
5651
|
-
|
|
6156
|
+
save_ax_client_to_json_file(original_ax_client_file)
|
|
5652
6157
|
|
|
5653
6158
|
load_original_generation_strategy(original_ax_client_file)
|
|
5654
6159
|
load_ax_client_from_experiment_parameters()
|
|
@@ -5664,6 +6169,12 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
|
|
|
5664
6169
|
|
|
5665
6170
|
experiment_constraints = get_constraints()
|
|
5666
6171
|
if experiment_constraints:
|
|
6172
|
+
|
|
6173
|
+
if not experiment_parameters:
|
|
6174
|
+
print_red("load_from_checkpoint: experiment_parameters was None")
|
|
6175
|
+
|
|
6176
|
+
return {}, "", ""
|
|
6177
|
+
|
|
5667
6178
|
experiment_args = set_experiment_constraints(
|
|
5668
6179
|
experiment_constraints,
|
|
5669
6180
|
experiment_args,
|
|
@@ -5672,7 +6183,116 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
|
|
|
5672
6183
|
|
|
5673
6184
|
return experiment_args, gpu_string, gpu_color
|
|
5674
6185
|
|
|
5675
|
-
def
|
|
6186
|
+
def get_experiment_args_import_python_script() -> str:
|
|
6187
|
+
|
|
6188
|
+
return """from ax.service.ax_client import AxClient, ObjectiveProperties
|
|
6189
|
+
from ax.adapter.registry import Generators
|
|
6190
|
+
import random
|
|
6191
|
+
|
|
6192
|
+
"""
|
|
6193
|
+
|
|
6194
|
+
def get_generate_and_test_random_function_str() -> str:
|
|
6195
|
+
raw_data_entries = ",\n ".join(
|
|
6196
|
+
f'"{name}": random.uniform(0, 1)' for name in arg_result_names
|
|
6197
|
+
)
|
|
6198
|
+
|
|
6199
|
+
return f"""
|
|
6200
|
+
def generate_and_test_random_parameters(n: int) -> None:
|
|
6201
|
+
for _ in range(n):
|
|
6202
|
+
print("======================================")
|
|
6203
|
+
parameters, trial_index = ax_client.get_next_trial()
|
|
6204
|
+
print("Trial Index:", trial_index)
|
|
6205
|
+
print("Suggested parameters:", parameters)
|
|
6206
|
+
|
|
6207
|
+
ax_client.complete_trial(
|
|
6208
|
+
trial_index=trial_index,
|
|
6209
|
+
raw_data={{
|
|
6210
|
+
{raw_data_entries}
|
|
6211
|
+
}}
|
|
6212
|
+
)
|
|
6213
|
+
|
|
6214
|
+
generate_and_test_random_parameters({args.num_random_steps + 1})
|
|
6215
|
+
"""
|
|
6216
|
+
|
|
6217
|
+
def get_global_gs_string() -> str:
|
|
6218
|
+
seed_str = ""
|
|
6219
|
+
if args.seed is not None:
|
|
6220
|
+
seed_str = f"model_kwargs={{'seed': {args.seed}}},"
|
|
6221
|
+
|
|
6222
|
+
return f"""from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy
|
|
6223
|
+
|
|
6224
|
+
global_gs = GenerationStrategy(
|
|
6225
|
+
steps=[
|
|
6226
|
+
GenerationStep(
|
|
6227
|
+
generator=Generators.SOBOL,
|
|
6228
|
+
num_trials={args.num_random_steps},
|
|
6229
|
+
max_parallelism=5,
|
|
6230
|
+
{seed_str}
|
|
6231
|
+
),
|
|
6232
|
+
GenerationStep(
|
|
6233
|
+
generator=Generators.{args.model},
|
|
6234
|
+
num_trials=-1,
|
|
6235
|
+
max_parallelism=5,
|
|
6236
|
+
),
|
|
6237
|
+
]
|
|
6238
|
+
)
|
|
6239
|
+
"""
|
|
6240
|
+
|
|
6241
|
+
def get_debug_ax_client_str() -> str:
|
|
6242
|
+
return """
|
|
6243
|
+
ax_client = AxClient(
|
|
6244
|
+
verbose_logging=True,
|
|
6245
|
+
enforce_sequential_optimization=False,
|
|
6246
|
+
generation_strategy=global_gs
|
|
6247
|
+
)
|
|
6248
|
+
"""
|
|
6249
|
+
|
|
6250
|
+
def write_ax_debug_python_code(experiment_args: dict) -> None:
|
|
6251
|
+
if args.generation_strategy:
|
|
6252
|
+
print_debug("Cannot write debug code for custom generation_strategy")
|
|
6253
|
+
return None
|
|
6254
|
+
|
|
6255
|
+
if args.model in uncontinuable_models:
|
|
6256
|
+
print_debug(f"Cannot write debug code for uncontinuable mode {args.model}")
|
|
6257
|
+
return None
|
|
6258
|
+
|
|
6259
|
+
python_code = python_code = get_experiment_args_import_python_script() + \
|
|
6260
|
+
get_global_gs_string() + \
|
|
6261
|
+
get_debug_ax_client_str() + \
|
|
6262
|
+
"experiment_args = " + pformat(experiment_args, width=120, compact=False) + \
|
|
6263
|
+
"\nax_client.create_experiment(**experiment_args)\n" + \
|
|
6264
|
+
get_generate_and_test_random_function_str()
|
|
6265
|
+
|
|
6266
|
+
file_path = f"{get_current_run_folder()}/debug.py"
|
|
6267
|
+
|
|
6268
|
+
try:
|
|
6269
|
+
print_debug(python_code)
|
|
6270
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
6271
|
+
f.write(python_code)
|
|
6272
|
+
except Exception as e:
|
|
6273
|
+
print_red(f"Error while writing {file_path}: {e}")
|
|
6274
|
+
|
|
6275
|
+
return None
|
|
6276
|
+
|
|
6277
|
+
def create_ax_client_experiment(experiment_args: dict) -> None:
|
|
6278
|
+
if not ax_client:
|
|
6279
|
+
my_exit(101)
|
|
6280
|
+
|
|
6281
|
+
return None
|
|
6282
|
+
|
|
6283
|
+
write_ax_debug_python_code(experiment_args)
|
|
6284
|
+
|
|
6285
|
+
ax_client.create_experiment(**experiment_args)
|
|
6286
|
+
|
|
6287
|
+
return None
|
|
6288
|
+
|
|
6289
|
+
def create_new_experiment() -> Tuple[dict, str, str]:
|
|
6290
|
+
if ax_client is None:
|
|
6291
|
+
print_red("create_new_experiment: ax_client is None")
|
|
6292
|
+
my_exit(101)
|
|
6293
|
+
|
|
6294
|
+
return {}, "", ""
|
|
6295
|
+
|
|
5676
6296
|
objectives = set_objectives()
|
|
5677
6297
|
|
|
5678
6298
|
experiment_args = {
|
|
@@ -5695,7 +6315,7 @@ def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str
|
|
|
5695
6315
|
experiment_args = set_experiment_constraints(get_constraints(), experiment_args, experiment_parameters)
|
|
5696
6316
|
|
|
5697
6317
|
try:
|
|
5698
|
-
|
|
6318
|
+
create_ax_client_experiment(experiment_args)
|
|
5699
6319
|
new_metrics = [Metric(k) for k in arg_result_names if k not in ax_client.metric_names]
|
|
5700
6320
|
ax_client.experiment.add_tracking_metrics(new_metrics)
|
|
5701
6321
|
except AssertionError as error:
|
|
@@ -5709,17 +6329,17 @@ def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str
|
|
|
5709
6329
|
|
|
5710
6330
|
return experiment_args, gpu_string, gpu_color
|
|
5711
6331
|
|
|
5712
|
-
def get_experiment_parameters(cli_params_experiment_parameters: Optional[list]) -> Optional[Tuple[
|
|
6332
|
+
def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[dict, str, str]]:
|
|
5713
6333
|
continue_previous_job = args.worker_generator_path or args.continue_previous_job
|
|
5714
6334
|
|
|
5715
|
-
|
|
6335
|
+
check_ax_client()
|
|
5716
6336
|
|
|
5717
6337
|
if continue_previous_job:
|
|
5718
|
-
experiment_args, gpu_string, gpu_color =
|
|
6338
|
+
experiment_args, gpu_string, gpu_color = load_from_checkpoint(continue_previous_job, cli_params_experiment_parameters)
|
|
5719
6339
|
else:
|
|
5720
|
-
experiment_args, gpu_string, gpu_color =
|
|
6340
|
+
experiment_args, gpu_string, gpu_color = create_new_experiment()
|
|
5721
6341
|
|
|
5722
|
-
return
|
|
6342
|
+
return experiment_args, gpu_string, gpu_color
|
|
5723
6343
|
|
|
5724
6344
|
def get_type_short(typename: str) -> str:
|
|
5725
6345
|
if typename == "RangeParameter":
|
|
@@ -5778,7 +6398,6 @@ def parse_single_experiment_parameter_table(classic_params: Optional[Union[list,
|
|
|
5778
6398
|
_upper = param["bounds"][1]
|
|
5779
6399
|
|
|
5780
6400
|
_possible_int_lower = str(helpers.to_int_when_possible(_lower))
|
|
5781
|
-
#print(f"name: {_name}, _possible_int_lower: {_possible_int_lower}, lower: {_lower}")
|
|
5782
6401
|
_possible_int_upper = str(helpers.to_int_when_possible(_upper))
|
|
5783
6402
|
|
|
5784
6403
|
rows.append([_name, _short_type, _possible_int_lower, _possible_int_upper, "", value_type, log_scale])
|
|
@@ -5855,19 +6474,62 @@ def print_ax_parameter_constraints_table(experiment_args: dict) -> None:
|
|
|
5855
6474
|
|
|
5856
6475
|
return None
|
|
5857
6476
|
|
|
6477
|
+
def check_base_for_print_overview() -> Optional[bool]:
|
|
6478
|
+
if args.continue_previous_job is not None and arg_result_names is not None and len(arg_result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
|
|
6479
|
+
print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
|
|
6480
|
+
|
|
6481
|
+
if ax_client is None:
|
|
6482
|
+
print_red("ax_client was None")
|
|
6483
|
+
return None
|
|
6484
|
+
|
|
6485
|
+
if ax_client.experiment is None:
|
|
6486
|
+
print_red("ax_client.experiment was None")
|
|
6487
|
+
return None
|
|
6488
|
+
|
|
6489
|
+
if ax_client.experiment.optimization_config is None:
|
|
6490
|
+
print_red("ax_client.experiment.optimization_config was None")
|
|
6491
|
+
return None
|
|
6492
|
+
|
|
6493
|
+
return True
|
|
6494
|
+
|
|
6495
|
+
def get_config_objectives() -> Any:
|
|
6496
|
+
if not ax_client:
|
|
6497
|
+
print_red("create_new_experiment: ax_client is None")
|
|
6498
|
+
my_exit(101)
|
|
6499
|
+
|
|
6500
|
+
return None
|
|
6501
|
+
|
|
6502
|
+
config_objectives = None
|
|
6503
|
+
|
|
6504
|
+
if ax_client.experiment and ax_client.experiment.optimization_config:
|
|
6505
|
+
opt_config = ax_client.experiment.optimization_config
|
|
6506
|
+
if opt_config.is_moo_problem:
|
|
6507
|
+
objective = getattr(opt_config, "objective", None)
|
|
6508
|
+
if objective and getattr(objective, "objectives", None) is not None:
|
|
6509
|
+
config_objectives = objective.objectives
|
|
6510
|
+
else:
|
|
6511
|
+
print_debug("ax_client.experiment.optimization_config.objective was None")
|
|
6512
|
+
else:
|
|
6513
|
+
config_objectives = [opt_config.objective]
|
|
6514
|
+
else:
|
|
6515
|
+
print_debug("ax_client.experiment or optimization_config was None")
|
|
6516
|
+
|
|
6517
|
+
return config_objectives
|
|
6518
|
+
|
|
5858
6519
|
def print_result_names_overview_table() -> None:
|
|
5859
6520
|
if not ax_client:
|
|
5860
6521
|
_fatal_error("Tried to access ax_client in print_result_names_overview_table, but it failed, because the ax_client was not defined.", 101)
|
|
5861
6522
|
|
|
5862
6523
|
return None
|
|
5863
6524
|
|
|
5864
|
-
if
|
|
5865
|
-
|
|
6525
|
+
if check_base_for_print_overview() is None:
|
|
6526
|
+
return None
|
|
5866
6527
|
|
|
5867
|
-
|
|
5868
|
-
|
|
5869
|
-
|
|
5870
|
-
config_objectives
|
|
6528
|
+
config_objectives = get_config_objectives()
|
|
6529
|
+
|
|
6530
|
+
if config_objectives is None:
|
|
6531
|
+
print_red("config_objectives not found")
|
|
6532
|
+
return None
|
|
5871
6533
|
|
|
5872
6534
|
res_names = []
|
|
5873
6535
|
res_min_max = []
|
|
@@ -5962,28 +6624,31 @@ def print_overview_tables(classic_params: Optional[Union[list, dict]], experimen
|
|
|
5962
6624
|
print_result_names_overview_table()
|
|
5963
6625
|
|
|
5964
6626
|
def update_progress_bar(nr: int) -> None:
|
|
5965
|
-
|
|
5966
|
-
|
|
5967
|
-
|
|
5968
|
-
|
|
6627
|
+
log_data()
|
|
6628
|
+
|
|
6629
|
+
if progress_bar is not None:
|
|
6630
|
+
try:
|
|
6631
|
+
progress_bar.update(nr)
|
|
6632
|
+
except Exception as e:
|
|
6633
|
+
print_red(f"Error updating progress bar: {e}")
|
|
6634
|
+
else:
|
|
6635
|
+
print_red("update_progress_bar: progress_bar was None")
|
|
5969
6636
|
|
|
5970
6637
|
def get_current_model_name() -> str:
|
|
5971
6638
|
if overwritten_to_random:
|
|
5972
6639
|
return "Random*"
|
|
5973
6640
|
|
|
6641
|
+
gs_model = "unknown model"
|
|
6642
|
+
|
|
5974
6643
|
if ax_client:
|
|
5975
6644
|
try:
|
|
5976
6645
|
if args.generation_strategy:
|
|
5977
|
-
idx = getattr(
|
|
6646
|
+
idx = getattr(global_gs, "current_step_index", None)
|
|
5978
6647
|
if isinstance(idx, int):
|
|
5979
6648
|
if 0 <= idx < len(generation_strategy_names):
|
|
5980
6649
|
gs_model = generation_strategy_names[int(idx)]
|
|
5981
|
-
else:
|
|
5982
|
-
gs_model = "unknown model"
|
|
5983
|
-
else:
|
|
5984
|
-
gs_model = "unknown model"
|
|
5985
6650
|
else:
|
|
5986
|
-
gs_model = getattr(
|
|
6651
|
+
gs_model = getattr(global_gs, "current_node_name", "unknown model")
|
|
5987
6652
|
|
|
5988
6653
|
if gs_model:
|
|
5989
6654
|
return str(gs_model)
|
|
@@ -6040,7 +6705,7 @@ def get_workers_string() -> str:
|
|
|
6040
6705
|
)
|
|
6041
6706
|
|
|
6042
6707
|
def _get_workers_string_collect_stats() -> dict:
|
|
6043
|
-
stats = {}
|
|
6708
|
+
stats: dict = {}
|
|
6044
6709
|
for job, _ in global_vars["jobs"][:]:
|
|
6045
6710
|
state = state_from_job(job)
|
|
6046
6711
|
stats[state] = stats.get(state, 0) + 1
|
|
@@ -6089,8 +6754,8 @@ def submitted_jobs(nr: int = 0) -> int:
|
|
|
6089
6754
|
def count_jobs_in_squeue() -> tuple[int, str]:
|
|
6090
6755
|
global _last_count_time, _last_count_result
|
|
6091
6756
|
|
|
6092
|
-
now = time.time()
|
|
6093
|
-
if _last_count_result != (0, "") and now - _last_count_time <
|
|
6757
|
+
now = int(time.time())
|
|
6758
|
+
if _last_count_result != (0, "") and now - _last_count_time < 5:
|
|
6094
6759
|
return _last_count_result
|
|
6095
6760
|
|
|
6096
6761
|
_len = len(global_vars["jobs"])
|
|
@@ -6112,6 +6777,7 @@ def count_jobs_in_squeue() -> tuple[int, str]:
|
|
|
6112
6777
|
check=True,
|
|
6113
6778
|
text=True
|
|
6114
6779
|
)
|
|
6780
|
+
|
|
6115
6781
|
if "slurm_load_jobs error" in result.stderr:
|
|
6116
6782
|
_last_count_result = (_len, "Detected slurm_load_jobs error in stderr.")
|
|
6117
6783
|
_last_count_time = now
|
|
@@ -6152,6 +6818,8 @@ def log_worker_numbers() -> None:
|
|
|
6152
6818
|
if len(WORKER_PERCENTAGE_USAGE) == 0 or WORKER_PERCENTAGE_USAGE[len(WORKER_PERCENTAGE_USAGE) - 1] != this_values:
|
|
6153
6819
|
WORKER_PERCENTAGE_USAGE.append(this_values)
|
|
6154
6820
|
|
|
6821
|
+
write_worker_usage()
|
|
6822
|
+
|
|
6155
6823
|
def get_slurm_in_brackets(in_brackets: list) -> list:
|
|
6156
6824
|
if is_slurm_job():
|
|
6157
6825
|
workers_strings = get_workers_string()
|
|
@@ -6236,6 +6904,8 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
|
|
|
6236
6904
|
global last_progress_bar_desc
|
|
6237
6905
|
global last_progress_bar_refresh_time
|
|
6238
6906
|
|
|
6907
|
+
log_data()
|
|
6908
|
+
|
|
6239
6909
|
if isinstance(new_msgs, str):
|
|
6240
6910
|
new_msgs = [new_msgs]
|
|
6241
6911
|
|
|
@@ -6253,7 +6923,7 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
|
|
|
6253
6923
|
print_red("Cannot update progress bar! It is None.")
|
|
6254
6924
|
|
|
6255
6925
|
def clean_completed_jobs() -> None:
|
|
6256
|
-
job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail"]
|
|
6926
|
+
job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail", "finished"]
|
|
6257
6927
|
job_states_to_be_ignored = ["ready", "completed", "unknown", "pending", "running", "completing", "out_of_memory", "requeued", "resv_del_hold"]
|
|
6258
6928
|
|
|
6259
6929
|
for job, trial_index in global_vars["jobs"][:]:
|
|
@@ -6311,7 +6981,7 @@ def load_existing_job_data_into_ax_client() -> None:
|
|
|
6311
6981
|
nr_of_imported_jobs = get_nr_of_imported_jobs()
|
|
6312
6982
|
set_nr_inserted_jobs(NR_INSERTED_JOBS + nr_of_imported_jobs)
|
|
6313
6983
|
|
|
6314
|
-
def parse_parameter_type_error(_error_message: Union[str, None]) -> Optional[dict]:
|
|
6984
|
+
def parse_parameter_type_error(_error_message: Union[Exception, str, None]) -> Optional[dict]:
|
|
6315
6985
|
if not _error_message:
|
|
6316
6986
|
return None
|
|
6317
6987
|
|
|
@@ -6378,7 +7048,7 @@ def get_generation_node_for_index(
|
|
|
6378
7048
|
results_list: List[Dict[str, Any]],
|
|
6379
7049
|
index: int,
|
|
6380
7050
|
__status: Any,
|
|
6381
|
-
base_str: str
|
|
7051
|
+
base_str: Optional[str]
|
|
6382
7052
|
) -> str:
|
|
6383
7053
|
__status.update(f"{base_str}: Getting generation node")
|
|
6384
7054
|
try:
|
|
@@ -6398,7 +7068,7 @@ def get_generation_node_for_index(
|
|
|
6398
7068
|
|
|
6399
7069
|
return generation_node
|
|
6400
7070
|
except Exception as e:
|
|
6401
|
-
|
|
7071
|
+
print_red(f"Error while get_generation_node_for_index: {e}")
|
|
6402
7072
|
return "MANUAL"
|
|
6403
7073
|
|
|
6404
7074
|
def _get_generation_node_for_index_index_valid(
|
|
@@ -6461,111 +7131,154 @@ def _get_generation_node_for_index_floats_match(
|
|
|
6461
7131
|
return False
|
|
6462
7132
|
return abs(row_val_num - val) <= tolerance
|
|
6463
7133
|
|
|
7134
|
+
def validate_and_convert_params_for_jobs_from_csv(arm_params: Dict) -> Dict:
|
|
7135
|
+
corrected_params: Dict[Any, Any] = {}
|
|
7136
|
+
|
|
7137
|
+
if experiment_parameters is not None:
|
|
7138
|
+
for param in experiment_parameters:
|
|
7139
|
+
name = param["name"]
|
|
7140
|
+
expected_type = param.get("value_type", "str")
|
|
7141
|
+
|
|
7142
|
+
if name not in arm_params:
|
|
7143
|
+
continue
|
|
7144
|
+
|
|
7145
|
+
value = arm_params[name]
|
|
7146
|
+
|
|
7147
|
+
try:
|
|
7148
|
+
if param["type"] == "range":
|
|
7149
|
+
if expected_type == "int":
|
|
7150
|
+
corrected_params[name] = int(value)
|
|
7151
|
+
elif expected_type == "float":
|
|
7152
|
+
corrected_params[name] = float(value)
|
|
7153
|
+
elif param["type"] == "choice":
|
|
7154
|
+
corrected_params[name] = str(value)
|
|
7155
|
+
except (ValueError, TypeError):
|
|
7156
|
+
corrected_params[name] = None
|
|
7157
|
+
|
|
7158
|
+
return corrected_params
|
|
7159
|
+
|
|
6464
7160
|
def insert_jobs_from_csv(this_csv_file_path: str) -> None:
|
|
6465
7161
|
with spinner(f"Inserting job into CSV from {this_csv_file_path}") as __status:
|
|
6466
|
-
this_csv_file_path = this_csv_file_path
|
|
7162
|
+
this_csv_file_path = normalize_path(this_csv_file_path)
|
|
6467
7163
|
|
|
6468
|
-
if not
|
|
7164
|
+
if not helpers.file_exists(this_csv_file_path):
|
|
6469
7165
|
print_red(f"--load_data_from_existing_jobs: Cannot find {this_csv_file_path}")
|
|
6470
|
-
|
|
6471
7166
|
return
|
|
6472
7167
|
|
|
6473
|
-
|
|
6474
|
-
|
|
7168
|
+
arm_params_list, results_list = parse_csv(this_csv_file_path)
|
|
7169
|
+
insert_jobs_from_lists(this_csv_file_path, arm_params_list, results_list, __status)
|
|
6475
7170
|
|
|
6476
|
-
|
|
6477
|
-
|
|
6478
|
-
name = param["name"]
|
|
6479
|
-
expected_type = param.get("value_type", "str")
|
|
7171
|
+
def normalize_path(file_path: str) -> str:
|
|
7172
|
+
return file_path.replace("//", "/")
|
|
6480
7173
|
|
|
6481
|
-
|
|
6482
|
-
|
|
7174
|
+
def insert_jobs_from_lists(csv_path: str, arm_params_list: Any, results_list: Any, __status: Any) -> None:
|
|
7175
|
+
cnt = 0
|
|
7176
|
+
err_msgs: list = []
|
|
6483
7177
|
|
|
6484
|
-
|
|
7178
|
+
for i, (arm_params, result) in enumerate(zip(arm_params_list, results_list)):
|
|
7179
|
+
base_str = f"[bold green]Loading job {i}/{len(results_list)} from {csv_path} into ax_client, result: {result}"
|
|
7180
|
+
__status.update(base_str)
|
|
6485
7181
|
|
|
6486
|
-
|
|
6487
|
-
|
|
6488
|
-
if expected_type == "int":
|
|
6489
|
-
corrected_params[name] = int(value)
|
|
6490
|
-
elif expected_type == "float":
|
|
6491
|
-
corrected_params[name] = float(value)
|
|
6492
|
-
elif param["type"] == "choice":
|
|
6493
|
-
corrected_params[name] = str(value)
|
|
6494
|
-
except (ValueError, TypeError):
|
|
6495
|
-
corrected_params[name] = None
|
|
6496
|
-
|
|
6497
|
-
return corrected_params
|
|
7182
|
+
if not args.worker_generator_path:
|
|
7183
|
+
arm_params = validate_and_convert_params_for_jobs_from_csv(arm_params)
|
|
6498
7184
|
|
|
6499
|
-
arm_params_list, results_list
|
|
7185
|
+
cnt = try_insert_job(csv_path, arm_params, result, i, arm_params_list, results_list, __status, base_str, cnt, err_msgs)
|
|
6500
7186
|
|
|
6501
|
-
|
|
7187
|
+
summarize_insertions(csv_path, cnt)
|
|
7188
|
+
update_global_job_counters(cnt)
|
|
6502
7189
|
|
|
6503
|
-
|
|
7190
|
+
def try_insert_job(csv_path: str, arm_params: Dict, result: Any, i: int, arm_params_list: Any, results_list: Any, __status: Any, base_str: Optional[str], cnt: int, err_msgs: Optional[Union[str, list[str]]]) -> int:
|
|
7191
|
+
try:
|
|
7192
|
+
gen_node_name = get_generation_node_for_index(csv_path, arm_params_list, results_list, i, __status, base_str)
|
|
6504
7193
|
|
|
6505
|
-
|
|
6506
|
-
|
|
6507
|
-
|
|
6508
|
-
__status.update(base_str)
|
|
6509
|
-
if not args.worker_generator_path:
|
|
6510
|
-
arm_params = validate_and_convert_params(arm_params)
|
|
7194
|
+
if not result:
|
|
7195
|
+
print_yellow("Encountered job without a result")
|
|
7196
|
+
return cnt
|
|
6511
7197
|
|
|
6512
|
-
|
|
6513
|
-
|
|
7198
|
+
if insert_job_into_ax_client(arm_params, result, gen_node_name, __status, base_str):
|
|
7199
|
+
cnt += 1
|
|
7200
|
+
print_debug(f"Inserted one job from {csv_path}, arm_params: {arm_params}, results: {result}")
|
|
7201
|
+
else:
|
|
7202
|
+
print_red(f"Failed to insert one job from {csv_path}, arm_params: {arm_params}, results: {result}")
|
|
6514
7203
|
|
|
6515
|
-
|
|
6516
|
-
|
|
6517
|
-
|
|
7204
|
+
except ValueError as e:
|
|
7205
|
+
err_msg = (
|
|
7206
|
+
f"Failed to insert job(s) from {csv_path} into ax_client. "
|
|
7207
|
+
f"This can happen when the csv file has different parameters or results as the main job one's "
|
|
7208
|
+
f"or other imported jobs. Error: {e}"
|
|
7209
|
+
)
|
|
6518
7210
|
|
|
6519
|
-
|
|
6520
|
-
|
|
6521
|
-
|
|
6522
|
-
|
|
6523
|
-
print_yellow("Encountered job without a result")
|
|
6524
|
-
except ValueError as e:
|
|
6525
|
-
err_msg = f"Failed to insert job(s) from {this_csv_file_path} into ax_client. This can happen when the csv file has different parameters or results as the main job one's or other imported jobs. Error: {e}"
|
|
7211
|
+
if err_msgs is None:
|
|
7212
|
+
print_red("try_insert_job: err_msgs was None")
|
|
7213
|
+
else:
|
|
7214
|
+
if isinstance(err_msgs, list):
|
|
6526
7215
|
if err_msg not in err_msgs:
|
|
6527
7216
|
print_red(err_msg)
|
|
6528
7217
|
err_msgs.append(err_msg)
|
|
7218
|
+
elif isinstance(err_msgs, str):
|
|
7219
|
+
err_msgs += f"\n{err_msg}"
|
|
6529
7220
|
|
|
6530
|
-
|
|
7221
|
+
return cnt
|
|
6531
7222
|
|
|
6532
|
-
|
|
6533
|
-
|
|
6534
|
-
|
|
6535
|
-
|
|
6536
|
-
|
|
7223
|
+
def summarize_insertions(csv_path: str, cnt: int) -> None:
|
|
7224
|
+
if cnt == 0:
|
|
7225
|
+
return
|
|
7226
|
+
if cnt == 1:
|
|
7227
|
+
print_yellow(f"Inserted one job from {csv_path}")
|
|
7228
|
+
else:
|
|
7229
|
+
print_yellow(f"Inserted {cnt} jobs from {csv_path}")
|
|
6537
7230
|
|
|
6538
|
-
|
|
6539
|
-
|
|
6540
|
-
|
|
7231
|
+
def update_global_job_counters(cnt: int) -> None:
|
|
7232
|
+
if not args.worker_generator_path:
|
|
7233
|
+
set_max_eval(max_eval + cnt)
|
|
7234
|
+
set_nr_inserted_jobs(NR_INSERTED_JOBS + cnt)
|
|
6541
7235
|
|
|
6542
|
-
def
|
|
7236
|
+
def update_status(__status: Optional[Any], base_str: Optional[str], new_text: str) -> None:
|
|
6543
7237
|
if __status and base_str:
|
|
6544
7238
|
__status.update(f"{base_str}: {new_text}")
|
|
6545
7239
|
|
|
6546
|
-
def
|
|
7240
|
+
def check_ax_client() -> None:
|
|
6547
7241
|
if ax_client is None or not ax_client:
|
|
6548
7242
|
_fatal_error("insert_job_into_ax_client: ax_client was not defined where it should have been", 101)
|
|
6549
7243
|
|
|
6550
|
-
def
|
|
7244
|
+
def attach_ax_client_data(arm_params: dict) -> Optional[Tuple[Any, int]]:
|
|
7245
|
+
if not ax_client:
|
|
7246
|
+
my_exit(101)
|
|
7247
|
+
|
|
7248
|
+
return None
|
|
7249
|
+
|
|
6551
7250
|
new_trial = ax_client.attach_trial(arm_params)
|
|
7251
|
+
|
|
7252
|
+
return new_trial
|
|
7253
|
+
|
|
7254
|
+
def attach_trial(arm_params: dict) -> Tuple[Any, int]:
|
|
7255
|
+
if ax_client is None:
|
|
7256
|
+
raise RuntimeError("attach_trial: ax_client was empty")
|
|
7257
|
+
|
|
7258
|
+
new_trial = attach_ax_client_data(arm_params)
|
|
6552
7259
|
if not isinstance(new_trial, tuple) or len(new_trial) < 2:
|
|
6553
7260
|
raise RuntimeError("attach_trial didn't return the expected tuple")
|
|
6554
7261
|
return new_trial
|
|
6555
7262
|
|
|
6556
|
-
def
|
|
7263
|
+
def get_trial_by_index(trial_idx: int) -> Any:
|
|
7264
|
+
if ax_client is None:
|
|
7265
|
+
raise RuntimeError("get_trial_by_index: ax_client was empty")
|
|
7266
|
+
|
|
6557
7267
|
trial = ax_client.experiment.trials.get(trial_idx)
|
|
6558
7268
|
if trial is None:
|
|
6559
7269
|
raise RuntimeError(f"Trial with index {trial_idx} not found")
|
|
6560
7270
|
return trial
|
|
6561
7271
|
|
|
6562
|
-
def
|
|
7272
|
+
def create_generator_run(arm_params: dict, trial_idx: int, new_job_type: str) -> GeneratorRun:
|
|
6563
7273
|
arm = Arm(parameters=arm_params, name=f'{trial_idx}_0')
|
|
6564
7274
|
return GeneratorRun(arms=[arm], generation_node_name=new_job_type)
|
|
6565
7275
|
|
|
6566
|
-
def
|
|
7276
|
+
def complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[Any], base_str: Optional[str]) -> None:
|
|
7277
|
+
if ax_client is None:
|
|
7278
|
+
raise RuntimeError("complete_trial_if_result: ax_client was empty")
|
|
7279
|
+
|
|
6567
7280
|
if f"{result}" != "":
|
|
6568
|
-
|
|
7281
|
+
update_status(__status, base_str, "Completing trial")
|
|
6569
7282
|
is_ok = True
|
|
6570
7283
|
|
|
6571
7284
|
for keyname in result.keys():
|
|
@@ -6573,20 +7286,20 @@ def __insert_job_into_ax_client__complete_trial_if_result(trial_idx: int, result
|
|
|
6573
7286
|
is_ok = False
|
|
6574
7287
|
|
|
6575
7288
|
if is_ok:
|
|
6576
|
-
|
|
6577
|
-
|
|
7289
|
+
complete_ax_client_trial(trial_idx, result)
|
|
7290
|
+
update_status(__status, base_str, "Completed trial")
|
|
6578
7291
|
else:
|
|
6579
7292
|
print_debug("Empty job encountered")
|
|
6580
7293
|
else:
|
|
6581
|
-
|
|
7294
|
+
update_status(__status, base_str, "Found trial without result. Not adding it.")
|
|
6582
7295
|
|
|
6583
|
-
def
|
|
7296
|
+
def save_results_if_needed(__status: Optional[Any], base_str: Optional[str]) -> None:
|
|
6584
7297
|
if not args.worker_generator_path:
|
|
6585
|
-
|
|
7298
|
+
update_status(__status, base_str, f"Saving {RESULTS_CSV_FILENAME}")
|
|
6586
7299
|
save_results_csv()
|
|
6587
|
-
|
|
7300
|
+
update_status(__status, base_str, f"Saved {RESULTS_CSV_FILENAME}")
|
|
6588
7301
|
|
|
6589
|
-
def
|
|
7302
|
+
def handle_insert_job_error(e: Exception, arm_params: dict) -> bool:
|
|
6590
7303
|
parsed_error = parse_parameter_type_error(e)
|
|
6591
7304
|
if parsed_error is not None:
|
|
6592
7305
|
param = parsed_error["parameter_name"]
|
|
@@ -6609,37 +7322,37 @@ def insert_job_into_ax_client(
|
|
|
6609
7322
|
result: dict,
|
|
6610
7323
|
new_job_type: str = "MANUAL",
|
|
6611
7324
|
__status: Optional[Any] = None,
|
|
6612
|
-
base_str: str = None
|
|
7325
|
+
base_str: Optional[str] = None
|
|
6613
7326
|
) -> bool:
|
|
6614
|
-
|
|
7327
|
+
check_ax_client()
|
|
6615
7328
|
|
|
6616
7329
|
done_converting = False
|
|
6617
7330
|
while not done_converting:
|
|
6618
7331
|
try:
|
|
6619
|
-
|
|
7332
|
+
update_status(__status, base_str, "Checking ax client")
|
|
6620
7333
|
if ax_client is None:
|
|
6621
7334
|
return False
|
|
6622
7335
|
|
|
6623
|
-
|
|
6624
|
-
_, new_trial_idx =
|
|
7336
|
+
update_status(__status, base_str, "Attaching new trial")
|
|
7337
|
+
_, new_trial_idx = attach_trial(arm_params)
|
|
6625
7338
|
|
|
6626
|
-
|
|
6627
|
-
trial =
|
|
6628
|
-
|
|
7339
|
+
update_status(__status, base_str, "Getting new trial")
|
|
7340
|
+
trial = get_trial_by_index(new_trial_idx)
|
|
7341
|
+
update_status(__status, base_str, "Got new trial")
|
|
6629
7342
|
|
|
6630
|
-
|
|
6631
|
-
manual_generator_run =
|
|
7343
|
+
update_status(__status, base_str, "Creating new arm")
|
|
7344
|
+
manual_generator_run = create_generator_run(arm_params, new_trial_idx, new_job_type)
|
|
6632
7345
|
trial._generator_run = manual_generator_run
|
|
6633
7346
|
fool_linter(trial._generator_run)
|
|
6634
7347
|
|
|
6635
|
-
|
|
7348
|
+
complete_trial_if_result(new_trial_idx, result, __status, base_str)
|
|
6636
7349
|
done_converting = True
|
|
6637
7350
|
|
|
6638
|
-
|
|
7351
|
+
save_results_if_needed(__status, base_str)
|
|
6639
7352
|
return True
|
|
6640
7353
|
|
|
6641
7354
|
except ax.exceptions.core.UnsupportedError as e:
|
|
6642
|
-
if not
|
|
7355
|
+
if not handle_insert_job_error(e, arm_params):
|
|
6643
7356
|
break
|
|
6644
7357
|
|
|
6645
7358
|
return False
|
|
@@ -6970,7 +7683,7 @@ def get_parameters_from_outfile(stdout_path: str) -> Union[None, dict, str]:
|
|
|
6970
7683
|
if not args.tests:
|
|
6971
7684
|
original_print(f"get_parameters_from_outfile: The file '{stdout_path}' was not found.")
|
|
6972
7685
|
except Exception as e:
|
|
6973
|
-
|
|
7686
|
+
print_red(f"get_parameters_from_outfile: There was an error: {e}")
|
|
6974
7687
|
|
|
6975
7688
|
return None
|
|
6976
7689
|
|
|
@@ -6988,7 +7701,7 @@ def get_hostname_from_outfile(stdout_path: Optional[str]) -> Optional[str]:
|
|
|
6988
7701
|
original_print(f"The file '{stdout_path}' was not found.")
|
|
6989
7702
|
return None
|
|
6990
7703
|
except Exception as e:
|
|
6991
|
-
|
|
7704
|
+
print_red(f"There was an error: {e}")
|
|
6992
7705
|
return None
|
|
6993
7706
|
|
|
6994
7707
|
def add_to_global_error_list(msg: str) -> None:
|
|
@@ -7024,22 +7737,70 @@ def mark_trial_as_failed(trial_index: int, _trial: Any) -> None:
|
|
|
7024
7737
|
|
|
7025
7738
|
return None
|
|
7026
7739
|
|
|
7027
|
-
|
|
7740
|
+
log_ax_client_trial_failure(trial_index)
|
|
7028
7741
|
_trial.mark_failed(unsafe=True)
|
|
7029
7742
|
except ValueError as e:
|
|
7030
7743
|
print_debug(f"mark_trial_as_failed error: {e}")
|
|
7031
7744
|
|
|
7032
7745
|
return None
|
|
7033
7746
|
|
|
7034
|
-
def
|
|
7747
|
+
def check_valid_result(result: Union[None, dict]) -> bool:
|
|
7035
7748
|
possible_val_not_found_values = [
|
|
7036
7749
|
VAL_IF_NOTHING_FOUND,
|
|
7037
7750
|
-VAL_IF_NOTHING_FOUND,
|
|
7038
7751
|
-99999999999999997168788049560464200849936328366177157906432,
|
|
7039
7752
|
99999999999999997168788049560464200849936328366177157906432
|
|
7040
7753
|
]
|
|
7041
|
-
|
|
7042
|
-
|
|
7754
|
+
|
|
7755
|
+
def flatten_values(obj: Any) -> Any:
|
|
7756
|
+
values = []
|
|
7757
|
+
try:
|
|
7758
|
+
if isinstance(obj, dict):
|
|
7759
|
+
for v in obj.values():
|
|
7760
|
+
values.extend(flatten_values(v))
|
|
7761
|
+
elif isinstance(obj, (list, tuple, set)):
|
|
7762
|
+
for v in obj:
|
|
7763
|
+
values.extend(flatten_values(v))
|
|
7764
|
+
else:
|
|
7765
|
+
values.append(obj)
|
|
7766
|
+
except Exception as e:
|
|
7767
|
+
print_red(f"Error while flattening values: {e}")
|
|
7768
|
+
return values
|
|
7769
|
+
|
|
7770
|
+
if result is None:
|
|
7771
|
+
return False
|
|
7772
|
+
|
|
7773
|
+
try:
|
|
7774
|
+
all_values = flatten_values(result)
|
|
7775
|
+
for val in all_values:
|
|
7776
|
+
if val in possible_val_not_found_values:
|
|
7777
|
+
return False
|
|
7778
|
+
return True
|
|
7779
|
+
except Exception as e:
|
|
7780
|
+
print_red(f"Error while checking result validity: {e}")
|
|
7781
|
+
return False
|
|
7782
|
+
|
|
7783
|
+
def update_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
|
|
7784
|
+
if not ax_client:
|
|
7785
|
+
my_exit(101)
|
|
7786
|
+
|
|
7787
|
+
return None
|
|
7788
|
+
|
|
7789
|
+
trial = get_trial_by_index(trial_idx)
|
|
7790
|
+
|
|
7791
|
+
trial.update_trial_data(raw_data=result)
|
|
7792
|
+
|
|
7793
|
+
return None
|
|
7794
|
+
|
|
7795
|
+
def complete_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
|
|
7796
|
+
if not ax_client:
|
|
7797
|
+
my_exit(101)
|
|
7798
|
+
|
|
7799
|
+
return None
|
|
7800
|
+
|
|
7801
|
+
ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
|
|
7802
|
+
|
|
7803
|
+
return None
|
|
7043
7804
|
|
|
7044
7805
|
def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -> None:
|
|
7045
7806
|
if ax_client is None:
|
|
@@ -7048,25 +7809,64 @@ def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -
|
|
|
7048
7809
|
|
|
7049
7810
|
try:
|
|
7050
7811
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result}...")
|
|
7051
|
-
|
|
7812
|
+
complete_ax_client_trial(trial_index, raw_result)
|
|
7052
7813
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result}... Done!")
|
|
7053
7814
|
except ax.exceptions.core.UnsupportedError as e:
|
|
7054
7815
|
if f"{e}":
|
|
7055
7816
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure. Trying to update trial...")
|
|
7056
|
-
|
|
7817
|
+
update_ax_client_trial(trial_index, raw_result)
|
|
7057
7818
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure... Done!")
|
|
7058
7819
|
else:
|
|
7059
7820
|
_fatal_error(f"Error completing trial: {e}", 234)
|
|
7060
7821
|
|
|
7061
7822
|
return None
|
|
7062
7823
|
|
|
7063
|
-
def
|
|
7824
|
+
def format_result_for_display(result: dict) -> str:
|
|
7825
|
+
def safe_float(v: Any) -> str:
|
|
7826
|
+
try:
|
|
7827
|
+
if v is None:
|
|
7828
|
+
return "None"
|
|
7829
|
+
if isinstance(v, (int, float)):
|
|
7830
|
+
if math.isnan(v):
|
|
7831
|
+
return "NaN"
|
|
7832
|
+
if math.isinf(v):
|
|
7833
|
+
return "∞" if v > 0 else "-∞"
|
|
7834
|
+
return f"{v:.6f}"
|
|
7835
|
+
return str(v)
|
|
7836
|
+
except Exception as e:
|
|
7837
|
+
return f"<error: {e}>"
|
|
7838
|
+
|
|
7839
|
+
try:
|
|
7840
|
+
if not isinstance(result, dict):
|
|
7841
|
+
return safe_float(result)
|
|
7842
|
+
|
|
7843
|
+
parts = []
|
|
7844
|
+
for key, val in result.items():
|
|
7845
|
+
try:
|
|
7846
|
+
if isinstance(val, (list, tuple)) and len(val) == 2:
|
|
7847
|
+
main, sem = val
|
|
7848
|
+
main_str = safe_float(main)
|
|
7849
|
+
if sem is not None:
|
|
7850
|
+
sem_str = safe_float(sem)
|
|
7851
|
+
parts.append(f"{key}: {main_str} (SEM: {sem_str})")
|
|
7852
|
+
else:
|
|
7853
|
+
parts.append(f"{key}: {main_str}")
|
|
7854
|
+
else:
|
|
7855
|
+
parts.append(f"{key}: {safe_float(val)}")
|
|
7856
|
+
except Exception as e:
|
|
7857
|
+
parts.append(f"{key}: <error: {e}>")
|
|
7858
|
+
|
|
7859
|
+
return ", ".join(parts)
|
|
7860
|
+
except Exception as e:
|
|
7861
|
+
return f"<error formatting result: {e}>"
|
|
7862
|
+
|
|
7863
|
+
def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: dict) -> None:
|
|
7064
7864
|
print_debug(f"Marking trial {_trial} as completed")
|
|
7065
7865
|
_trial.mark_completed(unsafe=True)
|
|
7066
7866
|
|
|
7067
7867
|
succeeded_jobs(1)
|
|
7068
7868
|
|
|
7069
|
-
progressbar_description(f"new result: {result}")
|
|
7869
|
+
progressbar_description(f"new result: {format_result_for_display(result)}")
|
|
7070
7870
|
update_progress_bar(1)
|
|
7071
7871
|
|
|
7072
7872
|
save_results_csv()
|
|
@@ -7080,8 +7880,8 @@ def _finish_job_core_helper_mark_failure(job: Any, trial_index: int, _trial: Any
|
|
|
7080
7880
|
if job:
|
|
7081
7881
|
try:
|
|
7082
7882
|
progressbar_description("job_failed")
|
|
7083
|
-
|
|
7084
|
-
_trial
|
|
7883
|
+
log_ax_client_trial_failure(trial_index)
|
|
7884
|
+
mark_trial_as_failed(trial_index, _trial)
|
|
7085
7885
|
except Exception as e:
|
|
7086
7886
|
print_red(f"\nERROR while trying to mark job as failure: {e}")
|
|
7087
7887
|
job.cancel()
|
|
@@ -7098,16 +7898,16 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
|
|
|
7098
7898
|
result = job.result()
|
|
7099
7899
|
print_debug(f"finish_job_core: trial-index: {trial_index}, job.result(): {result}, state: {state_from_job(job)}")
|
|
7100
7900
|
|
|
7101
|
-
raw_result = result
|
|
7102
|
-
result_keys = list(result.keys())
|
|
7103
|
-
result = result[result_keys[0]]
|
|
7104
7901
|
this_jobs_finished += 1
|
|
7105
7902
|
|
|
7106
7903
|
if ax_client:
|
|
7107
|
-
_trial =
|
|
7904
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7905
|
+
|
|
7906
|
+
if _trial is None:
|
|
7907
|
+
return 0
|
|
7108
7908
|
|
|
7109
|
-
if
|
|
7110
|
-
_finish_job_core_helper_complete_trial(trial_index,
|
|
7909
|
+
if check_valid_result(result):
|
|
7910
|
+
_finish_job_core_helper_complete_trial(trial_index, result)
|
|
7111
7911
|
|
|
7112
7912
|
try:
|
|
7113
7913
|
_finish_job_core_helper_mark_success(_trial, result)
|
|
@@ -7115,15 +7915,17 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
|
|
|
7115
7915
|
if len(arg_result_names) > 1 and count_done_jobs() > 1 and not job_calculate_pareto_front(get_current_run_folder(), True):
|
|
7116
7916
|
print_red("job_calculate_pareto_front post job failed")
|
|
7117
7917
|
except Exception as e:
|
|
7118
|
-
|
|
7918
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
7119
7919
|
else:
|
|
7120
7920
|
_finish_job_core_helper_mark_failure(job, trial_index, _trial)
|
|
7121
7921
|
else:
|
|
7122
|
-
_fatal_error("ax_client could not be found or used",
|
|
7922
|
+
_fatal_error("ax_client could not be found or used", 101)
|
|
7123
7923
|
|
|
7124
7924
|
print_debug(f"finish_job_core: removing job {job}, trial_index: {trial_index}")
|
|
7125
7925
|
global_vars["jobs"].remove((job, trial_index))
|
|
7126
7926
|
|
|
7927
|
+
log_data()
|
|
7928
|
+
|
|
7127
7929
|
force_live_share()
|
|
7128
7930
|
|
|
7129
7931
|
return this_jobs_finished
|
|
@@ -7136,11 +7938,14 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
|
|
|
7136
7938
|
if job:
|
|
7137
7939
|
try:
|
|
7138
7940
|
progressbar_description("job_failed")
|
|
7139
|
-
_trial =
|
|
7140
|
-
|
|
7941
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7942
|
+
if _trial is None:
|
|
7943
|
+
return None
|
|
7944
|
+
|
|
7945
|
+
log_ax_client_trial_failure(trial_index)
|
|
7141
7946
|
mark_trial_as_failed(trial_index, _trial)
|
|
7142
7947
|
except Exception as e:
|
|
7143
|
-
|
|
7948
|
+
print_debug(f"ERROR in line {get_line_info()}: {e}")
|
|
7144
7949
|
job.cancel()
|
|
7145
7950
|
orchestrate_job(job, trial_index)
|
|
7146
7951
|
|
|
@@ -7155,10 +7960,12 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
|
|
|
7155
7960
|
|
|
7156
7961
|
def _finish_previous_jobs_helper_handle_exception(job: Any, trial_index: int, error: Exception) -> int:
|
|
7157
7962
|
if "None for metric" in str(error):
|
|
7158
|
-
|
|
7159
|
-
|
|
7160
|
-
|
|
7161
|
-
|
|
7963
|
+
err_msg = f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string.\nError: {error}\nJob-result: {job.result()}"
|
|
7964
|
+
|
|
7965
|
+
if count_done_jobs() == 0:
|
|
7966
|
+
print_red(err_msg)
|
|
7967
|
+
else:
|
|
7968
|
+
print_debug(err_msg)
|
|
7162
7969
|
else:
|
|
7163
7970
|
print_red(f"\n⚠ {error}")
|
|
7164
7971
|
|
|
@@ -7177,7 +7984,10 @@ def _finish_previous_jobs_helper_process_job(job: Any, trial_index: int, this_jo
|
|
|
7177
7984
|
this_jobs_finished += _finish_previous_jobs_helper_handle_exception(job, trial_index, error)
|
|
7178
7985
|
return this_jobs_finished
|
|
7179
7986
|
|
|
7180
|
-
def _finish_previous_jobs_helper_check_and_process(
|
|
7987
|
+
def _finish_previous_jobs_helper_check_and_process(__args: Tuple[Any, int]) -> int:
|
|
7988
|
+
job, trial_index = __args
|
|
7989
|
+
|
|
7990
|
+
this_jobs_finished = 0
|
|
7181
7991
|
if job is None:
|
|
7182
7992
|
print_debug(f"finish_previous_jobs: job {job} is None")
|
|
7183
7993
|
return this_jobs_finished
|
|
@@ -7190,10 +8000,6 @@ def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, t
|
|
|
7190
8000
|
|
|
7191
8001
|
return this_jobs_finished
|
|
7192
8002
|
|
|
7193
|
-
def _finish_previous_jobs_helper_wrapper(__args: Tuple[Any, int]) -> int:
|
|
7194
|
-
job, trial_index = __args
|
|
7195
|
-
return _finish_previous_jobs_helper_check_and_process(job, trial_index, 0)
|
|
7196
|
-
|
|
7197
8003
|
def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
7198
8004
|
global JOBS_FINISHED
|
|
7199
8005
|
|
|
@@ -7206,13 +8012,10 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
|
7206
8012
|
|
|
7207
8013
|
jobs_copy = global_vars["jobs"][:]
|
|
7208
8014
|
|
|
7209
|
-
|
|
7210
|
-
print_debug(f"jobs in finish_previous_jobs: {jobs_copy}")
|
|
7211
|
-
|
|
7212
|
-
finishing_jobs_start_time = time.time()
|
|
8015
|
+
#finishing_jobs_start_time = time.time()
|
|
7213
8016
|
|
|
7214
8017
|
with ThreadPoolExecutor() as finish_job_executor:
|
|
7215
|
-
futures = [finish_job_executor.submit(
|
|
8018
|
+
futures = [finish_job_executor.submit(_finish_previous_jobs_helper_check_and_process, (job, trial_index)) for job, trial_index in jobs_copy]
|
|
7216
8019
|
|
|
7217
8020
|
for future in as_completed(futures):
|
|
7218
8021
|
try:
|
|
@@ -7220,17 +8023,15 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
|
7220
8023
|
except Exception as e:
|
|
7221
8024
|
print_red(f"⚠ Exception in parallel job handling: {e}")
|
|
7222
8025
|
|
|
7223
|
-
finishing_jobs_end_time = time.time()
|
|
7224
|
-
|
|
7225
|
-
finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
|
|
8026
|
+
#finishing_jobs_end_time = time.time()
|
|
7226
8027
|
|
|
7227
|
-
|
|
7228
|
-
|
|
7229
|
-
save_results_csv()
|
|
8028
|
+
#finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
|
|
7230
8029
|
|
|
7231
|
-
|
|
8030
|
+
#print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
|
|
7232
8031
|
|
|
7233
8032
|
if this_jobs_finished > 0:
|
|
8033
|
+
save_results_csv()
|
|
8034
|
+
save_checkpoint()
|
|
7234
8035
|
progressbar_description([*new_msgs, f"finished {this_jobs_finished} {'job' if this_jobs_finished == 1 else 'jobs'}"])
|
|
7235
8036
|
|
|
7236
8037
|
JOBS_FINISHED += this_jobs_finished
|
|
@@ -7373,11 +8174,15 @@ def is_already_in_defective_nodes(hostname: str) -> bool:
|
|
|
7373
8174
|
return True
|
|
7374
8175
|
except Exception as e:
|
|
7375
8176
|
print_red(f"is_already_in_defective_nodes: Error reading the file {file_path}: {e}")
|
|
7376
|
-
return False
|
|
7377
8177
|
|
|
7378
8178
|
return False
|
|
7379
8179
|
|
|
7380
8180
|
def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
|
|
8181
|
+
if submitit_executor is None:
|
|
8182
|
+
print_red("submit_new_job: submitit_executor was None")
|
|
8183
|
+
|
|
8184
|
+
return None
|
|
8185
|
+
|
|
7381
8186
|
print_debug(f"Submitting new job for trial_index {trial_index}, parameters {parameters}")
|
|
7382
8187
|
|
|
7383
8188
|
start = time.time()
|
|
@@ -7388,25 +8193,49 @@ def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
|
|
|
7388
8193
|
|
|
7389
8194
|
print_debug(f"Done submitting new job, took {elapsed} seconds")
|
|
7390
8195
|
|
|
8196
|
+
log_data()
|
|
8197
|
+
|
|
7391
8198
|
return new_job
|
|
7392
8199
|
|
|
8200
|
+
def get_ax_client_trial(trial_index: int) -> Optional[ax.core.trial.Trial]:
|
|
8201
|
+
if not ax_client:
|
|
8202
|
+
my_exit(101)
|
|
8203
|
+
|
|
8204
|
+
return None
|
|
8205
|
+
|
|
8206
|
+
try:
|
|
8207
|
+
log_data()
|
|
8208
|
+
|
|
8209
|
+
return ax_client.get_trial(trial_index)
|
|
8210
|
+
except KeyError:
|
|
8211
|
+
error_without_print(f"get_ax_client_trial: trial_index {trial_index} failed")
|
|
8212
|
+
return None
|
|
8213
|
+
|
|
7393
8214
|
def orchestrator_start_trial(parameters: Union[dict, str], trial_index: int) -> None:
|
|
7394
8215
|
if submitit_executor and ax_client:
|
|
7395
8216
|
new_job = submit_new_job(parameters, trial_index)
|
|
7396
|
-
|
|
8217
|
+
if new_job:
|
|
8218
|
+
submitted_jobs(1)
|
|
7397
8219
|
|
|
7398
|
-
|
|
8220
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7399
8221
|
|
|
7400
|
-
|
|
7401
|
-
|
|
7402
|
-
|
|
7403
|
-
|
|
7404
|
-
|
|
8222
|
+
if _trial is not None:
|
|
8223
|
+
try:
|
|
8224
|
+
_trial.mark_staged(unsafe=True)
|
|
8225
|
+
except Exception as e:
|
|
8226
|
+
print_debug(f"orchestrator_start_trial: error {e}")
|
|
8227
|
+
_trial.mark_running(unsafe=True, no_runner_required=True)
|
|
7405
8228
|
|
|
7406
|
-
|
|
7407
|
-
|
|
8229
|
+
print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
|
|
8230
|
+
global_vars["jobs"].append((new_job, trial_index))
|
|
8231
|
+
else:
|
|
8232
|
+
print_red("Trial was none in orchestrator_start_trial")
|
|
8233
|
+
else:
|
|
8234
|
+
print_red("orchestrator_start_trial: Failed to start new job")
|
|
8235
|
+
elif ax_client:
|
|
8236
|
+
_fatal_error("submitit_executor could not be found properly", 9)
|
|
7408
8237
|
else:
|
|
7409
|
-
_fatal_error("
|
|
8238
|
+
_fatal_error("ax_client could not be found properly", 101)
|
|
7410
8239
|
|
|
7411
8240
|
def handle_exclude_node(stdout_path: str, hostname_from_out_file: Union[None, str]) -> None:
|
|
7412
8241
|
stdout_path = check_alternate_path(stdout_path)
|
|
@@ -7425,7 +8254,7 @@ def handle_restart(stdout_path: str, trial_index: int) -> None:
|
|
|
7425
8254
|
if parameters:
|
|
7426
8255
|
orchestrator_start_trial(parameters, trial_index)
|
|
7427
8256
|
else:
|
|
7428
|
-
|
|
8257
|
+
print_red(f"Could not determine parameters from outfile {stdout_path} for restarting job")
|
|
7429
8258
|
|
|
7430
8259
|
def check_alternate_path(path: str) -> str:
|
|
7431
8260
|
if os.path.exists(path):
|
|
@@ -7512,7 +8341,7 @@ def execute_evaluation(_params: list) -> Optional[int]:
|
|
|
7512
8341
|
print_debug(f"execute_evaluation({_params})")
|
|
7513
8342
|
trial_index, parameters, trial_counter, phase = _params
|
|
7514
8343
|
if not ax_client:
|
|
7515
|
-
_fatal_error("Failed to get ax_client",
|
|
8344
|
+
_fatal_error("Failed to get ax_client", 101)
|
|
7516
8345
|
|
|
7517
8346
|
return None
|
|
7518
8347
|
|
|
@@ -7521,7 +8350,11 @@ def execute_evaluation(_params: list) -> Optional[int]:
|
|
|
7521
8350
|
|
|
7522
8351
|
return None
|
|
7523
8352
|
|
|
7524
|
-
_trial =
|
|
8353
|
+
_trial = get_ax_client_trial(trial_index)
|
|
8354
|
+
|
|
8355
|
+
if _trial is None:
|
|
8356
|
+
error_without_print(f"execute_evaluation: _trial was not in execute_evaluation for params {_params}")
|
|
8357
|
+
return None
|
|
7525
8358
|
|
|
7526
8359
|
def mark_trial_stage(stage: str, error_msg: str) -> None:
|
|
7527
8360
|
try:
|
|
@@ -7536,17 +8369,18 @@ def execute_evaluation(_params: list) -> Optional[int]:
|
|
|
7536
8369
|
try:
|
|
7537
8370
|
initialize_job_environment()
|
|
7538
8371
|
new_job = submit_new_job(parameters, trial_index)
|
|
7539
|
-
|
|
7540
|
-
|
|
7541
|
-
print_debug(f"execute_evaluation: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
|
|
7542
|
-
global_vars["jobs"].append((new_job, trial_index))
|
|
8372
|
+
if new_job:
|
|
8373
|
+
submitted_jobs(1)
|
|
7543
8374
|
|
|
7544
|
-
|
|
7545
|
-
|
|
8375
|
+
print_debug(f"execute_evaluation: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
|
|
8376
|
+
global_vars["jobs"].append((new_job, trial_index))
|
|
7546
8377
|
|
|
7547
|
-
|
|
8378
|
+
mark_trial_stage("mark_running", "Marking the trial as running failed")
|
|
8379
|
+
trial_counter += 1
|
|
7548
8380
|
|
|
7549
|
-
|
|
8381
|
+
progressbar_description("started new job")
|
|
8382
|
+
else:
|
|
8383
|
+
progressbar_description("Failed to start new job")
|
|
7550
8384
|
except submitit.core.utils.FailedJobError as error:
|
|
7551
8385
|
handle_failed_job(error, trial_index, new_job)
|
|
7552
8386
|
trial_counter += 1
|
|
@@ -7588,7 +8422,7 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
|
|
|
7588
8422
|
my_exit(144)
|
|
7589
8423
|
|
|
7590
8424
|
if new_job is None:
|
|
7591
|
-
|
|
8425
|
+
print_debug("handle_failed_job: job is None")
|
|
7592
8426
|
|
|
7593
8427
|
return None
|
|
7594
8428
|
|
|
@@ -7599,16 +8433,24 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
|
|
|
7599
8433
|
|
|
7600
8434
|
return None
|
|
7601
8435
|
|
|
8436
|
+
def log_ax_client_trial_failure(trial_index: int) -> None:
|
|
8437
|
+
if not ax_client:
|
|
8438
|
+
my_exit(101)
|
|
8439
|
+
|
|
8440
|
+
return
|
|
8441
|
+
|
|
8442
|
+
ax_client.log_trial_failure(trial_index=trial_index)
|
|
8443
|
+
|
|
7602
8444
|
def cancel_failed_job(trial_index: int, new_job: Job) -> None:
|
|
7603
8445
|
print_debug("Trying to cancel job that failed")
|
|
7604
8446
|
if new_job:
|
|
7605
8447
|
try:
|
|
7606
8448
|
if ax_client:
|
|
7607
|
-
|
|
8449
|
+
log_ax_client_trial_failure(trial_index)
|
|
7608
8450
|
else:
|
|
7609
8451
|
_fatal_error("ax_client not defined", 101)
|
|
7610
8452
|
except Exception as e:
|
|
7611
|
-
|
|
8453
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
7612
8454
|
new_job.cancel()
|
|
7613
8455
|
|
|
7614
8456
|
print_debug(f"cancel_failed_job: removing job {new_job}, trial_index: {trial_index}")
|
|
@@ -7644,10 +8486,12 @@ def show_debug_table_for_break_run_search(_name: str, _max_eval: Optional[int])
|
|
|
7644
8486
|
("failed_jobs()", failed_jobs()),
|
|
7645
8487
|
("count_done_jobs()", count_done_jobs()),
|
|
7646
8488
|
("_max_eval", _max_eval),
|
|
7647
|
-
("progress_bar.total", progress_bar.total),
|
|
7648
8489
|
("NR_INSERTED_JOBS", NR_INSERTED_JOBS)
|
|
7649
8490
|
]
|
|
7650
8491
|
|
|
8492
|
+
if progress_bar is not None:
|
|
8493
|
+
rows.append(("progress_bar.total", progress_bar.total))
|
|
8494
|
+
|
|
7651
8495
|
for row in rows:
|
|
7652
8496
|
table.add_row(str(row[0]), str(row[1]))
|
|
7653
8497
|
|
|
@@ -7660,6 +8504,8 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
|
|
|
7660
8504
|
_submitted_jobs = submitted_jobs()
|
|
7661
8505
|
_failed_jobs = failed_jobs()
|
|
7662
8506
|
|
|
8507
|
+
log_data()
|
|
8508
|
+
|
|
7663
8509
|
max_failed_jobs = max_eval
|
|
7664
8510
|
|
|
7665
8511
|
if args.max_failed_jobs is not None and args.max_failed_jobs > 0:
|
|
@@ -7687,11 +8533,11 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
|
|
|
7687
8533
|
_ret = True
|
|
7688
8534
|
|
|
7689
8535
|
if args.verbose_break_run_search_table:
|
|
7690
|
-
show_debug_table_for_break_run_search(_name, _max_eval
|
|
8536
|
+
show_debug_table_for_break_run_search(_name, _max_eval)
|
|
7691
8537
|
|
|
7692
8538
|
return _ret
|
|
7693
8539
|
|
|
7694
|
-
def
|
|
8540
|
+
def calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
|
|
7695
8541
|
"""Calculates the number of jobs to retrieve."""
|
|
7696
8542
|
return min(
|
|
7697
8543
|
max_eval + simulated_jobs - count_done_jobs(),
|
|
@@ -7704,7 +8550,7 @@ def remove_extra_spaces(text: str) -> str:
|
|
|
7704
8550
|
raise ValueError("Input must be a string")
|
|
7705
8551
|
return re.sub(r'\s+', ' ', text).strip()
|
|
7706
8552
|
|
|
7707
|
-
def
|
|
8553
|
+
def get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
|
|
7708
8554
|
"""Generates the appropriate message for the number of trials being retrieved."""
|
|
7709
8555
|
ret = ""
|
|
7710
8556
|
if full_nr_of_jobs_to_get > 1:
|
|
@@ -7789,70 +8635,70 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
|
|
|
7789
8635
|
print_red("get_batched_arms: ax_client was None")
|
|
7790
8636
|
return []
|
|
7791
8637
|
|
|
7792
|
-
# Experiment-Status laden
|
|
7793
8638
|
load_experiment_state()
|
|
7794
8639
|
|
|
7795
|
-
while len(batched_arms)
|
|
8640
|
+
while len(batched_arms) < nr_of_jobs_to_get:
|
|
7796
8641
|
if attempts > args.max_attempts_for_generation:
|
|
7797
8642
|
print_debug(f"get_batched_arms: Stopped after {attempts} attempts: could not generate enough arms "
|
|
7798
8643
|
f"(got {len(batched_arms)} out of {nr_of_jobs_to_get}).")
|
|
7799
8644
|
break
|
|
7800
8645
|
|
|
7801
|
-
|
|
7802
|
-
print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting {remaining} more arm(s).")
|
|
7803
|
-
|
|
7804
|
-
#print("get pending observations")
|
|
7805
|
-
pending_observations = get_pending_observation_features(experiment=ax_client.experiment)
|
|
7806
|
-
#print("got pending observations")
|
|
8646
|
+
#print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting 1 more arm")
|
|
7807
8647
|
|
|
7808
|
-
|
|
7809
|
-
batched_generator_run = global_gs.gen(
|
|
8648
|
+
pending_observations = get_pending_observation_features(
|
|
7810
8649
|
experiment=ax_client.experiment,
|
|
7811
|
-
|
|
7812
|
-
pending_observations=pending_observations
|
|
8650
|
+
include_out_of_design_points=True
|
|
7813
8651
|
)
|
|
7814
|
-
#print(f"got global_gs.gen(): {batched_generator_run}")
|
|
7815
8652
|
|
|
7816
|
-
|
|
8653
|
+
try:
|
|
8654
|
+
#print_debug("getting global_gs.gen() with n=1")
|
|
8655
|
+
|
|
8656
|
+
batched_generator_run: Any = global_gs.gen(
|
|
8657
|
+
experiment=ax_client.experiment,
|
|
8658
|
+
n=1,
|
|
8659
|
+
pending_observations=pending_observations,
|
|
8660
|
+
)
|
|
8661
|
+
print_debug(f"got global_gs.gen(): {batched_generator_run}")
|
|
8662
|
+
except Exception as e:
|
|
8663
|
+
print_debug(f"global_gs.gen failed: {e}")
|
|
8664
|
+
traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
|
|
8665
|
+
break
|
|
8666
|
+
|
|
7817
8667
|
depth = 0
|
|
7818
8668
|
path = "batched_generator_run"
|
|
7819
|
-
while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run)
|
|
7820
|
-
#
|
|
8669
|
+
while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) == 1:
|
|
8670
|
+
#print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
|
|
7821
8671
|
batched_generator_run = batched_generator_run[0]
|
|
7822
8672
|
path += "[0]"
|
|
7823
8673
|
depth += 1
|
|
7824
8674
|
|
|
7825
|
-
#
|
|
8675
|
+
#print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
|
|
7826
8676
|
|
|
7827
|
-
|
|
7828
|
-
new_arms = batched_generator_run.arms
|
|
7829
|
-
#print(f"new_arms: {new_arms}")
|
|
8677
|
+
new_arms = getattr(batched_generator_run, "arms", [])
|
|
7830
8678
|
if not new_arms:
|
|
7831
8679
|
print_debug("get_batched_arms: No new arms were generated in this attempt.")
|
|
7832
8680
|
else:
|
|
7833
|
-
print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s),
|
|
8681
|
+
print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), now at {len(batched_arms) + len(new_arms)} of {nr_of_jobs_to_get}.")
|
|
8682
|
+
batched_arms.extend(new_arms)
|
|
7834
8683
|
|
|
7835
|
-
batched_arms.extend(new_arms)
|
|
7836
8684
|
attempts += 1
|
|
7837
8685
|
|
|
7838
8686
|
print_debug(f"get_batched_arms: Finished with {len(batched_arms)} arm(s) after {attempts} attempt(s).")
|
|
7839
8687
|
|
|
7840
|
-
save_results_csv()
|
|
7841
|
-
|
|
7842
8688
|
return batched_arms
|
|
7843
8689
|
|
|
7844
|
-
def
|
|
8690
|
+
def fetch_next_trials(nr_of_jobs_to_get: int, recursion: bool = False) -> Tuple[Dict[int, Any], bool]:
|
|
7845
8691
|
die_101_if_no_ax_client_or_experiment_or_gs()
|
|
7846
8692
|
|
|
7847
8693
|
if not ax_client:
|
|
7848
|
-
_fatal_error("ax_client was not defined",
|
|
8694
|
+
_fatal_error("ax_client was not defined", 101)
|
|
7849
8695
|
|
|
7850
8696
|
if global_gs is None:
|
|
7851
8697
|
_fatal_error("Global generation strategy is not set. This is a bug in OmniOpt2.", 107)
|
|
7852
8698
|
|
|
7853
|
-
return
|
|
8699
|
+
return generate_trials(nr_of_jobs_to_get, recursion)
|
|
7854
8700
|
|
|
7855
|
-
def
|
|
8701
|
+
def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
7856
8702
|
trials_dict: Dict[int, Any] = {}
|
|
7857
8703
|
trial_durations: List[float] = []
|
|
7858
8704
|
|
|
@@ -7866,7 +8712,6 @@ def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7866
8712
|
if cnt >= n:
|
|
7867
8713
|
break
|
|
7868
8714
|
|
|
7869
|
-
# 🔹 Erzeuge einen komplett neuen Arm, damit Ax den Namen vergibt
|
|
7870
8715
|
try:
|
|
7871
8716
|
arm = Arm(parameters=arm.parameters)
|
|
7872
8717
|
except Exception as arm_err:
|
|
@@ -7874,16 +8719,14 @@ def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7874
8719
|
retries += 1
|
|
7875
8720
|
continue
|
|
7876
8721
|
|
|
7877
|
-
|
|
7878
|
-
progressbar_description(_get_trials_message(cnt + 1, n, trial_durations))
|
|
8722
|
+
progressbar_description(get_trials_message(cnt + 1, n, trial_durations))
|
|
7879
8723
|
|
|
7880
8724
|
try:
|
|
7881
|
-
result =
|
|
8725
|
+
result = create_and_handle_trial(arm)
|
|
7882
8726
|
if result is not None:
|
|
7883
8727
|
trial_index, trial_duration, trial_successful = result
|
|
7884
|
-
|
|
7885
8728
|
except TrialRejected as e:
|
|
7886
|
-
print_debug(f"Trial rejected: {e}")
|
|
8729
|
+
print_debug(f"generate_trials: Trial rejected, error: {e}")
|
|
7887
8730
|
retries += 1
|
|
7888
8731
|
continue
|
|
7889
8732
|
|
|
@@ -7893,19 +8736,26 @@ def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7893
8736
|
cnt += 1
|
|
7894
8737
|
trials_dict[trial_index] = arm.parameters
|
|
7895
8738
|
|
|
7896
|
-
|
|
8739
|
+
finalized = finalize_generation(trials_dict, cnt, n, start_time)
|
|
7897
8740
|
|
|
7898
|
-
return
|
|
8741
|
+
return finalized
|
|
7899
8742
|
|
|
7900
8743
|
except Exception as e:
|
|
7901
|
-
return
|
|
8744
|
+
return handle_generation_failure(e, n, recursion)
|
|
7902
8745
|
|
|
7903
8746
|
class TrialRejected(Exception):
|
|
7904
8747
|
pass
|
|
7905
8748
|
|
|
7906
|
-
def
|
|
8749
|
+
def mark_abandoned(trial: Any, reason: str, trial_index: int) -> None:
|
|
8750
|
+
try:
|
|
8751
|
+
print_debug(f"[INFO] Marking trial {trial.index} ({trial.arm.name}) as abandoned, trial-index: {trial_index}. Reason: {reason}")
|
|
8752
|
+
trial.mark_abandoned(reason)
|
|
8753
|
+
except Exception as e:
|
|
8754
|
+
print_red(f"[ERROR] Could not mark trial as abandoned: {e}")
|
|
8755
|
+
|
|
8756
|
+
def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
7907
8757
|
if ax_client is None:
|
|
7908
|
-
print_red("ax_client is None in
|
|
8758
|
+
print_red("ax_client is None in create_and_handle_trial")
|
|
7909
8759
|
return None
|
|
7910
8760
|
|
|
7911
8761
|
start = time.time()
|
|
@@ -7928,7 +8778,7 @@ def _create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
7928
8778
|
arm = trial.arms[0]
|
|
7929
8779
|
if deduplicated_arm(arm):
|
|
7930
8780
|
print_debug(f"Duplicated arm: {arm}")
|
|
7931
|
-
|
|
8781
|
+
mark_abandoned(trial, "Duplication detected", trial_index)
|
|
7932
8782
|
raise TrialRejected("Duplicate arm.")
|
|
7933
8783
|
|
|
7934
8784
|
arms_by_name_for_deduplication[arm.name] = arm
|
|
@@ -7936,8 +8786,8 @@ def _create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
7936
8786
|
params = arm.parameters
|
|
7937
8787
|
|
|
7938
8788
|
if not has_no_post_generation_constraints_or_matches_constraints(post_generation_constraints, params):
|
|
7939
|
-
print_debug(f"Trial {trial_index} does not meet post-generation constraints. Marking abandoned.")
|
|
7940
|
-
|
|
8789
|
+
print_debug(f"Trial {trial_index} does not meet post-generation constraints. Marking abandoned. Params: {params}, constraints: {post_generation_constraints}")
|
|
8790
|
+
mark_abandoned(trial, "Post-Generation-Constraint failed", trial_index)
|
|
7941
8791
|
abandoned_trial_indices.append(trial_index)
|
|
7942
8792
|
raise TrialRejected("Post-generation constraints not met.")
|
|
7943
8793
|
|
|
@@ -7945,7 +8795,7 @@ def _create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
7945
8795
|
end = time.time()
|
|
7946
8796
|
return trial_index, float(end - start), True
|
|
7947
8797
|
|
|
7948
|
-
def
|
|
8798
|
+
def finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
|
|
7949
8799
|
total_time = time.time() - start_time
|
|
7950
8800
|
|
|
7951
8801
|
log_gen_times.append(total_time)
|
|
@@ -7956,7 +8806,7 @@ def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int,
|
|
|
7956
8806
|
|
|
7957
8807
|
return trials_dict, False
|
|
7958
8808
|
|
|
7959
|
-
def
|
|
8809
|
+
def handle_generation_failure(
|
|
7960
8810
|
e: Exception,
|
|
7961
8811
|
requested: int,
|
|
7962
8812
|
recursion: bool
|
|
@@ -7972,19 +8822,19 @@ def _handle_generation_failure(
|
|
|
7972
8822
|
)):
|
|
7973
8823
|
msg = str(e)
|
|
7974
8824
|
if msg not in error_8_saved:
|
|
7975
|
-
|
|
8825
|
+
print_exhaustion_warning(e, recursion)
|
|
7976
8826
|
error_8_saved.append(msg)
|
|
7977
8827
|
|
|
7978
8828
|
if not recursion and args.revert_to_random_when_seemingly_exhausted:
|
|
7979
8829
|
print_debug("Switching to random search strategy.")
|
|
7980
|
-
|
|
7981
|
-
return
|
|
8830
|
+
set_global_gs_to_sobol()
|
|
8831
|
+
return fetch_next_trials(requested, True)
|
|
7982
8832
|
|
|
7983
|
-
print_red(f"
|
|
8833
|
+
print_red(f"handle_generation_failure: General Exception: {e}")
|
|
7984
8834
|
|
|
7985
8835
|
return {}, True
|
|
7986
8836
|
|
|
7987
|
-
def
|
|
8837
|
+
def print_exhaustion_warning(e: Exception, recursion: bool) -> None:
|
|
7988
8838
|
if not recursion and args.revert_to_random_when_seemingly_exhausted:
|
|
7989
8839
|
print_yellow(f"\n⚠Error 8: {e} From now (done jobs: {count_done_jobs()}) on, random points will be generated.")
|
|
7990
8840
|
else:
|
|
@@ -8029,21 +8879,24 @@ def get_model_gen_kwargs() -> dict:
|
|
|
8029
8879
|
"fit_out_of_design": args.fit_out_of_design
|
|
8030
8880
|
}
|
|
8031
8881
|
|
|
8032
|
-
def
|
|
8882
|
+
def set_global_gs_to_sobol() -> None:
|
|
8033
8883
|
global global_gs
|
|
8034
8884
|
global overwritten_to_random
|
|
8035
8885
|
|
|
8886
|
+
print("Reverting to SOBOL")
|
|
8887
|
+
|
|
8036
8888
|
global_gs = GenerationStrategy(
|
|
8037
8889
|
name="Random*",
|
|
8038
8890
|
nodes=[
|
|
8039
8891
|
GenerationNode(
|
|
8040
|
-
|
|
8041
|
-
|
|
8042
|
-
|
|
8043
|
-
|
|
8044
|
-
|
|
8045
|
-
|
|
8046
|
-
|
|
8892
|
+
name="Sobol",
|
|
8893
|
+
should_deduplicate=True,
|
|
8894
|
+
generator_specs=[ # type: ignore[arg-type]
|
|
8895
|
+
GeneratorSpec( # type: ignore[arg-type]
|
|
8896
|
+
Generators.SOBOL, # type: ignore[arg-type]
|
|
8897
|
+
model_gen_kwargs=get_model_gen_kwargs() # type: ignore[arg-type]
|
|
8898
|
+
) # type: ignore[arg-type]
|
|
8899
|
+
] # type: ignore[arg-type]
|
|
8047
8900
|
)
|
|
8048
8901
|
]
|
|
8049
8902
|
)
|
|
@@ -8261,14 +9114,14 @@ def _handle_linalg_error(error: Union[None, str, Exception]) -> None:
|
|
|
8261
9114
|
"""Handles the np.linalg.LinAlgError based on the model being used."""
|
|
8262
9115
|
print_red(f"Error: {error}")
|
|
8263
9116
|
|
|
8264
|
-
def
|
|
8265
|
-
finish_previous_jobs(["finishing jobs (
|
|
9117
|
+
def get_next_trials(nr_of_jobs_to_get: int) -> Tuple[Union[None, dict], bool]:
|
|
9118
|
+
finish_previous_jobs(["finishing jobs (get_next_trials)"])
|
|
8266
9119
|
|
|
8267
|
-
if break_run_search("
|
|
9120
|
+
if break_run_search("get_next_trials", max_eval) or nr_of_jobs_to_get == 0:
|
|
8268
9121
|
return {}, True
|
|
8269
9122
|
|
|
8270
9123
|
try:
|
|
8271
|
-
trial_index_to_param, optimization_complete =
|
|
9124
|
+
trial_index_to_param, optimization_complete = fetch_next_trials(nr_of_jobs_to_get)
|
|
8272
9125
|
|
|
8273
9126
|
cf = currentframe()
|
|
8274
9127
|
if cf:
|
|
@@ -8403,16 +9256,21 @@ def get_model_from_name(name: str) -> Any:
|
|
|
8403
9256
|
return gen
|
|
8404
9257
|
raise ValueError(f"Unknown or unsupported model: {name}")
|
|
8405
9258
|
|
|
8406
|
-
def get_name_from_model(model) ->
|
|
9259
|
+
def get_name_from_model(model: Any) -> str:
|
|
8407
9260
|
if not isinstance(SUPPORTED_MODELS, (list, set, tuple)):
|
|
8408
|
-
|
|
9261
|
+
raise RuntimeError("get_model_from_name: SUPPORTED_MODELS was not a list, set or tuple. Cannot continue")
|
|
8409
9262
|
|
|
8410
9263
|
model_str = model.value if hasattr(model, "value") else str(model)
|
|
8411
9264
|
|
|
8412
9265
|
model_str_lower = model_str.lower()
|
|
8413
9266
|
model_map = {m.lower(): m for m in SUPPORTED_MODELS}
|
|
8414
9267
|
|
|
8415
|
-
|
|
9268
|
+
ret = model_map.get(model_str_lower, None)
|
|
9269
|
+
|
|
9270
|
+
if ret is None:
|
|
9271
|
+
raise RuntimeError("get_name_from_model: failed to get Model")
|
|
9272
|
+
|
|
9273
|
+
return ret
|
|
8416
9274
|
|
|
8417
9275
|
def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str, int]], int]:
|
|
8418
9276
|
gen_strat_list = []
|
|
@@ -8423,10 +9281,10 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
|
|
|
8423
9281
|
|
|
8424
9282
|
for s in splitted_by_comma:
|
|
8425
9283
|
if "=" not in s:
|
|
8426
|
-
|
|
9284
|
+
print_red(f"'{s}' does not contain '='")
|
|
8427
9285
|
my_exit(123)
|
|
8428
9286
|
if s.count("=") != 1:
|
|
8429
|
-
|
|
9287
|
+
print_red(f"There can only be one '=' in the gen_strat_str's element '{s}'")
|
|
8430
9288
|
my_exit(123)
|
|
8431
9289
|
|
|
8432
9290
|
model_name, nr_str = s.split("=")
|
|
@@ -8436,13 +9294,13 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
|
|
|
8436
9294
|
_fatal_error(f"Model {matching_model} is not valid for custom generation strategy.", 56)
|
|
8437
9295
|
|
|
8438
9296
|
if not matching_model:
|
|
8439
|
-
|
|
9297
|
+
print_red(f"'{model_name}' not found in SUPPORTED_MODELS")
|
|
8440
9298
|
my_exit(123)
|
|
8441
9299
|
|
|
8442
9300
|
try:
|
|
8443
9301
|
nr = int(nr_str)
|
|
8444
9302
|
except ValueError:
|
|
8445
|
-
|
|
9303
|
+
print_red(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
|
|
8446
9304
|
my_exit(123)
|
|
8447
9305
|
|
|
8448
9306
|
gen_strat_list.append({matching_model: nr})
|
|
@@ -8566,7 +9424,7 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
|
|
|
8566
9424
|
if model_name == "TPE":
|
|
8567
9425
|
if len(arg_result_names) != 1:
|
|
8568
9426
|
_fatal_error(f"Has {len(arg_result_names)} results. TPE currently only supports single-objective-optimization.", 108)
|
|
8569
|
-
return ExternalProgramGenerationNode(external_generator=f"python3 {script_dir}/.tpe.py",
|
|
9427
|
+
return ExternalProgramGenerationNode(external_generator=f"python3 {script_dir}/.tpe.py", name="EXTERNAL_GENERATOR")
|
|
8570
9428
|
|
|
8571
9429
|
external_generators = {
|
|
8572
9430
|
"PSEUDORANDOM": f"python3 {script_dir}/.random_generator.py",
|
|
@@ -8580,10 +9438,10 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
|
|
|
8580
9438
|
cmd = external_generators[model_name]
|
|
8581
9439
|
if model_name == "EXTERNAL_GENERATOR" and not cmd:
|
|
8582
9440
|
_fatal_error("--external_generator is missing. Cannot create points for EXTERNAL_GENERATOR without it.", 204)
|
|
8583
|
-
return ExternalProgramGenerationNode(external_generator=cmd,
|
|
9441
|
+
return ExternalProgramGenerationNode(external_generator=cmd, name="EXTERNAL_GENERATOR")
|
|
8584
9442
|
|
|
8585
9443
|
trans_crit = [
|
|
8586
|
-
|
|
9444
|
+
MinTrials(
|
|
8587
9445
|
threshold=threshold,
|
|
8588
9446
|
block_transition_if_unmet=True,
|
|
8589
9447
|
transition_to=target_model,
|
|
@@ -8599,11 +9457,12 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
|
|
|
8599
9457
|
if model_name.lower() != "sobol":
|
|
8600
9458
|
kwargs["model_kwargs"] = get_model_kwargs()
|
|
8601
9459
|
|
|
8602
|
-
model_spec = [GeneratorSpec(selected_model, **kwargs)]
|
|
9460
|
+
model_spec = [GeneratorSpec(selected_model, **kwargs)] # type: ignore[arg-type]
|
|
8603
9461
|
|
|
8604
9462
|
res = GenerationNode(
|
|
8605
|
-
|
|
9463
|
+
name=model_name,
|
|
8606
9464
|
generator_specs=model_spec,
|
|
9465
|
+
should_deduplicate=True,
|
|
8607
9466
|
transition_criteria=trans_crit
|
|
8608
9467
|
)
|
|
8609
9468
|
|
|
@@ -8614,11 +9473,11 @@ def get_optimizer_kwargs() -> dict:
|
|
|
8614
9473
|
"sequential": False
|
|
8615
9474
|
}
|
|
8616
9475
|
|
|
8617
|
-
def create_step(model_name: str, _num_trials: int
|
|
9476
|
+
def create_step(model_name: str, _num_trials: int, index: int) -> GenerationStep:
|
|
8618
9477
|
model_enum = get_model_from_name(model_name)
|
|
8619
9478
|
|
|
8620
9479
|
return GenerationStep(
|
|
8621
|
-
generator=model_enum,
|
|
9480
|
+
generator=model_enum,
|
|
8622
9481
|
num_trials=_num_trials,
|
|
8623
9482
|
max_parallelism=1000 * max_eval + 1000,
|
|
8624
9483
|
model_kwargs=get_model_kwargs(),
|
|
@@ -8629,17 +9488,16 @@ def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = N
|
|
|
8629
9488
|
)
|
|
8630
9489
|
|
|
8631
9490
|
def set_global_generation_strategy() -> None:
|
|
8632
|
-
|
|
8633
|
-
continue_not_supported_on_custom_generation_strategy()
|
|
9491
|
+
continue_not_supported_on_custom_generation_strategy()
|
|
8634
9492
|
|
|
8635
|
-
|
|
8636
|
-
|
|
8637
|
-
|
|
8638
|
-
|
|
8639
|
-
|
|
8640
|
-
|
|
8641
|
-
|
|
8642
|
-
|
|
9493
|
+
try:
|
|
9494
|
+
if args.generation_strategy is None:
|
|
9495
|
+
setup_default_generation_strategy()
|
|
9496
|
+
else:
|
|
9497
|
+
setup_custom_generation_strategy()
|
|
9498
|
+
except Exception as e:
|
|
9499
|
+
print_red(f"Unexpected error in generation strategy setup: {e}")
|
|
9500
|
+
my_exit(111)
|
|
8643
9501
|
|
|
8644
9502
|
if global_gs is None:
|
|
8645
9503
|
print_red("global_gs is None after setup!")
|
|
@@ -8775,10 +9633,6 @@ def execute_trials(
|
|
|
8775
9633
|
index_param_list.append(_args)
|
|
8776
9634
|
i += 1
|
|
8777
9635
|
|
|
8778
|
-
save_results_csv()
|
|
8779
|
-
|
|
8780
|
-
save_results_csv()
|
|
8781
|
-
|
|
8782
9636
|
start_time = time.time()
|
|
8783
9637
|
|
|
8784
9638
|
cnt = 0
|
|
@@ -8794,10 +9648,14 @@ def execute_trials(
|
|
|
8794
9648
|
result = future.result()
|
|
8795
9649
|
print_debug(f"result in execute_trials: {result}")
|
|
8796
9650
|
except Exception as exc:
|
|
8797
|
-
|
|
9651
|
+
failed_args = future_to_args[future]
|
|
9652
|
+
print_red(f"execute_trials: Error at executing a trial with args {failed_args}: {exc}")
|
|
9653
|
+
traceback.print_exc()
|
|
8798
9654
|
|
|
8799
9655
|
end_time = time.time()
|
|
8800
9656
|
|
|
9657
|
+
log_data()
|
|
9658
|
+
|
|
8801
9659
|
duration = float(end_time - start_time)
|
|
8802
9660
|
job_submit_durations.append(duration)
|
|
8803
9661
|
job_submit_nrs.append(cnt)
|
|
@@ -8831,20 +9689,20 @@ def create_and_execute_next_runs(next_nr_steps: int, phase: Optional[str], _max_
|
|
|
8831
9689
|
done_optimizing: bool = False
|
|
8832
9690
|
|
|
8833
9691
|
try:
|
|
8834
|
-
done_optimizing, trial_index_to_param =
|
|
8835
|
-
|
|
9692
|
+
done_optimizing, trial_index_to_param = create_and_execute_next_runs_run_loop(_max_eval, phase)
|
|
9693
|
+
create_and_execute_next_runs_finish(done_optimizing)
|
|
8836
9694
|
except Exception as e:
|
|
8837
9695
|
stacktrace = traceback.format_exc()
|
|
8838
9696
|
print_debug(f"Warning: create_and_execute_next_runs encountered an exception: {e}\n{stacktrace}")
|
|
8839
9697
|
return handle_exceptions_create_and_execute_next_runs(e)
|
|
8840
9698
|
|
|
8841
|
-
return
|
|
9699
|
+
return create_and_execute_next_runs_return_value(trial_index_to_param)
|
|
8842
9700
|
|
|
8843
|
-
def
|
|
9701
|
+
def create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
|
|
8844
9702
|
done_optimizing = False
|
|
8845
9703
|
trial_index_to_param: Optional[Dict] = None
|
|
8846
9704
|
|
|
8847
|
-
nr_of_jobs_to_get =
|
|
9705
|
+
nr_of_jobs_to_get = calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
|
|
8848
9706
|
|
|
8849
9707
|
__max_eval = _max_eval if _max_eval is not None else 0
|
|
8850
9708
|
new_nr_of_jobs_to_get = min(__max_eval - (submitted_jobs() - failed_jobs()), nr_of_jobs_to_get)
|
|
@@ -8857,7 +9715,8 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
|
|
|
8857
9715
|
get_next_trials_nr = new_nr_of_jobs_to_get
|
|
8858
9716
|
|
|
8859
9717
|
for _ in range(range_nr):
|
|
8860
|
-
trial_index_to_param, done_optimizing =
|
|
9718
|
+
trial_index_to_param, done_optimizing = get_next_trials(get_next_trials_nr)
|
|
9719
|
+
log_data()
|
|
8861
9720
|
if done_optimizing:
|
|
8862
9721
|
continue
|
|
8863
9722
|
|
|
@@ -8882,13 +9741,13 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
|
|
|
8882
9741
|
|
|
8883
9742
|
return done_optimizing, trial_index_to_param
|
|
8884
9743
|
|
|
8885
|
-
def
|
|
9744
|
+
def create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
|
|
8886
9745
|
finish_previous_jobs(["finishing jobs"])
|
|
8887
9746
|
|
|
8888
9747
|
if done_optimizing:
|
|
8889
9748
|
end_program(False, 0)
|
|
8890
9749
|
|
|
8891
|
-
def
|
|
9750
|
+
def create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
|
|
8892
9751
|
try:
|
|
8893
9752
|
if trial_index_to_param:
|
|
8894
9753
|
res = len(trial_index_to_param.keys())
|
|
@@ -8999,7 +9858,7 @@ def execute_nvidia_smi() -> None:
|
|
|
8999
9858
|
if not host:
|
|
9000
9859
|
print_debug("host not defined")
|
|
9001
9860
|
except Exception as e:
|
|
9002
|
-
|
|
9861
|
+
print_red(f"execute_nvidia_smi: An error occurred: {e}")
|
|
9003
9862
|
if is_slurm_job() and not args.force_local_execution:
|
|
9004
9863
|
_sleep(30)
|
|
9005
9864
|
|
|
@@ -9037,11 +9896,6 @@ def run_search() -> bool:
|
|
|
9037
9896
|
|
|
9038
9897
|
return False
|
|
9039
9898
|
|
|
9040
|
-
async def start_logging_daemon() -> None:
|
|
9041
|
-
while True:
|
|
9042
|
-
log_data()
|
|
9043
|
-
time.sleep(30)
|
|
9044
|
-
|
|
9045
9899
|
def should_break_search() -> bool:
|
|
9046
9900
|
ret = False
|
|
9047
9901
|
|
|
@@ -9105,7 +9959,7 @@ def finalize_jobs() -> None:
|
|
|
9105
9959
|
handle_slurm_execution()
|
|
9106
9960
|
|
|
9107
9961
|
def go_through_jobs_that_are_not_completed_yet() -> None:
|
|
9108
|
-
print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
|
|
9962
|
+
#print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
|
|
9109
9963
|
|
|
9110
9964
|
nr_jobs_left = len(global_vars['jobs'])
|
|
9111
9965
|
if nr_jobs_left == 1:
|
|
@@ -9175,7 +10029,7 @@ def parse_orchestrator_file(_f: str, _test: bool = False) -> Union[dict, None]:
|
|
|
9175
10029
|
|
|
9176
10030
|
return data
|
|
9177
10031
|
except Exception as e:
|
|
9178
|
-
|
|
10032
|
+
print_red(f"Error while parse_experiment_parameters({_f}): {e}")
|
|
9179
10033
|
else:
|
|
9180
10034
|
print_red(f"{_f} could not be found")
|
|
9181
10035
|
|
|
@@ -9192,355 +10046,43 @@ def set_orchestrator() -> None:
|
|
|
9192
10046
|
print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
|
|
9193
10047
|
|
|
9194
10048
|
def check_if_has_random_steps() -> None:
|
|
9195
|
-
if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
|
|
9196
|
-
_fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
|
|
9197
|
-
|
|
9198
|
-
def add_exclude_to_defective_nodes() -> None:
|
|
9199
|
-
with spinner("Adding excluded nodes..."):
|
|
9200
|
-
if args.exclude:
|
|
9201
|
-
entries = [entry.strip() for entry in args.exclude.split(',')]
|
|
9202
|
-
|
|
9203
|
-
for entry in entries:
|
|
9204
|
-
count_defective_nodes(None, entry)
|
|
9205
|
-
|
|
9206
|
-
def check_max_eval(_max_eval: int) -> None:
|
|
9207
|
-
with spinner("Checking max_eval..."):
|
|
9208
|
-
if not _max_eval:
|
|
9209
|
-
_fatal_error("--max_eval needs to be set!", 19)
|
|
9210
|
-
|
|
9211
|
-
def parse_parameters() -> Any:
|
|
9212
|
-
cli_params_experiment_parameters = None
|
|
9213
|
-
if args.parameter:
|
|
9214
|
-
parse_experiment_parameters()
|
|
9215
|
-
cli_params_experiment_parameters = experiment_parameters
|
|
9216
|
-
|
|
9217
|
-
return cli_params_experiment_parameters
|
|
9218
|
-
|
|
9219
|
-
def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
|
|
9220
|
-
table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
|
|
9221
|
-
|
|
9222
|
-
rows = _pareto_front_table_read_csv()
|
|
9223
|
-
if not rows:
|
|
9224
|
-
table.add_column("No data found")
|
|
9225
|
-
return table
|
|
9226
|
-
|
|
9227
|
-
filtered_rows = _pareto_front_table_filter_rows(rows, idxs)
|
|
9228
|
-
if not filtered_rows:
|
|
9229
|
-
table.add_column("No matching entries")
|
|
9230
|
-
return table
|
|
9231
|
-
|
|
9232
|
-
param_cols, result_cols = _pareto_front_table_get_columns(filtered_rows[0])
|
|
9233
|
-
|
|
9234
|
-
_pareto_front_table_add_headers(table, param_cols, result_cols)
|
|
9235
|
-
_pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
|
|
9236
|
-
|
|
9237
|
-
return table
|
|
9238
|
-
|
|
9239
|
-
def _pareto_front_table_read_csv() -> List[Dict[str, str]]:
|
|
9240
|
-
with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
|
|
9241
|
-
return list(csv.DictReader(f))
|
|
9242
|
-
|
|
9243
|
-
def _pareto_front_table_filter_rows(rows: List[Dict[str, str]], idxs: List[int]) -> List[Dict[str, str]]:
|
|
9244
|
-
result = []
|
|
9245
|
-
for row in rows:
|
|
9246
|
-
try:
|
|
9247
|
-
trial_index = int(row["trial_index"])
|
|
9248
|
-
except (KeyError, ValueError):
|
|
9249
|
-
continue
|
|
9250
|
-
|
|
9251
|
-
if row.get("trial_status", "").strip().upper() == "COMPLETED" and trial_index in idxs:
|
|
9252
|
-
result.append(row)
|
|
9253
|
-
return result
|
|
9254
|
-
|
|
9255
|
-
def _pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
|
|
9256
|
-
all_columns = list(first_row.keys())
|
|
9257
|
-
ignored_cols = set(special_col_names) - {"trial_index"}
|
|
9258
|
-
|
|
9259
|
-
param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
|
|
9260
|
-
result_cols = [col for col in arg_result_names if col in all_columns]
|
|
9261
|
-
return param_cols, result_cols
|
|
9262
|
-
|
|
9263
|
-
def _pareto_front_table_add_headers(table: Table, param_cols: List[str], result_cols: List[str]) -> None:
|
|
9264
|
-
for col in param_cols:
|
|
9265
|
-
table.add_column(col, justify="center")
|
|
9266
|
-
for col in result_cols:
|
|
9267
|
-
table.add_column(Text(f"{col}", style="cyan"), justify="center")
|
|
9268
|
-
|
|
9269
|
-
def _pareto_front_table_add_rows(table: Table, rows: List[Dict[str, str]], param_cols: List[str], result_cols: List[str]) -> None:
|
|
9270
|
-
for row in rows:
|
|
9271
|
-
values = [str(helpers.to_int_when_possible(row[col])) for col in param_cols]
|
|
9272
|
-
result_values = [Text(str(helpers.to_int_when_possible(row[col])), style="cyan") for col in result_cols]
|
|
9273
|
-
table.add_row(*values, *result_values, style="bold green")
|
|
9274
|
-
|
|
9275
|
-
def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
|
|
9276
|
-
if not os.path.exists(RESULT_CSV_FILE):
|
|
9277
|
-
print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
|
|
9278
|
-
return None
|
|
9279
|
-
|
|
9280
|
-
return create_pareto_front_table(idxs, metric_x, metric_y)
|
|
9281
|
-
|
|
9282
|
-
def supports_sixel() -> bool:
|
|
9283
|
-
term = os.environ.get("TERM", "").lower()
|
|
9284
|
-
if "xterm" in term or "mlterm" in term:
|
|
9285
|
-
return True
|
|
9286
|
-
|
|
9287
|
-
try:
|
|
9288
|
-
output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
|
|
9289
|
-
if output.returncode == 0 and "sixel" in output.stdout.lower():
|
|
9290
|
-
return True
|
|
9291
|
-
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
9292
|
-
pass
|
|
9293
|
-
|
|
9294
|
-
return False
|
|
9295
|
-
|
|
9296
|
-
def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
|
|
9297
|
-
if data is None:
|
|
9298
|
-
print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
|
|
9299
|
-
return
|
|
9300
|
-
|
|
9301
|
-
if not supports_sixel():
|
|
9302
|
-
print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
|
|
9303
|
-
return
|
|
9304
|
-
|
|
9305
|
-
import matplotlib.pyplot as plt
|
|
9306
|
-
|
|
9307
|
-
means = data[x_metric][y_metric]["means"]
|
|
9308
|
-
|
|
9309
|
-
x_values = means[x_metric]
|
|
9310
|
-
y_values = means[y_metric]
|
|
9311
|
-
|
|
9312
|
-
fig, _ax = plt.subplots()
|
|
9313
|
-
|
|
9314
|
-
_ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
|
|
9315
|
-
|
|
9316
|
-
_ax.set_xlabel(x_metric)
|
|
9317
|
-
_ax.set_ylabel(y_metric)
|
|
9318
|
-
|
|
9319
|
-
_ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
|
|
9320
|
-
|
|
9321
|
-
_ax.ticklabel_format(style='plain', axis='both', useOffset=False)
|
|
9322
|
-
|
|
9323
|
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
|
|
9324
|
-
plt.savefig(tmp_file.name, dpi=300)
|
|
9325
|
-
|
|
9326
|
-
print_image_to_cli(tmp_file.name, 1000)
|
|
9327
|
-
|
|
9328
|
-
plt.close(fig)
|
|
9329
|
-
|
|
9330
|
-
def _pareto_front_general_validate_shapes(x: np.ndarray, y: np.ndarray) -> None:
|
|
9331
|
-
if x.shape != y.shape:
|
|
9332
|
-
raise ValueError("Input arrays x and y must have the same shape.")
|
|
9333
|
-
|
|
9334
|
-
def _pareto_front_general_compare(
|
|
9335
|
-
xi: float, yi: float, xj: float, yj: float,
|
|
9336
|
-
x_minimize: bool, y_minimize: bool
|
|
9337
|
-
) -> bool:
|
|
9338
|
-
x_better_eq = xj <= xi if x_minimize else xj >= xi
|
|
9339
|
-
y_better_eq = yj <= yi if y_minimize else yj >= yi
|
|
9340
|
-
x_strictly_better = xj < xi if x_minimize else xj > xi
|
|
9341
|
-
y_strictly_better = yj < yi if y_minimize else yj > yi
|
|
9342
|
-
|
|
9343
|
-
return bool(x_better_eq and y_better_eq and (x_strictly_better or y_strictly_better))
|
|
9344
|
-
|
|
9345
|
-
def _pareto_front_general_find_dominated(
|
|
9346
|
-
x: np.ndarray, y: np.ndarray, x_minimize: bool, y_minimize: bool
|
|
9347
|
-
) -> np.ndarray:
|
|
9348
|
-
num_points = len(x)
|
|
9349
|
-
is_dominated = np.zeros(num_points, dtype=bool)
|
|
9350
|
-
|
|
9351
|
-
for i in range(num_points):
|
|
9352
|
-
for j in range(num_points):
|
|
9353
|
-
if i == j:
|
|
9354
|
-
continue
|
|
9355
|
-
|
|
9356
|
-
if _pareto_front_general_compare(x[i], y[i], x[j], y[j], x_minimize, y_minimize):
|
|
9357
|
-
is_dominated[i] = True
|
|
9358
|
-
break
|
|
9359
|
-
|
|
9360
|
-
return is_dominated
|
|
9361
|
-
|
|
9362
|
-
def pareto_front_general(
|
|
9363
|
-
x: np.ndarray,
|
|
9364
|
-
y: np.ndarray,
|
|
9365
|
-
x_minimize: bool = True,
|
|
9366
|
-
y_minimize: bool = True
|
|
9367
|
-
) -> np.ndarray:
|
|
9368
|
-
try:
|
|
9369
|
-
_pareto_front_general_validate_shapes(x, y)
|
|
9370
|
-
is_dominated = _pareto_front_general_find_dominated(x, y, x_minimize, y_minimize)
|
|
9371
|
-
return np.where(~is_dominated)[0]
|
|
9372
|
-
except Exception as e:
|
|
9373
|
-
print("Error in pareto_front_general:", str(e))
|
|
9374
|
-
return np.array([], dtype=int)
|
|
9375
|
-
|
|
9376
|
-
def _pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
|
|
9377
|
-
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
9378
|
-
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
9379
|
-
|
|
9380
|
-
if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
|
|
9381
|
-
return None
|
|
9382
|
-
|
|
9383
|
-
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
9384
|
-
result_names = [line.strip() for line in f if line.strip()]
|
|
9385
|
-
|
|
9386
|
-
records: dict = defaultdict(lambda: {'means': {}})
|
|
9387
|
-
|
|
9388
|
-
with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
|
|
9389
|
-
reader = csv.DictReader(csvfile)
|
|
9390
|
-
for row in reader:
|
|
9391
|
-
trial_index = int(row['trial_index'])
|
|
9392
|
-
arm_name = row['arm_name']
|
|
9393
|
-
key = (trial_index, arm_name)
|
|
9394
|
-
|
|
9395
|
-
for metric in result_names:
|
|
9396
|
-
if metric in row:
|
|
9397
|
-
try:
|
|
9398
|
-
records[key]['means'][metric] = float(row[metric])
|
|
9399
|
-
except ValueError:
|
|
9400
|
-
continue
|
|
9401
|
-
|
|
9402
|
-
return records
|
|
9403
|
-
|
|
9404
|
-
def _pareto_front_filter_complete_points(
|
|
9405
|
-
path_to_calculate: str,
|
|
9406
|
-
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
9407
|
-
primary_name: str,
|
|
9408
|
-
secondary_name: str
|
|
9409
|
-
) -> List[Tuple[Tuple[int, str], float, float]]:
|
|
9410
|
-
points = []
|
|
9411
|
-
for key, metrics in records.items():
|
|
9412
|
-
means = metrics['means']
|
|
9413
|
-
if primary_name in means and secondary_name in means:
|
|
9414
|
-
x_val = means[primary_name]
|
|
9415
|
-
y_val = means[secondary_name]
|
|
9416
|
-
points.append((key, x_val, y_val))
|
|
9417
|
-
if len(points) == 0:
|
|
9418
|
-
raise ValueError(f"No full data points with both objectives found in {path_to_calculate}.")
|
|
9419
|
-
return points
|
|
9420
|
-
|
|
9421
|
-
def _pareto_front_transform_objectives(
|
|
9422
|
-
points: List[Tuple[Any, float, float]],
|
|
9423
|
-
primary_name: str,
|
|
9424
|
-
secondary_name: str
|
|
9425
|
-
) -> Tuple[np.ndarray, np.ndarray]:
|
|
9426
|
-
primary_idx = arg_result_names.index(primary_name)
|
|
9427
|
-
secondary_idx = arg_result_names.index(secondary_name)
|
|
9428
|
-
|
|
9429
|
-
x = np.array([p[1] for p in points])
|
|
9430
|
-
y = np.array([p[2] for p in points])
|
|
9431
|
-
|
|
9432
|
-
if arg_result_min_or_max[primary_idx] == "max":
|
|
9433
|
-
x = -x
|
|
9434
|
-
elif arg_result_min_or_max[primary_idx] != "min":
|
|
9435
|
-
raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
|
|
9436
|
-
|
|
9437
|
-
if arg_result_min_or_max[secondary_idx] == "max":
|
|
9438
|
-
y = -y
|
|
9439
|
-
elif arg_result_min_or_max[secondary_idx] != "min":
|
|
9440
|
-
raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
|
|
9441
|
-
|
|
9442
|
-
return x, y
|
|
9443
|
-
|
|
9444
|
-
def _pareto_front_select_pareto_points(
|
|
9445
|
-
x: np.ndarray,
|
|
9446
|
-
y: np.ndarray,
|
|
9447
|
-
x_minimize: bool,
|
|
9448
|
-
y_minimize: bool,
|
|
9449
|
-
points: List[Tuple[Any, float, float]],
|
|
9450
|
-
num_points: int
|
|
9451
|
-
) -> List[Tuple[Any, float, float]]:
|
|
9452
|
-
indices = pareto_front_general(x, y, x_minimize, y_minimize)
|
|
9453
|
-
sorted_indices = indices[np.argsort(x[indices])]
|
|
9454
|
-
sorted_indices = sorted_indices[:num_points]
|
|
9455
|
-
selected_points = [points[i] for i in sorted_indices]
|
|
9456
|
-
return selected_points
|
|
9457
|
-
|
|
9458
|
-
def _pareto_front_build_return_structure(
|
|
9459
|
-
path_to_calculate: str,
|
|
9460
|
-
selected_points: List[Tuple[Any, float, float]],
|
|
9461
|
-
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
9462
|
-
absolute_metrics: List[str],
|
|
9463
|
-
primary_name: str,
|
|
9464
|
-
secondary_name: str
|
|
9465
|
-
) -> dict:
|
|
9466
|
-
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
9467
|
-
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
9468
|
-
|
|
9469
|
-
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
9470
|
-
result_names = [line.strip() for line in f if line.strip()]
|
|
9471
|
-
|
|
9472
|
-
csv_rows = {}
|
|
9473
|
-
with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
|
|
9474
|
-
reader = csv.DictReader(csvfile)
|
|
9475
|
-
for row in reader:
|
|
9476
|
-
trial_index = int(row['trial_index'])
|
|
9477
|
-
csv_rows[trial_index] = row
|
|
9478
|
-
|
|
9479
|
-
ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
|
|
9480
|
-
ignored_columns.update(result_names)
|
|
9481
|
-
|
|
9482
|
-
param_dicts = []
|
|
9483
|
-
idxs = []
|
|
9484
|
-
means_dict = defaultdict(list)
|
|
9485
|
-
|
|
9486
|
-
for (trial_index, arm_name), _, _ in selected_points:
|
|
9487
|
-
row = csv_rows.get(trial_index, {})
|
|
9488
|
-
if row == {} or row is None or row['arm_name'] != arm_name:
|
|
9489
|
-
print_debug(f"_pareto_front_build_return_structure: trial_index '{trial_index}' could not be found and row returned as None")
|
|
9490
|
-
continue
|
|
9491
|
-
|
|
9492
|
-
idxs.append(int(row["trial_index"]))
|
|
9493
|
-
|
|
9494
|
-
param_dict: dict[str, int | float | str] = {}
|
|
9495
|
-
for key, value in row.items():
|
|
9496
|
-
if key not in ignored_columns:
|
|
9497
|
-
try:
|
|
9498
|
-
param_dict[key] = int(value)
|
|
9499
|
-
except ValueError:
|
|
9500
|
-
try:
|
|
9501
|
-
param_dict[key] = float(value)
|
|
9502
|
-
except ValueError:
|
|
9503
|
-
param_dict[key] = value
|
|
10049
|
+
if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
|
|
10050
|
+
_fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
|
|
9504
10051
|
|
|
9505
|
-
|
|
10052
|
+
def add_exclude_to_defective_nodes() -> None:
|
|
10053
|
+
with spinner("Adding excluded nodes..."):
|
|
10054
|
+
if args.exclude:
|
|
10055
|
+
entries = [entry.strip() for entry in args.exclude.split(',')]
|
|
9506
10056
|
|
|
9507
|
-
|
|
9508
|
-
|
|
10057
|
+
for entry in entries:
|
|
10058
|
+
count_defective_nodes(None, entry)
|
|
9509
10059
|
|
|
9510
|
-
|
|
9511
|
-
|
|
9512
|
-
|
|
9513
|
-
|
|
9514
|
-
"param_dicts": param_dicts,
|
|
9515
|
-
"means": dict(means_dict),
|
|
9516
|
-
"idxs": idxs
|
|
9517
|
-
},
|
|
9518
|
-
"absolute_metrics": absolute_metrics
|
|
9519
|
-
}
|
|
9520
|
-
}
|
|
10060
|
+
def check_max_eval(_max_eval: int) -> None:
|
|
10061
|
+
with spinner("Checking max_eval..."):
|
|
10062
|
+
if not _max_eval:
|
|
10063
|
+
_fatal_error("--max_eval needs to be set!", 19)
|
|
9521
10064
|
|
|
9522
|
-
|
|
10065
|
+
def parse_parameters() -> Any:
|
|
10066
|
+
cli_params_experiment_parameters = None
|
|
10067
|
+
if args.parameter:
|
|
10068
|
+
parse_experiment_parameters()
|
|
10069
|
+
cli_params_experiment_parameters = experiment_parameters
|
|
9523
10070
|
|
|
9524
|
-
|
|
9525
|
-
path_to_calculate: str,
|
|
9526
|
-
primary_objective: str,
|
|
9527
|
-
secondary_objective: str,
|
|
9528
|
-
x_minimize: bool,
|
|
9529
|
-
y_minimize: bool,
|
|
9530
|
-
absolute_metrics: List[str],
|
|
9531
|
-
num_points: int
|
|
9532
|
-
) -> Optional[dict]:
|
|
9533
|
-
records = _pareto_front_aggregate_data(path_to_calculate)
|
|
10071
|
+
return cli_params_experiment_parameters
|
|
9534
10072
|
|
|
9535
|
-
|
|
9536
|
-
|
|
10073
|
+
def supports_sixel() -> bool:
|
|
10074
|
+
term = os.environ.get("TERM", "").lower()
|
|
10075
|
+
if "xterm" in term or "mlterm" in term:
|
|
10076
|
+
return True
|
|
9537
10077
|
|
|
9538
|
-
|
|
9539
|
-
|
|
9540
|
-
|
|
9541
|
-
|
|
10078
|
+
try:
|
|
10079
|
+
output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
|
|
10080
|
+
if output.returncode == 0 and "sixel" in output.stdout.lower():
|
|
10081
|
+
return True
|
|
10082
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
10083
|
+
pass
|
|
9542
10084
|
|
|
9543
|
-
return
|
|
10085
|
+
return False
|
|
9544
10086
|
|
|
9545
10087
|
def save_experiment_state() -> None:
|
|
9546
10088
|
try:
|
|
@@ -9548,14 +10090,14 @@ def save_experiment_state() -> None:
|
|
|
9548
10090
|
print_red("save_experiment_state: ax_client or ax_client.experiment is None, cannot save.")
|
|
9549
10091
|
return
|
|
9550
10092
|
state_path = get_current_run_folder("experiment_state.json")
|
|
9551
|
-
|
|
10093
|
+
save_ax_client_to_json_file(state_path)
|
|
9552
10094
|
except Exception as e:
|
|
9553
|
-
|
|
10095
|
+
print_debug(f"Error saving experiment state: {e}")
|
|
9554
10096
|
|
|
9555
10097
|
def wait_for_state_file(state_path: str, min_size: int = 5, max_wait_seconds: int = 60) -> bool:
|
|
9556
10098
|
try:
|
|
9557
10099
|
if not os.path.exists(state_path):
|
|
9558
|
-
|
|
10100
|
+
print_debug(f"[ERROR] File '{state_path}' does not exist.")
|
|
9559
10101
|
return False
|
|
9560
10102
|
|
|
9561
10103
|
i = 0
|
|
@@ -9610,7 +10152,6 @@ def load_experiment_state() -> None:
|
|
|
9610
10152
|
state_path = get_current_run_folder("experiment_state.json")
|
|
9611
10153
|
|
|
9612
10154
|
if not os.path.exists(state_path):
|
|
9613
|
-
print(f"State file {state_path} does not exist, starting fresh")
|
|
9614
10155
|
return
|
|
9615
10156
|
|
|
9616
10157
|
if args.worker_generator_path:
|
|
@@ -9624,10 +10165,10 @@ def load_experiment_state() -> None:
|
|
|
9624
10165
|
return
|
|
9625
10166
|
|
|
9626
10167
|
try:
|
|
9627
|
-
arms_seen = {}
|
|
10168
|
+
arms_seen: dict = {}
|
|
9628
10169
|
for arm in data.get("arms", []):
|
|
9629
10170
|
name = arm.get("name")
|
|
9630
|
-
sig = arm.get("parameters")
|
|
10171
|
+
sig = arm.get("parameters")
|
|
9631
10172
|
if not name:
|
|
9632
10173
|
continue
|
|
9633
10174
|
if name in arms_seen and arms_seen[name] != sig:
|
|
@@ -9636,7 +10177,6 @@ def load_experiment_state() -> None:
|
|
|
9636
10177
|
arm["name"] = new_name
|
|
9637
10178
|
arms_seen[name] = sig
|
|
9638
10179
|
|
|
9639
|
-
# Gefilterten Zustand speichern und laden
|
|
9640
10180
|
temp_path = state_path + ".no_conflicts.json"
|
|
9641
10181
|
with open(temp_path, encoding="utf-8", mode="w") as f:
|
|
9642
10182
|
json.dump(data, f)
|
|
@@ -9776,38 +10316,47 @@ def get_result_minimize_flag(path_to_calculate: str, resname: str) -> bool:
|
|
|
9776
10316
|
|
|
9777
10317
|
return minmax[index] == "min"
|
|
9778
10318
|
|
|
9779
|
-
def
|
|
9780
|
-
|
|
10319
|
+
def post_job_calculate_pareto_front() -> None:
|
|
10320
|
+
if not args.calculate_pareto_front_of_job:
|
|
10321
|
+
return
|
|
9781
10322
|
|
|
9782
|
-
|
|
10323
|
+
failure = False
|
|
9783
10324
|
|
|
9784
|
-
|
|
10325
|
+
_paths_to_calculate = []
|
|
9785
10326
|
|
|
9786
|
-
for
|
|
9787
|
-
|
|
9788
|
-
|
|
9789
|
-
metric_y = arg_result_names[j]
|
|
10327
|
+
for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
|
|
10328
|
+
try:
|
|
10329
|
+
found_paths = find_results_paths(_path_to_calculate)
|
|
9790
10330
|
|
|
9791
|
-
|
|
9792
|
-
|
|
10331
|
+
for _fp in found_paths:
|
|
10332
|
+
if _fp not in _paths_to_calculate:
|
|
10333
|
+
_paths_to_calculate.append(_fp)
|
|
10334
|
+
except (FileNotFoundError, NotADirectoryError) as e:
|
|
10335
|
+
print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
|
|
9793
10336
|
|
|
9794
|
-
|
|
9795
|
-
if metric_x not in pareto_front_data:
|
|
9796
|
-
pareto_front_data[metric_x] = {}
|
|
10337
|
+
failure = True
|
|
9797
10338
|
|
|
9798
|
-
|
|
9799
|
-
|
|
9800
|
-
|
|
9801
|
-
|
|
9802
|
-
print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
|
|
9803
|
-
skip = True
|
|
10339
|
+
for _path_to_calculate in _paths_to_calculate:
|
|
10340
|
+
for path_to_calculate in found_paths:
|
|
10341
|
+
if not job_calculate_pareto_front(path_to_calculate):
|
|
10342
|
+
failure = True
|
|
9804
10343
|
|
|
9805
|
-
|
|
10344
|
+
if failure:
|
|
10345
|
+
my_exit(24)
|
|
10346
|
+
|
|
10347
|
+
my_exit(0)
|
|
10348
|
+
|
|
10349
|
+
def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
|
|
10350
|
+
if not os.path.exists(RESULT_CSV_FILE):
|
|
10351
|
+
print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
|
|
10352
|
+
return None
|
|
10353
|
+
|
|
10354
|
+
return create_pareto_front_table(idxs, metric_x, metric_y)
|
|
9806
10355
|
|
|
9807
10356
|
def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_sixel_and_table: bool = False) -> None:
|
|
9808
10357
|
if len(res_names) <= 1:
|
|
9809
10358
|
print_debug(f"--result_names (has {len(res_names)} entries) must be at least 2.")
|
|
9810
|
-
return
|
|
10359
|
+
return None
|
|
9811
10360
|
|
|
9812
10361
|
pareto_front_data: dict = get_pareto_front_data(path_to_calculate, res_names)
|
|
9813
10362
|
|
|
@@ -9828,8 +10377,16 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
|
|
|
9828
10377
|
else:
|
|
9829
10378
|
print(f"Not showing Pareto-front-sixel for {path_to_calculate}")
|
|
9830
10379
|
|
|
9831
|
-
if
|
|
9832
|
-
|
|
10380
|
+
if calculated_frontier is None:
|
|
10381
|
+
print_debug("ERROR: calculated_frontier is None")
|
|
10382
|
+
return None
|
|
10383
|
+
|
|
10384
|
+
try:
|
|
10385
|
+
if len(calculated_frontier[metric_x][metric_y]["idxs"]):
|
|
10386
|
+
pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
|
|
10387
|
+
except AttributeError:
|
|
10388
|
+
print_debug(f"ERROR: calculated_frontier structure invalid for ({metric_x}, {metric_y})")
|
|
10389
|
+
return None
|
|
9833
10390
|
|
|
9834
10391
|
rich_table = pareto_front_as_rich_table(
|
|
9835
10392
|
calculated_frontier[metric_x][metric_y]["idxs"],
|
|
@@ -9854,6 +10411,8 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
|
|
|
9854
10411
|
|
|
9855
10412
|
live_share_after_pareto()
|
|
9856
10413
|
|
|
10414
|
+
return None
|
|
10415
|
+
|
|
9857
10416
|
def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_color: str) -> None:
|
|
9858
10417
|
cpu_count = os.cpu_count()
|
|
9859
10418
|
|
|
@@ -9865,9 +10424,11 @@ def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_
|
|
|
9865
10424
|
pass
|
|
9866
10425
|
|
|
9867
10426
|
if gpu_string:
|
|
9868
|
-
console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]
|
|
10427
|
+
console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]")
|
|
9869
10428
|
else:
|
|
9870
|
-
print_green(f"You have {cpu_count} CPUs available for the main process.
|
|
10429
|
+
print_green(f"You have {cpu_count} CPUs available for the main process.")
|
|
10430
|
+
|
|
10431
|
+
print_green(gs_string)
|
|
9871
10432
|
|
|
9872
10433
|
def write_args_overview_table() -> None:
|
|
9873
10434
|
table = Table(title="Arguments Overview")
|
|
@@ -10134,112 +10695,6 @@ def find_results_paths(base_path: str) -> list:
|
|
|
10134
10695
|
|
|
10135
10696
|
return list(set(found_paths))
|
|
10136
10697
|
|
|
10137
|
-
def post_job_calculate_pareto_front() -> None:
|
|
10138
|
-
if not args.calculate_pareto_front_of_job:
|
|
10139
|
-
return
|
|
10140
|
-
|
|
10141
|
-
failure = False
|
|
10142
|
-
|
|
10143
|
-
_paths_to_calculate = []
|
|
10144
|
-
|
|
10145
|
-
for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
|
|
10146
|
-
try:
|
|
10147
|
-
found_paths = find_results_paths(_path_to_calculate)
|
|
10148
|
-
|
|
10149
|
-
for _fp in found_paths:
|
|
10150
|
-
if _fp not in _paths_to_calculate:
|
|
10151
|
-
_paths_to_calculate.append(_fp)
|
|
10152
|
-
except (FileNotFoundError, NotADirectoryError) as e:
|
|
10153
|
-
print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
|
|
10154
|
-
|
|
10155
|
-
failure = True
|
|
10156
|
-
|
|
10157
|
-
for _path_to_calculate in _paths_to_calculate:
|
|
10158
|
-
for path_to_calculate in found_paths:
|
|
10159
|
-
if not job_calculate_pareto_front(path_to_calculate):
|
|
10160
|
-
failure = True
|
|
10161
|
-
|
|
10162
|
-
if failure:
|
|
10163
|
-
my_exit(24)
|
|
10164
|
-
|
|
10165
|
-
my_exit(0)
|
|
10166
|
-
|
|
10167
|
-
def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
|
|
10168
|
-
pf_start_time = time.time()
|
|
10169
|
-
|
|
10170
|
-
if not path_to_calculate:
|
|
10171
|
-
return False
|
|
10172
|
-
|
|
10173
|
-
global CURRENT_RUN_FOLDER
|
|
10174
|
-
global RESULT_CSV_FILE
|
|
10175
|
-
global arg_result_names
|
|
10176
|
-
|
|
10177
|
-
if not path_to_calculate:
|
|
10178
|
-
print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
|
|
10179
|
-
return False
|
|
10180
|
-
|
|
10181
|
-
if not os.path.exists(path_to_calculate):
|
|
10182
|
-
print_red(f"Path '{path_to_calculate}' does not exist")
|
|
10183
|
-
return False
|
|
10184
|
-
|
|
10185
|
-
ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
|
|
10186
|
-
|
|
10187
|
-
if not os.path.exists(ax_client_json):
|
|
10188
|
-
print_red(f"Path '{ax_client_json}' not found")
|
|
10189
|
-
return False
|
|
10190
|
-
|
|
10191
|
-
checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
|
|
10192
|
-
if not os.path.exists(checkpoint_file):
|
|
10193
|
-
print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
|
|
10194
|
-
return False
|
|
10195
|
-
|
|
10196
|
-
RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
10197
|
-
if not os.path.exists(RESULT_CSV_FILE):
|
|
10198
|
-
print_red(f"{RESULT_CSV_FILE} not found")
|
|
10199
|
-
return False
|
|
10200
|
-
|
|
10201
|
-
res_names = []
|
|
10202
|
-
|
|
10203
|
-
res_names_file = f"{path_to_calculate}/result_names.txt"
|
|
10204
|
-
if not os.path.exists(res_names_file):
|
|
10205
|
-
print_red(f"File '{res_names_file}' does not exist")
|
|
10206
|
-
return False
|
|
10207
|
-
|
|
10208
|
-
try:
|
|
10209
|
-
with open(res_names_file, "r", encoding="utf-8") as file:
|
|
10210
|
-
lines = file.readlines()
|
|
10211
|
-
except Exception as e:
|
|
10212
|
-
print_red(f"Error reading file '{res_names_file}': {e}")
|
|
10213
|
-
return False
|
|
10214
|
-
|
|
10215
|
-
for line in lines:
|
|
10216
|
-
entry = line.strip()
|
|
10217
|
-
if entry != "":
|
|
10218
|
-
res_names.append(entry)
|
|
10219
|
-
|
|
10220
|
-
if len(res_names) < 2:
|
|
10221
|
-
print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
|
|
10222
|
-
return False
|
|
10223
|
-
|
|
10224
|
-
load_username_to_args(path_to_calculate)
|
|
10225
|
-
|
|
10226
|
-
CURRENT_RUN_FOLDER = path_to_calculate
|
|
10227
|
-
|
|
10228
|
-
arg_result_names = res_names
|
|
10229
|
-
|
|
10230
|
-
load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
|
|
10231
|
-
|
|
10232
|
-
if experiment_parameters is None:
|
|
10233
|
-
return False
|
|
10234
|
-
|
|
10235
|
-
show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
|
|
10236
|
-
|
|
10237
|
-
pf_end_time = time.time()
|
|
10238
|
-
|
|
10239
|
-
print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
|
|
10240
|
-
|
|
10241
|
-
return True
|
|
10242
|
-
|
|
10243
10698
|
def set_arg_states_from_continue() -> None:
|
|
10244
10699
|
if args.continue_previous_job and not args.num_random_steps:
|
|
10245
10700
|
num_random_steps_file = f"{args.continue_previous_job}/state_files/num_random_steps"
|
|
@@ -10273,7 +10728,7 @@ def write_result_names_file() -> None:
|
|
|
10273
10728
|
except Exception as e:
|
|
10274
10729
|
print_red(f"Error trying to open file '{fn}': {e}")
|
|
10275
10730
|
|
|
10276
|
-
def run_program_once(params=None) -> None:
|
|
10731
|
+
def run_program_once(params: Optional[dict] = None) -> None:
|
|
10277
10732
|
if not args.run_program_once:
|
|
10278
10733
|
print_debug("[yellow]No setup script specified (run_program_once). Skipping setup.[/yellow]")
|
|
10279
10734
|
return
|
|
@@ -10282,27 +10737,27 @@ def run_program_once(params=None) -> None:
|
|
|
10282
10737
|
params = {}
|
|
10283
10738
|
|
|
10284
10739
|
if isinstance(args.run_program_once, str):
|
|
10285
|
-
command_str = args.run_program_once
|
|
10740
|
+
command_str = decode_if_base64(args.run_program_once)
|
|
10286
10741
|
for k, v in params.items():
|
|
10287
10742
|
placeholder = f"%({k})"
|
|
10288
10743
|
command_str = command_str.replace(placeholder, str(v))
|
|
10289
10744
|
|
|
10290
|
-
|
|
10291
|
-
|
|
10292
|
-
|
|
10293
|
-
|
|
10294
|
-
|
|
10295
|
-
|
|
10745
|
+
print(f"Executing command: [cyan]{command_str}[/cyan]")
|
|
10746
|
+
result = subprocess.run(command_str, shell=True, check=True)
|
|
10747
|
+
if result.returncode == 0:
|
|
10748
|
+
print("[bold green]Setup script completed successfully ✅[/bold green]")
|
|
10749
|
+
else:
|
|
10750
|
+
print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
|
|
10296
10751
|
|
|
10297
|
-
|
|
10752
|
+
my_exit(57)
|
|
10298
10753
|
|
|
10299
10754
|
elif isinstance(args.run_program_once, (list, tuple)):
|
|
10300
10755
|
with spinner("run_program_once: Executing command list: [cyan]{args.run_program_once}[/cyan]"):
|
|
10301
10756
|
result = subprocess.run(args.run_program_once, check=True)
|
|
10302
10757
|
if result.returncode == 0:
|
|
10303
|
-
|
|
10758
|
+
print("[bold green]Setup script completed successfully ✅[/bold green]")
|
|
10304
10759
|
else:
|
|
10305
|
-
|
|
10760
|
+
print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
|
|
10306
10761
|
|
|
10307
10762
|
my_exit(57)
|
|
10308
10763
|
|
|
@@ -10312,7 +10767,7 @@ def run_program_once(params=None) -> None:
|
|
|
10312
10767
|
my_exit(57)
|
|
10313
10768
|
|
|
10314
10769
|
def show_omniopt_call() -> None:
|
|
10315
|
-
def remove_ui_url(arg_str) -> str:
|
|
10770
|
+
def remove_ui_url(arg_str: str) -> str:
|
|
10316
10771
|
return re.sub(r'(?:--ui_url(?:=\S+)?(?:\s+\S+)?)', '', arg_str).strip()
|
|
10317
10772
|
|
|
10318
10773
|
original_argv = " ".join(sys.argv[1:])
|
|
@@ -10320,8 +10775,14 @@ def show_omniopt_call() -> None:
|
|
|
10320
10775
|
|
|
10321
10776
|
original_print(oo_call + " " + cleaned)
|
|
10322
10777
|
|
|
10778
|
+
if args.dependency is not None and args.dependency != "":
|
|
10779
|
+
print(f"Dependency: {args.dependency}")
|
|
10780
|
+
|
|
10781
|
+
if args.ui_url is not None and args.ui_url != "":
|
|
10782
|
+
print_yellow("--ui_url is deprecated. Do not use it anymore. It will be ignored and one day be removed.")
|
|
10783
|
+
|
|
10323
10784
|
def main() -> None:
|
|
10324
|
-
global RESULT_CSV_FILE,
|
|
10785
|
+
global RESULT_CSV_FILE, LOGFILE_DEBUG_GET_NEXT_TRIALS
|
|
10325
10786
|
|
|
10326
10787
|
check_if_has_random_steps()
|
|
10327
10788
|
|
|
@@ -10372,7 +10833,7 @@ def main() -> None:
|
|
|
10372
10833
|
print_run_info()
|
|
10373
10834
|
|
|
10374
10835
|
initialize_nvidia_logs()
|
|
10375
|
-
|
|
10836
|
+
write_ui_url()
|
|
10376
10837
|
|
|
10377
10838
|
LOGFILE_DEBUG_GET_NEXT_TRIALS = get_current_run_folder('get_next_trials.csv')
|
|
10378
10839
|
cli_params_experiment_parameters = parse_parameters()
|
|
@@ -10403,15 +10864,13 @@ def main() -> None:
|
|
|
10403
10864
|
exp_params = get_experiment_parameters(cli_params_experiment_parameters)
|
|
10404
10865
|
|
|
10405
10866
|
if exp_params is not None:
|
|
10406
|
-
|
|
10867
|
+
experiment_args, gpu_string, gpu_color = exp_params
|
|
10407
10868
|
print_debug(f"experiment_parameters: {experiment_parameters}")
|
|
10408
10869
|
|
|
10409
10870
|
set_orchestrator()
|
|
10410
10871
|
|
|
10411
10872
|
init_live_share()
|
|
10412
10873
|
|
|
10413
|
-
start_periodic_live_share()
|
|
10414
|
-
|
|
10415
10874
|
show_available_hardware_and_generation_strategy_string(gpu_string, gpu_color)
|
|
10416
10875
|
|
|
10417
10876
|
original_print(f"Run-Program: {global_vars['joined_run_program']}")
|
|
@@ -10426,6 +10885,8 @@ def main() -> None:
|
|
|
10426
10885
|
|
|
10427
10886
|
write_files_and_show_overviews()
|
|
10428
10887
|
|
|
10888
|
+
live_share()
|
|
10889
|
+
|
|
10429
10890
|
#if args.continue_previous_job:
|
|
10430
10891
|
# insert_jobs_from_csv(f"{args.continue_previous_job}/{RESULTS_CSV_FILENAME}")
|
|
10431
10892
|
|
|
@@ -10434,7 +10895,7 @@ def main() -> None:
|
|
|
10434
10895
|
|
|
10435
10896
|
set_global_generation_strategy()
|
|
10436
10897
|
|
|
10437
|
-
start_worker_generators()
|
|
10898
|
+
#start_worker_generators()
|
|
10438
10899
|
|
|
10439
10900
|
try:
|
|
10440
10901
|
run_search_with_progress_bar()
|
|
@@ -10500,10 +10961,93 @@ def initialize_nvidia_logs() -> None:
|
|
|
10500
10961
|
global NVIDIA_SMI_LOGS_BASE
|
|
10501
10962
|
NVIDIA_SMI_LOGS_BASE = get_current_run_folder('gpu_usage_')
|
|
10502
10963
|
|
|
10503
|
-
def
|
|
10504
|
-
|
|
10505
|
-
|
|
10506
|
-
|
|
10964
|
+
def build_gui_url(config: argparse.Namespace) -> str:
|
|
10965
|
+
base_url = get_base_url()
|
|
10966
|
+
params = collect_params(config)
|
|
10967
|
+
ret = f"{base_url}?{urlencode(params, doseq=True)}"
|
|
10968
|
+
|
|
10969
|
+
return ret
|
|
10970
|
+
|
|
10971
|
+
def get_result_names_for_url(value: List) -> str:
|
|
10972
|
+
d = dict(v.split("=", 1) if "=" in v else (v, "min") for v in value)
|
|
10973
|
+
s = " ".join(f"{k}={v}" for k, v in d.items())
|
|
10974
|
+
|
|
10975
|
+
return s
|
|
10976
|
+
|
|
10977
|
+
def collect_params(config: argparse.Namespace) -> dict:
|
|
10978
|
+
params = {}
|
|
10979
|
+
user_home = os.path.expanduser("~")
|
|
10980
|
+
|
|
10981
|
+
for attr, value in vars(config).items():
|
|
10982
|
+
if attr == "run_program":
|
|
10983
|
+
params[attr] = global_vars["joined_run_program"]
|
|
10984
|
+
elif attr == "result_names" and value:
|
|
10985
|
+
params[attr] = get_result_names_for_url(value)
|
|
10986
|
+
elif attr == "parameter" and value is not None:
|
|
10987
|
+
params.update(process_parameters(config.parameter))
|
|
10988
|
+
elif attr == "root_venv_dir":
|
|
10989
|
+
if value is not None and os.path.abspath(value) != os.path.abspath(user_home):
|
|
10990
|
+
params[attr] = value
|
|
10991
|
+
elif isinstance(value, bool):
|
|
10992
|
+
params[attr] = int(value)
|
|
10993
|
+
elif isinstance(value, list):
|
|
10994
|
+
params[attr] = value
|
|
10995
|
+
elif value is not None:
|
|
10996
|
+
params[attr] = value
|
|
10997
|
+
|
|
10998
|
+
return params
|
|
10999
|
+
|
|
11000
|
+
def process_parameters(parameters: list) -> dict:
|
|
11001
|
+
params = {}
|
|
11002
|
+
for i, param in enumerate(parameters):
|
|
11003
|
+
if isinstance(param, dict):
|
|
11004
|
+
name = param.get("name", f"param_{i}")
|
|
11005
|
+
ptype = param.get("type", "unknown")
|
|
11006
|
+
else:
|
|
11007
|
+
name = param[0] if len(param) > 0 else f"param_{i}"
|
|
11008
|
+
ptype = param[1] if len(param) > 1 else "unknown"
|
|
11009
|
+
|
|
11010
|
+
params[f"parameter_{i}_name"] = name
|
|
11011
|
+
params[f"parameter_{i}_type"] = ptype
|
|
11012
|
+
|
|
11013
|
+
if ptype == "range":
|
|
11014
|
+
params.update(process_range_parameter(i, param))
|
|
11015
|
+
elif ptype == "choice":
|
|
11016
|
+
params.update(process_choice_parameter(i, param))
|
|
11017
|
+
elif ptype == "fixed":
|
|
11018
|
+
params.update(process_fixed_parameter(i, param))
|
|
11019
|
+
|
|
11020
|
+
params["num_parameters"] = len(parameters)
|
|
11021
|
+
return params
|
|
11022
|
+
|
|
11023
|
+
def process_range_parameter(i: int, param: list) -> dict:
|
|
11024
|
+
return {
|
|
11025
|
+
f"parameter_{i}_min": param[2] if len(param) > 3 else 0,
|
|
11026
|
+
f"parameter_{i}_max": param[3] if len(param) > 3 else 1,
|
|
11027
|
+
f"parameter_{i}_number_type": param[4] if len(param) > 4 else "float",
|
|
11028
|
+
f"parameter_{i}_log_scale": "false",
|
|
11029
|
+
}
|
|
11030
|
+
|
|
11031
|
+
def process_choice_parameter(i: int, param: list) -> dict:
|
|
11032
|
+
choices = ""
|
|
11033
|
+
if len(param) > 2 and param[2]:
|
|
11034
|
+
choices = ",".join([c.strip() for c in str(param[2]).split(",")])
|
|
11035
|
+
return {f"parameter_{i}_values": choices}
|
|
11036
|
+
|
|
11037
|
+
def process_fixed_parameter(i: int, param: list) -> dict:
|
|
11038
|
+
return {f"parameter_{i}_value": param[2] if len(param) > 2 else ""}
|
|
11039
|
+
|
|
11040
|
+
def get_base_url() -> str:
|
|
11041
|
+
file_path = Path.home() / ".oo_base_url"
|
|
11042
|
+
if file_path.exists():
|
|
11043
|
+
return file_path.read_text().strip()
|
|
11044
|
+
|
|
11045
|
+
return "https://imageseg.scads.de/omniax/"
|
|
11046
|
+
|
|
11047
|
+
def write_ui_url() -> None:
|
|
11048
|
+
url = build_gui_url(args)
|
|
11049
|
+
with open(get_current_run_folder("ui_url.txt"), mode="a", encoding="utf-8") as myfile:
|
|
11050
|
+
myfile.write(url)
|
|
10507
11051
|
|
|
10508
11052
|
def handle_random_steps() -> None:
|
|
10509
11053
|
if args.parameter and args.continue_previous_job and random_steps <= 0:
|
|
@@ -10562,8 +11106,6 @@ def run_search_with_progress_bar() -> None:
|
|
|
10562
11106
|
wait_for_jobs_to_complete()
|
|
10563
11107
|
|
|
10564
11108
|
def complex_tests(_program_name: str, wanted_stderr: str, wanted_exit_code: int, wanted_signal: Union[int, None], res_is_none: bool = False) -> int:
|
|
10565
|
-
#print_yellow(f"Test suite: {_program_name}")
|
|
10566
|
-
|
|
10567
11109
|
nr_errors: int = 0
|
|
10568
11110
|
|
|
10569
11111
|
program_path: str = f"./.tests/test_wronggoing_stuff.bin/bin/{_program_name}"
|
|
@@ -10637,7 +11179,7 @@ def test_find_paths(program_code: str) -> int:
|
|
|
10637
11179
|
for i in files:
|
|
10638
11180
|
if i not in string:
|
|
10639
11181
|
if os.path.exists(i):
|
|
10640
|
-
print("Missing {i} in find_file_paths string!")
|
|
11182
|
+
print(f"Missing {i} in find_file_paths string!")
|
|
10641
11183
|
nr_errors += 1
|
|
10642
11184
|
|
|
10643
11185
|
return nr_errors
|
|
@@ -11018,17 +11560,16 @@ Exit-Code: 159
|
|
|
11018
11560
|
|
|
11019
11561
|
my_exit(nr_errors)
|
|
11020
11562
|
|
|
11021
|
-
def
|
|
11563
|
+
def main_wrapper() -> None:
|
|
11022
11564
|
print(f"Run-UUID: {run_uuid}")
|
|
11023
11565
|
|
|
11024
11566
|
auto_wrap_namespace(globals())
|
|
11025
11567
|
|
|
11026
11568
|
print_logo()
|
|
11027
11569
|
|
|
11028
|
-
start_logging_daemon()
|
|
11029
|
-
|
|
11030
11570
|
fool_linter(args.num_cpus_main_job)
|
|
11031
11571
|
fool_linter(args.flame_graph)
|
|
11572
|
+
fool_linter(args.memray)
|
|
11032
11573
|
|
|
11033
11574
|
with warnings.catch_warnings():
|
|
11034
11575
|
warnings.simplefilter("ignore")
|
|
@@ -11060,9 +11601,34 @@ def main_outside() -> None:
|
|
|
11060
11601
|
else:
|
|
11061
11602
|
end_program(True)
|
|
11062
11603
|
|
|
11604
|
+
def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
|
|
11605
|
+
pattern = re.compile(regex) if regex else None
|
|
11606
|
+
|
|
11607
|
+
def wrapped(*args: Any, **kwargs: Any) -> None:
|
|
11608
|
+
if pattern and not pattern.search(func.__name__):
|
|
11609
|
+
return func(*args, **kwargs)
|
|
11610
|
+
|
|
11611
|
+
stack = inspect.stack()
|
|
11612
|
+
chain = []
|
|
11613
|
+
for frame in stack[1:]:
|
|
11614
|
+
fn = frame.function
|
|
11615
|
+
if fn in ("wrapped", "<module>"):
|
|
11616
|
+
continue
|
|
11617
|
+
chain.append(fn)
|
|
11618
|
+
|
|
11619
|
+
if chain:
|
|
11620
|
+
sys.stderr.write(" ⇒ ".join(reversed(chain)) + "\n")
|
|
11621
|
+
|
|
11622
|
+
return func(*args, **kwargs)
|
|
11623
|
+
|
|
11624
|
+
return wrapped
|
|
11625
|
+
|
|
11063
11626
|
def auto_wrap_namespace(namespace: Any) -> Any:
|
|
11064
11627
|
enable_beartype = any(os.getenv(v) for v in ("ENABLE_BEARTYPE", "CI"))
|
|
11065
11628
|
|
|
11629
|
+
if args.beartype:
|
|
11630
|
+
enable_beartype = True
|
|
11631
|
+
|
|
11066
11632
|
excluded_functions = {
|
|
11067
11633
|
"log_time_and_memory_wrapper",
|
|
11068
11634
|
"collect_runtime_stats",
|
|
@@ -11071,8 +11637,6 @@ def auto_wrap_namespace(namespace: Any) -> Any:
|
|
|
11071
11637
|
"_record_stats",
|
|
11072
11638
|
"_open",
|
|
11073
11639
|
"_check_memory_leak",
|
|
11074
|
-
"start_periodic_live_share",
|
|
11075
|
-
"start_logging_daemon",
|
|
11076
11640
|
"get_current_run_folder",
|
|
11077
11641
|
"show_func_name_wrapper"
|
|
11078
11642
|
}
|
|
@@ -11089,14 +11653,16 @@ def auto_wrap_namespace(namespace: Any) -> Any:
|
|
|
11089
11653
|
if args.show_func_name:
|
|
11090
11654
|
wrapped = show_func_name_wrapper(wrapped)
|
|
11091
11655
|
|
|
11656
|
+
if args.debug_stack_trace_regex:
|
|
11657
|
+
wrapped = stack_trace_wrapper(wrapped, args.debug_stack_trace_regex)
|
|
11658
|
+
|
|
11092
11659
|
namespace[name] = wrapped
|
|
11093
11660
|
|
|
11094
11661
|
return namespace
|
|
11095
11662
|
|
|
11096
11663
|
if __name__ == "__main__":
|
|
11097
|
-
|
|
11098
11664
|
try:
|
|
11099
|
-
|
|
11665
|
+
main_wrapper()
|
|
11100
11666
|
except (SignalUSR, SignalINT, SignalCONT) as e:
|
|
11101
|
-
print_red(f"
|
|
11667
|
+
print_red(f"main_wrapper failed with exception {e}")
|
|
11102
11668
|
end_program(True)
|