earthkit-workflows 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cascade/benchmarks/tests.py +173 -0
- cascade/benchmarks/util.py +4 -2
- cascade/controller/report.py +6 -4
- cascade/executor/data_server.py +1 -5
- cascade/executor/executor.py +8 -9
- cascade/executor/runner/memory.py +22 -19
- cascade/executor/runner/packages.py +61 -21
- cascade/gateway/__main__.py +5 -2
- cascade/gateway/api.py +2 -1
- cascade/gateway/router.py +49 -9
- cascade/gateway/server.py +11 -5
- cascade/low/builders.py +41 -4
- cascade/shm/api.py +53 -12
- cascade/shm/client.py +29 -16
- cascade/shm/dataset.py +15 -1
- cascade/shm/server.py +28 -15
- earthkit/workflows/_version.py +1 -1
- earthkit/workflows/decorators.py +23 -9
- earthkit/workflows/fluent.py +25 -17
- earthkit/workflows/mark.py +3 -4
- {earthkit_workflows-0.4.6.dist-info → earthkit_workflows-0.5.0.dist-info}/METADATA +1 -1
- {earthkit_workflows-0.4.6.dist-info → earthkit_workflows-0.5.0.dist-info}/RECORD +25 -25
- earthkit/workflows/py.typed +0 -0
- {earthkit_workflows-0.4.6.dist-info → earthkit_workflows-0.5.0.dist-info}/WHEEL +0 -0
- {earthkit_workflows-0.4.6.dist-info → earthkit_workflows-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {earthkit_workflows-0.4.6.dist-info → earthkit_workflows-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# (C) Copyright 2025- ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""Module for simplifying writing tests
|
|
10
|
+
|
|
11
|
+
Similar to util, but not enough to unify
|
|
12
|
+
|
|
13
|
+
It is capable, for a single given task, to spin an shm server, put all task's inputs into it, execute the task, store outputs in memory, and retrieve the result.
|
|
14
|
+
See the `demo()` function at the very end
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from time import perf_counter_ns
|
|
21
|
+
from typing import Any, Callable
|
|
22
|
+
|
|
23
|
+
import cloudpickle
|
|
24
|
+
|
|
25
|
+
import cascade.executor.platform as platform
|
|
26
|
+
import cascade.shm.api as shm_api
|
|
27
|
+
import cascade.shm.client as shm_client
|
|
28
|
+
from cascade.executor.comms import Listener as ZmqListener
|
|
29
|
+
from cascade.executor.config import logging_config
|
|
30
|
+
from cascade.executor.msg import BackboneAddress, DatasetPublished
|
|
31
|
+
from cascade.executor.runner.memory import Memory, ds2shmid
|
|
32
|
+
from cascade.executor.runner.packages import PackagesEnv
|
|
33
|
+
from cascade.executor.runner.runner import ExecutionContext, run
|
|
34
|
+
from cascade.low.builders import TaskBuilder
|
|
35
|
+
from cascade.low.core import DatasetId
|
|
36
|
+
from cascade.shm.server import entrypoint as shm_server
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@contextmanager
|
|
42
|
+
def setup_shm(testId: str):
|
|
43
|
+
mp_ctx = platform.get_mp_ctx("executor-aux")
|
|
44
|
+
shm_socket = f"/tmp/tcShm-{testId}"
|
|
45
|
+
shm_api.publish_socket_addr(shm_socket)
|
|
46
|
+
shm_process = mp_ctx.Process(
|
|
47
|
+
target=shm_server,
|
|
48
|
+
kwargs={
|
|
49
|
+
"logging_config": logging_config,
|
|
50
|
+
"shm_pref": f"tc{testId}",
|
|
51
|
+
},
|
|
52
|
+
)
|
|
53
|
+
shm_process.start()
|
|
54
|
+
shm_client.ensure()
|
|
55
|
+
try:
|
|
56
|
+
yield
|
|
57
|
+
except Exception as e:
|
|
58
|
+
# NOTE we log like this in case shm shutdown freezes
|
|
59
|
+
logger.exception(f"gotten {repr(e)}, proceed with shm shutdown")
|
|
60
|
+
raise
|
|
61
|
+
finally:
|
|
62
|
+
shm_client.shutdown(timeout_sec=1.0)
|
|
63
|
+
shm_process.join(1)
|
|
64
|
+
if shm_process.is_alive():
|
|
65
|
+
shm_process.terminate()
|
|
66
|
+
shm_process.join(1)
|
|
67
|
+
if shm_process.is_alive():
|
|
68
|
+
shm_process.kill()
|
|
69
|
+
shm_process.join(1)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def simple_runner(callback: BackboneAddress, executionContext: ExecutionContext):
|
|
73
|
+
tasks = list(executionContext.tasks.keys())
|
|
74
|
+
if len(tasks) != 1:
|
|
75
|
+
raise ValueError(f"expected 1 task, gotten {len(tasks)}")
|
|
76
|
+
taskId = tasks[0]
|
|
77
|
+
taskInstance = executionContext.tasks[taskId]
|
|
78
|
+
with Memory(callback, "testWorker") as memory, PackagesEnv() as pckg:
|
|
79
|
+
# for key, value in taskSequence.extra_env.items():
|
|
80
|
+
# os.environ[key] = value
|
|
81
|
+
|
|
82
|
+
pckg.extend(taskInstance.definition.environment)
|
|
83
|
+
run(taskId, executionContext, memory)
|
|
84
|
+
memory.flush()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class CallableInstance:
|
|
89
|
+
func: Callable
|
|
90
|
+
kwargs: dict[str, Any]
|
|
91
|
+
args: list[tuple[int, Any]]
|
|
92
|
+
env: list[str]
|
|
93
|
+
exp_output: Any
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def callable2ctx(
|
|
97
|
+
callableInstance: CallableInstance, callback: BackboneAddress
|
|
98
|
+
) -> ExecutionContext:
|
|
99
|
+
taskInstance = TaskBuilder.from_callable(
|
|
100
|
+
callableInstance.func, callableInstance.env
|
|
101
|
+
)
|
|
102
|
+
param_source = {}
|
|
103
|
+
params = [
|
|
104
|
+
(key, DatasetId("taskId", f"kwarg.{key}"), value)
|
|
105
|
+
for key, value in callableInstance.kwargs.items()
|
|
106
|
+
] + [
|
|
107
|
+
(key, DatasetId("taskId", f"pos.{key}"), value)
|
|
108
|
+
for key, value in callableInstance.args
|
|
109
|
+
]
|
|
110
|
+
for key, ds_key, value in params:
|
|
111
|
+
raw = cloudpickle.dumps(value)
|
|
112
|
+
L = len(raw)
|
|
113
|
+
buf = shm_client.allocate(ds2shmid(ds_key), L, "cloudpickle.loads")
|
|
114
|
+
buf.view()[:L] = raw
|
|
115
|
+
buf.close()
|
|
116
|
+
param_source[key] = (ds_key, "Any")
|
|
117
|
+
|
|
118
|
+
return ExecutionContext(
|
|
119
|
+
tasks={"taskId": taskInstance},
|
|
120
|
+
param_source={"taskId": param_source},
|
|
121
|
+
callback=callback,
|
|
122
|
+
publish={
|
|
123
|
+
DatasetId("taskId", output)
|
|
124
|
+
for output, _ in taskInstance.definition.output_schema
|
|
125
|
+
},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def run_test(
|
|
130
|
+
callableInstance: CallableInstance, testId: str, max_runtime_sec: int
|
|
131
|
+
) -> Any:
|
|
132
|
+
with setup_shm(testId):
|
|
133
|
+
addr = f"ipc:///tmp/tc{testId}"
|
|
134
|
+
listener = ZmqListener(addr)
|
|
135
|
+
ec_ctx = callable2ctx(callableInstance, addr)
|
|
136
|
+
mp_ctx = platform.get_mp_ctx("executor-aux")
|
|
137
|
+
runner = mp_ctx.Process(target=simple_runner, args=(addr, ec_ctx))
|
|
138
|
+
runner.start()
|
|
139
|
+
output = DatasetId("taskId", "0")
|
|
140
|
+
|
|
141
|
+
end = perf_counter_ns() + max_runtime_sec * int(1e9)
|
|
142
|
+
while perf_counter_ns() < end:
|
|
143
|
+
mess = listener.recv_messages()
|
|
144
|
+
if mess == [
|
|
145
|
+
DatasetPublished(origin="testWorker", ds=output, transmit_idx=None)
|
|
146
|
+
]:
|
|
147
|
+
break
|
|
148
|
+
elif not mess:
|
|
149
|
+
continue
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError(mess)
|
|
152
|
+
|
|
153
|
+
runner.join()
|
|
154
|
+
output_buf = shm_client.get(ds2shmid(output))
|
|
155
|
+
output_des = cloudpickle.loads(output_buf.view())
|
|
156
|
+
output_buf.close()
|
|
157
|
+
assert output_des == callableInstance.exp_output
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def demo():
|
|
161
|
+
def myfunc(l: int) -> float:
|
|
162
|
+
import numpy as np
|
|
163
|
+
|
|
164
|
+
return np.arange(l).sum()
|
|
165
|
+
|
|
166
|
+
ci = CallableInstance(
|
|
167
|
+
func=myfunc, kwargs={"l": 4}, args=[], env=["numpy"], exp_output=6
|
|
168
|
+
)
|
|
169
|
+
run_test(ci, "numpyTest1", 2)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
if __name__ == "__main__":
|
|
173
|
+
demo()
|
cascade/benchmarks/util.py
CHANGED
|
@@ -17,6 +17,7 @@ import subprocess
|
|
|
17
17
|
import sys
|
|
18
18
|
from concurrent.futures import ThreadPoolExecutor
|
|
19
19
|
from time import perf_counter_ns
|
|
20
|
+
from typing import Any
|
|
20
21
|
|
|
21
22
|
import orjson
|
|
22
23
|
|
|
@@ -28,7 +29,7 @@ from cascade.executor.comms import callback
|
|
|
28
29
|
from cascade.executor.config import logging_config, logging_config_filehandler
|
|
29
30
|
from cascade.executor.executor import Executor
|
|
30
31
|
from cascade.executor.msg import BackboneAddress, ExecutorShutdown
|
|
31
|
-
from cascade.low.core import JobInstance
|
|
32
|
+
from cascade.low.core import DatasetId, JobInstance
|
|
32
33
|
from cascade.low.func import msum
|
|
33
34
|
from cascade.scheduler.precompute import precompute
|
|
34
35
|
from earthkit.workflows.graph import Graph, deduplicate_nodes
|
|
@@ -159,7 +160,7 @@ def run_locally(
|
|
|
159
160
|
portBase: int = 12345,
|
|
160
161
|
log_base: str | None = None,
|
|
161
162
|
report_address: str | None = None,
|
|
162
|
-
) ->
|
|
163
|
+
) -> dict[DatasetId, Any]:
|
|
163
164
|
if log_base is not None:
|
|
164
165
|
log_path = f"{log_base}.controller.txt"
|
|
165
166
|
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
@@ -216,6 +217,7 @@ def run_locally(
|
|
|
216
217
|
if os.environ.get("CASCADE_DEBUG_PRINT"):
|
|
217
218
|
for key, value in result.outputs.items():
|
|
218
219
|
print(f"{key} => {value}")
|
|
220
|
+
return result.outputs
|
|
219
221
|
except Exception:
|
|
220
222
|
# NOTE we log this to get the stacktrace into the logfile
|
|
221
223
|
logger.exception("controller failure, proceed with executor shutdown")
|
cascade/controller/report.py
CHANGED
|
@@ -27,6 +27,7 @@ JobId = str
|
|
|
27
27
|
|
|
28
28
|
@dataclass
|
|
29
29
|
class JobProgress:
|
|
30
|
+
started: bool
|
|
30
31
|
completed: bool
|
|
31
32
|
pct: (
|
|
32
33
|
str | None
|
|
@@ -35,19 +36,20 @@ class JobProgress:
|
|
|
35
36
|
|
|
36
37
|
@classmethod
|
|
37
38
|
def failed(cls, failure: str) -> Self:
|
|
38
|
-
return cls(True, None, failure)
|
|
39
|
+
return cls(True, True, None, failure)
|
|
39
40
|
|
|
40
41
|
@classmethod
|
|
41
42
|
def progressed(cls, pct: float) -> Self:
|
|
42
43
|
progress = "{:.2%}".format(pct)[:-1]
|
|
43
|
-
return cls(False, progress, None)
|
|
44
|
+
return cls(True, False, progress, None)
|
|
44
45
|
|
|
45
46
|
@classmethod
|
|
46
47
|
def succeeded(cls) -> Self:
|
|
47
|
-
return cls(True, None, None)
|
|
48
|
+
return cls(True, True, None, None)
|
|
48
49
|
|
|
49
50
|
|
|
50
|
-
JobProgressStarted = JobProgress(False, "0.00", None)
|
|
51
|
+
JobProgressStarted = JobProgress(True, False, "0.00", None)
|
|
52
|
+
JobProgressEnqueued = JobProgress(False, False, None, None)
|
|
51
53
|
|
|
52
54
|
|
|
53
55
|
@dataclass
|
cascade/executor/data_server.py
CHANGED
|
@@ -20,7 +20,6 @@ from concurrent.futures import Executor as PythonExecutor
|
|
|
20
20
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
|
21
21
|
from time import time_ns
|
|
22
22
|
|
|
23
|
-
import cascade.shm.api as shm_api
|
|
24
23
|
import cascade.shm.client as shm_client
|
|
25
24
|
from cascade.executor.comms import Listener, callback, send_data
|
|
26
25
|
from cascade.executor.msg import (
|
|
@@ -48,7 +47,6 @@ class DataServer:
|
|
|
48
47
|
maddress: BackboneAddress,
|
|
49
48
|
daddress: BackboneAddress,
|
|
50
49
|
host: str,
|
|
51
|
-
shm_port: int,
|
|
52
50
|
logging_config: dict,
|
|
53
51
|
):
|
|
54
52
|
logging.config.dictConfig(logging_config)
|
|
@@ -58,7 +56,6 @@ class DataServer:
|
|
|
58
56
|
self.daddress = daddress
|
|
59
57
|
self.dlistener = Listener(daddress)
|
|
60
58
|
self.terminating = False
|
|
61
|
-
shm_api.publish_client_port(shm_port)
|
|
62
59
|
self.cap = 2
|
|
63
60
|
self.ds_proc_tp: PythonExecutor = ThreadPoolExecutor(max_workers=self.cap)
|
|
64
61
|
self.futs_in_progress: dict[
|
|
@@ -305,8 +302,7 @@ def start_data_server(
|
|
|
305
302
|
maddress: BackboneAddress,
|
|
306
303
|
daddress: BackboneAddress,
|
|
307
304
|
host: str,
|
|
308
|
-
shm_port: int,
|
|
309
305
|
logging_config: dict,
|
|
310
306
|
):
|
|
311
|
-
server = DataServer(maddress, daddress, host,
|
|
307
|
+
server = DataServer(maddress, daddress, host, logging_config)
|
|
312
308
|
server.recv_loop()
|
cascade/executor/executor.py
CHANGED
|
@@ -18,6 +18,7 @@ the tasks themselves.
|
|
|
18
18
|
import atexit
|
|
19
19
|
import logging
|
|
20
20
|
import os
|
|
21
|
+
import uuid
|
|
21
22
|
from multiprocessing.process import BaseProcess
|
|
22
23
|
from typing import Iterable
|
|
23
24
|
|
|
@@ -94,8 +95,8 @@ class Executor:
|
|
|
94
95
|
self.sender = ReliableSender(self.mlistener.address, resend_grace_ms)
|
|
95
96
|
self.sender.add_host("controller", controller_address)
|
|
96
97
|
# TODO make the shm server params configurable
|
|
97
|
-
shm_port = portBase + 2
|
|
98
|
-
shm_api.
|
|
98
|
+
shm_port = f"/tmp/cascShmSock-{uuid.uuid4()}" # portBase + 2
|
|
99
|
+
shm_api.publish_socket_addr(shm_port)
|
|
99
100
|
ctx = platform.get_mp_ctx("executor-aux")
|
|
100
101
|
if log_base:
|
|
101
102
|
shm_logging = logging_config_filehandler(f"{log_base}.shm.txt")
|
|
@@ -104,12 +105,11 @@ class Executor:
|
|
|
104
105
|
logger.debug("about to start an shm process")
|
|
105
106
|
self.shm_process = ctx.Process(
|
|
106
107
|
target=shm_server,
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
),
|
|
108
|
+
kwargs={
|
|
109
|
+
"capacity": shm_vol_gb * (1024**3) if shm_vol_gb else None,
|
|
110
|
+
"logging_config": shm_logging,
|
|
111
|
+
"shm_pref": f"sCasc{host}",
|
|
112
|
+
},
|
|
113
113
|
)
|
|
114
114
|
self.shm_process.start()
|
|
115
115
|
self.daddress = address_of(portBase + 1)
|
|
@@ -124,7 +124,6 @@ class Executor:
|
|
|
124
124
|
self.mlistener.address,
|
|
125
125
|
self.daddress,
|
|
126
126
|
self.host,
|
|
127
|
-
shm_port,
|
|
128
127
|
dsr_logging,
|
|
129
128
|
),
|
|
130
129
|
)
|
|
@@ -12,6 +12,7 @@ Interaction with shm
|
|
|
12
12
|
|
|
13
13
|
import hashlib
|
|
14
14
|
import logging
|
|
15
|
+
import sys
|
|
15
16
|
from contextlib import AbstractContextManager
|
|
16
17
|
from typing import Any, Literal
|
|
17
18
|
|
|
@@ -119,27 +120,29 @@ class Memory(AbstractContextManager):
|
|
|
119
120
|
self.local.pop(inputId)
|
|
120
121
|
|
|
121
122
|
# NOTE poor man's gpu mem management -- currently torch only. Given the task sequence limitation,
|
|
122
|
-
# this may not be the best place to invoke.
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
if free / total < 0.8:
|
|
131
|
-
torch.cuda.empty_cache()
|
|
123
|
+
# this may not be the best place to invoke.
|
|
124
|
+
if (
|
|
125
|
+
"torch" in sys.modules
|
|
126
|
+
): # if no task on this worker imported torch, no need to flush
|
|
127
|
+
try:
|
|
128
|
+
import torch
|
|
129
|
+
|
|
130
|
+
if torch.cuda.is_available():
|
|
132
131
|
free, total = torch.cuda.mem_get_info()
|
|
133
|
-
logger.debug(f"cuda mem avail
|
|
132
|
+
logger.debug(f"cuda mem avail: {free/total:.2%}")
|
|
134
133
|
if free / total < 0.8:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
logger.
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
134
|
+
torch.cuda.empty_cache()
|
|
135
|
+
free, total = torch.cuda.mem_get_info()
|
|
136
|
+
logger.debug(
|
|
137
|
+
f"cuda mem avail post cache empty: {free/total:.2%}"
|
|
138
|
+
)
|
|
139
|
+
if free / total < 0.8:
|
|
140
|
+
# NOTE this ofc makes low sense if there is any other application (like browser or ollama)
|
|
141
|
+
# that the user may be running
|
|
142
|
+
logger.warning("cuda mem avail low despite cache empty!")
|
|
143
|
+
logger.debug(torch.cuda.memory_summary())
|
|
144
|
+
except Exception:
|
|
145
|
+
logger.exception("failed to free cuda cache")
|
|
143
146
|
|
|
144
147
|
def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
|
|
145
148
|
# this is required so that the Shm can be properly freed, otherwise you get 'pointers cannot be closed'
|
|
@@ -12,6 +12,7 @@ Note that venv itself is left untouched after the run finishes -- we extend sys
|
|
|
12
12
|
with a temporary directory and install in there
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
+
import importlib
|
|
15
16
|
import logging
|
|
16
17
|
import os
|
|
17
18
|
import site
|
|
@@ -24,6 +25,51 @@ from typing import Literal
|
|
|
24
25
|
logger = logging.getLogger(__name__)
|
|
25
26
|
|
|
26
27
|
|
|
28
|
+
class Commands:
|
|
29
|
+
venv_command = lambda name: ["uv", "venv", name]
|
|
30
|
+
install_command = lambda name: [
|
|
31
|
+
"uv",
|
|
32
|
+
"pip",
|
|
33
|
+
"install",
|
|
34
|
+
"--prefix",
|
|
35
|
+
name,
|
|
36
|
+
"--prerelease",
|
|
37
|
+
"explicit",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def run_command(command: list[str]) -> None:
|
|
42
|
+
try:
|
|
43
|
+
result = subprocess.run(command, check=False, capture_output=True)
|
|
44
|
+
except FileNotFoundError as ex:
|
|
45
|
+
raise ValueError(f"command failure: {ex}")
|
|
46
|
+
if result.returncode != 0:
|
|
47
|
+
msg = f"command failed with {result.returncode}. Stderr: {result.stderr}, Stdout: {result.stdout}, Args: {result.args}"
|
|
48
|
+
logger.error(msg)
|
|
49
|
+
raise ValueError(msg)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def new_venv() -> tempfile.TemporaryDirectory:
|
|
53
|
+
"""1. Creates a new temporary directory with a venv inside.
|
|
54
|
+
2. Extends sys.path so that packages in that venv can be imported.
|
|
55
|
+
"""
|
|
56
|
+
logger.debug("creating a new venv")
|
|
57
|
+
td = tempfile.TemporaryDirectory(prefix="cascade_runner_venv_")
|
|
58
|
+
# NOTE we create a venv instead of just plain directory, because some of the packages create files
|
|
59
|
+
# outside of site-packages. Thus we then install with --prefix, not with --target
|
|
60
|
+
run_command(Commands.venv_command(td.name))
|
|
61
|
+
|
|
62
|
+
# NOTE not sure if getsitepackages was intended for this -- if issues, attempt replacing
|
|
63
|
+
# with something like f"{td.name}/lib/python*/site-packages" + globbing
|
|
64
|
+
extra_sp = site.getsitepackages(prefixes=[td.name])
|
|
65
|
+
# NOTE this makes the explicit packages go first, in case of a different version
|
|
66
|
+
logger.debug(f"extending sys.path with {extra_sp}")
|
|
67
|
+
sys.path = extra_sp + sys.path
|
|
68
|
+
logger.debug(f"new sys.path: {sys.path}")
|
|
69
|
+
|
|
70
|
+
return td
|
|
71
|
+
|
|
72
|
+
|
|
27
73
|
class PackagesEnv(AbstractContextManager):
|
|
28
74
|
def __init__(self) -> None:
|
|
29
75
|
self.td: tempfile.TemporaryDirectory | None = None
|
|
@@ -32,38 +78,32 @@ class PackagesEnv(AbstractContextManager):
|
|
|
32
78
|
if not packages:
|
|
33
79
|
return
|
|
34
80
|
if self.td is None:
|
|
35
|
-
|
|
36
|
-
self.td = tempfile.TemporaryDirectory()
|
|
37
|
-
venv_command = ["uv", "venv", self.td.name]
|
|
38
|
-
# NOTE we create a venv instead of just plain directory, because some of the packages create files
|
|
39
|
-
# outside of site-packages. Thus we then install with --prefix, not with --target
|
|
40
|
-
subprocess.run(venv_command, check=True)
|
|
81
|
+
self.td = new_venv()
|
|
41
82
|
|
|
42
83
|
logger.debug(
|
|
43
84
|
f"installing {len(packages)} packages: {','.join(packages[:3])}{',...' if len(packages) > 3 else ''}"
|
|
44
85
|
)
|
|
45
|
-
install_command =
|
|
46
|
-
"uv",
|
|
47
|
-
"pip",
|
|
48
|
-
"install",
|
|
49
|
-
"--prefix",
|
|
50
|
-
self.td.name,
|
|
51
|
-
"--prerelease",
|
|
52
|
-
"allow",
|
|
53
|
-
]
|
|
86
|
+
install_command = Commands.install_command(self.td.name)
|
|
54
87
|
if os.environ.get("VENV_OFFLINE", "") == "YES":
|
|
55
88
|
install_command += ["--offline"]
|
|
56
89
|
if cache_dir := os.environ.get("VENV_CACHE", ""):
|
|
57
90
|
install_command += ["--cache-dir", cache_dir]
|
|
58
91
|
install_command.extend(set(packages))
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
#
|
|
64
|
-
|
|
92
|
+
logger.debug(f"running install command: {' '.join(install_command)}")
|
|
93
|
+
run_command(install_command)
|
|
94
|
+
|
|
95
|
+
# NOTE we need this due to namespace packages:
|
|
96
|
+
# 1. task 1 installs ns.pkg1 in its venv
|
|
97
|
+
# 2. task 1 finishes, task 2 starts on the same worker
|
|
98
|
+
# 3. task 2 starts, installs ns.pkg2. However, importlib is in a state that ns is aware only of pkg1 submod
|
|
99
|
+
# Additionally, the caches are invalid anyway, because task 1's venv is already deleted
|
|
100
|
+
importlib.invalidate_caches()
|
|
101
|
+
# TODO some namespace packages may require a reimport because they dynamically build `__all__` -- eg earthkit
|
|
65
102
|
|
|
66
103
|
def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
|
|
104
|
+
sys.path = [
|
|
105
|
+
p for p in sys.path if self.td is None or not p.startswith(self.td.name)
|
|
106
|
+
]
|
|
67
107
|
if self.td is not None:
|
|
68
108
|
self.td.cleanup()
|
|
69
109
|
return False
|
cascade/gateway/__main__.py
CHANGED
|
@@ -15,14 +15,17 @@ from cascade.gateway.server import serve
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def main(
|
|
18
|
-
url: str,
|
|
18
|
+
url: str,
|
|
19
|
+
log_base: str | None = None,
|
|
20
|
+
troika_config: str | None = None,
|
|
21
|
+
max_jobs: int | None = None,
|
|
19
22
|
) -> None:
|
|
20
23
|
if log_base:
|
|
21
24
|
log_path = f"{log_base}/gateway.txt"
|
|
22
25
|
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
23
26
|
else:
|
|
24
27
|
logging.config.dictConfig(logging_config)
|
|
25
|
-
serve(url, log_base, troika_config)
|
|
28
|
+
serve(url, log_base, troika_config, max_jobs)
|
|
26
29
|
|
|
27
30
|
|
|
28
31
|
if __name__ == "__main__":
|
cascade/gateway/api.py
CHANGED
|
@@ -62,8 +62,9 @@ class JobProgressRequest(CascadeGatewayAPI):
|
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
class JobProgressResponse(CascadeGatewayAPI):
|
|
65
|
-
progresses: dict[JobId, JobProgress]
|
|
65
|
+
progresses: dict[JobId, JobProgress | None]
|
|
66
66
|
datasets: dict[JobId, list[DatasetId]]
|
|
67
|
+
queue_length: int
|
|
67
68
|
error: str | None # top level error
|
|
68
69
|
|
|
69
70
|
|
cascade/gateway/router.py
CHANGED
|
@@ -15,6 +15,7 @@ import os
|
|
|
15
15
|
import stat
|
|
16
16
|
import subprocess
|
|
17
17
|
import uuid
|
|
18
|
+
from collections import OrderedDict
|
|
18
19
|
from dataclasses import dataclass
|
|
19
20
|
from typing import Iterable
|
|
20
21
|
|
|
@@ -22,7 +23,12 @@ import orjson
|
|
|
22
23
|
import zmq
|
|
23
24
|
|
|
24
25
|
import cascade.executor.platform as platform
|
|
25
|
-
from cascade.controller.report import
|
|
26
|
+
from cascade.controller.report import (
|
|
27
|
+
JobId,
|
|
28
|
+
JobProgress,
|
|
29
|
+
JobProgressEnqueued,
|
|
30
|
+
JobProgressStarted,
|
|
31
|
+
)
|
|
26
32
|
from cascade.executor.comms import get_context
|
|
27
33
|
from cascade.gateway.api import JobSpec, TroikaSpec
|
|
28
34
|
from cascade.low.core import DatasetId
|
|
@@ -202,16 +208,29 @@ def _spawn_subprocess(
|
|
|
202
208
|
|
|
203
209
|
class JobRouter:
|
|
204
210
|
def __init__(
|
|
205
|
-
self,
|
|
211
|
+
self,
|
|
212
|
+
poller: zmq.Poller,
|
|
213
|
+
log_base: str | None,
|
|
214
|
+
troika_config: str | None,
|
|
215
|
+
max_jobs: int | None,
|
|
206
216
|
):
|
|
207
217
|
self.poller = poller
|
|
208
218
|
self.jobs: dict[str, Job] = {}
|
|
219
|
+
self.active_jobs = 0
|
|
220
|
+
self.max_jobs = max_jobs
|
|
221
|
+
self.jobs_queue: OrderedDict[JobId, JobSpec] = OrderedDict()
|
|
209
222
|
self.procs: dict[str, subprocess.Popen] = {}
|
|
210
223
|
self.log_base = log_base
|
|
211
224
|
self.troika_config = troika_config
|
|
212
225
|
|
|
213
|
-
def
|
|
214
|
-
|
|
226
|
+
def maybe_spawn(self) -> None:
|
|
227
|
+
if not self.jobs_queue:
|
|
228
|
+
return
|
|
229
|
+
if self.max_jobs is not None and self.active_jobs >= self.max_jobs:
|
|
230
|
+
logger.debug(f"already running {self.active_jobs}, no spawn")
|
|
231
|
+
return
|
|
232
|
+
|
|
233
|
+
job_id, job_spec = self.jobs_queue.popitem(False)
|
|
215
234
|
base_addr = f"tcp://{platform.get_bindabble_self()}"
|
|
216
235
|
socket = get_context().socket(zmq.PULL)
|
|
217
236
|
port = socket.bind_to_random_port(base_addr)
|
|
@@ -222,18 +241,37 @@ class JobRouter:
|
|
|
222
241
|
self.procs[job_id] = _spawn_subprocess(
|
|
223
242
|
job_spec, full_addr, job_id, self.log_base, self.troika_config
|
|
224
243
|
)
|
|
244
|
+
self.active_jobs += 1
|
|
245
|
+
return job_id
|
|
246
|
+
|
|
247
|
+
def enqueue_job(self, job_spec: JobSpec) -> JobId:
|
|
248
|
+
job_id = next_uuid(
|
|
249
|
+
set(self.jobs.keys()).union(self.jobs_queue.keys()),
|
|
250
|
+
lambda: str(uuid.uuid4()),
|
|
251
|
+
)
|
|
252
|
+
self.jobs_queue[job_id] = job_spec
|
|
253
|
+
self.maybe_spawn()
|
|
225
254
|
return job_id
|
|
226
255
|
|
|
227
256
|
def progress_of(
|
|
228
257
|
self, job_ids: Iterable[JobId]
|
|
229
|
-
) -> tuple[dict[JobId, JobProgress], dict[JobId, list[DatasetId]]]:
|
|
258
|
+
) -> tuple[dict[JobId, JobProgress], dict[JobId, list[DatasetId]], int]:
|
|
230
259
|
if not job_ids:
|
|
231
|
-
job_ids = self.jobs.keys()
|
|
232
|
-
progresses = {
|
|
260
|
+
job_ids = set(self.jobs.keys()).union(self.jobs_queue.keys())
|
|
261
|
+
progresses = {}
|
|
262
|
+
for job_id in job_ids:
|
|
263
|
+
if job_id in self.jobs:
|
|
264
|
+
progresses[job_id] = self.jobs[job_id].progress
|
|
265
|
+
elif job_id in self.jobs_queue:
|
|
266
|
+
progresses[job_id] = JobProgressEnqueued
|
|
267
|
+
else:
|
|
268
|
+
progresses[job_id] = None
|
|
233
269
|
datasets = {
|
|
234
|
-
job_id: list(self.jobs[job_id].results.keys())
|
|
270
|
+
job_id: list(self.jobs[job_id].results.keys())
|
|
271
|
+
for job_id in job_ids
|
|
272
|
+
if job_id in self.jobs
|
|
235
273
|
}
|
|
236
|
-
return progresses, datasets
|
|
274
|
+
return progresses, datasets, len(self.jobs_queue)
|
|
237
275
|
|
|
238
276
|
def get_result(self, job_id: JobId, dataset_id: DatasetId) -> bytes:
|
|
239
277
|
return self.jobs[job_id].results[dataset_id]
|
|
@@ -246,6 +284,8 @@ class JobRouter:
|
|
|
246
284
|
job = self.jobs[job_id]
|
|
247
285
|
if progress.completed:
|
|
248
286
|
self.poller.unregister(job.socket)
|
|
287
|
+
self.active_jobs -= 1
|
|
288
|
+
self.maybe_spawn()
|
|
249
289
|
if progress.failure is not None and job.progress.failure is None:
|
|
250
290
|
job.progress = progress
|
|
251
291
|
elif job.last_seen >= timestamp or job.progress.failure is not None:
|
cascade/gateway/server.py
CHANGED
|
@@ -31,16 +31,19 @@ def handle_fe(socket: zmq.Socket, jobs: JobRouter) -> bool:
|
|
|
31
31
|
rv: api.CascadeGatewayAPI
|
|
32
32
|
if isinstance(m, api.SubmitJobRequest):
|
|
33
33
|
try:
|
|
34
|
-
job_id = jobs.
|
|
34
|
+
job_id = jobs.enqueue_job(m.job)
|
|
35
35
|
rv = api.SubmitJobResponse(job_id=job_id, error=None)
|
|
36
36
|
except Exception as e:
|
|
37
37
|
logger.exception(f"failed to spawn a job: {m}")
|
|
38
38
|
rv = api.SubmitJobResponse(job_id=None, error=repr(e))
|
|
39
39
|
elif isinstance(m, api.JobProgressRequest):
|
|
40
40
|
try:
|
|
41
|
-
progresses, datasets = jobs.progress_of(m.job_ids)
|
|
41
|
+
progresses, datasets, queue_length = jobs.progress_of(m.job_ids)
|
|
42
42
|
rv = api.JobProgressResponse(
|
|
43
|
-
progresses=progresses,
|
|
43
|
+
progresses=progresses,
|
|
44
|
+
datasets=datasets,
|
|
45
|
+
error=None,
|
|
46
|
+
queue_length=queue_length,
|
|
44
47
|
)
|
|
45
48
|
except Exception as e:
|
|
46
49
|
logger.exception(f"failed to get progress of: {m}")
|
|
@@ -80,7 +83,10 @@ def handle_controller(socket: zmq.Socket, jobs: JobRouter) -> None:
|
|
|
80
83
|
|
|
81
84
|
|
|
82
85
|
def serve(
|
|
83
|
-
url: str,
|
|
86
|
+
url: str,
|
|
87
|
+
log_base: str | None = None,
|
|
88
|
+
troika_config: str | None = None,
|
|
89
|
+
max_jobs: int | None = None,
|
|
84
90
|
) -> None:
|
|
85
91
|
ctx = get_context()
|
|
86
92
|
poller = zmq.Poller()
|
|
@@ -88,7 +94,7 @@ def serve(
|
|
|
88
94
|
fe = ctx.socket(zmq.REP)
|
|
89
95
|
fe.bind(url)
|
|
90
96
|
poller.register(fe, flags=zmq.POLLIN)
|
|
91
|
-
jobs = JobRouter(poller, log_base, troika_config)
|
|
97
|
+
jobs = JobRouter(poller, log_base, troika_config, max_jobs)
|
|
92
98
|
|
|
93
99
|
logger.debug("entering recv loop")
|
|
94
100
|
is_break = False
|