climate-ref 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Protocol, cast
2
+ from typing import Any, Protocol, cast
3
3
 
4
4
  import pandas as pd
5
5
  from loguru import logger
@@ -35,6 +35,31 @@ def _log_duplicate_metadata(
35
35
  )
36
36
 
37
37
 
38
+ class DatasetParsingFunction(Protocol):
39
+ """
40
+ Protocol for a function that parses metadata from a file or directory
41
+ """
42
+
43
+ def __call__(self, file: str, **kwargs: Any) -> dict[str, Any]:
44
+ """
45
+ Parse a file or directory and return metadata for the dataset
46
+
47
+ Parameters
48
+ ----------
49
+ file
50
+ File or directory to parse
51
+
52
+ kwargs
53
+ Additional keyword arguments to pass to the parsing function.
54
+
55
+ Returns
56
+ -------
57
+ :
58
+ Data catalog containing the metadata for the dataset
59
+ """
60
+ ...
61
+
62
+
38
63
  class DatasetAdapter(Protocol):
39
64
  """
40
65
  An adapter to provide a common interface for different dataset types
@@ -173,7 +198,7 @@ class DatasetAdapter(Protocol):
173
198
  slug = unique_slugs[0]
174
199
 
175
200
  dataset_metadata = data_catalog_dataset[list(self.dataset_specific_metadata)].iloc[0].to_dict()
176
- dataset, created = db.get_or_create(DatasetModel, slug=slug, **dataset_metadata)
201
+ dataset, created = db.get_or_create(DatasetModel, defaults=dataset_metadata, slug=slug)
177
202
  if not created:
178
203
  logger.warning(f"{dataset} already exists in the database. Skipping")
179
204
  return None
@@ -212,6 +237,7 @@ class DatasetAdapter(Protocol):
212
237
  {
213
238
  **{k: getattr(file, k) for k in self.file_specific_metadata},
214
239
  **{k: getattr(file.dataset, k) for k in self.dataset_specific_metadata},
240
+ "finalised": file.dataset.finalised,
215
241
  }
216
242
  for file in result
217
243
  ],
@@ -1,18 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import traceback
4
3
  import warnings
5
4
  from datetime import datetime
6
5
  from pathlib import Path
7
6
  from typing import Any
8
7
 
9
8
  import pandas as pd
10
- import xarray as xr
11
9
  from ecgtools import Builder
12
- from ecgtools.parsers.utilities import extract_attr_with_regex # type: ignore
13
10
  from loguru import logger
14
11
 
15
- from climate_ref.datasets.base import DatasetAdapter
12
+ from climate_ref.config import Config
13
+ from climate_ref.datasets.base import DatasetAdapter, DatasetParsingFunction
14
+ from climate_ref.datasets.cmip6_parsers import parse_cmip6_complete, parse_cmip6_drs
16
15
  from climate_ref.models.dataset import CMIP6Dataset
17
16
 
18
17
 
@@ -22,16 +21,19 @@ def _parse_datetime(dt_str: pd.Series[str]) -> pd.Series[datetime | Any]:
22
21
  """
23
22
 
24
23
  def _inner(date_string: str | None) -> datetime | None:
25
- if not date_string:
24
+ if not date_string or pd.isnull(date_string):
26
25
  return None
27
26
 
28
27
  # Try to parse the date string with and without milliseconds
29
- try:
30
- dt = datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
31
- except ValueError:
32
- dt = datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f")
28
+ for fmt in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M:%S.%f"):
29
+ try:
30
+ return datetime.strptime(date_string, fmt)
31
+ except ValueError:
32
+ continue
33
33
 
34
- return dt
34
+ # If all parsing attempts fail, log an error and return None
35
+ logger.error(f"Failed to parse date string: {date_string}")
36
+ return None
35
37
 
