climate-ref 0.5.5__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- climate_ref/cli/__init__.py +20 -6
- climate_ref/cli/datasets.py +30 -7
- climate_ref/cli/solve.py +38 -3
- climate_ref/config.py +10 -1
- climate_ref/dataset_registry/obs4ref_reference.txt +44 -13
- climate_ref/dataset_registry/sample_data.txt +8 -6
- climate_ref/datasets/base.py +62 -4
- climate_ref/datasets/cmip6.py +14 -40
- climate_ref/datasets/obs4mips.py +11 -54
- climate_ref/executor/__init__.py +2 -1
- climate_ref/executor/hpc.py +308 -0
- climate_ref/executor/local.py +24 -4
- climate_ref/executor/result_handling.py +0 -1
- climate_ref/slurm.py +192 -0
- climate_ref/solver.py +67 -6
- climate_ref/testing.py +7 -5
- {climate_ref-0.5.5.dist-info → climate_ref-0.6.1.dist-info}/METADATA +3 -2
- {climate_ref-0.5.5.dist-info → climate_ref-0.6.1.dist-info}/RECORD +22 -20
- {climate_ref-0.5.5.dist-info → climate_ref-0.6.1.dist-info}/WHEEL +0 -0
- {climate_ref-0.5.5.dist-info → climate_ref-0.6.1.dist-info}/entry_points.txt +0 -0
- {climate_ref-0.5.5.dist-info → climate_ref-0.6.1.dist-info}/licenses/LICENCE +0 -0
- {climate_ref-0.5.5.dist-info → climate_ref-0.6.1.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HPC-based Executor to use job schedulers.
|
|
3
|
+
|
|
4
|
+
If you want to
|
|
5
|
+
- run REF under the HPC workflows
|
|
6
|
+
- run REF in multiple nodes
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
from typing import Any, cast
|
|
13
|
+
|
|
14
|
+
import parsl
|
|
15
|
+
from loguru import logger
|
|
16
|
+
from parsl import python_app
|
|
17
|
+
from parsl.config import Config as ParslConfig
|
|
18
|
+
from parsl.executors import HighThroughputExecutor
|
|
19
|
+
from parsl.launchers import SrunLauncher
|
|
20
|
+
from parsl.providers import SlurmProvider
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
|
|
23
|
+
from climate_ref.config import Config
|
|
24
|
+
from climate_ref.database import Database
|
|
25
|
+
from climate_ref.models import Execution
|
|
26
|
+
from climate_ref.slurm import HAS_REAL_SLURM, SlurmChecker
|
|
27
|
+
from climate_ref_core.diagnostics import ExecutionDefinition, ExecutionResult
|
|
28
|
+
from climate_ref_core.exceptions import DiagnosticError, ExecutionError
|
|
29
|
+
from climate_ref_core.executor import execute_locally
|
|
30
|
+
|
|
31
|
+
from .local import ExecutionFuture, process_result
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@python_app
|
|
35
|
+
def _process_run(definition: ExecutionDefinition, log_level: str) -> ExecutionResult:
|
|
36
|
+
"""Run the function on computer nodes"""
|
|
37
|
+
# This is a catch-all for any exceptions that occur in the process and need to raise for
|
|
38
|
+
# parsl retries to work
|
|
39
|
+
try:
|
|
40
|
+
return execute_locally(definition=definition, log_level=log_level, raise_error=True)
|
|
41
|
+
except DiagnosticError as e: # pragma: no cover
|
|
42
|
+
# any diagnostic error will be caught here
|
|
43
|
+
logger.exception("Error running diagnostic")
|
|
44
|
+
return cast(ExecutionResult, e.result)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _to_float(x: Any) -> float | None:
|
|
48
|
+
if x is None:
|
|
49
|
+
return None
|
|
50
|
+
if isinstance(x, int | float):
|
|
51
|
+
return float(x)
|
|
52
|
+
try:
|
|
53
|
+
return float(x)
|
|
54
|
+
except (ValueError, TypeError):
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _to_int(x: Any) -> int | None:
|
|
59
|
+
if x is None:
|
|
60
|
+
return None
|
|
61
|
+
if isinstance(x, int):
|
|
62
|
+
return x
|
|
63
|
+
try:
|
|
64
|
+
return int(float(x)) # Handles both "123" and "123.0"
|
|
65
|
+
except (ValueError, TypeError):
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class HPCExecutor:
|
|
70
|
+
"""
|
|
71
|
+
Run diagnostics by submitting a job script
|
|
72
|
+
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
name = "hpc"
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
*,
|
|
80
|
+
database: Database | None = None,
|
|
81
|
+
config: Config | None = None,
|
|
82
|
+
**executor_config: str | float | int,
|
|
83
|
+
) -> None:
|
|
84
|
+
config = config or Config.default()
|
|
85
|
+
database = database or Database.from_config(config, run_migrations=False)
|
|
86
|
+
|
|
87
|
+
self.config = config
|
|
88
|
+
self.database = database
|
|
89
|
+
|
|
90
|
+
self.scheduler = executor_config.get("scheduler", "slurm")
|
|
91
|
+
self.account = str(executor_config.get("account", os.environ.get("USER")))
|
|
92
|
+
self.username = executor_config.get("username", os.environ.get("USER"))
|
|
93
|
+
self.partition = str(executor_config.get("partition")) if executor_config.get("partition") else None
|
|
94
|
+
self.qos = str(executor_config.get("qos")) if executor_config.get("qos") else None
|
|
95
|
+
self.req_nodes = int(executor_config.get("req_nodes", 1))
|
|
96
|
+
self.walltime = str(executor_config.get("walltime", "00:10:00"))
|
|
97
|
+
self.log_dir = str(executor_config.get("log_dir", "runinfo"))
|
|
98
|
+
|
|
99
|
+
self.cores_per_worker = _to_int(executor_config.get("cores_per_worker"))
|
|
100
|
+
self.mem_per_worker = _to_float(executor_config.get("mem_per_worker"))
|
|
101
|
+
|
|
102
|
+
hours, minutes, seconds = map(int, self.walltime.split(":"))
|
|
103
|
+
total_minutes = hours * 60 + minutes + seconds / 60
|
|
104
|
+
self.total_minutes = total_minutes
|
|
105
|
+
|
|
106
|
+
if executor_config.get("validation") and HAS_REAL_SLURM:
|
|
107
|
+
self._validate_slurm_params()
|
|
108
|
+
|
|
109
|
+
self._initialize_parsl()
|
|
110
|
+
|
|
111
|
+
self.parsl_results: list[ExecutionFuture] = []
|
|
112
|
+
|
|
113
|
+
def _validate_slurm_params(self) -> None:
|
|
114
|
+
"""Validate the Slurm configuration using SlurmChecker.
|
|
115
|
+
|
|
116
|
+
Raises
|
|
117
|
+
------
|
|
118
|
+
ValueError: If account, partition or QOS are invalid or inaccessible.
|
|
119
|
+
"""
|
|
120
|
+
slurm_checker = SlurmChecker()
|
|
121
|
+
if self.account and not slurm_checker.get_account_info(self.account):
|
|
122
|
+
raise ValueError(f"Account: {self.account} not valid")
|
|
123
|
+
|
|
124
|
+
partition_limits = None
|
|
125
|
+
node_info = None
|
|
126
|
+
|
|
127
|
+
if self.partition:
|
|
128
|
+
if not slurm_checker.get_partition_info(self.partition):
|
|
129
|
+
raise ValueError(f"Partition: {self.partition} not valid")
|
|
130
|
+
|
|
131
|
+
if not slurm_checker.can_account_use_partition(self.account, self.partition):
|
|
132
|
+
raise ValueError(f"Account: {self.account} cannot access partiton: {self.partition}")
|
|
133
|
+
|
|
134
|
+
partition_limits = slurm_checker.get_partition_limits(self.partition)
|
|
135
|
+
node_info = slurm_checker.get_node_from_partition(self.partition)
|
|
136
|
+
|
|
137
|
+
qos_limits = None
|
|
138
|
+
if self.qos:
|
|
139
|
+
if not slurm_checker.get_qos_info(self.qos):
|
|
140
|
+
raise ValueError(f"QOS: {self.qos} not valid")
|
|
141
|
+
|
|
142
|
+
if not slurm_checker.can_account_use_qos(self.account, self.qos):
|
|
143
|
+
raise ValueError(f"Account: {self.account} cannot access qos: {self.qos}")
|
|
144
|
+
|
|
145
|
+
qos_limits = slurm_checker.get_qos_limits(self.qos)
|
|
146
|
+
|
|
147
|
+
max_cores_per_node = int(node_info["cpus"]) if node_info else None
|
|
148
|
+
if max_cores_per_node and self.cores_per_worker:
|
|
149
|
+
if self.cores_per_worker > max_cores_per_node:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"cores_per_work:{self.cores_per_worker}"
|
|
152
|
+
f"larger than the maximum in a node {max_cores_per_node}"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
max_mem_per_node = float(node_info["real_memory"]) if node_info else None
|
|
156
|
+
if max_mem_per_node and self.mem_per_worker:
|
|
157
|
+
if self.mem_per_worker > max_mem_per_node:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"mem_per_work:{self.mem_per_worker}"
|
|
160
|
+
f"larger than the maximum mem in a node {max_mem_per_node}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
max_walltime_partition = (
|
|
164
|
+
partition_limits["max_time_minutes"] if partition_limits else self.total_minutes
|
|
165
|
+
)
|
|
166
|
+
max_walltime_qos = qos_limits["max_time_minutes"] if qos_limits else self.total_minutes
|
|
167
|
+
|
|
168
|
+
max_walltime_minutes = min(float(max_walltime_partition), float(max_walltime_qos))
|
|
169
|
+
|
|
170
|
+
if self.total_minutes > float(max_walltime_minutes):
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"Walltime: {self.walltime} exceed the maximum time "
|
|
173
|
+
f"{max_walltime_minutes} allowed by {self.partition} and {self.qos}"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def _initialize_parsl(self) -> None:
|
|
177
|
+
executor_config = self.config.executor.config
|
|
178
|
+
|
|
179
|
+
provider = SlurmProvider(
|
|
180
|
+
account=self.account,
|
|
181
|
+
partition=self.partition,
|
|
182
|
+
qos=self.qos,
|
|
183
|
+
nodes_per_block=self.req_nodes,
|
|
184
|
+
max_blocks=int(executor_config.get("max_blocks", 1)),
|
|
185
|
+
scheduler_options=executor_config.get("scheduler_options", "#SBATCH -C cpu"),
|
|
186
|
+
worker_init=executor_config.get("worker_init", "source .venv/bin/activate"),
|
|
187
|
+
launcher=SrunLauncher(
|
|
188
|
+
debug=True,
|
|
189
|
+
overrides=executor_config.get("overrides", ""),
|
|
190
|
+
),
|
|
191
|
+
walltime=self.walltime,
|
|
192
|
+
cmd_timeout=int(executor_config.get("cmd_timeout", 120)),
|
|
193
|
+
)
|
|
194
|
+
executor = HighThroughputExecutor(
|
|
195
|
+
label="ref_hpc_executor",
|
|
196
|
+
cores_per_worker=self.cores_per_worker if self.cores_per_worker else 1,
|
|
197
|
+
mem_per_worker=self.mem_per_worker,
|
|
198
|
+
max_workers_per_node=_to_int(executor_config.get("max_workers_per_node", 16)),
|
|
199
|
+
cpu_affinity=str(executor_config.get("cpu_affinity")),
|
|
200
|
+
provider=provider,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
hpc_config = ParslConfig(
|
|
204
|
+
run_dir=self.log_dir, executors=[executor], retries=int(executor_config.get("retries", 2))
|
|
205
|
+
)
|
|
206
|
+
parsl.load(hpc_config)
|
|
207
|
+
|
|
208
|
+
def run(
|
|
209
|
+
self,
|
|
210
|
+
definition: ExecutionDefinition,
|
|
211
|
+
execution: Execution | None = None,
|
|
212
|
+
) -> None:
|
|
213
|
+
"""
|
|
214
|
+
Run a diagnostic in process
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
definition
|
|
219
|
+
A description of the information needed for this execution of the diagnostic
|
|
220
|
+
execution
|
|
221
|
+
A database model representing the execution of the diagnostic.
|
|
222
|
+
If provided, the result will be updated in the database when completed.
|
|
223
|
+
"""
|
|
224
|
+
# Submit the execution to the process pool
|
|
225
|
+
# and track the future so we can wait for it to complete
|
|
226
|
+
future = _process_run(
|
|
227
|
+
definition=definition,
|
|
228
|
+
log_level=self.config.log_level,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
self.parsl_results.append(
|
|
232
|
+
ExecutionFuture(
|
|
233
|
+
future=future,
|
|
234
|
+
definition=definition,
|
|
235
|
+
execution_id=execution.id if execution else None,
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def join(self, timeout: float) -> None:
|
|
240
|
+
"""
|
|
241
|
+
Wait for all diagnostics to finish
|
|
242
|
+
|
|
243
|
+
This will block until all diagnostics have completed or the timeout is reached.
|
|
244
|
+
If the timeout is reached, the method will return and raise an exception.
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
timeout
|
|
249
|
+
Timeout in seconds (won't used in HPCExecutor)
|
|
250
|
+
|
|
251
|
+
Raises
|
|
252
|
+
------
|
|
253
|
+
TimeoutError
|
|
254
|
+
If the timeout is reached
|
|
255
|
+
"""
|
|
256
|
+
start_time = time.time()
|
|
257
|
+
refresh_time = 0.5
|
|
258
|
+
|
|
259
|
+
results = self.parsl_results
|
|
260
|
+
t = tqdm(total=len(results), desc="Waiting for executions to complete", unit="execution")
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
while results:
|
|
264
|
+
# Iterate over a copy of the list and remove finished tasks
|
|
265
|
+
for result in results[:]:
|
|
266
|
+
if result.future.done():
|
|
267
|
+
try:
|
|
268
|
+
execution_result = result.future.result(timeout=0)
|
|
269
|
+
except Exception as e:
|
|
270
|
+
# Something went wrong when attempting to run the execution
|
|
271
|
+
# This is likely a failure in the execution itself not the diagnostic
|
|
272
|
+
raise ExecutionError(
|
|
273
|
+
f"Failed to execute {result.definition.execution_slug()!r}"
|
|
274
|
+
) from e
|
|
275
|
+
|
|
276
|
+
assert execution_result is not None, "Execution result should not be None"
|
|
277
|
+
assert isinstance(execution_result, ExecutionResult), (
|
|
278
|
+
"Execution result should be of type ExecutionResult"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Process the result in the main process
|
|
282
|
+
# The results should be committed after each execution
|
|
283
|
+
with self.database.session.begin():
|
|
284
|
+
execution = (
|
|
285
|
+
self.database.session.get(Execution, result.execution_id)
|
|
286
|
+
if result.execution_id
|
|
287
|
+
else None
|
|
288
|
+
)
|
|
289
|
+
process_result(self.config, self.database, result.future.result(), execution)
|
|
290
|
+
logger.debug(f"Execution completed: {result}")
|
|
291
|
+
t.update(n=1)
|
|
292
|
+
results.remove(result)
|
|
293
|
+
|
|
294
|
+
# Break early to avoid waiting for one more sleep cycle
|
|
295
|
+
if len(results) == 0:
|
|
296
|
+
break
|
|
297
|
+
|
|
298
|
+
elapsed_time = time.time() - start_time
|
|
299
|
+
|
|
300
|
+
if elapsed_time > self.total_minutes * 60:
|
|
301
|
+
logger.debug(f"Time elasped {elapsed_time} for joining the results")
|
|
302
|
+
|
|
303
|
+
# Wait for a short time before checking for completed executions
|
|
304
|
+
time.sleep(refresh_time)
|
|
305
|
+
finally:
|
|
306
|
+
t.close()
|
|
307
|
+
if parsl.dfk():
|
|
308
|
+
parsl.dfk().cleanup()
|
climate_ref/executor/local.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import concurrent.futures
|
|
2
|
+
import multiprocessing
|
|
2
3
|
import time
|
|
3
4
|
from concurrent.futures import Future, ProcessPoolExecutor
|
|
4
5
|
from typing import Any
|
|
@@ -13,7 +14,7 @@ from climate_ref.models import Execution
|
|
|
13
14
|
from climate_ref_core.diagnostics import ExecutionDefinition, ExecutionResult
|
|
14
15
|
from climate_ref_core.exceptions import ExecutionError
|
|
15
16
|
from climate_ref_core.executor import execute_locally
|
|
16
|
-
from climate_ref_core.logging import
|
|
17
|
+
from climate_ref_core.logging import initialise_logging
|
|
17
18
|
|
|
18
19
|
from .result_handling import handle_execution_result
|
|
19
20
|
|
|
@@ -63,11 +64,17 @@ class ExecutionFuture:
|
|
|
63
64
|
execution_id: int | None = None
|
|
64
65
|
|
|
65
66
|
|
|
66
|
-
def _process_initialiser() -> None:
|
|
67
|
+
def _process_initialiser() -> None: # pragma: no cover
|
|
67
68
|
# Setup the logging for the process
|
|
68
69
|
# This replaces the loguru default handler
|
|
69
70
|
try:
|
|
70
|
-
|
|
71
|
+
logger.remove()
|
|
72
|
+
config = Config.default()
|
|
73
|
+
initialise_logging(
|
|
74
|
+
level=config.log_level,
|
|
75
|
+
format=config.log_format,
|
|
76
|
+
log_directory=config.paths.log,
|
|
77
|
+
)
|
|
71
78
|
except Exception as e:
|
|
72
79
|
# Don't raise an exception here as that would kill the process pool
|
|
73
80
|
# We want to log the error and continue
|
|
@@ -118,7 +125,12 @@ class LocalExecutor:
|
|
|
118
125
|
if pool is not None:
|
|
119
126
|
self.pool = pool
|
|
120
127
|
else:
|
|
121
|
-
self.pool = ProcessPoolExecutor(
|
|
128
|
+
self.pool = ProcessPoolExecutor(
|
|
129
|
+
max_workers=n,
|
|
130
|
+
initializer=_process_initialiser,
|
|
131
|
+
# Explicitly set the context to "spawn" to avoid issues with hanging on MacOS
|
|
132
|
+
mp_context=multiprocessing.get_context("spawn"),
|
|
133
|
+
)
|
|
122
134
|
self._results: list[ExecutionFuture] = []
|
|
123
135
|
|
|
124
136
|
def run(
|
|
@@ -214,9 +226,17 @@ class LocalExecutor:
|
|
|
214
226
|
elapsed_time = time.time() - start_time
|
|
215
227
|
|
|
216
228
|
if elapsed_time > timeout:
|
|
229
|
+
for result in results:
|
|
230
|
+
logger.warning(
|
|
231
|
+
f"Execution {result.definition.execution_slug()} "
|
|
232
|
+
f"did not complete within the timeout"
|
|
233
|
+
)
|
|
234
|
+
self.pool.shutdown(wait=False, cancel_futures=True)
|
|
217
235
|
raise TimeoutError("Not all tasks completed within the specified timeout")
|
|
218
236
|
|
|
219
237
|
# Wait for a short time before checking for completed executions
|
|
220
238
|
time.sleep(refresh_time)
|
|
221
239
|
finally:
|
|
222
240
|
t.close()
|
|
241
|
+
|
|
242
|
+
logger.info("All executions completed successfully")
|
|
@@ -129,7 +129,6 @@ def handle_execution_result(
|
|
|
129
129
|
cv.validate_metrics(cmec_metric_bundle)
|
|
130
130
|
except (ResultValidationError, AssertionError):
|
|
131
131
|
logger.exception("Diagnostic values do not conform with the controlled vocabulary")
|
|
132
|
-
# TODO: Mark the diagnostic execution result as failed once the CV has stabilised
|
|
133
132
|
# execution.mark_failed()
|
|
134
133
|
|
|
135
134
|
# Perform a bulk insert of scalar values
|
climate_ref/slurm.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
HAS_REAL_SLURM = importlib.util.find_spec("pyslurm") is not None
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SlurmChecker:
|
|
8
|
+
"""Check and get slurm settings."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, intest: bool = False) -> None:
|
|
11
|
+
if HAS_REAL_SLURM:
|
|
12
|
+
import pyslurm # type: ignore
|
|
13
|
+
|
|
14
|
+
self.slurm_association: dict[int, Any] | None = pyslurm.db.Associations.load()
|
|
15
|
+
self.slurm_partition: dict[str, Any] | None = pyslurm.Partitions.load()
|
|
16
|
+
self.slurm_qos: dict[str, Any] | None = pyslurm.qos().get()
|
|
17
|
+
self.slurm_node: dict[str, Any] | None = pyslurm.Nodes.load()
|
|
18
|
+
elif intest:
|
|
19
|
+
import pyslurm
|
|
20
|
+
|
|
21
|
+
self.slurm_association = pyslurm.db.Associations.load() # dict [num -> Association]
|
|
22
|
+
self.slurm_partition = pyslurm.Partitions.load() # collection
|
|
23
|
+
self.slurm_qos = pyslurm.qos().get() # dict
|
|
24
|
+
self.slurm_node = pyslurm.Nodes.load() # dict
|
|
25
|
+
else:
|
|
26
|
+
print("Warning: pyslurm not found. Skipping HPCExecutor config validations")
|
|
27
|
+
self.slurm_association = None
|
|
28
|
+
self.slurm_partition = None
|
|
29
|
+
self.slurm_qos = None
|
|
30
|
+
self.slurm_node = None
|
|
31
|
+
|
|
32
|
+
def get_partition_info(self, partition_name: str) -> Any:
|
|
33
|
+
"""Check if a partition exists in the Slurm configuration."""
|
|
34
|
+
return self.slurm_partition.get(partition_name) if self.slurm_partition else None
|
|
35
|
+
|
|
36
|
+
def get_qos_info(self, qos_name: str) -> Any:
|
|
37
|
+
"""Check if a qos exists in the Slurm configuration."""
|
|
38
|
+
return self.slurm_qos.get(qos_name) if self.slurm_qos else None
|
|
39
|
+
|
|
40
|
+
def get_account_info(self, account_name: str) -> list[Any]:
|
|
41
|
+
"""Get all associations for an account"""
|
|
42
|
+
if self.slurm_association:
|
|
43
|
+
return [a for a in self.slurm_association.values() if a.account == account_name]
|
|
44
|
+
else:
|
|
45
|
+
return [None]
|
|
46
|
+
|
|
47
|
+
def can_account_use_partition(self, account_name: str, partition_name: str) -> bool:
|
|
48
|
+
"""
|
|
49
|
+
Check if an account has access to a specific partition.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
bool: True if accessible, False if not accessible or error occurred
|
|
54
|
+
"""
|
|
55
|
+
account_info = self.get_account_info(account_name)
|
|
56
|
+
if not account_info:
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
partition_info = self.get_partition_info(partition_name)
|
|
60
|
+
|
|
61
|
+
if not partition_info:
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
allowed_partitions = account_info[0].partition
|
|
65
|
+
if allowed_partitions is None:
|
|
66
|
+
return True
|
|
67
|
+
else:
|
|
68
|
+
return partition_name in allowed_partitions
|
|
69
|
+
|
|
70
|
+
def can_account_use_qos(self, account_name: str, qos_name: str) -> bool:
|
|
71
|
+
"""
|
|
72
|
+
Check if an account has access to a specific qos.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
bool: True if accessible, False if not accessible or error occurred
|
|
77
|
+
"""
|
|
78
|
+
account_info = self.get_account_info(account_name)
|
|
79
|
+
|
|
80
|
+
if not account_info:
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
qos_info = self.get_qos_info(qos_name)
|
|
84
|
+
if not qos_info:
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
sample_acc = account_info[0]
|
|
88
|
+
for acc in account_info:
|
|
89
|
+
if acc.user == "minxu":
|
|
90
|
+
sample_acc = acc
|
|
91
|
+
break
|
|
92
|
+
|
|
93
|
+
allowed_qoss = sample_acc.qos
|
|
94
|
+
if allowed_qoss is None:
|
|
95
|
+
return True
|
|
96
|
+
else:
|
|
97
|
+
return qos_name in allowed_qoss
|
|
98
|
+
|
|
99
|
+
def get_partition_limits(self, partition_name: str) -> dict[str, str | int] | None:
|
|
100
|
+
"""
|
|
101
|
+
Get time limits for a specific partition.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
Dict with 'max_time' and 'default_time' (strings or UNLIMITED)
|
|
106
|
+
or None if partition doesn't exist or error occurred
|
|
107
|
+
"""
|
|
108
|
+
partition_info = self.get_partition_info(partition_name)
|
|
109
|
+
if not partition_info:
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
return {
|
|
113
|
+
"max_time_minutes": partition_info.to_dict().get("max_time", 0), # in minutes
|
|
114
|
+
"default_time_minutes": partition_info.to_dict().get("default_time", 30), # in minutes
|
|
115
|
+
"max_nodes": partition_info.to_dict().get("max_node", 1),
|
|
116
|
+
"total_nodes": partition_info.to_dict().get("total_nodes", 0),
|
|
117
|
+
"total_cpus": partition_info.to_dict().get("total_cpus", 0),
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
def get_node_from_partition(self, partition_name: str) -> dict[str, str | int] | None:
|
|
121
|
+
"""
|
|
122
|
+
Get the node information for a specific partition.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
Dicts
|
|
127
|
+
"""
|
|
128
|
+
partition_info = self.get_partition_info(partition_name)
|
|
129
|
+
if not partition_info:
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
sample_node = None
|
|
133
|
+
|
|
134
|
+
if self.slurm_node:
|
|
135
|
+
for node in self.slurm_node.values():
|
|
136
|
+
if partition_name in node.partitions and "cpu" in node.available_features:
|
|
137
|
+
sample_node = node
|
|
138
|
+
break
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"cpus": int(sample_node.total_cpus) if sample_node is not None else 1,
|
|
142
|
+
"cores_per_socket": int(sample_node.cores_per_socket) if sample_node is not None else 1,
|
|
143
|
+
"sockets": int(sample_node.sockets) if sample_node is not None else 1,
|
|
144
|
+
"threads_per_core": int(sample_node.threads_per_core) if sample_node is not None else 1,
|
|
145
|
+
"real_memory": int(sample_node.real_memory) if sample_node is not None else 215,
|
|
146
|
+
"node_names": sample_node.name if sample_node is not None else "unknown",
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
def get_qos_limits(self, qos_name: str) -> dict[str, str | int]:
|
|
150
|
+
"""
|
|
151
|
+
Get time limits for a specific qos.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
Dict with 'max_time' and 'default_time' (strings or UNLIMITED)
|
|
156
|
+
or None if partition doesn't exist or error occurred
|
|
157
|
+
"""
|
|
158
|
+
qos_info = self.get_qos_info(qos_name)
|
|
159
|
+
|
|
160
|
+
return {
|
|
161
|
+
"max_time_minutes": qos_info.get("max_wall_pj", 1.0e6),
|
|
162
|
+
"max_jobs_pu": qos_info.get("max_jobs_pu", 1.0e6),
|
|
163
|
+
"max_submit_jobs_pu": qos_info.get("max_submit_jobs_pu", 1.0e6),
|
|
164
|
+
"max_tres_pj": qos_info.get("max_tres_pj").split("=")[0],
|
|
165
|
+
"default_time_minutes": 120,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def check_account_partition_access_with_limits(
|
|
169
|
+
self, account_name: str, partition_name: str
|
|
170
|
+
) -> dict[str, Any]:
|
|
171
|
+
"""
|
|
172
|
+
Comprehensive check of account access and partition limits.
|
|
173
|
+
|
|
174
|
+
Returns dictionary with all relevant information.
|
|
175
|
+
"""
|
|
176
|
+
result = {
|
|
177
|
+
"account_exists": True if self.get_account_info(account_name) else False,
|
|
178
|
+
"partition_exists": True if self.get_partition_info(partition_name) else False,
|
|
179
|
+
"has_access": False,
|
|
180
|
+
"time_limits": None,
|
|
181
|
+
"error": "none",
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
if result["account_exists"] and result["partition_exists"]:
|
|
186
|
+
result["has_access"] = self.can_account_use_partition(account_name, partition_name)
|
|
187
|
+
if result["has_access"]:
|
|
188
|
+
result["time_limits"] = self.get_partition_info(partition_name).to_dict().get("max_time")
|
|
189
|
+
except Exception as e:
|
|
190
|
+
result["error"] = str(e)
|
|
191
|
+
|
|
192
|
+
return result
|
climate_ref/solver.py
CHANGED
|
@@ -245,6 +245,57 @@ def _solve_from_data_requirements(
|
|
|
245
245
|
)
|
|
246
246
|
|
|
247
247
|
|
|
248
|
+
@define
|
|
249
|
+
class SolveFilterOptions:
|
|
250
|
+
"""
|
|
251
|
+
Options to filter the diagnostics that are solved
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
diagnostic: list[str] | None = None
|
|
255
|
+
"""
|
|
256
|
+
Check if the diagnostic slug contains any of the provided values
|
|
257
|
+
"""
|
|
258
|
+
provider: list[str] | None = None
|
|
259
|
+
"""
|
|
260
|
+
Check if the provider slug contains any of the provided values
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def matches_filter(diagnostic: Diagnostic, filters: SolveFilterOptions | None) -> bool:
|
|
265
|
+
"""
|
|
266
|
+
Check if a diagnostic matches the provided filters
|
|
267
|
+
|
|
268
|
+
Each filter is optional and a diagnostic will match if it satisfies all the provided filters.
|
|
269
|
+
i.e. the filters are ANDed together.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
diagnostic
|
|
274
|
+
Diagnostic to check against the filters
|
|
275
|
+
filters
|
|
276
|
+
Collection of filters to apply to the diagnostic
|
|
277
|
+
|
|
278
|
+
If no filters are provided, the diagnostic is considered to match
|
|
279
|
+
|
|
280
|
+
Returns
|
|
281
|
+
-------
|
|
282
|
+
True if the diagnostic matches the filters, False otherwise
|
|
283
|
+
"""
|
|
284
|
+
if filters is None:
|
|
285
|
+
return True
|
|
286
|
+
|
|
287
|
+
diagnostic_slug = diagnostic.slug
|
|
288
|
+
provider_slug = diagnostic.provider.slug
|
|
289
|
+
|
|
290
|
+
if filters.provider and not any([f.lower() in provider_slug for f in filters.provider]):
|
|
291
|
+
return False
|
|
292
|
+
|
|
293
|
+
if filters.diagnostic and not any([f.lower() in diagnostic_slug for f in filters.diagnostic]):
|
|
294
|
+
return False
|
|
295
|
+
|
|
296
|
+
return True
|
|
297
|
+
|
|
298
|
+
|
|
248
299
|
@define
|
|
249
300
|
class ExecutionSolver:
|
|
250
301
|
"""
|
|
@@ -278,7 +329,9 @@ class ExecutionSolver:
|
|
|
278
329
|
},
|
|
279
330
|
)
|
|
280
331
|
|
|
281
|
-
def solve(
|
|
332
|
+
def solve(
|
|
333
|
+
self, filters: SolveFilterOptions | None = None
|
|
334
|
+
) -> typing.Generator[DiagnosticExecution, None, None]:
|
|
282
335
|
"""
|
|
283
336
|
Solve which executions need to be calculated for a dataset
|
|
284
337
|
|
|
@@ -293,17 +346,23 @@ class ExecutionSolver:
|
|
|
293
346
|
"""
|
|
294
347
|
for provider in self.provider_registry.providers:
|
|
295
348
|
for diagnostic in provider.diagnostics():
|
|
349
|
+
# Filter the diagnostic based on the provided filters
|
|
350
|
+
if not matches_filter(diagnostic, filters):
|
|
351
|
+
logger.debug(f"Skipping {diagnostic.full_slug()} due to filter")
|
|
352
|
+
continue
|
|
296
353
|
yield from solve_executions(self.data_catalog, diagnostic, provider)
|
|
297
354
|
|
|
298
355
|
|
|
299
356
|
def solve_required_executions( # noqa: PLR0913
|
|
300
357
|
db: Database,
|
|
301
358
|
dry_run: bool = False,
|
|
359
|
+
execute: bool = True,
|
|
302
360
|
solver: ExecutionSolver | None = None,
|
|
303
361
|
config: Config | None = None,
|
|
304
362
|
timeout: int = 60,
|
|
305
363
|
one_per_provider: bool = False,
|
|
306
364
|
one_per_diagnostic: bool = False,
|
|
365
|
+
filters: SolveFilterOptions | None = None,
|
|
307
366
|
) -> None:
|
|
308
367
|
"""
|
|
309
368
|
Solve for executions that require recalculation
|
|
@@ -328,7 +387,7 @@ def solve_required_executions( # noqa: PLR0913
|
|
|
328
387
|
diagnostic_count = {}
|
|
329
388
|
provider_count = {}
|
|
330
389
|
|
|
331
|
-
for potential_execution in solver.solve():
|
|
390
|
+
for potential_execution in solver.solve(filters):
|
|
332
391
|
# The diagnostic output is first written to the scratch directory
|
|
333
392
|
definition = potential_execution.build_execution_definition(output_root=config.paths.scratch)
|
|
334
393
|
|
|
@@ -371,6 +430,7 @@ def solve_required_executions( # noqa: PLR0913
|
|
|
371
430
|
logger.info(f"Created new execution group: {potential_execution.execution_slug()!r}")
|
|
372
431
|
db.session.flush()
|
|
373
432
|
|
|
433
|
+
# TODO: Move this logic to the solver
|
|
374
434
|
# Check if we should run given the one_per_provider or one_per_diagnostic flags
|
|
375
435
|
one_of_check_failed = (
|
|
376
436
|
one_per_provider and provider_count.get(diagnostic.provider.slug, 0) > 0
|
|
@@ -403,10 +463,11 @@ def solve_required_executions( # noqa: PLR0913
|
|
|
403
463
|
# Add links to the datasets used in the execution
|
|
404
464
|
execution.register_datasets(db, definition.datasets)
|
|
405
465
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
466
|
+
if execute:
|
|
467
|
+
executor.run(
|
|
468
|
+
definition=definition,
|
|
469
|
+
execution=execution,
|
|
470
|
+
)
|
|
410
471
|
|
|
411
472
|
provider_count[diagnostic.provider.slug] += 1
|
|
412
473
|
diagnostic_count[diagnostic.full_slug()] += 1
|