earthkit-workflows 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -209,7 +209,7 @@ def run_locally(
209
209
 
210
210
  # start bridge itself
211
211
  logger.debug("starting bridge")
212
- b = Bridge(c, hosts)
212
+ b = Bridge(c, hosts, job.checkpointSpec)
213
213
  start = perf_counter_ns()
214
214
  result = run(job, b, preschedule, report_address=report_address)
215
215
  end = perf_counter_ns()
@@ -274,7 +274,7 @@ def main_dist(
274
274
  logging.config.dictConfig(logging_config)
275
275
  tp = ThreadPoolExecutor(max_workers=1)
276
276
  preschedule_fut = tp.submit(precompute, jobInstanceRich.jobInstance)
277
- b = Bridge(controller_url, hosts)
277
+ b = Bridge(controller_url, hosts, jobInstanceRich.checkpointSpec)
278
278
  preschedule = preschedule_fut.result()
279
279
  tp.shutdown()
280
280
  start = perf_counter_ns()
cascade/controller/act.py CHANGED
@@ -9,12 +9,14 @@
9
9
  """Implements the invocation of Bridge/Executor methods given a sequence of Actions"""
10
10
 
11
11
  import logging
12
+ from typing import Iterable, Iterator, cast
12
13
 
13
- import cascade.executor.checkpoints as checkpoints
14
14
  from cascade.controller.core import State
15
15
  from cascade.executor.bridge import Bridge
16
- from cascade.executor.msg import TaskSequence
17
- from cascade.low.execution_context import JobExecutionContext
16
+ from cascade.executor.checkpoints import build_retrieve_command, possible_repersist, retrieve_dataset
17
+ from cascade.executor.msg import DatasetPublished, TaskSequence
18
+ from cascade.low.core import DatasetId
19
+ from cascade.low.execution_context import JobExecutionContext, VirtualCheckpointHost
18
20
  from cascade.low.tracing import TaskLifecycle, TransmitLifecycle, mark
19
21
  from cascade.scheduler.core import Assignment
20
22
 
@@ -75,17 +77,46 @@ def flush_queues(bridge: Bridge, state: State, context: JobExecutionContext):
75
77
  """
76
78
 
77
79
  for dataset, host in state.drain_fetching_queue():
78
- bridge.fetch(dataset, host)
80
+ if host != VirtualCheckpointHost:
81
+ bridge.fetch(dataset, host)
82
+ else:
83
+ # NOTE we would rather not be here, but we dont generally expect
84
+ # checkpointed datasets to be outputs. If needbe, send a command
85
+ # to any worker, or spawn a thread with this
86
+ logger.warning(f"execute checkpoint retrieve on controller")
87
+ # NOTE the host is the virtual one so the message is not really valid, but no big deal
88
+ virtual_command = build_retrieve_command(bridge.checkpoint_spec, dataset, host)
89
+ buffer = retrieve_dataset(virtual_command)
90
+ try:
91
+ # the cast is wrong but ty is bit confused about memoryview anyway
92
+ state.receive_payload(dataset, cast(bytes, buffer.view()), buffer.deser_fun)
93
+ finally:
94
+ buffer.close()
79
95
 
80
96
  for dataset, host in state.drain_persist_queue():
81
- if context.checkpoint_spec is None:
82
- raise TypeError(f"unexpected persist need when checkpoint storage not configured")
83
- persist_params = checkpoints.serialize_persist_params(context.checkpoint_spec)
84
- bridge.persist(dataset, host, context.checkpoint_spec.storage_type, persist_params)
97
+ if host != VirtualCheckpointHost:
98
+ bridge.persist(dataset, host)
99
+ else:
100
+ possible_repersist(dataset, bridge.checkpoint_spec)
101
+ state.acknowledge_persist(dataset)
85
102
 
86
103
  for ds in state.drain_purging_queue():
87
104
  for host in context.purge_dataset(ds):
88
- logger.debug(f"issuing purge of {ds=} to {host=}")
89
- bridge.purge(host, ds)
105
+ if host != VirtualCheckpointHost:
106
+ logger.debug(f"issuing purge of {ds=} to {host=}")
107
+ bridge.purge(host, ds)
90
108
 
91
109
  return state
110
+
111
+ def virtual_checkpoint_publish(datasets: Iterable[DatasetId]) -> Iterator[DatasetPublished]:
112
+ """Virtual in the sense of not actually sending any message, but instead simulating
113
+ a response so that controller.notify can bring the contexts into the right state.
114
+ Invoked once, at the job start, after the checkpoint has been listed"""
115
+ return (
116
+ DatasetPublished(
117
+ origin=VirtualCheckpointHost,
118
+ ds=dataset,
119
+ transmit_idx=None,
120
+ )
121
+ for dataset in datasets
122
+ )
@@ -10,6 +10,7 @@ from typing import Any, Iterator
10
10
  import cascade.executor.serde as serde
11
11
  from cascade.executor.msg import DatasetPersistSuccess, DatasetTransmitPayload
12
12
  from cascade.low.core import DatasetId, HostId, TaskId
13
+ from cascade.low.execution_context import VirtualCheckpointHost
13
14
 
14
15
  logger = logging.getLogger(__name__)
15
16
 
@@ -67,18 +68,16 @@ class State:
67
68
  ):
68
69
  self.persist_queue[dataset] = at
69
70
 
70
- def receive_payload(self, payload: DatasetTransmitPayload) -> None:
71
+ def receive_payload(self, ds: DatasetId, payload: bytes, deser_fun: str) -> None:
71
72
  """Stores deserialized value into outputs, considers purge"""
72
73
  # NOTE ifneedbe get annotation from job.tasks[event.ds.task].definition.output_schema[event.ds.output]
73
- self.outputs[payload.header.ds] = serde.des_output(
74
- payload.value, "Any", payload.header.deser_fun
75
- )
76
- self._consider_purge(payload.header.ds)
74
+ self.outputs[ds] = serde.des_output(payload, "Any", deser_fun)
75
+ self._consider_purge(ds)
77
76
 
78
- def acknowledge_persist(self, payload: DatasetPersistSuccess) -> None:
77
+ def acknowledge_persist(self, ds: DatasetId) -> None:
79
78
  """Marks acknowledged, considers purge"""
80
- self.to_persist.discard(payload.ds)
81
- self._consider_purge(payload.ds)
79
+ self.to_persist.discard(ds)
80
+ self._consider_purge(ds)
82
81
 
83
82
  def task_done(self, task: TaskId, inputs: set[DatasetId]) -> None:
84
83
  """Marks that the inputs are not needed for this task anymore, considers purge of each"""
@@ -9,15 +9,17 @@
9
9
  import logging
10
10
 
11
11
  import cascade.executor.serde as serde
12
- from cascade.controller.act import act, flush_queues
12
+ from cascade.controller.act import act, flush_queues, virtual_checkpoint_publish
13
13
  from cascade.controller.core import State, init_state
14
14
  from cascade.controller.notify import notify
15
15
  from cascade.controller.report import Reporter
16
16
  from cascade.executor.bridge import Bridge, Event
17
+ from cascade.executor.checkpoints import list_persisted_datasets
17
18
  from cascade.low.core import JobInstance, JobInstanceRich, type_dec
18
19
  from cascade.low.execution_context import init_context
19
20
  from cascade.low.tracing import ControllerPhases, Microtrace, label, mark, timer
20
21
  from cascade.scheduler.api import assign, init_schedule, plan
22
+ from cascade.scheduler.checkpoints import trim_with_persisted, virtual_update_schedule
21
23
  from cascade.scheduler.core import Preschedule
22
24
 
23
25
  logger = logging.getLogger(__name__)
@@ -30,6 +32,9 @@ def run(
30
32
  report_address: str | None = None,
31
33
  ) -> State:
32
34
  env = bridge.get_environment()
35
+ persisted = list_persisted_datasets(job.checkpointSpec) if job.checkpointSpec is not None else []
36
+ jobInstance, preschedule, persisted_valid = trim_with_persisted(job, preschedule, set(persisted))
37
+ job.jobInstance = jobInstance
33
38
  context = init_context(env, job, preschedule.edge_o, preschedule.edge_i)
34
39
  outputs = set(context.job_instance.ext_outputs)
35
40
  logger.debug(f"starting with {env=} and {report_address=}")
@@ -42,6 +47,7 @@ def run(
42
47
  for serdeTypeEnc, (serdeSer, serdeDes) in context.job_instance.serdes.items():
43
48
  serde.SerdeRegistry.register(type_dec(serdeTypeEnc), serdeSer, serdeDes)
44
49
  reporter = Reporter(report_address)
50
+ notify_wrapper = lambda events: notify(state, schedule, context, events, reporter)
45
51
 
46
52
  try:
47
53
  total_gpus = sum(worker.gpu for worker in env.workers.values())
@@ -49,6 +55,10 @@ def run(
49
55
  if needs_gpus and total_gpus == 0:
50
56
  raise ValueError("environment contains no gpu yet job demands one")
51
57
 
58
+ virtual_update_schedule(persisted_valid, schedule, context)
59
+ virtual_events = virtual_checkpoint_publish(persisted_valid)
60
+ timer(notify_wrapper, Microtrace.ctrl_notify)(virtual_events)
61
+
52
62
  while (
53
63
  state.has_awaitable()
54
64
  or context.has_awaitable()
@@ -68,11 +78,9 @@ def run(
68
78
 
69
79
  mark({"action": ControllerPhases.wait})
70
80
  if state.has_awaitable() or context.has_awaitable():
71
- logger.debug(f"about to await bridge with {context.ongoing_total=}")
81
+ logger.debug(f"about to await bridge with {context.ongoing_total=}, {context.remaining=} and {state.has_awaitable()=}")
72
82
  events = timer(bridge.recv_events, Microtrace.ctrl_wait)()
73
- timer(notify, Microtrace.ctrl_notify)(
74
- state, schedule, context, events, reporter
75
- )
83
+ timer(notify_wrapper, Microtrace.ctrl_notify)(events)
76
84
  logger.debug(f"received {len(events)} events")
77
85
  except Exception as ex:
78
86
  logger.error("crash in controller, shuting down")
@@ -12,14 +12,14 @@
12
12
  # Thus the caller always *must* use the return value and cease using the input.
13
13
 
14
14
  import logging
15
- from typing import Iterable
15
+ from typing import Iterable, cast
16
16
 
17
17
  from cascade.controller.core import State
18
18
  from cascade.controller.report import Reporter
19
19
  from cascade.executor.bridge import Event
20
- from cascade.executor.msg import DatasetPersistSuccess, DatasetPublished, DatasetTransmitPayload
20
+ from cascade.executor.msg import DatasetPersistSuccess, DatasetPublished, DatasetRetrieveSuccess, DatasetTransmitPayload
21
21
  from cascade.low.core import DatasetId, HostId, WorkerId
22
- from cascade.low.execution_context import DatasetStatus, JobExecutionContext
22
+ from cascade.low.execution_context import DatasetStatus, JobExecutionContext, VirtualCheckpointHost
23
23
  from cascade.low.func import assert_never
24
24
  from cascade.low.tracing import TaskLifecycle, TransmitLifecycle, mark
25
25
  from cascade.scheduler.api import gang_check_ready
@@ -103,26 +103,34 @@ def notify(
103
103
  elif context.is_last_output_of(event.ds):
104
104
  worker = event.origin
105
105
  task = event.ds.task
106
- if not isinstance(worker, WorkerId):
106
+ isWorker = isinstance(worker, WorkerId)
107
+ isVirtual = worker == VirtualCheckpointHost
108
+ if not isWorker and not isVirtual:
107
109
  raise ValueError(
108
110
  f"malformed event, expected origin to be WorkerId: {event}"
109
111
  )
110
112
  logger.debug(f"last output of {task}, assuming completion")
111
- mark(
112
- {
113
- "task": task,
114
- "action": TaskLifecycle.completed,
115
- "worker": repr(worker),
116
- "host": "controller",
117
- }
118
- )
119
113
  state.task_done(task, context.edge_i.get(event.ds.task, set()))
120
- context.task_done(task, worker)
114
+ if isWorker:
115
+ mark(
116
+ {
117
+ "task": task,
118
+ "action": TaskLifecycle.completed,
119
+ "worker": repr(worker),
120
+ "host": "controller",
121
+ }
122
+ )
123
+ worker = cast(WorkerId, worker) # ty cant yet derive this to be true
124
+ context.task_done_at(task, worker)
125
+ else:
126
+ context.task_done(task)
121
127
  reporter.send_progress(context)
122
128
  elif isinstance(event, DatasetTransmitPayload):
123
- state.receive_payload(event)
129
+ state.receive_payload(event.header.ds, event.value, event.header.deser_fun)
124
130
  reporter.send_result(event.header.ds, event.value)
125
131
  elif isinstance(event, DatasetPersistSuccess):
126
- state.acknowledge_persist(event)
132
+ state.acknowledge_persist(event.ds)
133
+ elif isinstance(event, DatasetRetrieveSuccess):
134
+ pass
127
135
  else:
128
136
  assert_never(event)
@@ -11,6 +11,7 @@
11
11
  import logging
12
12
  import time
13
13
 
14
+ from cascade.executor.checkpoints import build_persist_command, build_retrieve_command, serialize_params
14
15
  from cascade.executor.comms import GraceWatcher, Listener, ReliableSender
15
16
  from cascade.executor.comms import default_message_resend_ms as resend_grace_ms
16
17
  from cascade.executor.executor import heartbeat_grace_ms as executor_heartbeat_grace_ms
@@ -21,6 +22,9 @@ from cascade.executor.msg import (
21
22
  DatasetPersistSuccess,
22
23
  DatasetPublished,
23
24
  DatasetPurge,
25
+ DatasetRetrieveCommand,
26
+ DatasetRetrieveFailure,
27
+ DatasetRetrieveSuccess,
24
28
  DatasetTransmitCommand,
25
29
  DatasetTransmitFailure,
26
30
  DatasetTransmitPayload,
@@ -32,19 +36,21 @@ from cascade.executor.msg import (
32
36
  TaskFailure,
33
37
  TaskSequence,
34
38
  )
35
- from cascade.low.core import CheckpointStorageType, DatasetId, Environment, HostId, Worker, WorkerId
39
+ from cascade.low.core import CheckpointSpec, DatasetId, Environment, HostId, Worker, WorkerId
40
+ from cascade.low.execution_context import VirtualCheckpointHost
36
41
  from cascade.low.func import assert_never
37
42
 
38
43
  logger = logging.getLogger(__name__)
39
44
 
40
- Event = DatasetPublished | DatasetTransmitPayload | DatasetPersistSuccess
41
- # TODO consider retries here, esp on the PersistFailure
42
- ToShutdown = TaskFailure | ExecutorFailure | DatasetTransmitFailure | DatasetPersistFailure | ExecutorExit
45
+ Event = DatasetPublished | DatasetTransmitPayload | DatasetPersistSuccess | DatasetRetrieveSuccess
46
+ # TODO consider retries here, esp on the Persist/Retrieve Failures
47
+ ToShutdown = TaskFailure | ExecutorFailure | DatasetRetrieveFailure | DatasetTransmitFailure | DatasetPersistFailure | ExecutorExit
43
48
  Unsupported = TaskSequence | DatasetPurge | DatasetTransmitCommand | DatasetPersistCommand | ExecutorShutdown
44
49
 
45
50
 
46
51
  class Bridge:
47
- def __init__(self, controller_url: str, expected_executors: int) -> None:
52
+ def __init__(self, controller_url: str, expected_executors: int, checkpoint_spec: CheckpointSpec|None=None) -> None:
53
+ self.checkpoint_spec = checkpoint_spec
48
54
  self.mlistener = Listener(controller_url)
49
55
  self.heartbeat_checker: dict[HostId, GraceWatcher] = {}
50
56
  self.transmit_idx_counter = 0
@@ -152,24 +158,23 @@ class Bridge:
152
158
  self._send(host, m)
153
159
 
154
160
  def transmit(self, ds: DatasetId, source: HostId, target: HostId) -> None:
155
- m = DatasetTransmitCommand(
156
- source=source,
157
- target=target,
158
- daddress=self.sender.hosts["data." + target][1],
159
- ds=ds,
160
- idx=self.transmit_idx_counter,
161
- )
162
- self.transmit_idx_counter += 1
163
- self.sender.send("data." + source, m)
161
+ if source == VirtualCheckpointHost:
162
+ command = build_retrieve_command(self.checkpoint_spec, ds, target)
163
+ self.sender.send("data." + target, command)
164
+ else:
165
+ m = DatasetTransmitCommand(
166
+ source=source,
167
+ target=target,
168
+ daddress=self.sender.hosts["data." + target][1],
169
+ ds=ds,
170
+ idx=self.transmit_idx_counter,
171
+ )
172
+ self.transmit_idx_counter += 1
173
+ self.sender.send("data." + source, m)
164
174
 
165
- def persist(self, ds: DatasetId, source: HostId, storage_type: CheckpointStorageType, persist_params: str) -> None:
166
- m = DatasetPersistCommand(
167
- source=source,
168
- ds=ds,
169
- storage_type=storage_type,
170
- persist_params=persist_params,
171
- )
172
- self.sender.send("data." + source, m)
175
+ def persist(self, ds: DatasetId, source: HostId) -> None:
176
+ command = build_persist_command(self.checkpoint_spec, ds, source)
177
+ self.sender.send("data." + source, command)
173
178
 
174
179
  def fetch(self, ds: DatasetId, source: HostId) -> None:
175
180
  m = DatasetTransmitCommand(
@@ -8,35 +8,127 @@
8
8
 
9
9
  """Handles the checkpoint management: storage, retrieval"""
10
10
 
11
+ import io
12
+ import logging
13
+ import os
11
14
  import pathlib
12
15
 
13
- from cascade.executor.msg import DatasetPersistCommand
14
- from cascade.low.core import CheckpointSpec
16
+ from cascade.executor.msg import DatasetPersistCommand, DatasetRetrieveCommand
17
+ from cascade.executor.platform import advise_seqread
18
+ from cascade.executor.runner.memory import ds2shmid
19
+ from cascade.executor.serde import DefaultSerde
20
+ from cascade.low.core import CheckpointSpec, DatasetId, HostId
21
+ from cascade.low.execution_context import VirtualCheckpointHost
15
22
  from cascade.low.func import assert_never
16
- from cascade.shm.client import AllocatedBuffer
23
+ from cascade.shm.client import AllocatedBuffer, allocate
17
24
 
25
+ logger = logging.getLogger(__name__)
26
+
27
+ def serialize_params(spec: CheckpointSpec, id_: str) -> str:
28
+ """id_ is either the persist id or retrieve id from the spec"""
29
+ # NOTE we call this every time we store, ideally call this once when building `low.execution_context`
30
+ match spec.storage_type:
31
+ case "fs":
32
+ if not isinstance(spec.storage_params, str):
33
+ raise TypeError(f"expected checkpoint storage params to be str, gotten {spec.storage_params.__class__}")
34
+ root = pathlib.Path(spec.storage_params)
35
+ return str(root / id_)
36
+ case s:
37
+ assert_never(s)
38
+
39
+ def build_persist_command(checkpoint_spec: CheckpointSpec|None, ds: DatasetId, hostId: HostId) -> DatasetPersistCommand:
40
+ if checkpoint_spec is None:
41
+ raise ValueError(f"unexpected persist need when checkpoint storage not configured")
42
+ id_ = checkpoint_spec.persist_id
43
+ if not id_:
44
+ raise ValueError(f"unexpected persist need when there is no persist id")
45
+ persist_params = serialize_params(checkpoint_spec, id_)
46
+ return DatasetPersistCommand(
47
+ source=hostId,
48
+ ds=ds,
49
+ storage_type=checkpoint_spec.storage_type,
50
+ persist_params=persist_params,
51
+ )
18
52
 
19
53
  def persist_dataset(command: DatasetPersistCommand, buf: AllocatedBuffer) -> None:
20
54
  match command.storage_type:
21
55
  case "fs":
22
56
  root = pathlib.Path(command.persist_params)
23
57
  root.mkdir(parents=True, exist_ok=True)
24
- file = root / repr(command.ds)
58
+ file = root / command.ds.ser()
25
59
  # TODO what about overwrites / concurrent writes? Append uuid?
26
60
  file.write_bytes(buf.view())
27
61
  case s:
28
62
  assert_never(s)
29
63
 
30
- def serialize_persist_params(spec: CheckpointSpec) -> str:
31
- # NOTE we call this every time we store, ideally call this once when building `low.execution_context`
64
+ def list_persisted_datasets(spec: CheckpointSpec) -> list[DatasetId]:
32
65
  match spec.storage_type:
33
66
  case "fs":
34
- if not isinstance(spec.storage_params, str):
35
- raise TypeError(f"expected checkpoint storage params to be str, gotten {spec.storage_params.__class__}")
36
- if spec.persist_id is None:
37
- raise TypeError(f"serialize_persist_params called, but persist_id is None")
38
- root = pathlib.Path(spec.storage_params)
39
- return str(root / spec.persist_id)
67
+ if not spec.persist_id:
68
+ raise ValueError("unexpected list persisted when there is no persist id")
69
+ root = pathlib.Path(spec.storage_params) / spec.persist_id
70
+ if not root.exists():
71
+ return [] # we mkdir only at a first persist, so absence of folder is valid emptiness
72
+ files = (x for x in root.iterdir() if x.is_file())
73
+ return [DatasetId.des(file.parts[-1]) for file in files]
40
74
  case s:
41
75
  assert_never(s)
42
76
 
77
+ def build_retrieve_command(checkpoint_spec: CheckpointSpec|None, ds: DatasetId, hostId: HostId) -> DatasetRetrieveCommand:
78
+ if checkpoint_spec is None:
79
+ raise ValueError(f"unexpected retrieve need when checkpoint storage not configured")
80
+ id_ = checkpoint_spec.retrieve_id
81
+ if not id_:
82
+ raise ValueError(f"unexpected retrieve when there is no retrive id")
83
+ retrieve_params = serialize_params(checkpoint_spec, id_)
84
+ return DatasetRetrieveCommand(
85
+ target=hostId,
86
+ ds=ds,
87
+ storage_type=checkpoint_spec.storage_type,
88
+ retrieve_params=retrieve_params,
89
+ )
90
+
91
+ def retrieve_dataset(command: DatasetRetrieveCommand) -> AllocatedBuffer:
92
+ match command.storage_type:
93
+ case "fs":
94
+ shm_key = ds2shmid(command.ds)
95
+ fpath = pathlib.Path(command.retrieve_params) / command.ds.ser()
96
+ fd = os.open(fpath, os.O_RDONLY)
97
+ try:
98
+ advise_seqread(fd)
99
+ size = os.fstat(fd).st_size
100
+ # TODO dont use default serde, get it via the command
101
+ buf = allocate(shm_key, size, DefaultSerde)
102
+ # once on 3.14+, replace with this
103
+ # os.readinto(fd, buf.view())
104
+ with io.FileIO(fd, closefd=False) as raw_io:
105
+ raw_io.readinto(buf.view())
106
+ finally:
107
+ os.close(fd)
108
+ return buf
109
+ case s:
110
+ assert_never(s)
111
+
112
+ def possible_repersist(dataset: DatasetId, checkpointSpec: CheckpointSpec|None) -> None:
113
+ # NOTE blocking -> unfortunate for controller, but we dont expect this to be frequent/hot.
114
+ # If needbe, spawn a thread or something. In that case needs a completion callback
115
+ if not checkpointSpec:
116
+ raise ValueError(f"unexpected repersist when checkpoint storage not configured")
117
+ if not checkpointSpec.retrieve_id:
118
+ raise ValueError(f"unexpected repersist when no retrieve id")
119
+ if not checkpointSpec.persist_id:
120
+ raise ValueError(f"unexpected repersist when no persist id")
121
+
122
+ if checkpointSpec.retrieve_id == checkpointSpec.persist_id:
123
+ # we assume reproducibility---bold!---so we better warn about it
124
+ logger.warning(f"no-op for persist of {dataset} as was already persisted under the same id {checkpointSpec.retrieve_id}")
125
+ return
126
+
127
+ # NOTE the host is the virtual one so the message is not really valid, but no big deal
128
+ retrieve_command = build_retrieve_command(checkpointSpec, dataset, VirtualCheckpointHost)
129
+ persist_command = build_persist_command(checkpointSpec, dataset, VirtualCheckpointHost)
130
+ buffer = retrieve_dataset(retrieve_command)
131
+ try:
132
+ persist_dataset(persist_command, buffer)
133
+ finally:
134
+ buffer.close()
@@ -21,7 +21,7 @@ from time import time_ns
21
21
  from typing import cast
22
22
 
23
23
  import cascade.shm.client as shm_client
24
- from cascade.executor.checkpoints import persist_dataset
24
+ from cascade.executor.checkpoints import persist_dataset, retrieve_dataset
25
25
  from cascade.executor.comms import Listener, callback, send_data
26
26
  from cascade.executor.msg import (
27
27
  Ack,
@@ -31,6 +31,9 @@ from cascade.executor.msg import (
31
31
  DatasetPersistSuccess,
32
32
  DatasetPublished,
33
33
  DatasetPurge,
34
+ DatasetRetrieveCommand,
35
+ DatasetRetrieveFailure,
36
+ DatasetRetrieveSuccess,
34
37
  DatasetTransmitCommand,
35
38
  DatasetTransmitFailure,
36
39
  DatasetTransmitPayload,
@@ -63,7 +66,7 @@ class DataServer:
63
66
  self.cap = 2
64
67
  self.ds_proc_tp: PythonExecutor = ThreadPoolExecutor(max_workers=self.cap)
65
68
  self.futs_in_progress: dict[
66
- DatasetTransmitCommand | DatasetTransmitPayload | DatasetPersistCommand, Future
69
+ DatasetTransmitCommand | DatasetTransmitPayload | DatasetPersistCommand | DatasetRetrieveCommand, Future
67
70
  ] = {}
68
71
  self.awaiting_confirmation: dict[int, tuple[DatasetTransmitCommand, int]] = {}
69
72
  self.invalid: set[DatasetId] = (
@@ -97,7 +100,7 @@ class DataServer:
97
100
  return
98
101
  wait(self.futs_in_progress.values(), return_when=FIRST_COMPLETED)
99
102
 
100
- def store_payload(self, payload: DatasetTransmitPayload) -> int:
103
+ def _store_payload(self, payload: DatasetTransmitPayload) -> int:
101
104
  try:
102
105
  l = len(payload.value)
103
106
  try:
@@ -153,11 +156,30 @@ class DataServer:
153
156
  time_ns()
154
157
  ) # not actually consumed but uniform signature with send_payload simplifies typing
155
158
 
156
- def persist_payload(self, command: DatasetPersistCommand) -> int:
159
+ def _retrieve_dataset(self, command: DatasetRetrieveCommand) -> int:
160
+ buf: None | shm_client.AllocatedBuffer = None
161
+ try:
162
+ if command.target != self.host:
163
+ raise ValueError(f"invalid host in {command=}")
164
+ buf = retrieve_dataset(command)
165
+ logger.debug(f"dataset for {command} retrieved")
166
+ callback(self.maddress, DatasetRetrieveSuccess(host=self.host, ds=command.ds))
167
+ except Exception as e:
168
+ logger.exception(f"failed to retrieve dataset for {command}, reporting up")
169
+ callback(
170
+ self.maddress,
171
+ DatasetRetrieveFailure(host=self.host, detail=f"{repr(command)} -> {repr(e)}"),
172
+ )
173
+ finally:
174
+ if buf is not None:
175
+ buf.close()
176
+ return time_ns()
177
+
178
+ def _persist_dataset(self, command: DatasetPersistCommand) -> int:
157
179
  buf: None | shm_client.AllocatedBuffer = None
158
180
  try:
159
181
  if command.source != self.host:
160
- raise ValueError(f"invalid {command=}")
182
+ raise ValueError(f"invalid host in {command=}")
161
183
  buf = shm_client.get(key=ds2shmid(command.ds))
162
184
  persist_dataset(command, buf)
163
185
  logger.debug(f"dataset for {command} persisted")
@@ -173,7 +195,7 @@ class DataServer:
173
195
  buf.close()
174
196
  return time_ns()
175
197
 
176
- def send_payload(self, command: DatasetTransmitCommand) -> int:
198
+ def _send_payload(self, command: DatasetTransmitCommand) -> int:
177
199
  buf: None | shm_client.AllocatedBuffer = None
178
200
  payload: None | DatasetTransmitPayload = None
179
201
  try:
@@ -240,7 +262,7 @@ class DataServer:
240
262
  }
241
263
  )
242
264
  self.awaiting_confirmation[m.idx] = (m, -1)
243
- fut = self.ds_proc_tp.submit(self.send_payload, m)
265
+ fut = self.ds_proc_tp.submit(self._send_payload, m)
244
266
  self.futs_in_progress[m] = fut
245
267
  elif isinstance(m, DatasetPersistCommand):
246
268
  if m.ds in self.invalid:
@@ -248,7 +270,11 @@ class DataServer:
248
270
  f"unexpected persist command {m} as the dataset was already purged"
249
271
  )
250
272
  # TODO mark?
251
- fut = self.ds_proc_tp.submit(self.persist_payload, m)
273
+ fut = self.ds_proc_tp.submit(self._persist_dataset, m)
274
+ self.futs_in_progress[m] = fut
275
+ elif isinstance(m, DatasetRetrieveCommand):
276
+ # TODO mark?
277
+ fut = self.ds_proc_tp.submit(self._retrieve_dataset, m)
252
278
  self.futs_in_progress[m] = fut
253
279
  elif isinstance(m, DatasetTransmitPayload):
254
280
  if m.header.ds in self.invalid:
@@ -263,7 +289,7 @@ class DataServer:
263
289
  "target": self.host,
264
290
  }
265
291
  )
266
- fut = self.ds_proc_tp.submit(self.store_payload, m)
292
+ fut = self.ds_proc_tp.submit(self._store_payload, m)
267
293
  self.futs_in_progress[m] = fut
268
294
  elif isinstance(m, Ack):
269
295
  logger.debug(f"confirmed transmit {m.idx}")
@@ -273,7 +299,7 @@ class DataServer:
273
299
  # TODO submit this as a future? This actively blocks the whole server
274
300
  to_wait = []
275
301
  for commandProg, fut in self.futs_in_progress.items():
276
- if isinstance(commandProg, DatasetTransmitCommand|DatasetPersistCommand):
302
+ if isinstance(commandProg, DatasetTransmitCommand|DatasetPersistCommand|DatasetRetrieveCommand):
277
303
  val = commandProg.ds
278
304
  elif isinstance(commandProg, DatasetTransmitPayload):
279
305
  val = commandProg.header.ds
@@ -323,7 +349,7 @@ class DataServer:
323
349
  self.awaiting_confirmation.pop(e)
324
350
  else:
325
351
  logger.warning(f"submitting a retry of {command}")
326
- fut = self.ds_proc_tp.submit(self.send_payload, command)
352
+ fut = self.ds_proc_tp.submit(self._send_payload, command)
327
353
  self.futs_in_progress[command] = fut
328
354
  self.awaiting_confirmation[e] = (command, -1)
329
355
  except:
@@ -39,6 +39,8 @@ from cascade.executor.msg import (
39
39
  DatasetPersistSuccess,
40
40
  DatasetPublished,
41
41
  DatasetPurge,
42
+ DatasetRetrieveFailure,
43
+ DatasetRetrieveSuccess,
42
44
  DatasetTransmitFailure,
43
45
  ExecutorExit,
44
46
  ExecutorFailure,
@@ -60,6 +62,8 @@ from cascade.shm.server import entrypoint as shm_server
60
62
  logger = logging.getLogger(__name__)
61
63
  heartbeat_grace_ms = 2 * comms_default_timeout_ms
62
64
 
65
+ # messages from the data server which need to go to controller, but have no additional logic here
66
+ JustForwardToController = DatasetTransmitFailure|DatasetPersistSuccess|DatasetPersistFailure|DatasetRetrieveFailure
63
67
 
64
68
  def address_of(port: int) -> BackboneAddress:
65
69
  return f"tcp://{platform.get_bindabble_self()}:{port}"
@@ -308,7 +312,12 @@ class Executor:
308
312
  callback(worker_address(worker), m)
309
313
  self.datasets.add(m.ds)
310
314
  self.to_controller(m)
311
- elif isinstance(m, DatasetTransmitFailure|DatasetPersistSuccess|DatasetPersistFailure):
315
+ elif isinstance(m, DatasetRetrieveSuccess):
316
+ availability_notification = DatasetPublished(ds=m.ds, origin=self.host, transmit_idx=None)
317
+ for worker in self.workers:
318
+ callback(worker_address(worker), availability_notification)
319
+ self.to_controller(m)
320
+ elif isinstance(m, JustForwardToController):
312
321
  self.to_controller(m)
313
322
  else:
314
323
  # NOTE transmit and store are handled in DataServer (which has its own socket)
cascade/executor/msg.py CHANGED
@@ -138,6 +138,22 @@ class DatasetPersistSuccess:
138
138
  host: HostId
139
139
  ds: DatasetId
140
140
 
141
+ @dataclass(frozen=True)
142
+ class DatasetRetrieveCommand:
143
+ target: HostId
144
+ ds: DatasetId
145
+ storage_type: CheckpointStorageType
146
+ retrieve_params: str # storage-type-specific serialization of params
147
+
148
+ @dataclass(frozen=True)
149
+ class DatasetRetrieveFailure:
150
+ host: HostId
151
+ detail: str
152
+
153
+ @dataclass(frozen=True)
154
+ class DatasetRetrieveSuccess:
155
+ host: HostId
156
+ ds: DatasetId
141
157
 
142
158
  @dataclass(frozen=True)
143
159
  class ExecutorFailure:
@@ -197,6 +213,9 @@ Message = (
197
213
  | DatasetPersistCommand
198
214
  | DatasetPersistFailure
199
215
  | DatasetPersistSuccess
216
+ | DatasetRetrieveCommand
217
+ | DatasetRetrieveFailure
218
+ | DatasetRetrieveSuccess
200
219
  | ExecutorFailure
201
220
  | ExecutorExit
202
221
  | ExecutorRegistration