omniopt2 8471__py3-none-any.whl → 9171__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- .gitignore +2 -0
- .helpers.py +0 -9
- .omniopt.py +1298 -903
- .omniopt_plot_scatter.py +1 -1
- .omniopt_plot_scatter_hex.py +1 -1
- .pareto.py +134 -0
- .shellscript_functions +24 -15
- .tests/pylint.rc +0 -4
- README.md +1 -1
- omniopt +33 -22
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.helpers.py +0 -9
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt.py +1298 -903
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter.py +1 -1
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter_hex.py +1 -1
- omniopt2-9171.data/data/bin/.pareto.py +134 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.shellscript_functions +24 -15
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt +33 -22
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt_plot +1 -1
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/pylint.rc +0 -4
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/test_requirements.txt +1 -0
- {omniopt2-8471.dist-info → omniopt2-9171.dist-info}/METADATA +6 -5
- omniopt2-9171.dist-info/RECORD +73 -0
- omniopt2.egg-info/PKG-INFO +6 -5
- omniopt2.egg-info/SOURCES.txt +1 -0
- omniopt2.egg-info/requires.txt +4 -3
- omniopt_plot +1 -1
- pyproject.toml +1 -1
- requirements.txt +3 -3
- test_requirements.txt +1 -0
- omniopt2-8471.dist-info/RECORD +0 -71
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.colorfunctions.sh +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.general.sh +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_cpu_ram_usage.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_general.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_gpu_usage.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_kde.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_scatter_generation_method.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_time_and_exit_code.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_trial_index_result.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.omniopt_plot_worker.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.random_generator.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/.tpe.py +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/LICENSE +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/apt-dependencies.txt +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt_docker +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt_evaluate +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/omniopt_share +0 -0
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/requirements.txt +3 -3
- {omniopt2-8471.data → omniopt2-9171.data}/data/bin/setup.py +0 -0
- {omniopt2-8471.dist-info → omniopt2-9171.dist-info}/WHEEL +0 -0
- {omniopt2-8471.dist-info → omniopt2-9171.dist-info}/licenses/LICENSE +0 -0
- {omniopt2-8471.dist-info → omniopt2-9171.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
#from mayhemmonkey import MayhemMonkey
|
|
4
4
|
#mayhemmonkey = MayhemMonkey()
|
|
5
|
-
#mayhemmonkey.set_function_fail_after_count("open",
|
|
5
|
+
#mayhemmonkey.set_function_fail_after_count("open", 10)
|
|
6
6
|
#mayhemmonkey.set_function_error_rate("open", 0.1)
|
|
7
7
|
#mayhemmonkey.set_function_group_error_rate(["io", "math"], 0.8)
|
|
8
8
|
#mayhemmonkey.install_faulty()
|
|
@@ -27,11 +27,15 @@ import inspect
|
|
|
27
27
|
import tracemalloc
|
|
28
28
|
import resource
|
|
29
29
|
from urllib.parse import urlencode
|
|
30
|
-
|
|
31
30
|
import psutil
|
|
32
31
|
|
|
33
32
|
FORCE_EXIT: bool = False
|
|
34
33
|
|
|
34
|
+
LAST_LOG_TIME: int = 0
|
|
35
|
+
last_msg_progressbar = ""
|
|
36
|
+
last_msg_raw = None
|
|
37
|
+
last_lock_print_debug = threading.Lock()
|
|
38
|
+
|
|
35
39
|
def force_exit(signal_number: Any, frame: Any) -> Any:
|
|
36
40
|
global FORCE_EXIT
|
|
37
41
|
|
|
@@ -84,6 +88,7 @@ log_nr_gen_jobs: list[int] = []
|
|
|
84
88
|
generation_strategy_human_readable: str = ""
|
|
85
89
|
oo_call: str = "./omniopt"
|
|
86
90
|
progress_bar_length: int = 0
|
|
91
|
+
worker_usage_file = 'worker_usage.csv'
|
|
87
92
|
|
|
88
93
|
if os.environ.get("CUSTOM_VIRTUAL_ENV") == "1":
|
|
89
94
|
oo_call = "omniopt"
|
|
@@ -99,7 +104,7 @@ joined_valid_occ_types: str = ", ".join(valid_occ_types)
|
|
|
99
104
|
SUPPORTED_MODELS: list = ["SOBOL", "FACTORIAL", "SAASBO", "BOTORCH_MODULAR", "UNIFORM", "BO_MIXED", "RANDOMFOREST", "EXTERNAL_GENERATOR", "PSEUDORANDOM", "TPE"]
|
|
100
105
|
joined_supported_models: str = ", ".join(SUPPORTED_MODELS)
|
|
101
106
|
|
|
102
|
-
special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid"]
|
|
107
|
+
special_col_names: list = ["arm_name", "generation_method", "trial_index", "trial_status", "generation_node", "idxs", "start_time", "end_time", "run_time", "exit_code", "program_string", "signal", "hostname", "submit_time", "queue_time", "metric_name", "mean", "sem", "worker_generator_uuid", "runtime", "status"]
|
|
103
108
|
|
|
104
109
|
IGNORABLE_COLUMNS: list = ["start_time", "end_time", "hostname", "signal", "exit_code", "run_time", "program_string"] + special_col_names
|
|
105
110
|
|
|
@@ -153,12 +158,6 @@ try:
|
|
|
153
158
|
message="Ax currently requires a sqlalchemy version below 2.0.*",
|
|
154
159
|
)
|
|
155
160
|
|
|
156
|
-
warnings.filterwarnings(
|
|
157
|
-
"ignore",
|
|
158
|
-
category=RuntimeWarning,
|
|
159
|
-
message="coroutine 'start_logging_daemon' was never awaited"
|
|
160
|
-
)
|
|
161
|
-
|
|
162
161
|
with spinner("Importing argparse..."):
|
|
163
162
|
import argparse
|
|
164
163
|
|
|
@@ -204,6 +203,9 @@ try:
|
|
|
204
203
|
with spinner("Importing rich.pretty..."):
|
|
205
204
|
from rich.pretty import pprint
|
|
206
205
|
|
|
206
|
+
with spinner("Importing pformat..."):
|
|
207
|
+
from pprint import pformat
|
|
208
|
+
|
|
207
209
|
with spinner("Importing rich.prompt..."):
|
|
208
210
|
from rich.prompt import Prompt, FloatPrompt, IntPrompt
|
|
209
211
|
|
|
@@ -237,9 +239,6 @@ try:
|
|
|
237
239
|
with spinner("Importing uuid..."):
|
|
238
240
|
import uuid
|
|
239
241
|
|
|
240
|
-
#with spinner("Importing qrcode..."):
|
|
241
|
-
# import qrcode
|
|
242
|
-
|
|
243
242
|
with spinner("Importing cowsay..."):
|
|
244
243
|
import cowsay
|
|
245
244
|
|
|
@@ -270,6 +269,9 @@ try:
|
|
|
270
269
|
with spinner("Importing beartype..."):
|
|
271
270
|
from beartype import beartype
|
|
272
271
|
|
|
272
|
+
with spinner("Importing rendering stuff..."):
|
|
273
|
+
from ax.plot.base import AxPlotConfig
|
|
274
|
+
|
|
273
275
|
with spinner("Importing statistics..."):
|
|
274
276
|
import statistics
|
|
275
277
|
|
|
@@ -381,16 +383,9 @@ def _record_stats(func_name: str, elapsed: float, mem_diff: float, mem_after: fl
|
|
|
381
383
|
call_path_str = " -> ".join(short_stack)
|
|
382
384
|
_func_call_paths[func_name][call_path_str] += 1
|
|
383
385
|
|
|
384
|
-
print(
|
|
385
|
-
|
|
386
|
-
f"(total {percent_if_added:.1f}% of tracked time)"
|
|
387
|
-
)
|
|
388
|
-
print(
|
|
389
|
-
f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, "
|
|
390
|
-
f"diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB"
|
|
391
|
-
)
|
|
386
|
+
print(f"Function '{func_name}' took {elapsed:.4f}s (total {percent_if_added:.1f}% of tracked time)")
|
|
387
|
+
print(f"Memory before: {mem_after - mem_diff:.2f} MB, after: {mem_after:.2f} MB, diff: {mem_diff:+.2f} MB, peak during call: {mem_peak:.2f} MB")
|
|
392
388
|
|
|
393
|
-
# NEU: Runtime Stats
|
|
394
389
|
runtime_stats = collect_runtime_stats()
|
|
395
390
|
print("=== Runtime Stats ===")
|
|
396
391
|
print(f"RSS: {runtime_stats['rss_MB']:.2f} MB, VMS: {runtime_stats['vms_MB']:.2f} MB")
|
|
@@ -502,6 +497,24 @@ try:
|
|
|
502
497
|
dier: FunctionType = helpers.dier
|
|
503
498
|
is_equal: FunctionType = helpers.is_equal
|
|
504
499
|
is_not_equal: FunctionType = helpers.is_not_equal
|
|
500
|
+
with spinner("Importing pareto..."):
|
|
501
|
+
pareto_file: str = f"{script_dir}/.pareto.py"
|
|
502
|
+
spec = importlib.util.spec_from_file_location(
|
|
503
|
+
name="pareto",
|
|
504
|
+
location=pareto_file,
|
|
505
|
+
)
|
|
506
|
+
if spec is not None and spec.loader is not None:
|
|
507
|
+
pareto = importlib.util.module_from_spec(spec)
|
|
508
|
+
spec.loader.exec_module(pareto)
|
|
509
|
+
else:
|
|
510
|
+
raise ImportError(f"Could not load module from {pareto_file}")
|
|
511
|
+
|
|
512
|
+
pareto_front_table_filter_rows: FunctionType = pareto.pareto_front_table_filter_rows
|
|
513
|
+
pareto_front_table_add_headers: FunctionType = pareto.pareto_front_table_add_headers
|
|
514
|
+
pareto_front_table_add_rows: FunctionType = pareto.pareto_front_table_add_rows
|
|
515
|
+
pareto_front_filter_complete_points: FunctionType = pareto.pareto_front_filter_complete_points
|
|
516
|
+
pareto_front_select_pareto_points: FunctionType = pareto.pareto_front_select_pareto_points
|
|
517
|
+
|
|
505
518
|
except KeyboardInterrupt:
|
|
506
519
|
print("You pressed CTRL-c while importing the helpers file")
|
|
507
520
|
sys.exit(0)
|
|
@@ -544,6 +557,17 @@ logfile_worker_creation_logs: str = f'{log_uuid_dir}_worker_creation_logs'
|
|
|
544
557
|
logfile_trial_index_to_param_logs: str = f'{log_uuid_dir}_trial_index_to_param_logs'
|
|
545
558
|
LOGFILE_DEBUG_GET_NEXT_TRIALS: Union[str, None] = None
|
|
546
559
|
|
|
560
|
+
def error_without_print(text: str) -> None:
|
|
561
|
+
print_debug(text)
|
|
562
|
+
|
|
563
|
+
if get_current_run_folder():
|
|
564
|
+
try:
|
|
565
|
+
with open(get_current_run_folder("oo_errors.txt"), mode="a", encoding="utf-8") as myfile:
|
|
566
|
+
myfile.write(text + "\n\n")
|
|
567
|
+
except (OSError, FileNotFoundError) as e:
|
|
568
|
+
helpers.print_color("red", f"Error: {e}. This may mean that the {get_current_run_folder()} was deleted during the run. Could not write '{text} to {get_current_run_folder()}/oo_errors.txt'")
|
|
569
|
+
sys.exit(99)
|
|
570
|
+
|
|
547
571
|
def print_red(text: str) -> None:
|
|
548
572
|
helpers.print_color("red", text)
|
|
549
573
|
|
|
@@ -802,6 +826,7 @@ class ConfigLoader:
|
|
|
802
826
|
parameter: Optional[List[str]]
|
|
803
827
|
experiment_constraints: Optional[List[str]]
|
|
804
828
|
main_process_gb: int
|
|
829
|
+
beartype: bool
|
|
805
830
|
worker_timeout: int
|
|
806
831
|
slurm_signal_delay_s: int
|
|
807
832
|
gridsearch: bool
|
|
@@ -822,6 +847,7 @@ class ConfigLoader:
|
|
|
822
847
|
run_program_once: str
|
|
823
848
|
mem_gb: int
|
|
824
849
|
flame_graph: bool
|
|
850
|
+
memray: bool
|
|
825
851
|
continue_previous_job: Optional[str]
|
|
826
852
|
calculate_pareto_front_of_job: Optional[List[str]]
|
|
827
853
|
revert_to_random_when_seemingly_exhausted: bool
|
|
@@ -983,6 +1009,7 @@ class ConfigLoader:
|
|
|
983
1009
|
debug.add_argument('--verbose_break_run_search_table', help='Verbose logging for break_run_search', action='store_true', default=False)
|
|
984
1010
|
debug.add_argument('--debug', help='Enable debugging', action='store_true', default=False)
|
|
985
1011
|
debug.add_argument('--flame_graph', help='Enable flame-graphing. Makes everything slower, but creates a flame graph', action='store_true', default=False)
|
|
1012
|
+
debug.add_argument('--memray', help='Use memray to show memory usage', action='store_true', default=False)
|
|
986
1013
|
debug.add_argument('--no_sleep', help='Disables sleeping for fast job generation (not to be used on HPC)', action='store_true', default=False)
|
|
987
1014
|
debug.add_argument('--tests', help='Run simple internal tests', action='store_true', default=False)
|
|
988
1015
|
debug.add_argument('--show_worker_percentage_table_at_end', help='Show a table of percentage of usage of max worker over time', action='store_true', default=False)
|
|
@@ -996,8 +1023,8 @@ class ConfigLoader:
|
|
|
996
1023
|
debug.add_argument('--runtime_debug', help='Logs which functions use most of the time', action='store_true', default=False)
|
|
997
1024
|
debug.add_argument('--debug_stack_regex', help='Only print debug messages if call stack matches any regex', type=str, default='')
|
|
998
1025
|
debug.add_argument('--debug_stack_trace_regex', help='Show compact call stack with arrows if any function in stack matches regex', type=str, default=None)
|
|
999
|
-
|
|
1000
1026
|
debug.add_argument('--show_func_name', help='Show func name before each execution and when it is done', action='store_true', default=False)
|
|
1027
|
+
debug.add_argument('--beartype', help='Use beartype', action='store_true', default=False)
|
|
1001
1028
|
|
|
1002
1029
|
def load_config(self: Any, config_path: str, file_format: str) -> dict:
|
|
1003
1030
|
if not os.path.isfile(config_path):
|
|
@@ -1235,11 +1262,15 @@ for _rn in args.result_names:
|
|
|
1235
1262
|
_key = _rn
|
|
1236
1263
|
_min_or_max = __default_min_max
|
|
1237
1264
|
|
|
1265
|
+
_min_or_max = re.sub(r"'", "", _min_or_max)
|
|
1266
|
+
|
|
1238
1267
|
if _min_or_max not in ["min", "max"]:
|
|
1239
1268
|
if _min_or_max:
|
|
1240
1269
|
print_yellow(f"Value for determining whether to minimize or maximize was neither 'min' nor 'max' for key '{_key}', but '{_min_or_max}'. It will be set to the default, which is '{__default_min_max}' instead.")
|
|
1241
1270
|
_min_or_max = __default_min_max
|
|
1242
1271
|
|
|
1272
|
+
_key = re.sub(r"'", "", _key)
|
|
1273
|
+
|
|
1243
1274
|
if _key in arg_result_names:
|
|
1244
1275
|
console.print(f"[red]The --result_names option '{_key}' was specified multiple times![/]")
|
|
1245
1276
|
sys.exit(50)
|
|
@@ -1340,27 +1371,14 @@ try:
|
|
|
1340
1371
|
with spinner("Importing ExternalGenerationNode..."):
|
|
1341
1372
|
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
|
|
1342
1373
|
|
|
1343
|
-
with spinner("Importing
|
|
1344
|
-
from ax.generation_strategy.transition_criterion import
|
|
1374
|
+
with spinner("Importing MinTrials..."):
|
|
1375
|
+
from ax.generation_strategy.transition_criterion import MinTrials
|
|
1345
1376
|
|
|
1346
1377
|
with spinner("Importing GeneratorSpec..."):
|
|
1347
1378
|
from ax.generation_strategy.generator_spec import GeneratorSpec
|
|
1348
1379
|
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
# import ax.generation_strategy.generation_node
|
|
1352
|
-
|
|
1353
|
-
# with spinner("Fallback: Importing GenerationStep, GenerationStrategy from ax.generation_strategy..."):
|
|
1354
|
-
# from ax.generation_strategy.generation_node import GenerationNode, GenerationStep
|
|
1355
|
-
|
|
1356
|
-
# with spinner("Fallback: Importing ExternalGenerationNode..."):
|
|
1357
|
-
# from ax.generation_strategy.external_generation_node import ExternalGenerationNode
|
|
1358
|
-
|
|
1359
|
-
# with spinner("Fallback: Importing MaxTrials..."):
|
|
1360
|
-
# from ax.generation_strategy.transition_criterion import MaxTrials
|
|
1361
|
-
|
|
1362
|
-
with spinner("Importing Models from ax.generation_strategy.registry..."):
|
|
1363
|
-
from ax.adapter.registry import Models
|
|
1380
|
+
with spinner("Importing Generators from ax.generation_strategy.registry..."):
|
|
1381
|
+
from ax.adapter.registry import Generators
|
|
1364
1382
|
|
|
1365
1383
|
with spinner("Importing get_pending_observation_features..."):
|
|
1366
1384
|
from ax.core.utils import get_pending_observation_features
|
|
@@ -1446,7 +1464,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1446
1464
|
def __init__(self: Any, regressor_options: Dict[str, Any] = {}, seed: Optional[int] = None, num_samples: int = 1) -> None:
|
|
1447
1465
|
print_debug("Initializing RandomForestGenerationNode...")
|
|
1448
1466
|
t_init_start = time.monotonic()
|
|
1449
|
-
super().__init__(
|
|
1467
|
+
super().__init__(name="RANDOMFOREST")
|
|
1450
1468
|
self.num_samples: int = num_samples
|
|
1451
1469
|
self.seed: int = seed
|
|
1452
1470
|
|
|
@@ -1466,6 +1484,9 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1466
1484
|
def update_generator_state(self: Any, experiment: Experiment, data: Data) -> None:
|
|
1467
1485
|
search_space = experiment.search_space
|
|
1468
1486
|
parameter_names = list(search_space.parameters.keys())
|
|
1487
|
+
if experiment.optimization_config is None:
|
|
1488
|
+
print_red("Error: update_generator_state is None")
|
|
1489
|
+
return
|
|
1469
1490
|
metric_names = list(experiment.optimization_config.metrics.keys())
|
|
1470
1491
|
|
|
1471
1492
|
completed_trials = [
|
|
@@ -1477,7 +1498,7 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1477
1498
|
y = np.zeros([num_completed_trials, 1])
|
|
1478
1499
|
|
|
1479
1500
|
for t_idx, trial in enumerate(completed_trials):
|
|
1480
|
-
trial_parameters = trial.
|
|
1501
|
+
trial_parameters = trial.arms[t_idx].parameters
|
|
1481
1502
|
x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
|
|
1482
1503
|
trial_df = data.df[data.df["trial_index"] == trial.index]
|
|
1483
1504
|
y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]]["mean"].item()
|
|
@@ -1617,10 +1638,18 @@ class RandomForestGenerationNode(ExternalGenerationNode):
|
|
|
1617
1638
|
def _format_best_sample(self: Any, best_sample: TParameterization, reverse_choice_map: dict) -> None:
|
|
1618
1639
|
for name in best_sample.keys():
|
|
1619
1640
|
param = self.parameters.get(name)
|
|
1641
|
+
best_sample_by_name = best_sample[name]
|
|
1642
|
+
|
|
1620
1643
|
if isinstance(param, RangeParameter) and param.parameter_type == ParameterType.INT:
|
|
1621
|
-
|
|
1644
|
+
if best_sample_by_name is not None:
|
|
1645
|
+
best_sample[name] = int(round(float(best_sample_by_name)))
|
|
1646
|
+
else:
|
|
1647
|
+
print_debug("best_sample_by_name was empty")
|
|
1622
1648
|
elif isinstance(param, ChoiceParameter):
|
|
1623
|
-
|
|
1649
|
+
if best_sample_by_name is not None:
|
|
1650
|
+
best_sample[name] = str(reverse_choice_map.get(int(best_sample_by_name)))
|
|
1651
|
+
else:
|
|
1652
|
+
print_debug("best_sample_by_name was empty")
|
|
1624
1653
|
|
|
1625
1654
|
decoder_registry["RandomForestGenerationNode"] = RandomForestGenerationNode
|
|
1626
1655
|
|
|
@@ -1668,10 +1697,10 @@ class InteractiveCLIGenerationNode(ExternalGenerationNode):
|
|
|
1668
1697
|
|
|
1669
1698
|
def __init__(
|
|
1670
1699
|
self: Any,
|
|
1671
|
-
|
|
1700
|
+
name: str = "INTERACTIVE_GENERATOR",
|
|
1672
1701
|
) -> None:
|
|
1673
1702
|
t0 = time.monotonic()
|
|
1674
|
-
super().__init__(
|
|
1703
|
+
super().__init__(name=name)
|
|
1675
1704
|
self.parameters = None
|
|
1676
1705
|
self.minimize = None
|
|
1677
1706
|
self.data = None
|
|
@@ -1827,10 +1856,10 @@ class InteractiveCLIGenerationNode(ExternalGenerationNode):
|
|
|
1827
1856
|
|
|
1828
1857
|
@dataclass(init=False)
|
|
1829
1858
|
class ExternalProgramGenerationNode(ExternalGenerationNode):
|
|
1830
|
-
def __init__(self: Any, external_generator: str = args.external_generator,
|
|
1859
|
+
def __init__(self: Any, external_generator: str = args.external_generator, name: str = "EXTERNAL_GENERATOR") -> None:
|
|
1831
1860
|
print_debug("Initializing ExternalProgramGenerationNode...")
|
|
1832
1861
|
t_init_start = time.monotonic()
|
|
1833
|
-
super().__init__(
|
|
1862
|
+
super().__init__(name=name)
|
|
1834
1863
|
self.seed: int = args.seed
|
|
1835
1864
|
self.external_generator: str = decode_if_base64(external_generator)
|
|
1836
1865
|
self.constraints = None
|
|
@@ -2051,9 +2080,9 @@ def run_live_share_command(force: bool = False) -> Tuple[str, str]:
|
|
|
2051
2080
|
return str(result.stdout), str(result.stderr)
|
|
2052
2081
|
except subprocess.CalledProcessError as e:
|
|
2053
2082
|
if e.stderr:
|
|
2054
|
-
|
|
2083
|
+
print_debug(f"run_live_share_command: command failed with error: {e}, stderr: {e.stderr}")
|
|
2055
2084
|
else:
|
|
2056
|
-
|
|
2085
|
+
print_debug(f"run_live_share_command: command failed with error: {e}")
|
|
2057
2086
|
return "", str(e.stderr)
|
|
2058
2087
|
except Exception as e:
|
|
2059
2088
|
print(f"run_live_share_command: An error occurred: {e}")
|
|
@@ -2067,6 +2096,8 @@ def force_live_share() -> bool:
|
|
|
2067
2096
|
return False
|
|
2068
2097
|
|
|
2069
2098
|
def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
|
|
2099
|
+
log_data()
|
|
2100
|
+
|
|
2070
2101
|
if not get_current_run_folder():
|
|
2071
2102
|
print(f"live_share: get_current_run_folder was empty or false: {get_current_run_folder()}")
|
|
2072
2103
|
return False
|
|
@@ -2080,17 +2111,23 @@ def live_share(force: bool = False, text_and_qr: bool = False) -> bool:
|
|
|
2080
2111
|
if stderr:
|
|
2081
2112
|
print_green(stderr)
|
|
2082
2113
|
else:
|
|
2083
|
-
|
|
2114
|
+
if stderr and stdout:
|
|
2115
|
+
print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}, stdout: {stdout}")
|
|
2116
|
+
elif stderr:
|
|
2117
|
+
print_red(f"This call should have shown the CURL, but didnt. Stderr: {stderr}")
|
|
2118
|
+
elif stdout:
|
|
2119
|
+
print_red(f"This call should have shown the CURL, but didnt. Stdout: {stdout}")
|
|
2120
|
+
else:
|
|
2121
|
+
print_red("This call should have shown the CURL, but didnt.")
|
|
2084
2122
|
if stdout:
|
|
2085
2123
|
print_debug(f"live_share stdout: {stdout}")
|
|
2086
2124
|
|
|
2087
2125
|
return True
|
|
2088
2126
|
|
|
2089
2127
|
def init_live_share() -> bool:
|
|
2090
|
-
|
|
2091
|
-
ret = live_share(True, True)
|
|
2128
|
+
ret = live_share(True, True)
|
|
2092
2129
|
|
|
2093
|
-
|
|
2130
|
+
return ret
|
|
2094
2131
|
|
|
2095
2132
|
def init_storage(db_url: str) -> None:
|
|
2096
2133
|
init_engine_and_session_factory(url=db_url, force_init=True)
|
|
@@ -2116,7 +2153,11 @@ def try_saving_to_db() -> None:
|
|
|
2116
2153
|
else:
|
|
2117
2154
|
print_red("ax_client was not defined in try_saving_to_db")
|
|
2118
2155
|
my_exit(101)
|
|
2119
|
-
|
|
2156
|
+
|
|
2157
|
+
if global_gs is not None:
|
|
2158
|
+
save_generation_strategy(global_gs)
|
|
2159
|
+
else:
|
|
2160
|
+
print_red("Not saving generation strategy: global_gs was empty")
|
|
2120
2161
|
except Exception as e:
|
|
2121
2162
|
print_debug(f"Failed trying to save sqlite3-DB: {e}")
|
|
2122
2163
|
|
|
@@ -2147,6 +2188,8 @@ def merge_with_job_infos(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
2147
2188
|
return merged
|
|
2148
2189
|
|
|
2149
2190
|
def save_results_csv() -> Optional[str]:
|
|
2191
|
+
log_data()
|
|
2192
|
+
|
|
2150
2193
|
if args.dryrun:
|
|
2151
2194
|
return None
|
|
2152
2195
|
|
|
@@ -2183,19 +2226,21 @@ def save_results_csv() -> Optional[str]:
|
|
|
2183
2226
|
def get_results_paths() -> tuple[str, str]:
|
|
2184
2227
|
return (get_current_run_folder(RESULTS_CSV_FILENAME), get_state_file_name('pd.json'))
|
|
2185
2228
|
|
|
2229
|
+
def ax_client_get_trials_data_frame() -> Optional[pd.DataFrame]:
|
|
2230
|
+
if not ax_client:
|
|
2231
|
+
my_exit(101)
|
|
2232
|
+
|
|
2233
|
+
return None
|
|
2234
|
+
|
|
2235
|
+
return ax_client.get_trials_data_frame()
|
|
2236
|
+
|
|
2186
2237
|
def fetch_and_prepare_trials() -> Optional[pd.DataFrame]:
|
|
2187
2238
|
if not ax_client:
|
|
2188
2239
|
return None
|
|
2189
2240
|
|
|
2190
2241
|
ax_client.experiment.fetch_data()
|
|
2191
|
-
df =
|
|
2192
|
-
|
|
2193
|
-
#print("========================")
|
|
2194
|
-
#print("BEFORE merge_with_job_infos:")
|
|
2195
|
-
#print(df["generation_node"])
|
|
2242
|
+
df = ax_client_get_trials_data_frame()
|
|
2196
2243
|
df = merge_with_job_infos(df)
|
|
2197
|
-
#print("AFTER merge_with_job_infos:")
|
|
2198
|
-
#print(df["generation_node"])
|
|
2199
2244
|
|
|
2200
2245
|
return df
|
|
2201
2246
|
|
|
@@ -2206,11 +2251,24 @@ def write_csv(df: pd.DataFrame, path: str) -> None:
|
|
|
2206
2251
|
pass
|
|
2207
2252
|
df.to_csv(path, index=False, float_format="%.30f")
|
|
2208
2253
|
|
|
2254
|
+
def ax_client_to_json_snapshot() -> Optional[dict]:
|
|
2255
|
+
if not ax_client:
|
|
2256
|
+
my_exit(101)
|
|
2257
|
+
|
|
2258
|
+
return None
|
|
2259
|
+
|
|
2260
|
+
json_snapshot = ax_client.to_json_snapshot()
|
|
2261
|
+
|
|
2262
|
+
return json_snapshot
|
|
2263
|
+
|
|
2209
2264
|
def write_json_snapshot(path: str) -> None:
|
|
2210
2265
|
if ax_client is not None:
|
|
2211
|
-
json_snapshot =
|
|
2212
|
-
|
|
2213
|
-
|
|
2266
|
+
json_snapshot = ax_client_to_json_snapshot()
|
|
2267
|
+
if json_snapshot is not None:
|
|
2268
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
2269
|
+
json.dump(json_snapshot, f, indent=4)
|
|
2270
|
+
else:
|
|
2271
|
+
print_debug('json_snapshot from ax_client_to_json_snapshot was None')
|
|
2214
2272
|
else:
|
|
2215
2273
|
print_red("write_json_snapshot: ax_client was None")
|
|
2216
2274
|
|
|
@@ -2410,7 +2468,7 @@ def set_nr_inserted_jobs(new_nr_inserted_jobs: int) -> None:
|
|
|
2410
2468
|
|
|
2411
2469
|
def write_worker_usage() -> None:
|
|
2412
2470
|
if len(WORKER_PERCENTAGE_USAGE):
|
|
2413
|
-
csv_filename = get_current_run_folder(
|
|
2471
|
+
csv_filename = get_current_run_folder(worker_usage_file)
|
|
2414
2472
|
|
|
2415
2473
|
csv_columns = ['time', 'num_parallel_jobs', 'nr_current_workers', 'percentage']
|
|
2416
2474
|
|
|
@@ -2420,35 +2478,39 @@ def write_worker_usage() -> None:
|
|
|
2420
2478
|
csv_writer.writerow(row)
|
|
2421
2479
|
else:
|
|
2422
2480
|
if is_slurm_job():
|
|
2423
|
-
print_debug("WORKER_PERCENTAGE_USAGE seems to be empty. Not writing
|
|
2481
|
+
print_debug(f"WORKER_PERCENTAGE_USAGE seems to be empty. Not writing {worker_usage_file}")
|
|
2424
2482
|
|
|
2425
2483
|
def log_system_usage() -> None:
|
|
2484
|
+
global LAST_LOG_TIME
|
|
2485
|
+
|
|
2486
|
+
now = time.time()
|
|
2487
|
+
if now - LAST_LOG_TIME < 30:
|
|
2488
|
+
return
|
|
2489
|
+
|
|
2490
|
+
LAST_LOG_TIME = int(now)
|
|
2491
|
+
|
|
2426
2492
|
if not get_current_run_folder():
|
|
2427
2493
|
return
|
|
2428
2494
|
|
|
2429
2495
|
ram_cpu_csv_file_path = os.path.join(get_current_run_folder(), "cpu_ram_usage.csv")
|
|
2430
|
-
|
|
2431
2496
|
makedirs(os.path.dirname(ram_cpu_csv_file_path))
|
|
2432
2497
|
|
|
2433
2498
|
file_exists = os.path.isfile(ram_cpu_csv_file_path)
|
|
2434
2499
|
|
|
2435
|
-
|
|
2436
|
-
|
|
2437
|
-
|
|
2438
|
-
current_time = int(time.time())
|
|
2439
|
-
|
|
2440
|
-
if process is not None:
|
|
2441
|
-
mem_proc = process.memory_info()
|
|
2442
|
-
|
|
2443
|
-
if mem_proc is not None:
|
|
2444
|
-
ram_usage_mb = mem_proc.rss / (1024 * 1024)
|
|
2445
|
-
cpu_usage_percent = psutil.cpu_percent(percpu=False)
|
|
2500
|
+
mem_proc = process.memory_info() if process else None
|
|
2501
|
+
if not mem_proc:
|
|
2502
|
+
return
|
|
2446
2503
|
|
|
2447
|
-
|
|
2448
|
-
|
|
2449
|
-
|
|
2504
|
+
ram_usage_mb = mem_proc.rss / (1024 * 1024)
|
|
2505
|
+
cpu_usage_percent = psutil.cpu_percent(percpu=False)
|
|
2506
|
+
if ram_usage_mb <= 0 or cpu_usage_percent <= 0:
|
|
2507
|
+
return
|
|
2450
2508
|
|
|
2451
|
-
|
|
2509
|
+
with open(ram_cpu_csv_file_path, mode='a', newline='', encoding="utf-8") as file:
|
|
2510
|
+
writer = csv.writer(file)
|
|
2511
|
+
if not file_exists:
|
|
2512
|
+
writer.writerow(["timestamp", "ram_usage_mb", "cpu_usage_percent"])
|
|
2513
|
+
writer.writerow([int(now), ram_usage_mb, cpu_usage_percent])
|
|
2452
2514
|
|
|
2453
2515
|
def write_process_info() -> None:
|
|
2454
2516
|
try:
|
|
@@ -2798,10 +2860,20 @@ def print_debug_get_next_trials(got: int, requested: int, _line: int) -> None:
|
|
|
2798
2860
|
log_message_to_file(LOGFILE_DEBUG_GET_NEXT_TRIALS, msg, 0, "")
|
|
2799
2861
|
|
|
2800
2862
|
def print_debug_progressbar(msg: str) -> None:
|
|
2801
|
-
|
|
2802
|
-
|
|
2863
|
+
global last_msg_progressbar, last_msg_raw
|
|
2864
|
+
|
|
2865
|
+
try:
|
|
2866
|
+
with last_lock_print_debug:
|
|
2867
|
+
if msg != last_msg_raw:
|
|
2868
|
+
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
2869
|
+
full_msg = f"{time_str} ({worker_generator_uuid}): {msg}"
|
|
2803
2870
|
|
|
2804
|
-
|
|
2871
|
+
_debug_progressbar(full_msg)
|
|
2872
|
+
|
|
2873
|
+
last_msg_raw = msg
|
|
2874
|
+
last_msg_progressbar = full_msg
|
|
2875
|
+
except Exception as e:
|
|
2876
|
+
print(f"Error in print_debug_progressbar: {e}", flush=True)
|
|
2805
2877
|
|
|
2806
2878
|
def get_process_info(pid: Any) -> str:
|
|
2807
2879
|
try:
|
|
@@ -3308,145 +3380,467 @@ def parse_experiment_parameters() -> None:
|
|
|
3308
3380
|
|
|
3309
3381
|
experiment_parameters = params # type: ignore[assignment]
|
|
3310
3382
|
|
|
3311
|
-
def
|
|
3312
|
-
|
|
3313
|
-
_fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
|
|
3383
|
+
def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
|
|
3384
|
+
pf_start_time = time.time()
|
|
3314
3385
|
|
|
3315
|
-
|
|
3316
|
-
|
|
3317
|
-
valid_value_types_string = ", ".join(valid_value_types)
|
|
3318
|
-
_fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
|
|
3386
|
+
if not path_to_calculate:
|
|
3387
|
+
return False
|
|
3319
3388
|
|
|
3320
|
-
|
|
3321
|
-
|
|
3322
|
-
|
|
3389
|
+
global CURRENT_RUN_FOLDER
|
|
3390
|
+
global RESULT_CSV_FILE
|
|
3391
|
+
global arg_result_names
|
|
3323
3392
|
|
|
3324
|
-
|
|
3325
|
-
|
|
3326
|
-
|
|
3327
|
-
if upper_bound == lower_bound:
|
|
3328
|
-
if lower_bound == 0:
|
|
3329
|
-
_fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
|
|
3330
|
-
print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
|
|
3331
|
-
if upper_bound is not None:
|
|
3332
|
-
lower_bound = -upper_bound
|
|
3393
|
+
if not path_to_calculate:
|
|
3394
|
+
print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
|
|
3395
|
+
return False
|
|
3333
3396
|
|
|
3334
|
-
|
|
3335
|
-
|
|
3336
|
-
|
|
3337
|
-
s = format(value, float_format)
|
|
3338
|
-
s = s.rstrip('0').rstrip('.') if '.' in s else s
|
|
3339
|
-
return s
|
|
3340
|
-
return str(value)
|
|
3341
|
-
except Exception as e:
|
|
3342
|
-
print_red(f"⚠ Error formatting the number {value}: {e}")
|
|
3343
|
-
return str(value)
|
|
3397
|
+
if not os.path.exists(path_to_calculate):
|
|
3398
|
+
print_red(f"Path '{path_to_calculate}' does not exist")
|
|
3399
|
+
return False
|
|
3344
3400
|
|
|
3345
|
-
|
|
3346
|
-
parameters: dict,
|
|
3347
|
-
input_string: str,
|
|
3348
|
-
float_format: str = '.20f',
|
|
3349
|
-
additional_prefixes: list[str] = [],
|
|
3350
|
-
additional_patterns: list[str] = [],
|
|
3351
|
-
) -> str:
|
|
3352
|
-
try:
|
|
3353
|
-
prefixes = ['$', '%'] + additional_prefixes
|
|
3354
|
-
patterns = ['{key}', '({key})'] + additional_patterns
|
|
3401
|
+
ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
|
|
3355
3402
|
|
|
3356
|
-
|
|
3357
|
-
|
|
3358
|
-
|
|
3359
|
-
for pattern in patterns:
|
|
3360
|
-
token = prefix + pattern.format(key=key)
|
|
3361
|
-
input_string = input_string.replace(token, replacement)
|
|
3403
|
+
if not os.path.exists(ax_client_json):
|
|
3404
|
+
print_red(f"Path '{ax_client_json}' not found")
|
|
3405
|
+
return False
|
|
3362
3406
|
|
|
3363
|
-
|
|
3364
|
-
|
|
3407
|
+
checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
|
|
3408
|
+
if not os.path.exists(checkpoint_file):
|
|
3409
|
+
print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
|
|
3410
|
+
return False
|
|
3365
3411
|
|
|
3366
|
-
|
|
3367
|
-
|
|
3368
|
-
|
|
3412
|
+
RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3413
|
+
if not os.path.exists(RESULT_CSV_FILE):
|
|
3414
|
+
print_red(f"{RESULT_CSV_FILE} not found")
|
|
3415
|
+
return False
|
|
3369
3416
|
|
|
3370
|
-
|
|
3371
|
-
user_uid = os.getuid()
|
|
3417
|
+
res_names = []
|
|
3372
3418
|
|
|
3373
|
-
|
|
3374
|
-
|
|
3375
|
-
|
|
3376
|
-
|
|
3419
|
+
res_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3420
|
+
if not os.path.exists(res_names_file):
|
|
3421
|
+
print_red(f"File '{res_names_file}' does not exist")
|
|
3422
|
+
return False
|
|
3377
3423
|
|
|
3378
|
-
|
|
3424
|
+
try:
|
|
3425
|
+
with open(res_names_file, "r", encoding="utf-8") as file:
|
|
3426
|
+
lines = file.readlines()
|
|
3427
|
+
except Exception as e:
|
|
3428
|
+
print_red(f"Error reading file '{res_names_file}': {e}")
|
|
3429
|
+
return False
|
|
3379
3430
|
|
|
3380
|
-
|
|
3381
|
-
|
|
3382
|
-
|
|
3383
|
-
|
|
3384
|
-
self.running = True
|
|
3385
|
-
self.thread = threading.Thread(target=self._monitor)
|
|
3386
|
-
self.thread.daemon = True
|
|
3431
|
+
for line in lines:
|
|
3432
|
+
entry = line.strip()
|
|
3433
|
+
if entry != "":
|
|
3434
|
+
res_names.append(entry)
|
|
3387
3435
|
|
|
3388
|
-
|
|
3436
|
+
if len(res_names) < 2:
|
|
3437
|
+
print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
|
|
3438
|
+
return False
|
|
3389
3439
|
|
|
3390
|
-
|
|
3391
|
-
try:
|
|
3392
|
-
_internal_process = psutil.Process(self.pid)
|
|
3393
|
-
while self.running and _internal_process.is_running():
|
|
3394
|
-
crf = get_current_run_folder()
|
|
3440
|
+
load_username_to_args(path_to_calculate)
|
|
3395
3441
|
|
|
3396
|
-
|
|
3397
|
-
log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
|
|
3442
|
+
CURRENT_RUN_FOLDER = path_to_calculate
|
|
3398
3443
|
|
|
3399
|
-
|
|
3444
|
+
arg_result_names = res_names
|
|
3400
3445
|
|
|
3401
|
-
|
|
3402
|
-
hostname = socket.gethostname()
|
|
3446
|
+
load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
|
|
3403
3447
|
|
|
3404
|
-
|
|
3448
|
+
if experiment_parameters is None:
|
|
3449
|
+
return False
|
|
3405
3450
|
|
|
3406
|
-
|
|
3407
|
-
hostname += f"-SLURM-ID-{slurm_job_id}"
|
|
3451
|
+
show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
|
|
3408
3452
|
|
|
3409
|
-
|
|
3410
|
-
cpu_usage = psutil.cpu_percent(interval=5)
|
|
3453
|
+
pf_end_time = time.time()
|
|
3411
3454
|
|
|
3412
|
-
|
|
3455
|
+
print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
|
|
3413
3456
|
|
|
3414
|
-
|
|
3457
|
+
return True
|
|
3415
3458
|
|
|
3416
|
-
|
|
3417
|
-
|
|
3418
|
-
|
|
3419
|
-
|
|
3459
|
+
def show_pareto_or_error_msg(path_to_calculate: str, res_names: list = arg_result_names, disable_sixel_and_table: bool = False) -> None:
|
|
3460
|
+
if args.dryrun:
|
|
3461
|
+
print_debug("Not showing Pareto-frontier data with --dryrun")
|
|
3462
|
+
return None
|
|
3420
3463
|
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
|
|
3464
|
+
if len(res_names) > 1:
|
|
3465
|
+
try:
|
|
3466
|
+
show_pareto_frontier_data(path_to_calculate, res_names, disable_sixel_and_table)
|
|
3467
|
+
except Exception as e:
|
|
3468
|
+
inner_tb = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
3469
|
+
print_red(f"show_pareto_frontier_data() failed with exception '{e}':\n{inner_tb}")
|
|
3470
|
+
else:
|
|
3471
|
+
print_debug(f"show_pareto_frontier_data will NOT be executed because len(arg_result_names) is {len(arg_result_names)}")
|
|
3472
|
+
return None
|
|
3424
3473
|
|
|
3425
|
-
|
|
3426
|
-
|
|
3427
|
-
self.thread.join()
|
|
3474
|
+
def get_pareto_front_data(path_to_calculate: str, res_names: list) -> dict:
|
|
3475
|
+
pareto_front_data: dict = {}
|
|
3428
3476
|
|
|
3429
|
-
|
|
3430
|
-
process_item = subprocess.Popen(code, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
3477
|
+
all_combinations = list(combinations(range(len(arg_result_names)), 2))
|
|
3431
3478
|
|
|
3432
|
-
|
|
3433
|
-
try:
|
|
3434
|
-
stdout, stderr = process_item.communicate()
|
|
3435
|
-
result = subprocess.CompletedProcess(
|
|
3436
|
-
args=code, returncode=process_item.returncode, stdout=stdout, stderr=stderr
|
|
3437
|
-
)
|
|
3438
|
-
return [result.stdout, result.stderr, result.returncode, None]
|
|
3439
|
-
except subprocess.CalledProcessError as e:
|
|
3440
|
-
real_exit_code = e.returncode
|
|
3441
|
-
signal_code = None
|
|
3442
|
-
if real_exit_code < 0:
|
|
3443
|
-
signal_code = abs(e.returncode)
|
|
3444
|
-
real_exit_code = 1
|
|
3445
|
-
return [e.stdout, e.stderr, real_exit_code, signal_code]
|
|
3479
|
+
skip = False
|
|
3446
3480
|
|
|
3447
|
-
|
|
3448
|
-
|
|
3449
|
-
|
|
3481
|
+
for i, j in all_combinations:
|
|
3482
|
+
if not skip:
|
|
3483
|
+
metric_x = arg_result_names[i]
|
|
3484
|
+
metric_y = arg_result_names[j]
|
|
3485
|
+
|
|
3486
|
+
x_minimize = get_result_minimize_flag(path_to_calculate, metric_x)
|
|
3487
|
+
y_minimize = get_result_minimize_flag(path_to_calculate, metric_y)
|
|
3488
|
+
|
|
3489
|
+
try:
|
|
3490
|
+
if metric_x not in pareto_front_data:
|
|
3491
|
+
pareto_front_data[metric_x] = {}
|
|
3492
|
+
|
|
3493
|
+
pareto_front_data[metric_x][metric_y] = get_calculated_frontier(path_to_calculate, metric_x, metric_y, x_minimize, y_minimize, res_names)
|
|
3494
|
+
except ax.exceptions.core.DataRequiredError as e:
|
|
3495
|
+
print_red(f"Error computing Pareto frontier for {metric_x} and {metric_y}: {e}")
|
|
3496
|
+
except SignalINT:
|
|
3497
|
+
print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
|
|
3498
|
+
skip = True
|
|
3499
|
+
|
|
3500
|
+
return pareto_front_data
|
|
3501
|
+
|
|
3502
|
+
def pareto_front_transform_objectives(
|
|
3503
|
+
points: List[Tuple[Any, float, float]],
|
|
3504
|
+
primary_name: str,
|
|
3505
|
+
secondary_name: str
|
|
3506
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
3507
|
+
primary_idx = arg_result_names.index(primary_name)
|
|
3508
|
+
secondary_idx = arg_result_names.index(secondary_name)
|
|
3509
|
+
|
|
3510
|
+
x = np.array([p[1] for p in points])
|
|
3511
|
+
y = np.array([p[2] for p in points])
|
|
3512
|
+
|
|
3513
|
+
if arg_result_min_or_max[primary_idx] == "max":
|
|
3514
|
+
x = -x
|
|
3515
|
+
elif arg_result_min_or_max[primary_idx] != "min":
|
|
3516
|
+
raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
|
|
3517
|
+
|
|
3518
|
+
if arg_result_min_or_max[secondary_idx] == "max":
|
|
3519
|
+
y = -y
|
|
3520
|
+
elif arg_result_min_or_max[secondary_idx] != "min":
|
|
3521
|
+
raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
|
|
3522
|
+
|
|
3523
|
+
return x, y
|
|
3524
|
+
|
|
3525
|
+
def get_pareto_frontier_points(
|
|
3526
|
+
path_to_calculate: str,
|
|
3527
|
+
primary_objective: str,
|
|
3528
|
+
secondary_objective: str,
|
|
3529
|
+
x_minimize: bool,
|
|
3530
|
+
y_minimize: bool,
|
|
3531
|
+
absolute_metrics: List[str],
|
|
3532
|
+
num_points: int
|
|
3533
|
+
) -> Optional[dict]:
|
|
3534
|
+
records = pareto_front_aggregate_data(path_to_calculate)
|
|
3535
|
+
|
|
3536
|
+
if records is None:
|
|
3537
|
+
return None
|
|
3538
|
+
|
|
3539
|
+
points = pareto_front_filter_complete_points(path_to_calculate, records, primary_objective, secondary_objective)
|
|
3540
|
+
x, y = pareto_front_transform_objectives(points, primary_objective, secondary_objective)
|
|
3541
|
+
selected_points = pareto_front_select_pareto_points(x, y, x_minimize, y_minimize, points, num_points)
|
|
3542
|
+
result = pareto_front_build_return_structure(path_to_calculate, selected_points, records, absolute_metrics, primary_objective, secondary_objective)
|
|
3543
|
+
|
|
3544
|
+
return result
|
|
3545
|
+
|
|
3546
|
+
def pareto_front_table_read_csv() -> List[Dict[str, str]]:
|
|
3547
|
+
with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
|
|
3548
|
+
return list(csv.DictReader(f))
|
|
3549
|
+
|
|
3550
|
+
def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
|
|
3551
|
+
table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
|
|
3552
|
+
|
|
3553
|
+
rows = pareto_front_table_read_csv()
|
|
3554
|
+
if not rows:
|
|
3555
|
+
table.add_column("No data found")
|
|
3556
|
+
return table
|
|
3557
|
+
|
|
3558
|
+
filtered_rows = pareto_front_table_filter_rows(rows, idxs)
|
|
3559
|
+
if not filtered_rows:
|
|
3560
|
+
table.add_column("No matching entries")
|
|
3561
|
+
return table
|
|
3562
|
+
|
|
3563
|
+
param_cols, result_cols = pareto_front_table_get_columns(filtered_rows[0])
|
|
3564
|
+
|
|
3565
|
+
pareto_front_table_add_headers(table, param_cols, result_cols)
|
|
3566
|
+
pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
|
|
3567
|
+
|
|
3568
|
+
return table
|
|
3569
|
+
|
|
3570
|
+
def pareto_front_build_return_structure(
|
|
3571
|
+
path_to_calculate: str,
|
|
3572
|
+
selected_points: List[Tuple[Any, float, float]],
|
|
3573
|
+
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
3574
|
+
absolute_metrics: List[str],
|
|
3575
|
+
primary_name: str,
|
|
3576
|
+
secondary_name: str
|
|
3577
|
+
) -> dict:
|
|
3578
|
+
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3579
|
+
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3580
|
+
|
|
3581
|
+
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
3582
|
+
result_names = [line.strip() for line in f if line.strip()]
|
|
3583
|
+
|
|
3584
|
+
csv_rows = {}
|
|
3585
|
+
with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
|
|
3586
|
+
reader = csv.DictReader(csvfile)
|
|
3587
|
+
for row in reader:
|
|
3588
|
+
trial_index = int(row['trial_index'])
|
|
3589
|
+
csv_rows[trial_index] = row
|
|
3590
|
+
|
|
3591
|
+
ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
|
|
3592
|
+
ignored_columns.update(result_names)
|
|
3593
|
+
|
|
3594
|
+
param_dicts = []
|
|
3595
|
+
idxs = []
|
|
3596
|
+
means_dict = defaultdict(list)
|
|
3597
|
+
|
|
3598
|
+
for (trial_index, arm_name), _, _ in selected_points:
|
|
3599
|
+
row = csv_rows.get(trial_index, {})
|
|
3600
|
+
if row == {} or row is None or row['arm_name'] != arm_name:
|
|
3601
|
+
continue
|
|
3602
|
+
|
|
3603
|
+
idxs.append(int(row["trial_index"]))
|
|
3604
|
+
|
|
3605
|
+
param_dict: dict[str, int | float | str] = {}
|
|
3606
|
+
for key, value in row.items():
|
|
3607
|
+
if key not in ignored_columns:
|
|
3608
|
+
try:
|
|
3609
|
+
param_dict[key] = int(value)
|
|
3610
|
+
except ValueError:
|
|
3611
|
+
try:
|
|
3612
|
+
param_dict[key] = float(value)
|
|
3613
|
+
except ValueError:
|
|
3614
|
+
param_dict[key] = value
|
|
3615
|
+
|
|
3616
|
+
param_dicts.append(param_dict)
|
|
3617
|
+
|
|
3618
|
+
for metric in absolute_metrics:
|
|
3619
|
+
means_dict[metric].append(records[(trial_index, arm_name)]['means'].get(metric, float("nan")))
|
|
3620
|
+
|
|
3621
|
+
ret = {
|
|
3622
|
+
primary_name: {
|
|
3623
|
+
secondary_name: {
|
|
3624
|
+
"absolute_metrics": absolute_metrics,
|
|
3625
|
+
"param_dicts": param_dicts,
|
|
3626
|
+
"means": dict(means_dict),
|
|
3627
|
+
"idxs": idxs
|
|
3628
|
+
},
|
|
3629
|
+
"absolute_metrics": absolute_metrics
|
|
3630
|
+
}
|
|
3631
|
+
}
|
|
3632
|
+
|
|
3633
|
+
return ret
|
|
3634
|
+
|
|
3635
|
+
def pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
|
|
3636
|
+
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
3637
|
+
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
3638
|
+
|
|
3639
|
+
if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
|
|
3640
|
+
return None
|
|
3641
|
+
|
|
3642
|
+
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
3643
|
+
result_names = [line.strip() for line in f if line.strip()]
|
|
3644
|
+
|
|
3645
|
+
records: dict = defaultdict(lambda: {'means': {}})
|
|
3646
|
+
|
|
3647
|
+
with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
|
|
3648
|
+
reader = csv.DictReader(csvfile)
|
|
3649
|
+
for row in reader:
|
|
3650
|
+
trial_index = int(row['trial_index'])
|
|
3651
|
+
arm_name = row['arm_name']
|
|
3652
|
+
key = (trial_index, arm_name)
|
|
3653
|
+
|
|
3654
|
+
for metric in result_names:
|
|
3655
|
+
if metric in row:
|
|
3656
|
+
try:
|
|
3657
|
+
records[key]['means'][metric] = float(row[metric])
|
|
3658
|
+
except ValueError:
|
|
3659
|
+
continue
|
|
3660
|
+
|
|
3661
|
+
return records
|
|
3662
|
+
|
|
3663
|
+
def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
|
|
3664
|
+
if data is None:
|
|
3665
|
+
print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
|
|
3666
|
+
return
|
|
3667
|
+
|
|
3668
|
+
if not supports_sixel():
|
|
3669
|
+
print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
|
|
3670
|
+
return
|
|
3671
|
+
|
|
3672
|
+
import matplotlib.pyplot as plt
|
|
3673
|
+
|
|
3674
|
+
means = data[x_metric][y_metric]["means"]
|
|
3675
|
+
|
|
3676
|
+
x_values = means[x_metric]
|
|
3677
|
+
y_values = means[y_metric]
|
|
3678
|
+
|
|
3679
|
+
fig, _ax = plt.subplots()
|
|
3680
|
+
|
|
3681
|
+
_ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
|
|
3682
|
+
|
|
3683
|
+
_ax.set_xlabel(x_metric)
|
|
3684
|
+
_ax.set_ylabel(y_metric)
|
|
3685
|
+
|
|
3686
|
+
_ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
|
|
3687
|
+
|
|
3688
|
+
_ax.ticklabel_format(style='plain', axis='both', useOffset=False)
|
|
3689
|
+
|
|
3690
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
|
|
3691
|
+
plt.savefig(tmp_file.name, dpi=300)
|
|
3692
|
+
|
|
3693
|
+
print_image_to_cli(tmp_file.name, 1000)
|
|
3694
|
+
|
|
3695
|
+
plt.close(fig)
|
|
3696
|
+
|
|
3697
|
+
def pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
|
|
3698
|
+
all_columns = list(first_row.keys())
|
|
3699
|
+
ignored_cols = set(special_col_names) - {"trial_index"}
|
|
3700
|
+
|
|
3701
|
+
param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
|
|
3702
|
+
result_cols = [col for col in arg_result_names if col in all_columns]
|
|
3703
|
+
return param_cols, result_cols
|
|
3704
|
+
|
|
3705
|
+
def check_factorial_range() -> None:
|
|
3706
|
+
if args.model and args.model == "FACTORIAL":
|
|
3707
|
+
_fatal_error("\n⚠ --model FACTORIAL cannot be used with range parameter", 181)
|
|
3708
|
+
|
|
3709
|
+
def check_if_range_types_are_invalid(value_type: str, valid_value_types: list) -> None:
|
|
3710
|
+
if value_type not in valid_value_types:
|
|
3711
|
+
valid_value_types_string = ", ".join(valid_value_types)
|
|
3712
|
+
_fatal_error(f"⚠ {value_type} is not a valid value type. Valid types for range are: {valid_value_types_string}", 181)
|
|
3713
|
+
|
|
3714
|
+
def check_range_params_length(this_args: Union[str, list]) -> None:
|
|
3715
|
+
if len(this_args) != 5 and len(this_args) != 4 and len(this_args) != 6:
|
|
3716
|
+
_fatal_error("\n⚠ --parameter for type range must have 4 (or 5, the last one being optional and float by default, or 6, while the last one is true or false) parameters: <NAME> range <START> <END> (<TYPE (int or float)>, <log_scale: bool>)", 181)
|
|
3717
|
+
|
|
3718
|
+
def die_if_lower_and_upper_bound_equal_zero(lower_bound: Union[int, float], upper_bound: Union[int, float]) -> None:
|
|
3719
|
+
if upper_bound is None or lower_bound is None:
|
|
3720
|
+
_fatal_error("die_if_lower_and_upper_bound_equal_zero: upper_bound or lower_bound is None. Cannot continue.", 91)
|
|
3721
|
+
if upper_bound == lower_bound:
|
|
3722
|
+
if lower_bound == 0:
|
|
3723
|
+
_fatal_error(f"⚠ Lower bound and upper bound are equal: {lower_bound}, cannot automatically fix this, because they -0 = +0 (usually a quickfix would be to set lower_bound = -upper_bound)", 181)
|
|
3724
|
+
print_red(f"⚠ Lower bound and upper bound are equal: {lower_bound}, setting lower_bound = -upper_bound")
|
|
3725
|
+
if upper_bound is not None:
|
|
3726
|
+
lower_bound = -upper_bound
|
|
3727
|
+
|
|
3728
|
+
def format_value(value: Any, float_format: str = '.80f') -> str:
|
|
3729
|
+
try:
|
|
3730
|
+
if isinstance(value, float):
|
|
3731
|
+
s = format(value, float_format)
|
|
3732
|
+
s = s.rstrip('0').rstrip('.') if '.' in s else s
|
|
3733
|
+
return s
|
|
3734
|
+
return str(value)
|
|
3735
|
+
except Exception as e:
|
|
3736
|
+
print_red(f"⚠ Error formatting the number {value}: {e}")
|
|
3737
|
+
return str(value)
|
|
3738
|
+
|
|
3739
|
+
def replace_parameters_in_string(
|
|
3740
|
+
parameters: dict,
|
|
3741
|
+
input_string: str,
|
|
3742
|
+
float_format: str = '.20f',
|
|
3743
|
+
additional_prefixes: list[str] = [],
|
|
3744
|
+
additional_patterns: list[str] = [],
|
|
3745
|
+
) -> str:
|
|
3746
|
+
try:
|
|
3747
|
+
prefixes = ['$', '%'] + additional_prefixes
|
|
3748
|
+
patterns = ['{' + 'key' + '}', '(' + '{' + 'key' + '}' + ')'] + additional_patterns
|
|
3749
|
+
|
|
3750
|
+
for key, value in parameters.items():
|
|
3751
|
+
replacement = format_value(value, float_format=float_format)
|
|
3752
|
+
for prefix in prefixes:
|
|
3753
|
+
for pattern in patterns:
|
|
3754
|
+
token = prefix + pattern.format(key=key)
|
|
3755
|
+
input_string = input_string.replace(token, replacement)
|
|
3756
|
+
|
|
3757
|
+
input_string = input_string.replace('\r', ' ').replace('\n', ' ')
|
|
3758
|
+
return input_string
|
|
3759
|
+
|
|
3760
|
+
except Exception as e:
|
|
3761
|
+
print_red(f"\n⚠ Error: {e}")
|
|
3762
|
+
return ""
|
|
3763
|
+
|
|
3764
|
+
def get_memory_usage() -> float:
|
|
3765
|
+
user_uid = os.getuid()
|
|
3766
|
+
|
|
3767
|
+
memory_usage = float(sum(
|
|
3768
|
+
p.memory_info().rss for p in psutil.process_iter(attrs=['memory_info', 'uids'])
|
|
3769
|
+
if p.info['uids'].real == user_uid
|
|
3770
|
+
) / (1024 * 1024))
|
|
3771
|
+
|
|
3772
|
+
return memory_usage
|
|
3773
|
+
|
|
3774
|
+
class MonitorProcess:
|
|
3775
|
+
def __init__(self: Any, pid: int, interval: float = 1.0) -> None:
|
|
3776
|
+
self.pid = pid
|
|
3777
|
+
self.interval = interval
|
|
3778
|
+
self.running = True
|
|
3779
|
+
self.thread = threading.Thread(target=self._monitor)
|
|
3780
|
+
self.thread.daemon = True
|
|
3781
|
+
|
|
3782
|
+
fool_linter(f"self.thread.daemon was set to {self.thread.daemon}")
|
|
3783
|
+
|
|
3784
|
+
def _monitor(self: Any) -> None:
|
|
3785
|
+
try:
|
|
3786
|
+
_internal_process = psutil.Process(self.pid)
|
|
3787
|
+
while self.running and _internal_process.is_running():
|
|
3788
|
+
crf = get_current_run_folder()
|
|
3789
|
+
|
|
3790
|
+
if crf and crf != "":
|
|
3791
|
+
log_file_path = os.path.join(crf, "eval_nodes_cpu_ram_logs.txt")
|
|
3792
|
+
|
|
3793
|
+
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
|
3794
|
+
|
|
3795
|
+
with open(log_file_path, mode="a", encoding="utf-8") as log_file:
|
|
3796
|
+
hostname = socket.gethostname()
|
|
3797
|
+
|
|
3798
|
+
slurm_job_id = os.getenv("SLURM_JOB_ID")
|
|
3799
|
+
|
|
3800
|
+
if slurm_job_id:
|
|
3801
|
+
hostname += f"-SLURM-ID-{slurm_job_id}"
|
|
3802
|
+
|
|
3803
|
+
total_memory = psutil.virtual_memory().total / (1024 * 1024)
|
|
3804
|
+
cpu_usage = psutil.cpu_percent(interval=5)
|
|
3805
|
+
|
|
3806
|
+
memory_usage = get_memory_usage()
|
|
3807
|
+
|
|
3808
|
+
unix_timestamp = int(time.time())
|
|
3809
|
+
|
|
3810
|
+
log_file.write(f"\nUnix-Timestamp: {unix_timestamp}, Hostname: {hostname}, CPU: {cpu_usage:.2f}%, RAM: {memory_usage:.2f} MB / {total_memory:.2f} MB\n")
|
|
3811
|
+
time.sleep(self.interval)
|
|
3812
|
+
except psutil.NoSuchProcess:
|
|
3813
|
+
pass
|
|
3814
|
+
|
|
3815
|
+
def __enter__(self: Any) -> None:
|
|
3816
|
+
self.thread.start()
|
|
3817
|
+
return self
|
|
3818
|
+
|
|
3819
|
+
def __exit__(self: Any, exc_type: Any, exc_value: Any, _traceback: Any) -> None:
|
|
3820
|
+
self.running = False
|
|
3821
|
+
self.thread.join()
|
|
3822
|
+
|
|
3823
|
+
def execute_bash_code_log_time(code: str) -> list:
|
|
3824
|
+
process_item = subprocess.Popen(code, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
3825
|
+
|
|
3826
|
+
with MonitorProcess(process_item.pid):
|
|
3827
|
+
try:
|
|
3828
|
+
stdout, stderr = process_item.communicate()
|
|
3829
|
+
result = subprocess.CompletedProcess(
|
|
3830
|
+
args=code, returncode=process_item.returncode, stdout=stdout, stderr=stderr
|
|
3831
|
+
)
|
|
3832
|
+
return [result.stdout, result.stderr, result.returncode, None]
|
|
3833
|
+
except subprocess.CalledProcessError as e:
|
|
3834
|
+
real_exit_code = e.returncode
|
|
3835
|
+
signal_code = None
|
|
3836
|
+
if real_exit_code < 0:
|
|
3837
|
+
signal_code = abs(e.returncode)
|
|
3838
|
+
real_exit_code = 1
|
|
3839
|
+
return [e.stdout, e.stderr, real_exit_code, signal_code]
|
|
3840
|
+
|
|
3841
|
+
def execute_bash_code(code: str) -> list:
|
|
3842
|
+
try:
|
|
3843
|
+
result = subprocess.run(
|
|
3450
3844
|
code,
|
|
3451
3845
|
shell=True,
|
|
3452
3846
|
check=True,
|
|
@@ -3559,7 +3953,7 @@ def _add_to_csv_acquire_lock(lockfile: str, dir_path: str) -> bool:
|
|
|
3559
3953
|
time.sleep(wait_time)
|
|
3560
3954
|
max_wait -= wait_time
|
|
3561
3955
|
except Exception as e:
|
|
3562
|
-
|
|
3956
|
+
print_red(f"Lock error: {e}")
|
|
3563
3957
|
return False
|
|
3564
3958
|
return False
|
|
3565
3959
|
|
|
@@ -3632,12 +4026,12 @@ def find_file_paths(_text: str) -> List[str]:
|
|
|
3632
4026
|
def check_file_info(file_path: str) -> str:
|
|
3633
4027
|
if not os.path.exists(file_path):
|
|
3634
4028
|
if not args.tests:
|
|
3635
|
-
|
|
4029
|
+
print_red(f"check_file_info: The file {file_path} does not exist.")
|
|
3636
4030
|
return ""
|
|
3637
4031
|
|
|
3638
4032
|
if not os.access(file_path, os.R_OK):
|
|
3639
4033
|
if not args.tests:
|
|
3640
|
-
|
|
4034
|
+
print_red(f"check_file_info: The file {file_path} is not readable.")
|
|
3641
4035
|
return ""
|
|
3642
4036
|
|
|
3643
4037
|
file_stat = os.stat(file_path)
|
|
@@ -3751,7 +4145,7 @@ def count_defective_nodes(file_path: Union[str, None] = None, entry: Any = None)
|
|
|
3751
4145
|
return sorted(set(entries))
|
|
3752
4146
|
|
|
3753
4147
|
except Exception as e:
|
|
3754
|
-
|
|
4148
|
+
print_red(f"An error has occurred: {e}")
|
|
3755
4149
|
return []
|
|
3756
4150
|
|
|
3757
4151
|
def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]:
|
|
@@ -3762,7 +4156,7 @@ def test_gpu_before_evaluate(return_in_case_of_error: dict) -> Union[None, dict]
|
|
|
3762
4156
|
|
|
3763
4157
|
fool_linter(tmp)
|
|
3764
4158
|
except RuntimeError:
|
|
3765
|
-
|
|
4159
|
+
print_red(f"Node {socket.gethostname()} was detected as faulty. It should have had a GPU, but there is an error initializing the CUDA driver. Adding this node to the --exclude list.")
|
|
3766
4160
|
count_defective_nodes(None, socket.gethostname())
|
|
3767
4161
|
return return_in_case_of_error
|
|
3768
4162
|
except Exception:
|
|
@@ -4233,6 +4627,8 @@ def evaluate(parameters_with_trial_index: dict) -> Optional[Union[int, float, Di
|
|
|
4233
4627
|
trial_index = parameters_with_trial_index["trial_idx"]
|
|
4234
4628
|
submit_time = parameters_with_trial_index["submit_time"]
|
|
4235
4629
|
|
|
4630
|
+
print(f'Trial-Index: {trial_index}')
|
|
4631
|
+
|
|
4236
4632
|
queue_time = abs(int(time.time()) - int(submit_time))
|
|
4237
4633
|
|
|
4238
4634
|
start_nvidia_smi_thread()
|
|
@@ -4470,7 +4866,7 @@ def replace_string_with_params(input_string: str, params: list) -> str:
|
|
|
4470
4866
|
return replaced_string
|
|
4471
4867
|
except AssertionError as e:
|
|
4472
4868
|
error_text = f"Error in replace_string_with_params: {e}"
|
|
4473
|
-
|
|
4869
|
+
print_red(error_text)
|
|
4474
4870
|
raise
|
|
4475
4871
|
|
|
4476
4872
|
return ""
|
|
@@ -4747,9 +5143,7 @@ def get_sixel_graphics_data(_pd_csv: str, _force: bool = False) -> list:
|
|
|
4747
5143
|
_params = [_command, plot, _tmp, plot_type, tmp_file, _width]
|
|
4748
5144
|
data.append(_params)
|
|
4749
5145
|
except Exception as e:
|
|
4750
|
-
|
|
4751
|
-
print_red(f"Error trying to print {plot_type} to CLI: {e}, {tb}")
|
|
4752
|
-
print_debug(f"Error trying to print {plot_type} to CLI: {e}")
|
|
5146
|
+
print_red(f"Error trying to print {plot_type} to CLI: {e}")
|
|
4753
5147
|
|
|
4754
5148
|
return data
|
|
4755
5149
|
|
|
@@ -4940,15 +5334,17 @@ def abandon_job(job: Job, trial_index: int, reason: str) -> bool:
|
|
|
4940
5334
|
if job:
|
|
4941
5335
|
try:
|
|
4942
5336
|
if ax_client:
|
|
4943
|
-
_trial =
|
|
4944
|
-
_trial
|
|
5337
|
+
_trial = get_ax_client_trial(trial_index)
|
|
5338
|
+
if _trial is None:
|
|
5339
|
+
return False
|
|
5340
|
+
|
|
5341
|
+
mark_abandoned(_trial, reason, trial_index)
|
|
4945
5342
|
print_debug(f"abandon_job: removing job {job}, trial_index: {trial_index}")
|
|
4946
5343
|
global_vars["jobs"].remove((job, trial_index))
|
|
4947
5344
|
else:
|
|
4948
5345
|
_fatal_error("ax_client could not be found", 101)
|
|
4949
5346
|
except Exception as e:
|
|
4950
|
-
|
|
4951
|
-
print_debug(f"ERROR in line {get_line_info()}: {e}")
|
|
5347
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
4952
5348
|
return False
|
|
4953
5349
|
job.cancel()
|
|
4954
5350
|
return True
|
|
@@ -4961,23 +5357,120 @@ def abandon_all_jobs() -> None:
|
|
|
4961
5357
|
if not abandoned:
|
|
4962
5358
|
print_debug(f"Job {job} could not be abandoned.")
|
|
4963
5359
|
|
|
4964
|
-
def
|
|
4965
|
-
if
|
|
4966
|
-
|
|
5360
|
+
def write_result_to_trace_file(res: str) -> bool:
|
|
5361
|
+
if res is None:
|
|
5362
|
+
sys.stderr.write("Provided result is None, nothing to write\n")
|
|
5363
|
+
return False
|
|
5364
|
+
|
|
5365
|
+
target_folder = get_current_run_folder()
|
|
5366
|
+
target_file = os.path.join(target_folder, "optimization_trace.html")
|
|
5367
|
+
|
|
5368
|
+
try:
|
|
5369
|
+
file_handle = open(target_file, "w", encoding="utf-8")
|
|
5370
|
+
except OSError as error:
|
|
5371
|
+
sys.stderr.write("Unable to open target file for writing\n")
|
|
5372
|
+
sys.stderr.write(str(error) + "\n")
|
|
5373
|
+
return False
|
|
5374
|
+
|
|
5375
|
+
try:
|
|
5376
|
+
written = file_handle.write(str(res))
|
|
5377
|
+
file_handle.flush()
|
|
5378
|
+
|
|
5379
|
+
if written == 0:
|
|
5380
|
+
sys.stderr.write("No data was written to the file\n")
|
|
5381
|
+
file_handle.close()
|
|
5382
|
+
return False
|
|
5383
|
+
except Exception as error:
|
|
5384
|
+
sys.stderr.write("Error occurred while writing to file\n")
|
|
5385
|
+
sys.stderr.write(str(error) + "\n")
|
|
5386
|
+
file_handle.close()
|
|
5387
|
+
return False
|
|
5388
|
+
|
|
5389
|
+
try:
|
|
5390
|
+
file_handle.close()
|
|
5391
|
+
except Exception as error:
|
|
5392
|
+
sys.stderr.write("Failed to properly close file\n")
|
|
5393
|
+
sys.stderr.write(str(error) + "\n")
|
|
5394
|
+
return False
|
|
5395
|
+
|
|
5396
|
+
return True
|
|
5397
|
+
|
|
5398
|
+
def render(plot_config: AxPlotConfig) -> None:
|
|
5399
|
+
if plot_config is None or "data" not in plot_config:
|
|
4967
5400
|
return None
|
|
4968
5401
|
|
|
4969
|
-
|
|
4970
|
-
|
|
4971
|
-
|
|
4972
|
-
|
|
4973
|
-
|
|
4974
|
-
|
|
4975
|
-
|
|
5402
|
+
res: str = plot_config.data # type: ignore
|
|
5403
|
+
|
|
5404
|
+
repair_funcs = """
|
|
5405
|
+
function decodeBData(obj) {
|
|
5406
|
+
if (!obj || typeof obj !== "object") {
|
|
5407
|
+
return obj;
|
|
5408
|
+
}
|
|
5409
|
+
|
|
5410
|
+
if (obj.bdata && obj.dtype) {
|
|
5411
|
+
var binary_string = atob(obj.bdata);
|
|
5412
|
+
var len = binary_string.length;
|
|
5413
|
+
var bytes = new Uint8Array(len);
|
|
5414
|
+
|
|
5415
|
+
for (var i = 0; i < len; i++) {
|
|
5416
|
+
bytes[i] = binary_string.charCodeAt(i);
|
|
5417
|
+
}
|
|
5418
|
+
|
|
5419
|
+
switch (obj.dtype) {
|
|
5420
|
+
case "i1": return Array.from(new Int8Array(bytes.buffer));
|
|
5421
|
+
case "i2": return Array.from(new Int16Array(bytes.buffer));
|
|
5422
|
+
case "i4": return Array.from(new Int32Array(bytes.buffer));
|
|
5423
|
+
case "f4": return Array.from(new Float32Array(bytes.buffer));
|
|
5424
|
+
case "f8": return Array.from(new Float64Array(bytes.buffer));
|
|
5425
|
+
default:
|
|
5426
|
+
console.error("Unknown dtype:", obj.dtype);
|
|
5427
|
+
return [];
|
|
5428
|
+
}
|
|
5429
|
+
}
|
|
5430
|
+
|
|
5431
|
+
return obj;
|
|
5432
|
+
}
|
|
5433
|
+
|
|
5434
|
+
function repairTraces(traces) {
|
|
5435
|
+
var fixed = [];
|
|
5436
|
+
|
|
5437
|
+
for (var i = 0; i < traces.length; i++) {
|
|
5438
|
+
var t = traces[i];
|
|
5439
|
+
|
|
5440
|
+
if (t.x) {
|
|
5441
|
+
t.x = decodeBData(t.x);
|
|
5442
|
+
}
|
|
5443
|
+
|
|
5444
|
+
if (t.y) {
|
|
5445
|
+
t.y = decodeBData(t.y);
|
|
5446
|
+
}
|
|
5447
|
+
|
|
5448
|
+
fixed.push(t);
|
|
5449
|
+
}
|
|
5450
|
+
|
|
5451
|
+
return fixed;
|
|
5452
|
+
}
|
|
5453
|
+
"""
|
|
5454
|
+
|
|
5455
|
+
res = str(res)
|
|
5456
|
+
|
|
5457
|
+
res = f"<div id='plot' style='width:100%;height:600px;'></div>\n<script type='text/javascript' src='https://cdn.plot.ly/plotly-latest.min.js'></script><script>{repair_funcs}\nconst True = true;\nconst False = false;\nconst data = {res};\ndata.data = repairTraces(data.data);\nPlotly.newPlot(document.getElementById('plot'), data.data, data.layout);</script>"
|
|
5458
|
+
|
|
5459
|
+
write_result_to_trace_file(res)
|
|
5460
|
+
|
|
4976
5461
|
return None
|
|
4977
5462
|
|
|
4978
5463
|
def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None) -> None:
|
|
4979
5464
|
global END_PROGRAM_RAN
|
|
4980
5465
|
|
|
5466
|
+
#dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.training_data[0].X)
|
|
5467
|
+
#dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.training_data[0].Y)
|
|
5468
|
+
#dier(global_gs.current_node.generator_specs[0]._fitted_adapter.generator._surrogate.outcomes)
|
|
5469
|
+
|
|
5470
|
+
if ax_client is not None:
|
|
5471
|
+
if len(arg_result_names) == 1:
|
|
5472
|
+
render(ax_client.get_optimization_trace())
|
|
5473
|
+
|
|
4981
5474
|
wait_for_jobs_to_complete()
|
|
4982
5475
|
|
|
4983
5476
|
show_pareto_or_error_msg(get_current_run_folder(), arg_result_names)
|
|
@@ -5011,7 +5504,7 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
|
|
|
5011
5504
|
_exit = new_exit
|
|
5012
5505
|
except (SignalUSR, SignalINT, SignalCONT, KeyboardInterrupt):
|
|
5013
5506
|
print_red("\n⚠ You pressed CTRL+C or a signal was sent. Program execution halted while ending program.")
|
|
5014
|
-
|
|
5507
|
+
print_red("\n⚠ KeyboardInterrupt signal was sent. Ending program will still run.")
|
|
5015
5508
|
new_exit = show_end_table_and_save_end_files()
|
|
5016
5509
|
if new_exit > 0:
|
|
5017
5510
|
_exit = new_exit
|
|
@@ -5032,19 +5525,29 @@ def end_program(_force: Optional[bool] = False, exit_code: Optional[int] = None)
|
|
|
5032
5525
|
|
|
5033
5526
|
my_exit(_exit)
|
|
5034
5527
|
|
|
5528
|
+
def save_ax_client_to_json_file(checkpoint_filepath: str) -> None:
|
|
5529
|
+
if not ax_client:
|
|
5530
|
+
my_exit(101)
|
|
5531
|
+
|
|
5532
|
+
return None
|
|
5533
|
+
|
|
5534
|
+
ax_client.save_to_json_file(checkpoint_filepath)
|
|
5535
|
+
|
|
5536
|
+
return None
|
|
5537
|
+
|
|
5035
5538
|
def save_checkpoint(trial_nr: int = 0, eee: Union[None, str, Exception] = None) -> None:
|
|
5036
5539
|
if trial_nr > 3:
|
|
5037
5540
|
if eee:
|
|
5038
|
-
|
|
5541
|
+
print_red(f"Error during saving checkpoint: {eee}")
|
|
5039
5542
|
else:
|
|
5040
|
-
|
|
5543
|
+
print_red("Error during saving checkpoint")
|
|
5041
5544
|
return
|
|
5042
5545
|
|
|
5043
5546
|
try:
|
|
5044
5547
|
checkpoint_filepath = get_state_file_name('checkpoint.json')
|
|
5045
5548
|
|
|
5046
5549
|
if ax_client:
|
|
5047
|
-
|
|
5550
|
+
save_ax_client_to_json_file(checkpoint_filepath)
|
|
5048
5551
|
else:
|
|
5049
5552
|
_fatal_error("Something went wrong using the ax_client", 101)
|
|
5050
5553
|
except Exception as e:
|
|
@@ -5207,7 +5710,7 @@ def parse_equation_item(comparer_found: bool, item: str, parsed: list, parsed_or
|
|
|
5207
5710
|
})
|
|
5208
5711
|
elif item in [">=", "<="]:
|
|
5209
5712
|
if comparer_found:
|
|
5210
|
-
|
|
5713
|
+
print_red("There is already one comparison operator! Cannot have more than one in an equation!")
|
|
5211
5714
|
return_totally = True
|
|
5212
5715
|
comparer_found = True
|
|
5213
5716
|
|
|
@@ -5418,20 +5921,9 @@ def check_equation(variables: list, equation: str) -> Union[str, bool]:
|
|
|
5418
5921
|
def set_objectives() -> dict:
|
|
5419
5922
|
objectives = {}
|
|
5420
5923
|
|
|
5421
|
-
|
|
5422
|
-
|
|
5423
|
-
|
|
5424
|
-
if "=" in rn:
|
|
5425
|
-
key, value = rn.split('=', 1)
|
|
5426
|
-
else:
|
|
5427
|
-
key = rn
|
|
5428
|
-
value = ""
|
|
5429
|
-
|
|
5430
|
-
if value not in ["min", "max"]:
|
|
5431
|
-
if value:
|
|
5432
|
-
print_yellow(f"Value '{value}' for --result_names {rn} is not a valid value. Must be min or max. Will be set to min.")
|
|
5433
|
-
|
|
5434
|
-
value = "min"
|
|
5924
|
+
k = 0
|
|
5925
|
+
for key in arg_result_names:
|
|
5926
|
+
value = arg_result_min_or_max[k]
|
|
5435
5927
|
|
|
5436
5928
|
_min = True
|
|
5437
5929
|
|
|
@@ -5439,6 +5931,7 @@ def set_objectives() -> dict:
|
|
|
5439
5931
|
_min = False
|
|
5440
5932
|
|
|
5441
5933
|
objectives[key] = ObjectiveProperties(minimize=_min)
|
|
5934
|
+
k = k + 1
|
|
5442
5935
|
|
|
5443
5936
|
return objectives
|
|
5444
5937
|
|
|
@@ -5660,7 +6153,7 @@ def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_param
|
|
|
5660
6153
|
|
|
5661
6154
|
replace_parameters_for_continued_jobs(args.parameter, cli_params_experiment_parameters)
|
|
5662
6155
|
|
|
5663
|
-
|
|
6156
|
+
save_ax_client_to_json_file(original_ax_client_file)
|
|
5664
6157
|
|
|
5665
6158
|
load_original_generation_strategy(original_ax_client_file)
|
|
5666
6159
|
load_ax_client_from_experiment_parameters()
|
|
@@ -5690,6 +6183,109 @@ def load_from_checkpoint(continue_previous_job: str, cli_params_experiment_param
|
|
|
5690
6183
|
|
|
5691
6184
|
return experiment_args, gpu_string, gpu_color
|
|
5692
6185
|
|
|
6186
|
+
def get_experiment_args_import_python_script() -> str:
|
|
6187
|
+
|
|
6188
|
+
return """from ax.service.ax_client import AxClient, ObjectiveProperties
|
|
6189
|
+
from ax.adapter.registry import Generators
|
|
6190
|
+
import random
|
|
6191
|
+
|
|
6192
|
+
"""
|
|
6193
|
+
|
|
6194
|
+
def get_generate_and_test_random_function_str() -> str:
|
|
6195
|
+
raw_data_entries = ",\n ".join(
|
|
6196
|
+
f'"{name}": random.uniform(0, 1)' for name in arg_result_names
|
|
6197
|
+
)
|
|
6198
|
+
|
|
6199
|
+
return f"""
|
|
6200
|
+
def generate_and_test_random_parameters(n: int) -> None:
|
|
6201
|
+
for _ in range(n):
|
|
6202
|
+
print("======================================")
|
|
6203
|
+
parameters, trial_index = ax_client.get_next_trial()
|
|
6204
|
+
print("Trial Index:", trial_index)
|
|
6205
|
+
print("Suggested parameters:", parameters)
|
|
6206
|
+
|
|
6207
|
+
ax_client.complete_trial(
|
|
6208
|
+
trial_index=trial_index,
|
|
6209
|
+
raw_data={{
|
|
6210
|
+
{raw_data_entries}
|
|
6211
|
+
}}
|
|
6212
|
+
)
|
|
6213
|
+
|
|
6214
|
+
generate_and_test_random_parameters({args.num_random_steps + 1})
|
|
6215
|
+
"""
|
|
6216
|
+
|
|
6217
|
+
def get_global_gs_string() -> str:
|
|
6218
|
+
seed_str = ""
|
|
6219
|
+
if args.seed is not None:
|
|
6220
|
+
seed_str = f"model_kwargs={{'seed': {args.seed}}},"
|
|
6221
|
+
|
|
6222
|
+
return f"""from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy
|
|
6223
|
+
|
|
6224
|
+
global_gs = GenerationStrategy(
|
|
6225
|
+
steps=[
|
|
6226
|
+
GenerationStep(
|
|
6227
|
+
generator=Generators.SOBOL,
|
|
6228
|
+
num_trials={args.num_random_steps},
|
|
6229
|
+
max_parallelism=5,
|
|
6230
|
+
{seed_str}
|
|
6231
|
+
),
|
|
6232
|
+
GenerationStep(
|
|
6233
|
+
generator=Generators.{args.model},
|
|
6234
|
+
num_trials=-1,
|
|
6235
|
+
max_parallelism=5,
|
|
6236
|
+
),
|
|
6237
|
+
]
|
|
6238
|
+
)
|
|
6239
|
+
"""
|
|
6240
|
+
|
|
6241
|
+
def get_debug_ax_client_str() -> str:
|
|
6242
|
+
return """
|
|
6243
|
+
ax_client = AxClient(
|
|
6244
|
+
verbose_logging=True,
|
|
6245
|
+
enforce_sequential_optimization=False,
|
|
6246
|
+
generation_strategy=global_gs
|
|
6247
|
+
)
|
|
6248
|
+
"""
|
|
6249
|
+
|
|
6250
|
+
def write_ax_debug_python_code(experiment_args: dict) -> None:
|
|
6251
|
+
if args.generation_strategy:
|
|
6252
|
+
print_debug("Cannot write debug code for custom generation_strategy")
|
|
6253
|
+
return None
|
|
6254
|
+
|
|
6255
|
+
if args.model in uncontinuable_models:
|
|
6256
|
+
print_debug(f"Cannot write debug code for uncontinuable mode {args.model}")
|
|
6257
|
+
return None
|
|
6258
|
+
|
|
6259
|
+
python_code = python_code = get_experiment_args_import_python_script() + \
|
|
6260
|
+
get_global_gs_string() + \
|
|
6261
|
+
get_debug_ax_client_str() + \
|
|
6262
|
+
"experiment_args = " + pformat(experiment_args, width=120, compact=False) + \
|
|
6263
|
+
"\nax_client.create_experiment(**experiment_args)\n" + \
|
|
6264
|
+
get_generate_and_test_random_function_str()
|
|
6265
|
+
|
|
6266
|
+
file_path = f"{get_current_run_folder()}/debug.py"
|
|
6267
|
+
|
|
6268
|
+
try:
|
|
6269
|
+
print_debug(python_code)
|
|
6270
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
6271
|
+
f.write(python_code)
|
|
6272
|
+
except Exception as e:
|
|
6273
|
+
print_red(f"Error while writing {file_path}: {e}")
|
|
6274
|
+
|
|
6275
|
+
return None
|
|
6276
|
+
|
|
6277
|
+
def create_ax_client_experiment(experiment_args: dict) -> None:
|
|
6278
|
+
if not ax_client:
|
|
6279
|
+
my_exit(101)
|
|
6280
|
+
|
|
6281
|
+
return None
|
|
6282
|
+
|
|
6283
|
+
write_ax_debug_python_code(experiment_args)
|
|
6284
|
+
|
|
6285
|
+
ax_client.create_experiment(**experiment_args)
|
|
6286
|
+
|
|
6287
|
+
return None
|
|
6288
|
+
|
|
5693
6289
|
def create_new_experiment() -> Tuple[dict, str, str]:
|
|
5694
6290
|
if ax_client is None:
|
|
5695
6291
|
print_red("create_new_experiment: ax_client is None")
|
|
@@ -5719,7 +6315,7 @@ def create_new_experiment() -> Tuple[dict, str, str]:
|
|
|
5719
6315
|
experiment_args = set_experiment_constraints(get_constraints(), experiment_args, experiment_parameters)
|
|
5720
6316
|
|
|
5721
6317
|
try:
|
|
5722
|
-
|
|
6318
|
+
create_ax_client_experiment(experiment_args)
|
|
5723
6319
|
new_metrics = [Metric(k) for k in arg_result_names if k not in ax_client.metric_names]
|
|
5724
6320
|
ax_client.experiment.add_tracking_metrics(new_metrics)
|
|
5725
6321
|
except AssertionError as error:
|
|
@@ -5733,7 +6329,7 @@ def create_new_experiment() -> Tuple[dict, str, str]:
|
|
|
5733
6329
|
|
|
5734
6330
|
return experiment_args, gpu_string, gpu_color
|
|
5735
6331
|
|
|
5736
|
-
def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[
|
|
6332
|
+
def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict | list]) -> Optional[Tuple[dict, str, str]]:
|
|
5737
6333
|
continue_previous_job = args.worker_generator_path or args.continue_previous_job
|
|
5738
6334
|
|
|
5739
6335
|
check_ax_client()
|
|
@@ -5743,7 +6339,7 @@ def get_experiment_parameters(cli_params_experiment_parameters: Optional[dict |
|
|
|
5743
6339
|
else:
|
|
5744
6340
|
experiment_args, gpu_string, gpu_color = create_new_experiment()
|
|
5745
6341
|
|
|
5746
|
-
return
|
|
6342
|
+
return experiment_args, gpu_string, gpu_color
|
|
5747
6343
|
|
|
5748
6344
|
def get_type_short(typename: str) -> str:
|
|
5749
6345
|
if typename == "RangeParameter":
|
|
@@ -5802,7 +6398,6 @@ def parse_single_experiment_parameter_table(classic_params: Optional[Union[list,
|
|
|
5802
6398
|
_upper = param["bounds"][1]
|
|
5803
6399
|
|
|
5804
6400
|
_possible_int_lower = str(helpers.to_int_when_possible(_lower))
|
|
5805
|
-
#print(f"name: {_name}, _possible_int_lower: {_possible_int_lower}, lower: {_lower}")
|
|
5806
6401
|
_possible_int_upper = str(helpers.to_int_when_possible(_upper))
|
|
5807
6402
|
|
|
5808
6403
|
rows.append([_name, _short_type, _possible_int_lower, _possible_int_upper, "", value_type, log_scale])
|
|
@@ -5879,19 +6474,62 @@ def print_ax_parameter_constraints_table(experiment_args: dict) -> None:
|
|
|
5879
6474
|
|
|
5880
6475
|
return None
|
|
5881
6476
|
|
|
6477
|
+
def check_base_for_print_overview() -> Optional[bool]:
|
|
6478
|
+
if args.continue_previous_job is not None and arg_result_names is not None and len(arg_result_names) != 0 and original_result_names is not None and len(original_result_names) != 0:
|
|
6479
|
+
print_yellow("--result_names will be ignored in continued jobs. The result names from the previous job will be used.")
|
|
6480
|
+
|
|
6481
|
+
if ax_client is None:
|
|
6482
|
+
print_red("ax_client was None")
|
|
6483
|
+
return None
|
|
6484
|
+
|
|
6485
|
+
if ax_client.experiment is None:
|
|
6486
|
+
print_red("ax_client.experiment was None")
|
|
6487
|
+
return None
|
|
6488
|
+
|
|
6489
|
+
if ax_client.experiment.optimization_config is None:
|
|
6490
|
+
print_red("ax_client.experiment.optimization_config was None")
|
|
6491
|
+
return None
|
|
6492
|
+
|
|
6493
|
+
return True
|
|
6494
|
+
|
|
6495
|
+
def get_config_objectives() -> Any:
|
|
6496
|
+
if not ax_client:
|
|
6497
|
+
print_red("create_new_experiment: ax_client is None")
|
|
6498
|
+
my_exit(101)
|
|
6499
|
+
|
|
6500
|
+
return None
|
|
6501
|
+
|
|
6502
|
+
config_objectives = None
|
|
6503
|
+
|
|
6504
|
+
if ax_client.experiment and ax_client.experiment.optimization_config:
|
|
6505
|
+
opt_config = ax_client.experiment.optimization_config
|
|
6506
|
+
if opt_config.is_moo_problem:
|
|
6507
|
+
objective = getattr(opt_config, "objective", None)
|
|
6508
|
+
if objective and getattr(objective, "objectives", None) is not None:
|
|
6509
|
+
config_objectives = objective.objectives
|
|
6510
|
+
else:
|
|
6511
|
+
print_debug("ax_client.experiment.optimization_config.objective was None")
|
|
6512
|
+
else:
|
|
6513
|
+
config_objectives = [opt_config.objective]
|
|
6514
|
+
else:
|
|
6515
|
+
print_debug("ax_client.experiment or optimization_config was None")
|
|
6516
|
+
|
|
6517
|
+
return config_objectives
|
|
6518
|
+
|
|
5882
6519
|
def print_result_names_overview_table() -> None:
|
|
5883
6520
|
if not ax_client:
|
|
5884
6521
|
_fatal_error("Tried to access ax_client in print_result_names_overview_table, but it failed, because the ax_client was not defined.", 101)
|
|
5885
6522
|
|
|
5886
6523
|
return None
|
|
5887
6524
|
|
|
5888
|
-
if
|
|
5889
|
-
|
|
6525
|
+
if check_base_for_print_overview() is None:
|
|
6526
|
+
return None
|
|
5890
6527
|
|
|
5891
|
-
|
|
5892
|
-
|
|
5893
|
-
|
|
5894
|
-
config_objectives
|
|
6528
|
+
config_objectives = get_config_objectives()
|
|
6529
|
+
|
|
6530
|
+
if config_objectives is None:
|
|
6531
|
+
print_red("config_objectives not found")
|
|
6532
|
+
return None
|
|
5895
6533
|
|
|
5896
6534
|
res_names = []
|
|
5897
6535
|
res_min_max = []
|
|
@@ -5986,11 +6624,13 @@ def print_overview_tables(classic_params: Optional[Union[list, dict]], experimen
|
|
|
5986
6624
|
print_result_names_overview_table()
|
|
5987
6625
|
|
|
5988
6626
|
def update_progress_bar(nr: int) -> None:
|
|
6627
|
+
log_data()
|
|
6628
|
+
|
|
5989
6629
|
if progress_bar is not None:
|
|
5990
6630
|
try:
|
|
5991
6631
|
progress_bar.update(nr)
|
|
5992
6632
|
except Exception as e:
|
|
5993
|
-
|
|
6633
|
+
print_red(f"Error updating progress bar: {e}")
|
|
5994
6634
|
else:
|
|
5995
6635
|
print_red("update_progress_bar: progress_bar was None")
|
|
5996
6636
|
|
|
@@ -5998,19 +6638,17 @@ def get_current_model_name() -> str:
|
|
|
5998
6638
|
if overwritten_to_random:
|
|
5999
6639
|
return "Random*"
|
|
6000
6640
|
|
|
6641
|
+
gs_model = "unknown model"
|
|
6642
|
+
|
|
6001
6643
|
if ax_client:
|
|
6002
6644
|
try:
|
|
6003
6645
|
if args.generation_strategy:
|
|
6004
|
-
idx = getattr(
|
|
6646
|
+
idx = getattr(global_gs, "current_step_index", None)
|
|
6005
6647
|
if isinstance(idx, int):
|
|
6006
6648
|
if 0 <= idx < len(generation_strategy_names):
|
|
6007
6649
|
gs_model = generation_strategy_names[int(idx)]
|
|
6008
|
-
else:
|
|
6009
|
-
gs_model = "unknown model"
|
|
6010
|
-
else:
|
|
6011
|
-
gs_model = "unknown model"
|
|
6012
6650
|
else:
|
|
6013
|
-
gs_model = getattr(
|
|
6651
|
+
gs_model = getattr(global_gs, "current_node_name", "unknown model")
|
|
6014
6652
|
|
|
6015
6653
|
if gs_model:
|
|
6016
6654
|
return str(gs_model)
|
|
@@ -6117,7 +6755,7 @@ def count_jobs_in_squeue() -> tuple[int, str]:
|
|
|
6117
6755
|
global _last_count_time, _last_count_result
|
|
6118
6756
|
|
|
6119
6757
|
now = int(time.time())
|
|
6120
|
-
if _last_count_result != (0, "") and now - _last_count_time <
|
|
6758
|
+
if _last_count_result != (0, "") and now - _last_count_time < 5:
|
|
6121
6759
|
return _last_count_result
|
|
6122
6760
|
|
|
6123
6761
|
_len = len(global_vars["jobs"])
|
|
@@ -6139,6 +6777,7 @@ def count_jobs_in_squeue() -> tuple[int, str]:
|
|
|
6139
6777
|
check=True,
|
|
6140
6778
|
text=True
|
|
6141
6779
|
)
|
|
6780
|
+
|
|
6142
6781
|
if "slurm_load_jobs error" in result.stderr:
|
|
6143
6782
|
_last_count_result = (_len, "Detected slurm_load_jobs error in stderr.")
|
|
6144
6783
|
_last_count_time = now
|
|
@@ -6179,6 +6818,8 @@ def log_worker_numbers() -> None:
|
|
|
6179
6818
|
if len(WORKER_PERCENTAGE_USAGE) == 0 or WORKER_PERCENTAGE_USAGE[len(WORKER_PERCENTAGE_USAGE) - 1] != this_values:
|
|
6180
6819
|
WORKER_PERCENTAGE_USAGE.append(this_values)
|
|
6181
6820
|
|
|
6821
|
+
write_worker_usage()
|
|
6822
|
+
|
|
6182
6823
|
def get_slurm_in_brackets(in_brackets: list) -> list:
|
|
6183
6824
|
if is_slurm_job():
|
|
6184
6825
|
workers_strings = get_workers_string()
|
|
@@ -6263,6 +6904,8 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
|
|
|
6263
6904
|
global last_progress_bar_desc
|
|
6264
6905
|
global last_progress_bar_refresh_time
|
|
6265
6906
|
|
|
6907
|
+
log_data()
|
|
6908
|
+
|
|
6266
6909
|
if isinstance(new_msgs, str):
|
|
6267
6910
|
new_msgs = [new_msgs]
|
|
6268
6911
|
|
|
@@ -6280,7 +6923,7 @@ def progressbar_description(new_msgs: Union[str, List[str]] = []) -> None:
|
|
|
6280
6923
|
print_red("Cannot update progress bar! It is None.")
|
|
6281
6924
|
|
|
6282
6925
|
def clean_completed_jobs() -> None:
|
|
6283
|
-
job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail"]
|
|
6926
|
+
job_states_to_be_removed = ["early_stopped", "abandoned", "cancelled", "timeout", "interrupted", "failed", "preempted", "node_fail", "boot_fail", "finished"]
|
|
6284
6927
|
job_states_to_be_ignored = ["ready", "completed", "unknown", "pending", "running", "completing", "out_of_memory", "requeued", "resv_del_hold"]
|
|
6285
6928
|
|
|
6286
6929
|
for job, trial_index in global_vars["jobs"][:]:
|
|
@@ -6425,7 +7068,7 @@ def get_generation_node_for_index(
|
|
|
6425
7068
|
|
|
6426
7069
|
return generation_node
|
|
6427
7070
|
except Exception as e:
|
|
6428
|
-
|
|
7071
|
+
print_red(f"Error while get_generation_node_for_index: {e}")
|
|
6429
7072
|
return "MANUAL"
|
|
6430
7073
|
|
|
6431
7074
|
def _get_generation_node_for_index_index_valid(
|
|
@@ -6598,11 +7241,21 @@ def check_ax_client() -> None:
|
|
|
6598
7241
|
if ax_client is None or not ax_client:
|
|
6599
7242
|
_fatal_error("insert_job_into_ax_client: ax_client was not defined where it should have been", 101)
|
|
6600
7243
|
|
|
7244
|
+
def attach_ax_client_data(arm_params: dict) -> Optional[Tuple[Any, int]]:
|
|
7245
|
+
if not ax_client:
|
|
7246
|
+
my_exit(101)
|
|
7247
|
+
|
|
7248
|
+
return None
|
|
7249
|
+
|
|
7250
|
+
new_trial = ax_client.attach_trial(arm_params)
|
|
7251
|
+
|
|
7252
|
+
return new_trial
|
|
7253
|
+
|
|
6601
7254
|
def attach_trial(arm_params: dict) -> Tuple[Any, int]:
|
|
6602
7255
|
if ax_client is None:
|
|
6603
7256
|
raise RuntimeError("attach_trial: ax_client was empty")
|
|
6604
7257
|
|
|
6605
|
-
new_trial =
|
|
7258
|
+
new_trial = attach_ax_client_data(arm_params)
|
|
6606
7259
|
if not isinstance(new_trial, tuple) or len(new_trial) < 2:
|
|
6607
7260
|
raise RuntimeError("attach_trial didn't return the expected tuple")
|
|
6608
7261
|
return new_trial
|
|
@@ -6633,7 +7286,7 @@ def complete_trial_if_result(trial_idx: int, result: dict, __status: Optional[An
|
|
|
6633
7286
|
is_ok = False
|
|
6634
7287
|
|
|
6635
7288
|
if is_ok:
|
|
6636
|
-
|
|
7289
|
+
complete_ax_client_trial(trial_idx, result)
|
|
6637
7290
|
update_status(__status, base_str, "Completed trial")
|
|
6638
7291
|
else:
|
|
6639
7292
|
print_debug("Empty job encountered")
|
|
@@ -7030,7 +7683,7 @@ def get_parameters_from_outfile(stdout_path: str) -> Union[None, dict, str]:
|
|
|
7030
7683
|
if not args.tests:
|
|
7031
7684
|
original_print(f"get_parameters_from_outfile: The file '{stdout_path}' was not found.")
|
|
7032
7685
|
except Exception as e:
|
|
7033
|
-
|
|
7686
|
+
print_red(f"get_parameters_from_outfile: There was an error: {e}")
|
|
7034
7687
|
|
|
7035
7688
|
return None
|
|
7036
7689
|
|
|
@@ -7048,7 +7701,7 @@ def get_hostname_from_outfile(stdout_path: Optional[str]) -> Optional[str]:
|
|
|
7048
7701
|
original_print(f"The file '{stdout_path}' was not found.")
|
|
7049
7702
|
return None
|
|
7050
7703
|
except Exception as e:
|
|
7051
|
-
|
|
7704
|
+
print_red(f"There was an error: {e}")
|
|
7052
7705
|
return None
|
|
7053
7706
|
|
|
7054
7707
|
def add_to_global_error_list(msg: str) -> None:
|
|
@@ -7084,22 +7737,70 @@ def mark_trial_as_failed(trial_index: int, _trial: Any) -> None:
|
|
|
7084
7737
|
|
|
7085
7738
|
return None
|
|
7086
7739
|
|
|
7087
|
-
|
|
7740
|
+
log_ax_client_trial_failure(trial_index)
|
|
7088
7741
|
_trial.mark_failed(unsafe=True)
|
|
7089
7742
|
except ValueError as e:
|
|
7090
7743
|
print_debug(f"mark_trial_as_failed error: {e}")
|
|
7091
7744
|
|
|
7092
7745
|
return None
|
|
7093
7746
|
|
|
7094
|
-
def check_valid_result(result: Union[None,
|
|
7747
|
+
def check_valid_result(result: Union[None, dict]) -> bool:
|
|
7095
7748
|
possible_val_not_found_values = [
|
|
7096
7749
|
VAL_IF_NOTHING_FOUND,
|
|
7097
7750
|
-VAL_IF_NOTHING_FOUND,
|
|
7098
7751
|
-99999999999999997168788049560464200849936328366177157906432,
|
|
7099
7752
|
99999999999999997168788049560464200849936328366177157906432
|
|
7100
7753
|
]
|
|
7101
|
-
|
|
7102
|
-
|
|
7754
|
+
|
|
7755
|
+
def flatten_values(obj: Any) -> Any:
|
|
7756
|
+
values = []
|
|
7757
|
+
try:
|
|
7758
|
+
if isinstance(obj, dict):
|
|
7759
|
+
for v in obj.values():
|
|
7760
|
+
values.extend(flatten_values(v))
|
|
7761
|
+
elif isinstance(obj, (list, tuple, set)):
|
|
7762
|
+
for v in obj:
|
|
7763
|
+
values.extend(flatten_values(v))
|
|
7764
|
+
else:
|
|
7765
|
+
values.append(obj)
|
|
7766
|
+
except Exception as e:
|
|
7767
|
+
print_red(f"Error while flattening values: {e}")
|
|
7768
|
+
return values
|
|
7769
|
+
|
|
7770
|
+
if result is None:
|
|
7771
|
+
return False
|
|
7772
|
+
|
|
7773
|
+
try:
|
|
7774
|
+
all_values = flatten_values(result)
|
|
7775
|
+
for val in all_values:
|
|
7776
|
+
if val in possible_val_not_found_values:
|
|
7777
|
+
return False
|
|
7778
|
+
return True
|
|
7779
|
+
except Exception as e:
|
|
7780
|
+
print_red(f"Error while checking result validity: {e}")
|
|
7781
|
+
return False
|
|
7782
|
+
|
|
7783
|
+
def update_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
|
|
7784
|
+
if not ax_client:
|
|
7785
|
+
my_exit(101)
|
|
7786
|
+
|
|
7787
|
+
return None
|
|
7788
|
+
|
|
7789
|
+
trial = get_trial_by_index(trial_idx)
|
|
7790
|
+
|
|
7791
|
+
trial.update_trial_data(raw_data=result)
|
|
7792
|
+
|
|
7793
|
+
return None
|
|
7794
|
+
|
|
7795
|
+
def complete_ax_client_trial(trial_idx: int, result: Union[list, dict]) -> None:
|
|
7796
|
+
if not ax_client:
|
|
7797
|
+
my_exit(101)
|
|
7798
|
+
|
|
7799
|
+
return None
|
|
7800
|
+
|
|
7801
|
+
ax_client.complete_trial(trial_index=trial_idx, raw_data=result)
|
|
7802
|
+
|
|
7803
|
+
return None
|
|
7103
7804
|
|
|
7104
7805
|
def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -> None:
|
|
7105
7806
|
if ax_client is None:
|
|
@@ -7108,25 +7809,64 @@ def _finish_job_core_helper_complete_trial(trial_index: int, raw_result: dict) -
|
|
|
7108
7809
|
|
|
7109
7810
|
try:
|
|
7110
7811
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result}...")
|
|
7111
|
-
|
|
7812
|
+
complete_ax_client_trial(trial_index, raw_result)
|
|
7112
7813
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result}... Done!")
|
|
7113
7814
|
except ax.exceptions.core.UnsupportedError as e:
|
|
7114
7815
|
if f"{e}":
|
|
7115
7816
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure. Trying to update trial...")
|
|
7116
|
-
|
|
7817
|
+
update_ax_client_trial(trial_index, raw_result)
|
|
7117
7818
|
print_debug(f"Completing trial: {trial_index} with result: {raw_result} after failure... Done!")
|
|
7118
7819
|
else:
|
|
7119
7820
|
_fatal_error(f"Error completing trial: {e}", 234)
|
|
7120
7821
|
|
|
7121
7822
|
return None
|
|
7122
7823
|
|
|
7123
|
-
def
|
|
7824
|
+
def format_result_for_display(result: dict) -> str:
|
|
7825
|
+
def safe_float(v: Any) -> str:
|
|
7826
|
+
try:
|
|
7827
|
+
if v is None:
|
|
7828
|
+
return "None"
|
|
7829
|
+
if isinstance(v, (int, float)):
|
|
7830
|
+
if math.isnan(v):
|
|
7831
|
+
return "NaN"
|
|
7832
|
+
if math.isinf(v):
|
|
7833
|
+
return "∞" if v > 0 else "-∞"
|
|
7834
|
+
return f"{v:.6f}"
|
|
7835
|
+
return str(v)
|
|
7836
|
+
except Exception as e:
|
|
7837
|
+
return f"<error: {e}>"
|
|
7838
|
+
|
|
7839
|
+
try:
|
|
7840
|
+
if not isinstance(result, dict):
|
|
7841
|
+
return safe_float(result)
|
|
7842
|
+
|
|
7843
|
+
parts = []
|
|
7844
|
+
for key, val in result.items():
|
|
7845
|
+
try:
|
|
7846
|
+
if isinstance(val, (list, tuple)) and len(val) == 2:
|
|
7847
|
+
main, sem = val
|
|
7848
|
+
main_str = safe_float(main)
|
|
7849
|
+
if sem is not None:
|
|
7850
|
+
sem_str = safe_float(sem)
|
|
7851
|
+
parts.append(f"{key}: {main_str} (SEM: {sem_str})")
|
|
7852
|
+
else:
|
|
7853
|
+
parts.append(f"{key}: {main_str}")
|
|
7854
|
+
else:
|
|
7855
|
+
parts.append(f"{key}: {safe_float(val)}")
|
|
7856
|
+
except Exception as e:
|
|
7857
|
+
parts.append(f"{key}: <error: {e}>")
|
|
7858
|
+
|
|
7859
|
+
return ", ".join(parts)
|
|
7860
|
+
except Exception as e:
|
|
7861
|
+
return f"<error formatting result: {e}>"
|
|
7862
|
+
|
|
7863
|
+
def _finish_job_core_helper_mark_success(_trial: ax.core.trial.Trial, result: dict) -> None:
|
|
7124
7864
|
print_debug(f"Marking trial {_trial} as completed")
|
|
7125
7865
|
_trial.mark_completed(unsafe=True)
|
|
7126
7866
|
|
|
7127
7867
|
succeeded_jobs(1)
|
|
7128
7868
|
|
|
7129
|
-
progressbar_description(f"new result: {result}")
|
|
7869
|
+
progressbar_description(f"new result: {format_result_for_display(result)}")
|
|
7130
7870
|
update_progress_bar(1)
|
|
7131
7871
|
|
|
7132
7872
|
save_results_csv()
|
|
@@ -7140,8 +7880,8 @@ def _finish_job_core_helper_mark_failure(job: Any, trial_index: int, _trial: Any
|
|
|
7140
7880
|
if job:
|
|
7141
7881
|
try:
|
|
7142
7882
|
progressbar_description("job_failed")
|
|
7143
|
-
|
|
7144
|
-
_trial
|
|
7883
|
+
log_ax_client_trial_failure(trial_index)
|
|
7884
|
+
mark_trial_as_failed(trial_index, _trial)
|
|
7145
7885
|
except Exception as e:
|
|
7146
7886
|
print_red(f"\nERROR while trying to mark job as failure: {e}")
|
|
7147
7887
|
job.cancel()
|
|
@@ -7158,16 +7898,16 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
|
|
|
7158
7898
|
result = job.result()
|
|
7159
7899
|
print_debug(f"finish_job_core: trial-index: {trial_index}, job.result(): {result}, state: {state_from_job(job)}")
|
|
7160
7900
|
|
|
7161
|
-
raw_result = result
|
|
7162
|
-
result_keys = list(result.keys())
|
|
7163
|
-
result = result[result_keys[0]]
|
|
7164
7901
|
this_jobs_finished += 1
|
|
7165
7902
|
|
|
7166
7903
|
if ax_client:
|
|
7167
|
-
_trial =
|
|
7904
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7905
|
+
|
|
7906
|
+
if _trial is None:
|
|
7907
|
+
return 0
|
|
7168
7908
|
|
|
7169
7909
|
if check_valid_result(result):
|
|
7170
|
-
_finish_job_core_helper_complete_trial(trial_index,
|
|
7910
|
+
_finish_job_core_helper_complete_trial(trial_index, result)
|
|
7171
7911
|
|
|
7172
7912
|
try:
|
|
7173
7913
|
_finish_job_core_helper_mark_success(_trial, result)
|
|
@@ -7175,7 +7915,7 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
|
|
|
7175
7915
|
if len(arg_result_names) > 1 and count_done_jobs() > 1 and not job_calculate_pareto_front(get_current_run_folder(), True):
|
|
7176
7916
|
print_red("job_calculate_pareto_front post job failed")
|
|
7177
7917
|
except Exception as e:
|
|
7178
|
-
|
|
7918
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
7179
7919
|
else:
|
|
7180
7920
|
_finish_job_core_helper_mark_failure(job, trial_index, _trial)
|
|
7181
7921
|
else:
|
|
@@ -7184,6 +7924,8 @@ def finish_job_core(job: Any, trial_index: int, this_jobs_finished: int) -> int:
|
|
|
7184
7924
|
print_debug(f"finish_job_core: removing job {job}, trial_index: {trial_index}")
|
|
7185
7925
|
global_vars["jobs"].remove((job, trial_index))
|
|
7186
7926
|
|
|
7927
|
+
log_data()
|
|
7928
|
+
|
|
7187
7929
|
force_live_share()
|
|
7188
7930
|
|
|
7189
7931
|
return this_jobs_finished
|
|
@@ -7196,11 +7938,14 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
|
|
|
7196
7938
|
if job:
|
|
7197
7939
|
try:
|
|
7198
7940
|
progressbar_description("job_failed")
|
|
7199
|
-
_trial =
|
|
7200
|
-
|
|
7941
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7942
|
+
if _trial is None:
|
|
7943
|
+
return None
|
|
7944
|
+
|
|
7945
|
+
log_ax_client_trial_failure(trial_index)
|
|
7201
7946
|
mark_trial_as_failed(trial_index, _trial)
|
|
7202
7947
|
except Exception as e:
|
|
7203
|
-
|
|
7948
|
+
print_debug(f"ERROR in line {get_line_info()}: {e}")
|
|
7204
7949
|
job.cancel()
|
|
7205
7950
|
orchestrate_job(job, trial_index)
|
|
7206
7951
|
|
|
@@ -7215,10 +7960,12 @@ def _finish_previous_jobs_helper_handle_failed_job(job: Any, trial_index: int) -
|
|
|
7215
7960
|
|
|
7216
7961
|
def _finish_previous_jobs_helper_handle_exception(job: Any, trial_index: int, error: Exception) -> int:
|
|
7217
7962
|
if "None for metric" in str(error):
|
|
7218
|
-
|
|
7219
|
-
|
|
7220
|
-
|
|
7221
|
-
|
|
7963
|
+
err_msg = f"\n⚠ It seems like the program that was about to be run didn't have 'RESULT: <FLOAT>' in it's output string.\nError: {error}\nJob-result: {job.result()}"
|
|
7964
|
+
|
|
7965
|
+
if count_done_jobs() == 0:
|
|
7966
|
+
print_red(err_msg)
|
|
7967
|
+
else:
|
|
7968
|
+
print_debug(err_msg)
|
|
7222
7969
|
else:
|
|
7223
7970
|
print_red(f"\n⚠ {error}")
|
|
7224
7971
|
|
|
@@ -7237,7 +7984,10 @@ def _finish_previous_jobs_helper_process_job(job: Any, trial_index: int, this_jo
|
|
|
7237
7984
|
this_jobs_finished += _finish_previous_jobs_helper_handle_exception(job, trial_index, error)
|
|
7238
7985
|
return this_jobs_finished
|
|
7239
7986
|
|
|
7240
|
-
def _finish_previous_jobs_helper_check_and_process(
|
|
7987
|
+
def _finish_previous_jobs_helper_check_and_process(__args: Tuple[Any, int]) -> int:
|
|
7988
|
+
job, trial_index = __args
|
|
7989
|
+
|
|
7990
|
+
this_jobs_finished = 0
|
|
7241
7991
|
if job is None:
|
|
7242
7992
|
print_debug(f"finish_previous_jobs: job {job} is None")
|
|
7243
7993
|
return this_jobs_finished
|
|
@@ -7250,10 +8000,6 @@ def _finish_previous_jobs_helper_check_and_process(job: Any, trial_index: int, t
|
|
|
7250
8000
|
|
|
7251
8001
|
return this_jobs_finished
|
|
7252
8002
|
|
|
7253
|
-
def _finish_previous_jobs_helper_wrapper(__args: Tuple[Any, int]) -> int:
|
|
7254
|
-
job, trial_index = __args
|
|
7255
|
-
return _finish_previous_jobs_helper_check_and_process(job, trial_index, 0)
|
|
7256
|
-
|
|
7257
8003
|
def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
7258
8004
|
global JOBS_FINISHED
|
|
7259
8005
|
|
|
@@ -7266,13 +8012,10 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
|
7266
8012
|
|
|
7267
8013
|
jobs_copy = global_vars["jobs"][:]
|
|
7268
8014
|
|
|
7269
|
-
|
|
7270
|
-
print_debug(f"jobs in finish_previous_jobs: {jobs_copy}")
|
|
7271
|
-
|
|
7272
|
-
finishing_jobs_start_time = time.time()
|
|
8015
|
+
#finishing_jobs_start_time = time.time()
|
|
7273
8016
|
|
|
7274
8017
|
with ThreadPoolExecutor() as finish_job_executor:
|
|
7275
|
-
futures = [finish_job_executor.submit(
|
|
8018
|
+
futures = [finish_job_executor.submit(_finish_previous_jobs_helper_check_and_process, (job, trial_index)) for job, trial_index in jobs_copy]
|
|
7276
8019
|
|
|
7277
8020
|
for future in as_completed(futures):
|
|
7278
8021
|
try:
|
|
@@ -7280,11 +8023,11 @@ def finish_previous_jobs(new_msgs: List[str] = []) -> None:
|
|
|
7280
8023
|
except Exception as e:
|
|
7281
8024
|
print_red(f"⚠ Exception in parallel job handling: {e}")
|
|
7282
8025
|
|
|
7283
|
-
finishing_jobs_end_time = time.time()
|
|
8026
|
+
#finishing_jobs_end_time = time.time()
|
|
7284
8027
|
|
|
7285
|
-
finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
|
|
8028
|
+
#finishing_jobs_runtime = finishing_jobs_end_time - finishing_jobs_start_time
|
|
7286
8029
|
|
|
7287
|
-
print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
|
|
8030
|
+
#print_debug(f"Finishing jobs took {finishing_jobs_runtime} second(s)")
|
|
7288
8031
|
|
|
7289
8032
|
if this_jobs_finished > 0:
|
|
7290
8033
|
save_results_csv()
|
|
@@ -7450,24 +8193,43 @@ def submit_new_job(parameters: Union[dict, str], trial_index: int) -> Any:
|
|
|
7450
8193
|
|
|
7451
8194
|
print_debug(f"Done submitting new job, took {elapsed} seconds")
|
|
7452
8195
|
|
|
8196
|
+
log_data()
|
|
8197
|
+
|
|
7453
8198
|
return new_job
|
|
7454
8199
|
|
|
8200
|
+
def get_ax_client_trial(trial_index: int) -> Optional[ax.core.trial.Trial]:
|
|
8201
|
+
if not ax_client:
|
|
8202
|
+
my_exit(101)
|
|
8203
|
+
|
|
8204
|
+
return None
|
|
8205
|
+
|
|
8206
|
+
try:
|
|
8207
|
+
log_data()
|
|
8208
|
+
|
|
8209
|
+
return ax_client.get_trial(trial_index)
|
|
8210
|
+
except KeyError:
|
|
8211
|
+
error_without_print(f"get_ax_client_trial: trial_index {trial_index} failed")
|
|
8212
|
+
return None
|
|
8213
|
+
|
|
7455
8214
|
def orchestrator_start_trial(parameters: Union[dict, str], trial_index: int) -> None:
|
|
7456
8215
|
if submitit_executor and ax_client:
|
|
7457
8216
|
new_job = submit_new_job(parameters, trial_index)
|
|
7458
8217
|
if new_job:
|
|
7459
8218
|
submitted_jobs(1)
|
|
7460
8219
|
|
|
7461
|
-
_trial =
|
|
8220
|
+
_trial = get_ax_client_trial(trial_index)
|
|
7462
8221
|
|
|
7463
|
-
|
|
7464
|
-
|
|
7465
|
-
|
|
7466
|
-
|
|
7467
|
-
|
|
8222
|
+
if _trial is not None:
|
|
8223
|
+
try:
|
|
8224
|
+
_trial.mark_staged(unsafe=True)
|
|
8225
|
+
except Exception as e:
|
|
8226
|
+
print_debug(f"orchestrator_start_trial: error {e}")
|
|
8227
|
+
_trial.mark_running(unsafe=True, no_runner_required=True)
|
|
7468
8228
|
|
|
7469
|
-
|
|
7470
|
-
|
|
8229
|
+
print_debug(f"orchestrator_start_trial: appending job {new_job} to global_vars['jobs'], trial_index: {trial_index}")
|
|
8230
|
+
global_vars["jobs"].append((new_job, trial_index))
|
|
8231
|
+
else:
|
|
8232
|
+
print_red("Trial was none in orchestrator_start_trial")
|
|
7471
8233
|
else:
|
|
7472
8234
|
print_red("orchestrator_start_trial: Failed to start new job")
|
|
7473
8235
|
elif ax_client:
|
|
@@ -7492,7 +8254,7 @@ def handle_restart(stdout_path: str, trial_index: int) -> None:
|
|
|
7492
8254
|
if parameters:
|
|
7493
8255
|
orchestrator_start_trial(parameters, trial_index)
|
|
7494
8256
|
else:
|
|
7495
|
-
|
|
8257
|
+
print_red(f"Could not determine parameters from outfile {stdout_path} for restarting job")
|
|
7496
8258
|
|
|
7497
8259
|
def check_alternate_path(path: str) -> str:
|
|
7498
8260
|
if os.path.exists(path):
|
|
@@ -7588,7 +8350,11 @@ def execute_evaluation(_params: list) -> Optional[int]:
|
|
|
7588
8350
|
|
|
7589
8351
|
return None
|
|
7590
8352
|
|
|
7591
|
-
_trial =
|
|
8353
|
+
_trial = get_ax_client_trial(trial_index)
|
|
8354
|
+
|
|
8355
|
+
if _trial is None:
|
|
8356
|
+
error_without_print(f"execute_evaluation: _trial was not in execute_evaluation for params {_params}")
|
|
8357
|
+
return None
|
|
7592
8358
|
|
|
7593
8359
|
def mark_trial_stage(stage: str, error_msg: str) -> None:
|
|
7594
8360
|
try:
|
|
@@ -7656,7 +8422,7 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
|
|
|
7656
8422
|
my_exit(144)
|
|
7657
8423
|
|
|
7658
8424
|
if new_job is None:
|
|
7659
|
-
|
|
8425
|
+
print_debug("handle_failed_job: job is None")
|
|
7660
8426
|
|
|
7661
8427
|
return None
|
|
7662
8428
|
|
|
@@ -7667,16 +8433,24 @@ def handle_failed_job(error: Union[None, Exception, str], trial_index: int, new_
|
|
|
7667
8433
|
|
|
7668
8434
|
return None
|
|
7669
8435
|
|
|
8436
|
+
def log_ax_client_trial_failure(trial_index: int) -> None:
|
|
8437
|
+
if not ax_client:
|
|
8438
|
+
my_exit(101)
|
|
8439
|
+
|
|
8440
|
+
return
|
|
8441
|
+
|
|
8442
|
+
ax_client.log_trial_failure(trial_index=trial_index)
|
|
8443
|
+
|
|
7670
8444
|
def cancel_failed_job(trial_index: int, new_job: Job) -> None:
|
|
7671
8445
|
print_debug("Trying to cancel job that failed")
|
|
7672
8446
|
if new_job:
|
|
7673
8447
|
try:
|
|
7674
8448
|
if ax_client:
|
|
7675
|
-
|
|
8449
|
+
log_ax_client_trial_failure(trial_index)
|
|
7676
8450
|
else:
|
|
7677
8451
|
_fatal_error("ax_client not defined", 101)
|
|
7678
8452
|
except Exception as e:
|
|
7679
|
-
|
|
8453
|
+
print_red(f"ERROR in line {get_line_info()}: {e}")
|
|
7680
8454
|
new_job.cancel()
|
|
7681
8455
|
|
|
7682
8456
|
print_debug(f"cancel_failed_job: removing job {new_job}, trial_index: {trial_index}")
|
|
@@ -7730,6 +8504,8 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
|
|
|
7730
8504
|
_submitted_jobs = submitted_jobs()
|
|
7731
8505
|
_failed_jobs = failed_jobs()
|
|
7732
8506
|
|
|
8507
|
+
log_data()
|
|
8508
|
+
|
|
7733
8509
|
max_failed_jobs = max_eval
|
|
7734
8510
|
|
|
7735
8511
|
if args.max_failed_jobs is not None and args.max_failed_jobs > 0:
|
|
@@ -7761,7 +8537,7 @@ def break_run_search(_name: str, _max_eval: Optional[int]) -> bool:
|
|
|
7761
8537
|
|
|
7762
8538
|
return _ret
|
|
7763
8539
|
|
|
7764
|
-
def
|
|
8540
|
+
def calculate_nr_of_jobs_to_get(simulated_jobs: int, currently_running_jobs: int) -> int:
|
|
7765
8541
|
"""Calculates the number of jobs to retrieve."""
|
|
7766
8542
|
return min(
|
|
7767
8543
|
max_eval + simulated_jobs - count_done_jobs(),
|
|
@@ -7774,7 +8550,7 @@ def remove_extra_spaces(text: str) -> str:
|
|
|
7774
8550
|
raise ValueError("Input must be a string")
|
|
7775
8551
|
return re.sub(r'\s+', ' ', text).strip()
|
|
7776
8552
|
|
|
7777
|
-
def
|
|
8553
|
+
def get_trials_message(nr_of_jobs_to_get: int, full_nr_of_jobs_to_get: int, trial_durations: List[float]) -> str:
|
|
7778
8554
|
"""Generates the appropriate message for the number of trials being retrieved."""
|
|
7779
8555
|
ret = ""
|
|
7780
8556
|
if full_nr_of_jobs_to_get > 1:
|
|
@@ -7859,51 +8635,52 @@ def get_batched_arms(nr_of_jobs_to_get: int) -> list:
|
|
|
7859
8635
|
print_red("get_batched_arms: ax_client was None")
|
|
7860
8636
|
return []
|
|
7861
8637
|
|
|
7862
|
-
# Experiment-Status laden
|
|
7863
8638
|
load_experiment_state()
|
|
7864
8639
|
|
|
7865
|
-
while len(batched_arms)
|
|
8640
|
+
while len(batched_arms) < nr_of_jobs_to_get:
|
|
7866
8641
|
if attempts > args.max_attempts_for_generation:
|
|
7867
8642
|
print_debug(f"get_batched_arms: Stopped after {attempts} attempts: could not generate enough arms "
|
|
7868
8643
|
f"(got {len(batched_arms)} out of {nr_of_jobs_to_get}).")
|
|
7869
8644
|
break
|
|
7870
8645
|
|
|
7871
|
-
|
|
7872
|
-
print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting {remaining} more arm(s).")
|
|
7873
|
-
|
|
7874
|
-
print_debug("get pending observations")
|
|
7875
|
-
t0 = time.time()
|
|
7876
|
-
pending_observations = get_pending_observation_features(experiment=ax_client.experiment)
|
|
7877
|
-
dt = time.time() - t0
|
|
7878
|
-
print_debug(f"got pending observations: {pending_observations} (took {dt:.2f} seconds)")
|
|
8646
|
+
#print_debug(f"get_batched_arms: Attempt {attempts + 1}: requesting 1 more arm")
|
|
7879
8647
|
|
|
7880
|
-
|
|
7881
|
-
batched_generator_run = global_gs.gen(
|
|
8648
|
+
pending_observations = get_pending_observation_features(
|
|
7882
8649
|
experiment=ax_client.experiment,
|
|
7883
|
-
|
|
7884
|
-
pending_observations=pending_observations
|
|
8650
|
+
include_out_of_design_points=True
|
|
7885
8651
|
)
|
|
7886
|
-
|
|
8652
|
+
|
|
8653
|
+
try:
|
|
8654
|
+
#print_debug("getting global_gs.gen() with n=1")
|
|
8655
|
+
|
|
8656
|
+
batched_generator_run: Any = global_gs.gen(
|
|
8657
|
+
experiment=ax_client.experiment,
|
|
8658
|
+
n=1,
|
|
8659
|
+
pending_observations=pending_observations,
|
|
8660
|
+
)
|
|
8661
|
+
print_debug(f"got global_gs.gen(): {batched_generator_run}")
|
|
8662
|
+
except Exception as e:
|
|
8663
|
+
print_debug(f"global_gs.gen failed: {e}")
|
|
8664
|
+
traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
|
|
8665
|
+
break
|
|
7887
8666
|
|
|
7888
8667
|
depth = 0
|
|
7889
8668
|
path = "batched_generator_run"
|
|
7890
|
-
while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run)
|
|
7891
|
-
print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
|
|
8669
|
+
while isinstance(batched_generator_run, (list, tuple)) and len(batched_generator_run) == 1:
|
|
8670
|
+
#print_debug(f"Depth {depth}, path {path}, type {type(batched_generator_run).__name__}, length {len(batched_generator_run)}: {batched_generator_run}")
|
|
7892
8671
|
batched_generator_run = batched_generator_run[0]
|
|
7893
8672
|
path += "[0]"
|
|
7894
8673
|
depth += 1
|
|
7895
8674
|
|
|
7896
|
-
print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
|
|
8675
|
+
#print_debug(f"Final flat object at depth {depth}, path {path}: {batched_generator_run} (type {type(batched_generator_run).__name__})")
|
|
7897
8676
|
|
|
7898
|
-
|
|
7899
|
-
new_arms = batched_generator_run.arms
|
|
7900
|
-
print_debug(f"new_arms: {new_arms}")
|
|
8677
|
+
new_arms = getattr(batched_generator_run, "arms", [])
|
|
7901
8678
|
if not new_arms:
|
|
7902
8679
|
print_debug("get_batched_arms: No new arms were generated in this attempt.")
|
|
7903
8680
|
else:
|
|
7904
|
-
print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s),
|
|
8681
|
+
print_debug(f"get_batched_arms: Generated {len(new_arms)} new arm(s), now at {len(batched_arms) + len(new_arms)} of {nr_of_jobs_to_get}.")
|
|
8682
|
+
batched_arms.extend(new_arms)
|
|
7905
8683
|
|
|
7906
|
-
batched_arms.extend(new_arms)
|
|
7907
8684
|
attempts += 1
|
|
7908
8685
|
|
|
7909
8686
|
print_debug(f"get_batched_arms: Finished with {len(batched_arms)} arm(s) after {attempts} attempt(s).")
|
|
@@ -7942,16 +8719,14 @@ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7942
8719
|
retries += 1
|
|
7943
8720
|
continue
|
|
7944
8721
|
|
|
7945
|
-
|
|
7946
|
-
progressbar_description(_get_trials_message(cnt + 1, n, trial_durations))
|
|
8722
|
+
progressbar_description(get_trials_message(cnt + 1, n, trial_durations))
|
|
7947
8723
|
|
|
7948
8724
|
try:
|
|
7949
8725
|
result = create_and_handle_trial(arm)
|
|
7950
8726
|
if result is not None:
|
|
7951
8727
|
trial_index, trial_duration, trial_successful = result
|
|
7952
|
-
|
|
7953
8728
|
except TrialRejected as e:
|
|
7954
|
-
print_debug(f"Trial rejected: {e}")
|
|
8729
|
+
print_debug(f"generate_trials: Trial rejected, error: {e}")
|
|
7955
8730
|
retries += 1
|
|
7956
8731
|
continue
|
|
7957
8732
|
|
|
@@ -7961,14 +8736,23 @@ def generate_trials(n: int, recursion: bool) -> Tuple[Dict[int, Any], bool]:
|
|
|
7961
8736
|
cnt += 1
|
|
7962
8737
|
trials_dict[trial_index] = arm.parameters
|
|
7963
8738
|
|
|
7964
|
-
|
|
8739
|
+
finalized = finalize_generation(trials_dict, cnt, n, start_time)
|
|
8740
|
+
|
|
8741
|
+
return finalized
|
|
7965
8742
|
|
|
7966
8743
|
except Exception as e:
|
|
7967
|
-
return
|
|
8744
|
+
return handle_generation_failure(e, n, recursion)
|
|
7968
8745
|
|
|
7969
8746
|
class TrialRejected(Exception):
|
|
7970
8747
|
pass
|
|
7971
8748
|
|
|
8749
|
+
def mark_abandoned(trial: Any, reason: str, trial_index: int) -> None:
|
|
8750
|
+
try:
|
|
8751
|
+
print_debug(f"[INFO] Marking trial {trial.index} ({trial.arm.name}) as abandoned, trial-index: {trial_index}. Reason: {reason}")
|
|
8752
|
+
trial.mark_abandoned(reason)
|
|
8753
|
+
except Exception as e:
|
|
8754
|
+
print_red(f"[ERROR] Could not mark trial as abandoned: {e}")
|
|
8755
|
+
|
|
7972
8756
|
def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
7973
8757
|
if ax_client is None:
|
|
7974
8758
|
print_red("ax_client is None in create_and_handle_trial")
|
|
@@ -7994,7 +8778,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
7994
8778
|
arm = trial.arms[0]
|
|
7995
8779
|
if deduplicated_arm(arm):
|
|
7996
8780
|
print_debug(f"Duplicated arm: {arm}")
|
|
7997
|
-
|
|
8781
|
+
mark_abandoned(trial, "Duplication detected", trial_index)
|
|
7998
8782
|
raise TrialRejected("Duplicate arm.")
|
|
7999
8783
|
|
|
8000
8784
|
arms_by_name_for_deduplication[arm.name] = arm
|
|
@@ -8003,7 +8787,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
8003
8787
|
|
|
8004
8788
|
if not has_no_post_generation_constraints_or_matches_constraints(post_generation_constraints, params):
|
|
8005
8789
|
print_debug(f"Trial {trial_index} does not meet post-generation constraints. Marking abandoned. Params: {params}, constraints: {post_generation_constraints}")
|
|
8006
|
-
|
|
8790
|
+
mark_abandoned(trial, "Post-Generation-Constraint failed", trial_index)
|
|
8007
8791
|
abandoned_trial_indices.append(trial_index)
|
|
8008
8792
|
raise TrialRejected("Post-generation constraints not met.")
|
|
8009
8793
|
|
|
@@ -8011,7 +8795,7 @@ def create_and_handle_trial(arm: Any) -> Optional[Tuple[int, float, bool]]:
|
|
|
8011
8795
|
end = time.time()
|
|
8012
8796
|
return trial_index, float(end - start), True
|
|
8013
8797
|
|
|
8014
|
-
def
|
|
8798
|
+
def finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int, start_time: float) -> Tuple[Dict[int, Any], bool]:
|
|
8015
8799
|
total_time = time.time() - start_time
|
|
8016
8800
|
|
|
8017
8801
|
log_gen_times.append(total_time)
|
|
@@ -8022,7 +8806,7 @@ def _finalize_generation(trials_dict: Dict[int, Any], cnt: int, requested: int,
|
|
|
8022
8806
|
|
|
8023
8807
|
return trials_dict, False
|
|
8024
8808
|
|
|
8025
|
-
def
|
|
8809
|
+
def handle_generation_failure(
|
|
8026
8810
|
e: Exception,
|
|
8027
8811
|
requested: int,
|
|
8028
8812
|
recursion: bool
|
|
@@ -8038,19 +8822,19 @@ def _handle_generation_failure(
|
|
|
8038
8822
|
)):
|
|
8039
8823
|
msg = str(e)
|
|
8040
8824
|
if msg not in error_8_saved:
|
|
8041
|
-
|
|
8825
|
+
print_exhaustion_warning(e, recursion)
|
|
8042
8826
|
error_8_saved.append(msg)
|
|
8043
8827
|
|
|
8044
8828
|
if not recursion and args.revert_to_random_when_seemingly_exhausted:
|
|
8045
8829
|
print_debug("Switching to random search strategy.")
|
|
8046
|
-
|
|
8830
|
+
set_global_gs_to_sobol()
|
|
8047
8831
|
return fetch_next_trials(requested, True)
|
|
8048
8832
|
|
|
8049
|
-
print_red(f"
|
|
8833
|
+
print_red(f"handle_generation_failure: General Exception: {e}")
|
|
8050
8834
|
|
|
8051
8835
|
return {}, True
|
|
8052
8836
|
|
|
8053
|
-
def
|
|
8837
|
+
def print_exhaustion_warning(e: Exception, recursion: bool) -> None:
|
|
8054
8838
|
if not recursion and args.revert_to_random_when_seemingly_exhausted:
|
|
8055
8839
|
print_yellow(f"\n⚠Error 8: {e} From now (done jobs: {count_done_jobs()}) on, random points will be generated.")
|
|
8056
8840
|
else:
|
|
@@ -8095,21 +8879,24 @@ def get_model_gen_kwargs() -> dict:
|
|
|
8095
8879
|
"fit_out_of_design": args.fit_out_of_design
|
|
8096
8880
|
}
|
|
8097
8881
|
|
|
8098
|
-
def
|
|
8882
|
+
def set_global_gs_to_sobol() -> None:
|
|
8099
8883
|
global global_gs
|
|
8100
8884
|
global overwritten_to_random
|
|
8101
8885
|
|
|
8886
|
+
print("Reverting to SOBOL")
|
|
8887
|
+
|
|
8102
8888
|
global_gs = GenerationStrategy(
|
|
8103
8889
|
name="Random*",
|
|
8104
8890
|
nodes=[
|
|
8105
8891
|
GenerationNode(
|
|
8106
|
-
|
|
8107
|
-
|
|
8108
|
-
|
|
8109
|
-
|
|
8110
|
-
|
|
8111
|
-
|
|
8112
|
-
|
|
8892
|
+
name="Sobol",
|
|
8893
|
+
should_deduplicate=True,
|
|
8894
|
+
generator_specs=[ # type: ignore[arg-type]
|
|
8895
|
+
GeneratorSpec( # type: ignore[arg-type]
|
|
8896
|
+
Generators.SOBOL, # type: ignore[arg-type]
|
|
8897
|
+
model_gen_kwargs=get_model_gen_kwargs() # type: ignore[arg-type]
|
|
8898
|
+
) # type: ignore[arg-type]
|
|
8899
|
+
] # type: ignore[arg-type]
|
|
8113
8900
|
)
|
|
8114
8901
|
]
|
|
8115
8902
|
)
|
|
@@ -8494,10 +9281,10 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
|
|
|
8494
9281
|
|
|
8495
9282
|
for s in splitted_by_comma:
|
|
8496
9283
|
if "=" not in s:
|
|
8497
|
-
|
|
9284
|
+
print_red(f"'{s}' does not contain '='")
|
|
8498
9285
|
my_exit(123)
|
|
8499
9286
|
if s.count("=") != 1:
|
|
8500
|
-
|
|
9287
|
+
print_red(f"There can only be one '=' in the gen_strat_str's element '{s}'")
|
|
8501
9288
|
my_exit(123)
|
|
8502
9289
|
|
|
8503
9290
|
model_name, nr_str = s.split("=")
|
|
@@ -8507,13 +9294,13 @@ def parse_generation_strategy_string(gen_strat_str: str) -> tuple[list[dict[str,
|
|
|
8507
9294
|
_fatal_error(f"Model {matching_model} is not valid for custom generation strategy.", 56)
|
|
8508
9295
|
|
|
8509
9296
|
if not matching_model:
|
|
8510
|
-
|
|
9297
|
+
print_red(f"'{model_name}' not found in SUPPORTED_MODELS")
|
|
8511
9298
|
my_exit(123)
|
|
8512
9299
|
|
|
8513
9300
|
try:
|
|
8514
9301
|
nr = int(nr_str)
|
|
8515
9302
|
except ValueError:
|
|
8516
|
-
|
|
9303
|
+
print_red(f"Invalid number of generations '{nr_str}' for model '{model_name}'")
|
|
8517
9304
|
my_exit(123)
|
|
8518
9305
|
|
|
8519
9306
|
gen_strat_list.append({matching_model: nr})
|
|
@@ -8637,7 +9424,7 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
|
|
|
8637
9424
|
if model_name == "TPE":
|
|
8638
9425
|
if len(arg_result_names) != 1:
|
|
8639
9426
|
_fatal_error(f"Has {len(arg_result_names)} results. TPE currently only supports single-objective-optimization.", 108)
|
|
8640
|
-
return ExternalProgramGenerationNode(external_generator=f"python3 {script_dir}/.tpe.py",
|
|
9427
|
+
return ExternalProgramGenerationNode(external_generator=f"python3 {script_dir}/.tpe.py", name="EXTERNAL_GENERATOR")
|
|
8641
9428
|
|
|
8642
9429
|
external_generators = {
|
|
8643
9430
|
"PSEUDORANDOM": f"python3 {script_dir}/.random_generator.py",
|
|
@@ -8651,10 +9438,10 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
|
|
|
8651
9438
|
cmd = external_generators[model_name]
|
|
8652
9439
|
if model_name == "EXTERNAL_GENERATOR" and not cmd:
|
|
8653
9440
|
_fatal_error("--external_generator is missing. Cannot create points for EXTERNAL_GENERATOR without it.", 204)
|
|
8654
|
-
return ExternalProgramGenerationNode(external_generator=cmd,
|
|
9441
|
+
return ExternalProgramGenerationNode(external_generator=cmd, name="EXTERNAL_GENERATOR")
|
|
8655
9442
|
|
|
8656
9443
|
trans_crit = [
|
|
8657
|
-
|
|
9444
|
+
MinTrials(
|
|
8658
9445
|
threshold=threshold,
|
|
8659
9446
|
block_transition_if_unmet=True,
|
|
8660
9447
|
transition_to=target_model,
|
|
@@ -8670,11 +9457,12 @@ def create_node(model_name: str, threshold: int, next_model_name: Optional[str])
|
|
|
8670
9457
|
if model_name.lower() != "sobol":
|
|
8671
9458
|
kwargs["model_kwargs"] = get_model_kwargs()
|
|
8672
9459
|
|
|
8673
|
-
model_spec = [GeneratorSpec(selected_model, **kwargs)]
|
|
9460
|
+
model_spec = [GeneratorSpec(selected_model, **kwargs)] # type: ignore[arg-type]
|
|
8674
9461
|
|
|
8675
9462
|
res = GenerationNode(
|
|
8676
|
-
|
|
9463
|
+
name=model_name,
|
|
8677
9464
|
generator_specs=model_spec,
|
|
9465
|
+
should_deduplicate=True,
|
|
8678
9466
|
transition_criteria=trans_crit
|
|
8679
9467
|
)
|
|
8680
9468
|
|
|
@@ -8685,7 +9473,7 @@ def get_optimizer_kwargs() -> dict:
|
|
|
8685
9473
|
"sequential": False
|
|
8686
9474
|
}
|
|
8687
9475
|
|
|
8688
|
-
def create_step(model_name: str, _num_trials: int
|
|
9476
|
+
def create_step(model_name: str, _num_trials: int, index: int) -> GenerationStep:
|
|
8689
9477
|
model_enum = get_model_from_name(model_name)
|
|
8690
9478
|
|
|
8691
9479
|
return GenerationStep(
|
|
@@ -8700,17 +9488,16 @@ def create_step(model_name: str, _num_trials: int = -1, index: Optional[int] = N
|
|
|
8700
9488
|
)
|
|
8701
9489
|
|
|
8702
9490
|
def set_global_generation_strategy() -> None:
|
|
8703
|
-
|
|
8704
|
-
continue_not_supported_on_custom_generation_strategy()
|
|
9491
|
+
continue_not_supported_on_custom_generation_strategy()
|
|
8705
9492
|
|
|
8706
|
-
|
|
8707
|
-
|
|
8708
|
-
|
|
8709
|
-
|
|
8710
|
-
|
|
8711
|
-
|
|
8712
|
-
|
|
8713
|
-
|
|
9493
|
+
try:
|
|
9494
|
+
if args.generation_strategy is None:
|
|
9495
|
+
setup_default_generation_strategy()
|
|
9496
|
+
else:
|
|
9497
|
+
setup_custom_generation_strategy()
|
|
9498
|
+
except Exception as e:
|
|
9499
|
+
print_red(f"Unexpected error in generation strategy setup: {e}")
|
|
9500
|
+
my_exit(111)
|
|
8714
9501
|
|
|
8715
9502
|
if global_gs is None:
|
|
8716
9503
|
print_red("global_gs is None after setup!")
|
|
@@ -8861,10 +9648,14 @@ def execute_trials(
|
|
|
8861
9648
|
result = future.result()
|
|
8862
9649
|
print_debug(f"result in execute_trials: {result}")
|
|
8863
9650
|
except Exception as exc:
|
|
8864
|
-
|
|
9651
|
+
failed_args = future_to_args[future]
|
|
9652
|
+
print_red(f"execute_trials: Error at executing a trial with args {failed_args}: {exc}")
|
|
9653
|
+
traceback.print_exc()
|
|
8865
9654
|
|
|
8866
9655
|
end_time = time.time()
|
|
8867
9656
|
|
|
9657
|
+
log_data()
|
|
9658
|
+
|
|
8868
9659
|
duration = float(end_time - start_time)
|
|
8869
9660
|
job_submit_durations.append(duration)
|
|
8870
9661
|
job_submit_nrs.append(cnt)
|
|
@@ -8898,20 +9689,20 @@ def create_and_execute_next_runs(next_nr_steps: int, phase: Optional[str], _max_
|
|
|
8898
9689
|
done_optimizing: bool = False
|
|
8899
9690
|
|
|
8900
9691
|
try:
|
|
8901
|
-
done_optimizing, trial_index_to_param =
|
|
8902
|
-
|
|
9692
|
+
done_optimizing, trial_index_to_param = create_and_execute_next_runs_run_loop(_max_eval, phase)
|
|
9693
|
+
create_and_execute_next_runs_finish(done_optimizing)
|
|
8903
9694
|
except Exception as e:
|
|
8904
9695
|
stacktrace = traceback.format_exc()
|
|
8905
9696
|
print_debug(f"Warning: create_and_execute_next_runs encountered an exception: {e}\n{stacktrace}")
|
|
8906
9697
|
return handle_exceptions_create_and_execute_next_runs(e)
|
|
8907
9698
|
|
|
8908
|
-
return
|
|
9699
|
+
return create_and_execute_next_runs_return_value(trial_index_to_param)
|
|
8909
9700
|
|
|
8910
|
-
def
|
|
9701
|
+
def create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Optional[str]) -> Tuple[bool, Optional[Dict]]:
|
|
8911
9702
|
done_optimizing = False
|
|
8912
9703
|
trial_index_to_param: Optional[Dict] = None
|
|
8913
9704
|
|
|
8914
|
-
nr_of_jobs_to_get =
|
|
9705
|
+
nr_of_jobs_to_get = calculate_nr_of_jobs_to_get(get_nr_of_imported_jobs(), len(global_vars["jobs"]))
|
|
8915
9706
|
|
|
8916
9707
|
__max_eval = _max_eval if _max_eval is not None else 0
|
|
8917
9708
|
new_nr_of_jobs_to_get = min(__max_eval - (submitted_jobs() - failed_jobs()), nr_of_jobs_to_get)
|
|
@@ -8925,6 +9716,7 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
|
|
|
8925
9716
|
|
|
8926
9717
|
for _ in range(range_nr):
|
|
8927
9718
|
trial_index_to_param, done_optimizing = get_next_trials(get_next_trials_nr)
|
|
9719
|
+
log_data()
|
|
8928
9720
|
if done_optimizing:
|
|
8929
9721
|
continue
|
|
8930
9722
|
|
|
@@ -8949,13 +9741,13 @@ def _create_and_execute_next_runs_run_loop(_max_eval: Optional[int], phase: Opti
|
|
|
8949
9741
|
|
|
8950
9742
|
return done_optimizing, trial_index_to_param
|
|
8951
9743
|
|
|
8952
|
-
def
|
|
9744
|
+
def create_and_execute_next_runs_finish(done_optimizing: bool) -> None:
|
|
8953
9745
|
finish_previous_jobs(["finishing jobs"])
|
|
8954
9746
|
|
|
8955
9747
|
if done_optimizing:
|
|
8956
9748
|
end_program(False, 0)
|
|
8957
9749
|
|
|
8958
|
-
def
|
|
9750
|
+
def create_and_execute_next_runs_return_value(trial_index_to_param: Optional[Dict]) -> int:
|
|
8959
9751
|
try:
|
|
8960
9752
|
if trial_index_to_param:
|
|
8961
9753
|
res = len(trial_index_to_param.keys())
|
|
@@ -9066,7 +9858,7 @@ def execute_nvidia_smi() -> None:
|
|
|
9066
9858
|
if not host:
|
|
9067
9859
|
print_debug("host not defined")
|
|
9068
9860
|
except Exception as e:
|
|
9069
|
-
|
|
9861
|
+
print_red(f"execute_nvidia_smi: An error occurred: {e}")
|
|
9070
9862
|
if is_slurm_job() and not args.force_local_execution:
|
|
9071
9863
|
_sleep(30)
|
|
9072
9864
|
|
|
@@ -9104,11 +9896,6 @@ def run_search() -> bool:
|
|
|
9104
9896
|
|
|
9105
9897
|
return False
|
|
9106
9898
|
|
|
9107
|
-
async def start_logging_daemon() -> None:
|
|
9108
|
-
while True:
|
|
9109
|
-
log_data()
|
|
9110
|
-
time.sleep(30)
|
|
9111
|
-
|
|
9112
9899
|
def should_break_search() -> bool:
|
|
9113
9900
|
ret = False
|
|
9114
9901
|
|
|
@@ -9157,10 +9944,8 @@ def check_search_space_exhaustion(nr_of_items: int) -> bool:
|
|
|
9157
9944
|
print_debug(_wrn)
|
|
9158
9945
|
progressbar_description(_wrn)
|
|
9159
9946
|
|
|
9160
|
-
live_share()
|
|
9161
9947
|
return True
|
|
9162
9948
|
|
|
9163
|
-
live_share()
|
|
9164
9949
|
return False
|
|
9165
9950
|
|
|
9166
9951
|
def finalize_jobs() -> None:
|
|
@@ -9174,7 +9959,7 @@ def finalize_jobs() -> None:
|
|
|
9174
9959
|
handle_slurm_execution()
|
|
9175
9960
|
|
|
9176
9961
|
def go_through_jobs_that_are_not_completed_yet() -> None:
|
|
9177
|
-
print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
|
|
9962
|
+
#print_debug(f"Waiting for jobs to finish (currently, len(global_vars['jobs']) = {len(global_vars['jobs'])}")
|
|
9178
9963
|
|
|
9179
9964
|
nr_jobs_left = len(global_vars['jobs'])
|
|
9180
9965
|
if nr_jobs_left == 1:
|
|
@@ -9244,372 +10029,60 @@ def parse_orchestrator_file(_f: str, _test: bool = False) -> Union[dict, None]:
|
|
|
9244
10029
|
|
|
9245
10030
|
return data
|
|
9246
10031
|
except Exception as e:
|
|
9247
|
-
|
|
10032
|
+
print_red(f"Error while parse_experiment_parameters({_f}): {e}")
|
|
9248
10033
|
else:
|
|
9249
10034
|
print_red(f"{_f} could not be found")
|
|
9250
10035
|
|
|
9251
|
-
return None
|
|
9252
|
-
|
|
9253
|
-
def set_orchestrator() -> None:
|
|
9254
|
-
with spinner("Setting orchestrator..."):
|
|
9255
|
-
global orchestrator
|
|
9256
|
-
|
|
9257
|
-
if args.orchestrator_file:
|
|
9258
|
-
if SYSTEM_HAS_SBATCH:
|
|
9259
|
-
orchestrator = parse_orchestrator_file(args.orchestrator_file, False)
|
|
9260
|
-
else:
|
|
9261
|
-
print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
|
|
9262
|
-
|
|
9263
|
-
def check_if_has_random_steps() -> None:
|
|
9264
|
-
if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
|
|
9265
|
-
_fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
|
|
9266
|
-
|
|
9267
|
-
def add_exclude_to_defective_nodes() -> None:
|
|
9268
|
-
with spinner("Adding excluded nodes..."):
|
|
9269
|
-
if args.exclude:
|
|
9270
|
-
entries = [entry.strip() for entry in args.exclude.split(',')]
|
|
9271
|
-
|
|
9272
|
-
for entry in entries:
|
|
9273
|
-
count_defective_nodes(None, entry)
|
|
9274
|
-
|
|
9275
|
-
def check_max_eval(_max_eval: int) -> None:
|
|
9276
|
-
with spinner("Checking max_eval..."):
|
|
9277
|
-
if not _max_eval:
|
|
9278
|
-
_fatal_error("--max_eval needs to be set!", 19)
|
|
9279
|
-
|
|
9280
|
-
def parse_parameters() -> Any:
|
|
9281
|
-
cli_params_experiment_parameters = None
|
|
9282
|
-
if args.parameter:
|
|
9283
|
-
parse_experiment_parameters()
|
|
9284
|
-
cli_params_experiment_parameters = experiment_parameters
|
|
9285
|
-
|
|
9286
|
-
return cli_params_experiment_parameters
|
|
9287
|
-
|
|
9288
|
-
def create_pareto_front_table(idxs: List[int], metric_x: str, metric_y: str) -> Table:
|
|
9289
|
-
table = Table(title=f"Pareto-Front for {metric_y}/{metric_x}:", show_lines=True)
|
|
9290
|
-
|
|
9291
|
-
rows = _pareto_front_table_read_csv()
|
|
9292
|
-
if not rows:
|
|
9293
|
-
table.add_column("No data found")
|
|
9294
|
-
return table
|
|
9295
|
-
|
|
9296
|
-
filtered_rows = _pareto_front_table_filter_rows(rows, idxs)
|
|
9297
|
-
if not filtered_rows:
|
|
9298
|
-
table.add_column("No matching entries")
|
|
9299
|
-
return table
|
|
9300
|
-
|
|
9301
|
-
param_cols, result_cols = _pareto_front_table_get_columns(filtered_rows[0])
|
|
9302
|
-
|
|
9303
|
-
_pareto_front_table_add_headers(table, param_cols, result_cols)
|
|
9304
|
-
_pareto_front_table_add_rows(table, filtered_rows, param_cols, result_cols)
|
|
9305
|
-
|
|
9306
|
-
return table
|
|
9307
|
-
|
|
9308
|
-
def _pareto_front_table_read_csv() -> List[Dict[str, str]]:
|
|
9309
|
-
with open(RESULT_CSV_FILE, mode="r", encoding="utf-8", newline="") as f:
|
|
9310
|
-
return list(csv.DictReader(f))
|
|
9311
|
-
|
|
9312
|
-
def _pareto_front_table_filter_rows(rows: List[Dict[str, str]], idxs: List[int]) -> List[Dict[str, str]]:
|
|
9313
|
-
result = []
|
|
9314
|
-
for row in rows:
|
|
9315
|
-
try:
|
|
9316
|
-
trial_index = int(row["trial_index"])
|
|
9317
|
-
except (KeyError, ValueError):
|
|
9318
|
-
continue
|
|
9319
|
-
|
|
9320
|
-
if row.get("trial_status", "").strip().upper() == "COMPLETED" and trial_index in idxs:
|
|
9321
|
-
result.append(row)
|
|
9322
|
-
return result
|
|
9323
|
-
|
|
9324
|
-
def _pareto_front_table_get_columns(first_row: Dict[str, str]) -> Tuple[List[str], List[str]]:
|
|
9325
|
-
all_columns = list(first_row.keys())
|
|
9326
|
-
ignored_cols = set(special_col_names) - {"trial_index"}
|
|
9327
|
-
|
|
9328
|
-
param_cols = [col for col in all_columns if col not in ignored_cols and col not in arg_result_names and not col.startswith("OO_Info_")]
|
|
9329
|
-
result_cols = [col for col in arg_result_names if col in all_columns]
|
|
9330
|
-
return param_cols, result_cols
|
|
9331
|
-
|
|
9332
|
-
def _pareto_front_table_add_headers(table: Table, param_cols: List[str], result_cols: List[str]) -> None:
|
|
9333
|
-
for col in param_cols:
|
|
9334
|
-
table.add_column(col, justify="center")
|
|
9335
|
-
for col in result_cols:
|
|
9336
|
-
table.add_column(Text(f"{col}", style="cyan"), justify="center")
|
|
9337
|
-
|
|
9338
|
-
def _pareto_front_table_add_rows(table: Table, rows: List[Dict[str, str]], param_cols: List[str], result_cols: List[str]) -> None:
|
|
9339
|
-
for row in rows:
|
|
9340
|
-
values = [str(helpers.to_int_when_possible(row[col])) for col in param_cols]
|
|
9341
|
-
result_values = [Text(str(helpers.to_int_when_possible(row[col])), style="cyan") for col in result_cols]
|
|
9342
|
-
table.add_row(*values, *result_values, style="bold green")
|
|
9343
|
-
|
|
9344
|
-
def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
|
|
9345
|
-
if not os.path.exists(RESULT_CSV_FILE):
|
|
9346
|
-
print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
|
|
9347
|
-
return None
|
|
9348
|
-
|
|
9349
|
-
return create_pareto_front_table(idxs, metric_x, metric_y)
|
|
9350
|
-
|
|
9351
|
-
def supports_sixel() -> bool:
|
|
9352
|
-
term = os.environ.get("TERM", "").lower()
|
|
9353
|
-
if "xterm" in term or "mlterm" in term:
|
|
9354
|
-
return True
|
|
9355
|
-
|
|
9356
|
-
try:
|
|
9357
|
-
output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
|
|
9358
|
-
if output.returncode == 0 and "sixel" in output.stdout.lower():
|
|
9359
|
-
return True
|
|
9360
|
-
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
9361
|
-
pass
|
|
9362
|
-
|
|
9363
|
-
return False
|
|
9364
|
-
|
|
9365
|
-
def plot_pareto_frontier_sixel(data: Any, x_metric: str, y_metric: str) -> None:
|
|
9366
|
-
if data is None:
|
|
9367
|
-
print("[italic yellow]The data seems to be empty. Cannot plot pareto frontier.[/]")
|
|
9368
|
-
return
|
|
9369
|
-
|
|
9370
|
-
if not supports_sixel():
|
|
9371
|
-
print(f"[italic yellow]Your console does not support sixel-images. Will not print Pareto-frontier as a matplotlib-sixel-plot for {x_metric}/{y_metric}.[/]")
|
|
9372
|
-
return
|
|
9373
|
-
|
|
9374
|
-
import matplotlib.pyplot as plt
|
|
9375
|
-
|
|
9376
|
-
means = data[x_metric][y_metric]["means"]
|
|
9377
|
-
|
|
9378
|
-
x_values = means[x_metric]
|
|
9379
|
-
y_values = means[y_metric]
|
|
9380
|
-
|
|
9381
|
-
fig, _ax = plt.subplots()
|
|
9382
|
-
|
|
9383
|
-
_ax.scatter(x_values, y_values, s=50, marker='x', c='blue', label='Data Points')
|
|
9384
|
-
|
|
9385
|
-
_ax.set_xlabel(x_metric)
|
|
9386
|
-
_ax.set_ylabel(y_metric)
|
|
9387
|
-
|
|
9388
|
-
_ax.set_title(f'Pareto-Front {x_metric}/{y_metric}')
|
|
9389
|
-
|
|
9390
|
-
_ax.ticklabel_format(style='plain', axis='both', useOffset=False)
|
|
9391
|
-
|
|
9392
|
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
|
|
9393
|
-
plt.savefig(tmp_file.name, dpi=300)
|
|
9394
|
-
|
|
9395
|
-
print_image_to_cli(tmp_file.name, 1000)
|
|
9396
|
-
|
|
9397
|
-
plt.close(fig)
|
|
9398
|
-
|
|
9399
|
-
def _pareto_front_general_validate_shapes(x: np.ndarray, y: np.ndarray) -> None:
|
|
9400
|
-
if x.shape != y.shape:
|
|
9401
|
-
raise ValueError("Input arrays x and y must have the same shape.")
|
|
9402
|
-
|
|
9403
|
-
def _pareto_front_general_compare(
|
|
9404
|
-
xi: float, yi: float, xj: float, yj: float,
|
|
9405
|
-
x_minimize: bool, y_minimize: bool
|
|
9406
|
-
) -> bool:
|
|
9407
|
-
x_better_eq = xj <= xi if x_minimize else xj >= xi
|
|
9408
|
-
y_better_eq = yj <= yi if y_minimize else yj >= yi
|
|
9409
|
-
x_strictly_better = xj < xi if x_minimize else xj > xi
|
|
9410
|
-
y_strictly_better = yj < yi if y_minimize else yj > yi
|
|
9411
|
-
|
|
9412
|
-
return bool(x_better_eq and y_better_eq and (x_strictly_better or y_strictly_better))
|
|
9413
|
-
|
|
9414
|
-
def _pareto_front_general_find_dominated(
|
|
9415
|
-
x: np.ndarray, y: np.ndarray, x_minimize: bool, y_minimize: bool
|
|
9416
|
-
) -> np.ndarray:
|
|
9417
|
-
num_points = len(x)
|
|
9418
|
-
is_dominated = np.zeros(num_points, dtype=bool)
|
|
9419
|
-
|
|
9420
|
-
for i in range(num_points):
|
|
9421
|
-
for j in range(num_points):
|
|
9422
|
-
if i == j:
|
|
9423
|
-
continue
|
|
9424
|
-
|
|
9425
|
-
if _pareto_front_general_compare(x[i], y[i], x[j], y[j], x_minimize, y_minimize):
|
|
9426
|
-
is_dominated[i] = True
|
|
9427
|
-
break
|
|
9428
|
-
|
|
9429
|
-
return is_dominated
|
|
9430
|
-
|
|
9431
|
-
def pareto_front_general(
|
|
9432
|
-
x: np.ndarray,
|
|
9433
|
-
y: np.ndarray,
|
|
9434
|
-
x_minimize: bool = True,
|
|
9435
|
-
y_minimize: bool = True
|
|
9436
|
-
) -> np.ndarray:
|
|
9437
|
-
try:
|
|
9438
|
-
_pareto_front_general_validate_shapes(x, y)
|
|
9439
|
-
is_dominated = _pareto_front_general_find_dominated(x, y, x_minimize, y_minimize)
|
|
9440
|
-
return np.where(~is_dominated)[0]
|
|
9441
|
-
except Exception as e:
|
|
9442
|
-
print("Error in pareto_front_general:", str(e))
|
|
9443
|
-
return np.array([], dtype=int)
|
|
9444
|
-
|
|
9445
|
-
def _pareto_front_aggregate_data(path_to_calculate: str) -> Optional[Dict[Tuple[int, str], Dict[str, Dict[str, float]]]]:
|
|
9446
|
-
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
9447
|
-
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
9448
|
-
|
|
9449
|
-
if not os.path.exists(results_csv_file) or not os.path.exists(result_names_file):
|
|
9450
|
-
return None
|
|
9451
|
-
|
|
9452
|
-
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
9453
|
-
result_names = [line.strip() for line in f if line.strip()]
|
|
9454
|
-
|
|
9455
|
-
records: dict = defaultdict(lambda: {'means': {}})
|
|
9456
|
-
|
|
9457
|
-
with open(results_csv_file, encoding="utf-8", mode="r", newline='') as csvfile:
|
|
9458
|
-
reader = csv.DictReader(csvfile)
|
|
9459
|
-
for row in reader:
|
|
9460
|
-
trial_index = int(row['trial_index'])
|
|
9461
|
-
arm_name = row['arm_name']
|
|
9462
|
-
key = (trial_index, arm_name)
|
|
9463
|
-
|
|
9464
|
-
for metric in result_names:
|
|
9465
|
-
if metric in row:
|
|
9466
|
-
try:
|
|
9467
|
-
records[key]['means'][metric] = float(row[metric])
|
|
9468
|
-
except ValueError:
|
|
9469
|
-
continue
|
|
9470
|
-
|
|
9471
|
-
return records
|
|
9472
|
-
|
|
9473
|
-
def _pareto_front_filter_complete_points(
|
|
9474
|
-
path_to_calculate: str,
|
|
9475
|
-
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
9476
|
-
primary_name: str,
|
|
9477
|
-
secondary_name: str
|
|
9478
|
-
) -> List[Tuple[Tuple[int, str], float, float]]:
|
|
9479
|
-
points = []
|
|
9480
|
-
for key, metrics in records.items():
|
|
9481
|
-
means = metrics['means']
|
|
9482
|
-
if primary_name in means and secondary_name in means:
|
|
9483
|
-
x_val = means[primary_name]
|
|
9484
|
-
y_val = means[secondary_name]
|
|
9485
|
-
points.append((key, x_val, y_val))
|
|
9486
|
-
if len(points) == 0:
|
|
9487
|
-
raise ValueError(f"No full data points with both objectives found in {path_to_calculate}.")
|
|
9488
|
-
return points
|
|
9489
|
-
|
|
9490
|
-
def _pareto_front_transform_objectives(
|
|
9491
|
-
points: List[Tuple[Any, float, float]],
|
|
9492
|
-
primary_name: str,
|
|
9493
|
-
secondary_name: str
|
|
9494
|
-
) -> Tuple[np.ndarray, np.ndarray]:
|
|
9495
|
-
primary_idx = arg_result_names.index(primary_name)
|
|
9496
|
-
secondary_idx = arg_result_names.index(secondary_name)
|
|
9497
|
-
|
|
9498
|
-
x = np.array([p[1] for p in points])
|
|
9499
|
-
y = np.array([p[2] for p in points])
|
|
9500
|
-
|
|
9501
|
-
if arg_result_min_or_max[primary_idx] == "max":
|
|
9502
|
-
x = -x
|
|
9503
|
-
elif arg_result_min_or_max[primary_idx] != "min":
|
|
9504
|
-
raise ValueError(f"Unknown mode for {primary_name}: {arg_result_min_or_max[primary_idx]}")
|
|
9505
|
-
|
|
9506
|
-
if arg_result_min_or_max[secondary_idx] == "max":
|
|
9507
|
-
y = -y
|
|
9508
|
-
elif arg_result_min_or_max[secondary_idx] != "min":
|
|
9509
|
-
raise ValueError(f"Unknown mode for {secondary_name}: {arg_result_min_or_max[secondary_idx]}")
|
|
9510
|
-
|
|
9511
|
-
return x, y
|
|
9512
|
-
|
|
9513
|
-
def _pareto_front_select_pareto_points(
|
|
9514
|
-
x: np.ndarray,
|
|
9515
|
-
y: np.ndarray,
|
|
9516
|
-
x_minimize: bool,
|
|
9517
|
-
y_minimize: bool,
|
|
9518
|
-
points: List[Tuple[Any, float, float]],
|
|
9519
|
-
num_points: int
|
|
9520
|
-
) -> List[Tuple[Any, float, float]]:
|
|
9521
|
-
indices = pareto_front_general(x, y, x_minimize, y_minimize)
|
|
9522
|
-
sorted_indices = indices[np.argsort(x[indices])]
|
|
9523
|
-
sorted_indices = sorted_indices[:num_points]
|
|
9524
|
-
selected_points = [points[i] for i in sorted_indices]
|
|
9525
|
-
return selected_points
|
|
9526
|
-
|
|
9527
|
-
def _pareto_front_build_return_structure(
|
|
9528
|
-
path_to_calculate: str,
|
|
9529
|
-
selected_points: List[Tuple[Any, float, float]],
|
|
9530
|
-
records: Dict[Tuple[int, str], Dict[str, Dict[str, float]]],
|
|
9531
|
-
absolute_metrics: List[str],
|
|
9532
|
-
primary_name: str,
|
|
9533
|
-
secondary_name: str
|
|
9534
|
-
) -> dict:
|
|
9535
|
-
results_csv_file = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
9536
|
-
result_names_file = f"{path_to_calculate}/result_names.txt"
|
|
9537
|
-
|
|
9538
|
-
with open(result_names_file, mode="r", encoding="utf-8") as f:
|
|
9539
|
-
result_names = [line.strip() for line in f if line.strip()]
|
|
9540
|
-
|
|
9541
|
-
csv_rows = {}
|
|
9542
|
-
with open(results_csv_file, mode="r", encoding="utf-8", newline='') as csvfile:
|
|
9543
|
-
reader = csv.DictReader(csvfile)
|
|
9544
|
-
for row in reader:
|
|
9545
|
-
trial_index = int(row['trial_index'])
|
|
9546
|
-
csv_rows[trial_index] = row
|
|
9547
|
-
|
|
9548
|
-
ignored_columns = {'trial_index', 'arm_name', 'trial_status', 'generation_node'}
|
|
9549
|
-
ignored_columns.update(result_names)
|
|
9550
|
-
|
|
9551
|
-
param_dicts = []
|
|
9552
|
-
idxs = []
|
|
9553
|
-
means_dict = defaultdict(list)
|
|
10036
|
+
return None
|
|
9554
10037
|
|
|
9555
|
-
|
|
9556
|
-
|
|
9557
|
-
|
|
9558
|
-
print_debug(f"_pareto_front_build_return_structure: trial_index '{trial_index}' could not be found and row returned as None")
|
|
9559
|
-
continue
|
|
10038
|
+
def set_orchestrator() -> None:
|
|
10039
|
+
with spinner("Setting orchestrator..."):
|
|
10040
|
+
global orchestrator
|
|
9560
10041
|
|
|
9561
|
-
|
|
10042
|
+
if args.orchestrator_file:
|
|
10043
|
+
if SYSTEM_HAS_SBATCH:
|
|
10044
|
+
orchestrator = parse_orchestrator_file(args.orchestrator_file, False)
|
|
10045
|
+
else:
|
|
10046
|
+
print_yellow("--orchestrator_file will be ignored on non-sbatch-systems.")
|
|
9562
10047
|
|
|
9563
|
-
|
|
9564
|
-
|
|
9565
|
-
|
|
9566
|
-
try:
|
|
9567
|
-
param_dict[key] = int(value)
|
|
9568
|
-
except ValueError:
|
|
9569
|
-
try:
|
|
9570
|
-
param_dict[key] = float(value)
|
|
9571
|
-
except ValueError:
|
|
9572
|
-
param_dict[key] = value
|
|
10048
|
+
def check_if_has_random_steps() -> None:
|
|
10049
|
+
if (not args.continue_previous_job and "--continue" not in sys.argv) and (args.num_random_steps == 0 or not args.num_random_steps) and args.model not in ["EXTERNAL_GENERATOR", "SOBOL", "PSEUDORANDOM"]:
|
|
10050
|
+
_fatal_error("You have no random steps set. This is only allowed in continued jobs. To start, you need either some random steps, or a continued run.", 233)
|
|
9573
10051
|
|
|
9574
|
-
|
|
10052
|
+
def add_exclude_to_defective_nodes() -> None:
|
|
10053
|
+
with spinner("Adding excluded nodes..."):
|
|
10054
|
+
if args.exclude:
|
|
10055
|
+
entries = [entry.strip() for entry in args.exclude.split(',')]
|
|
9575
10056
|
|
|
9576
|
-
|
|
9577
|
-
|
|
10057
|
+
for entry in entries:
|
|
10058
|
+
count_defective_nodes(None, entry)
|
|
9578
10059
|
|
|
9579
|
-
|
|
9580
|
-
|
|
9581
|
-
|
|
9582
|
-
|
|
9583
|
-
"param_dicts": param_dicts,
|
|
9584
|
-
"means": dict(means_dict),
|
|
9585
|
-
"idxs": idxs
|
|
9586
|
-
},
|
|
9587
|
-
"absolute_metrics": absolute_metrics
|
|
9588
|
-
}
|
|
9589
|
-
}
|
|
10060
|
+
def check_max_eval(_max_eval: int) -> None:
|
|
10061
|
+
with spinner("Checking max_eval..."):
|
|
10062
|
+
if not _max_eval:
|
|
10063
|
+
_fatal_error("--max_eval needs to be set!", 19)
|
|
9590
10064
|
|
|
9591
|
-
|
|
10065
|
+
def parse_parameters() -> Any:
|
|
10066
|
+
cli_params_experiment_parameters = None
|
|
10067
|
+
if args.parameter:
|
|
10068
|
+
parse_experiment_parameters()
|
|
10069
|
+
cli_params_experiment_parameters = experiment_parameters
|
|
9592
10070
|
|
|
9593
|
-
|
|
9594
|
-
path_to_calculate: str,
|
|
9595
|
-
primary_objective: str,
|
|
9596
|
-
secondary_objective: str,
|
|
9597
|
-
x_minimize: bool,
|
|
9598
|
-
y_minimize: bool,
|
|
9599
|
-
absolute_metrics: List[str],
|
|
9600
|
-
num_points: int
|
|
9601
|
-
) -> Optional[dict]:
|
|
9602
|
-
records = _pareto_front_aggregate_data(path_to_calculate)
|
|
10071
|
+
return cli_params_experiment_parameters
|
|
9603
10072
|
|
|
9604
|
-
|
|
9605
|
-
|
|
10073
|
+
def supports_sixel() -> bool:
|
|
10074
|
+
term = os.environ.get("TERM", "").lower()
|
|
10075
|
+
if "xterm" in term or "mlterm" in term:
|
|
10076
|
+
return True
|
|
9606
10077
|
|
|
9607
|
-
|
|
9608
|
-
|
|
9609
|
-
|
|
9610
|
-
|
|
10078
|
+
try:
|
|
10079
|
+
output = subprocess.run(["tput", "setab", "256"], capture_output=True, text=True, check=True)
|
|
10080
|
+
if output.returncode == 0 and "sixel" in output.stdout.lower():
|
|
10081
|
+
return True
|
|
10082
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
10083
|
+
pass
|
|
9611
10084
|
|
|
9612
|
-
return
|
|
10085
|
+
return False
|
|
9613
10086
|
|
|
9614
10087
|
def save_experiment_state() -> None:
|
|
9615
10088
|
try:
|
|
@@ -9617,14 +10090,14 @@ def save_experiment_state() -> None:
|
|
|
9617
10090
|
print_red("save_experiment_state: ax_client or ax_client.experiment is None, cannot save.")
|
|
9618
10091
|
return
|
|
9619
10092
|
state_path = get_current_run_folder("experiment_state.json")
|
|
9620
|
-
|
|
10093
|
+
save_ax_client_to_json_file(state_path)
|
|
9621
10094
|
except Exception as e:
|
|
9622
|
-
|
|
10095
|
+
print_debug(f"Error saving experiment state: {e}")
|
|
9623
10096
|
|
|
9624
10097
|
def wait_for_state_file(state_path: str, min_size: int = 5, max_wait_seconds: int = 60) -> bool:
|
|
9625
10098
|
try:
|
|
9626
10099
|
if not os.path.exists(state_path):
|
|
9627
|
-
|
|
10100
|
+
print_debug(f"[ERROR] File '{state_path}' does not exist.")
|
|
9628
10101
|
return False
|
|
9629
10102
|
|
|
9630
10103
|
i = 0
|
|
@@ -9843,38 +10316,47 @@ def get_result_minimize_flag(path_to_calculate: str, resname: str) -> bool:
|
|
|
9843
10316
|
|
|
9844
10317
|
return minmax[index] == "min"
|
|
9845
10318
|
|
|
9846
|
-
def
|
|
9847
|
-
|
|
10319
|
+
def post_job_calculate_pareto_front() -> None:
|
|
10320
|
+
if not args.calculate_pareto_front_of_job:
|
|
10321
|
+
return
|
|
9848
10322
|
|
|
9849
|
-
|
|
10323
|
+
failure = False
|
|
9850
10324
|
|
|
9851
|
-
|
|
10325
|
+
_paths_to_calculate = []
|
|
9852
10326
|
|
|
9853
|
-
for
|
|
9854
|
-
|
|
9855
|
-
|
|
9856
|
-
metric_y = arg_result_names[j]
|
|
10327
|
+
for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
|
|
10328
|
+
try:
|
|
10329
|
+
found_paths = find_results_paths(_path_to_calculate)
|
|
9857
10330
|
|
|
9858
|
-
|
|
9859
|
-
|
|
10331
|
+
for _fp in found_paths:
|
|
10332
|
+
if _fp not in _paths_to_calculate:
|
|
10333
|
+
_paths_to_calculate.append(_fp)
|
|
10334
|
+
except (FileNotFoundError, NotADirectoryError) as e:
|
|
10335
|
+
print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
|
|
9860
10336
|
|
|
9861
|
-
|
|
9862
|
-
if metric_x not in pareto_front_data:
|
|
9863
|
-
pareto_front_data[metric_x] = {}
|
|
10337
|
+
failure = True
|
|
9864
10338
|
|
|
9865
|
-
|
|
9866
|
-
|
|
9867
|
-
|
|
9868
|
-
|
|
9869
|
-
print_red("Calculating Pareto-fronts was cancelled by pressing CTRL-c")
|
|
9870
|
-
skip = True
|
|
10339
|
+
for _path_to_calculate in _paths_to_calculate:
|
|
10340
|
+
for path_to_calculate in found_paths:
|
|
10341
|
+
if not job_calculate_pareto_front(path_to_calculate):
|
|
10342
|
+
failure = True
|
|
9871
10343
|
|
|
9872
|
-
|
|
10344
|
+
if failure:
|
|
10345
|
+
my_exit(24)
|
|
10346
|
+
|
|
10347
|
+
my_exit(0)
|
|
10348
|
+
|
|
10349
|
+
def pareto_front_as_rich_table(idxs: list, metric_x: str, metric_y: str) -> Optional[Table]:
|
|
10350
|
+
if not os.path.exists(RESULT_CSV_FILE):
|
|
10351
|
+
print_debug(f"pareto_front_as_rich_table: File '{RESULT_CSV_FILE}' not found")
|
|
10352
|
+
return None
|
|
10353
|
+
|
|
10354
|
+
return create_pareto_front_table(idxs, metric_x, metric_y)
|
|
9873
10355
|
|
|
9874
10356
|
def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_sixel_and_table: bool = False) -> None:
|
|
9875
10357
|
if len(res_names) <= 1:
|
|
9876
10358
|
print_debug(f"--result_names (has {len(res_names)} entries) must be at least 2.")
|
|
9877
|
-
return
|
|
10359
|
+
return None
|
|
9878
10360
|
|
|
9879
10361
|
pareto_front_data: dict = get_pareto_front_data(path_to_calculate, res_names)
|
|
9880
10362
|
|
|
@@ -9895,8 +10377,16 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
|
|
|
9895
10377
|
else:
|
|
9896
10378
|
print(f"Not showing Pareto-front-sixel for {path_to_calculate}")
|
|
9897
10379
|
|
|
9898
|
-
if
|
|
9899
|
-
|
|
10380
|
+
if calculated_frontier is None:
|
|
10381
|
+
print_debug("ERROR: calculated_frontier is None")
|
|
10382
|
+
return None
|
|
10383
|
+
|
|
10384
|
+
try:
|
|
10385
|
+
if len(calculated_frontier[metric_x][metric_y]["idxs"]):
|
|
10386
|
+
pareto_points[metric_x][metric_y] = sorted(calculated_frontier[metric_x][metric_y]["idxs"])
|
|
10387
|
+
except AttributeError:
|
|
10388
|
+
print_debug(f"ERROR: calculated_frontier structure invalid for ({metric_x}, {metric_y})")
|
|
10389
|
+
return None
|
|
9900
10390
|
|
|
9901
10391
|
rich_table = pareto_front_as_rich_table(
|
|
9902
10392
|
calculated_frontier[metric_x][metric_y]["idxs"],
|
|
@@ -9921,6 +10411,8 @@ def show_pareto_frontier_data(path_to_calculate: str, res_names: list, disable_s
|
|
|
9921
10411
|
|
|
9922
10412
|
live_share_after_pareto()
|
|
9923
10413
|
|
|
10414
|
+
return None
|
|
10415
|
+
|
|
9924
10416
|
def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_color: str) -> None:
|
|
9925
10417
|
cpu_count = os.cpu_count()
|
|
9926
10418
|
|
|
@@ -9932,9 +10424,11 @@ def show_available_hardware_and_generation_strategy_string(gpu_string: str, gpu_
|
|
|
9932
10424
|
pass
|
|
9933
10425
|
|
|
9934
10426
|
if gpu_string:
|
|
9935
|
-
console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]
|
|
10427
|
+
console.print(f"[green]You have {cpu_count} CPUs available for the main process.[/green] [{gpu_color}]{gpu_string}[/{gpu_color}]")
|
|
9936
10428
|
else:
|
|
9937
|
-
print_green(f"You have {cpu_count} CPUs available for the main process.
|
|
10429
|
+
print_green(f"You have {cpu_count} CPUs available for the main process.")
|
|
10430
|
+
|
|
10431
|
+
print_green(gs_string)
|
|
9938
10432
|
|
|
9939
10433
|
def write_args_overview_table() -> None:
|
|
9940
10434
|
table = Table(title="Arguments Overview")
|
|
@@ -10201,112 +10695,6 @@ def find_results_paths(base_path: str) -> list:
|
|
|
10201
10695
|
|
|
10202
10696
|
return list(set(found_paths))
|
|
10203
10697
|
|
|
10204
|
-
def post_job_calculate_pareto_front() -> None:
|
|
10205
|
-
if not args.calculate_pareto_front_of_job:
|
|
10206
|
-
return
|
|
10207
|
-
|
|
10208
|
-
failure = False
|
|
10209
|
-
|
|
10210
|
-
_paths_to_calculate = []
|
|
10211
|
-
|
|
10212
|
-
for _path_to_calculate in list(set(args.calculate_pareto_front_of_job)):
|
|
10213
|
-
try:
|
|
10214
|
-
found_paths = find_results_paths(_path_to_calculate)
|
|
10215
|
-
|
|
10216
|
-
for _fp in found_paths:
|
|
10217
|
-
if _fp not in _paths_to_calculate:
|
|
10218
|
-
_paths_to_calculate.append(_fp)
|
|
10219
|
-
except (FileNotFoundError, NotADirectoryError) as e:
|
|
10220
|
-
print_red(f"post_job_calculate_pareto_front: find_results_paths('{_path_to_calculate}') failed with {e}")
|
|
10221
|
-
|
|
10222
|
-
failure = True
|
|
10223
|
-
|
|
10224
|
-
for _path_to_calculate in _paths_to_calculate:
|
|
10225
|
-
for path_to_calculate in found_paths:
|
|
10226
|
-
if not job_calculate_pareto_front(path_to_calculate):
|
|
10227
|
-
failure = True
|
|
10228
|
-
|
|
10229
|
-
if failure:
|
|
10230
|
-
my_exit(24)
|
|
10231
|
-
|
|
10232
|
-
my_exit(0)
|
|
10233
|
-
|
|
10234
|
-
def job_calculate_pareto_front(path_to_calculate: str, disable_sixel_and_table: bool = False) -> bool:
|
|
10235
|
-
pf_start_time = time.time()
|
|
10236
|
-
|
|
10237
|
-
if not path_to_calculate:
|
|
10238
|
-
return False
|
|
10239
|
-
|
|
10240
|
-
global CURRENT_RUN_FOLDER
|
|
10241
|
-
global RESULT_CSV_FILE
|
|
10242
|
-
global arg_result_names
|
|
10243
|
-
|
|
10244
|
-
if not path_to_calculate:
|
|
10245
|
-
print_red("Can only calculate pareto front of previous job when --calculate_pareto_front_of_job is set")
|
|
10246
|
-
return False
|
|
10247
|
-
|
|
10248
|
-
if not os.path.exists(path_to_calculate):
|
|
10249
|
-
print_red(f"Path '{path_to_calculate}' does not exist")
|
|
10250
|
-
return False
|
|
10251
|
-
|
|
10252
|
-
ax_client_json = f"{path_to_calculate}/state_files/ax_client.experiment.json"
|
|
10253
|
-
|
|
10254
|
-
if not os.path.exists(ax_client_json):
|
|
10255
|
-
print_red(f"Path '{ax_client_json}' not found")
|
|
10256
|
-
return False
|
|
10257
|
-
|
|
10258
|
-
checkpoint_file: str = f"{path_to_calculate}/state_files/checkpoint.json"
|
|
10259
|
-
if not os.path.exists(checkpoint_file):
|
|
10260
|
-
print_red(f"The checkpoint file '{checkpoint_file}' does not exist")
|
|
10261
|
-
return False
|
|
10262
|
-
|
|
10263
|
-
RESULT_CSV_FILE = f"{path_to_calculate}/{RESULTS_CSV_FILENAME}"
|
|
10264
|
-
if not os.path.exists(RESULT_CSV_FILE):
|
|
10265
|
-
print_red(f"{RESULT_CSV_FILE} not found")
|
|
10266
|
-
return False
|
|
10267
|
-
|
|
10268
|
-
res_names = []
|
|
10269
|
-
|
|
10270
|
-
res_names_file = f"{path_to_calculate}/result_names.txt"
|
|
10271
|
-
if not os.path.exists(res_names_file):
|
|
10272
|
-
print_red(f"File '{res_names_file}' does not exist")
|
|
10273
|
-
return False
|
|
10274
|
-
|
|
10275
|
-
try:
|
|
10276
|
-
with open(res_names_file, "r", encoding="utf-8") as file:
|
|
10277
|
-
lines = file.readlines()
|
|
10278
|
-
except Exception as e:
|
|
10279
|
-
print_red(f"Error reading file '{res_names_file}': {e}")
|
|
10280
|
-
return False
|
|
10281
|
-
|
|
10282
|
-
for line in lines:
|
|
10283
|
-
entry = line.strip()
|
|
10284
|
-
if entry != "":
|
|
10285
|
-
res_names.append(entry)
|
|
10286
|
-
|
|
10287
|
-
if len(res_names) < 2:
|
|
10288
|
-
print_red(f"Error: There are less than 2 result names (is: {len(res_names)}, {', '.join(res_names)}) in {path_to_calculate}. Cannot continue calculating the pareto front.")
|
|
10289
|
-
return False
|
|
10290
|
-
|
|
10291
|
-
load_username_to_args(path_to_calculate)
|
|
10292
|
-
|
|
10293
|
-
CURRENT_RUN_FOLDER = path_to_calculate
|
|
10294
|
-
|
|
10295
|
-
arg_result_names = res_names
|
|
10296
|
-
|
|
10297
|
-
load_experiment_parameters_from_checkpoint_file(checkpoint_file, False)
|
|
10298
|
-
|
|
10299
|
-
if experiment_parameters is None:
|
|
10300
|
-
return False
|
|
10301
|
-
|
|
10302
|
-
show_pareto_or_error_msg(path_to_calculate, res_names, disable_sixel_and_table)
|
|
10303
|
-
|
|
10304
|
-
pf_end_time = time.time()
|
|
10305
|
-
|
|
10306
|
-
print_debug(f"Calculating the Pareto-front took {pf_end_time - pf_start_time} seconds")
|
|
10307
|
-
|
|
10308
|
-
return True
|
|
10309
|
-
|
|
10310
10698
|
def set_arg_states_from_continue() -> None:
|
|
10311
10699
|
if args.continue_previous_job and not args.num_random_steps:
|
|
10312
10700
|
num_random_steps_file = f"{args.continue_previous_job}/state_files/num_random_steps"
|
|
@@ -10349,27 +10737,27 @@ def run_program_once(params: Optional[dict] = None) -> None:
|
|
|
10349
10737
|
params = {}
|
|
10350
10738
|
|
|
10351
10739
|
if isinstance(args.run_program_once, str):
|
|
10352
|
-
command_str = args.run_program_once
|
|
10740
|
+
command_str = decode_if_base64(args.run_program_once)
|
|
10353
10741
|
for k, v in params.items():
|
|
10354
10742
|
placeholder = f"%({k})"
|
|
10355
10743
|
command_str = command_str.replace(placeholder, str(v))
|
|
10356
10744
|
|
|
10357
|
-
|
|
10358
|
-
|
|
10359
|
-
|
|
10360
|
-
|
|
10361
|
-
|
|
10362
|
-
|
|
10745
|
+
print(f"Executing command: [cyan]{command_str}[/cyan]")
|
|
10746
|
+
result = subprocess.run(command_str, shell=True, check=True)
|
|
10747
|
+
if result.returncode == 0:
|
|
10748
|
+
print("[bold green]Setup script completed successfully ✅[/bold green]")
|
|
10749
|
+
else:
|
|
10750
|
+
print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
|
|
10363
10751
|
|
|
10364
|
-
|
|
10752
|
+
my_exit(57)
|
|
10365
10753
|
|
|
10366
10754
|
elif isinstance(args.run_program_once, (list, tuple)):
|
|
10367
10755
|
with spinner("run_program_once: Executing command list: [cyan]{args.run_program_once}[/cyan]"):
|
|
10368
10756
|
result = subprocess.run(args.run_program_once, check=True)
|
|
10369
10757
|
if result.returncode == 0:
|
|
10370
|
-
|
|
10758
|
+
print("[bold green]Setup script completed successfully ✅[/bold green]")
|
|
10371
10759
|
else:
|
|
10372
|
-
|
|
10760
|
+
print(f"[bold red]Setup script failed with exit code {result.returncode} ❌[/bold red]")
|
|
10373
10761
|
|
|
10374
10762
|
my_exit(57)
|
|
10375
10763
|
|
|
@@ -10387,8 +10775,14 @@ def show_omniopt_call() -> None:
|
|
|
10387
10775
|
|
|
10388
10776
|
original_print(oo_call + " " + cleaned)
|
|
10389
10777
|
|
|
10778
|
+
if args.dependency is not None and args.dependency != "":
|
|
10779
|
+
print(f"Dependency: {args.dependency}")
|
|
10780
|
+
|
|
10781
|
+
if args.ui_url is not None and args.ui_url != "":
|
|
10782
|
+
print_yellow("--ui_url is deprecated. Do not use it anymore. It will be ignored and one day be removed.")
|
|
10783
|
+
|
|
10390
10784
|
def main() -> None:
|
|
10391
|
-
global RESULT_CSV_FILE,
|
|
10785
|
+
global RESULT_CSV_FILE, LOGFILE_DEBUG_GET_NEXT_TRIALS
|
|
10392
10786
|
|
|
10393
10787
|
check_if_has_random_steps()
|
|
10394
10788
|
|
|
@@ -10470,7 +10864,7 @@ def main() -> None:
|
|
|
10470
10864
|
exp_params = get_experiment_parameters(cli_params_experiment_parameters)
|
|
10471
10865
|
|
|
10472
10866
|
if exp_params is not None:
|
|
10473
|
-
|
|
10867
|
+
experiment_args, gpu_string, gpu_color = exp_params
|
|
10474
10868
|
print_debug(f"experiment_parameters: {experiment_parameters}")
|
|
10475
10869
|
|
|
10476
10870
|
set_orchestrator()
|
|
@@ -10491,6 +10885,8 @@ def main() -> None:
|
|
|
10491
10885
|
|
|
10492
10886
|
write_files_and_show_overviews()
|
|
10493
10887
|
|
|
10888
|
+
live_share()
|
|
10889
|
+
|
|
10494
10890
|
#if args.continue_previous_job:
|
|
10495
10891
|
# insert_jobs_from_csv(f"{args.continue_previous_job}/{RESULTS_CSV_FILENAME}")
|
|
10496
10892
|
|
|
@@ -10710,8 +11106,6 @@ def run_search_with_progress_bar() -> None:
|
|
|
10710
11106
|
wait_for_jobs_to_complete()
|
|
10711
11107
|
|
|
10712
11108
|
def complex_tests(_program_name: str, wanted_stderr: str, wanted_exit_code: int, wanted_signal: Union[int, None], res_is_none: bool = False) -> int:
|
|
10713
|
-
#print_yellow(f"Test suite: {_program_name}")
|
|
10714
|
-
|
|
10715
11109
|
nr_errors: int = 0
|
|
10716
11110
|
|
|
10717
11111
|
program_path: str = f"./.tests/test_wronggoing_stuff.bin/bin/{_program_name}"
|
|
@@ -10785,7 +11179,7 @@ def test_find_paths(program_code: str) -> int:
|
|
|
10785
11179
|
for i in files:
|
|
10786
11180
|
if i not in string:
|
|
10787
11181
|
if os.path.exists(i):
|
|
10788
|
-
print("Missing {i} in find_file_paths string!")
|
|
11182
|
+
print(f"Missing {i} in find_file_paths string!")
|
|
10789
11183
|
nr_errors += 1
|
|
10790
11184
|
|
|
10791
11185
|
return nr_errors
|
|
@@ -11166,17 +11560,16 @@ Exit-Code: 159
|
|
|
11166
11560
|
|
|
11167
11561
|
my_exit(nr_errors)
|
|
11168
11562
|
|
|
11169
|
-
def
|
|
11563
|
+
def main_wrapper() -> None:
|
|
11170
11564
|
print(f"Run-UUID: {run_uuid}")
|
|
11171
11565
|
|
|
11172
11566
|
auto_wrap_namespace(globals())
|
|
11173
11567
|
|
|
11174
11568
|
print_logo()
|
|
11175
11569
|
|
|
11176
|
-
start_logging_daemon() # type: ignore[unused-coroutine]
|
|
11177
|
-
|
|
11178
11570
|
fool_linter(args.num_cpus_main_job)
|
|
11179
11571
|
fool_linter(args.flame_graph)
|
|
11572
|
+
fool_linter(args.memray)
|
|
11180
11573
|
|
|
11181
11574
|
with warnings.catch_warnings():
|
|
11182
11575
|
warnings.simplefilter("ignore")
|
|
@@ -11233,6 +11626,9 @@ def stack_trace_wrapper(func: Any, regex: Any = None) -> Any:
|
|
|
11233
11626
|
def auto_wrap_namespace(namespace: Any) -> Any:
|
|
11234
11627
|
enable_beartype = any(os.getenv(v) for v in ("ENABLE_BEARTYPE", "CI"))
|
|
11235
11628
|
|
|
11629
|
+
if args.beartype:
|
|
11630
|
+
enable_beartype = True
|
|
11631
|
+
|
|
11236
11632
|
excluded_functions = {
|
|
11237
11633
|
"log_time_and_memory_wrapper",
|
|
11238
11634
|
"collect_runtime_stats",
|
|
@@ -11241,7 +11637,6 @@ def auto_wrap_namespace(namespace: Any) -> Any:
|
|
|
11241
11637
|
"_record_stats",
|
|
11242
11638
|
"_open",
|
|
11243
11639
|
"_check_memory_leak",
|
|
11244
|
-
"start_logging_daemon",
|
|
11245
11640
|
"get_current_run_folder",
|
|
11246
11641
|
"show_func_name_wrapper"
|
|
11247
11642
|
}
|
|
@@ -11267,7 +11662,7 @@ def auto_wrap_namespace(namespace: Any) -> Any:
|
|
|
11267
11662
|
|
|
11268
11663
|
if __name__ == "__main__":
|
|
11269
11664
|
try:
|
|
11270
|
-
|
|
11665
|
+
main_wrapper()
|
|
11271
11666
|
except (SignalUSR, SignalINT, SignalCONT) as e:
|
|
11272
|
-
print_red(f"
|
|
11667
|
+
print_red(f"main_wrapper failed with exception {e}")
|
|
11273
11668
|
end_program(True)
|