earthkit-workflows 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cascade/benchmarks/anemoi.py +1 -1
  2. cascade/benchmarks/dask.py +4 -4
  3. cascade/benchmarks/dist.py +3 -3
  4. cascade/benchmarks/job1.py +4 -5
  5. cascade/benchmarks/matmul.py +4 -4
  6. cascade/benchmarks/tests.py +3 -3
  7. cascade/benchmarks/util.py +22 -19
  8. cascade/controller/act.py +7 -0
  9. cascade/controller/core.py +31 -4
  10. cascade/controller/impl.py +5 -4
  11. cascade/controller/notify.py +4 -1
  12. cascade/executor/bridge.py +17 -4
  13. cascade/executor/checkpoints.py +42 -0
  14. cascade/executor/data_server.py +38 -5
  15. cascade/executor/executor.py +3 -1
  16. cascade/executor/msg.py +21 -2
  17. cascade/executor/platform.py +1 -1
  18. cascade/executor/runner/entrypoint.py +2 -2
  19. cascade/executor/runner/memory.py +1 -1
  20. cascade/gateway/api.py +2 -7
  21. cascade/gateway/client.py +1 -1
  22. cascade/gateway/router.py +9 -170
  23. cascade/gateway/server.py +5 -4
  24. cascade/gateway/spawning.py +163 -0
  25. cascade/low/builders.py +2 -2
  26. cascade/low/core.py +30 -1
  27. cascade/low/dask.py +1 -1
  28. cascade/low/execution_context.py +15 -5
  29. cascade/low/func.py +1 -1
  30. cascade/low/into.py +9 -3
  31. cascade/scheduler/assign.py +11 -11
  32. cascade/shm/api.py +4 -4
  33. cascade/shm/client.py +1 -0
  34. cascade/shm/disk.py +2 -2
  35. earthkit/workflows/_version.py +1 -1
  36. earthkit/workflows/backends/__init__.py +0 -1
  37. earthkit/workflows/backends/earthkit.py +1 -1
  38. earthkit/workflows/fluent.py +14 -11
  39. earthkit_workflows-0.6.0.dist-info/METADATA +132 -0
  40. {earthkit_workflows-0.5.0.dist-info → earthkit_workflows-0.6.0.dist-info}/RECORD +43 -41
  41. {earthkit_workflows-0.5.0.dist-info → earthkit_workflows-0.6.0.dist-info}/WHEEL +1 -1
  42. earthkit_workflows-0.5.0.dist-info/METADATA +0 -44
  43. {earthkit_workflows-0.5.0.dist-info → earthkit_workflows-0.6.0.dist-info}/licenses/LICENSE +0 -0
  44. {earthkit_workflows-0.5.0.dist-info → earthkit_workflows-0.6.0.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ from earthkit.workflows import Cascade
10
10
 
11
11
 
12
12
  def get_graph(lead_time, ensemble_members, CKPT=None, date="2024-12-02T00:00"):
13
- import anemoicascade as ac
13
+ import anemoicascade as ac # ty: ignore[unresolved-import]
14
14
 
15
15
  CKPT = (
16
16
  CKPT
@@ -5,9 +5,9 @@ from cascade.low.core import JobInstance
5
5
  from cascade.low.dask import graph2job
6
6
 
7
7
 
8
- def get_job(job: str) -> JobInstance:
8
+ def get_job(job_name: str) -> JobInstance:
9
9
 
10
- if job == "add":
10
+ if job_name == "add":
11
11
 
12
12
  def add(x, y):
13
13
  result = x + y
@@ -21,7 +21,7 @@ def get_job(job: str) -> JobInstance:
21
21
  dataset for task in job.tasks for dataset in job.outputs_of(task)
22
22
  ]
23
23
  return job
24
- elif job == "groupby":
24
+ elif job_name == "groupby":
25
25
  df = dd.DataFrame.from_dict({"x": [0, 0, 1, 1], "y": [1, 2, 3, 4]})
26
26
  df = df.groupby("x").sum()
27
27
  job = graph2job(df.__dask_graph__())
@@ -30,4 +30,4 @@ def get_job(job: str) -> JobInstance:
30
30
  ]
31
31
  return job
32
32
  else:
33
- raise NotImplementedError(job)
33
+ raise NotImplementedError(job_name)
@@ -26,7 +26,7 @@ def dist_func_torch(a: int) -> int:
26
26
  import datetime as dt
27
27
 
