omniopt2 8184__py3-none-any.whl → 8285__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omniopt2 might be problematic. Click here for more details.
- .omniopt.py +133 -92
- .omniopt_plot_trial_index_result.py +1 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt.py +133 -92
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_trial_index_result.py +1 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/omniopt_docker +60 -60
- {omniopt2-8184.dist-info → omniopt2-8285.dist-info}/METADATA +1 -1
- {omniopt2-8184.dist-info → omniopt2-8285.dist-info}/RECORD +38 -38
- omniopt2.egg-info/PKG-INFO +1 -1
- omniopt_docker +60 -60
- pyproject.toml +1 -1
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.colorfunctions.sh +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.general.sh +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.helpers.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_cpu_ram_usage.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_general.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_gpu_usage.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_kde.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_scatter.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_scatter_generation_method.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_scatter_hex.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_time_and_exit_code.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.omniopt_plot_worker.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.random_generator.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.shellscript_functions +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/.tpe.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/LICENSE +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/apt-dependencies.txt +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/omniopt +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/omniopt_evaluate +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/omniopt_plot +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/omniopt_share +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/pylint.rc +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/requirements.txt +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/setup.py +0 -0
- {omniopt2-8184.data → omniopt2-8285.data}/data/bin/test_requirements.txt +0 -0
- {omniopt2-8184.dist-info → omniopt2-8285.dist-info}/WHEEL +0 -0
- {omniopt2-8184.dist-info → omniopt2-8285.dist-info}/licenses/LICENSE +0 -0
- {omniopt2-8184.dist-info → omniopt2-8285.dist-info}/top_level.txt +0 -0
.omniopt.py
CHANGED
|
@@ -599,20 +599,46 @@ def _get_debug_json(time_str: str, msg: str) -> str:
|
|
|
599
599
|
separators=(",", ":") # no pretty indent → smaller, faster
|
|
600
600
|
).replace('\r', '').replace('\n', '')
|
|
601
601
|
|
|
602
|
-
def
|
|
603
|
-
|
|
602
|
+
def print_stack_paths() -> None:
|
|
603
|
+
stack = inspect.stack()[1:] # skip current frame
|
|
604
|
+
stack.reverse() # vom Hauptprogramm zur tiefsten Funktion
|
|
604
605
|
|
|
605
|
-
|
|
606
|
+
last_filename = None
|
|
607
|
+
for depth, frame_info in enumerate(stack):
|
|
608
|
+
filename = frame_info.filename
|
|
609
|
+
lineno = frame_info.lineno
|
|
610
|
+
func_name = frame_info.function
|
|
606
611
|
|
|
607
|
-
|
|
612
|
+
if func_name in ["<module>", "print_debug"]:
|
|
613
|
+
continue
|
|
614
|
+
|
|
615
|
+
if filename != last_filename:
|
|
616
|
+
print(filename)
|
|
617
|
+
last_filename = filename
|
|
618
|
+
indent = ""
|
|
619
|
+
else:
|
|
620
|
+
indent = " " * 4 * depth
|
|
621
|
+
|
|
622
|
+
print(f"{indent}↳ {func_name}:{lineno}")
|
|
623
|
+
|
|
624
|
+
def print_debug(msg: str) -> None:
|
|
625
|
+
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
626
|
+
|
|
627
|
+
stack = traceback.extract_stack()[:-1]
|
|
628
|
+
stack_funcs = [frame.name for frame in stack]
|
|
608
629
|
|
|
609
|
-
|
|
630
|
+
if "args" in globals() and args and hasattr(args, "debug_stack_regex") and args.debug_stack_regex:
|
|
631
|
+
matched = any(any(re.match(regex, func) for regex in args.debug_stack_regex) for func in stack_funcs)
|
|
632
|
+
if matched:
|
|
633
|
+
print(f"DEBUG: {msg}")
|
|
634
|
+
print_stack_paths()
|
|
610
635
|
|
|
611
|
-
|
|
636
|
+
stack_trace_element = _get_debug_json(time_str, msg)
|
|
637
|
+
_debug(stack_trace_element)
|
|
612
638
|
|
|
613
639
|
try:
|
|
614
640
|
with open(logfile_bare, mode='a', encoding="utf-8") as f:
|
|
615
|
-
original_print(
|
|
641
|
+
original_print(msg, file=f)
|
|
616
642
|
except FileNotFoundError:
|
|
617
643
|
print_red("It seems like the run's folder was deleted during the run. Cannot continue.")
|
|
618
644
|
sys.exit(99)
|
|
@@ -730,6 +756,7 @@ _DEFAULT_SPECIALS: Dict[str, Any] = {
|
|
|
730
756
|
class ConfigLoader:
|
|
731
757
|
runtime_debug: bool
|
|
732
758
|
show_func_name: bool
|
|
759
|
+
debug_stack_regex: str
|
|
733
760
|
number_of_generators: int
|
|
734
761
|
disable_previous_job_constraint: bool
|
|
735
762
|
save_to_database: bool
|
|
@@ -961,6 +988,7 @@ class ConfigLoader:
|
|
|
961
988
|
debug.add_argument('--just_return_defaults', help='Just return defaults in dryrun', action='store_true', default=False)
|
|
962
989
|
debug.add_argument('--prettyprint', help='Shows stdout and stderr in a pretty printed format', action='store_true', default=False)
|
|
963
990
|
debug.add_argument('--runtime_debug', help='Logs which functions use most of the time', action='store_true', default=False)
|
|
991
|
+
debug.add_argument('--debug_stack_regex', help='Only print debug messages if call stack matches any regex', type=str, default='')
|
|
964
992
|
debug.add_argument('--show_func_name', help='Show func name before each execution and when it is done', action='store_true', default=False)
|
|
965
993
|
|
|
966
994
|
def load_config(self: Any, config_path: str, file_format: str) -> dict:
|
|
@@ -2598,9 +2626,6 @@ def _debug_worker_creation(msg: str, _lvl: int = 0, eee: Union[None, str, Except
|
|
|
2598
2626
|
def append_to_nvidia_smi_logs(_file: str, _host: str, result: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
2599
2627
|
log_message_to_file(_file, result, _lvl, str(eee))
|
|
2600
2628
|
|
|
2601
|
-
def _debug_get_next_trials(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
2602
|
-
log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, _lvl, str(eee))
|
|
2603
|
-
|
|
2604
2629
|
def _debug_progressbar(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
2605
2630
|
log_message_to_file(logfile_progressbar, msg, _lvl, str(eee))
|
|
2606
2631
|
|
|
@@ -2796,7 +2821,7 @@ def print_debug_get_next_trials(got: int, requested: int, _line: int) -> None:
|
|
|
2796
2821
|
time_str: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
2797
2822
|
msg: str = f"{time_str}, {got}, {requested}"
|
|
2798
2823
|
|
|
2799
|
-
|
|
2824
|
+
log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, 0, "")
|
|
2800
2825
|
|
|
2801
2826
|
def print_debug_progressbar(msg: str) -> None:
|
|
2802
2827
|
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
@@ -4858,7 +4883,7 @@ def process_best_result(res_name: str, print_to_file: bool) -> int:
|
|
|
4858
4883
|
|
|
4859
4884
|
if str(best_result) in [NO_RESULT, None, "None"]:
|
|
4860
4885
|
print_red(f"Best {res_name} could not be determined")
|
|
4861
|
-
return 87
|
|
4886
|
+
return 87 # exit-code: 87
|
|
4862
4887
|
|
|
4863
4888
|
total_str = f"total: {_count_done_jobs(RESULT_CSV_FILE) - NR_INSERTED_JOBS}"
|
|
4864
4889
|
if NR_INSERTED_JOBS:
|
|
@@ -6461,83 +6486,100 @@ def _get_generation_node_for_index_floats_match(
|
|
|
6461
6486
|
return False
|
|
6462
6487
|
return abs(row_val_num - val) <= tolerance
|
|
6463
6488
|
|
|
6489
|
+
def validate_and_convert_params_for_jobs_from_csv(arm_params: Dict) -> Dict:
|
|
6490
|
+
corrected_params: Dict[Any, Any] = {}
|
|
6491
|
+
|
|
6492
|
+
if experiment_parameters is not None:
|
|
6493
|
+
for param in experiment_parameters:
|
|
6494
|
+
name = param["name"]
|
|
6495
|
+
expected_type = param.get("value_type", "str")
|
|
6496
|
+
|
|
6497
|
+
if name not in arm_params:
|
|
6498
|
+
continue
|
|
6499
|
+
|
|
6500
|
+
value = arm_params[name]
|
|
6501
|
+
|
|
6502
|
+
try:
|
|
6503
|
+
if param["type"] == "range":
|
|
6504
|
+
if expected_type == "int":
|
|
6505
|
+
corrected_params[name] = int(value)
|
|
6506
|
+
elif expected_type == "float":
|
|
6507
|
+
corrected_params[name] = float(value)
|
|
6508
|
+
elif param["type"] == "choice":
|
|
6509
|
+
corrected_params[name] = str(value)
|
|
6510
|
+
except (ValueError, TypeError):
|
|
6511
|
+
corrected_params[name] = None
|
|
6512
|
+
|
|
6513
|
+
return corrected_params
|
|
6514
|
+
|
|
6464
6515
|
def insert_jobs_from_csv(this_csv_file_path: str) -> None:
|
|
6465
6516
|
with spinner(f"Inserting job into CSV from {this_csv_file_path}") as __status:
|
|
6466
|
-
this_csv_file_path = this_csv_file_path
|
|
6517
|
+
this_csv_file_path = normalize_path(this_csv_file_path)
|
|
6467
6518
|
|
|
6468
|
-
if not
|
|
6519
|
+
if not helpers.file_exists(this_csv_file_path):
|
|
6469
6520
|
print_red(f"--load_data_from_existing_jobs: Cannot find {this_csv_file_path}")
|
|
6470
|
-
|
|
6471
6521
|
return
|
|
6472
6522
|
|
|
6473
|
-
|
|
6474
|
-
|
|
6523
|
+
arm_params_list, results_list = parse_csv(this_csv_file_path)
|
|
6524
|
+
insert_jobs_from_lists(this_csv_file_path, arm_params_list, results_list, __status)
|
|
6475
6525
|
|
|
6476
|
-
|
|
6477
|
-
|
|
6478
|
-
name = param["name"]
|
|
6479
|
-
expected_type = param.get("value_type", "str")
|
|
6526
|
+
def normalize_path(file_path: str) -> str:
|
|
6527
|
+
return file_path.replace("//", "/")
|
|
6480
6528
|
|
|
6481
|
-
|
|
6482
|
-
|
|
6529
|
+
def insert_jobs_from_lists(csv_path, arm_params_list, results_list, __status):
|
|
6530
|
+
cnt = 0
|
|
6531
|
+
err_msgs = []
|
|
6483
6532
|
|
|
6484
|
-
|
|
6533
|
+
for i, (arm_params, result) in enumerate(zip(arm_params_list, results_list)):
|
|
6534
|
+
base_str = f"[bold green]Loading job {i}/{len(results_list)} from {csv_path} into ax_client, result: {result}"
|
|
6535
|
+
__status.update(base_str)
|
|
6485
6536
|
|
|
6486
|
-
|
|
6487
|
-
|
|
6488
|
-
if expected_type == "int":
|
|
6489
|
-
corrected_params[name] = int(value)
|
|
6490
|
-
elif expected_type == "float":
|
|
6491
|
-
corrected_params[name] = float(value)
|
|
6492
|
-
elif param["type"] == "choice":
|
|
6493
|
-
corrected_params[name] = str(value)
|
|
6494
|
-
except (ValueError, TypeError):
|
|
6495
|
-
corrected_params[name] = None
|
|
6496
|
-
|
|
6497
|
-
return corrected_params
|
|
6537
|
+
if not args.worker_generator_path:
|
|
6538
|
+
arm_params = validate_and_convert_params_for_jobs_from_csv(arm_params)
|
|
6498
6539
|
|
|
6499
|
-
arm_params_list, results_list
|
|
6540
|
+
cnt = try_insert_job(csv_path, arm_params, result, i, arm_params_list, results_list, __status, base_str, cnt, err_msgs)
|
|
6500
6541
|
|
|
6501
|
-
|
|
6542
|
+
summarize_insertions(csv_path, cnt)
|
|
6543
|
+
update_global_job_counters(cnt)
|
|
6502
6544
|
|
|
6503
|
-
|
|
6545
|
+
def try_insert_job(csv_path: str, arm_params: Dict, result: Any, i: int, arm_params_list: Any, results_list: Any, __status: Any, base_str: Optional[str], cnt: int, err_msgs: Optional[Union[str, list[str]]]) -> int:
|
|
6546
|
+
try:
|
|
6547
|
+
gen_node_name = get_generation_node_for_index(csv_path, arm_params_list, results_list, i, __status, base_str)
|
|
6504
6548
|
|
|
6505
|
-
|
|
6506
|
-
|
|
6507
|
-
|
|
6508
|
-
__status.update(base_str)
|
|
6509
|
-
if not args.worker_generator_path:
|
|
6510
|
-
arm_params = validate_and_convert_params(arm_params)
|
|
6549
|
+
if not result:
|
|
6550
|
+
print_yellow("Encountered job without a result")
|
|
6551
|
+
return cnt
|
|
6511
6552
|
|
|
6512
|
-
|
|
6513
|
-
|
|
6553
|
+
if insert_job_into_ax_client(arm_params, result, gen_node_name, __status, base_str):
|
|
6554
|
+
cnt += 1
|
|
6555
|
+
print_debug(f"Inserted one job from {csv_path}, arm_params: {arm_params}, results: {result}")
|
|
6556
|
+
else:
|
|
6557
|
+
print_red(f"Failed to insert one job from {csv_path}, arm_params: {arm_params}, results: {result}")
|
|
6514
6558
|
|
|
6515
|
-
|
|
6516
|
-
|
|
6517
|
-
|
|
6559
|
+
except ValueError as e:
|
|
6560
|
+
err_msg = (
|
|
6561
|
+
f"Failed to insert job(s) from {csv_path} into ax_client. "
|
|
6562
|
+
f"This can happen when the csv file has different parameters or results as the main job one's "
|
|
6563
|
+
f"or other imported jobs. Error: {e}"
|
|
6564
|
+
)
|
|
6565
|
+
if err_msg not in err_msgs:
|
|
6566
|
+
print_red(err_msg)
|
|
6567
|
+
err_msgs.append(err_msg)
|
|
6518
6568
|
|
|
6519
|
-
|
|
6520
|
-
else:
|
|
6521
|
-
print_red(f"Failed to insert one job from {this_csv_file_path}, arm_params: {arm_params}, results: {result}")
|
|
6522
|
-
else:
|
|
6523
|
-
print_yellow("Encountered job without a result")
|
|
6524
|
-
except ValueError as e:
|
|
6525
|
-
err_msg = f"Failed to insert job(s) from {this_csv_file_path} into ax_client. This can happen when the csv file has different parameters or results as the main job one's or other imported jobs. Error: {e}"
|
|
6526
|
-
if err_msg not in err_msgs:
|
|
6527
|
-
print_red(err_msg)
|
|
6528
|
-
err_msgs.append(err_msg)
|
|
6529
|
-
|
|
6530
|
-
i = i + 1
|
|
6531
|
-
|
|
6532
|
-
if cnt:
|
|
6533
|
-
if cnt == 1:
|
|
6534
|
-
print_yellow(f"Inserted one job from {this_csv_file_path}")
|
|
6535
|
-
else:
|
|
6536
|
-
print_yellow(f"Inserted {cnt} jobs from {this_csv_file_path}")
|
|
6569
|
+
return cnt
|
|
6537
6570
|
|
|
6538
|
-
|
|
6539
|
-
|
|
6540
|
-
|
|
6571
|
+
def summarize_insertions(csv_path: str, cnt: int) -> None:
|
|
6572
|
+
if cnt == 0:
|
|
6573
|
+
return
|
|
6574
|
+
if cnt == 1:
|
|
6575
|
+
print_yellow(f"Inserted one job from {csv_path}")
|
|
6576
|
+
else:
|
|
6577
|
+
print_yellow(f"Inserted {cnt} jobs from {csv_path}")
|
|
6578
|
+
|
|
6579
|
+
def update_global_job_counters(cnt: int) -> None:
|
|
6580
|
+
if not args.worker_generator_path:
|
|
6581
|
+
set_max_eval(max_eval + cnt)
|
|
6582
|
+
set_nr_inserted_jobs(NR_INSERTED_JOBS + cnt)
|
|
6541
6583
|
|
|
6542
6584
|
def __insert_job_into_ax_client__update_status(__status: Optional[Any], base_str: Optional[str], new_text: str) -> None:
|
|
6543
6585
|
if __status and base_str:
|
|
@@ -7801,32 +7843,32 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
|
|
|
7801
7843
|
remaining = nr_of_jobs_to_get - len(batched_arms)
|
|
7802
7844
|
print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting {remaining} more arm(s).")
|
|
7803
7845
|
|
|
7804
|
-
|
|
7846
|
+
print_debug("get pending observations")
|
|
7805
7847
|
pending_observations = get_pending_observation_features(experiment=ax_client.experiment)
|
|
7806
|
-
|
|
7848
|
+
print_debug("got pending observations")
|
|
7807
7849
|
|
|
7808
|
-
|
|
7850
|
+
print_debug("getting global_gs.gen()")
|
|
7809
7851
|
batched_generator_run = global_gs.gen(
|
|
7810
7852
|
experiment=ax_client.experiment,
|
|
7811
7853
|
n=remaining,
|
|
7812
7854
|
pending_observations=pending_observations
|
|
7813
7855
|
)
|
|
7814
|
-
|
|
7856
|
+
print_debug(f"got global_gs.gen(): {batched_generator_run}")
|
|
7815
7857
|
|
|
7816
7858
|
# Inline rekursiv entpacken bis flach
|
|
7817
7859
|
depth = 0
|
|
7818
7860
|
path = "batched_generator_run"
|
|
7819
7861
|
while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) > 0:
|
|
7820
|
-
|
|
7862
|
+
print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
|
|
7821
7863
|
batched_generator_run = batched_generator_run[0]
|
|
7822
7864
|
path += "[0]"
|
|
7823
7865
|
depth += 1
|
|
7824
7866
|
|
|
7825
|
-
|
|
7867
|
+
print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
|
|
7826
7868
|
|
|
7827
|
-
|
|
7869
|
+
print_debug("got new arms")
|
|
7828
7870
|
new_arms = batched_generator_run.arms
|
|
7829
|
-
|
|
7871
|
+
print_debug(f"new_arms: {new_arms}")
|
|
7830
7872
|
if not new_arms:
|
|
7831
7873
|
print_debug("get_batched_arms: No new arms were generated in this attempt.")
|
|
7832
7874
|
else:
|
|
@@ -7841,7 +7883,7 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
|
|
|
7841
7883
|
|
|
7842
7884
|
return batched_arms
|
|
7843
7885
|
|
|
7844
|
-
def
|
|
7886
|
+
def fetch_next_trials(nr_of_jobs_to_get: int, recursion: bool = False) -> Tuple[Dict[int, Any], bool]:
|
|
7845
7887
|
die_101_if_no_ax_client_or_experiment_or_gs()
|
|
7846
7888
|
|
|
7847
7889
|
if not ax_client:
|
|
@@ -7850,9 +7892,9 @@ def _fetch_next_trials(nr_of_jobs_to_get: int, recursion: bool = False) -> Tuple
|
|
|
7850
7892
|
if global_gs is None:
|
|
7851
7893
|
_fatal_error("Global generation strategy is not set. This is a bug in OmniOpt2.", 107)
|
|
7852
7894
|
|
|
7853
|
-
return
|
|
7895
|
+
return generate_trials(nr_of_jobs_to_get, recursion)
|
|
7854
7896
|
|
|
7855
|
-
def
|
|
7897
|
+
def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
7856
7898
|
trials_dict: Dict[int, Any] = {}
|
|
7857
7899
|
trial_durations: List[float] = []
|
|
7858
7900
|
|
|
@@ -7878,7 +7920,7 @@ def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7878
7920
|
progressbar_description(_get_trials_message(cnt + 1, n, trial_durations))
|
|
7879
7921
|
|
|
7880
7922
|
try:
|
|
7881
|
-
result =
|
|
7923
|
+
result = create_and_handle_trial(arm)
|
|
7882
7924
|
if result is not None:
|
|
7883
7925
|
trial_index, trial_duration, trial_successful = result
|
|
7884
7926
|
|
|
@@ -7903,9 +7945,9 @@ def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7903
7945
|
class TrialRejected(Exception):
|
|
7904
7946
|
pass
|
|
7905
7947
|
|
|
7906
|
-
def
|
|
7948
|
+
def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
7907
7949
|
if ax_client is None:
|
|
7908
|
-
print_red("ax_client is None in
|
|
7950
|
+
print_red("ax_client is None in create_and_handle_trial")
|
|
7909
7951
|
return None
|
|
7910
7952
|
|
|
7911
7953
|
start = time.time()
|
|
@@ -7978,7 +8020,7 @@ def _handle_generation_failure(
|
|
|
7978
8020
|
if not recursion and args.revert_to_random_when_seemingly_exhausted:
|
|
7979
8021
|
print_debug("Switching to random search strategy.")
|
|
7980
8022
|
set_global_gs_to_random()
|
|
7981
|
-
return
|
|
8023
|
+
return fetch_next_trials(requested, True)
|
|
7982
8024
|
|
|
7983
8025
|
print_red(f"_handle_generation_failure: General Exception: {e}")
|
|
7984
8026
|
|
|
@@ -8261,14 +8303,14 @@ def _handle_linalg_error(error: Union[None, str, Exception]) -> None:
|
|
|
8261
8303
|
"""Handles the np.linalg.LinAlgError based on the model being used."""
|
|
8262
8304
|
print_red(f"Error: {error}")
|
|
8263
8305
|
|
|
8264
|
-
def
|
|
8265
|
-
finish_previous_jobs(["finishing jobs (
|
|
8306
|
+
def get_next_trials(nr_of_jobs_to_get: int) -> Tuple[Union[None, dict], bool]:
|
|
8307
|
+
finish_previous_jobs(["finishing jobs (get_next_trials)"])
|
|
8266
8308
|
|
|
8267
|
-
if break_run_search("
|
|
8309
|
+
if break_run_search("get_next_trials", max_eval) or nr_of_jobs_to_get == 0:
|
|
8268
8310
|
return {}, True
|
|
8269
8311
|
|
|
8270
8312
|
try:
|
|
8271
|
-
trial_index_to_param, optimization_complete =
|
|
8313
|
+
trial_index_to_param, optimization_complete = fetch_next_trials(nr_of_jobs_to_get)
|
|
8272
8314
|
|
|
8273
8315
|
cf = currentframe()
|
|
8274
8316
|
if cf:
|
|
@@ -8857,7 +8899,7 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
|
|
|
8857
8899
|
get_next_trials_nr = new_nr_of_jobs_to_get
|
|
8858
8900
|
|
|
8859
8901
|
for _ in range(range_nr):
|
|
8860
|
-
trial_index_to_param, done_optimizing =
|
|
8902
|
+
trial_index_to_param, done_optimizing = get_next_trials(get_next_trials_nr)
|
|
8861
8903
|
if done_optimizing:
|
|
8862
8904
|
continue
|
|
8863
8905
|
|
|
@@ -11094,7 +11136,6 @@ def auto_wrap_namespace(namespace: Any) -> Any:
|
|
|
11094
11136
|
return namespace
|
|
11095
11137
|
|
|
11096
11138
|
if __name__ == "__main__":
|
|
11097
|
-
|
|
11098
11139
|
try:
|
|
11099
11140
|
main_outside()
|
|
11100
11141
|
except (SignalUSR, SignalINT, SignalCONT) as e:
|