earthkit-workflows 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cascade/benchmarks/anemoi.py +1 -1
- cascade/benchmarks/dask.py +4 -4
- cascade/benchmarks/dist.py +3 -3
- cascade/benchmarks/job1.py +4 -5
- cascade/benchmarks/matmul.py +4 -4
- cascade/benchmarks/tests.py +3 -3
- cascade/benchmarks/util.py +22 -19
- cascade/controller/act.py +7 -0
- cascade/controller/core.py +31 -4
- cascade/controller/impl.py +5 -4
- cascade/controller/notify.py +4 -1
- cascade/executor/bridge.py +17 -4
- cascade/executor/checkpoints.py +42 -0
- cascade/executor/data_server.py +38 -5
- cascade/executor/executor.py +3 -1
- cascade/executor/msg.py +21 -2
- cascade/executor/platform.py +1 -1
- cascade/executor/runner/entrypoint.py +2 -2
- cascade/executor/runner/memory.py +1 -1
- cascade/gateway/api.py +2 -7
- cascade/gateway/client.py +1 -1
- cascade/gateway/router.py +9 -170
- cascade/gateway/server.py +5 -4
- cascade/gateway/spawning.py +163 -0
- cascade/low/builders.py +2 -2
- cascade/low/core.py +30 -1
- cascade/low/dask.py +1 -1
- cascade/low/execution_context.py +15 -5
- cascade/low/func.py +1 -1
- cascade/low/into.py +9 -3
- cascade/scheduler/assign.py +11 -11
- cascade/shm/api.py +4 -4
- cascade/shm/client.py +1 -0
- cascade/shm/disk.py +2 -2
- earthkit/workflows/_version.py +1 -1
- earthkit/workflows/backends/__init__.py +0 -1
- earthkit/workflows/backends/earthkit.py +1 -1
- earthkit/workflows/fluent.py +14 -11
- earthkit_workflows-0.6.0.dist-info/METADATA +132 -0
- {earthkit_workflows-0.5.0.dist-info → earthkit_workflows-0.6.0.dist-info}/RECORD +43 -41
- {earthkit_workflows-0.5.0.dist-info → earthkit_workflows-0.6.0.dist-info}/WHEEL +1 -1
- earthkit_workflows-0.5.0.dist-info/METADATA +0 -44
- {earthkit_workflows-0.5.0.dist-info → earthkit_workflows-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {earthkit_workflows-0.5.0.dist-info → earthkit_workflows-0.6.0.dist-info}/top_level.txt +0 -0
cascade/benchmarks/anemoi.py
CHANGED
cascade/benchmarks/dask.py
CHANGED
|
@@ -5,9 +5,9 @@ from cascade.low.core import JobInstance
|
|
|
5
5
|
from cascade.low.dask import graph2job
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def get_job(
|
|
8
|
+
def get_job(job_name: str) -> JobInstance:
|
|
9
9
|
|
|
10
|
-
if
|
|
10
|
+
if job_name == "add":
|
|
11
11
|
|
|
12
12
|
def add(x, y):
|
|
13
13
|
result = x + y
|
|
@@ -21,7 +21,7 @@ def get_job(job: str) -> JobInstance:
|
|
|
21
21
|
dataset for task in job.tasks for dataset in job.outputs_of(task)
|
|
22
22
|
]
|
|
23
23
|
return job
|
|
24
|
-
elif
|
|
24
|
+
elif job_name == "groupby":
|
|
25
25
|
df = dd.DataFrame.from_dict({"x": [0, 0, 1, 1], "y": [1, 2, 3, 4]})
|
|
26
26
|
df = df.groupby("x").sum()
|
|
27
27
|
job = graph2job(df.__dask_graph__())
|
|
@@ -30,4 +30,4 @@ def get_job(job: str) -> JobInstance:
|
|
|
30
30
|
]
|
|
31
31
|
return job
|
|
32
32
|
else:
|
|
33
|
-
raise NotImplementedError(
|
|
33
|
+
raise NotImplementedError(job_name)
|
cascade/benchmarks/dist.py
CHANGED
|
@@ -26,7 +26,7 @@ def dist_func_torch(a: int) -> int:
|
|
|
26
26
|
import datetime as dt
|
|
27
27
|
|
|
28
28
|
import numpy as np
|
|
29
|
-
import torch.distributed as dist
|
|
29
|
+
import torch.distributed as dist # ty: ignore[unresolved-import]
|
|
30
30
|
|
|
31
31
|
world_size = int(os.environ["CASCADE_GANG_WORLD_SIZE"])
|
|
32
32
|
rank = int(os.environ["CASCADE_GANG_RANK"])
|
|
@@ -61,8 +61,8 @@ def dist_func_jax(a: int) -> int:
|
|
|
61
61
|
os.environ["JAX_NUM_CPU_DEVICES"] = "1"
|
|
62
62
|
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
|
63
63
|
os.environ["JAX_PLATFORMS"] = "cpu"
|
|
64
|
-
import jax
|
|
65
|
-
import jax.numpy as jp
|
|
64
|
+
import jax # ty: ignore[unresolved-import]
|
|
65
|
+
import jax.numpy as jp # ty: ignore[unresolved-import]
|
|
66
66
|
|
|
67
67
|
jax.config.update("jax_platforms", "cpu")
|
|
68
68
|
jax.config.update("jax_platform_name", "cpu")
|
cascade/benchmarks/job1.py
CHANGED
|
@@ -16,10 +16,9 @@ Controlled by env var params: JOB1_{DATA_ROOT, GRID, ...}, see below
|
|
|
16
16
|
import os
|
|
17
17
|
|
|
18
18
|
import earthkit.data
|
|
19
|
-
|
|
20
19
|
from earthkit.workflows.fluent import Payload
|
|
21
|
-
from earthkit.workflows.plugins.pproc.fluent import from_source
|
|
22
|
-
from earthkit.workflows.plugins.pproc.utils.window import Range
|
|
20
|
+
from earthkit.workflows.plugins.pproc.fluent import from_source # ty: ignore
|
|
21
|
+
from earthkit.workflows.plugins.pproc.utils.window import Range # ty: ignore
|
|
23
22
|
|
|
24
23
|
# *** PARAMS ***
|
|
25
24
|
|
|
@@ -137,7 +136,7 @@ def download_inputs():
|
|
|
137
136
|
}
|
|
138
137
|
data = earthkit.data.from_source("mars", **ekp)
|
|
139
138
|
with open(f"{data_root}/data_{number}_{step}.grib", "wb") as f:
|
|
140
|
-
data.write(f)
|
|
139
|
+
data.write(f) # ty: ignore
|
|
141
140
|
|
|
142
141
|
|
|
143
142
|
def download_climatology():
|
|
@@ -157,7 +156,7 @@ def download_climatology():
|
|
|
157
156
|
}
|
|
158
157
|
data = earthkit.data.from_source("mars", **ekp)
|
|
159
158
|
with open(f"{data_root}/data_clim_{step}.grib", "wb") as f:
|
|
160
|
-
data.write(f)
|
|
159
|
+
data.write(f) # ty: ignore
|
|
161
160
|
|
|
162
161
|
|
|
163
162
|
if __name__ == "__main__":
|
cascade/benchmarks/matmul.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
|
-
import jax
|
|
5
|
-
import jax.numpy as jp
|
|
6
|
-
import jax.random as jr
|
|
4
|
+
import jax # ty: ignore[unresolved-import]
|
|
5
|
+
import jax.numpy as jp # ty: ignore[unresolved-import]
|
|
6
|
+
import jax.random as jr # ty: ignore[unresolved-import]
|
|
7
7
|
|
|
8
8
|
from cascade.low.builders import JobBuilder, TaskBuilder
|
|
9
9
|
from cascade.low.core import JobInstance
|
|
@@ -65,7 +65,7 @@ def execute_locally():
|
|
|
65
65
|
|
|
66
66
|
from multiprocessing.shared_memory import SharedMemory
|
|
67
67
|
|
|
68
|
-
mem = SharedMemory("benchmark_tmp", create=True, size=m0.nbytes)
|
|
68
|
+
mem = SharedMemory("benchmark_tmp", create=True, size=m0.nbytes); assert mem.buf is not None
|
|
69
69
|
mem.buf[:] = m0.tobytes()
|
|
70
70
|
|
|
71
71
|
|
cascade/benchmarks/tests.py
CHANGED
|
@@ -32,7 +32,7 @@ from cascade.executor.runner.memory import Memory, ds2shmid
|
|
|
32
32
|
from cascade.executor.runner.packages import PackagesEnv
|
|
33
33
|
from cascade.executor.runner.runner import ExecutionContext, run
|
|
34
34
|
from cascade.low.builders import TaskBuilder
|
|
35
|
-
from cascade.low.core import DatasetId
|
|
35
|
+
from cascade.low.core import DatasetId, WorkerId
|
|
36
36
|
from cascade.shm.server import entrypoint as shm_server
|
|
37
37
|
|
|
38
38
|
logger = logging.getLogger(__name__)
|
|
@@ -75,7 +75,7 @@ def simple_runner(callback: BackboneAddress, executionContext: ExecutionContext)
|
|
|
75
75
|
raise ValueError(f"expected 1 task, gotten {len(tasks)}")
|
|
76
76
|
taskId = tasks[0]
|
|
77
77
|
taskInstance = executionContext.tasks[taskId]
|
|
78
|
-
with Memory(callback, "testWorker") as memory, PackagesEnv() as pckg:
|
|
78
|
+
with Memory(callback, WorkerId(host="testHost", worker="testWorker")) as memory, PackagesEnv() as pckg:
|
|
79
79
|
# for key, value in taskSequence.extra_env.items():
|
|
80
80
|
# os.environ[key] = value
|
|
81
81
|
|
|
@@ -142,7 +142,7 @@ def run_test(
|
|
|
142
142
|
while perf_counter_ns() < end:
|
|
143
143
|
mess = listener.recv_messages()
|
|
144
144
|
if mess == [
|
|
145
|
-
DatasetPublished(origin="testWorker", ds=output, transmit_idx=None)
|
|
145
|
+
DatasetPublished(origin=WorkerId(host="testHost", worker="testWorker"), ds=output, transmit_idx=None)
|
|
146
146
|
]:
|
|
147
147
|
break
|
|
148
148
|
elif not mess:
|
cascade/benchmarks/util.py
CHANGED
|
@@ -29,7 +29,7 @@ from cascade.executor.comms import callback
|
|
|
29
29
|
from cascade.executor.config import logging_config, logging_config_filehandler
|
|
30
30
|
from cascade.executor.executor import Executor
|
|
31
31
|
from cascade.executor.msg import BackboneAddress, ExecutorShutdown
|
|
32
|
-
from cascade.low.core import DatasetId, JobInstance
|
|
32
|
+
from cascade.low.core import DatasetId, JobInstance, JobInstanceRich
|
|
33
33
|
from cascade.low.func import msum
|
|
34
34
|
from cascade.scheduler.precompute import precompute
|
|
35
35
|
from earthkit.workflows.graph import Graph, deduplicate_nodes
|
|
@@ -37,15 +37,16 @@ from earthkit.workflows.graph import Graph, deduplicate_nodes
|
|
|
37
37
|
logger = logging.getLogger("cascade.benchmarks")
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
def get_job(benchmark: str | None, instance_path: str | None) ->
|
|
40
|
+
def get_job(benchmark: str | None, instance_path: str | None) -> JobInstanceRich:
|
|
41
41
|
# NOTE because of os.environ, we don't import all... ideally we'd have some file-based init/config mech instead
|
|
42
42
|
if benchmark is not None and instance_path is not None:
|
|
43
43
|
raise TypeError("specified both benchmark name and job instance")
|
|
44
44
|
elif instance_path is not None:
|
|
45
45
|
with open(instance_path, "rb") as f:
|
|
46
46
|
d = orjson.loads(f.read())
|
|
47
|
-
return
|
|
47
|
+
return JobInstanceRich(**d)
|
|
48
48
|
elif benchmark is not None:
|
|
49
|
+
instance: JobInstance
|
|
49
50
|
if benchmark.startswith("j1"):
|
|
50
51
|
import cascade.benchmarks.job1 as job1
|
|
51
52
|
|
|
@@ -58,25 +59,26 @@ def get_job(benchmark: str | None, instance_path: str | None) -> JobInstance:
|
|
|
58
59
|
msum((v for k, v in graphs.items() if k.startswith(prefix)), Graph)
|
|
59
60
|
)
|
|
60
61
|
graphs["j1.all"] = union("j1.")
|
|
61
|
-
|
|
62
|
+
instance = cascade.low.into.graph2job(graphs[benchmark])
|
|
62
63
|
elif benchmark.startswith("generators"):
|
|
63
64
|
import cascade.benchmarks.generators as generators
|
|
64
65
|
|
|
65
|
-
|
|
66
|
+
instance = generators.get_job()
|
|
66
67
|
elif benchmark.startswith("matmul"):
|
|
67
68
|
import cascade.benchmarks.matmul as matmul
|
|
68
69
|
|
|
69
|
-
|
|
70
|
+
instance = matmul.get_job()
|
|
70
71
|
elif benchmark.startswith("dist"):
|
|
71
72
|
import cascade.benchmarks.dist as dist
|
|
72
73
|
|
|
73
|
-
|
|
74
|
+
instance = dist.get_job()
|
|
74
75
|
elif benchmark.startswith("dask"):
|
|
75
76
|
import cascade.benchmarks.dask as dask
|
|
76
77
|
|
|
77
|
-
|
|
78
|
+
instance = dask.get_job(benchmark[len("dask.") :])
|
|
78
79
|
else:
|
|
79
80
|
raise NotImplementedError(benchmark)
|
|
81
|
+
return JobInstanceRich(jobInstance=instance, checkpointSpec=None)
|
|
80
82
|
else:
|
|
81
83
|
raise TypeError("specified neither benchmark name nor job instance")
|
|
82
84
|
|
|
@@ -116,7 +118,7 @@ def get_gpu_count(host_idx: int, worker_count: int) -> int:
|
|
|
116
118
|
|
|
117
119
|
|
|
118
120
|
def launch_executor(
|
|
119
|
-
|
|
121
|
+
job: JobInstanceRich,
|
|
120
122
|
controller_address: BackboneAddress,
|
|
121
123
|
workers_per_host: int,
|
|
122
124
|
portBase: int,
|
|
@@ -136,7 +138,7 @@ def launch_executor(
|
|
|
136
138
|
logger.info(f"will set {gpu_count} gpus on host {i}")
|
|
137
139
|
os.environ["CASCADE_GPU_COUNT"] = str(gpu_count)
|
|
138
140
|
executor = Executor(
|
|
139
|
-
|
|
141
|
+
job.jobInstance,
|
|
140
142
|
controller_address,
|
|
141
143
|
workers_per_host,
|
|
142
144
|
f"h{i}",
|
|
@@ -154,7 +156,7 @@ def launch_executor(
|
|
|
154
156
|
|
|
155
157
|
|
|
156
158
|
def run_locally(
|
|
157
|
-
job:
|
|
159
|
+
job: JobInstanceRich,
|
|
158
160
|
hosts: int,
|
|
159
161
|
workers: int,
|
|
160
162
|
portBase: int = 12345,
|
|
@@ -195,7 +197,7 @@ def run_locally(
|
|
|
195
197
|
ps.append(p)
|
|
196
198
|
|
|
197
199
|
# compute preschedule
|
|
198
|
-
preschedule = precompute(job)
|
|
200
|
+
preschedule = precompute(job.jobInstance)
|
|
199
201
|
|
|
200
202
|
# check processes started healthy
|
|
201
203
|
for i, p in enumerate(ps):
|
|
@@ -240,9 +242,9 @@ def main_local(
|
|
|
240
242
|
port_base: int = 12345,
|
|
241
243
|
log_base: str | None = None,
|
|
242
244
|
) -> None:
|
|
243
|
-
|
|
245
|
+
jobInstanceRich = get_job(job, instance)
|
|
244
246
|
run_locally(
|
|
245
|
-
|
|
247
|
+
jobInstanceRich,
|
|
246
248
|
hosts,
|
|
247
249
|
workers_per_host,
|
|
248
250
|
report_address=report_address,
|
|
@@ -266,17 +268,17 @@ def main_dist(
|
|
|
266
268
|
"""
|
|
267
269
|
launch = perf_counter_ns()
|
|
268
270
|
|
|
269
|
-
|
|
271
|
+
jobInstanceRich = get_job(job, instance)
|
|
270
272
|
|
|
271
273
|
if idx == 0:
|
|
272
274
|
logging.config.dictConfig(logging_config)
|
|
273
275
|
tp = ThreadPoolExecutor(max_workers=1)
|
|
274
|
-
preschedule_fut = tp.submit(precompute, jobInstance)
|
|
276
|
+
preschedule_fut = tp.submit(precompute, jobInstanceRich.jobInstance)
|
|
275
277
|
b = Bridge(controller_url, hosts)
|
|
276
278
|
preschedule = preschedule_fut.result()
|
|
277
279
|
tp.shutdown()
|
|
278
280
|
start = perf_counter_ns()
|
|
279
|
-
run(
|
|
281
|
+
run(jobInstanceRich, b, preschedule, report_address=report_address)
|
|
280
282
|
end = perf_counter_ns()
|
|
281
283
|
print(
|
|
282
284
|
f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
|
|
@@ -284,12 +286,13 @@ def main_dist(
|
|
|
284
286
|
else:
|
|
285
287
|
gpu_count = get_gpu_count(0, workers_per_host)
|
|
286
288
|
launch_executor(
|
|
287
|
-
|
|
289
|
+
jobInstanceRich,
|
|
288
290
|
controller_url,
|
|
289
291
|
workers_per_host,
|
|
290
292
|
12345,
|
|
291
293
|
idx,
|
|
292
294
|
shm_vol_gb,
|
|
293
295
|
gpu_count,
|
|
294
|
-
|
|
296
|
+
log_base = None, # TODO handle log collection for dist scenario
|
|
297
|
+
url_base = f"tcp://{platform.get_bindabble_self()}",
|
|
295
298
|
)
|
cascade/controller/act.py
CHANGED
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
|
|
13
|
+
import cascade.executor.checkpoints as checkpoints
|
|
13
14
|
from cascade.controller.core import State
|
|
14
15
|
from cascade.executor.bridge import Bridge
|
|
15
16
|
from cascade.executor.msg import TaskSequence
|
|
@@ -76,6 +77,12 @@ def flush_queues(bridge: Bridge, state: State, context: JobExecutionContext):
|
|
|
76
77
|
for dataset, host in state.drain_fetching_queue():
|
|
77
78
|
bridge.fetch(dataset, host)
|
|
78
79
|
|
|
80
|
+
for dataset, host in state.drain_persist_queue():
|
|
81
|
+
if context.checkpoint_spec is None:
|
|
82
|
+
raise TypeError(f"unexpected persist need when checkpoint storage not configured")
|
|
83
|
+
persist_params = checkpoints.serialize_persist_params(context.checkpoint_spec)
|
|
84
|
+
bridge.persist(dataset, host, context.checkpoint_spec.storage_type, persist_params)
|
|
85
|
+
|
|
79
86
|
for ds in state.drain_purging_queue():
|
|
80
87
|
for host in context.purge_dataset(ds):
|
|
81
88
|
logger.debug(f"issuing purge of {ds=} to {host=}")
|
cascade/controller/core.py
CHANGED
|
@@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|
|
8
8
|
from typing import Any, Iterator
|
|
9
9
|
|
|
10
10
|
import cascade.executor.serde as serde
|
|
11
|
-
from cascade.executor.msg import DatasetTransmitPayload
|
|
11
|
+
from cascade.executor.msg import DatasetPersistSuccess, DatasetTransmitPayload
|
|
12
12
|
from cascade.low.core import DatasetId, HostId, TaskId
|
|
13
13
|
|
|
14
14
|
logger = logging.getLogger(__name__)
|
|
@@ -16,10 +16,14 @@ logger = logging.getLogger(__name__)
|
|
|
16
16
|
|
|
17
17
|
@dataclass
|
|
18
18
|
class State:
|
|
19
|
-
# key add by core.
|
|
19
|
+
# key add by core.init_state, value add by notify.notify
|
|
20
20
|
outputs: dict[DatasetId, Any]
|
|
21
|
+
# key add by core.init_state, value add by notify.notify
|
|
22
|
+
to_persist: set[DatasetId]
|
|
21
23
|
# add by notify.notify, remove by act.flush_queues
|
|
22
24
|
fetching_queue: dict[DatasetId, HostId]
|
|
25
|
+
# add by notify.notify, remove by act.flush_queues
|
|
26
|
+
persist_queue: dict[DatasetId, HostId]
|
|
23
27
|
# add by notify.notify, removed by act.flush_queues
|
|
24
28
|
purging_queue: list[DatasetId]
|
|
25
29
|
# add by core.init_state, remove by notify.notify
|
|
@@ -31,13 +35,16 @@ class State:
|
|
|
31
35
|
for e in self.outputs.values():
|
|
32
36
|
if e is None:
|
|
33
37
|
return True
|
|
38
|
+
if self.to_persist:
|
|
39
|
+
return True
|
|
34
40
|
return False
|
|
35
41
|
|
|
36
42
|
def _consider_purge(self, dataset: DatasetId) -> None:
|
|
37
43
|
"""If dataset not required anymore, add to purging_queue"""
|
|
38
44
|
no_dependants = not self.purging_tracker.get(dataset, None)
|
|
39
45
|
not_required_output = self.outputs.get(dataset, 1) is not None
|
|
40
|
-
|
|
46
|
+
not_required_persist = not dataset in self.to_persist
|
|
47
|
+
if all((no_dependants, not_required_output, not_required_persist)):
|
|
41
48
|
logger.debug(f"adding {dataset=} to purging queue")
|
|
42
49
|
if dataset in self.purging_tracker:
|
|
43
50
|
self.purging_tracker.pop(dataset)
|
|
@@ -52,6 +59,14 @@ class State:
|
|
|
52
59
|
):
|
|
53
60
|
self.fetching_queue[dataset] = at
|
|
54
61
|
|
|
62
|
+
def consider_persist(self, dataset: DatasetId, at: HostId) -> None:
|
|
63
|
+
"""If required as persist and not yet acknowledged, add to persist queue"""
|
|
64
|
+
if (
|
|
65
|
+
dataset in self.to_persist
|
|
66
|
+
and dataset not in self.persist_queue
|
|
67
|
+
):
|
|
68
|
+
self.persist_queue[dataset] = at
|
|
69
|
+
|
|
55
70
|
def receive_payload(self, payload: DatasetTransmitPayload) -> None:
|
|
56
71
|
"""Stores deserialized value into outputs, considers purge"""
|
|
57
72
|
# NOTE ifneedbe get annotation from job.tasks[event.ds.task].definition.output_schema[event.ds.output]
|
|
@@ -60,6 +75,11 @@ class State:
|
|
|
60
75
|
)
|
|
61
76
|
self._consider_purge(payload.header.ds)
|
|
62
77
|
|
|
78
|
+
def acknowledge_persist(self, payload: DatasetPersistSuccess) -> None:
|
|
79
|
+
"""Marks acknowledged, considers purge"""
|
|
80
|
+
self.to_persist.discard(payload.ds)
|
|
81
|
+
self._consider_purge(payload.ds)
|
|
82
|
+
|
|
63
83
|
def task_done(self, task: TaskId, inputs: set[DatasetId]) -> None:
|
|
64
84
|
"""Marks that the inputs are not needed for this task anymore, considers purge of each"""
|
|
65
85
|
for sourceDataset in inputs:
|
|
@@ -76,15 +96,22 @@ class State:
|
|
|
76
96
|
yield dataset, host
|
|
77
97
|
self.fetching_queue = {}
|
|
78
98
|
|
|
99
|
+
def drain_persist_queue(self) -> Iterator[tuple[DatasetId, HostId]]:
|
|
100
|
+
for dataset, host in self.persist_queue.items():
|
|
101
|
+
yield dataset, host
|
|
102
|
+
self.persist_queue = {}
|
|
103
|
+
|
|
79
104
|
|
|
80
|
-
def init_state(outputs: set[DatasetId], edge_o: dict[DatasetId, set[TaskId]]) -> State:
|
|
105
|
+
def init_state(outputs: set[DatasetId], to_persist: set[DatasetId], edge_o: dict[DatasetId, set[TaskId]]) -> State:
|
|
81
106
|
purging_tracker = {
|
|
82
107
|
ds: {task for task in dependants} for ds, dependants in edge_o.items()
|
|
83
108
|
}
|
|
84
109
|
|
|
85
110
|
return State(
|
|
86
111
|
outputs={e: None for e in outputs},
|
|
112
|
+
to_persist={e for e in to_persist},
|
|
87
113
|
fetching_queue={},
|
|
88
114
|
purging_queue=[],
|
|
89
115
|
purging_tracker=purging_tracker,
|
|
116
|
+
persist_queue={},
|
|
90
117
|
)
|
cascade/controller/impl.py
CHANGED
|
@@ -14,7 +14,7 @@ from cascade.controller.core import State, init_state
|
|
|
14
14
|
from cascade.controller.notify import notify
|
|
15
15
|
from cascade.controller.report import Reporter
|
|
16
16
|
from cascade.executor.bridge import Bridge, Event
|
|
17
|
-
from cascade.low.core import JobInstance, type_dec
|
|
17
|
+
from cascade.low.core import JobInstance, JobInstanceRich, type_dec
|
|
18
18
|
from cascade.low.execution_context import init_context
|
|
19
19
|
from cascade.low.tracing import ControllerPhases, Microtrace, label, mark, timer
|
|
20
20
|
from cascade.scheduler.api import assign, init_schedule, plan
|
|
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def run(
|
|
27
|
-
job:
|
|
27
|
+
job: JobInstanceRich,
|
|
28
28
|
bridge: Bridge,
|
|
29
29
|
preschedule: Preschedule,
|
|
30
30
|
report_address: str | None = None,
|
|
@@ -34,7 +34,8 @@ def run(
|
|
|
34
34
|
outputs = set(context.job_instance.ext_outputs)
|
|
35
35
|
logger.debug(f"starting with {env=} and {report_address=}")
|
|
36
36
|
schedule = timer(init_schedule, Microtrace.ctrl_init)(preschedule, context)
|
|
37
|
-
|
|
37
|
+
to_persist = set(job.checkpointSpec.to_persist) if job.checkpointSpec is not None else set()
|
|
38
|
+
state = init_state(outputs, to_persist, context.edge_o)
|
|
38
39
|
|
|
39
40
|
label("host", "controller")
|
|
40
41
|
events: list[Event] = []
|
|
@@ -44,7 +45,7 @@ def run(
|
|
|
44
45
|
|
|
45
46
|
try:
|
|
46
47
|
total_gpus = sum(worker.gpu for worker in env.workers.values())
|
|
47
|
-
needs_gpus = any(task.definition.needs_gpu for task in job.tasks.values())
|
|
48
|
+
needs_gpus = any(task.definition.needs_gpu for task in job.jobInstance.tasks.values())
|
|
48
49
|
if needs_gpus and total_gpus == 0:
|
|
49
50
|
raise ValueError("environment contains no gpu yet job demands one")
|
|
50
51
|
|
cascade/controller/notify.py
CHANGED
|
@@ -17,7 +17,7 @@ from typing import Iterable
|
|
|
17
17
|
from cascade.controller.core import State
|
|
18
18
|
from cascade.controller.report import Reporter
|
|
19
19
|
from cascade.executor.bridge import Event
|
|
20
|
-
from cascade.executor.msg import DatasetPublished, DatasetTransmitPayload
|
|
20
|
+
from cascade.executor.msg import DatasetPersistSuccess, DatasetPublished, DatasetTransmitPayload
|
|
21
21
|
from cascade.low.core import DatasetId, HostId, WorkerId
|
|
22
22
|
from cascade.low.execution_context import DatasetStatus, JobExecutionContext
|
|
23
23
|
from cascade.low.func import assert_never
|
|
@@ -89,6 +89,7 @@ def notify(
|
|
|
89
89
|
context.host2ds[host][event.ds] = DatasetStatus.available
|
|
90
90
|
context.ds2host[event.ds][host] = DatasetStatus.available
|
|
91
91
|
state.consider_fetch(event.ds, host)
|
|
92
|
+
state.consider_persist(event.ds, host)
|
|
92
93
|
consider_computable(schedule, state, context, event.ds, host)
|
|
93
94
|
if event.transmit_idx is not None:
|
|
94
95
|
mark(
|
|
@@ -121,5 +122,7 @@ def notify(
|
|
|
121
122
|
elif isinstance(event, DatasetTransmitPayload):
|
|
122
123
|
state.receive_payload(event)
|
|
123
124
|
reporter.send_result(event.header.ds, event.value)
|
|
125
|
+
elif isinstance(event, DatasetPersistSuccess):
|
|
126
|
+
state.acknowledge_persist(event)
|
|
124
127
|
else:
|
|
125
128
|
assert_never(event)
|
cascade/executor/bridge.py
CHANGED
|
@@ -16,6 +16,9 @@ from cascade.executor.comms import default_message_resend_ms as resend_grace_ms
|
|
|
16
16
|
from cascade.executor.executor import heartbeat_grace_ms as executor_heartbeat_grace_ms
|
|
17
17
|
from cascade.executor.msg import (
|
|
18
18
|
Ack,
|
|
19
|
+
DatasetPersistCommand,
|
|
20
|
+
DatasetPersistFailure,
|
|
21
|
+
DatasetPersistSuccess,
|
|
19
22
|
DatasetPublished,
|
|
20
23
|
DatasetPurge,
|
|
21
24
|
DatasetTransmitCommand,
|
|
@@ -29,14 +32,15 @@ from cascade.executor.msg import (
|
|
|
29
32
|
TaskFailure,
|
|
30
33
|
TaskSequence,
|
|
31
34
|
)
|
|
32
|
-
from cascade.low.core import DatasetId, Environment, HostId, Worker, WorkerId
|
|
35
|
+
from cascade.low.core import CheckpointStorageType, DatasetId, Environment, HostId, Worker, WorkerId
|
|
33
36
|
from cascade.low.func import assert_never
|
|
34
37
|
|
|
35
38
|
logger = logging.getLogger(__name__)
|
|
36
39
|
|
|
37
|
-
Event = DatasetPublished | DatasetTransmitPayload
|
|
38
|
-
|
|
39
|
-
|
|
40
|
+
Event = DatasetPublished | DatasetTransmitPayload | DatasetPersistSuccess
|
|
41
|
+
# TODO consider retries here, esp on the PersistFailure
|
|
42
|
+
ToShutdown = TaskFailure | ExecutorFailure | DatasetTransmitFailure | DatasetPersistFailure | ExecutorExit
|
|
43
|
+
Unsupported = TaskSequence | DatasetPurge | DatasetTransmitCommand | DatasetPersistCommand | ExecutorShutdown
|
|
40
44
|
|
|
41
45
|
|
|
42
46
|
class Bridge:
|
|
@@ -158,6 +162,15 @@ class Bridge:
|
|
|
158
162
|
self.transmit_idx_counter += 1
|
|
159
163
|
self.sender.send("data." + source, m)
|
|
160
164
|
|
|
165
|
+
def persist(self, ds: DatasetId, source: HostId, storage_type: CheckpointStorageType, persist_params: str) -> None:
|
|
166
|
+
m = DatasetPersistCommand(
|
|
167
|
+
source=source,
|
|
168
|
+
ds=ds,
|
|
169
|
+
storage_type=storage_type,
|
|
170
|
+
persist_params=persist_params,
|
|
171
|
+
)
|
|
172
|
+
self.sender.send("data." + source, m)
|
|
173
|
+
|
|
161
174
|
def fetch(self, ds: DatasetId, source: HostId) -> None:
|
|
162
175
|
m = DatasetTransmitCommand(
|
|
163
176
|
source=source,
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# (C) Copyright 2025- ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""Handles the checkpoint management: storage, retrieval"""
|
|
10
|
+
|
|
11
|
+
import pathlib
|
|
12
|
+
|
|
13
|
+
from cascade.executor.msg import DatasetPersistCommand
|
|
14
|
+
from cascade.low.core import CheckpointSpec
|
|
15
|
+
from cascade.low.func import assert_never
|
|
16
|
+
from cascade.shm.client import AllocatedBuffer
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def persist_dataset(command: DatasetPersistCommand, buf: AllocatedBuffer) -> None:
|
|
20
|
+
match command.storage_type:
|
|
21
|
+
case "fs":
|
|
22
|
+
root = pathlib.Path(command.persist_params)
|
|
23
|
+
root.mkdir(parents=True, exist_ok=True)
|
|
24
|
+
file = root / repr(command.ds)
|
|
25
|
+
# TODO what about overwrites / concurrent writes? Append uuid?
|
|
26
|
+
file.write_bytes(buf.view())
|
|
27
|
+
case s:
|
|
28
|
+
assert_never(s)
|
|
29
|
+
|
|
30
|
+
def serialize_persist_params(spec: CheckpointSpec) -> str:
|
|
31
|
+
# NOTE we call this every time we store, ideally call this once when building `low.execution_context`
|
|
32
|
+
match spec.storage_type:
|
|
33
|
+
case "fs":
|
|
34
|
+
if not isinstance(spec.storage_params, str):
|
|
35
|
+
raise TypeError(f"expected checkpoint storage params to be str, gotten {spec.storage_params.__class__}")
|
|
36
|
+
if spec.persist_id is None:
|
|
37
|
+
raise TypeError(f"serialize_persist_params called, but persist_id is None")
|
|
38
|
+
root = pathlib.Path(spec.storage_params)
|
|
39
|
+
return str(root / spec.persist_id)
|
|
40
|
+
case s:
|
|
41
|
+
assert_never(s)
|
|
42
|
+
|
cascade/executor/data_server.py
CHANGED
|
@@ -15,16 +15,20 @@ large data object.
|
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
17
|
import logging.config
|
|
18
|
-
from concurrent.futures import ALL_COMPLETED, FIRST_COMPLETED
|
|
18
|
+
from concurrent.futures import ALL_COMPLETED, FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
|
19
19
|
from concurrent.futures import Executor as PythonExecutor
|
|
20
|
-
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
|
21
20
|
from time import time_ns
|
|
21
|
+
from typing import cast
|
|
22
22
|
|
|
23
23
|
import cascade.shm.client as shm_client
|
|
24
|
+
from cascade.executor.checkpoints import persist_dataset
|
|
24
25
|
from cascade.executor.comms import Listener, callback, send_data
|
|
25
26
|
from cascade.executor.msg import (
|
|
26
27
|
Ack,
|
|
27
28
|
BackboneAddress,
|
|
29
|
+
DatasetPersistCommand,
|
|
30
|
+
DatasetPersistFailure,
|
|
31
|
+
DatasetPersistSuccess,
|
|
28
32
|
DatasetPublished,
|
|
29
33
|
DatasetPurge,
|
|
30
34
|
DatasetTransmitCommand,
|
|
@@ -59,7 +63,7 @@ class DataServer:
|
|
|
59
63
|
self.cap = 2
|
|
60
64
|
self.ds_proc_tp: PythonExecutor = ThreadPoolExecutor(max_workers=self.cap)
|
|
61
65
|
self.futs_in_progress: dict[
|
|
62
|
-
DatasetTransmitCommand | DatasetTransmitPayload, Future
|
|
66
|
+
DatasetTransmitCommand | DatasetTransmitPayload | DatasetPersistCommand, Future
|
|
63
67
|
] = {}
|
|
64
68
|
self.awaiting_confirmation: dict[int, tuple[DatasetTransmitCommand, int]] = {}
|
|
65
69
|
self.invalid: set[DatasetId] = (
|
|
@@ -149,6 +153,26 @@ class DataServer:
|
|
|
149
153
|
time_ns()
|
|
150
154
|
) # not actually consumed but uniform signature with send_payload simplifies typing
|
|
151
155
|
|
|
156
|
+
def persist_payload(self, command: DatasetPersistCommand) -> int:
|
|
157
|
+
buf: None | shm_client.AllocatedBuffer = None
|
|
158
|
+
try:
|
|
159
|
+
if command.source != self.host:
|
|
160
|
+
raise ValueError(f"invalid {command=}")
|
|
161
|
+
buf = shm_client.get(key=ds2shmid(command.ds))
|
|
162
|
+
persist_dataset(command, buf)
|
|
163
|
+
logger.debug(f"dataset for {command} persisted")
|
|
164
|
+
callback(self.maddress, DatasetPersistSuccess(host=self.host, ds=command.ds))
|
|
165
|
+
except Exception as e:
|
|
166
|
+
logger.exception(f"failed to persist dataset for {command}, reporting up")
|
|
167
|
+
callback(
|
|
168
|
+
self.maddress,
|
|
169
|
+
DatasetPersistFailure(host=self.host, detail=f"{repr(command)} -> {repr(e)}"),
|
|
170
|
+
)
|
|
171
|
+
finally:
|
|
172
|
+
if buf is not None:
|
|
173
|
+
buf.close()
|
|
174
|
+
return time_ns()
|
|
175
|
+
|
|
152
176
|
def send_payload(self, command: DatasetTransmitCommand) -> int:
|
|
153
177
|
buf: None | shm_client.AllocatedBuffer = None
|
|
154
178
|
payload: None | DatasetTransmitPayload = None
|
|
@@ -171,7 +195,7 @@ class DataServer:
|
|
|
171
195
|
ds=command.ds,
|
|
172
196
|
deser_fun=buf.deser_fun,
|
|
173
197
|
)
|
|
174
|
-
payload = DatasetTransmitPayload(header, value=buf.view())
|
|
198
|
+
payload = DatasetTransmitPayload(header, value=cast(bytes, buf.view()))
|
|
175
199
|
syn = Syn(command.idx, self.dlistener.address)
|
|
176
200
|
send_data(command.daddress, payload, syn)
|
|
177
201
|
logger.debug(f"payload for {command} sent")
|
|
@@ -218,6 +242,14 @@ class DataServer:
|
|
|
218
242
|
self.awaiting_confirmation[m.idx] = (m, -1)
|
|
219
243
|
fut = self.ds_proc_tp.submit(self.send_payload, m)
|
|
220
244
|
self.futs_in_progress[m] = fut
|
|
245
|
+
elif isinstance(m, DatasetPersistCommand):
|
|
246
|
+
if m.ds in self.invalid:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"unexpected persist command {m} as the dataset was already purged"
|
|
249
|
+
)
|
|
250
|
+
# TODO mark?
|
|
251
|
+
fut = self.ds_proc_tp.submit(self.persist_payload, m)
|
|
252
|
+
self.futs_in_progress[m] = fut
|
|
221
253
|
elif isinstance(m, DatasetTransmitPayload):
|
|
222
254
|
if m.header.ds in self.invalid:
|
|
223
255
|
logger.warning(
|
|
@@ -238,9 +270,10 @@ class DataServer:
|
|
|
238
270
|
self.acks.add(m.idx)
|
|
239
271
|
elif isinstance(m, DatasetPurge):
|
|
240
272
|
# we need to handle potential commands transmitting this dataset, as otherwise they'd fail
|
|
273
|
+
# TODO submit this as a future? This actively blocks the whole server
|
|
241
274
|
to_wait = []
|
|
242
275
|
for commandProg, fut in self.futs_in_progress.items():
|
|
243
|
-
if isinstance(commandProg, DatasetTransmitCommand):
|
|
276
|
+
if isinstance(commandProg, DatasetTransmitCommand|DatasetPersistCommand):
|
|
244
277
|
val = commandProg.ds
|
|
245
278
|
elif isinstance(commandProg, DatasetTransmitPayload):
|
|
246
279
|
val = commandProg.header.ds
|
cascade/executor/executor.py
CHANGED
|
@@ -35,6 +35,8 @@ from cascade.executor.data_server import start_data_server
|
|
|
35
35
|
from cascade.executor.msg import (
|
|
36
36
|
Ack,
|
|
37
37
|
BackboneAddress,
|
|
38
|
+
DatasetPersistFailure,
|
|
39
|
+
DatasetPersistSuccess,
|
|
38
40
|
DatasetPublished,
|
|
39
41
|
DatasetPurge,
|
|
40
42
|
DatasetTransmitFailure,
|
|
@@ -306,7 +308,7 @@ class Executor:
|
|
|
306
308
|
callback(worker_address(worker), m)
|
|
307
309
|
self.datasets.add(m.ds)
|
|
308
310
|
self.to_controller(m)
|
|
309
|
-
elif isinstance(m, DatasetTransmitFailure):
|
|
311
|
+
elif isinstance(m, DatasetTransmitFailure|DatasetPersistSuccess|DatasetPersistFailure):
|
|
310
312
|
self.to_controller(m)
|
|
311
313
|
else:
|
|
312
314
|
# NOTE transmit and store are handled in DataServer (which has its own socket)
|