omniopt2 8471__py3-none-any.whl → 9171__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. .gitignore +2 -0
  2. .helpers.py +0 -9
  3. .omniopt.py +1298 -903
  4. .omniopt_plot_scatter.py +1 -1
  5. .omniopt_plot_scatter_hex.py +1 -1
  6. .pareto.py +134 -0
  7. .shellscript_functions +24 -15
  8. .tests/pylint.rc +0 -4
  9. README.md +1 -1
  10. omniopt +33 -22
  11. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.helpers.py +0 -9
  12. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt.py +1298 -903
  13. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter.py +1 -1
  14. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter_hex.py +1 -1
  15. omniopt2-9171.data/data/bin/.pareto.py +134 -0
  16. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.shellscript_functions +24 -15
  17. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt +33 -22
  18. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt_plot +1 -1
  19. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/pylint.rc +0 -4
  20. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/test_requirements.txt +1 -0
  21. {omniopt2-8471.dist-info → omniopt2-9171.dist-info}/METADATA +6 -5
  22. omniopt2-9171.dist-info/RECORD +73 -0
  23. omniopt2.egg-info/PKG-INFO +6 -5
  24. omniopt2.egg-info/SOURCES.txt +1 -0
  25. omniopt2.egg-info/requires.txt +4 -3
  26. omniopt_plot +1 -1
  27. pyproject.toml +1 -1
  28. requirements.txt +3 -3
  29. test_requirements.txt +1 -0
  30. omniopt2-8471.dist-info/RECORD +0 -71
  31. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.colorfunctions.sh +0 -0
  32. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.general.sh +0 -0
  33. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_cpu_ram_usage.py +0 -0
  34. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_general.py +0 -0
  35. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_gpu_usage.py +0 -0
  36. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_kde.py +0 -0
  37. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter_generation_method.py +0 -0
  38. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_time_and_exit_code.py +0 -0
  39. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_trial_index_result.py +0 -0
  40. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_worker.py +0 -0
  41. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.random_generator.py +0 -0
  42. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.tpe.py +0 -0
  43. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/LICENSE +0 -0
  44. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/apt-dependencies.txt +0 -0
  45. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt_docker +0 -0
  46. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt_evaluate +0 -0
  47. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt_share +0 -0
  48. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/requirements.txt +3 -3
  49. {omniopt2-8471.data → omniopt2-9171.data}/data/bin/setup.py +0 -0
  50. {omniopt2-8471.dist-info → omniopt2-9171.dist-info}/WHEEL +0 -0
  51. {omniopt2-8471.dist-info → omniopt2-9171.dist-info}/licenses/LICENSE +0 -0
  52. {omniopt2-8471.dist-info → omniopt2-9171.dist-info}/top_level.txt +0 -0
.omniopt.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  #from mayhemmonkey import MayhemMonkey
4
4
  #mayhemmonkey = MayhemMonkey()
5
- #mayhemmonkey.set_function_fail_after_count("open", 201)
5
+ #mayhemmonkey.set_function_fail_after_count("open", 10)
6
6
  #mayhemmonkey.set_function_error_rate("open", 0.1)
7
7
  #mayhemmonkey.set_function_group_error_rate(["io", "math"], 0.8)
8
8
  #mayhemmonkey.install_faulty()
@@ -27,11 +27,15 @@ import inspect
27
27
  import tracemalloc
28
28
  import resource
29
29
  from urllib.parse import urlencode
30
-
31
30
  import psutil
32
31
 
33
32
  FORCE_EXIT: bool = False
34
33
 
34
+ LAST_LOG_TIME: int = 0
35
+ last_msg_progressbar = ""
36
+ last_msg_raw = None
37
+ last_lock_print_debug = threading.Lock()
38
+
35
39
  def force_exit(signal_number: Any, frame: Any) -> Any:
36
40
  global FORCE_EXIT
37
41
 
@@ -84,6 +88,7 @@ log_nr_gen_jobs: list[int] = []
84
88
  generation_strategy_human_readable: str = ""
85
89
  oo_call: str = "./omniopt"
86
90
  progress_bar_length: int = 0
91
+ worker_usage_file = 'worker_usage.csv'
87
92
 
88
93
  if os.environ.get("CUSTOM_VIRTUAL_ENV") == "1":
89
94
  oo_call = "omniopt"
@@ -99,7 +104,7 @@ joined_valid_occ_types: str = ", ".join(valid_occ_types)
99
104
  SUPPORTED_MODELS: list = ["SOBOL", "FACTORIAL", "SAASBO", "BOTORCH_MODULAR", "UNIFORM", "BO_MIXED", "RANDOMFOREST", "EXTERNAL_GENERATOR", "PSEUDORANDOM", "TPE"]
100
105
  joined_supported_models: str = ", ".join(SUPPORTED_MODELS)
101
106
 
102
- special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid"]
107
+ special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid", "runtime", "status"]
103
108
 
104
109
  IGNORABLE_COLUMNS: list = ["start_time", "end_time", "hostname", "signal", "exit_code", "run_time", "program_string"] + special_col_names
105
110
 
@@ -153,12 +158,6 @@ try:
153
158
  message="Ax currently requires a sqlalchemy version below 2.0.*",
154
159
  )
155
160
 
156
- warnings.filterwarnings(
157
- "ignore",
158
- category=RuntimeWarning,
159
- message="coroutine 'start_logging_daemon' was never awaited"
160
- )
161
-
162
161
  with spinner("Importing argparse..."):
163
162
  import argparse
164
163
 
@@ -204,6 +203,9 @@ try:
204
203
  with spinner("Importing rich.pretty..."):
205
204
  from rich.pretty import pprint
206
205
 
206
+ with spinner("Importing pformat..."):
207
+ from pprint import pformat
208
+
207
209
  with spinner("Importing rich.prompt..."):
208
210
  from rich.prompt import Prompt, FloatPrompt, IntPrompt
209
211
 
@@ -237,9 +239,6 @@ try:
237
239
  with spinner("Importing uuid..."):
238
240
  import uuid
239
241
 
240
- #with spinner("Importing qrcode..."):
241
- # import qrcode
242
-
243
242
  with spinner("Importing cowsay..."):
244
243
  import cowsay
245
244
 
@@ -270,6 +269,9 @@ try:
270
269
  with spinner("Importing beartype..."):
271
270
  from beartype import beartype
272
271
 
272
+ with spinner("Importing rendering stuff..."):
273
+ from ax.plot.base import AxPlotConfig
274
+
273
275
  with spinner("Importing statistics..."):
274
276
  import statistics
275
277
 
@@ -381,16 +383,9 @@ def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: fl
381
383
  call_path_str = " -> ".join(short_stack)
382
384
  _func_call_paths[func_name][call_path_str] += 1
383
385
 
384
- print(
385
- f"Function '{func_name}' took {elapsed:.4f}s "
386
- f"(total {percent_if_added:.1f}% of tracked time)"
387
- )
388
- print(
389
- f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, "
390
- f"diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB"
391
- )
386
+ print(f"Function '{func_name}' took {elapsed:.4f}s (total {percent_if_added:.1f}% of tracked time)")
387
+ print(f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB")
392
388
 
393
- # NEU: Runtime Stats
394
389
  runtime_stats = collect_runtime_stats()
395
390
  print("=== Runtime Stats ===")
396
391
  print(f"RSS: {runtime_stats['rss_MB']:.2f} MB, VMS: {runtime_stats['vms_MB']:.2f} MB")
@@ -502,6 +497,24 @@ try:
502
497
  dier: FunctionType = helpers.dier
503
498
  is_equal: FunctionType = helpers.is_equal
504
499
  is_not_equal: FunctionType = helpers.is_not_equal
500
+ with spinner("Importing pareto..."):
501
+ pareto_file: str = f"{script_dir}/.pareto.py"
502
+ spec = importlib.util.spec_from_file_location(
503
+ name="pareto",
504
+ location=pareto_file,
505
+ )
506
+ if spec is not None and spec.loader is not None:
507
+ pareto = importlib.util.module_from_spec(spec)
508
+ spec.loader.exec_module(pareto)
509
+ else:
510
+ raise ImportError(f"Could not load module from {pareto_file}")
511
+
512
+ pareto_front_table_filter_rows: FunctionType = pareto.pareto_front_table_filter_rows
513
+ pareto_front_table_add_headers: FunctionType = pareto.pareto_front_table_add_headers
514
+ pareto_front_table_add_rows: FunctionType = pareto.pareto_front_table_add_rows
515
+ pareto_front_filter_complete_points: FunctionType = pareto.pareto_front_filter_complete_points
516
+ pareto_front_select_pareto_points: FunctionType = pareto.pareto_front_select_pareto_points
517
+
505
518
  except KeyboardInterrupt:
506
519
  print("You pressed CTRL-c while importing the helpers file")
507
520
  sys.exit(0)
@@ -544,6 +557,17 @@ logfile_worker_creation_logs: str = f'{log_uuid_dir}_worker_creation_logs'
544
557
  logfile_trial_index_to_param_logs: str = f'{log_uuid_dir}_trial_index_to_param_logs'
545
558
  LOGFILE_DEBUG_GET_NEXT_TRIALS: Union[str, None] = None
546
559
 
560
+ def error_without_print(text: str) -> None:
561
+ print_debug(text)
562
+
563
+ if get_current_run_folder():
564
+ try:
565
+ with open(get_current_run_folder("oo_errors.txt"), mode="a", encoding="utf-8") as myfile:
566
+ myfile.write(text + "\n\n")
567
+ except (OSError, FileNotFoundError) as e:
568
+ helpers.print_color("red", f"Error: {e}. This may mean that the {get_current_run_folder()} was deleted during the run. Could not write '{text} to {get_current_run_folder()}/oo_errors.txt'")
569
+ sys.exit(99)
570
+
547
571
  def print_red(text: str) -> None:
548
572
  helpers.print_color("red", text)
549
573
 
@@ -802,6 +826,7 @@ class ConfigLoader:
802
826
  parameter: Optional[List[str]]
803
827
  experiment_constraints: Optional[List[str]]
804
828
  main_process_gb: int
829
+ beartype: bool
805
830
  worker_timeout: int
806
831
  slurm_signal_delay_s: int
807
832
  gridsearch: bool
@@ -822,6 +847,7 @@ class ConfigLoader:
822
847
  run_program_once: str
823
848
  mem_gb: int
824
849
  flame_graph: bool
850
+ memray: bool
825
851
  continue_previous_job: Optional[str]
826
852
  calculate_pareto_front_of_job: Optional[List[str]]
827
853
  revert_to_random_when_seemingly_exhausted: bool
@@ -983,6 +1009,7 @@ class ConfigLoader:
983
1009
  debug.add_argument('--verbose_break_run_search_table', help='Verbose logging for break_run_search', action='store_true', default=False)
984
1010
  debug.add_argument('--debug', help='Enable debugging', action='store_true', default=False)
985
1011
  debug.add_argument('--flame_graph', help='Enable flame-graphing. Makes everything slower, but creates a flame graph', action='store_true', default=False)
1012
+ debug.add_argument('--memray', help='Use memray to show memory usage', action='store_true', default=False)
986
1013
  debug.add_argument('--no_sleep', help='Disables sleeping for fast job generation (not to be used on HPC)', action='store_true', default=False)
987
1014
  debug.add_argument('--tests', help='Run simple internal tests', action='store_true', default=False)
988
1015
  debug.add_argument('--show_worker_percentage_table_at_end', help='Show a table of percentage of usage of max worker over time', action='store_true', default=False)
@@ -996,8 +1023,8 @@ class ConfigLoader:
996
1023
  debug.add_argument('--runtime_debug', help='Logs which functions use most of the time', action='store_true', default=False)
997
1024
  debug.add_argument('--debug_stack_regex', help='Only print debug messages if call stack matches any regex', type=str, default='')
998
1025
  debug.add_argument('--debug_stack_trace_regex', help='Show compact call stack with arrows if any function in stack matches regex', type=str, default=None)
999
-
1000
1026
  debug.add_argument('--show_func_name', help='Show func name before each execution and when it is done', action='store_true', default=False)
1027
+ debug.add_argument('--beartype', help='Use beartype', action='store_true', default=False)
1001
1028
 
1002
1029
  def load_config(self: Any, config_path: str, file_format: str) -> dict:
1003
1030
  if not os.path.isfile(config_path):
@@ -1235,11 +1262,15 @@ for _rn in args.result_names:
1235
1262
  _key = _rn
1236
1263
  _min_or_max = __default_min_max
1237
1264
 
1265
+ _min_or_max = re.sub(r"'", "", _min_or_max)
1266
+
1238
1267
  if _min_or_max not in ["min", "max"]:
1239
1268
  if _min_or_max:
1240
1269
  print_yellow(f"Value for determining whether to minimize or maximize was neither 'min' nor 'max' for key '{_key}', but '{_min_or_max}'. It will be set to the default, which is '{__default_min_max}' instead.")
1241
1270
  _min_or_max = __default_min_max
1242
1271
 
1272
+ _key = re.sub(r"'", "", _key)
1273
+
1243
1274
  if _key in arg_result_names:
1244
1275
  console.print(f"[red]The --result_names option '{_key}' was specified multiple times![/]")
1245
1276
  sys.exit(50)
@@ -1340,27 +1371,14 @@ try:
1340
1371
  with spinner("Importing ExternalGenerationNode..."):
1341
1372
  from ax.generation_strategy.external_generation_node import ExternalGenerationNode
1342
1373
 
1343
- with spinner("Importing MaxTrials..."):
1344
- from ax.generation_strategy.transition_criterion import MaxTrials
1374
+ with spinner("Importing MinTrials..."):
1375
+ from ax.generation_strategy.transition_criterion import MinTrials
1345
1376
 
1346
1377
  with spinner("Importing GeneratorSpec..."):
1347
1378
  from ax.generation_strategy.generator_spec import GeneratorSpec
1348
1379
 
1349
- #except Exception:
1350
- # with spinner("Fallback: Importing ax.generation_strategy.generation_node..."):
1351
- # import ax.generation_strategy.generation_node
1352
-
1353
- # with spinner("Fallback: Importing GenerationStep, GenerationStrategy from ax.generation_strategy..."):
1354
- # from ax.generation_strategy.generation_node import GenerationNode, GenerationStep
1355
-
1356
- # with spinner("Fallback: Importing ExternalGenerationNode..."):
1357
- # from ax.generation_strategy.external_generation_node import ExternalGenerationNode
1358
-
1359
- # with spinner("Fallback: Importing MaxTrials..."):
1360
- # from ax.generation_strategy.transition_criterion import MaxTrials
1361
-
1362
- with spinner("Importing Models from ax.generation_strategy.registry..."):
1363
- from ax.adapter.registry import Models
1380
+ with spinner("Importing Generators from ax.generation_strategy.registry..."):
1381
+ from ax.adapter.registry import Generators
1364
1382
 
1365
1383
  with spinner("Importing get_pending_observation_features..."):
1366
1384
  from ax.core.utils import get_pending_observation_features
@@ -1446,7 +1464,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1446
1464
  def __init__(self: Any, regressor_options: Dict[str, Any] = {}, seed: Optional[int] = None, num_samples: int = 1) -> None:
1447
1465
  print_debug("Initializing RandomForestGenerationNode...")
1448
1466
  t_init_start = time.monotonic()
1449
- super().__init__(node_name="RANDOMFOREST")
1467
+ super().__init__(name="RANDOMFOREST")
1450
1468
  self.num_samples: int = num_samples
1451
1469
  self.seed: int = seed
1452
1470
 
@@ -1466,6 +1484,9 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1466
1484
  def update_generator_state(self: Any, experiment: Experiment, data: Data) -> None:
1467
1485
  search_space = experiment.search_space
1468
1486
  parameter_names = list(search_space.parameters.keys())
1487
+ if experiment.optimization_config is None:
1488
+ print_red("Error: update_generator_state is None")
1489
+ return
1469
1490
  metric_names = list(experiment.optimization_config.metrics.keys())
1470
1491
 
1471
1492
  completed_trials = [
@@ -1477,7 +1498,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1477
1498
  y = np.zeros([num_completed_trials, 1])
1478
1499
 
1479
1500
  for t_idx, trial in enumerate(completed_trials):
1480
- trial_parameters = trial.arm.parameters
1501
+ trial_parameters = trial.arms[t_idx].parameters
1481
1502
  x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
1482
1503
  trial_df = data.df[data.df["trial_index"] == trial.index]
1483
1504
  y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]]["mean"].item()
@@ -1617,10 +1638,18 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1617
1638
  def _format_best_sample(self: Any, best_sample: TParameterization, reverse_choice_map: dict) -> None:
1618
1639
  for name in best_sample.keys():
1619
1640
  param = self.parameters.get(name)
1641
+ best_sample_by_name = best_sample[name]
1642
+
1620
1643
  if isinstance(param, RangeParameter) and param.parameter_type == ParameterType.INT:
1621
- best_sample[name] = int(round(best_sample[name]))
1644
+ if best_sample_by_name is not None:
1645
+ best_sample[name] = int(round(float(best_sample_by_name)))
1646
+ else:
1647
+ print_debug("best_sample_by_name was empty")
1622
1648
  elif isinstance(param, ChoiceParameter):
1623
- best_sample[name] = str(reverse_choice_map.get(int(best_sample[name])))
1649
+ if best_sample_by_name is not None:
1650
+ best_sample[name] = str(reverse_choice_map.get(int(best_sample_by_name)))
1651
+ else:
1652
+ print_debug("best_sample_by_name was empty")
1624
1653
 
1625
1654
  decoder_registry["RandomForestGenerationNode"] = RandomForestGenerationNode
1626
1655
 
@@ -1668,10 +1697,10 @@ class InteractiveCLIGenerationNode(ExternalGenerationNode):
1668
1697
 
1669
1698
  def __init__(
1670
1699
  self: Any,
1671
- node_name: str = "INTERACTIVE_GENERATOR",
1700
+ name: str = "INTERACTIVE_GENERATOR",
1672
1701
  ) -> None:
1673
1702
  t0 = time.monotonic()
1674
- super().__init__(node_name=node_name)
1703
+ super().__init__(name=name)
1675
1704
  self.parameters = None
1676
1705
  self.minimize = None
1677
1706
  self.data = None
@@ -1827,10 +1856,10 @@ class InteractiveCLIGenerationNode(ExternalGenerationNode):
1827
1856
 
1828
1857
  @dataclass(init=False)