36
38
  return pd.Series(
37
39
  [_inner(dt) for dt in dt_str],
@@ -44,15 +46,16 @@ def _apply_fixes(data_catalog: pd.DataFrame) -> pd.DataFrame:
44
46
  def _fix_parent_variant_label(group: pd.DataFrame) -> pd.DataFrame:
45
47
  if group["parent_variant_label"].nunique() == 1:
46
48
  return group
47
- group["parent_variant_label"] = group["variant_label"].iloc[0]
49
+ group["parent_variant_label"] = group["parent_variant_label"].iloc[0]
48
50
 
49
51
  return group
50
52
 
51
- data_catalog = (
52
- data_catalog.groupby("instance_id")
53
- .apply(_fix_parent_variant_label, include_groups=False)
54
- .reset_index(level="instance_id")
55
- )
53
+ if "parent_variant_label" in data_catalog:
54
+ data_catalog = (
55
+ data_catalog.groupby("instance_id")
56
+ .apply(_fix_parent_variant_label, include_groups=False)
57
+ .reset_index(level="instance_id")
58
+ )
56
59
 
57
60
  if "branch_time_in_child" in data_catalog:
58
61
  data_catalog["branch_time_in_child"] = _clean_branch_time(data_catalog["branch_time_in_child"])
@@ -68,88 +71,6 @@ def _clean_branch_time(branch_time: pd.Series[str]) -> pd.Series[float]:
68
71
  return pd.to_numeric(branch_time.astype(str).str.replace("D", ""), errors="coerce")
69
72
 
70
73
 
71
- def parse_cmip6(file: str) -> dict[str, Any]:
72
- """
73
- Parser for CMIP6
74
-
75
- This function parses the CMIP6 dataset and returns a dictionary with the metadata.
76
- This was copied from the ecgtools package, but we want to log the exception when it fails.
77
- """
78
- keys = sorted(
79
- {
80
- "activity_id",
81
- "branch_method",
82
- "branch_time_in_child",
83
- "branch_time_in_parent",
84
- "experiment",
85
- "experiment_id",
86
- "frequency",
87
- "grid",
88
- "grid_label",
89
- "institution_id",
90
- "nominal_resolution",
91
- "parent_activity_id",
92
- "parent_experiment_id",
93
- "parent_source_id",
94
- "parent_time_units",
95
- "parent_variant_label",
96
- "realm",
97
- "product",
98
- "source_id",
99
- "source_type",
100
- "sub_experiment",
101
- "sub_experiment_id",
102
- "table_id",
103
- "variable_id",
104
- "variant_label",
105
- }
106
- )
107
-
108
- try:
109
- with xr.open_dataset(file, chunks={}, use_cftime=True) as ds:
110
- info = {key: ds.attrs.get(key) for key in keys}
111
- info["member_id"] = info["variant_label"]
112
-
113
- variable_id = info["variable_id"]
114
- if variable_id: # pragma: no branch
115
- attrs = ds[variable_id].attrs
116
- for attr in ["standard_name", "long_name", "units"]:
117
- info[attr] = attrs.get(attr)
118
-
119
- # Set the default of # of vertical levels to 1
120
- vertical_levels = 1
121
- start_time, end_time = None, None
122
- init_year = None
123
- try:
124
- vertical_levels = ds[ds.cf["vertical"].name].size
125
- except (KeyError, AttributeError, ValueError):
126
- ...
127
-
128
- try:
129
- start_time, end_time = str(ds.cf["T"][0].data), str(ds.cf["T"][-1].data)
130
- except (KeyError, AttributeError, ValueError):
131
- ...
132
- if info.get("sub_experiment_id"): # pragma: no branch
133
- init_year = extract_attr_with_regex(info["sub_experiment_id"], r"\d{4}")
134
- if init_year: # pragma: no cover
135
- init_year = int(init_year)
136
- info["vertical_levels"] = vertical_levels
137
- info["init_year"] = init_year
138
- info["start_time"] = start_time
139
- info["end_time"] = end_time
140
- if not (start_time and end_time):
141
- info["time_range"] = None
142
- else:
143
- info["time_range"] = f"{start_time}-{end_time}"
144
- info["path"] = str(file)
145
- info["version"] = extract_attr_with_regex(str(file), regex=r"v\d{4}\d{2}\d{2}|v\d{1}") or "v0"
146
- return info
147
-
148
- except Exception:
149
- logger.exception(f"Failed to parse {file}")
150
- return {"INVALID_ASSET": file, "TRACEBACK": traceback.format_exc()}
151
-
152
-
153
74
  class CMIP6DatasetAdapter(DatasetAdapter):
154
75
  """
155
76
  Adapter for CMIP6 datasets
@@ -191,6 +112,7 @@ class CMIP6DatasetAdapter(DatasetAdapter):
191
112
  "standard_name",
192
113
  "long_name",
193
114
  "units",
115
+ "finalised",
194
116
  slug_column,
195
117
  )
196
118
 
@@ -208,8 +130,30 @@ class CMIP6DatasetAdapter(DatasetAdapter):
208
130
  "grid_label",
209
131
  )
210
132
 
211
- def __init__(self, n_jobs: int = 1):
133
+ def __init__(self, n_jobs: int = 1, config: Config | None = None):
212
134
  self.n_jobs = n_jobs
135
+ self.config = config or Config.default()
136
+
137
+ def get_parsing_function(self) -> DatasetParsingFunction:
138
+ """
139
+ Get the parsing function for CMIP6 datasets based on configuration
140
+
141
+ The parsing function used is determined by the `cmip6_parser` configuration value:
142
+ - "drs": Use the DRS parser (default)
143
+ - "complete": Use the complete parser that extracts all available metadata
144
+
145
+ Returns
146
+ -------
147
+ :
148
+ The appropriate parsing function based on configuration
149
+ """
150
+ parser_type = self.config.cmip6_parser
151
+ if parser_type == "complete":
152
+ logger.info("Using complete CMIP6 parser")
153
+ return parse_cmip6_complete
154
+ else:
155
+ logger.info(f"Using DRS CMIP6 parser (config value: {parser_type})")
156
+ return parse_cmip6_drs
213
157
 
214
158
  def find_local_datasets(self, file_or_directory: Path) -> pd.DataFrame:
215
159
  """
@@ -228,6 +172,8 @@ class CMIP6DatasetAdapter(DatasetAdapter):
228
172
  :
229
173
  Data catalog containing the metadata for the dataset
230
174
  """
175
+ parsing_function = self.get_parsing_function()
176
+
231
177
  with warnings.catch_warnings():
232
178
  # Ignore the DeprecationWarning from xarray
233
179
  warnings.simplefilter("ignore", DeprecationWarning)
@@ -237,7 +183,7 @@ class CMIP6DatasetAdapter(DatasetAdapter):
237
183
  depth=10,
238
184
  include_patterns=["*.nc"],
239
185
  joblib_parallel_kwargs={"n_jobs": self.n_jobs},
240
- ).build(parsing_func=parse_cmip6) # type: ignore
186
+ ).build(parsing_func=parsing_function)
241
187
 
242
188
  datasets: pd.DataFrame = builder.df.drop(["init_year"], axis=1)
243
189
 
@@ -254,6 +200,14 @@ class CMIP6DatasetAdapter(DatasetAdapter):
254
200
  lambda row: "CMIP6." + ".".join([row[item] for item in drs_items]), axis=1
255
201
  )
256
202
 
203
+ # Add in any missing metadata columns
204
+ missing_columns = set(self.dataset_specific_metadata + self.file_specific_metadata) - set(
205
+ datasets.columns
206
+ )
207
+ if missing_columns:
208
+ for column in missing_columns:
209
+ datasets[column] = pd.NA
210
+
257
211
  # Temporary fix for some datasets
258
212
  # TODO: Replace with a standalone package that contains metadata fixes for CMIP6 datasets
259
213
  datasets = _apply_fixes(datasets)
@@ -0,0 +1,189 @@
1
+ """
2
+ CMIP6 parser functions for extracting metadata from netCDF files
3
+
4
+ Additional non-official DRS's may be added in the future.
5
+ """
6
+
7
+ import traceback
8
+ from typing import Any
9
+
10
+ import xarray as xr
11
+ from ecgtools.parsers.cmip import parse_cmip6_using_directories # type: ignore
12
+ from ecgtools.parsers.utilities import extract_attr_with_regex # type: ignore
13
+ from loguru import logger
14
+
15
+
16
+ def _parse_daterange(date_range: str) -> tuple[str | None, str | None]:
17
+ """
18
+ Parse a date range string into start and end dates
19
+
20
+ The output from this is an estimated date range until the file is completely parsed.
21
+
22
+ Parameters
23
+ ----------
24
+ date_range
25
+ Date range string in the format "YYYYMM-YYYYMM"
26
+
27
+ Returns
28
+ -------
29
+ :
30
+ Tuple containing start and end dates as strings in the format "YYYY-MM-DD"
31
+ """
32
+ try:
33
+ start, end = date_range.split("-")
34
+ if len(start) != 6 or len(end) != 6: # noqa: PLR2004
35
+ raise ValueError("Date range must be in the format 'YYYYMM-YYYYMM'")
36
+
37
+ start = f"{start[:4]}-{start[4:6]}-01"
38
+ # Up to the 30th of the month, assuming a 30-day month
39
+ # These values will be corrected later when the file is parsed
40
+ end = f"{end[:4]}-{end[4:6]}-30"
41
+
42
+ return start, end
43
+ except ValueError:
44
+ logger.error(f"Invalid date range format: {date_range}")
45
+ return None, None
46
+
47
+
48
+ def parse_cmip6_complete(file: str, **kwargs: Any) -> dict[str, Any]:
49
+ """
50
+ Complete parser for CMIP6 files
51
+
52
+ This parser loads each file and extracts all available metadata.
53
+
54
+ For some filesystems this may be slow, as it involves a lot of I/O operations.
55
+
56
+ Parameters
57
+ ----------
58
+ file
59
+ File to parse
60
+ kwargs
61
+ Additional keyword arguments (not used, but required for compatibility)
62
+
63
+ Returns
64
+ -------
65
+ :
66
+ Dictionary with extracted metadata
67
+ """
68
+ keys = sorted(
69
+ {
70
+ "activity_id",
71
+ "branch_method",
72
+ "branch_time_in_child",
73
+ "branch_time_in_parent",
74
+ "experiment",
75
+ "experiment_id",
76
+ "frequency",
77
+ "grid",
78
+ "grid_label",
79
+ "institution_id",
80
+ "nominal_resolution",
81
+ "parent_activity_id",
82
+ "parent_experiment_id",
83
+ "parent_source_id",
84
+ "parent_time_units",
85
+ "parent_variant_label",
86
+ "realm",
87
+ "product",
88
+ "source_id",
89
+ "source_type",
90
+ "sub_experiment",
91
+ "sub_experiment_id",
92
+ "table_id",
93
+ "variable_id",
94
+ "variant_label",
95
+ }
96
+ )
97
+
98
+ try:
99
+ with xr.open_dataset(file, chunks={}, use_cftime=True) as ds:
100
+ info = {key: ds.attrs.get(key) for key in keys}
101
+ info["member_id"] = info["variant_label"]
102
+
103
+ variable_id = info["variable_id"]
104
+ if variable_id: # pragma: no branch
105
+ attrs = ds[variable_id].attrs
106
+ for attr in ["standard_name", "long_name", "units"]:
107
+ info[attr] = attrs.get(attr)
108
+
109
+ # Set the default of # of vertical levels to 1
110
+ vertical_levels = 1
111
+ start_time, end_time = None, None
112
+ init_year = None
113
+ try:
114
+ vertical_levels = ds[ds.cf["vertical"].name].size
115
+ except (KeyError, AttributeError, ValueError):
116
+ ...
117
+
118
+ try:
119
+ start_time, end_time = str(ds.cf["T"][0].data), str(ds.cf["T"][-1].data)
120
+ except (KeyError, AttributeError, ValueError):
121
+ ...
122
+ if info.get("sub_experiment_id"): # pragma: no branch
123
+ init_year = extract_attr_with_regex(info["sub_experiment_id"], r"\d{4}")
124
+ if init_year: # pragma: no cover
125
+ init_year = int(init_year)
126
+ info["vertical_levels"] = vertical_levels
127
+ info["init_year"] = init_year
128
+ info["start_time"] = start_time
129
+ info["end_time"] = end_time
130
+ if not (start_time and end_time):
131
+ info["time_range"] = None
132
+ else:
133
+ info["time_range"] = f"{start_time}-{end_time}"
134
+ info["path"] = str(file)
135
+ info["version"] = extract_attr_with_regex(str(file), regex=r"v\d{4}\d{2}\d{2}|v\d{1}") or "v0"
136
+
137
+ # Mark the dataset as finalised
138
+ # This is used to indicate that the dataset has been fully parsed and is ready for use
139
+ info["finalised"] = True
140
+
141
+ return info
142
+
143
+ except Exception:
144
+ logger.exception(f"Failed to parse {file}")
145
+ return {"INVALID_ASSET": file, "TRACEBACK": traceback.format_exc()}
146
+
147
+
148
+ def parse_cmip6_drs(file: str, **kwargs: Any) -> dict[str, Any]:
149
+ """
150
+ DRS parser for CMIP6 files
151
+
152
+ This parser extracts metadata according to the CMIP6 Data Reference Syntax (DRS).
153
+ This includes the essential metadata required to identify the dataset and is included in the filename.
154
+
155
+ Parameters
156
+ ----------
157
+ file
158
+ File to parse
159
+ kwargs
160
+ Additional keyword arguments (not used, but required for compatibility)
161
+
162
+ Returns
163
+ -------
164
+ :
165
+ Dictionary with extracted metadata
166
+ """
167
+ info: dict[str, Any] = parse_cmip6_using_directories(file)
168
+
169
+ if "INVALID_ASSET" in info:
170
+ logger.warning(f"Failed to parse {file}: {info['INVALID_ASSET']}")
171
+ return info
172
+
173
+ # The member_id is technically incorrect
174
+ # but for simplicity we are going to ignore sub-experiments for the DRS parser
175
+ info["variant_label"] = info["member_id"]
176
+
177
+ # Rename the `dcpp_init_year` key to `init_year` if it exists
178
+ if "dcpp_init_year" in info:
179
+ info["init_year"] = info.pop("dcpp_init_year")
180
+
181
+ if info.get("time_range"):
182
+ # Parse the time range if it exists
183
+ start_time, end_time = _parse_daterange(info["time_range"])
184
+ info["start_time"] = start_time
185
+ info["end_time"] = end_time
186
+
187
+ info["finalised"] = False
188
+
189
+ return info
@@ -15,8 +15,17 @@ from climate_ref.datasets.cmip6 import _parse_datetime
15
15
  from climate_ref.models.dataset import Dataset, Obs4MIPsDataset
16
16
 
17
17
 
18
- def parse_obs4mips(file: str) -> dict[str, Any | None]:
19
- """Parser for obs4mips"""
18
+ def parse_obs4mips(file: str, **kwargs: Any) -> dict[str, Any]:
19
+ """
20
+ Parser for obs4mips
21
+
22
+ Parameters
23
+ ----------
24
+ file
25
+ File to parse
26
+ kwargs
27
+ Additional keyword arguments (not used, but required for protocol compatibility)
28
+ """
20
29
  keys = sorted(
21
30
  list(
22
31
  {
@@ -106,6 +115,7 @@ class Obs4MIPsDatasetAdapter(DatasetAdapter):
106
115
 
107
116
  dataset_specific_metadata = (
108
117
  "activity_id",
118
+ "finalised",
109
119
  "frequency",
110
120
  "grid",
111
121
  "grid_label",
@@ -159,7 +169,7 @@ class Obs4MIPsDatasetAdapter(DatasetAdapter):
159
169
  depth=10,
160
170
  include_patterns=["*.nc"],
161
171
  joblib_parallel_kwargs={"n_jobs": self.n_jobs},
162
- ).build(parsing_func=parse_obs4mips) # type: ignore[arg-type]
172
+ ).build(parsing_func=parse_obs4mips)
163
173
 
164
174
  datasets = builder.df
165
175
  if datasets.empty:
@@ -178,4 +188,5 @@ class Obs4MIPsDatasetAdapter(DatasetAdapter):
178
188
  datasets["instance_id"] = datasets.apply(
179
189
  lambda row: "obs4MIPs." + ".".join([row[item] for item in drs_items]), axis=1
180
190
  )
191
+ datasets["finalised"] = True
181
192
  return datasets
@@ -9,7 +9,14 @@ The simplest executor is the `LocalExecutor`, which runs the diagnostic in the s
9
9
  This is useful for local testing and debugging.
10
10
  """
11
11
 
12
- from .hpc import HPCExecutor
12
+ from climate_ref_core.exceptions import InvalidExecutorException
13
+
14
+ try:
15
+ from .hpc import HPCExecutor
16
+ except InvalidExecutorException as exc:
17
+ # This exception is reraised when importing the executor as `climate_ref.executors.HPCExecutor`
18
+ HPCExecutor = exc # type: ignore
19
+
13
20
  from .local import LocalExecutor
14
21
  from .result_handling import handle_execution_result
15
22
  from .synchronous import SynchronousExecutor
@@ -5,12 +5,18 @@ If you want to
5
5
  - run REF under the HPC workflows
6
6
  - run REF in multiple nodes
7
7
 
8
+ The `HPCExecutor` requires the optional `parsl` dependency.
9
+ This dependency (and therefore this executor) is not available on Windows.
8
10
  """
9
11
 
10
12
  try:
11
13
  import parsl
12
14
  except ImportError: # pragma: no cover
13
- raise ImportError("The HPCExecutor requires the `parsl` package")
15
+ from climate_ref_core.exceptions import InvalidExecutorException
16
+
17
+ raise InvalidExecutorException(
18
+ "climate_ref_core.executor.hpc.HPCExecutor", "The HPCExecutor requires the `parsl` package"
19
+ )
14
20
 
15
21
  import os
16
22
  import time
@@ -21,7 +27,7 @@ from loguru import logger
21
27
  from parsl import python_app
22
28
  from parsl.config import Config as ParslConfig
23
29
  from parsl.executors import HighThroughputExecutor
24
- from parsl.launchers import SrunLauncher
30
+ from parsl.launchers import SimpleLauncher, SrunLauncher
25
31
  from parsl.providers import SlurmProvider
26
32
  from tqdm import tqdm
27
33
 
@@ -34,6 +40,7 @@ from climate_ref_core.exceptions import DiagnosticError, ExecutionError
34
40
  from climate_ref_core.executor import execute_locally
35
41
 
36
42
  from .local import ExecutionFuture, process_result
43
+ from .pbs_scheduler import SmartPBSProvider
37
44
 
38
45
 
39
46
  @python_app
@@ -96,8 +103,9 @@ class HPCExecutor:
96
103
  self.account = str(executor_config.get("account", os.environ.get("USER")))
97
104
  self.username = executor_config.get("username", os.environ.get("USER"))
98
105
  self.partition = str(executor_config.get("partition")) if executor_config.get("partition") else None
106
+ self.queue = str(executor_config.get("queue")) if executor_config.get("queue") else None
99
107
  self.qos = str(executor_config.get("qos")) if executor_config.get("qos") else None
100
- self.req_nodes = int(executor_config.get("req_nodes", 1))
108
+ self.req_nodes = int(executor_config.get("req_nodes", 1)) if self.scheduler == "slurm" else 1
101
109
  self.walltime = str(executor_config.get("walltime", "00:10:00"))
102
110
  self.log_dir = str(executor_config.get("log_dir", "runinfo"))
103
111
 
@@ -181,21 +189,47 @@ class HPCExecutor:
181
189
  def _initialize_parsl(self) -> None:
182
190
  executor_config = self.config.executor.config
183
191
 
184
- provider = SlurmProvider(
185
- account=self.account,
186
- partition=self.partition,
187
- qos=self.qos,
188
- nodes_per_block=self.req_nodes,
189
- max_blocks=int(executor_config.get("max_blocks", 1)),
190
- scheduler_options=executor_config.get("scheduler_options", "#SBATCH -C cpu"),
191
- worker_init=executor_config.get("worker_init", "source .venv/bin/activate"),
192
- launcher=SrunLauncher(
193
- debug=True,
194
- overrides=executor_config.get("overrides", ""),
195
- ),
196
- walltime=self.walltime,
197
- cmd_timeout=int(executor_config.get("cmd_timeout", 120)),
198
- )
192
+ provider: SlurmProvider | SmartPBSProvider
193
+ if self.scheduler == "slurm":
194
+ provider = SlurmProvider(
195
+ account=self.account,
196
+ partition=self.partition,
197
+ qos=self.qos,
198
+ nodes_per_block=self.req_nodes,
199
+ max_blocks=int(executor_config.get("max_blocks", 1)),
200
+ scheduler_options=executor_config.get("scheduler_options", "#SBATCH -C cpu"),
201
+ worker_init=executor_config.get("worker_init", "source .venv/bin/activate"),
202
+ launcher=SrunLauncher(
203
+ debug=True,
204
+ overrides=executor_config.get("overrides", ""),
205
+ ),
206
+ walltime=self.walltime,
207
+ cmd_timeout=int(executor_config.get("cmd_timeout", 120)),
208
+ )
209
+
210
+ elif self.scheduler == "pbs":
211
+ provider = SmartPBSProvider(
212
+ account=self.account,
213
+ queue=self.queue,
214
+ worker_init=executor_config.get("worker_init", "source .venv/bin/activate"),
215
+ nodes_per_block=_to_int(executor_config.get("nodes_per_block", 1)),
216
+ cpus_per_node=_to_int(executor_config.get("cpus_per_node", None)),
217
+ ncpus=_to_int(executor_config.get("ncpus", None)),
218
+ mem=executor_config.get("mem", "4GB"),
219
+ jobfs=executor_config.get("jobfs", "10GB"),
220
+ storage=executor_config.get("storage", ""),
221
+ init_blocks=executor_config.get("init_blocks", 1),
222
+ min_blocks=executor_config.get("min_blocks", 0),
223
+ max_blocks=executor_config.get("max_blocks", 1),
224
+ parallelism=executor_config.get("parallelism", 1),
225
+ scheduler_options=executor_config.get("scheduler_options", ""),
226
+ launcher=SimpleLauncher(),
227
+ walltime=self.walltime,
228
+ cmd_timeout=int(executor_config.get("cmd_timeout", 120)),
229
+ )
230
+ else:
231
+ raise ValueError(f"Unsupported scheduler: {self.scheduler}")
232
+
199
233
  executor = HighThroughputExecutor(
200
234
  label="ref_hpc_executor",
201
235
  cores_per_worker=self.cores_per_worker if self.cores_per_worker else 1,
@@ -206,8 +240,11 @@ class HPCExecutor:
206
240
  )
207
241
 
208
242
  hpc_config = ParslConfig(
209
- run_dir=self.log_dir, executors=[executor], retries=int(executor_config.get("retries", 2))
243
+ run_dir=self.log_dir,
244
+ executors=[executor],
245
+ retries=int(executor_config.get("retries", 2)),
210
246
  )
247
+
211
248
  parsl.load(hpc_config)
212
249
 
213
250
  def run(