28
28
  import numpy as np
29
- import torch.distributed as dist
29
+ import torch.distributed as dist # ty: ignore[unresolved-import]
30
30
 
31
31
  world_size = int(os.environ["CASCADE_GANG_WORLD_SIZE"])
32
32
  rank = int(os.environ["CASCADE_GANG_RANK"])
@@ -61,8 +61,8 @@ def dist_func_jax(a: int) -> int:
61
61
  os.environ["JAX_NUM_CPU_DEVICES"] = "1"
62
62
  os.environ["JAX_PLATFORM_NAME"] = "cpu"
63
63
  os.environ["JAX_PLATFORMS"] = "cpu"
64
- import jax
65
- import jax.numpy as jp
64
+ import jax # ty: ignore[unresolved-import]
65
+ import jax.numpy as jp # ty: ignore[unresolved-import]
66
66
 
67
67
  jax.config.update("jax_platforms", "cpu")
68
68
  jax.config.update("jax_platform_name", "cpu")
@@ -16,10 +16,9 @@ Controlled by env var params: JOB1_{DATA_ROOT, GRID, ...}, see below
16
16
  import os
17
17
 
18
18
  import earthkit.data
19
-
20
19
  from earthkit.workflows.fluent import Payload
21
- from earthkit.workflows.plugins.pproc.fluent import from_source
22
- from earthkit.workflows.plugins.pproc.utils.window import Range
20
+ from earthkit.workflows.plugins.pproc.fluent import from_source # ty: ignore
21
+ from earthkit.workflows.plugins.pproc.utils.window import Range # ty: ignore
23
22
 
24
23
  # *** PARAMS ***
25
24
 
@@ -137,7 +136,7 @@ def download_inputs():
137
136
  }
138
137
  data = earthkit.data.from_source("mars", **ekp)
139
138
  with open(f"{data_root}/data_{number}_{step}.grib", "wb") as f:
140
- data.write(f)
139
+ data.write(f) # ty: ignore
141
140
 
142
141
 
143
142
  def download_climatology():
@@ -157,7 +156,7 @@ def download_climatology():
157
156
  }
158
157
  data = earthkit.data.from_source("mars", **ekp)
159
158
  with open(f"{data_root}/data_clim_{step}.grib", "wb") as f:
160
- data.write(f)
159
+ data.write(f) # ty: ignore
161
160
 
162
161
 
163
162
  if __name__ == "__main__":
@@ -1,9 +1,9 @@
1
1
  import os
2
2
  from typing import Any
3
3
 
4
- import jax
5
- import jax.numpy as jp
6
- import jax.random as jr
4
+ import jax # ty: ignore[unresolved-import]
5
+ import jax.numpy as jp # ty: ignore[unresolved-import]
6
+ import jax.random as jr # ty: ignore[unresolved-import]
7
7
 
8
8
  from cascade.low.builders import JobBuilder, TaskBuilder
9
9
  from cascade.low.core import JobInstance
@@ -65,7 +65,7 @@ def execute_locally():
65
65
 
66
66
  from multiprocessing.shared_memory import SharedMemory
67
67
 
68
- mem = SharedMemory("benchmark_tmp", create=True, size=m0.nbytes)
68
+ mem = SharedMemory("benchmark_tmp", create=True, size=m0.nbytes); assert mem.buf is not None
69
69
  mem.buf[:] = m0.tobytes()
70
70
 
71
71
 
@@ -32,7 +32,7 @@ from cascade.executor.runner.memory import Memory, ds2shmid
32
32
  from cascade.executor.runner.packages import PackagesEnv
33
33
  from cascade.executor.runner.runner import ExecutionContext, run
34
34
  from cascade.low.builders import TaskBuilder
35
- from cascade.low.core import DatasetId
35
+ from cascade.low.core import DatasetId, WorkerId
36
36
  from cascade.shm.server import entrypoint as shm_server
37
37
 
38
38
  logger = logging.getLogger(__name__)
@@ -75,7 +75,7 @@ def simple_runner(callback: BackboneAddress, executionContext: ExecutionContext)
75
75
  raise ValueError(f"expected 1 task, gotten {len(tasks)}")
76
76
  taskId = tasks[0]
77
77
  taskInstance = executionContext.tasks[taskId]
78
- with Memory(callback, "testWorker") as memory, PackagesEnv() as pckg:
78
+ with Memory(callback, WorkerId(host="testHost", worker="testWorker")) as memory, PackagesEnv() as pckg:
79
79
  # for key, value in taskSequence.extra_env.items():
