earthkit-workflows 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,173 @@
1
+ # (C) Copyright 2025- ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+
9
+ """Module for simplifying writing tests
10
+
11
+ Similar to util, but not enough to unify
12
+
13
+ It is capable, for a single given task, to spin an shm server, put all task's inputs into it, execute the task, store outputs in memory, and retrieve the result.
14
+ See the `demo()` function at the very end
15
+ """
16
+
17
+ import logging
18
+ from contextlib import contextmanager
19
+ from dataclasses import dataclass
20
+ from time import perf_counter_ns
21
+ from typing import Any, Callable
22
+
23
+ import cloudpickle
24
+
25
+ import cascade.executor.platform as platform
26
+ import cascade.shm.api as shm_api
27
+ import cascade.shm.client as shm_client
28
+ from cascade.executor.comms import Listener as ZmqListener
29
+ from cascade.executor.config import logging_config
30
+ from cascade.executor.msg import BackboneAddress, DatasetPublished
31
+ from cascade.executor.runner.memory import Memory, ds2shmid
32
+ from cascade.executor.runner.packages import PackagesEnv
33
+ from cascade.executor.runner.runner import ExecutionContext, run
34
+ from cascade.low.builders import TaskBuilder
35
+ from cascade.low.core import DatasetId
36
+ from cascade.shm.server import entrypoint as shm_server
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @contextmanager
42
+ def setup_shm(testId: str):
43
+ mp_ctx = platform.get_mp_ctx("executor-aux")
44
+ shm_socket = f"/tmp/tcShm-{testId}"
45
+ shm_api.publish_socket_addr(shm_socket)
46
+ shm_process = mp_ctx.Process(
47
+ target=shm_server,
48
+ kwargs={
49
+ "logging_config": logging_config,
50
+ "shm_pref": f"tc{testId}",
51
+ },
52
+ )
53
+ shm_process.start()
54
+ shm_client.ensure()
55
+ try:
56
+ yield
57
+ except Exception as e:
58
+ # NOTE we log like this in case shm shutdown freezes
59
+ logger.exception(f"gotten {repr(e)}, proceed with shm shutdown")
60
+ raise
61
+ finally:
62
+ shm_client.shutdown(timeout_sec=1.0)
63
+ shm_process.join(1)
64
+ if shm_process.is_alive():
65
+ shm_process.terminate()
66
+ shm_process.join(1)
67
+ if shm_process.is_alive():
68
+ shm_process.kill()
69
+ shm_process.join(1)
70
+
71
+
72
+ def simple_runner(callback: BackboneAddress, executionContext: ExecutionContext):
73
+ tasks = list(executionContext.tasks.keys())
74
+ if len(tasks) != 1:
75
+ raise ValueError(f"expected 1 task, gotten {len(tasks)}")
76
+ taskId = tasks[0]
77
+ taskInstance = executionContext.tasks[taskId]
78
+ with Memory(callback, "testWorker") as memory, PackagesEnv() as pckg:
79
+ # for key, value in taskSequence.extra_env.items():
80
+ # os.environ[key] = value
81
+
82
+ pckg.extend(taskInstance.definition.environment)
83
+ run(taskId, executionContext, memory)
84
+ memory.flush()
85
+
86
+
87
+ @dataclass
88
+ class CallableInstance:
89
+ func: Callable
90
+ kwargs: dict[str, Any]
91
+ args: list[tuple[int, Any]]
92
+ env: list[str]
93
+ exp_output: Any
94
+
95
+
96
+ def callable2ctx(
97
+ callableInstance: CallableInstance, callback: BackboneAddress
98
+ ) -> ExecutionContext:
99
+ taskInstance = TaskBuilder.from_callable(
100
+ callableInstance.func, callableInstance.env
101
+ )
102
+ param_source = {}
103
+ params = [
104
+ (key, DatasetId("taskId", f"kwarg.{key}"), value)
105
+ for key, value in callableInstance.kwargs.items()
106
+ ] + [
107
+ (key, DatasetId("taskId", f"pos.{key}"), value)
108
+ for key, value in callableInstance.args
109
+ ]
110
+ for key, ds_key, value in params:
111
+ raw = cloudpickle.dumps(value)
112
+ L = len(raw)
113
+ buf = shm_client.allocate(ds2shmid(ds_key), L, "cloudpickle.loads")
114
+ buf.view()[:L] = raw
115
+ buf.close()
116
+ param_source[key] = (ds_key, "Any")
117
+
118
+ return ExecutionContext(
119
+ tasks={"taskId": taskInstance},
120
+ param_source={"taskId": param_source},
121
+ callback=callback,
122
+ publish={
123
+ DatasetId("taskId", output)
124
+ for output, _ in taskInstance.definition.output_schema
125
+ },
126
+ )
127
+
128
+
129
+ def run_test(
130
+ callableInstance: CallableInstance, testId: str, max_runtime_sec: int
131
+ ) -> Any:
132
+ with setup_shm(testId):
133
+ addr = f"ipc:///tmp/tc{testId}"
134
+ listener = ZmqListener(addr)
135
+ ec_ctx = callable2ctx(callableInstance, addr)
136
+ mp_ctx = platform.get_mp_ctx("executor-aux")
137
+ runner = mp_ctx.Process(target=simple_runner, args=(addr, ec_ctx))
138
+ runner.start()
139
+ output = DatasetId("taskId", "0")
140
+
141
+ end = perf_counter_ns() + max_runtime_sec * int(1e9)
142
+ while perf_counter_ns() < end:
143
+ mess = listener.recv_messages()
144
+ if mess == [
145
+ DatasetPublished(origin="testWorker", ds=output, transmit_idx=None)
146
+ ]:
147
+ break
148
+ elif not mess:
149
+ continue
150
+ else:
151
+ raise ValueError(mess)
152
+
153
+ runner.join()
154
+ output_buf = shm_client.get(ds2shmid(output))
155
+ output_des = cloudpickle.loads(output_buf.view())
156
+ output_buf.close()
157
+ assert output_des == callableInstance.exp_output
158
+
159
+
160
+ def demo():
161
+ def myfunc(l: int) -> float:
162
+ import numpy as np
163
+
164
+ return np.arange(l).sum()
165
+
166
+ ci = CallableInstance(
167
+ func=myfunc, kwargs={"l": 4}, args=[], env=["numpy"], exp_output=6
168
+ )
169
+ run_test(ci, "numpyTest1", 2)
170
+
171
+
172
+ if __name__ == "__main__":
173
+ demo()
@@ -17,6 +17,7 @@ import subprocess
17
17
  import sys