1829
1858
  class ExternalProgramGenerationNode(ExternalGenerationNode):
1830
- def __init__(self: Any, external_generator: str = args.external_generator, node_name: str = "EXTERNAL_GENERATOR") -> None:
1859
+ def __init__(self: Any, external_generator: str = args.external_generator, name: str = "EXTERNAL_GENERATOR") -> None:
1831
1860
  print_debug("Initializing ExternalProgramGenerationNode...")
1832
1861
  t_init_start = time.monotonic()
1833
- super().__init__(node_name=node_name)
1862
+ super().__init__(name=name)
1834
1863
  self.seed: int = args.seed
1835
1864
  self.external_generator: str = decode_if_base64(external_generator)
1836
1865
  self.constraints = None
@@ -2051,9 +2080,9 @@ def run_live_share_command(force: bool = False) -> Tuple[str, str]:
2051
2080
  return str(result.stdout), str(result.stderr)
2052
2081
  except subprocess.CalledProcessError as e:
2053
2082
  if e.stderr:
2054
- original_print(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
2083
+ print_debug(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
2055
2084
  else:
2056
- original_print(f"run_live_share_command: command failed with error: {e}")
2085
+ print_debug(f"run_live_share_command: command failed with error: {e}")
2057
2086
  return "", str(e.stderr)
2058
2087
  except Exception as e:
2059
2088
  print(f"run_live_share_command: An error occurred: {e}")
@@ -2067,6 +2096,8 @@ def force_live_share() -> bool:
2067
2096
  return False
2068
2097
 
2069
2098
  def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
2099
+ log_data()
2100
+
2070
2101
  if not get_current_run_folder():
2071
2102
  print(f"live_share: get_current_run_folder was empty or false: {get_current_run_folder()}")
2072
2103
  return False
@@ -2080,17 +2111,23 @@ def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
2080
2111
  if stderr:
2081
2112
  print_green(stderr)
2082
2113
  else:
2083
- print_red("This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
2114
+ if stderr and stdout:
2115
+ print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
2116
+ elif stderr:
2117
+ print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}")
2118
+ elif stdout:
2119
+ print_red(f"This call should have shown the CURL, but didnt. Stdout: {stdout}")
2120
+ else:
2121
+ print_red("This call should have shown the CURL, but didnt.")
2084
2122
  if stdout:
2085
2123
  print_debug(f"live_share stdout: {stdout}")
2086
2124
 
2087
2125
  return True
2088
2126
 
2089
2127
  def init_live_share() -> bool:
2090
- with spinner("Initializing live share..."):
2091
- ret = live_share(True, True)
2128
+ ret = live_share(True, True)
2092
2129
 
2093
- return ret
2130
+ return ret
2094
2131
 
2095
2132
  def init_storage(db_url: str) -> None:
2096
2133
  init_engine_and_session_factory(url=db_url, force_init=True)
@@ -2116,7 +2153,11 @@ def try_saving_to_db() -> None:
2116
2153
  else:
2117
2154
  print_red("ax_client was not defined in try_saving_to_db")
2118
2155
  my_exit(101)
2119
- save_generation_strategy(global_gs)
2156
+
2157
+ if global_gs is not None:
2158
+ save_generation_strategy(global_gs)
2159
+ else:
2160
+ print_red("Not saving generation strategy: global_gs was empty")
2120
2161
  except Exception as e:
2121
2162
  print_debug(f"Failed trying to save sqlite3-DB: {e}")
2122
2163
 
@@ -2147,6 +2188,8 @@ def merge_with_job_infos(df: pd.DataFrame) -> pd.DataFrame:
2147
2188
  return merged
2148
2189
 
2149
2190
  def save_results_csv() -> Optional[str]:
2191
+ log_data()
2192
+
2150
2193
  if args.dryrun:
2151
2194
  return None
2152
2195
 
@@ -2183,19 +2226,21 @@ def save_results_csv() -> Optional[str]:
2183
2226
  def get_results_paths() -> tuple[str, str]:
2184
2227
  return (get_current_run_folder(RESULTS_CSV_FILENAME), get_state_file_name('pd.json'))
2185
2228
 
2229
+ def ax_client_get_trials_data_frame() -> Optional[pd.DataFrame]:
2230
+ if not ax_client:
2231
+ my_exit(101)
2232
+
2233
+ return None
2234
+
2235
+ return ax_client.get_trials_data_frame()
2236
+
2186
2237
  def fetch_and_prepare_trials() -> Optional[pd.DataFrame]:
2187
2238
  if not ax_client:
2188
2239
  return None
2189
2240
 
2190
2241
  ax_client.experiment.fetch_data()
2191
- df = ax_client.get_trials_data_frame()
2192
-
2193
- #print("========================")
2194
- #print("BEFORE merge_with_job_infos:")
2195
- #print(df["generation_node"])
2242
+ df = ax_client_get_trials_data_frame()
2196
2243
  df = merge_with_job_infos(df)
2197
- #print("AFTER merge_with_job_infos:")
2198
- #print(df["generation_node"])
2199
2244
 
2200
2245
  return df
2201
2246
 
@@ -2206,11 +2251,24 @@ def write_csv(df: pd.DataFrame, path: str) -> None:
2206
2251
  pass
2207
2252
  df.to_csv(path, index=False, float_format="%.30f")
2208
2253
 
2254
+ def ax_client_to_json_snapshot() -> Optional[dict]:
2255
+ if not ax_client:
2256
+ my_exit(101)
2257
+
2258
+ return None
2259
+
2260
+ json_snapshot = ax_client.to_json_snapshot()
2261
+
2262
+ return json_snapshot
2263
+
2209
2264
  def write_json_snapshot(path: str) -> None:
2210
2265
  if ax_client is not None:
2211
- json_snapshot = ax_client.to_json_snapshot()
2212
- with open(path, "w", encoding="utf-8") as f:
2213
- json.dump(json_snapshot, f, indent=4)
2266
+ json_snapshot = ax_client_to_json_snapshot()
2267
+ if json_snapshot is not None:
2268
+ with open(path, "w", encoding="utf-8") as f:
2269
+ json.dump(json_snapshot, f, indent=4)
2270
+ else:
2271
+ print_debug('json_snapshot from ax_client_to_json_snapshot was None')
2214
2272
  else:
2215
2273
  print_red("write_json_snapshot: ax_client was None")
2216
2274
 
@@ -2410,7 +2468,7 @@ def set_nr_inserted_jobs(new_nr_inserted_jobs: int) -> None:
2410
2468
 
2411
2469
  def write_worker_usage() -> None:
2412
2470
  if len(WORKER_PERCENTAGE_USAGE):
2413
- csv_filename = get_current_run_folder('worker_usage.csv')
2471
+ csv_filename = get_current_run_folder(worker_usage_file)
2414
2472
 
2415
2473
  csv_columns = ['time', 'num_parallel_jobs', 'nr_current_workers', 'percentage']
2416
2474
 
@@ -2420,35 +2478,39 @@ def write_worker_usage() -> None:
2420
2478
  csv_writer.writerow(row)
2421
2479
  else:
2422
2480
  if is_slurm_job():
2423
- print_debug("WORKER_PERCENTAGE_USAGE seems to be empty. Not writing worker_usage.csv")
2481
+ print_debug(f"WORKER_PERCENTAGE_USAGE seems to be empty. Not writing {worker_usage_file}")
2424
2482
 
2425
2483
  def log_system_usage() -> None:
2484
+ global LAST_LOG_TIME
2485
+
2486
+ now = time.time()
2487
+ if now - LAST_LOG_TIME < 30:
2488
+ return
2489
+
2490
+ LAST_LOG_TIME = int(now)
2491
+
2426
2492
  if not get_current_run_folder():
2427
2493
  return
2428
2494
 
2429
2495
  ram_cpu_csv_file_path = os.path.join(get_current_run_folder(), "cpu_ram_usage.csv")
2430
-
2431
2496
  makedirs(os.path.dirname(ram_cpu_csv_file_path))
2432
2497
 
2433
2498
  file_exists = os.path.isfile(ram_cpu_csv_file_path)
2434
2499
 
2435
- with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
2436
- writer = csv.writer(file)
2437
-
2438
- current_time = int(time.time())
2439
-
2440
- if process is not None:
2441
- mem_proc = process.memory_info()
2442
-
2443
- if mem_proc is not None:
2444
- ram_usage_mb = mem_proc.rss / (1024 * 1024)
2445
- cpu_usage_percent = psutil.cpu_percent(percpu=False)
2500
+ mem_proc = process.memory_info() if process else None
2501
+ if not mem_proc:
2502
+ return
2446
2503
 
2447
- if ram_usage_mb > 0 and cpu_usage_percent > 0:
2448
- if not file_exists:
2449
- writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
2504
+ ram_usage_mb = mem_proc.rss / (1024 * 1024)
2505
+ cpu_usage_percent = psutil.cpu_percent(percpu=False)
2506
+ if ram_usage_mb <= 0 or cpu_usage_percent <= 0:
2507
+ return
2450
2508
 
2451
- writer.writerow([current_time, ram_usage_mb, cpu_usage_percent])
2509
+ with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
2510
+ writer = csv.writer(file)
2511
+ if not file_exists:
2512
+ writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
2513
+ writer.writerow([int(now), ram_usage_mb, cpu_usage_percent])
2452
2514
 
2453
2515
  def write_process_info() -> None:
2454
2516
  try:
@@ -2798,10 +2860,20 @@ def print_debug_get_next_trials(got: int, requested: int, _line: int) -> None:
2798
2860
  log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, 0, "")
2799
2861
 
2800
2862
  def print_debug_progressbar(msg: str) -> None:
2801
- time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
2802
- msg = f"{time_str} ({worker_generator_uuid}): {msg}"
2863
+ global last_msg_progressbar, last_msg_raw
2864
+
2865
+ try:
2866
+ with last_lock_print_debug:
2867
+ if msg != last_msg_raw:
2868
+ time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
2869
+ full_msg = f"{time_str} ({worker_generator_uuid}): {msg}"
2803
2870
 
2804
- _debug_progressbar(msg)
2871
+ _debug_progressbar(full_msg)
2872
+
2873
+ last_msg_raw = msg
2874
+ last_msg_progressbar = full_msg
2875
+ except Exception as e:
2876
+ print(f"Error in print_debug_progressbar: {e}", flush=True)
2805
2877
 
2806
2878
  def get_process_info(pid: Any) -> str:
2807
2879
  try:
@@ -3308,145 +3380,467 @@ def parse_experiment_parameters() -> None:
3308
3380
 
3309
3381
  experiment_parameters = params # type: ignore[assignment]
3310
3382
 
3311
- def check_factorial_range() -> None:
3312
- if args.model and args.model == "FACTORIAL":
3313
- _fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
3383
+ def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
3384
+ pf_start_time = time.time()
3314
3385
 
3315
- def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
3316
- if value_type not in valid_value_types:
3317
- valid_value_types_string = ", ".join(valid_value_types)
3318
- _fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
3386
+ if not path_to_calculate:
3387
+ return False
3319
3388
 
3320
- def check_range_params_length(this_args: Union[str, list]) -> None:
3321
- if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
3322
- _fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
3389
+ global CURRENT_RUN_FOLDER
3390
+ global RESULT_CSV_FILE
3391
+ global arg_result_names
3323
3392
 
3324
- def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
3325
- if upper_bound is None or lower_bound is None:
3326
- _fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
3327
- if upper_bound == lower_bound:
3328
- if lower_bound == 0:
3329
- _fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
3330
- print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
3331
- if upper_bound is not None:
3332
- lower_bound = -upper_bound
3393
+ if not path_to_calculate:
3394
+ print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
3395
+ return False
3333
3396
 
3334
- def format_value(value: Any, float_format: str = '.80f') -> str:
3335
- try:
3336
- if isinstance(value, float):
3337
- s = format(value, float_format)
3338
- s = s.rstrip('0').rstrip('.') if '.' in s else s
3339
- return s
3340
- return str(value)
3341
- except Exception as e:
3342
- print_red(f"⚠ Error formatting the number {value}: {e}")
3343
- return str(value)
3397
+ if not os.path.exists(path_to_calculate):
3398
+ print_red(f"Path '{path_to_calculate}' does not exist")
3399
+ return False
3344
3400
 
3345
- def replace_parameters_in_string(
3346
- parameters: dict,
3347
- input_string: str,
3348
- float_format: str = '.20f',
3349
- additional_prefixes: list[str] = [],
3350
- additional_patterns: list[str] = [],
3351
- ) -> str:
3352
- try:
3353
- prefixes = ['$', '%'] + additional_prefixes
3354
- patterns = ['{key}', '({key})'] + additional_patterns
3401
+ ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
3355
3402
 
3356
- for key, value in parameters.items():
3357
- replacement = format_value(value, float_format=float_format)
3358
- for prefix in prefixes:
3359
- for pattern in patterns:
3360
- token = prefix + pattern.format(key=key)
3361
- input_string = input_string.replace(token, replacement)
3403
+ if not os.path.exists(ax_client_json):
3404
+ print_red(f"Path '{ax_client_json}' not found")
3405
+ return False
3362
3406
 
3363
- input_string = input_string.replace('\r', ' ').replace('\n', ' ')
3364
- return input_string
3407
+ checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
3408
+ if not os.path.exists(checkpoint_file):
3409
+ print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
3410
+ return False
3365
3411
 
3366
- except Exception as e:
3367
- print_red(f"\n⚠ Error: {e}")
3368
- return ""
3412
+ RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3413
+ if not os.path.exists(RESULT_CSV_FILE):
3414
+ print_red(f"{RESULT_CSV_FILE} not found")
3415
+ return False
3369
3416
 
3370
- def get_memory_usage() -> float:
3371
- user_uid = os.getuid()
3417
+ res_names = []
3372
3418
 
3373
- memory_usage = float(sum(
3374
- p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
3375
- if p.info['uids'].real == user_uid
3376
- ) / (1024 * 1024))
3419
+ res_names_file = f"{path_to_calculate}/result_names.txt"
3420
+ if not os.path.exists(res_names_file):
3421
+ print_red(f"File '{res_names_file}' does not exist")
3422
+ return False
3377
3423
 
3378
- return memory_usage
3424
+ try:
3425
+ with open(res_names_file, "r", encoding="utf-8") as file:
3426
+ lines = file.readlines()
3427
+ except Exception as e:
3428
+ print_red(f"Error reading file '{res_names_file}': {e}")
3429
+ return False
3379
3430
 
3380
- class MonitorProcess:
3381
- def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
3382
- self.pid = pid
3383
- self.interval = interval
3384
- self.running = True
3385
- self.thread = threading.Thread(target=self._monitor)
3386
- self.thread.daemon = True
3431
+ for line in lines:
3432
+ entry = line.strip()
3433
+ if entry != "":
3434
+ res_names.append(entry)
3387
3435
 
3388
- fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
3436
+ if len(res_names) < 2:
3437
+ print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
3438
+ return False
3389
3439
 
3390
- def _monitor(self: Any) -> None:
3391
- try:
3392
- _internal_process = psutil.Process(self.pid)
3393
- while self.running and _internal_process.is_running():
3394
- crf = get_current_run_folder()
3440
+ load_username_to_args(path_to_calculate)
3395
3441
 
3396
- if crf and crf != "":
3397
- log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
3442
+ CURRENT_RUN_FOLDER = path_to_calculate
3398
3443
 
3399
- os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
3444
+ arg_result_names = res_names
3400
3445
 
3401
- with open(log_file_path, mode="a", encoding="utf-8") as log_file:
3402
- hostname = socket.gethostname()
3446
+ load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
3403
3447
 
3404
- slurm_job_id = os.getenv("SLURM_JOB_ID")
3448
+ if experiment_parameters is None:
3449
+ return False
3405
3450
 
3406
- if slurm_job_id:
3407
- hostname += f"-SLURM-ID-{slurm_job_id}"
3451
+ show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
3408
3452
 
3409
- total_memory = psutil.virtual_memory().total / (1024 * 1024)
3410
- cpu_usage = psutil.cpu_percent(interval=5)
3453
+ pf_end_time = time.time()
3411
3454
 
3412
- memory_usage = get_memory_usage()
3455
+ print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
3413
3456
 
3414
- unix_timestamp = int(time.time())
3457
+ return True
3415
3458
 
3416
- log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
3417
- time.sleep(self.interval)
3418
- except psutil.NoSuchProcess:
3419
- pass
3459
+ def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
3460
+ if args.dryrun:
3461
+ print_debug("Not showing Pareto-frontier data with --dryrun")
3462
+ return None
3420
3463
 
3421
- def __enter__(self: Any) -> None:
3422
- self.thread.start()
3423
- return self
3464
+ if len(res_names) > 1:
3465
+ try:
3466
+ show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
3467
+ except Exception as e:
3468
+ inner_tb = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
3469
+ print_red(f"show_pareto_frontier_data() failed with exception '{e}':\n{inner_tb}")
3470
+ else:
3471
+ print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
3472
+ return None
3424
3473
 
3425
- def __exit__(self: Any, exc_type: Any, exc_value: Any, _traceback: Any) -> None:
3426
- self.running = False
3427
- self.thread.join()
3474
+ def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
3475
+ pareto_front_data: dict = {}
3428
3476
 
3429
- def execute_bash_code_log_time(code: str) -> list:
3430
- process_item = subprocess.Popen(code, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
3477
+ all_combinations = list(combinations(range(len(arg_result_names)), 2))
3431
3478
 
3432
- with MonitorProcess(process_item.pid):
3433
- try:
3434
- stdout, stderr = process_item.communicate()
3435
- result = subprocess.CompletedProcess(
3436
- args=code, returncode=process_item.returncode, stdout=stdout, stderr=stderr
3437
- )
3438
- return [result.stdout, result.stderr, result.returncode, None]
3439
- except subprocess.CalledProcessError as e:
3440
- real_exit_code = e.returncode
3441
- signal_code = None
3442
- if real_exit_code < 0:
3443
- signal_code = abs(e.returncode)
3444
- real_exit_code = 1
3445
- return [e.stdout, e.stderr, real_exit_code, signal_code]
3479
+ skip = False
3446
3480
 
3447
- def execute_bash_code(code: str) -> list:
3448
- try:
3449
- result = subprocess.run(
3481
+ for i, j in all_combinations:
3482
+ if not skip:
3483
+ metric_x = arg_result_names[i]
3484
+ metric_y = arg_result_names[j]
3485
+
3486
+ x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
3487
+ y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
3488
+
3489
+ try:
3490
+ if metric_x not in pareto_front_data:
3491
+ pareto_front_data[metric_x] = {}
3492
+
3493
+ pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
3494
+ except ax.exceptions.core.DataRequiredError as e:
3495
+ print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
3496
+ except SignalINT:
3497
+ print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
3498
+ skip = True
3499
+
3500
+ return pareto_front_data
3501
+
3502
+ def pareto_front_transform_objectives(
3503
+ points: List[Tuple[Any, float, float]],
3504
+ primary_name: str,
3505
+ secondary_name: str
3506
+ ) -> Tuple[np.ndarray, np.ndarray]:
3507
+ primary_idx = arg_result_names.index(primary_name)
3508
+ secondary_idx = arg_result_names.index(secondary_name)
3509
+
3510
+ x = np.array([p[1] for p in points])
3511
+ y = np.array([p[2] for p in points])
3512
+
3513
+ if arg_result_min_or_max[primary_idx] == "max":
3514
+ x = -x
3515
+ elif arg_result_min_or_max[primary_idx] != "min":
3516
+ raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
3517
+
3518
+ if arg_result_min_or_max[secondary_idx] == "max":
3519
+ y = -y
3520
+ elif arg_result_min_or_max[secondary_idx] != "min":
3521
+ raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
3522
+
3523
+ return x, y
3524
+
3525
+ def get_pareto_frontier_points(
3526
+ path_to_calculate: str,
3527
+ primary_objective: str,
3528
+ secondary_objective: str,
3529
+ x_minimize: bool,
3530
+ y_minimize: bool,
3531
+ absolute_metrics: List[str],
3532
+ num_points: int
3533
+ ) -> Optional[dict]:
3534
+ records = pareto_front_aggregate_data(path_to_calculate)
3535
+
3536
+ if records is None:
3537
+ return None
3538
+
3539
+ points = pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
3540
+ x, y = pareto_front_transform_objectives(points, primary_objective, secondary_objective)
3541
+ selected_points = pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
3542
+ result = pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
3543
+
3544
+ return result
3545
+
3546
+ def pareto_front_table_read_csv() -> List[Dict[str, str]]:
3547
+ with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
3548
+ return list(csv.DictReader(f))
3549
+
3550
+ def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
3551
+ table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
3552
+
3553
+ rows = pareto_front_table_read_csv()
3554
+ if not rows:
3555
+ table.add_column("No data found")
3556
+ return table
3557
+
3558
+ filtered_rows = pareto_front_table_filter_rows(rows, idxs)
3559
+ if not filtered_rows:
3560
+ table.add_column("No matching entries")
3561
+ return table
3562
+
3563
+ param_cols, result_cols = pareto_front_table_get_columns(filtered_rows[0])
3564
+
3565
+ pareto_front_table_add_headers(table, param_cols, result_cols)
3566
+ pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
3567
+
3568
+ return table
3569
+
3570
+ def pareto_front_build_return_structure(
3571
+ path_to_calculate: str,
3572
+ selected_points: List[Tuple[Any, float, float]],
3573
+ records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
3574
+ absolute_metrics: List[str],
3575
+ primary_name: str,
3576
+ secondary_name: str
3577
+ ) -> dict:
3578
+ results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3579
+ result_names_file = f"{path_to_calculate}/result_names.txt"
3580
+
3581
+ with open(result_names_file, mode="r", encoding="utf-8") as f:
3582
+ result_names = [line.strip() for line in f if line.strip()]
3583
+
3584
+ csv_rows = {}
3585
+ with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
3586
+ reader = csv.DictReader(csvfile)
3587
+ for row in reader:
3588
+ trial_index = int(row['trial_index'])
3589
+ csv_rows[trial_index] = row
3590
+
3591
+ ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
3592
+ ignored_columns.update(result_names)
3593
+
3594
+ param_dicts = []
3595
+ idxs = []
3596
+ means_dict = defaultdict(list)
3597
+
3598
+ for (trial_index, arm_name), _, _ in selected_points:
3599
+ row = csv_rows.get(trial_index, {})
3600
+ if row == {} or row is None or row['arm_name'] != arm_name:
3601
+ continue
3602
+
3603
+ idxs.append(int(row["trial_index"]))
3604
+
3605
+ param_dict: dict[str, int | float | str] = {}
3606
+ for key, value in row.items():
3607
+ if key not in ignored_columns:
3608
+ try:
3609
+ param_dict[key] = int(value)
3610
+ except ValueError:
3611
+ try:
3612
+ param_dict[key] = float(value)
3613
+ except ValueError:
3614
+ param_dict[key] = value
3615
+
3616
+ param_dicts.append(param_dict)
3617
+
3618
+ for metric in absolute_metrics:
3619
+ means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
3620
+
3621
+ ret = {
3622
+ primary_name: {
3623
+ secondary_name: {
3624
+ "absolute_metrics": absolute_metrics,
3625
+ "param_dicts": param_dicts,
3626
+ "means": dict(means_dict),
3627
+ "idxs": idxs
3628
+ },
3629
+ "absolute_metrics": absolute_metrics
3630
+ }
3631
+ }
3632
+
3633
+ return ret
3634
+
3635
+ def pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
3636
+ results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
3637
+ result_names_file = f"{path_to_calculate}/result_names.txt"
3638
+
3639
+ if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
3640
+ return None
3641
+
3642
+ with open(result_names_file, mode="r", encoding="utf-8") as f:
3643
+ result_names = [line.strip() for line in f if line.strip()]
3644
+
3645
+ records: dict = defaultdict(lambda: {'means': {}})
3646
+
3647
+ with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
3648
+ reader = csv.DictReader(csvfile)
3649
+ for row in reader:
3650
+ trial_index = int(row['trial_index'])
3651
+ arm_name = row['arm_name']
3652
+ key = (trial_index, arm_name)
3653
+
3654
+ for metric in result_names:
3655
+ if metric in row:
3656
+ try:
3657
+ records[key]['means'][metric] = float(row[metric])
3658
+ except ValueError:
3659
+ continue
3660
+
3661
+ return records
3662
+
3663
+ def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
3664
+ if data is None:
3665
+ print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
3666
+ return
3667
+
3668
+ if not supports_sixel():
3669
+ print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
3670
+ return
3671
+
3672
+ import matplotlib.pyplot as plt
3673
+
3674
+ means = data[x_metric][y_metric]["means"]
3675
+
3676
+ x_values = means[x_metric]
3677
+ y_values = means[y_metric]
3678
+
3679
+ fig, _ax = plt.subplots()
3680
+
3681
+ _ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
3682
+
3683
+ _ax.set_xlabel(x_metric)
3684
+ _ax.set_ylabel(y_metric)
3685
+
3686
+ _ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
3687
+
3688
+ _ax.ticklabel_format(style='plain', axis='both', useOffset=False)
3689
+
3690
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
3691
+ plt.savefig(tmp_file.name, dpi=300)
3692
+
3693
+ print_image_to_cli(tmp_file.name, 1000)
3694
+
3695
+ plt.close(fig)
3696
+
3697
+ def pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
3698
+ all_columns = list(first_row.keys())
3699
+ ignored_cols = set(special_col_names) - {"trial_index"}
3700
+
3701
+ param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
3702
+ result_cols = [col for col in arg_result_names if col in all_columns]
3703
+ return param_cols, result_cols
3704
+
3705
+ def check_factorial_range() -> None:
3706
+ if args.model and args.model == "FACTORIAL":
3707
+ _fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
3708
+
3709
+ def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
3710
+ if value_type not in valid_value_types:
3711
+ valid_value_types_string = ", ".join(valid_value_types)
3712
+ _fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
3713
+
3714
+ def check_range_params_length(this_args: Union[str, list]) -> None:
3715
+ if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
3716
+ _fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
3717
+
3718
+ def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
3719
+ if upper_bound is None or lower_bound is None:
3720
+ _fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
3721
+ if upper_bound == lower_bound:
3722
+ if lower_bound == 0:
3723
+ _fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
3724
+ print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
3725
+ if upper_bound is not None:
3726
+ lower_bound = -upper_bound
3727
+
3728
+ def format_value(value: Any, float_format: str = '.80f') -> str:
3729
+ try:
3730
+ if isinstance(value, float):
3731
+ s = format(value, float_format)
3732
+ s = s.rstrip('0').rstrip('.') if '.' in s else s
3733
+ return s
3734
+ return str(value)
3735
+ except Exception as e:
3736
+ print_red(f"⚠ Error formatting the number {value}: {e}")
3737
+ return str(value)
3738
+
3739
+ def replace_parameters_in_string(
3740
+ parameters: dict,
3741
+ input_string: str,
3742
+ float_format: str = '.20f',
3743
+ additional_prefixes: list[str] = [],
3744
+ additional_patterns: list[str] = [],
3745
+ ) -> str:
3746
+ try:
3747
+ prefixes = ['$', '%'] + additional_prefixes
3748
+ patterns = ['{' + 'key' + '}', '(' + '{' + 'key' + '}' + ')'] + additional_patterns
3749
+
3750
+ for key, value in parameters.items():
3751
+ replacement = format_value(value, float_format=float_format)
3752
+ for prefix in prefixes:
3753
+ for pattern in patterns:
3754
+ token = prefix + pattern.format(key=key)
3755
+ input_string = input_string.replace(token, replacement)
3756
+
3757
+ input_string = input_string.replace('\r', ' ').replace('\n', ' ')
3758
+ return input_string
3759
+
3760
+ except Exception as e:
3761
+ print_red(f"\n⚠ Error: {e}")
3762
+ return ""
3763
+
3764
+ def get_memory_usage() -> float:
3765
+ user_uid = os.getuid()
3766
+
3767
+ memory_usage = float(sum(
3768
+ p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
3769
+ if p.info['uids'].real == user_uid
3770
+ ) / (1024 * 1024))
3771
+
3772
+ return memory_usage
3773
+
3774
+ class MonitorProcess:
3775
+ def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
3776
+ self.pid = pid
3777
+ self.interval = interval
3778
+ self.running = True
3779
+ self.thread = threading.Thread(target=self._monitor)
3780
+ self.thread.daemon = True
3781
+
3782
+ fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
3783
+
3784
+ def _monitor(self: Any) -> None:
3785
+ try:
3786
+ _internal_process = psutil.Process(self.pid)
3787
+ while self.running and _internal_process.is_running():
3788
+ crf = get_current_run_folder()
3789
+
3790
+ if crf and crf != "":
3791
+ log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
3792
+
3793
+ os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
3794
+
3795
+ with open(log_file_path, mode="a", encoding="utf-8") as log_file:
3796
+ hostname = socket.gethostname()
3797
+
3798
+ slurm_job_id = os.getenv("SLURM_JOB_ID")
3799
+
3800
+ if slurm_job_id:
3801
+ hostname += f"-SLURM-ID-{slurm_job_id}"
3802
+
3803
+ total_memory = psutil.virtual_memory().total / (1024 * 1024)
3804
+ cpu_usage = psutil.cpu_percent(interval=5)
3805
+
3806
+ memory_usage = get_memory_usage()
3807
+
3808
+ unix_timestamp = int(time.time())
3809
+
3810
+ log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
3811
+ time.sleep(self.interval)
3812
+ except psutil.NoSuchProcess:
3813
+ pass
3814
+
3815
+ def __enter__(self: Any) -> None:
3816
+ self.thread.start()
3817
+ return self
3818
+
3819
+ def __exit__(self: Any, exc_type: Any, exc_value: Any, _traceback: Any) -> None:
3820
+ self.running = False
3821
+ self.thread.join()
3822
+
3823
+ def execute_bash_code_log_time(code: str) -> list:
3824
+ process_item = subprocess.Popen(code, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
3825
+
3826
+ with MonitorProcess(process_item.pid):
3827
+ try:
3828
+ stdout, stderr = process_item.communicate()
3829
+ result = subprocess.CompletedProcess(
3830
+ args=code, returncode=process_item.returncode, stdout=stdout, stderr=stderr
3831
+ )
3832
+ return [result.stdout, result.stderr, result.returncode, None]
3833
+ except subprocess.CalledProcessError as e:
3834
+ real_exit_code = e.returncode
3835
+ signal_code = None
3836
+ if real_exit_code < 0:
3837
+ signal_code = abs(e.returncode)
3838
+ real_exit_code = 1
3839
+ return [e.stdout, e.stderr, real_exit_code, signal_code]
3840
+
3841
+ def execute_bash_code(code: str) -> list:
3842
+ try:
3843
+ result = subprocess.run(
3450
3844
  code,
3451
3845
  shell=True,
3452
3846
  check=True,
@@ -3559,7 +3953,7 @@ def _add_to_csv_acquire_lock(lockfile: str, dir_path: str) -> bool:
3559
3953
  time.sleep(wait_time)
3560
3954
  max_wait -= wait_time
3561
3955
  except Exception as e:
3562
- print("Lock error:", e)
3956
+ print_red(f"Lock error: {e}")
3563
3957
  return False
3564
3958
  return False
3565
3959
 
@@ -3632,12 +4026,12 @@ def find_file_paths(_text: str) -> List[str]:
3632
4026
  def check_file_info(file_path: str) -> str:
3633
4027
  if not os.path.exists(file_path):
3634
4028
  if not args.tests:
3635
- print(f"check_file_info: The file {file_path} does not exist.")
4029
+ print_red(f"check_file_info: The file {file_path} does not exist.")
3636
4030
  return ""
3637
4031
 
3638
4032
  if not os.access(file_path, os.R_OK):
3639
4033
  if not args.tests:
3640
- print(f"check_file_info: The file {file_path} is not readable.")
4034
+ print_red(f"check_file_info: The file {file_path} is not readable.")
3641
4035
  return ""
3642
4036
 
3643
4037
  file_stat = os.stat(file_path)
@@ -3751,7 +4145,7 @@ def count_defective_nodes(file_path: Union[str, None] = None, entry: Any = None)
3751
4145
  return sorted(set(entries))
3752
4146
 
3753
4147
  except Exception as e:
3754
- print(f"An error has occurred: {e}")
4148
+ print_red(f"An error has occurred: {e}")
3755
4149
  return []
3756
4150
 
3757
4151
  def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]:
@@ -3762,7 +4156,7 @@ def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]
3762
4156
 
3763
4157
  fool_linter(tmp)
3764
4158
  except RuntimeError:
3765
- print(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
4159
+ print_red(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
3766
4160
  count_defective_nodes(None, socket.gethostname())
3767
4161
  return return_in_case_of_error
3768
4162
  except Exception:
@@ -4233,6 +4627,8 @@ def evaluate(parameters_with_trial_index: dict) -> Optional[Union[int, float, Di
4233
4627
  trial_index = parameters_with_trial_index["trial_idx"]
4234
4628
  submit_time = parameters_with_trial_index["submit_time"]
4235
4629
 
4630
+ print(f'Trial-Index: {trial_index}')
4631
+
4236
4632
  queue_time = abs(int(time.time()) - int(submit_time))
4237
4633
 
4238
4634
  start_nvidia_smi_thread()
@@ -4470,7 +4866,7 @@ def replace_string_with_params(input_string: str, params: list) -> str:
4470
4866
  return replaced_string
4471
4867
  except AssertionError as e:
4472
4868
  error_text = f"Error in replace_string_with_params: {e}"
4473
- print(error_text)
4869
+ print_red(error_text)
4474
4870
  raise
4475
4871
 
4476
4872
  return ""
@@ -4747,9 +5143,7 @@ def get_sixel_graphics_data(_pd_csv: str, _force: bool = False) -> list:
4747
5143
  _params = [_command, plot, _tmp, plot_type, tmp_file, _width]
4748
5144
  data.append(_params)
4749
5145
  except Exception as e:
4750
- tb = traceback.format_exc()
4751
- print_red(f"Error trying to print {plot_type} to CLI: {e}, {tb}")
4752
- print_debug(f"Error trying to print {plot_type} to CLI: {e}")
5146
+ print_red(f"Error trying to print {plot_type} to CLI: {e}")
4753
5147
 
4754
5148
  return data
4755
5149
 
@@ -4940,15 +5334,17 @@ def abandon_job(job: Job, trial_index: int, reason: str) -> bool:
4940
5334
  if job:
4941
5335
  try:
4942
5336
  if ax_client:
4943
- _trial = ax_client.get_trial(trial_index)
4944
- _trial.mark_abandoned(reason=reason)
5337
+ _trial = get_ax_client_trial(trial_index)
5338
+ if _trial is None:
5339
+ return False
5340
+
5341
+ mark_abandoned(_trial, reason, trial_index)
4945
5342
  print_debug(f"abandon_job: removing job {job}, trial_index: {trial_index}")
4946
5343
  global_vars["jobs"].remove((job, trial_index))
4947
5344
  else:
4948
5345
  _fatal_error("ax_client could not be found", 101)
4949
5346
  except Exception as e:
4950
- print(f"ERROR in line {get_line_info()}: {e}")
4951
- print_debug(f"ERROR in line {get_line_info()}: {e}")
5347
+ print_red(f"ERROR in line {get_line_info()}: {e}")
4952
5348
  return False
4953
5349
  job.cancel()
4954
5350
  return True
@@ -4961,23 +5357,120 @@ def abandon_all_jobs() -> None:
4961
5357
  if not abandoned:
4962
5358
  print_debug(f"Job {job} could not be abandoned.")
4963
5359
 
4964
- def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
4965
- if args.dryrun:
4966
- print_debug("Not showing Pareto-frontier data with --dryrun")
5360
+ def write_result_to_trace_file(res: str) -> bool:
5361
+ if res is None:
5362
+ sys.stderr.write("Provided result is None, nothing to write\n")
5363
+ return False
5364
+
5365
+ target_folder = get_current_run_folder()
5366
+ target_file = os.path.join(target_folder, "optimization_trace.html")
5367
+
5368
+ try:
5369
+ file_handle = open(target_file, "w", encoding="utf-8")
5370
+ except OSError as error:
5371
+ sys.stderr.write("Unable to open target file for writing\n")
5372
+ sys.stderr.write(str(error) + "\n")
5373
+ return False
5374
+
5375
+ try:
5376
+ written = file_handle.write(str(res))
5377
+ file_handle.flush()
5378
+
5379
+ if written == 0:
5380
+ sys.stderr.write("No data was written to the file\n")
5381
+ file_handle.close()
5382
+ return False
5383
+ except Exception as error:
5384
+ sys.stderr.write("Error occurred while writing to file\n")
5385
+ sys.stderr.write(str(error) + "\n")
5386
+ file_handle.close()
5387
+ return False
5388
+
5389
+ try:
5390
+ file_handle.close()
5391
+ except Exception as error:
5392
+ sys.stderr.write("Failed to properly close file\n")
5393
+ sys.stderr.write(str(error) + "\n")
5394
+ return False
5395
+
5396
+ return True
5397
+
5398
+ def render(plot_config: AxPlotConfig) -> None:
5399
+ if plot_config is None or "data" not in plot_config:
4967
5400
  return None
4968
5401
 
4969
- if len(res_names) > 1:
4970
- try:
4971
- show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
4972
- except Exception as e:
4973
- print_red(f"show_pareto_frontier_data() failed with exception '{e}'")
4974
- else:
4975
- print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
5402
+ res: str = plot_config.data # type: ignore
5403
+
5404
+ repair_funcs = """
5405
+ function decodeBData(obj) {
5406
+ if (!obj || typeof obj !== "object") {
5407
+ return obj;
5408
+ }
5409
+
5410
+ if (obj.bdata && obj.dtype) {
5411
+ var binary_string = atob(obj.bdata);
5412
+ var len = binary_string.length;
5413
+ var bytes = new Uint8Array(len);
5414
+
5415
+ for (var i = 0; i < len; i++) {
5416
+ bytes[i] = binary_string.charCodeAt(i);
5417
+ }
5418
+
5419
+ switch (obj.dtype) {
5420
+ case "i1": return Array.from(new Int8Array(bytes.buffer));
5421
+ case "i2": return Array.from(new Int16Array(bytes.buffer));
5422
+ case "i4": return Array.from(new Int32Array(bytes.buffer));
5423
+ case "f4": return Array.from(new Float32Array(bytes.buffer));
5424
+ case "f8": return Array.from(new Float64Array(bytes.buffer));
5425
+ default:
5426
+ console.error("Unknown dtype:", obj.dtype);
5427
+ return [];
5428
+ }
5429
+ }
5430
+
5431
+ return obj;
5432
+ }
5433
+
5434
+ function repairTraces(traces) {
5435
+ var fixed = [];
5436
+
5437
+ for (var i = 0; i < traces.length; i++) {
5438
+ var t = traces[i];
5439
+
5440
+ if (t.x) {
5441
+ t.x = decodeBData(t.x);
5442
+ }
5443
+
5444
+ if (t.y) {
5445
+ t.y = decodeBData(t.y);
5446
+ }
5447
+
5448
+ fixed.push(t);
5449
+ }
5450
+
5451
+ return fixed;
5452
+ }
5453
+ """
5454
+
5455
+ res = str(res)
5456
+
5457
+ res = f"<div id='plot' style='width:100%;height:600px;'></div>\n<script type='text/javascript' src='https://cdn.plot.ly/plotly-latest.min.js'></script><script>{repair_funcs}\nconst True = true;\nconst False = false;\nconst data = {res};\ndata.data = repairTraces(data.data);\nPlotly.newPlot(document.getElementById('plot'), data.data, data.layout);</script>"
5458
+
5459
+ write_result_to_trace_file(res)
5460
+
4976
5461
  return None
4977
5462
 
4978
5463
  def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None) -> None:
4979
5464
  global END_PROGRAM_RAN
4980
5465
 
5466
+ #dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.training_data[0].X)
5467
+ #dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.training_data[0].Y)
5468
+ #dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.outcomes)
5469
+
5470
+ if ax_client is not None:
5471
+ if len(arg_result_names) == 1:
5472
+ render(ax_client.get_optimization_trace())
5473
+
4981
5474
  wait_for_jobs_to_complete()
4982
5475
 
4983
5476
  show_pareto_or_error_msg(get_current_run_folder(), arg_result_names)
@@ -5011,7 +5504,7 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
5011
5504
  _exit = new_exit
5012
5505
  except (SignalUSR, SignalINT, SignalCONT, KeyboardInterrupt):
5013
5506
  print_red("\n⚠ You pressed CTRL+C or a signal was sent. Program execution halted while ending program.")
5014
- print("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
5507
+ print_red("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
5015
5508
  new_exit = show_end_table_and_save_end_files()
5016
5509
  if new_exit > 0:
5017
5510
  _exit = new_exit
@@ -5032,19 +5525,29 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
5032
5525
 
5033
5526
  my_exit(_exit)
5034
5527
 
5528
+ def save_ax_client_to_json_file(checkpoint_filepath: str) -> None:
5529
+ if not ax_client:
5530
+ my_exit(101)
5531
+
5532
+ return None
5533
+
5534
+ ax_client.save_to_json_file(checkpoint_filepath)
5535
+
5536
+ return None
5537
+
5035
5538
  def save_checkpoint(trial_nr: int = 0, eee: Union[None, str, Exception] = None) -> None:
5036
5539
  if trial_nr > 3:
5037
5540
  if eee:
5038
- print(f"Error during saving checkpoint: {eee}")
5541
+ print_red(f"Error during saving checkpoint: {eee}")
5039
5542
  else:
5040
- print("Error during saving checkpoint")
5543
+ print_red("Error during saving checkpoint")
5041
5544
  return
5042
5545
 
5043
5546
  try:
5044
5547
  checkpoint_filepath = get_state_file_name('checkpoint.json')
5045
5548
 
5046
5549
  if ax_client:
5047
- ax_client.save_to_json_file(filepath=checkpoint_filepath)
5550
+ save_ax_client_to_json_file(checkpoint_filepath)
5048
5551
  else:
5049
5552
  _fatal_error("Something went wrong using the ax_client", 101)
5050
5553
  except Exception as e:
@@ -5207,7 +5710,7 @@ def parse_equation_item(comparer_found: bool, item: str, parsed: list, parsed_or
5207
5710
  })
5208
5711
  elif item in [">=", "<="]:
5209
5712
  if comparer_found:
5210
- print("There is already one comparison operator! Cannot have more than one in an equation!")
5713
+ print_red("There is already one comparison operator! Cannot have more than one in an equation!")
5211
5714
  return_totally = True
5212
5715
  comparer_found = True
5213
5716
 
@@ -5418,20 +5921,9 @@ def check_equation(variables: list, equation: str) -> Union[str, bool]:
5418
5921
  def set_objectives() -> dict:
5419
5922
  objectives = {}
5420
5923
 
5421
- for rn in args.result_names:
5422
- key, value = "", ""
5423
-
5424
- if "=" in rn:
5425
- key, value = rn.split('=', 1)
5426
- else:
5427
- key = rn
5428
- value = ""
5429
-
5430
- if value not in ["min", "max"]:
5431
- if value:
5432
- print_yellow(f"Value '{value}' for --result_names {rn} is not a valid value. Must be min or max. Will be set to min.")
5433
-
5434
- value = "min"
5924
+ k = 0
5925
+ for key in arg_result_names:
5926
+ value = arg_result_min_or_max[k]
5435
5927
 
5436
5928
  _min = True
5437
5929
 
@@ -5439,6 +5931,7 @@ def set_objectives() -> dict:
5439
5931
  _min = False
5440
5932
 
5441
5933
  objectives[key] = ObjectiveProperties(minimize=_min)
5934
+ k = k + 1
5442
5935
 
5443
5936
  return objectives
5444
5937
 
@@ -5660,7 +6153,7 @@ def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_param
5660
6153
 
5661
6154
  replace_parameters_for_continued_jobs(args.parameter, cli_params_experiment_parameters)
5662
6155
 
5663
- ax_client.save_to_json_file(filepath=original_ax_client_file)
6156
+ save_ax_client_to_json_file(original_ax_client_file)
5664
6157
 
5665
6158
  load_original_generation_strategy(original_ax_client_file)
5666
6159
  load_ax_client_from_experiment_parameters()
@@ -5690,6 +6183,109 @@ def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_param
5690
6183
 
5691
6184
  return experiment_args, gpu_string, gpu_color
5692
6185
 
6186
+ def get_experiment_args_import_python_script() -> str:
6187
+
6188
+ return """from ax.service.ax_client import AxClient, ObjectiveProperties
6189
+ from ax.adapter.registry import Generators
6190
+ import random
6191
+
6192
+ """
6193
+
6194
+ def get_generate_and_test_random_function_str() -> str:
6195
+ raw_data_entries = ",\n ".join(
6196
+ f'"{name}": random.uniform(0, 1)' for name in arg_result_names
6197
+ )
6198
+
6199
+ return f"""
6200
+ def generate_and_test_random_parameters(n: int) -> None:
6201
+ for _ in range(n):
6202
+ print("======================================")
6203
+ parameters, trial_index = ax_client.get_next_trial()
6204
+ print("Trial Index:", trial_index)
6205
+ print("Suggested parameters:", parameters)
6206
+
6207
+ ax_client.complete_trial(
6208
+ trial_index=trial_index,
6209
+ raw_data={{
6210
+ {raw_data_entries}
6211
+ }}
6212
+ )
6213
+
6214
+ generate_and_test_random_parameters({args.num_random_steps + 1})
6215
+ """
6216
+
6217
+ def get_global_gs_string() -> str:
6218
+ seed_str = ""
6219
+ if args.seed is not None:
6220
+ seed_str = f"model_kwargs={{'seed': {args.seed}}},"
6221
+
6222
+ return f"""from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy
6223
+
6224
+ global_gs = GenerationStrategy(
6225
+ steps=[
6226
+ GenerationStep(
6227
+ generator=Generators.SOBOL,
6228
+ num_trials={args.num_random_steps},
6229
+ max_parallelism=5,
6230
+ {seed_str}
6231
+ ),
6232
+ GenerationStep(
6233
+ generator=Generators.{args.model},
6234
+ num_trials=-1,
6235
+ max_parallelism=5,
6236
+ ),
6237
+ ]
6238
+ )
6239
+ """
6240
+
6241
+ def get_debug_ax_client_str() -> str:
6242
+ return """
6243
+ ax_client = AxClient(
6244
+ verbose_logging=True,
6245
+ enforce_sequential_optimization=False,
6246
+ generation_strategy=global_gs
6247
+ )
6248
+ """
6249
+
6250
+ def write_ax_debug_python_code(experiment_args: dict) -> None:
6251
+ if args.generation_strategy:
6252
+ print_debug("Cannot write debug code for custom generation_strategy")
6253
+ return None
6254
+
6255
+ if args.model in uncontinuable_models:
6256
+ print_debug(f"Cannot write debug code for uncontinuable mode {args.model}")
6257
+ return None
6258
+
6259
+ python_code = python_code = get_experiment_args_import_python_script() + \
6260
+ get_global_gs_string() + \
6261
+ get_debug_ax_client_str() + \
6262
+ "experiment_args = " + pformat(experiment_args, width=120, compact=False) + \
6263
+ "\nax_client.create_experiment(**experiment_args)\n" + \
6264
+ get_generate_and_test_random_function_str()
6265
+
6266
+ file_path = f"{get_current_run_folder()}/debug.py"
6267
+
6268
+ try:
6269
+ print_debug(python_code)
6270
+ with open(file_path, "w", encoding="utf-8") as f:
6271
+ f.write(python_code)
6272
+ except Exception as e:
6273
+ print_red(f"Error while writing {file_path}: {e}")
6274
+
6275
+ return None
6276
+
6277
+ def create_ax_client_experiment(experiment_args: dict) -> None:
6278
+ if not ax_client:
6279
+ my_exit(101)
6280
+
6281
+ return None
6282
+
6283
+ write_ax_debug_python_code(experiment_args)
6284
+
6285
+ ax_client.create_experiment(**experiment_args)
6286
+
6287
+ return None
6288
+
5693
6289
  def create_new_experiment() -> Tuple[dict, str, str]:
5694
6290
  if ax_client is None:
5695
6291
  print_red("create_new_experiment: ax_client is None")
@@ -5719,7 +6315,7 @@ def create_new_experiment() -> Tuple[dict, str, str]:
5719
6315
  experiment_args = set_experiment_constraints(get_constraints(), experiment_args, experiment_parameters)
5720
6316
 
5721
6317
  try:
5722
- ax_client.create_experiment(**experiment_args)
6318
+ create_ax_client_experiment(experiment_args)
5723
6319
  new_metrics = [Metric(k) for k in arg_result_names if k not in ax_client.metric_names]
5724
6320
  ax_client.experiment.add_tracking_metrics(new_metrics)
5725
6321
  except AssertionError as error:
@@ -5733,7 +6329,7 @@ def create_new_experiment() -> Tuple[dict, str, str]:
5733
6329
 
5734
6330
  return experiment_args, gpu_string, gpu_color
5735
6331
 
5736
- def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[AxClient, dict, str, str]]:
6332
+ def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[dict, str, str]]:
5737
6333
  continue_previous_job = args.worker_generator_path or args.continue_previous_job
5738
6334
 
5739
6335
  check_ax_client()
@@ -5743,7 +6339,7 @@ def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict |
5743
6339
  else:
5744
6340
  experiment_args, gpu_string, gpu_color = create_new_experiment()
5745
6341
 
5746
- return ax_client, experiment_args, gpu_string, gpu_color
6342
+ return experiment_args, gpu_string, gpu_color
5747
6343
 
5748
6344
  def get_type_short(typename: str) -> str:
5749
6345
  if typename == "RangeParameter":
@@ -5802,7 +6398,6 @@ def parse_single_experiment_parameter_table(classic_params: Optional[Union[list,
5802
6398
  _upper = param["bounds"][1]
5803
6399
 
5804
6400
  _possible_int_lower = str(helpers.to_int_when_possible(_lower))
5805
- #print(f"name: {_name}, _possible_int_lower: {_possible_int_lower}, lower: {_lower}")
5806
6401
  _possible_int_upper = str(helpers.to_int_when_possible(_upper))
5807
6402
 
5808
6403
  rows.append([_name, _short_type, _possible_int_lower, _possible_int_upper, "", value_type, log_scale])
@@ -5879,19 +6474,62 @@ def print_ax_parameter_constraints_table(experiment_args: dict) -> None:
5879
6474
 
5880
6475
  return None
5881
6476
 
6477
+ def check_base_for_print_overview() -> Optional[bool]:
6478
+ if args.continue_previous_job is not None and arg_result_names is not None and len(arg_result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
6479
+ print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
6480
+
6481
+ if ax_client is None:
6482
+ print_red("ax_client was None")
6483
+ return None
6484
+
6485
+ if ax_client.experiment is None:
6486
+ print_red("ax_client.experiment was None")
6487
+ return None
6488
+
6489
+ if ax_client.experiment.optimization_config is None:
6490
+ print_red("ax_client.experiment.optimization_config was None")
6491
+ return None
6492
+
6493
+ return True
6494
+
6495
+ def get_config_objectives() -> Any:
6496
+ if not ax_client:
6497
+ print_red("create_new_experiment: ax_client is None")
6498
+ my_exit(101)
6499
+
6500
+ return None
6501
+
6502
+ config_objectives = None
6503
+
6504
+ if ax_client.experiment and ax_client.experiment.optimization_config:
6505
+ opt_config = ax_client.experiment.optimization_config
6506
+ if opt_config.is_moo_problem:
6507
+ objective = getattr(opt_config, "objective", None)
6508
+ if objective and getattr(objective, "objectives", None) is not None:
6509
+ config_objectives = objective.objectives
6510
+ else:
6511
+ print_debug("ax_client.experiment.optimization_config.objective was None")
6512
+ else:
6513
+ config_objectives = [opt_config.objective]
6514
+ else:
6515
+ print_debug("ax_client.experiment or optimization_config was None")
6516
+
6517
+ return config_objectives
6518
+
5882
6519
  def print_result_names_overview_table() -> None:
5883
6520
  if not ax_client:
5884
6521
  _fatal_error("Tried to access ax_client in print_result_names_overview_table, but it failed, because the ax_client was not defined.", 101)
5885
6522
 
5886
6523
  return None
5887
6524
 
5888
- if args.continue_previous_job is not None and args.result_names is not None and len(args.result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
5889
- print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
6525
+ if check_base_for_print_overview() is None:
6526
+ return None
5890
6527
 
5891
- if ax_client.experiment.optimization_config.is_moo_problem:
5892
- config_objectives = ax_client.experiment.optimization_config.objective.objectives
5893
- else:
5894
- config_objectives = [ax_client.experiment.optimization_config.objective]
6528
+ config_objectives = get_config_objectives()
6529
+
6530
+ if config_objectives is None:
6531
+ print_red("config_objectives not found")
6532
+ return None
5895
6533
 
5896
6534
  res_names = []
5897
6535
  res_min_max = []
@@ -5986,11 +6624,13 @@ def print_overview_tables(classic_params: Optional[Union[list, dict]], experimen
5986
6624
  print_result_names_overview_table()
5987
6625
 
5988
6626
  def update_progress_bar(nr: int) -> None:
6627
+ log_data()
6628
+
5989
6629
  if progress_bar is not None:
5990
6630
  try:
5991
6631
  progress_bar.update(nr)
5992
6632
  except Exception as e:
5993
- print(f"Error updating progress bar: {e}")
6633
+ print_red(f"Error updating progress bar: {e}")
5994
6634
  else:
5995
6635
  print_red("update_progress_bar: progress_bar was None")
5996
6636
 
@@ -5998,19 +6638,17 @@ def get_current_model_name() -> str:
5998
6638
  if overwritten_to_random:
5999
6639
  return "Random*"
6000
6640
 
6641
+ gs_model = "unknown model"
6642
+
6001
6643
  if ax_client:
6002
6644
  try:
6003
6645
  if args.generation_strategy:
6004
- idx = getattr(ax_client.generation_strategy, "current_step_index", None)
6646
+ idx = getattr(global_gs, "current_step_index", None)
6005
6647
  if isinstance(idx, int):
6006
6648
  if 0 <= idx < len(generation_strategy_names):
6007
6649
  gs_model = generation_strategy_names[int(idx)]
6008
- else:
6009
- gs_model = "unknown model"
6010
- else:
6011
- gs_model = "unknown model"
6012
6650
  else:
6013
- gs_model = getattr(ax_client.generation_strategy, "current_node_name", "unknown model")
6651
+ gs_model = getattr(global_gs, "current_node_name", "unknown model")
6014
6652
 
6015
6653
  if gs_model:
6016
6654
  return str(gs_model)
@@ -6117,7 +6755,7 @@ def count_jobs_in_squeue() -> tuple[int, str]:
6117
6755
  global _last_count_time, _last_count_result
6118
6756
 
6119
6757
  now = int(time.time())
6120
- if _last_count_result != (0, "") and now - _last_count_time < 15:
6758
+ if _last_count_result != (0, "") and now - _last_count_time < 5:
6121
6759
  return _last_count_result
6122
6760
 
6123
6761
  _len = len(global_vars["jobs"])
@@ -6139,6 +6777,7 @@ def count_jobs_in_squeue() -> tuple[int, str]:
6139
6777
  check=True,
6140
6778
  text=True
6141
6779
  )
6780
+
6142
6781
  if "slurm_load_jobs error" in result.stderr:
6143
6782
  _last_count_result = (_len, "Detected slurm_load_jobs error in stderr.")
6144
6783
  _last_count_time = now
@@ -6179,6 +6818,8 @@ def log_worker_numbers() -> None:
6179
6818
  if len(WORKER_PERCENTAGE_USAGE) == 0 or WORKER_PERCENTAGE_USAGE[len(WORKER_PERCENTAGE_USAGE) - 1] != this_values:
6180
6819
  WORKER_PERCENTAGE_USAGE.append(this_values)
6181
6820
 
6821
+ write_worker_usage()
6822
+
6182
6823
  def get_slurm_in_brackets(in_brackets: list) -> list:
6183
6824
  if is_slurm_job():
6184
6825
  workers_strings = get_workers_string()
@@ -6263,6 +6904,8 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
6263
6904
  global last_progress_bar_desc
6264
6905
  global last_progress_bar_refresh_time
6265
6906
 
6907
+ log_data()
6908
+
6266
6909
  if isinstance(new_msgs, str):
6267
6910
  new_msgs = [new_msgs]
6268
6911
 
@@ -6280,7 +6923,7 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
6280
6923
  print_red("Cannot update progress bar! It is None.")
6281
6924
 
6282
6925
  def clean_completed_jobs() -> None:
6283
- job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail"]
6926
+ job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail", "finished"]
6284
6927
  job_states_to_be_ignored = ["ready", "completed", "unknown", "pending", "running", "completing", "out_of_memory", "requeued", "resv_del_hold"]
6285
6928
 
6286
6929
  for job, trial_index in global_vars["jobs"][:]:
@@ -6425,7 +7068,7 @@ def get_generation_node_for_index(
6425
7068
 
6426
7069
  return generation_node
6427
7070
  except Exception as e:
6428
- print(f"Error while get_generation_node_for_index: {e}")
7071
+ print_red(f"Error while get_generation_node_for_index: {e}")
6429
7072
  return "MANUAL"
6430
7073
 
6431
7074
  def _get_generation_node_for_index_index_valid(
@@ -6598,11 +7241,21 @@ def check_ax_client() -> None:
6598
7241
  if ax_client is None or not ax_client:
6599
7242
  _fatal_error("insert_job_into_ax_client: ax_client was not defined where it should have been", 101)
6600
7243
 
7244
+ def attach_ax_client_data(arm_params: dict) -> Optional[Tuple[Any, int]]:
7245
+ if not ax_client:
7246
+ my_exit(101)
7247
+
7248
+ return None
7249
+
7250
+ new_trial = ax_client.attach_trial(arm_params)
7251
+
7252
+ return new_trial
7253
+
6601
7254
  def attach_trial(arm_params: dict) -> Tuple[Any, int]:
6602
7255
  if ax_client is None:
6603
7256
  raise RuntimeError("attach_trial: ax_client was empty")
6604
7257
 
6605
- new_trial = ax_client.attach_trial(arm_params)
7258
+ new_trial = attach_ax_client_data(arm_params)
6606
7259
  if not isinstance(new_trial, tuple) or len(new_trial) < 2:
6607
7260
  raise RuntimeError("attach_trial didn't return the expected tuple")
6608
7261
  return new_trial
@@ -6633,7 +7286,7 @@ def complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[An
6633
7286
  is_ok = False
6634
7287
 
6635
7288
  if is_ok:
6636
- ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
7289
+ complete_ax_client_trial(trial_idx, result)
6637
7290
  update_status(__status, base_str, "Completed trial")
6638
7291
  else:
6639
7292
  print_debug("Empty job encountered")
@@ -7030,7 +7683,7 @@ def get_parameters_from_outfile(stdout_path: str) -> Union[None, dict, str]:
7030
7683
  if not args.tests:
7031
7684
  original_print(f"get_parameters_from_outfile: The file '{stdout_path}' was not found.")
7032
7685
  except Exception as e:
7033
- print(f"get_parameters_from_outfile: There was an error: {e}")
7686
+ print_red(f"get_parameters_from_outfile: There was an error: {e}")
7034
7687
 
7035
7688
  return None
7036
7689
 
@@ -7048,7 +7701,7 @@ def get_hostname_from_outfile(stdout_path: Optional[str]) -> Optional[str]:
7048
7701
  original_print(f"The file '{stdout_path}' was not found.")
7049
7702
  return None
7050
7703
  except Exception as e:
7051
- print(f"There was an error: {e}")
7704
+ print_red(f"There was an error: {e}")
7052
7705
  return None
7053
7706
 
7054
7707
  def add_to_global_error_list(msg: str) -> None:
@@ -7084,22 +7737,70 @@ def mark_trial_as_failed(trial_index: int, _trial: Any) -> None:
7084
7737
 
7085
7738
  return None
7086
7739
 
7087
- ax_client.log_trial_failure(trial_index=trial_index)
7740
+ log_ax_client_trial_failure(trial_index)
7088
7741
  _trial.mark_failed(unsafe=True)
7089
7742
  except ValueError as e:
7090
7743
  print_debug(f"mark_trial_as_failed error: {e}")
7091
7744
 
7092
7745
  return None
7093
7746
 
7094
- def check_valid_result(result: Union[None, list, int, float, tuple]) -> bool:
7747
+ def check_valid_result(result: Union[None, dict]) -> bool:
7095
7748
  possible_val_not_found_values = [
7096
7749
  VAL_IF_NOTHING_FOUND,
7097
7750
  -VAL_IF_NOTHING_FOUND,
7098
7751
  -99999999999999997168788049560464200849936328366177157906432,
7099
7752
  99999999999999997168788049560464200849936328366177157906432
7100
7753
  ]
7101
- values_to_check = result if isinstance(result, list) else [result]
7102
- return result is not None and all(r not in possible_val_not_found_values for r in values_to_check)
7754
+
7755
+ def flatten_values(obj: Any) -> Any:
7756
+ values = []
7757
+ try:
7758
+ if isinstance(obj, dict):
7759
+ for v in obj.values():
7760
+ values.extend(flatten_values(v))
7761
+ elif isinstance(obj, (list, tuple, set)):
7762
+ for v in obj:
7763
+ values.extend(flatten_values(v))
7764
+ else:
7765
+ values.append(obj)
7766
+ except Exception as e:
7767
+ print_red(f"Error while flattening values: {e}")
7768
+ return values
7769
+
7770
+ if result is None:
7771
+ return False
7772
+
7773
+ try:
7774
+ all_values = flatten_values(result)
7775
+ for val in all_values:
7776
+ if val in possible_val_not_found_values:
7777
+ return False
7778
+ return True
7779
+ except Exception as e:
7780
+ print_red(f"Error while checking result validity: {e}")
7781
+ return False
7782
+
7783
+ def update_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
7784
+ if not ax_client:
7785
+ my_exit(101)
7786
+
7787
+ return None
7788
+
7789
+ trial = get_trial_by_index(trial_idx)
7790
+
7791
+ trial.update_trial_data(raw_data=result)
7792
+
7793
+ return None
7794
+
7795
+ def complete_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
7796
+ if not ax_client:
7797
+ my_exit(101)
7798
+
7799
+ return None
7800
+
7801
+ ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
7802
+
7803
+ return None
7103
7804
 
7104
7805
  def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -> None:
7105
7806
  if ax_client is None:
@@ -7108,25 +7809,64 @@ def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -
7108
7809
 
7109
7810
  try:
7110
7811
  print_debug(f"Completing trial: {trial_index} with result: {raw_result}...")
7111
- ax_client.complete_trial(trial_index=trial_index, raw_data=raw_result)
7812
+ complete_ax_client_trial(trial_index, raw_result)
7112
7813
  print_debug(f"Completing trial: {trial_index} with result: {raw_result}... Done!")
7113
7814
  except ax.exceptions.core.UnsupportedError as e:
7114
7815
  if f"{e}":
7115
7816
  print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure. Trying to update trial...")
7116
- ax_client.update_trial_data(trial_index=trial_index, raw_data=raw_result)
7817
+ update_ax_client_trial(trial_index, raw_result)
7117
7818
  print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure... Done!")
7118
7819
  else:
7119
7820
  _fatal_error(f"Error completing trial: {e}", 234)
7120
7821
 
7121
7822
  return None
7122
7823
 
7123
- def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: Union[float, int, tuple]) -> None:
7824
+ def format_result_for_display(result: dict) -> str:
7825
+ def safe_float(v: Any) -> str:
7826
+ try:
7827
+ if v is None:
7828
+ return "None"
7829
+ if isinstance(v, (int, float)):
7830
+ if math.isnan(v):
7831
+ return "NaN"
7832
+ if math.isinf(v):
7833
+ return "∞" if v > 0 else "-∞"
7834
+ return f"{v:.6f}"
7835
+ return str(v)
7836
+ except Exception as e:
7837
+ return f"<error: {e}>"
7838
+
7839
+ try:
7840
+ if not isinstance(result, dict):
7841
+ return safe_float(result)
7842
+
7843
+ parts = []
7844
+ for key, val in result.items():
7845
+ try:
7846
+ if isinstance(val, (list, tuple)) and len(val) == 2:
7847
+ main, sem = val
7848
+ main_str = safe_float(main)
7849
+ if sem is not None:
7850
+ sem_str = safe_float(sem)
7851
+ parts.append(f"{key}: {main_str} (SEM: {sem_str})")
7852
+ else:
7853
+ parts.append(f"{key}: {main_str}")
7854
+ else:
7855
+ parts.append(f"{key}: {safe_float(val)}")
7856
+ except Exception as e:
7857
+ parts.append(f"{key}: <error: {e}>")
7858
+
7859
+ return ", ".join(parts)
7860
+ except Exception as e:
7861
+ return f"<error formatting result: {e}>"
7862
+
7863
+ def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: dict) -> None:
7124
7864
  print_debug(f"Marking trial {_trial} as completed")
7125
7865
  _trial.mark_completed(unsafe=True)
7126
7866
 
7127
7867
  succeeded_jobs(1)
7128
7868
 
7129
- progressbar_description(f"new result: {result}")
7869
+ progressbar_description(f"new result: {format_result_for_display(result)}")
7130
7870
  update_progress_bar(1)
7131
7871
 
7132
7872
  save_results_csv()
@@ -7140,8 +7880,8 @@ def _finish_job_core_helper_mark_failure(job: Any, trial_index: int, _trial: Any
7140
7880
  if job:
7141
7881
  try:
7142
7882
  progressbar_description("job_failed")
7143
- ax_client.log_trial_failure(trial_index=trial_index)
7144
- _trial.mark_failed(unsafe=True)
7883
+ log_ax_client_trial_failure(trial_index)
7884
+ mark_trial_as_failed(trial_index, _trial)
7145
7885
  except Exception as e:
7146
7886
  print_red(f"\nERROR while trying to mark job as failure: {e}")
7147
7887
  job.cancel()
@@ -7158,16 +7898,16 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7158
7898
  result = job.result()
7159
7899
  print_debug(f"finish_job_core: trial-index: {trial_index}, job.result(): {result}, state: {state_from_job(job)}")
7160
7900
 
7161
- raw_result = result
7162
- result_keys = list(result.keys())
7163
- result = result[result_keys[0]]
7164
7901
  this_jobs_finished += 1
7165
7902
 
7166
7903
  if ax_client:
7167
- _trial = ax_client.get_trial(trial_index)
7904
+ _trial = get_ax_client_trial(trial_index)
7905
+
7906
+ if _trial is None:
7907
+ return 0
7168
7908
 
7169
7909
  if check_valid_result(result):
7170
- _finish_job_core_helper_complete_trial(trial_index, raw_result)
7910
+ _finish_job_core_helper_complete_trial(trial_index, result)
7171
7911
 
7172
7912
  try:
7173
7913
  _finish_job_core_helper_mark_success(_trial, result)
@@ -7175,7 +7915,7 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7175
7915
  if len(arg_result_names) > 1 and count_done_jobs() > 1 and not job_calculate_pareto_front(get_current_run_folder(), True):
7176
7916
  print_red("job_calculate_pareto_front post job failed")
7177
7917
  except Exception as e:
7178
- print(f"ERROR in line {get_line_info()}: {e}")
7918
+ print_red(f"ERROR in line {get_line_info()}: {e}")
7179
7919
  else:
7180
7920
  _finish_job_core_helper_mark_failure(job, trial_index, _trial)
7181
7921
  else:
@@ -7184,6 +7924,8 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7184
7924
  print_debug(f"finish_job_core: removing job {job}, trial_index: {trial_index}")
7185
7925
  global_vars["jobs"].remove((job, trial_index))
7186
7926
 
7927
+ log_data()
7928
+
7187
7929
  force_live_share()
7188
7930
 
7189
7931
  return this_jobs_finished
@@ -7196,11 +7938,14 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
7196
7938
  if job:
7197
7939
  try:
7198
7940
  progressbar_description("job_failed")
7199
- _trial = ax_client.get_trial(trial_index)
7200
- ax_client.log_trial_failure(trial_index=trial_index)
7941
+ _trial = get_ax_client_trial(trial_index)
7942
+ if _trial is None:
7943
+ return None
7944
+
7945
+ log_ax_client_trial_failure(trial_index)
7201
7946
  mark_trial_as_failed(trial_index, _trial)
7202
7947
  except Exception as e:
7203
- print(f"ERROR in line {get_line_info()}: {e}")
7948
+ print_debug(f"ERROR in line {get_line_info()}: {e}")
7204
7949
  job.cancel()
7205
7950
  orchestrate_job(job, trial_index)
7206
7951
 
@@ -7215,10 +7960,12 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
7215
7960
 
7216
7961
  def _finish_previous_jobs_helper_handle_exception(job: Any, trial_index: int, error: Exception) -> int:
7217
7962
  if "None for metric" in str(error):
7218
- print_red(
7219
- f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string."
7220
- f"\nError: {error}\nJob-result: {job.result()}"
7221
- )
7963
+ err_msg = f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string.\nError: {error}\nJob-result: {job.result()}"
7964
+
7965
+ if count_done_jobs() == 0:
7966
+ print_red(err_msg)
7967
+ else:
7968
+ print_debug(err_msg)
7222
7969
  else:
7223
7970
  print_red(f"\n⚠ {error}")
7224
7971
 
@@ -7237,7 +7984,10 @@ def _finish_previous_jobs_helper_process_job(job: Any, trial_index: int, this_jo
7237
7984
  this_jobs_finished += _finish_previous_jobs_helper_handle_exception(job, trial_index, error)
7238
7985
  return this_jobs_finished
7239
7986
 
7240
- def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7987
+ def _finish_previous_jobs_helper_check_and_process(__args: Tuple[Any, int]) -> int:
7988
+ job, trial_index = __args
7989
+
7990
+ this_jobs_finished = 0
7241
7991
  if job is None:
7242
7992
  print_debug(f"finish_previous_jobs: job {job} is None")
7243
7993
  return this_jobs_finished
@@ -7250,10 +8000,6 @@ def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, t
7250
8000
 
7251
8001
  return this_jobs_finished
7252
8002
 
7253
- def _finish_previous_jobs_helper_wrapper(__args: Tuple[Any, int]) -> int:
7254
- job, trial_index = __args
7255
- return _finish_previous_jobs_helper_check_and_process(job, trial_index, 0)
7256
-
7257
8003
  def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7258
8004
  global JOBS_FINISHED
7259
8005
 
@@ -7266,13 +8012,10 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7266
8012
 
7267
8013
  jobs_copy = global_vars["jobs"][:]
7268
8014
 
7269
- if len(jobs_copy) > 0:
7270
- print_debug(f"jobs in finish_previous_jobs: {jobs_copy}")
7271
-
7272
- finishing_jobs_start_time = time.time()
8015
+ #finishing_jobs_start_time = time.time()
7273
8016
 
7274
8017
  with ThreadPoolExecutor() as finish_job_executor:
7275
- futures = [finish_job_executor.submit(_finish_previous_jobs_helper_wrapper, (job, trial_index)) for job, trial_index in jobs_copy]
8018
+ futures = [finish_job_executor.submit(_finish_previous_jobs_helper_check_and_process, (job, trial_index)) for job, trial_index in jobs_copy]
7276
8019
 
7277
8020
  for future in as_completed(futures):
7278
8021
  try:
@@ -7280,11 +8023,11 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7280
8023
  except Exception as e:
7281
8024
  print_red(f"⚠ Exception in parallel job handling: {e}")
7282
8025
 
7283
- finishing_jobs_end_time = time.time()
8026
+ #finishing_jobs_end_time = time.time()
7284
8027
 
7285
- finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
8028
+ #finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
7286
8029
 
7287
- print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
8030
+ #print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
7288
8031
 
7289
8032
  if this_jobs_finished > 0:
7290
8033
  save_results_csv()
@@ -7450,24 +8193,43 @@ def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
7450
8193
 
7451
8194
  print_debug(f"Done submitting new job, took {elapsed} seconds")
7452
8195
 
8196
+ log_data()
8197
+
7453
8198
  return new_job
7454
8199
 
8200
+ def get_ax_client_trial(trial_index: int) -> Optional[ax.core.trial.Trial]:
8201
+ if not ax_client:
8202
+ my_exit(101)
8203
+
8204
+ return None
8205
+
8206
+ try:
8207
+ log_data()
8208
+
8209
+ return ax_client.get_trial(trial_index)
8210
+ except KeyError:
8211
+ error_without_print(f"get_ax_client_trial: trial_index {trial_index} failed")
8212
+ return None
8213
+
7455
8214
  def orchestrator_start_trial(parameters: Union[dict, str], trial_index: int) -> None:
7456
8215
  if submitit_executor and ax_client:
7457
8216
  new_job = submit_new_job(parameters, trial_index)
7458
8217
  if new_job:
7459
8218
  submitted_jobs(1)
7460
8219
 
7461
- _trial = ax_client.get_trial(trial_index)
8220
+ _trial = get_ax_client_trial(trial_index)
7462
8221
 
7463
- try:
7464
- _trial.mark_staged(unsafe=True)
7465
- except Exception as e:
7466
- print_debug(f"orchestrator_start_trial: error {e}")
7467
- _trial.mark_running(unsafe=True, no_runner_required=True)
8222
+ if _trial is not None:
8223
+ try:
8224
+ _trial.mark_staged(unsafe=True)
8225
+ except Exception as e:
8226
+ print_debug(f"orchestrator_start_trial: error {e}")
8227
+ _trial.mark_running(unsafe=True, no_runner_required=True)
7468
8228
 
7469
- print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
7470
- global_vars["jobs"].append((new_job, trial_index))
8229
+ print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
8230
+ global_vars["jobs"].append((new_job, trial_index))
8231
+ else:
8232
+ print_red("Trial was none in orchestrator_start_trial")
7471
8233
  else:
7472
8234
  print_red("orchestrator_start_trial: Failed to start new job")
7473
8235
  elif ax_client:
@@ -7492,7 +8254,7 @@ def handle_restart(stdout_path: str, trial_index: int) -> None:
7492
8254
  if parameters:
7493
8255
  orchestrator_start_trial(parameters, trial_index)
7494
8256
  else:
7495
- print(f"Could not determine parameters from outfile {stdout_path} for restarting job")
8257
+ print_red(f"Could not determine parameters from outfile {stdout_path} for restarting job")
7496
8258
 
7497
8259
  def check_alternate_path(path: str) -> str:
7498
8260
  if os.path.exists(path):
@@ -7588,7 +8350,11 @@ def execute_evaluation(_params: list) -> Optional[int]:
7588
8350
 
7589
8351
  return None
7590
8352
 
7591
- _trial = ax_client.get_trial(trial_index)
8353
+ _trial = get_ax_client_trial(trial_index)
8354
+
8355
+ if _trial is None:
8356
+ error_without_print(f"execute_evaluation: _trial was not in execute_evaluation for params {_params}")
8357
+ return None
7592
8358
 
7593
8359
  def mark_trial_stage(stage: str, error_msg: str) -> None:
7594
8360
  try:
@@ -7656,7 +8422,7 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
7656
8422
  my_exit(144)
7657
8423
 
7658
8424
  if new_job is None:
7659
- print_red("handle_failed_job: job is None")
8425
+ print_debug("handle_failed_job: job is None")
7660
8426
 
7661
8427
  return None
7662
8428
 
@@ -7667,16 +8433,24 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
7667
8433
 
7668
8434
  return None
7669
8435
 
8436
+ def log_ax_client_trial_failure(trial_index: int) -> None:
8437
+ if not ax_client:
8438
+ my_exit(101)
8439
+
8440
+ return
8441
+
8442
+ ax_client.log_trial_failure(trial_index=trial_index)
8443
+
7670
8444
  def cancel_failed_job(trial_index: int, new_job: Job) -> None:
7671
8445
  print_debug("Trying to cancel job that failed")
7672
8446
  if new_job:
7673
8447
  try:
7674
8448
  if ax_client:
7675
- ax_client.log_trial_failure(trial_index=trial_index)
8449
+ log_ax_client_trial_failure(trial_index)
7676
8450
  else:
7677
8451
  _fatal_error("ax_client not defined", 101)
7678
8452
  except Exception as e:
7679
- print(f"ERROR in line {get_line_info()}: {e}")
8453
+ print_red(f"ERROR in line {get_line_info()}: {e}")
7680
8454
  new_job.cancel()
7681
8455
 
7682
8456
  print_debug(f"cancel_failed_job: removing job {new_job}, trial_index: {trial_index}")
@@ -7730,6 +8504,8 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
7730
8504
  _submitted_jobs = submitted_jobs()
7731
8505
  _failed_jobs = failed_jobs()
7732
8506
 
8507
+ log_data()
8508
+
7733
8509
  max_failed_jobs = max_eval
7734
8510
 
7735
8511
  if args.max_failed_jobs is not None and args.max_failed_jobs > 0:
@@ -7761,7 +8537,7 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
7761
8537
 
7762
8538
  return _ret
7763
8539
 
7764
- def _calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
8540
+ def calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
7765
8541
  """Calculates the number of jobs to retrieve."""
7766
8542
  return min(
7767
8543
  max_eval + simulated_jobs - count_done_jobs(),
@@ -7774,7 +8550,7 @@ def remove_extra_spaces(text: str) -> str:
7774
8550
  raise ValueError("Input must be a string")
7775
8551
  return re.sub(r'\s+', ' ', text).strip()
7776
8552
 
7777
- def _get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
8553
+ def get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
7778
8554
  """Generates the appropriate message for the number of trials being retrieved."""
7779
8555
  ret = ""
7780
8556
  if full_nr_of_jobs_to_get > 1:
@@ -7859,51 +8635,52 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
7859
8635
  print_red("get_batched_arms: ax_client was None")
7860
8636
  return []
7861
8637
 
7862
- # Experiment-Status laden
7863
8638
  load_experiment_state()
7864
8639
 
7865
- while len(batched_arms) != nr_of_jobs_to_get:
8640
+ while len(batched_arms) < nr_of_jobs_to_get:
7866
8641
  if attempts > args.max_attempts_for_generation:
7867
8642
  print_debug(f"get_batched_arms: Stopped after {attempts} attempts: could not generate enough arms "
7868
8643
  f"(got {len(batched_arms)} out of {nr_of_jobs_to_get}).")
7869
8644
  break
7870
8645
 
7871
- remaining = nr_of_jobs_to_get - len(batched_arms)
7872
- print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting {remaining} more arm(s).")
7873
-
7874
- print_debug("get pending observations")
7875
- t0 = time.time()
7876
- pending_observations = get_pending_observation_features(experiment=ax_client.experiment)
7877
- dt = time.time() - t0
7878
- print_debug(f"got pending observations: {pending_observations} (took {dt:.2f} seconds)")
8646
+ #print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting 1 more arm")
7879
8647
 
7880
- print_debug("getting global_gs.gen()")
7881
- batched_generator_run = global_gs.gen(
8648
+ pending_observations = get_pending_observation_features(
7882
8649
  experiment=ax_client.experiment,
7883
- n=remaining,
7884
- pending_observations=pending_observations
8650
+ include_out_of_design_points=True
7885
8651
  )
7886
- print_debug(f"got global_gs.gen(): {batched_generator_run}")
8652
+
8653
+ try:
8654
+ #print_debug("getting global_gs.gen() with n=1")
8655
+
8656
+ batched_generator_run: Any = global_gs.gen(
8657
+ experiment=ax_client.experiment,
8658
+ n=1,
8659
+ pending_observations=pending_observations,
8660
+ )
8661
+ print_debug(f"got global_gs.gen(): {batched_generator_run}")
8662
+ except Exception as e:
8663
+ print_debug(f"global_gs.gen failed: {e}")
8664
+ traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
8665
+ break
7887
8666
 
7888
8667
  depth = 0
7889
8668
  path = "batched_generator_run"
7890
- while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) > 0:
7891
- print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
8669
+ while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) == 1:
8670
+ #print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
7892
8671
  batched_generator_run = batched_generator_run[0]
7893
8672
  path += "[0]"
7894
8673
  depth += 1
7895
8674
 
7896
- print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
8675
+ #print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
7897
8676
 
7898
- print_debug("got new arms")
7899
- new_arms = batched_generator_run.arms
7900
- print_debug(f"new_arms: {new_arms}")
8677
+ new_arms = getattr(batched_generator_run, "arms", [])
7901
8678
  if not new_arms:
7902
8679
  print_debug("get_batched_arms: No new arms were generated in this attempt.")
7903
8680
  else:
7904
- print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), wanted {nr_of_jobs_to_get}.")
8681
+ print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), now at {len(batched_arms) + len(new_arms)} of {nr_of_jobs_to_get}.")
8682
+ batched_arms.extend(new_arms)
7905
8683
 
7906
- batched_arms.extend(new_arms)
7907
8684
  attempts += 1
7908
8685
 
7909
8686
  print_debug(f"get_batched_arms: Finished with {len(batched_arms)} arm(s) after {attempts} attempt(s).")
@@ -7942,16 +8719,14 @@ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
7942
8719
  retries += 1
7943
8720
  continue
7944
8721
 
7945
- print_debug(f"Fetching trial {cnt + 1}/{n}...")
7946
- progressbar_description(_get_trials_message(cnt + 1, n, trial_durations))
8722
+ progressbar_description(get_trials_message(cnt + 1, n, trial_durations))
7947
8723
 
7948
8724
  try:
7949
8725
  result = create_and_handle_trial(arm)
7950
8726
  if result is not None:
7951
8727
  trial_index, trial_duration, trial_successful = result
7952
-
7953
8728
  except TrialRejected as e:
7954
- print_debug(f"Trial rejected: {e}")
8729
+ print_debug(f"generate_trials: Trial rejected, error: {e}")
7955
8730
  retries += 1
7956
8731
  continue
7957
8732
 
@@ -7961,14 +8736,23 @@ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
7961
8736
  cnt += 1
7962
8737
  trials_dict[trial_index] = arm.parameters
7963
8738
 
7964
- return _finalize_generation(trials_dict, cnt, n, start_time)
8739
+ finalized = finalize_generation(trials_dict, cnt, n, start_time)
8740
+
8741
+ return finalized
7965
8742
 
7966
8743
  except Exception as e:
7967
- return _handle_generation_failure(e, n, recursion)
8744
+ return handle_generation_failure(e, n, recursion)
7968
8745
 
7969
8746
  class TrialRejected(Exception):
7970
8747
  pass
7971
8748
 
8749
+ def mark_abandoned(trial: Any, reason: str, trial_index: int) -> None:
8750
+ try:
8751
+ print_debug(f"[INFO] Marking trial {trial.index} ({trial.arm.name}) as abandoned, trial-index: {trial_index}. Reason: {reason}")
8752
+ trial.mark_abandoned(reason)
8753
+ except Exception as e:
8754
+ print_red(f"[ERROR] Could not mark trial as abandoned: {e}")
8755
+
7972
8756
  def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7973
8757
  if ax_client is None:
7974
8758
  print_red("ax_client is None in create_and_handle_trial")
@@ -7994,7 +8778,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
7994
8778
  arm = trial.arms[0]
7995
8779
  if deduplicated_arm(arm):
7996
8780
  print_debug(f"Duplicated arm: {arm}")
7997
- trial.mark_abandoned(reason="Duplication detected")
8781
+ mark_abandoned(trial, "Duplication detected", trial_index)
7998
8782
  raise TrialRejected("Duplicate arm.")
7999
8783
 
8000
8784
  arms_by_name_for_deduplication[arm.name] = arm
@@ -8003,7 +8787,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
8003
8787
 
8004
8788
  if not has_no_post_generation_constraints_or_matches_constraints(post_generation_constraints, params):
8005
8789
  print_debug(f"Trial {trial_index} does not meet post-generation constraints. Marking abandoned. Params: {params}, constraints: {post_generation_constraints}")
8006
- trial.mark_abandoned(reason="Post-Generation-Constraint failed")
8790
+ mark_abandoned(trial, "Post-Generation-Constraint failed", trial_index)
8007
8791
  abandoned_trial_indices.append(trial_index)
8008
8792
  raise TrialRejected("Post-generation constraints not met.")
8009
8793
 
@@ -8011,7 +8795,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
8011
8795
  end = time.time()
8012
8796
  return trial_index, float(end - start), True
8013
8797
 
8014
- def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
8798
+ def finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
8015
8799
  total_time = time.time() - start_time
8016
8800
 
8017
8801
  log_gen_times.append(total_time)
@@ -8022,7 +8806,7 @@ def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int,
8022
8806
 
8023
8807
  return trials_dict, False
8024
8808
 
8025
- def _handle_generation_failure(
8809
+ def handle_generation_failure(
8026
8810
  e: Exception,
8027
8811
  requested: int,
8028
8812
  recursion: bool
@@ -8038,19 +8822,19 @@ def _handle_generation_failure(
8038
8822
  )):
8039
8823
  msg = str(e)
8040
8824
  if msg not in error_8_saved:
8041
- _print_exhaustion_warning(e, recursion)
8825
+ print_exhaustion_warning(e, recursion)
8042
8826
  error_8_saved.append(msg)
8043
8827
 
8044
8828
  if not recursion and args.revert_to_random_when_seemingly_exhausted:
8045
8829
  print_debug("Switching to random search strategy.")
8046
- set_global_gs_to_random()
8830
+ set_global_gs_to_sobol()
8047
8831
  return fetch_next_trials(requested, True)
8048
8832
 
8049
- print_red(f"_handle_generation_failure: General Exception: {e}")
8833
+ print_red(f"handle_generation_failure: General Exception: {e}")
8050
8834
 
8051
8835
  return {}, True
8052
8836
 
8053
- def _print_exhaustion_warning(e: Exception, recursion: bool) -> None:
8837
+ def print_exhaustion_warning(e: Exception, recursion: bool) -> None:
8054
8838
  if not recursion and args.revert_to_random_when_seemingly_exhausted:
8055
8839
  print_yellow(f"\n⚠Error 8: {e} From now (done jobs: {count_done_jobs()}) on, random points will be generated.")
8056
8840
  else:
@@ -8095,21 +8879,24 @@ def get_model_gen_kwargs() -> dict:
8095
8879
  "fit_out_of_design": args.fit_out_of_design
8096
8880
  }
8097
8881
 
8098
- def set_global_gs_to_random() -> None:
8882
+ def set_global_gs_to_sobol() -> None:
8099
8883
  global global_gs
8100
8884
  global overwritten_to_random
8101
8885
 
8886
+ print("Reverting to SOBOL")
8887
+
8102
8888
  global_gs = GenerationStrategy(
8103
8889
  name="Random*",
8104
8890
  nodes=[
8105
8891
  GenerationNode(
8106
- node_name="Sobol",
8107
- generator_specs=[
8108
- GeneratorSpec(
8109
- Models.SOBOL,
8110
- model_gen_kwargs=get_model_gen_kwargs()
8111
- )
8112
- ]
8892
+ name="Sobol",
8893
+ should_deduplicate=True,
8894
+ generator_specs=[ # type: ignore[arg-type]
8895
+ GeneratorSpec( # type: ignore[arg-type]
8896
+ Generators.SOBOL, # type: ignore[arg-type]
8897
+ model_gen_kwargs=get_model_gen_kwargs() # type: ignore[arg-type]
8898
+ ) # type: ignore[arg-type]
8899
+ ] # type: ignore[arg-type]
8113
8900
  )
8114
8901
  ]
8115
8902
  )
@@ -8494,10 +9281,10 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
8494
9281
 
8495
9282
  for s in splitted_by_comma:
8496
9283
  if "=" not in s:
8497
- print(f"'{s}' does not contain '='")
9284
+ print_red(f"'{s}' does not contain '='")
8498
9285
  my_exit(123)
8499
9286
  if s.count("=") != 1:
8500
- print(f"There can only be one '=' in the gen_strat_str's element '{s}'")
9287
+ print_red(f"There can only be one '=' in the gen_strat_str's element '{s}'")
8501
9288
  my_exit(123)
8502
9289
 
8503
9290
  model_name, nr_str = s.split("=")
@@ -8507,13 +9294,13 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
8507
9294
  _fatal_error(f"Model {matching_model} is not valid for custom generation strategy.", 56)
8508
9295
 
8509
9296
  if not matching_model:
8510
- print(f"'{model_name}' not found in SUPPORTED_MODELS")
9297
+ print_red(f"'{model_name}' not found in SUPPORTED_MODELS")
8511
9298
  my_exit(123)
8512
9299
 
8513
9300
  try:
8514
9301
  nr = int(nr_str)
8515
9302
  except ValueError:
8516
- print(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
9303
+ print_red(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
8517
9304
  my_exit(123)
8518
9305
 
8519
9306
  gen_strat_list.append({matching_model: nr})
@@ -8637,7 +9424,7 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
8637
9424
  if model_name == "TPE":
8638
9425
  if len(arg_result_names) != 1:
8639
9426
  _fatal_error(f"Has {len(arg_result_names)} results. TPE currently only supports single-objective-optimization.", 108)
8640
- return ExternalProgramGenerationNode(external_generator=f"python3 {script_dir}/.tpe.py", node_name="EXTERNAL_GENERATOR")
9427
+ return ExternalProgramGenerationNode(external_generator=f"python3 {script_dir}/.tpe.py", name="EXTERNAL_GENERATOR")
8641
9428
 
8642
9429
  external_generators = {
8643
9430
  "PSEUDORANDOM": f"python3 {script_dir}/.random_generator.py",
@@ -8651,10 +9438,10 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
8651
9438
  cmd = external_generators[model_name]
8652
9439
  if model_name == "EXTERNAL_GENERATOR" and not cmd:
8653
9440
  _fatal_error("--external_generator is missing. Cannot create points for EXTERNAL_GENERATOR without it.", 204)
8654
- return ExternalProgramGenerationNode(external_generator=cmd, node_name="EXTERNAL_GENERATOR")
9441
+ return ExternalProgramGenerationNode(external_generator=cmd, name="EXTERNAL_GENERATOR")
8655
9442
 
8656
9443
  trans_crit = [
8657
- MaxTrials(
9444
+ MinTrials(
8658
9445
  threshold=threshold,
8659
9446
  block_transition_if_unmet=True,
8660
9447
  transition_to=target_model,
@@ -8670,11 +9457,12 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
8670
9457
  if model_name.lower() != "sobol":
8671
9458
  kwargs["model_kwargs"] = get_model_kwargs()
8672
9459
 
8673
- model_spec = [GeneratorSpec(selected_model, **kwargs)]
9460
+ model_spec = [GeneratorSpec(selected_model, **kwargs)] # type: ignore[arg-type]
8674
9461
 
8675
9462
  res = GenerationNode(
8676
- node_name=model_name,
9463
+ name=model_name,
8677
9464
  generator_specs=model_spec,
9465
+ should_deduplicate=True,
8678
9466
  transition_criteria=trans_crit
8679
9467
  )
8680
9468
 
@@ -8685,7 +9473,7 @@ def get_optimizer_kwargs() -> dict:
8685
9473
  "sequential": False
8686
9474
  }
8687
9475
 
8688
- def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = None) -> GenerationStep:
9476
+ def create_step(model_name: str, _num_trials: int, index: int) -> GenerationStep:
8689
9477
  model_enum = get_model_from_name(model_name)
8690
9478
 
8691
9479
  return GenerationStep(
@@ -8700,17 +9488,16 @@ def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = N
8700
9488
  )
8701
9489
 
8702
9490
  def set_global_generation_strategy() -> None:
8703
- with spinner("Setting global generation strategy"):
8704
- continue_not_supported_on_custom_generation_strategy()
9491
+ continue_not_supported_on_custom_generation_strategy()
8705
9492
 
8706
- try:
8707
- if args.generation_strategy is None:
8708
- setup_default_generation_strategy()
8709
- else:
8710
- setup_custom_generation_strategy()
8711
- except Exception as e:
8712
- print_red(f"Unexpected error in generation strategy setup: {e}")
8713
- my_exit(111)
9493
+ try:
9494
+ if args.generation_strategy is None:
9495
+ setup_default_generation_strategy()
9496
+ else:
9497
+ setup_custom_generation_strategy()
9498
+ except Exception as e:
9499
+ print_red(f"Unexpected error in generation strategy setup: {e}")
9500
+ my_exit(111)
8714
9501
 
8715
9502
  if global_gs is None:
8716
9503
  print_red("global_gs is None after setup!")
@@ -8861,10 +9648,14 @@ def execute_trials(
8861
9648
  result = future.result()
8862
9649
  print_debug(f"result in execute_trials: {result}")
8863
9650
  except Exception as exc:
8864
- print_red(f"execute_trials: Error at executing a trial: {exc}")
9651
+ failed_args = future_to_args[future]
9652
+ print_red(f"execute_trials: Error at executing a trial with args {failed_args}: {exc}")
9653
+ traceback.print_exc()
8865
9654
 
8866
9655
  end_time = time.time()
8867
9656
 
9657
+ log_data()
9658
+
8868
9659
  duration = float(end_time - start_time)
8869
9660
  job_submit_durations.append(duration)
8870
9661
  job_submit_nrs.append(cnt)
@@ -8898,20 +9689,20 @@ def create_and_execute_next_runs(next_nr_steps: int, phase: Optional[str], _max_
8898
9689
  done_optimizing: bool = False
8899
9690
 
8900
9691
  try:
8901
- done_optimizing, trial_index_to_param = _create_and_execute_next_runs_run_loop(_max_eval, phase)
8902
- _create_and_execute_next_runs_finish(done_optimizing)
9692
+ done_optimizing, trial_index_to_param = create_and_execute_next_runs_run_loop(_max_eval, phase)
9693
+ create_and_execute_next_runs_finish(done_optimizing)
8903
9694
  except Exception as e:
8904
9695
  stacktrace = traceback.format_exc()
8905
9696
  print_debug(f"Warning: create_and_execute_next_runs encountered an exception: {e}\n{stacktrace}")
8906
9697
  return handle_exceptions_create_and_execute_next_runs(e)
8907
9698
 
8908
- return _create_and_execute_next_runs_return_value(trial_index_to_param)
9699
+ return create_and_execute_next_runs_return_value(trial_index_to_param)
8909
9700
 
8910
- def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
9701
+ def create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
8911
9702
  done_optimizing = False
8912
9703
  trial_index_to_param: Optional[Dict] = None
8913
9704
 
8914
- nr_of_jobs_to_get = _calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
9705
+ nr_of_jobs_to_get = calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
8915
9706
 
8916
9707
  __max_eval = _max_eval if _max_eval is not None else 0
8917
9708
  new_nr_of_jobs_to_get = min(__max_eval - (submitted_jobs() - failed_jobs()), nr_of_jobs_to_get)
@@ -8925,6 +9716,7 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
8925
9716
 
8926
9717
  for _ in range(range_nr):
8927
9718
  trial_index_to_param, done_optimizing = get_next_trials(get_next_trials_nr)
9719
+ log_data()
8928
9720
  if done_optimizing:
8929
9721
  continue
8930
9722
 
@@ -8949,13 +9741,13 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
8949
9741
 
8950
9742
  return done_optimizing, trial_index_to_param
8951
9743
 
8952
- def _create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
9744
+ def create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
8953
9745
  finish_previous_jobs(["finishing jobs"])
8954
9746
 
8955
9747
  if done_optimizing:
8956
9748
  end_program(False, 0)
8957
9749
 
8958
- def _create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
9750
+ def create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
8959
9751
  try:
8960
9752
  if trial_index_to_param:
8961
9753
  res = len(trial_index_to_param.keys())
@@ -9066,7 +9858,7 @@ def execute_nvidia_smi() -> None:
9066
9858
  if not host:
9067
9859
  print_debug("host not defined")
9068
9860
  except Exception as e:
9069
- print(f"execute_nvidia_smi: An error occurred: {e}")
9861
+ print_red(f"execute_nvidia_smi: An error occurred: {e}")
9070
9862
  if is_slurm_job() and not args.force_local_execution:
9071
9863
  _sleep(30)
9072
9864
 
@@ -9104,11 +9896,6 @@ def run_search() -> bool:
9104
9896
 
9105
9897
  return False
9106
9898
 
9107
- async def start_logging_daemon() -> None:
9108
- while True:
9109
- log_data()
9110
- time.sleep(30)
9111
-
9112
9899
  def should_break_search() -> bool:
9113
9900
  ret = False
9114
9901
 
@@ -9157,10 +9944,8 @@ def check_search_space_exhaustion(nr_of_items: int) -> bool:
9157
9944
  print_debug(_wrn)
9158
9945
  progressbar_description(_wrn)
9159
9946
 
9160
- live_share()
9161
9947
  return True
9162
9948
 
9163
- live_share()
9164
9949
  return False
9165
9950
 
9166
9951
  def finalize_jobs() -> None:
@@ -9174,7 +9959,7 @@ def finalize_jobs() -> None:
9174
9959
  handle_slurm_execution()
9175
9960
 
9176
9961
  def go_through_jobs_that_are_not_completed_yet() -> None:
9177
- print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
9962
+ #print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
9178
9963
 
9179
9964
  nr_jobs_left = len(global_vars['jobs'])
9180
9965
  if nr_jobs_left == 1:
@@ -9244,372 +10029,60 @@ def parse_orchestrator_file(_f: str, _test: bool = False) -> Union[dict, None]:
9244
10029
 
9245
10030
  return data
9246
10031
  except Exception as e:
9247
- print(f"Error while parse_experiment_parameters({_f}): {e}")
10032
+ print_red(f"Error while parse_experiment_parameters({_f}): {e}")
9248
10033
  else:
9249
10034
  print_red(f"{_f} could not be found")
9250
10035
 
9251
- return None
9252
-
9253
- def set_orchestrator() -> None:
9254
- with spinner("Setting orchestrator..."):
9255
- global orchestrator
9256
-
9257
- if args.orchestrator_file:
9258
- if SYSTEM_HAS_SBATCH:
9259
- orchestrator = parse_orchestrator_file(args.orchestrator_file, False)
9260
- else:
9261
- print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
9262
-
9263
- def check_if_has_random_steps() -> None:
9264
- if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
9265
- _fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
9266
-
9267
- def add_exclude_to_defective_nodes() -> None:
9268
- with spinner("Adding excluded nodes..."):
9269
- if args.exclude:
9270
- entries = [entry.strip() for entry in args.exclude.split(',')]
9271
-
9272
- for entry in entries:
9273
- count_defective_nodes(None, entry)
9274
-
9275
- def check_max_eval(_max_eval: int) -> None:
9276
- with spinner("Checking max_eval..."):
9277
- if not _max_eval:
9278
- _fatal_error("--max_eval needs to be set!", 19)
9279
-
9280
- def parse_parameters() -> Any:
9281
- cli_params_experiment_parameters = None
9282
- if args.parameter:
9283
- parse_experiment_parameters()
9284
- cli_params_experiment_parameters = experiment_parameters
9285
-
9286
- return cli_params_experiment_parameters
9287
-
9288
- def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
9289
- table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
9290
-
9291
- rows = _pareto_front_table_read_csv()
9292
- if not rows:
9293
- table.add_column("No data found")
9294
- return table
9295
-
9296
- filtered_rows = _pareto_front_table_filter_rows(rows, idxs)
9297
- if not filtered_rows:
9298
- table.add_column("No matching entries")
9299
- return table
9300
-
9301
- param_cols, result_cols = _pareto_front_table_get_columns(filtered_rows[0])
9302
-
9303
- _pareto_front_table_add_headers(table, param_cols, result_cols)
9304
- _pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
9305
-
9306
- return table
9307
-
9308
- def _pareto_front_table_read_csv() -> List[Dict[str, str]]:
9309
- with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
9310
- return list(csv.DictReader(f))
9311
-
9312
- def _pareto_front_table_filter_rows(rows: List[Dict[str, str]], idxs: List[int]) -> List[Dict[str, str]]:
9313
- result = []
9314
- for row in rows:
9315
- try:
9316
- trial_index = int(row["trial_index"])
9317
- except (KeyError, ValueError):
9318
- continue
9319
-
9320
- if row.get("trial_status", "").strip().upper() == "COMPLETED" and trial_index in idxs:
9321
- result.append(row)
9322
- return result
9323
-
9324
- def _pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
9325
- all_columns = list(first_row.keys())
9326
- ignored_cols = set(special_col_names) - {"trial_index"}
9327
-
9328
- param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
9329
- result_cols = [col for col in arg_result_names if col in all_columns]
9330
- return param_cols, result_cols
9331
-
9332
- def _pareto_front_table_add_headers(table: Table, param_cols: List[str], result_cols: List[str]) -> None:
9333
- for col in param_cols:
9334
- table.add_column(col, justify="center")
9335
- for col in result_cols:
9336
- table.add_column(Text(f"{col}", style="cyan"), justify="center")
9337
-
9338
- def _pareto_front_table_add_rows(table: Table, rows: List[Dict[str, str]], param_cols: List[str], result_cols: List[str]) -> None:
9339
- for row in rows:
9340
- values = [str(helpers.to_int_when_possible(row[col])) for col in param_cols]
9341
- result_values = [Text(str(helpers.to_int_when_possible(row[col])), style="cyan") for col in result_cols]
9342
- table.add_row(*values, *result_values, style="bold green")
9343
-
9344
- def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
9345
- if not os.path.exists(RESULT_CSV_FILE):
9346
- print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
9347
- return None
9348
-
9349
- return create_pareto_front_table(idxs, metric_x, metric_y)
9350
-
9351
- def supports_sixel() -> bool:
9352
- term = os.environ.get("TERM", "").lower()
9353
- if "xterm" in term or "mlterm" in term:
9354
- return True
9355
-
9356
- try:
9357
- output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
9358
- if output.returncode == 0 and "sixel" in output.stdout.lower():
9359
- return True
9360
- except (subprocess.CalledProcessError, FileNotFoundError):
9361
- pass
9362
-
9363
- return False
9364
-
9365
- def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
9366
- if data is None:
9367
- print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
9368
- return
9369
-
9370
- if not supports_sixel():
9371
- print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
9372
- return
9373
-
9374
- import matplotlib.pyplot as plt
9375
-
9376
- means = data[x_metric][y_metric]["means"]
9377
-
9378
- x_values = means[x_metric]
9379
- y_values = means[y_metric]
9380
-
9381
- fig, _ax = plt.subplots()
9382
-
9383
- _ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
9384
-
9385
- _ax.set_xlabel(x_metric)
9386
- _ax.set_ylabel(y_metric)
9387
-
9388
- _ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
9389
-
9390
- _ax.ticklabel_format(style='plain', axis='both', useOffset=False)
9391
-
9392
- with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
9393
- plt.savefig(tmp_file.name, dpi=300)
9394
-
9395
- print_image_to_cli(tmp_file.name, 1000)
9396
-
9397
- plt.close(fig)
9398
-
9399
- def _pareto_front_general_validate_shapes(x: np.ndarray, y: np.ndarray) -> None:
9400
- if x.shape != y.shape:
9401
- raise ValueError("Input arrays x and y must have the same shape.")
9402
-
9403
- def _pareto_front_general_compare(
9404
- xi: float, yi: float, xj: float, yj: float,
9405
- x_minimize: bool, y_minimize: bool
9406
- ) -> bool:
9407
- x_better_eq = xj <= xi if x_minimize else xj >= xi
9408
- y_better_eq = yj <= yi if y_minimize else yj >= yi
9409
- x_strictly_better = xj < xi if x_minimize else xj > xi
9410
- y_strictly_better = yj < yi if y_minimize else yj > yi
9411
-
9412
- return bool(x_better_eq and y_better_eq and (x_strictly_better or y_strictly_better))
9413
-
9414
- def _pareto_front_general_find_dominated(
9415
- x: np.ndarray, y: np.ndarray, x_minimize: bool, y_minimize: bool
9416
- ) -> np.ndarray:
9417
- num_points = len(x)
9418
- is_dominated = np.zeros(num_points, dtype=bool)
9419
-
9420
- for i in range(num_points):
9421
- for j in range(num_points):
9422
- if i == j:
9423
- continue
9424
-
9425
- if _pareto_front_general_compare(x[i], y[i], x[j], y[j], x_minimize, y_minimize):
9426
- is_dominated[i] = True
9427
- break
9428
-
9429
- return is_dominated
9430
-
9431
- def pareto_front_general(
9432
- x: np.ndarray,
9433
- y: np.ndarray,
9434
- x_minimize: bool = True,
9435
- y_minimize: bool = True
9436
- ) -> np.ndarray:
9437
- try:
9438
- _pareto_front_general_validate_shapes(x, y)
9439
- is_dominated = _pareto_front_general_find_dominated(x, y, x_minimize, y_minimize)
9440
- return np.where(~is_dominated)[0]
9441
- except Exception as e:
9442
- print("Error in pareto_front_general:", str(e))
9443
- return np.array([], dtype=int)
9444
-
9445
- def _pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
9446
- results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
9447
- result_names_file = f"{path_to_calculate}/result_names.txt"
9448
-
9449
- if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
9450
- return None
9451
-
9452
- with open(result_names_file, mode="r", encoding="utf-8") as f:
9453
- result_names = [line.strip() for line in f if line.strip()]
9454
-
9455
- records: dict = defaultdict(lambda: {'means': {}})
9456
-
9457
- with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
9458
- reader = csv.DictReader(csvfile)
9459
- for row in reader:
9460
- trial_index = int(row['trial_index'])
9461
- arm_name = row['arm_name']
9462
- key = (trial_index, arm_name)
9463
-
9464
- for metric in result_names:
9465
- if metric in row:
9466
- try:
9467
- records[key]['means'][metric] = float(row[metric])
9468
- except ValueError:
9469
- continue
9470
-
9471
- return records
9472
-
9473
- def _pareto_front_filter_complete_points(
9474
- path_to_calculate: str,
9475
- records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
9476
- primary_name: str,
9477
- secondary_name: str
9478
- ) -> List[Tuple[Tuple[int, str], float, float]]:
9479
- points = []
9480
- for key, metrics in records.items():
9481
- means = metrics['means']
9482
- if primary_name in means and secondary_name in means:
9483
- x_val = means[primary_name]
9484
- y_val = means[secondary_name]
9485
- points.append((key, x_val, y_val))
9486
- if len(points) == 0:
9487
- raise ValueError(f"No full data points with both objectives found in {path_to_calculate}.")
9488
- return points
9489
-
9490
- def _pareto_front_transform_objectives(
9491
- points: List[Tuple[Any, float, float]],
9492
- primary_name: str,
9493
- secondary_name: str
9494
- ) -> Tuple[np.ndarray, np.ndarray]:
9495
- primary_idx = arg_result_names.index(primary_name)
9496
- secondary_idx = arg_result_names.index(secondary_name)
9497
-
9498
- x = np.array([p[1] for p in points])
9499
- y = np.array([p[2] for p in points])
9500
-
9501
- if arg_result_min_or_max[primary_idx] == "max":
9502
- x = -x
9503
- elif arg_result_min_or_max[primary_idx] != "min":
9504
- raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
9505
-
9506
- if arg_result_min_or_max[secondary_idx] == "max":
9507
- y = -y
9508
- elif arg_result_min_or_max[secondary_idx] != "min":
9509
- raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
9510
-
9511
- return x, y
9512
-
9513
- def _pareto_front_select_pareto_points(
9514
- x: np.ndarray,
9515
- y: np.ndarray,
9516
- x_minimize: bool,
9517
- y_minimize: bool,
9518
- points: List[Tuple[Any, float, float]],
9519
- num_points: int
9520
- ) -> List[Tuple[Any, float, float]]:
9521
- indices = pareto_front_general(x, y, x_minimize, y_minimize)
9522
- sorted_indices = indices[np.argsort(x[indices])]
9523
- sorted_indices = sorted_indices[:num_points]
9524
- selected_points = [points[i] for i in sorted_indices]
9525
- return selected_points
9526
-
9527
- def _pareto_front_build_return_structure(
9528
- path_to_calculate: str,
9529
- selected_points: List[Tuple[Any, float, float]],
9530
- records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
9531
- absolute_metrics: List[str],
9532
- primary_name: str,
9533
- secondary_name: str
9534
- ) -> dict:
9535
- results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
9536
- result_names_file = f"{path_to_calculate}/result_names.txt"
9537
-
9538
- with open(result_names_file, mode="r", encoding="utf-8") as f:
9539
- result_names = [line.strip() for line in f if line.strip()]
9540
-
9541
- csv_rows = {}
9542
- with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
9543
- reader = csv.DictReader(csvfile)
9544
- for row in reader:
9545
- trial_index = int(row['trial_index'])
9546
- csv_rows[trial_index] = row
9547
-
9548
- ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
9549
- ignored_columns.update(result_names)
9550
-
9551
- param_dicts = []
9552
- idxs = []
9553
- means_dict = defaultdict(list)
10036
+ return None
9554
10037
 
9555
- for (trial_index, arm_name), _, _ in selected_points:
9556
- row = csv_rows.get(trial_index, {})
9557
- if row == {} or row is None or row['arm_name'] != arm_name:
9558
- print_debug(f"_pareto_front_build_return_structure: trial_index '{trial_index}' could not be found and row returned as None")
9559
- continue
10038
+ def set_orchestrator() -> None:
10039
+ with spinner("Setting orchestrator..."):
10040
+ global orchestrator
9560
10041
 
9561
- idxs.append(int(row["trial_index"]))
10042
+ if args.orchestrator_file:
10043
+ if SYSTEM_HAS_SBATCH:
10044
+ orchestrator = parse_orchestrator_file(args.orchestrator_file, False)
10045
+ else:
10046
+ print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
9562
10047
 
9563
- param_dict: dict[str, int | float | str] = {}
9564
- for key, value in row.items():
9565
- if key not in ignored_columns:
9566
- try:
9567
- param_dict[key] = int(value)
9568
- except ValueError:
9569
- try:
9570
- param_dict[key] = float(value)
9571
- except ValueError:
9572
- param_dict[key] = value
10048
+ def check_if_has_random_steps() -> None:
10049
+ if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
10050
+ _fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
9573
10051
 
9574
- param_dicts.append(param_dict)
10052
+ def add_exclude_to_defective_nodes() -> None:
10053
+ with spinner("Adding excluded nodes..."):
10054
+ if args.exclude:
10055
+ entries = [entry.strip() for entry in args.exclude.split(',')]
9575
10056
 
9576
- for metric in absolute_metrics:
9577
- means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
10057
+ for entry in entries:
10058
+ count_defective_nodes(None, entry)
9578
10059
 
9579
- ret = {
9580
- primary_name: {
9581
- secondary_name: {
9582
- "absolute_metrics": absolute_metrics,
9583
- "param_dicts": param_dicts,
9584
- "means": dict(means_dict),
9585
- "idxs": idxs
9586
- },
9587
- "absolute_metrics": absolute_metrics
9588
- }
9589
- }
10060
+ def check_max_eval(_max_eval: int) -> None:
10061
+ with spinner("Checking max_eval..."):
10062
+ if not _max_eval:
10063
+ _fatal_error("--max_eval needs to be set!", 19)
9590
10064
 
9591
- return ret
10065
+ def parse_parameters() -> Any:
10066
+ cli_params_experiment_parameters = None
10067
+ if args.parameter:
10068
+ parse_experiment_parameters()
10069
+ cli_params_experiment_parameters = experiment_parameters
9592
10070
 
9593
- def get_pareto_frontier_points(
9594
- path_to_calculate: str,
9595
- primary_objective: str,
9596
- secondary_objective: str,
9597
- x_minimize: bool,
9598
- y_minimize: bool,
9599
- absolute_metrics: List[str],
9600
- num_points: int
9601
- ) -> Optional[dict]:
9602
- records = _pareto_front_aggregate_data(path_to_calculate)
10071
+ return cli_params_experiment_parameters
9603
10072
 
9604
- if records is None:
9605
- return None
10073
+ def supports_sixel() -> bool:
10074
+ term = os.environ.get("TERM", "").lower()
10075
+ if "xterm" in term or "mlterm" in term:
10076
+ return True
9606
10077
 
9607
- points = _pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
9608
- x, y = _pareto_front_transform_objectives(points, primary_objective, secondary_objective)
9609
- selected_points = _pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
9610
- result = _pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
10078
+ try:
10079
+ output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
10080
+ if output.returncode == 0 and "sixel" in output.stdout.lower():
10081
+ return True
10082
+ except (subprocess.CalledProcessError, FileNotFoundError):
10083
+ pass
9611
10084
 
9612
- return result
10085
+ return False
9613
10086
 
9614
10087
  def save_experiment_state() -> None:
9615
10088
  try:
@@ -9617,14 +10090,14 @@ def save_experiment_state() -> None:
9617
10090
  print_red("save_experiment_state: ax_client or ax_client.experiment is None, cannot save.")
9618
10091
  return
9619
10092
  state_path = get_current_run_folder("experiment_state.json")
9620
- ax_client.save_to_json_file(state_path)
10093
+ save_ax_client_to_json_file(state_path)
9621
10094
  except Exception as e:
9622
- print(f"Error saving experiment state: {e}")
10095
+ print_debug(f"Error saving experiment state: {e}")
9623
10096
 
9624
10097
  def wait_for_state_file(state_path: str, min_size: int = 5, max_wait_seconds: int = 60) -> bool:
9625
10098
  try:
9626
10099
  if not os.path.exists(state_path):
9627
- print(f"[ERROR] File '{state_path}' does not exist.")
10100
+ print_debug(f"[ERROR] File '{state_path}' does not exist.")
9628
10101
  return False
9629
10102
 
9630
10103
  i = 0
@@ -9843,38 +10316,47 @@ def get_result_minimize_flag(path_to_calculate: str, resname: str) -> bool:
9843
10316
 
9844
10317
  return minmax[index] == "min"
9845
10318
 
9846
- def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
9847
- pareto_front_data: dict = {}
10319
+ def post_job_calculate_pareto_front() -> None:
10320
+ if not args.calculate_pareto_front_of_job:
10321
+ return
9848
10322
 
9849
- all_combinations = list(combinations(range(len(arg_result_names)), 2))
10323
+ failure = False
9850
10324
 
9851
- skip = False
10325
+ _paths_to_calculate = []
9852
10326
 
9853
- for i, j in all_combinations:
9854
- if not skip:
9855
- metric_x = arg_result_names[i]
9856
- metric_y = arg_result_names[j]
10327
+ for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
10328
+ try:
10329
+ found_paths = find_results_paths(_path_to_calculate)
9857
10330
 
9858
- x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
9859
- y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
10331
+ for _fp in found_paths:
10332
+ if _fp not in _paths_to_calculate:
10333
+ _paths_to_calculate.append(_fp)
10334
+ except (FileNotFoundError, NotADirectoryError) as e:
10335
+ print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
9860
10336
 
9861
- try:
9862
- if metric_x not in pareto_front_data:
9863
- pareto_front_data[metric_x] = {}
10337
+ failure = True
9864
10338
 
9865
- pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
9866
- except ax.exceptions.core.DataRequiredError as e:
9867
- print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
9868
- except SignalINT:
9869
- print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
9870
- skip = True
10339
+ for _path_to_calculate in _paths_to_calculate:
10340
+ for path_to_calculate in found_paths:
10341
+ if not job_calculate_pareto_front(path_to_calculate):
10342
+ failure = True
9871
10343
 
9872
- return pareto_front_data
10344
+ if failure:
10345
+ my_exit(24)
10346
+
10347
+ my_exit(0)
10348
+
10349
+ def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
10350
+ if not os.path.exists(RESULT_CSV_FILE):
10351
+ print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
10352
+ return None
10353
+
10354
+ return create_pareto_front_table(idxs, metric_x, metric_y)
9873
10355
 
9874
10356
  def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_sixel_and_table: bool = False) -> None:
9875
10357
  if len(res_names) <= 1:
9876
10358
  print_debug(f"--result_names (has {len(res_names)} entries) must be at least 2.")
9877
- return
10359
+ return None
9878
10360
 
9879
10361
  pareto_front_data: dict = get_pareto_front_data(path_to_calculate, res_names)
9880
10362
 
@@ -9895,8 +10377,16 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
9895
10377
  else:
9896
10378
  print(f"Not showing Pareto-front-sixel for {path_to_calculate}")
9897
10379
 
9898
- if len(calculated_frontier[metric_x][metric_y]["idxs"]):
9899
- pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
10380
+ if calculated_frontier is None:
10381
+ print_debug("ERROR: calculated_frontier is None")
10382
+ return None
10383
+
10384
+ try:
10385
+ if len(calculated_frontier[metric_x][metric_y]["idxs"]):
10386
+ pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
10387
+ except AttributeError:
10388
+ print_debug(f"ERROR: calculated_frontier structure invalid for ({metric_x}, {metric_y})")
10389
+ return None
9900
10390
 
9901
10391
  rich_table = pareto_front_as_rich_table(
9902
10392
  calculated_frontier[metric_x][metric_y]["idxs"],
@@ -9921,6 +10411,8 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
9921
10411
 
9922
10412
  live_share_after_pareto()
9923
10413
 
10414
+ return None
10415
+
9924
10416
  def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_color: str) -> None:
9925
10417
  cpu_count = os.cpu_count()
9926
10418
 
@@ -9932,9 +10424,11 @@ def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_
9932
10424
  pass
9933
10425
 
9934
10426
  if gpu_string:
9935
- console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}] [green]{gs_string}[/green]")
10427
+ console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]")
9936
10428
  else:
9937
- print_green(f"You have {cpu_count} CPUs available for the main process. {gs_string}")
10429
+ print_green(f"You have {cpu_count} CPUs available for the main process.")
10430
+
10431
+ print_green(gs_string)
9938
10432
 
9939
10433
  def write_args_overview_table() -> None:
9940
10434
  table = Table(title="Arguments Overview")
@@ -10201,112 +10695,6 @@ def find_results_paths(base_path: str) -> list:
10201
10695
 
10202
10696
  return list(set(found_paths))
10203
10697
 
10204
- def post_job_calculate_pareto_front() -> None:
10205
- if not args.calculate_pareto_front_of_job:
10206
- return
10207
-
10208
- failure = False
10209
-
10210
- _paths_to_calculate = []
10211
-
10212
- for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
10213
- try:
10214
- found_paths = find_results_paths(_path_to_calculate)
10215
-
10216
- for _fp in found_paths:
10217
- if _fp not in _paths_to_calculate:
10218
- _paths_to_calculate.append(_fp)
10219
- except (FileNotFoundError, NotADirectoryError) as e:
10220
- print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
10221
-
10222
- failure = True
10223
-
10224
- for _path_to_calculate in _paths_to_calculate:
10225
- for path_to_calculate in found_paths:
10226
- if not job_calculate_pareto_front(path_to_calculate):
10227
- failure = True
10228
-
10229
- if failure:
10230
- my_exit(24)
10231
-
10232
- my_exit(0)
10233
-
10234
- def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
10235
- pf_start_time = time.time()
10236
-
10237
- if not path_to_calculate:
10238
- return False
10239
-
10240
- global CURRENT_RUN_FOLDER
10241
- global RESULT_CSV_FILE
10242
- global arg_result_names
10243
-
10244
- if not path_to_calculate:
10245
- print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
10246
- return False
10247
-
10248
- if not os.path.exists(path_to_calculate):
10249
- print_red(f"Path '{path_to_calculate}' does not exist")
10250
- return False
10251
-
10252
- ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
10253
-
10254
- if not os.path.exists(ax_client_json):
10255
- print_red(f"Path '{ax_client_json}' not found")
10256
- return False
10257
-
10258
- checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
10259
- if not os.path.exists(checkpoint_file):
10260
- print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
10261
- return False
10262
-
10263
- RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
10264
- if not os.path.exists(RESULT_CSV_FILE):
10265
- print_red(f"{RESULT_CSV_FILE} not found")
10266
- return False
10267
-
10268
- res_names = []
10269
-
10270
- res_names_file = f"{path_to_calculate}/result_names.txt"
10271
- if not os.path.exists(res_names_file):
10272
- print_red(f"File '{res_names_file}' does not exist")
10273
- return False
10274
-
10275
- try:
10276
- with open(res_names_file, "r", encoding="utf-8") as file:
10277
- lines = file.readlines()
10278
- except Exception as e:
10279
- print_red(f"Error reading file '{res_names_file}': {e}")
10280
- return False
10281
-
10282
- for line in lines:
10283
- entry = line.strip()
10284
- if entry != "":
10285
- res_names.append(entry)
10286
-
10287
- if len(res_names) < 2:
10288
- print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
10289
- return False
10290
-
10291
- load_username_to_args(path_to_calculate)
10292
-
10293
- CURRENT_RUN_FOLDER = path_to_calculate
10294
-
10295
- arg_result_names = res_names
10296
-
10297
- load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
10298
-
10299
- if experiment_parameters is None:
10300
- return False
10301
-
10302
- show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
10303
-
10304
- pf_end_time = time.time()
10305
-
10306
- print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
10307
-
10308
- return True
10309
-
10310
10698
  def set_arg_states_from_continue() -> None:
10311
10699
  if args.continue_previous_job and not args.num_random_steps:
10312
10700
  num_random_steps_file = f"{args.continue_previous_job}/state_files/num_random_steps"
@@ -10349,27 +10737,27 @@ def run_program_once(params: Optional[dict] = None) -> None:
10349
10737
  params = {}
10350
10738
 
10351
10739
  if isinstance(args.run_program_once, str):
10352
- command_str = args.run_program_once
10740
+ command_str = decode_if_base64(args.run_program_once)
10353
10741
  for k, v in params.items():
10354
10742
  placeholder = f"%({k})"
10355
10743
  command_str = command_str.replace(placeholder, str(v))
10356
10744
 
10357
- with spinner(f"Executing command: [cyan]{command_str}[/cyan]"):
10358
- result = subprocess.run(command_str, shell=True, check=True)
10359
- if result.returncode == 0:
10360
- console.log("[bold green]Setup script completed successfully ✅[/bold green]")
10361
- else:
10362
- console.log(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10745
+ print(f"Executing command: [cyan]{command_str}[/cyan]")
10746
+ result = subprocess.run(command_str, shell=True, check=True)
10747
+ if result.returncode == 0:
10748
+ print("[bold green]Setup script completed successfully ✅[/bold green]")
10749
+ else:
10750
+ print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10363
10751
 
10364
- my_exit(57)
10752
+ my_exit(57)
10365
10753
 
10366
10754
  elif isinstance(args.run_program_once, (list, tuple)):
10367
10755
  with spinner("run_program_once: Executing command list: [cyan]{args.run_program_once}[/cyan]"):
10368
10756
  result = subprocess.run(args.run_program_once, check=True)
10369
10757
  if result.returncode == 0:
10370
- console.log("[bold green]Setup script completed successfully ✅[/bold green]")
10758
+ print("[bold green]Setup script completed successfully ✅[/bold green]")
10371
10759
  else:
10372
- console.log(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10760
+ print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
10373
10761
 
10374
10762
  my_exit(57)
10375
10763
 
@@ -10387,8 +10775,14 @@ def show_omniopt_call() -> None:
10387
10775
 
10388
10776
  original_print(oo_call + " " + cleaned)
10389
10777
 
10778
+ if args.dependency is not None and args.dependency != "":
10779
+ print(f"Dependency: {args.dependency}")
10780
+
10781
+ if args.ui_url is not None and args.ui_url != "":
10782
+ print_yellow("--ui_url is deprecated. Do not use it anymore. It will be ignored and one day be removed.")
10783
+
10390
10784
  def main() -> None:
10391
- global RESULT_CSV_FILE, ax_client, LOGFILE_DEBUG_GET_NEXT_TRIALS
10785
+ global RESULT_CSV_FILE, LOGFILE_DEBUG_GET_NEXT_TRIALS
10392
10786
 
10393
10787
  check_if_has_random_steps()
10394
10788
 
@@ -10470,7 +10864,7 @@ def main() -> None:
10470
10864
  exp_params = get_experiment_parameters(cli_params_experiment_parameters)
10471
10865
 
10472
10866
  if exp_params is not None:
10473
- ax_client, experiment_args, gpu_string, gpu_color = exp_params
10867
+ experiment_args, gpu_string, gpu_color = exp_params
10474
10868
  print_debug(f"experiment_parameters: {experiment_parameters}")
10475
10869
 
10476
10870
  set_orchestrator()
@@ -10491,6 +10885,8 @@ def main() -> None:
10491
10885
 
10492
10886
  write_files_and_show_overviews()
10493
10887
 
10888
+ live_share()
10889
+
10494
10890
  #if args.continue_previous_job:
10495
10891
  # insert_jobs_from_csv(f"{args.continue_previous_job}/{RESULTS_CSV_FILENAME}")
10496
10892
 
@@ -10710,8 +11106,6 @@ def run_search_with_progress_bar() -> None:
10710
11106
  wait_for_jobs_to_complete()
10711
11107
 
10712
11108
  def complex_tests(_program_name: str, wanted_stderr: str, wanted_exit_code: int, wanted_signal: Union[int, None], res_is_none: bool = False) -> int:
10713
- #print_yellow(f"Test suite: {_program_name}")
10714
-
10715
11109
  nr_errors: int = 0
10716
11110
 
10717
11111
  program_path: str = f"./.tests/test_wronggoing_stuff.bin/bin/{_program_name}"
@@ -10785,7 +11179,7 @@ def test_find_paths(program_code: str) -> int:
10785
11179
  for i in files:
10786
11180
  if i not in string:
10787
11181
  if os.path.exists(i):
10788
- print("Missing {i} in find_file_paths string!")
11182
+ print(f"Missing {i} in find_file_paths string!")
10789
11183
  nr_errors += 1
10790
11184
 
10791
11185
  return nr_errors
@@ -11166,17 +11560,16 @@ Exit-Code: 159
11166
11560
 
11167
11561
  my_exit(nr_errors)
11168
11562
 
11169
- def main_outside() -> None:
11563
+ def main_wrapper() -> None:
11170
11564
  print(f"Run-UUID: {run_uuid}")
11171
11565
 
11172
11566
  auto_wrap_namespace(globals())
11173
11567
 
11174
11568
  print_logo()
11175
11569
 
11176
- start_logging_daemon() # type: ignore[unused-coroutine]
11177
-
11178
11570
  fool_linter(args.num_cpus_main_job)
11179
11571
  fool_linter(args.flame_graph)
11572
+ fool_linter(args.memray)
11180
11573
 
11181
11574
  with warnings.catch_warnings():
11182
11575
  warnings.simplefilter("ignore")
@@ -11233,6 +11626,9 @@ def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
11233
11626
  def auto_wrap_namespace(namespace: Any) -> Any:
11234
11627
  enable_beartype = any(os.getenv(v) for v in ("ENABLE_BEARTYPE", "CI"))
11235
11628
 
11629
+ if args.beartype:
11630
+ enable_beartype = True
11631
+
11236
11632
  excluded_functions = {
11237
11633
  "log_time_and_memory_wrapper",
11238
11634
  "collect_runtime_stats",
@@ -11241,7 +11637,6 @@ def auto_wrap_namespace(namespace: Any) -> Any:
11241
11637
  "_record_stats",
11242
11638
  "_open",
11243
11639
  "_check_memory_leak",
11244
- "start_logging_daemon",
11245
11640
  "get_current_run_folder",
11246
11641
  "show_func_name_wrapper"
11247
11642
  }
@@ -11267,7 +11662,7 @@ def auto_wrap_namespace(namespace: Any) -> Any:
11267
11662
 
11268
11663
  if __name__ == "__main__":
11269
11664
  try:
11270
- main_outside()
11665
+ main_wrapper()
11271
11666
  except (SignalUSR, SignalINT, SignalCONT) as e:
11272
- print_red(f"main_outside failed with exception {e}")
11667
+ print_red(f"main_wrapper failed with exception {e}")
11273
11668
  end_program(True)