omniopt2 8509__tar.gz → 8588__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omniopt2 might be problematic. Click here for more details.

Files changed (40) hide show
  1. {omniopt2-8509 → omniopt2-8588}/.helpers.py +0 -1
  2. {omniopt2-8509 → omniopt2-8588}/.omniopt.py +324 -76
  3. {omniopt2-8509 → omniopt2-8588}/PKG-INFO +2 -2
  4. {omniopt2-8509 → omniopt2-8588}/omniopt2.egg-info/PKG-INFO +2 -2
  5. {omniopt2-8509 → omniopt2-8588}/omniopt2.egg-info/requires.txt +1 -1
  6. {omniopt2-8509 → omniopt2-8588}/pyproject.toml +1 -1
  7. {omniopt2-8509 → omniopt2-8588}/requirements.txt +1 -1
  8. {omniopt2-8509 → omniopt2-8588}/.colorfunctions.sh +0 -0
  9. {omniopt2-8509 → omniopt2-8588}/.dockerignore +0 -0
  10. {omniopt2-8509 → omniopt2-8588}/.general.sh +0 -0
  11. {omniopt2-8509 → omniopt2-8588}/.gitignore +0 -0
  12. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_cpu_ram_usage.py +0 -0
  13. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_general.py +0 -0
  14. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_gpu_usage.py +0 -0
  15. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_kde.py +0 -0
  16. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_scatter.py +0 -0
  17. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_scatter_generation_method.py +0 -0
  18. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_scatter_hex.py +0 -0
  19. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_time_and_exit_code.py +0 -0
  20. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_trial_index_result.py +0 -0
  21. {omniopt2-8509 → omniopt2-8588}/.omniopt_plot_worker.py +0 -0
  22. {omniopt2-8509 → omniopt2-8588}/.random_generator.py +0 -0
  23. {omniopt2-8509 → omniopt2-8588}/.shellscript_functions +0 -0
  24. {omniopt2-8509 → omniopt2-8588}/.tests/pylint.rc +0 -0
  25. {omniopt2-8509 → omniopt2-8588}/.tpe.py +0 -0
  26. {omniopt2-8509 → omniopt2-8588}/LICENSE +0 -0
  27. {omniopt2-8509 → omniopt2-8588}/MANIFEST.in +0 -0
  28. {omniopt2-8509 → omniopt2-8588}/README.md +0 -0
  29. {omniopt2-8509 → omniopt2-8588}/apt-dependencies.txt +0 -0
  30. {omniopt2-8509 → omniopt2-8588}/omniopt +0 -0
  31. {omniopt2-8509 → omniopt2-8588}/omniopt2.egg-info/SOURCES.txt +0 -0
  32. {omniopt2-8509 → omniopt2-8588}/omniopt2.egg-info/dependency_links.txt +0 -0
  33. {omniopt2-8509 → omniopt2-8588}/omniopt2.egg-info/top_level.txt +0 -0
  34. {omniopt2-8509 → omniopt2-8588}/omniopt_docker +0 -0
  35. {omniopt2-8509 → omniopt2-8588}/omniopt_evaluate +0 -0
  36. {omniopt2-8509 → omniopt2-8588}/omniopt_plot +0 -0
  37. {omniopt2-8509 → omniopt2-8588}/omniopt_share +0 -0
  38. {omniopt2-8509 → omniopt2-8588}/setup.cfg +0 -0
  39. {omniopt2-8509 → omniopt2-8588}/setup.py +0 -0
  40. {omniopt2-8509 → omniopt2-8588}/test_requirements.txt +0 -0
@@ -10,7 +10,6 @@ try:
10
10
  import difflib
11
11
  import logging
12
12
  import os
13
- import platform
14
13
  import re
15
14
  import traceback
16
15
  import numpy as np
@@ -99,7 +99,7 @@ joined_valid_occ_types: str = ", ".join(valid_occ_types)
99
99
  SUPPORTED_MODELS: list = ["SOBOL", "FACTORIAL", "SAASBO", "BOTORCH_MODULAR", "UNIFORM", "BO_MIXED", "RANDOMFOREST", "EXTERNAL_GENERATOR", "PSEUDORANDOM", "TPE"]
100
100
  joined_supported_models: str = ", ".join(SUPPORTED_MODELS)
101
101
 
102
- special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid"]
102
+ special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid", "runtime", "status"]
103
103
 
104
104
  IGNORABLE_COLUMNS: list = ["start_time", "end_time", "hostname", "signal", "exit_code", "run_time", "program_string"] + special_col_names
105
105
 
@@ -204,6 +204,9 @@ try:
204
204
  with spinner("Importing rich.pretty..."):
205
205
  from rich.pretty import pprint
206
206
 
207
+ with spinner("Importing pformat..."):
208
+ from pprint import pformat
209
+
207
210
  with spinner("Importing rich.prompt..."):
208
211
  from rich.prompt import Prompt, FloatPrompt, IntPrompt
209
212
 
@@ -802,6 +805,7 @@ class ConfigLoader:
802
805
  parameter: Optional[List[str]]
803
806
  experiment_constraints: Optional[List[str]]
804
807
  main_process_gb: int
808
+ beartype: bool
805
809
  worker_timeout: int
806
810
  slurm_signal_delay_s: int
