parsl 2024.2.12__py3-none-any.whl → 2024.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. parsl/channels/errors.py +1 -4
  2. parsl/configs/{comet.py → expanse.py} +5 -5
  3. parsl/dataflow/dflow.py +12 -12
  4. parsl/executors/flux/executor.py +5 -3
  5. parsl/executors/high_throughput/executor.py +56 -10
  6. parsl/executors/high_throughput/mpi_prefix_composer.py +137 -0
  7. parsl/executors/high_throughput/mpi_resource_management.py +217 -0
  8. parsl/executors/high_throughput/process_worker_pool.py +65 -9
  9. parsl/executors/radical/executor.py +6 -3
  10. parsl/executors/radical/rpex_worker.py +2 -2
  11. parsl/jobs/states.py +5 -5
  12. parsl/monitoring/db_manager.py +2 -1
  13. parsl/monitoring/monitoring.py +7 -4
  14. parsl/multiprocessing.py +3 -4
  15. parsl/providers/cobalt/cobalt.py +6 -0
  16. parsl/providers/pbspro/pbspro.py +18 -4
  17. parsl/providers/pbspro/template.py +2 -2
  18. parsl/providers/slurm/slurm.py +17 -4
  19. parsl/providers/slurm/template.py +2 -2
  20. parsl/serialize/__init__.py +7 -2
  21. parsl/serialize/facade.py +32 -1
  22. parsl/tests/test_error_handling/test_resource_spec.py +6 -0
  23. parsl/tests/test_htex/test_htex.py +66 -3
  24. parsl/tests/test_monitoring/test_incomplete_futures.py +65 -0
  25. parsl/tests/test_mpi_apps/__init__.py +0 -0
  26. parsl/tests/test_mpi_apps/test_bad_mpi_config.py +41 -0
  27. parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py +51 -0
  28. parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py +171 -0
  29. parsl/tests/test_mpi_apps/test_mpi_prefix.py +71 -0
  30. parsl/tests/test_mpi_apps/test_mpi_scheduler.py +158 -0
  31. parsl/tests/test_mpi_apps/test_resource_spec.py +145 -0
  32. parsl/tests/test_providers/test_cobalt_deprecation_warning.py +16 -0
  33. parsl/tests/test_providers/test_pbspro_template.py +28 -0
  34. parsl/tests/test_providers/test_slurm_template.py +29 -0
  35. parsl/tests/test_radical/test_mpi_funcs.py +1 -0
  36. parsl/tests/test_scaling/test_scale_down.py +6 -5
  37. parsl/tests/test_serialization/test_htex_code_cache.py +57 -0
  38. parsl/tests/test_serialization/test_pack_resource_spec.py +22 -0
  39. parsl/usage_tracking/usage.py +29 -55
  40. parsl/utils.py +12 -35
  41. parsl/version.py +1 -1
  42. {parsl-2024.2.12.data → parsl-2024.2.26.data}/scripts/process_worker_pool.py +65 -9
  43. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/METADATA +2 -2
  44. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/RECORD +50 -37
  45. parsl/configs/cooley.py +0 -29
  46. parsl/configs/theta.py +0 -33
  47. {parsl-2024.2.12.data → parsl-2024.2.26.data}/scripts/exec_parsl_function.py +0 -0
  48. {parsl-2024.2.12.data → parsl-2024.2.26.data}/scripts/parsl_coprocess.py +0 -0
  49. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/LICENSE +0 -0
  50. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/WHEEL +0 -0
  51. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/entry_points.txt +0 -0
  52. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/top_level.txt +0 -0
parsl/channels/errors.py CHANGED
@@ -14,11 +14,8 @@ class ChannelError(ParslError):
14
14
  self.e = e
15
15
  self.hostname = hostname
16
16
 
17
- def __repr__(self) -> str:
18
- return "Hostname:{0}, Reason:{1}".format(self.hostname, self.reason)
19
-
20
17
  def __str__(self) -> str:
21
- return self.__repr__()
18
+ return "Hostname:{0}, Reason:{1}".format(self.hostname, self.reason)
22
19
 
23
20
 
24
21
  class BadHostKeyException(ChannelError):
@@ -7,11 +7,11 @@ from parsl.executors import HighThroughputExecutor
7
7
  config = Config(
8
8
  executors=[
9
9
  HighThroughputExecutor(
10
- label='Comet_HTEX_multinode',
11
- worker_logdir_root='YOUR_LOGDIR_ON_COMET',
12
- max_workers=2,
10
+ label='Expanse_CPU_Multinode',
11
+ max_workers=32,
13
12
  provider=SlurmProvider(
14
- 'debug',
13
+ 'compute',
14
+ account='YOUR_ALLOCATION_ON_EXPANSE',
15
15
  launcher=SrunLauncher(),
16
16
  # string to prepend to #SBATCH blocks in the submit
17
17
  # script to the scheduler
@@ -19,7 +19,7 @@ config = Config(
19
19
  # Command to be run before starting a worker, such as:
20
20
  # 'module load Anaconda; source activate parsl_env'.
21
21
  worker_init='',
22
- walltime='00:10:00',
22
+ walltime='01:00:00',
23
23
  init_blocks=1,
24
24
  max_blocks=1,
25
25
  nodes_per_block=2,
parsl/dataflow/dflow.py CHANGED
@@ -113,7 +113,7 @@ class DataFlowKernel:
113
113
  if self.monitoring.logdir is None:
114
114
  self.monitoring.logdir = self.run_dir
115
115
  self.hub_address = self.monitoring.hub_address
116
- self.hub_interchange_port = self.monitoring.start(self.run_id, self.run_dir)
116
+ self.hub_interchange_port = self.monitoring.start(self.run_id, self.run_dir, self.config.run_dir)
117
117
 
118
118
  self.time_began = datetime.datetime.now()
119
119
  self.time_completed: Optional[datetime.datetime] = None
@@ -678,10 +678,10 @@ class DataFlowKernel:
678
678
  task_record : The task record
679
679
 
680
680
  Returns:
681
- Future that tracks the execution of the submitted executable
681
+ Future that tracks the execution of the submitted function
682
682
  """
683
683
  task_id = task_record['id']
684
- executable = task_record['func']
684
+ function = task_record['func']
685
685
  args = task_record['args']
686
686
  kwargs = task_record['kwargs']
687
687
 
@@ -706,17 +706,17 @@ class DataFlowKernel:
706
706
 
707
707
  if self.monitoring is not None and self.monitoring.resource_monitoring_enabled:
708
708
  wrapper_logging_level = logging.DEBUG if self.monitoring.monitoring_debug else logging.INFO
709
- (executable, args, kwargs) = self.monitoring.monitor_wrapper(executable, args, kwargs, try_id, task_id,
710
- self.monitoring.monitoring_hub_url,
711
- self.run_id,
712
- wrapper_logging_level,
713
- self.monitoring.resource_monitoring_interval,
714
- executor.radio_mode,
715
- executor.monitor_resources(),
716
- self.run_dir)
709
+ (function, args, kwargs) = self.monitoring.monitor_wrapper(function, args, kwargs, try_id, task_id,
710
+ self.monitoring.monitoring_hub_url,
711
+ self.run_id,
712
+ wrapper_logging_level,
713
+ self.monitoring.resource_monitoring_interval,
714
+ executor.radio_mode,
715
+ executor.monitor_resources(),
716
+ self.run_dir)
717
717
 
718
718
  with self.submitter_lock:
719
- exec_fu = executor.submit(executable, task_record['resource_specification'], *args, **kwargs)
719
+ exec_fu = executor.submit(function, task_record['resource_specification'], *args, **kwargs)
720
720
  self.update_task_state(task_record, States.launched)
721
721
 
722
722
  self._send_task_log_info(task_record)
@@ -24,7 +24,7 @@ from parsl.executors.flux.flux_instance_manager import __file__ as _MANAGER_PATH
24
24
  from parsl.executors.errors import ScalingFailed
25
25
  from parsl.providers import LocalProvider
26
26
  from parsl.providers.base import ExecutionProvider
27
- from parsl.serialize import pack_apply_message, deserialize
27
+ from parsl.serialize import deserialize, pack_res_spec_apply_message
28
28
  from parsl.serialize.errors import SerializationError
29
29
  from parsl.app.errors import AppException
30
30
 
@@ -284,8 +284,10 @@ class FluxExecutor(ParslExecutor, RepresentationMixin):
284
284
  infile = os.path.join(self.working_dir, f"{task_id}_in{os.extsep}pkl")
285
285
  outfile = os.path.join(self.working_dir, f"{task_id}_out{os.extsep}pkl")
286
286
  try:
287
- fn_buf = pack_apply_message(
288
- func, args, kwargs, buffer_threshold=1024 * 1024
287
+ fn_buf = pack_res_spec_apply_message(
288
+ func, args, kwargs,
289
+ resource_specification={},
290
+ buffer_threshold=1024 * 1024
289
291
  )
290
292
  except TypeError:
291
293
  raise SerializationError(func.__name__)
@@ -6,12 +6,13 @@ import threading
6
6
  import queue
7
7
  import datetime
8
8
  import pickle
9
- from multiprocessing import Queue
9
+ from multiprocessing import Process, Queue
10
10
  from typing import Dict, Sequence
11
11
  from typing import List, Optional, Tuple, Union, Callable
12
12
  import math
13
13
 
14
- from parsl.serialize import pack_apply_message, deserialize
14
+ import parsl.launchers
15
+ from parsl.serialize import pack_res_spec_apply_message, deserialize
15
16
  from parsl.serialize.errors import SerializationError, DeserializationError
16
17
  from parsl.app.errors import RemoteExceptionWrapper
17
18
  from parsl.jobs.states import JobStatus, JobState
@@ -19,7 +20,10 @@ from parsl.executors.high_throughput import zmq_pipes
19
20
  from parsl.executors.high_throughput import interchange
20
21
  from parsl.executors.errors import (
21
22
  BadMessage, ScalingFailed,
22
- UnsupportedFeatureError
23
+ )
24
+ from parsl.executors.high_throughput.mpi_prefix_composer import (
25
+ VALID_LAUNCHERS,
26
+ validate_resource_spec
23
27
  )
24
28
 
25
29
  from parsl import curvezmq
@@ -50,6 +54,8 @@ DEFAULT_LAUNCH_CMD = ("process_worker_pool.py {debug} {max_workers} "
50
54
  "{address_probe_timeout_string} "
51
55
  "--hb_threshold={heartbeat_threshold} "
52
56
  "--cpu-affinity {cpu_affinity} "
57
+ "{enable_mpi_mode} "
58
+ "--mpi-launcher={mpi_launcher} "
53
59
  "--available-accelerators {accelerators}")
54
60
 
55
61
 
@@ -193,6 +199,17 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
193
199
  worker_logdir_root : string
194
200
  In case of a remote file system, specify the path to where logs will be kept.
195
201
 
202
+ enable_mpi_mode: bool
203
+ If enabled, MPI launch prefixes will be composed for the batch scheduler based on
204
+ the nodes available in each batch job and the resource_specification dict passed
205
+ from the app. This is an experimental feature, please refer to the following doc section
206
+ before use: https://parsl.readthedocs.io/en/stable/userguide/mpi_apps.html
207
+
208
+ mpi_launcher: str
209
+ This field is only used if enable_mpi_mode is set. Select one from the
210
+ list of supported MPI launchers = ("srun", "aprun", "mpiexec").
211
+ default: "mpiexec"
212
+
196
213
  encrypted : bool
197
214
  Flag to enable/disable encryption (CurveZMQ). Default is False.
198
215
  """
@@ -220,6 +237,8 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
220
237
  poll_period: int = 10,
221
238
  address_probe_timeout: Optional[int] = None,
222
239
  worker_logdir_root: Optional[str] = None,
240
+ enable_mpi_mode: bool = False,
241
+ mpi_launcher: str = "mpiexec",
223
242
  block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True,
224
243
  encrypted: bool = False):
225
244
 
@@ -271,6 +290,7 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
271
290
  self.hub_port = None # set to the correct hub port in dfk
272
291
  self.worker_ports = worker_ports
273
292
  self.worker_port_range = worker_port_range
293
+ self.interchange_proc: Optional[Process] = None
274
294
  self.interchange_port_range = interchange_port_range
275
295
  self.heartbeat_threshold = heartbeat_threshold
276
296
  self.heartbeat_period = heartbeat_period
@@ -281,6 +301,15 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
281
301
  self.encrypted = encrypted
282
302
  self.cert_dir = None
283
303
 
304
+ self.enable_mpi_mode = enable_mpi_mode
305
+ assert mpi_launcher in VALID_LAUNCHERS, \
306
+ f"mpi_launcher must be set to one of {VALID_LAUNCHERS}"
307
+ if self.enable_mpi_mode:
308
+ assert isinstance(self.provider.launcher, parsl.launchers.SingleNodeLauncher), \
309
+ "mpi_mode requires the provider to be configured to use a SingleNodeLauncher"
310
+
311
+ self.mpi_launcher = mpi_launcher
312
+
284
313
  if not launch_cmd:
285
314
  launch_cmd = DEFAULT_LAUNCH_CMD
286
315
  self.launch_cmd = launch_cmd
@@ -302,6 +331,7 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
302
331
  """
303
332
  debug_opts = "--debug" if self.worker_debug else ""
304
333
  max_workers = "" if self.max_workers == float('inf') else "--max_workers={}".format(self.max_workers)
334
+ enable_mpi_opts = "--enable_mpi_mode " if self.enable_mpi_mode else ""
305
335
 
306
336
  address_probe_timeout_string = ""
307
337
  if self.address_probe_timeout:
@@ -323,6 +353,8 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
323
353
  cert_dir=self.cert_dir,
324
354
  logdir=self.worker_logdir,
325
355
  cpu_affinity=self.cpu_affinity,
356
+ enable_mpi_mode=enable_mpi_opts,
357
+ mpi_launcher=self.mpi_launcher,
326
358
  accelerators=" ".join(self.available_accelerators))
327
359
  self.launch_cmd = l_cmd
328
360
  logger.debug("Launch command: {}".format(self.launch_cmd))
@@ -584,10 +616,7 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
584
616
  Returns:
585
617
  Future
586
618
  """
587
- if resource_specification:
588
- logger.error("Ignoring the call specification. "
589
- "Parsl call specification is not supported in HighThroughput Executor.")
590
- raise UnsupportedFeatureError('resource specification', 'HighThroughput Executor', None)
619
+ validate_resource_spec(resource_specification)
591
620
 
592
621
  if self.bad_state_is_set:
593
622
  raise self.executor_exception
@@ -605,8 +634,9 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
605
634
  self.tasks[task_id] = fut
606
635
 
607
636
  try:
608
- fn_buf = pack_apply_message(func, args, kwargs,
609
- buffer_threshold=1024 * 1024)
637
+ fn_buf = pack_res_spec_apply_message(func, args, kwargs,
638
+ resource_specification=resource_specification,
639
+ buffer_threshold=1024 * 1024)
610
640
  except TypeError:
611
641
  raise SerializationError(func.__name__)
612
642
 
@@ -737,12 +767,28 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin):
737
767
  )
738
768
  return job_status
739
769
 
740
- def shutdown(self):
770
+ def shutdown(self, timeout: float = 10.0):
741
771
  """Shutdown the executor, including the interchange. This does not
742
772
  shut down any workers directly - workers should be terminated by the
743
773
  scaling mechanism or by heartbeat timeout.
774
+
775
+ Parameters
776
+ ----------
777
+
778
+ timeout : float
779
+ Amount of time to wait for the Interchange process to terminate before
780
+ we forcefully kill it.
744
781
  """
782
+ if self.interchange_proc is None:
783
+ logger.info("HighThroughputExecutor has not started; skipping shutdown")
784
+ return
745
785
 
746
786
  logger.info("Attempting HighThroughputExecutor shutdown")
787
+
747
788
  self.interchange_proc.terminate()
789
+ self.interchange_proc.join(timeout=timeout)
790
+ if self.interchange_proc.is_alive():
791
+ logger.info("Unable to terminate Interchange process; sending SIGKILL")
792
+ self.interchange_proc.kill()
793
+
748
794
  logger.info("Finished HighThroughputExecutor shutdown attempt")
@@ -0,0 +1,137 @@
1
+ import logging
2
+ from typing import Dict, List, Tuple, Set
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ VALID_LAUNCHERS = ('srun',
7
+ 'aprun',
8
+ 'mpiexec')
9
+
10
+
11
+ class InvalidResourceSpecification(Exception):
12
+ """Exception raised when Invalid keys are supplied via resource specification"""
13
+
14
+ def __init__(self, invalid_keys: Set[str]):
15
+ self.invalid_keys = invalid_keys
16
+
17
+ def __str__(self):
18
+ return f"Invalid resource specification options supplied: {self.invalid_keys}"
19
+
20
+
21
+ def validate_resource_spec(resource_spec: Dict[str, str]):
22
+ """Basic validation of keys in the resource_spec
23
+
24
+ Raises: InvalidResourceSpecification if the resource_spec
25
+ is invalid (e.g, contains invalid keys)
26
+ """
27
+ user_keys = set(resource_spec.keys())
28
+ legal_keys = set(("ranks_per_node",
29
+ "num_nodes",
30
+ "num_ranks",
31
+ "launcher_options",
32
+ ))
33
+ invalid_keys = user_keys - legal_keys
34
+ if invalid_keys:
35
+ raise InvalidResourceSpecification(invalid_keys)
36
+ if "num_nodes" in resource_spec:
37
+ if not resource_spec.get("num_ranks") and resource_spec.get("ranks_per_node"):
38
+ resource_spec["num_ranks"] = str(int(resource_spec["num_nodes"]) * int(resource_spec["ranks_per_node"]))
39
+ if not resource_spec.get("ranks_per_node") and resource_spec.get("num_ranks"):
40
+ resource_spec["ranks_per_node"] = str(int(resource_spec["num_ranks"]) / int(resource_spec["num_nodes"]))
41
+ return
42
+
43
+
44
+ def compose_mpiexec_launch_cmd(
45
+ resource_spec: Dict, node_hostnames: List[str]
46
+ ) -> Tuple[str, str]:
47
+ """Compose mpiexec launch command prefix"""
48
+
49
+ node_str = ",".join(node_hostnames)
50
+ args = [
51
+ "mpiexec",
52
+ "-n",
53
+ resource_spec.get("num_ranks"),
54
+ "-ppn",
55
+ resource_spec.get("ranks_per_node"),
56
+ "-hosts",
57
+ node_str,
58
+ resource_spec.get("launcher_options", ""),
59
+ ]
60
+ prefix = " ".join(str(arg) for arg in args)
61
+ return "PARSL_MPIEXEC_PREFIX", prefix
62
+
63
+
64
+ def compose_srun_launch_cmd(
65
+ resource_spec: Dict, node_hostnames: List[str]
66
+ ) -> Tuple[str, str]:
67
+ """Compose srun launch command prefix"""
68
+
69
+ num_nodes = str(len(node_hostnames))
70
+ args = [
71
+ "srun",
72
+ "--ntasks",
73
+ resource_spec.get("num_ranks"),
74
+ "--ntasks-per-node",
75
+ resource_spec.get("ranks_per_node"),
76
+ "--nodelist",
77
+ ",".join(node_hostnames),
78
+ "--nodes",
79
+ num_nodes,
80
+ resource_spec.get("launcher_options", ""),
81
+ ]
82
+
83
+ prefix = " ".join(str(arg) for arg in args)
84
+ return "PARSL_SRUN_PREFIX", prefix
85
+
86
+
87
+ def compose_aprun_launch_cmd(
88
+ resource_spec: Dict, node_hostnames: List[str]
89
+ ) -> Tuple[str, str]:
90
+ """Compose aprun launch command prefix"""
91
+
92
+ node_str = ",".join(node_hostnames)
93
+ args = [
94
+ "aprun",
95
+ "-n",
96
+ resource_spec.get("num_ranks"),
97
+ "-N",
98
+ resource_spec.get("ranks_per_node"),
99
+ "-node-list",
100
+ node_str,
101
+ resource_spec.get("launcher_options", ""),
102
+ ]
103
+ prefix = " ".join(str(arg) for arg in args)
104
+ return "PARSL_APRUN_PREFIX", prefix
105
+
106
+
107
+ def compose_all(
108
+ mpi_launcher: str, resource_spec: Dict, node_hostnames: List[str]
109
+ ) -> Dict[str, str]:
110
+ """Compose all launch command prefixes and set the default"""
111
+
112
+ all_prefixes = {}
113
+ composers = [
114
+ compose_aprun_launch_cmd,
115
+ compose_srun_launch_cmd,
116
+ compose_mpiexec_launch_cmd,
117
+ ]
118
+ for composer in composers:
119
+ try:
120
+ key, prefix = composer(resource_spec, node_hostnames)
121
+ all_prefixes[key] = prefix
122
+ except Exception:
123
+ logging.exception(
124
+ f"Failed to compose launch prefix with {composer} from {resource_spec}"
125
+ )
126
+ pass
127
+
128
+ if mpi_launcher == "srun":
129
+ all_prefixes["PARSL_MPI_PREFIX"] = all_prefixes["PARSL_SRUN_PREFIX"]
130
+ elif mpi_launcher == "aprun":
131
+ all_prefixes["PARSL_MPI_PREFIX"] = all_prefixes["PARSL_APRUN_PREFIX"]
132
+ elif mpi_launcher == "mpiexec":
133
+ all_prefixes["PARSL_MPI_PREFIX"] = all_prefixes["PARSL_MPIEXEC_PREFIX"]
134
+ else:
135
+ raise RuntimeError(f"Unknown mpi_launcher:{mpi_launcher}")
136
+
137
+ return all_prefixes
@@ -0,0 +1,217 @@
1
+ import logging
2
+ import multiprocessing
3
+ import os
4
+ import pickle
5
+ import queue
6
+ import subprocess
7
+ from enum import Enum
8
+ from typing import Dict, List
9
+
10
+ from parsl.multiprocessing import SpawnContext
11
+ from parsl.serialize import (pack_res_spec_apply_message,
12
+ unpack_res_spec_apply_message)
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Scheduler(Enum):
18
+ Unknown = 0
19
+ Slurm = 1
20
+ PBS = 2
21
+ Cobalt = 3
22
+
23
+
24
+ def get_slurm_hosts_list() -> List[str]:
25
+ """Get list of slurm hosts from scontrol"""
26
+ cmd = "scontrol show hostname $SLURM_NODELIST"
27
+ b_output = subprocess.check_output(
28
+ cmd, stderr=subprocess.STDOUT, shell=True
29
+ ) # bytes
30
+ output = b_output.decode().strip().split()
31
+ return output
32
+
33
+
34
+ def get_pbs_hosts_list() -> List[str]:
35
+ """Get list of PBS hosts from envvar: PBS_NODEFILE"""
36
+ nodefile_name = os.environ["PBS_NODEFILE"]
37
+ with open(nodefile_name) as f:
38
+ return [line.strip() for line in f.readlines()]
39
+
40
+
41
+ def get_cobalt_hosts_list() -> List[str]:
42
+ """Get list of COBALT hosts from envvar: COBALT_NODEFILE"""
43
+ nodefile_name = os.environ["COBALT_NODEFILE"]
44
+ with open(nodefile_name) as f:
45
+ return [line.strip() for line in f.readlines()]
46
+
47
+
48
+ def get_nodes_in_batchjob(scheduler: Scheduler) -> List[str]:
49
+ """Get nodelist from all supported schedulers"""
50
+ nodelist = []
51
+ if scheduler == Scheduler.Slurm:
52
+ nodelist = get_slurm_hosts_list()
53
+ elif scheduler == Scheduler.PBS:
54
+ nodelist = get_pbs_hosts_list()
55
+ elif scheduler == Scheduler.Cobalt:
56
+ nodelist = get_cobalt_hosts_list()
57
+ else:
58
+ raise RuntimeError(f"mpi_mode does not support scheduler:{scheduler}")
59
+ return nodelist
60
+
61
+
62
+ def identify_scheduler() -> Scheduler:
63
+ """Use envvars to determine batch scheduler"""
64
+ if os.environ.get("SLURM_NODELIST"):
65
+ return Scheduler.Slurm
66
+ elif os.environ.get("PBS_NODEFILE"):
67
+ return Scheduler.PBS
68
+ elif os.environ.get("COBALT_NODEFILE"):
69
+ return Scheduler.Cobalt
70
+ else:
71
+ return Scheduler.Unknown
72
+
73
+
74
+ class MPINodesUnavailable(Exception):
75
+ """Raised if there are no free nodes available for an MPI request"""
76
+
77
+ def __init__(self, requested: int, available: int):
78
+ self.requested = requested
79
+ self.available = available
80
+
81
+ def __str__(self):
82
+ return f"MPINodesUnavailable(requested={self.requested} available={self.available})"
83
+
84
+
85
+ class TaskScheduler:
86
+ """Default TaskScheduler that does no taskscheduling
87
+
88
+ This class simply acts as an abstraction over the task_q and result_q
89
+ that can be extended to implement more complex task scheduling logic
90
+ """
91
+ def __init__(
92
+ self,
93
+ pending_task_q: multiprocessing.Queue,
94
+ pending_result_q: multiprocessing.Queue,
95
+ ):
96
+ self.pending_task_q = pending_task_q
97
+ self.pending_result_q = pending_result_q
98
+
99
+ def put_task(self, task) -> None:
100
+ return self.pending_task_q.put(task)
101
+
102
+ def get_result(self, block: bool, timeout: float):
103
+ return self.pending_result_q.get(block, timeout=timeout)
104
+
105
+
106
+ class MPITaskScheduler(TaskScheduler):
107
+ """Extends TaskScheduler to schedule MPI functions over provisioned nodes
108
+ The MPITaskScheduler runs on a Manager on the lead node of a batch job, as
109
+ such it is expected to control task placement over this single batch job.
110
+
111
+ The MPITaskScheduler adds the following functionality:
112
+ 1) Determine list of nodes attached to current batch job
113
+ 2) put_task for execution onto workers:
114
+ a) if resources are available attach resource list
115
+ b) if unavailable place tasks into backlog
116
+ 3) get_result will fetch a result and relinquish nodes,
117
+ and attempt to schedule tasks in backlog if any.
118
+ """
119
+ def __init__(
120
+ self,
121
+ pending_task_q: multiprocessing.Queue,
122
+ pending_result_q: multiprocessing.Queue,
123
+ ):
124
+ super().__init__(pending_task_q, pending_result_q)
125
+ self.scheduler = identify_scheduler()
126
+ # PriorityQueue is threadsafe
127
+ self._backlog_queue: queue.PriorityQueue = queue.PriorityQueue()
128
+ self._map_tasks_to_nodes: Dict[str, List[str]] = {}
129
+ self.available_nodes = get_nodes_in_batchjob(self.scheduler)
130
+ self._free_node_counter = SpawnContext.Value("i", len(self.available_nodes))
131
+ # mp.Value has issues with mypy
132
+ # issue https://github.com/python/typeshed/issues/8799
133
+ # from mypy 0.981 onwards
134
+ self.nodes_q: queue.Queue = queue.Queue()
135
+ for node in self.available_nodes:
136
+ self.nodes_q.put(node)
137
+
138
+ logger.info(
139
+ f"Starting MPITaskScheduler with {len(self.available_nodes)}"
140
+ )
141
+
142
+ def _get_nodes(self, num_nodes: int) -> List[str]:
143
+ """Thread safe method to acquire num_nodes from free resources
144
+
145
+ Raises: MPINodesUnavailable if there aren't enough resources
146
+ Returns: List of nodenames:str
147
+ """
148
+ logger.debug(
149
+ f"Requesting : {num_nodes=} we have {self._free_node_counter}"
150
+ )
151
+ acquired_nodes = []
152
+ with self._free_node_counter.get_lock():
153
+ if num_nodes <= self._free_node_counter.value: # type: ignore[attr-defined]
154
+ self._free_node_counter.value -= num_nodes # type: ignore[attr-defined]
155
+ else:
156
+ raise MPINodesUnavailable(
157
+ requested=num_nodes, available=self._free_node_counter.value # type: ignore[attr-defined]
158
+ )
159
+
160
+ for i in range(num_nodes):
161
+ node = self.nodes_q.get()
162
+ acquired_nodes.append(node)
163
+ return acquired_nodes
164
+
165
+ def _return_nodes(self, nodes: List[str]) -> None:
166
+ """Threadsafe method to return a list of nodes"""
167
+ for node in nodes:
168
+ self.nodes_q.put(node)
169
+ with self._free_node_counter.get_lock():
170
+ self._free_node_counter.value += len(nodes) # type: ignore[attr-defined]
171
+
172
+ def put_task(self, task_package: dict):
173
+ """Schedule task if resources are available otherwise backlog the task"""
174
+ user_ns = locals()
175
+ user_ns.update({"__builtins__": __builtins__})
176
+ _f, _args, _kwargs, resource_spec = unpack_res_spec_apply_message(
177
+ task_package["buffer"], user_ns, copy=False
178
+ )
179
+
180
+ nodes_needed = resource_spec.get("num_nodes")
181
+ if nodes_needed:
182
+ try:
183
+ allocated_nodes = self._get_nodes(nodes_needed)
184
+ except MPINodesUnavailable:
185
+ logger.warning("Not enough resources, placing task into backlog")
186
+ self._backlog_queue.put((nodes_needed, task_package))
187
+ return
188
+ else:
189
+ resource_spec["MPI_NODELIST"] = ",".join(allocated_nodes)
190
+ self._map_tasks_to_nodes[task_package["task_id"]] = allocated_nodes
191
+ buffer = pack_res_spec_apply_message(_f, _args, _kwargs, resource_spec)
192
+ task_package["buffer"] = buffer
193
+
194
+ self.pending_task_q.put(task_package)
195
+
196
+ def _schedule_backlog_tasks(self):
197
+ """Attempt to schedule backlogged tasks"""
198
+ try:
199
+ _nodes_requested, task_package = self._backlog_queue.get(block=False)
200
+ self.put_task(task_package)
201
+ except queue.Empty:
202
+ return
203
+ else:
204
+ # Keep attempting to schedule tasks till we are out of resources
205
+ self._schedule_backlog_tasks()
206
+
207
+ def get_result(self, block: bool, timeout: float):
208
+ """Return result and relinquish provisioned nodes"""
209
+ result_pkl = self.pending_result_q.get(block, timeout=timeout)
210
+ result_dict = pickle.loads(result_pkl)
211
+ if result_dict["type"] == "result":
212
+ task_id = result_dict["task_id"]
213
+ nodes_to_reallocate = self._map_tasks_to_nodes[task_id]
214
+ self._return_nodes(nodes_to_reallocate)
215
+ self._schedule_backlog_tasks()
216
+
217
+ return result_pkl