omniopt2 8366__py3-none-any.whl → 9061__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omniopt2 might be problematic. Click here for more details.

Files changed (48) hide show
  1. .gitignore +2 -0
  2. .helpers.py +0 -9
  3. .omniopt.py +1369 -1011
  4. .pareto.py +134 -0
  5. .shellscript_functions +2 -2
  6. .tests/pylint.rc +0 -4
  7. .tpe.py +4 -3
  8. omniopt +56 -27
  9. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.helpers.py +0 -9
  10. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt.py +1369 -1011
  11. omniopt2-9061.data/data/bin/.pareto.py +134 -0
  12. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.shellscript_functions +2 -2
  13. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.tpe.py +4 -3
  14. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt +56 -27
  15. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/pylint.rc +0 -4
  16. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/requirements.txt +1 -2
  17. {omniopt2-8366.dist-info → omniopt2-9061.dist-info}/METADATA +2 -3
  18. omniopt2-9061.dist-info/RECORD +73 -0
  19. omniopt2.egg-info/PKG-INFO +2 -3
  20. omniopt2.egg-info/SOURCES.txt +1 -0
  21. omniopt2.egg-info/requires.txt +1 -2
  22. pyproject.toml +1 -1
  23. requirements.txt +1 -2
  24. omniopt2-8366.dist-info/RECORD +0 -71
  25. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.colorfunctions.sh +0 -0
  26. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.general.sh +0 -0
  27. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_cpu_ram_usage.py +0 -0
  28. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_general.py +0 -0
  29. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_gpu_usage.py +0 -0
  30. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_kde.py +0 -0
  31. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_scatter.py +0 -0
  32. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_scatter_generation_method.py +0 -0
  33. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_scatter_hex.py +0 -0
  34. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_time_and_exit_code.py +0 -0
  35. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_trial_index_result.py +0 -0
  36. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.omniopt_plot_worker.py +0 -0
  37. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/.random_generator.py +0 -0
  38. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/LICENSE +0 -0
  39. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/apt-dependencies.txt +0 -0
  40. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt_docker +0 -0
  41. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt_evaluate +0 -0
  42. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt_plot +0 -0
  43. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/omniopt_share +0 -0
  44. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/setup.py +0 -0
  45. {omniopt2-8366.data → omniopt2-9061.data}/data/bin/test_requirements.txt +0 -0
  46. {omniopt2-8366.dist-info → omniopt2-9061.dist-info}/WHEEL +0 -0
  47. {omniopt2-8366.dist-info → omniopt2-9061.dist-info}/licenses/LICENSE +0 -0
  48. {omniopt2-8366.dist-info → omniopt2-9061.dist-info}/top_level.txt +0 -0
.omniopt.py CHANGED
@@ -27,11 +27,15 @@ import inspect
27
27
  import tracemalloc
28
28
  import resource
29
29
  from urllib.parse import urlencode
30
-
31
30
  import psutil
32
31
 
33
32
  FORCE_EXIT: bool = False
34
33
 
34
+ LAST_LOG_TIME: int = 0
35
+ last_msg_progressbar = ""
36
+ last_msg_raw = None
37
+ last_lock_print_debug = threading.Lock()
38
+
35
39
  def force_exit(signal_number: Any, frame: Any) -> Any:
36
40
  global FORCE_EXIT
37
41
 
@@ -61,10 +65,10 @@ _last_count_time = 0
61
65
  _last_count_result: tuple[int, str] = (0, "")
62
66
 
63
67
  _total_time = 0.0
64
- _func_times = defaultdict(float)
65
- _func_mem = defaultdict(float)
66
- _func_call_paths = defaultdict(Counter)
67
- _last_mem = defaultdict(float)
68
+ _func_times: defaultdict = defaultdict(float)
69
+ _func_mem: defaultdict = defaultdict(float)
70
+ _func_call_paths: defaultdict = defaultdict(Counter)
71
+ _last_mem: defaultdict = defaultdict(float)
68
72
  _leak_threshold_mb = 10.0
69
73
  generation_strategy_names: list = []
70
74
  default_max_range_difference: int = 1000000
@@ -84,6 +88,7 @@ log_nr_gen_jobs: list[int] = []
84
88
  generation_strategy_human_readable: str = ""
85
89
  oo_call: str = "./omniopt"
86
90
  progress_bar_length: int = 0
91
+ worker_usage_file = 'worker_usage.csv'
87
92
 
88
93
  if os.environ.get("CUSTOM_VIRTUAL_ENV") == "1":
89
94
  oo_call = "omniopt"
@@ -99,7 +104,7 @@ joined_valid_occ_types: str = ", ".join(valid_occ_types)
99
104
  SUPPORTED_MODELS: list = ["SOBOL", "FACTORIAL", "SAASBO", "BOTORCH_MODULAR", "UNIFORM", "BO_MIXED", "RANDOMFOREST", "EXTERNAL_GENERATOR", "PSEUDORANDOM", "TPE"]
100
105
  joined_supported_models: str = ", ".join(SUPPORTED_MODELS)
101
106
 
102
- special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid"]
107
+ special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid", "runtime", "status"]
103
108
 
104
109
  IGNORABLE_COLUMNS: list = ["start_time", "end_time", "hostname", "signal", "exit_code", "run_time", "program_string"] + special_col_names
105
110
 
@@ -153,12 +158,6 @@ try:
153
158
  message="Ax currently requires a sqlalchemy version below 2.0.*",
154
159
  )
155
160
 
156
- warnings.filterwarnings(
157
- "ignore",
158
- category=RuntimeWarning,
159
- message="coroutine 'start_logging_daemon' was never awaited"
160
- )
161
-
162
161
  with spinner("Importing argparse..."):
163
162
  import argparse
164
163
 
@@ -204,6 +203,9 @@ try:
204
203
  with spinner("Importing rich.pretty..."):
205
204
  from rich.pretty import pprint
206
205
 
206
+ with spinner("Importing pformat..."):
207
+ from pprint import pformat
208
+
207
209
  with spinner("Importing rich.prompt..."):
208
210
  from rich.prompt import Prompt, FloatPrompt, IntPrompt
209
211
 
@@ -237,9 +239,6 @@ try:
237
239
  with spinner("Importing uuid..."):
238
240
  import uuid
239
241
 
240
- #with spinner("Importing qrcode..."):
241
- # import qrcode
242
-
243
242
  with spinner("Importing cowsay..."):
244
243
  import cowsay
245
244
 
@@ -333,7 +332,7 @@ def show_func_name_wrapper(func: F) -> F:
333
332
 
334
333
  return result
335
334
 
336
- return wrapper # type: ignore
335
+ return wrapper # type: ignore
337
336
 
338
337
  def log_time_and_memory_wrapper(func: F) -> F:
339
338
  @functools.wraps(func)
@@ -360,7 +359,7 @@ def log_time_and_memory_wrapper(func: F) -> F:
360
359
 
361
360
  return result
362
361
 
363
- return wrapper # type: ignore
362
+ return wrapper # type: ignore
364
363
 
365
364
  def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: float, mem_peak: float) -> None:
366
365
  global _total_time
@@ -381,16 +380,9 @@ def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: fl
381
380
  call_path_str = " -> ".join(short_stack)
382
381
  _func_call_paths[func_name][call_path_str] += 1
383
382
 
384
- print(
385
- f"Function '{func_name}' took {elapsed:.4f}s "
386
- f"(total {percent_if_added:.1f}% of tracked time)"
387
- )
388
- print(
389
- f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, "
390
- f"diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB"
391
- )
383
+ print(f"Function '{func_name}' took {elapsed:.4f}s (total {percent_if_added:.1f}% of tracked time)")
384
+ print(f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB")
392
385
 
393
- # NEU: Runtime Stats
394
386
  runtime_stats = collect_runtime_stats()
395
387
  print("=== Runtime Stats ===")
396
388
  print(f"RSS: {runtime_stats['rss_MB']:.2f} MB, VMS: {runtime_stats['vms_MB']:.2f} MB")
@@ -449,7 +441,7 @@ RESET: str = "\033[0m"
449
441
 
450
442
  uuid_regex: Pattern = re.compile(r"^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[89aAbB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$")
451
443
 
452
- worker_generator_uuid: str = uuid.uuid4()
444
+ worker_generator_uuid: str = str(uuid.uuid4())
453
445
 
454
446
  new_uuid: str = str(uuid.uuid4())
455
447
  run_uuid: str = os.getenv("RUN_UUID", new_uuid)
@@ -478,7 +470,7 @@ def get_current_run_folder(name: Optional[str] = None) -> str:
478
470
 
479
471
  return CURRENT_RUN_FOLDER
480
472
 
481
- def get_state_file_name(name) -> str:
473
+ def get_state_file_name(name: str) -> str:
482
474
  state_files_folder = f"{get_current_run_folder()}/state_files/"
483
475
  makedirs(state_files_folder)
484
476
 
@@ -502,6 +494,24 @@ try:
502
494
  dier: FunctionType = helpers.dier
503
495
  is_equal: FunctionType = helpers.is_equal
504
496
  is_not_equal: FunctionType = helpers.is_not_equal
497
+ with spinner("Importing pareto..."):
498
+ pareto_file: str = f"{script_dir}/.pareto.py"
499
+ spec = importlib.util.spec_from_file_location(
500
+ name="pareto",
501
+ location=pareto_file,
502
+ )
503
+ if spec is not None and spec.loader is not None:
504
+ pareto = importlib.util.module_from_spec(spec)
505
+ spec.loader.exec_module(pareto)
506
+ else:
507
+ raise ImportError(f"Could not load module from {pareto_file}")
508
+
509
+ pareto_front_table_filter_rows: FunctionType = pareto.pareto_front_table_filter_rows
510
+ pareto_front_table_add_headers: FunctionType = pareto.pareto_front_table_add_headers
511
+ pareto_front_table_add_rows: FunctionType = pareto.pareto_front_table_add_rows
512
+ pareto_front_filter_complete_points: FunctionType = pareto.pareto_front_filter_complete_points
513
+ pareto_front_select_pareto_points: FunctionType = pareto.pareto_front_select_pareto_points
514
+
505
515
  except KeyboardInterrupt:
506
516
  print("You pressed CTRL-c while importing the helpers file")
507
517
  sys.exit(0)
@@ -525,15 +535,13 @@ def is_slurm_job() -> bool:
525
535
  return True
526
536
  return False
527
537
 
528
- def _sleep(t: int) -> int:
538
+ def _sleep(t: Union[float, int]) -> None:
529
539
  if args is not None and not args.no_sleep:
530
540
  try:
531
541
  time.sleep(t)
532
542
  except KeyboardInterrupt:
533
543
  pass
534
544
 
535
- return t
536
-
537
545
  LOG_DIR: str = "logs"
538
546
  makedirs(LOG_DIR)
539
547
 
@@ -546,6 +554,17 @@ logfile_worker_creation_logs: str = f'{log_uuid_dir}_worker_creation_logs'
546
554
  logfile_trial_index_to_param_logs: str = f'{log_uuid_dir}_trial_index_to_param_logs'
547
555
  LOGFILE_DEBUG_GET_NEXT_TRIALS: Union[str, None] = None
548
556
 
557
+ def error_without_print(text: str) -> None:
558
+ print_debug(text)
559
+
560
+ if get_current_run_folder():
561
+ try:
562
+ with open(get_current_run_folder("oo_errors.txt"), mode="a", encoding="utf-8") as myfile:
563
+ myfile.write(text + "\n\n")
564
+ except (OSError, FileNotFoundError) as e:
565
+ helpers.print_color("red", f"Error: {e}. This may mean that the {get_current_run_folder()} was deleted during the run. Could not write '{text} to {get_current_run_folder()}/oo_errors.txt'")
566
+ sys.exit(99)
567
+
549
568
  def print_red(text: str) -> None:
550
569
  helpers.print_color("red", text)
551
570
 
@@ -579,20 +598,22 @@ def _debug(msg: str, _lvl: int = 0, eee: Union[None, str, Exception] = None) ->
579
598
  def _get_debug_json(time_str: str, msg: str) -> str:
580
599
  function_stack = []
581
600
  try:
582
- frame = inspect.currentframe().f_back # skip _get_debug_json
583
- while frame:
584
- func_name = _function_name_cache.get(frame.f_code)
585
- if func_name is None:
586
- func_name = frame.f_code.co_name
587
- _function_name_cache[frame.f_code] = func_name
588
-
589
- if func_name not in ("<module>", "print_debug", "wrapper"):
590
- function_stack.append({
591
- "function": func_name,
592
- "line_number": frame.f_lineno
593
- })
594
-
595
- frame = frame.f_back
601
+ cf = inspect.currentframe()
602
+ if cf:
603
+ frame = cf.f_back # skip _get_debug_json
604
+ while frame:
605
+ func_name = _function_name_cache.get(frame.f_code)
606
+ if func_name is None:
607
+ func_name = frame.f_code.co_name
608
+ _function_name_cache[frame.f_code] = func_name
609
+
610
+ if func_name not in ("<module>", "print_debug", "wrapper"):
611
+ function_stack.append({
612
+ "function": func_name,
613
+ "line_number": frame.f_lineno
614
+ })
615
+
616
+ frame = frame.f_back
596
617
  except (SignalUSR, SignalINT, SignalCONT):
597
618
  print_red("\n⚠ You pressed CTRL-C. This is ignored in _get_debug_json.")
598
619
 
@@ -694,11 +715,14 @@ def my_exit(_code: int = 0) -> None:
694
715
  if is_skip_search() and os.getenv("SKIP_SEARCH_EXIT_CODE"):
695
716
  skip_search_exit_code = os.getenv("SKIP_SEARCH_EXIT_CODE")
696
717
 
718
+ skip_search_exit_code_found = None
719
+
697
720
  try:
698
- sys.exit(int(skip_search_exit_code))
721
+ if skip_search_exit_code_found is not None:
722
+ skip_search_exit_code_found = int(skip_search_exit_code)
723
+ sys.exit(skip_search_exit_code_found)
699
724
  except ValueError:
700
- print(f"Trying to look for SKIP_SEARCH_EXIT_CODE failed. Exiting with original exit code {_code}")
701
- sys.exit(_code)
725
+ print_debug(f"Trying to look for SKIP_SEARCH_EXIT_CODE failed. Exiting with original exit code {_code}")
702
726
 
703
727
  sys.exit(_code)
704
728
 
@@ -799,6 +823,7 @@ class ConfigLoader:
799
823
  parameter: Optional[List[str]]
800
824
  experiment_constraints: Optional[List[str]]
801
825
  main_process_gb: int
826
+ beartype: bool
802
827
  worker_timeout: int
803
828
  slurm_signal_delay_s: int
804
829
  gridsearch: bool
@@ -819,6 +844,7 @@ class ConfigLoader:
819
844
  run_program_once: str
820
845
  mem_gb: int
821
846
  flame_graph: bool
847
+ memray: bool
822
848
  continue_previous_job: Optional[str]
823
849
  calculate_pareto_front_of_job: Optional[List[str]]
824
850
  revert_to_random_when_seemingly_exhausted: bool
@@ -980,6 +1006,7 @@ class ConfigLoader:
980
1006
  debug.add_argument('--verbose_break_run_search_table', help='Verbose logging for break_run_search', action='store_true', default=False)
981
1007
  debug.add_argument('--debug', help='Enable debugging', action='store_true', default=False)
982
1008
  debug.add_argument('--flame_graph', help='Enable flame-graphing. Makes everything slower, but creates a flame graph', action='store_true', default=False)
1009
+ debug.add_argument('--memray', help='Use memray to show memory usage', action='store_true', default=False)
983
1010
  debug.add_argument('--no_sleep', help='Disables sleeping for fast job generation (not to be used on HPC)', action='store_true', default=False)
984
1011
  debug.add_argument('--tests', help='Run simple internal tests', action='store_true', default=False)
985
1012
  debug.add_argument('--show_worker_percentage_table_at_end', help='Show a table of percentage of usage of max worker over time', action='store_true', default=False)
@@ -993,8 +1020,8 @@ class ConfigLoader:
993
1020
  debug.add_argument('--runtime_debug', help='Logs which functions use most of the time', action='store_true', default=False)
994
1021
  debug.add_argument('--debug_stack_regex', help='Only print debug messages if call stack matches any regex', type=str, default='')
995
1022
  debug.add_argument('--debug_stack_trace_regex', help='Show compact call stack with arrows if any function in stack matches regex', type=str, default=None)
996
-
997
1023
  debug.add_argument('--show_func_name', help='Show func name before each execution and when it is done', action='store_true', default=False)
1024
+ debug.add_argument('--beartype', help='Use beartype', action='store_true', default=False)
998
1025
 
999
1026
  def load_config(self: Any, config_path: str, file_format: str) -> dict:
1000
1027
  if not os.path.isfile(config_path):
@@ -1232,11 +1259,15 @@ for _rn in args.result_names:
1232
1259
  _key = _rn
1233
1260
  _min_or_max = __default_min_max
1234
1261
 
1262
+ _min_or_max = re.sub(r"'", "", _min_or_max)
1263
+
1235
1264
  if _min_or_max not in ["min", "max"]:
1236
1265
  if _min_or_max:
1237
1266
  print_yellow(f"Value for determining whether to minimize or maximize was neither 'min' nor 'max' for key '{_key}', but '{_min_or_max}'. It will be set to the default, which is '{__default_min_max}' instead.")
1238
1267
  _min_or_max = __default_min_max
1239
1268
 
1269
+ _key = re.sub(r"'", "", _key)
1270
+
1240
1271
  if _key in arg_result_names:
1241
1272
  console.print(f"[red]The --result_names option '{_key}' was specified multiple times![/]")
1242
1273
  sys.exit(50)
@@ -1343,19 +1374,6 @@ try:
1343
1374
  with spinner("Importing GeneratorSpec..."):
1344
1375
  from ax.generation_strategy.generator_spec import GeneratorSpec
1345
1376
 
1346
- #except Exception:
1347
- # with spinner("Fallback: Importing ax.generation_strategy.generation_node..."):
1348
- # import ax.generation_strategy.generation_node
1349
-
1350
- # with spinner("Fallback: Importing GenerationStep, GenerationStrategy from ax.generation_strategy..."):
1351
- # from ax.generation_strategy.generation_node import GenerationNode, GenerationStep
1352
-
1353
- # with spinner("Fallback: Importing ExternalGenerationNode..."):
1354
- # from ax.generation_strategy.external_generation_node import ExternalGenerationNode
1355
-
1356
- # with spinner("Fallback: Importing MaxTrials..."):
1357
- # from ax.generation_strategy.transition_criterion import MaxTrials
1358
-
1359
1377
  with spinner("Importing Models from ax.generation_strategy.registry..."):
1360
1378
  from ax.adapter.registry import Models
1361
1379
 
@@ -1463,6 +1481,9 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1463
1481
  def update_generator_state(self: Any, experiment: Experiment, data: Data) -> None:
1464
1482
  search_space = experiment.search_space
1465
1483
  parameter_names = list(search_space.parameters.keys())
1484
+ if experiment.optimization_config is None:
1485
+ print_red("Error: update_generator_state is None")
1486
+ return
1466
1487
  metric_names = list(experiment.optimization_config.metrics.keys())
1467
1488
 
1468
1489
  completed_trials = [
@@ -1474,7 +1495,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1474
1495
  y = np.zeros([num_completed_trials, 1])
1475
1496
 
1476
1497
  for t_idx, trial in enumerate(completed_trials):
1477
- trial_parameters = trial.arm.parameters
1498
+ trial_parameters = trial.arms[t_idx].parameters
1478
1499
  x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
1479
1500
  trial_df = data.df[data.df["trial_index"] == trial.index]
1480
1501
  y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]]["mean"].item()
@@ -1614,10 +1635,18 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1614
1635
  def _format_best_sample(self: Any, best_sample: TParameterization, reverse_choice_map: dict) -> None:
1615
1636
  for name in best_sample.keys():
1616
1637
  param = self.parameters.get(name)
1638
+ best_sample_by_name = best_sample[name]
1639
+
1617
1640
  if isinstance(param, RangeParameter) and param.parameter_type == ParameterType.INT:
1618
- best_sample[name] = int(round(best_sample[name]))
1641
+ if best_sample_by_name is not None:
1642
+ best_sample[name] = int(round(float(best_sample_by_name)))
1643
+ else:
1644
+ print_debug("best_sample_by_name was empty")
1619
1645
  elif isinstance(param, ChoiceParameter):
1620
- best_sample[name] = str(reverse_choice_map.get(int(best_sample[name])))
1646
+ if best_sample_by_name is not None:
1647
+ best_sample[name] = str(reverse_choice_map.get(int(best_sample_by_name)))
1648
+ else:
1649
+ print_debug("best_sample_by_name was empty")
1621
1650
 
1622
1651
  decoder_registry["RandomForestGenerationNode"] = RandomForestGenerationNode
1623
1652
 
@@ -2048,9 +2077,9 @@ def run_live_share_command(force: bool = False) -> Tuple[str, str]:
2048
2077
  return str(result.stdout), str(result.stderr)
2049
2078
  except subprocess.CalledProcessError as e:
2050
2079
  if e.stderr:
2051
- original_print(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
2080
+ print_debug(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
2052
2081
  else:
2053
- original_print(f"run_live_share_command: command failed with error: {e}")
2082
+ print_debug(f"run_live_share_command: command failed with error: {e}")
2054
2083
  return "", str(e.stderr)
2055
2084
  except Exception as e:
2056
2085
  print(f"run_live_share_command: An error occurred: {e}")
@@ -2064,6 +2093,8 @@ def force_live_share() -> bool:
2064
2093
  return False
2065
2094
 
2066
2095
  def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
2096
+ log_data()
2097
+
2067
2098
  if not get_current_run_folder():
2068
2099
  print(f"live_share: get_current_run_folder was empty or false: {get_current_run_folder()}")
2069
2100
  return False
@@ -2077,23 +2108,16 @@ def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
2077
2108
  if stderr:
2078
2109
  print_green(stderr)
2079
2110
  else:
2080
- print_red("This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
2111
+ print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
2081
2112
  if stdout:
2082
2113
  print_debug(f"live_share stdout: {stdout}")
2083
2114
 
2084
2115
  return True
2085
2116
 
2086
2117
  def init_live_share() -> bool:
2087
- with spinner("Initializing live share..."):
2088
- ret = live_share(True, True)
2089
-
2090
- return ret
2118
+ ret = live_share(True, True)
2091
2119
 
2092
- async def start_periodic_live_share() -> None:
2093
- if args.live_share and not os.environ.get("CI"):
2094
- while True:
2095
- live_share(force=False)
2096
- time.sleep(30)
2120
+ return ret
2097
2121
 
2098
2122
  def init_storage(db_url: str) -> None:
2099
2123
  init_engine_and_session_factory(url=db_url, force_init=True)
@@ -2119,7 +2143,11 @@ def try_saving_to_db() -> None:
2119
2143
  else:
2120
2144
  print_red("ax_client was not defined in try_saving_to_db")
2121
2145
  my_exit(101)
2122
- save_generation_strategy(global_gs)
2146
+
2147
+ if global_gs is not None:
2148
+ save_generation_strategy(global_gs)
2149
+ else:
2150
+ print_red("Not saving generation strategy: global_gs was empty")
2123
2151
  except Exception as e:
2124
2152
  print_debug(f"Failed trying to save sqlite3-DB: {e}")
2125
2153
 
@@ -2150,6 +2178,8 @@ def merge_with_job_infos(df: pd.DataFrame) -> pd.DataFrame:
2150
2178
  return merged
2151
2179
 
2152
2180
  def save_results_csv() -> Optional[str]:
2181
+ log_data()
2182
+
2153
2183
  if args.dryrun:
2154
2184
  return None
2155
2185
 
@@ -2164,6 +2194,9 @@ def save_results_csv() -> Optional[str]:
2164
2194
 
2165
2195
  try:
2166
2196
  df = fetch_and_prepare_trials()
2197
+ if df is None:
2198
+ print_red(f"save_results_csv: fetch_and_prepare_trials returned an empty element: {df}")
2199
+ return None
2167
2200
  write_csv(df, pd_csv)
2168
2201
  write_json_snapshot(pd_json)
2169
2202
  save_experiment_to_file()
@@ -2176,43 +2209,67 @@ def save_results_csv() -> Optional[str]:
2176
2209
  except (SignalUSR, SignalCONT, SignalINT) as e:
2177
2210
  raise type(e)(str(e)) from e
2178
2211
  except Exception as e:
2179
- print_red(f"While saving all trials as a pandas-dataframe-csv, an error occurred: {e}")
2212
+ print_red(f"\nWhile saving all trials as a pandas-dataframe-csv, an error occurred: {e}")
2180
2213
 
2181
2214
  return pd_csv
2182
2215
 
2183
2216
  def get_results_paths() -> tuple[str, str]:
2184
2217
  return (get_current_run_folder(RESULTS_CSV_FILENAME), get_state_file_name('pd.json'))
2185
2218
 
2186
- def fetch_and_prepare_trials() -> pd.DataFrame:
2187
- ax_client.experiment.fetch_data()
2188
- df = ax_client.get_trials_data_frame()
2219
+ def ax_client_get_trials_data_frame() -> Optional[pd.DataFrame]:
2220
+ if not ax_client:
2221
+ my_exit(101)
2222
+
2223
+ return None
2224
+
2225
+ return ax_client.get_trials_data_frame()
2226
+
2227
+ def fetch_and_prepare_trials() -> Optional[pd.DataFrame]:
2228
+ if not ax_client:
2229
+ return None
2189
2230
 
2190
- #print("========================")
2191
- #print("BEFORE merge_with_job_infos:")
2192
- #print(df["generation_node"])
2231
+ ax_client.experiment.fetch_data()
2232
+ df = ax_client_get_trials_data_frame()
2193
2233
  df = merge_with_job_infos(df)
2194
- #print("AFTER merge_with_job_infos:")
2195
- #print(df["generation_node"])
2196
2234
 
2197
2235
  return df
2198
2236
 
2199
- def write_csv(df, path: str) -> None:
2237
+ def write_csv(df: pd.DataFrame, path: str) -> None:
2200
2238
  try:
2201
2239
  df = df.sort_values(by=["trial_index"], kind="stable").reset_index(drop=True)
2202
2240
  except KeyError:
2203
2241
  pass
2204
2242
  df.to_csv(path, index=False, float_format="%.30f")
2205
2243
 
2206
- def write_json_snapshot(path: str) -> None:
2244
+ def ax_client_to_json_snapshot() -> Optional[dict]:
2245
+ if not ax_client:
2246
+ my_exit(101)
2247
+
2248
+ return None
2249
+
2207
2250
  json_snapshot = ax_client.to_json_snapshot()
2208
- with open(path, "w", encoding="utf-8") as f:
2209
- json.dump(json_snapshot, f, indent=4)
2251
+
2252
+ return json_snapshot
2253
+
2254
+ def write_json_snapshot(path: str) -> None:
2255
+ if ax_client is not None:
2256
+ json_snapshot = ax_client_to_json_snapshot()
2257
+ if json_snapshot is not None:
2258
+ with open(path, "w", encoding="utf-8") as f:
2259
+ json.dump(json_snapshot, f, indent=4)
2260
+ else:
2261
+ print_debug('json_snapshot from ax_client_to_json_snapshot was None')
2262
+ else:
2263
+ print_red("write_json_snapshot: ax_client was None")
2210
2264
 
2211
2265
  def save_experiment_to_file() -> None:
2212
- save_experiment(
2213
- ax_client.experiment,
2214
- get_state_file_name("ax_client.experiment.json")
2215
- )
2266
+ if ax_client is not None:
2267
+ save_experiment(
2268
+ ax_client.experiment,
2269
+ get_state_file_name("ax_client.experiment.json")
2270
+ )
2271
+ else:
2272
+ print_red("save_experiment: ax_client is None")
2216
2273
 
2217
2274
  def should_save_to_database() -> bool:
2218
2275
  return args.model not in uncontinuable_models and args.save_to_database
@@ -2401,7 +2458,7 @@ def set_nr_inserted_jobs(new_nr_inserted_jobs: int) -> None:
2401
2458
 
2402
2459
  def write_worker_usage() -> None:
2403
2460
  if len(WORKER_PERCENTAGE_USAGE):
2404
- csv_filename = get_current_run_folder('worker_usage.csv')
2461
+ csv_filename = get_current_run_folder(worker_usage_file)
2405
2462
 
2406
2463
  csv_columns = ['time', 'num_parallel_jobs', 'nr_current_workers', 'percentage']
2407
2464
 
@@ -2411,35 +2468,39 @@ def write_worker_usage() -> None:
2411
2468
  csv_writer.writerow(row)
2412
2469
  else:
2413
2470
  if is_slurm_job():
2414
- print_debug("WORKER_PERCENTAGE_USAGE seems to be empty. Not writing worker_usage.csv")
2471
+ print_debug(f"WORKER_PERCENTAGE_USAGE seems to be empty. Not writing {worker_usage_file}")
2415
2472
 
2416
2473
  def log_system_usage() -> None:
2474
+ global LAST_LOG_TIME
2475
+
2476
+ now = time.time()
2477
+ if now - LAST_LOG_TIME < 30:
2478
+ return
2479
+
2480
+ LAST_LOG_TIME = int(now)
2481
+
2417
2482
  if not get_current_run_folder():
2418
2483
  return
2419
2484
 
2420
2485
  ram_cpu_csv_file_path = os.path.join(get_current_run_folder(), "cpu_ram_usage.csv")
2421
-
2422
2486
  makedirs(os.path.dirname(ram_cpu_csv_file_path))
2423
2487
 
2424
2488
  file_exists = os.path.isfile(ram_cpu_csv_file_path)
2425
2489
 
2426
- with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
2427
- writer = csv.writer(file)
2428
-
2429
- current_time = int(time.time())
2430
-
2431
- if process is not None:
2432
- mem_proc = process.memory_info()
2433
-
2434
- if mem_proc is not None:
2435
- ram_usage_mb = mem_proc.rss / (1024 * 1024)
2436
- cpu_usage_percent = psutil.cpu_percent(percpu=False)
2490
+ mem_proc = process.memory_info() if process else None
2491
+ if not mem_proc:
2492
+ return
2437
2493
 
2438
- if ram_usage_mb > 0 and cpu_usage_percent > 0:
2439
- if not file_exists:
2440
- writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
2494
+ ram_usage_mb = mem_proc.rss / (1024 * 1024)
2495
+ cpu_usage_percent = psutil.cpu_percent(percpu=False)
2496
+ if ram_usage_mb <= 0 or cpu_usage_percent <= 0:
2497
+ return
2441
2498
 
2442
- writer.writerow([current_time, ram_usage_mb, cpu_usage_percent])
2499
+ with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
2500
+ writer = csv.writer(file)
2501
+ if not file_exists:
2502
+ writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
2503
+ writer.writerow([int(now), ram_usage_mb, cpu_usage_percent])
2443
2504
 
2444
2505
  def write_process_info() -> None:
2445
2506
  try:
@@ -2789,10 +2850,20 @@ def print_debug_get_next_trials(got: int, requested: int, _line: int) -> None:
2789
2850
  log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, 0, "")
2790
2851
 
2791
2852
  def print_debug_progressbar(msg: str) -> None:
2792
- time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
2793
- msg = f"{time_str} ({worker_generator_uuid}): {msg}"
2853
+ global last_msg_progressbar, last_msg_raw
2854
+
2855
+ try:
2856
+ with last_lock_print_debug:
2857
+ if msg != last_msg_raw:
2858
+ time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
2859
+ full_msg = f"{time_str} ({worker_generator_uuid}): {msg}"
2794
2860
 
2795
- _debug_progressbar(msg)
2861
+ _debug_progressbar(full_msg)
2862
+
2863
+ last_msg_raw = msg
2864
+ last_msg_progressbar = full_msg
2865
+ except Exception as e:
2866
+ print(f"Error in print_debug_progressbar: {e}", flush=True)
2796
2867
 
2797
2868
  def get_process_info(pid: Any) -> str:
2798
2869
  try:
@@ -3297,135 +3368,457 @@ def parse_experiment_parameters() -> None:
3297
3368
  # Remove duplicates by 'name' key preserving order
3298
3369
  params = list({p['name']: p for p in params}.values())
3299
3370
 
3300
- experiment_parameters = params
3371
+ experiment_parameters = params # type: ignore[assignment]
3301
3372
 
3302
- def check_factorial_range() -> None:
3303
- if args.model and args.model == "FACTORIAL":
3304
- _fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
3373
+ def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
3374
+ pf_start_time = time.time()
3305
3375
 
3306
- def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
3307
- if value_type not in valid_value_types:
3308
- valid_value_types_string = ", ".join(valid_value_types)
3309
- _fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
3376
+ if not path_to_calculate:
3377
+ return False
3310
3378
 
3311
- def check_range_params_length(this_args: Union[str, list]) -> None:
3312
- if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
3313
- _fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
3379
+ global CURRENT_RUN_FOLDER
3380
+ global RESULT_CSV_FILE
3381
+ global arg_result_names
3314
3382
 
3315
- def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
3316
- if upper_bound is None or lower_bound is None:
3317
- _fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
3318
- if upper_bound == lower_bound:
3319
- if lower_bound == 0:
3320
- _fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
3321
- print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
3322
- if upper_bound is not None:
3323
- lower_bound = -upper_bound
3383
+ if not path_to_calculate:
3384
+ print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
3385
+ return False
3324
3386
 
3325
- def format_value(value: Any, float_format: str = '.80f') -> str:
3326
- try:
3327
- if isinstance(value, float):
3328
- s = format(value, float_format)
3329
- s = s.rstrip('0').rstrip('.') if '.' in s else s
3330
- return s
3331
- return str(value)
3332
- except Exception as e:
3333
- print_red(f"⚠ Error formatting the number {value}: {e}")
3334
- return str(value)
3387
+ if not os.path.exists(path_to_calculate):
3388
+ print_red(f"Path '{path_to_calculate}' does not exist")
3389
+ return False
3335
3390
 
3336
- def replace_parameters_in_string(
3337
- parameters: dict,
3338
- input_string: str,
3339
- float_format: str = '.20f',
3340
- additional_prefixes: list[str] = [],
3341
- additional_patterns: list[str] = [],
3342
- ) -> str:
3343
- try:
3344
- prefixes = ['$', '%'] + additional_prefixes
3345
- patterns = ['{key}', '({key})'] + additional_patterns
3391
+ ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
3346
3392
 
3347
- for key, value in parameters.items():
3348
- replacement = format_value(value, float_format=float_format)
3349
- for prefix in prefixes:
3350
- for pattern in patterns:
3351
- token = prefix + pattern.format(key=key)
3352
- input_string = input_string.replace(token, replacement)
3393
+ if not os.path.exists(ax_client_json):
3394
+ print_red(f"Path '{ax_client_json}' not found")
3395
+ return False
3353
3396
 
3354
- input_string = input_string.replace('\r', ' ').replace('\n', ' ')
3355
- return input_string
3397
+ checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
3398
+ if not os.path.exists(checkpoint_file):
3399
+ print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
3400
+ return False
3356
3401
 
3357
- except Exception as e:
3358
- print_red(f"\n⚠ Error: {e}")
3359
- return ""
3402
+ RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3403
+ if not os.path.exists(RESULT_CSV_FILE):
3404
+ print_red(f"{RESULT_CSV_FILE} not found")
3405
+ return False
3360
3406
 
3361
- def get_memory_usage() -> float:
3362
- user_uid = os.getuid()
3407
+ res_names = []
3363
3408
 
3364
- memory_usage = float(sum(
3365
- p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
3366
- if p.info['uids'].real == user_uid
3367
- ) / (1024 * 1024))
3409
+ res_names_file = f"{path_to_calculate}/result_names.txt"
3410
+ if not os.path.exists(res_names_file):
3411
+ print_red(f"File '{res_names_file}' does not exist")
3412
+ return False
3368
3413
 
3369
- return memory_usage
3414
+ try:
3415
+ with open(res_names_file, "r", encoding="utf-8") as file:
3416
+ lines = file.readlines()
3417
+ except Exception as e:
3418
+ print_red(f"Error reading file '{res_names_file}': {e}")
3419
+ return False
3370
3420
 
3371
- class MonitorProcess:
3372
- def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
3373
- self.pid = pid
3374
- self.interval = interval
3375
- self.running = True
3376
- self.thread = threading.Thread(target=self._monitor)
3377
- self.thread.daemon = True
3421
+ for line in lines:
3422
+ entry = line.strip()
3423
+ if entry != "":
3424
+ res_names.append(entry)
3378
3425
 
3379
- fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
3426
+ if len(res_names) < 2:
3427
+ print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
3428
+ return False
3380
3429
 
3381
- def _monitor(self: Any) -> None:
3382
- try:
3383
- _internal_process = psutil.Process(self.pid)
3384
- while self.running and _internal_process.is_running():
3385
- crf = get_current_run_folder()
3430
+ load_username_to_args(path_to_calculate)
3386
3431
 
3387
- if crf and crf != "":
3388
- log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
3432
+ CURRENT_RUN_FOLDER = path_to_calculate
3389
3433
 
3390
- os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
3434
+ arg_result_names = res_names
3391
3435
 
3392
- with open(log_file_path, mode="a", encoding="utf-8") as log_file:
3393
- hostname = socket.gethostname()
3436
+ load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
3394
3437
 
3395
- slurm_job_id = os.getenv("SLURM_JOB_ID")
3438
+ if experiment_parameters is None:
3439
+ return False
3396
3440
 
3397
- if slurm_job_id:
3398
- hostname += f"-SLURM-ID-{slurm_job_id}"
3441
+ show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
3399
3442
 
3400
- total_memory = psutil.virtual_memory().total / (1024 * 1024)
3401
- cpu_usage = psutil.cpu_percent(interval=5)
3443
+ pf_end_time = time.time()
3402
3444
 
3403
- memory_usage = get_memory_usage()
3445
+ print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
3404
3446
 
3405
- unix_timestamp = int(time.time())
3447
+ return True
3406
3448
 
3407
- log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
3408
- time.sleep(self.interval)
3409
- except psutil.NoSuchProcess:
3410
- pass
3449
+ def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
3450
+ if args.dryrun:
3451
+ print_debug("Not showing Pareto-frontier data with --dryrun")
3452
+ return None
3411
3453
 
3412
- def __enter__(self: Any) -> None:
3413
- self.thread.start()
3414
- return self
3454
+ if len(res_names) > 1:
3455
+ try:
3456
+ show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
3457
+ except Exception as e:
3458
+ inner_tb = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
3459
+ print_red(f"show_pareto_frontier_data() failed with exception '{e}':\n{inner_tb}")
3460
+ else:
3461
+ print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
3462
+ return None
3415
3463
 
3416
- def __exit__(self: Any, exc_type: Any, exc_value: Any, _traceback: Any) -> None:
3417
- self.running = False
3418
- self.thread.join()
3464
+ def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
3465
+ pareto_front_data: dict = {}
3419
3466
 
3420
- def execute_bash_code_log_time(code: str) -> list:
3421
- process_item = subprocess.Popen(code, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
3467
+ all_combinations = list(combinations(range(len(arg_result_names)), 2))
3422
3468
 
3423
- with MonitorProcess(process_item.pid):
3424
- try:
3425
- stdout, stderr = process_item.communicate()
3426
- result = subprocess.CompletedProcess(
3427
- args=code, returncode=process_item.returncode, stdout=stdout, stderr=stderr
3428
- )
3469
+ skip = False
3470
+
3471
+ for i, j in all_combinations:
3472
+ if not skip:
3473
+ metric_x = arg_result_names[i]
3474
+ metric_y = arg_result_names[j]
3475
+
3476
+ x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
3477
+ y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
3478
+
3479
+ try:
3480
+ if metric_x not in pareto_front_data:
3481
+ pareto_front_data[metric_x] = {}
3482
+
3483
+ pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
3484
+ except ax.exceptions.core.DataRequiredError as e:
3485
+ print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
3486
+ except SignalINT:
3487
+ print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
3488
+ skip = True
3489
+
3490
+ return pareto_front_data
3491
+
3492
+ def pareto_front_transform_objectives(
3493
+ points: List[Tuple[Any, float, float]],
3494
+ primary_name: str,
3495
+ secondary_name: str
3496
+ ) -> Tuple[np.ndarray, np.ndarray]:
3497
+ primary_idx = arg_result_names.index(primary_name)
3498
+ secondary_idx = arg_result_names.index(secondary_name)
3499
+
3500
+ x = np.array([p[1] for p in points])
3501
+ y = np.array([p[2] for p in points])
3502
+
3503
+ if arg_result_min_or_max[primary_idx] == "max":
3504
+ x = -x
3505
+ elif arg_result_min_or_max[primary_idx] != "min":
3506
+ raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
3507
+
3508
+ if arg_result_min_or_max[secondary_idx] == "max":
3509
+ y = -y
3510
+ elif arg_result_min_or_max[secondary_idx] != "min":
3511
+ raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
3512
+
3513
+ return x, y
3514
+
3515
+ def get_pareto_frontier_points(
3516
+ path_to_calculate: str,
3517
+ primary_objective: str,
3518
+ secondary_objective: str,
3519
+ x_minimize: bool,
3520
+ y_minimize: bool,
3521
+ absolute_metrics: List[str],
3522
+ num_points: int
3523
+ ) -> Optional[dict]:
3524
+ records = pareto_front_aggregate_data(path_to_calculate)
3525
+
3526
+ if records is None:
3527
+ return None
3528
+
3529
+ points = pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
3530
+ x, y = pareto_front_transform_objectives(points, primary_objective, secondary_objective)
3531
+ selected_points = pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
3532
+ result = pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
3533
+
3534
+ return result
3535
+
3536
+ def pareto_front_table_read_csv() -> List[Dict[str, str]]:
3537
+ with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
3538
+ return list(csv.DictReader(f))
3539
+
3540
+ def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
3541
+ table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
3542
+
3543
+ rows = pareto_front_table_read_csv()
3544
+ if not rows:
3545
+ table.add_column("No data found")
3546
+ return table
3547
+
3548
+ filtered_rows = pareto_front_table_filter_rows(rows, idxs)
3549
+ if not filtered_rows:
3550
+ table.add_column("No matching entries")
3551
+ return table
3552
+
3553
+ param_cols, result_cols = pareto_front_table_get_columns(filtered_rows[0])
3554
+
3555
+ pareto_front_table_add_headers(table, param_cols, result_cols)
3556
+ pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
3557
+
3558
+ return table
3559
+
3560
+ def pareto_front_build_return_structure(
3561
+ path_to_calculate: str,
3562
+ selected_points: List[Tuple[Any, float, float]],
3563
+ records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
3564
+ absolute_metrics: List[str],
3565
+ primary_name: str,
3566
+ secondary_name: str
3567
+ ) -> dict:
3568
+ results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3569
+ result_names_file = f"{path_to_calculate}/result_names.txt"
3570
+
3571
+ with open(result_names_file, mode="r", encoding="utf-8") as f:
3572
+ result_names = [line.strip() for line in f if line.strip()]
3573
+
3574
+ csv_rows = {}
3575
+ with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
3576
+ reader = csv.DictReader(csvfile)
3577
+ for row in reader:
3578
+ trial_index = int(row['trial_index'])
3579
+ csv_rows[trial_index] = row
3580
+
3581
+ ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
3582
+ ignored_columns.update(result_names)
3583
+
3584
+ param_dicts = []
3585
+ idxs = []
3586
+ means_dict = defaultdict(list)
3587
+
3588
+ for (trial_index, arm_name), _, _ in selected_points:
3589
+ row = csv_rows.get(trial_index, {})
3590
+ if row == {} or row is None or row['arm_name'] != arm_name:
3591
+ continue
3592
+
3593
+ idxs.append(int(row["trial_index"]))
3594
+
3595
+ param_dict: dict[str, int | float | str] = {}
3596
+ for key, value in row.items():
3597
+ if key not in ignored_columns:
3598
+ try:
3599
+ param_dict[key] = int(value)
3600
+ except ValueError:
3601
+ try:
3602
+ param_dict[key] = float(value)
3603
+ except ValueError:
3604
+ param_dict[key] = value
3605
+
3606
+ param_dicts.append(param_dict)
3607
+
3608
+ for metric in absolute_metrics:
3609
+ means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
3610
+
3611
+ ret = {
3612
+ primary_name: {
3613
+ secondary_name: {
3614
+ "absolute_metrics": absolute_metrics,
3615
+ "param_dicts": param_dicts,
3616
+ "means": dict(means_dict),
3617
+ "idxs": idxs
3618
+ },
3619
+ "absolute_metrics": absolute_metrics
3620
+ }
3621
+ }
3622
+
3623
+ return ret
3624
+
3625
+ def pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
3626
+ results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3627
+ result_names_file = f"{path_to_calculate}/result_names.txt"
3628
+
3629
+ if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
3630
+ return None
3631
+
3632
+ with open(result_names_file, mode="r", encoding="utf-8") as f:
3633
+ result_names = [line.strip() for line in f if line.strip()]
3634
+
3635
+ records: dict = defaultdict(lambda: {'means': {}})
3636
+
3637
+ with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
3638
+ reader = csv.DictReader(csvfile)
3639
+ for row in reader:
3640
+ trial_index = int(row['trial_index'])
3641
+ arm_name = row['arm_name']
3642
+ key = (trial_index, arm_name)
3643
+
3644
+ for metric in result_names:
3645
+ if metric in row:
3646
+ try:
3647
+ records[key]['means'][metric] = float(row[metric])
3648
+ except ValueError:
3649
+ continue
3650
+
3651
+ return records
3652
+
3653
+ def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
3654
+ if data is None:
3655
+ print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
3656
+ return
3657
+
3658
+ if not supports_sixel():
3659
+ print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
3660
+ return
3661
+
3662
+ import matplotlib.pyplot as plt
3663
+
3664
+ means = data[x_metric][y_metric]["means"]
3665
+
3666
+ x_values = means[x_metric]
3667
+ y_values = means[y_metric]
3668
+
3669
+ fig, _ax = plt.subplots()
3670
+
3671
+ _ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
3672
+
3673
+ _ax.set_xlabel(x_metric)
3674
+ _ax.set_ylabel(y_metric)
3675
+
3676
+ _ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
3677
+
3678
+ _ax.ticklabel_format(style='plain', axis='both', useOffset=False)
3679
+
3680
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
3681
+ plt.savefig(tmp_file.name, dpi=300)
3682
+
3683
+ print_image_to_cli(tmp_file.name, 1000)
3684
+
3685
+ plt.close(fig)
3686
+
3687
+ def pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
3688
+ all_columns = list(first_row.keys())
3689
+ ignored_cols = set(special_col_names) - {"trial_index"}
3690
+
3691
+ param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
3692
+ result_cols = [col for col in arg_result_names if col in all_columns]
3693
+ return param_cols, result_cols
3694
+
3695
+ def check_factorial_range() -> None:
3696
+ if args.model and args.model == "FACTORIAL":
3697
+ _fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
3698
+
3699
+ def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
3700
+ if value_type not in valid_value_types:
3701
+ valid_value_types_string = ", ".join(valid_value_types)
3702
+ _fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
3703
+
3704
+ def check_range_params_length(this_args: Union[str, list]) -> None:
3705
+ if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
3706
+ _fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
3707
+
3708
+ def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
3709
+ if upper_bound is None or lower_bound is None:
3710
+ _fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
3711
+ if upper_bound == lower_bound:
3712
+ if lower_bound == 0:
3713
+ _fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
3714
+ print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
3715
+ if upper_bound is not None:
3716
+ lower_bound = -upper_bound
3717
+
3718
+ def format_value(value: Any, float_format: str = '.80f') -> str:
3719
+ try:
3720
+ if isinstance(value, float):
3721
+ s = format(value, float_format)
3722
+ s = s.rstrip('0').rstrip('.') if '.' in s else s
3723
+ return s
3724
+ return str(value)
3725
+ except Exception as e:
3726
+ print_red(f"⚠ Error formatting the number {value}: {e}")
3727
+ return str(value)
3728
+
3729
+ def replace_parameters_in_string(
3730
+ parameters: dict,
3731
+ input_string: str,
3732
+ float_format: str = '.20f',
3733
+ additional_prefixes: list[str] = [],
3734
+ additional_patterns: list[str] = [],
3735
+ ) -> str:
3736
+ try:
3737
+ prefixes = ['$', '%'] + additional_prefixes
3738
+ patterns = ['{' + 'key' + '}', '(' + '{' + 'key' + '}' + ')'] + additional_patterns
3739
+
3740
+ for key, value in parameters.items():
3741
+ replacement = format_value(value, float_format=float_format)
3742
+ for prefix in prefixes:
3743
+ for pattern in patterns:
3744
+ token = prefix + pattern.format(key=key)
3745
+ input_string = input_string.replace(token, replacement)
3746
+
3747
+ input_string = input_string.replace('\r', ' ').replace('\n', ' ')
3748
+ return input_string
3749
+
3750
+ except Exception as e:
3751
+ print_red(f"\n⚠ Error: {e}")
3752
+ return ""
3753
+
3754
+ def get_memory_usage() -> float:
3755
+ user_uid = os.getuid()
3756
+
3757
+ memory_usage = float(sum(
3758
+ p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
3759
+ if p.info['uids'].real == user_uid
3760
+ ) / (1024 * 1024))
3761
+
3762
+ return memory_usage
3763
+
3764
+ class MonitorProcess:
3765
+ def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
3766
+ self.pid = pid
3767
+ self.interval = interval
3768
+ self.running = True
3769
+ self.thread = threading.Thread(target=self._monitor)
3770
+ self.thread.daemon = True
3771
+
3772
+ fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
3773
+
3774
+ def _monitor(self: Any) -> None:
3775
+ try:
3776
+ _internal_process = psutil.Process(self.pid)
3777
+ while self.running and _internal_process.is_running():
3778
+ crf = get_current_run_folder()
3779
+
3780
+ if crf and crf != "":
3781
+ log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
3782
+
3783
+ os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
3784
+
3785
+ with open(log_file_path, mode="a", encoding="utf-8") as log_file:
3786
+ hostname = socket.gethostname()
3787
+
3788
+ slurm_job_id = os.getenv("SLURM_JOB_ID")
3789
+
3790
+ if slurm_job_id:
3791
+ hostname += f"-SLURM-ID-{slurm_job_id}"
3792
+
3793
+ total_memory = psutil.virtual_memory().total / (1024 * 1024)
3794
+ cpu_usage = psutil.cpu_percent(interval=5)
3795
+
3796
+ memory_usage = get_memory_usage()
3797
+
3798
+ unix_timestamp = int(time.time())
3799
+
3800
+ log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
3801
+ time.sleep(self.interval)
3802
+ except psutil.NoSuchProcess:
3803
+ pass
3804
+
3805
+ def __enter__(self: Any) -> None:
3806
+ self.thread.start()
3807
+ return self
3808
+
3809
+ def __exit__(self: Any, exc_type: Any, exc_value: Any, _traceback: Any) -> None:
3810
+ self.running = False
3811
+ self.thread.join()
3812
+
3813
+ def execute_bash_code_log_time(code: str) -> list:
3814
+ process_item = subprocess.Popen(code, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
3815
+
3816
+ with MonitorProcess(process_item.pid):
3817
+ try:
3818
+ stdout, stderr = process_item.communicate()
3819
+ result = subprocess.CompletedProcess(
3820
+ args=code, returncode=process_item.returncode, stdout=stdout, stderr=stderr
3821
+ )
3429
3822
  return [result.stdout, result.stderr, result.returncode, None]
3430
3823
  except subprocess.CalledProcessError as e:
3431
3824
  real_exit_code = e.returncode
@@ -3550,7 +3943,7 @@ def _add_to_csv_acquire_lock(lockfile: str, dir_path: str) -> bool:
3550
3943
  time.sleep(wait_time)
3551
3944
  max_wait -= wait_time
3552
3945
  except Exception as e:
3553
- print("Lock error:", e)
3946
+ print_red(f"Lock error: {e}")
3554
3947
  return False
3555
3948
  return False
3556
3949
 
@@ -3588,11 +3981,11 @@ def _add_to_csv_rewrite_file(file_path: str, rows: List[list], existing_heading:
3588
3981
  writer = csv.writer(tmp_file)
3589
3982
  writer.writerow(all_headings)
3590
3983
  for row in rows[1:]:
3591
- tmp_file.writerow([
3984
+ tmp_file.writerow([ # type: ignore[attr-defined]
3592
3985
  row[existing_heading.index(h)] if h in existing_heading else ""
3593
3986
  for h in all_headings
3594
3987
  ])
3595
- tmp_file.writerow([
3988
+ tmp_file.writerow([ # type: ignore[attr-defined]
3596
3989
  formatted_data[new_heading.index(h)] if h in new_heading else ""
3597
3990
  for h in all_headings
3598
3991
  ])
@@ -3623,12 +4016,12 @@ def find_file_paths(_text: str) -> List[str]:
3623
4016
  def check_file_info(file_path: str) -> str:
3624
4017
  if not os.path.exists(file_path):
3625
4018
  if not args.tests:
3626
- print(f"check_file_info: The file {file_path} does not exist.")
4019
+ print_red(f"check_file_info: The file {file_path} does not exist.")
3627
4020
  return ""
3628
4021
 
3629
4022
  if not os.access(file_path, os.R_OK):
3630
4023
  if not args.tests:
3631
- print(f"check_file_info: The file {file_path} is not readable.")
4024
+ print_red(f"check_file_info: The file {file_path} is not readable.")
3632
4025
  return ""
3633
4026
 
3634
4027
  file_stat = os.stat(file_path)
@@ -3742,7 +4135,7 @@ def count_defective_nodes(file_path: Union[str, None] = None, entry: Any = None)
3742
4135
  return sorted(set(entries))
3743
4136
 
3744
4137
  except Exception as e:
3745
- print(f"An error has occurred: {e}")
4138
+ print_red(f"An error has occurred: {e}")
3746
4139
  return []
3747
4140
 
3748
4141
  def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]:
@@ -3753,7 +4146,7 @@ def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]
3753
4146
 
3754
4147
  fool_linter(tmp)
3755
4148
  except RuntimeError:
3756
- print(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
4149
+ print_red(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
3757
4150
  count_defective_nodes(None, socket.gethostname())
3758
4151
  return return_in_case_of_error
3759
4152
  except Exception:
@@ -4224,6 +4617,8 @@ def evaluate(parameters_with_trial_index: dict) -> Optional[Union[int, float, Di
4224
4617
  trial_index = parameters_with_trial_index["trial_idx"]
4225
4618
  submit_time = parameters_with_trial_index["submit_time"]
4226
4619
 
4620
+ print(f'Trial-Index: {trial_index}')
4621
+
4227
4622
  queue_time = abs(int(time.time()) - int(submit_time))
4228
4623
 
4229
4624
  start_nvidia_smi_thread()
@@ -4461,7 +4856,7 @@ def replace_string_with_params(input_string: str, params: list) -> str:
4461
4856
  return replaced_string
4462
4857
  except AssertionError as e:
4463
4858
  error_text = f"Error in replace_string_with_params: {e}"
4464
- print(error_text)
4859
+ print_red(error_text)
4465
4860
  raise
4466
4861
 
4467
4862
  return ""
@@ -4738,9 +5133,7 @@ def get_sixel_graphics_data(_pd_csv: str, _force: bool = False) -> list:
4738
5133
  _params = [_command, plot, _tmp, plot_type, tmp_file, _width]
4739
5134
  data.append(_params)
4740
5135
  except Exception as e:
4741
- tb = traceback.format_exc()
4742
- print_red(f"Error trying to print {plot_type} to CLI: {e}, {tb}")
4743
- print_debug(f"Error trying to print {plot_type} to CLI: {e}")
5136
+ print_red(f"Error trying to print {plot_type} to CLI: {e}")
4744
5137
 
4745
5138
  return data
4746
5139
 
@@ -4931,15 +5324,17 @@ def abandon_job(job: Job, trial_index: int, reason: str) -> bool:
4931
5324
  if job:
4932
5325
  try:
4933
5326
  if ax_client:
4934
- _trial = ax_client.get_trial(trial_index)
4935
- _trial.mark_abandoned(reason=reason)
5327
+ _trial = get_ax_client_trial(trial_index)
5328
+ if _trial is None:
5329
+ return False
5330
+
5331
+ mark_abandoned(_trial, reason, trial_index)
4936
5332
  print_debug(f"abandon_job: removing job {job}, trial_index: {trial_index}")
4937
5333
  global_vars["jobs"].remove((job, trial_index))
4938
5334
  else:
4939
- _fatal_error("ax_client could not be found", 9)
5335
+ _fatal_error("ax_client could not be found", 101)
4940
5336
  except Exception as e:
4941
- print(f"ERROR in line {get_line_info()}: {e}")
4942
- print_debug(f"ERROR in line {get_line_info()}: {e}")
5337
+ print_red(f"ERROR in line {get_line_info()}: {e}")
4943
5338
  return False
4944
5339
  job.cancel()
4945
5340
  return True
@@ -4952,20 +5347,6 @@ def abandon_all_jobs() -> None:
4952
5347
  if not abandoned:
4953
5348
  print_debug(f"Job {job} could not be abandoned.")
4954
5349
 
4955
- def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
4956
- if args.dryrun:
4957
- print_debug("Not showing Pareto-frontier data with --dryrun")
4958
- return None
4959
-
4960
- if len(res_names) > 1:
4961
- try:
4962
- show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
4963
- except Exception as e:
4964
- print_red(f"show_pareto_frontier_data() failed with exception '{e}'")
4965
- else:
4966
- print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
4967
- return None
4968
-
4969
5350
  def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None) -> None:
4970
5351
  global END_PROGRAM_RAN
4971
5352
 
@@ -5002,7 +5383,7 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
5002
5383
  _exit = new_exit
5003
5384
  except (SignalUSR, SignalINT, SignalCONT, KeyboardInterrupt):
5004
5385
  print_red("\n⚠ You pressed CTRL+C or a signal was sent. Program execution halted while ending program.")
5005
- print("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
5386
+ print_red("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
5006
5387
  new_exit = show_end_table_and_save_end_files()
5007
5388
  if new_exit > 0:
5008
5389
  _exit = new_exit
@@ -5023,21 +5404,31 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
5023
5404
 
5024
5405
  my_exit(_exit)
5025
5406
 
5407
+ def save_ax_client_to_json_file(checkpoint_filepath: str) -> None:
5408
+ if not ax_client:
5409
+ my_exit(101)
5410
+
5411
+ return None
5412
+
5413
+ ax_client.save_to_json_file(checkpoint_filepath)
5414
+
5415
+ return None
5416
+
5026
5417
  def save_checkpoint(trial_nr: int = 0, eee: Union[None, str, Exception] = None) -> None:
5027
5418
  if trial_nr > 3:
5028
5419
  if eee:
5029
- print(f"Error during saving checkpoint: {eee}")
5420
+ print_red(f"Error during saving checkpoint: {eee}")
5030
5421
  else:
5031
- print("Error during saving checkpoint")
5422
+ print_red("Error during saving checkpoint")
5032
5423
  return
5033
5424
 
5034
5425
  try:
5035
5426
  checkpoint_filepath = get_state_file_name('checkpoint.json')
5036
5427
 
5037
5428
  if ax_client:
5038
- ax_client.save_to_json_file(filepath=checkpoint_filepath)
5429
+ save_ax_client_to_json_file(checkpoint_filepath)
5039
5430
  else:
5040
- _fatal_error("Something went wrong using the ax_client", 9)
5431
+ _fatal_error("Something went wrong using the ax_client", 101)
5041
5432
  except Exception as e:
5042
5433
  save_checkpoint(trial_nr + 1, e)
5043
5434
 
@@ -5198,7 +5589,7 @@ def parse_equation_item(comparer_found: bool, item: str, parsed: list, parsed_or
5198
5589
  })
5199
5590
  elif item in [">=", "<="]:
5200
5591
  if comparer_found:
5201
- print("There is already one comparison operator! Cannot have more than one in an equation!")
5592
+ print_red("There is already one comparison operator! Cannot have more than one in an equation!")
5202
5593
  return_totally = True
5203
5594
  comparer_found = True
5204
5595
 
@@ -5409,20 +5800,9 @@ def check_equation(variables: list, equation: str) -> Union[str, bool]:
5409
5800
  def set_objectives() -> dict:
5410
5801
  objectives = {}
5411
5802
 
5412
- for rn in args.result_names:
5413
- key, value = "", ""
5414
-
5415
- if "=" in rn:
5416
- key, value = rn.split('=', 1)
5417
- else:
5418
- key = rn
5419
- value = ""
5420
-
5421
- if value not in ["min", "max"]:
5422
- if value:
5423
- print_yellow(f"Value '{value}' for --result_names {rn} is not a valid value. Must be min or max. Will be set to min.")
5424
-
5425
- value = "min"
5803
+ k = 0
5804
+ for key in arg_result_names:
5805
+ value = arg_result_min_or_max[k]
5426
5806
 
5427
5807
  _min = True
5428
5808
 
@@ -5430,12 +5810,18 @@ def set_objectives() -> dict:
5430
5810
  _min = False
5431
5811
 
5432
5812
  objectives[key] = ObjectiveProperties(minimize=_min)
5813
+ k = k + 1
5433
5814
 
5434
5815
  return objectives
5435
5816
 
5436
- def set_experiment_constraints(experiment_constraints: Optional[list], experiment_args: dict, _experiment_parameters: Union[dict, list]) -> dict:
5437
- if experiment_constraints and len(experiment_constraints):
5817
+ def set_experiment_constraints(experiment_constraints: Optional[list], experiment_args: dict, _experiment_parameters: Optional[Union[dict, list]]) -> dict:
5818
+ if _experiment_parameters is None:
5819
+ print_red("set_experiment_constraints: _experiment_parameters was None")
5820
+ my_exit(95)
5821
+
5822
+ return {}
5438
5823
 
5824
+ if experiment_constraints and len(experiment_constraints):
5439
5825
  experiment_args["parameter_constraints"] = []
5440
5826
 
5441
5827
  if experiment_constraints:
@@ -5464,11 +5850,15 @@ def set_experiment_constraints(experiment_constraints: Optional[list], experimen
5464
5850
 
5465
5851
  return experiment_args
5466
5852
 
5467
- def replace_parameters_for_continued_jobs(parameter: Optional[list], cli_params_experiment_parameters: Optional[list]) -> None:
5853
+ def replace_parameters_for_continued_jobs(parameter: Optional[list], cli_params_experiment_parameters: Optional[dict | list]) -> None:
5854
+ if not experiment_parameters:
5855
+ print_red("replace_parameters_for_continued_jobs: experiment_parameters was False")
5856
+ return None
5857
+
5468
5858
  if args.worker_generator_path:
5469
5859
  return None
5470
5860
 
5471
- def get_name(obj) -> Optional[str]:
5861
+ def get_name(obj: Any) -> Optional[str]:
5472
5862
  """Extract a parameter name from dict, list, or tuple safely."""
5473
5863
  if isinstance(obj, dict):
5474
5864
  return obj.get("name")
@@ -5550,13 +5940,13 @@ def copy_continue_uuid() -> None:
5550
5940
  print_debug(f"copy_continue_uuid: Source file does not exist: {source_file}")
5551
5941
 
5552
5942
  def load_ax_client_from_experiment_parameters() -> None:
5553
- #pprint(experiment_parameters)
5554
- global ax_client
5943
+ if experiment_parameters:
5944
+ global ax_client
5555
5945
 
5556
- tmp_file_path = get_tmp_file_from_json(experiment_parameters)
5557
- ax_client = AxClient.load_from_json_file(tmp_file_path)
5558
- ax_client = cast(AxClient, ax_client)
5559
- os.unlink(tmp_file_path)
5946
+ tmp_file_path = get_tmp_file_from_json(experiment_parameters)
5947
+ ax_client = AxClient.load_from_json_file(tmp_file_path)
5948
+ ax_client = cast(AxClient, ax_client)
5949
+ os.unlink(tmp_file_path)
5560
5950
 
5561
5951
  def save_checkpoint_for_continued() -> None:
5562
5952
  checkpoint_filepath = get_state_file_name('checkpoint.json')
@@ -5568,12 +5958,15 @@ def save_checkpoint_for_continued() -> None:
5568
5958
  _fatal_error(f"{checkpoint_filepath} not found. Cannot continue_previous_job without.", 47)
5569
5959
 
5570
5960
  def load_original_generation_strategy(original_ax_client_file: str) -> None:
5571
- with open(original_ax_client_file, encoding="utf-8") as f:
5572
- loaded_original_ax_client_json = json.load(f)
5573
- original_generation_strategy = loaded_original_ax_client_json["generation_strategy"]
5961
+ if experiment_parameters:
5962
+ with open(original_ax_client_file, encoding="utf-8") as f:
5963
+ loaded_original_ax_client_json = json.load(f)
5964
+ original_generation_strategy = loaded_original_ax_client_json["generation_strategy"]
5574
5965
 
5575
- if original_generation_strategy:
5576
- experiment_parameters["generation_strategy"] = original_generation_strategy
5966
+ if original_generation_strategy:
5967
+ experiment_parameters["generation_strategy"] = original_generation_strategy
5968
+ else:
5969
+ print_red("load_original_generation_strategy: experiment_parameters was empty!")
5577
5970
 
5578
5971
  def wait_for_checkpoint_file(checkpoint_file: str) -> None:
5579
5972
  start_time = time.time()
@@ -5586,10 +5979,6 @@ def wait_for_checkpoint_file(checkpoint_file: str) -> None:
5586
5979
  elapsed = int(time.time() - start_time)
5587
5980
  console.print(f"[green]Checkpoint file found after {elapsed} seconds[/green] ")
5588
5981
 
5589
- def __get_experiment_parameters__check_ax_client() -> None:
5590
- if not ax_client:
5591
- _fatal_error("Something went wrong with the ax_client", 9)
5592
-
5593
5982
  def validate_experiment_parameters() -> None:
5594
5983
  if experiment_parameters is None:
5595
5984
  print_red("Error: experiment_parameters is None.")
@@ -5599,6 +5988,8 @@ def validate_experiment_parameters() -> None:
5599
5988
  print_red(f"Error: experiment_parameters is not a dict: {type(experiment_parameters).__name__}")
5600
5989
  my_exit(95)
5601
5990
 
5991
+ sys.exit(95)
5992
+
5602
5993
  path_checks = [
5603
5994
  ("experiment", experiment_parameters),
5604
5995
  ("search_space", experiment_parameters.get("experiment")),
@@ -5610,7 +6001,12 @@ def validate_experiment_parameters() -> None:
5610
6001
  print_red(f"Error: Missing key '{key}' at level: {current_level}")
5611
6002
  my_exit(95)
5612
6003
 
5613
- def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str, cli_params_experiment_parameters: Optional[list]) -> Tuple[Any, str, str]:
6004
+ def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_parameters: Optional[dict | list]) -> Tuple[Any, str, str]:
6005
+ if not ax_client:
6006
+ print_red("load_from_checkpoint: ax_client was None")
6007
+ my_exit(101)
6008
+ return {}, "", ""
6009
+
5614
6010
  print_debug(f"Load from checkpoint: {continue_previous_job}")
5615
6011
 
5616
6012
  checkpoint_file = f"{continue_previous_job}/state_files/checkpoint.json"
@@ -5636,7 +6032,7 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
5636
6032
 
5637
6033
  replace_parameters_for_continued_jobs(args.parameter, cli_params_experiment_parameters)
5638
6034
 
5639
- ax_client.save_to_json_file(filepath=original_ax_client_file)
6035
+ save_ax_client_to_json_file(original_ax_client_file)
5640
6036
 
5641
6037
  load_original_generation_strategy(original_ax_client_file)
5642
6038
  load_ax_client_from_experiment_parameters()
@@ -5652,6 +6048,12 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
5652
6048
 
5653
6049
  experiment_constraints = get_constraints()
5654
6050
  if experiment_constraints:
6051
+
6052
+ if not experiment_parameters:
6053
+ print_red("load_from_checkpoint: experiment_parameters was None")
6054
+
6055
+ return {}, "", ""
6056
+
5655
6057
  experiment_args = set_experiment_constraints(
5656
6058
  experiment_constraints,
5657
6059
  experiment_args,
@@ -5660,7 +6062,116 @@ def __get_experiment_parameters__load_from_checkpoint(continue_previous_job: str
5660
6062
 
5661
6063
  return experiment_args, gpu_string, gpu_color
5662
6064
 
5663
- def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str]:
6065
+ def get_experiment_args_import_python_script() -> str:
6066
+
6067
+ return """from ax.service.ax_client import AxClient, ObjectiveProperties
6068
+ from ax.adapter.registry import Generators
6069
+ import random
6070
+
6071
+ """
6072
+
6073
+ def get_generate_and_test_random_function_str() -> str:
6074
+ raw_data_entries = ",\n ".join(
6075
+ f'"{name}": random.uniform(0, 1)' for name in arg_result_names
6076
+ )
6077
+
6078
+ return f"""
6079
+ def generate_and_test_random_parameters(n: int) -> None:
6080
+ for _ in range(n):
6081
+ print("======================================")
6082
+ parameters, trial_index = ax_client.get_next_trial()
6083
+ print("Trial Index:", trial_index)
6084
+ print("Suggested parameters:", parameters)
6085
+
6086
+ ax_client.complete_trial(
6087
+ trial_index=trial_index,
6088
+ raw_data={{
6089
+ {raw_data_entries}
6090
+ }}
6091
+ )
6092
+
6093
+ generate_and_test_random_parameters({args.num_random_steps + 1})
6094
+ """
6095
+
6096
+ def get_global_gs_string() -> str:
6097
+ seed_str = ""
6098
+ if args.seed is not None:
6099
+ seed_str = f"model_kwargs={{'seed': {args.seed}}},"
6100
+
6101
+ return f"""from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy
6102
+
6103
+ global_gs = GenerationStrategy(
6104
+ steps=[
6105
+ GenerationStep(
6106
+ generator=Generators.SOBOL,
6107
+ num_trials={args.num_random_steps},
6108
+ max_parallelism=5,
6109
+ {seed_str}
6110
+ ),
6111
+ GenerationStep(
6112
+ generator=Generators.{args.model},
6113
+ num_trials=-1,
6114
+ max_parallelism=5,
6115
+ ),
6116
+ ]
6117
+ )
6118
+ """
6119
+
6120
+ def get_debug_ax_client_str() -> str:
6121
+ return """
6122
+ ax_client = AxClient(
6123
+ verbose_logging=True,
6124
+ enforce_sequential_optimization=False,
6125
+ generation_strategy=global_gs
6126
+ )
6127
+ """
6128
+
6129
+ def write_ax_debug_python_code(experiment_args: dict) -> None:
6130
+ if args.generation_strategy:
6131
+ print_debug("Cannot write debug code for custom generation_strategy")
6132
+ return None
6133
+
6134
+ if args.model in uncontinuable_models:
6135
+ print_debug(f"Cannot write debug code for uncontinuable mode {args.model}")
6136
+ return None
6137
+
6138
+ python_code = python_code = get_experiment_args_import_python_script() + \
6139
+ get_global_gs_string() + \
6140
+ get_debug_ax_client_str() + \
6141
+ "experiment_args = " + pformat(experiment_args, width=120, compact=False) + \
6142
+ "\nax_client.create_experiment(**experiment_args)\n" + \
6143
+ get_generate_and_test_random_function_str()
6144
+
6145
+ file_path = f"{get_current_run_folder()}/debug.py"
6146
+
6147
+ try:
6148
+ print_debug(python_code)
6149
+ with open(file_path, "w", encoding="utf-8") as f:
6150
+ f.write(python_code)
6151
+ except Exception as e:
6152
+ print_red(f"Error while writing {file_path}: {e}")
6153
+
6154
+ return None
6155
+
6156
+ def create_ax_client_experiment(experiment_args: dict) -> None:
6157
+ if not ax_client:
6158
+ my_exit(101)
6159
+
6160
+ return None
6161
+
6162
+ write_ax_debug_python_code(experiment_args)
6163
+
6164
+ ax_client.create_experiment(**experiment_args)
6165
+
6166
+ return None
6167
+
6168
+ def create_new_experiment() -> Tuple[dict, str, str]:
6169
+ if ax_client is None:
6170
+ print_red("create_new_experiment: ax_client is None")
6171
+ my_exit(101)
6172
+
6173
+ return {}, "", ""
6174
+
5664
6175
  objectives = set_objectives()
5665
6176
 
5666
6177
  experiment_args = {
@@ -5683,7 +6194,7 @@ def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str
5683
6194
  experiment_args = set_experiment_constraints(get_constraints(), experiment_args, experiment_parameters)
5684
6195
 
5685
6196
  try:
5686
- ax_client.create_experiment(**experiment_args)
6197
+ create_ax_client_experiment(experiment_args)
5687
6198
  new_metrics = [Metric(k) for k in arg_result_names if k not in ax_client.metric_names]
5688
6199
  ax_client.experiment.add_tracking_metrics(new_metrics)
5689
6200
  except AssertionError as error:
@@ -5697,17 +6208,17 @@ def __get_experiment_parameters__create_new_experiment() -> Tuple[dict, str, str
5697
6208
 
5698
6209
  return experiment_args, gpu_string, gpu_color
5699
6210
 
5700
- def get_experiment_parameters(cli_params_experiment_parameters: Optional[list]) -> Optional[Tuple[AxClient, dict, str, str]]:
6211
+ def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[dict, str, str]]:
5701
6212
  continue_previous_job = args.worker_generator_path or args.continue_previous_job
5702
6213
 
5703
- __get_experiment_parameters__check_ax_client()
6214
+ check_ax_client()
5704
6215
 
5705
6216
  if continue_previous_job:
5706
- experiment_args, gpu_string, gpu_color = __get_experiment_parameters__load_from_checkpoint(continue_previous_job, cli_params_experiment_parameters)
6217
+ experiment_args, gpu_string, gpu_color = load_from_checkpoint(continue_previous_job, cli_params_experiment_parameters)
5707
6218
  else:
5708
- experiment_args, gpu_string, gpu_color = __get_experiment_parameters__create_new_experiment()
6219
+ experiment_args, gpu_string, gpu_color = create_new_experiment()
5709
6220
 
5710
- return ax_client, experiment_args, gpu_string, gpu_color
6221
+ return experiment_args, gpu_string, gpu_color
5711
6222
 
5712
6223
  def get_type_short(typename: str) -> str:
5713
6224
  if typename == "RangeParameter":
@@ -5766,7 +6277,6 @@ def parse_single_experiment_parameter_table(classic_params: Optional[Union[list,
5766
6277
  _upper = param["bounds"][1]
5767
6278
 
5768
6279
  _possible_int_lower = str(helpers.to_int_when_possible(_lower))
5769
- #print(f"name: {_name}, _possible_int_lower: {_possible_int_lower}, lower: {_lower}")
5770
6280
  _possible_int_upper = str(helpers.to_int_when_possible(_upper))
5771
6281
 
5772
6282
  rows.append([_name, _short_type, _possible_int_lower, _possible_int_upper, "", value_type, log_scale])
@@ -5843,19 +6353,62 @@ def print_ax_parameter_constraints_table(experiment_args: dict) -> None:
5843
6353
 
5844
6354
  return None
5845
6355
 
6356
+ def check_base_for_print_overview() -> Optional[bool]:
6357
+ if args.continue_previous_job is not None and arg_result_names is not None and len(arg_result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
6358
+ print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
6359
+
6360
+ if ax_client is None:
6361
+ print_red("ax_client was None")
6362
+ return None
6363
+
6364
+ if ax_client.experiment is None:
6365
+ print_red("ax_client.experiment was None")
6366
+ return None
6367
+
6368
+ if ax_client.experiment.optimization_config is None:
6369
+ print_red("ax_client.experiment.optimization_config was None")
6370
+ return None
6371
+
6372
+ return True
6373
+
6374
+ def get_config_objectives() -> Any:
6375
+ if not ax_client:
6376
+ print_red("create_new_experiment: ax_client is None")
6377
+ my_exit(101)
6378
+
6379
+ return None
6380
+
6381
+ config_objectives = None
6382
+
6383
+ if ax_client.experiment and ax_client.experiment.optimization_config:
6384
+ opt_config = ax_client.experiment.optimization_config
6385
+ if opt_config.is_moo_problem:
6386
+ objective = getattr(opt_config, "objective", None)
6387
+ if objective and getattr(objective, "objectives", None) is not None:
6388
+ config_objectives = objective.objectives
6389
+ else:
6390
+ print_debug("ax_client.experiment.optimization_config.objective was None")
6391
+ else:
6392
+ config_objectives = [opt_config.objective]
6393
+ else:
6394
+ print_debug("ax_client.experiment or optimization_config was None")
6395
+
6396
+ return config_objectives
6397
+
5846
6398
  def print_result_names_overview_table() -> None:
5847
6399
  if not ax_client:
5848
6400
  _fatal_error("Tried to access ax_client in print_result_names_overview_table, but it failed, because the ax_client was not defined.", 101)
5849
6401
 
5850
6402
  return None
5851
6403
 
5852
- if args.continue_previous_job is not None and args.result_names is not None and len(args.result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
5853
- print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
6404
+ if check_base_for_print_overview() is None:
6405
+ return None
5854
6406
 
5855
- if ax_client.experiment.optimization_config.is_moo_problem:
5856
- config_objectives = ax_client.experiment.optimization_config.objective.objectives
5857
- else:
5858
- config_objectives = [ax_client.experiment.optimization_config.objective]
6407
+ config_objectives = get_config_objectives()
6408
+
6409
+ if config_objectives is None:
6410
+ print_red("config_objectives not found")
6411
+ return None
5859
6412
 
5860
6413
  res_names = []
5861
6414
  res_min_max = []
@@ -5950,28 +6503,31 @@ def print_overview_tables(classic_params: Optional[Union[list, dict]], experimen
5950
6503
  print_result_names_overview_table()
5951
6504
 
5952
6505
  def update_progress_bar(nr: int) -> None:
5953
- try:
5954
- progress_bar.update(nr)
5955
- except Exception as e:
5956
- print(f"Error updating progress bar: {e}")
6506
+ log_data()
6507
+
6508
+ if progress_bar is not None:
6509
+ try:
6510
+ progress_bar.update(nr)
6511
+ except Exception as e:
6512
+ print_red(f"Error updating progress bar: {e}")
6513
+ else:
6514
+ print_red("update_progress_bar: progress_bar was None")
5957
6515
 
5958
6516
  def get_current_model_name() -> str:
5959
6517
  if overwritten_to_random:
5960
6518
  return "Random*"
5961
6519
 
6520
+ gs_model = "unknown model"
6521
+
5962
6522
  if ax_client:
5963
6523
  try:
5964
6524
  if args.generation_strategy:
5965
- idx = getattr(ax_client.generation_strategy, "current_step_index", None)
6525
+ idx = getattr(global_gs, "current_step_index", None)
5966
6526
  if isinstance(idx, int):
5967
6527
  if 0 <= idx < len(generation_strategy_names):
5968
6528
  gs_model = generation_strategy_names[int(idx)]
5969
- else:
5970
- gs_model = "unknown model"
5971
- else:
5972
- gs_model = "unknown model"
5973
6529
  else:
5974
- gs_model = getattr(ax_client.generation_strategy, "current_node_name", "unknown model")
6530
+ gs_model = getattr(global_gs, "current_node_name", "unknown model")
5975
6531
 
5976
6532
  if gs_model:
5977
6533
  return str(gs_model)
@@ -6028,7 +6584,7 @@ def get_workers_string() -> str:
6028
6584
  )
6029
6585
 
6030
6586
  def _get_workers_string_collect_stats() -> dict:
6031
- stats = {}
6587
+ stats: dict = {}
6032
6588
  for job, _ in global_vars["jobs"][:]:
6033
6589
  state = state_from_job(job)
6034
6590
  stats[state] = stats.get(state, 0) + 1
@@ -6077,8 +6633,8 @@ def submitted_jobs(nr: int = 0) -> int:
6077
6633
  def count_jobs_in_squeue() -> tuple[int, str]:
6078
6634
  global _last_count_time, _last_count_result
6079
6635
 
6080
- now = time.time()
6081
- if _last_count_result != (0, "") and now - _last_count_time < 15:
6636
+ now = int(time.time())
6637
+ if _last_count_result != (0, "") and now - _last_count_time < 5:
6082
6638
  return _last_count_result
6083
6639
 
6084
6640
  _len = len(global_vars["jobs"])
@@ -6100,6 +6656,7 @@ def count_jobs_in_squeue() -> tuple[int, str]:
6100
6656
  check=True,
6101
6657
  text=True
6102
6658
  )
6659
+
6103
6660
  if "slurm_load_jobs error" in result.stderr:
6104
6661
  _last_count_result = (_len, "Detected slurm_load_jobs error in stderr.")
6105
6662
  _last_count_time = now
@@ -6140,6 +6697,8 @@ def log_worker_numbers() -> None:
6140
6697
  if len(WORKER_PERCENTAGE_USAGE) == 0 or WORKER_PERCENTAGE_USAGE[len(WORKER_PERCENTAGE_USAGE) - 1] != this_values:
6141
6698
  WORKER_PERCENTAGE_USAGE.append(this_values)
6142
6699
 
6700
+ write_worker_usage()
6701
+
6143
6702
  def get_slurm_in_brackets(in_brackets: list) -> list:
6144
6703
  if is_slurm_job():
6145
6704
  workers_strings = get_workers_string()
@@ -6224,6 +6783,8 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
6224
6783
  global last_progress_bar_desc
6225
6784
  global last_progress_bar_refresh_time
6226
6785
 
6786
+ log_data()
6787
+
6227
6788
  if isinstance(new_msgs, str):
6228
6789
  new_msgs = [new_msgs]
6229
6790
 
@@ -6241,7 +6802,7 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
6241
6802
  print_red("Cannot update progress bar! It is None.")
6242
6803
 
6243
6804
  def clean_completed_jobs() -> None:
6244
- job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail"]
6805
+ job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail", "finished"]
6245
6806
  job_states_to_be_ignored = ["ready", "completed", "unknown", "pending", "running", "completing", "out_of_memory", "requeued", "resv_del_hold"]
6246
6807
 
6247
6808
  for job, trial_index in global_vars["jobs"][:]:
@@ -6299,7 +6860,7 @@ def load_existing_job_data_into_ax_client() -> None:
6299
6860
  nr_of_imported_jobs = get_nr_of_imported_jobs()
6300
6861
  set_nr_inserted_jobs(NR_INSERTED_JOBS + nr_of_imported_jobs)
6301
6862
 
6302
- def parse_parameter_type_error(_error_message: Union[str, None]) -> Optional[dict]:
6863
+ def parse_parameter_type_error(_error_message: Union[Exception, str, None]) -> Optional[dict]:
6303
6864
  if not _error_message:
6304
6865
  return None
6305
6866
 
@@ -6366,7 +6927,7 @@ def get_generation_node_for_index(
6366
6927
  results_list: List[Dict[str, Any]],
6367
6928
  index: int,
6368
6929
  __status: Any,
6369
- base_str: str
6930
+ base_str: Optional[str]
6370
6931
  ) -> str:
6371
6932
  __status.update(f"{base_str}: Getting generation node")
6372
6933
  try:
@@ -6386,7 +6947,7 @@ def get_generation_node_for_index(
6386
6947
 
6387
6948
  return generation_node
6388
6949
  except Exception as e:
6389
- print(f"Error while get_generation_node_for_index: {e}")
6950
+ print_red(f"Error while get_generation_node_for_index: {e}")
6390
6951
  return "MANUAL"
6391
6952
 
6392
6953
  def _get_generation_node_for_index_index_valid(
@@ -6491,7 +7052,7 @@ def normalize_path(file_path: str) -> str:
6491
7052
 
6492
7053
  def insert_jobs_from_lists(csv_path: str, arm_params_list: Any, results_list: Any, __status: Any) -> None:
6493
7054
  cnt = 0
6494
- err_msgs = []
7055
+ err_msgs: list = []
6495
7056
 
6496
7057
  for i, (arm_params, result) in enumerate(zip(arm_params_list, results_list)):
6497
7058
  base_str = f"[bold green]Loading job {i}/{len(results_list)} from {csv_path} into ax_client, result: {result}"
@@ -6525,9 +7086,16 @@ def try_insert_job(csv_path: str, arm_params: Dict, result: Any, i: int, arm_par
6525
7086
  f"This can happen when the csv file has different parameters or results as the main job one's "
6526
7087
  f"or other imported jobs. Error: {e}"
6527
7088
  )
6528
- if err_msg not in err_msgs:
6529
- print_red(err_msg)
6530
- err_msgs.append(err_msg)
7089
+
7090
+ if err_msgs is None:
7091
+ print_red("try_insert_job: err_msgs was None")
7092
+ else:
7093
+ if isinstance(err_msgs, list):
7094
+ if err_msg not in err_msgs:
7095
+ print_red(err_msg)
7096
+ err_msgs.append(err_msg)
7097
+ elif isinstance(err_msgs, str):
7098
+ err_msgs += f"\n{err_msg}"
6531
7099
 
6532
7100
  return cnt
6533
7101
 
@@ -6544,33 +7112,52 @@ def update_global_job_counters(cnt: int) -> None:
6544
7112
  set_max_eval(max_eval + cnt)
6545
7113
  set_nr_inserted_jobs(NR_INSERTED_JOBS + cnt)
6546
7114
 
6547
- def __insert_job_into_ax_client__update_status(__status: Optional[Any], base_str: Optional[str], new_text: str) -> None:
7115
+ def update_status(__status: Optional[Any], base_str: Optional[str], new_text: str) -> None:
6548
7116
  if __status and base_str:
6549
7117
  __status.update(f"{base_str}: {new_text}")
6550
7118
 
6551
- def __insert_job_into_ax_client__check_ax_client() -> None:
7119
+ def check_ax_client() -> None:
6552
7120
  if ax_client is None or not ax_client:
6553
7121
  _fatal_error("insert_job_into_ax_client: ax_client was not defined where it should have been", 101)
6554
7122
 
6555
- def __insert_job_into_ax_client__attach_trial(arm_params: dict) -> Tuple[Any, int]:
7123
+ def attach_ax_client_data(arm_params: dict) -> Optional[Tuple[Any, int]]:
7124
+ if not ax_client:
7125
+ my_exit(101)
7126
+
7127
+ return None
7128
+
6556
7129
  new_trial = ax_client.attach_trial(arm_params)
7130
+
7131
+ return new_trial
7132
+
7133
+ def attach_trial(arm_params: dict) -> Tuple[Any, int]:
7134
+ if ax_client is None:
7135
+ raise RuntimeError("attach_trial: ax_client was empty")
7136
+
7137
+ new_trial = attach_ax_client_data(arm_params)
6557
7138
  if not isinstance(new_trial, tuple) or len(new_trial) < 2:
6558
7139
  raise RuntimeError("attach_trial didn't return the expected tuple")
6559
7140
  return new_trial
6560
7141
 
6561
- def __insert_job_into_ax_client__get_trial(trial_idx: int) -> Any:
7142
+ def get_trial_by_index(trial_idx: int) -> Any:
7143
+ if ax_client is None:
7144
+ raise RuntimeError("get_trial_by_index: ax_client was empty")
7145
+
6562
7146
  trial = ax_client.experiment.trials.get(trial_idx)
6563
7147
  if trial is None:
6564
7148
  raise RuntimeError(f"Trial with index {trial_idx} not found")
6565
7149
  return trial
6566
7150
 
6567
- def __insert_job_into_ax_client__create_generator_run(arm_params: dict, trial_idx: int, new_job_type: str) -> GeneratorRun:
7151
+ def create_generator_run(arm_params: dict, trial_idx: int, new_job_type: str) -> GeneratorRun:
6568
7152
  arm = Arm(parameters=arm_params, name=f'{trial_idx}_0')
6569
7153
  return GeneratorRun(arms=[arm], generation_node_name=new_job_type)
6570
7154
 
6571
- def __insert_job_into_ax_client__complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[Any], base_str: Optional[str]) -> None:
7155
+ def complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[Any], base_str: Optional[str]) -> None:
7156
+ if ax_client is None:
7157
+ raise RuntimeError("complete_trial_if_result: ax_client was empty")
7158
+
6572
7159
  if f"{result}" != "":
6573
- __insert_job_into_ax_client__update_status(__status, base_str, "Completing trial")
7160
+ update_status(__status, base_str, "Completing trial")
6574
7161
  is_ok = True
6575
7162
 
6576
7163
  for keyname in result.keys():
@@ -6578,20 +7165,20 @@ def __insert_job_into_ax_client__complete_trial_if_result(trial_idx: int, result
6578
7165
  is_ok = False
6579
7166
 
6580
7167
  if is_ok:
6581
- ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
6582
- __insert_job_into_ax_client__update_status(__status, base_str, "Completed trial")
7168
+ complete_ax_client_trial(trial_idx, result)
7169
+ update_status(__status, base_str, "Completed trial")
6583
7170
  else:
6584
7171
  print_debug("Empty job encountered")
6585
7172
  else:
6586
- __insert_job_into_ax_client__update_status(__status, base_str, "Found trial without result. Not adding it.")
7173
+ update_status(__status, base_str, "Found trial without result. Not adding it.")
6587
7174
 
6588
- def __insert_job_into_ax_client__save_results_if_needed(__status: Optional[Any], base_str: Optional[str]) -> None:
7175
+ def save_results_if_needed(__status: Optional[Any], base_str: Optional[str]) -> None:
6589
7176
  if not args.worker_generator_path:
6590
- __insert_job_into_ax_client__update_status(__status, base_str, f"Saving {RESULTS_CSV_FILENAME}")
7177
+ update_status(__status, base_str, f"Saving {RESULTS_CSV_FILENAME}")
6591
7178
  save_results_csv()
6592
- __insert_job_into_ax_client__update_status(__status, base_str, f"Saved {RESULTS_CSV_FILENAME}")
7179
+ update_status(__status, base_str, f"Saved {RESULTS_CSV_FILENAME}")
6593
7180
 
6594
- def __insert_job_into_ax_client__handle_type_error(e: Exception, arm_params: dict) -> bool:
7181
+ def handle_insert_job_error(e: Exception, arm_params: dict) -> bool:
6595
7182
  parsed_error = parse_parameter_type_error(e)
6596
7183
  if parsed_error is not None:
6597
7184
  param = parsed_error["parameter_name"]
@@ -6614,37 +7201,37 @@ def insert_job_into_ax_client(
6614
7201
  result: dict,
6615
7202
  new_job_type: str = "MANUAL",
6616
7203
  __status: Optional[Any] = None,
6617
- base_str: str = None
7204
+ base_str: Optional[str] = None
6618
7205
  ) -> bool:
6619
- __insert_job_into_ax_client__check_ax_client()
7206
+ check_ax_client()
6620
7207
 
6621
7208
  done_converting = False
6622
7209
  while not done_converting:
6623
7210
  try:
6624
- __insert_job_into_ax_client__update_status(__status, base_str, "Checking ax client")
7211
+ update_status(__status, base_str, "Checking ax client")
6625
7212
  if ax_client is None:
6626
7213
  return False
6627
7214
 
6628
- __insert_job_into_ax_client__update_status(__status, base_str, "Attaching new trial")
6629
- _, new_trial_idx = __insert_job_into_ax_client__attach_trial(arm_params)
7215
+ update_status(__status, base_str, "Attaching new trial")
7216
+ _, new_trial_idx = attach_trial(arm_params)
6630
7217
 
6631
- __insert_job_into_ax_client__update_status(__status, base_str, "Getting new trial")
6632
- trial = __insert_job_into_ax_client__get_trial(new_trial_idx)
6633
- __insert_job_into_ax_client__update_status(__status, base_str, "Got new trial")
7218
+ update_status(__status, base_str, "Getting new trial")
7219
+ trial = get_trial_by_index(new_trial_idx)
7220
+ update_status(__status, base_str, "Got new trial")
6634
7221
 
6635
- __insert_job_into_ax_client__update_status(__status, base_str, "Creating new arm")
6636
- manual_generator_run = __insert_job_into_ax_client__create_generator_run(arm_params, new_trial_idx, new_job_type)
7222
+ update_status(__status, base_str, "Creating new arm")
7223
+ manual_generator_run = create_generator_run(arm_params, new_trial_idx, new_job_type)
6637
7224
  trial._generator_run = manual_generator_run
6638
7225
  fool_linter(trial._generator_run)
6639
7226
 
6640
- __insert_job_into_ax_client__complete_trial_if_result(new_trial_idx, result, __status, base_str)
7227
+ complete_trial_if_result(new_trial_idx, result, __status, base_str)
6641
7228
  done_converting = True
6642
7229
 
6643
- __insert_job_into_ax_client__save_results_if_needed(__status, base_str)
7230
+ save_results_if_needed(__status, base_str)
6644
7231
  return True
6645
7232
 
6646
7233
  except ax.exceptions.core.UnsupportedError as e:
6647
- if not __insert_job_into_ax_client__handle_type_error(e, arm_params):
7234
+ if not handle_insert_job_error(e, arm_params):
6648
7235
  break
6649
7236
 
6650
7237
  return False
@@ -6975,7 +7562,7 @@ def get_parameters_from_outfile(stdout_path: str) -> Union[None, dict, str]:
6975
7562
  if not args.tests:
6976
7563
  original_print(f"get_parameters_from_outfile: The file '{stdout_path}' was not found.")
6977
7564
  except Exception as e:
6978
- print(f"get_parameters_from_outfile: There was an error: {e}")
7565
+ print_red(f"get_parameters_from_outfile: There was an error: {e}")
6979
7566
 
6980
7567
  return None
6981
7568
 
@@ -6993,7 +7580,7 @@ def get_hostname_from_outfile(stdout_path: Optional[str]) -> Optional[str]:
6993
7580
  original_print(f"The file '{stdout_path}' was not found.")
6994
7581
  return None
6995
7582
  except Exception as e:
6996
- print(f"There was an error: {e}")
7583
+ print_red(f"There was an error: {e}")
6997
7584
  return None
6998
7585
 
6999
7586
  def add_to_global_error_list(msg: str) -> None:
@@ -7029,22 +7616,68 @@ def mark_trial_as_failed(trial_index: int, _trial: Any) -> None:
7029
7616
 
7030
7617
  return None
7031
7618
 
7032
- ax_client.log_trial_failure(trial_index=trial_index)
7619
+ log_ax_client_trial_failure(trial_index)
7033
7620
  _trial.mark_failed(unsafe=True)
7034
7621
  except ValueError as e:
7035
7622
  print_debug(f"mark_trial_as_failed error: {e}")
7036
7623
 
7037
7624
  return None
7038
7625
 
7039
- def _finish_job_core_helper_check_valid_result(result: Union[None, list, int, float, tuple]) -> bool:
7626
+ def check_valid_result(result: Union[None, dict]) -> bool:
7040
7627
  possible_val_not_found_values = [
7041
7628
  VAL_IF_NOTHING_FOUND,
7042
7629
  -VAL_IF_NOTHING_FOUND,
7043
7630
  -99999999999999997168788049560464200849936328366177157906432,
7044
7631
  99999999999999997168788049560464200849936328366177157906432
7045
7632
  ]
7046
- values_to_check = result if isinstance(result, list) else [result]
7047
- return result is not None and all(r not in possible_val_not_found_values for r in values_to_check)
7633
+
7634
+ def flatten_values(obj: Any) -> Any:
7635
+ values = []
7636
+ try:
7637
+ if isinstance(obj, dict):
7638
+ for v in obj.values():
7639
+ values.extend(flatten_values(v))
7640
+ elif isinstance(obj, (list, tuple, set)):
7641
+ for v in obj:
7642
+ values.extend(flatten_values(v))
7643
+ else:
7644
+ values.append(obj)
7645
+ except Exception as e:
7646
+ print_red(f"Error while flattening values: {e}")
7647
+ return values
7648
+
7649
+ if result is None:
7650
+ return False
7651
+
7652
+ try:
7653
+ all_values = flatten_values(result)
7654
+ for val in all_values:
7655
+ if val in possible_val_not_found_values:
7656
+ return False
7657
+ return True
7658
+ except Exception as e:
7659
+ print_red(f"Error while checking result validity: {e}")
7660
+ return False
7661
+
7662
+ def update_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
7663
+ if not ax_client:
7664
+ my_exit(101)
7665
+
7666
+ return None
7667
+
7668
+ ax_client.update_trial_data(trial_index=trial_idx, raw_data=result)
7669
+
7670
+ return None
7671
+
7672
+ def complete_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
7673
+ if not ax_client:
7674
+ my_exit(101)
7675
+
7676
+ return None
7677
+
7678
+ ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
7679
+
7680
+ return None
7048
7681
 
7049
7682
  def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -> None:
7050
7683
  if ax_client is None:
@@ -7053,25 +7686,64 @@ def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -
7053
7686
 
7054
7687
  try:
7055
7688
  print_debug(f"Completing trial: {trial_index} with result: {raw_result}...")
7056
- ax_client.complete_trial(trial_index=trial_index, raw_data=raw_result)
7689
+ complete_ax_client_trial(trial_index, raw_result)
7057
7690
  print_debug(f"Completing trial: {trial_index} with result: {raw_result}... Done!")
7058
7691
  except ax.exceptions.core.UnsupportedError as e:
7059
7692
  if f"{e}":
7060
7693
  print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure. Trying to update trial...")
7061
- ax_client.update_trial_data(trial_index=trial_index, raw_data=raw_result)
7694
+ update_ax_client_trial(trial_index, raw_result)
7062
7695
  print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure... Done!")
7063
7696
  else:
7064
7697
  _fatal_error(f"Error completing trial: {e}", 234)
7065
7698
 
7066
7699
  return None
7067
7700
 
7068
- def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: Union[float, int, tuple]) -> None:
7701
+ def format_result_for_display(result: dict) -> str:
7702
+ def safe_float(v: Any) -> str:
7703
+ try:
7704
+ if v is None:
7705
+ return "None"
7706
+ if isinstance(v, (int, float)):
7707
+ if math.isnan(v):
7708
+ return "NaN"
7709
+ if math.isinf(v):
7710
+ return "∞" if v > 0 else "-∞"
7711
+ return f"{v:.6f}"
7712
+ return str(v)
7713
+ except Exception as e:
7714
+ return f"<error: {e}>"
7715
+
7716
+ try:
7717
+ if not isinstance(result, dict):
7718
+ return safe_float(result)
7719
+
7720
+ parts = []
7721
+ for key, val in result.items():
7722
+ try:
7723
+ if isinstance(val, (list, tuple)) and len(val) == 2:
7724
+ main, sem = val
7725
+ main_str = safe_float(main)
7726
+ if sem is not None:
7727
+ sem_str = safe_float(sem)
7728
+ parts.append(f"{key}: {main_str} (SEM: {sem_str})")
7729
+ else:
7730
+ parts.append(f"{key}: {main_str}")
7731
+ else:
7732
+ parts.append(f"{key}: {safe_float(val)}")
7733
+ except Exception as e:
7734
+ parts.append(f"{key}: <error: {e}>")
7735
+
7736
+ return ", ".join(parts)
7737
+ except Exception as e:
7738
+ return f"<error formatting result: {e}>"
7739
+
7740
+ def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: dict) -> None:
7069
7741
  print_debug(f"Marking trial {_trial} as completed")
7070
7742
  _trial.mark_completed(unsafe=True)
7071
7743
 
7072
7744
  succeeded_jobs(1)
7073
7745
 
7074
- progressbar_description(f"new result: {result}")
7746
+ progressbar_description(f"new result: {format_result_for_display(result)}")
7075
7747
  update_progress_bar(1)
7076
7748
 
7077
7749
  save_results_csv()
@@ -7085,8 +7757,8 @@ def _finish_job_core_helper_mark_failure(job: Any, trial_index: int, _trial: Any
7085
7757
  if job:
7086
7758
  try:
7087
7759
  progressbar_description("job_failed")
7088
- ax_client.log_trial_failure(trial_index=trial_index)
7089
- _trial.mark_failed(unsafe=True)
7760
+ log_ax_client_trial_failure(trial_index)
7761
+ mark_trial_as_failed(trial_index, _trial)
7090
7762
  except Exception as e:
7091
7763
  print_red(f"\nERROR while trying to mark job as failure: {e}")
7092
7764
  job.cancel()
@@ -7103,16 +7775,16 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7103
7775
  result = job.result()
7104
7776
  print_debug(f"finish_job_core: trial-index: {trial_index}, job.result(): {result}, state: {state_from_job(job)}")
7105
7777
 
7106
- raw_result = result
7107
- result_keys = list(result.keys())
7108
- result = result[result_keys[0]]
7109
7778
  this_jobs_finished += 1
7110
7779
 
7111
7780
  if ax_client:
7112
- _trial = ax_client.get_trial(trial_index)
7781
+ _trial = get_ax_client_trial(trial_index)
7782
+
7783
+ if _trial is None:
7784
+ return 0
7113
7785
 
7114
- if _finish_job_core_helper_check_valid_result(result):
7115
- _finish_job_core_helper_complete_trial(trial_index, raw_result)
7786
+ if check_valid_result(result):
7787
+ _finish_job_core_helper_complete_trial(trial_index, result)
7116
7788
 
7117
7789
  try:
7118
7790
  _finish_job_core_helper_mark_success(_trial, result)
@@ -7120,15 +7792,17 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7120
7792
  if len(arg_result_names) > 1 and count_done_jobs() > 1 and not job_calculate_pareto_front(get_current_run_folder(), True):
7121
7793
  print_red("job_calculate_pareto_front post job failed")
7122
7794
  except Exception as e:
7123
- print(f"ERROR in line {get_line_info()}: {e}")
7795
+ print_red(f"ERROR in line {get_line_info()}: {e}")
7124
7796
  else:
7125
7797
  _finish_job_core_helper_mark_failure(job, trial_index, _trial)
7126
7798
  else:
7127
- _fatal_error("ax_client could not be found or used", 9)
7799
+ _fatal_error("ax_client could not be found or used", 101)
7128
7800
 
7129
7801
  print_debug(f"finish_job_core: removing job {job}, trial_index: {trial_index}")
7130
7802
  global_vars["jobs"].remove((job, trial_index))
7131
7803
 
7804
+ log_data()
7805
+
7132
7806
  force_live_share()
7133
7807
 
7134
7808
  return this_jobs_finished
@@ -7141,11 +7815,14 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
7141
7815
  if job:
7142
7816
  try:
7143
7817
  progressbar_description("job_failed")
7144
- _trial = ax_client.get_trial(trial_index)
7145
- ax_client.log_trial_failure(trial_index=trial_index)
7818
+ _trial = get_ax_client_trial(trial_index)
7819
+ if _trial is None:
7820
+ return None
7821
+
7822
+ log_ax_client_trial_failure(trial_index)
7146
7823
  mark_trial_as_failed(trial_index, _trial)
7147
7824
  except Exception as e:
7148
- print(f"ERROR in line {get_line_info()}: {e}")
7825
+ print_debug(f"ERROR in line {get_line_info()}: {e}")
7149
7826
  job.cancel()
7150
7827
  orchestrate_job(job, trial_index)
7151
7828
 
@@ -7160,10 +7837,12 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
7160
7837
 
7161
7838
  def _finish_previous_jobs_helper_handle_exception(job: Any, trial_index: int, error: Exception) -> int:
7162
7839
  if "None for metric" in str(error):
7163
- print_red(
7164
- f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string."
7165
- f"\nError: {error}\nJob-result: {job.result()}"
7166
- )
7840
+ err_msg = f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string.\nError: {error}\nJob-result: {job.result()}"
7841
+
7842
+ if count_done_jobs() == 0:
7843
+ print_red(err_msg)
7844
+ else:
7845
+ print_debug(err_msg)
7167
7846
  else:
7168
7847
  print_red(f"\n⚠ {error}")
7169
7848
 
@@ -7182,7 +7861,10 @@ def _finish_previous_jobs_helper_process_job(job: Any, trial_index: int, this_jo
7182
7861
  this_jobs_finished += _finish_previous_jobs_helper_handle_exception(job, trial_index, error)
7183
7862
  return this_jobs_finished
7184
7863
 
7185
- def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7864
+ def _finish_previous_jobs_helper_check_and_process(__args: Tuple[Any, int]) -> int:
7865
+ job, trial_index = __args
7866
+
7867
+ this_jobs_finished = 0
7186
7868
  if job is None:
7187
7869
  print_debug(f"finish_previous_jobs: job {job} is None")
7188
7870
  return this_jobs_finished
@@ -7195,10 +7877,6 @@ def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, t
7195
7877
 
7196
7878
  return this_jobs_finished
7197
7879
 
7198
- def _finish_previous_jobs_helper_wrapper(__args: Tuple[Any, int]) -> int:
7199
- job, trial_index = __args
7200
- return _finish_previous_jobs_helper_check_and_process(job, trial_index, 0)
7201
-
7202
7880
  def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7203
7881
  global JOBS_FINISHED
7204
7882
 
@@ -7211,13 +7889,10 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7211
7889
 
7212
7890
  jobs_copy = global_vars["jobs"][:]
7213
7891
 
7214
- if len(jobs_copy) > 0:
7215
- print_debug(f"jobs in finish_previous_jobs: {jobs_copy}")
7216
-
7217
- finishing_jobs_start_time = time.time()
7892
+ #finishing_jobs_start_time = time.time()
7218
7893
 
7219
7894
  with ThreadPoolExecutor() as finish_job_executor:
7220
- futures = [finish_job_executor.submit(_finish_previous_jobs_helper_wrapper, (job, trial_index)) for job, trial_index in jobs_copy]
7895
+ futures = [finish_job_executor.submit(_finish_previous_jobs_helper_check_and_process, (job, trial_index)) for job, trial_index in jobs_copy]
7221
7896
 
7222
7897
  for future in as_completed(futures):
7223
7898
  try:
@@ -7225,11 +7900,11 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7225
7900
  except Exception as e:
7226
7901
  print_red(f"⚠ Exception in parallel job handling: {e}")
7227
7902
 
7228
- finishing_jobs_end_time = time.time()
7903
+ #finishing_jobs_end_time = time.time()
7229
7904
 
7230
- finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
7905
+ #finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
7231
7906
 
7232
- print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
7907
+ #print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
7233
7908
 
7234
7909
  if this_jobs_finished > 0:
7235
7910
  save_results_csv()
@@ -7376,11 +8051,15 @@ def is_already_in_defective_nodes(hostname: str) -> bool:
7376
8051
  return True
7377
8052
  except Exception as e:
7378
8053
  print_red(f"is_already_in_defective_nodes: Error reading the file {file_path}: {e}")
7379
- return False
7380
8054
 
7381
8055
  return False
7382
8056
 
7383
8057
  def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
8058
+ if submitit_executor is None:
8059
+ print_red("submit_new_job: submitit_executor was None")
8060
+
8061
+ return None
8062
+
7384
8063
  print_debug(f"Submitting new job for trial_index {trial_index}, parameters {parameters}")
7385
8064
 
7386
8065
  start = time.time()
@@ -7391,25 +8070,49 @@ def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
7391
8070
 
7392
8071
  print_debug(f"Done submitting new job, took {elapsed} seconds")
7393
8072
 
8073
+ log_data()
8074
+
7394
8075
  return new_job
7395
8076
 
8077
+ def get_ax_client_trial(trial_index: int) -> Optional[ax.core.trial.Trial]:
8078
+ if not ax_client:
8079
+ my_exit(101)
8080
+
8081
+ return None
8082
+
8083
+ try:
8084
+ log_data()
8085
+
8086
+ return ax_client.get_trial(trial_index)
8087
+ except KeyError:
8088
+ error_without_print(f"get_ax_client_trial: trial_index {trial_index} failed")
8089
+ return None
8090
+
7396
8091
  def orchestrator_start_trial(parameters: Union[dict, str], trial_index: int) -> None:
7397
8092
  if submitit_executor and ax_client:
7398
8093
  new_job = submit_new_job(parameters, trial_index)
7399
- submitted_jobs(1)
8094
+ if new_job:
8095
+ submitted_jobs(1)
7400
8096
 
7401
- _trial = ax_client.get_trial(trial_index)
8097
+ _trial = get_ax_client_trial(trial_index)
7402
8098
 
7403
- try:
7404
- _trial.mark_staged(unsafe=True)
7405
- except Exception as e:
7406
- print_debug(f"orchestrator_start_trial: error {e}")
7407
- _trial.mark_running(unsafe=True, no_runner_required=True)
8099
+ if _trial is not None:
8100
+ try:
8101
+ _trial.mark_staged(unsafe=True)
8102
+ except Exception as e:
8103
+ print_debug(f"orchestrator_start_trial: error {e}")
8104
+ _trial.mark_running(unsafe=True, no_runner_required=True)
7408
8105
 
7409
- print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
7410
- global_vars["jobs"].append((new_job, trial_index))
8106
+ print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
8107
+ global_vars["jobs"].append((new_job, trial_index))
8108
+ else:
8109
+ print_red("Trial was none in orchestrator_start_trial")
8110
+ else:
8111
+ print_red("orchestrator_start_trial: Failed to start new job")
8112
+ elif ax_client:
8113
+ _fatal_error("submitit_executor could not be found properly", 9)
7411
8114
  else:
7412
- _fatal_error("submitit_executor or ax_client could not be found properly", 9)
8115
+ _fatal_error("ax_client could not be found properly", 101)
7413
8116
 
7414
8117
  def handle_exclude_node(stdout_path: str, hostname_from_out_file: Union[None, str]) -> None:
7415
8118
  stdout_path = check_alternate_path(stdout_path)
@@ -7428,7 +8131,7 @@ def handle_restart(stdout_path: str, trial_index: int) -> None:
7428
8131
  if parameters:
7429
8132
  orchestrator_start_trial(parameters, trial_index)
7430
8133
  else:
7431
- print(f"Could not determine parameters from outfile {stdout_path} for restarting job")
8134
+ print_red(f"Could not determine parameters from outfile {stdout_path} for restarting job")
7432
8135
 
7433
8136
  def check_alternate_path(path: str) -> str:
7434
8137
  if os.path.exists(path):
@@ -7515,7 +8218,7 @@ def execute_evaluation(_params: list) -> Optional[int]:
7515
8218
  print_debug(f"execute_evaluation({_params})")
7516
8219
  trial_index, parameters, trial_counter, phase = _params
7517
8220
  if not ax_client:
7518
- _fatal_error("Failed to get ax_client", 9)
8221
+ _fatal_error("Failed to get ax_client", 101)
7519
8222
 
7520
8223
  return None
7521
8224
 
@@ -7524,7 +8227,11 @@ def execute_evaluation(_params: list) -> Optional[int]:
7524
8227
 
7525
8228
  return None
7526
8229
 
7527
- _trial = ax_client.get_trial(trial_index)
8230
+ _trial = get_ax_client_trial(trial_index)
8231
+
8232
+ if _trial is None:
8233
+ error_without_print(f"execute_evaluation: _trial was not in execute_evaluation for params {_params}")
8234
+ return None
7528
8235
 
7529
8236
  def mark_trial_stage(stage: str, error_msg: str) -> None:
7530
8237
  try:
@@ -7539,15 +8246,18 @@ def execute_evaluation(_params: list) -> Optional[int]:
7539
8246
  try:
7540
8247
  initialize_job_environment()
7541
8248
  new_job = submit_new_job(parameters, trial_index)
7542
- submitted_jobs(1)
8249
+ if new_job:
8250
+ submitted_jobs(1)
7543
8251
 
7544
- print_debug(f"execute_evaluation: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
7545
- global_vars["jobs"].append((new_job, trial_index))
8252
+ print_debug(f"execute_evaluation: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
8253
+ global_vars["jobs"].append((new_job, trial_index))
7546
8254
 
7547
- mark_trial_stage("mark_running", "Marking the trial as running failed")
7548
- trial_counter += 1
8255
+ mark_trial_stage("mark_running", "Marking the trial as running failed")
8256
+ trial_counter += 1
7549
8257
 
7550
- progressbar_description("started new job")
8258
+ progressbar_description("started new job")
8259
+ else:
8260
+ progressbar_description("Failed to start new job")
7551
8261
  except submitit.core.utils.FailedJobError as error:
7552
8262
  handle_failed_job(error, trial_index, new_job)
7553
8263
  trial_counter += 1
@@ -7589,7 +8299,7 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
7589
8299
  my_exit(144)
7590
8300
 
7591
8301
  if new_job is None:
7592
- print_red("handle_failed_job: job is None")
8302
+ print_debug("handle_failed_job: job is None")
7593
8303
 
7594
8304
  return None
7595
8305
 
@@ -7600,16 +8310,24 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
7600
8310
 
7601
8311
  return None
7602
8312
 
8313
+ def log_ax_client_trial_failure(trial_index: int) -> None:
8314
+ if not ax_client:
8315
+ my_exit(101)
8316
+
8317
+ return
8318
+
8319
+ ax_client.log_trial_failure(trial_index=trial_index)
8320
+
7603
8321
  def cancel_failed_job(trial_index: int, new_job: Job) -> None:
7604
8322
  print_debug("Trying to cancel job that failed")
7605
8323
  if new_job:
7606
8324
  try:
7607
8325
  if ax_client:
7608
- ax_client.log_trial_failure(trial_index=trial_index)
8326
+ log_ax_client_trial_failure(trial_index)
7609
8327
  else:
7610
8328
  _fatal_error("ax_client not defined", 101)
7611
8329
  except Exception as e:
7612
- print(f"ERROR in line {get_line_info()}: {e}")
8330
+ print_red(f"ERROR in line {get_line_info()}: {e}")
7613
8331
  new_job.cancel()
7614
8332
 
7615
8333
  print_debug(f"cancel_failed_job: removing job {new_job}, trial_index: {trial_index}")
@@ -7645,10 +8363,12 @@ def show_debug_table_for_break_run_search(_name: str, _max_eval: Optional[int])
7645
8363
  ("failed_jobs()", failed_jobs()),
7646
8364
  ("count_done_jobs()", count_done_jobs()),
7647
8365
  ("_max_eval", _max_eval),
7648
- ("progress_bar.total", progress_bar.total),
7649
8366
  ("NR_INSERTED_JOBS", NR_INSERTED_JOBS)
7650
8367
  ]
7651
8368
 
8369
+ if progress_bar is not None:
8370
+ rows.append(("progress_bar.total", progress_bar.total))
8371
+
7652
8372
  for row in rows:
7653
8373
  table.add_row(str(row[0]), str(row[1]))
7654
8374
 
@@ -7661,6 +8381,8 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
7661
8381
  _submitted_jobs = submitted_jobs()
7662
8382
  _failed_jobs = failed_jobs()
7663
8383
 
8384
+ log_data()
8385
+
7664
8386
  max_failed_jobs = max_eval
7665
8387
 
7666
8388
  if args.max_failed_jobs is not None and args.max_failed_jobs > 0:
@@ -7688,11 +8410,11 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
7688
8410
  _ret = True
7689
8411
 
7690
8412
  if args.verbose_break_run_search_table:
7691
- show_debug_table_for_break_run_search(_name, _max_eval, _ret)
8413
+ show_debug_table_for_break_run_search(_name, _max_eval)
7692
8414
 
7693
8415
  return _ret
7694
8416
 
7695
- def _calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
8417
+ def calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
7696
8418
  """Calculates the number of jobs to retrieve."""
7697
8419
  return min(
7698
8420
  max_eval + simulated_jobs - count_done_jobs(),
@@ -7705,7 +8427,7 @@ def remove_extra_spaces(text: str) -> str:
7705
8427
  raise ValueError("Input must be a string")
7706
8428
  return re.sub(r'\s+', ' ', text).strip()
7707
8429
 
7708
- def _get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
8430
+ def get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
7709
8431
  """Generates the appropriate message for the number of trials being retrieved."""
7710
8432
  ret = ""
7711
8433
  if full_nr_of_jobs_to_get > 1:
@@ -7790,51 +8512,51 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
7790
8512
  print_red("get_batched_arms: ax_client was None")
7791
8513
  return []
7792
8514
 
7793
- # Experiment-Status laden
7794
8515
  load_experiment_state()
7795
8516
 
7796
- while len(batched_arms) != nr_of_jobs_to_get:
8517
+ while len(batched_arms) < nr_of_jobs_to_get:
7797
8518
  if attempts > args.max_attempts_for_generation:
7798
8519
  print_debug(f"get_batched_arms: Stopped after {attempts} attempts: could not generate enough arms "
7799
8520
  f"(got {len(batched_arms)} out of {nr_of_jobs_to_get}).")
7800
8521
  break
7801
8522
 
7802
- remaining = nr_of_jobs_to_get - len(batched_arms)
7803
- print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting {remaining} more arm(s).")
7804
-
7805
- print_debug("get pending observations")
7806
- t0 = time.time()
7807
- pending_observations = get_pending_observation_features(experiment=ax_client.experiment)
7808
- dt = time.time() - t0
7809
- print_debug(f"got pending observations: {pending_observations} (took {dt:.2f} seconds)")
8523
+ #print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting 1 more arm")
7810
8524
 
7811
- print_debug("getting global_gs.gen()")
7812
- batched_generator_run = global_gs.gen(
8525
+ pending_observations = get_pending_observation_features(
7813
8526
  experiment=ax_client.experiment,
7814
- n=remaining,
7815
- pending_observations=pending_observations
8527
+ include_out_of_design_points=True
7816
8528
  )
7817
- print_debug(f"got global_gs.gen(): {batched_generator_run}")
8529
+
8530
+ try:
8531
+ #print_debug("getting global_gs.gen() with n=1")
8532
+ batched_generator_run: Any = global_gs.gen(
8533
+ experiment=ax_client.experiment,
8534
+ n=1,
8535
+ pending_observations=pending_observations,
8536
+ )
8537
+ print_debug(f"got global_gs.gen(): {batched_generator_run}")
8538
+ except Exception as e:
8539
+ print_debug(f"global_gs.gen failed: {e}")
8540
+ traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
8541
+ break
7818
8542
 
7819
8543
  depth = 0
7820
8544
  path = "batched_generator_run"
7821
- while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) > 0:
7822
- print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
8545
+ while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) == 1:
8546
+ #print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
7823
8547
  batched_generator_run = batched_generator_run[0]
7824
8548
  path += "[0]"
7825
8549
  depth += 1
7826
8550
 
7827
- print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
8551
+ #print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
7828
8552
 
7829
- print_debug("got new arms")
7830
- new_arms = batched_generator_run.arms
7831
- print_debug(f"new_arms: {new_arms}")
8553
+ new_arms = getattr(batched_generator_run, "arms", [])
7832
8554
  if not new_arms:
7833
8555
  print_debug("get_batched_arms: No new arms were generated in this attempt.")
7834
8556
  else:
7835
- print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), wanted {nr_of_jobs_to_get}.")
8557
+ print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), now at {len(batched_arms) + len(new_arms)} of {nr_of_jobs_to_get}.")
8558
+ batched_arms.extend(new_arms)
7836
8559
 
7837
- batched_arms.extend(new_arms)
7838
8560
  attempts += 1
7839
8561
 
7840
8562
  print_debug(f"get_batched_arms: Finished with {len(batched_arms)} arm(s) after {attempts} attempt(s).")
@@ -7845,7 +8567,7 @@ def fetch_next_trials(nr_of_jobs_to_get: int, recursion: bool = False) -> Tuple[
7845
8567
  die_101_if_no_ax_client_or_experiment_or_gs()
7846
8568
 
7847
8569
  if not ax_client:
7848
- _fatal_error("ax_client was not defined", 9)
8570
+ _fatal_error("ax_client was not defined", 101)
7849
8571
 
7850
8572
  if global_gs is None:
7851
8573
  _fatal_error("Global generation strategy is not set. This is a bug in OmniOpt2.", 107)
@@ -7873,16 +8595,14 @@ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
7873
8595
  retries += 1
7874
8596
  continue
7875
8597
 
7876
- print_debug(f"Fetching trial {cnt + 1}/{n}...")
7877
- progressbar_description(_get_trials_message(cnt + 1, n, trial_durations))
8598
+ progressbar_description(get_trials_message(cnt + 1, n, trial_durations))
7878
8599
 
7879
8600
  try:
7880
8601
  result = create_and_handle_trial(arm)
7881
8602
  if result is not None:
7882
8603
  trial_index, trial_duration, trial_successful = result
7883
-
7884
8604
  except TrialRejected as e:
7885
- print_debug(f"Trial rejected: {e}")
8605
+ print_debug(f"generate_trials: Trial rejected, error: {e}")
7886
8606
  retries += 1
7887
8607
  continue
7888
8608
 
@@ -7892,14 +8612,23 @@ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
7892
8612
  cnt += 1
7893
8613
  trials_dict[trial_index] = arm.parameters
7894
8614
 
7895
- return _finalize_generation(trials_dict, cnt, n, start_time)
8615
+ finalized = finalize_generation(trials_dict, cnt, n, start_time)
8616
+
8617
+ return finalized
7896
8618
 
7897
8619
  except Exception as e:
7898
- return _handle_generation_failure(e, n, recursion)
8620
+ return handle_generation_failure(e, n, recursion)
7899
8621
 
7900
8622
  class TrialRejected(Exception):
7901
8623
  pass
7902
8624
 
8625
+ def mark_abandoned(trial: Any, reason: str, trial_index: int) -> None:
8626
+ try:
8627
+ print_debug(f"[INFO] Marking trial {trial.index} ({trial.arm.name}) as abandoned, trial-index: {trial_index}. Reason: {reason}")
8628
+ trial.mark_abandoned(reason)
8629
+ except Exception as e:
8630
+ print_red(f"[ERROR] Could not mark trial as abandoned: {e}")
8631
+
7903
8632
  def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7904
8633
  if ax_client is None:
7905
8634
  print_red("ax_client is None in create_and_handle_trial")
@@ -7925,7 +8654,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7925
8654
  arm = trial.arms[0]
7926
8655
  if deduplicated_arm(arm):
7927
8656
  print_debug(f"Duplicated arm: {arm}")
7928
- trial.mark_abandoned(reason="Duplication detected")
8657
+ mark_abandoned(trial, "Duplication detected", trial_index)
7929
8658
  raise TrialRejected("Duplicate arm.")
7930
8659
 
7931
8660
  arms_by_name_for_deduplication[arm.name] = arm
@@ -7934,7 +8663,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7934
8663
 
7935
8664
  if not has_no_post_generation_constraints_or_matches_constraints(post_generation_constraints, params):
7936
8665
  print_debug(f"Trial {trial_index} does not meet post-generation constraints. Marking abandoned. Params: {params}, constraints: {post_generation_constraints}")
7937
- trial.mark_abandoned(reason="Post-Generation-Constraint failed")
8666
+ mark_abandoned(trial, "Post-Generation-Constraint failed", trial_index)
7938
8667
  abandoned_trial_indices.append(trial_index)
7939
8668
  raise TrialRejected("Post-generation constraints not met.")
7940
8669
 
@@ -7942,7 +8671,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7942
8671
  end = time.time()
7943
8672
  return trial_index, float(end - start), True
7944
8673
 
7945
- def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
8674
+ def finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
7946
8675
  total_time = time.time() - start_time
7947
8676
 
7948
8677
  log_gen_times.append(total_time)
@@ -7953,7 +8682,7 @@ def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int,
7953
8682
 
7954
8683
  return trials_dict, False
7955
8684
 
7956
- def _handle_generation_failure(
8685
+ def handle_generation_failure(
7957
8686
  e: Exception,
7958
8687
  requested: int,
7959
8688
  recursion: bool
@@ -7969,19 +8698,19 @@ def _handle_generation_failure(
7969
8698
  )):
7970
8699
  msg = str(e)
7971
8700
  if msg not in error_8_saved:
7972
- _print_exhaustion_warning(e, recursion)
8701
+ print_exhaustion_warning(e, recursion)
7973
8702
  error_8_saved.append(msg)
7974
8703
 
7975
8704
  if not recursion and args.revert_to_random_when_seemingly_exhausted:
7976
8705
  print_debug("Switching to random search strategy.")
7977
- set_global_gs_to_random()
8706
+ set_global_gs_to_sobol()
7978
8707
  return fetch_next_trials(requested, True)
7979
8708
 
7980
- print_red(f"_handle_generation_failure: General Exception: {e}")
8709
+ print_red(f"handle_generation_failure: General Exception: {e}")
7981
8710
 
7982
8711
  return {}, True
7983
8712
 
7984
- def _print_exhaustion_warning(e: Exception, recursion: bool) -> None:
8713
+ def print_exhaustion_warning(e: Exception, recursion: bool) -> None:
7985
8714
  if not recursion and args.revert_to_random_when_seemingly_exhausted:
7986
8715
  print_yellow(f"\n⚠Error 8: {e} From now (done jobs: {count_done_jobs()}) on, random points will be generated.")
7987
8716
  else:
@@ -8026,21 +8755,24 @@ def get_model_gen_kwargs() -> dict:
8026
8755
  "fit_out_of_design": args.fit_out_of_design
8027
8756
  }
8028
8757
 
8029
- def set_global_gs_to_random() -> None:
8758
+ def set_global_gs_to_sobol() -> None:
8030
8759
  global global_gs
8031
8760
  global overwritten_to_random
8032
8761
 
8762
+ print("Reverting to SOBOL")
8763
+
8033
8764
  global_gs = GenerationStrategy(
8034
8765
  name="Random*",
8035
8766
  nodes=[
8036
8767
  GenerationNode(
8037
8768
  node_name="Sobol",
8038
- generator_specs=[
8039
- GeneratorSpec(
8040
- Models.SOBOL,
8041
- model_gen_kwargs=get_model_gen_kwargs()
8042
- )
8043
- ]
8769
+ should_deduplicate=True,
8770
+ generator_specs=[ # type: ignore[arg-type]
8771
+ GeneratorSpec( # type: ignore[arg-type]
8772
+ Models.SOBOL, # type: ignore[arg-type]
8773
+ model_gen_kwargs=get_model_gen_kwargs() # type: ignore[arg-type]
8774
+ ) # type: ignore[arg-type]
8775
+ ] # type: ignore[arg-type]
8044
8776
  )
8045
8777
  ]
8046
8778
  )
@@ -8400,16 +9132,21 @@ def get_model_from_name(name: str) -> Any:
8400
9132
  return gen
8401
9133
  raise ValueError(f"Unknown or unsupported model: {name}")
8402
9134
 
8403
- def get_name_from_model(model) -> Optional[str]:
9135
+ def get_name_from_model(model: Any) -> str:
8404
9136
  if not isinstance(SUPPORTED_MODELS, (list, set, tuple)):
8405
- return None
9137
+ raise RuntimeError("get_model_from_name: SUPPORTED_MODELS was not a list, set or tuple. Cannot continue")
8406
9138
 
8407
9139
  model_str = model.value if hasattr(model, "value") else str(model)
8408
9140
 
8409
9141
  model_str_lower = model_str.lower()
8410
9142
  model_map = {m.lower(): m for m in SUPPORTED_MODELS}
8411
9143
 
8412
- return model_map.get(model_str_lower, None)
9144
+ ret = model_map.get(model_str_lower, None)
9145
+
9146
+ if ret is None:
9147
+ raise RuntimeError("get_name_from_model: failed to get Model")
9148
+
9149
+ return ret
8413
9150
 
8414
9151
  def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str, int]], int]:
8415
9152
  gen_strat_list = []
@@ -8420,10 +9157,10 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
8420
9157
 
8421
9158
  for s in splitted_by_comma:
8422
9159
  if "=" not in s:
8423
- print(f"'{s}' does not contain '='")
9160
+ print_red(f"'{s}' does not contain '='")
8424
9161
  my_exit(123)
8425
9162
  if s.count("=") != 1:
8426
- print(f"There can only be one '=' in the gen_strat_str's element '{s}'")
9163
+ print_red(f"There can only be one '=' in the gen_strat_str's element '{s}'")
8427
9164
  my_exit(123)
8428
9165
 
8429
9166
  model_name, nr_str = s.split("=")
@@ -8433,13 +9170,13 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
8433
9170
  _fatal_error(f"Model {matching_model} is not valid for custom generation strategy.", 56)
8434
9171
 
8435
9172
  if not matching_model:
8436
- print(f"'{model_name}' not found in SUPPORTED_MODELS")
9173
+ print_red(f"'{model_name}' not found in SUPPORTED_MODELS")
8437
9174
  my_exit(123)
8438
9175
 
8439
9176
  try:
8440
9177
  nr = int(nr_str)
8441
9178
  except ValueError:
8442
- print(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
9179
+ print_red(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
8443
9180
  my_exit(123)
8444
9181
 
8445
9182
  gen_strat_list.append({matching_model: nr})
@@ -8596,11 +9333,12 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
8596
9333
  if model_name.lower() != "sobol":
8597
9334
  kwargs["model_kwargs"] = get_model_kwargs()
8598
9335
 
8599
- model_spec = [GeneratorSpec(selected_model, **kwargs)]
9336
+ model_spec = [GeneratorSpec(selected_model, **kwargs)] # type: ignore[arg-type]
8600
9337
 
8601
9338
  res = GenerationNode(
8602
9339
  node_name=model_name,
8603
9340
  generator_specs=model_spec,
9341
+ should_deduplicate=True,
8604
9342
  transition_criteria=trans_crit
8605
9343
  )
8606
9344
 
@@ -8611,7 +9349,7 @@ def get_optimizer_kwargs() -> dict:
8611
9349
  "sequential": False
8612
9350
  }
8613
9351
 
8614
- def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = None) -> GenerationStep:
9352
+ def create_step(model_name: str, _num_trials: int, index: int) -> GenerationStep:
8615
9353
  model_enum = get_model_from_name(model_name)
8616
9354
 
8617
9355
  return GenerationStep(
@@ -8626,17 +9364,16 @@ def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = N
8626
9364
  )
8627
9365
 
8628
9366
  def set_global_generation_strategy() -> None:
8629
- with spinner("Setting global generation strategy"):
8630
- continue_not_supported_on_custom_generation_strategy()
9367
+ continue_not_supported_on_custom_generation_strategy()
8631
9368
 
8632
- try:
8633
- if args.generation_strategy is None:
8634
- setup_default_generation_strategy()
8635
- else:
8636
- setup_custom_generation_strategy()
8637
- except Exception as e:
8638
- print_red(f"Unexpected error in generation strategy setup: {e}")
8639
- my_exit(111)
9369
+ try:
9370
+ if args.generation_strategy is None:
9371
+ setup_default_generation_strategy()
9372
+ else:
9373
+ setup_custom_generation_strategy()
9374
+ except Exception as e:
9375
+ print_red(f"Unexpected error in generation strategy setup: {e}")
9376
+ my_exit(111)
8640
9377
 
8641
9378
  if global_gs is None:
8642
9379
  print_red("global_gs is None after setup!")
@@ -8787,10 +9524,14 @@ def execute_trials(
8787
9524
  result = future.result()
8788
9525
  print_debug(f"result in execute_trials: {result}")
8789
9526
  except Exception as exc:
8790
- print_red(f"execute_trials: Error at executing a trial: {exc}")
9527
+ failed_args = future_to_args[future]
9528
+ print_red(f"execute_trials: Error at executing a trial with args {failed_args}: {exc}")
9529
+ traceback.print_exc()
8791
9530
 
8792
9531
  end_time = time.time()
8793
9532
 
9533
+ log_data()
9534
+
8794
9535
  duration = float(end_time - start_time)
8795
9536
  job_submit_durations.append(duration)
8796
9537
  job_submit_nrs.append(cnt)
@@ -8824,20 +9565,20 @@ def create_and_execute_next_runs(next_nr_steps: int, phase: Optional[str], _max_
8824
9565
  done_optimizing: bool = False
8825
9566
 
8826
9567
  try:
8827
- done_optimizing, trial_index_to_param = _create_and_execute_next_runs_run_loop(_max_eval, phase)
8828
- _create_and_execute_next_runs_finish(done_optimizing)
9568
+ done_optimizing, trial_index_to_param = create_and_execute_next_runs_run_loop(_max_eval, phase)
9569
+ create_and_execute_next_runs_finish(done_optimizing)
8829
9570
  except Exception as e:
8830
9571
  stacktrace = traceback.format_exc()
8831
9572
  print_debug(f"Warning: create_and_execute_next_runs encountered an exception: {e}\n{stacktrace}")
8832
9573
  return handle_exceptions_create_and_execute_next_runs(e)
8833
9574
 
8834
- return _create_and_execute_next_runs_return_value(trial_index_to_param)
9575
+ return create_and_execute_next_runs_return_value(trial_index_to_param)
8835
9576
 
8836
- def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
9577
+ def create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
8837
9578
  done_optimizing = False
8838
9579
  trial_index_to_param: Optional[Dict] = None
8839
9580
 
8840
- nr_of_jobs_to_get = _calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
9581
+ nr_of_jobs_to_get = calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
8841
9582
 
8842
9583
  __max_eval = _max_eval if _max_eval is not None else 0
8843
9584
  new_nr_of_jobs_to_get = min(__max_eval - (submitted_jobs() - failed_jobs()), nr_of_jobs_to_get)
@@ -8851,6 +9592,7 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
8851
9592
 
8852
9593
  for _ in range(range_nr):
8853
9594
  trial_index_to_param, done_optimizing = get_next_trials(get_next_trials_nr)
9595
+ log_data()
8854
9596
  if done_optimizing:
8855
9597
  continue
8856
9598
 
@@ -8875,13 +9617,13 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
8875
9617
 
8876
9618
  return done_optimizing, trial_index_to_param
8877
9619
 
8878
- def _create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
9620
+ def create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
8879
9621
  finish_previous_jobs(["finishing jobs"])
8880
9622
 
8881
9623
  if done_optimizing:
8882
9624
  end_program(False, 0)
8883
9625
 
8884
- def _create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
9626
+ def create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
8885
9627
  try:
8886
9628
  if trial_index_to_param:
8887
9629
  res = len(trial_index_to_param.keys())
@@ -8992,7 +9734,7 @@ def execute_nvidia_smi() -> None:
8992
9734
  if not host:
8993
9735
  print_debug("host not defined")
8994
9736
  except Exception as e:
8995
- print(f"execute_nvidia_smi: An error occurred: {e}")
9737
+ print_red(f"execute_nvidia_smi: An error occurred: {e}")
8996
9738
  if is_slurm_job() and not args.force_local_execution:
8997
9739
  _sleep(30)
8998
9740
 
@@ -9030,11 +9772,6 @@ def run_search() -> bool:
9030
9772
 
9031
9773
  return False
9032
9774
 
9033
- async def start_logging_daemon() -> None:
9034
- while True:
9035
- log_data()
9036
- time.sleep(30)
9037
-
9038
9775
  def should_break_search() -> bool:
9039
9776
  ret = False
9040
9777
 
@@ -9083,10 +9820,8 @@ def check_search_space_exhaustion(nr_of_items: int) -> bool:
9083
9820
  print_debug(_wrn)
9084
9821
  progressbar_description(_wrn)
9085
9822
 
9086
- live_share()
9087
9823
  return True
9088
9824
 
9089
- live_share()
9090
9825
  return False
9091
9826
 
9092
9827
  def finalize_jobs() -> None:
@@ -9100,7 +9835,7 @@ def finalize_jobs() -> None:
9100
9835
  handle_slurm_execution()
9101
9836
 
9102
9837
  def go_through_jobs_that_are_not_completed_yet() -> None:
9103
- print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
9838
+ #print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
9104
9839
 
9105
9840
  nr_jobs_left = len(global_vars['jobs'])
9106
9841
  if nr_jobs_left == 1:
@@ -9170,7 +9905,7 @@ def parse_orchestrator_file(_f: str, _test: bool = False) -> Union[dict, None]:
9170
9905
 
9171
9906
  return data
9172
9907
  except Exception as e:
9173
- print(f"Error while parse_experiment_parameters({_f}): {e}")
9908
+ print_red(f"Error while parse_experiment_parameters({_f}): {e}")
9174
9909
  else:
9175
9910
  print_red(f"{_f} could not be found")
9176
9911
 
@@ -9178,364 +9913,52 @@ def parse_orchestrator_file(_f: str, _test: bool = False) -> Union[dict, None]:
9178
9913
 
9179
9914
  def set_orchestrator() -> None:
9180
9915
  with spinner("Setting orchestrator..."):
9181
- global orchestrator
9182
-
9183
- if args.orchestrator_file:
9184
- if SYSTEM_HAS_SBATCH:
9185
- orchestrator = parse_orchestrator_file(args.orchestrator_file, False)
9186
- else:
9187
- print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
9188
-
9189
- def check_if_has_random_steps() -> None:
9190
- if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
9191
- _fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
9192
-
9193
- def add_exclude_to_defective_nodes() -> None:
9194
- with spinner("Adding excluded nodes..."):
9195
- if args.exclude:
9196
- entries = [entry.strip() for entry in args.exclude.split(',')]
9197
-
9198
- for entry in entries:
9199
- count_defective_nodes(None, entry)
9200
-
9201
- def check_max_eval(_max_eval: int) -> None:
9202
- with spinner("Checking max_eval..."):
9203
- if not _max_eval:
9204
- _fatal_error("--max_eval needs to be set!", 19)
9205
-
9206
- def parse_parameters() -> Any:
9207
- cli_params_experiment_parameters = None
9208
- if args.parameter:
9209
- parse_experiment_parameters()
9210
- cli_params_experiment_parameters = experiment_parameters
9211
-
9212
- return cli_params_experiment_parameters
9213
-
9214
- def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
9215
- table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
9216
-
9217
- rows = _pareto_front_table_read_csv()
9218
- if not rows:
9219
- table.add_column("No data found")
9220
- return table
9221
-
9222
- filtered_rows = _pareto_front_table_filter_rows(rows, idxs)
9223
- if not filtered_rows:
9224
- table.add_column("No matching entries")
9225
- return table
9226
-
9227
- param_cols, result_cols = _pareto_front_table_get_columns(filtered_rows[0])
9228
-
9229
- _pareto_front_table_add_headers(table, param_cols, result_cols)
9230
- _pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
9231
-
9232
- return table
9233
-
9234
- def _pareto_front_table_read_csv() -> List[Dict[str, str]]:
9235
- with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
9236
- return list(csv.DictReader(f))
9237
-
9238
- def _pareto_front_table_filter_rows(rows: List[Dict[str, str]], idxs: List[int]) -> List[Dict[str, str]]:
9239
- result = []
9240
- for row in rows:
9241
- try:
9242
- trial_index = int(row["trial_index"])
9243
- except (KeyError, ValueError):
9244
- continue
9245
-
9246
- if row.get("trial_status", "").strip().upper() == "COMPLETED" and trial_index in idxs:
9247
- result.append(row)
9248
- return result
9249
-
9250
- def _pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
9251
- all_columns = list(first_row.keys())
9252
- ignored_cols = set(special_col_names) - {"trial_index"}
9253
-
9254
- param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
9255
- result_cols = [col for col in arg_result_names if col in all_columns]
9256
- return param_cols, result_cols
9257
-
9258
- def _pareto_front_table_add_headers(table: Table, param_cols: List[str], result_cols: List[str]) -> None:
9259
- for col in param_cols:
9260
- table.add_column(col, justify="center")
9261
- for col in result_cols:
9262
- table.add_column(Text(f"{col}", style="cyan"), justify="center")
9263
-
9264
- def _pareto_front_table_add_rows(table: Table, rows: List[Dict[str, str]], param_cols: List[str], result_cols: List[str]) -> None:
9265
- for row in rows:
9266
- values = [str(helpers.to_int_when_possible(row[col])) for col in param_cols]
9267
- result_values = [Text(str(helpers.to_int_when_possible(row[col])), style="cyan") for col in result_cols]
9268
- table.add_row(*values, *result_values, style="bold green")
9269
-
9270
- def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
9271
- if not os.path.exists(RESULT_CSV_FILE):
9272
- print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
9273
- return None
9274
-
9275
- return create_pareto_front_table(idxs, metric_x, metric_y)
9276
-
9277
- def supports_sixel() -> bool:
9278
- term = os.environ.get("TERM", "").lower()
9279
- if "xterm" in term or "mlterm" in term:
9280
- return True
9281
-
9282
- try:
9283
- output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
9284
- if output.returncode == 0 and "sixel" in output.stdout.lower():
9285
- return True
9286
- except (subprocess.CalledProcessError, FileNotFoundError):
9287
- pass
9288
-
9289
- return False
9290
-
9291
- def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
9292
- if data is None:
9293
- print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
9294
- return
9295
-
9296
- if not supports_sixel():
9297
- print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
9298
- return
9299
-
9300
- import matplotlib.pyplot as plt
9301
-
9302
- means = data[x_metric][y_metric]["means"]
9303
-
9304
- x_values = means[x_metric]
9305
- y_values = means[y_metric]
9306
-
9307
- fig, _ax = plt.subplots()
9308
-
9309
- _ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
9310
-
9311
- _ax.set_xlabel(x_metric)
9312
- _ax.set_ylabel(y_metric)
9313
-
9314
- _ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
9315
-
9316
- _ax.ticklabel_format(style='plain', axis='both', useOffset=False)
9317
-
9318
- with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
9319
- plt.savefig(tmp_file.name, dpi=300)
9320
-
9321
- print_image_to_cli(tmp_file.name, 1000)
9322
-
9323
- plt.close(fig)
9324
-
9325
- def _pareto_front_general_validate_shapes(x: np.ndarray, y: np.ndarray) -> None:
9326
- if x.shape != y.shape:
9327
- raise ValueError("Input arrays x and y must have the same shape.")
9328
-
9329
- def _pareto_front_general_compare(
9330
- xi: float, yi: float, xj: float, yj: float,
9331
- x_minimize: bool, y_minimize: bool
9332
- ) -> bool:
9333
- x_better_eq = xj <= xi if x_minimize else xj >= xi
9334
- y_better_eq = yj <= yi if y_minimize else yj >= yi
9335
- x_strictly_better = xj < xi if x_minimize else xj > xi
9336
- y_strictly_better = yj < yi if y_minimize else yj > yi
9337
-
9338
- return bool(x_better_eq and y_better_eq and (x_strictly_better or y_strictly_better))
9339
-
9340
- def _pareto_front_general_find_dominated(
9341
- x: np.ndarray, y: np.ndarray, x_minimize: bool, y_minimize: bool
9342
- ) -> np.ndarray:
9343
- num_points = len(x)
9344
- is_dominated = np.zeros(num_points, dtype=bool)
9345
-
9346
- for i in range(num_points):
9347
- for j in range(num_points):
9348
- if i == j:
9349
- continue
9350
-
9351
- if _pareto_front_general_compare(x[i], y[i], x[j], y[j], x_minimize, y_minimize):
9352
- is_dominated[i] = True
9353
- break
9354
-
9355
- return is_dominated
9356
-
9357
- def pareto_front_general(
9358
- x: np.ndarray,
9359
- y: np.ndarray,
9360
- x_minimize: bool = True,
9361
- y_minimize: bool = True
9362
- ) -> np.ndarray:
9363
- try:
9364
- _pareto_front_general_validate_shapes(x, y)
9365
- is_dominated = _pareto_front_general_find_dominated(x, y, x_minimize, y_minimize)
9366
- return np.where(~is_dominated)[0]
9367
- except Exception as e:
9368
- print("Error in pareto_front_general:", str(e))
9369
- return np.array([], dtype=int)
9370
-
9371
- def _pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
9372
- results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
9373
- result_names_file = f"{path_to_calculate}/result_names.txt"
9374
-
9375
- if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
9376
- return None
9377
-
9378
- with open(result_names_file, mode="r", encoding="utf-8") as f:
9379
- result_names = [line.strip() for line in f if line.strip()]
9380
-
9381
- records: dict = defaultdict(lambda: {'means': {}})
9382
-
9383
- with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
9384
- reader = csv.DictReader(csvfile)
9385
- for row in reader:
9386
- trial_index = int(row['trial_index'])
9387
- arm_name = row['arm_name']
9388
- key = (trial_index, arm_name)
9389
-
9390
- for metric in result_names:
9391
- if metric in row:
9392
- try:
9393
- records[key]['means'][metric] = float(row[metric])
9394
- except ValueError:
9395
- continue
9396
-
9397
- return records
9398
-
9399
- def _pareto_front_filter_complete_points(
9400
- path_to_calculate: str,
9401
- records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
9402
- primary_name: str,
9403
- secondary_name: str
9404
- ) -> List[Tuple[Tuple[int, str], float, float]]:
9405
- points = []
9406
- for key, metrics in records.items():
9407
- means = metrics['means']
9408
- if primary_name in means and secondary_name in means:
9409
- x_val = means[primary_name]
9410
- y_val = means[secondary_name]
9411
- points.append((key, x_val, y_val))
9412
- if len(points) == 0:
9413
- raise ValueError(f"No full data points with both objectives found in {path_to_calculate}.")
9414
- return points
9415
-
9416
- def _pareto_front_transform_objectives(
9417
- points: List[Tuple[Any, float, float]],
9418
- primary_name: str,
9419
- secondary_name: str
9420
- ) -> Tuple[np.ndarray, np.ndarray]:
9421
- primary_idx = arg_result_names.index(primary_name)
9422
- secondary_idx = arg_result_names.index(secondary_name)
9423
-
9424
- x = np.array([p[1] for p in points])
9425
- y = np.array([p[2] for p in points])
9426
-
9427
- if arg_result_min_or_max[primary_idx] == "max":
9428
- x = -x
9429
- elif arg_result_min_or_max[primary_idx] != "min":
9430
- raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
9431
-
9432
- if arg_result_min_or_max[secondary_idx] == "max":
9433
- y = -y
9434
- elif arg_result_min_or_max[secondary_idx] != "min":
9435
- raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
9436
-
9437
- return x, y
9438
-
9439
- def _pareto_front_select_pareto_points(
9440
- x: np.ndarray,
9441
- y: np.ndarray,
9442
- x_minimize: bool,
9443
- y_minimize: bool,
9444
- points: List[Tuple[Any, float, float]],
9445
- num_points: int
9446
- ) -> List[Tuple[Any, float, float]]:
9447
- indices = pareto_front_general(x, y, x_minimize, y_minimize)
9448
- sorted_indices = indices[np.argsort(x[indices])]
9449
- sorted_indices = sorted_indices[:num_points]
9450
- selected_points = [points[i] for i in sorted_indices]
9451
- return selected_points
9452
-
9453
- def _pareto_front_build_return_structure(
9454
- path_to_calculate: str,
9455
- selected_points: List[Tuple[Any, float, float]],
9456
- records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
9457
- absolute_metrics: List[str],
9458
- primary_name: str,
9459
- secondary_name: str
9460
- ) -> dict:
9461
- results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
9462
- result_names_file = f"{path_to_calculate}/result_names.txt"
9463
-
9464
- with open(result_names_file, mode="r", encoding="utf-8") as f:
9465
- result_names = [line.strip() for line in f if line.strip()]
9466
-
9467
- csv_rows = {}
9468
- with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
9469
- reader = csv.DictReader(csvfile)
9470
- for row in reader:
9471
- trial_index = int(row['trial_index'])
9472
- csv_rows[trial_index] = row
9473
-
9474
- ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
9475
- ignored_columns.update(result_names)
9476
-
9477
- param_dicts = []
9478
- idxs = []
9479
- means_dict = defaultdict(list)
9480
-
9481
- for (trial_index, arm_name), _, _ in selected_points:
9482
- row = csv_rows.get(trial_index, {})
9483
- if row == {} or row is None or row['arm_name'] != arm_name:
9484
- print_debug(f"_pareto_front_build_return_structure: trial_index '{trial_index}' could not be found and row returned as None")
9485
- continue
9916
+ global orchestrator
9486
9917
 
9487
- idxs.append(int(row["trial_index"]))
9918
+ if args.orchestrator_file:
9919
+ if SYSTEM_HAS_SBATCH:
9920
+ orchestrator = parse_orchestrator_file(args.orchestrator_file, False)
9921
+ else:
9922
+ print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
9488
9923
 
9489
- param_dict: dict[str, int | float | str] = {}
9490
- for key, value in row.items():
9491
- if key not in ignored_columns:
9492
- try:
9493
- param_dict[key] = int(value)
9494
- except ValueError:
9495
- try:
9496
- param_dict[key] = float(value)
9497
- except ValueError:
9498
- param_dict[key] = value
9924
+ def check_if_has_random_steps() -> None:
9925
+ if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
9926
+ _fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
9499
9927
 
9500
- param_dicts.append(param_dict)
9928
+ def add_exclude_to_defective_nodes() -> None:
9929
+ with spinner("Adding excluded nodes..."):
9930
+ if args.exclude:
9931
+ entries = [entry.strip() for entry in args.exclude.split(',')]
9501
9932
 
9502
- for metric in absolute_metrics:
9503
- means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
9933
+ for entry in entries:
9934
+ count_defective_nodes(None, entry)
9504
9935
 
9505
- ret = {
9506
- primary_name: {
9507
- secondary_name: {
9508
- "absolute_metrics": absolute_metrics,
9509
- "param_dicts": param_dicts,
9510
- "means": dict(means_dict),
9511
- "idxs": idxs
9512
- },
9513
- "absolute_metrics": absolute_metrics
9514
- }
9515
- }
9936
+ def check_max_eval(_max_eval: int) -> None:
9937
+ with spinner("Checking max_eval..."):
9938
+ if not _max_eval:
9939
+ _fatal_error("--max_eval needs to be set!", 19)
9516
9940
 
9517
- return ret
9941
+ def parse_parameters() -> Any:
9942
+ cli_params_experiment_parameters = None
9943
+ if args.parameter:
9944
+ parse_experiment_parameters()
9945
+ cli_params_experiment_parameters = experiment_parameters
9518
9946
 
9519
- def get_pareto_frontier_points(
9520
- path_to_calculate: str,
9521
- primary_objective: str,
9522
- secondary_objective: str,
9523
- x_minimize: bool,
9524
- y_minimize: bool,
9525
- absolute_metrics: List[str],
9526
- num_points: int
9527
- ) -> Optional[dict]:
9528
- records = _pareto_front_aggregate_data(path_to_calculate)
9947
+ return cli_params_experiment_parameters
9529
9948
 
9530
- if records is None:
9531
- return None
9949
+ def supports_sixel() -> bool:
9950
+ term = os.environ.get("TERM", "").lower()
9951
+ if "xterm" in term or "mlterm" in term:
9952
+ return True
9532
9953
 
9533
- points = _pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
9534
- x, y = _pareto_front_transform_objectives(points, primary_objective, secondary_objective)
9535
- selected_points = _pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
9536
- result = _pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
9954
+ try:
9955
+ output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
9956
+ if output.returncode == 0 and "sixel" in output.stdout.lower():
9957
+ return True
9958
+ except (subprocess.CalledProcessError, FileNotFoundError):
9959
+ pass
9537
9960
 
9538
- return result
9961
+ return False
9539
9962
 
9540
9963
  def save_experiment_state() -> None:
9541
9964
  try:
@@ -9543,14 +9966,14 @@ def save_experiment_state() -> None:
9543
9966
  print_red("save_experiment_state: ax_client or ax_client.experiment is None, cannot save.")
9544
9967
  return
9545
9968
  state_path = get_current_run_folder("experiment_state.json")
9546
- ax_client.save_to_json_file(state_path)
9969
+ save_ax_client_to_json_file(state_path)
9547
9970
  except Exception as e:
9548
- print(f"Error saving experiment state: {e}")
9971
+ print_debug(f"Error saving experiment state: {e}")
9549
9972
 
9550
9973
  def wait_for_state_file(state_path: str, min_size: int = 5, max_wait_seconds: int = 60) -> bool:
9551
9974
  try:
9552
9975
  if not os.path.exists(state_path):
9553
- print(f"[ERROR] File '{state_path}' does not exist.")
9976
+ print_debug(f"[ERROR] File '{state_path}' does not exist.")
9554
9977
  return False
9555
9978
 
9556
9979
  i = 0
@@ -9618,7 +10041,7 @@ def load_experiment_state() -> None:
9618
10041
  return
9619
10042
 
9620
10043
  try:
9621
- arms_seen = {}
10044
+ arms_seen: dict = {}
9622
10045
  for arm in data.get("arms", []):
9623
10046
  name = arm.get("name")
9624
10047
  sig = arm.get("parameters")
@@ -9769,38 +10192,47 @@ def get_result_minimize_flag(path_to_calculate: str, resname: str) -> bool:
9769
10192
 
9770
10193
  return minmax[index] == "min"
9771
10194
 
9772
- def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
9773
- pareto_front_data: dict = {}
10195
+ def post_job_calculate_pareto_front() -> None:
10196
+ if not args.calculate_pareto_front_of_job:
10197
+ return
9774
10198
 
9775
- all_combinations = list(combinations(range(len(arg_result_names)), 2))
10199
+ failure = False
9776
10200
 
9777
- skip = False
10201
+ _paths_to_calculate = []
9778
10202
 
9779
- for i, j in all_combinations:
9780
- if not skip:
9781
- metric_x = arg_result_names[i]
9782
- metric_y = arg_result_names[j]
10203
+ for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
10204
+ try:
10205
+ found_paths = find_results_paths(_path_to_calculate)
9783
10206
 
9784
- x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
9785
- y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
10207
+ for _fp in found_paths:
10208
+ if _fp not in _paths_to_calculate:
10209
+ _paths_to_calculate.append(_fp)
10210
+ except (FileNotFoundError, NotADirectoryError) as e:
10211
+ print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
9786
10212
 
9787
- try:
9788
- if metric_x not in pareto_front_data:
9789
- pareto_front_data[metric_x] = {}
10213
+ failure = True
9790
10214
 
9791
- pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
9792
- except ax.exceptions.core.DataRequiredError as e:
9793
- print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
9794
- except SignalINT:
9795
- print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
9796
- skip = True
10215
+ for _path_to_calculate in _paths_to_calculate:
10216
+ for path_to_calculate in found_paths:
10217
+ if not job_calculate_pareto_front(path_to_calculate):
10218
+ failure = True
9797
10219
 
9798
- return pareto_front_data
10220
+ if failure:
10221
+ my_exit(24)
10222
+
10223
+ my_exit(0)
10224
+
10225
+ def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
10226
+ if not os.path.exists(RESULT_CSV_FILE):
10227
+ print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
10228
+ return None
10229
+
10230
+ return create_pareto_front_table(idxs, metric_x, metric_y)
9799
10231
 
9800
10232
  def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_sixel_and_table: bool = False) -> None:
9801
10233
  if len(res_names) <= 1:
9802
10234
  print_debug(f"--result_names (has {len(res_names)} entries) must be at least 2.")
9803
- return
10235
+ return None
9804
10236
 
9805
10237
  pareto_front_data: dict = get_pareto_front_data(path_to_calculate, res_names)
9806
10238
 
@@ -9821,8 +10253,16 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
9821
10253
  else:
9822
10254
  print(f"Not showing Pareto-front-sixel for {path_to_calculate}")
9823
10255
 
9824
- if len(calculated_frontier[metric_x][metric_y]["idxs"]):
9825
- pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
10256
+ if calculated_frontier is None:
10257
+ print_debug("ERROR: calculated_frontier is None")
10258
+ return None
10259
+
10260
+ try:
10261
+ if len(calculated_frontier[metric_x][metric_y]["idxs"]):
10262
+ pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
10263
+ except AttributeError:
10264
+ print_debug(f"ERROR: calculated_frontier structure invalid for ({metric_x}, {metric_y})")
10265
+ return None
9826
10266
 
9827
10267
  rich_table = pareto_front_as_rich_table(
9828
10268
  calculated_frontier[metric_x][metric_y]["idxs"],
@@ -9847,6 +10287,8 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
9847
10287
 
9848
10288
  live_share_after_pareto()
9849
10289
 
10290
+ return None
10291
+
9850
10292
  def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_color: str) -> None:
9851
10293
  cpu_count = os.cpu_count()
9852
10294
 
@@ -9858,9 +10300,11 @@ def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_
9858
10300
  pass
9859
10301
 
9860
10302
  if gpu_string:
9861
- console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}] [green]{gs_string}[/green]")
10303
+ console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]")
9862
10304
  else:
9863
- print_green(f"You have {cpu_count} CPUs available for the main process. {gs_string}")
10305
+ print_green(f"You have {cpu_count} CPUs available for the main process.")
10306
+
10307
+ print_green(gs_string)
9864
10308
 
9865
10309
  def write_args_overview_table() -> None:
9866
10310
  table = Table(title="Arguments Overview")
@@ -10127,112 +10571,6 @@ def find_results_paths(base_path: str) -> list:
10127
10571
 
10128
10572
  return list(set(found_paths))
10129
10573
 
10130
- def post_job_calculate_pareto_front() -> None:
10131
- if not args.calculate_pareto_front_of_job:
10132
- return
10133
-
10134
- failure = False
10135
-
10136
- _paths_to_calculate = []
10137
-
10138
- for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
10139
- try:
10140
- found_paths = find_results_paths(_path_to_calculate)
10141
-
10142
- for _fp in found_paths:
10143
- if _fp not in _paths_to_calculate:
10144
- _paths_to_calculate.append(_fp)
10145
- except (FileNotFoundError, NotADirectoryError) as e:
10146
- print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
10147
-
10148
- failure = True
10149
-
10150
- for _path_to_calculate in _paths_to_calculate:
10151
- for path_to_calculate in found_paths:
10152
- if not job_calculate_pareto_front(path_to_calculate):
10153
- failure = True
10154
-
10155
- if failure:
10156
- my_exit(24)
10157
-
10158
- my_exit(0)
10159
-
10160
- def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
10161
- pf_start_time = time.time()
10162
-
10163
- if not path_to_calculate:
10164
- return False
10165
-
10166
- global CURRENT_RUN_FOLDER
10167
- global RESULT_CSV_FILE
10168
- global arg_result_names
10169
-
10170
- if not path_to_calculate:
10171
- print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
10172
- return False
10173
-
10174
- if not os.path.exists(path_to_calculate):
10175
- print_red(f"Path '{path_to_calculate}' does not exist")
10176
- return False
10177
-
10178
- ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
10179
-
10180
- if not os.path.exists(ax_client_json):
10181
- print_red(f"Path '{ax_client_json}' not found")
10182
- return False
10183
-
10184
- checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
10185
- if not os.path.exists(checkpoint_file):
10186
- print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
10187
- return False
10188
-
10189
- RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
10190
- if not os.path.exists(RESULT_CSV_FILE):
10191
- print_red(f"{RESULT_CSV_FILE} not found")
10192
- return False
10193
-
10194
- res_names = []
10195
-
10196
- res_names_file = f"{path_to_calculate}/result_names.txt"
10197
- if not os.path.exists(res_names_file):
10198
- print_red(f"File '{res_names_file}' does not exist")
10199
- return False
10200
-
10201
- try:
10202
- with open(res_names_file, "r", encoding="utf-8") as file:
10203
- lines = file.readlines()
10204
- except Exception as e:
10205
- print_red(f"Error reading file '{res_names_file}': {e}")
10206
- return False
10207
-
10208
- for line in lines:
10209
- entry = line.strip()
10210
- if entry != "":
10211
- res_names.append(entry)
10212
-
10213
- if len(res_names) < 2:
10214
- print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
10215
- return False
10216
-
10217
- load_username_to_args(path_to_calculate)
10218
-
10219
- CURRENT_RUN_FOLDER = path_to_calculate
10220
-
10221
- arg_result_names = res_names
10222
-
10223
- load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
10224
-
10225
- if experiment_parameters is None:
10226
- return False
10227
-
10228
- show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
10229
-
10230
- pf_end_time = time.time()
10231
-
10232
- print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
10233
-
10234
- return True
10235
-
10236
10574
  def set_arg_states_from_continue() -> None:
10237
10575
  if args.continue_previous_job and not args.num_random_steps:
10238
10576
  num_random_steps_file = f"{args.continue_previous_job}/state_files/num_random_steps"
@@ -10266,7 +10604,7 @@ def write_result_names_file() -> None:
10266
10604
  except Exception as e:
10267
10605
  print_red(f"Error trying to open file '{fn}': {e}")
10268
10606
 
10269
- def run_program_once(params=None) -> None:
10607
+ def run_program_once(params: Optional[dict] = None) -> None:
10270
10608
  if not args.run_program_once:
10271
10609
  print_debug("[yellow]No setup script specified (run_program_once). Skipping setup.[/yellow]")
10272
10610
  return
@@ -10275,27 +10613,27 @@ def run_program_once(params=None) -> None:
10275
10613
  params = {}
10276
10614
 
10277
10615
  if isinstance(args.run_program_once, str):
10278
- command_str = args.run_program_once
10616
+ command_str = decode_if_base64(args.run_program_once)
10279
10617
  for k, v in params.items():
10280
10618
  placeholder = f"%({k})"
10281
10619
  command_str = command_str.replace(placeholder, str(v))
10282
10620
 
10283
- with spinner(f"Executing command: [cyan]{command_str}[/cyan]"):
10284
- result = subprocess.run(command_str, shell=True, check=True)
10285
- if result.returncode == 0:
10286
- console.log("[bold green]Setup script completed successfully ✅[/bold green]")
10287
- else:
10288
- console.log(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10621
+ print(f"Executing command: [cyan]{command_str}[/cyan]")
10622
+ result = subprocess.run(command_str, shell=True, check=True)
10623
+ if result.returncode == 0:
10624
+ print("[bold green]Setup script completed successfully ✅[/bold green]")
10625
+ else:
10626
+ print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10289
10627
 
10290
- my_exit(57)
10628
+ my_exit(57)
10291
10629
 
10292
10630
  elif isinstance(args.run_program_once, (list, tuple)):
10293
10631
  with spinner("run_program_once: Executing command list: [cyan]{args.run_program_once}[/cyan]"):
10294
10632
  result = subprocess.run(args.run_program_once, check=True)
10295
10633
  if result.returncode == 0:
10296
- console.log("[bold green]Setup script completed successfully ✅[/bold green]")
10634
+ print("[bold green]Setup script completed successfully ✅[/bold green]")
10297
10635
  else:
10298
- console.log(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10636
+ print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10299
10637
 
10300
10638
  my_exit(57)
10301
10639
 
@@ -10305,7 +10643,7 @@ def run_program_once(params=None) -> None:
10305
10643
  my_exit(57)
10306
10644
 
10307
10645
  def show_omniopt_call() -> None:
10308
- def remove_ui_url(arg_str) -> str:
10646
+ def remove_ui_url(arg_str: str) -> str:
10309
10647
  return re.sub(r'(?:--ui_url(?:=\S+)?(?:\s+\S+)?)', '', arg_str).strip()
10310
10648
 
10311
10649
  original_argv = " ".join(sys.argv[1:])
@@ -10313,8 +10651,14 @@ def show_omniopt_call() -> None:
10313
10651
 
10314
10652
  original_print(oo_call + " " + cleaned)
10315
10653
 
10654
+ if args.dependency is not None and args.dependency != "":
10655
+ print(f"Dependency: {args.dependency}")
10656
+
10657
+ if args.ui_url is not None and args.ui_url != "":
10658
+ print_yellow("--ui_url is deprecated. Do not use it anymore. It will be ignored and one day be removed.")
10659
+
10316
10660
  def main() -> None:
10317
- global RESULT_CSV_FILE, ax_client, LOGFILE_DEBUG_GET_NEXT_TRIALS
10661
+ global RESULT_CSV_FILE, LOGFILE_DEBUG_GET_NEXT_TRIALS
10318
10662
 
10319
10663
  check_if_has_random_steps()
10320
10664
 
@@ -10396,15 +10740,13 @@ def main() -> None:
10396
10740
  exp_params = get_experiment_parameters(cli_params_experiment_parameters)
10397
10741
 
10398
10742
  if exp_params is not None:
10399
- ax_client, experiment_args, gpu_string, gpu_color = exp_params
10743
+ experiment_args, gpu_string, gpu_color = exp_params
10400
10744
  print_debug(f"experiment_parameters: {experiment_parameters}")
10401
10745
 
10402
10746
  set_orchestrator()
10403
10747
 
10404
10748
  init_live_share()
10405
10749
 
10406
- start_periodic_live_share()
10407
-
10408
10750
  show_available_hardware_and_generation_strategy_string(gpu_string, gpu_color)
10409
10751
 
10410
10752
  original_print(f"Run-Program: {global_vars['joined_run_program']}")
@@ -10419,6 +10761,8 @@ def main() -> None:
10419
10761
 
10420
10762
  write_files_and_show_overviews()
10421
10763
 
10764
+ live_share()
10765
+
10422
10766
  #if args.continue_previous_job:
10423
10767
  # insert_jobs_from_csv(f"{args.continue_previous_job}/{RESULTS_CSV_FILENAME}")
10424
10768
 
@@ -10496,21 +10840,37 @@ def initialize_nvidia_logs() -> None:
10496
10840
  def build_gui_url(config: argparse.Namespace) -> str:
10497
10841
  base_url = get_base_url()
10498
10842
  params = collect_params(config)
10499
- return f"{base_url}?{urlencode(params, doseq=True)}"
10843
+ ret = f"{base_url}?{urlencode(params, doseq=True)}"
10844
+
10845
+ return ret
10846
+
10847
+ def get_result_names_for_url(value: List) -> str:
10848
+ d = dict(v.split("=", 1) if "=" in v else (v, "min") for v in value)
10849
+ s = " ".join(f"{k}={v}" for k, v in d.items())
10850
+
10851
+ return s
10500
10852
 
10501
10853
  def collect_params(config: argparse.Namespace) -> dict:
10502
10854
  params = {}
10855
+ user_home = os.path.expanduser("~")
10856
+
10503
10857
  for attr, value in vars(config).items():
10504
10858
  if attr == "run_program":
10505
10859
  params[attr] = global_vars["joined_run_program"]
10860
+ elif attr == "result_names" and value:
10861
+ params[attr] = get_result_names_for_url(value)
10506
10862
  elif attr == "parameter" and value is not None:
10507
10863
  params.update(process_parameters(config.parameter))
10864
+ elif attr == "root_venv_dir":
10865
+ if value is not None and os.path.abspath(value) != os.path.abspath(user_home):
10866
+ params[attr] = value
10508
10867
  elif isinstance(value, bool):
10509
10868
  params[attr] = int(value)
10510
10869
  elif isinstance(value, list):
10511
10870
  params[attr] = value
10512
10871
  elif value is not None:
10513
10872
  params[attr] = value
10873
+
10514
10874
  return params
10515
10875
 
10516
10876
  def process_parameters(parameters: list) -> dict:
@@ -10563,7 +10923,7 @@ def get_base_url() -> str:
10563
10923
  def write_ui_url() -> None:
10564
10924
  url = build_gui_url(args)
10565
10925
  with open(get_current_run_folder("ui_url.txt"), mode="a", encoding="utf-8") as myfile:
10566
- myfile.write(decode_if_base64(url))
10926
+ myfile.write(url)
10567
10927
 
10568
10928
  def handle_random_steps() -> None:
10569
10929
  if args.parameter and args.continue_previous_job and random_steps <= 0:
@@ -10622,8 +10982,6 @@ def run_search_with_progress_bar() -> None:
10622
10982
  wait_for_jobs_to_complete()
10623
10983
 
10624
10984
  def complex_tests(_program_name: str, wanted_stderr: str, wanted_exit_code: int, wanted_signal: Union[int, None], res_is_none: bool = False) -> int:
10625
- #print_yellow(f"Test suite: {_program_name}")
10626
-
10627
10985
  nr_errors: int = 0
10628
10986
 
10629
10987
  program_path: str = f"./.tests/test_wronggoing_stuff.bin/bin/{_program_name}"
@@ -10697,7 +11055,7 @@ def test_find_paths(program_code: str) -> int:
10697
11055
  for i in files:
10698
11056
  if i not in string:
10699
11057
  if os.path.exists(i):
10700
- print("Missing {i} in find_file_paths string!")
11058
+ print(f"Missing {i} in find_file_paths string!")
10701
11059
  nr_errors += 1
10702
11060
 
10703
11061
  return nr_errors
@@ -11078,17 +11436,16 @@ Exit-Code: 159
11078
11436
 
11079
11437
  my_exit(nr_errors)
11080
11438
 
11081
- def main_outside() -> None:
11439
+ def main_wrapper() -> None:
11082
11440
  print(f"Run-UUID: {run_uuid}")
11083
11441
 
11084
11442
  auto_wrap_namespace(globals())
11085
11443
 
11086
11444
  print_logo()
11087
11445
 
11088
- start_logging_daemon()
11089
-
11090
11446
  fool_linter(args.num_cpus_main_job)
11091
11447
  fool_linter(args.flame_graph)
11448
+ fool_linter(args.memray)
11092
11449
 
11093
11450
  with warnings.catch_warnings():
11094
11451
  warnings.simplefilter("ignore")
@@ -11123,7 +11480,7 @@ def main_outside() -> None:
11123
11480
  def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
11124
11481
  pattern = re.compile(regex) if regex else None
11125
11482
 
11126
- def wrapped(*args, **kwargs):
11483
+ def wrapped(*args: Any, **kwargs: Any) -> None:
11127
11484
  if pattern and not pattern.search(func.__name__):
11128
11485
  return func(*args, **kwargs)
11129
11486
 
@@ -11145,6 +11502,9 @@ def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
11145
11502
  def auto_wrap_namespace(namespace: Any) -> Any:
11146
11503
  enable_beartype = any(os.getenv(v) for v in ("ENABLE_BEARTYPE", "CI"))
11147
11504
 
11505
+ if args.beartype:
11506
+ enable_beartype = True
11507
+
11148
11508
  excluded_functions = {
11149
11509
  "log_time_and_memory_wrapper",
11150
11510
  "collect_runtime_stats",
@@ -11153,8 +11513,6 @@ def auto_wrap_namespace(namespace: Any) -> Any:
11153
11513
  "_record_stats",
11154
11514
  "_open",
11155
11515
  "_check_memory_leak",
11156
- "start_periodic_live_share",
11157
- "start_logging_daemon",
11158
11516
  "get_current_run_folder",
11159
11517
  "show_func_name_wrapper"
11160
11518
  }
@@ -11180,7 +11538,7 @@ def auto_wrap_namespace(namespace: Any) -> Any:
11180
11538
 
11181
11539
  if __name__ == "__main__":
11182
11540
  try:
11183
- main_outside()
11541
+ main_wrapper()
11184
11542
  except (SignalUSR, SignalINT, SignalCONT) as e:
11185
- print_red(f"main_outside failed with exception {e}")
11543
+ print_red(f"main_wrapper failed with exception {e}")
11186
11544
  end_program(True)