807
811
  gridsearch: bool
@@ -999,6 +1003,7 @@ class ConfigLoader:
999
1003
  debug.add_argument('--debug_stack_regex', help='Only print debug messages if call stack matches any regex', type=str, default='')
1000
1004
  debug.add_argument('--debug_stack_trace_regex', help='Show compact call stack with arrows if any function in stack matches regex', type=str, default=None)
1001
1005
  debug.add_argument('--show_func_name', help='Show func name before each execution and when it is done', action='store_true', default=False)
1006
+ debug.add_argument('--beartype', help='Use beartype', action='store_true', default=False)
1002
1007
 
1003
1008
  def load_config(self: Any, config_path: str, file_format: str) -> dict:
1004
1009
  if not os.path.isfile(config_path):
@@ -1351,19 +1356,6 @@ try:
1351
1356
  with spinner("Importing GeneratorSpec..."):
1352
1357
  from ax.generation_strategy.generator_spec import GeneratorSpec
1353
1358
 
1354
- #except Exception:
1355
- # with spinner("Fallback: Importing ax.generation_strategy.generation_node..."):
1356
- # import ax.generation_strategy.generation_node
1357
-
1358
- # with spinner("Fallback: Importing GenerationStep, GenerationStrategy from ax.generation_strategy..."):
1359
- # from ax.generation_strategy.generation_node import GenerationNode, GenerationStep
1360
-
1361
- # with spinner("Fallback: Importing ExternalGenerationNode..."):
1362
- # from ax.generation_strategy.external_generation_node import ExternalGenerationNode
1363
-
1364
- # with spinner("Fallback: Importing MaxTrials..."):
1365
- # from ax.generation_strategy.transition_criterion import MaxTrials
1366
-
1367
1359
  with spinner("Importing Models from ax.generation_strategy.registry..."):
1368
1360
  from ax.adapter.registry import Models
1369
1361
 
@@ -1471,6 +1463,9 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1471
1463
  def update_generator_state(self: Any, experiment: Experiment, data: Data) -> None:
1472
1464
  search_space = experiment.search_space
1473
1465
  parameter_names = list(search_space.parameters.keys())
1466
+ if experiment.optimization_config is None:
1467
+ print_red("Error: update_generator_state is None")
1468
+ return
1474
1469
  metric_names = list(experiment.optimization_config.metrics.keys())
1475
1470
 
1476
1471
  completed_trials = [
@@ -1482,7 +1477,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1482
1477
  y = np.zeros([num_completed_trials, 1])
1483
1478
 
1484
1479
  for t_idx, trial in enumerate(completed_trials):
1485
- trial_parameters = trial.arm.parameters
1480
+ trial_parameters = trial.arms[t_idx].parameters
1486
1481
  x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
1487
1482
  trial_df = data.df[data.df["trial_index"] == trial.index]
1488
1483
  y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]]["mean"].item()
@@ -1622,10 +1617,18 @@ class RandomForestGenerationNode(ExternalGenerationNode):
1622
1617
  def _format_best_sample(self: Any, best_sample: TParameterization, reverse_choice_map: dict) -> None:
1623
1618
  for name in best_sample.keys():
1624
1619
  param = self.parameters.get(name)
1620
+ best_sample_by_name = best_sample[name]
1621
+
1625
1622
  if isinstance(param, RangeParameter) and param.parameter_type == ParameterType.INT:
1626
- best_sample[name] = int(round(best_sample[name]))
1623
+ if best_sample_by_name is not None:
1624
+ best_sample[name] = int(round(float(best_sample_by_name)))
1625
+ else:
1626
+ print_debug("best_sample_by_name was empty")
1627
1627
  elif isinstance(param, ChoiceParameter):
1628
- best_sample[name] = str(reverse_choice_map.get(int(best_sample[name])))
1628
+ if best_sample_by_name is not None:
1629
+ best_sample[name] = str(reverse_choice_map.get(int(best_sample_by_name)))
1630
+ else:
1631
+ print_debug("best_sample_by_name was empty")
1629
1632
 
1630
1633
  decoder_registry["RandomForestGenerationNode"] = RandomForestGenerationNode
1631
1634
 
@@ -2121,7 +2124,11 @@ def try_saving_to_db() -> None:
2121
2124
  else:
2122
2125
  print_red("ax_client was not defined in try_saving_to_db")
2123
2126
  my_exit(101)
2124
- save_generation_strategy(global_gs)
2127
+
2128
+ if global_gs is not None:
2129
+ save_generation_strategy(global_gs)
2130
+ else:
2131
+ print_red("Not saving generation strategy: global_gs was empty")
2125
2132
  except Exception as e:
2126
2133
  print_debug(f"Failed trying to save sqlite3-DB: {e}")
2127
2134
 
@@ -2188,13 +2195,20 @@ def save_results_csv() -> Optional[str]:
2188
2195
  def get_results_paths() -> tuple[str, str]:
2189
2196
  return (get_current_run_folder(RESULTS_CSV_FILENAME), get_state_file_name('pd.json'))
2190
2197
 
2198
+ def ax_client_get_trials_data_frame() -> Optional[pd.DataFrame]:
2199
+ if not ax_client:
2200
+ my_exit(101)
2201
+
2202
+ return None
2203
+
2204
+ return ax_client.get_trials_data_frame()
2205
+
2191
2206
  def fetch_and_prepare_trials() -> Optional[pd.DataFrame]:
2192
2207
  if not ax_client:
2193
2208
  return None
2194
2209
 
2195
2210
  ax_client.experiment.fetch_data()
2196
- df = ax_client.get_trials_data_frame()
2197
-
2211
+ df = ax_client_get_trials_data_frame()
2198
2212
  #print("========================")
2199
2213
  #print("BEFORE merge_with_job_infos:")
2200
2214
  #print(df["generation_node"])
@@ -2211,11 +2225,24 @@ def write_csv(df: pd.DataFrame, path: str) -> None:
2211
2225
  pass
2212
2226
  df.to_csv(path, index=False, float_format="%.30f")
2213
2227
 
2228
+ def ax_client_to_json_snapshot() -> Optional[dict]:
2229
+ if not ax_client:
2230
+ my_exit(101)
2231
+
2232
+ return None
2233
+
2234
+ json_snapshot = ax_client.to_json_snapshot()
2235
+
2236
+ return json_snapshot
2237
+
2214
2238
  def write_json_snapshot(path: str) -> None:
2215
2239
  if ax_client is not None:
2216
- json_snapshot = ax_client.to_json_snapshot()
2217
- with open(path, "w", encoding="utf-8") as f:
2218
- json.dump(json_snapshot, f, indent=4)
2240
+ json_snapshot = ax_client_to_json_snapshot()
2241
+ if json_snapshot is not None:
2242
+ with open(path, "w", encoding="utf-8") as f:
2243
+ json.dump(json_snapshot, f, indent=4)
2244
+ else:
2245
+ print_debug('json_snapshot from ax_client_to_json_snapshot was None')
2219
2246
  else:
2220
2247
  print_red("write_json_snapshot: ax_client was None")
2221
2248
 
@@ -4945,7 +4972,10 @@ def abandon_job(job: Job, trial_index: int, reason: str) -> bool:
4945
4972
  if job:
4946
4973
  try:
4947
4974
  if ax_client:
4948
- _trial = ax_client.get_trial(trial_index)
4975
+ _trial = get_ax_client_trial(trial_index)
4976
+ if _trial is None:
4977
+ return False
4978
+
4949
4979
  _trial.mark_abandoned(reason=reason)
4950
4980
  print_debug(f"abandon_job: removing job {job}, trial_index: {trial_index}")
4951
4981
  global_vars["jobs"].remove((job, trial_index))
@@ -5037,6 +5067,16 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
5037
5067
 
5038
5068
  my_exit(_exit)
5039
5069
 
5070
+ def save_ax_client_to_json_file(checkpoint_filepath: str) -> None:
5071
+ if not ax_client:
5072
+ my_exit(101)
5073
+
5074
+ return None
5075
+
5076
+ ax_client.save_to_json_file(checkpoint_filepath)
5077
+
5078
+ return None
5079
+
5040
5080
  def save_checkpoint(trial_nr: int = 0, eee: Union[None, str, Exception] = None) -> None:
5041
5081
  if trial_nr > 3:
5042
5082
  if eee:
@@ -5049,7 +5089,7 @@ def save_checkpoint(trial_nr: int = 0, eee: Union[None, str, Exception] = None)
5049
5089
  checkpoint_filepath = get_state_file_name('checkpoint.json')
5050
5090
 
5051
5091
  if ax_client:
5052
- ax_client.save_to_json_file(filepath=checkpoint_filepath)
5092
+ save_ax_client_to_json_file(checkpoint_filepath)
5053
5093
  else:
5054
5094
  _fatal_error("Something went wrong using the ax_client", 101)
5055
5095
  except Exception as e:
@@ -5665,7 +5705,7 @@ def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_param
5665
5705
 
5666
5706
  replace_parameters_for_continued_jobs(args.parameter, cli_params_experiment_parameters)
5667
5707
 
5668
- ax_client.save_to_json_file(filepath=original_ax_client_file)
5708
+ save_ax_client_to_json_file(original_ax_client_file)
5669
5709
 
5670
5710
  load_original_generation_strategy(original_ax_client_file)
5671
5711
  load_ax_client_from_experiment_parameters()
@@ -5695,6 +5735,109 @@ def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_param
5695
5735
 
5696
5736
  return experiment_args, gpu_string, gpu_color
5697
5737
 
