climate-ref 0.5.5__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,308 @@
1
+ """
2
+ HPC-based Executor to use job schedulers.
3
+
4
+ If you want to
5
+ - run REF under the HPC workflows
6
+ - run REF in multiple nodes
7
+
8
+ """
9
+
10
+ import os
11
+ import time
12
+ from typing import Any, cast
13
+
14
+ import parsl
15
+ from loguru import logger
16
+ from parsl import python_app
17
+ from parsl.config import Config as ParslConfig
18
+ from parsl.executors import HighThroughputExecutor
19
+ from parsl.launchers import SrunLauncher
20
+ from parsl.providers import SlurmProvider
21
+ from tqdm import tqdm
22
+
23
+ from climate_ref.config import Config
24
+ from climate_ref.database import Database
25
+ from climate_ref.models import Execution
26
+ from climate_ref.slurm import HAS_REAL_SLURM, SlurmChecker
27
+ from climate_ref_core.diagnostics import ExecutionDefinition, ExecutionResult
28
+ from climate_ref_core.exceptions import DiagnosticError, ExecutionError
29
+ from climate_ref_core.executor import execute_locally
30
+
31
+ from .local import ExecutionFuture, process_result
32
+
33
+
34
+ @python_app
35
+ def _process_run(definition: ExecutionDefinition, log_level: str) -> ExecutionResult:
36
+ """Run the function on computer nodes"""
37
+ # This is a catch-all for any exceptions that occur in the process and need to raise for
38
+ # parsl retries to work
39
+ try:
40
+ return execute_locally(definition=definition, log_level=log_level, raise_error=True)
41
+ except DiagnosticError as e: # pragma: no cover
42
+ # any diagnostic error will be caught here
43
+ logger.exception("Error running diagnostic")
44
+ return cast(ExecutionResult, e.result)
45
+
46
+
47
+ def _to_float(x: Any) -> float | None:
48
+ if x is None:
49
+ return None
50
+ if isinstance(x, int | float):
51
+ return float(x)
52
+ try:
53
+ return float(x)
54
+ except (ValueError, TypeError):
55
+ return None
56
+
57
+
58
+ def _to_int(x: Any) -> int | None:
59
+ if x is None:
60
+ return None
61
+ if isinstance(x, int):
62
+ return x
63
+ try:
64
+ return int(float(x)) # Handles both "123" and "123.0"
65
+ except (ValueError, TypeError):
66
+ return None
67
+
68
+
69
+ class HPCExecutor:
70
+ """
71
+ Run diagnostics by submitting a job script
72
+
73
+ """
74
+
75
+ name = "hpc"
76
+
77
+ def __init__(
78
+ self,
79
+ *,
80
+ database: Database | None = None,
81
+ config: Config | None = None,
82
+ **executor_config: str | float | int,
83
+ ) -> None:
84
+ config = config or Config.default()
85
+ database = database or Database.from_config(config, run_migrations=False)
86
+
87
+ self.config = config
88
+ self.database = database
89
+
90
+ self.scheduler = executor_config.get("scheduler", "slurm")
91
+ self.account = str(executor_config.get("account", os.environ.get("USER")))
92
+ self.username = executor_config.get("username", os.environ.get("USER"))
93
+ self.partition = str(executor_config.get("partition")) if executor_config.get("partition") else None
94
+ self.qos = str(executor_config.get("qos")) if executor_config.get("qos") else None
95
+ self.req_nodes = int(executor_config.get("req_nodes", 1))
96
+ self.walltime = str(executor_config.get("walltime", "00:10:00"))
97
+ self.log_dir = str(executor_config.get("log_dir", "runinfo"))
98
+
99
+ self.cores_per_worker = _to_int(executor_config.get("cores_per_worker"))
100
+ self.mem_per_worker = _to_float(executor_config.get("mem_per_worker"))
101
+
102
+ hours, minutes, seconds = map(int, self.walltime.split(":"))
103
+ total_minutes = hours * 60 + minutes + seconds / 60
104
+ self.total_minutes = total_minutes
105
+
106
+ if executor_config.get("validation") and HAS_REAL_SLURM:
107
+ self._validate_slurm_params()
108
+
109
+ self._initialize_parsl()
110
+
111
+ self.parsl_results: list[ExecutionFuture] = []
112
+
113
+ def _validate_slurm_params(self) -> None:
114
+ """Validate the Slurm configuration using SlurmChecker.
115
+
116
+ Raises
117
+ ------
118
+ ValueError: If account, partition or QOS are invalid or inaccessible.
119
+ """
120
+ slurm_checker = SlurmChecker()
121
+ if self.account and not slurm_checker.get_account_info(self.account):
122
+ raise ValueError(f"Account: {self.account} not valid")
123
+
124
+ partition_limits = None
125
+ node_info = None
126
+
127
+ if self.partition:
128
+ if not slurm_checker.get_partition_info(self.partition):
129
+ raise ValueError(f"Partition: {self.partition} not valid")
130
+
131
+ if not slurm_checker.can_account_use_partition(self.account, self.partition):
132
+ raise ValueError(f"Account: {self.account} cannot access partiton: {self.partition}")
133
+
134
+ partition_limits = slurm_checker.get_partition_limits(self.partition)
135
+ node_info = slurm_checker.get_node_from_partition(self.partition)
136
+
137
+ qos_limits = None
138
+ if self.qos:
139
+ if not slurm_checker.get_qos_info(self.qos):
140
+ raise ValueError(f"QOS: {self.qos} not valid")
141
+
142
+ if not slurm_checker.can_account_use_qos(self.account, self.qos):
143
+ raise ValueError(f"Account: {self.account} cannot access qos: {self.qos}")
144
+
145
+ qos_limits = slurm_checker.get_qos_limits(self.qos)
146
+
147
+ max_cores_per_node = int(node_info["cpus"]) if node_info else None
148
+ if max_cores_per_node and self.cores_per_worker:
149
+ if self.cores_per_worker > max_cores_per_node:
150
+ raise ValueError(
151
+ f"cores_per_work:{self.cores_per_worker}"
152
+ f"larger than the maximum in a node {max_cores_per_node}"
153
+ )
154
+
155
+ max_mem_per_node = float(node_info["real_memory"]) if node_info else None
156
+ if max_mem_per_node and self.mem_per_worker:
157
+ if self.mem_per_worker > max_mem_per_node:
158
+ raise ValueError(
159
+ f"mem_per_work:{self.mem_per_worker}"
160
+ f"larger than the maximum mem in a node {max_mem_per_node}"
161
+ )
162
+
163
+ max_walltime_partition = (
164
+ partition_limits["max_time_minutes"] if partition_limits else self.total_minutes
165
+ )
166
+ max_walltime_qos = qos_limits["max_time_minutes"] if qos_limits else self.total_minutes
167
+
168
+ max_walltime_minutes = min(float(max_walltime_partition), float(max_walltime_qos))
169
+
170
+ if self.total_minutes > float(max_walltime_minutes):
171
+ raise ValueError(
172
+ f"Walltime: {self.walltime} exceed the maximum time "
173
+ f"{max_walltime_minutes} allowed by {self.partition} and {self.qos}"
174
+ )
175
+
176
+ def _initialize_parsl(self) -> None:
177
+ executor_config = self.config.executor.config
178
+
179
+ provider = SlurmProvider(
180
+ account=self.account,
181
+ partition=self.partition,
182
+ qos=self.qos,
183
+ nodes_per_block=self.req_nodes,
184
+ max_blocks=int(executor_config.get("max_blocks", 1)),
185
+ scheduler_options=executor_config.get("scheduler_options", "#SBATCH -C cpu"),
186
+ worker_init=executor_config.get("worker_init", "source .venv/bin/activate"),
187
+ launcher=SrunLauncher(
188
+ debug=True,
189
+ overrides=executor_config.get("overrides", ""),
190
+ ),
191
+ walltime=self.walltime,
192
+ cmd_timeout=int(executor_config.get("cmd_timeout", 120)),
193
+ )
194
+ executor = HighThroughputExecutor(
195
+ label="ref_hpc_executor",
196
+ cores_per_worker=self.cores_per_worker if self.cores_per_worker else 1,
197
+ mem_per_worker=self.mem_per_worker,
198
+ max_workers_per_node=_to_int(executor_config.get("max_workers_per_node", 16)),
199
+ cpu_affinity=str(executor_config.get("cpu_affinity")),
200
+ provider=provider,
201
+ )
202
+
203
+ hpc_config = ParslConfig(
204
+ run_dir=self.log_dir, executors=[executor], retries=int(executor_config.get("retries", 2))
205
+ )
206
+ parsl.load(hpc_config)
207
+
208
+ def run(
209
+ self,
210
+ definition: ExecutionDefinition,
211
+ execution: Execution | None = None,
212
+ ) -> None:
213
+ """
214
+ Run a diagnostic in process
215
+
216
+ Parameters
217
+ ----------
218
+ definition
219
+ A description of the information needed for this execution of the diagnostic
220
+ execution
221
+ A database model representing the execution of the diagnostic.
222
+ If provided, the result will be updated in the database when completed.
223
+ """
224
+ # Submit the execution to the process pool
225
+ # and track the future so we can wait for it to complete
226
+ future = _process_run(
227
+ definition=definition,
228
+ log_level=self.config.log_level,
229
+ )
230
+
231
+ self.parsl_results.append(
232
+ ExecutionFuture(
233
+ future=future,
234
+ definition=definition,
235
+ execution_id=execution.id if execution else None,
236
+ )
237
+ )
238
+
239
+ def join(self, timeout: float) -> None:
240
+ """
241
+ Wait for all diagnostics to finish
242
+
243
+ This will block until all diagnostics have completed or the timeout is reached.
244
+ If the timeout is reached, the method will return and raise an exception.
245
+
246
+ Parameters
247
+ ----------
248
+ timeout
249
+ Timeout in seconds (won't used in HPCExecutor)
250
+
251
+ Raises
252
+ ------
253
+ TimeoutError
254
+ If the timeout is reached
255
+ """
256
+ start_time = time.time()
257
+ refresh_time = 0.5
258
+
259
+ results = self.parsl_results
260
+ t = tqdm(total=len(results), desc="Waiting for executions to complete", unit="execution")
261
+
262
+ try:
263
+ while results:
264
+ # Iterate over a copy of the list and remove finished tasks
265
+ for result in results[:]:
266
+ if result.future.done():
267
+ try:
268
+ execution_result = result.future.result(timeout=0)
269
+ except Exception as e:
270
+ # Something went wrong when attempting to run the execution
271
+ # This is likely a failure in the execution itself not the diagnostic
272
+ raise ExecutionError(
273
+ f"Failed to execute {result.definition.execution_slug()!r}"
274
+ ) from e
275
+
276
+ assert execution_result is not None, "Execution result should not be None"
277
+ assert isinstance(execution_result, ExecutionResult), (
278
+ "Execution result should be of type ExecutionResult"
279
+ )
280
+
281
+ # Process the result in the main process
282
+ # The results should be committed after each execution
283
+ with self.database.session.begin():
284
+ execution = (
285
+ self.database.session.get(Execution, result.execution_id)
286
+ if result.execution_id
287
+ else None
288
+ )
289
+ process_result(self.config, self.database, result.future.result(), execution)
290
+ logger.debug(f"Execution completed: {result}")
291
+ t.update(n=1)
292
+ results.remove(result)
293
+
294
+ # Break early to avoid waiting for one more sleep cycle
295
+ if len(results) == 0:
296
+ break
297
+
298
+ elapsed_time = time.time() - start_time
299
+
300
+ if elapsed_time > self.total_minutes * 60:
301
+ logger.debug(f"Time elasped {elapsed_time} for joining the results")
302
+
303
+ # Wait for a short time before checking for completed executions
304
+ time.sleep(refresh_time)
305
+ finally:
306
+ t.close()
307
+ if parsl.dfk():
308
+ parsl.dfk().cleanup()
@@ -1,4 +1,5 @@
1
1
  import concurrent.futures
2
+ import multiprocessing
2
3
  import time
3
4
  from concurrent.futures import Future, ProcessPoolExecutor
4
5
  from typing import Any
@@ -13,7 +14,7 @@ from climate_ref.models import Execution
13
14
  from climate_ref_core.diagnostics import ExecutionDefinition, ExecutionResult
14
15
  from climate_ref_core.exceptions import ExecutionError
15
16
  from climate_ref_core.executor import execute_locally
16
- from climate_ref_core.logging import add_log_handler
17
+ from climate_ref_core.logging import initialise_logging
17
18
 
18
19
  from .result_handling import handle_execution_result
19
20
 
@@ -63,11 +64,17 @@ class ExecutionFuture:
63
64
  execution_id: int | None = None
64
65
 
65
66
 
66
- def _process_initialiser() -> None:
67
+ def _process_initialiser() -> None: # pragma: no cover
67
68
  # Setup the logging for the process
68
69
  # This replaces the loguru default handler
69
70
  try:
70
- add_log_handler()
71
+ logger.remove()
72
+ config = Config.default()
73
+ initialise_logging(
74
+ level=config.log_level,
75
+ format=config.log_format,
76
+ log_directory=config.paths.log,
77
+ )
71
78
  except Exception as e:
72
79
  # Don't raise an exception here as that would kill the process pool
73
80
  # We want to log the error and continue
@@ -118,7 +125,12 @@ class LocalExecutor:
118
125
  if pool is not None:
119
126
  self.pool = pool
120
127
  else:
121
- self.pool = ProcessPoolExecutor(max_workers=n, initializer=_process_initialiser)
128
+ self.pool = ProcessPoolExecutor(
129
+ max_workers=n,
130
+ initializer=_process_initialiser,
131
+ # Explicitly set the context to "spawn" to avoid issues with hanging on MacOS
132
+ mp_context=multiprocessing.get_context("spawn"),
133
+ )
122
134
  self._results: list[ExecutionFuture] = []
123
135
 
124
136
  def run(
@@ -214,9 +226,17 @@ class LocalExecutor:
214
226
  elapsed_time = time.time() - start_time
215
227
 
216
228
  if elapsed_time > timeout:
229
+ for result in results:
230
+ logger.warning(
231
+ f"Execution {result.definition.execution_slug()} "
232
+ f"did not complete within the timeout"
233
+ )
234
+ self.pool.shutdown(wait=False, cancel_futures=True)
217
235
  raise TimeoutError("Not all tasks completed within the specified timeout")
218
236
 
219
237
  # Wait for a short time before checking for completed executions
220
238
  time.sleep(refresh_time)
221
239
  finally:
222
240
  t.close()
241
+
242
+ logger.info("All executions completed successfully")
@@ -129,7 +129,6 @@ def handle_execution_result(
129
129
  cv.validate_metrics(cmec_metric_bundle)
130
130
  except (ResultValidationError, AssertionError):
131
131
  logger.exception("Diagnostic values do not conform with the controlled vocabulary")
132
- # TODO: Mark the diagnostic execution result as failed once the CV has stabilised
133
132
  # execution.mark_failed()
134
133
 
135
134
  # Perform a bulk insert of scalar values
climate_ref/slurm.py ADDED
@@ -0,0 +1,192 @@
1
+ import importlib.util
2
+ from typing import Any
3
+
4
+ HAS_REAL_SLURM = importlib.util.find_spec("pyslurm") is not None
5
+
6
+
7
+ class SlurmChecker:
8
+ """Check and get slurm settings."""
9
+
10
+ def __init__(self, intest: bool = False) -> None:
11
+ if HAS_REAL_SLURM:
12
+ import pyslurm # type: ignore
13
+
14
+ self.slurm_association: dict[int, Any] | None = pyslurm.db.Associations.load()
15
+ self.slurm_partition: dict[str, Any] | None = pyslurm.Partitions.load()
16
+ self.slurm_qos: dict[str, Any] | None = pyslurm.qos().get()
17
+ self.slurm_node: dict[str, Any] | None = pyslurm.Nodes.load()
18
+ elif intest:
19
+ import pyslurm
20
+
21
+ self.slurm_association = pyslurm.db.Associations.load() # dict [num -> Association]
22
+ self.slurm_partition = pyslurm.Partitions.load() # collection
23
+ self.slurm_qos = pyslurm.qos().get() # dict
24
+ self.slurm_node = pyslurm.Nodes.load() # dict
25
+ else:
26
+ print("Warning: pyslurm not found. Skipping HPCExecutor config validations")
27
+ self.slurm_association = None
28
+ self.slurm_partition = None
29
+ self.slurm_qos = None
30
+ self.slurm_node = None
31
+
32
+ def get_partition_info(self, partition_name: str) -> Any:
33
+ """Check if a partition exists in the Slurm configuration."""
34
+ return self.slurm_partition.get(partition_name) if self.slurm_partition else None
35
+
36
+ def get_qos_info(self, qos_name: str) -> Any:
37
+ """Check if a qos exists in the Slurm configuration."""
38
+ return self.slurm_qos.get(qos_name) if self.slurm_qos else None
39
+
40
+ def get_account_info(self, account_name: str) -> list[Any]:
41
+ """Get all associations for an account"""
42
+ if self.slurm_association:
43
+ return [a for a in self.slurm_association.values() if a.account == account_name]
44
+ else:
45
+ return [None]
46
+
47
+ def can_account_use_partition(self, account_name: str, partition_name: str) -> bool:
48
+ """
49
+ Check if an account has access to a specific partition.
50
+
51
+ Returns
52
+ -------
53
+ bool: True if accessible, False if not accessible or error occurred
54
+ """
55
+ account_info = self.get_account_info(account_name)
56
+ if not account_info:
57
+ return False
58
+
59
+ partition_info = self.get_partition_info(partition_name)
60
+
61
+ if not partition_info:
62
+ return False
63
+
64
+ allowed_partitions = account_info[0].partition
65
+ if allowed_partitions is None:
66
+ return True
67
+ else:
68
+ return partition_name in allowed_partitions
69
+
70
+ def can_account_use_qos(self, account_name: str, qos_name: str) -> bool:
71
+ """
72
+ Check if an account has access to a specific qos.
73
+
74
+ Returns
75
+ -------
76
+ bool: True if accessible, False if not accessible or error occurred
77
+ """
78
+ account_info = self.get_account_info(account_name)
79
+
80
+ if not account_info:
81
+ return False
82
+
83
+ qos_info = self.get_qos_info(qos_name)
84
+ if not qos_info:
85
+ return False
86
+
87
+ sample_acc = account_info[0]
88
+ for acc in account_info:
89
+ if acc.user == "minxu":
90
+ sample_acc = acc
91
+ break
92
+
93
+ allowed_qoss = sample_acc.qos
94
+ if allowed_qoss is None:
95
+ return True
96
+ else:
97
+ return qos_name in allowed_qoss
98
+
99
+ def get_partition_limits(self, partition_name: str) -> dict[str, str | int] | None:
100
+ """
101
+ Get time limits for a specific partition.
102
+
103
+ Returns
104
+ -------
105
+ Dict with 'max_time' and 'default_time' (strings or UNLIMITED)
106
+ or None if partition doesn't exist or error occurred
107
+ """
108
+ partition_info = self.get_partition_info(partition_name)
109
+ if not partition_info:
110
+ return None
111
+
112
+ return {
113
+ "max_time_minutes": partition_info.to_dict().get("max_time", 0), # in minutes
114
+ "default_time_minutes": partition_info.to_dict().get("default_time", 30), # in minutes
115
+ "max_nodes": partition_info.to_dict().get("max_node", 1),
116
+ "total_nodes": partition_info.to_dict().get("total_nodes", 0),
117
+ "total_cpus": partition_info.to_dict().get("total_cpus", 0),
118
+ }
119
+
120
+ def get_node_from_partition(self, partition_name: str) -> dict[str, str | int] | None:
121
+ """
122
+ Get the node information for a specific partition.
123
+
124
+ Returns
125
+ -------
126
+ Dicts
127
+ """
128
+ partition_info = self.get_partition_info(partition_name)
129
+ if not partition_info:
130
+ return None
131
+
132
+ sample_node = None
133
+
134
+ if self.slurm_node:
135
+ for node in self.slurm_node.values():
136
+ if partition_name in node.partitions and "cpu" in node.available_features:
137
+ sample_node = node
138
+ break
139
+
140
+ return {
141
+ "cpus": int(sample_node.total_cpus) if sample_node is not None else 1,
142
+ "cores_per_socket": int(sample_node.cores_per_socket) if sample_node is not None else 1,
143
+ "sockets": int(sample_node.sockets) if sample_node is not None else 1,
144
+ "threads_per_core": int(sample_node.threads_per_core) if sample_node is not None else 1,
145
+ "real_memory": int(sample_node.real_memory) if sample_node is not None else 215,
146
+ "node_names": sample_node.name if sample_node is not None else "unknown",
147
+ }
148
+
149
+ def get_qos_limits(self, qos_name: str) -> dict[str, str | int]:
150
+ """
151
+ Get time limits for a specific qos.
152
+
153
+ Returns
154
+ -------
155
+ Dict with 'max_time' and 'default_time' (strings or UNLIMITED)
156
+ or None if partition doesn't exist or error occurred
157
+ """
158
+ qos_info = self.get_qos_info(qos_name)
159
+
160
+ return {
161
+ "max_time_minutes": qos_info.get("max_wall_pj", 1.0e6),
162
+ "max_jobs_pu": qos_info.get("max_jobs_pu", 1.0e6),
163
+ "max_submit_jobs_pu": qos_info.get("max_submit_jobs_pu", 1.0e6),
164
+ "max_tres_pj": qos_info.get("max_tres_pj").split("=")[0],
165
+ "default_time_minutes": 120,
166
+ }
167
+
168
+ def check_account_partition_access_with_limits(
169
+ self, account_name: str, partition_name: str
170
+ ) -> dict[str, Any]:
171
+ """
172
+ Comprehensive check of account access and partition limits.
173
+
174
+ Returns dictionary with all relevant information.
175
+ """
176
+ result = {
177
+ "account_exists": True if self.get_account_info(account_name) else False,
178
+ "partition_exists": True if self.get_partition_info(partition_name) else False,
179
+ "has_access": False,
180
+ "time_limits": None,
181
+ "error": "none",
182
+ }
183
+
184
+ try:
185
+ if result["account_exists"] and result["partition_exists"]:
186
+ result["has_access"] = self.can_account_use_partition(account_name, partition_name)
187
+ if result["has_access"]:
188
+ result["time_limits"] = self.get_partition_info(partition_name).to_dict().get("max_time")
189
+ except Exception as e:
190
+ result["error"] = str(e)
191
+
192
+ return result
climate_ref/solver.py CHANGED
@@ -245,6 +245,57 @@ def _solve_from_data_requirements(
245
245
  )
246
246
 
247
247
 
248
+ @define
249
+ class SolveFilterOptions:
250
+ """
251
+ Options to filter the diagnostics that are solved
252
+ """
253
+
254
+ diagnostic: list[str] | None = None
255
+ """
256
+ Check if the diagnostic slug contains any of the provided values
257
+ """
258
+ provider: list[str] | None = None
259
+ """
260
+ Check if the provider slug contains any of the provided values
261
+ """
262
+
263
+
264
+ def matches_filter(diagnostic: Diagnostic, filters: SolveFilterOptions | None) -> bool:
265
+ """
266
+ Check if a diagnostic matches the provided filters
267
+
268
+ Each filter is optional and a diagnostic will match if it satisfies all the provided filters.
269
+ i.e. the filters are ANDed together.
270
+
271
+ Parameters
272
+ ----------
273
+ diagnostic
274
+ Diagnostic to check against the filters
275
+ filters
276
+ Collection of filters to apply to the diagnostic
277
+
278
+ If no filters are provided, the diagnostic is considered to match
279
+
280
+ Returns
281
+ -------
282
+ True if the diagnostic matches the filters, False otherwise
283
+ """
284
+ if filters is None:
285
+ return True
286
+
287
+ diagnostic_slug = diagnostic.slug
288
+ provider_slug = diagnostic.provider.slug
289
+
290
+ if filters.provider and not any([f.lower() in provider_slug for f in filters.provider]):
291
+ return False
292
+
293
+ if filters.diagnostic and not any([f.lower() in diagnostic_slug for f in filters.diagnostic]):
294
+ return False
295
+
296
+ return True
297
+
298
+
248
299
  @define
249
300
  class ExecutionSolver:
250
301
  """
@@ -278,7 +329,9 @@ class ExecutionSolver:
278
329
  },
279
330
  )
280
331
 
281
- def solve(self) -> typing.Generator[DiagnosticExecution, None, None]:
332
+ def solve(
333
+ self, filters: SolveFilterOptions | None = None
334
+ ) -> typing.Generator[DiagnosticExecution, None, None]:
282
335
  """
283
336
  Solve which executions need to be calculated for a dataset
284
337
 
@@ -293,17 +346,23 @@ class ExecutionSolver:
293
346
  """
294
347
  for provider in self.provider_registry.providers:
295
348
  for diagnostic in provider.diagnostics():
349
+ # Filter the diagnostic based on the provided filters
350
+ if not matches_filter(diagnostic, filters):
351
+ logger.debug(f"Skipping {diagnostic.full_slug()} due to filter")
352
+ continue
296
353
  yield from solve_executions(self.data_catalog, diagnostic, provider)
297
354
 
298
355
 
299
356
  def solve_required_executions( # noqa: PLR0913
300
357
  db: Database,
301
358
  dry_run: bool = False,
359
+ execute: bool = True,
302
360
  solver: ExecutionSolver | None = None,
303
361
  config: Config | None = None,
304
362
  timeout: int = 60,
305
363
  one_per_provider: bool = False,
306
364
  one_per_diagnostic: bool = False,
365
+ filters: SolveFilterOptions | None = None,
307
366
  ) -> None:
308
367
  """
309
368
  Solve for executions that require recalculation
@@ -328,7 +387,7 @@ def solve_required_executions( # noqa: PLR0913
328
387
  diagnostic_count = {}
329
388
  provider_count = {}
330
389
 
331
- for potential_execution in solver.solve():
390
+ for potential_execution in solver.solve(filters):
332
391
  # The diagnostic output is first written to the scratch directory
333
392
  definition = potential_execution.build_execution_definition(output_root=config.paths.scratch)
334
393
 
@@ -371,6 +430,7 @@ def solve_required_executions( # noqa: PLR0913
371
430
  logger.info(f"Created new execution group: {potential_execution.execution_slug()!r}")
372
431
  db.session.flush()
373
432
 
433
+ # TODO: Move this logic to the solver
374
434
  # Check if we should run given the one_per_provider or one_per_diagnostic flags
375
435
  one_of_check_failed = (
376
436
  one_per_provider and provider_count.get(diagnostic.provider.slug, 0) > 0
@@ -403,10 +463,11 @@ def solve_required_executions( # noqa: PLR0913
403
463
  # Add links to the datasets used in the execution
404
464
  execution.register_datasets(db, definition.datasets)
405
465
 
406
- executor.run(
407
- definition=definition,
408
- execution=execution,
409
- )
466
+ if execute:
467
+ executor.run(
468
+ definition=definition,
469
+ execution=execution,
470
+ )
410
471
 
411
472
  provider_count[diagnostic.provider.slug] += 1
412
473
  diagnostic_count[diagnostic.full_slug()] += 1