80
80
  # os.environ[key] = value
81
81
 
@@ -142,7 +142,7 @@ def run_test(
142
142
  while perf_counter_ns() < end:
143
143
  mess = listener.recv_messages()
144
144
  if mess == [
145
- DatasetPublished(origin="testWorker", ds=output, transmit_idx=None)
145
+ DatasetPublished(origin=WorkerId(host="testHost", worker="testWorker"), ds=output, transmit_idx=None)
146
146
  ]:
147
147
  break
148
148
  elif not mess:
@@ -29,7 +29,7 @@ from cascade.executor.comms import callback
29
29
  from cascade.executor.config import logging_config, logging_config_filehandler
30
30
  from cascade.executor.executor import Executor
31
31
  from cascade.executor.msg import BackboneAddress, ExecutorShutdown
32
- from cascade.low.core import DatasetId, JobInstance
32
+ from cascade.low.core import DatasetId, JobInstance, JobInstanceRich
33
33
  from cascade.low.func import msum
34
34
  from cascade.scheduler.precompute import precompute
35
35
  from earthkit.workflows.graph import Graph, deduplicate_nodes
@@ -37,15 +37,16 @@ from earthkit.workflows.graph import Graph, deduplicate_nodes
37
37
  logger = logging.getLogger("cascade.benchmarks")
38
38
 
39
39
 
40
- def get_job(benchmark: str | None, instance_path: str | None) -> JobInstance:
40
+ def get_job(benchmark: str | None, instance_path: str | None) -> JobInstanceRich:
41
41
  # NOTE because of os.environ, we don't import all... ideally we'd have some file-based init/config mech instead
42
42
  if benchmark is not None and instance_path is not None:
43
43
  raise TypeError("specified both benchmark name and job instance")
44
44
  elif instance_path is not None:
45
45
  with open(instance_path, "rb") as f:
46
46
  d = orjson.loads(f.read())
47
- return JobInstance(**d)
47
+ return JobInstanceRich(**d)
48
48
  elif benchmark is not None:
49
+ instance: JobInstance
49
50
  if benchmark.startswith("j1"):
50
51
  import cascade.benchmarks.job1 as job1
51
52
 
@@ -58,25 +59,26 @@ def get_job(benchmark: str | None, instance_path: str | None) -> JobInstance:
58
59
  msum((v for k, v in graphs.items() if k.startswith(prefix)), Graph)
59
60
  )
60
61
  graphs["j1.all"] = union("j1.")
61
- return cascade.low.into.graph2job(graphs[benchmark])
62
+ instance = cascade.low.into.graph2job(graphs[benchmark])
62
63
  elif benchmark.startswith("generators"):
63
64
  import cascade.benchmarks.generators as generators
64
65
 
65
- return generators.get_job()
66
+ instance = generators.get_job()
66
67
  elif benchmark.startswith("matmul"):
67
68
  import cascade.benchmarks.matmul as matmul
68
69
 
69
- return matmul.get_job()
70
+ instance = matmul.get_job()
70
71
  elif benchmark.startswith("dist"):
71
72
  import cascade.benchmarks.dist as dist
72
73
 
73
- return dist.get_job()
74
+ instance = dist.get_job()
74
75
  elif benchmark.startswith("dask"):
75
76
  import cascade.benchmarks.dask as dask
76
77
 
77
- return dask.get_job(benchmark[len("dask.") :])
78
+ instance = dask.get_job(benchmark[len("dask.") :])
78
79
  else:
79
80
  raise NotImplementedError(benchmark)
81
+ return JobInstanceRich(jobInstance=instance, checkpointSpec=None)
80
82
  else:
81
83
  raise TypeError("specified neither benchmark name nor job instance")
82
84
 
@@ -116,7 +118,7 @@ def get_gpu_count(host_idx: int, worker_count: int) -> int:
116
118
 
117
119
 
118
120
  def launch_executor(
119
- job_instance: JobInstance,
121
+ job: JobInstanceRich,
120
122
  controller_address: BackboneAddress,
121
123
  workers_per_host: int,
122
124
  portBase: int,
@@ -136,7 +138,7 @@ def launch_executor(
136
138
  logger.info(f"will set {gpu_count} gpus on host {i}")
137
139
  os.environ["CASCADE_GPU_COUNT"] = str(gpu_count)
138
140
  executor = Executor(
139
- job_instance,
141
+ job.jobInstance,
140
142
  controller_address,
141
143
  workers_per_host,
142
144
  f"h{i}",
@@ -154,7 +156,7 @@ def launch_executor(
154
156
 
155
157
 
156
158
  def run_locally(
157
- job: JobInstance,
159
+ job: JobInstanceRich,
158
160
  hosts: int,
159
161
  workers: int,
160
162
  portBase: int = 12345,
@@ -195,7 +197,7 @@ def run_locally(
195
197
  ps.append(p)
196
198
 
197
199
  # compute preschedule
198
- preschedule = precompute(job)
200
+ preschedule = precompute(job.jobInstance)
199
201
 
200
202
  # check processes started healthy
201
203
  for i, p in enumerate(ps):
@@ -240,9 +242,9 @@ def main_local(
240
242
  port_base: int = 12345,
241
243
  log_base: str | None = None,
242
244
  ) -> None:
243
- jobInstance = get_job(job, instance)
245
+ jobInstanceRich = get_job(job, instance)
244
246
  run_locally(
245
- jobInstance,
247
+ jobInstanceRich,
246
248
  hosts,
247
249
  workers_per_host,
248
250
  report_address=report_address,
@@ -266,17 +268,17 @@ def main_dist(
266
268
  """
267
269
  launch = perf_counter_ns()
268
270
 
269
- jobInstance = get_job(job, instance)
271
+ jobInstanceRich = get_job(job, instance)
270
272
 
271
273
  if idx == 0:
272
274
  logging.config.dictConfig(logging_config)
273
275
  tp = ThreadPoolExecutor(max_workers=1)
274
- preschedule_fut = tp.submit(precompute, jobInstance)
276
+ preschedule_fut = tp.submit(precompute, jobInstanceRich.jobInstance)
275
277
  b = Bridge(controller_url, hosts)
276
278
  preschedule = preschedule_fut.result()
277
279
  tp.shutdown()
278
280
  start = perf_counter_ns()
279
- run(jobInstance, b, preschedule, report_address=report_address)
281
+ run(jobInstanceRich, b, preschedule, report_address=report_address)
280
282
  end = perf_counter_ns()
281
283
  print(
282
284
  f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
@@ -284,12 +286,13 @@ def main_dist(
284
286
  else:
285
287
  gpu_count = get_gpu_count(0, workers_per_host)
286
288
  launch_executor(
287
- jobInstance,
289
+ jobInstanceRich,
288
290
  controller_url,
289
291
  workers_per_host,
290
292
  12345,
291
293
  idx,
292
294
  shm_vol_gb,
293
295
  gpu_count,
294
- f"tcp://{platform.get_bindabble_self()}",
296
+ log_base = None, # TODO handle log collection for dist scenario
297
+ url_base = f"tcp://{platform.get_bindabble_self()}",
295
298
  )
cascade/controller/act.py CHANGED
@@ -10,6 +10,7 @@
10
10
 
11
11
  import logging
12
12
 
13
+ import cascade.executor.checkpoints as checkpoints
13
14
  from cascade.controller.core import State
14
15
  from cascade.executor.bridge import Bridge
15
16
  from cascade.executor.msg import TaskSequence
@@ -76,6 +77,12 @@ def flush_queues(bridge: Bridge, state: State, context: JobExecutionContext):
76
77
  for dataset, host in state.drain_fetching_queue():
77
78
  bridge.fetch(dataset, host)
78
79
 
80
+ for dataset, host in state.drain_persist_queue():
81
+ if context.checkpoint_spec is None:
82
+ raise TypeError(f"unexpected persist need when checkpoint storage not configured")
83
+ persist_params = checkpoints.serialize_persist_params(context.checkpoint_spec)
84
+ bridge.persist(dataset, host, context.checkpoint_spec.storage_type, persist_params)
85
+
79
86
  for ds in state.drain_purging_queue():
80
87
  for host in context.purge_dataset(ds):
81
88
  logger.debug(f"issuing purge of {ds=} to {host=}")
@@ -8,7 +8,7 @@ from dataclasses import dataclass
8
8
  from typing import Any, Iterator
9
9
 
10
10
  import cascade.executor.serde as serde
11
- from cascade.executor.msg import DatasetTransmitPayload
11
+ from cascade.executor.msg import DatasetPersistSuccess, DatasetTransmitPayload
12
12
  from cascade.low.core import DatasetId, HostId, TaskId
13
13
 
14
14
  logger = logging.getLogger(__name__)
@@ -16,10 +16,14 @@ logger = logging.getLogger(__name__)
16
16
 
17
17
  @dataclass
18
18
  class State:
19
- # key add by core.initialize, value add by notify.notify
19
+ # key add by core.init_state, value add by notify.notify
20
20
  outputs: dict[DatasetId, Any]
21
+ # key add by core.init_state, value add by notify.notify
22
+ to_persist: set[DatasetId]
21
23
  # add by notify.notify, remove by act.flush_queues
22
24
  fetching_queue: dict[DatasetId, HostId]
25
+ # add by notify.notify, remove by act.flush_queues
26
+ persist_queue: dict[DatasetId, HostId]
23
27
  # add by notify.notify, removed by act.flush_queues
24
28
  purging_queue: list[DatasetId]
25
29
  # add by core.init_state, remove by notify.notify
@@ -31,13 +35,16 @@ class State:
31
35
  for e in self.outputs.values():
32
36
  if e is None:
33
37
  return True
38
+ if self.to_persist:
39
+ return True
34
40
  return False
35
41
 
36
42
  def _consider_purge(self, dataset: DatasetId) -> None:
37
43
  """If dataset not required anymore, add to purging_queue"""
38
44
  no_dependants = not self.purging_tracker.get(dataset, None)
39
45
  not_required_output = self.outputs.get(dataset, 1) is not None
40
- if no_dependants and not_required_output:
46
+ not_required_persist = not dataset in self.to_persist
47
+ if all((no_dependants, not_required_output, not_required_persist)):
41
48
  logger.debug(f"adding {dataset=} to purging queue")
42
49
  if dataset in self.purging_tracker:
43
50
  self.purging_tracker.pop(dataset)
@@ -52,6 +59,14 @@ class State:
52
59
  ):
53
60
  self.fetching_queue[dataset] = at
54
61
 
62
+ def consider_persist(self, dataset: DatasetId, at: HostId) -> None:
63
+ """If required as persist and not yet acknowledged, add to persist queue"""
64
+ if (
65
+ dataset in self.to_persist
66
+ and dataset not in self.persist_queue
67
+ ):
68
+ self.persist_queue[dataset] = at
69
+
55
70
  def receive_payload(self, payload: DatasetTransmitPayload) -> None:
56
71
  """Stores deserialized value into outputs, considers purge"""
57
72
  # NOTE ifneedbe get annotation from job.tasks[event.ds.task].definition.output_schema[event.ds.output]
@@ -60,6 +75,11 @@ class State:
60
75
  )
61
76
  self._consider_purge(payload.header.ds)
62
77
 
78
+ def acknowledge_persist(self, payload: DatasetPersistSuccess) -> None:
79
+ """Marks acknowledged, considers purge"""
80
+ self.to_persist.discard(payload.ds)
81
+ self._consider_purge(payload.ds)
82
+
63
83
  def task_done(self, task: TaskId, inputs: set[DatasetId]) -> None:
64
84
  """Marks that the inputs are not needed for this task anymore, considers purge of each"""
65
85
  for sourceDataset in inputs:
@@ -76,15 +96,22 @@ class State:
76
96
  yield dataset, host
77
97
  self.fetching_queue = {}
78
98
 
99
+ def drain_persist_queue(self) -> Iterator[tuple[DatasetId, HostId]]:
100
+ for dataset, host in self.persist_queue.items():
101
+ yield dataset, host
102
+ self.persist_queue = {}
103
+
79
104
 
80
- def init_state(outputs: set[DatasetId], edge_o: dict[DatasetId, set[TaskId]]) -> State:
105
+ def init_state(outputs: set[DatasetId], to_persist: set[DatasetId], edge_o: dict[DatasetId, set[TaskId]]) -> State:
81
106
  purging_tracker = {
82
107
  ds: {task for task in dependants} for ds, dependants in edge_o.items()
83
108
  }
84
109
 
85
110
  return State(
86
111
  outputs={e: None for e in outputs},
112
+ to_persist={e for e in to_persist},
87
113
  fetching_queue={},
88
114
  purging_queue=[],
89
115
  purging_tracker=purging_tracker,
116
+ persist_queue={},
90
117
  )
@@ -14,7 +14,7 @@ from cascade.controller.core import State, init_state
14
14
  from cascade.controller.notify import notify
15
15
  from cascade.controller.report import Reporter
16
16
  from cascade.executor.bridge import Bridge, Event
17
- from cascade.low.core import JobInstance, type_dec
17
+ from cascade.low.core import JobInstance, JobInstanceRich, type_dec
18
18
  from cascade.low.execution_context import init_context
19
19
  from cascade.low.tracing import ControllerPhases, Microtrace, label, mark, timer
20
20
  from cascade.scheduler.api import assign, init_schedule, plan
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
24
24
 
25
25
 
26
26
  def run(
27
- job: JobInstance,
27
+ job: JobInstanceRich,
28
28
  bridge: Bridge,
29
29
  preschedule: Preschedule,
30
30
  report_address: str | None = None,
@@ -34,7 +34,8 @@ def run(
34
34
  outputs = set(context.job_instance.ext_outputs)
35
35
  logger.debug(f"starting with {env=} and {report_address=}")
36
36
  schedule = timer(init_schedule, Microtrace.ctrl_init)(preschedule, context)
37
- state = init_state(outputs, context.edge_o)
37
+ to_persist = set(job.checkpointSpec.to_persist) if job.checkpointSpec is not None else set()
38
+ state = init_state(outputs, to_persist, context.edge_o)
38
39
 
39
40
  label("host", "controller")
40
41
  events: list[Event] = []
@@ -44,7 +45,7 @@ def run(
44
45
 
45
46
  try:
46
47
  total_gpus = sum(worker.gpu for worker in env.workers.values())
47
- needs_gpus = any(task.definition.needs_gpu for task in job.tasks.values())
48
+ needs_gpus = any(task.definition.needs_gpu for task in job.jobInstance.tasks.values())
48
49
  if needs_gpus and total_gpus == 0:
49
50
  raise ValueError("environment contains no gpu yet job demands one")
50
51
 
@@ -17,7 +17,7 @@ from typing import Iterable
17
17
  from cascade.controller.core import State
18
18
  from cascade.controller.report import Reporter
19
19
  from cascade.executor.bridge import Event
20
- from cascade.executor.msg import DatasetPublished, DatasetTransmitPayload
20
+ from cascade.executor.msg import DatasetPersistSuccess, DatasetPublished, DatasetTransmitPayload
21
21
  from cascade.low.core import DatasetId, HostId, WorkerId
22
22
  from cascade.low.execution_context import DatasetStatus, JobExecutionContext
23
23
  from cascade.low.func import assert_never
@@ -89,6 +89,7 @@ def notify(
89
89
  context.host2ds[host][event.ds] = DatasetStatus.available
90
90
  context.ds2host[event.ds][host] = DatasetStatus.available
91
91
  state.consider_fetch(event.ds, host)
92
+ state.consider_persist(event.ds, host)
92
93
  consider_computable(schedule, state, context, event.ds, host)
93
94
  if event.transmit_idx is not None:
94
95
  mark(
@@ -121,5 +122,7 @@ def notify(
121
122
  elif isinstance(event, DatasetTransmitPayload):
122
123
  state.receive_payload(event)
123
124
  reporter.send_result(event.header.ds, event.value)
125
+ elif isinstance(event, DatasetPersistSuccess):
126
+ state.acknowledge_persist(event)
124
127
  else:
125
128
  assert_never(event)
@@ -16,6 +16,9 @@ from cascade.executor.comms import default_message_resend_ms as resend_grace_ms
16
16
  from cascade.executor.executor import heartbeat_grace_ms as executor_heartbeat_grace_ms
17
17
  from cascade.executor.msg import (
18
18
  Ack,
19
+ DatasetPersistCommand,
20
+ DatasetPersistFailure,
21
+ DatasetPersistSuccess,
19
22
  DatasetPublished,
20
23
  DatasetPurge,
21
24
  DatasetTransmitCommand,
@@ -29,14 +32,15 @@ from cascade.executor.msg import (
29
32
  TaskFailure,
30
33
  TaskSequence,
31
34
  )
32
- from cascade.low.core import DatasetId, Environment, HostId, Worker, WorkerId
35
+ from cascade.low.core import CheckpointStorageType, DatasetId, Environment, HostId, Worker, WorkerId
33
36
  from cascade.low.func import assert_never
34
37
 
35
38
  logger = logging.getLogger(__name__)
36
39
 
37
- Event = DatasetPublished | DatasetTransmitPayload
38
- ToShutdown = TaskFailure | ExecutorFailure | DatasetTransmitFailure | ExecutorExit
39
- Unsupported = TaskSequence | DatasetPurge | DatasetTransmitCommand | ExecutorShutdown
40
+ Event = DatasetPublished | DatasetTransmitPayload | DatasetPersistSuccess
41
+ # TODO consider retries here, esp on the PersistFailure
42
+ ToShutdown = TaskFailure | ExecutorFailure | DatasetTransmitFailure | DatasetPersistFailure | ExecutorExit
43
+ Unsupported = TaskSequence | DatasetPurge | DatasetTransmitCommand | DatasetPersistCommand | ExecutorShutdown
40
44
 
41
45
 
42
46
  class Bridge:
@@ -158,6 +162,15 @@ class Bridge:
158
162
  self.transmit_idx_counter += 1
159
163
  self.sender.send("data." + source, m)
160
164
 
165
+ def persist(self, ds: DatasetId, source: HostId, storage_type: CheckpointStorageType, persist_params: str) -> None:
166
+ m = DatasetPersistCommand(
167
+ source=source,
168
+ ds=ds,
169
+ storage_type=storage_type,
170
+ persist_params=persist_params,
171
+ )
172
+ self.sender.send("data." + source, m)
173
+
161
174
  def fetch(self, ds: DatasetId, source: HostId) -> None:
162
175
  m = DatasetTransmitCommand(
163
176
  source=source,
@@ -0,0 +1,42 @@
1
+ # (C) Copyright 2025- ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+
9
+ """Handles the checkpoint management: storage, retrieval"""
10
+
11
+ import pathlib
12
+
13
+ from cascade.executor.msg import DatasetPersistCommand
14
+ from cascade.low.core import CheckpointSpec
15
+ from cascade.low.func import assert_never
16
+ from cascade.shm.client import AllocatedBuffer
17
+
18
+
19
+ def persist_dataset(command: DatasetPersistCommand, buf: AllocatedBuffer) -> None:
20
+ match command.storage_type:
21
+ case "fs":
22
+ root = pathlib.Path(command.persist_params)
23
+ root.mkdir(parents=True, exist_ok=True)
24
+ file = root / repr(command.ds)
25
+ # TODO what about overwrites / concurrent writes? Append uuid?
26
+ file.write_bytes(buf.view())
27
+ case s:
28
+ assert_never(s)
29
+
30
+ def serialize_persist_params(spec: CheckpointSpec) -> str:
31
+ # NOTE we call this every time we store, ideally call this once when building `low.execution_context`
32
+ match spec.storage_type:
33
+ case "fs":
34
+ if not isinstance(spec.storage_params, str):
35
+ raise TypeError(f"expected checkpoint storage params to be str, gotten {spec.storage_params.__class__}")
36
+ if spec.persist_id is None:
37
+ raise TypeError(f"serialize_persist_params called, but persist_id is None")
38
+ root = pathlib.Path(spec.storage_params)
39
+ return str(root / spec.persist_id)
40
+ case s:
41
+ assert_never(s)
42
+
@@ -15,16 +15,20 @@ large data object.
15
15
 
16
16
  import logging
17
17
  import logging.config
18
- from concurrent.futures import ALL_COMPLETED, FIRST_COMPLETED
18
+ from concurrent.futures import ALL_COMPLETED, FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
19
19
  from concurrent.futures import Executor as PythonExecutor
20
- from concurrent.futures import Future, ThreadPoolExecutor, wait
21
20
  from time import time_ns
21
+ from typing import cast
22
22
 
23
23
  import cascade.shm.client as shm_client
24
+ from cascade.executor.checkpoints import persist_dataset
24
25
  from cascade.executor.comms import Listener, callback, send_data
25
26
  from cascade.executor.msg import (
26
27
  Ack,
27
28
  BackboneAddress,
29
+ DatasetPersistCommand,
30
+ DatasetPersistFailure,
31
+ DatasetPersistSuccess,
28
32
  DatasetPublished,
29
33
  DatasetPurge,
30
34
  DatasetTransmitCommand,
@@ -59,7 +63,7 @@ class DataServer:
59
63
  self.cap = 2
60
64
  self.ds_proc_tp: PythonExecutor = ThreadPoolExecutor(max_workers=self.cap)
61
65
  self.futs_in_progress: dict[
62
- DatasetTransmitCommand | DatasetTransmitPayload, Future
66
+ DatasetTransmitCommand | DatasetTransmitPayload | DatasetPersistCommand, Future
63
67
  ] = {}
64
68
  self.awaiting_confirmation: dict[int, tuple[DatasetTransmitCommand, int]] = {}
65
69
  self.invalid: set[DatasetId] = (
@@ -149,6 +153,26 @@ class DataServer:
149
153
  time_ns()
150
154
  ) # not actually consumed but uniform signature with send_payload simplifies typing
151
155
 
156
+ def persist_payload(self, command: DatasetPersistCommand) -> int:
157
+ buf: None | shm_client.AllocatedBuffer = None
158
+ try:
159
+ if command.source != self.host:
160
+ raise ValueError(f"invalid {command=}")
161
+ buf = shm_client.get(key=ds2shmid(command.ds))
162
+ persist_dataset(command, buf)
163
+ logger.debug(f"dataset for {command} persisted")
164
+ callback(self.maddress, DatasetPersistSuccess(host=self.host, ds=command.ds))
165
+ except Exception as e:
166
+ logger.exception(f"failed to persist dataset for {command}, reporting up")
167
+ callback(
168
+ self.maddress,
169
+ DatasetPersistFailure(host=self.host, detail=f"{repr(command)} -> {repr(e)}"),
170
+ )
171
+ finally:
172
+ if buf is not None:
173
+ buf.close()
174
+ return time_ns()
175
+
152
176
  def send_payload(self, command: DatasetTransmitCommand) -> int:
153
177
  buf: None | shm_client.AllocatedBuffer = None
154
178
  payload: None | DatasetTransmitPayload = None
@@ -171,7 +195,7 @@ class DataServer:
171
195
  ds=command.ds,
172
196
  deser_fun=buf.deser_fun,
173
197
  )
174
- payload = DatasetTransmitPayload(header, value=buf.view())
198
+ payload = DatasetTransmitPayload(header, value=cast(bytes, buf.view()))
175
199
  syn = Syn(command.idx, self.dlistener.address)
176
200
  send_data(command.daddress, payload, syn)
177
201
  logger.debug(f"payload for {command} sent")
@@ -218,6 +242,14 @@ class DataServer:
218
242
  self.awaiting_confirmation[m.idx] = (m, -1)
219
243
  fut = self.ds_proc_tp.submit(self.send_payload, m)
220
244
  self.futs_in_progress[m] = fut
245
+ elif isinstance(m, DatasetPersistCommand):
246
+ if m.ds in self.invalid:
247
+ raise ValueError(
248
+ f"unexpected persist command {m} as the dataset was already purged"
249
+ )
250
+ # TODO mark?
251
+ fut = self.ds_proc_tp.submit(self.persist_payload, m)
252
+ self.futs_in_progress[m] = fut
221
253
  elif isinstance(m, DatasetTransmitPayload):
222
254
  if m.header.ds in self.invalid:
223
255
  logger.warning(
@@ -238,9 +270,10 @@ class DataServer:
238
270
  self.acks.add(m.idx)
239
271
  elif isinstance(m, DatasetPurge):
240
272
  # we need to handle potential commands transmitting this dataset, as otherwise they'd fail
273
+ # TODO submit this as a future? This actively blocks the whole server
241
274
  to_wait = []
242
275
  for commandProg, fut in self.futs_in_progress.items():
243
- if isinstance(commandProg, DatasetTransmitCommand):
276
+ if isinstance(commandProg, DatasetTransmitCommand|DatasetPersistCommand):
244
277
  val = commandProg.ds
245
278
  elif isinstance(commandProg, DatasetTransmitPayload):
246
279
  val = commandProg.header.ds
@@ -35,6 +35,8 @@ from cascade.executor.data_server import start_data_server
35
35
  from cascade.executor.msg import (
36
36
  Ack,
37
37
  BackboneAddress,
38
+ DatasetPersistFailure,
39
+ DatasetPersistSuccess,
38
40
  DatasetPublished,
39
41
  DatasetPurge,
40
42
  DatasetTransmitFailure,
@@ -306,7 +308,7 @@ class Executor:
306
308
  callback(worker_address(worker), m)
307
309
  self.datasets.add(m.ds)
308
310
  self.to_controller(m)
309
- elif isinstance(m, DatasetTransmitFailure):
311
+ elif isinstance(m, DatasetTransmitFailure|DatasetPersistSuccess|DatasetPersistFailure):
310
312
  self.to_controller(m)
311
313
  else:
312
314
  # NOTE transmit and store are handled in DataServer (which has its own socket)