5738
+ def get_experiment_args_import_python_script() -> str:
5739
+
5740
+ return """from ax.service.ax_client import AxClient, ObjectiveProperties
5741
+ from ax.adapter.registry import Generators
5742
+ import random
5743
+
5744
+ """
5745
+
5746
+ def get_generate_and_test_random_function_str() -> str:
5747
+ raw_data_entries = ",\n ".join(
5748
+ f'"{name}": random.uniform(0, 1)' for name in arg_result_names
5749
+ )
5750
+
5751
+ return f"""
5752
+ def generate_and_test_random_parameters(n):
5753
+ for _ in range(n):
5754
+ print("======================================")
5755
+ parameters, trial_index = ax_client.get_next_trial()
5756
+ print("Trial Index:", trial_index)
5757
+ print("Suggested parameters:", parameters)
5758
+
5759
+ ax_client.complete_trial(
5760
+ trial_index=trial_index,
5761
+ raw_data={{
5762
+ {raw_data_entries}
5763
+ }}
5764
+ )
5765
+
5766
+ generate_and_test_random_parameters({args.num_random_steps + 1})
5767
+ """
5768
+
5769
+ def get_global_gs_string() -> str:
5770
+ seed_str = ""
5771
+ if args.seed is not None:
5772
+ seed_str = f"model_kwargs={{'seed': {args.seed}}},"
5773
+
5774
+ return f"""from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy
5775
+
5776
+ global_gs = GenerationStrategy(
5777
+ steps=[
5778
+ GenerationStep(
5779
+ generator=Generators.SOBOL,
5780
+ num_trials={args.num_random_steps},
5781
+ max_parallelism=5,
5782
+ {seed_str}
5783
+ ),
5784
+ GenerationStep(
5785
+ generator=Generators.{args.model},
5786
+ num_trials=-1,
5787
+ max_parallelism=5,
5788
+ ),
5789
+ ]
5790
+ )
5791
+ """
5792
+
5793
+ def get_debug_ax_client_str() -> str:
5794
+ return """
5795
+ ax_client = AxClient(
5796
+ verbose_logging=True,
5797
+ enforce_sequential_optimization=False,
5798
+ generation_strategy=global_gs
5799
+ )
5800
+ """
5801
+
5802
+ def write_ax_debug_python_code(experiment_args: dict) -> None:
5803
+ if args.generation_strategy:
5804
+ print_debug("Cannot write debug code for custom generation_strategy")
5805
+ return None
5806
+
5807
+ if args.model in uncontinuable_models:
5808
+ print_debug(f"Cannot write debug code for uncontinuable mode {args.model}")
5809
+ return None
5810
+
5811
+ python_code = python_code = get_experiment_args_import_python_script() + \
5812
+ get_global_gs_string() + \
5813
+ get_debug_ax_client_str() + \
5814
+ "experiment_args = " + pformat(experiment_args, width=120, compact=False) + \
5815
+ "\nax_client.create_experiment(**experiment_args)\n" + \
5816
+ get_generate_and_test_random_function_str()
5817
+
5818
+ file_path = f"{get_current_run_folder()}/debug.py"
5819
+
5820
+ try:
5821
+ print_debug(python_code)
5822
+ with open(file_path, "w", encoding="utf-8") as f:
5823
+ f.write(python_code)
5824
+ except Exception as e:
5825
+ print_red(f"Error while writing {file_path}: {e}")
5826
+
5827
+ return None
5828
+
5829
+ def create_ax_client_experiment(experiment_args: dict) -> None:
5830
+ if not ax_client:
5831
+ my_exit(101)
5832
+
5833
+ return None
5834
+
5835
+ write_ax_debug_python_code(experiment_args)
5836
+
5837
+ ax_client.create_experiment(**experiment_args)
5838
+
5839
+ return None
5840
+
5698
5841
  def create_new_experiment() -> Tuple[dict, str, str]:
5699
5842
  if ax_client is None:
5700
5843
  print_red("create_new_experiment: ax_client is None")
@@ -5724,7 +5867,7 @@ def create_new_experiment() -> Tuple[dict, str, str]:
5724
5867
  experiment_args = set_experiment_constraints(get_constraints(), experiment_args, experiment_parameters)
5725
5868
 
5726
5869
  try:
5727
- ax_client.create_experiment(**experiment_args)
5870
+ create_ax_client_experiment(experiment_args)
5728
5871
  new_metrics = [Metric(k) for k in arg_result_names if k not in ax_client.metric_names]
5729
5872
  ax_client.experiment.add_tracking_metrics(new_metrics)
5730
5873
  except AssertionError as error:
@@ -5738,7 +5881,7 @@ def create_new_experiment() -> Tuple[dict, str, str]:
5738
5881
 
5739
5882
  return experiment_args, gpu_string, gpu_color
5740
5883
 
5741
- def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[AxClient, dict, str, str]]:
5884
+ def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[dict, str, str]]:
5742
5885
  continue_previous_job = args.worker_generator_path or args.continue_previous_job
5743
5886
 
5744
5887
  check_ax_client()
@@ -5748,7 +5891,7 @@ def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict |
5748
5891
  else:
5749
5892
  experiment_args, gpu_string, gpu_color = create_new_experiment()
5750
5893
 
5751
- return ax_client, experiment_args, gpu_string, gpu_color
5894
+ return experiment_args, gpu_string, gpu_color
5752
5895
 
5753
5896
  def get_type_short(typename: str) -> str:
5754
5897
  if typename == "RangeParameter":
@@ -5884,19 +6027,62 @@ def print_ax_parameter_constraints_table(experiment_args: dict) -> None:
5884
6027
 
5885
6028
  return None
5886
6029
 
