parsl 2024.2.12__py3-none-any.whl → 2024.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. parsl/channels/errors.py +1 -4
  2. parsl/configs/{comet.py → expanse.py} +5 -5
  3. parsl/dataflow/dflow.py +12 -12
  4. parsl/executors/flux/executor.py +5 -3
  5. parsl/executors/high_throughput/executor.py +56 -10
  6. parsl/executors/high_throughput/mpi_prefix_composer.py +137 -0
  7. parsl/executors/high_throughput/mpi_resource_management.py +217 -0
  8. parsl/executors/high_throughput/process_worker_pool.py +65 -9
  9. parsl/executors/radical/executor.py +6 -3
  10. parsl/executors/radical/rpex_worker.py +2 -2
  11. parsl/jobs/states.py +5 -5
  12. parsl/monitoring/db_manager.py +2 -1
  13. parsl/monitoring/monitoring.py +7 -4
  14. parsl/multiprocessing.py +3 -4
  15. parsl/providers/cobalt/cobalt.py +6 -0
  16. parsl/providers/pbspro/pbspro.py +18 -4
  17. parsl/providers/pbspro/template.py +2 -2
  18. parsl/providers/slurm/slurm.py +17 -4
  19. parsl/providers/slurm/template.py +2 -2
  20. parsl/serialize/__init__.py +7 -2
  21. parsl/serialize/facade.py +32 -1
  22. parsl/tests/test_error_handling/test_resource_spec.py +6 -0
  23. parsl/tests/test_htex/test_htex.py +66 -3
  24. parsl/tests/test_monitoring/test_incomplete_futures.py +65 -0
  25. parsl/tests/test_mpi_apps/__init__.py +0 -0
  26. parsl/tests/test_mpi_apps/test_bad_mpi_config.py +41 -0
  27. parsl/tests/test_mpi_apps/test_mpi_mode_disabled.py +51 -0
  28. parsl/tests/test_mpi_apps/test_mpi_mode_enabled.py +171 -0
  29. parsl/tests/test_mpi_apps/test_mpi_prefix.py +71 -0
  30. parsl/tests/test_mpi_apps/test_mpi_scheduler.py +158 -0
  31. parsl/tests/test_mpi_apps/test_resource_spec.py +145 -0
  32. parsl/tests/test_providers/test_cobalt_deprecation_warning.py +16 -0
  33. parsl/tests/test_providers/test_pbspro_template.py +28 -0
  34. parsl/tests/test_providers/test_slurm_template.py +29 -0
  35. parsl/tests/test_radical/test_mpi_funcs.py +1 -0
  36. parsl/tests/test_scaling/test_scale_down.py +6 -5
  37. parsl/tests/test_serialization/test_htex_code_cache.py +57 -0
  38. parsl/tests/test_serialization/test_pack_resource_spec.py +22 -0
  39. parsl/usage_tracking/usage.py +29 -55
  40. parsl/utils.py +12 -35
  41. parsl/version.py +1 -1
  42. {parsl-2024.2.12.data → parsl-2024.2.26.data}/scripts/process_worker_pool.py +65 -9
  43. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/METADATA +2 -2
  44. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/RECORD +50 -37
  45. parsl/configs/cooley.py +0 -29
  46. parsl/configs/theta.py +0 -33
  47. {parsl-2024.2.12.data → parsl-2024.2.26.data}/scripts/exec_parsl_function.py +0 -0
  48. {parsl-2024.2.12.data → parsl-2024.2.26.data}/scripts/parsl_coprocess.py +0 -0
  49. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/LICENSE +0 -0
  50. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/WHEEL +0 -0
  51. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/entry_points.txt +0 -0
  52. {parsl-2024.2.12.dist-info → parsl-2024.2.26.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,40 @@
1
1
  import pathlib
2
+ from unittest import mock
2
3
 
3
4
  import pytest
4
5
 
5
6
  from parsl import curvezmq
6
7
  from parsl import HighThroughputExecutor
8
+ from parsl.multiprocessing import ForkProcess
9
+
10
+ _MOCK_BASE = "parsl.executors.high_throughput.executor"
11
+
12
+
13
+ @pytest.fixture
14
+ def encrypted(request: pytest.FixtureRequest):
15
+ if hasattr(request, "param"):
16
+ return request.param
17
+ return True
18
+
19
+
20
+ @pytest.fixture
21
+ def htex(encrypted: bool):
22
+ htex = HighThroughputExecutor(encrypted=encrypted)
23
+
24
+ yield htex
25
+
26
+ htex.shutdown()
7
27
 
8
28
 
9
29
  @pytest.mark.local
10
- @pytest.mark.parametrize("encrypted", (True, False))
30
+ @pytest.mark.parametrize("encrypted", (True, False), indirect=True)
11
31
  @pytest.mark.parametrize("cert_dir_provided", (True, False))
12
32
  def test_htex_start_encrypted(
13
- encrypted: bool, cert_dir_provided: bool, tmpd_cwd: pathlib.Path
33
+ encrypted: bool,
34
+ cert_dir_provided: bool,
35
+ htex: HighThroughputExecutor,
36
+ tmpd_cwd: pathlib.Path,
14
37
  ):
15
- htex = HighThroughputExecutor(encrypted=encrypted)
16
38
  htex.run_dir = str(tmpd_cwd)
17
39
  if cert_dir_provided:
18
40
  provided_base_dir = tmpd_cwd / "provided"
@@ -44,3 +66,44 @@ def test_htex_start_encrypted(
44
66
  assert htex.outgoing_q.zmq_context.cert_dir is None
45
67
  assert htex.incoming_q.zmq_context.cert_dir is None
46
68
  assert htex.command_client.zmq_context.cert_dir is None
69
+
70
+
71
+ @pytest.mark.local
72
+ @pytest.mark.parametrize("started", (True, False))
73
+ @pytest.mark.parametrize("timeout_expires", (True, False))
74
+ @mock.patch(f"{_MOCK_BASE}.logger")
75
+ def test_htex_shutdown(
76
+ mock_logger: mock.MagicMock,
77
+ started: bool,
78
+ timeout_expires: bool,
79
+ htex: HighThroughputExecutor,
80
+ ):
81
+ mock_ix_proc = mock.Mock(spec=ForkProcess)
82
+
83
+ if started:
84
+ htex.interchange_proc = mock_ix_proc
85
+ mock_ix_proc.is_alive.return_value = True
86
+
87
+ if not timeout_expires:
88
+ # Simulate termination of the Interchange process
89
+ def kill_interchange(*args, **kwargs):
90
+ mock_ix_proc.is_alive.return_value = False
91
+
92
+ mock_ix_proc.terminate.side_effect = kill_interchange
93
+
94
+ htex.shutdown()
95
+
96
+ mock_logs = mock_logger.info.call_args_list
97
+ if started:
98
+ assert mock_ix_proc.terminate.called
99
+ assert mock_ix_proc.join.called
100
+ assert {"timeout": 10} == mock_ix_proc.join.call_args[1]
101
+ if timeout_expires:
102
+ assert "Unable to terminate Interchange" in mock_logs[1][0][0]
103
+ assert mock_ix_proc.kill.called
104
+ assert "Attempting" in mock_logs[0][0][0]
105
+ assert "Finished" in mock_logs[-1][0][0]
106
+ else:
107
+ assert not mock_ix_proc.terminate.called
108
+ assert not mock_ix_proc.join.called
109
+ assert "has not started" in mock_logs[0][0][0]
@@ -0,0 +1,65 @@
1
+ import logging
2
+ import os
3
+ import parsl
4
+ import pytest
5
+ import random
6
+
7
+ from concurrent.futures import Future
8
+
9
+
10
+ @parsl.python_app
11
+ def this_app(inputs=()):
12
+ return inputs[0]
13
+
14
+
15
+ @pytest.mark.local
16
+ def test_future_representation(tmpd_cwd):
17
+ import sqlalchemy
18
+ from sqlalchemy import text
19
+ from parsl.tests.configs.htex_local_alternate import fresh_config
20
+
21
+ monitoring_db = str(tmpd_cwd / "monitoring.db")
22
+ monitoring_url = "sqlite:///" + monitoring_db
23
+
24
+ c = fresh_config()
25
+ c.monitoring.logging_endpoint = monitoring_url
26
+ c.run_dir = tmpd_cwd
27
+
28
+ parsl.load(c)
29
+
30
+ # we're going to pass this TOKEN into an app via a pre-requisite Future,
31
+ # and then expect to see it appear in the monitoring database.
32
+ TOKEN = random.randint(0, 1000000)
33
+
34
+ # make a Future that has no result yet
35
+ # invoke a task that depends on it
36
+ # inspect and insert something about the monitoring recorded value of that Future
37
+ # make the Future complete
38
+ # inspect and insert something about the monitoring recorded value of that Future
39
+
40
+ f1 = Future()
41
+
42
+ f2 = this_app(inputs=[f1])
43
+
44
+ f1.set_result(TOKEN)
45
+
46
+ assert f2.result() == TOKEN
47
+
48
+ # this cleanup gives a barrier that allows the monitoring code to store
49
+ # everything it has in the database - without this, there's a race
50
+ # condition that the task will not have arrived in the database yet.
51
+ # A different approach for this test might be to poll the DB for a few
52
+ # seconds, with the assumption "data will arrive in the DB within
53
+ # 30 seconds, but probably much sooner".
54
+ parsl.dfk().cleanup()
55
+ parsl.clear()
56
+
57
+ engine = sqlalchemy.create_engine(monitoring_url)
58
+ with engine.begin() as connection:
59
+ result = connection.execute(text("SELECT COUNT(*) FROM task"))
60
+ (task_count, ) = result.first()
61
+ assert task_count == 1
62
+
63
+ result = connection.execute(text("SELECT task_inputs FROM task"))
64
+ (task_inputs, ) = result.first()
65
+ assert task_inputs == "[" + repr(TOKEN) + "]"
File without changes
@@ -0,0 +1,41 @@
1
+ import pytest
2
+
3
+ from parsl import Config
4
+ from parsl.executors import HighThroughputExecutor
5
+ from parsl.launchers import SrunLauncher, SingleNodeLauncher, SimpleLauncher, AprunLauncher
6
+ from parsl.providers import SlurmProvider
7
+
8
+
9
+ @pytest.mark.local
10
+ def test_bad_launcher_with_mpi_mode():
11
+ """AssertionError if a launcher other than SingleNodeLauncher is supplied"""
12
+
13
+ for launcher in [SrunLauncher(), SimpleLauncher(), AprunLauncher()]:
14
+ with pytest.raises(AssertionError):
15
+ Config(executors=[
16
+ HighThroughputExecutor(
17
+ enable_mpi_mode=True,
18
+ provider=SlurmProvider(launcher=launcher),
19
+ )
20
+ ])
21
+
22
+
23
+ @pytest.mark.local
24
+ def test_correct_launcher_with_mpi_mode():
25
+ """Confirm that SingleNodeLauncer works with mpi_mode"""
26
+
27
+ config = Config(executors=[
28
+ HighThroughputExecutor(
29
+ enable_mpi_mode=True,
30
+ provider=SlurmProvider(launcher=SingleNodeLauncher()),
31
+ )
32
+ ])
33
+ assert isinstance(config.executors[0].provider.launcher, SingleNodeLauncher)
34
+
35
+ config = Config(executors=[
36
+ HighThroughputExecutor(
37
+ enable_mpi_mode=True,
38
+ provider=SlurmProvider(),
39
+ )
40
+ ])
41
+ assert isinstance(config.executors[0].provider.launcher, SingleNodeLauncher)
@@ -0,0 +1,51 @@
1
+ import logging
2
+ from typing import Dict
3
+ import pytest
4
+ import parsl
5
+ from parsl import python_app
6
+ from parsl.tests.configs.htex_local import fresh_config
7
+
8
+ EXECUTOR_LABEL = "MPI_TEST"
9
+
10
+
11
+ def local_setup():
12
+ config = fresh_config()
13
+ config.executors[0].label = EXECUTOR_LABEL
14
+ config.executors[0].max_workers = 1
15
+ config.executors[0].enable_mpi_mode = False
16
+ parsl.load(config)
17
+
18
+
19
+ def local_teardown():
20
+ parsl.dfk().cleanup()
21
+ parsl.clear()
22
+
23
+
24
+ @python_app
25
+ def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict:
26
+ import os
27
+
28
+ parsl_vars = {}
29
+ for key in os.environ:
30
+ if key.startswith("PARSL_"):
31
+ parsl_vars[key] = os.environ[key]
32
+ return parsl_vars
33
+
34
+
35
+ @pytest.mark.local
36
+ def test_only_resource_specs_set():
37
+ """Confirm that resource_spec env vars are set while launch prefixes are not
38
+ when enable_mpi_mode = False"""
39
+ resource_spec = {
40
+ "num_nodes": 4,
41
+ "ranks_per_node": 2,
42
+ }
43
+
44
+ future = get_env_vars(parsl_resource_specification=resource_spec)
45
+
46
+ result = future.result()
47
+ assert isinstance(result, Dict)
48
+ assert "PARSL_DEFAULT_PREFIX" not in result
49
+ assert "PARSL_SRUN_PREFIX" not in result
50
+ assert result["PARSL_NUM_NODES"] == str(resource_spec["num_nodes"])
51
+ assert result["PARSL_RANKS_PER_NODE"] == str(resource_spec["ranks_per_node"])
@@ -0,0 +1,171 @@
1
+ import logging
2
+ import random
3
+ from typing import Dict
4
+ import pytest
5
+ import parsl
6
+ from parsl import python_app, bash_app
7
+ from parsl.tests.configs.htex_local import fresh_config
8
+
9
+ import os
10
+
11
+ EXECUTOR_LABEL = "MPI_TEST"
12
+
13
+
14
+ def local_setup():
15
+ config = fresh_config()
16
+ config.executors[0].label = EXECUTOR_LABEL
17
+ config.executors[0].max_workers = 2
18
+ config.executors[0].enable_mpi_mode = True
19
+ config.executors[0].mpi_launcher = "mpiexec"
20
+
21
+ cwd = os.path.abspath(os.path.dirname(__file__))
22
+ pbs_nodefile = os.path.join(cwd, "mocks", "pbs_nodefile")
23
+
24
+ config.executors[0].provider.worker_init = f"export PBS_NODEFILE={pbs_nodefile}"
25
+
26
+ parsl.load(config)
27
+
28
+
29
+ def local_teardown():
30
+ parsl.dfk().cleanup()
31
+ parsl.clear()
32
+
33
+
34
+ @python_app
35
+ def get_env_vars(parsl_resource_specification: Dict = {}) -> Dict:
36
+ import os
37
+
38
+ parsl_vars = {}
39
+ for key in os.environ:
40
+ if key.startswith("PARSL_"):
41
+ parsl_vars[key] = os.environ[key]
42
+ return parsl_vars
43
+
44
+
45
+ @pytest.mark.local
46
+ def test_only_resource_specs_set():
47
+ """Confirm that resource_spec env vars are set while launch prefixes are not
48
+ when enable_mpi_mode = False"""
49
+ resource_spec = {
50
+ "num_nodes": 2,
51
+ "ranks_per_node": 2,
52
+ }
53
+
54
+ future = get_env_vars(parsl_resource_specification=resource_spec)
55
+
56
+ result = future.result()
57
+ assert isinstance(result, Dict)
58
+ logging.warning(f"Got table: {result}")
59
+ assert "PARSL_MPI_PREFIX" in result
60
+ assert "PARSL_MPIEXEC_PREFIX" in result
61
+ assert result["PARSL_MPI_PREFIX"] == result["PARSL_MPIEXEC_PREFIX"]
62
+ assert result["PARSL_NUM_NODES"] == str(resource_spec["num_nodes"])
63
+ assert result["PARSL_RANKS_PER_NODE"] == str(resource_spec["ranks_per_node"])
64
+ assert result["PARSL_NUM_RANKS"] == str(resource_spec["ranks_per_node"] * resource_spec["num_nodes"])
65
+
66
+
67
+ @bash_app
68
+ def echo_launch_cmd(
69
+ parsl_resource_specification: Dict,
70
+ stdout=parsl.AUTO_LOGNAME,
71
+ stderr=parsl.AUTO_LOGNAME,
72
+ ):
73
+ return 'echo "$PARSL_MPI_PREFIX hostname"'
74
+
75
+
76
+ @pytest.mark.local
77
+ def test_bash_default_prefix_set():
78
+ """Confirm that resource_spec env vars are set while launch prefixes are not
79
+ when enable_mpi_mode = False"""
80
+ resource_spec = {
81
+ "num_nodes": 2,
82
+ "ranks_per_node": 2,
83
+ }
84
+
85
+ future = echo_launch_cmd(parsl_resource_specification=resource_spec)
86
+
87
+ result = future.result()
88
+ assert result == 0
89
+ with open(future.stdout) as f:
90
+ output = f.readlines()
91
+ assert output[0].startswith("mpiexec")
92
+ logging.warning(f"output : {output}")
93
+
94
+
95
+ @pytest.mark.local
96
+ def test_bash_multiple_set():
97
+ """Confirm that multiple apps can run without blocking each other out
98
+ when enable_mpi_mode = False"""
99
+ resource_spec = {
100
+ "num_nodes": 2,
101
+ "num_ranks": 4,
102
+ }
103
+ futures = []
104
+ for i in range(4):
105
+ resource_spec["num_nodes"] = i + 1
106
+ future = echo_launch_cmd(parsl_resource_specification=resource_spec)
107
+ futures.append(future)
108
+
109
+ for future in futures:
110
+ result = future.result()
111
+ assert result == 0
112
+ with open(future.stdout) as f:
113
+ output = f.readlines()
114
+ assert output[0].startswith("mpiexec")
115
+
116
+
117
+ @bash_app
118
+ def bash_resource_spec(parsl_resource_specification=None, stdout=parsl.AUTO_LOGNAME):
119
+ total_ranks = (
120
+ parsl_resource_specification["ranks_per_node"] * parsl_resource_specification["num_nodes"]
121
+ )
122
+ return f'echo "{total_ranks}"'
123
+
124
+
125
+ @pytest.mark.local
126
+ def test_bash_app_using_resource_spec():
127
+ resource_spec = {
128
+ "num_nodes": 2,
129
+ "ranks_per_node": 2,
130
+ }
131
+ future = bash_resource_spec(parsl_resource_specification=resource_spec)
132
+ assert future.result() == 0
133
+ with open(future.stdout) as f:
134
+ output = f.readlines()
135
+ total_ranks = resource_spec["num_nodes"] * resource_spec["ranks_per_node"]
136
+ assert int(output[0].strip()) == total_ranks
137
+
138
+
139
+ @python_app
140
+ def mock_app(sleep_dur: float = 0.0, parsl_resource_specification: Dict = {}):
141
+ import os
142
+ import time
143
+ time.sleep(sleep_dur)
144
+
145
+ total_ranks = int(os.environ["PARSL_NUM_NODES"]) * int(os.environ["PARSL_RANKS_PER_NODE"])
146
+ nodes = os.environ["PARSL_MPI_NODELIST"].split(',')
147
+
148
+ return total_ranks, nodes
149
+
150
+
151
+ @pytest.mark.local
152
+ def test_simulated_load(rounds: int = 100):
153
+
154
+ node_choices = (1, 2, 4)
155
+ sleep_choices = (0, 0.01, 0.02, 0.04)
156
+ ranks_per_node = (4, 8)
157
+
158
+ futures = {}
159
+ for i in range(rounds):
160
+ resource_spec = {
161
+ "num_nodes": random.choice(node_choices),
162
+ "ranks_per_node": random.choice(ranks_per_node),
163
+ }
164
+ future = mock_app(sleep_dur=random.choice(sleep_choices),
165
+ parsl_resource_specification=resource_spec)
166
+ futures[future] = resource_spec
167
+
168
+ for future in futures:
169
+ total_ranks, nodes = future.result(timeout=10)
170
+ assert len(nodes) == futures[future]["num_nodes"]
171
+ assert total_ranks == futures[future]["num_nodes"] * futures[future]["ranks_per_node"]
@@ -0,0 +1,71 @@
1
+ import logging
2
+ import pytest
3
+
4
+ from parsl.executors.high_throughput.mpi_resource_management import Scheduler
5
+ from parsl.executors.high_throughput.mpi_prefix_composer import (
6
+ compose_srun_launch_cmd,
7
+ compose_aprun_launch_cmd,
8
+ compose_mpiexec_launch_cmd,
9
+ compose_all,
10
+ )
11
+
12
+
13
+ resource_spec = {"num_nodes": 2,
14
+ "num_ranks": 8,
15
+ "ranks_per_node": 4}
16
+
17
+
18
+ @pytest.mark.local
19
+ def test_srun_launch_cmd():
20
+ prefix_name, composed_prefix = compose_srun_launch_cmd(
21
+ resource_spec=resource_spec, node_hostnames=["node1", "node2"]
22
+ )
23
+ assert prefix_name == "PARSL_SRUN_PREFIX"
24
+ logging.warning(composed_prefix)
25
+
26
+ assert "None" not in composed_prefix
27
+
28
+
29
+ @pytest.mark.local
30
+ def test_aprun_launch_cmd():
31
+ prefix_name, composed_prefix = compose_aprun_launch_cmd(
32
+ resource_spec=resource_spec, node_hostnames=["node1", "node2"]
33
+ )
34
+ logging.warning(composed_prefix)
35
+ assert prefix_name == "PARSL_APRUN_PREFIX"
36
+ assert "None" not in composed_prefix
37
+
38
+
39
+ @pytest.mark.local
40
+ def test_mpiexec_launch_cmd():
41
+ prefix_name, composed_prefix = compose_mpiexec_launch_cmd(
42
+ resource_spec=resource_spec, node_hostnames=["node1", "node2"]
43
+ )
44
+ logging.warning(composed_prefix)
45
+ assert prefix_name == "PARSL_MPIEXEC_PREFIX"
46
+ assert "None" not in composed_prefix
47
+
48
+
49
+ @pytest.mark.local
50
+ def test_slurm_launch_cmd():
51
+ table = compose_all(
52
+ mpi_launcher="srun",
53
+ resource_spec=resource_spec,
54
+ node_hostnames=["NODE001", "NODE002"],
55
+ )
56
+
57
+ assert "PARSL_MPI_PREFIX" in table
58
+ assert "PARSL_SRUN_PREFIX" in table
59
+
60
+
61
+ @pytest.mark.local
62
+ def test_default_launch_cmd():
63
+ table = compose_all(
64
+ mpi_launcher="srun",
65
+ resource_spec=resource_spec,
66
+ node_hostnames=["NODE001", "NODE002"],
67
+ )
68
+
69
+ assert "PARSL_MPI_PREFIX" in table
70
+ assert "PARSL_SRUN_PREFIX" in table
71
+ assert table["PARSL_MPI_PREFIX"] == table["PARSL_SRUN_PREFIX"]
@@ -0,0 +1,158 @@
1
+ import logging
2
+ import os
3
+ from unittest import mock
4
+ import pytest
5
+ import pickle
6
+ from parsl.executors.high_throughput.mpi_resource_management import TaskScheduler, MPITaskScheduler
7
+ from parsl.multiprocessing import SpawnContext
8
+ from parsl.serialize import pack_res_spec_apply_message, unpack_res_spec_apply_message
9
+
10
+
11
+ @pytest.fixture(autouse=True)
12
+ def set_pbs_nodefile_envvars():
13
+ cwd = os.path.abspath(os.path.dirname(__file__))
14
+ pbs_nodefile = os.path.join(cwd, "mocks", "pbs_nodefile.8")
15
+ with mock.patch.dict(os.environ, {"PBS_NODEFILE": pbs_nodefile}):
16
+ yield
17
+
18
+
19
+ @pytest.mark.local
20
+ def test_NoopScheduler():
21
+ task_q, result_q = SpawnContext.Queue(), SpawnContext.Queue()
22
+ scheduler = TaskScheduler(task_q, result_q)
23
+
24
+ scheduler.put_task("TaskFoo")
25
+ assert task_q.get() == "TaskFoo"
26
+
27
+ result_q.put("Result1")
28
+ assert scheduler.get_result(True, 1) == "Result1"
29
+
30
+
31
+ @pytest.mark.local
32
+ def test_MPISched_put_task():
33
+ task_q, result_q = SpawnContext.Queue(), SpawnContext.Queue()
34
+ scheduler = MPITaskScheduler(task_q, result_q)
35
+
36
+ assert scheduler.available_nodes
37
+ assert len(scheduler.available_nodes) == 8
38
+ assert scheduler._free_node_counter.value == 8
39
+
40
+ mock_task_buffer = pack_res_spec_apply_message("func",
41
+ "args",
42
+ "kwargs",
43
+ resource_specification={"num_nodes": 2,
44
+ "ranks_per_node": 2})
45
+ task_package = {"task_id": 1, "buffer": mock_task_buffer}
46
+ scheduler.put_task(task_package)
47
+
48
+ assert scheduler._free_node_counter.value == 6
49
+
50
+
51
+ @pytest.mark.local
52
+ def test_MPISched_get_result():
53
+ task_q, result_q = SpawnContext.Queue(), SpawnContext.Queue()
54
+ scheduler = MPITaskScheduler(task_q, result_q)
55
+
56
+ assert scheduler.available_nodes
57
+ assert len(scheduler.available_nodes) == 8
58
+ assert scheduler._free_node_counter.value == 8
59
+
60
+ nodes = [scheduler.nodes_q.get() for _ in range(4)]
61
+ scheduler._free_node_counter.value = 4
62
+ scheduler._map_tasks_to_nodes[1] = nodes
63
+
64
+ result_package = pickle.dumps({"task_id": 1, "type": "result", "buffer": "Foo"})
65
+ result_q.put(result_package)
66
+ result_received = scheduler.get_result(block=True, timeout=1)
67
+ assert result_received == result_package
68
+
69
+ assert scheduler._free_node_counter.value == 8
70
+
71
+
72
+ @pytest.mark.local
73
+ def test_MPISched_roundtrip():
74
+ task_q, result_q = SpawnContext.Queue(), SpawnContext.Queue()
75
+ scheduler = MPITaskScheduler(task_q, result_q)
76
+
77
+ assert scheduler.available_nodes
78
+ assert len(scheduler.available_nodes) == 8
79
+
80
+ for round in range(1, 9):
81
+ assert scheduler._free_node_counter.value == 8
82
+
83
+ mock_task_buffer = pack_res_spec_apply_message("func",
84
+ "args",
85
+ "kwargs",
86
+ resource_specification={"num_nodes": round,
87
+ "ranks_per_node": 2})
88
+ task_package = {"task_id": round, "buffer": mock_task_buffer}
89
+ scheduler.put_task(task_package)
90
+
91
+ assert scheduler._free_node_counter.value == 8 - round
92
+
93
+ # Pop in a mock result
94
+ result_pkl = pickle.dumps({"task_id": round, "type": "result", "buffer": "RESULT BUF"})
95
+ result_q.put(result_pkl)
96
+
97
+ got_result = scheduler.get_result(True, 1)
98
+ assert got_result == result_pkl
99
+
100
+
101
+ @pytest.mark.local
102
+ def test_MPISched_contention():
103
+ """Second task has to wait for the first task due to insufficient resources"""
104
+ task_q, result_q = SpawnContext.Queue(), SpawnContext.Queue()
105
+ scheduler = MPITaskScheduler(task_q, result_q)
106
+
107
+ assert scheduler.available_nodes
108
+ assert len(scheduler.available_nodes) == 8
109
+
110
+ assert scheduler._free_node_counter.value == 8
111
+
112
+ mock_task_buffer = pack_res_spec_apply_message("func",
113
+ "args",
114
+ "kwargs",
115
+ resource_specification={
116
+ "num_nodes": 8,
117
+ "ranks_per_node": 2
118
+ })
119
+ task_package = {"task_id": 1, "buffer": mock_task_buffer}
120
+ scheduler.put_task(task_package)
121
+
122
+ assert scheduler._free_node_counter.value == 0
123
+ assert scheduler._backlog_queue.empty()
124
+
125
+ mock_task_buffer = pack_res_spec_apply_message("func",
126
+ "args",
127
+ "kwargs",
128
+ resource_specification={
129
+ "num_nodes": 8,
130
+ "ranks_per_node": 2
131
+ })
132
+ task_package = {"task_id": 2, "buffer": mock_task_buffer}
133
+ scheduler.put_task(task_package)
134
+
135
+ # Second task should now be in the backlog_queue
136
+ assert not scheduler._backlog_queue.empty()
137
+
138
+ # Confirm that the first task is available and has all 8 nodes provisioned
139
+ task_on_worker_side = task_q.get()
140
+ assert task_on_worker_side['task_id'] == 1
141
+ _, _, _, resource_spec = unpack_res_spec_apply_message(task_on_worker_side['buffer'])
142
+ assert len(resource_spec['MPI_NODELIST'].split(',')) == 8
143
+ assert task_q.empty() # Confirm that task 2 is not yet scheduled
144
+
145
+ # Simulate worker returning result and the scheduler picking up result
146
+ result_pkl = pickle.dumps({"task_id": 1, "type": "result", "buffer": "RESULT BUF"})
147
+ result_q.put(result_pkl)
148
+ got_result = scheduler.get_result(True, 1)
149
+ assert got_result == result_pkl
150
+
151
+ # Now task2 must be scheduled
152
+ assert scheduler._backlog_queue.empty()
153
+
154
+ # Pop in a mock result
155
+ task_on_worker_side = task_q.get()
156
+ assert task_on_worker_side['task_id'] == 2
157
+ _, _, _, resource_spec = unpack_res_spec_apply_message(task_on_worker_side['buffer'])
158
+ assert len(resource_spec['MPI_NODELIST'].split(',')) == 8