omniopt2 8178__py3-none-any.whl → 9171__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. .gitignore +2 -0
  2. .helpers.py +0 -9
  3. .omniopt.py +1717 -1151
  4. .omniopt_plot_scatter.py +1 -1
  5. .omniopt_plot_scatter_hex.py +1 -1
  6. .omniopt_plot_trial_index_result.py +1 -0
  7. .pareto.py +134 -0
  8. .shellscript_functions +24 -15
  9. .tests/pylint.rc +0 -4
  10. .tpe.py +4 -3
  11. README.md +1 -1
  12. omniopt +92 -55
  13. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.helpers.py +0 -9
  14. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt.py +1717 -1151
  15. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter.py +1 -1
  16. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter_hex.py +1 -1
  17. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_trial_index_result.py +1 -0
  18. omniopt2-9171.data/data/bin/.pareto.py +134 -0
  19. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.shellscript_functions +24 -15
  20. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.tpe.py +4 -3
  21. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt +92 -55
  22. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt_docker +60 -60
  23. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt_plot +1 -1
  24. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/pylint.rc +0 -4
  25. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/requirements.txt +3 -4
  26. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/test_requirements.txt +1 -0
  27. {omniopt2-8178.dist-info → omniopt2-9171.dist-info}/METADATA +6 -6
  28. omniopt2-9171.dist-info/RECORD +73 -0
  29. omniopt2.egg-info/PKG-INFO +6 -6
  30. omniopt2.egg-info/SOURCES.txt +1 -0
  31. omniopt2.egg-info/requires.txt +4 -4
  32. omniopt_docker +60 -60
  33. omniopt_plot +1 -1
  34. pyproject.toml +1 -1
  35. requirements.txt +3 -4
  36. test_requirements.txt +1 -0
  37. omniopt2-8178.dist-info/RECORD +0 -71
  38. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.colorfunctions.sh +0 -0
  39. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.general.sh +0 -0
  40. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_cpu_ram_usage.py +0 -0
  41. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_general.py +0 -0
  42. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_gpu_usage.py +0 -0
  43. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_kde.py +0 -0
  44. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter_generation_method.py +0 -0
  45. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_time_and_exit_code.py +0 -0
  46. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.omniopt_plot_worker.py +0 -0
  47. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/.random_generator.py +0 -0
  48. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/LICENSE +0 -0
  49. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/apt-dependencies.txt +0 -0
  50. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt_evaluate +0 -0
  51. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/omniopt_share +0 -0
  52. {omniopt2-8178.data → omniopt2-9171.data}/data/bin/setup.py +0 -0
  53. {omniopt2-8178.dist-info → omniopt2-9171.dist-info}/WHEEL +0 -0
  54. {omniopt2-8178.dist-info → omniopt2-9171.dist-info}/licenses/LICENSE +0 -0
  55. {omniopt2-8178.dist-info → omniopt2-9171.dist-info}/top_level.txt +0 -0
.omniopt.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  #from mayhemmonkey import MayhemMonkey
4
4
  #mayhemmonkey = MayhemMonkey()
5
- #mayhemmonkey.set_function_fail_after_count("open", 201)
5
+ #mayhemmonkey.set_function_fail_after_count("open", 10)
6
6
  #mayhemmonkey.set_function_error_rate("open", 0.1)
7
7
  #mayhemmonkey.set_function_group_error_rate(["io", "math"], 0.8)
8
8
  #mayhemmonkey.install_faulty()
@@ -26,10 +26,16 @@ import traceback
26
26
  import inspect
27
27
  import tracemalloc
28
28
  import resource
29
+ from urllib.parse import urlencode
29
30
  import psutil
30
31
 
31
32
  FORCE_EXIT: bool = False
32
33
 
34
+ LAST_LOG_TIME: int = 0
35
+ last_msg_progressbar = ""
36
+ last_msg_raw = None
37
+ last_lock_print_debug = threading.Lock()
38
+
33
39
  def force_exit(signal_number: Any, frame: Any) -> Any:
34
40
  global FORCE_EXIT
35
41
 
@@ -59,10 +65,10 @@ _last_count_time = 0
59
65
  _last_count_result: tuple[int, str] = (0, "")
60
66
 
61
67
  _total_time = 0.0
62
- _func_times = defaultdict(float)
63
- _func_mem = defaultdict(float)
64
- _func_call_paths = defaultdict(Counter)
65
- _last_mem = defaultdict(float)
68
+ _func_times: defaultdict = defaultdict(float)
69
+ _func_mem: defaultdict = defaultdict(float)
70
+ _func_call_paths: defaultdict = defaultdict(Counter)
71
+ _last_mem: defaultdict = defaultdict(float)
66
72
  _leak_threshold_mb = 10.0
67
73
  generation_strategy_names: list = []
68
74
  default_max_range_difference: int = 1000000
@@ -82,6 +88,7 @@ log_nr_gen_jobs: list[int] = []
82
88
  generation_strategy_human_readable: str = ""
83
89
  oo_call: str = "./omniopt"
84
90
  progress_bar_length: int = 0
91
+ worker_usage_file = 'worker_usage.csv'
85
92
 
86
93
  if os.environ.get("CUSTOM_VIRTUAL_ENV") == "1":
87
94
  oo_call = "omniopt"
@@ -97,7 +104,7 @@ joined_valid_occ_types: str = ", ".join(valid_occ_types)
97
104
  SUPPORTED_MODELS: list = ["SOBOL", "FACTORIAL", "SAASBO", "BOTORCH_MODULAR", "UNIFORM", "BO_MIXED", "RANDOMFOREST", "EXTERNAL_GENERATOR", "PSEUDORANDOM", "TPE"]
98
105
  joined_supported_models: str = ", ".join(SUPPORTED_MODELS)
99
106
 
100
- special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid"]
107
+ special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid", "runtime", "status"]
101
108
 
102
109
  IGNORABLE_COLUMNS: list = ["start_time", "end_time", "hostname", "signal", "exit_code", "run_time", "program_string"] + special_col_names
103
110
 
@@ -151,12 +158,6 @@ try:
151
158
  message="Ax currently requires a sqlalchemy version below 2.0.*",
152
159
  )
153
160
 
154
- warnings.filterwarnings(
155
- "ignore",
156
- category=RuntimeWarning,
157
- message="coroutine 'start_logging_daemon' was never awaited"
158
- )
159
-
160
161
  with spinner("Importing argparse..."):
161
162
  import argparse
162
163
 
@@ -202,6 +203,9 @@ try:
202
203
  with spinner("Importing rich.pretty..."):
203
204
  from rich.pretty import pprint
204
205
 
206
+ with spinner("Importing pformat..."):
207
+ from pprint import pformat
208
+
205
209
  with spinner("Importing rich.prompt..."):
206
210
  from rich.prompt import Prompt, FloatPrompt, IntPrompt
207
211
 
@@ -235,9 +239,6 @@ try:
235
239
  with spinner("Importing uuid..."):
236
240
  import uuid
237
241
 
238
- #with spinner("Importing qrcode..."):
239
- # import qrcode
240
-
241
242
  with spinner("Importing cowsay..."):
242
243
  import cowsay
243
244
 
@@ -268,6 +269,9 @@ try:
268
269
  with spinner("Importing beartype..."):
269
270
  from beartype import beartype
270
271
 
272
+ with spinner("Importing rendering stuff..."):
273
+ from ax.plot.base import AxPlotConfig
274
+
271
275
  with spinner("Importing statistics..."):
272
276
  import statistics
273
277
 
@@ -331,7 +335,7 @@ def show_func_name_wrapper(func: F) -> F:
331
335
 
332
336
  return result
333
337
 
334
- return wrapper # type: ignore
338
+ return wrapper # type: ignore
335
339
 
336
340
  def log_time_and_memory_wrapper(func: F) -> F:
337
341
  @functools.wraps(func)
@@ -358,7 +362,7 @@ def log_time_and_memory_wrapper(func: F) -> F:
358
362
 
359
363
  return result
360
364
 
361
- return wrapper # type: ignore
365
+ return wrapper # type: ignore
362
366
 
363
367
  def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: float, mem_peak: float) -> None:
364
368
  global _total_time
@@ -379,16 +383,9 @@ def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: fl
379
383
  call_path_str = " -> ".join(short_stack)
380
384
  _func_call_paths[func_name][call_path_str] += 1
381
385
 
382
- print(
383
- f"Function '{func_name}' took {elapsed:.4f}s "
384
- f"(total {percent_if_added:.1f}% of tracked time)"
385
- )
386
- print(
387
- f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, "
388
- f"diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB"
389
- )
386
+ print(f"Function '{func_name}' took {elapsed:.4f}s (total {percent_if_added:.1f}% of tracked time)")
387
+ print(f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB")
390
388
 
391
- # NEU: Runtime Stats
392
389
  runtime_stats = collect_runtime_stats()
393
390
  print("=== Runtime Stats ===")
394
391
  print(f"RSS: {runtime_stats['rss_MB']:.2f} MB, VMS: {runtime_stats['vms_MB']:.2f} MB")
@@ -447,7 +444,7 @@ RESET: str = "\033[0m"
447
444
 
448
445
  uuid_regex: Pattern = re.compile(r"^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[89aAbB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$")
449
446
 
450
- worker_generator_uuid: str = uuid.uuid4()
447
+ worker_generator_uuid: str = str(uuid.uuid4())
451
448
 
452
449
  new_uuid: str = str(uuid.uuid4())
453
450
  run_uuid: str = os.getenv("RUN_UUID", new_uuid)
@@ -476,7 +473,7 @@ def get_current_run_folder(name: Optional[str] = None) -> str:
476
473
 
477
474
  return CURRENT_RUN_FOLDER
478
475
 
479
- def get_state_file_name(name) -> str:
476
+ def get_state_file_name(name: str) -> str:
480
477
  state_files_folder = f"{get_current_run_folder()}/state_files/"
481
478
  makedirs(state_files_folder)
482
479
 
@@ -500,6 +497,24 @@ try:
500
497
  dier: FunctionType = helpers.dier
501
498
  is_equal: FunctionType = helpers.is_equal
502
499
  is_not_equal: FunctionType = helpers.is_not_equal
500
+ with spinner("Importing pareto..."):
501
+ pareto_file: str = f"{script_dir}/.pareto.py"
502
+ spec = importlib.util.spec_from_file_location(
503
+ name="pareto",
504
+ location=pareto_file,
505
+ )
506
+ if spec is not None and spec.loader is not None:
507
+ pareto = importlib.util.module_from_spec(spec)
508
+ spec.loader.exec_module(pareto)
509
+ else:
510
+ raise ImportError(f"Could not load module from {pareto_file}")
511
+
512
+ pareto_front_table_filter_rows: FunctionType = pareto.pareto_front_table_filter_rows
513
+ pareto_front_table_add_headers: FunctionType = pareto.pareto_front_table_add_headers
514
+ pareto_front_table_add_rows: FunctionType = pareto.pareto_front_table_add_rows
515
+ pareto_front_filter_complete_points: FunctionType = pareto.pareto_front_filter_complete_points
516
+ pareto_front_select_pareto_points: FunctionType = pareto.pareto_front_select_pareto_points
517
+
503
518
  except KeyboardInterrupt:
504
519
  print("You pressed CTRL-c while importing the helpers file")
505
520
  sys.exit(0)
@@ -523,15 +538,13 @@ def is_slurm_job() -> bool:
523
538
  return True
524
539
  return False
525
540
 
526
- def _sleep(t: int) -> int:
541
+ def _sleep(t: Union[float, int]) -> None:
527
542
  if args is not None and not args.no_sleep:
528
543
  try:
529
544
  time.sleep(t)
530
545
  except KeyboardInterrupt:
531
546
  pass
532
547
 
533
- return t
534
-
535
548
  LOG_DIR: str = "logs"
536
549
  makedirs(LOG_DIR)
537
550
 
@@ -544,6 +557,17 @@ logfile_worker_creation_logs: str = f'{log_uuid_dir}_worker_creation_logs'
544
557
  logfile_trial_index_to_param_logs: str = f'{log_uuid_dir}_trial_index_to_param_logs'
545
558
  LOGFILE_DEBUG_GET_NEXT_TRIALS: Union[str, None] = None
546
559
 
560
+ def error_without_print(text: str) -> None:
561
+ print_debug(text)
562
+
563
+ if get_current_run_folder():
564
+ try:
565
+ with open(get_current_run_folder("oo_errors.txt"), mode="a", encoding="utf-8") as myfile:
566
+ myfile.write(text + "\n\n")
567
+ except (OSError, FileNotFoundError) as e:
568
+ helpers.print_color("red", f"Error: {e}. This may mean that the {get_current_run_folder()} was deleted during the run. Could not write '{text} to {get_current_run_folder()}/oo_errors.txt'")
569
+ sys.exit(99)
570
+
547
571
  def print_red(text: str) -> None:
548
572
  helpers.print_color("red", text)
549
573
 
@@ -577,20 +601,22 @@ def _debug(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) ->
577
601
  def _get_debug_json(time_str: str, msg: str) -> str:
578
602
  function_stack = []
579
603
  try:
580
- frame = inspect.currentframe().f_back # skip _get_debug_json
581
- while frame:
582
- func_name = _function_name_cache.get(frame.f_code)
583
- if func_name is None:
584
- func_name = frame.f_code.co_name
585
- _function_name_cache[frame.f_code] = func_name
586
-
587
- if func_name not in ("<module>", "print_debug", "wrapper"):
588
- function_stack.append({
589
- "function": func_name,
590
- "line_number": frame.f_lineno
591
- })
592
-
593
- frame = frame.f_back
604
+ cf = inspect.currentframe()
605
+ if cf:
606
+ frame = cf.f_back # skip _get_debug_json
607
+ while frame:
608
+ func_name = _function_name_cache.get(frame.f_code)
609
+ if func_name is None:
610
+ func_name = frame.f_code.co_name
611
+ _function_name_cache[frame.f_code] = func_name
612
+
613
+ if func_name not in ("<module>", "print_debug", "wrapper"):
614
+ function_stack.append({
615
+ "function": func_name,
616
+ "line_number": frame.f_lineno
617
+ })
618
+
619
+ frame = frame.f_back
594
620
  except (SignalUSR, SignalINT, SignalCONT):
595
621
  print_red("\n⚠ You pressed CTRL-C. This is ignored in _get_debug_json.")
596
622
 
@@ -599,20 +625,46 @@ def _get_debug_json(time_str: str, msg: str) -> str:
599
625
  separators=(",", ":") # no pretty indent → smaller, faster
600
626
  ).replace('\r', '').replace('\n', '')
601
627
 
602
- def print_debug(msg: str) -> None:
603
- original_msg = msg
628
+ def print_stack_paths() -> None:
629
+ stack = inspect.stack()[1:] # skip current frame
630
+ stack.reverse()
604
631
 
605
- time_str: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
632
+ last_filename = None
633
+ for depth, frame_info in enumerate(stack):
634
+ filename = frame_info.filename
635
+ lineno = frame_info.lineno
636
+ func_name = frame_info.function
606
637
 
607
- stack_trace_element = _get_debug_json(time_str, msg)
638
+ if func_name in ["<module>", "print_debug"]:
639
+ continue
608
640
 
609
- msg = f"{stack_trace_element}"
641
+ if filename != last_filename:
642
+ print(filename)
643
+ last_filename = filename
644
+ indent = ""
645
+ else:
646
+ indent = " " * 4 * depth
647
+
648
+ print(f"{indent}↳ {func_name}:{lineno}")
649
+
650
+ def print_debug(msg: str) -> None:
651
+ time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
652
+
653
+ stack = traceback.extract_stack()[:-1]
654
+ stack_funcs = [frame.name for frame in stack]
610
655
 
611
- _debug(msg)
656
+ if "args" in globals() and args and hasattr(args, "debug_stack_regex") and args.debug_stack_regex:
657
+ matched = any(any(re.match(regex, func) for regex in args.debug_stack_regex) for func in stack_funcs)
658
+ if matched:
659
+ print(f"DEBUG: {msg}")
660
+ print_stack_paths()
661
+
662
+ stack_trace_element = _get_debug_json(time_str, msg)
663
+ _debug(stack_trace_element)
612
664
 
613
665
  try:
614
666
  with open(logfile_bare, mode='a', encoding="utf-8") as f:
615
- original_print(original_msg, file=f)
667
+ original_print(msg, file=f)
616
668
  except FileNotFoundError:
617
669
  print_red("It seems like the run's folder was deleted during the run. Cannot continue.")
618
670
  sys.exit(99)
@@ -666,11 +718,14 @@ def my_exit(_code: int = 0) -> None:
666
718
  if is_skip_search() and os.getenv("SKIP_SEARCH_EXIT_CODE"):
667
719
  skip_search_exit_code = os.getenv("SKIP_SEARCH_EXIT_CODE")
668
720
 
721
+ skip_search_exit_code_found = None
722
+
669
723
  try:
670
- sys.exit(int(skip_search_exit_code))
724
+ if skip_search_exit_code_found is not None:
725
+ skip_search_exit_code_found = int(skip_search_exit_code)
726
+ sys.exit(skip_search_exit_code_found)
671
727
  except ValueError:
672
- print(f"Trying to look for SKIP_SEARCH_EXIT_CODE failed. Exiting with original exit code {_code}")
673
- sys.exit(_code)
728
+ print_debug(f"Trying to look for SKIP_SEARCH_EXIT_CODE failed. Exiting with original exit code {_code}")
674
729
 
675
730
  sys.exit(_code)
676
731
 
@@ -730,6 +785,7 @@ _DEFAULT_SPECIALS: Dict[str, Any] = {
730
785
  class ConfigLoader:
731
786
  runtime_debug: bool
732
787
  show_func_name: bool
788
+ debug_stack_regex: str
733
789
  number_of_generators: int
734
790
  disable_previous_job_constraint: bool
735
791
  save_to_database: bool
@@ -770,11 +826,13 @@ class ConfigLoader:
770
826
  parameter: Optional[List[str]]
771
827
  experiment_constraints: Optional[List[str]]
772
828
  main_process_gb: int
829
+ beartype: bool
773
830
  worker_timeout: int
774
831
  slurm_signal_delay_s: int
775
832
  gridsearch: bool
776
833
  auto_exclude_defective_hosts: bool
777
834
  debug: bool
835
+ debug_stack_trace_regex: Optional[str]
778
836
  num_restarts: int
779
837
  raw_samples: int
780
838
  show_generate_time_table: bool
@@ -789,6 +847,7 @@ class ConfigLoader:
789
847
  run_program_once: str
790
848
  mem_gb: int
791
849
  flame_graph: bool
850
+ memray: bool
792
851
  continue_previous_job: Optional[str]
793
852
  calculate_pareto_front_of_job: Optional[List[str]]
794
853
  revert_to_random_when_seemingly_exhausted: bool
@@ -950,6 +1009,7 @@ class ConfigLoader:
950
1009
  debug.add_argument('--verbose_break_run_search_table', help='Verbose logging for break_run_search', action='store_true', default=False)
951
1010
  debug.add_argument('--debug', help='Enable debugging', action='store_true', default=False)
952
1011
  debug.add_argument('--flame_graph', help='Enable flame-graphing. Makes everything slower, but creates a flame graph', action='store_true', default=False)
1012
+ debug.add_argument('--memray', help='Use memray to show memory usage', action='store_true', default=False)
953
1013
  debug.add_argument('--no_sleep', help='Disables sleeping for fast job generation (not to be used on HPC)', action='store_true', default=False)
954
1014
  debug.add_argument('--tests', help='Run simple internal tests', action='store_true', default=False)
955
1015
  debug.add_argument('--show_worker_percentage_table_at_end', help='Show a table of percentage of usage of max worker over time', action='store_true', default=False)
@@ -961,7 +1021,10 @@ class ConfigLoader:
961
1021
  debug.add_argument('--just_return_defaults', help='Just return defaults in dryrun', action='store_true', default=False)
962
1022
  debug.add_argument('--prettyprint', help='Shows stdout and stderr in a pretty printed format', action='store_true', default=False)
963
1023
  debug.add_argument('--runtime_debug', help='Logs which functions use most of the time', action='store_true', default=False)
1024
+ debug.add_argument('--debug_stack_regex', help='Only print debug messages if call stack matches any regex', type=str, default='')
1025
+ debug.add_argument('--debug_stack_trace_regex', help='Show compact call stack with arrows if any function in stack matches regex', type=str, default=None)
964
1026
  debug.add_argument('--show_func_name', help='Show func name before each execution and when it is done', action='store_true', default=False)
1027
+ debug.add_argument('--beartype', help='Use beartype', action='store_true', default=False)
965
1028
 
966
1029
  def load_config(self: Any, config_path: str, file_format: str) -> dict:
967
1030
  if not os.path.isfile(config_path):
@@ -1199,11 +1262,15 @@ for _rn in args.result_names:
1199
1262
  _key = _rn
1200
1263
  _min_or_max = __default_min_max
1201
1264
 
1265
+ _min_or_max = re.sub(r"'", "", _min_or_max)
1266
+
1202
1267
  if _min_or_max not in ["min", "max"]:
1203
1268
  if _min_or_max:
1204
1269
  print_yellow(f"Value for determining whether to minimize or maximize was neither 'min' nor 'max' for key '{_key}', but '{_min_or_max}'. It will be set to the default, which is '{__default_min_max}' instead.")
1205
1270
  _min_or_max = __default_min_max
1206
1271
 
1272
+ _key = re.sub(r"'", "", _key)
1273
+
1207
1274
  if _key in arg_result_names:
1208
1275
  console.print(f"[red]The --result_names option '{_key}' was specified multiple times![/]")
1209
1276
  sys.exit(50)
@@ -1304,27 +1371,14 @@ try:
1304
1371
  with spinner("Importing ExternalGenerationNode..."):
1305
1372
  from ax.generation_strategy.external_generation_node import ExternalGenerationNode
1306
1373
 
1307
- with spinner("Importing MaxTrials..."):
1308
- from ax.generation_strategy.transition_criterion import MaxTrials
1374
+ with spinner("Importing MinTrials..."):
1375
+ from ax.generation_strategy.transition_criterion import MinTrials
1309
1376
 
1310
1377
  with spinner("Importing GeneratorSpec..."):
1311
1378
  from ax.generation_strategy.generator_spec import GeneratorSpec
1312
1379
 
1313
- #except Exception:
1314
- # with spinner("Fallback: Importing ax.generation_strategy.generation_node..."):
1315
- # import ax.generation_strategy.generation_node
1316
-
1317
- # with spinner("Fallback: Importing GenerationStep, GenerationStrategy from ax.generation_strategy..."):
1318
- # from ax.generation_strategy.generation_node import GenerationNode, GenerationStep
1319
-
1320
- # with spinner("Fallback: Importing ExternalGenerationNode..."):
1321
- # from ax.generation_strategy.external_generation_node import ExternalGenerationNode
1322
-
1323
- # with spinner("Fallback: Importing MaxTrials..."):
1324
- # from ax.generation_strategy.transition_criterion import MaxTrials
1325
-
1326
- with spinner("Importing Models from ax.generation_strategy.registry..."):
1327
- from ax.adapter.registry import Models
1380
+ with spinner("Importing Generators from ax.generation_strategy.registry..."):
1381
+ from ax.adapter.registry import Generators
1328
1382
 
1329
1383
  with spinner("Importing get_pending_observation_features..."):
1330
1384
  from ax.core.utils import get_pending_observation_features
@@ -1410,7 +1464,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1410
1464
  def __init__(self: Any, regressor_options: Dict[str, Any] = {}, seed: Optional[int] = None, num_samples: int = 1) -> None:
1411
1465
  print_debug("Initializing RandomForestGenerationNode...")
1412
1466
  t_init_start = time.monotonic()
1413
- super().__init__(node_name="RANDOMFOREST")
1467
+ super().__init__(name="RANDOMFOREST")
1414
1468
  self.num_samples: int = num_samples
1415
1469
  self.seed: int = seed
1416
1470
 
@@ -1430,6 +1484,9 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1430
1484
  def update_generator_state(self: Any, experiment: Experiment, data: Data) -> None:
1431
1485
  search_space = experiment.search_space
1432
1486
  parameter_names = list(search_space.parameters.keys())
1487
+ if experiment.optimization_config is None:
1488
+ print_red("Error: update_generator_state is None")
1489
+ return
1433
1490
  metric_names = list(experiment.optimization_config.metrics.keys())
1434
1491
 
1435
1492
  completed_trials = [
@@ -1441,7 +1498,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1441
1498
  y = np.zeros([num_completed_trials, 1])
1442
1499
 
1443
1500
  for t_idx, trial in enumerate(completed_trials):
1444
- trial_parameters = trial.arm.parameters
1501
+ trial_parameters = trial.arms[t_idx].parameters
1445
1502
  x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
1446
1503
  trial_df = data.df[data.df["trial_index"] == trial.index]
1447
1504
  y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]]["mean"].item()
@@ -1581,10 +1638,18 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1581
1638
  def _format_best_sample(self: Any, best_sample: TParameterization, reverse_choice_map: dict) -> None:
1582
1639
  for name in best_sample.keys():
1583
1640
  param = self.parameters.get(name)
1641
+ best_sample_by_name = best_sample[name]
1642
+
1584
1643
  if isinstance(param, RangeParameter) and param.parameter_type == ParameterType.INT:
1585
- best_sample[name] = int(round(best_sample[name]))
1644
+ if best_sample_by_name is not None:
1645
+ best_sample[name] = int(round(float(best_sample_by_name)))
1646
+ else:
1647
+ print_debug("best_sample_by_name was empty")
1586
1648
  elif isinstance(param, ChoiceParameter):
1587
- best_sample[name] = str(reverse_choice_map.get(int(best_sample[name])))
1649
+ if best_sample_by_name is not None:
1650
+ best_sample[name] = str(reverse_choice_map.get(int(best_sample_by_name)))
1651
+ else:
1652
+ print_debug("best_sample_by_name was empty")
1588
1653
 
1589
1654
  decoder_registry["RandomForestGenerationNode"] = RandomForestGenerationNode
1590
1655
 
@@ -1632,10 +1697,10 @@ class InteractiveCLIGenerationNode(ExternalGenerationNode):
1632
1697
 
1633
1698
  def __init__(
1634
1699
  self: Any,
1635
- node_name: str = "INTERACTIVE_GENERATOR",
1700
+ name: str = "INTERACTIVE_GENERATOR",
1636
1701
  ) -> None:
1637
1702
  t0 = time.monotonic()
1638
- super().__init__(node_name=node_name)
1703
+ super().__init__(name=name)
1639
1704
  self.parameters = None
1640
1705
  self.minimize = None
1641
1706
  self.data = None
@@ -1791,10 +1856,10 @@ class InteractiveCLIGenerationNode(ExternalGenerationNode):
1791
1856
 
1792
1857
  @dataclass(init=False)
1793
1858
  class ExternalProgramGenerationNode(ExternalGenerationNode):
1794
- def __init__(self: Any, external_generator: str = args.external_generator, node_name: str = "EXTERNAL_GENERATOR") -> None:
1859
+ def __init__(self: Any, external_generator: str = args.external_generator, name: str = "EXTERNAL_GENERATOR") -> None:
1795
1860
  print_debug("Initializing ExternalProgramGenerationNode...")
1796
1861
  t_init_start = time.monotonic()
1797
- super().__init__(node_name=node_name)
1862
+ super().__init__(name=name)
1798
1863
  self.seed: int = args.seed
1799
1864
  self.external_generator: str = decode_if_base64(external_generator)
1800
1865
  self.constraints = None
@@ -2015,24 +2080,15 @@ def run_live_share_command(force: bool = False) -> Tuple[str, str]:
2015
2080
  return str(result.stdout), str(result.stderr)
2016
2081
  except subprocess.CalledProcessError as e:
2017
2082
  if e.stderr:
2018
- original_print(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
2083
+ print_debug(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
2019
2084
  else:
2020
- original_print(f"run_live_share_command: command failed with error: {e}")
2085
+ print_debug(f"run_live_share_command: command failed with error: {e}")
2021
2086
  return "", str(e.stderr)
2022
2087
  except Exception as e:
2023
2088
  print(f"run_live_share_command: An error occurred: {e}")
2024
2089
 
2025
2090
  return "", ""
2026
2091
 
2027
- #def extract_and_print_qr(text: str) -> None:
2028
- # match = re.search(r"(https?://\S+|\b[\w.-]+@[\w.-]+\.\w+\b|\b\d{10,}\b)", text)
2029
- # if match:
2030
- # data = match.group(0)
2031
- # qr = qrcode.QRCode(box_size=1, error_correction=qrcode.constants.ERROR_CORRECT_L, border=0)
2032
- # qr.add_data(data)
2033
- # qr.make()
2034
- # qr.print_ascii(out=sys.stdout)
2035
-
2036
2092
  def force_live_share() -> bool:
2037
2093
  if args.live_share:
2038
2094
  return live_share(True)
@@ -2040,6 +2096,8 @@ def force_live_share() -> bool:
2040
2096
  return False
2041
2097
 
2042
2098
  def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
2099
+ log_data()
2100
+
2043
2101
  if not get_current_run_folder():
2044
2102
  print(f"live_share: get_current_run_folder was empty or false: {get_current_run_folder()}")
2045
2103
  return False
@@ -2052,25 +2110,24 @@ def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
2052
2110
  if text_and_qr:
2053
2111
  if stderr:
2054
2112
  print_green(stderr)
2055
- #extract_and_print_qr(stderr)
2056
2113
  else:
2057
- print_red("This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
2114
+ if stderr and stdout:
2115
+ print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
2116
+ elif stderr:
2117
+ print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}")
2118
+ elif stdout:
2119
+ print_red(f"This call should have shown the CURL, but didnt. Stdout: {stdout}")
2120
+ else:
2121
+ print_red("This call should have shown the CURL, but didnt.")
2058
2122
  if stdout:
2059
2123
  print_debug(f"live_share stdout: {stdout}")
2060
2124
 
2061
2125
  return True
2062
2126
 
2063
2127
  def init_live_share() -> bool:
2064
- with spinner("Initializing live share..."):
2065
- ret = live_share(True, True)
2066
-
2067
- return ret
2128
+ ret = live_share(True, True)
2068
2129
 
2069
- async def start_periodic_live_share() -> None:
2070
- if args.live_share and not os.environ.get("CI"):
2071
- while True:
2072
- live_share(force=False)
2073
- time.sleep(30)
2130
+ return ret
2074
2131
 
2075
2132
  def init_storage(db_url: str) -> None:
2076
2133
  init_engine_and_session_factory(url=db_url, force_init=True)
@@ -2096,29 +2153,33 @@ def try_saving_to_db() -> None:
2096
2153
  else:
2097
2154
  print_red("ax_client was not defined in try_saving_to_db")
2098
2155
  my_exit(101)
2099
- save_generation_strategy(global_gs)
2156
+
2157
+ if global_gs is not None:
2158
+ save_generation_strategy(global_gs)
2159
+ else:
2160
+ print_red("Not saving generation strategy: global_gs was empty")
2100
2161
  except Exception as e:
2101
2162
  print_debug(f"Failed trying to save sqlite3-DB: {e}")
2102
2163
 
2103
- def merge_with_job_infos(pd_frame: pd.DataFrame) -> pd.DataFrame:
2164
+ def merge_with_job_infos(df: pd.DataFrame) -> pd.DataFrame:
2104
2165
  job_infos_path = os.path.join(get_current_run_folder(), "job_infos.csv")
2105
2166
  if not os.path.exists(job_infos_path):
2106
- return pd_frame
2167
+ return df
2107
2168
 
2108
2169
  job_df = pd.read_csv(job_infos_path)
2109
2170
 
2110
- if 'trial_index' not in pd_frame.columns or 'trial_index' not in job_df.columns:
2171
+ if 'trial_index' not in df.columns or 'trial_index' not in job_df.columns:
2111
2172
  raise ValueError("Both DataFrames must contain a 'trial_index' column.")
2112
2173
 
2113
- job_df_filtered = job_df[job_df['trial_index'].isin(pd_frame['trial_index'])]
2174
+ job_df_filtered = job_df[job_df['trial_index'].isin(df['trial_index'])]
2114
2175
 
2115
- new_cols = [col for col in job_df_filtered.columns if col != 'trial_index' and col not in pd_frame.columns]
2176
+ new_cols = [col for col in job_df_filtered.columns if col != 'trial_index' and col not in df.columns]
2116
2177
 
2117
2178
  job_df_reduced = job_df_filtered[['trial_index'] + new_cols]
2118
2179
 
2119
- merged = pd.merge(pd_frame, job_df_reduced, on='trial_index', how='left')
2180
+ merged = pd.merge(df, job_df_reduced, on='trial_index', how='left')
2120
2181
 
2121
- old_cols = [col for col in pd_frame.columns if col != 'trial_index']
2182
+ old_cols = [col for col in df.columns if col != 'trial_index']
2122
2183
 
2123
2184
  new_order = ['trial_index'] + new_cols + old_cols
2124
2185
 
@@ -2126,48 +2187,9 @@ def merge_with_job_infos(pd_frame: pd.DataFrame) -> pd.DataFrame:
2126
2187
 
2127
2188
  return merged
2128
2189
 
2129
- def reindex_trials(df: pd.DataFrame) -> pd.DataFrame:
2130
- """
2131
- Ensure trial_index is sequential and arm_name unique.
2132
- Keep arm_name unless all parameters except 'order', 'hostname', 'queue_time' match.
2133
- """
2134
- if "trial_index" not in df.columns or "arm_name" not in df.columns:
2135
- return df
2136
-
2137
- # Sort by something stable (queue_time if available)
2138
- sort_cols = ["queue_time"] if "queue_time" in df.columns else df.columns.tolist()
2139
- df = df.sort_values(by=sort_cols, ignore_index=True)
2140
-
2141
- # Mapping from "parameter signature" to assigned arm_name
2142
- seen_signatures = {}
2143
- new_arm_names = []
2144
-
2145
- for new_idx, row in df.iterrows():
2146
- # Create signature without 'order', 'hostname', 'queue_time', 'trial_index', 'arm_name'
2147
- ignore_cols = {"order", "hostname", "queue_time", "trial_index", "arm_name"}
2148
- signature = tuple((col, row[col]) for col in df.columns if col not in ignore_cols)
2149
-
2150
- if signature in seen_signatures:
2151
- # Collision → make a unique name
2152
- base_name = seen_signatures[signature]
2153
- suffix = 1
2154
- new_name = f"{base_name}_{suffix}"
2155
- while new_name in new_arm_names:
2156
- suffix += 1
2157
- new_name = f"{base_name}_{suffix}"
2158
- new_arm_names.append(new_name)
2159
- else:
2160
- # First occurrence → use new_idx as trial index in name
2161
- new_name = f"{new_idx}_0"
2162
- seen_signatures[signature] = f"{new_idx}_0"
2163
- new_arm_names.append(new_name)
2164
-
2165
- df.at[new_idx, "trial_index"] = new_idx
2166
- df.at[new_idx, "arm_name"] = new_name
2167
-
2168
- return df
2169
-
2170
2190
  def save_results_csv() -> Optional[str]:
2191
+ log_data()
2192
+
2171
2193
  if args.dryrun:
2172
2194
  return None
2173
2195
 
@@ -2181,8 +2203,11 @@ def save_results_csv() -> Optional[str]:
2181
2203
  save_checkpoint()
2182
2204
 
2183
2205
  try:
2184
- pd_frame = fetch_and_prepare_trials()
2185
- write_csv(pd_frame, pd_csv)
2206
+ df = fetch_and_prepare_trials()
2207
+ if df is None:
2208
+ print_red(f"save_results_csv: fetch_and_prepare_trials returned an empty element: {df}")
2209
+ return None
2210
+ write_csv(df, pd_csv)
2186
2211
  write_json_snapshot(pd_json)
2187
2212
  save_experiment_to_file()
2188
2213
 
@@ -2194,32 +2219,67 @@ def save_results_csv() -> Optional[str]:
2194
2219
  except (SignalUSR, SignalCONT, SignalINT) as e:
2195
2220
  raise type(e)(str(e)) from e
2196
2221
  except Exception as e:
2197
- print_red(f"While saving all trials as a pandas-dataframe-csv, an error occurred: {e}")
2222
+ print_red(f"\nWhile saving all trials as a pandas-dataframe-csv, an error occurred: {e}")
2198
2223
 
2199
2224
  return pd_csv
2200
2225
 
2201
2226
  def get_results_paths() -> tuple[str, str]:
2202
2227
  return (get_current_run_folder(RESULTS_CSV_FILENAME), get_state_file_name('pd.json'))
2203
2228
 
2204
- def fetch_and_prepare_trials() -> pd.DataFrame:
2229
+ def ax_client_get_trials_data_frame() -> Optional[pd.DataFrame]:
2230
+ if not ax_client:
2231
+ my_exit(101)
2232
+
2233
+ return None
2234
+
2235
+ return ax_client.get_trials_data_frame()
2236
+
2237
+ def fetch_and_prepare_trials() -> Optional[pd.DataFrame]:
2238
+ if not ax_client:
2239
+ return None
2240
+
2205
2241
  ax_client.experiment.fetch_data()
2206
- df = ax_client.get_trials_data_frame()
2242
+ df = ax_client_get_trials_data_frame()
2207
2243
  df = merge_with_job_infos(df)
2208
- return reindex_trials(df)
2209
2244
 
2210
- def write_csv(df, path: str) -> None:
2245
+ return df
2246
+
2247
+ def write_csv(df: pd.DataFrame, path: str) -> None:
2248
+ try:
2249
+ df = df.sort_values(by=["trial_index"], kind="stable").reset_index(drop=True)
2250
+ except KeyError:
2251
+ pass
2211
2252
  df.to_csv(path, index=False, float_format="%.30f")
2212
2253
 
2213
- def write_json_snapshot(path: str) -> None:
2254
+ def ax_client_to_json_snapshot() -> Optional[dict]:
2255
+ if not ax_client:
2256
+ my_exit(101)
2257
+
2258
+ return None
2259
+
2214
2260
  json_snapshot = ax_client.to_json_snapshot()
2215
- with open(path, "w", encoding="utf-8") as f:
2216
- json.dump(json_snapshot, f, indent=4)
2261
+
2262
+ return json_snapshot
2263
+
2264
+ def write_json_snapshot(path: str) -> None:
2265
+ if ax_client is not None:
2266
+ json_snapshot = ax_client_to_json_snapshot()
2267
+ if json_snapshot is not None:
2268
+ with open(path, "w", encoding="utf-8") as f:
2269
+ json.dump(json_snapshot, f, indent=4)
2270
+ else:
2271
+ print_debug('json_snapshot from ax_client_to_json_snapshot was None')
2272
+ else:
2273
+ print_red("write_json_snapshot: ax_client was None")
2217
2274
 
2218
2275
  def save_experiment_to_file() -> None:
2219
- save_experiment(
2220
- ax_client.experiment,
2221
- get_state_file_name("ax_client.experiment.json")
2222
- )
2276
+ if ax_client is not None:
2277
+ save_experiment(
2278
+ ax_client.experiment,
2279
+ get_state_file_name("ax_client.experiment.json")
2280
+ )
2281
+ else:
2282
+ print_red("save_experiment: ax_client is None")
2223
2283
 
2224
2284
  def should_save_to_database() -> bool:
2225
2285
  return args.model not in uncontinuable_models and args.save_to_database
@@ -2408,7 +2468,7 @@ def set_nr_inserted_jobs(new_nr_inserted_jobs: int) -> None:
2408
2468
 
2409
2469
  def write_worker_usage() -> None:
2410
2470
  if len(WORKER_PERCENTAGE_USAGE):
2411
- csv_filename = get_current_run_folder('worker_usage.csv')
2471
+ csv_filename = get_current_run_folder(worker_usage_file)
2412
2472
 
2413
2473
  csv_columns = ['time', 'num_parallel_jobs', 'nr_current_workers', 'percentage']
2414
2474
 
@@ -2418,35 +2478,39 @@ def write_worker_usage() -> None:
2418
2478
  csv_writer.writerow(row)
2419
2479
  else:
2420
2480
  if is_slurm_job():
2421
- print_debug("WORKER_PERCENTAGE_USAGE seems to be empty. Not writing worker_usage.csv")
2481
+ print_debug(f"WORKER_PERCENTAGE_USAGE seems to be empty. Not writing {worker_usage_file}")
2422
2482
 
2423
2483
  def log_system_usage() -> None:
2484
+ global LAST_LOG_TIME
2485
+
2486
+ now = time.time()
2487
+ if now - LAST_LOG_TIME < 30:
2488
+ return
2489
+
2490
+ LAST_LOG_TIME = int(now)
2491
+
2424
2492
  if not get_current_run_folder():
2425
2493
  return
2426
2494
 
2427
2495
  ram_cpu_csv_file_path = os.path.join(get_current_run_folder(), "cpu_ram_usage.csv")
2428
-
2429
2496
  makedirs(os.path.dirname(ram_cpu_csv_file_path))
2430
2497
 
2431
2498
  file_exists = os.path.isfile(ram_cpu_csv_file_path)
2432
2499
 
2433
- with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
2434
- writer = csv.writer(file)
2435
-
2436
- current_time = int(time.time())
2437
-
2438
- if process is not None:
2439
- mem_proc = process.memory_info()
2440
-
2441
- if mem_proc is not None:
2442
- ram_usage_mb = mem_proc.rss / (1024 * 1024)
2443
- cpu_usage_percent = psutil.cpu_percent(percpu=False)
2500
+ mem_proc = process.memory_info() if process else None
2501
+ if not mem_proc:
2502
+ return
2444
2503
 
2445
- if ram_usage_mb > 0 and cpu_usage_percent > 0:
2446
- if not file_exists:
2447
- writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
2504
+ ram_usage_mb = mem_proc.rss / (1024 * 1024)
2505
+ cpu_usage_percent = psutil.cpu_percent(percpu=False)
2506
+ if ram_usage_mb <= 0 or cpu_usage_percent <= 0:
2507
+ return
2448
2508
 
2449
- writer.writerow([current_time, ram_usage_mb, cpu_usage_percent])
2509
+ with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
2510
+ writer = csv.writer(file)
2511
+ if not file_exists:
2512
+ writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
2513
+ writer.writerow([int(now), ram_usage_mb, cpu_usage_percent])
2450
2514
 
2451
2515
  def write_process_info() -> None:
2452
2516
  try:
@@ -2598,9 +2662,6 @@ def _debug_worker_creation(msg: str, _lvl: int = 0, eee: Union[None, str, Except
2598
2662
  def append_to_nvidia_smi_logs(_file: str, _host: str, result: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
2599
2663
  log_message_to_file(_file, result, _lvl, str(eee))
2600
2664
 
2601
- def _debug_get_next_trials(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
2602
- log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, _lvl, str(eee))
2603
-
2604
2665
  def _debug_progressbar(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) -> None:
2605
2666
  log_message_to_file(logfile_progressbar, msg, _lvl, str(eee))
2606
2667
 
@@ -2796,13 +2857,23 @@ def print_debug_get_next_trials(got: int, requested: int, _line: int) -> None:
2796
2857
  time_str: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
2797
2858
  msg: str = f"{time_str}, {got}, {requested}"
2798
2859
 
2799
- _debug_get_next_trials(msg)
2860
+ log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, 0, "")
2800
2861
 
2801
2862
  def print_debug_progressbar(msg: str) -> None:
2802
- time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
2803
- msg = f"{time_str} ({worker_generator_uuid}): {msg}"
2863
+ global last_msg_progressbar, last_msg_raw
2864
+
2865
+ try:
2866
+ with last_lock_print_debug:
2867
+ if msg != last_msg_raw:
2868
+ time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
2869
+ full_msg = f"{time_str} ({worker_generator_uuid}): {msg}"
2870
+
2871
+ _debug_progressbar(full_msg)
2804
2872
 
2805
- _debug_progressbar(msg)
2873
+ last_msg_raw = msg
2874
+ last_msg_progressbar = full_msg
2875
+ except Exception as e:
2876
+ print(f"Error in print_debug_progressbar: {e}", flush=True)
2806
2877
 
2807
2878
  def get_process_info(pid: Any) -> str:
2808
2879
  try:
@@ -3307,116 +3378,438 @@ def parse_experiment_parameters() -> None:
3307
3378
  # Remove duplicates by 'name' key preserving order
3308
3379
  params = list({p['name']: p for p in params}.values())
3309
3380
 
3310
- experiment_parameters = params
3381
+ experiment_parameters = params # type: ignore[assignment]
3311
3382
 
3312
- def check_factorial_range() -> None:
3313
- if args.model and args.model == "FACTORIAL":
3314
- _fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
3383
+ def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
3384
+ pf_start_time = time.time()
3315
3385
 
3316
- def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
3317
- if value_type not in valid_value_types:
3318
- valid_value_types_string = ", ".join(valid_value_types)
3319
- _fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
3386
+ if not path_to_calculate:
3387
+ return False
3320
3388
 
3321
- def check_range_params_length(this_args: Union[str, list]) -> None:
3322
- if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
3323
- _fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
3389
+ global CURRENT_RUN_FOLDER
3390
+ global RESULT_CSV_FILE
3391
+ global arg_result_names
3324
3392
 
3325
- def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
3326
- if upper_bound is None or lower_bound is None:
3327
- _fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
3328
- if upper_bound == lower_bound:
3329
- if lower_bound == 0:
3330
- _fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
3331
- print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
3332
- if upper_bound is not None:
3333
- lower_bound = -upper_bound
3393
+ if not path_to_calculate:
3394
+ print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
3395
+ return False
3334
3396
 
3335
- def format_value(value: Any, float_format: str = '.80f') -> str:
3336
- try:
3337
- if isinstance(value, float):
3338
- s = format(value, float_format)
3339
- s = s.rstrip('0').rstrip('.') if '.' in s else s
3340
- return s
3341
- return str(value)
3342
- except Exception as e:
3343
- print_red(f"⚠ Error formatting the number {value}: {e}")
3344
- return str(value)
3397
+ if not os.path.exists(path_to_calculate):
3398
+ print_red(f"Path '{path_to_calculate}' does not exist")
3399
+ return False
3345
3400
 
3346
- def replace_parameters_in_string(
3347
- parameters: dict,
3348
- input_string: str,
3349
- float_format: str = '.20f',
3350
- additional_prefixes: list[str] = [],
3351
- additional_patterns: list[str] = [],
3352
- ) -> str:
3353
- try:
3354
- prefixes = ['$', '%'] + additional_prefixes
3355
- patterns = ['{key}', '({key})'] + additional_patterns
3401
+ ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
3356
3402
 
3357
- for key, value in parameters.items():
3358
- replacement = format_value(value, float_format=float_format)
3359
- for prefix in prefixes:
3360
- for pattern in patterns:
3361
- token = prefix + pattern.format(key=key)
3362
- input_string = input_string.replace(token, replacement)
3403
+ if not os.path.exists(ax_client_json):
3404
+ print_red(f"Path '{ax_client_json}' not found")
3405
+ return False
3363
3406
 
3364
- input_string = input_string.replace('\r', ' ').replace('\n', ' ')
3365
- return input_string
3407
+ checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
3408
+ if not os.path.exists(checkpoint_file):
3409
+ print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
3410
+ return False
3366
3411
 
3367
- except Exception as e:
3368
- print_red(f"\n⚠ Error: {e}")
3369
- return ""
3412
+ RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3413
+ if not os.path.exists(RESULT_CSV_FILE):
3414
+ print_red(f"{RESULT_CSV_FILE} not found")
3415
+ return False
3370
3416
 
3371
- def get_memory_usage() -> float:
3372
- user_uid = os.getuid()
3417
+ res_names = []
3373
3418
 
3374
- memory_usage = float(sum(
3375
- p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
3376
- if p.info['uids'].real == user_uid
3377
- ) / (1024 * 1024))
3419
+ res_names_file = f"{path_to_calculate}/result_names.txt"
3420
+ if not os.path.exists(res_names_file):
3421
+ print_red(f"File '{res_names_file}' does not exist")
3422
+ return False
3378
3423
 
3379
- return memory_usage
3424
+ try:
3425
+ with open(res_names_file, "r", encoding="utf-8") as file:
3426
+ lines = file.readlines()
3427
+ except Exception as e:
3428
+ print_red(f"Error reading file '{res_names_file}': {e}")
3429
+ return False
3380
3430
 
3381
- class MonitorProcess:
3382
- def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
3383
- self.pid = pid
3384
- self.interval = interval
3385
- self.running = True
3386
- self.thread = threading.Thread(target=self._monitor)
3387
- self.thread.daemon = True
3431
+ for line in lines:
3432
+ entry = line.strip()
3433
+ if entry != "":
3434
+ res_names.append(entry)
3388
3435
 
3389
- fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
3436
+ if len(res_names) < 2:
3437
+ print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
3438
+ return False
3390
3439
 
3391
- def _monitor(self: Any) -> None:
3392
- try:
3393
- _internal_process = psutil.Process(self.pid)
3394
- while self.running and _internal_process.is_running():
3395
- crf = get_current_run_folder()
3440
+ load_username_to_args(path_to_calculate)
3396
3441
 
3397
- if crf and crf != "":
3398
- log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
3442
+ CURRENT_RUN_FOLDER = path_to_calculate
3399
3443
 
3400
- os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
3444
+ arg_result_names = res_names
3401
3445
 
3402
- with open(log_file_path, mode="a", encoding="utf-8") as log_file:
3403
- hostname = socket.gethostname()
3446
+ load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
3404
3447
 
3405
- slurm_job_id = os.getenv("SLURM_JOB_ID")
3448
+ if experiment_parameters is None:
3449
+ return False
3406
3450
 
3407
- if slurm_job_id:
3408
- hostname += f"-SLURM-ID-{slurm_job_id}"
3451
+ show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
3409
3452
 
3410
- total_memory = psutil.virtual_memory().total / (1024 * 1024)
3411
- cpu_usage = psutil.cpu_percent(interval=5)
3453
+ pf_end_time = time.time()
3412
3454
 
3413
- memory_usage = get_memory_usage()
3455
+ print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
3414
3456
 
3415
- unix_timestamp = int(time.time())
3457
+ return True
3416
3458
 
3417
- log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
3418
- time.sleep(self.interval)
3419
- except psutil.NoSuchProcess:
3459
+ def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
3460
+ if args.dryrun:
3461
+ print_debug("Not showing Pareto-frontier data with --dryrun")
3462
+ return None
3463
+
3464
+ if len(res_names) > 1:
3465
+ try:
3466
+ show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
3467
+ except Exception as e:
3468
+ inner_tb = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
3469
+ print_red(f"show_pareto_frontier_data() failed with exception '{e}':\n{inner_tb}")
3470
+ else:
3471
+ print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
3472
+ return None
3473
+
3474
+ def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
3475
+ pareto_front_data: dict = {}
3476
+
3477
+ all_combinations = list(combinations(range(len(arg_result_names)), 2))
3478
+
3479
+ skip = False
3480
+
3481
+ for i, j in all_combinations:
3482
+ if not skip:
3483
+ metric_x = arg_result_names[i]
3484
+ metric_y = arg_result_names[j]
3485
+
3486
+ x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
3487
+ y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
3488
+
3489
+ try:
3490
+ if metric_x not in pareto_front_data:
3491
+ pareto_front_data[metric_x] = {}
3492
+
3493
+ pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
3494
+ except ax.exceptions.core.DataRequiredError as e:
3495
+ print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
3496
+ except SignalINT:
3497
+ print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
3498
+ skip = True
3499
+
3500
+ return pareto_front_data
3501
+
3502
+ def pareto_front_transform_objectives(
3503
+ points: List[Tuple[Any, float, float]],
3504
+ primary_name: str,
3505
+ secondary_name: str
3506
+ ) -> Tuple[np.ndarray, np.ndarray]:
3507
+ primary_idx = arg_result_names.index(primary_name)
3508
+ secondary_idx = arg_result_names.index(secondary_name)
3509
+
3510
+ x = np.array([p[1] for p in points])
3511
+ y = np.array([p[2] for p in points])
3512
+
3513
+ if arg_result_min_or_max[primary_idx] == "max":
3514
+ x = -x
3515
+ elif arg_result_min_or_max[primary_idx] != "min":
3516
+ raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
3517
+
3518
+ if arg_result_min_or_max[secondary_idx] == "max":
3519
+ y = -y
3520
+ elif arg_result_min_or_max[secondary_idx] != "min":
3521
+ raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
3522
+
3523
+ return x, y
3524
+
3525
+ def get_pareto_frontier_points(
3526
+ path_to_calculate: str,
3527
+ primary_objective: str,
3528
+ secondary_objective: str,
3529
+ x_minimize: bool,
3530
+ y_minimize: bool,
3531
+ absolute_metrics: List[str],
3532
+ num_points: int
3533
+ ) -> Optional[dict]:
3534
+ records = pareto_front_aggregate_data(path_to_calculate)
3535
+
3536
+ if records is None:
3537
+ return None
3538
+
3539
+ points = pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
3540
+ x, y = pareto_front_transform_objectives(points, primary_objective, secondary_objective)
3541
+ selected_points = pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
3542
+ result = pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
3543
+
3544
+ return result
3545
+
3546
+ def pareto_front_table_read_csv() -> List[Dict[str, str]]:
3547
+ with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
3548
+ return list(csv.DictReader(f))
3549
+
3550
+ def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
3551
+ table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
3552
+
3553
+ rows = pareto_front_table_read_csv()
3554
+ if not rows:
3555
+ table.add_column("No data found")
3556
+ return table
3557
+
3558
+ filtered_rows = pareto_front_table_filter_rows(rows, idxs)
3559
+ if not filtered_rows:
3560
+ table.add_column("No matching entries")
3561
+ return table
3562
+
3563
+ param_cols, result_cols = pareto_front_table_get_columns(filtered_rows[0])
3564
+
3565
+ pareto_front_table_add_headers(table, param_cols, result_cols)
3566
+ pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
3567
+
3568
+ return table
3569
+
3570
+ def pareto_front_build_return_structure(
3571
+ path_to_calculate: str,
3572
+ selected_points: List[Tuple[Any, float, float]],
3573
+ records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
3574
+ absolute_metrics: List[str],
3575
+ primary_name: str,
3576
+ secondary_name: str
3577
+ ) -> dict:
3578
+ results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3579
+ result_names_file = f"{path_to_calculate}/result_names.txt"
3580
+
3581
+ with open(result_names_file, mode="r", encoding="utf-8") as f:
3582
+ result_names = [line.strip() for line in f if line.strip()]
3583
+
3584
+ csv_rows = {}
3585
+ with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
3586
+ reader = csv.DictReader(csvfile)
3587
+ for row in reader:
3588
+ trial_index = int(row['trial_index'])
3589
+ csv_rows[trial_index] = row
3590
+
3591
+ ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
3592
+ ignored_columns.update(result_names)
3593
+
3594
+ param_dicts = []
3595
+ idxs = []
3596
+ means_dict = defaultdict(list)
3597
+
3598
+ for (trial_index, arm_name), _, _ in selected_points:
3599
+ row = csv_rows.get(trial_index, {})
3600
+ if row == {} or row is None or row['arm_name'] != arm_name:
3601
+ continue
3602
+
3603
+ idxs.append(int(row["trial_index"]))
3604
+
3605
+ param_dict: dict[str, int | float | str] = {}
3606
+ for key, value in row.items():
3607
+ if key not in ignored_columns:
3608
+ try:
3609
+ param_dict[key] = int(value)
3610
+ except ValueError:
3611
+ try:
3612
+ param_dict[key] = float(value)
3613
+ except ValueError:
3614
+ param_dict[key] = value
3615
+
3616
+ param_dicts.append(param_dict)
3617
+
3618
+ for metric in absolute_metrics:
3619
+ means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
3620
+
3621
+ ret = {
3622
+ primary_name: {
3623
+ secondary_name: {
3624
+ "absolute_metrics": absolute_metrics,
3625
+ "param_dicts": param_dicts,
3626
+ "means": dict(means_dict),
3627
+ "idxs": idxs
3628
+ },
3629
+ "absolute_metrics": absolute_metrics
3630
+ }
3631
+ }
3632
+
3633
+ return ret
3634
+
3635
+ def pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
3636
+ results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3637
+ result_names_file = f"{path_to_calculate}/result_names.txt"
3638
+
3639
+ if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
3640
+ return None
3641
+
3642
+ with open(result_names_file, mode="r", encoding="utf-8") as f:
3643
+ result_names = [line.strip() for line in f if line.strip()]
3644
+
3645
+ records: dict = defaultdict(lambda: {'means': {}})
3646
+
3647
+ with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
3648
+ reader = csv.DictReader(csvfile)
3649
+ for row in reader:
3650
+ trial_index = int(row['trial_index'])
3651
+ arm_name = row['arm_name']
3652
+ key = (trial_index, arm_name)
3653
+
3654
+ for metric in result_names:
3655
+ if metric in row:
3656
+ try:
3657
+ records[key]['means'][metric] = float(row[metric])
3658
+ except ValueError:
3659
+ continue
3660
+
3661
+ return records
3662
+
3663
+ def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
3664
+ if data is None:
3665
+ print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
3666
+ return
3667
+
3668
+ if not supports_sixel():
3669
+ print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
3670
+ return
3671
+
3672
+ import matplotlib.pyplot as plt
3673
+
3674
+ means = data[x_metric][y_metric]["means"]
3675
+
3676
+ x_values = means[x_metric]
3677
+ y_values = means[y_metric]
3678
+
3679
+ fig, _ax = plt.subplots()
3680
+
3681
+ _ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
3682
+
3683
+ _ax.set_xlabel(x_metric)
3684
+ _ax.set_ylabel(y_metric)
3685
+
3686
+ _ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
3687
+
3688
+ _ax.ticklabel_format(style='plain', axis='both', useOffset=False)
3689
+
3690
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
3691
+ plt.savefig(tmp_file.name, dpi=300)
3692
+
3693
+ print_image_to_cli(tmp_file.name, 1000)
3694
+
3695
+ plt.close(fig)
3696
+
3697
+ def pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
3698
+ all_columns = list(first_row.keys())
3699
+ ignored_cols = set(special_col_names) - {"trial_index"}
3700
+
3701
+ param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
3702
+ result_cols = [col for col in arg_result_names if col in all_columns]
3703
+ return param_cols, result_cols
3704
+
3705
+ def check_factorial_range() -> None:
3706
+ if args.model and args.model == "FACTORIAL":
3707
+ _fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
3708
+
3709
+ def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
3710
+ if value_type not in valid_value_types:
3711
+ valid_value_types_string = ", ".join(valid_value_types)
3712
+ _fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
3713
+
3714
+ def check_range_params_length(this_args: Union[str, list]) -> None:
3715
+ if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
3716
+ _fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
3717
+
3718
+ def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
3719
+ if upper_bound is None or lower_bound is None:
3720
+ _fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
3721
+ if upper_bound == lower_bound:
3722
+ if lower_bound == 0:
3723
+ _fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
3724
+ print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
3725
+ if upper_bound is not None:
3726
+ lower_bound = -upper_bound
3727
+
3728
+ def format_value(value: Any, float_format: str = '.80f') -> str:
3729
+ try:
3730
+ if isinstance(value, float):
3731
+ s = format(value, float_format)
3732
+ s = s.rstrip('0').rstrip('.') if '.' in s else s
3733
+ return s
3734
+ return str(value)
3735
+ except Exception as e:
3736
+ print_red(f"⚠ Error formatting the number {value}: {e}")
3737
+ return str(value)
3738
+
3739
+ def replace_parameters_in_string(
3740
+ parameters: dict,
3741
+ input_string: str,
3742
+ float_format: str = '.20f',
3743
+ additional_prefixes: list[str] = [],
3744
+ additional_patterns: list[str] = [],
3745
+ ) -> str:
3746
+ try:
3747
+ prefixes = ['$', '%'] + additional_prefixes
3748
+ patterns = ['{' + 'key' + '}', '(' + '{' + 'key' + '}' + ')'] + additional_patterns
3749
+
3750
+ for key, value in parameters.items():
3751
+ replacement = format_value(value, float_format=float_format)
3752
+ for prefix in prefixes:
3753
+ for pattern in patterns:
3754
+ token = prefix + pattern.format(key=key)
3755
+ input_string = input_string.replace(token, replacement)
3756
+
3757
+ input_string = input_string.replace('\r', ' ').replace('\n', ' ')
3758
+ return input_string
3759
+
3760
+ except Exception as e:
3761
+ print_red(f"\n⚠ Error: {e}")
3762
+ return ""
3763
+
3764
+ def get_memory_usage() -> float:
3765
+ user_uid = os.getuid()
3766
+
3767
+ memory_usage = float(sum(
3768
+ p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
3769
+ if p.info['uids'].real == user_uid
3770
+ ) / (1024 * 1024))
3771
+
3772
+ return memory_usage
3773
+
3774
+ class MonitorProcess:
3775
+ def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
3776
+ self.pid = pid
3777
+ self.interval = interval
3778
+ self.running = True
3779
+ self.thread = threading.Thread(target=self._monitor)
3780
+ self.thread.daemon = True
3781
+
3782
+ fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
3783
+
3784
+ def _monitor(self: Any) -> None:
3785
+ try:
3786
+ _internal_process = psutil.Process(self.pid)
3787
+ while self.running and _internal_process.is_running():
3788
+ crf = get_current_run_folder()
3789
+
3790
+ if crf and crf != "":
3791
+ log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
3792
+
3793
+ os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
3794
+
3795
+ with open(log_file_path, mode="a", encoding="utf-8") as log_file:
3796
+ hostname = socket.gethostname()
3797
+
3798
+ slurm_job_id = os.getenv("SLURM_JOB_ID")
3799
+
3800
+ if slurm_job_id:
3801
+ hostname += f"-SLURM-ID-{slurm_job_id}"
3802
+
3803
+ total_memory = psutil.virtual_memory().total / (1024 * 1024)
3804
+ cpu_usage = psutil.cpu_percent(interval=5)
3805
+
3806
+ memory_usage = get_memory_usage()
3807
+
3808
+ unix_timestamp = int(time.time())
3809
+
3810
+ log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
3811
+ time.sleep(self.interval)
3812
+ except psutil.NoSuchProcess:
3420
3813
  pass
3421
3814
 
3422
3815
  def __enter__(self: Any) -> None:
@@ -3560,7 +3953,7 @@ def _add_to_csv_acquire_lock(lockfile: str, dir_path: str) -> bool:
3560
3953
  time.sleep(wait_time)
3561
3954
  max_wait -= wait_time
3562
3955
  except Exception as e:
3563
- print("Lock error:", e)
3956
+ print_red(f"Lock error: {e}")
3564
3957
  return False
3565
3958
  return False
3566
3959
 
@@ -3598,11 +3991,11 @@ def _add_to_csv_rewrite_file(file_path: str, rows: List[list], existing_heading:
3598
3991
  writer = csv.writer(tmp_file)
3599
3992
  writer.writerow(all_headings)
3600
3993
  for row in rows[1:]:
3601
- tmp_file.writerow([
3994
+ tmp_file.writerow([ # type: ignore[attr-defined]
3602
3995
  row[existing_heading.index(h)] if h in existing_heading else ""
3603
3996
  for h in all_headings
3604
3997
  ])
3605
- tmp_file.writerow([
3998
+ tmp_file.writerow([ # type: ignore[attr-defined]
3606
3999
  formatted_data[new_heading.index(h)] if h in new_heading else ""
3607
4000
  for h in all_headings
3608
4001
  ])
@@ -3633,12 +4026,12 @@ def find_file_paths(_text: str) -> List[str]:
3633
4026
  def check_file_info(file_path: str) -> str:
3634
4027
  if not os.path.exists(file_path):
3635
4028
  if not args.tests:
3636
- print(f"check_file_info: The file {file_path} does not exist.")
4029
+ print_red(f"check_file_info: The file {file_path} does not exist.")
3637
4030
  return ""
3638
4031
 
3639
4032
  if not os.access(file_path, os.R_OK):
3640
4033
  if not args.tests:
3641
- print(f"check_file_info: The file {file_path} is not readable.")
4034
+ print_red(f"check_file_info: The file {file_path} is not readable.")
3642
4035
  return ""
3643
4036
 
3644
4037
  file_stat = os.stat(file_path)
@@ -3698,7 +4091,7 @@ def write_failed_logs(data_dict: Optional[dict], error_description: str = "") ->
3698
4091
  data = [list(data_dict.values())]
3699
4092
  else:
3700
4093
  print_debug("No data_dict provided, writing only error description.")
3701
- data = [[]] # leeres Datenfeld, nur error_description kommt dazu
4094
+ data = [[]]
3702
4095
 
3703
4096
  if error_description:
3704
4097
  headers.append('error_description')
@@ -3752,7 +4145,7 @@ def count_defective_nodes(file_path: Union[str, None] = None, entry: Any = None)
3752
4145
  return sorted(set(entries))
3753
4146
 
3754
4147
  except Exception as e:
3755
- print(f"An error has occurred: {e}")
4148
+ print_red(f"An error has occurred: {e}")
3756
4149
  return []
3757
4150
 
3758
4151
  def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]:
@@ -3763,7 +4156,7 @@ def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]
3763
4156
 
3764
4157
  fool_linter(tmp)
3765
4158
  except RuntimeError:
3766
- print(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
4159
+ print_red(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
3767
4160
  count_defective_nodes(None, socket.gethostname())
3768
4161
  return return_in_case_of_error
3769
4162
  except Exception:
@@ -3937,7 +4330,7 @@ def _write_job_infos_csv_build_headline(parameters_keys: List[str], extra_vars_n
3937
4330
  "run_time",
3938
4331
  "program_string",
3939
4332
  *parameters_keys,
3940
- *arg_result_names, # arg_result_names muss global definiert sein
4333
+ *arg_result_names,
3941
4334
  "exit_code",
3942
4335
  "signal",
3943
4336
  "hostname",
@@ -4234,6 +4627,8 @@ def evaluate(parameters_with_trial_index: dict) -> Optional[Union[int, float, Di
4234
4627
  trial_index = parameters_with_trial_index["trial_idx"]
4235
4628
  submit_time = parameters_with_trial_index["submit_time"]
4236
4629
 
4630
+ print(f'Trial-Index: {trial_index}')
4631
+
4237
4632
  queue_time = abs(int(time.time()) - int(submit_time))
4238
4633
 
4239
4634
  start_nvidia_smi_thread()
@@ -4471,7 +4866,7 @@ def replace_string_with_params(input_string: str, params: list) -> str:
4471
4866
  return replaced_string
4472
4867
  except AssertionError as e:
4473
4868
  error_text = f"Error in replace_string_with_params: {e}"
4474
- print(error_text)
4869
+ print_red(error_text)
4475
4870
  raise
4476
4871
 
4477
4872
  return ""
@@ -4748,9 +5143,7 @@ def get_sixel_graphics_data(_pd_csv: str, _force: bool = False) -> list:
4748
5143
  _params = [_command, plot, _tmp, plot_type, tmp_file, _width]
4749
5144
  data.append(_params)
4750
5145
  except Exception as e:
4751
- tb = traceback.format_exc()
4752
- print_red(f"Error trying to print {plot_type} to CLI: {e}, {tb}")
4753
- print_debug(f"Error trying to print {plot_type} to CLI: {e}")
5146
+ print_red(f"Error trying to print {plot_type} to CLI: {e}")
4754
5147
 
4755
5148
  return data
4756
5149
 
@@ -4858,7 +5251,7 @@ def process_best_result(res_name: str, print_to_file: bool) -> int:
4858
5251
 
4859
5252
  if str(best_result) in [NO_RESULT, None, "None"]:
4860
5253
  print_red(f"Best {res_name} could not be determined")
4861
- return 87
5254
+ return 87 # exit-code: 87
4862
5255
 
4863
5256
  total_str = f"total: {_count_done_jobs(RESULT_CSV_FILE) - NR_INSERTED_JOBS}"
4864
5257
  if NR_INSERTED_JOBS:
@@ -4941,15 +5334,17 @@ def abandon_job(job: Job, trial_index: int, reason: str) -> bool:
4941
5334
  if job:
4942
5335
  try:
4943
5336
  if ax_client:
4944
- _trial = ax_client.get_trial(trial_index)
4945
- _trial.mark_abandoned(reason=reason)
5337
+ _trial = get_ax_client_trial(trial_index)
5338
+ if _trial is None:
5339
+ return False
5340
+
5341
+ mark_abandoned(_trial, reason, trial_index)
4946
5342
  print_debug(f"abandon_job: removing job {job}, trial_index: {trial_index}")
4947
5343
  global_vars["jobs"].remove((job, trial_index))
4948
5344
  else:
4949
- _fatal_error("ax_client could not be found", 9)
5345
+ _fatal_error("ax_client could not be found", 101)
4950
5346
  except Exception as e:
4951
- print(f"ERROR in line {get_line_info()}: {e}")
4952
- print_debug(f"ERROR in line {get_line_info()}: {e}")
5347
+ print_red(f"ERROR in line {get_line_info()}: {e}")
4953
5348
  return False
4954
5349
  job.cancel()
4955
5350
  return True
@@ -4962,23 +5357,120 @@ def abandon_all_jobs() -> None:
4962
5357
  if not abandoned:
4963
5358
  print_debug(f"Job {job} could not be abandoned.")
4964
5359
 
4965
- def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
4966
- if args.dryrun:
4967
- print_debug("Not showing Pareto-frontier data with --dryrun")
5360
+ def write_result_to_trace_file(res: str) -> bool:
5361
+ if res is None:
5362
+ sys.stderr.write("Provided result is None, nothing to write\n")
5363
+ return False
5364
+
5365
+ target_folder = get_current_run_folder()
5366
+ target_file = os.path.join(target_folder, "optimization_trace.html")
5367
+
5368
+ try:
5369
+ file_handle = open(target_file, "w", encoding="utf-8")
5370
+ except OSError as error:
5371
+ sys.stderr.write("Unable to open target file for writing\n")
5372
+ sys.stderr.write(str(error) + "\n")
5373
+ return False
5374
+
5375
+ try:
5376
+ written = file_handle.write(str(res))
5377
+ file_handle.flush()
5378
+
5379
+ if written == 0:
5380
+ sys.stderr.write("No data was written to the file\n")
5381
+ file_handle.close()
5382
+ return False
5383
+ except Exception as error:
5384
+ sys.stderr.write("Error occurred while writing to file\n")
5385
+ sys.stderr.write(str(error) + "\n")
5386
+ file_handle.close()
5387
+ return False
5388
+
5389
+ try:
5390
+ file_handle.close()
5391
+ except Exception as error:
5392
+ sys.stderr.write("Failed to properly close file\n")
5393
+ sys.stderr.write(str(error) + "\n")
5394
+ return False
5395
+
5396
+ return True
5397
+
5398
+ def render(plot_config: AxPlotConfig) -> None:
5399
+ if plot_config is None or "data" not in plot_config:
4968
5400
  return None
4969
5401
 
4970
- if len(res_names) > 1:
4971
- try:
4972
- show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
4973
- except Exception as e:
4974
- print_red(f"show_pareto_frontier_data() failed with exception '{e}'")
4975
- else:
4976
- print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
5402
+ res: str = plot_config.data # type: ignore
5403
+
5404
+ repair_funcs = """
5405
+ function decodeBData(obj) {
5406
+ if (!obj || typeof obj !== "object") {
5407
+ return obj;
5408
+ }
5409
+
5410
+ if (obj.bdata && obj.dtype) {
5411
+ var binary_string = atob(obj.bdata);
5412
+ var len = binary_string.length;
5413
+ var bytes = new Uint8Array(len);
5414
+
5415
+ for (var i = 0; i < len; i++) {
5416
+ bytes[i] = binary_string.charCodeAt(i);
5417
+ }
5418
+
5419
+ switch (obj.dtype) {
5420
+ case "i1": return Array.from(new Int8Array(bytes.buffer));
5421
+ case "i2": return Array.from(new Int16Array(bytes.buffer));
5422
+ case "i4": return Array.from(new Int32Array(bytes.buffer));
5423
+ case "f4": return Array.from(new Float32Array(bytes.buffer));
5424
+ case "f8": return Array.from(new Float64Array(bytes.buffer));
5425
+ default:
5426
+ console.error("Unknown dtype:", obj.dtype);
5427
+ return [];
5428
+ }
5429
+ }
5430
+
5431
+ return obj;
5432
+ }
5433
+
5434
+ function repairTraces(traces) {
5435
+ var fixed = [];
5436
+
5437
+ for (var i = 0; i < traces.length; i++) {
5438
+ var t = traces[i];
5439
+
5440
+ if (t.x) {
5441
+ t.x = decodeBData(t.x);
5442
+ }
5443
+
5444
+ if (t.y) {
5445
+ t.y = decodeBData(t.y);
5446
+ }
5447
+
5448
+ fixed.push(t);
5449
+ }
5450
+
5451
+ return fixed;
5452
+ }
5453
+ """
5454
+
5455
+ res = str(res)
5456
+
5457
+ res = f"<div id='plot' style='width:100%;height:600px;'></div>\n<script type='text/javascript' src='https://cdn.plot.ly/plotly-latest.min.js'></script><script>{repair_funcs}\nconst True = true;\nconst False = false;\nconst data = {res};\ndata.data = repairTraces(data.data);\nPlotly.newPlot(document.getElementById('plot'), data.data, data.layout);</script>"
5458
+
5459
+ write_result_to_trace_file(res)
5460
+
4977
5461
  return None
4978
5462
 
4979
5463
  def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None) -> None:
4980
5464
  global END_PROGRAM_RAN
4981
5465
 
5466
+ #dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.training_data[0].X)
5467
+ #dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.training_data[0].Y)
5468
+ #dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.outcomes)
5469
+
5470
+ if ax_client is not None:
5471
+ if len(arg_result_names) == 1:
5472
+ render(ax_client.get_optimization_trace())
5473
+
4982
5474
  wait_for_jobs_to_complete()
4983
5475
 
4984
5476
  show_pareto_or_error_msg(get_current_run_folder(), arg_result_names)
@@ -5012,7 +5504,7 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
5012
5504
  _exit = new_exit
5013
5505
  except (SignalUSR, SignalINT, SignalCONT, KeyboardInterrupt):
5014
5506
  print_red("\n⚠ You pressed CTRL+C or a signal was sent. Program execution halted while ending program.")
5015
- print("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
5507
+ print_red("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
5016
5508
  new_exit = show_end_table_and_save_end_files()
5017
5509
  if new_exit > 0:
5018
5510
  _exit = new_exit
@@ -5021,8 +5513,6 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
5021
5513
 
5022
5514
  abandon_all_jobs()
5023
5515
 
5024
- save_results_csv()
5025
-
5026
5516
  if exit_code:
5027
5517
  _exit = exit_code
5028
5518
 
@@ -5035,21 +5525,31 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
5035
5525
 
5036
5526
  my_exit(_exit)
5037
5527
 
5528
+ def save_ax_client_to_json_file(checkpoint_filepath: str) -> None:
5529
+ if not ax_client:
5530
+ my_exit(101)
5531
+
5532
+ return None
5533
+
5534
+ ax_client.save_to_json_file(checkpoint_filepath)
5535
+
5536
+ return None
5537
+
5038
5538
  def save_checkpoint(trial_nr: int = 0, eee: Union[None, str, Exception] = None) -> None:
5039
5539
  if trial_nr > 3:
5040
5540
  if eee:
5041
- print(f"Error during saving checkpoint: {eee}")
5541
+ print_red(f"Error during saving checkpoint: {eee}")
5042
5542
  else:
5043
- print("Error during saving checkpoint")
5543
+ print_red("Error during saving checkpoint")
5044
5544
  return
5045
5545
 
5046
5546
  try:
5047
5547
  checkpoint_filepath = get_state_file_name('checkpoint.json')
5048
5548
 
5049
5549
  if ax_client:
5050
- ax_client.save_to_json_file(filepath=checkpoint_filepath)
5550
+ save_ax_client_to_json_file(checkpoint_filepath)
5051
5551
  else:
5052
- _fatal_error("Something went wrong using the ax_client", 9)
5552
+ _fatal_error("Something went wrong using the ax_client", 101)
5053
5553
  except Exception as e:
5054
5554
  save_checkpoint(trial_nr + 1, e)
5055
5555
 
@@ -5210,7 +5710,7 @@ def parse_equation_item(comparer_found: bool, item: str, parsed: list, parsed_or
5210
5710
  })
5211
5711
  elif item in [">=", "<="]:
5212
5712
  if comparer_found:
5213
- print("There is already one comparison operator! Cannot have more than one in an equation!")
5713
+ print_red("There is already one comparison operator! Cannot have more than one in an equation!")
5214
5714
  return_totally = True
5215
5715
  comparer_found = True
5216
5716
 
@@ -5421,20 +5921,9 @@ def check_equation(variables: list, equation: str) -> Union[str, bool]:
5421
5921
  def set_objectives() -> dict:
5422
5922
  objectives = {}
5423
5923
 
5424
- for rn in args.result_names:
5425
- key, value = "", ""
5426
-
5427
- if "=" in rn:
5428
- key, value = rn.split('=', 1)
5429
- else:
5430
- key = rn
5431
- value = ""
5432
-
5433
- if value not in ["min", "max"]:
5434
- if value:
5435
- print_yellow(f"Value '{value}' for --result_names {rn} is not a valid value. Must be min or max. Will be set to min.")
5436
-
5437
- value = "min"
5924
+ k = 0
5925
+ for key in arg_result_names:
5926
+ value = arg_result_min_or_max[k]
5438
5927
 
5439
5928
  _min = True
5440
5929
 
@@ -5442,12 +5931,18 @@ def set_objectives() -> dict:
5442
5931
  _min = False
5443
5932
 
5444
5933
  objectives[key] = ObjectiveProperties(minimize=_min)
5934
+ k = k + 1
5445
5935
 
5446
5936
  return objectives
5447
5937
 
5448
- def set_experiment_constraints(experiment_constraints: Optional[list], experiment_args: dict, _experiment_parameters: Union[dict, list]) -> dict:
5449
- if experiment_constraints and len(experiment_constraints):
5938
+ def set_experiment_constraints(experiment_constraints: Optional[list], experiment_args: dict, _experiment_parameters: Optional[Union[dict, list]]) -> dict:
5939
+ if _experiment_parameters is None:
5940
+ print_red("set_experiment_constraints: _experiment_parameters was None")
5941
+ my_exit(95)
5942
+
5943
+ return {}
5450
5944
 
5945
+ if experiment_constraints and len(experiment_constraints):
5451
5946
  experiment_args["parameter_constraints"] = []
5452
5947
 
5453
5948
  if experiment_constraints:
@@ -5476,11 +5971,15 @@ def set_experiment_constraints(experiment_constraints: Optional[list], experimen
5476
5971
 
5477
5972
  return experiment_args
5478
5973
 
5479
- def replace_parameters_for_continued_jobs(parameter: Optional[list], cli_params_experiment_parameters: Optional[list]) -> None:
5974
+ def replace_parameters_for_continued_jobs(parameter: Optional[list], cli_params_experiment_parameters: Optional[dict | list]) -> None:
5975
+ if not experiment_parameters:
5976
+ print_red("replace_parameters_for_continued_jobs: experiment_parameters was False")
5977
+ return None
5978
+
5480
5979
  if args.worker_generator_path:
5481
5980
  return None
5482
5981
 
5483
- def get_name(obj) -> Optional[str]:
5982
+ def get_name(obj: Any) -> Optional[str]:
5484
5983
  """Extract a parameter name from dict, list, or tuple safely."""
5485
5984
  if isinstance(obj, dict):
5486
5985
  return obj.get("name")
@@ -5562,13 +6061,13 @@ def copy_continue_uuid() -> None:
5562
6061
  print_debug(f"copy_continue_uuid: Source file does not exist: {source_file}")
5563
6062
 
5564
6063
  def load_ax_client_from_experiment_parameters() -> None:
5565
- #pprint(experiment_parameters)
5566
- global ax_client
6064
+ if experiment_parameters:
6065
+ global ax_client
5567
6066
 
5568
- tmp_file_path = get_tmp_file_from_json(experiment_parameters)
5569
- ax_client = AxClient.load_from_json_file(tmp_file_path)
5570
- ax_client = cast(AxClient, ax_client)
5571
- os.unlink(tmp_file_path)
6067
+ tmp_file_path = get_tmp_file_from_json(experiment_parameters)
6068
+ ax_client = AxClient.load_from_json_file(tmp_file_path)
6069
+ ax_client = cast(AxClient, ax_client)
6070
+ os.unlink(tmp_file_path)
5572
6071
 
5573
6072
  def save_checkpoint_for_continued() -> None:
5574
6073
  checkpoint_filepath = get_state_file_name('checkpoint.json')
@@ -5580,12 +6079,15 @@ def save_checkpoint_for_continued() -> None:
5580
6079
  _fatal_error(f"{checkpoint_filepath} not found. Cannot continue_previous_job without.", 47)
5581
6080
 
5582
6081
  def load_original_generation_strategy(original_ax_client_file: str) -> None:
5583
- with open(original_ax_client_file, encoding="utf-8") as f:
5584
- loaded_original_ax_client_json = json.load(f)
5585
- original_generation_strategy = loaded_original_ax_client_json["generation_strategy"]
6082
+ if experiment_parameters:
6083
+ with open(original_ax_client_file, encoding="utf-8") as f:
6084
+ loaded_original_ax_client_json = json.load(f)
6085
+ original_generation_strategy = loaded_original_ax_client_json["generation_strategy"]
5586
6086
 
5587
- if original_generation_strategy:
5588
- experiment_parameters["generation_strategy"] = original_generation_strategy
6087
+ if original_generation_strategy:
6088
+ experiment_parameters["generation_strategy"] = original_generation_strategy
6089
+ else:
6090
+ print_red("load_original_generation_strategy: experiment_parameters was empty!")
5589
6091
 
5590
6092
  def wait_for_checkpoint_file(checkpoint_file: str) -> None:
5591
6093
  start_time = time.time()
@@ -5598,10 +6100,6 @@ def wait_for_checkpoint_file(checkpoint_file: str) -> None:
5598
6100
  elapsed = int(time.time() - start_time)
5599
6101
  console.print(f"[green]Checkpoint file found after {elapsed} seconds[/green] ")
5600
6102
 
5601
- def __get_experiment_parameters__check_ax_client() -> None:
5602
- if not ax_client:
5603
- _fatal_error("Something went wrong with the ax_client", 9)
5604
-
5605
6103
  def validate_experiment_parameters() -> None:
5606
6104
  if experiment_parameters is None:
5607
6105
  print_red("Error: experiment_parameters is None.")
@@ -5611,6 +6109,8 @@ def validate_experiment_parameters() -> None:
5611
6109
  print_red(f"Error: experiment_parameters is not a dict: {type(experiment_parameters).__name__}")
5612
6110
  my_exit(95)
5613
6111
 
6112
+ sys.exit(95)
6113
+
5614
6114
  path_checks = [
5615
6115
  ("experiment", experiment_parameters),
5616
6116
  ("search_space", experiment_parameters.get("experiment")),
@@ -5622,7 +6122,12 @@ def validate_experiment_parameters() -> None:
5622
6122
  print_red(f"Error: Missing key '{key}' at level: {current_level}")
5623
6123
  my_exit(95)
5624
6124
 
5625
- def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str, cli_params_experiment_parameters: Optional[list]) -> Tuple[Any, str, str]:
6125
+ def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_parameters: Optional[dict | list]) -> Tuple[Any, str, str]:
6126
+ if not ax_client:
6127
+ print_red("load_from_checkpoint: ax_client was None")
6128
+ my_exit(101)
6129
+ return {}, "", ""
6130
+
5626
6131
  print_debug(f"Load from checkpoint: {continue_previous_job}")
5627
6132
 
5628
6133
  checkpoint_file = f"{continue_previous_job}/state_files/checkpoint.json"
@@ -5648,7 +6153,7 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
5648
6153
 
5649
6154
  replace_parameters_for_continued_jobs(args.parameter, cli_params_experiment_parameters)
5650
6155
 
5651
- ax_client.save_to_json_file(filepath=original_ax_client_file)
6156
+ save_ax_client_to_json_file(original_ax_client_file)
5652
6157
 
5653
6158
  load_original_generation_strategy(original_ax_client_file)
5654
6159
  load_ax_client_from_experiment_parameters()
@@ -5664,6 +6169,12 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
5664
6169
 
5665
6170
  experiment_constraints = get_constraints()
5666
6171
  if experiment_constraints:
6172
+
6173
+ if not experiment_parameters:
6174
+ print_red("load_from_checkpoint: experiment_parameters was None")
6175
+
6176
+ return {}, "", ""
6177
+
5667
6178
  experiment_args = set_experiment_constraints(
5668
6179
  experiment_constraints,
5669
6180
  experiment_args,
@@ -5672,7 +6183,116 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
5672
6183
 
5673
6184
  return experiment_args, gpu_string, gpu_color
5674
6185
 
5675
- def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str]:
6186
+ def get_experiment_args_import_python_script() -> str:
6187
+
6188
+ return """from ax.service.ax_client import AxClient, ObjectiveProperties
6189
+ from ax.adapter.registry import Generators
6190
+ import random
6191
+
6192
+ """
6193
+
6194
+ def get_generate_and_test_random_function_str() -> str:
6195
+ raw_data_entries = ",\n ".join(
6196
+ f'"{name}": random.uniform(0, 1)' for name in arg_result_names
6197
+ )
6198
+
6199
+ return f"""
6200
+ def generate_and_test_random_parameters(n: int) -> None:
6201
+ for _ in range(n):
6202
+ print("======================================")
6203
+ parameters, trial_index = ax_client.get_next_trial()
6204
+ print("Trial Index:", trial_index)
6205
+ print("Suggested parameters:", parameters)
6206
+
6207
+ ax_client.complete_trial(
6208
+ trial_index=trial_index,
6209
+ raw_data={{
6210
+ {raw_data_entries}
6211
+ }}
6212
+ )
6213
+
6214
+ generate_and_test_random_parameters({args.num_random_steps + 1})
6215
+ """
6216
+
6217
+ def get_global_gs_string() -> str:
6218
+ seed_str = ""
6219
+ if args.seed is not None:
6220
+ seed_str = f"model_kwargs={{'seed': {args.seed}}},"
6221
+
6222
+ return f"""from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy
6223
+
6224
+ global_gs = GenerationStrategy(
6225
+ steps=[
6226
+ GenerationStep(
6227
+ generator=Generators.SOBOL,
6228
+ num_trials={args.num_random_steps},
6229
+ max_parallelism=5,
6230
+ {seed_str}
6231
+ ),
6232
+ GenerationStep(
6233
+ generator=Generators.{args.model},
6234
+ num_trials=-1,
6235
+ max_parallelism=5,
6236
+ ),
6237
+ ]
6238
+ )
6239
+ """
6240
+
6241
+ def get_debug_ax_client_str() -> str:
6242
+ return """
6243
+ ax_client = AxClient(
6244
+ verbose_logging=True,
6245
+ enforce_sequential_optimization=False,
6246
+ generation_strategy=global_gs
6247
+ )
6248
+ """
6249
+
6250
+ def write_ax_debug_python_code(experiment_args: dict) -> None:
6251
+ if args.generation_strategy:
6252
+ print_debug("Cannot write debug code for custom generation_strategy")
6253
+ return None
6254
+
6255
+ if args.model in uncontinuable_models:
6256
+ print_debug(f"Cannot write debug code for uncontinuable mode {args.model}")
6257
+ return None
6258
+
6259
+ python_code = python_code = get_experiment_args_import_python_script() + \
6260
+ get_global_gs_string() + \
6261
+ get_debug_ax_client_str() + \
6262
+ "experiment_args = " + pformat(experiment_args, width=120, compact=False) + \
6263
+ "\nax_client.create_experiment(**experiment_args)\n" + \
6264
+ get_generate_and_test_random_function_str()
6265
+
6266
+ file_path = f"{get_current_run_folder()}/debug.py"
6267
+
6268
+ try:
6269
+ print_debug(python_code)
6270
+ with open(file_path, "w", encoding="utf-8") as f:
6271
+ f.write(python_code)
6272
+ except Exception as e:
6273
+ print_red(f"Error while writing {file_path}: {e}")
6274
+
6275
+ return None
6276
+
6277
+ def create_ax_client_experiment(experiment_args: dict) -> None:
6278
+ if not ax_client:
6279
+ my_exit(101)
6280
+
6281
+ return None
6282
+
6283
+ write_ax_debug_python_code(experiment_args)
6284
+
6285
+ ax_client.create_experiment(**experiment_args)
6286
+
6287
+ return None
6288
+
6289
+ def create_new_experiment() -> Tuple[dict, str, str]:
6290
+ if ax_client is None:
6291
+ print_red("create_new_experiment: ax_client is None")
6292
+ my_exit(101)
6293
+
6294
+ return {}, "", ""
6295
+
5676
6296
  objectives = set_objectives()
5677
6297
 
5678
6298
  experiment_args = {
@@ -5695,7 +6315,7 @@ def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str
5695
6315
  experiment_args = set_experiment_constraints(get_constraints(), experiment_args, experiment_parameters)
5696
6316
 
5697
6317
  try:
5698
- ax_client.create_experiment(**experiment_args)
6318
+ create_ax_client_experiment(experiment_args)
5699
6319
  new_metrics = [Metric(k) for k in arg_result_names if k not in ax_client.metric_names]
5700
6320
  ax_client.experiment.add_tracking_metrics(new_metrics)
5701
6321
  except AssertionError as error:
@@ -5709,17 +6329,17 @@ def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str
5709
6329
 
5710
6330
  return experiment_args, gpu_string, gpu_color
5711
6331
 
5712
- def get_experiment_parameters(cli_params_experiment_parameters: Optional[list]) -> Optional[Tuple[AxClient, dict, str, str]]:
6332
+ def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[dict, str, str]]:
5713
6333
  continue_previous_job = args.worker_generator_path or args.continue_previous_job
5714
6334
 
5715
- __get_experiment_parameters__check_ax_client()
6335
+ check_ax_client()
5716
6336
 
5717
6337
  if continue_previous_job:
5718
- experiment_args, gpu_string, gpu_color = __get_experiment_parameters__load_from_checkpoint(continue_previous_job, cli_params_experiment_parameters)
6338
+ experiment_args, gpu_string, gpu_color = load_from_checkpoint(continue_previous_job, cli_params_experiment_parameters)
5719
6339
  else:
5720
- experiment_args, gpu_string, gpu_color = __get_experiment_parameters__create_new_experiment()
6340
+ experiment_args, gpu_string, gpu_color = create_new_experiment()
5721
6341
 
5722
- return ax_client, experiment_args, gpu_string, gpu_color
6342
+ return experiment_args, gpu_string, gpu_color
5723
6343
 
5724
6344
  def get_type_short(typename: str) -> str:
5725
6345
  if typename == "RangeParameter":
@@ -5778,7 +6398,6 @@ def parse_single_experiment_parameter_table(classic_params: Optional[Union[list,
5778
6398
  _upper = param["bounds"][1]
5779
6399
 
5780
6400
  _possible_int_lower = str(helpers.to_int_when_possible(_lower))
5781
- #print(f"name: {_name}, _possible_int_lower: {_possible_int_lower}, lower: {_lower}")
5782
6401
  _possible_int_upper = str(helpers.to_int_when_possible(_upper))
5783
6402
 
5784
6403
  rows.append([_name, _short_type, _possible_int_lower, _possible_int_upper, "", value_type, log_scale])
@@ -5855,19 +6474,62 @@ def print_ax_parameter_constraints_table(experiment_args: dict) -> None:
5855
6474
 
5856
6475
  return None
5857
6476
 
6477
+ def check_base_for_print_overview() -> Optional[bool]:
6478
+ if args.continue_previous_job is not None and arg_result_names is not None and len(arg_result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
6479
+ print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
6480
+
6481
+ if ax_client is None:
6482
+ print_red("ax_client was None")
6483
+ return None
6484
+
6485
+ if ax_client.experiment is None:
6486
+ print_red("ax_client.experiment was None")
6487
+ return None
6488
+
6489
+ if ax_client.experiment.optimization_config is None:
6490
+ print_red("ax_client.experiment.optimization_config was None")
6491
+ return None
6492
+
6493
+ return True
6494
+
6495
+ def get_config_objectives() -> Any:
6496
+ if not ax_client:
6497
+ print_red("create_new_experiment: ax_client is None")
6498
+ my_exit(101)
6499
+
6500
+ return None
6501
+
6502
+ config_objectives = None
6503
+
6504
+ if ax_client.experiment and ax_client.experiment.optimization_config:
6505
+ opt_config = ax_client.experiment.optimization_config
6506
+ if opt_config.is_moo_problem:
6507
+ objective = getattr(opt_config, "objective", None)
6508
+ if objective and getattr(objective, "objectives", None) is not None:
6509
+ config_objectives = objective.objectives
6510
+ else:
6511
+ print_debug("ax_client.experiment.optimization_config.objective was None")
6512
+ else:
6513
+ config_objectives = [opt_config.objective]
6514
+ else:
6515
+ print_debug("ax_client.experiment or optimization_config was None")
6516
+
6517
+ return config_objectives
6518
+
5858
6519
  def print_result_names_overview_table() -> None:
5859
6520
  if not ax_client:
5860
6521
  _fatal_error("Tried to access ax_client in print_result_names_overview_table, but it failed, because the ax_client was not defined.", 101)
5861
6522
 
5862
6523
  return None
5863
6524
 
5864
- if args.continue_previous_job is not None and args.result_names is not None and len(args.result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
5865
- print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
6525
+ if check_base_for_print_overview() is None:
6526
+ return None
5866
6527
 
5867
- if ax_client.experiment.optimization_config.is_moo_problem:
5868
- config_objectives = ax_client.experiment.optimization_config.objective.objectives
5869
- else:
5870
- config_objectives = [ax_client.experiment.optimization_config.objective]
6528
+ config_objectives = get_config_objectives()
6529
+
6530
+ if config_objectives is None:
6531
+ print_red("config_objectives not found")
6532
+ return None
5871
6533
 
5872
6534
  res_names = []
5873
6535
  res_min_max = []
@@ -5962,28 +6624,31 @@ def print_overview_tables(classic_params: Optional[Union[list, dict]], experimen
5962
6624
  print_result_names_overview_table()
5963
6625
 
5964
6626
  def update_progress_bar(nr: int) -> None:
5965
- try:
5966
- progress_bar.update(nr)
5967
- except Exception as e:
5968
- print(f"Error updating progress bar: {e}")
6627
+ log_data()
6628
+
6629
+ if progress_bar is not None:
6630
+ try:
6631
+ progress_bar.update(nr)
6632
+ except Exception as e:
6633
+ print_red(f"Error updating progress bar: {e}")
6634
+ else:
6635
+ print_red("update_progress_bar: progress_bar was None")
5969
6636
 
5970
6637
  def get_current_model_name() -> str:
5971
6638
  if overwritten_to_random:
5972
6639
  return "Random*"
5973
6640
 
6641
+ gs_model = "unknown model"
6642
+
5974
6643
  if ax_client:
5975
6644
  try:
5976
6645
  if args.generation_strategy:
5977
- idx = getattr(ax_client.generation_strategy, "current_step_index", None)
6646
+ idx = getattr(global_gs, "current_step_index", None)
5978
6647
  if isinstance(idx, int):
5979
6648
  if 0 <= idx < len(generation_strategy_names):
5980
6649
  gs_model = generation_strategy_names[int(idx)]
5981
- else:
5982
- gs_model = "unknown model"
5983
- else:
5984
- gs_model = "unknown model"
5985
6650
  else:
5986
- gs_model = getattr(ax_client.generation_strategy, "current_node_name", "unknown model")
6651
+ gs_model = getattr(global_gs, "current_node_name", "unknown model")
5987
6652
 
5988
6653
  if gs_model:
5989
6654
  return str(gs_model)
@@ -6040,7 +6705,7 @@ def get_workers_string() -> str:
6040
6705
  )
6041
6706
 
6042
6707
  def _get_workers_string_collect_stats() -> dict:
6043
- stats = {}
6708
+ stats: dict = {}
6044
6709
  for job, _ in global_vars["jobs"][:]:
6045
6710
  state = state_from_job(job)
6046
6711
  stats[state] = stats.get(state, 0) + 1
@@ -6089,8 +6754,8 @@ def submitted_jobs(nr: int = 0) -> int:
6089
6754
  def count_jobs_in_squeue() -> tuple[int, str]:
6090
6755
  global _last_count_time, _last_count_result
6091
6756
 
6092
- now = time.time()
6093
- if _last_count_result != (0, "") and now - _last_count_time < 15:
6757
+ now = int(time.time())
6758
+ if _last_count_result != (0, "") and now - _last_count_time < 5:
6094
6759
  return _last_count_result
6095
6760
 
6096
6761
  _len = len(global_vars["jobs"])
@@ -6112,6 +6777,7 @@ def count_jobs_in_squeue() -> tuple[int, str]:
6112
6777
  check=True,
6113
6778
  text=True
6114
6779
  )
6780
+
6115
6781
  if "slurm_load_jobs error" in result.stderr:
6116
6782
  _last_count_result = (_len, "Detected slurm_load_jobs error in stderr.")
6117
6783
  _last_count_time = now
@@ -6152,6 +6818,8 @@ def log_worker_numbers() -> None:
6152
6818
  if len(WORKER_PERCENTAGE_USAGE) == 0 or WORKER_PERCENTAGE_USAGE[len(WORKER_PERCENTAGE_USAGE) - 1] != this_values:
6153
6819
  WORKER_PERCENTAGE_USAGE.append(this_values)
6154
6820
 
6821
+ write_worker_usage()
6822
+
6155
6823
  def get_slurm_in_brackets(in_brackets: list) -> list:
6156
6824
  if is_slurm_job():
6157
6825
  workers_strings = get_workers_string()
@@ -6236,6 +6904,8 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
6236
6904
  global last_progress_bar_desc
6237
6905
  global last_progress_bar_refresh_time
6238
6906
 
6907
+ log_data()
6908
+
6239
6909
  if isinstance(new_msgs, str):
6240
6910
  new_msgs = [new_msgs]
6241
6911
 
@@ -6253,7 +6923,7 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
6253
6923
  print_red("Cannot update progress bar! It is None.")
6254
6924
 
6255
6925
  def clean_completed_jobs() -> None:
6256
- job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail"]
6926
+ job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail", "finished"]
6257
6927
  job_states_to_be_ignored = ["ready", "completed", "unknown", "pending", "running", "completing", "out_of_memory", "requeued", "resv_del_hold"]
6258
6928
 
6259
6929
  for job, trial_index in global_vars["jobs"][:]:
@@ -6311,7 +6981,7 @@ def load_existing_job_data_into_ax_client() -> None:
6311
6981
  nr_of_imported_jobs = get_nr_of_imported_jobs()
6312
6982
  set_nr_inserted_jobs(NR_INSERTED_JOBS + nr_of_imported_jobs)
6313
6983
 
6314
- def parse_parameter_type_error(_error_message: Union[str, None]) -> Optional[dict]:
6984
+ def parse_parameter_type_error(_error_message: Union[Exception, str, None]) -> Optional[dict]:
6315
6985
  if not _error_message:
6316
6986
  return None
6317
6987
 
@@ -6378,7 +7048,7 @@ def get_generation_node_for_index(
6378
7048
  results_list: List[Dict[str, Any]],
6379
7049
  index: int,
6380
7050
  __status: Any,
6381
- base_str: str
7051
+ base_str: Optional[str]
6382
7052
  ) -> str:
6383
7053
  __status.update(f"{base_str}: Getting generation node")
6384
7054
  try:
@@ -6398,7 +7068,7 @@ def get_generation_node_for_index(
6398
7068
 
6399
7069
  return generation_node
6400
7070
  except Exception as e:
6401
- print(f"Error while get_generation_node_for_index: {e}")
7071
+ print_red(f"Error while get_generation_node_for_index: {e}")
6402
7072
  return "MANUAL"
6403
7073
 
6404
7074
  def _get_generation_node_for_index_index_valid(
@@ -6461,111 +7131,154 @@ def _get_generation_node_for_index_floats_match(
6461
7131
  return False
6462
7132
  return abs(row_val_num - val) <= tolerance
6463
7133
 
7134
+ def validate_and_convert_params_for_jobs_from_csv(arm_params: Dict) -> Dict:
7135
+ corrected_params: Dict[Any, Any] = {}
7136
+
7137
+ if experiment_parameters is not None:
7138
+ for param in experiment_parameters:
7139
+ name = param["name"]
7140
+ expected_type = param.get("value_type", "str")
7141
+
7142
+ if name not in arm_params:
7143
+ continue
7144
+
7145
+ value = arm_params[name]
7146
+
7147
+ try:
7148
+ if param["type"] == "range":
7149
+ if expected_type == "int":
7150
+ corrected_params[name] = int(value)
7151
+ elif expected_type == "float":
7152
+ corrected_params[name] = float(value)
7153
+ elif param["type"] == "choice":
7154
+ corrected_params[name] = str(value)
7155
+ except (ValueError, TypeError):
7156
+ corrected_params[name] = None
7157
+
7158
+ return corrected_params
7159
+
6464
7160
  def insert_jobs_from_csv(this_csv_file_path: str) -> None:
6465
7161
  with spinner(f"Inserting job into CSV from {this_csv_file_path}") as __status:
6466
- this_csv_file_path = this_csv_file_path.replace("//", "/")
7162
+ this_csv_file_path = normalize_path(this_csv_file_path)
6467
7163
 
6468
- if not os.path.exists(this_csv_file_path):
7164
+ if not helpers.file_exists(this_csv_file_path):
6469
7165
  print_red(f"--load_data_from_existing_jobs: Cannot find {this_csv_file_path}")
6470
-
6471
7166
  return
6472
7167
 
6473
- def validate_and_convert_params(arm_params: Dict) -> Dict:
6474
- corrected_params: Dict[Any, Any] = {}
7168
+ arm_params_list, results_list = parse_csv(this_csv_file_path)
7169
+ insert_jobs_from_lists(this_csv_file_path, arm_params_list, results_list, __status)
6475
7170
 
6476
- if experiment_parameters is not None:
6477
- for param in experiment_parameters:
6478
- name = param["name"]
6479
- expected_type = param.get("value_type", "str")
7171
+ def normalize_path(file_path: str) -> str:
7172
+ return file_path.replace("//", "/")
6480
7173
 
6481
- if name not in arm_params:
6482
- continue
7174
+ def insert_jobs_from_lists(csv_path: str, arm_params_list: Any, results_list: Any, __status: Any) -> None:
7175
+ cnt = 0
7176
+ err_msgs: list = []
6483
7177
 
6484
- value = arm_params[name]
7178
+ for i, (arm_params, result) in enumerate(zip(arm_params_list, results_list)):
7179
+ base_str = f"[bold green]Loading job {i}/{len(results_list)} from {csv_path} into ax_client, result: {result}"
7180
+ __status.update(base_str)
6485
7181
 
6486
- try:
6487
- if param["type"] == "range":
6488
- if expected_type == "int":
6489
- corrected_params[name] = int(value)
6490
- elif expected_type == "float":
6491
- corrected_params[name] = float(value)
6492
- elif param["type"] == "choice":
6493
- corrected_params[name] = str(value)
6494
- except (ValueError, TypeError):
6495
- corrected_params[name] = None
6496
-
6497
- return corrected_params
7182
+ if not args.worker_generator_path:
7183
+ arm_params = validate_and_convert_params_for_jobs_from_csv(arm_params)
6498
7184
 
6499
- arm_params_list, results_list = parse_csv(this_csv_file_path)
7185
+ cnt = try_insert_job(csv_path, arm_params, result, i, arm_params_list, results_list, __status, base_str, cnt, err_msgs)
6500
7186
 
6501
- cnt = 0
7187
+ summarize_insertions(csv_path, cnt)
7188
+ update_global_job_counters(cnt)
6502
7189
 
6503
- err_msgs = []
7190
+ def try_insert_job(csv_path: str, arm_params: Dict, result: Any, i: int, arm_params_list: Any, results_list: Any, __status: Any, base_str: Optional[str], cnt: int, err_msgs: Optional[Union[str, list[str]]]) -> int:
7191
+ try:
7192
+ gen_node_name = get_generation_node_for_index(csv_path, arm_params_list, results_list, i, __status, base_str)
6504
7193
 
6505
- i = 0
6506
- for arm_params, result in zip(arm_params_list, results_list):
6507
- base_str = f"[bold green]Loading job {i}/{len(results_list)} from {this_csv_file_path} into ax_client, result: {result}"
6508
- __status.update(base_str)
6509
- if not args.worker_generator_path:
6510
- arm_params = validate_and_convert_params(arm_params)
7194
+ if not result:
7195
+ print_yellow("Encountered job without a result")
7196
+ return cnt
6511
7197
 
6512
- try:
6513
- gen_node_name = get_generation_node_for_index(this_csv_file_path, arm_params_list, results_list, i, __status, base_str)
7198
+ if insert_job_into_ax_client(arm_params, result, gen_node_name, __status, base_str):
7199
+ cnt += 1
7200
+ print_debug(f"Inserted one job from {csv_path}, arm_params: {arm_params}, results: {result}")
7201
+ else:
7202
+ print_red(f"Failed to insert one job from {csv_path}, arm_params: {arm_params}, results: {result}")
6514
7203
 
6515
- if len(result):
6516
- if insert_job_into_ax_client(arm_params, result, gen_node_name, __status, base_str):
6517
- cnt += 1
7204
+ except ValueError as e:
7205
+ err_msg = (
7206
+ f"Failed to insert job(s) from {csv_path} into ax_client. "
7207
+ f"This can happen when the csv file has different parameters or results as the main job one's "
7208
+ f"or other imported jobs. Error: {e}"
7209
+ )
6518
7210
 
6519
- print_debug(f"Inserted one job from {this_csv_file_path}, arm_params: {arm_params}, results: {result}")
6520
- else:
6521
- print_red(f"Failed to insert one job from {this_csv_file_path}, arm_params: {arm_params}, results: {result}")
6522
- else:
6523
- print_yellow("Encountered job without a result")
6524
- except ValueError as e:
6525
- err_msg = f"Failed to insert job(s) from {this_csv_file_path} into ax_client. This can happen when the csv file has different parameters or results as the main job one's or other imported jobs. Error: {e}"
7211
+ if err_msgs is None:
7212
+ print_red("try_insert_job: err_msgs was None")
7213
+ else:
7214
+ if isinstance(err_msgs, list):
6526
7215
  if err_msg not in err_msgs:
6527
7216
  print_red(err_msg)
6528
7217
  err_msgs.append(err_msg)
7218
+ elif isinstance(err_msgs, str):
7219
+ err_msgs += f"\n{err_msg}"
6529
7220
 
6530
- i = i + 1
7221
+ return cnt
6531
7222
 
6532
- if cnt:
6533
- if cnt == 1:
6534
- print_yellow(f"Inserted one job from {this_csv_file_path}")
6535
- else:
6536
- print_yellow(f"Inserted {cnt} jobs from {this_csv_file_path}")
7223
+ def summarize_insertions(csv_path: str, cnt: int) -> None:
7224
+ if cnt == 0:
7225
+ return
7226
+ if cnt == 1:
7227
+ print_yellow(f"Inserted one job from {csv_path}")
7228
+ else:
7229
+ print_yellow(f"Inserted {cnt} jobs from {csv_path}")
6537
7230
 
6538
- if not args.worker_generator_path:
6539
- set_max_eval(max_eval + cnt)
6540
- set_nr_inserted_jobs(NR_INSERTED_JOBS + cnt)
7231
+ def update_global_job_counters(cnt: int) -> None:
7232
+ if not args.worker_generator_path:
7233
+ set_max_eval(max_eval + cnt)
7234
+ set_nr_inserted_jobs(NR_INSERTED_JOBS + cnt)
6541
7235
 
6542
- def __insert_job_into_ax_client__update_status(__status: Optional[Any], base_str: Optional[str], new_text: str) -> None:
7236
+ def update_status(__status: Optional[Any], base_str: Optional[str], new_text: str) -> None:
6543
7237
  if __status and base_str:
6544
7238
  __status.update(f"{base_str}: {new_text}")
6545
7239
 
6546
- def __insert_job_into_ax_client__check_ax_client() -> None:
7240
+ def check_ax_client() -> None:
6547
7241
  if ax_client is None or not ax_client:
6548
7242
  _fatal_error("insert_job_into_ax_client: ax_client was not defined where it should have been", 101)
6549
7243
 
6550
- def __insert_job_into_ax_client__attach_trial(arm_params: dict) -> Tuple[Any, int]:
7244
+ def attach_ax_client_data(arm_params: dict) -> Optional[Tuple[Any, int]]:
7245
+ if not ax_client:
7246
+ my_exit(101)
7247
+
7248
+ return None
7249
+
6551
7250
  new_trial = ax_client.attach_trial(arm_params)
7251
+
7252
+ return new_trial
7253
+
7254
+ def attach_trial(arm_params: dict) -> Tuple[Any, int]:
7255
+ if ax_client is None:
7256
+ raise RuntimeError("attach_trial: ax_client was empty")
7257
+
7258
+ new_trial = attach_ax_client_data(arm_params)
6552
7259
  if not isinstance(new_trial, tuple) or len(new_trial) < 2:
6553
7260
  raise RuntimeError("attach_trial didn't return the expected tuple")
6554
7261
  return new_trial
6555
7262
 
6556
- def __insert_job_into_ax_client__get_trial(trial_idx: int) -> Any:
7263
+ def get_trial_by_index(trial_idx: int) -> Any:
7264
+ if ax_client is None:
7265
+ raise RuntimeError("get_trial_by_index: ax_client was empty")
7266
+
6557
7267
  trial = ax_client.experiment.trials.get(trial_idx)
6558
7268
  if trial is None:
6559
7269
  raise RuntimeError(f"Trial with index {trial_idx} not found")
6560
7270
  return trial
6561
7271
 
6562
- def __insert_job_into_ax_client__create_generator_run(arm_params: dict, trial_idx: int, new_job_type: str) -> GeneratorRun:
7272
+ def create_generator_run(arm_params: dict, trial_idx: int, new_job_type: str) -> GeneratorRun:
6563
7273
  arm = Arm(parameters=arm_params, name=f'{trial_idx}_0')
6564
7274
  return GeneratorRun(arms=[arm], generation_node_name=new_job_type)
6565
7275
 
6566
- def __insert_job_into_ax_client__complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[Any], base_str: Optional[str]) -> None:
7276
+ def complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[Any], base_str: Optional[str]) -> None:
7277
+ if ax_client is None:
7278
+ raise RuntimeError("complete_trial_if_result: ax_client was empty")
7279
+
6567
7280
  if f"{result}" != "":
6568
- __insert_job_into_ax_client__update_status(__status, base_str, "Completing trial")
7281
+ update_status(__status, base_str, "Completing trial")
6569
7282
  is_ok = True
6570
7283
 
6571
7284
  for keyname in result.keys():
@@ -6573,20 +7286,20 @@ def __insert_job_into_ax_client__complete_trial_if_result(trial_idx: int, result
6573
7286
  is_ok = False
6574
7287
 
6575
7288
  if is_ok:
6576
- ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
6577
- __insert_job_into_ax_client__update_status(__status, base_str, "Completed trial")
7289
+ complete_ax_client_trial(trial_idx, result)
7290
+ update_status(__status, base_str, "Completed trial")
6578
7291
  else:
6579
7292
  print_debug("Empty job encountered")
6580
7293
  else:
6581
- __insert_job_into_ax_client__update_status(__status, base_str, "Found trial without result. Not adding it.")
7294
+ update_status(__status, base_str, "Found trial without result. Not adding it.")
6582
7295
 
6583
- def __insert_job_into_ax_client__save_results_if_needed(__status: Optional[Any], base_str: Optional[str]) -> None:
7296
+ def save_results_if_needed(__status: Optional[Any], base_str: Optional[str]) -> None:
6584
7297
  if not args.worker_generator_path:
6585
- __insert_job_into_ax_client__update_status(__status, base_str, f"Saving {RESULTS_CSV_FILENAME}")
7298
+ update_status(__status, base_str, f"Saving {RESULTS_CSV_FILENAME}")
6586
7299
  save_results_csv()
6587
- __insert_job_into_ax_client__update_status(__status, base_str, f"Saved {RESULTS_CSV_FILENAME}")
7300
+ update_status(__status, base_str, f"Saved {RESULTS_CSV_FILENAME}")
6588
7301
 
6589
- def __insert_job_into_ax_client__handle_type_error(e: Exception, arm_params: dict) -> bool:
7302
+ def handle_insert_job_error(e: Exception, arm_params: dict) -> bool:
6590
7303
  parsed_error = parse_parameter_type_error(e)
6591
7304
  if parsed_error is not None:
6592
7305
  param = parsed_error["parameter_name"]
@@ -6609,37 +7322,37 @@ def insert_job_into_ax_client(
6609
7322
  result: dict,
6610
7323
  new_job_type: str = "MANUAL",
6611
7324
  __status: Optional[Any] = None,
6612
- base_str: str = None
7325
+ base_str: Optional[str] = None
6613
7326
  ) -> bool:
6614
- __insert_job_into_ax_client__check_ax_client()
7327
+ check_ax_client()
6615
7328
 
6616
7329
  done_converting = False
6617
7330
  while not done_converting:
6618
7331
  try:
6619
- __insert_job_into_ax_client__update_status(__status, base_str, "Checking ax client")
7332
+ update_status(__status, base_str, "Checking ax client")
6620
7333
  if ax_client is None:
6621
7334
  return False
6622
7335
 
6623
- __insert_job_into_ax_client__update_status(__status, base_str, "Attaching new trial")
6624
- _, new_trial_idx = __insert_job_into_ax_client__attach_trial(arm_params)
7336
+ update_status(__status, base_str, "Attaching new trial")
7337
+ _, new_trial_idx = attach_trial(arm_params)
6625
7338
 
6626
- __insert_job_into_ax_client__update_status(__status, base_str, "Getting new trial")
6627
- trial = __insert_job_into_ax_client__get_trial(new_trial_idx)
6628
- __insert_job_into_ax_client__update_status(__status, base_str, "Got new trial")
7339
+ update_status(__status, base_str, "Getting new trial")
7340
+ trial = get_trial_by_index(new_trial_idx)
7341
+ update_status(__status, base_str, "Got new trial")
6629
7342
 
6630
- __insert_job_into_ax_client__update_status(__status, base_str, "Creating new arm")
6631
- manual_generator_run = __insert_job_into_ax_client__create_generator_run(arm_params, new_trial_idx, new_job_type)
7343
+ update_status(__status, base_str, "Creating new arm")
7344
+ manual_generator_run = create_generator_run(arm_params, new_trial_idx, new_job_type)
6632
7345
  trial._generator_run = manual_generator_run
6633
7346
  fool_linter(trial._generator_run)
6634
7347
 
6635
- __insert_job_into_ax_client__complete_trial_if_result(new_trial_idx, result, __status, base_str)
7348
+ complete_trial_if_result(new_trial_idx, result, __status, base_str)
6636
7349
  done_converting = True
6637
7350
 
6638
- __insert_job_into_ax_client__save_results_if_needed(__status, base_str)
7351
+ save_results_if_needed(__status, base_str)
6639
7352
  return True
6640
7353
 
6641
7354
  except ax.exceptions.core.UnsupportedError as e:
6642
- if not __insert_job_into_ax_client__handle_type_error(e, arm_params):
7355
+ if not handle_insert_job_error(e, arm_params):
6643
7356
  break
6644
7357
 
6645
7358
  return False
@@ -6970,7 +7683,7 @@ def get_parameters_from_outfile(stdout_path: str) -> Union[None, dict, str]:
6970
7683
  if not args.tests:
6971
7684
  original_print(f"get_parameters_from_outfile: The file '{stdout_path}' was not found.")
6972
7685
  except Exception as e:
6973
- print(f"get_parameters_from_outfile: There was an error: {e}")
7686
+ print_red(f"get_parameters_from_outfile: There was an error: {e}")
6974
7687
 
6975
7688
  return None
6976
7689
 
@@ -6988,7 +7701,7 @@ def get_hostname_from_outfile(stdout_path: Optional[str]) -> Optional[str]:
6988
7701
  original_print(f"The file '{stdout_path}' was not found.")
6989
7702
  return None
6990
7703
  except Exception as e:
6991
- print(f"There was an error: {e}")
7704
+ print_red(f"There was an error: {e}")
6992
7705
  return None
6993
7706
 
6994
7707
  def add_to_global_error_list(msg: str) -> None:
@@ -7024,22 +7737,70 @@ def mark_trial_as_failed(trial_index: int, _trial: Any) -> None:
7024
7737
 
7025
7738
  return None
7026
7739
 
7027
- ax_client.log_trial_failure(trial_index=trial_index)
7740
+ log_ax_client_trial_failure(trial_index)
7028
7741
  _trial.mark_failed(unsafe=True)
7029
7742
  except ValueError as e:
7030
7743
  print_debug(f"mark_trial_as_failed error: {e}")
7031
7744
 
7032
7745
  return None
7033
7746
 
7034
- def _finish_job_core_helper_check_valid_result(result: Union[None, list, int, float, tuple]) -> bool:
7747
+ def check_valid_result(result: Union[None, dict]) -> bool:
7035
7748
  possible_val_not_found_values = [
7036
7749
  VAL_IF_NOTHING_FOUND,
7037
7750
  -VAL_IF_NOTHING_FOUND,
7038
7751
  -99999999999999997168788049560464200849936328366177157906432,
7039
7752
  99999999999999997168788049560464200849936328366177157906432
7040
7753
  ]
7041
- values_to_check = result if isinstance(result, list) else [result]
7042
- return result is not None and all(r not in possible_val_not_found_values for r in values_to_check)
7754
+
7755
+ def flatten_values(obj: Any) -> Any:
7756
+ values = []
7757
+ try:
7758
+ if isinstance(obj, dict):
7759
+ for v in obj.values():
7760
+ values.extend(flatten_values(v))
7761
+ elif isinstance(obj, (list, tuple, set)):
7762
+ for v in obj:
7763
+ values.extend(flatten_values(v))
7764
+ else:
7765
+ values.append(obj)
7766
+ except Exception as e:
7767
+ print_red(f"Error while flattening values: {e}")
7768
+ return values
7769
+
7770
+ if result is None:
7771
+ return False
7772
+
7773
+ try:
7774
+ all_values = flatten_values(result)
7775
+ for val in all_values:
7776
+ if val in possible_val_not_found_values:
7777
+ return False
7778
+ return True
7779
+ except Exception as e:
7780
+ print_red(f"Error while checking result validity: {e}")
7781
+ return False
7782
+
7783
+ def update_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
7784
+ if not ax_client:
7785
+ my_exit(101)
7786
+
7787
+ return None
7788
+
7789
+ trial = get_trial_by_index(trial_idx)
7790
+
7791
+ trial.update_trial_data(raw_data=result)
7792
+
7793
+ return None
7794
+
7795
+ def complete_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
7796
+ if not ax_client:
7797
+ my_exit(101)
7798
+
7799
+ return None
7800
+
7801
+ ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
7802
+
7803
+ return None
7043
7804
 
7044
7805
  def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -> None:
7045
7806
  if ax_client is None:
@@ -7048,25 +7809,64 @@ def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -
7048
7809
 
7049
7810
  try:
7050
7811
  print_debug(f"Completing trial: {trial_index} with result: {raw_result}...")
7051
- ax_client.complete_trial(trial_index=trial_index, raw_data=raw_result)
7812
+ complete_ax_client_trial(trial_index, raw_result)
7052
7813
  print_debug(f"Completing trial: {trial_index} with result: {raw_result}... Done!")
7053
7814
  except ax.exceptions.core.UnsupportedError as e:
7054
7815
  if f"{e}":
7055
7816
  print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure. Trying to update trial...")
7056
- ax_client.update_trial_data(trial_index=trial_index, raw_data=raw_result)
7817
+ update_ax_client_trial(trial_index, raw_result)
7057
7818
  print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure... Done!")
7058
7819
  else:
7059
7820
  _fatal_error(f"Error completing trial: {e}", 234)
7060
7821
 
7061
7822
  return None
7062
7823
 
7063
- def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: Union[float, int, tuple]) -> None:
7824
+ def format_result_for_display(result: dict) -> str:
7825
+ def safe_float(v: Any) -> str:
7826
+ try:
7827
+ if v is None:
7828
+ return "None"
7829
+ if isinstance(v, (int, float)):
7830
+ if math.isnan(v):
7831
+ return "NaN"
7832
+ if math.isinf(v):
7833
+ return "∞" if v > 0 else "-∞"
7834
+ return f"{v:.6f}"
7835
+ return str(v)
7836
+ except Exception as e:
7837
+ return f"<error: {e}>"
7838
+
7839
+ try:
7840
+ if not isinstance(result, dict):
7841
+ return safe_float(result)
7842
+
7843
+ parts = []
7844
+ for key, val in result.items():
7845
+ try:
7846
+ if isinstance(val, (list, tuple)) and len(val) == 2:
7847
+ main, sem = val
7848
+ main_str = safe_float(main)
7849
+ if sem is not None:
7850
+ sem_str = safe_float(sem)
7851
+ parts.append(f"{key}: {main_str} (SEM: {sem_str})")
7852
+ else:
7853
+ parts.append(f"{key}: {main_str}")
7854
+ else:
7855
+ parts.append(f"{key}: {safe_float(val)}")
7856
+ except Exception as e:
7857
+ parts.append(f"{key}: <error: {e}>")
7858
+
7859
+ return ", ".join(parts)
7860
+ except Exception as e:
7861
+ return f"<error formatting result: {e}>"
7862
+
7863
+ def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: dict) -> None:
7064
7864
  print_debug(f"Marking trial {_trial} as completed")
7065
7865
  _trial.mark_completed(unsafe=True)
7066
7866
 
7067
7867
  succeeded_jobs(1)
7068
7868
 
7069
- progressbar_description(f"new result: {result}")
7869
+ progressbar_description(f"new result: {format_result_for_display(result)}")
7070
7870
  update_progress_bar(1)
7071
7871
 
7072
7872
  save_results_csv()
@@ -7080,8 +7880,8 @@ def _finish_job_core_helper_mark_failure(job: Any, trial_index: int, _trial: Any
7080
7880
  if job:
7081
7881
  try:
7082
7882
  progressbar_description("job_failed")
7083
- ax_client.log_trial_failure(trial_index=trial_index)
7084
- _trial.mark_failed(unsafe=True)
7883
+ log_ax_client_trial_failure(trial_index)
7884
+ mark_trial_as_failed(trial_index, _trial)
7085
7885
  except Exception as e:
7086
7886
  print_red(f"\nERROR while trying to mark job as failure: {e}")
7087
7887
  job.cancel()
@@ -7098,16 +7898,16 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7098
7898
  result = job.result()
7099
7899
  print_debug(f"finish_job_core: trial-index: {trial_index}, job.result(): {result}, state: {state_from_job(job)}")
7100
7900
 
7101
- raw_result = result
7102
- result_keys = list(result.keys())
7103
- result = result[result_keys[0]]
7104
7901
  this_jobs_finished += 1
7105
7902
 
7106
7903
  if ax_client:
7107
- _trial = ax_client.get_trial(trial_index)
7904
+ _trial = get_ax_client_trial(trial_index)
7905
+
7906
+ if _trial is None:
7907
+ return 0
7108
7908
 
7109
- if _finish_job_core_helper_check_valid_result(result):
7110
- _finish_job_core_helper_complete_trial(trial_index, raw_result)
7909
+ if check_valid_result(result):
7910
+ _finish_job_core_helper_complete_trial(trial_index, result)
7111
7911
 
7112
7912
  try:
7113
7913
  _finish_job_core_helper_mark_success(_trial, result)
@@ -7115,15 +7915,17 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7115
7915
  if len(arg_result_names) > 1 and count_done_jobs() > 1 and not job_calculate_pareto_front(get_current_run_folder(), True):
7116
7916
  print_red("job_calculate_pareto_front post job failed")
7117
7917
  except Exception as e:
7118
- print(f"ERROR in line {get_line_info()}: {e}")
7918
+ print_red(f"ERROR in line {get_line_info()}: {e}")
7119
7919
  else:
7120
7920
  _finish_job_core_helper_mark_failure(job, trial_index, _trial)
7121
7921
  else:
7122
- _fatal_error("ax_client could not be found or used", 9)
7922
+ _fatal_error("ax_client could not be found or used", 101)
7123
7923
 
7124
7924
  print_debug(f"finish_job_core: removing job {job}, trial_index: {trial_index}")
7125
7925
  global_vars["jobs"].remove((job, trial_index))
7126
7926
 
7927
+ log_data()
7928
+
7127
7929
  force_live_share()
7128
7930
 
7129
7931
  return this_jobs_finished
@@ -7136,11 +7938,14 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
7136
7938
  if job:
7137
7939
  try:
7138
7940
  progressbar_description("job_failed")
7139
- _trial = ax_client.get_trial(trial_index)
7140
- ax_client.log_trial_failure(trial_index=trial_index)
7941
+ _trial = get_ax_client_trial(trial_index)
7942
+ if _trial is None:
7943
+ return None
7944
+
7945
+ log_ax_client_trial_failure(trial_index)
7141
7946
  mark_trial_as_failed(trial_index, _trial)
7142
7947
  except Exception as e:
7143
- print(f"ERROR in line {get_line_info()}: {e}")
7948
+ print_debug(f"ERROR in line {get_line_info()}: {e}")
7144
7949
  job.cancel()
7145
7950
  orchestrate_job(job, trial_index)
7146
7951
 
@@ -7155,10 +7960,12 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
7155
7960
 
7156
7961
  def _finish_previous_jobs_helper_handle_exception(job: Any, trial_index: int, error: Exception) -> int:
7157
7962
  if "None for metric" in str(error):
7158
- print_red(
7159
- f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string."
7160
- f"\nError: {error}\nJob-result: {job.result()}"
7161
- )
7963
+ err_msg = f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string.\nError: {error}\nJob-result: {job.result()}"
7964
+
7965
+ if count_done_jobs() == 0:
7966
+ print_red(err_msg)
7967
+ else:
7968
+ print_debug(err_msg)
7162
7969
  else:
7163
7970
  print_red(f"\n⚠ {error}")
7164
7971
 
@@ -7177,7 +7984,10 @@ def _finish_previous_jobs_helper_process_job(job: Any, trial_index: int, this_jo
7177
7984
  this_jobs_finished += _finish_previous_jobs_helper_handle_exception(job, trial_index, error)
7178
7985
  return this_jobs_finished
7179
7986
 
7180
- def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7987
+ def _finish_previous_jobs_helper_check_and_process(__args: Tuple[Any, int]) -> int:
7988
+ job, trial_index = __args
7989
+
7990
+ this_jobs_finished = 0
7181
7991
  if job is None:
7182
7992
  print_debug(f"finish_previous_jobs: job {job} is None")
7183
7993
  return this_jobs_finished
@@ -7190,10 +8000,6 @@ def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, t
7190
8000
 
7191
8001
  return this_jobs_finished
7192
8002
 
7193
- def _finish_previous_jobs_helper_wrapper(__args: Tuple[Any, int]) -> int:
7194
- job, trial_index = __args
7195
- return _finish_previous_jobs_helper_check_and_process(job, trial_index, 0)
7196
-
7197
8003
  def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7198
8004
  global JOBS_FINISHED
7199
8005
 
@@ -7206,13 +8012,10 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7206
8012
 
7207
8013
  jobs_copy = global_vars["jobs"][:]
7208
8014
 
7209
- if len(jobs_copy) > 0:
7210
- print_debug(f"jobs in finish_previous_jobs: {jobs_copy}")
7211
-
7212
- finishing_jobs_start_time = time.time()
8015
+ #finishing_jobs_start_time = time.time()
7213
8016
 
7214
8017
  with ThreadPoolExecutor() as finish_job_executor:
7215
- futures = [finish_job_executor.submit(_finish_previous_jobs_helper_wrapper, (job, trial_index)) for job, trial_index in jobs_copy]
8018
+ futures = [finish_job_executor.submit(_finish_previous_jobs_helper_check_and_process, (job, trial_index)) for job, trial_index in jobs_copy]
7216
8019
 
7217
8020
  for future in as_completed(futures):
7218
8021
  try:
@@ -7220,17 +8023,15 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7220
8023
  except Exception as e:
7221
8024
  print_red(f"⚠ Exception in parallel job handling: {e}")
7222
8025
 
7223
- finishing_jobs_end_time = time.time()
7224
-
7225
- finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
8026
+ #finishing_jobs_end_time = time.time()
7226
8027
 
7227
- print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
7228
-
7229
- save_results_csv()
8028
+ #finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
7230
8029
 
7231
- save_checkpoint()
8030
+ #print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
7232
8031
 
7233
8032
  if this_jobs_finished > 0:
8033
+ save_results_csv()
8034
+ save_checkpoint()
7234
8035
  progressbar_description([*new_msgs, f"finished {this_jobs_finished} {'job' if this_jobs_finished == 1 else 'jobs'}"])
7235
8036
 
7236
8037
  JOBS_FINISHED += this_jobs_finished
@@ -7373,11 +8174,15 @@ def is_already_in_defective_nodes(hostname: str) -> bool:
7373
8174
  return True
7374
8175
  except Exception as e:
7375
8176
  print_red(f"is_already_in_defective_nodes: Error reading the file {file_path}: {e}")
7376
- return False
7377
8177
 
7378
8178
  return False
7379
8179
 
7380
8180
  def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
8181
+ if submitit_executor is None:
8182
+ print_red("submit_new_job: submitit_executor was None")
8183
+
8184
+ return None
8185
+
7381
8186
  print_debug(f"Submitting new job for trial_index {trial_index}, parameters {parameters}")
7382
8187
 
7383
8188
  start = time.time()
@@ -7388,25 +8193,49 @@ def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
7388
8193
 
7389
8194
  print_debug(f"Done submitting new job, took {elapsed} seconds")
7390
8195
 
8196
+ log_data()
8197
+
7391
8198
  return new_job
7392
8199
 
8200
+ def get_ax_client_trial(trial_index: int) -> Optional[ax.core.trial.Trial]:
8201
+ if not ax_client:
8202
+ my_exit(101)
8203
+
8204
+ return None
8205
+
8206
+ try:
8207
+ log_data()
8208
+
8209
+ return ax_client.get_trial(trial_index)
8210
+ except KeyError:
8211
+ error_without_print(f"get_ax_client_trial: trial_index {trial_index} failed")
8212
+ return None
8213
+
7393
8214
  def orchestrator_start_trial(parameters: Union[dict, str], trial_index: int) -> None:
7394
8215
  if submitit_executor and ax_client:
7395
8216
  new_job = submit_new_job(parameters, trial_index)
7396
- submitted_jobs(1)
8217
+ if new_job:
8218
+ submitted_jobs(1)
7397
8219
 
7398
- _trial = ax_client.get_trial(trial_index)
8220
+ _trial = get_ax_client_trial(trial_index)
7399
8221
 
7400
- try:
7401
- _trial.mark_staged(unsafe=True)
7402
- except Exception as e:
7403
- print_debug(f"orchestrator_start_trial: error {e}")
7404
- _trial.mark_running(unsafe=True, no_runner_required=True)
8222
+ if _trial is not None:
8223
+ try:
8224
+ _trial.mark_staged(unsafe=True)
8225
+ except Exception as e:
8226
+ print_debug(f"orchestrator_start_trial: error {e}")
8227
+ _trial.mark_running(unsafe=True, no_runner_required=True)
7405
8228
 
7406
- print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
7407
- global_vars["jobs"].append((new_job, trial_index))
8229
+ print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
8230
+ global_vars["jobs"].append((new_job, trial_index))
8231
+ else:
8232
+ print_red("Trial was none in orchestrator_start_trial")
8233
+ else:
8234
+ print_red("orchestrator_start_trial: Failed to start new job")
8235
+ elif ax_client:
8236
+ _fatal_error("submitit_executor could not be found properly", 9)
7408
8237
  else:
7409
- _fatal_error("submitit_executor or ax_client could not be found properly", 9)
8238
+ _fatal_error("ax_client could not be found properly", 101)
7410
8239
 
7411
8240
  def handle_exclude_node(stdout_path: str, hostname_from_out_file: Union[None, str]) -> None:
7412
8241
  stdout_path = check_alternate_path(stdout_path)
@@ -7425,7 +8254,7 @@ def handle_restart(stdout_path: str, trial_index: int) -> None:
7425
8254
  if parameters:
7426
8255
  orchestrator_start_trial(parameters, trial_index)
7427
8256
  else:
7428
- print(f"Could not determine parameters from outfile {stdout_path} for restarting job")
8257
+ print_red(f"Could not determine parameters from outfile {stdout_path} for restarting job")
7429
8258
 
7430
8259
  def check_alternate_path(path: str) -> str:
7431
8260
  if os.path.exists(path):
@@ -7512,7 +8341,7 @@ def execute_evaluation(_params: list) -> Optional[int]:
7512
8341
  print_debug(f"execute_evaluation({_params})")
7513
8342
  trial_index, parameters, trial_counter, phase = _params
7514
8343
  if not ax_client:
7515
- _fatal_error("Failed to get ax_client", 9)
8344
+ _fatal_error("Failed to get ax_client", 101)
7516
8345
 
7517
8346
  return None
7518
8347
 
@@ -7521,7 +8350,11 @@ def execute_evaluation(_params: list) -> Optional[int]:
7521
8350
 
7522
8351
  return None
7523
8352
 
7524
- _trial = ax_client.get_trial(trial_index)
8353
+ _trial = get_ax_client_trial(trial_index)
8354
+
8355
+ if _trial is None:
8356
+ error_without_print(f"execute_evaluation: _trial was not in execute_evaluation for params {_params}")
8357
+ return None
7525
8358
 
7526
8359
  def mark_trial_stage(stage: str, error_msg: str) -> None:
7527
8360
  try:
@@ -7536,17 +8369,18 @@ def execute_evaluation(_params: list) -> Optional[int]:
7536
8369
  try:
7537
8370
  initialize_job_environment()
7538
8371
  new_job = submit_new_job(parameters, trial_index)
7539
- submitted_jobs(1)
7540
-
7541
- print_debug(f"execute_evaluation: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
7542
- global_vars["jobs"].append((new_job, trial_index))
8372
+ if new_job:
8373
+ submitted_jobs(1)
7543
8374
 
7544
- mark_trial_stage("mark_running", "Marking the trial as running failed")
7545
- trial_counter += 1
8375
+ print_debug(f"execute_evaluation: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
8376
+ global_vars["jobs"].append((new_job, trial_index))
7546
8377
 
7547
- progressbar_description("started new job")
8378
+ mark_trial_stage("mark_running", "Marking the trial as running failed")
8379
+ trial_counter += 1
7548
8380
 
7549
- save_results_csv()
8381
+ progressbar_description("started new job")
8382
+ else:
8383
+ progressbar_description("Failed to start new job")
7550
8384
  except submitit.core.utils.FailedJobError as error:
7551
8385
  handle_failed_job(error, trial_index, new_job)
7552
8386
  trial_counter += 1
@@ -7588,7 +8422,7 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
7588
8422
  my_exit(144)
7589
8423
 
7590
8424
  if new_job is None:
7591
- print_red("handle_failed_job: job is None")
8425
+ print_debug("handle_failed_job: job is None")
7592
8426
 
7593
8427
  return None
7594
8428
 
@@ -7599,16 +8433,24 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
7599
8433
 
7600
8434
  return None
7601
8435
 
8436
+ def log_ax_client_trial_failure(trial_index: int) -> None:
8437
+ if not ax_client:
8438
+ my_exit(101)
8439
+
8440
+ return
8441
+
8442
+ ax_client.log_trial_failure(trial_index=trial_index)
8443
+
7602
8444
  def cancel_failed_job(trial_index: int, new_job: Job) -> None:
7603
8445
  print_debug("Trying to cancel job that failed")
7604
8446
  if new_job:
7605
8447
  try:
7606
8448
  if ax_client:
7607
- ax_client.log_trial_failure(trial_index=trial_index)
8449
+ log_ax_client_trial_failure(trial_index)
7608
8450
  else:
7609
8451
  _fatal_error("ax_client not defined", 101)
7610
8452
  except Exception as e:
7611
- print(f"ERROR in line {get_line_info()}: {e}")
8453
+ print_red(f"ERROR in line {get_line_info()}: {e}")
7612
8454
  new_job.cancel()
7613
8455
 
7614
8456
  print_debug(f"cancel_failed_job: removing job {new_job}, trial_index: {trial_index}")
@@ -7644,10 +8486,12 @@ def show_debug_table_for_break_run_search(_name: str, _max_eval: Optional[int])
7644
8486
  ("failed_jobs()", failed_jobs()),
7645
8487
  ("count_done_jobs()", count_done_jobs()),
7646
8488
  ("_max_eval", _max_eval),
7647
- ("progress_bar.total", progress_bar.total),
7648
8489
  ("NR_INSERTED_JOBS", NR_INSERTED_JOBS)
7649
8490
  ]
7650
8491
 
8492
+ if progress_bar is not None:
8493
+ rows.append(("progress_bar.total", progress_bar.total))
8494
+
7651
8495
  for row in rows:
7652
8496
  table.add_row(str(row[0]), str(row[1]))
7653
8497
 
@@ -7660,6 +8504,8 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
7660
8504
  _submitted_jobs = submitted_jobs()
7661
8505
  _failed_jobs = failed_jobs()
7662
8506
 
8507
+ log_data()
8508
+
7663
8509
  max_failed_jobs = max_eval
7664
8510
 
7665
8511
  if args.max_failed_jobs is not None and args.max_failed_jobs > 0:
@@ -7687,11 +8533,11 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
7687
8533
  _ret = True
7688
8534
 
7689
8535
  if args.verbose_break_run_search_table:
7690
- show_debug_table_for_break_run_search(_name, _max_eval, _ret)
8536
+ show_debug_table_for_break_run_search(_name, _max_eval)
7691
8537
 
7692
8538
  return _ret
7693
8539
 
7694
- def _calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
8540
+ def calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
7695
8541
  """Calculates the number of jobs to retrieve."""
7696
8542
  return min(
7697
8543
  max_eval + simulated_jobs - count_done_jobs(),
@@ -7704,7 +8550,7 @@ def remove_extra_spaces(text: str) -> str:
7704
8550
  raise ValueError("Input must be a string")
7705
8551
  return re.sub(r'\s+', ' ', text).strip()
7706
8552
 
7707
- def _get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
8553
+ def get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
7708
8554
  """Generates the appropriate message for the number of trials being retrieved."""
7709
8555
  ret = ""
7710
8556
  if full_nr_of_jobs_to_get > 1:
@@ -7789,70 +8635,70 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
7789
8635
  print_red("get_batched_arms: ax_client was None")
7790
8636
  return []
7791
8637
 
7792
- # Experiment-Status laden
7793
8638
  load_experiment_state()
7794
8639
 
7795
- while len(batched_arms) != nr_of_jobs_to_get:
8640
+ while len(batched_arms) < nr_of_jobs_to_get:
7796
8641
  if attempts > args.max_attempts_for_generation:
7797
8642
  print_debug(f"get_batched_arms: Stopped after {attempts} attempts: could not generate enough arms "
7798
8643
  f"(got {len(batched_arms)} out of {nr_of_jobs_to_get}).")
7799
8644
  break
7800
8645
 
7801
- remaining = nr_of_jobs_to_get - len(batched_arms)
7802
- print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting {remaining} more arm(s).")
7803
-
7804
- #print("get pending observations")
7805
- pending_observations = get_pending_observation_features(experiment=ax_client.experiment)
7806
- #print("got pending observations")
8646
+ #print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting 1 more arm")
7807
8647
 
7808
- #print("getting global_gs.gen()")
7809
- batched_generator_run = global_gs.gen(
8648
+ pending_observations = get_pending_observation_features(
7810
8649
  experiment=ax_client.experiment,
7811
- n=remaining,
7812
- pending_observations=pending_observations
8650
+ include_out_of_design_points=True
7813
8651
  )
7814
- #print(f"got global_gs.gen(): {batched_generator_run}")
7815
8652
 
7816
- # Inline rekursiv entpacken bis flach
8653
+ try:
8654
+ #print_debug("getting global_gs.gen() with n=1")
8655
+
8656
+ batched_generator_run: Any = global_gs.gen(
8657
+ experiment=ax_client.experiment,
8658
+ n=1,
8659
+ pending_observations=pending_observations,
8660
+ )
8661
+ print_debug(f"got global_gs.gen(): {batched_generator_run}")
8662
+ except Exception as e:
8663
+ print_debug(f"global_gs.gen failed: {e}")
8664
+ traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
8665
+ break
8666
+
7817
8667
  depth = 0
7818
8668
  path = "batched_generator_run"
7819
- while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) > 0:
7820
- #print(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
8669
+ while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) == 1:
8670
+ #print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
7821
8671
  batched_generator_run = batched_generator_run[0]
7822
8672
  path += "[0]"
7823
8673
  depth += 1
7824
8674
 
7825
- #print(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
8675
+ #print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
7826
8676
 
7827
- #print("got new arms")
7828
- new_arms = batched_generator_run.arms
7829
- #print(f"new_arms: {new_arms}")
8677
+ new_arms = getattr(batched_generator_run, "arms", [])
7830
8678
  if not new_arms:
7831
8679
  print_debug("get_batched_arms: No new arms were generated in this attempt.")
7832
8680
  else:
7833
- print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), wanted {nr_of_jobs_to_get}.")
8681
+ print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), now at {len(batched_arms) + len(new_arms)} of {nr_of_jobs_to_get}.")
8682
+ batched_arms.extend(new_arms)
7834
8683
 
7835
- batched_arms.extend(new_arms)
7836
8684
  attempts += 1
7837
8685
 
7838
8686
  print_debug(f"get_batched_arms: Finished with {len(batched_arms)} arm(s) after {attempts} attempt(s).")
7839
8687
 
7840
- save_results_csv()
7841
-
7842
8688
  return batched_arms
7843
8689
 
7844
- def _fetch_next_trials(nr_of_jobs_to_get: int, recursion: bool = False) -> Tuple[Dict[int, Any], bool]:
8690
+ def fetch_next_trials(nr_of_jobs_to_get: int, recursion: bool = False) -> Tuple[Dict[int, Any], bool]:
7845
8691
  die_101_if_no_ax_client_or_experiment_or_gs()
7846
8692
 
7847
8693
  if not ax_client:
7848
- _fatal_error("ax_client was not defined", 9)
8694
+ _fatal_error("ax_client was not defined", 101)
7849
8695
 
7850
8696
  if global_gs is None:
7851
8697
  _fatal_error("Global generation strategy is not set. This is a bug in OmniOpt2.", 107)
7852
8698
 
7853
- return _generate_trials(nr_of_jobs_to_get, recursion)
8699
+ return generate_trials(nr_of_jobs_to_get, recursion)
7854
8700
 
7855
- def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
8701
+ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
7856
8702
  trials_dict: Dict[int, Any] = {}
7857
8703
  trial_durations: List[float] = []
7858
8704
 
@@ -7866,7 +8712,6 @@ def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
7866
8712
  if cnt >= n:
7867
8713
  break
7868
8714
 
7869
- # 🔹 Erzeuge einen komplett neuen Arm, damit Ax den Namen vergibt
7870
8715
  try:
7871
8716
  arm = Arm(parameters=arm.parameters)
7872
8717
  except Exception as arm_err:
@@ -7874,16 +8719,14 @@ def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
7874
8719
  retries += 1
7875
8720
  continue
7876
8721
 
7877
- print_debug(f"Fetching trial {cnt + 1}/{n}...")
7878
- progressbar_description(_get_trials_message(cnt + 1, n, trial_durations))
8722
+ progressbar_description(get_trials_message(cnt + 1, n, trial_durations))
7879
8723
 
7880
8724
  try:
7881
- result = _create_and_handle_trial(arm)
8725
+ result = create_and_handle_trial(arm)
7882
8726
  if result is not None:
7883
8727
  trial_index, trial_duration, trial_successful = result
7884
-
7885
8728
  except TrialRejected as e:
7886
- print_debug(f"Trial rejected: {e}")
8729
+ print_debug(f"generate_trials: Trial rejected, error: {e}")
7887
8730
  retries += 1
7888
8731
  continue
7889
8732
 
@@ -7893,19 +8736,26 @@ def _generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
7893
8736
  cnt += 1
7894
8737
  trials_dict[trial_index] = arm.parameters
7895
8738
 
7896
- save_results_csv()
8739
+ finalized = finalize_generation(trials_dict, cnt, n, start_time)
7897
8740
 
7898
- return _finalize_generation(trials_dict, cnt, n, start_time)
8741
+ return finalized
7899
8742
 
7900
8743
  except Exception as e:
7901
- return _handle_generation_failure(e, n, recursion)
8744
+ return handle_generation_failure(e, n, recursion)
7902
8745
 
7903
8746
  class TrialRejected(Exception):
7904
8747
  pass
7905
8748
 
7906
- def _create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
8749
+ def mark_abandoned(trial: Any, reason: str, trial_index: int) -> None:
8750
+ try:
8751
+ print_debug(f"[INFO] Marking trial {trial.index} ({trial.arm.name}) as abandoned, trial-index: {trial_index}. Reason: {reason}")
8752
+ trial.mark_abandoned(reason)
8753
+ except Exception as e:
8754
+ print_red(f"[ERROR] Could not mark trial as abandoned: {e}")
8755
+
8756
+ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7907
8757
  if ax_client is None:
7908
- print_red("ax_client is None in _create_and_handle_trial")
8758
+ print_red("ax_client is None in create_and_handle_trial")
7909
8759
  return None
7910
8760
 
7911
8761
  start = time.time()
@@ -7928,7 +8778,7 @@ def _create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7928
8778
  arm = trial.arms[0]
7929
8779
  if deduplicated_arm(arm):
7930
8780
  print_debug(f"Duplicated arm: {arm}")
7931
- trial.mark_abandoned(reason="Duplication detected")
8781
+ mark_abandoned(trial, "Duplication detected", trial_index)
7932
8782
  raise TrialRejected("Duplicate arm.")
7933
8783
 
7934
8784
  arms_by_name_for_deduplication[arm.name] = arm
@@ -7936,8 +8786,8 @@ def _create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7936
8786
  params = arm.parameters
7937
8787
 
7938
8788
  if not has_no_post_generation_constraints_or_matches_constraints(post_generation_constraints, params):
7939
- print_debug(f"Trial {trial_index} does not meet post-generation constraints. Marking abandoned.")
7940
- trial.mark_abandoned(reason="Post-Generation-Constraint failed")
8789
+ print_debug(f"Trial {trial_index} does not meet post-generation constraints. Marking abandoned. Params: {params}, constraints: {post_generation_constraints}")
8790
+ mark_abandoned(trial, "Post-Generation-Constraint failed", trial_index)
7941
8791
  abandoned_trial_indices.append(trial_index)
7942
8792
  raise TrialRejected("Post-generation constraints not met.")
7943
8793
 
@@ -7945,7 +8795,7 @@ def _create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7945
8795
  end = time.time()
7946
8796
  return trial_index, float(end - start), True
7947
8797
 
7948
- def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
8798
+ def finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
7949
8799
  total_time = time.time() - start_time
7950
8800
 
7951
8801
  log_gen_times.append(total_time)
@@ -7956,7 +8806,7 @@ def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int,
7956
8806
 
7957
8807
  return trials_dict, False
7958
8808
 
7959
- def _handle_generation_failure(
8809
+ def handle_generation_failure(
7960
8810
  e: Exception,
7961
8811
  requested: int,
7962
8812
  recursion: bool
@@ -7972,19 +8822,19 @@ def _handle_generation_failure(
7972
8822
  )):
7973
8823
  msg = str(e)
7974
8824
  if msg not in error_8_saved:
7975
- _print_exhaustion_warning(e, recursion)
8825
+ print_exhaustion_warning(e, recursion)
7976
8826
  error_8_saved.append(msg)
7977
8827
 
7978
8828
  if not recursion and args.revert_to_random_when_seemingly_exhausted:
7979
8829
  print_debug("Switching to random search strategy.")
7980
- set_global_gs_to_random()
7981
- return _fetch_next_trials(requested, True)
8830
+ set_global_gs_to_sobol()
8831
+ return fetch_next_trials(requested, True)
7982
8832
 
7983
- print_red(f"_handle_generation_failure: General Exception: {e}")
8833
+ print_red(f"handle_generation_failure: General Exception: {e}")
7984
8834
 
7985
8835
  return {}, True
7986
8836
 
7987
- def _print_exhaustion_warning(e: Exception, recursion: bool) -> None:
8837
+ def print_exhaustion_warning(e: Exception, recursion: bool) -> None:
7988
8838
  if not recursion and args.revert_to_random_when_seemingly_exhausted:
7989
8839
  print_yellow(f"\n⚠Error 8: {e} From now (done jobs: {count_done_jobs()}) on, random points will be generated.")
7990
8840
  else:
@@ -8029,21 +8879,24 @@ def get_model_gen_kwargs() -> dict:
8029
8879
  "fit_out_of_design": args.fit_out_of_design
8030
8880
  }
8031
8881
 
8032
- def set_global_gs_to_random() -> None:
8882
+ def set_global_gs_to_sobol() -> None:
8033
8883
  global global_gs
8034
8884
  global overwritten_to_random
8035
8885
 
8886
+ print("Reverting to SOBOL")
8887
+
8036
8888
  global_gs = GenerationStrategy(
8037
8889
  name="Random*",
8038
8890
  nodes=[
8039
8891
  GenerationNode(
8040
- node_name="Sobol",
8041
- generator_specs=[
8042
- GeneratorSpec(
8043
- Models.SOBOL,
8044
- model_gen_kwargs=get_model_gen_kwargs()
8045
- )
8046
- ]
8892
+ name="Sobol",
8893
+ should_deduplicate=True,
8894
+ generator_specs=[ # type: ignore[arg-type]
8895
+ GeneratorSpec( # type: ignore[arg-type]
8896
+ Generators.SOBOL, # type: ignore[arg-type]
8897
+ model_gen_kwargs=get_model_gen_kwargs() # type: ignore[arg-type]
8898
+ ) # type: ignore[arg-type]
8899
+ ] # type: ignore[arg-type]
8047
8900
  )
8048
8901
  ]
8049
8902
  )
@@ -8261,14 +9114,14 @@ def _handle_linalg_error(error: Union[None, str, Exception]) -> None:
8261
9114
  """Handles the np.linalg.LinAlgError based on the model being used."""
8262
9115
  print_red(f"Error: {error}")
8263
9116
 
8264
- def _get_next_trials(nr_of_jobs_to_get: int) -> Tuple[Union[None, dict], bool]:
8265
- finish_previous_jobs(["finishing jobs (_get_next_trials)"])
9117
+ def get_next_trials(nr_of_jobs_to_get: int) -> Tuple[Union[None, dict], bool]:
9118
+ finish_previous_jobs(["finishing jobs (get_next_trials)"])
8266
9119
 
8267
- if break_run_search("_get_next_trials", max_eval) or nr_of_jobs_to_get == 0:
9120
+ if break_run_search("get_next_trials", max_eval) or nr_of_jobs_to_get == 0:
8268
9121
  return {}, True
8269
9122
 
8270
9123
  try:
8271
- trial_index_to_param, optimization_complete = _fetch_next_trials(nr_of_jobs_to_get)
9124
+ trial_index_to_param, optimization_complete = fetch_next_trials(nr_of_jobs_to_get)
8272
9125
 
8273
9126
  cf = currentframe()
8274
9127
  if cf:
@@ -8403,16 +9256,21 @@ def get_model_from_name(name: str) -> Any:
8403
9256
  return gen
8404
9257
  raise ValueError(f"Unknown or unsupported model: {name}")
8405
9258
 
8406
- def get_name_from_model(model) -> Optional[str]:
9259
+ def get_name_from_model(model: Any) -> str:
8407
9260
  if not isinstance(SUPPORTED_MODELS, (list, set, tuple)):
8408
- return None
9261
+ raise RuntimeError("get_model_from_name: SUPPORTED_MODELS was not a list, set or tuple. Cannot continue")
8409
9262
 
8410
9263
  model_str = model.value if hasattr(model, "value") else str(model)
8411
9264
 
8412
9265
  model_str_lower = model_str.lower()
8413
9266
  model_map = {m.lower(): m for m in SUPPORTED_MODELS}
8414
9267
 
8415
- return model_map.get(model_str_lower, None)
9268
+ ret = model_map.get(model_str_lower, None)
9269
+
9270
+ if ret is None:
9271
+ raise RuntimeError("get_name_from_model: failed to get Model")
9272
+
9273
+ return ret
8416
9274
 
8417
9275
  def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str, int]], int]:
8418
9276
  gen_strat_list = []
@@ -8423,10 +9281,10 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
8423
9281
 
8424
9282
  for s in splitted_by_comma:
8425
9283
  if "=" not in s:
8426
- print(f"'{s}' does not contain '='")
9284
+ print_red(f"'{s}' does not contain '='")
8427
9285
  my_exit(123)
8428
9286
  if s.count("=") != 1:
8429
- print(f"There can only be one '=' in the gen_strat_str's element '{s}'")
9287
+ print_red(f"There can only be one '=' in the gen_strat_str's element '{s}'")
8430
9288
  my_exit(123)
8431
9289
 
8432
9290
  model_name, nr_str = s.split("=")
@@ -8436,13 +9294,13 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
8436
9294
  _fatal_error(f"Model {matching_model} is not valid for custom generation strategy.", 56)
8437
9295
 
8438
9296
  if not matching_model:
8439
- print(f"'{model_name}' not found in SUPPORTED_MODELS")
9297
+ print_red(f"'{model_name}' not found in SUPPORTED_MODELS")
8440
9298
  my_exit(123)
8441
9299
 
8442
9300
  try:
8443
9301
  nr = int(nr_str)
8444
9302
  except ValueError:
8445
- print(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
9303
+ print_red(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
8446
9304
  my_exit(123)
8447
9305
 
8448
9306
  gen_strat_list.append({matching_model: nr})
@@ -8566,7 +9424,7 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
8566
9424
  if model_name == "TPE":
8567
9425
  if len(arg_result_names) != 1:
8568
9426
  _fatal_error(f"Has {len(arg_result_names)} results. TPE currently only supports single-objective-optimization.", 108)
8569
- return ExternalProgramGenerationNode(external_generator=f"python3 {script_dir}/.tpe.py", node_name="EXTERNAL_GENERATOR")
9427
+ return ExternalProgramGenerationNode(external_generator=f"python3 {script_dir}/.tpe.py", name="EXTERNAL_GENERATOR")
8570
9428
 
8571
9429
  external_generators = {
8572
9430
  "PSEUDORANDOM": f"python3 {script_dir}/.random_generator.py",
@@ -8580,10 +9438,10 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
8580
9438
  cmd = external_generators[model_name]
8581
9439
  if model_name == "EXTERNAL_GENERATOR" and not cmd:
8582
9440
  _fatal_error("--external_generator is missing. Cannot create points for EXTERNAL_GENERATOR without it.", 204)
8583
- return ExternalProgramGenerationNode(external_generator=cmd, node_name="EXTERNAL_GENERATOR")
9441
+ return ExternalProgramGenerationNode(external_generator=cmd, name="EXTERNAL_GENERATOR")
8584
9442
 
8585
9443
  trans_crit = [
8586
- MaxTrials(
9444
+ MinTrials(
8587
9445
  threshold=threshold,
8588
9446
  block_transition_if_unmet=True,
8589
9447
  transition_to=target_model,
@@ -8599,11 +9457,12 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
8599
9457
  if model_name.lower() != "sobol":
8600
9458
  kwargs["model_kwargs"] = get_model_kwargs()
8601
9459
 
8602
- model_spec = [GeneratorSpec(selected_model, **kwargs)]
9460
+ model_spec = [GeneratorSpec(selected_model, **kwargs)] # type: ignore[arg-type]
8603
9461
 
8604
9462
  res = GenerationNode(
8605
- node_name=model_name,
9463
+ name=model_name,
8606
9464
  generator_specs=model_spec,
9465
+ should_deduplicate=True,
8607
9466
  transition_criteria=trans_crit
8608
9467
  )
8609
9468
 
@@ -8614,11 +9473,11 @@ def get_optimizer_kwargs() -> dict:
8614
9473
  "sequential": False
8615
9474
  }
8616
9475
 
8617
- def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = None) -> GenerationStep:
9476
+ def create_step(model_name: str, _num_trials: int, index: int) -> GenerationStep:
8618
9477
  model_enum = get_model_from_name(model_name)
8619
9478
 
8620
9479
  return GenerationStep(
8621
- generator=model_enum, # ✅ neue API
9480
+ generator=model_enum,
8622
9481
  num_trials=_num_trials,
8623
9482
  max_parallelism=1000 * max_eval + 1000,
8624
9483
  model_kwargs=get_model_kwargs(),
@@ -8629,17 +9488,16 @@ def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = N
8629
9488
  )
8630
9489
 
8631
9490
  def set_global_generation_strategy() -> None:
8632
- with spinner("Setting global generation strategy"):
8633
- continue_not_supported_on_custom_generation_strategy()
9491
+ continue_not_supported_on_custom_generation_strategy()
8634
9492
 
8635
- try:
8636
- if args.generation_strategy is None:
8637
- setup_default_generation_strategy()
8638
- else:
8639
- setup_custom_generation_strategy()
8640
- except Exception as e:
8641
- print_red(f"Unexpected error in generation strategy setup: {e}")
8642
- my_exit(111)
9493
+ try:
9494
+ if args.generation_strategy is None:
9495
+ setup_default_generation_strategy()
9496
+ else:
9497
+ setup_custom_generation_strategy()
9498
+ except Exception as e:
9499
+ print_red(f"Unexpected error in generation strategy setup: {e}")
9500
+ my_exit(111)
8643
9501
 
8644
9502
  if global_gs is None:
8645
9503
  print_red("global_gs is None after setup!")
@@ -8775,10 +9633,6 @@ def execute_trials(
8775
9633
  index_param_list.append(_args)
8776
9634
  i += 1
8777
9635
 
8778
- save_results_csv()
8779
-
8780
- save_results_csv()
8781
-
8782
9636
  start_time = time.time()
8783
9637
 
8784
9638
  cnt = 0
@@ -8794,10 +9648,14 @@ def execute_trials(
8794
9648
  result = future.result()
8795
9649
  print_debug(f"result in execute_trials: {result}")
8796
9650
  except Exception as exc:
8797
- print_red(f"execute_trials: Error at executing a trial: {exc}")
9651
+ failed_args = future_to_args[future]
9652
+ print_red(f"execute_trials: Error at executing a trial with args {failed_args}: {exc}")
9653
+ traceback.print_exc()
8798
9654
 
8799
9655
  end_time = time.time()
8800
9656
 
9657
+ log_data()
9658
+
8801
9659
  duration = float(end_time - start_time)
8802
9660
  job_submit_durations.append(duration)
8803
9661
  job_submit_nrs.append(cnt)
@@ -8831,20 +9689,20 @@ def create_and_execute_next_runs(next_nr_steps: int, phase: Optional[str], _max_
8831
9689
  done_optimizing: bool = False
8832
9690
 
8833
9691
  try:
8834
- done_optimizing, trial_index_to_param = _create_and_execute_next_runs_run_loop(_max_eval, phase)
8835
- _create_and_execute_next_runs_finish(done_optimizing)
9692
+ done_optimizing, trial_index_to_param = create_and_execute_next_runs_run_loop(_max_eval, phase)
9693
+ create_and_execute_next_runs_finish(done_optimizing)
8836
9694
  except Exception as e:
8837
9695
  stacktrace = traceback.format_exc()
8838
9696
  print_debug(f"Warning: create_and_execute_next_runs encountered an exception: {e}\n{stacktrace}")
8839
9697
  return handle_exceptions_create_and_execute_next_runs(e)
8840
9698
 
8841
- return _create_and_execute_next_runs_return_value(trial_index_to_param)
9699
+ return create_and_execute_next_runs_return_value(trial_index_to_param)
8842
9700
 
8843
- def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
9701
+ def create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
8844
9702
  done_optimizing = False
8845
9703
  trial_index_to_param: Optional[Dict] = None
8846
9704
 
8847
- nr_of_jobs_to_get = _calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
9705
+ nr_of_jobs_to_get = calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
8848
9706
 
8849
9707
  __max_eval = _max_eval if _max_eval is not None else 0
8850
9708
  new_nr_of_jobs_to_get = min(__max_eval - (submitted_jobs() - failed_jobs()), nr_of_jobs_to_get)
@@ -8857,7 +9715,8 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
8857
9715
  get_next_trials_nr = new_nr_of_jobs_to_get
8858
9716
 
8859
9717
  for _ in range(range_nr):
8860
- trial_index_to_param, done_optimizing = _get_next_trials(get_next_trials_nr)
9718
+ trial_index_to_param, done_optimizing = get_next_trials(get_next_trials_nr)
9719
+ log_data()
8861
9720
  if done_optimizing:
8862
9721
  continue
8863
9722
 
@@ -8882,13 +9741,13 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
8882
9741
 
8883
9742
  return done_optimizing, trial_index_to_param
8884
9743
 
8885
- def _create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
9744
+ def create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
8886
9745
  finish_previous_jobs(["finishing jobs"])
8887
9746
 
8888
9747
  if done_optimizing:
8889
9748
  end_program(False, 0)
8890
9749
 
8891
- def _create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
9750
+ def create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
8892
9751
  try:
8893
9752
  if trial_index_to_param:
8894
9753
  res = len(trial_index_to_param.keys())
@@ -8999,7 +9858,7 @@ def execute_nvidia_smi() -> None:
8999
9858
  if not host:
9000
9859
  print_debug("host not defined")
9001
9860
  except Exception as e:
9002
- print(f"execute_nvidia_smi: An error occurred: {e}")
9861
+ print_red(f"execute_nvidia_smi: An error occurred: {e}")
9003
9862
  if is_slurm_job() and not args.force_local_execution:
9004
9863
  _sleep(30)
9005
9864
 
@@ -9037,11 +9896,6 @@ def run_search() -> bool:
9037
9896
 
9038
9897
  return False
9039
9898
 
9040
- async def start_logging_daemon() -> None:
9041
- while True:
9042
- log_data()
9043
- time.sleep(30)
9044
-
9045
9899
  def should_break_search() -> bool:
9046
9900
  ret = False
9047
9901
 
@@ -9105,7 +9959,7 @@ def finalize_jobs() -> None:
9105
9959
  handle_slurm_execution()
9106
9960
 
9107
9961
  def go_through_jobs_that_are_not_completed_yet() -> None:
9108
- print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
9962
+ #print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
9109
9963
 
9110
9964
  nr_jobs_left = len(global_vars['jobs'])
9111
9965
  if nr_jobs_left == 1:
@@ -9175,7 +10029,7 @@ def parse_orchestrator_file(_f: str, _test: bool = False) -> Union[dict, None]:
9175
10029
 
9176
10030
  return data
9177
10031
  except Exception as e:
9178
- print(f"Error while parse_experiment_parameters({_f}): {e}")
10032
+ print_red(f"Error while parse_experiment_parameters({_f}): {e}")
9179
10033
  else:
9180
10034
  print_red(f"{_f} could not be found")
9181
10035
 
@@ -9192,355 +10046,43 @@ def set_orchestrator() -> None:
9192
10046
  print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
9193
10047
 
9194
10048
  def check_if_has_random_steps() -> None:
9195
- if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
9196
- _fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
9197
-
9198
- def add_exclude_to_defective_nodes() -> None:
9199
- with spinner("Adding excluded nodes..."):
9200
- if args.exclude:
9201
- entries = [entry.strip() for entry in args.exclude.split(',')]
9202
-
9203
- for entry in entries:
9204
- count_defective_nodes(None, entry)
9205
-
9206
- def check_max_eval(_max_eval: int) -> None:
9207
- with spinner("Checking max_eval..."):
9208
- if not _max_eval:
9209
- _fatal_error("--max_eval needs to be set!", 19)
9210
-
9211
- def parse_parameters() -> Any:
9212
- cli_params_experiment_parameters = None
9213
- if args.parameter:
9214
- parse_experiment_parameters()
9215
- cli_params_experiment_parameters = experiment_parameters
9216
-
9217
- return cli_params_experiment_parameters
9218
-
9219
- def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
9220
- table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
9221
-
9222
- rows = _pareto_front_table_read_csv()
9223
- if not rows:
9224
- table.add_column("No data found")
9225
- return table
9226
-
9227
- filtered_rows = _pareto_front_table_filter_rows(rows, idxs)
9228
- if not filtered_rows:
9229
- table.add_column("No matching entries")
9230
- return table
9231
-
9232
- param_cols, result_cols = _pareto_front_table_get_columns(filtered_rows[0])
9233
-
9234
- _pareto_front_table_add_headers(table, param_cols, result_cols)
9235
- _pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
9236
-
9237
- return table
9238
-
9239
- def _pareto_front_table_read_csv() -> List[Dict[str, str]]:
9240
- with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
9241
- return list(csv.DictReader(f))
9242
-
9243
- def _pareto_front_table_filter_rows(rows: List[Dict[str, str]], idxs: List[int]) -> List[Dict[str, str]]:
9244
- result = []
9245
- for row in rows:
9246
- try:
9247
- trial_index = int(row["trial_index"])
9248
- except (KeyError, ValueError):
9249
- continue
9250
-
9251
- if row.get("trial_status", "").strip().upper() == "COMPLETED" and trial_index in idxs:
9252
- result.append(row)
9253
- return result
9254
-
9255
- def _pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
9256
- all_columns = list(first_row.keys())
9257
- ignored_cols = set(special_col_names) - {"trial_index"}
9258
-
9259
- param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
9260
- result_cols = [col for col in arg_result_names if col in all_columns]
9261
- return param_cols, result_cols
9262
-
9263
- def _pareto_front_table_add_headers(table: Table, param_cols: List[str], result_cols: List[str]) -> None:
9264
- for col in param_cols:
9265
- table.add_column(col, justify="center")
9266
- for col in result_cols:
9267
- table.add_column(Text(f"{col}", style="cyan"), justify="center")
9268
-
9269
- def _pareto_front_table_add_rows(table: Table, rows: List[Dict[str, str]], param_cols: List[str], result_cols: List[str]) -> None:
9270
- for row in rows:
9271
- values = [str(helpers.to_int_when_possible(row[col])) for col in param_cols]
9272
- result_values = [Text(str(helpers.to_int_when_possible(row[col])), style="cyan") for col in result_cols]
9273
- table.add_row(*values, *result_values, style="bold green")
9274
-
9275
- def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
9276
- if not os.path.exists(RESULT_CSV_FILE):
9277
- print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
9278
- return None
9279
-
9280
- return create_pareto_front_table(idxs, metric_x, metric_y)
9281
-
9282
- def supports_sixel() -> bool:
9283
- term = os.environ.get("TERM", "").lower()
9284
- if "xterm" in term or "mlterm" in term:
9285
- return True
9286
-
9287
- try:
9288
- output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
9289
- if output.returncode == 0 and "sixel" in output.stdout.lower():
9290
- return True
9291
- except (subprocess.CalledProcessError, FileNotFoundError):
9292
- pass
9293
-
9294
- return False
9295
-
9296
- def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
9297
- if data is None:
9298
- print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
9299
- return
9300
-
9301
- if not supports_sixel():
9302
- print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
9303
- return
9304
-
9305
- import matplotlib.pyplot as plt
9306
-
9307
- means = data[x_metric][y_metric]["means"]
9308
-
9309
- x_values = means[x_metric]
9310
- y_values = means[y_metric]
9311
-
9312
- fig, _ax = plt.subplots()
9313
-
9314
- _ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
9315
-
9316
- _ax.set_xlabel(x_metric)
9317
- _ax.set_ylabel(y_metric)
9318
-
9319
- _ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
9320
-
9321
- _ax.ticklabel_format(style='plain', axis='both', useOffset=False)
9322
-
9323
- with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
9324
- plt.savefig(tmp_file.name, dpi=300)
9325
-
9326
- print_image_to_cli(tmp_file.name, 1000)
9327
-
9328
- plt.close(fig)
9329
-
9330
- def _pareto_front_general_validate_shapes(x: np.ndarray, y: np.ndarray) -> None:
9331
- if x.shape != y.shape:
9332
- raise ValueError("Input arrays x and y must have the same shape.")
9333
-
9334
- def _pareto_front_general_compare(
9335
- xi: float, yi: float, xj: float, yj: float,
9336
- x_minimize: bool, y_minimize: bool
9337
- ) -> bool:
9338
- x_better_eq = xj <= xi if x_minimize else xj >= xi
9339
- y_better_eq = yj <= yi if y_minimize else yj >= yi
9340
- x_strictly_better = xj < xi if x_minimize else xj > xi
9341
- y_strictly_better = yj < yi if y_minimize else yj > yi
9342
-
9343
- return bool(x_better_eq and y_better_eq and (x_strictly_better or y_strictly_better))
9344
-
9345
- def _pareto_front_general_find_dominated(
9346
- x: np.ndarray, y: np.ndarray, x_minimize: bool, y_minimize: bool
9347
- ) -> np.ndarray:
9348
- num_points = len(x)
9349
- is_dominated = np.zeros(num_points, dtype=bool)
9350
-
9351
- for i in range(num_points):
9352
- for j in range(num_points):
9353
- if i == j:
9354
- continue
9355
-
9356
- if _pareto_front_general_compare(x[i], y[i], x[j], y[j], x_minimize, y_minimize):
9357
- is_dominated[i] = True
9358
- break
9359
-
9360
- return is_dominated
9361
-
9362
- def pareto_front_general(
9363
- x: np.ndarray,
9364
- y: np.ndarray,
9365
- x_minimize: bool = True,
9366
- y_minimize: bool = True
9367
- ) -> np.ndarray:
9368
- try:
9369
- _pareto_front_general_validate_shapes(x, y)
9370
- is_dominated = _pareto_front_general_find_dominated(x, y, x_minimize, y_minimize)
9371
- return np.where(~is_dominated)[0]
9372
- except Exception as e:
9373
- print("Error in pareto_front_general:", str(e))
9374
- return np.array([], dtype=int)
9375
-
9376
- def _pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
9377
- results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
9378
- result_names_file = f"{path_to_calculate}/result_names.txt"
9379
-
9380
- if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
9381
- return None
9382
-
9383
- with open(result_names_file, mode="r", encoding="utf-8") as f:
9384
- result_names = [line.strip() for line in f if line.strip()]
9385
-
9386
- records: dict = defaultdict(lambda: {'means': {}})
9387
-
9388
- with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
9389
- reader = csv.DictReader(csvfile)
9390
- for row in reader:
9391
- trial_index = int(row['trial_index'])
9392
- arm_name = row['arm_name']
9393
- key = (trial_index, arm_name)
9394
-
9395
- for metric in result_names:
9396
- if metric in row:
9397
- try:
9398
- records[key]['means'][metric] = float(row[metric])
9399
- except ValueError:
9400
- continue
9401
-
9402
- return records
9403
-
9404
- def _pareto_front_filter_complete_points(
9405
- path_to_calculate: str,
9406
- records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
9407
- primary_name: str,
9408
- secondary_name: str
9409
- ) -> List[Tuple[Tuple[int, str], float, float]]:
9410
- points = []
9411
- for key, metrics in records.items():
9412
- means = metrics['means']
9413
- if primary_name in means and secondary_name in means:
9414
- x_val = means[primary_name]
9415
- y_val = means[secondary_name]
9416
- points.append((key, x_val, y_val))
9417
- if len(points) == 0:
9418
- raise ValueError(f"No full data points with both objectives found in {path_to_calculate}.")
9419
- return points
9420
-
9421
- def _pareto_front_transform_objectives(
9422
- points: List[Tuple[Any, float, float]],
9423
- primary_name: str,
9424
- secondary_name: str
9425
- ) -> Tuple[np.ndarray, np.ndarray]:
9426
- primary_idx = arg_result_names.index(primary_name)
9427
- secondary_idx = arg_result_names.index(secondary_name)
9428
-
9429
- x = np.array([p[1] for p in points])
9430
- y = np.array([p[2] for p in points])
9431
-
9432
- if arg_result_min_or_max[primary_idx] == "max":
9433
- x = -x
9434
- elif arg_result_min_or_max[primary_idx] != "min":
9435
- raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
9436
-
9437
- if arg_result_min_or_max[secondary_idx] == "max":
9438
- y = -y
9439
- elif arg_result_min_or_max[secondary_idx] != "min":
9440
- raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
9441
-
9442
- return x, y
9443
-
9444
- def _pareto_front_select_pareto_points(
9445
- x: np.ndarray,
9446
- y: np.ndarray,
9447
- x_minimize: bool,
9448
- y_minimize: bool,
9449
- points: List[Tuple[Any, float, float]],
9450
- num_points: int
9451
- ) -> List[Tuple[Any, float, float]]:
9452
- indices = pareto_front_general(x, y, x_minimize, y_minimize)
9453
- sorted_indices = indices[np.argsort(x[indices])]
9454
- sorted_indices = sorted_indices[:num_points]
9455
- selected_points = [points[i] for i in sorted_indices]
9456
- return selected_points
9457
-
9458
- def _pareto_front_build_return_structure(
9459
- path_to_calculate: str,
9460
- selected_points: List[Tuple[Any, float, float]],
9461
- records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
9462
- absolute_metrics: List[str],
9463
- primary_name: str,
9464
- secondary_name: str
9465
- ) -> dict:
9466
- results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
9467
- result_names_file = f"{path_to_calculate}/result_names.txt"
9468
-
9469
- with open(result_names_file, mode="r", encoding="utf-8") as f:
9470
- result_names = [line.strip() for line in f if line.strip()]
9471
-
9472
- csv_rows = {}
9473
- with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
9474
- reader = csv.DictReader(csvfile)
9475
- for row in reader:
9476
- trial_index = int(row['trial_index'])
9477
- csv_rows[trial_index] = row
9478
-
9479
- ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
9480
- ignored_columns.update(result_names)
9481
-
9482
- param_dicts = []
9483
- idxs = []
9484
- means_dict = defaultdict(list)
9485
-
9486
- for (trial_index, arm_name), _, _ in selected_points:
9487
- row = csv_rows.get(trial_index, {})
9488
- if row == {} or row is None or row['arm_name'] != arm_name:
9489
- print_debug(f"_pareto_front_build_return_structure: trial_index '{trial_index}' could not be found and row returned as None")
9490
- continue
9491
-
9492
- idxs.append(int(row["trial_index"]))
9493
-
9494
- param_dict: dict[str, int | float | str] = {}
9495
- for key, value in row.items():
9496
- if key not in ignored_columns:
9497
- try:
9498
- param_dict[key] = int(value)
9499
- except ValueError:
9500
- try:
9501
- param_dict[key] = float(value)
9502
- except ValueError:
9503
- param_dict[key] = value
10049
+ if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
10050
+ _fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
9504
10051
 
9505
- param_dicts.append(param_dict)
10052
+ def add_exclude_to_defective_nodes() -> None:
10053
+ with spinner("Adding excluded nodes..."):
10054
+ if args.exclude:
10055
+ entries = [entry.strip() for entry in args.exclude.split(',')]
9506
10056
 
9507
- for metric in absolute_metrics:
9508
- means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
10057
+ for entry in entries:
10058
+ count_defective_nodes(None, entry)
9509
10059
 
9510
- ret = {
9511
- primary_name: {
9512
- secondary_name: {
9513
- "absolute_metrics": absolute_metrics,
9514
- "param_dicts": param_dicts,
9515
- "means": dict(means_dict),
9516
- "idxs": idxs
9517
- },
9518
- "absolute_metrics": absolute_metrics
9519
- }
9520
- }
10060
+ def check_max_eval(_max_eval: int) -> None:
10061
+ with spinner("Checking max_eval..."):
10062
+ if not _max_eval:
10063
+ _fatal_error("--max_eval needs to be set!", 19)
9521
10064
 
9522
- return ret
10065
+ def parse_parameters() -> Any:
10066
+ cli_params_experiment_parameters = None
10067
+ if args.parameter:
10068
+ parse_experiment_parameters()
10069
+ cli_params_experiment_parameters = experiment_parameters
9523
10070
 
9524
- def get_pareto_frontier_points(
9525
- path_to_calculate: str,
9526
- primary_objective: str,
9527
- secondary_objective: str,
9528
- x_minimize: bool,
9529
- y_minimize: bool,
9530
- absolute_metrics: List[str],
9531
- num_points: int
9532
- ) -> Optional[dict]:
9533
- records = _pareto_front_aggregate_data(path_to_calculate)
10071
+ return cli_params_experiment_parameters
9534
10072
 
9535
- if records is None:
9536
- return None
10073
+ def supports_sixel() -> bool:
10074
+ term = os.environ.get("TERM", "").lower()
10075
+ if "xterm" in term or "mlterm" in term:
10076
+ return True
9537
10077
 
9538
- points = _pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
9539
- x, y = _pareto_front_transform_objectives(points, primary_objective, secondary_objective)
9540
- selected_points = _pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
9541
- result = _pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
10078
+ try:
10079
+ output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
10080
+ if output.returncode == 0 and "sixel" in output.stdout.lower():
10081
+ return True
10082
+ except (subprocess.CalledProcessError, FileNotFoundError):
10083
+ pass
9542
10084
 
9543
- return result
10085
+ return False
9544
10086
 
9545
10087
  def save_experiment_state() -> None:
9546
10088
  try:
@@ -9548,14 +10090,14 @@ def save_experiment_state() -> None:
9548
10090
  print_red("save_experiment_state: ax_client or ax_client.experiment is None, cannot save.")
9549
10091
  return
9550
10092
  state_path = get_current_run_folder("experiment_state.json")
9551
- ax_client.save_to_json_file(state_path)
10093
+ save_ax_client_to_json_file(state_path)
9552
10094
  except Exception as e:
9553
- print(f"Error saving experiment state: {e}")
10095
+ print_debug(f"Error saving experiment state: {e}")
9554
10096
 
9555
10097
  def wait_for_state_file(state_path: str, min_size: int = 5, max_wait_seconds: int = 60) -> bool:
9556
10098
  try:
9557
10099
  if not os.path.exists(state_path):
9558
- print(f"[ERROR] File '{state_path}' does not exist.")
10100
+ print_debug(f"[ERROR] File '{state_path}' does not exist.")
9559
10101
  return False
9560
10102
 
9561
10103
  i = 0
@@ -9610,7 +10152,6 @@ def load_experiment_state() -> None:
9610
10152
  state_path = get_current_run_folder("experiment_state.json")
9611
10153
 
9612
10154
  if not os.path.exists(state_path):
9613
- print(f"State file {state_path} does not exist, starting fresh")
9614
10155
  return
9615
10156
 
9616
10157
  if args.worker_generator_path:
@@ -9624,10 +10165,10 @@ def load_experiment_state() -> None:
9624
10165
  return
9625
10166
 
9626
10167
  try:
9627
- arms_seen = {}
10168
+ arms_seen: dict = {}
9628
10169
  for arm in data.get("arms", []):
9629
10170
  name = arm.get("name")
9630
- sig = arm.get("parameters") # grobe Signatur
10171
+ sig = arm.get("parameters")
9631
10172
  if not name:
9632
10173
  continue
9633
10174
  if name in arms_seen and arms_seen[name] != sig:
@@ -9636,7 +10177,6 @@ def load_experiment_state() -> None:
9636
10177
  arm["name"] = new_name
9637
10178
  arms_seen[name] = sig
9638
10179
 
9639
- # Gefilterten Zustand speichern und laden
9640
10180
  temp_path = state_path + ".no_conflicts.json"
9641
10181
  with open(temp_path, encoding="utf-8", mode="w") as f:
9642
10182
  json.dump(data, f)
@@ -9776,38 +10316,47 @@ def get_result_minimize_flag(path_to_calculate: str, resname: str) -> bool:
9776
10316
 
9777
10317
  return minmax[index] == "min"
9778
10318
 
9779
- def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
9780
- pareto_front_data: dict = {}
10319
+ def post_job_calculate_pareto_front() -> None:
10320
+ if not args.calculate_pareto_front_of_job:
10321
+ return
9781
10322
 
9782
- all_combinations = list(combinations(range(len(arg_result_names)), 2))
10323
+ failure = False
9783
10324
 
9784
- skip = False
10325
+ _paths_to_calculate = []
9785
10326
 
9786
- for i, j in all_combinations:
9787
- if not skip:
9788
- metric_x = arg_result_names[i]
9789
- metric_y = arg_result_names[j]
10327
+ for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
10328
+ try:
10329
+ found_paths = find_results_paths(_path_to_calculate)
9790
10330
 
9791
- x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
9792
- y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
10331
+ for _fp in found_paths:
10332
+ if _fp not in _paths_to_calculate:
10333
+ _paths_to_calculate.append(_fp)
10334
+ except (FileNotFoundError, NotADirectoryError) as e:
10335
+ print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
9793
10336
 
9794
- try:
9795
- if metric_x not in pareto_front_data:
9796
- pareto_front_data[metric_x] = {}
10337
+ failure = True
9797
10338
 
9798
- pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
9799
- except ax.exceptions.core.DataRequiredError as e:
9800
- print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
9801
- except SignalINT:
9802
- print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
9803
- skip = True
10339
+ for _path_to_calculate in _paths_to_calculate:
10340
+ for path_to_calculate in found_paths:
10341
+ if not job_calculate_pareto_front(path_to_calculate):
10342
+ failure = True
9804
10343
 
9805
- return pareto_front_data
10344
+ if failure:
10345
+ my_exit(24)
10346
+
10347
+ my_exit(0)
10348
+
10349
+ def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
10350
+ if not os.path.exists(RESULT_CSV_FILE):
10351
+ print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
10352
+ return None
10353
+
10354
+ return create_pareto_front_table(idxs, metric_x, metric_y)
9806
10355
 
9807
10356
  def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_sixel_and_table: bool = False) -> None:
9808
10357
  if len(res_names) <= 1:
9809
10358
  print_debug(f"--result_names (has {len(res_names)} entries) must be at least 2.")
9810
- return
10359
+ return None
9811
10360
 
9812
10361
  pareto_front_data: dict = get_pareto_front_data(path_to_calculate, res_names)
9813
10362
 
@@ -9828,8 +10377,16 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
9828
10377
  else:
9829
10378
  print(f"Not showing Pareto-front-sixel for {path_to_calculate}")
9830
10379
 
9831
- if len(calculated_frontier[metric_x][metric_y]["idxs"]):
9832
- pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
10380
+ if calculated_frontier is None:
10381
+ print_debug("ERROR: calculated_frontier is None")
10382
+ return None
10383
+
10384
+ try:
10385
+ if len(calculated_frontier[metric_x][metric_y]["idxs"]):
10386
+ pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
10387
+ except AttributeError:
10388
+ print_debug(f"ERROR: calculated_frontier structure invalid for ({metric_x}, {metric_y})")
10389
+ return None
9833
10390
 
9834
10391
  rich_table = pareto_front_as_rich_table(
9835
10392
  calculated_frontier[metric_x][metric_y]["idxs"],
@@ -9854,6 +10411,8 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
9854
10411
 
9855
10412
  live_share_after_pareto()
9856
10413
 
10414
+ return None
10415
+
9857
10416
  def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_color: str) -> None:
9858
10417
  cpu_count = os.cpu_count()
9859
10418
 
@@ -9865,9 +10424,11 @@ def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_
9865
10424
  pass
9866
10425
 
9867
10426
  if gpu_string:
9868
- console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}] [green]{gs_string}[/green]")
10427
+ console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]")
9869
10428
  else:
9870
- print_green(f"You have {cpu_count} CPUs available for the main process. {gs_string}")
10429
+ print_green(f"You have {cpu_count} CPUs available for the main process.")
10430
+
10431
+ print_green(gs_string)
9871
10432
 
9872
10433
  def write_args_overview_table() -> None:
9873
10434
  table = Table(title="Arguments Overview")
@@ -10134,112 +10695,6 @@ def find_results_paths(base_path: str) -> list:
10134
10695
 
10135
10696
  return list(set(found_paths))
10136
10697
 
10137
- def post_job_calculate_pareto_front() -> None:
10138
- if not args.calculate_pareto_front_of_job:
10139
- return
10140
-
10141
- failure = False
10142
-
10143
- _paths_to_calculate = []
10144
-
10145
- for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
10146
- try:
10147
- found_paths = find_results_paths(_path_to_calculate)
10148
-
10149
- for _fp in found_paths:
10150
- if _fp not in _paths_to_calculate:
10151
- _paths_to_calculate.append(_fp)
10152
- except (FileNotFoundError, NotADirectoryError) as e:
10153
- print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
10154
-
10155
- failure = True
10156
-
10157
- for _path_to_calculate in _paths_to_calculate:
10158
- for path_to_calculate in found_paths:
10159
- if not job_calculate_pareto_front(path_to_calculate):
10160
- failure = True
10161
-
10162
- if failure:
10163
- my_exit(24)
10164
-
10165
- my_exit(0)
10166
-
10167
- def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
10168
- pf_start_time = time.time()
10169
-
10170
- if not path_to_calculate:
10171
- return False
10172
-
10173
- global CURRENT_RUN_FOLDER
10174
- global RESULT_CSV_FILE
10175
- global arg_result_names
10176
-
10177
- if not path_to_calculate:
10178
- print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
10179
- return False
10180
-
10181
- if not os.path.exists(path_to_calculate):
10182
- print_red(f"Path '{path_to_calculate}' does not exist")
10183
- return False
10184
-
10185
- ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
10186
-
10187
- if not os.path.exists(ax_client_json):
10188
- print_red(f"Path '{ax_client_json}' not found")
10189
- return False
10190
-
10191
- checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
10192
- if not os.path.exists(checkpoint_file):
10193
- print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
10194
- return False
10195
-
10196
- RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
10197
- if not os.path.exists(RESULT_CSV_FILE):
10198
- print_red(f"{RESULT_CSV_FILE} not found")
10199
- return False
10200
-
10201
- res_names = []
10202
-
10203
- res_names_file = f"{path_to_calculate}/result_names.txt"
10204
- if not os.path.exists(res_names_file):
10205
- print_red(f"File '{res_names_file}' does not exist")
10206
- return False
10207
-
10208
- try:
10209
- with open(res_names_file, "r", encoding="utf-8") as file:
10210
- lines = file.readlines()
10211
- except Exception as e:
10212
- print_red(f"Error reading file '{res_names_file}': {e}")
10213
- return False
10214
-
10215
- for line in lines:
10216
- entry = line.strip()
10217
- if entry != "":
10218
- res_names.append(entry)
10219
-
10220
- if len(res_names) < 2:
10221
- print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
10222
- return False
10223
-
10224
- load_username_to_args(path_to_calculate)
10225
-
10226
- CURRENT_RUN_FOLDER = path_to_calculate
10227
-
10228
- arg_result_names = res_names
10229
-
10230
- load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
10231
-
10232
- if experiment_parameters is None:
10233
- return False
10234
-
10235
- show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
10236
-
10237
- pf_end_time = time.time()
10238
-
10239
- print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
10240
-
10241
- return True
10242
-
10243
10698
  def set_arg_states_from_continue() -> None:
10244
10699
  if args.continue_previous_job and not args.num_random_steps:
10245
10700
  num_random_steps_file = f"{args.continue_previous_job}/state_files/num_random_steps"
@@ -10273,7 +10728,7 @@ def write_result_names_file() -> None:
10273
10728
  except Exception as e:
10274
10729
  print_red(f"Error trying to open file '{fn}': {e}")
10275
10730
 
10276
- def run_program_once(params=None) -> None:
10731
+ def run_program_once(params: Optional[dict] = None) -> None:
10277
10732
  if not args.run_program_once:
10278
10733
  print_debug("[yellow]No setup script specified (run_program_once). Skipping setup.[/yellow]")
10279
10734
  return
@@ -10282,27 +10737,27 @@ def run_program_once(params=None) -> None:
10282
10737
  params = {}
10283
10738
 
10284
10739
  if isinstance(args.run_program_once, str):
10285
- command_str = args.run_program_once
10740
+ command_str = decode_if_base64(args.run_program_once)
10286
10741
  for k, v in params.items():
10287
10742
  placeholder = f"%({k})"
10288
10743
  command_str = command_str.replace(placeholder, str(v))
10289
10744
 
10290
- with spinner(f"Executing command: [cyan]{command_str}[/cyan]"):
10291
- result = subprocess.run(command_str, shell=True, check=True)
10292
- if result.returncode == 0:
10293
- console.log("[bold green]Setup script completed successfully ✅[/bold green]")
10294
- else:
10295
- console.log(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10745
+ print(f"Executing command: [cyan]{command_str}[/cyan]")
10746
+ result = subprocess.run(command_str, shell=True, check=True)
10747
+ if result.returncode == 0:
10748
+ print("[bold green]Setup script completed successfully ✅[/bold green]")
10749
+ else:
10750
+ print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10296
10751
 
10297
- my_exit(57)
10752
+ my_exit(57)
10298
10753
 
10299
10754
  elif isinstance(args.run_program_once, (list, tuple)):
10300
10755
  with spinner("run_program_once: Executing command list: [cyan]{args.run_program_once}[/cyan]"):
10301
10756
  result = subprocess.run(args.run_program_once, check=True)
10302
10757
  if result.returncode == 0:
10303
- console.log("[bold green]Setup script completed successfully ✅[/bold green]")
10758
+ print("[bold green]Setup script completed successfully ✅[/bold green]")
10304
10759
  else:
10305
- console.log(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10760
+ print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10306
10761
 
10307
10762
  my_exit(57)
10308
10763
 
@@ -10312,7 +10767,7 @@ def run_program_once(params=None) -> None:
10312
10767
  my_exit(57)
10313
10768
 
10314
10769
  def show_omniopt_call() -> None:
10315
- def remove_ui_url(arg_str) -> str:
10770
+ def remove_ui_url(arg_str: str) -> str:
10316
10771
  return re.sub(r'(?:--ui_url(?:=\S+)?(?:\s+\S+)?)', '', arg_str).strip()
10317
10772
 
10318
10773
  original_argv = " ".join(sys.argv[1:])
@@ -10320,8 +10775,14 @@ def show_omniopt_call() -> None:
10320
10775
 
10321
10776
  original_print(oo_call + " " + cleaned)
10322
10777
 
10778
+ if args.dependency is not None and args.dependency != "":
10779
+ print(f"Dependency: {args.dependency}")
10780
+
10781
+ if args.ui_url is not None and args.ui_url != "":
10782
+ print_yellow("--ui_url is deprecated. Do not use it anymore. It will be ignored and one day be removed.")
10783
+
10323
10784
  def main() -> None:
10324
- global RESULT_CSV_FILE, ax_client, LOGFILE_DEBUG_GET_NEXT_TRIALS
10785
+ global RESULT_CSV_FILE, LOGFILE_DEBUG_GET_NEXT_TRIALS
10325
10786
 
10326
10787
  check_if_has_random_steps()
10327
10788
 
@@ -10372,7 +10833,7 @@ def main() -> None:
10372
10833
  print_run_info()
10373
10834
 
10374
10835
  initialize_nvidia_logs()
10375
- write_ui_url_if_present()
10836
+ write_ui_url()
10376
10837
 
10377
10838
  LOGFILE_DEBUG_GET_NEXT_TRIALS = get_current_run_folder('get_next_trials.csv')
10378
10839
  cli_params_experiment_parameters = parse_parameters()
@@ -10403,15 +10864,13 @@ def main() -> None:
10403
10864
  exp_params = get_experiment_parameters(cli_params_experiment_parameters)
10404
10865
 
10405
10866
  if exp_params is not None:
10406
- ax_client, experiment_args, gpu_string, gpu_color = exp_params
10867
+ experiment_args, gpu_string, gpu_color = exp_params
10407
10868
  print_debug(f"experiment_parameters: {experiment_parameters}")
10408
10869
 
10409
10870
  set_orchestrator()
10410
10871
 
10411
10872
  init_live_share()
10412
10873
 
10413
- start_periodic_live_share()
10414
-
10415
10874
  show_available_hardware_and_generation_strategy_string(gpu_string, gpu_color)
10416
10875
 
10417
10876
  original_print(f"Run-Program: {global_vars['joined_run_program']}")
@@ -10426,6 +10885,8 @@ def main() -> None:
10426
10885
 
10427
10886
  write_files_and_show_overviews()
10428
10887
 
10888
+ live_share()
10889
+
10429
10890
  #if args.continue_previous_job:
10430
10891
  # insert_jobs_from_csv(f"{args.continue_previous_job}/{RESULTS_CSV_FILENAME}")
10431
10892
 
@@ -10434,7 +10895,7 @@ def main() -> None:
10434
10895
 
10435
10896
  set_global_generation_strategy()
10436
10897
 
10437
- start_worker_generators()
10898
+ #start_worker_generators()
10438
10899
 
10439
10900
  try:
10440
10901
  run_search_with_progress_bar()
@@ -10500,10 +10961,93 @@ def initialize_nvidia_logs() -> None:
10500
10961
  global NVIDIA_SMI_LOGS_BASE
10501
10962
  NVIDIA_SMI_LOGS_BASE = get_current_run_folder('gpu_usage_')
10502
10963
 
10503
- def write_ui_url_if_present() -> None:
10504
- if args.ui_url:
10505
- with open(get_current_run_folder("ui_url.txt"), mode="a", encoding="utf-8") as myfile:
10506
- myfile.write(decode_if_base64(args.ui_url))
10964
+ def build_gui_url(config: argparse.Namespace) -> str:
10965
+ base_url = get_base_url()
10966
+ params = collect_params(config)
10967
+ ret = f"{base_url}?{urlencode(params, doseq=True)}"
10968
+
10969
+ return ret
10970
+
10971
+ def get_result_names_for_url(value: List) -> str:
10972
+ d = dict(v.split("=", 1) if "=" in v else (v, "min") for v in value)
10973
+ s = " ".join(f"{k}={v}" for k, v in d.items())
10974
+
10975
+ return s
10976
+
10977
+ def collect_params(config: argparse.Namespace) -> dict:
10978
+ params = {}
10979
+ user_home = os.path.expanduser("~")
10980
+
10981
+ for attr, value in vars(config).items():
10982
+ if attr == "run_program":
10983
+ params[attr] = global_vars["joined_run_program"]
10984
+ elif attr == "result_names" and value:
10985
+ params[attr] = get_result_names_for_url(value)
10986
+ elif attr == "parameter" and value is not None:
10987
+ params.update(process_parameters(config.parameter))
10988
+ elif attr == "root_venv_dir":
10989
+ if value is not None and os.path.abspath(value) != os.path.abspath(user_home):
10990
+ params[attr] = value
10991
+ elif isinstance(value, bool):
10992
+ params[attr] = int(value)
10993
+ elif isinstance(value, list):
10994
+ params[attr] = value
10995
+ elif value is not None:
10996
+ params[attr] = value
10997
+
10998
+ return params
10999
+
11000
+ def process_parameters(parameters: list) -> dict:
11001
+ params = {}
11002
+ for i, param in enumerate(parameters):
11003
+ if isinstance(param, dict):
11004
+ name = param.get("name", f"param_{i}")
11005
+ ptype = param.get("type", "unknown")
11006
+ else:
11007
+ name = param[0] if len(param) > 0 else f"param_{i}"
11008
+ ptype = param[1] if len(param) > 1 else "unknown"
11009
+
11010
+ params[f"parameter_{i}_name"] = name
11011
+ params[f"parameter_{i}_type"] = ptype
11012
+
11013
+ if ptype == "range":
11014
+ params.update(process_range_parameter(i, param))
11015
+ elif ptype == "choice":
11016
+ params.update(process_choice_parameter(i, param))
11017
+ elif ptype == "fixed":
11018
+ params.update(process_fixed_parameter(i, param))
11019
+
11020
+ params["num_parameters"] = len(parameters)
11021
+ return params
11022
+
11023
+ def process_range_parameter(i: int, param: list) -> dict:
11024
+ return {
11025
+ f"parameter_{i}_min": param[2] if len(param) > 3 else 0,
11026
+ f"parameter_{i}_max": param[3] if len(param) > 3 else 1,
11027
+ f"parameter_{i}_number_type": param[4] if len(param) > 4 else "float",
11028
+ f"parameter_{i}_log_scale": "false",
11029
+ }
11030
+
11031
+ def process_choice_parameter(i: int, param: list) -> dict:
11032
+ choices = ""
11033
+ if len(param) > 2 and param[2]:
11034
+ choices = ",".join([c.strip() for c in str(param[2]).split(",")])
11035
+ return {f"parameter_{i}_values": choices}
11036
+
11037
+ def process_fixed_parameter(i: int, param: list) -> dict:
11038
+ return {f"parameter_{i}_value": param[2] if len(param) > 2 else ""}
11039
+
11040
+ def get_base_url() -> str:
11041
+ file_path = Path.home() / ".oo_base_url"
11042
+ if file_path.exists():
11043
+ return file_path.read_text().strip()
11044
+
11045
+ return "https://imageseg.scads.de/omniax/"
11046
+
11047
+ def write_ui_url() -> None:
11048
+ url = build_gui_url(args)
11049
+ with open(get_current_run_folder("ui_url.txt"), mode="a", encoding="utf-8") as myfile:
11050
+ myfile.write(url)
10507
11051
 
10508
11052
  def handle_random_steps() -> None:
10509
11053
  if args.parameter and args.continue_previous_job and random_steps <= 0:
@@ -10562,8 +11106,6 @@ def run_search_with_progress_bar() -> None:
10562
11106
  wait_for_jobs_to_complete()
10563
11107
 
10564
11108
  def complex_tests(_program_name: str, wanted_stderr: str, wanted_exit_code: int, wanted_signal: Union[int, None], res_is_none: bool = False) -> int:
10565
- #print_yellow(f"Test suite: {_program_name}")
10566
-
10567
11109
  nr_errors: int = 0
10568
11110
 
10569
11111
  program_path: str = f"./.tests/test_wronggoing_stuff.bin/bin/{_program_name}"
@@ -10637,7 +11179,7 @@ def test_find_paths(program_code: str) -> int:
10637
11179
  for i in files:
10638
11180
  if i not in string:
10639
11181
  if os.path.exists(i):
10640
- print("Missing {i} in find_file_paths string!")
11182
+ print(f"Missing {i} in find_file_paths string!")
10641
11183
  nr_errors += 1
10642
11184
 
10643
11185
  return nr_errors
@@ -11018,17 +11560,16 @@ Exit-Code: 159
11018
11560
 
11019
11561
  my_exit(nr_errors)
11020
11562
 
11021
- def main_outside() -> None:
11563
+ def main_wrapper() -> None:
11022
11564
  print(f"Run-UUID: {run_uuid}")
11023
11565
 
11024
11566
  auto_wrap_namespace(globals())
11025
11567
 
11026
11568
  print_logo()
11027
11569
 
11028
- start_logging_daemon()
11029
-
11030
11570
  fool_linter(args.num_cpus_main_job)
11031
11571
  fool_linter(args.flame_graph)
11572
+ fool_linter(args.memray)
11032
11573
 
11033
11574
  with warnings.catch_warnings():
11034
11575
  warnings.simplefilter("ignore")
@@ -11060,9 +11601,34 @@ def main_outside() -> None:
11060
11601
  else:
11061
11602
  end_program(True)
11062
11603
 
11604
+ def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
11605
+ pattern = re.compile(regex) if regex else None
11606
+
11607
+ def wrapped(*args: Any, **kwargs: Any) -> None:
11608
+ if pattern and not pattern.search(func.__name__):
11609
+ return func(*args, **kwargs)
11610
+
11611
+ stack = inspect.stack()
11612
+ chain = []
11613
+ for frame in stack[1:]:
11614
+ fn = frame.function
11615
+ if fn in ("wrapped", "<module>"):
11616
+ continue
11617
+ chain.append(fn)
11618
+
11619
+ if chain:
11620
+ sys.stderr.write(" ⇒ ".join(reversed(chain)) + "\n")
11621
+
11622
+ return func(*args, **kwargs)
11623
+
11624
+ return wrapped
11625
+
11063
11626
  def auto_wrap_namespace(namespace: Any) -> Any:
11064
11627
  enable_beartype = any(os.getenv(v) for v in ("ENABLE_BEARTYPE", "CI"))
11065
11628
 
11629
+ if args.beartype:
11630
+ enable_beartype = True
11631
+
11066
11632
  excluded_functions = {
11067
11633
  "log_time_and_memory_wrapper",
11068
11634
  "collect_runtime_stats",
@@ -11071,8 +11637,6 @@ def auto_wrap_namespace(namespace: Any) -> Any:
11071
11637
  "_record_stats",
11072
11638
  "_open",
11073
11639
  "_check_memory_leak",
11074
- "start_periodic_live_share",
11075
- "start_logging_daemon",
11076
11640
  "get_current_run_folder",
11077
11641
  "show_func_name_wrapper"
11078
11642
  }
@@ -11089,14 +11653,16 @@ def auto_wrap_namespace(namespace: Any) -> Any:
11089
11653
  if args.show_func_name:
11090
11654
  wrapped = show_func_name_wrapper(wrapped)
11091
11655
 
11656
+ if args.debug_stack_trace_regex:
11657
+ wrapped = stack_trace_wrapper(wrapped, args.debug_stack_trace_regex)
11658
+
11092
11659
  namespace[name] = wrapped
11093
11660
 
11094
11661
  return namespace
11095
11662
 
11096
11663
  if __name__ == "__main__":
11097
-
11098
11664
  try:
11099
- main_outside()
11665
+ main_wrapper()
11100
11666
  except (SignalUSR, SignalINT, SignalCONT) as e:
11101
- print_red(f"main_outside failed with exception {e}")
11667
+ print_red(f"main_wrapper failed with exception {e}")
11102
11668
  end_program(True)