6030
+ def check_base_for_print_overview() -> Optional[bool]:
6031
+ if args.continue_previous_job is not None and arg_result_names is not None and len(arg_result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
6032
+ print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
6033
+
6034
+ if ax_client is None:
6035
+ print_red("ax_client was None")
6036
+ return None
6037
+
6038
+ if ax_client.experiment is None:
6039
+ print_red("ax_client.experiment was None")
6040
+ return None
6041
+
6042
+ if ax_client.experiment.optimization_config is None:
6043
+ print_red("ax_client.experiment.optimization_config was None")
6044
+ return None
6045
+
6046
+ return True
6047
+
6048
+ def get_config_objectives() -> Any:
6049
+ if not ax_client:
6050
+ print_red("create_new_experiment: ax_client is None")
6051
+ my_exit(101)
6052
+
6053
+ return None
6054
+
6055
+ config_objectives = None
6056
+
6057
+ if ax_client.experiment and ax_client.experiment.optimization_config:
6058
+ opt_config = ax_client.experiment.optimization_config
6059
+ if opt_config.is_moo_problem:
6060
+ objective = getattr(opt_config, "objective", None)
6061
+ if objective and getattr(objective, "objectives", None) is not None:
6062
+ config_objectives = objective.objectives
6063
+ else:
6064
+ print_debug("ax_client.experiment.optimization_config.objective was None")
6065
+ else:
6066
+ config_objectives = [opt_config.objective]
6067
+ else:
6068
+ print_debug("ax_client.experiment or optimization_config was None")
6069
+
6070
+ return config_objectives
6071
+
5887
6072
  def print_result_names_overview_table() -> None:
5888
6073
  if not ax_client:
5889
6074
  _fatal_error("Tried to access ax_client in print_result_names_overview_table, but it failed, because the ax_client was not defined.", 101)
5890
6075
 
5891
6076
  return None
5892
6077
 
5893
- if args.continue_previous_job is not None and arg_result_names is not None and len(arg_result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
5894
- print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
6078
+ if check_base_for_print_overview() is None:
6079
+ return None
5895
6080
 
5896
- if ax_client.experiment.optimization_config.is_moo_problem:
5897
- config_objectives = ax_client.experiment.optimization_config.objective.objectives
5898
- else:
5899
- config_objectives = [ax_client.experiment.optimization_config.objective]
6081
+ config_objectives = get_config_objectives()
6082
+
6083
+ if config_objectives is None:
6084
+ print_red("config_objectives not found")
6085
+ return None
5900
6086
 
5901
6087
  res_names = []
5902
6088
  res_min_max = []
@@ -6285,7 +6471,7 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
6285
6471
  print_red("Cannot update progress bar! It is None.")
6286
6472
 
6287
6473
  def clean_completed_jobs() -> None:
6288
- job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail"]
6474
+ job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail", "finished"]
6289
6475
  job_states_to_be_ignored = ["ready", "completed", "unknown", "pending", "running", "completing", "out_of_memory", "requeued", "resv_del_hold"]
6290
6476
 
6291
6477
  for job, trial_index in global_vars["jobs"][:]:
@@ -6603,11 +6789,21 @@ def check_ax_client() -> None:
6603
6789
  if ax_client is None or not ax_client:
6604
6790
  _fatal_error("insert_job_into_ax_client: ax_client was not defined where it should have been", 101)
6605
6791
 
6792
+ def attach_ax_client_data(arm_params: dict) -> Optional[Tuple[Any, int]]:
6793
+ if not ax_client:
6794
+ my_exit(101)
6795
+
6796
+ return None
6797
+
6798
+ new_trial = ax_client.attach_trial(arm_params)
6799
+
6800
+ return new_trial
6801
+
6606
6802
  def attach_trial(arm_params: dict) -> Tuple[Any, int]:
6607
6803
  if ax_client is None:
6608
6804
  raise RuntimeError("attach_trial: ax_client was empty")
6609
6805
 
6610
- new_trial = ax_client.attach_trial(arm_params)
6806
+ new_trial = attach_ax_client_data(arm_params)
6611
6807
  if not isinstance(new_trial, tuple) or len(new_trial) < 2:
6612
6808
  raise RuntimeError("attach_trial didn't return the expected tuple")
6613
6809
  return new_trial
@@ -6638,7 +6834,7 @@ def complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[An
6638
6834
  is_ok = False
6639
6835
 
6640
6836
  if is_ok:
6641
- ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
6837
+ complete_ax_client_trial(trial_idx, result)
6642
6838
  update_status(__status, base_str, "Completed trial")
6643
6839
  else:
6644
6840
  print_debug("Empty job encountered")
@@ -7089,7 +7285,7 @@ def mark_trial_as_failed(trial_index: int, _trial: Any) -> None:
7089
7285
 
7090
7286
  return None
7091
7287
 
7092
- ax_client.log_trial_failure(trial_index=trial_index)
7288
+ log_ax_client_trial_failure(trial_index)
7093
7289
  _trial.mark_failed(unsafe=True)
7094
7290
  except ValueError as e:
7095
7291
  print_debug(f"mark_trial_as_failed error: {e}")
@@ -7106,6 +7302,26 @@ def check_valid_result(result: Union[None, list, int, float, tuple]) -> bool:
7106
7302
  values_to_check = result if isinstance(result, list) else [result]
7107
7303
  return result is not None and all(r not in possible_val_not_found_values for r in values_to_check)
7108
7304
 
7305
+ def update_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
7306
+ if not ax_client:
7307
+ my_exit(101)
7308
+
7309
+ return None
7310
+
7311
+ ax_client.update_trial_data(trial_index=trial_idx, raw_data=result)
7312
+
7313
+ return None
7314
+
7315
+ def complete_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
7316
+ if not ax_client:
7317
+ my_exit(101)
7318
+
7319
+ return None
7320
+
7321
+ ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
7322
+
7323
+ return None
7324
+
7109
7325
  def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -> None:
7110
7326
  if ax_client is None:
7111
7327
  print_red("ax_client is not defined in _finish_job_core_helper_complete_trial")
@@ -7113,12 +7329,12 @@ def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -
7113
7329
 
7114
7330
  try:
7115
7331
  print_debug(f"Completing trial: {trial_index} with result: {raw_result}...")
7116
- ax_client.complete_trial(trial_index=trial_index, raw_data=raw_result)
7332
+ complete_ax_client_trial(trial_index, raw_result)
7117
7333
  print_debug(f"Completing trial: {trial_index} with result: {raw_result}... Done!")
7118
7334
  except ax.exceptions.core.UnsupportedError as e:
7119
7335
  if f"{e}":
7120
7336
  print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure. Trying to update trial...")
7121
- ax_client.update_trial_data(trial_index=trial_index, raw_data=raw_result)
7337
+ update_ax_client_trial(trial_index, raw_result)
7122
7338
  print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure... Done!")
7123
7339
  else:
7124
7340
  _fatal_error(f"Error completing trial: {e}", 234)
@@ -7145,8 +7361,8 @@ def _finish_job_core_helper_mark_failure(job: Any, trial_index: int, _trial: Any
7145
7361
  if job:
7146
7362
  try:
7147
7363
  progressbar_description("job_failed")
7148
- ax_client.log_trial_failure(trial_index=trial_index)
7149
- _trial.mark_failed(unsafe=True)
7364
+ log_ax_client_trial_failure(trial_index)
7365
+ mark_trial_as_failed(trial_index, _trial)
7150
7366
  except Exception as e:
7151
7367
  print_red(f"\nERROR while trying to mark job as failure: {e}")
7152
7368
  job.cancel()
@@ -7169,7 +7385,10 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7169
7385
  this_jobs_finished += 1
7170
7386
 
7171
7387
  if ax_client:
7172
- _trial = ax_client.get_trial(trial_index)
7388
+ _trial = get_ax_client_trial(trial_index)
7389
+
7390
+ if _trial is None:
7391
+ return 0
7173
7392
 
7174
7393
  if check_valid_result(result):
7175
7394
  _finish_job_core_helper_complete_trial(trial_index, raw_result)
@@ -7201,8 +7420,11 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
7201
7420
  if job:
7202
7421
  try:
7203
7422
  progressbar_description("job_failed")
7204
- _trial = ax_client.get_trial(trial_index)
7205
- ax_client.log_trial_failure(trial_index=trial_index)
7423
+ _trial = get_ax_client_trial(trial_index)
7424
+ if _trial is None:
7425
+ return None
7426
+
7427
+ log_ax_client_trial_failure(trial_index)
7206
7428
  mark_trial_as_failed(trial_index, _trial)
7207
7429
  except Exception as e:
7208
7430
  print(f"ERROR in line {get_line_info()}: {e}")
@@ -7242,7 +7464,10 @@ def _finish_previous_jobs_helper_process_job(job: Any, trial_index: int, this_jo
7242
7464
  this_jobs_finished += _finish_previous_jobs_helper_handle_exception(job, trial_index, error)
7243
7465
  return this_jobs_finished
7244
7466
 
7245
- def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, this_jobs_finished: int) -> int:
7467
+ def _finish_previous_jobs_helper_check_and_process(__args: Tuple[Any, int]) -> int:
7468
+ job, trial_index = __args
7469
+
7470
+ this_jobs_finished = 0
7246
7471
  if job is None:
7247
7472
  print_debug(f"finish_previous_jobs: job {job} is None")
7248
7473
  return this_jobs_finished
@@ -7255,10 +7480,6 @@ def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, t
7255
7480
 
7256
7481
  return this_jobs_finished
7257
7482
 
7258
- def _finish_previous_jobs_helper_wrapper(__args: Tuple[Any, int]) -> int:
7259
- job, trial_index = __args
7260
- return _finish_previous_jobs_helper_check_and_process(job, trial_index, 0)
7261
-
7262
7483
  def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7263
7484
  global JOBS_FINISHED
7264
7485
 
@@ -7274,7 +7495,7 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
7274
7495
  finishing_jobs_start_time = time.time()
7275
7496
 
7276
7497
  with ThreadPoolExecutor() as finish_job_executor:
7277
- futures = [finish_job_executor.submit(_finish_previous_jobs_helper_wrapper, (job, trial_index)) for job, trial_index in jobs_copy]
7498
+ futures = [finish_job_executor.submit(_finish_previous_jobs_helper_check_and_process, (job, trial_index)) for job, trial_index in jobs_copy]
7278
7499
 
7279
7500
  for future in as_completed(futures):
7280
7501
  try:
@@ -7454,22 +7675,33 @@ def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
7454
7675
 
7455
7676
  return new_job
7456
7677
 
7678
+ def get_ax_client_trial(trial_index: int) -> Optional[ax.core.trial.Trial]:
7679
+ if not ax_client:
7680
+ my_exit(101)
7681
+
7682
+ return None
7683
+
7684
+ return ax_client.get_trial(trial_index)
7685
+
7457
7686
  def orchestrator_start_trial(parameters: Union[dict, str], trial_index: int) -> None:
7458
7687
  if submitit_executor and ax_client:
7459
7688
  new_job = submit_new_job(parameters, trial_index)
7460
7689
  if new_job:
7461
7690
  submitted_jobs(1)
7462
7691
 
7463
- _trial = ax_client.get_trial(trial_index)
7692
+ _trial = get_ax_client_trial(trial_index)
7464
7693
 
7465
- try:
7466
- _trial.mark_staged(unsafe=True)
7467
- except Exception as e:
7468
- print_debug(f"orchestrator_start_trial: error {e}")
7469
- _trial.mark_running(unsafe=True, no_runner_required=True)
7694
+ if _trial is not None:
7695
+ try:
7696
+ _trial.mark_staged(unsafe=True)
7697
+ except Exception as e:
7698
+ print_debug(f"orchestrator_start_trial: error {e}")
7699
+ _trial.mark_running(unsafe=True, no_runner_required=True)
7470
7700
 
7471
- print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
7472
- global_vars["jobs"].append((new_job, trial_index))
7701
+ print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
7702
+ global_vars["jobs"].append((new_job, trial_index))
7703
+ else:
7704
+ print_red("Trial was none in orchestrator_start_trial")
7473
7705
  else:
7474
7706
  print_red("orchestrator_start_trial: Failed to start new job")
7475
7707
  elif ax_client:
@@ -7590,7 +7822,11 @@ def execute_evaluation(_params: list) -> Optional[int]:
7590
7822
 
7591
7823
  return None
7592
7824
 
7593
- _trial = ax_client.get_trial(trial_index)
7825
+ _trial = get_ax_client_trial(trial_index)
7826
+
7827
+ if _trial is None:
7828
+ print_red("_trial was not in execute_evaluation")
7829
+ return None
7594
7830
 
7595
7831
  def mark_trial_stage(stage: str, error_msg: str) -> None:
7596
7832
  try:
@@ -7669,12 +7905,20 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
7669
7905
 
7670
7906
  return None
7671
7907
 
7908
+ def log_ax_client_trial_failure(trial_index: int) -> None:
7909
+ if not ax_client:
7910
+ my_exit(101)
7911
+
7912
+ return
7913
+
7914
+ ax_client.log_trial_failure(trial_index=trial_index)
7915
+
7672
7916
  def cancel_failed_job(trial_index: int, new_job: Job) -> None:
7673
7917
  print_debug("Trying to cancel job that failed")
7674
7918
  if new_job:
7675
7919
  try:
7676
7920
  if ax_client:
7677
- ax_client.log_trial_failure(trial_index=trial_index)
7921
+ log_ax_client_trial_failure(trial_index)
7678
7922
  else:
7679
7923
  _fatal_error("ax_client not defined", 101)
7680
7924
  except Exception as e:
@@ -7874,11 +8118,11 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
7874
8118
  t0 = time.time()
7875
8119
  pending_observations = get_pending_observation_features(experiment=ax_client.experiment)
7876
8120
  dt = time.time() - t0
7877
- print_debug(f"got pending observations (took {dt:.2f} seconds)")
8121
+ print_debug(f"got pending observations: {pending_observations} (took {dt:.2f} seconds)")
7878
8122
 
7879
8123
  try:
7880
8124
  print_debug("getting global_gs.gen() with n=1")
7881
- batched_generator_run = global_gs.gen(
8125
+ batched_generator_run: Any = global_gs.gen(
7882
8126
  experiment=ax_client.experiment,
7883
8127
  n=1,
7884
8128
  pending_observations=pending_observations,
@@ -7886,11 +8130,12 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
7886
8130
  print_debug(f"got global_gs.gen(): {batched_generator_run}")
7887
8131
  except Exception as e:
7888
8132
  print_debug(f"global_gs.gen failed: {e}")
8133
+ traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
7889
8134
  break
7890
8135
 
7891
8136
  depth = 0
7892
8137
  path = "batched_generator_run"
7893
- while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) > 0:
8138
+ while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) == 1:
7894
8139
  print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
7895
8140
  batched_generator_run = batched_generator_run[0]
7896
8141
  path += "[0]"
@@ -8105,12 +8350,12 @@ def set_global_gs_to_random() -> None:
8105
8350
  nodes=[
8106
8351
  GenerationNode(
8107
8352
  node_name="Sobol",
8108
- generator_specs=[
8109
- GeneratorSpec(
8110
- Models.SOBOL,
8111
- model_gen_kwargs=get_model_gen_kwargs()
8112
- )
8113
- ]
8353
+ generator_specs=[ # type: ignore[arg-type]
8354
+ GeneratorSpec( # type: ignore[arg-type]
8355
+ Models.SOBOL, # type: ignore[arg-type]
8356
+ model_gen_kwargs=get_model_gen_kwargs() # type: ignore[arg-type]
8357
+ ) # type: ignore[arg-type]
8358
+ ] # type: ignore[arg-type]
8114
8359
  )
8115
8360
  ]
8116
8361
  )
@@ -8671,7 +8916,7 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
8671
8916
  if model_name.lower() != "sobol":
8672
8917
  kwargs["model_kwargs"] = get_model_kwargs()
8673
8918
 
8674
- model_spec = [GeneratorSpec(selected_model, **kwargs)]
8919
+ model_spec = [GeneratorSpec(selected_model, **kwargs)] # type: ignore[arg-type]
8675
8920
 
8676
8921
  res = GenerationNode(
8677
8922
  node_name=model_name,
@@ -8686,7 +8931,7 @@ def get_optimizer_kwargs() -> dict:
8686
8931
  "sequential": False
8687
8932
  }
8688
8933
 
8689
- def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = None) -> GenerationStep:
8934
+ def create_step(model_name: str, _num_trials: int, index: int) -> GenerationStep:
8690
8935
  model_enum = get_model_from_name(model_name)
8691
8936
 
8692
8937
  return GenerationStep(
@@ -9616,7 +9861,7 @@ def save_experiment_state() -> None:
9616
9861
  print_red("save_experiment_state: ax_client or ax_client.experiment is None, cannot save.")
9617
9862
  return
9618
9863
  state_path = get_current_run_folder("experiment_state.json")
9619
- ax_client.save_to_json_file(state_path)
9864
+ save_ax_client_to_json_file(state_path)
9620
9865
  except Exception as e:
9621
9866
  print(f"Error saving experiment state: {e}")
9622
9867
 
@@ -10389,7 +10634,7 @@ def show_omniopt_call() -> None:
10389
10634
  original_print(oo_call + " " + cleaned)
10390
10635
 
10391
10636
  def main() -> None:
10392
- global RESULT_CSV_FILE, ax_client, LOGFILE_DEBUG_GET_NEXT_TRIALS
10637
+ global RESULT_CSV_FILE, LOGFILE_DEBUG_GET_NEXT_TRIALS
10393
10638
 
10394
10639
  check_if_has_random_steps()
10395
10640
 
@@ -10471,7 +10716,7 @@ def main() -> None:
10471
10716
  exp_params = get_experiment_parameters(cli_params_experiment_parameters)
10472
10717
 
10473
10718
  if exp_params is not None:
10474
- ax_client, experiment_args, gpu_string, gpu_color = exp_params
10719
+ experiment_args, gpu_string, gpu_color = exp_params
10475
10720
  print_debug(f"experiment_parameters: {experiment_parameters}")
10476
10721
 
10477
10722
  set_orchestrator()
@@ -11235,6 +11480,9 @@ def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
11235
11480
  def auto_wrap_namespace(namespace: Any) -> Any:
11236
11481
  enable_beartype = any(os.getenv(v) for v in ("ENABLE_BEARTYPE", "CI"))
11237
11482
 
11483
+ if args.beartype:
11484
+ enable_beartype = True
11485
+
11238
11486
  excluded_functions = {
11239
11487
  "log_time_and_memory_wrapper",
11240
11488
  "collect_runtime_stats",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: omniopt2
3
- Version: 8509
3
+ Version: 8588
4
4
  Summary: Automatic highly parallelized hyperparameter optimizer based on Ax/Botorch
5
5
  Home-page: https://scads.ai/transfer-2/verfuegbare-software-dienste-en/omniopt/
6
6
  Author: Norman Koch
@@ -15,7 +15,7 @@ Requires-Dist: wheel
15
15
  Requires-Dist: multidict
16
16
  Requires-Dist: numpy
17
17
  Requires-Dist: python-dateutil
18
- Requires-Dist: ax-platform
18
+ Requires-Dist: ax-platform==1.1.0
19
19
  Requires-Dist: art
20
20
  Requires-Dist: tzlocal
21
21
  Requires-Dist: Rich
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: omniopt2
3
- Version: 8509
3
+ Version: 8588
4
4
  Summary: Automatic highly parallelized hyperparameter optimizer based on Ax/Botorch
5
5
  Home-page: https://scads.ai/transfer-2/verfuegbare-software-dienste-en/omniopt/
6
6
  Author: Norman Koch
@@ -15,7 +15,7 @@ Requires-Dist: wheel
15
15
  Requires-Dist: multidict
16
16
  Requires-Dist: numpy
17
17
  Requires-Dist: python-dateutil
18
- Requires-Dist: ax-platform
18
+ Requires-Dist: ax-platform==1.1.0
19
19
  Requires-Dist: art
20
20
  Requires-Dist: tzlocal
21
21
  Requires-Dist: Rich
@@ -5,7 +5,7 @@ wheel
5
5
  multidict
6
6
  numpy
7
7
  python-dateutil
8
- ax-platform
8
+ ax-platform==1.1.0
9
9
  art
10
10
  tzlocal
11
11
  Rich
@@ -5,7 +5,7 @@ authors = [
5
5
  {email = "norman.koch@tu-dresden.de"},
6
6
  {name = "Norman Koch"}
7
7
  ]
8
- version = "8509"
8
+ version = "8588"
9
9
 
10
10
  readme = "README.md"
11
11
  dynamic = ["dependencies"]
@@ -5,7 +5,7 @@ wheel
5
5
  multidict
6
6
  numpy
7
7
  python-dateutil
8
- ax-platform
8
+ ax-platform==1.1.0
9
9
  art
10
10
  tzlocal
11
11
  Rich
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes