earthkit-workflows 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cascade/benchmarks/tests.py +173 -0
- cascade/benchmarks/util.py +4 -2
- cascade/executor/data_server.py +1 -5
- cascade/executor/executor.py +8 -9
- cascade/executor/runner/memory.py +22 -19
- cascade/executor/runner/packages.py +61 -21
- cascade/low/builders.py +42 -5
- cascade/shm/api.py +53 -12
- cascade/shm/client.py +29 -16
- cascade/shm/dataset.py +15 -1
- cascade/shm/server.py +28 -15
- earthkit/workflows/_version.py +1 -1
- earthkit/workflows/decorators.py +23 -9
- earthkit/workflows/fluent.py +25 -17
- earthkit/workflows/mark.py +3 -4
- {earthkit_workflows-0.4.7.dist-info → earthkit_workflows-0.5.1.dist-info}/METADATA +1 -1
- {earthkit_workflows-0.4.7.dist-info → earthkit_workflows-0.5.1.dist-info}/RECORD +20 -20
- earthkit/workflows/py.typed +0 -0
- {earthkit_workflows-0.4.7.dist-info → earthkit_workflows-0.5.1.dist-info}/WHEEL +0 -0
- {earthkit_workflows-0.4.7.dist-info → earthkit_workflows-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {earthkit_workflows-0.4.7.dist-info → earthkit_workflows-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# (C) Copyright 2025- ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""Module for simplifying writing tests
|
|
10
|
+
|
|
11
|
+
Similar to util, but not enough to unify
|
|
12
|
+
|
|
13
|
+
It is capable, for a single given task, to spin an shm server, put all task's inputs into it, execute the task, store outputs in memory, and retrieve the result.
|
|
14
|
+
See the `demo()` function at the very end
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from time import perf_counter_ns
|
|
21
|
+
from typing import Any, Callable
|
|
22
|
+
|
|
23
|
+
import cloudpickle
|
|
24
|
+
|
|
25
|
+
import cascade.executor.platform as platform
|
|
26
|
+
import cascade.shm.api as shm_api
|
|
27
|
+
import cascade.shm.client as shm_client
|
|
28
|
+
from cascade.executor.comms import Listener as ZmqListener
|
|
29
|
+
from cascade.executor.config import logging_config
|
|
30
|
+
from cascade.executor.msg import BackboneAddress, DatasetPublished
|
|
31
|
+
from cascade.executor.runner.memory import Memory, ds2shmid
|
|
32
|
+
from cascade.executor.runner.packages import PackagesEnv
|
|
33
|
+
from cascade.executor.runner.runner import ExecutionContext, run
|
|
34
|
+
from cascade.low.builders import TaskBuilder
|
|
35
|
+
from cascade.low.core import DatasetId
|
|
36
|
+
from cascade.shm.server import entrypoint as shm_server
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@contextmanager
|
|
42
|
+
def setup_shm(testId: str):
|
|
43
|
+
mp_ctx = platform.get_mp_ctx("executor-aux")
|
|
44
|
+
shm_socket = f"/tmp/tcShm-{testId}"
|
|
45
|
+
shm_api.publish_socket_addr(shm_socket)
|
|
46
|
+
shm_process = mp_ctx.Process(
|
|
47
|
+
target=shm_server,
|
|
48
|
+
kwargs={
|
|
49
|
+
"logging_config": logging_config,
|
|
50
|
+
"shm_pref": f"tc{testId}",
|
|
51
|
+
},
|
|
52
|
+
)
|
|
53
|
+
shm_process.start()
|
|
54
|
+
shm_client.ensure()
|
|
55
|
+
try:
|
|
56
|
+
yield
|
|
57
|
+
except Exception as e:
|
|
58
|
+
# NOTE we log like this in case shm shutdown freezes
|
|
59
|
+
logger.exception(f"gotten {repr(e)}, proceed with shm shutdown")
|
|
60
|
+
raise
|
|
61
|
+
finally:
|
|
62
|
+
shm_client.shutdown(timeout_sec=1.0)
|
|
63
|
+
shm_process.join(1)
|
|
64
|
+
if shm_process.is_alive():
|
|
65
|
+
shm_process.terminate()
|
|
66
|
+
shm_process.join(1)
|
|
67
|
+
if shm_process.is_alive():
|
|
68
|
+
shm_process.kill()
|
|
69
|
+
shm_process.join(1)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def simple_runner(callback: BackboneAddress, executionContext: ExecutionContext):
|
|
73
|
+
tasks = list(executionContext.tasks.keys())
|
|
74
|
+
if len(tasks) != 1:
|
|
75
|
+
raise ValueError(f"expected 1 task, gotten {len(tasks)}")
|
|
76
|
+
taskId = tasks[0]
|
|
77
|
+
taskInstance = executionContext.tasks[taskId]
|
|
78
|
+
with Memory(callback, "testWorker") as memory, PackagesEnv() as pckg:
|
|
79
|
+
# for key, value in taskSequence.extra_env.items():
|
|
80
|
+
# os.environ[key] = value
|
|
81
|
+
|
|
82
|
+
pckg.extend(taskInstance.definition.environment)
|
|
83
|
+
run(taskId, executionContext, memory)
|
|
84
|
+
memory.flush()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class CallableInstance:
|
|
89
|
+
func: Callable
|
|
90
|
+
kwargs: dict[str, Any]
|
|
91
|
+
args: list[tuple[int, Any]]
|
|
92
|
+
env: list[str]
|
|
93
|
+
exp_output: Any
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def callable2ctx(
|
|
97
|
+
callableInstance: CallableInstance, callback: BackboneAddress
|
|
98
|
+
) -> ExecutionContext:
|
|
99
|
+
taskInstance = TaskBuilder.from_callable(
|
|
100
|
+
callableInstance.func, callableInstance.env
|
|
101
|
+
)
|
|
102
|
+
param_source = {}
|
|
103
|
+
params = [
|
|
104
|
+
(key, DatasetId("taskId", f"kwarg.{key}"), value)
|
|
105
|
+
for key, value in callableInstance.kwargs.items()
|
|
106
|
+
] + [
|
|
107
|
+
(key, DatasetId("taskId", f"pos.{key}"), value)
|
|
108
|
+
for key, value in callableInstance.args
|
|
109
|
+
]
|
|
110
|
+
for key, ds_key, value in params:
|
|
111
|
+
raw = cloudpickle.dumps(value)
|
|
112
|
+
L = len(raw)
|
|
113
|
+
buf = shm_client.allocate(ds2shmid(ds_key), L, "cloudpickle.loads")
|
|
114
|
+
buf.view()[:L] = raw
|
|
115
|
+
buf.close()
|
|
116
|
+
param_source[key] = (ds_key, "Any")
|
|
117
|
+
|
|
118
|
+
return ExecutionContext(
|
|
119
|
+
tasks={"taskId": taskInstance},
|
|
120
|
+
param_source={"taskId": param_source},
|
|
121
|
+
callback=callback,
|
|
122
|
+
publish={
|
|
123
|
+
DatasetId("taskId", output)
|
|
124
|
+
for output, _ in taskInstance.definition.output_schema
|
|
125
|
+
},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def run_test(
|
|
130
|
+
callableInstance: CallableInstance, testId: str, max_runtime_sec: int
|
|
131
|
+
) -> Any:
|
|
132
|
+
with setup_shm(testId):
|
|
133
|
+
addr = f"ipc:///tmp/tc{testId}"
|
|
134
|
+
listener = ZmqListener(addr)
|
|
135
|
+
ec_ctx = callable2ctx(callableInstance, addr)
|
|
136
|
+
mp_ctx = platform.get_mp_ctx("executor-aux")
|
|
137
|
+
runner = mp_ctx.Process(target=simple_runner, args=(addr, ec_ctx))
|
|
138
|
+
runner.start()
|
|
139
|
+
output = DatasetId("taskId", "0")
|
|
140
|
+
|
|
141
|
+
end = perf_counter_ns() + max_runtime_sec * int(1e9)
|
|
142
|
+
while perf_counter_ns() < end:
|
|
143
|
+
mess = listener.recv_messages()
|
|
144
|
+
if mess == [
|
|
145
|
+
DatasetPublished(origin="testWorker", ds=output, transmit_idx=None)
|
|
146
|
+
]:
|
|
147
|
+
break
|
|
148
|
+
elif not mess:
|
|
149
|
+
continue
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError(mess)
|
|
152
|
+
|
|
153
|
+
runner.join()
|
|
154
|
+
output_buf = shm_client.get(ds2shmid(output))
|
|
155
|
+
output_des = cloudpickle.loads(output_buf.view())
|
|
156
|
+
output_buf.close()
|
|
157
|
+
assert output_des == callableInstance.exp_output
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def demo():
|
|
161
|
+
def myfunc(l: int) -> float:
|
|
162
|
+
import numpy as np
|
|
163
|
+
|
|
164
|
+
return np.arange(l).sum()
|
|
165
|
+
|
|
166
|
+
ci = CallableInstance(
|
|
167
|
+
func=myfunc, kwargs={"l": 4}, args=[], env=["numpy"], exp_output=6
|
|
168
|
+
)
|
|
169
|
+
run_test(ci, "numpyTest1", 2)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
if __name__ == "__main__":
|
|
173
|
+
demo()
|
cascade/benchmarks/util.py
CHANGED
|
@@ -17,6 +17,7 @@ import subprocess
|
|
|
17
17
|
import sys
|
|
18
18
|
from concurrent.futures import ThreadPoolExecutor
|
|
19
19
|
from time import perf_counter_ns
|
|
20
|
+
from typing import Any
|
|
20
21
|
|
|
21
22
|
import orjson
|
|
22
23
|
|
|
@@ -28,7 +29,7 @@ from cascade.executor.comms import callback
|
|
|
28
29
|
from cascade.executor.config import logging_config, logging_config_filehandler
|
|
29
30
|
from cascade.executor.executor import Executor
|
|
30
31
|
from cascade.executor.msg import BackboneAddress, ExecutorShutdown
|
|
31
|
-
from cascade.low.core import JobInstance
|
|
32
|
+
from cascade.low.core import DatasetId, JobInstance
|
|
32
33
|
from cascade.low.func import msum
|
|
33
34
|
from cascade.scheduler.precompute import precompute
|
|
34
35
|
from earthkit.workflows.graph import Graph, deduplicate_nodes
|
|
@@ -159,7 +160,7 @@ def run_locally(
|
|
|
159
160
|
portBase: int = 12345,
|
|
160
161
|
log_base: str | None = None,
|
|
161
162
|
report_address: str | None = None,
|
|
162
|
-
) ->
|
|
163
|
+
) -> dict[DatasetId, Any]:
|
|
163
164
|
if log_base is not None:
|
|
164
165
|
log_path = f"{log_base}.controller.txt"
|
|
165
166
|
logging.config.dictConfig(logging_config_filehandler(log_path))
|
|
@@ -216,6 +217,7 @@ def run_locally(
|
|
|
216
217
|
if os.environ.get("CASCADE_DEBUG_PRINT"):
|
|
217
218
|
for key, value in result.outputs.items():
|
|
218
219
|
print(f"{key} => {value}")
|
|
220
|
+
return result.outputs
|
|
219
221
|
except Exception:
|
|
220
222
|
# NOTE we log this to get the stacktrace into the logfile
|
|
221
223
|
logger.exception("controller failure, proceed with executor shutdown")
|
cascade/executor/data_server.py
CHANGED
|
@@ -20,7 +20,6 @@ from concurrent.futures import Executor as PythonExecutor
|
|
|
20
20
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
|
21
21
|
from time import time_ns
|
|
22
22
|
|
|
23
|
-
import cascade.shm.api as shm_api
|
|
24
23
|
import cascade.shm.client as shm_client
|
|
25
24
|
from cascade.executor.comms import Listener, callback, send_data
|
|
26
25
|
from cascade.executor.msg import (
|
|
@@ -48,7 +47,6 @@ class DataServer:
|
|
|
48
47
|
maddress: BackboneAddress,
|
|
49
48
|
daddress: BackboneAddress,
|
|
50
49
|
host: str,
|
|
51
|
-
shm_port: int,
|
|
52
50
|
logging_config: dict,
|
|
53
51
|
):
|
|
54
52
|
logging.config.dictConfig(logging_config)
|
|
@@ -58,7 +56,6 @@ class DataServer:
|
|
|
58
56
|
self.daddress = daddress
|
|
59
57
|
self.dlistener = Listener(daddress)
|
|
60
58
|
self.terminating = False
|
|
61
|
-
shm_api.publish_client_port(shm_port)
|
|
62
59
|
self.cap = 2
|
|
63
60
|
self.ds_proc_tp: PythonExecutor = ThreadPoolExecutor(max_workers=self.cap)
|
|
64
61
|
self.futs_in_progress: dict[
|
|
@@ -305,8 +302,7 @@ def start_data_server(
|
|
|
305
302
|
maddress: BackboneAddress,
|
|
306
303
|
daddress: BackboneAddress,
|
|
307
304
|
host: str,
|
|
308
|
-
shm_port: int,
|
|
309
305
|
logging_config: dict,
|
|
310
306
|
):
|
|
311
|
-
server = DataServer(maddress, daddress, host,
|
|
307
|
+
server = DataServer(maddress, daddress, host, logging_config)
|
|
312
308
|
server.recv_loop()
|
cascade/executor/executor.py
CHANGED
|
@@ -18,6 +18,7 @@ the tasks themselves.
|
|
|
18
18
|
import atexit
|
|
19
19
|
import logging
|
|
20
20
|
import os
|
|
21
|
+
import uuid
|
|
21
22
|
from multiprocessing.process import BaseProcess
|
|
22
23
|
from typing import Iterable
|
|
23
24
|
|
|
@@ -94,8 +95,8 @@ class Executor:
|
|
|
94
95
|
self.sender = ReliableSender(self.mlistener.address, resend_grace_ms)
|
|
95
96
|
self.sender.add_host("controller", controller_address)
|
|
96
97
|
# TODO make the shm server params configurable
|
|
97
|
-
shm_port = portBase + 2
|
|
98
|
-
shm_api.
|
|
98
|
+
shm_port = f"/tmp/cascShmSock-{uuid.uuid4()}" # portBase + 2
|
|
99
|
+
shm_api.publish_socket_addr(shm_port)
|
|
99
100
|
ctx = platform.get_mp_ctx("executor-aux")
|
|
100
101
|
if log_base:
|
|
101
102
|
shm_logging = logging_config_filehandler(f"{log_base}.shm.txt")
|
|
@@ -104,12 +105,11 @@ class Executor:
|
|
|
104
105
|
logger.debug("about to start an shm process")
|
|
105
106
|
self.shm_process = ctx.Process(
|
|
106
107
|
target=shm_server,
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
),
|
|
108
|
+
kwargs={
|
|
109
|
+
"capacity": shm_vol_gb * (1024**3) if shm_vol_gb else None,
|
|
110
|
+
"logging_config": shm_logging,
|
|
111
|
+
"shm_pref": f"sCasc{host}",
|
|
112
|
+
},
|
|
113
113
|
)
|
|
114
114
|
self.shm_process.start()
|
|
115
115
|
self.daddress = address_of(portBase + 1)
|
|
@@ -124,7 +124,6 @@ class Executor:
|
|
|
124
124
|
self.mlistener.address,
|
|
125
125
|
self.daddress,
|
|
126
126
|
self.host,
|
|
127
|
-
shm_port,
|
|
128
127
|
dsr_logging,
|
|
129
128
|
),
|
|
130
129
|
)
|
|
@@ -12,6 +12,7 @@ Interaction with shm
|
|
|
12
12
|
|
|
13
13
|
import hashlib
|
|
14
14
|
import logging
|
|
15
|
+
import sys
|
|
15
16
|
from contextlib import AbstractContextManager
|
|
16
17
|
from typing import Any, Literal
|
|
17
18
|
|
|
@@ -119,27 +120,29 @@ class Memory(AbstractContextManager):
|
|
|
119
120
|
self.local.pop(inputId)
|
|
120
121
|
|
|
121
122
|
# NOTE poor man's gpu mem management -- currently torch only. Given the task sequence limitation,
|
|
122
|
-
# this may not be the best place to invoke.
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
if free / total < 0.8:
|
|
131
|
-
torch.cuda.empty_cache()
|
|
123
|
+
# this may not be the best place to invoke.
|
|
124
|
+
if (
|
|
125
|
+
"torch" in sys.modules
|
|
126
|
+
): # if no task on this worker imported torch, no need to flush
|
|
127
|
+
try:
|
|
128
|
+
import torch
|
|
129
|
+
|
|
130
|
+
if torch.cuda.is_available():
|
|
132
131
|
free, total = torch.cuda.mem_get_info()
|
|
133
|
-
logger.debug(f"cuda mem avail
|
|
132
|
+
logger.debug(f"cuda mem avail: {free/total:.2%}")
|
|
134
133
|
if free / total < 0.8:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
logger.
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
134
|
+
torch.cuda.empty_cache()
|
|
135
|
+
free, total = torch.cuda.mem_get_info()
|
|
136
|
+
logger.debug(
|
|
137
|
+
f"cuda mem avail post cache empty: {free/total:.2%}"
|
|
138
|
+
)
|
|
139
|
+
if free / total < 0.8:
|
|
140
|
+
# NOTE this ofc makes low sense if there is any other application (like browser or ollama)
|
|
141
|
+
# that the user may be running
|
|
142
|
+
logger.warning("cuda mem avail low despite cache empty!")
|
|
143
|
+
logger.debug(torch.cuda.memory_summary())
|
|
144
|
+
except Exception:
|
|
145
|
+
logger.exception("failed to free cuda cache")
|
|
143
146
|
|
|
144
147
|
def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
|
|
145
148
|
# this is required so that the Shm can be properly freed, otherwise you get 'pointers cannot be closed'
|
|
@@ -12,6 +12,7 @@ Note that venv itself is left untouched after the run finishes -- we extend sys
|
|
|
12
12
|
with a temporary directory and install in there
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
+
import importlib
|
|
15
16
|
import logging
|
|
16
17
|
import os
|
|
17
18
|
import site
|
|
@@ -24,6 +25,51 @@ from typing import Literal
|
|
|
24
25
|
logger = logging.getLogger(__name__)
|
|
25
26
|
|
|
26
27
|
|
|
28
|
+
class Commands:
|
|
29
|
+
venv_command = lambda name: ["uv", "venv", name]
|
|
30
|
+
install_command = lambda name: [
|
|
31
|
+
"uv",
|
|
32
|
+
"pip",
|
|
33
|
+
"install",
|
|
34
|
+
"--prefix",
|
|
35
|
+
name,
|
|
36
|
+
"--prerelease",
|
|
37
|
+
"explicit",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def run_command(command: list[str]) -> None:
|
|
42
|
+
try:
|
|
43
|
+
result = subprocess.run(command, check=False, capture_output=True)
|
|
44
|
+
except FileNotFoundError as ex:
|
|
45
|
+
raise ValueError(f"command failure: {ex}")
|
|
46
|
+
if result.returncode != 0:
|
|
47
|
+
msg = f"command failed with {result.returncode}. Stderr: {result.stderr}, Stdout: {result.stdout}, Args: {result.args}"
|
|
48
|
+
logger.error(msg)
|
|
49
|
+
raise ValueError(msg)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def new_venv() -> tempfile.TemporaryDirectory:
|
|
53
|
+
"""1. Creates a new temporary directory with a venv inside.
|
|
54
|
+
2. Extends sys.path so that packages in that venv can be imported.
|
|
55
|
+
"""
|
|
56
|
+
logger.debug("creating a new venv")
|
|
57
|
+
td = tempfile.TemporaryDirectory(prefix="cascade_runner_venv_")
|
|
58
|
+
# NOTE we create a venv instead of just plain directory, because some of the packages create files
|
|
59
|
+
# outside of site-packages. Thus we then install with --prefix, not with --target
|
|
60
|
+
run_command(Commands.venv_command(td.name))
|
|
61
|
+
|
|
62
|
+
# NOTE not sure if getsitepackages was intended for this -- if issues, attempt replacing
|
|
63
|
+
# with something like f"{td.name}/lib/python*/site-packages" + globbing
|
|
64
|
+
extra_sp = site.getsitepackages(prefixes=[td.name])
|
|
65
|
+
# NOTE this makes the explicit packages go first, in case of a different version
|
|
66
|
+
logger.debug(f"extending sys.path with {extra_sp}")
|
|
67
|
+
sys.path = extra_sp + sys.path
|
|
68
|
+
logger.debug(f"new sys.path: {sys.path}")
|
|
69
|
+
|
|
70
|
+
return td
|
|
71
|
+
|
|
72
|
+
|
|
27
73
|
class PackagesEnv(AbstractContextManager):
|
|
28
74
|
def __init__(self) -> None:
|
|
29
75
|
self.td: tempfile.TemporaryDirectory | None = None
|
|
@@ -32,38 +78,32 @@ class PackagesEnv(AbstractContextManager):
|
|
|
32
78
|
if not packages:
|
|
33
79
|
return
|
|
34
80
|
if self.td is None:
|
|
35
|
-
|
|
36
|
-
self.td = tempfile.TemporaryDirectory()
|
|
37
|
-
venv_command = ["uv", "venv", self.td.name]
|
|
38
|
-
# NOTE we create a venv instead of just plain directory, because some of the packages create files
|
|
39
|
-
# outside of site-packages. Thus we then install with --prefix, not with --target
|
|
40
|
-
subprocess.run(venv_command, check=True)
|
|
81
|
+
self.td = new_venv()
|
|
41
82
|
|
|
42
83
|
logger.debug(
|
|
43
84
|
f"installing {len(packages)} packages: {','.join(packages[:3])}{',...' if len(packages) > 3 else ''}"
|
|
44
85
|
)
|
|
45
|
-
install_command =
|
|
46
|
-
"uv",
|
|
47
|
-
"pip",
|
|
48
|
-
"install",
|
|
49
|
-
"--prefix",
|
|
50
|
-
self.td.name,
|
|
51
|
-
"--prerelease",
|
|
52
|
-
"allow",
|
|
53
|
-
]
|
|
86
|
+
install_command = Commands.install_command(self.td.name)
|
|
54
87
|
if os.environ.get("VENV_OFFLINE", "") == "YES":
|
|
55
88
|
install_command += ["--offline"]
|
|
56
89
|
if cache_dir := os.environ.get("VENV_CACHE", ""):
|
|
57
90
|
install_command += ["--cache-dir", cache_dir]
|
|
58
91
|
install_command.extend(set(packages))
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
#
|
|
64
|
-
|
|
92
|
+
logger.debug(f"running install command: {' '.join(install_command)}")
|
|
93
|
+
run_command(install_command)
|
|
94
|
+
|
|
95
|
+
# NOTE we need this due to namespace packages:
|
|
96
|
+
# 1. task 1 installs ns.pkg1 in its venv
|
|
97
|
+
# 2. task 1 finishes, task 2 starts on the same worker
|
|
98
|
+
# 3. task 2 starts, installs ns.pkg2. However, importlib is in a state that ns is aware only of pkg1 submod
|
|
99
|
+
# Additionally, the caches are invalid anyway, because task 1's venv is already deleted
|
|
100
|
+
importlib.invalidate_caches()
|
|
101
|
+
# TODO some namespace packages may require a reimport because they dynamically build `__all__` -- eg earthkit
|
|
65
102
|
|
|
66
103
|
def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
|
|
104
|
+
sys.path = [
|
|
105
|
+
p for p in sys.path if self.td is None or not p.startswith(self.td.name)
|
|
106
|
+
]
|
|
67
107
|
if self.td is not None:
|
|
68
108
|
self.td.cleanup()
|
|
69
109
|
return False
|
cascade/low/builders.py
CHANGED
|
@@ -6,10 +6,11 @@
|
|
|
6
6
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
|
|
9
|
+
import importlib
|
|
9
10
|
import inspect
|
|
10
11
|
import itertools
|
|
11
12
|
from dataclasses import dataclass, field, replace
|
|
12
|
-
from typing import Callable, Iterable, Iterator, Type, cast
|
|
13
|
+
from typing import Any, Callable, Iterable, Iterator, Type, cast
|
|
13
14
|
|
|
14
15
|
import pyrsistent
|
|
15
16
|
from typing_extensions import Self
|
|
@@ -29,8 +30,17 @@ class TaskBuilder(TaskInstance):
|
|
|
29
30
|
@classmethod
|
|
30
31
|
def from_callable(cls, f: Callable, environment: list[str] | None = None) -> Self:
|
|
31
32
|
def type2str(t: str | Type) -> str:
|
|
32
|
-
type_name: str
|
|
33
|
-
|
|
33
|
+
type_name: str
|
|
34
|
+
if isinstance(t, str):
|
|
35
|
+
type_name = t
|
|
36
|
+
elif isinstance(t, tuple):
|
|
37
|
+
# TODO properly break down etc
|
|
38
|
+
type_name = "tuple"
|
|
39
|
+
elif t.__module__ == "builtins":
|
|
40
|
+
type_name = t.__name__
|
|
41
|
+
else:
|
|
42
|
+
type_name = f"{t.__module__}.{t.__name__}"
|
|
43
|
+
return "Any" if type_name in ("_empty", "inspect._empty") else type_name
|
|
34
44
|
|
|
35
45
|
sig = inspect.signature(f)
|
|
36
46
|
input_schema = {
|
|
@@ -91,10 +101,14 @@ class TaskBuilder(TaskInstance):
|
|
|
91
101
|
class JobBuilder:
|
|
92
102
|
nodes: pyrsistent.PMap = field(default_factory=lambda: pyrsistent.m())
|
|
93
103
|
edges: pyrsistent.PVector = field(default_factory=lambda: pyrsistent.v())
|
|
104
|
+
outputs: pyrsistent.PVector = field(default_factory=lambda: pyrsistent.v())
|
|
94
105
|
|
|
95
106
|
def with_node(self, name: str, task: TaskInstance) -> Self:
|
|
96
107
|
return replace(self, nodes=self.nodes.set(name, task))
|
|
97
108
|
|
|
109
|
+
def with_output(self, task: str, output: str = Node.DEFAULT_OUTPUT) -> Self:
|
|
110
|
+
return replace(self, outputs=self.outputs.append(DatasetId(task, output)))
|
|
111
|
+
|
|
98
112
|
def with_edge(
|
|
99
113
|
self, source: str, sink: str, into: str | int, frum: str = Node.DEFAULT_OUTPUT
|
|
100
114
|
) -> Self:
|
|
@@ -116,7 +130,24 @@ class JobBuilder:
|
|
|
116
130
|
"marsParamList",
|
|
117
131
|
"grib",
|
|
118
132
|
}
|
|
119
|
-
|
|
133
|
+
|
|
134
|
+
def getType(
|
|
135
|
+
fqn: str,
|
|
136
|
+
) -> (
|
|
137
|
+
Any
|
|
138
|
+
): # NOTE: typing.Type return type is tempting but not true for builtin aliases
|
|
139
|
+
if fqn.startswith("tuple"):
|
|
140
|
+
# TODO recursive parsing of tuples etc!
|
|
141
|
+
return tuple
|
|
142
|
+
if "." in fqn:
|
|
143
|
+
mpath, name = fqn.rsplit(".", 1)
|
|
144
|
+
return getattr(importlib.import_module(mpath), name)
|
|
145
|
+
else:
|
|
146
|
+
return eval(fqn)
|
|
147
|
+
|
|
148
|
+
_isinstance = (
|
|
149
|
+
lambda v, t: t == "Any" or t in skipped or isinstance(v, getType(t))
|
|
150
|
+
)
|
|
120
151
|
|
|
121
152
|
# static input types
|
|
122
153
|
static_kw_errors: Iterable[str] = (
|
|
@@ -157,7 +188,9 @@ class JobBuilder:
|
|
|
157
188
|
lambda t1, t2: t2 == "Any"
|
|
158
189
|
or t1 == t2
|
|
159
190
|
or (t1, t2) in legits
|
|
160
|
-
or
|
|
191
|
+
or t1
|
|
192
|
+
== "typing.Iterator" # TODO replace with type extraction *and* check that this is multi-output
|
|
193
|
+
or issubclass(getType(t1), getType(t2))
|
|
161
194
|
)
|
|
162
195
|
if not _issubclass(output_param, input_param):
|
|
163
196
|
yield f"edge connects two incompatible nodes: {edge}"
|
|
@@ -169,6 +202,9 @@ class JobBuilder:
|
|
|
169
202
|
# all inputs present
|
|
170
203
|
# TODO
|
|
171
204
|
|
|
205
|
+
# all outputs created
|
|
206
|
+
# TODO
|
|
207
|
+
|
|
172
208
|
errors = list(itertools.chain(static_kw_errors, edge_errors))
|
|
173
209
|
if errors:
|
|
174
210
|
return Either.error(errors)
|
|
@@ -177,5 +213,6 @@ class JobBuilder:
|
|
|
177
213
|
JobInstance(
|
|
178
214
|
tasks=cast(dict[str, TaskInstance], pyrsistent.thaw(self.nodes)),
|
|
179
215
|
edges=pyrsistent.thaw(self.edges),
|
|
216
|
+
ext_outputs=pyrsistent.thaw(self.outputs),
|
|
180
217
|
)
|
|
181
218
|
)
|
cascade/shm/api.py
CHANGED
|
@@ -6,13 +6,17 @@
|
|
|
6
6
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
|
|
9
|
+
import logging
|
|
9
10
|
import os
|
|
11
|
+
import socket
|
|
10
12
|
from dataclasses import dataclass
|
|
11
13
|
from enum import Enum, auto
|
|
12
14
|
from typing import Protocol, Type, runtime_checkable
|
|
13
15
|
|
|
14
16
|
from typing_extensions import Self
|
|
15
17
|
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
16
20
|
# TODO too much manual serde... either automate it based on dataclass field inspection, or just pickle it
|
|
17
21
|
# (mind the server.recv/client.recv comment tho)
|
|
18
22
|
# Also, consider switching from GetRequest, PurgeRequest, to DatasetRequest(get|purge|...)
|
|
@@ -244,15 +248,52 @@ def deser(data: bytes) -> Comm:
|
|
|
244
248
|
return b2c[data[:1]].deser(data[1:])
|
|
245
249
|
|
|
246
250
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
def
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
251
|
+
client_socket_envvar = "CASCADE_SHM_SOCKET"
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def publish_socket_addr(sock: int | str) -> None:
|
|
255
|
+
if isinstance(sock, int):
|
|
256
|
+
ssock = f"port:{sock}"
|
|
257
|
+
else:
|
|
258
|
+
ssock = f"file:{sock}"
|
|
259
|
+
os.environ[client_socket_envvar] = ssock
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def get_socket_addr() -> tuple[socket.socket, int | str]:
|
|
263
|
+
ssock = os.getenv(client_socket_envvar)
|
|
264
|
+
if not ssock:
|
|
265
|
+
raise ValueError(f"missing sock addr in {client_socket_envvar}")
|
|
266
|
+
kind, addr = ssock.split(":", 1)
|
|
267
|
+
if kind == "port":
|
|
268
|
+
addr = int(addr)
|
|
269
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
270
|
+
elif kind == "file":
|
|
271
|
+
# TODO can we support SOCK_DGRAM too? Problem with response address
|
|
272
|
+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
273
|
+
else:
|
|
274
|
+
raise NotImplementedError(kind)
|
|
275
|
+
return sock, addr
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def get_client_socket():
|
|
279
|
+
sock, addr = get_socket_addr()
|
|
280
|
+
if isinstance(addr, int):
|
|
281
|
+
sock.connect(("localhost", addr))
|
|
282
|
+
else:
|
|
283
|
+
sock.connect(addr)
|
|
284
|
+
return sock
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_server_socket():
|
|
288
|
+
sock, addr = get_socket_addr()
|
|
289
|
+
if isinstance(addr, int):
|
|
290
|
+
sock.bind(("0.0.0.0", addr))
|
|
291
|
+
else:
|
|
292
|
+
try:
|
|
293
|
+
os.unlink(addr)
|
|
294
|
+
logger.warning(f"unlinking at {addr}")
|
|
295
|
+
except FileNotFoundError:
|
|
296
|
+
pass
|
|
297
|
+
sock.bind(addr)
|
|
298
|
+
sock.listen(32)
|
|
299
|
+
return sock
|
cascade/shm/client.py
CHANGED
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
import logging
|
|
10
10
|
import multiprocessing.resource_tracker
|
|
11
|
-
import
|
|
11
|
+
import sys
|
|
12
12
|
import time
|
|
13
13
|
from multiprocessing.shared_memory import SharedMemory
|
|
14
14
|
from typing import Callable, Type, TypeVar
|
|
@@ -17,14 +17,19 @@ import cascade.shm.api as api
|
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
20
|
-
# TODO eleminate in favour of track=False, once we are on python 3.13+
|
|
21
|
-
is_unregister = True # NOTE exposed for pytest control
|
|
22
|
-
|
|
23
20
|
|
|
24
21
|
class ConflictError(Exception):
|
|
25
22
|
pass
|
|
26
23
|
|
|
27
24
|
|
|
25
|
+
if (sys.version_info.major, sys.version_info.minor) >= (3, 13):
|
|
26
|
+
is_unregister = False
|
|
27
|
+
shm_kwargs = {"track": False}
|
|
28
|
+
else:
|
|
29
|
+
is_unregister = True
|
|
30
|
+
shm_kwargs = {}
|
|
31
|
+
|
|
32
|
+
|
|
28
33
|
class AllocatedBuffer:
|
|
29
34
|
def __init__(
|
|
30
35
|
self,
|
|
@@ -34,8 +39,9 @@ class AllocatedBuffer:
|
|
|
34
39
|
close_callback: Callable[[], None] | None,
|
|
35
40
|
deser_fun: str,
|
|
36
41
|
):
|
|
42
|
+
self.shm: SharedMemory | None
|
|
37
43
|
try:
|
|
38
|
-
self.shm
|
|
44
|
+
self.shm = SharedMemory(shmid, create=create, size=l, **shm_kwargs)
|
|
39
45
|
except FileExistsError:
|
|
40
46
|
# NOTE this is quite wrong as instead of crashing, it would lead to undefined behaviour
|
|
41
47
|
# However, as the systems we operate on don't seem to be reliable wrt cleanup/isolation,
|
|
@@ -43,13 +49,16 @@ class AllocatedBuffer:
|
|
|
43
49
|
logger.error(
|
|
44
50
|
f"attempted opening {shmid=} but gotten FileExists. Will delete and retry"
|
|
45
51
|
)
|
|
46
|
-
_shm = SharedMemory(shmid, create=False)
|
|
52
|
+
_shm = SharedMemory(shmid, create=False, **shm_kwargs)
|
|
53
|
+
_shm.close()
|
|
47
54
|
_shm.unlink()
|
|
48
|
-
self.shm
|
|
55
|
+
self.shm = SharedMemory(shmid, create=create, size=l, **shm_kwargs)
|
|
49
56
|
self.l = l
|
|
50
57
|
self.readonly = not create
|
|
51
58
|
self.close_callback = close_callback
|
|
52
59
|
self.deser_fun = deser_fun
|
|
60
|
+
if is_unregister:
|
|
61
|
+
multiprocessing.resource_tracker.unregister(self.shm._name, "shared_memory") # type: ignore # _name
|
|
53
62
|
|
|
54
63
|
def view(self) -> memoryview:
|
|
55
64
|
if not self.shm:
|
|
@@ -60,14 +69,19 @@ class AllocatedBuffer:
|
|
|
60
69
|
return mv
|
|
61
70
|
|
|
62
71
|
def close(self) -> None:
|
|
63
|
-
if self.shm is not None:
|
|
72
|
+
if hasattr(self, "shm") and self.shm is not None:
|
|
64
73
|
self.shm.close()
|
|
65
|
-
if is_unregister:
|
|
66
|
-
multiprocessing.resource_tracker.unregister(self.shm._name, "shared_memory") # type: ignore # _name
|
|
67
74
|
if self.close_callback:
|
|
68
75
|
self.close_callback()
|
|
69
76
|
self.shm = None
|
|
70
77
|
|
|
78
|
+
def __del__(self) -> None:
|
|
79
|
+
if hasattr(self, "shm") and self.shm is not None:
|
|
80
|
+
try:
|
|
81
|
+
logger.error(f"missed close() call on {self.shm._name}") # type: ignore # _name
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.exception(f"failed to log due to {repr(e)}")
|
|
84
|
+
|
|
71
85
|
# TODO context manager
|
|
72
86
|
|
|
73
87
|
|
|
@@ -80,11 +94,10 @@ def _send_command(comm: api.Comm, resp_class: Type[T], timeout_sec: float = 60.0
|
|
|
80
94
|
# timeout_i and coeff determine rate of busy-waits: coeff=1 is additive, =2 is exponential
|
|
81
95
|
# eventually this busy-waits will go away as we switch to event driven behaviour
|
|
82
96
|
while timeout_sec > 0:
|
|
83
|
-
sock =
|
|
84
|
-
client_port = api.get_client_port()
|
|
85
|
-
sock.connect(("localhost", client_port))
|
|
97
|
+
sock = api.get_client_socket()
|
|
86
98
|
logger.debug(f"sending message {comm}")
|
|
87
99
|
sock.send(api.ser(comm))
|
|
100
|
+
# TODO rewrite to poller with timeout
|
|
88
101
|
response_raw = sock.recv(1024) # TODO or recv(4) + recv(int.from_bytes)?
|
|
89
102
|
sock.close()
|
|
90
103
|
response_com = api.deser(response_raw)
|
|
@@ -148,9 +161,9 @@ def status(key: str) -> api.DatasetStatus:
|
|
|
148
161
|
return response.status
|
|
149
162
|
|
|
150
163
|
|
|
151
|
-
def shutdown() -> None:
|
|
164
|
+
def shutdown(timeout_sec: float = 2.0) -> None:
|
|
152
165
|
comm = api.ShutdownCommand()
|
|
153
|
-
_send_command(comm, api.OkResponse)
|
|
166
|
+
_send_command(comm, api.OkResponse, timeout_sec)
|
|
154
167
|
|
|
155
168
|
|
|
156
169
|
def ensure() -> None:
|
|
@@ -161,7 +174,7 @@ def ensure() -> None:
|
|
|
161
174
|
try:
|
|
162
175
|
_send_command(comm, api.OkResponse)
|
|
163
176
|
logger.debug("shm server responds ok, leaving ensure loop")
|
|
164
|
-
except ConnectionRefusedError:
|
|
177
|
+
except (ConnectionRefusedError, FileNotFoundError):
|
|
165
178
|
time.sleep(0.1)
|
|
166
179
|
continue
|
|
167
180
|
break
|
cascade/shm/dataset.py
CHANGED
|
@@ -277,7 +277,21 @@ class Manager:
|
|
|
277
277
|
return ds.shmid, ds.size, rdid, ds.deser_fun, ""
|
|
278
278
|
|
|
279
279
|
def purge(self, key: str, is_exit: bool = False) -> None:
|
|
280
|
-
|
|
280
|
+
if key not in self.datasets:
|
|
281
|
+
# NOTE known cases for this appearing:
|
|
282
|
+
# - a1/ a TaskSequence([t1, t2], publish={t2}) is sent by controller
|
|
283
|
+
# - a2/ t2 is computed and published by worker
|
|
284
|
+
# - a3/ controller sees t1 not required anymore => sends purge
|
|
285
|
+
# - a4/ as t1 was never materialized in shm => keyerror here
|
|
286
|
+
# - a/ this would be fixed by expaned dataset2host/worker model at the controller
|
|
287
|
+
# - b1/ purge request is sent by the data server
|
|
288
|
+
# - b2/ it is received by shm and purged, but ack fails on zmq
|
|
289
|
+
# - b3/ data server retries, but now the key is unknown
|
|
290
|
+
# - b/ this would be fixed by keeping track of tombstones
|
|
291
|
+
logger.warning(
|
|
292
|
+
f"key unknown to this shm instance: {key}, ignoring purge request"
|
|
293
|
+
)
|
|
294
|
+
return
|
|
281
295
|
try:
|
|
282
296
|
logger.debug(f"attempting purge-inquire of {key}")
|
|
283
297
|
try:
|
cascade/shm/server.py
CHANGED
|
@@ -10,7 +10,7 @@ import logging
|
|
|
10
10
|
import logging.config
|
|
11
11
|
import signal
|
|
12
12
|
import socket
|
|
13
|
-
from typing import Any
|
|
13
|
+
from typing import Any, cast
|
|
14
14
|
|
|
15
15
|
import cascade.shm.api as api
|
|
16
16
|
import cascade.shm.dataset as dataset
|
|
@@ -21,20 +21,35 @@ logger = logging.getLogger(__name__)
|
|
|
21
21
|
class LocalServer:
|
|
22
22
|
"""Handles the socket communication, and the invocation of dataset.Manager which has the business logic"""
|
|
23
23
|
|
|
24
|
-
def __init__(self,
|
|
25
|
-
self.sock =
|
|
26
|
-
|
|
27
|
-
|
|
24
|
+
def __init__(self, shm_pref: str, capacity: int | None = None):
|
|
25
|
+
self.sock = api.get_server_socket()
|
|
26
|
+
logger.info(
|
|
27
|
+
f"shm server starting with {self.sock=} with {capacity=} and prefix {shm_pref}"
|
|
28
|
+
)
|
|
28
29
|
self.manager = dataset.Manager(shm_pref, capacity)
|
|
29
30
|
signal.signal(signal.SIGINT, self.atexit)
|
|
30
31
|
signal.signal(signal.SIGTERM, self.atexit)
|
|
31
32
|
|
|
32
|
-
def receive(self) -> tuple[api.Comm, str]:
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
def receive(self) -> tuple[api.Comm, str | socket.socket]:
|
|
34
|
+
# TODO recv(1024) or recv(4) + recv(int.from_bytes)?
|
|
35
|
+
if self.sock.type == socket.SOCK_DGRAM:
|
|
36
|
+
b, resp = self.sock.recvfrom(1024)
|
|
37
|
+
elif self.sock.type == socket.SOCK_STREAM:
|
|
38
|
+
resp, _addr = self.sock.accept()
|
|
39
|
+
b = resp.recv(1024)
|
|
40
|
+
else:
|
|
41
|
+
raise NotImplementedError(self.sock.type)
|
|
42
|
+
return api.deser(b), resp
|
|
35
43
|
|
|
36
|
-
def respond(self, comm: api.Comm, address: str) -> None:
|
|
37
|
-
|
|
44
|
+
def respond(self, comm: api.Comm, address: str | socket.socket) -> None:
|
|
45
|
+
m = api.ser(comm)
|
|
46
|
+
if self.sock.type == socket.SOCK_DGRAM:
|
|
47
|
+
self.sock.sendto(m, address)
|
|
48
|
+
elif self.sock.type == socket.SOCK_STREAM:
|
|
49
|
+
logger.debug(f"will send to {address} message {m}")
|
|
50
|
+
cast(socket.socket, address).send(m)
|
|
51
|
+
else:
|
|
52
|
+
raise NotImplementedError(self.sock.type)
|
|
38
53
|
|
|
39
54
|
def atexit(self, signum: int, frame: Any) -> None:
|
|
40
55
|
self.manager.atexit()
|
|
@@ -43,6 +58,7 @@ class LocalServer:
|
|
|
43
58
|
def start(self):
|
|
44
59
|
while True:
|
|
45
60
|
payload, client = self.receive()
|
|
61
|
+
logger.debug(f"gotten {payload=}")
|
|
46
62
|
try:
|
|
47
63
|
if isinstance(payload, api.AllocateRequest):
|
|
48
64
|
shmid, error = self.manager.add(
|
|
@@ -96,21 +112,18 @@ class LocalServer:
|
|
|
96
112
|
except Exception as e:
|
|
97
113
|
logger.exception(f"failure during handling of {payload}")
|
|
98
114
|
response = api.OkResponse(error=repr(e))
|
|
115
|
+
logger.debug(f"sending {response=} to {client}")
|
|
99
116
|
self.respond(response, client)
|
|
100
117
|
|
|
101
118
|
|
|
102
119
|
def entrypoint(
|
|
103
|
-
port: int,
|
|
104
120
|
capacity: int | None = None,
|
|
105
121
|
logging_config: dict | None = None,
|
|
106
122
|
shm_pref: str = "shm",
|
|
107
123
|
):
|
|
108
124
|
if logging_config:
|
|
109
125
|
logging.config.dictConfig(logging_config)
|
|
110
|
-
server = LocalServer(
|
|
111
|
-
logger.info(
|
|
112
|
-
f"shm server starting on {port=} with {capacity=} and prefix {shm_pref}"
|
|
113
|
-
)
|
|
126
|
+
server = LocalServer(shm_pref, capacity)
|
|
114
127
|
try:
|
|
115
128
|
server.start()
|
|
116
129
|
except Exception as e:
|
earthkit/workflows/_version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# Do not change! Do not track in version control!
|
|
2
|
-
__version__ = "0.
|
|
2
|
+
__version__ = "0.5.1"
|
earthkit/workflows/decorators.py
CHANGED
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
|
|
9
9
|
from functools import wraps
|
|
10
|
-
from typing import Callable, ParamSpec, TypeVar
|
|
10
|
+
from typing import Any, Callable, Concatenate, ParamSpec, ParamSpecArgs, TypeVar
|
|
11
11
|
|
|
12
12
|
from .fluent import Payload
|
|
13
13
|
|
|
@@ -15,16 +15,30 @@ P = ParamSpec("P")
|
|
|
15
15
|
R = TypeVar("R")
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def as_payload(func: Callable[P, R])
|
|
19
|
-
"""Wrap a function and return a
|
|
18
|
+
def as_payload(func: Callable[Concatenate[ParamSpecArgs, P], R]):
|
|
19
|
+
"""Wrap a function and return a Payload object.
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
Forces the function to be called with keyword arguments only, with args being passed
|
|
22
|
+
once the payload is executed from earlier Nodes.
|
|
23
|
+
|
|
24
|
+
Set `metadata` to pass metadata to the payload.
|
|
25
|
+
|
|
26
|
+
Examples
|
|
27
|
+
--------
|
|
28
|
+
```python
|
|
29
|
+
@as_payload
|
|
30
|
+
def my_function(a, b, *, keyword):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
my_function(1, 2, keyword='test') # Raises an error
|
|
34
|
+
my_function(b=2, keyword='test') # OK, a will be passed from earlier nodes
|
|
35
|
+
my_function(keyword='test') # OK, a and b will be passed from earlier nodes
|
|
36
|
+
|
|
37
|
+
```
|
|
22
38
|
"""
|
|
23
|
-
from .fluent import Payload
|
|
24
39
|
|
|
25
|
-
@wraps(func)
|
|
26
|
-
def decorator(
|
|
27
|
-
|
|
28
|
-
return Payload(func, args, kwargs, metadata=metadata)
|
|
40
|
+
@wraps(func, assigned=["__name__", "__doc__"])
|
|
41
|
+
def decorator(*, metadata: dict[str, Any] | None = None, **kwargs) -> Payload:
|
|
42
|
+
return Payload(func, args=None, kwargs=kwargs, metadata=metadata)
|
|
29
43
|
|
|
30
44
|
return decorator
|
earthkit/workflows/fluent.py
CHANGED
|
@@ -10,7 +10,16 @@ from __future__ import annotations
|
|
|
10
10
|
|
|
11
11
|
import functools
|
|
12
12
|
import hashlib
|
|
13
|
-
from typing import
|
|
13
|
+
from typing import (
|
|
14
|
+
Any,
|
|
15
|
+
Callable,
|
|
16
|
+
Hashable,
|
|
17
|
+
Iterable,
|
|
18
|
+
Optional,
|
|
19
|
+
ParamSpec,
|
|
20
|
+
Sequence,
|
|
21
|
+
TypeVar,
|
|
22
|
+
)
|
|
14
23
|
|
|
15
24
|
import numpy as np
|
|
16
25
|
import xarray as xr
|
|
@@ -20,8 +29,6 @@ from .graph import Graph
|
|
|
20
29
|
from .graph import Node as BaseNode
|
|
21
30
|
from .graph import Output
|
|
22
31
|
|
|
23
|
-
ActionType = TypeVar("ActionType", bound="Action")
|
|
24
|
-
|
|
25
32
|
|
|
26
33
|
class Payload:
|
|
27
34
|
"""Class for detailing function, args and kwargs to be computing in a graph node"""
|
|
@@ -68,7 +75,7 @@ class Payload:
|
|
|
68
75
|
def __str__(self) -> str:
|
|
69
76
|
return f"{self.name()}{self.args}{self.kwargs}:{self.metadata}"
|
|
70
77
|
|
|
71
|
-
def __eq__(self, other
|
|
78
|
+
def __eq__(self, other) -> bool:
|
|
72
79
|
if not isinstance(other, Payload):
|
|
73
80
|
return False
|
|
74
81
|
return str(self) == str(other)
|
|
@@ -86,15 +93,16 @@ def custom_hash(string: str) -> str:
|
|
|
86
93
|
Coord = tuple[str, list[Any]]
|
|
87
94
|
Input = BaseNode | Output
|
|
88
95
|
|
|
89
|
-
|
|
96
|
+
P = ParamSpec("P")
|
|
97
|
+
R = TypeVar("R")
|
|
90
98
|
|
|
91
99
|
|
|
92
|
-
def capture_payload_metadata(func:
|
|
100
|
+
def capture_payload_metadata(func: Callable[P, R]) -> Callable[P, R]:
|
|
93
101
|
"""Wrap a function which returns a new action and insert
|
|
94
102
|
given `payload_metadata`
|
|
95
103
|
"""
|
|
96
104
|
|
|
97
|
-
@functools.wraps(func)
|
|
105
|
+
# @functools.wraps(func)
|
|
98
106
|
def decorator(*args, **kwargs):
|
|
99
107
|
metadata = kwargs.pop("payload_metadata", {})
|
|
100
108
|
result = func(*args, **kwargs)
|
|
@@ -162,7 +170,7 @@ class Node(BaseNode):
|
|
|
162
170
|
|
|
163
171
|
class Action:
|
|
164
172
|
|
|
165
|
-
REGISTRY: dict[str, Action] = {}
|
|
173
|
+
REGISTRY: dict[str, type[Action]] = {}
|
|
166
174
|
|
|
167
175
|
def __init__(self, nodes: xr.DataArray, yields: Optional[Coord] = None):
|
|
168
176
|
if yields:
|
|
@@ -224,24 +232,17 @@ class Action:
|
|
|
224
232
|
f"Action class {obj} already has an attribute {name}, will not override"
|
|
225
233
|
)
|
|
226
234
|
|
|
227
|
-
cls.REGISTRY[name] = obj
|
|
235
|
+
cls.REGISTRY[name] = obj
|
|
228
236
|
|
|
229
237
|
@classmethod
|
|
230
238
|
def flush_registry(cls):
|
|
231
239
|
"""Flush the registry of all registered actions"""
|
|
232
240
|
cls.REGISTRY = {}
|
|
233
241
|
|
|
234
|
-
def as_action(self, other
|
|
242
|
+
def as_action(self, other) -> Action:
|
|
235
243
|
"""Parse action into another action class"""
|
|
236
244
|
return other(self.nodes)
|
|
237
245
|
|
|
238
|
-
def __getattr__(self, attr):
|
|
239
|
-
if attr in Action.REGISTRY:
|
|
240
|
-
return RegisteredAction(
|
|
241
|
-
attr, Action.REGISTRY[attr], self
|
|
242
|
-
) # When the attr is a registered action class
|
|
243
|
-
raise AttributeError(f"{self.__class__.__name__} has no attribute {attr!r}")
|
|
244
|
-
|
|
245
246
|
def join(
|
|
246
247
|
self,
|
|
247
248
|
other_action: "Action",
|
|
@@ -803,6 +804,13 @@ class Action:
|
|
|
803
804
|
if dim_name in self.nodes.coords and len(self.nodes.coords[dim_name]) == 1:
|
|
804
805
|
self.nodes = self.nodes.squeeze(dim_name, drop=drop)
|
|
805
806
|
|
|
807
|
+
def __getattr__(self, attr):
|
|
808
|
+
if attr in Action.REGISTRY:
|
|
809
|
+
return RegisteredAction(
|
|
810
|
+
attr, Action.REGISTRY[attr], self
|
|
811
|
+
) # When the attr is a registered action class
|
|
812
|
+
raise AttributeError(f"{self.__class__.__name__} has no attribute {attr!r}")
|
|
813
|
+
|
|
806
814
|
|
|
807
815
|
class RegisteredAction:
|
|
808
816
|
"""Wrapper around registered actions"""
|
earthkit/workflows/mark.py
CHANGED
|
@@ -15,10 +15,9 @@ def add_execution_metadata(**kwargs) -> Callable[[F], F]:
|
|
|
15
15
|
"""Add execution metadata to a function."""
|
|
16
16
|
|
|
17
17
|
def decorator(func: F) -> F:
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
func._cascade.update(kwargs)
|
|
18
|
+
kw = getattr(func, "_cascade", {}).copy()
|
|
19
|
+
kw.update(kwargs)
|
|
20
|
+
setattr(func, "_cascade", kw)
|
|
22
21
|
return func
|
|
23
22
|
|
|
24
23
|
return decorator
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: earthkit-workflows
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.1
|
|
4
4
|
Summary: Earthkit Workflows is a Python library for declaring earthkit task DAGs, as well as scheduling and executing them on heterogeneous computing systems.
|
|
5
5
|
Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -10,7 +10,8 @@ cascade/benchmarks/job1.py,sha256=MOcZZYgf36MzHCjtby0lQyenM1ODUlagG8wtt2CbpnI,46
|
|
|
10
10
|
cascade/benchmarks/matmul.py,sha256=5STuvPY6Q37E2pKRCde9dQjL5M6tx7tkES9cBLZ6eK4,1972
|
|
11
11
|
cascade/benchmarks/plotting.py,sha256=vSz9HHbqZwMXHpBUS-In6xsXGgK7QIoQTTiYfSwYwZs,4428
|
|
12
12
|
cascade/benchmarks/reporting.py,sha256=MejaM-eekbMYLAnuBxGv_t4dR1ODJs4Rpc0fiZSGjyw,5410
|
|
13
|
-
cascade/benchmarks/
|
|
13
|
+
cascade/benchmarks/tests.py,sha256=eeQE0YR4FKi5k9BMJaTcXjKF5eIu3xXJsHc099P0Jio,5537
|
|
14
|
+
cascade/benchmarks/util.py,sha256=wP7lDI6v9ATIF96uagVB-23EiagCTVYJhUUy-_CfqQ8,9892
|
|
14
15
|
cascade/controller/__init__.py,sha256=p4C2p3S_0nUGamP9Mi6cSa5bvpiWbI6sVWtGhFnNqjw,1278
|
|
15
16
|
cascade/controller/act.py,sha256=WHIsk4H-Bbyl_DABX2VWhyKy_cNnp12x1nilatPCL8I,2981
|
|
16
17
|
cascade/controller/core.py,sha256=NqvZ5g5GNphwOpzdXbCI0_fxIzzmO97_n2xZKswK72Q,3589
|
|
@@ -20,15 +21,15 @@ cascade/controller/report.py,sha256=rKGYmq4nIiDqKuP_C7YSwpEAUOPdjILlDcbKkdUt30s,
|
|
|
20
21
|
cascade/executor/bridge.py,sha256=WDE-GM2Bv7nUk1-nV-otMGuaRYw1-Vmd7PWploXBp6Y,8267
|
|
21
22
|
cascade/executor/comms.py,sha256=-9qrKwva6WXkHRQtzSnLFy5gB3bOWuxYJP5fL6Uavw8,8736
|
|
22
23
|
cascade/executor/config.py,sha256=8azy_sXdvDGO0zTNqA0pdtkXsyihM4FQ4U1W_3Dhua0,1571
|
|
23
|
-
cascade/executor/data_server.py,sha256=
|
|
24
|
-
cascade/executor/executor.py,sha256=
|
|
24
|
+
cascade/executor/data_server.py,sha256=TSFJdSR9PtKSvvLTosHt0ITQlqtGGAl5N_io6wtvL0A,13569
|
|
25
|
+
cascade/executor/executor.py,sha256=OwLrhSLm4bIHsWdnjXlnQntxGIOHgrIPSSVZ5nbNWvQ,13686
|
|
25
26
|
cascade/executor/msg.py,sha256=7HI0rKeCRaV1ONR4HWEa64nHbu-p6-QdBwJNitmst48,4340
|
|
26
27
|
cascade/executor/platform.py,sha256=mRUauodvRle9rAbtFr5n9toKzIgt_pecNlhOjon4dvY,2348
|
|
27
28
|
cascade/executor/serde.py,sha256=z6klTOZqW_BVGrbIRNz4FN0_XTfRiKBRQuvgsQIuyAo,2827
|
|
28
29
|
cascade/executor/runner/__init__.py,sha256=30BM80ZyA7w3IrGiKKLSFuhRehbR2Mm99OJ8q5PJ63c,1547
|
|
29
30
|
cascade/executor/runner/entrypoint.py,sha256=WyxOFGAYDQD_fXsM4H9_6xBrnAmQrCTUnljfcW6-BoM,7918
|
|
30
|
-
cascade/executor/runner/memory.py,sha256=
|
|
31
|
-
cascade/executor/runner/packages.py,sha256=
|
|
31
|
+
cascade/executor/runner/memory.py,sha256=VEOrYfFNGNBM7vMY05wjbX3L0U-RJZWpm_Ud4bMUR5g,6486
|
|
32
|
+
cascade/executor/runner/packages.py,sha256=lic5ItjyDpcQVRBFOZssvnco9bmxWpq_JRFDeShVR8k,4150
|
|
32
33
|
cascade/executor/runner/runner.py,sha256=zqpkvxdWLbwyUFaUbZmSj0KQEBNRpmF8gwVotiaamhc,4870
|
|
33
34
|
cascade/gateway/__init__.py,sha256=1EzMKdLFXEucj0YWOlyVqLx4suOntitwM03T_rRubIk,829
|
|
34
35
|
cascade/gateway/__main__.py,sha256=kmfklSeA7v5ie75SBHOql-eHuY6x4eTHlItMYqCQ1Pg,969
|
|
@@ -37,7 +38,7 @@ cascade/gateway/client.py,sha256=1p4Tvrf-BH0LQHOES5rY1z3JNIfmXcqWG2kYl4rpcE0,406
|
|
|
37
38
|
cascade/gateway/router.py,sha256=9oTkqssb3dHF24TIaAn_7oQoNfm4qkOvriufbOJxnyE,11582
|
|
38
39
|
cascade/gateway/server.py,sha256=BfUKpU2nCEB_zI4BdZU_9zHYHX1WoQaLARCTxMSP0Nk,3971
|
|
39
40
|
cascade/low/__init__.py,sha256=5cw2taOGITK_gFbICftzK2YLdEAnLUY5OzblFzdHss4,769
|
|
40
|
-
cascade/low/builders.py,sha256=
|
|
41
|
+
cascade/low/builders.py,sha256=F2W47zIa8tfBxHvvRekNd8SV7l4HOPEisQujhh9gisQ,8428
|
|
41
42
|
cascade/low/core.py,sha256=_3x4ka_pmCgZbfwFeyhq8S4M6wmh0s24VRCLhk5yQFM,6444
|
|
42
43
|
cascade/low/dask.py,sha256=xToT_vyfkgUUxSFN7dS7qLttxzuBbBZfDylPzGg7sPg,3319
|
|
43
44
|
cascade/low/execution_context.py,sha256=cdDJLYhreo4T7t4qXgFBosncubZpTrm0hELo7q4miqo,6640
|
|
@@ -52,18 +53,17 @@ cascade/scheduler/core.py,sha256=umORLC6SDeOyS4z8nQuVFkDukBJ96JfH4hdLSj6Km20,337
|
|
|
52
53
|
cascade/scheduler/precompute.py,sha256=AhTn8RgnU4XuV_WAgbVXz9z0YRpNS6LCY1dJeHdTfCc,8709
|
|
53
54
|
cascade/shm/__init__.py,sha256=R9QgGSnsl_YDjFjAUQkoleM_5yGM37ce9S8a4ReA1mE,3854
|
|
54
55
|
cascade/shm/algorithms.py,sha256=SGxnJF4ovUaywTunMJWkG77l5DN-jXx7HgABt3sRJXM,2356
|
|
55
|
-
cascade/shm/api.py,sha256=
|
|
56
|
-
cascade/shm/client.py,sha256=
|
|
57
|
-
cascade/shm/dataset.py,sha256=
|
|
56
|
+
cascade/shm/api.py,sha256=TFK0ioKJpJ2-rTxwk_O5BtB6AKjgwfM8CIl-VZaUIZo,7180
|
|
57
|
+
cascade/shm/client.py,sha256=7rUG0bra7XTJRumywQ-Gos4pWeZoXpTZqseh36uNWFg,6312
|
|
58
|
+
cascade/shm/dataset.py,sha256=QAALiWK0fyMLet9XFXmATm-c9gTuF77cifGbjP3WjXo,13155
|
|
58
59
|
cascade/shm/disk.py,sha256=Fdl_pKOseaXroRp01OwqWVsdI-sSmiFizIFCdxBuMWM,2653
|
|
59
60
|
cascade/shm/func.py,sha256=ZWikgnSLCmbSoW2LDRJwtjxdwTxkR00OUHAsIRQ-ChE,638
|
|
60
|
-
cascade/shm/server.py,sha256=
|
|
61
|
+
cascade/shm/server.py,sha256=geWo2BuF8sa_BqY8akh6ardWFfKfDJktWujrdDHn624,5648
|
|
61
62
|
earthkit/workflows/__init__.py,sha256=-p4anEn0YQbYWM2tbXb0Vc3wq4-m6kFhcNEgAVu5Jis,1948
|
|
62
|
-
earthkit/workflows/_version.py,sha256=
|
|
63
|
-
earthkit/workflows/decorators.py,sha256=
|
|
64
|
-
earthkit/workflows/fluent.py,sha256=
|
|
65
|
-
earthkit/workflows/mark.py,sha256=
|
|
66
|
-
earthkit/workflows/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
|
+
earthkit/workflows/_version.py,sha256=gCSK2O7UArCP94MovPvZF039nZ9JlETEx89VakvCxQg,72
|
|
64
|
+
earthkit/workflows/decorators.py,sha256=YK6AN-Ta9cAOX__DjZbn_vNYdpRL98N6dbF31E6Vu1c,1478
|
|
65
|
+
earthkit/workflows/fluent.py,sha256=3CvZfdLjXCoGR0VJDTB8_PDFgR7n-UhGLdKo7E5zuvM,30161
|
|
66
|
+
earthkit/workflows/mark.py,sha256=otgR6ar_9R7q5VRFD6RlLUROfjhyiaMIsgcleW2icKI,1322
|
|
67
67
|
earthkit/workflows/taskgraph.py,sha256=RsT1Qlng1uPZSaSBNqE8vFsoI5J8DDcQl468YPX-kCY,4460
|
|
68
68
|
earthkit/workflows/transformers.py,sha256=BsUUvnG-UyerT3XUYcHc1qJkSsLc0ZX3Zxqq70tJWLU,2105
|
|
69
69
|
earthkit/workflows/utility.py,sha256=ygqn1s846WQbo7HGY46Z8N1AXrDFGwyygSgsv4YnGJ8,1344
|
|
@@ -89,8 +89,8 @@ earthkit/workflows/graph/split.py,sha256=t-Sji5eZb01QO1szqmDNTodDDALqdo-0R0x1ESs
|
|
|
89
89
|
earthkit/workflows/graph/transform.py,sha256=BZ8n7ePUnuGgoHkMqZC3SLzifu4oq6q6t6vka0khFtg,3842
|
|
90
90
|
earthkit/workflows/graph/visit.py,sha256=MP-aFSqOl7aqJY2i7QTgY4epqb6yM7_lK3ofvOqfahw,1755
|
|
91
91
|
earthkit/workflows/plugins/__init__.py,sha256=nhMAC0eMLxoJamjqB5Ns0OWy0OuxEJ_YvaDFGEQITls,129
|
|
92
|
-
earthkit_workflows-0.
|
|
93
|
-
earthkit_workflows-0.
|
|
94
|
-
earthkit_workflows-0.
|
|
95
|
-
earthkit_workflows-0.
|
|
96
|
-
earthkit_workflows-0.
|
|
92
|
+
earthkit_workflows-0.5.1.dist-info/licenses/LICENSE,sha256=73MJ7twXMKnWwmzmrMiFwUeY7c6JTvxphVggeUq9Sq4,11381
|
|
93
|
+
earthkit_workflows-0.5.1.dist-info/METADATA,sha256=r90HXtWZwIhQ1Em7bIlByUbkiT1YQDAa9nyOUvejago,1571
|
|
94
|
+
earthkit_workflows-0.5.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
95
|
+
earthkit_workflows-0.5.1.dist-info/top_level.txt,sha256=oNrH3Km3hK5kDkTOiM-8G8OQglvZcy-gUKy7rlooWXs,17
|
|
96
|
+
earthkit_workflows-0.5.1.dist-info/RECORD,,
|
earthkit/workflows/py.typed
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|