18
18
  from concurrent.futures import ThreadPoolExecutor
19
19
  from time import perf_counter_ns
20
+ from typing import Any
20
21
 
21
22
  import orjson
22
23
 
@@ -28,7 +29,7 @@ from cascade.executor.comms import callback
28
29
  from cascade.executor.config import logging_config, logging_config_filehandler
29
30
  from cascade.executor.executor import Executor
30
31
  from cascade.executor.msg import BackboneAddress, ExecutorShutdown
31
- from cascade.low.core import JobInstance
32
+ from cascade.low.core import DatasetId, JobInstance
32
33
  from cascade.low.func import msum
33
34
  from cascade.scheduler.precompute import precompute
34
35
  from earthkit.workflows.graph import Graph, deduplicate_nodes
@@ -159,7 +160,7 @@ def run_locally(
159
160
  portBase: int = 12345,
160
161
  log_base: str | None = None,
161
162
  report_address: str | None = None,
162
- ) -> None:
163
+ ) -> dict[DatasetId, Any]:
163
164
  if log_base is not None:
164
165
  log_path = f"{log_base}.controller.txt"
165
166
  logging.config.dictConfig(logging_config_filehandler(log_path))
@@ -216,6 +217,7 @@ def run_locally(
216
217
  if os.environ.get("CASCADE_DEBUG_PRINT"):
217
218
  for key, value in result.outputs.items():
218
219
  print(f"{key} => {value}")
220
+ return result.outputs
219
221
  except Exception:
220
222
  # NOTE we log this to get the stacktrace into the logfile
221
223
  logger.exception("controller failure, proceed with executor shutdown")
@@ -27,6 +27,7 @@ JobId = str
27
27
 
28
28
  @dataclass
29
29
  class JobProgress:
30
+ started: bool
30
31
  completed: bool
31
32
  pct: (
32
33
  str | None
@@ -35,19 +36,20 @@ class JobProgress:
35
36
 
36
37
  @classmethod
37
38
  def failed(cls, failure: str) -> Self:
38
- return cls(True, None, failure)
39
+ return cls(True, True, None, failure)
39
40
 
40
41
  @classmethod
41
42
  def progressed(cls, pct: float) -> Self:
42
43
  progress = "{:.2%}".format(pct)[:-1]
43
- return cls(False, progress, None)
44
+ return cls(True, False, progress, None)
44
45
 
45
46
  @classmethod
46
47
  def succeeded(cls) -> Self:
47
- return cls(True, None, None)
48
+ return cls(True, True, None, None)
48
49
 
49
50
 
50
- JobProgressStarted = JobProgress(False, "0.00", None)
51
+ JobProgressStarted = JobProgress(True, False, "0.00", None)
52
+ JobProgressEnqueued = JobProgress(False, False, None, None)
51
53
 
52
54
 
53
55
  @dataclass
@@ -20,7 +20,6 @@ from concurrent.futures import Executor as PythonExecutor
20
20
  from concurrent.futures import Future, ThreadPoolExecutor, wait
21
21
  from time import time_ns
22
22
 
23
- import cascade.shm.api as shm_api
24
23
  import cascade.shm.client as shm_client
25
24
  from cascade.executor.comms import Listener, callback, send_data
26
25
  from cascade.executor.msg import (
@@ -48,7 +47,6 @@ class DataServer:
48
47
  maddress: BackboneAddress,
49
48
  daddress: BackboneAddress,
50
49
  host: str,
51
- shm_port: int,
52
50
  logging_config: dict,
53
51
  ):
54
52
  logging.config.dictConfig(logging_config)
@@ -58,7 +56,6 @@ class DataServer:
58
56
  self.daddress = daddress
59
57
  self.dlistener = Listener(daddress)
60
58
  self.terminating = False
61
- shm_api.publish_client_port(shm_port)
62
59
  self.cap = 2
63
60
  self.ds_proc_tp: PythonExecutor = ThreadPoolExecutor(max_workers=self.cap)
64
61
  self.futs_in_progress: dict[
@@ -305,8 +302,7 @@ def start_data_server(
305
302
  maddress: BackboneAddress,
306
303
  daddress: BackboneAddress,
307
304
  host: str,
308
- shm_port: int,
309
305
  logging_config: dict,
310
306
  ):
311
- server = DataServer(maddress, daddress, host, shm_port, logging_config)
307
+ server = DataServer(maddress, daddress, host, logging_config)
312
308
  server.recv_loop()
@@ -18,6 +18,7 @@ the tasks themselves.
18
18
  import atexit
19
19
  import logging
20
20
  import os
21
+ import uuid
21
22
  from multiprocessing.process import BaseProcess
22
23
  from typing import Iterable
23
24
 
@@ -94,8 +95,8 @@ class Executor:
94
95
  self.sender = ReliableSender(self.mlistener.address, resend_grace_ms)
95
96
  self.sender.add_host("controller", controller_address)
96
97
  # TODO make the shm server params configurable
97
- shm_port = portBase + 2
98
- shm_api.publish_client_port(shm_port)
98
+ shm_port = f"/tmp/cascShmSock-{uuid.uuid4()}" # portBase + 2
99
+ shm_api.publish_socket_addr(shm_port)
99
100
  ctx = platform.get_mp_ctx("executor-aux")
100
101
  if log_base:
101
102
  shm_logging = logging_config_filehandler(f"{log_base}.shm.txt")
@@ -104,12 +105,11 @@ class Executor:
104
105
  logger.debug("about to start an shm process")
105
106
  self.shm_process = ctx.Process(
106
107
  target=shm_server,
107
- args=(
108
- shm_port,
109
- shm_vol_gb * (1024**3) if shm_vol_gb else None,
110
- shm_logging,
111
- f"sCasc{host}",
112
- ),
108
+ kwargs={
109
+ "capacity": shm_vol_gb * (1024**3) if shm_vol_gb else None,
110
+ "logging_config": shm_logging,
111
+ "shm_pref": f"sCasc{host}",
112
+ },
113
113
  )
114
114
  self.shm_process.start()
115
115
  self.daddress = address_of(portBase + 1)
@@ -124,7 +124,6 @@ class Executor:
124
124
  self.mlistener.address,
125
125
  self.daddress,
126
126
  self.host,
127
- shm_port,
128
127
  dsr_logging,
129
128
  ),
130
129
  )
@@ -12,6 +12,7 @@ Interaction with shm
12
12
 
13
13
  import hashlib
14
14
  import logging
15
+ import sys
15
16
  from contextlib import AbstractContextManager
16
17
  from typing import Any, Literal
17
18
 
@@ -119,27 +120,29 @@ class Memory(AbstractContextManager):
119
120
  self.local.pop(inputId)
120
121
 
121
122
  # NOTE poor man's gpu mem management -- currently torch only. Given the task sequence limitation,
122
- # this may not be the best place to invoke. Additionally, we may want to check first whether
123
- # the worker is gpu aware, etc
124
- try:
125
- import torch
126
-
127
- if torch.cuda.is_available():
128
- free, total = torch.cuda.mem_get_info()
129
- logger.debug(f"cuda mem avail: {free/total:.2%}")
130
- if free / total < 0.8:
131
- torch.cuda.empty_cache()
123
+ # this may not be the best place to invoke.
124
+ if (
125
+ "torch" in sys.modules
126
+ ): # if no task on this worker imported torch, no need to flush
127
+ try:
128
+ import torch
129
+
130
+ if torch.cuda.is_available():
132
131
  free, total = torch.cuda.mem_get_info()
133
- logger.debug(f"cuda mem avail post cache empty: {free/total:.2%}")
132
+ logger.debug(f"cuda mem avail: {free/total:.2%}")
134
133
  if free / total < 0.8:
135
- # NOTE this ofc makes low sense if there is any other application (like browser or ollama)
136
- # that the user may be running
137
- logger.warning("cuda mem avail low despite cache empty!")
138
- logger.debug(torch.cuda.memory_summary())
139
- except ImportError:
140
- return
141
- except Exception:
142
- logger.exception("failed to free cuda cache")
134
+ torch.cuda.empty_cache()
135
+ free, total = torch.cuda.mem_get_info()
136
+ logger.debug(
137
+ f"cuda mem avail post cache empty: {free/total:.2%}"
138
+ )
139
+ if free / total < 0.8:
140
+ # NOTE this ofc makes low sense if there is any other application (like browser or ollama)
141
+ # that the user may be running
142
+ logger.warning("cuda mem avail low despite cache empty!")
143
+ logger.debug(torch.cuda.memory_summary())
144
+ except Exception:
145
+ logger.exception("failed to free cuda cache")
143
146
 
144
147
  def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
145
148
  # this is required so that the Shm can be properly freed, otherwise you get 'pointers cannot be closed'
@@ -12,6 +12,7 @@ Note that venv itself is left untouched after the run finishes -- we extend sys
12
12
  with a temporary directory and install in there
13
13
  """
14
14
 
15
+ import importlib
15
16
  import logging
16
17
  import os
17
18
  import site
@@ -24,6 +25,51 @@ from typing import Literal
24
25
  logger = logging.getLogger(__name__)
25
26
 
26
27
 
28
+ class Commands:
29
+ venv_command = lambda name: ["uv", "venv", name]
30
+ install_command = lambda name: [
31
+ "uv",
32
+ "pip",
33
+ "install",
34
+ "--prefix",
35
+ name,
36
+ "--prerelease",
37
+ "explicit",
38
+ ]
39
+
40
+
41
+ def run_command(command: list[str]) -> None:
42
+ try:
43
+ result = subprocess.run(command, check=False, capture_output=True)
44
+ except FileNotFoundError as ex:
45
+ raise ValueError(f"command failure: {ex}")
46
+ if result.returncode != 0:
47
+ msg = f"command failed with {result.returncode}. Stderr: {result.stderr}, Stdout: {result.stdout}, Args: {result.args}"
48
+ logger.error(msg)
49
+ raise ValueError(msg)
50
+
51
+
52
+ def new_venv() -> tempfile.TemporaryDirectory:
53
+ """1. Creates a new temporary directory with a venv inside.
54
+ 2. Extends sys.path so that packages in that venv can be imported.
55
+ """
56
+ logger.debug("creating a new venv")
57
+ td = tempfile.TemporaryDirectory(prefix="cascade_runner_venv_")
58
+ # NOTE we create a venv instead of just plain directory, because some of the packages create files
59
+ # outside of site-packages. Thus we then install with --prefix, not with --target
60
+ run_command(Commands.venv_command(td.name))
61
+
62
+ # NOTE not sure if getsitepackages was intended for this -- if issues, attempt replacing
63
+ # with something like f"{td.name}/lib/python*/site-packages" + globbing
64
+ extra_sp = site.getsitepackages(prefixes=[td.name])
65
+ # NOTE this makes the explicit packages go first, in case of a different version
66
+ logger.debug(f"extending sys.path with {extra_sp}")
67
+ sys.path = extra_sp + sys.path
68
+ logger.debug(f"new sys.path: {sys.path}")
69
+
70
+ return td
71
+
72
+
27
73
  class PackagesEnv(AbstractContextManager):
28
74
  def __init__(self) -> None:
29
75
  self.td: tempfile.TemporaryDirectory | None = None
@@ -32,38 +78,32 @@ class PackagesEnv(AbstractContextManager):
32
78
  if not packages:
33
79
  return
34
80
  if self.td is None:
35
- logger.debug("creating a new venv")
36
- self.td = tempfile.TemporaryDirectory()
37
- venv_command = ["uv", "venv", self.td.name]
38
- # NOTE we create a venv instead of just plain directory, because some of the packages create files
39
- # outside of site-packages. Thus we then install with --prefix, not with --target
40
- subprocess.run(venv_command, check=True)
81
+ self.td = new_venv()
41
82
 
42
83
  logger.debug(
43
84
  f"installing {len(packages)} packages: {','.join(packages[:3])}{',...' if len(packages) > 3 else ''}"
44
85
  )
45
- install_command = [
46
- "uv",
47
- "pip",
48
- "install",
49
- "--prefix",
50
- self.td.name,
51
- "--prerelease",
52
- "allow",
53
- ]
86
+ install_command = Commands.install_command(self.td.name)
54
87
  if os.environ.get("VENV_OFFLINE", "") == "YES":
55
88
  install_command += ["--offline"]
56
89
  if cache_dir := os.environ.get("VENV_CACHE", ""):
57
90
  install_command += ["--cache-dir", cache_dir]
58
91
  install_command.extend(set(packages))
59
- subprocess.run(install_command, check=True)
60
- # NOTE not sure if getsitepackages was intended for this -- if issues, attempt replacing
61
- # with something like f"{self.td.name}/lib/python*/site-packages" + globbing
62
- extra_sp = site.getsitepackages(prefixes=[self.td.name])
63
- # NOTE this makes the explicit packages go first, in case of a different version
64
- sys.path = extra_sp + sys.path
92
+ logger.debug(f"running install command: {' '.join(install_command)}")
93
+ run_command(install_command)
94
+
95
+ # NOTE we need this due to namespace packages:
96
+ # 1. task 1 installs ns.pkg1 in its venv
97
+ # 2. task 1 finishes, task 2 starts on the same worker
98
+ # 3. task 2 starts, installs ns.pkg2. However, importlib is in a state that ns is aware only of pkg1 submod
99
+ # Additionally, the caches are invalid anyway, because task 1's venv is already deleted
100
+ importlib.invalidate_caches()
101
+ # TODO some namespace packages may require a reimport because they dynamically build `__all__` -- eg earthkit
65
102
 
66
103
  def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
104
+ sys.path = [
105
+ p for p in sys.path if self.td is None or not p.startswith(self.td.name)
106
+ ]
67
107
  if self.td is not None:
68
108
  self.td.cleanup()
69
109
  return False
@@ -15,14 +15,17 @@ from cascade.gateway.server import serve
15
15
 
16
16
 
17
17
  def main(
18
- url: str, log_base: str | None = None, troika_config: str | None = None
18
+ url: str,
19
+ log_base: str | None = None,
20
+ troika_config: str | None = None,
21
+ max_jobs: int | None = None,
19
22
  ) -> None:
20
23
  if log_base:
21
24
  log_path = f"{log_base}/gateway.txt"
22
25
  logging.config.dictConfig(logging_config_filehandler(log_path))
23
26
  else:
24
27
  logging.config.dictConfig(logging_config)
25
- serve(url, log_base, troika_config)
28
+ serve(url, log_base, troika_config, max_jobs)
26
29
 
27
30
 
28
31
  if __name__ == "__main__":
cascade/gateway/api.py CHANGED
@@ -62,8 +62,9 @@ class JobProgressRequest(CascadeGatewayAPI):
62
62
 
63
63
 
64
64
  class JobProgressResponse(CascadeGatewayAPI):
65
- progresses: dict[JobId, JobProgress]
65
+ progresses: dict[JobId, JobProgress | None]
66
66
  datasets: dict[JobId, list[DatasetId]]
67
+ queue_length: int
67
68
  error: str | None # top level error
68
69
 
69
70
 
cascade/gateway/router.py CHANGED
@@ -15,6 +15,7 @@ import os
15
15
  import stat
16
16
  import subprocess
17
17
  import uuid
18
+ from collections import OrderedDict
18
19
  from dataclasses import dataclass
19
20
  from typing import Iterable
20
21
 
@@ -22,7 +23,12 @@ import orjson
22
23
  import zmq
23
24
 
24
25
  import cascade.executor.platform as platform
25
- from cascade.controller.report import JobId, JobProgress, JobProgressStarted
26
+ from cascade.controller.report import (
27
+ JobId,
28
+ JobProgress,
29
+ JobProgressEnqueued,
30
+ JobProgressStarted,
31
+ )
26
32
  from cascade.executor.comms import get_context
27
33
  from cascade.gateway.api import JobSpec, TroikaSpec
28
34
  from cascade.low.core import DatasetId
@@ -202,16 +208,29 @@ def _spawn_subprocess(
202
208
 
203
209
  class JobRouter:
204
210
  def __init__(
205
- self, poller: zmq.Poller, log_base: str | None, troika_config: str | None
211
+ self,
212
+ poller: zmq.Poller,
213
+ log_base: str | None,
214
+ troika_config: str | None,
215
+ max_jobs: int | None,
206
216
  ):
207
217
  self.poller = poller
208
218
  self.jobs: dict[str, Job] = {}
219
+ self.active_jobs = 0
220
+ self.max_jobs = max_jobs
221
+ self.jobs_queue: OrderedDict[JobId, JobSpec] = OrderedDict()
209
222
  self.procs: dict[str, subprocess.Popen] = {}
210
223
  self.log_base = log_base
211
224
  self.troika_config = troika_config
212
225
 
213
- def spawn_job(self, job_spec: JobSpec) -> JobId:
214
- job_id = next_uuid(self.jobs.keys(), lambda: str(uuid.uuid4()))
226
+ def maybe_spawn(self) -> None:
227
+ if not self.jobs_queue:
228
+ return
229
+ if self.max_jobs is not None and self.active_jobs >= self.max_jobs:
230
+ logger.debug(f"already running {self.active_jobs}, no spawn")
231
+ return
232
+
233
+ job_id, job_spec = self.jobs_queue.popitem(False)
215
234
  base_addr = f"tcp://{platform.get_bindabble_self()}"
216
235
  socket = get_context().socket(zmq.PULL)
217
236
  port = socket.bind_to_random_port(base_addr)
@@ -222,18 +241,37 @@ class JobRouter:
222
241
  self.procs[job_id] = _spawn_subprocess(
223
242
  job_spec, full_addr, job_id, self.log_base, self.troika_config
224
243
  )
244
+ self.active_jobs += 1
245
+ return job_id
246
+
247
+ def enqueue_job(self, job_spec: JobSpec) -> JobId:
248
+ job_id = next_uuid(
249
+ set(self.jobs.keys()).union(self.jobs_queue.keys()),
250
+ lambda: str(uuid.uuid4()),
251
+ )
252
+ self.jobs_queue[job_id] = job_spec
253
+ self.maybe_spawn()
225
254
  return job_id
226
255
 
227
256
  def progress_of(
228
257
  self, job_ids: Iterable[JobId]
229
- ) -> tuple[dict[JobId, JobProgress], dict[JobId, list[DatasetId]]]:
258
+ ) -> tuple[dict[JobId, JobProgress], dict[JobId, list[DatasetId]], int]:
230
259
  if not job_ids:
231
- job_ids = self.jobs.keys()
232
- progresses = {job_id: self.jobs[job_id].progress for job_id in job_ids}
260
+ job_ids = set(self.jobs.keys()).union(self.jobs_queue.keys())
261
+ progresses = {}
262
+ for job_id in job_ids:
263
+ if job_id in self.jobs:
264
+ progresses[job_id] = self.jobs[job_id].progress
265
+ elif job_id in self.jobs_queue:
266
+ progresses[job_id] = JobProgressEnqueued
267
+ else:
268
+ progresses[job_id] = None
233
269
  datasets = {
234
- job_id: list(self.jobs[job_id].results.keys()) for job_id in job_ids
270
+ job_id: list(self.jobs[job_id].results.keys())
271
+ for job_id in job_ids
272
+ if job_id in self.jobs
235
273
  }
236
- return progresses, datasets
274
+ return progresses, datasets, len(self.jobs_queue)
237
275
 
238
276
  def get_result(self, job_id: JobId, dataset_id: DatasetId) -> bytes:
239
277
  return self.jobs[job_id].results[dataset_id]
@@ -246,6 +284,8 @@ class JobRouter:
246
284
  job = self.jobs[job_id]
247
285
  if progress.completed:
248
286
  self.poller.unregister(job.socket)
287
+ self.active_jobs -= 1
288
+ self.maybe_spawn()
249
289
  if progress.failure is not None and job.progress.failure is None:
250
290
  job.progress = progress
251
291
  elif job.last_seen >= timestamp or job.progress.failure is not None:
cascade/gateway/server.py CHANGED
@@ -31,16 +31,19 @@ def handle_fe(socket: zmq.Socket, jobs: JobRouter) -> bool:
31
31
  rv: api.CascadeGatewayAPI
32
32
  if isinstance(m, api.SubmitJobRequest):
33
33
  try:
34
- job_id = jobs.spawn_job(m.job)
34
+ job_id = jobs.enqueue_job(m.job)
35
35
  rv = api.SubmitJobResponse(job_id=job_id, error=None)
36
36
  except Exception as e:
37
37
  logger.exception(f"failed to spawn a job: {m}")
38
38
  rv = api.SubmitJobResponse(job_id=None, error=repr(e))
39
39
  elif isinstance(m, api.JobProgressRequest):
40
40
  try:
41
- progresses, datasets = jobs.progress_of(m.job_ids)
41
+ progresses, datasets, queue_length = jobs.progress_of(m.job_ids)
42
42
  rv = api.JobProgressResponse(
43
- progresses=progresses, datasets=datasets, error=None
43
+ progresses=progresses,
44
+ datasets=datasets,
45
+ error=None,
46
+ queue_length=queue_length,
44
47
  )
45
48
  except Exception as e:
46
49
  logger.exception(f"failed to get progress of: {m}")
@@ -80,7 +83,10 @@ def handle_controller(socket: zmq.Socket, jobs: JobRouter) -> None:
80
83
 
81
84
 
82
85
  def serve(
83
- url: str, log_base: str | None = None, troika_config: str | None = None
86
+ url: str,
87
+ log_base: str | None = None,
88
+ troika_config: str | None = None,
89
+ max_jobs: int | None = None,
84
90
  ) -> None:
85
91
  ctx = get_context()
86
92
  poller = zmq.Poller()
@@ -88,7 +94,7 @@ def serve(
88
94
  fe = ctx.socket(zmq.REP)
89
95
  fe.bind(url)
90
96
  poller.register(fe, flags=zmq.POLLIN)
91
- jobs = JobRouter(poller, log_base, troika_config)
97
+ jobs = JobRouter(poller, log_base, troika_config, max_jobs)
92
98
 
93
99
  logger.debug("entering recv loop")
94
100
  is_break = False