earthkit-workflows 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cascade/benchmarks/__main__.py +1 -277
- cascade/benchmarks/dask.py +33 -0
- cascade/benchmarks/util.py +294 -0
- cascade/executor/executor.py +9 -4
- cascade/executor/platform.py +12 -0
- cascade/executor/runner/entrypoint.py +8 -12
- cascade/gateway/__main__.py +4 -2
- cascade/gateway/api.py +13 -0
- cascade/gateway/router.py +86 -5
- cascade/gateway/server.py +4 -2
- cascade/low/dask.py +99 -0
- cascade/low/into.py +1 -1
- earthkit/workflows/_version.py +1 -1
- {earthkit_workflows-0.4.3.dist-info → earthkit_workflows-0.4.4.dist-info}/METADATA +1 -1
- {earthkit_workflows-0.4.3.dist-info → earthkit_workflows-0.4.4.dist-info}/RECORD +18 -15
- {earthkit_workflows-0.4.3.dist-info → earthkit_workflows-0.4.4.dist-info}/WHEEL +0 -0
- {earthkit_workflows-0.4.3.dist-info → earthkit_workflows-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {earthkit_workflows-0.4.3.dist-info → earthkit_workflows-0.4.4.dist-info}/top_level.txt +0 -0
cascade/benchmarks/__main__.py
CHANGED
|
@@ -19,285 +19,9 @@ Make sure you correctly configure:
|
|
|
19
19
|
- your venv (cascade, fiab, pproc-cascade, compatible version of earthkit-data, ...)
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
# TODO rework, simplify
|
|
23
|
-
|
|
24
|
-
import logging
|
|
25
|
-
import logging.config
|
|
26
|
-
import multiprocessing
|
|
27
|
-
import os
|
|
28
|
-
import subprocess
|
|
29
|
-
import sys
|
|
30
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
31
|
-
from time import perf_counter_ns
|
|
32
|
-
|
|
33
22
|
import fire
|
|
34
|
-
import orjson
|
|
35
|
-
|
|
36
|
-
import cascade.executor.platform as platform
|
|
37
|
-
import cascade.low.into
|
|
38
|
-
from cascade.controller.impl import run
|
|
39
|
-
from cascade.executor.bridge import Bridge
|
|
40
|
-
from cascade.executor.comms import callback
|
|
41
|
-
from cascade.executor.config import logging_config, logging_config_filehandler
|
|
42
|
-
from cascade.executor.executor import Executor
|
|
43
|
-
from cascade.executor.msg import BackboneAddress, ExecutorShutdown
|
|
44
|
-
from cascade.low.core import JobInstance
|
|
45
|
-
from cascade.low.func import msum
|
|
46
|
-
from cascade.scheduler.precompute import precompute
|
|
47
|
-
from earthkit.workflows.graph import Graph, deduplicate_nodes
|
|
48
|
-
|
|
49
|
-
logger = logging.getLogger("cascade.benchmarks")
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def get_job(benchmark: str | None, instance_path: str | None) -> JobInstance:
|
|
53
|
-
# NOTE because of os.environ, we don't import all... ideally we'd have some file-based init/config mech instead
|
|
54
|
-
if benchmark is not None and instance_path is not None:
|
|
55
|
-
raise TypeError("specified both benchmark name and job instance")
|
|
56
|
-
elif instance_path is not None:
|
|
57
|
-
with open(instance_path, "rb") as f:
|
|
58
|
-
d = orjson.loads(f.read())
|
|
59
|
-
return JobInstance(**d)
|
|
60
|
-
elif benchmark is not None:
|
|
61
|
-
if benchmark.startswith("j1"):
|
|
62
|
-
import cascade.benchmarks.job1 as job1
|
|
63
|
-
|
|
64
|
-
graphs = {
|
|
65
|
-
"j1.prob": job1.get_prob(),
|
|
66
|
-
"j1.ensms": job1.get_ensms(),
|
|
67
|
-
"j1.efi": job1.get_efi(),
|
|
68
|
-
}
|
|
69
|
-
union = lambda prefix: deduplicate_nodes(
|
|
70
|
-
msum((v for k, v in graphs.items() if k.startswith(prefix)), Graph)
|
|
71
|
-
)
|
|
72
|
-
graphs["j1.all"] = union("j1.")
|
|
73
|
-
return cascade.low.into.graph2job(graphs[benchmark])
|
|
74
|
-
elif benchmark.startswith("generators"):
|
|
75
|
-
import cascade.benchmarks.generators as generators
|
|
76
|
-
|
|
77
|
-
return generators.get_job()
|
|
78
|
-
elif benchmark.startswith("matmul"):
|
|
79
|
-
import cascade.benchmarks.matmul as matmul
|
|
80
|
-
|
|
81
|
-
return matmul.get_job()
|
|
82
|
-
elif benchmark.startswith("dist"):
|
|
83
|
-
import cascade.benchmarks.dist as dist
|
|
84
|
-
|
|
85
|
-
return dist.get_job()
|
|
86
|
-
else:
|
|
87
|
-
raise NotImplementedError(benchmark)
|
|
88
|
-
else:
|
|
89
|
-
raise TypeError("specified neither benchmark name nor job instance")
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def get_cuda_count() -> int:
|
|
93
|
-
try:
|
|
94
|
-
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
|
95
|
-
# TODO we dont want to just count, we want to actually use literally these ids
|
|
96
|
-
# NOTE this is particularly useful for "" value -- careful when refactoring
|
|
97
|
-
visible = os.environ["CUDA_VISIBLE_DEVICES"]
|
|
98
|
-
visible_count = sum(1 for e in visible if e == ",") + (1 if visible else 0)
|
|
99
|
-
return visible_count
|
|
100
|
-
gpus = sum(
|
|
101
|
-
1
|
|
102
|
-
for l in subprocess.run(
|
|
103
|
-
["nvidia-smi", "--list-gpus"], check=True, capture_output=True
|
|
104
|
-
)
|
|
105
|
-
.stdout.decode("ascii")
|
|
106
|
-
.split("\n")
|
|
107
|
-
if "GPU" in l
|
|
108
|
-
)
|
|
109
|
-
except:
|
|
110
|
-
logger.exception("unable to determine available gpus")
|
|
111
|
-
gpus = 0
|
|
112
|
-
return gpus
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def get_gpu_count(host_idx: int, worker_count: int) -> int:
|
|
116
|
-
if sys.platform == "darwin":
|
|
117
|
-
# we should inspect some gpu capabilities details to prevent overcommit
|
|
118
|
-
return worker_count
|
|
119
|
-
else:
|
|
120
|
-
if host_idx == 0:
|
|
121
|
-
return get_cuda_count()
|
|
122
|
-
else:
|
|
123
|
-
return 0
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def launch_executor(
|
|
127
|
-
job_instance: JobInstance,
|
|
128
|
-
controller_address: BackboneAddress,
|
|
129
|
-
workers_per_host: int,
|
|
130
|
-
portBase: int,
|
|
131
|
-
i: int,
|
|
132
|
-
shm_vol_gb: int | None,
|
|
133
|
-
gpu_count: int,
|
|
134
|
-
log_base: str | None,
|
|
135
|
-
url_base: str,
|
|
136
|
-
):
|
|
137
|
-
if log_base is not None:
|
|
138
|
-
log_base = f"{log_base}.host{i}"
|
|
139
|
-
log_path = f"{log_base}.txt"
|
|
140
|
-
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
141
|
-
else:
|
|
142
|
-
logging.config.dictConfig(logging_config)
|
|
143
|
-
try:
|
|
144
|
-
logger.info(f"will set {gpu_count} gpus on host {i}")
|
|
145
|
-
os.environ["CASCADE_GPU_COUNT"] = str(gpu_count)
|
|
146
|
-
executor = Executor(
|
|
147
|
-
job_instance,
|
|
148
|
-
controller_address,
|
|
149
|
-
workers_per_host,
|
|
150
|
-
f"h{i}",
|
|
151
|
-
portBase,
|
|
152
|
-
shm_vol_gb,
|
|
153
|
-
log_base,
|
|
154
|
-
url_base,
|
|
155
|
-
)
|
|
156
|
-
executor.register()
|
|
157
|
-
executor.recv_loop()
|
|
158
|
-
except Exception:
|
|
159
|
-
# NOTE we log this to get the stacktrace into the logfile
|
|
160
|
-
logger.exception("executor failure")
|
|
161
|
-
raise
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
def run_locally(
|
|
165
|
-
job: JobInstance,
|
|
166
|
-
hosts: int,
|
|
167
|
-
workers: int,
|
|
168
|
-
portBase: int = 12345,
|
|
169
|
-
log_base: str | None = None,
|
|
170
|
-
report_address: str | None = None,
|
|
171
|
-
):
|
|
172
|
-
if log_base is not None:
|
|
173
|
-
log_path = f"{log_base}.controller.txt"
|
|
174
|
-
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
175
|
-
else:
|
|
176
|
-
logging.config.dictConfig(logging_config)
|
|
177
|
-
logger.debug(f"local run starting with {hosts=} and {workers=} on {portBase=}")
|
|
178
|
-
launch = perf_counter_ns()
|
|
179
|
-
c = f"tcp://localhost:{portBase}"
|
|
180
|
-
m = f"tcp://localhost:{portBase+1}"
|
|
181
|
-
ps = []
|
|
182
|
-
try:
|
|
183
|
-
# executors forking
|
|
184
|
-
for i, executor in enumerate(range(hosts)):
|
|
185
|
-
gpu_count = get_gpu_count(i, workers)
|
|
186
|
-
# NOTE forkserver/spawn seem to forget venv, we need fork
|
|
187
|
-
logger.debug(f"forking into executor on host {i}")
|
|
188
|
-
p = multiprocessing.get_context("fork").Process(
|
|
189
|
-
target=launch_executor,
|
|
190
|
-
args=(
|
|
191
|
-
job,
|
|
192
|
-
c,
|
|
193
|
-
workers,
|
|
194
|
-
portBase + 1 + i * 10,
|
|
195
|
-
i,
|
|
196
|
-
None,
|
|
197
|
-
gpu_count,
|
|
198
|
-
log_base,
|
|
199
|
-
"tcp://localhost",
|
|
200
|
-
),
|
|
201
|
-
)
|
|
202
|
-
p.start()
|
|
203
|
-
ps.append(p)
|
|
204
|
-
|
|
205
|
-
# compute preschedule
|
|
206
|
-
preschedule = precompute(job)
|
|
207
|
-
|
|
208
|
-
# check processes started healthy
|
|
209
|
-
for i, p in enumerate(ps):
|
|
210
|
-
if not p.is_alive():
|
|
211
|
-
# TODO ideally we would somehow connect this with the Register message
|
|
212
|
-
# consumption in the Controller -- but there we don't assume that
|
|
213
|
-
# executors are on the same physical host
|
|
214
|
-
raise ValueError(f"executor {i} failed to live due to {p.exitcode}")
|
|
215
|
-
|
|
216
|
-
# start bridge itself
|
|
217
|
-
logger.debug("starting bridge")
|
|
218
|
-
b = Bridge(c, hosts)
|
|
219
|
-
start = perf_counter_ns()
|
|
220
|
-
run(job, b, preschedule, report_address=report_address)
|
|
221
|
-
end = perf_counter_ns()
|
|
222
|
-
print(
|
|
223
|
-
f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
|
|
224
|
-
)
|
|
225
|
-
except Exception:
|
|
226
|
-
# NOTE we log this to get the stacktrace into the logfile
|
|
227
|
-
logger.exception("controller failure, proceed with executor shutdown")
|
|
228
|
-
for p in ps:
|
|
229
|
-
if p.is_alive():
|
|
230
|
-
callback(m, ExecutorShutdown())
|
|
231
|
-
import time
|
|
232
|
-
|
|
233
|
-
time.sleep(1)
|
|
234
|
-
p.kill()
|
|
235
|
-
raise
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def main_local(
|
|
239
|
-
workers_per_host: int,
|
|
240
|
-
hosts: int = 1,
|
|
241
|
-
report_address: str | None = None,
|
|
242
|
-
job: str | None = None,
|
|
243
|
-
instance: str | None = None,
|
|
244
|
-
port_base: int = 12345,
|
|
245
|
-
log_base: str | None = None,
|
|
246
|
-
) -> None:
|
|
247
|
-
jobInstance = get_job(job, instance)
|
|
248
|
-
run_locally(
|
|
249
|
-
jobInstance,
|
|
250
|
-
hosts,
|
|
251
|
-
workers_per_host,
|
|
252
|
-
report_address=report_address,
|
|
253
|
-
portBase=port_base,
|
|
254
|
-
log_base=log_base,
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
def main_dist(
|
|
259
|
-
idx: int,
|
|
260
|
-
controller_url: str,
|
|
261
|
-
hosts: int = 3,
|
|
262
|
-
workers_per_host: int = 10,
|
|
263
|
-
shm_vol_gb: int = 64,
|
|
264
|
-
job: str | None = None,
|
|
265
|
-
instance: str | None = None,
|
|
266
|
-
report_address: str | None = None,
|
|
267
|
-
) -> None:
|
|
268
|
-
"""Entrypoint for *both* controller and worker -- they are on different hosts! Distinguished by idx: 0 for
|
|
269
|
-
controller, 1+ for worker. Assumed to come from slurm procid.
|
|
270
|
-
"""
|
|
271
|
-
launch = perf_counter_ns()
|
|
272
|
-
|
|
273
|
-
jobInstance = get_job(job, instance)
|
|
274
|
-
|
|
275
|
-
if idx == 0:
|
|
276
|
-
logging.config.dictConfig(logging_config)
|
|
277
|
-
tp = ThreadPoolExecutor(max_workers=1)
|
|
278
|
-
preschedule_fut = tp.submit(precompute, jobInstance)
|
|
279
|
-
b = Bridge(controller_url, hosts)
|
|
280
|
-
preschedule = preschedule_fut.result()
|
|
281
|
-
tp.shutdown()
|
|
282
|
-
start = perf_counter_ns()
|
|
283
|
-
run(jobInstance, b, preschedule, report_address=report_address)
|
|
284
|
-
end = perf_counter_ns()
|
|
285
|
-
print(
|
|
286
|
-
f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
|
|
287
|
-
)
|
|
288
|
-
else:
|
|
289
|
-
gpu_count = get_gpu_count(0, workers_per_host)
|
|
290
|
-
launch_executor(
|
|
291
|
-
jobInstance,
|
|
292
|
-
controller_url,
|
|
293
|
-
workers_per_host,
|
|
294
|
-
12345,
|
|
295
|
-
idx,
|
|
296
|
-
shm_vol_gb,
|
|
297
|
-
gpu_count,
|
|
298
|
-
f"tcp://{platform.get_bindabble_self()}",
|
|
299
|
-
)
|
|
300
23
|
|
|
24
|
+
from cascade.benchmarks.util import main_dist, main_local
|
|
301
25
|
|
|
302
26
|
if __name__ == "__main__":
|
|
303
27
|
fire.Fire({"local": main_local, "dist": main_dist})
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import dask.dataframe as dd
|
|
2
|
+
from dask._task_spec import convert_legacy_graph
|
|
3
|
+
|
|
4
|
+
from cascade.low.core import JobInstance
|
|
5
|
+
from cascade.low.dask import graph2job
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_job(job: str) -> JobInstance:
|
|
9
|
+
|
|
10
|
+
if job == "add":
|
|
11
|
+
|
|
12
|
+
def add(x, y):
|
|
13
|
+
result = x + y
|
|
14
|
+
print(f"da {result=}")
|
|
15
|
+
return result
|
|
16
|
+
|
|
17
|
+
dl = {"a": 1, "b": 2, "c": (add, "a", "b")}
|
|
18
|
+
dn = convert_legacy_graph(dl)
|
|
19
|
+
job = graph2job(dn)
|
|
20
|
+
job.ext_outputs = [
|
|
21
|
+
dataset for task in job.tasks for dataset in job.outputs_of(task)
|
|
22
|
+
]
|
|
23
|
+
return job
|
|
24
|
+
elif job == "groupby":
|
|
25
|
+
df = dd.DataFrame.from_dict({"x": [0, 0, 1, 1], "y": [1, 2, 3, 4]})
|
|
26
|
+
df = df.groupby("x").sum()
|
|
27
|
+
job = graph2job(df.__dask_graph__())
|
|
28
|
+
job.ext_outputs = [
|
|
29
|
+
dataset for task in job.tasks for dataset in job.outputs_of(task)
|
|
30
|
+
]
|
|
31
|
+
return job
|
|
32
|
+
else:
|
|
33
|
+
raise NotImplementedError(job)
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
# (C) Copyright 2025- ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""Contains utility methods for benchmark definitions and cluster starting"""
|
|
10
|
+
|
|
11
|
+
# TODO rework, simplify, split into benchmark.util and cluster.setup or smth
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import logging.config
|
|
15
|
+
import multiprocessing
|
|
16
|
+
import os
|
|
17
|
+
import subprocess
|
|
18
|
+
import sys
|
|
19
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
+
from time import perf_counter_ns
|
|
21
|
+
|
|
22
|
+
import orjson
|
|
23
|
+
|
|
24
|
+
import cascade.executor.platform as platform
|
|
25
|
+
import cascade.low.into
|
|
26
|
+
from cascade.controller.impl import run
|
|
27
|
+
from cascade.executor.bridge import Bridge
|
|
28
|
+
from cascade.executor.comms import callback
|
|
29
|
+
from cascade.executor.config import logging_config, logging_config_filehandler
|
|
30
|
+
from cascade.executor.executor import Executor
|
|
31
|
+
from cascade.executor.msg import BackboneAddress, ExecutorShutdown
|
|
32
|
+
from cascade.low.core import JobInstance
|
|
33
|
+
from cascade.low.func import msum
|
|
34
|
+
from cascade.scheduler.precompute import precompute
|
|
35
|
+
from earthkit.workflows.graph import Graph, deduplicate_nodes
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger("cascade.benchmarks")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_job(benchmark: str | None, instance_path: str | None) -> JobInstance:
|
|
41
|
+
# NOTE because of os.environ, we don't import all... ideally we'd have some file-based init/config mech instead
|
|
42
|
+
if benchmark is not None and instance_path is not None:
|
|
43
|
+
raise TypeError("specified both benchmark name and job instance")
|
|
44
|
+
elif instance_path is not None:
|
|
45
|
+
with open(instance_path, "rb") as f:
|
|
46
|
+
d = orjson.loads(f.read())
|
|
47
|
+
return JobInstance(**d)
|
|
48
|
+
elif benchmark is not None:
|
|
49
|
+
if benchmark.startswith("j1"):
|
|
50
|
+
import cascade.benchmarks.job1 as job1
|
|
51
|
+
|
|
52
|
+
graphs = {
|
|
53
|
+
"j1.prob": job1.get_prob(),
|
|
54
|
+
"j1.ensms": job1.get_ensms(),
|
|
55
|
+
"j1.efi": job1.get_efi(),
|
|
56
|
+
}
|
|
57
|
+
union = lambda prefix: deduplicate_nodes(
|
|
58
|
+
msum((v for k, v in graphs.items() if k.startswith(prefix)), Graph)
|
|
59
|
+
)
|
|
60
|
+
graphs["j1.all"] = union("j1.")
|
|
61
|
+
return cascade.low.into.graph2job(graphs[benchmark])
|
|
62
|
+
elif benchmark.startswith("generators"):
|
|
63
|
+
import cascade.benchmarks.generators as generators
|
|
64
|
+
|
|
65
|
+
return generators.get_job()
|
|
66
|
+
elif benchmark.startswith("matmul"):
|
|
67
|
+
import cascade.benchmarks.matmul as matmul
|
|
68
|
+
|
|
69
|
+
return matmul.get_job()
|
|
70
|
+
elif benchmark.startswith("dist"):
|
|
71
|
+
import cascade.benchmarks.dist as dist
|
|
72
|
+
|
|
73
|
+
return dist.get_job()
|
|
74
|
+
elif benchmark.startswith("dask"):
|
|
75
|
+
import cascade.benchmarks.dask as dask
|
|
76
|
+
|
|
77
|
+
return dask.get_job(benchmark[len("dask.") :])
|
|
78
|
+
else:
|
|
79
|
+
raise NotImplementedError(benchmark)
|
|
80
|
+
else:
|
|
81
|
+
raise TypeError("specified neither benchmark name nor job instance")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_cuda_count() -> int:
|
|
85
|
+
try:
|
|
86
|
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
|
87
|
+
# TODO we dont want to just count, we want to actually use literally these ids
|
|
88
|
+
# NOTE this is particularly useful for "" value -- careful when refactoring
|
|
89
|
+
visible = os.environ["CUDA_VISIBLE_DEVICES"]
|
|
90
|
+
visible_count = sum(1 for e in visible if e == ",") + (1 if visible else 0)
|
|
91
|
+
return visible_count
|
|
92
|
+
gpus = sum(
|
|
93
|
+
1
|
|
94
|
+
for l in subprocess.run(
|
|
95
|
+
["nvidia-smi", "--list-gpus"], check=True, capture_output=True
|
|
96
|
+
)
|
|
97
|
+
.stdout.decode("ascii")
|
|
98
|
+
.split("\n")
|
|
99
|
+
if "GPU" in l
|
|
100
|
+
)
|
|
101
|
+
except:
|
|
102
|
+
logger.exception("unable to determine available gpus")
|
|
103
|
+
gpus = 0
|
|
104
|
+
return gpus
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_gpu_count(host_idx: int, worker_count: int) -> int:
|
|
108
|
+
if sys.platform == "darwin":
|
|
109
|
+
# we should inspect some gpu capabilities details to prevent overcommit
|
|
110
|
+
return worker_count
|
|
111
|
+
else:
|
|
112
|
+
if host_idx == 0:
|
|
113
|
+
return get_cuda_count()
|
|
114
|
+
else:
|
|
115
|
+
return 0
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def launch_executor(
|
|
119
|
+
job_instance: JobInstance,
|
|
120
|
+
controller_address: BackboneAddress,
|
|
121
|
+
workers_per_host: int,
|
|
122
|
+
portBase: int,
|
|
123
|
+
i: int,
|
|
124
|
+
shm_vol_gb: int | None,
|
|
125
|
+
gpu_count: int,
|
|
126
|
+
log_base: str | None,
|
|
127
|
+
url_base: str,
|
|
128
|
+
):
|
|
129
|
+
if log_base is not None:
|
|
130
|
+
log_base = f"{log_base}.host{i}"
|
|
131
|
+
log_path = f"{log_base}.txt"
|
|
132
|
+
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
133
|
+
else:
|
|
134
|
+
logging.config.dictConfig(logging_config)
|
|
135
|
+
try:
|
|
136
|
+
logger.info(f"will set {gpu_count} gpus on host {i}")
|
|
137
|
+
os.environ["CASCADE_GPU_COUNT"] = str(gpu_count)
|
|
138
|
+
executor = Executor(
|
|
139
|
+
job_instance,
|
|
140
|
+
controller_address,
|
|
141
|
+
workers_per_host,
|
|
142
|
+
f"h{i}",
|
|
143
|
+
portBase,
|
|
144
|
+
shm_vol_gb,
|
|
145
|
+
log_base,
|
|
146
|
+
url_base,
|
|
147
|
+
)
|
|
148
|
+
executor.register()
|
|
149
|
+
executor.recv_loop()
|
|
150
|
+
except Exception:
|
|
151
|
+
# NOTE we log this to get the stacktrace into the logfile
|
|
152
|
+
logger.exception("executor failure")
|
|
153
|
+
raise
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def run_locally(
|
|
157
|
+
job: JobInstance,
|
|
158
|
+
hosts: int,
|
|
159
|
+
workers: int,
|
|
160
|
+
portBase: int = 12345,
|
|
161
|
+
log_base: str | None = None,
|
|
162
|
+
report_address: str | None = None,
|
|
163
|
+
) -> None:
|
|
164
|
+
if log_base is not None:
|
|
165
|
+
log_path = f"{log_base}.controller.txt"
|
|
166
|
+
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
167
|
+
else:
|
|
168
|
+
logging.config.dictConfig(logging_config)
|
|
169
|
+
logger.debug(f"local run starting with {hosts=} and {workers=} on {portBase=}")
|
|
170
|
+
launch = perf_counter_ns()
|
|
171
|
+
c = f"tcp://localhost:{portBase}"
|
|
172
|
+
m = f"tcp://localhost:{portBase+1}"
|
|
173
|
+
ps = []
|
|
174
|
+
try:
|
|
175
|
+
# executors forking
|
|
176
|
+
for i, executor in enumerate(range(hosts)):
|
|
177
|
+
gpu_count = get_gpu_count(i, workers)
|
|
178
|
+
# NOTE forkserver/spawn seem to forget venv, we need fork
|
|
179
|
+
logger.debug(f"forking into executor on host {i}")
|
|
180
|
+
p = multiprocessing.get_context("fork").Process(
|
|
181
|
+
target=launch_executor,
|
|
182
|
+
args=(
|
|
183
|
+
job,
|
|
184
|
+
c,
|
|
185
|
+
workers,
|
|
186
|
+
portBase + 1 + i * 10,
|
|
187
|
+
i,
|
|
188
|
+
None,
|
|
189
|
+
gpu_count,
|
|
190
|
+
log_base,
|
|
191
|
+
"tcp://localhost",
|
|
192
|
+
),
|
|
193
|
+
)
|
|
194
|
+
p.start()
|
|
195
|
+
ps.append(p)
|
|
196
|
+
|
|
197
|
+
# compute preschedule
|
|
198
|
+
preschedule = precompute(job)
|
|
199
|
+
|
|
200
|
+
# check processes started healthy
|
|
201
|
+
for i, p in enumerate(ps):
|
|
202
|
+
if not p.is_alive():
|
|
203
|
+
# TODO ideally we would somehow connect this with the Register message
|
|
204
|
+
# consumption in the Controller -- but there we don't assume that
|
|
205
|
+
# executors are on the same physical host
|
|
206
|
+
raise ValueError(f"executor {i} failed to live due to {p.exitcode}")
|
|
207
|
+
|
|
208
|
+
# start bridge itself
|
|
209
|
+
logger.debug("starting bridge")
|
|
210
|
+
b = Bridge(c, hosts)
|
|
211
|
+
start = perf_counter_ns()
|
|
212
|
+
result = run(job, b, preschedule, report_address=report_address)
|
|
213
|
+
end = perf_counter_ns()
|
|
214
|
+
print(
|
|
215
|
+
f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
|
|
216
|
+
)
|
|
217
|
+
if os.environ.get("CASCADE_DEBUG_PRINT"):
|
|
218
|
+
for key, value in result.outputs.items():
|
|
219
|
+
print(f"{key} => {value}")
|
|
220
|
+
except Exception:
|
|
221
|
+
# NOTE we log this to get the stacktrace into the logfile
|
|
222
|
+
logger.exception("controller failure, proceed with executor shutdown")
|
|
223
|
+
for p in ps:
|
|
224
|
+
if p.is_alive():
|
|
225
|
+
callback(m, ExecutorShutdown())
|
|
226
|
+
import time
|
|
227
|
+
|
|
228
|
+
time.sleep(1)
|
|
229
|
+
p.kill()
|
|
230
|
+
raise
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def main_local(
|
|
234
|
+
workers_per_host: int,
|
|
235
|
+
hosts: int = 1,
|
|
236
|
+
report_address: str | None = None,
|
|
237
|
+
job: str | None = None,
|
|
238
|
+
instance: str | None = None,
|
|
239
|
+
port_base: int = 12345,
|
|
240
|
+
log_base: str | None = None,
|
|
241
|
+
) -> None:
|
|
242
|
+
jobInstance = get_job(job, instance)
|
|
243
|
+
run_locally(
|
|
244
|
+
jobInstance,
|
|
245
|
+
hosts,
|
|
246
|
+
workers_per_host,
|
|
247
|
+
report_address=report_address,
|
|
248
|
+
portBase=port_base,
|
|
249
|
+
log_base=log_base,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def main_dist(
|
|
254
|
+
idx: int,
|
|
255
|
+
controller_url: str,
|
|
256
|
+
hosts: int = 3,
|
|
257
|
+
workers_per_host: int = 10,
|
|
258
|
+
shm_vol_gb: int = 64,
|
|
259
|
+
job: str | None = None,
|
|
260
|
+
instance: str | None = None,
|
|
261
|
+
report_address: str | None = None,
|
|
262
|
+
) -> None:
|
|
263
|
+
"""Entrypoint for *both* controller and worker -- they are on different hosts! Distinguished by idx: 0 for
|
|
264
|
+
controller, 1+ for worker. Assumed to come from slurm procid.
|
|
265
|
+
"""
|
|
266
|
+
launch = perf_counter_ns()
|
|
267
|
+
|
|
268
|
+
jobInstance = get_job(job, instance)
|
|
269
|
+
|
|
270
|
+
if idx == 0:
|
|
271
|
+
logging.config.dictConfig(logging_config)
|
|
272
|
+
tp = ThreadPoolExecutor(max_workers=1)
|
|
273
|
+
preschedule_fut = tp.submit(precompute, jobInstance)
|
|
274
|
+
b = Bridge(controller_url, hosts)
|
|
275
|
+
preschedule = preschedule_fut.result()
|
|
276
|
+
tp.shutdown()
|
|
277
|
+
start = perf_counter_ns()
|
|
278
|
+
run(jobInstance, b, preschedule, report_address=report_address)
|
|
279
|
+
end = perf_counter_ns()
|
|
280
|
+
print(
|
|
281
|
+
f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
|
|
282
|
+
)
|
|
283
|
+
else:
|
|
284
|
+
gpu_count = get_gpu_count(0, workers_per_host)
|
|
285
|
+
launch_executor(
|
|
286
|
+
jobInstance,
|
|
287
|
+
controller_url,
|
|
288
|
+
workers_per_host,
|
|
289
|
+
12345,
|
|
290
|
+
idx,
|
|
291
|
+
shm_vol_gb,
|
|
292
|
+
gpu_count,
|
|
293
|
+
f"tcp://{platform.get_bindabble_self()}",
|
|
294
|
+
)
|
cascade/executor/executor.py
CHANGED
|
@@ -22,6 +22,8 @@ from multiprocessing import get_context
|
|
|
22
22
|
from multiprocessing.process import BaseProcess
|
|
23
23
|
from typing import Iterable
|
|
24
24
|
|
|
25
|
+
import cloudpickle
|
|
26
|
+
|
|
25
27
|
import cascade.executor.platform as platform
|
|
26
28
|
import cascade.shm.api as shm_api
|
|
27
29
|
import cascade.shm.client as shm_client
|
|
@@ -187,9 +189,8 @@ class Executor:
|
|
|
187
189
|
|
|
188
190
|
def start_workers(self, workers: Iterable[WorkerId]) -> None:
|
|
189
191
|
# TODO this method assumes no other message will arrive to mlistener! Thus cannot be used for workers now
|
|
190
|
-
# NOTE
|
|
191
|
-
|
|
192
|
-
ctx = get_context("fork")
|
|
192
|
+
# NOTE fork would be better but causes issues on macos+torch with XPC_ERROR_CONNECTION_INVALID
|
|
193
|
+
ctx = get_context("forkserver")
|
|
193
194
|
for worker in workers:
|
|
194
195
|
runnerContext = RunnerContext(
|
|
195
196
|
workerId=worker,
|
|
@@ -198,7 +199,11 @@ class Executor:
|
|
|
198
199
|
callback=self.mlistener.address,
|
|
199
200
|
log_base=self.log_base,
|
|
200
201
|
)
|
|
201
|
-
|
|
202
|
+
# NOTE we need to cloudpickle because runnerContext contains some lambdas
|
|
203
|
+
p = ctx.Process(
|
|
204
|
+
target=entrypoint,
|
|
205
|
+
kwargs={"runnerContext": cloudpickle.dumps(runnerContext)},
|
|
206
|
+
)
|
|
202
207
|
p.start()
|
|
203
208
|
self.workers[worker] = p
|
|
204
209
|
logger.debug(f"started process {p.pid} for worker {worker}")
|
cascade/executor/platform.py
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
"""Macos-vs-Linux specific code"""
|
|
10
10
|
|
|
11
|
+
import os
|
|
11
12
|
import socket
|
|
12
13
|
import sys
|
|
13
14
|
|
|
@@ -22,3 +23,14 @@ def get_bindabble_self():
|
|
|
22
23
|
else:
|
|
23
24
|
# NOTE not sure if fqdn or hostname is better -- all we need is for it to be resolvable within cluster
|
|
24
25
|
return socket.gethostname() # socket.getfqdn()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def gpu_init(worker_num: int):
|
|
29
|
+
if sys.platform != "darwin":
|
|
30
|
+
# TODO there is implicit coupling with executor.executor and benchmarks.main -- make it cleaner!
|
|
31
|
+
gpus = int(os.environ.get("CASCADE_GPU_COUNT", "0"))
|
|
32
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = (
|
|
33
|
+
str(worker_num) if worker_num < gpus else ""
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
pass # no macos specific gpu init due to unified mem model
|
|
@@ -11,11 +11,13 @@
|
|
|
11
11
|
import logging
|
|
12
12
|
import logging.config
|
|
13
13
|
import os
|
|
14
|
-
import sys
|
|
15
14
|
from dataclasses import dataclass
|
|
15
|
+
from typing import Any
|
|
16
16
|
|
|
17
|
+
import cloudpickle
|
|
17
18
|
import zmq
|
|
18
19
|
|
|
20
|
+
import cascade.executor.platform as platform
|
|
19
21
|
import cascade.executor.serde as serde
|
|
20
22
|
from cascade.executor.comms import callback
|
|
21
23
|
from cascade.executor.config import logging_config, logging_config_filehandler
|
|
@@ -118,7 +120,9 @@ def execute_sequence(
|
|
|
118
120
|
)
|
|
119
121
|
|
|
120
122
|
|
|
121
|
-
def entrypoint(runnerContext:
|
|
123
|
+
def entrypoint(runnerContext: Any):
|
|
124
|
+
"""runnerContext is a cloudpickled instance of RunnerContext -- needed for forkserver mp context due to defautdicts"""
|
|
125
|
+
runnerContext = cloudpickle.loads(runnerContext)
|
|
122
126
|
if runnerContext.log_base:
|
|
123
127
|
log_path = f"{runnerContext.log_base}.{runnerContext.workerId.worker}"
|
|
124
128
|
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
@@ -134,16 +138,8 @@ def entrypoint(runnerContext: RunnerContext):
|
|
|
134
138
|
):
|
|
135
139
|
label("worker", repr(runnerContext.workerId))
|
|
136
140
|
worker_num = runnerContext.workerId.worker_num()
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = (
|
|
140
|
-
str(worker_num) if worker_num < gpus else ""
|
|
141
|
-
)
|
|
142
|
-
# NOTE check any(task.definition.needs_gpu) anywhere?
|
|
143
|
-
# TODO configure OMP_NUM_THREADS, blas, mkl, etc -- not clear how tho
|
|
144
|
-
else:
|
|
145
|
-
if gpus != 1:
|
|
146
|
-
logger.warning("unexpected absence of gpu on darwin")
|
|
141
|
+
platform.gpu_init(worker_num)
|
|
142
|
+
# TODO configure OMP_NUM_THREADS, blas, mkl, etc -- not clear how tho
|
|
147
143
|
|
|
148
144
|
for serdeTypeEnc, (serdeSer, serdeDes) in runnerContext.job.serdes.items():
|
|
149
145
|
serde.SerdeRegistry.register(type_dec(serdeTypeEnc), serdeSer, serdeDes)
|
cascade/gateway/__main__.py
CHANGED
|
@@ -14,13 +14,15 @@ from cascade.executor.config import logging_config, logging_config_filehandler
|
|
|
14
14
|
from cascade.gateway.server import serve
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def main(
|
|
17
|
+
def main(
|
|
18
|
+
url: str, log_base: str | None = None, troika_config: str | None = None
|
|
19
|
+
) -> None:
|
|
18
20
|
if log_base:
|
|
19
21
|
log_path = f"{log_base}/gateway.txt"
|
|
20
22
|
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
21
23
|
else:
|
|
22
24
|
logging.config.dictConfig(logging_config)
|
|
23
|
-
serve(url, log_base)
|
|
25
|
+
serve(url, log_base, troika_config)
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
if __name__ == "__main__":
|
cascade/gateway/api.py
CHANGED
|
@@ -19,6 +19,18 @@ from cascade.low.core import DatasetId, JobInstance
|
|
|
19
19
|
CascadeGatewayAPI = BaseModel
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
@dataclass
|
|
23
|
+
class TroikaSpec:
|
|
24
|
+
"""Requires the gateway to have been started with --troika_config pointing
|
|
25
|
+
to some config.yml troika file. The connection must work (passwordlessly),
|
|
26
|
+
and must allow for script being copied. The remote host must have a venv
|
|
27
|
+
already in place, and must be able to resolve gateway's fqdn
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
venv: str # remote host path to venv -- *do* include the bin/activate
|
|
31
|
+
conn: str # which connection from config.yml to pick
|
|
32
|
+
|
|
33
|
+
|
|
22
34
|
@dataclass
|
|
23
35
|
class JobSpec:
|
|
24
36
|
# job benchmark + envvars -- set to None/{} if using custom jobs instead
|
|
@@ -33,6 +45,7 @@ class JobSpec:
|
|
|
33
45
|
workers_per_host: int
|
|
34
46
|
hosts: int
|
|
35
47
|
use_slurm: bool
|
|
48
|
+
troika: TroikaSpec | None = None
|
|
36
49
|
|
|
37
50
|
|
|
38
51
|
class SubmitJobRequest(CascadeGatewayAPI):
|
cascade/gateway/router.py
CHANGED
|
@@ -8,9 +8,11 @@
|
|
|
8
8
|
|
|
9
9
|
"""Represents information about submitted jobs. The main business logic of `cascade.gateway`"""
|
|
10
10
|
|
|
11
|
+
import base64
|
|
11
12
|
import itertools
|
|
12
13
|
import logging
|
|
13
14
|
import os
|
|
15
|
+
import stat
|
|
14
16
|
import subprocess
|
|
15
17
|
import uuid
|
|
16
18
|
from dataclasses import dataclass
|
|
@@ -22,7 +24,7 @@ import zmq
|
|
|
22
24
|
import cascade.executor.platform as platform
|
|
23
25
|
from cascade.controller.report import JobId, JobProgress, JobProgressStarted
|
|
24
26
|
from cascade.executor.comms import get_context
|
|
25
|
-
from cascade.gateway.api import JobSpec
|
|
27
|
+
from cascade.gateway.api import JobSpec, TroikaSpec
|
|
26
28
|
from cascade.low.core import DatasetId
|
|
27
29
|
from cascade.low.func import next_uuid
|
|
28
30
|
|
|
@@ -43,6 +45,65 @@ class Job:
|
|
|
43
45
|
local_job_port = 12345
|
|
44
46
|
|
|
45
47
|
|
|
48
|
+
def _spawn_troika_singlehost(
|
|
49
|
+
job_spec: JobSpec, addr: str, job_id: JobId, troika: TroikaSpec, troika_config: str
|
|
50
|
+
) -> subprocess.Popen:
|
|
51
|
+
script = "#!/bin/bash\n"
|
|
52
|
+
script += f"source {troika.venv}\n"
|
|
53
|
+
for k, v in job_spec.envvars.items():
|
|
54
|
+
script += f"export {k}={v}\n"
|
|
55
|
+
if job_spec.benchmark_name is not None:
|
|
56
|
+
if job_spec.job_instance is not None:
|
|
57
|
+
raise TypeError("specified both benchmark name and job instance")
|
|
58
|
+
script += "python -m cascade.benchmarks local"
|
|
59
|
+
script += f" --job {job_spec.benchmark_name}"
|
|
60
|
+
else:
|
|
61
|
+
if job_spec.job_instance is None:
|
|
62
|
+
raise TypeError("specified neither benchmark name nor job instance")
|
|
63
|
+
job_desc_raw = orjson.dumps(job_spec.job_instance.dict())
|
|
64
|
+
job_desc_enc = base64.b64encode(job_desc_raw).decode("ascii")
|
|
65
|
+
script += f'JOB_ENC="{job_desc_enc}"'
|
|
66
|
+
job_json_path = f"/tmp/cascJob.{job_id}.json"
|
|
67
|
+
script += f'echo "$JOB_ENC" | base64 --decode > {job_json_path}'
|
|
68
|
+
script += "python -m cascade.benchmarks local"
|
|
69
|
+
script += f" --instance {job_json_path}"
|
|
70
|
+
|
|
71
|
+
script += (
|
|
72
|
+
f" --workers_per_host {job_spec.workers_per_host} --hosts {job_spec.hosts}"
|
|
73
|
+
)
|
|
74
|
+
script += f" --report_address {addr},{job_id}"
|
|
75
|
+
# NOTE technically not needed to be globally unique, but we cant rely on troika environment isolation...
|
|
76
|
+
global local_job_port
|
|
77
|
+
script += f" --port_base {local_job_port}"
|
|
78
|
+
local_job_port += 1 + job_spec.hosts * job_spec.workers_per_host * 10
|
|
79
|
+
script += "\n"
|
|
80
|
+
script_path = f"/tmp/troikascade.{job_id}.sh"
|
|
81
|
+
with open(script_path, "w") as f:
|
|
82
|
+
f.write(script)
|
|
83
|
+
os.chmod(
|
|
84
|
+
script_path,
|
|
85
|
+
stat.S_IRUSR
|
|
86
|
+
| stat.S_IRGRP
|
|
87
|
+
| stat.S_IROTH
|
|
88
|
+
| stat.S_IWUSR
|
|
89
|
+
| stat.S_IXUSR
|
|
90
|
+
| stat.S_IXGRP
|
|
91
|
+
| stat.S_IXOTH,
|
|
92
|
+
)
|
|
93
|
+
return subprocess.Popen(
|
|
94
|
+
[
|
|
95
|
+
"troika",
|
|
96
|
+
"-c",
|
|
97
|
+
troika_config,
|
|
98
|
+
"submit",
|
|
99
|
+
"-o",
|
|
100
|
+
f"/tmp/output.{job_id}.txt",
|
|
101
|
+
troika.conn,
|
|
102
|
+
script_path,
|
|
103
|
+
]
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
46
107
|
def _spawn_local(
|
|
47
108
|
job_spec: JobSpec, addr: str, job_id: JobId, log_base: str | None
|
|
48
109
|
) -> subprocess.Popen:
|
|
@@ -112,9 +173,26 @@ def _spawn_slurm(job_spec: JobSpec, addr: str, job_id: JobId) -> subprocess.Pope
|
|
|
112
173
|
|
|
113
174
|
|
|
114
175
|
def _spawn_subprocess(
|
|
115
|
-
job_spec: JobSpec,
|
|
176
|
+
job_spec: JobSpec,
|
|
177
|
+
addr: str,
|
|
178
|
+
job_id: JobId,
|
|
179
|
+
log_base: str | None,
|
|
180
|
+
troika_config: str | None,
|
|
116
181
|
) -> subprocess.Popen:
|
|
117
|
-
if job_spec.
|
|
182
|
+
if job_spec.troika is not None:
|
|
183
|
+
if log_base is not None:
|
|
184
|
+
raise ValueError(f"unexpected {log_base=}")
|
|
185
|
+
if troika_config is None:
|
|
186
|
+
raise ValueError("cant spawn troika job without troika config")
|
|
187
|
+
if not job_spec.use_slurm:
|
|
188
|
+
return _spawn_troika_singlehost(
|
|
189
|
+
job_spec, addr, job_id, job_spec.troika, troika_config
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
# TODO create a slurm script like in spawn_slurm, but dont refer to any other file
|
|
193
|
+
raise NotImplementedError
|
|
194
|
+
|
|
195
|
+
elif job_spec.use_slurm:
|
|
118
196
|
if log_base is not None:
|
|
119
197
|
raise ValueError(f"unexpected {log_base=}")
|
|
120
198
|
return _spawn_slurm(job_spec, addr, job_id)
|
|
@@ -123,11 +201,14 @@ def _spawn_subprocess(
|
|
|
123
201
|
|
|
124
202
|
|
|
125
203
|
class JobRouter:
|
|
126
|
-
def __init__(
|
|
204
|
+
def __init__(
|
|
205
|
+
self, poller: zmq.Poller, log_base: str | None, troika_config: str | None
|
|
206
|
+
):
|
|
127
207
|
self.poller = poller
|
|
128
208
|
self.jobs: dict[str, Job] = {}
|
|
129
209
|
self.procs: dict[str, subprocess.Popen] = {}
|
|
130
210
|
self.log_base = log_base
|
|
211
|
+
self.troika_config = troika_config
|
|
131
212
|
|
|
132
213
|
def spawn_job(self, job_spec: JobSpec) -> JobId:
|
|
133
214
|
job_id = next_uuid(self.jobs.keys(), lambda: str(uuid.uuid4()))
|
|
@@ -139,7 +220,7 @@ class JobRouter:
|
|
|
139
220
|
self.poller.register(socket, flags=zmq.POLLIN)
|
|
140
221
|
self.jobs[job_id] = Job(socket, JobProgressStarted, -1, {})
|
|
141
222
|
self.procs[job_id] = _spawn_subprocess(
|
|
142
|
-
job_spec, full_addr, job_id, self.log_base
|
|
223
|
+
job_spec, full_addr, job_id, self.log_base, self.troika_config
|
|
143
224
|
)
|
|
144
225
|
return job_id
|
|
145
226
|
|
cascade/gateway/server.py
CHANGED
|
@@ -79,14 +79,16 @@ def handle_controller(socket: zmq.Socket, jobs: JobRouter) -> None:
|
|
|
79
79
|
jobs.put_result(report.job_id, dataset_id, result)
|
|
80
80
|
|
|
81
81
|
|
|
82
|
-
def serve(
|
|
82
|
+
def serve(
|
|
83
|
+
url: str, log_base: str | None = None, troika_config: str | None = None
|
|
84
|
+
) -> None:
|
|
83
85
|
ctx = get_context()
|
|
84
86
|
poller = zmq.Poller()
|
|
85
87
|
|
|
86
88
|
fe = ctx.socket(zmq.REP)
|
|
87
89
|
fe.bind(url)
|
|
88
90
|
poller.register(fe, flags=zmq.POLLIN)
|
|
89
|
-
jobs = JobRouter(poller, log_base)
|
|
91
|
+
jobs = JobRouter(poller, log_base, troika_config)
|
|
90
92
|
|
|
91
93
|
logger.debug("entering recv loop")
|
|
92
94
|
is_break = False
|
cascade/low/dask.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# (C) Copyright 2025- ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""Experimental module to convert dask graphs into cascade jobs. May not preserve
|
|
10
|
+
all semantics.
|
|
11
|
+
|
|
12
|
+
We don't explicitly support legacy dask graph -- you need to invoke
|
|
13
|
+
`dask._task_spec.convert_legacy_graph` yourself.
|
|
14
|
+
For higher level dask objects, extract the graph via `__dask_graph__()` (defined
|
|
15
|
+
eg on dd.DataFrame or dask.delayed objects).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from dask._task_spec import Alias, DataNode, Task, TaskRef
|
|
22
|
+
|
|
23
|
+
from cascade.low.builders import TaskBuilder
|
|
24
|
+
from cascade.low.core import DatasetId, JobInstance, Task2TaskEdge, TaskInstance
|
|
25
|
+
from earthkit.workflows.graph import Node
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def daskKeyRepr(key: str | int | float | tuple) -> str:
|
|
31
|
+
return repr(key)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def task2task(key: str, task: Task) -> tuple[TaskInstance, list[Task2TaskEdge]]:
|
|
35
|
+
instance = TaskBuilder.from_callable(task.func)
|
|
36
|
+
edges: list[Task2TaskEdge] = []
|
|
37
|
+
|
|
38
|
+
for i, v in enumerate(task.args):
|
|
39
|
+
if isinstance(v, Alias | TaskRef):
|
|
40
|
+
edge = Task2TaskEdge(
|
|
41
|
+
source=DatasetId(task=daskKeyRepr(v.key), output=Node.DEFAULT_OUTPUT),
|
|
42
|
+
sink_task=key,
|
|
43
|
+
sink_input_ps=str(i),
|
|
44
|
+
sink_input_kw=None,
|
|
45
|
+
)
|
|
46
|
+
edges.append(edge)
|
|
47
|
+
elif isinstance(v, Task):
|
|
48
|
+
# TODO
|
|
49
|
+
raise NotImplementedError
|
|
50
|
+
else:
|
|
51
|
+
instance.static_input_ps[f"{i}"] = v
|
|
52
|
+
for k, v in task.kwargs.items():
|
|
53
|
+
if isinstance(v, Alias | TaskRef):
|
|
54
|
+
edge = Task2TaskEdge(
|
|
55
|
+
source=DatasetId(task=daskKeyRepr(v.key), output=Node.DEFAULT_OUTPUT),
|
|
56
|
+
sink_task=key,
|
|
57
|
+
sink_input_kw=k,
|
|
58
|
+
sink_input_ps=None,
|
|
59
|
+
)
|
|
60
|
+
edges.append(edge)
|
|
61
|
+
elif isinstance(v, Task):
|
|
62
|
+
# TODO
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
else:
|
|
65
|
+
instance.static_input_kw[k] = v
|
|
66
|
+
|
|
67
|
+
return instance, edges
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def graph2job(dask: dict) -> JobInstance:
|
|
71
|
+
task_nodes = {}
|
|
72
|
+
edges = []
|
|
73
|
+
|
|
74
|
+
for node, value in dask.items():
|
|
75
|
+
key = daskKeyRepr(node)
|
|
76
|
+
if isinstance(value, DataNode):
|
|
77
|
+
|
|
78
|
+
def provider() -> Any:
|
|
79
|
+
return value.value
|
|
80
|
+
|
|
81
|
+
task_nodes[key] = TaskBuilder.from_callable(provider)
|
|
82
|
+
elif isinstance(value, Task):
|
|
83
|
+
node, _edges = task2task(key, value)
|
|
84
|
+
task_nodes[key] = node
|
|
85
|
+
edges.extend(_edges)
|
|
86
|
+
elif isinstance(value, list | tuple | set):
|
|
87
|
+
# TODO implement, consult further:
|
|
88
|
+
# https://docs.dask.org/en/stable/spec.html
|
|
89
|
+
# https://docs.dask.org/en/stable/custom-graphs.html
|
|
90
|
+
# https://github.com/dask/dask/blob/main/dask/_task_spec.py#L829
|
|
91
|
+
logger.warning("encountered nested container => confused ostrich")
|
|
92
|
+
continue
|
|
93
|
+
else:
|
|
94
|
+
raise NotImplementedError
|
|
95
|
+
|
|
96
|
+
return JobInstance(
|
|
97
|
+
tasks=task_nodes,
|
|
98
|
+
edges=edges,
|
|
99
|
+
)
|
cascade/low/into.py
CHANGED
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
|
|
9
|
-
"""Lowering of the
|
|
9
|
+
"""Lowering of the earthkit.workflows.graph structures into cascade.low representation"""
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
from typing import Any, Callable, cast
|
earthkit/workflows/_version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# Do not change! Do not track in version control!
|
|
2
|
-
__version__ = "0.4.
|
|
2
|
+
__version__ = "0.4.4"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: earthkit-workflows
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.4
|
|
4
4
|
Summary: Earthkit Workflows is a Python library for declaring earthkit task DAGs, as well as scheduling and executing them on heterogeneous computing systems.
|
|
5
5
|
Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
cascade/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
cascade/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
cascade/benchmarks/__init__.py,sha256=Gu8kEApmJ2zsIhT2zpm1-6n84-OwWnz-0vO8UHYtBzo,528
|
|
4
|
-
cascade/benchmarks/__main__.py,sha256=
|
|
4
|
+
cascade/benchmarks/__main__.py,sha256=z3Ib0NlIgMrn2zjrJhqqnJkjCIb4xKDSpO5vF9j-Onc,966
|
|
5
5
|
cascade/benchmarks/anemoi.py,sha256=qtAI03HdtAmcksCgjIEZyNyUNzMp370KF4lAh5g4cOk,1077
|
|
6
|
+
cascade/benchmarks/dask.py,sha256=U0B0jpLIeIs4Zl0SX_opMypXQXOIS5ER6mGdtPCgqkQ,953
|
|
6
7
|
cascade/benchmarks/dist.py,sha256=ngXJJzegnMUVwDFPvGMG6997lamB-aSEHi74oBbayrE,4116
|
|
7
8
|
cascade/benchmarks/generators.py,sha256=NK4fFisWsZdMkA2Auzrn-P7G5D9AKpo2JVnqXE44YT8,2169
|
|
8
9
|
cascade/benchmarks/job1.py,sha256=MOcZZYgf36MzHCjtby0lQyenM1ODUlagG8wtt2CbpnI,4640
|
|
9
10
|
cascade/benchmarks/matmul.py,sha256=5STuvPY6Q37E2pKRCde9dQjL5M6tx7tkES9cBLZ6eK4,1972
|
|
10
11
|
cascade/benchmarks/plotting.py,sha256=vSz9HHbqZwMXHpBUS-In6xsXGgK7QIoQTTiYfSwYwZs,4428
|
|
11
12
|
cascade/benchmarks/reporting.py,sha256=MejaM-eekbMYLAnuBxGv_t4dR1ODJs4Rpc0fiZSGjyw,5410
|
|
13
|
+
cascade/benchmarks/util.py,sha256=obgRxtRcz023lvrtnI8vzDtgVqlqlRrWCDy_lwuML30,9835
|
|
12
14
|
cascade/controller/__init__.py,sha256=p4C2p3S_0nUGamP9Mi6cSa5bvpiWbI6sVWtGhFnNqjw,1278
|
|
13
15
|
cascade/controller/act.py,sha256=WHIsk4H-Bbyl_DABX2VWhyKy_cNnp12x1nilatPCL8I,2981
|
|
14
16
|
cascade/controller/core.py,sha256=NqvZ5g5GNphwOpzdXbCI0_fxIzzmO97_n2xZKswK72Q,3589
|
|
@@ -19,27 +21,28 @@ cascade/executor/bridge.py,sha256=WDE-GM2Bv7nUk1-nV-otMGuaRYw1-Vmd7PWploXBp6Y,82
|
|
|
19
21
|
cascade/executor/comms.py,sha256=-9qrKwva6WXkHRQtzSnLFy5gB3bOWuxYJP5fL6Uavw8,8736
|
|
20
22
|
cascade/executor/config.py,sha256=8azy_sXdvDGO0zTNqA0pdtkXsyihM4FQ4U1W_3Dhua0,1571
|
|
21
23
|
cascade/executor/data_server.py,sha256=xLIbLkWn8PnJl4lMP8ADHa2S0EgPwr0-bH7_Sib_Y70,13701
|
|
22
|
-
cascade/executor/executor.py,sha256=
|
|
24
|
+
cascade/executor/executor.py,sha256=3I9QnyX-YvJvGnMSM4kWfBJDgi_uUCv0M4ncXf4z85o,13659
|
|
23
25
|
cascade/executor/msg.py,sha256=7HI0rKeCRaV1ONR4HWEa64nHbu-p6-QdBwJNitmst48,4340
|
|
24
|
-
cascade/executor/platform.py,sha256=
|
|
26
|
+
cascade/executor/platform.py,sha256=6uLdcH8mibvIQfF1nTSvtfyym4r6dLThNSF1JZ-6mLM,1393
|
|
25
27
|
cascade/executor/serde.py,sha256=z6klTOZqW_BVGrbIRNz4FN0_XTfRiKBRQuvgsQIuyAo,2827
|
|
26
28
|
cascade/executor/runner/__init__.py,sha256=30BM80ZyA7w3IrGiKKLSFuhRehbR2Mm99OJ8q5PJ63c,1547
|
|
27
|
-
cascade/executor/runner/entrypoint.py,sha256=
|
|
29
|
+
cascade/executor/runner/entrypoint.py,sha256=WyxOFGAYDQD_fXsM4H9_6xBrnAmQrCTUnljfcW6-BoM,7918
|
|
28
30
|
cascade/executor/runner/memory.py,sha256=jkAV9T7-imciVcGvkV7OhRfosEpOQJU1OME7z-4ztAs,6371
|
|
29
31
|
cascade/executor/runner/packages.py,sha256=OZjEOvKy8LQ2uguGZU1L7TVYz1415JOUGySRfU_D_sc,2513
|
|
30
32
|
cascade/executor/runner/runner.py,sha256=zqpkvxdWLbwyUFaUbZmSj0KQEBNRpmF8gwVotiaamhc,4870
|
|
31
33
|
cascade/gateway/__init__.py,sha256=1EzMKdLFXEucj0YWOlyVqLx4suOntitwM03T_rRubIk,829
|
|
32
|
-
cascade/gateway/__main__.py,sha256=
|
|
33
|
-
cascade/gateway/api.py,sha256=-
|
|
34
|
+
cascade/gateway/__main__.py,sha256=F_wft7ja5ckM0SqeXsy_u2j-Ch6OTlpbTTlYtDkvGMI,917
|
|
35
|
+
cascade/gateway/api.py,sha256=-Vuo9fDqFNFIofcHZ79UB1rTWnQR3D9Pna2CjqdyHaE,3021
|
|
34
36
|
cascade/gateway/client.py,sha256=1p4Tvrf-BH0LQHOES5rY1z3JNIfmXcqWG2kYl4rpcE0,4061
|
|
35
|
-
cascade/gateway/router.py,sha256=
|
|
36
|
-
cascade/gateway/server.py,sha256=
|
|
37
|
+
cascade/gateway/router.py,sha256=RcDniyPOZnu6_HuMUrQjZ4P-PoUbUVezvYXG_ryBLUg,10399
|
|
38
|
+
cascade/gateway/server.py,sha256=vb3z0TfoMvSHqczhmYgzeXGVcw2M9yGpyW0t6d57Oag,3827
|
|
37
39
|
cascade/low/__init__.py,sha256=5cw2taOGITK_gFbICftzK2YLdEAnLUY5OzblFzdHss4,769
|
|
38
40
|
cascade/low/builders.py,sha256=_u5X8G_EF00hFt8Anv9AXo6yPf1O8MHDmqs2kKmREl0,7073
|
|
39
41
|
cascade/low/core.py,sha256=_3x4ka_pmCgZbfwFeyhq8S4M6wmh0s24VRCLhk5yQFM,6444
|
|
42
|
+
cascade/low/dask.py,sha256=xToT_vyfkgUUxSFN7dS7qLttxzuBbBZfDylPzGg7sPg,3319
|
|
40
43
|
cascade/low/execution_context.py,sha256=cdDJLYhreo4T7t4qXgFBosncubZpTrm0hELo7q4miqo,6640
|
|
41
44
|
cascade/low/func.py,sha256=ihL5n3cK-IJnATgP4Dub2m-Mp_jHMxJzCA1v4uMEsi8,5211
|
|
42
|
-
cascade/low/into.py,sha256=
|
|
45
|
+
cascade/low/into.py,sha256=lDOpO4gX-154BgLJWonVQZiGRbUqv-GhYy8qWBqJ1QQ,3402
|
|
43
46
|
cascade/low/tracing.py,sha256=qvGVKB1huwcYoyvMYN-2wQ92pLQTErocTjpIjWv9glA,4511
|
|
44
47
|
cascade/low/views.py,sha256=UwafO2EQHre17GjG8hdzO8b6qBRtTRtDlhOc1pTf8Io,1822
|
|
45
48
|
cascade/scheduler/__init__.py,sha256=VT2qQ0gOQWHC4-T0FcCs59w8WZ94j2nUn7tiGm5XepA,1148
|
|
@@ -56,7 +59,7 @@ cascade/shm/disk.py,sha256=Fdl_pKOseaXroRp01OwqWVsdI-sSmiFizIFCdxBuMWM,2653
|
|
|
56
59
|
cascade/shm/func.py,sha256=ZWikgnSLCmbSoW2LDRJwtjxdwTxkR00OUHAsIRQ-ChE,638
|
|
57
60
|
cascade/shm/server.py,sha256=LnnNX0F6QJt5V_JLfmC3ZMHGNL5WpLY44wpB_pYDr7Y,5042
|
|
58
61
|
earthkit/workflows/__init__.py,sha256=-p4anEn0YQbYWM2tbXb0Vc3wq4-m6kFhcNEgAVu5Jis,1948
|
|
59
|
-
earthkit/workflows/_version.py,sha256=
|
|
62
|
+
earthkit/workflows/_version.py,sha256=TdWpE_3Kcp1U42ekVTDBVAxSA3TKm_FWE5S-AZDssXw,72
|
|
60
63
|
earthkit/workflows/decorators.py,sha256=DM4QAtQ2glUUcDecwPkXcdlu4dio7MvgpcdmU5LYvD8,937
|
|
61
64
|
earthkit/workflows/fluent.py,sha256=IN_sqwr7W8wbwP7wTOklgnjVe34IUCmv1ku-DWVTCJc,30179
|
|
62
65
|
earthkit/workflows/mark.py,sha256=PdsXmRfhw1SyyJ74mzFPsLRqMCdlYv556fFX4bqlh9Y,1319
|
|
@@ -86,8 +89,8 @@ earthkit/workflows/graph/split.py,sha256=t-Sji5eZb01QO1szqmDNTodDDALqdo-0R0x1ESs
|
|
|
86
89
|
earthkit/workflows/graph/transform.py,sha256=BZ8n7ePUnuGgoHkMqZC3SLzifu4oq6q6t6vka0khFtg,3842
|
|
87
90
|
earthkit/workflows/graph/visit.py,sha256=MP-aFSqOl7aqJY2i7QTgY4epqb6yM7_lK3ofvOqfahw,1755
|
|
88
91
|
earthkit/workflows/plugins/__init__.py,sha256=nhMAC0eMLxoJamjqB5Ns0OWy0OuxEJ_YvaDFGEQITls,129
|
|
89
|
-
earthkit_workflows-0.4.
|
|
90
|
-
earthkit_workflows-0.4.
|
|
91
|
-
earthkit_workflows-0.4.
|
|
92
|
-
earthkit_workflows-0.4.
|
|
93
|
-
earthkit_workflows-0.4.
|
|
92
|
+
earthkit_workflows-0.4.4.dist-info/licenses/LICENSE,sha256=73MJ7twXMKnWwmzmrMiFwUeY7c6JTvxphVggeUq9Sq4,11381
|
|
93
|
+
earthkit_workflows-0.4.4.dist-info/METADATA,sha256=DUjbGRAmD_GJrn0XwYMLFUMpCg8cLlNII4lz7WLWbms,1571
|
|
94
|
+
earthkit_workflows-0.4.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
95
|
+
earthkit_workflows-0.4.4.dist-info/top_level.txt,sha256=oNrH3Km3hK5kDkTOiM-8G8OQglvZcy-gUKy7rlooWXs,17
|
|
96
|
+
earthkit_workflows-0.4.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|