earthkit-workflows 0.3.6__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cascade/benchmarks/__main__.py +45 -9
- cascade/benchmarks/dist.py +123 -0
- cascade/benchmarks/job1.py +2 -2
- cascade/benchmarks/matmul.py +73 -0
- cascade/controller/act.py +1 -0
- cascade/controller/impl.py +5 -0
- cascade/controller/notify.py +2 -0
- cascade/executor/bridge.py +2 -1
- cascade/executor/config.py +1 -0
- cascade/executor/executor.py +4 -2
- cascade/executor/msg.py +2 -0
- cascade/executor/runner/entrypoint.py +41 -5
- cascade/executor/runner/memory.py +23 -4
- cascade/low/core.py +14 -2
- cascade/low/execution_context.py +6 -0
- cascade/scheduler/api.py +56 -1
- cascade/scheduler/assign.py +269 -58
- cascade/scheduler/core.py +19 -0
- cascade/scheduler/{graph.py → precompute.py} +101 -44
- earthkit/workflows/__init__.py +4 -0
- earthkit/workflows/_version.py +1 -1
- earthkit/workflows/backends/__init__.py +27 -11
- earthkit/workflows/plugins/__init__.py +4 -0
- {earthkit_workflows-0.3.6.dist-info → earthkit_workflows-0.4.1.dist-info}/METADATA +1 -1
- {earthkit_workflows-0.3.6.dist-info → earthkit_workflows-0.4.1.dist-info}/RECORD +28 -26
- {earthkit_workflows-0.3.6.dist-info → earthkit_workflows-0.4.1.dist-info}/WHEEL +0 -0
- {earthkit_workflows-0.3.6.dist-info → earthkit_workflows-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {earthkit_workflows-0.3.6.dist-info → earthkit_workflows-0.4.1.dist-info}/top_level.txt +0 -0
cascade/benchmarks/__main__.py
CHANGED
|
@@ -26,7 +26,9 @@ import logging.config
|
|
|
26
26
|
import multiprocessing
|
|
27
27
|
import os
|
|
28
28
|
import subprocess
|
|
29
|
+
import sys
|
|
29
30
|
from concurrent.futures import ThreadPoolExecutor
|
|
31
|
+
from socket import getfqdn
|
|
30
32
|
from time import perf_counter_ns
|
|
31
33
|
|
|
32
34
|
import fire
|
|
@@ -41,7 +43,7 @@ from cascade.executor.executor import Executor
|
|
|
41
43
|
from cascade.executor.msg import BackboneAddress, ExecutorShutdown
|
|
42
44
|
from cascade.low.core import JobInstance
|
|
43
45
|
from cascade.low.func import msum
|
|
44
|
-
from cascade.scheduler.
|
|
46
|
+
from cascade.scheduler.precompute import precompute
|
|
45
47
|
from earthkit.workflows.graph import Graph, deduplicate_nodes
|
|
46
48
|
|
|
47
49
|
logger = logging.getLogger("cascade.benchmarks")
|
|
@@ -73,14 +75,28 @@ def get_job(benchmark: str | None, instance_path: str | None) -> JobInstance:
|
|
|
73
75
|
import cascade.benchmarks.generators as generators
|
|
74
76
|
|
|
75
77
|
return generators.get_job()
|
|
78
|
+
elif benchmark.startswith("matmul"):
|
|
79
|
+
import cascade.benchmarks.matmul as matmul
|
|
80
|
+
|
|
81
|
+
return matmul.get_job()
|
|
82
|
+
elif benchmark.startswith("dist"):
|
|
83
|
+
import cascade.benchmarks.dist as dist
|
|
84
|
+
|
|
85
|
+
return dist.get_job()
|
|
76
86
|
else:
|
|
77
87
|
raise NotImplementedError(benchmark)
|
|
78
88
|
else:
|
|
79
89
|
raise TypeError("specified neither benchmark name nor job instance")
|
|
80
90
|
|
|
81
91
|
|
|
82
|
-
def
|
|
92
|
+
def get_cuda_count() -> int:
|
|
83
93
|
try:
|
|
94
|
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
|
95
|
+
# TODO we dont want to just count, we want to actually use literally these ids
|
|
96
|
+
# NOTE this is particularly useful for "" value -- careful when refactoring
|
|
97
|
+
visible = os.environ["CUDA_VISIBLE_DEVICES"]
|
|
98
|
+
visible_count = sum(1 for e in visible if e == ",") + (1 if visible else 0)
|
|
99
|
+
return visible_count
|
|
84
100
|
gpus = sum(
|
|
85
101
|
1
|
|
86
102
|
for l in subprocess.run(
|
|
@@ -91,12 +107,22 @@ def get_gpu_count() -> int:
|
|
|
91
107
|
if "GPU" in l
|
|
92
108
|
)
|
|
93
109
|
except:
|
|
94
|
-
# TODO support macos
|
|
95
110
|
logger.exception("unable to determine available gpus")
|
|
96
111
|
gpus = 0
|
|
97
112
|
return gpus
|
|
98
113
|
|
|
99
114
|
|
|
115
|
+
def get_gpu_count(host_idx: int, worker_count: int) -> int:
|
|
116
|
+
if sys.platform == "darwin":
|
|
117
|
+
# we should inspect some gpu capabilities details to prevent overcommit
|
|
118
|
+
return worker_count
|
|
119
|
+
else:
|
|
120
|
+
if host_idx == 0:
|
|
121
|
+
return get_cuda_count()
|
|
122
|
+
else:
|
|
123
|
+
return 0
|
|
124
|
+
|
|
125
|
+
|
|
100
126
|
def launch_executor(
|
|
101
127
|
job_instance: JobInstance,
|
|
102
128
|
controller_address: BackboneAddress,
|
|
@@ -106,6 +132,7 @@ def launch_executor(
|
|
|
106
132
|
shm_vol_gb: int | None,
|
|
107
133
|
gpu_count: int,
|
|
108
134
|
log_base: str | None,
|
|
135
|
+
url_base: str,
|
|
109
136
|
):
|
|
110
137
|
if log_base is not None:
|
|
111
138
|
log_base = f"{log_base}.host{i}"
|
|
@@ -123,6 +150,7 @@ def launch_executor(
|
|
|
123
150
|
portBase,
|
|
124
151
|
shm_vol_gb,
|
|
125
152
|
log_base,
|
|
153
|
+
url_base,
|
|
126
154
|
)
|
|
127
155
|
executor.register()
|
|
128
156
|
executor.recv_loop()
|
|
@@ -147,14 +175,21 @@ def run_locally(
|
|
|
147
175
|
m = f"tcp://localhost:{portBase+1}"
|
|
148
176
|
ps = []
|
|
149
177
|
for i, executor in enumerate(range(hosts)):
|
|
150
|
-
|
|
151
|
-
gpu_count = get_gpu_count()
|
|
152
|
-
else:
|
|
153
|
-
gpu_count = 0
|
|
178
|
+
gpu_count = get_gpu_count(i, workers)
|
|
154
179
|
# NOTE forkserver/spawn seem to forget venv, we need fork
|
|
155
180
|
p = multiprocessing.get_context("fork").Process(
|
|
156
181
|
target=launch_executor,
|
|
157
|
-
args=(
|
|
182
|
+
args=(
|
|
183
|
+
job,
|
|
184
|
+
c,
|
|
185
|
+
workers,
|
|
186
|
+
portBase + 1 + i * 10,
|
|
187
|
+
i,
|
|
188
|
+
None,
|
|
189
|
+
gpu_count,
|
|
190
|
+
log_base,
|
|
191
|
+
"tcp://localhost",
|
|
192
|
+
),
|
|
158
193
|
)
|
|
159
194
|
p.start()
|
|
160
195
|
ps.append(p)
|
|
@@ -228,7 +263,7 @@ def main_dist(
|
|
|
228
263
|
f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
|
|
229
264
|
)
|
|
230
265
|
else:
|
|
231
|
-
gpu_count = get_gpu_count()
|
|
266
|
+
gpu_count = get_gpu_count(0, workers_per_host)
|
|
232
267
|
launch_executor(
|
|
233
268
|
jobInstance,
|
|
234
269
|
controller_url,
|
|
@@ -237,6 +272,7 @@ def main_dist(
|
|
|
237
272
|
idx,
|
|
238
273
|
shm_vol_gb,
|
|
239
274
|
gpu_count,
|
|
275
|
+
f"tcp://{getfqdn()}",
|
|
240
276
|
)
|
|
241
277
|
|
|
242
278
|
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Demonstrates gang scheduling capabilities, ie, multiple nodes capable of mutual communication.
|
|
2
|
+
|
|
3
|
+
The job is a source -> (dist group) -> sink, where:
|
|
4
|
+
source just returns an int,
|
|
5
|
+
dist group is L nodes to be scheduled as a single gang
|
|
6
|
+
rank=0 node broadcasts a buffer containing the node's input
|
|
7
|
+
each node returns its input multiplied by broadcasted buffer
|
|
8
|
+
sink returns the sum of all inputs
|
|
9
|
+
|
|
10
|
+
There are multiple implementations of that:
|
|
11
|
+
torch
|
|
12
|
+
jax (actually does a mesh-shard global sum instead of broadcast -- the point is to showcase dist init)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from cascade.low.builders import JobBuilder, TaskBuilder
|
|
18
|
+
from cascade.low.core import JobInstance, SchedulingConstraint
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def source_func() -> int:
|
|
22
|
+
return 42
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def dist_func_torch(a: int) -> int:
|
|
26
|
+
import datetime as dt
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import torch.distributed as dist
|
|
30
|
+
|
|
31
|
+
world_size = int(os.environ["CASCADE_GANG_WORLD_SIZE"])
|
|
32
|
+
rank = int(os.environ["CASCADE_GANG_RANK"])
|
|
33
|
+
coordinator = os.environ["CASCADE_GANG_COORDINATOR"]
|
|
34
|
+
print(f"starting with envvars: {rank=}/{world_size=}, {coordinator=}")
|
|
35
|
+
dist.init_process_group(
|
|
36
|
+
backend="gloo",
|
|
37
|
+
init_method=coordinator,
|
|
38
|
+
timeout=dt.timedelta(minutes=1),
|
|
39
|
+
world_size=world_size,
|
|
40
|
+
rank=rank,
|
|
41
|
+
)
|
|
42
|
+
group_ranks = np.arange(world_size, dtype=int)
|
|
43
|
+
group = dist.new_group(group_ranks)
|
|
44
|
+
|
|
45
|
+
if rank == 0:
|
|
46
|
+
buf = [a]
|
|
47
|
+
dist.broadcast_object_list(buf, src=0, group=group)
|
|
48
|
+
print("broadcast ok")
|
|
49
|
+
else:
|
|
50
|
+
buf = np.array([0], dtype=np.uint64)
|
|
51
|
+
dist.broadcast_object_list(buf, src=0, group=group)
|
|
52
|
+
print(f"broadcast recevied {buf}")
|
|
53
|
+
|
|
54
|
+
return a * buf[0]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def dist_func_jax(a: int) -> int:
|
|
58
|
+
world_size = int(os.environ["CASCADE_GANG_WORLD_SIZE"])
|
|
59
|
+
rank = int(os.environ["CASCADE_GANG_RANK"])
|
|
60
|
+
coordinator = os.environ["CASCADE_GANG_COORDINATOR"]
|
|
61
|
+
os.environ["JAX_NUM_CPU_DEVICES"] = "1"
|
|
62
|
+
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
|
63
|
+
os.environ["JAX_PLATFORMS"] = "cpu"
|
|
64
|
+
import jax
|
|
65
|
+
import jax.numpy as jp
|
|
66
|
+
|
|
67
|
+
jax.config.update("jax_platforms", "cpu")
|
|
68
|
+
jax.config.update("jax_platform_name", "cpu")
|
|
69
|
+
# NOTE neither of the above seems to actually help with an init error message :(
|
|
70
|
+
print(f"starting with envvars: {rank=}/{world_size=}, {coordinator=}")
|
|
71
|
+
if coordinator.startswith("tcp://"):
|
|
72
|
+
coordinator = coordinator[len("tcp://") :]
|
|
73
|
+
jax.distributed.initialize(coordinator, num_processes=world_size, process_id=rank)
|
|
74
|
+
assert jax.device_count() == world_size
|
|
75
|
+
|
|
76
|
+
mesh = jax.make_mesh((world_size,), ("i",))
|
|
77
|
+
global_data = jp.arange(world_size)
|
|
78
|
+
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("i"))
|
|
79
|
+
global_array = jax.device_put(global_data, sharding)
|
|
80
|
+
result = jp.sum(global_array)
|
|
81
|
+
print(f"worker {rank}# got result {result=}")
|
|
82
|
+
return a + result
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def build_dist_func(impl: str):
|
|
86
|
+
if impl == "torch":
|
|
87
|
+
return dist_func_torch
|
|
88
|
+
elif impl == "jax":
|
|
89
|
+
return dist_func_jax
|
|
90
|
+
else:
|
|
91
|
+
raise NotImplementedError(impl)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def sink_func(**kwargs) -> int:
|
|
95
|
+
c = 0
|
|
96
|
+
for _, v in kwargs.items():
|
|
97
|
+
c += v
|
|
98
|
+
print(f"sink accumulated {c}")
|
|
99
|
+
return c
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_job() -> JobInstance:
|
|
103
|
+
source_node = TaskBuilder.from_callable(source_func)
|
|
104
|
+
sink_node = TaskBuilder.from_callable(sink_func)
|
|
105
|
+
job = JobBuilder().with_node("source", source_node).with_node("sink", sink_node)
|
|
106
|
+
L = int(os.environ["DIST_L"])
|
|
107
|
+
IMPL = os.environ["DIST_IMPL"]
|
|
108
|
+
node = TaskBuilder.from_callable(build_dist_func(IMPL))
|
|
109
|
+
|
|
110
|
+
for i in range(L):
|
|
111
|
+
job = (
|
|
112
|
+
job.with_node(f"proc{i}", node)
|
|
113
|
+
.with_edge("source", f"proc{i}", "a")
|
|
114
|
+
.with_edge(f"proc{i}", "sink", f"v{i}")
|
|
115
|
+
)
|
|
116
|
+
job.nodes["sink"].definition.input_schema[
|
|
117
|
+
f"v{i}"
|
|
118
|
+
] = "int" # TODO put some allow_kw into TaskDefinition instead to allow this
|
|
119
|
+
|
|
120
|
+
job = job.build().get_or_raise()
|
|
121
|
+
job.ext_outputs = list(job.outputs_of("sink"))
|
|
122
|
+
job.constraints = [SchedulingConstraint(gang=[f"proc{i}" for i in range(L)])]
|
|
123
|
+
return job
|
cascade/benchmarks/job1.py
CHANGED
|
@@ -16,10 +16,10 @@ Controlled by env var params: JOB1_{DATA_ROOT, GRID, ...}, see below
|
|
|
16
16
|
import os
|
|
17
17
|
|
|
18
18
|
import earthkit.data
|
|
19
|
-
from ppcascade.fluent import from_source
|
|
20
|
-
from ppcascade.utils.window import Range
|
|
21
19
|
|
|
22
20
|
from earthkit.workflows.fluent import Payload
|
|
21
|
+
from earthkit.workflows.plugins.pproc.fluent import from_source
|
|
22
|
+
from earthkit.workflows.plugins.pproc.utils.window import Range
|
|
23
23
|
|
|
24
24
|
# *** PARAMS ***
|
|
25
25
|
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jp
|
|
6
|
+
import jax.random as jr
|
|
7
|
+
|
|
8
|
+
from cascade.low.builders import JobBuilder, TaskBuilder
|
|
9
|
+
from cascade.low.core import JobInstance
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_funcs():
|
|
13
|
+
K = int(os.environ["MATMUL_K"])
|
|
14
|
+
size = (2**K, 2**K)
|
|
15
|
+
E = int(os.environ["MATMUL_E"])
|
|
16
|
+
|
|
17
|
+
def source() -> Any:
|
|
18
|
+
k0 = jr.key(0)
|
|
19
|
+
m = jr.uniform(key=k0, shape=size)
|
|
20
|
+
return m
|
|
21
|
+
|
|
22
|
+
def powr(m: Any) -> Any:
|
|
23
|
+
print(f"powr device is {m.device}")
|
|
24
|
+
return m**E * jp.percentile(m, 0.7)
|
|
25
|
+
|
|
26
|
+
return source, powr
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_job() -> JobInstance:
|
|
30
|
+
L = int(os.environ["MATMUL_L"])
|
|
31
|
+
# D = os.environ["MATMUL_D"]
|
|
32
|
+
# it would be tempting to with jax.default_device(jax.devices(D)):
|
|
33
|
+
# alas, it doesn't work because we can't inject this at deser time
|
|
34
|
+
|
|
35
|
+
source, powr = get_funcs()
|
|
36
|
+
source_node = TaskBuilder.from_callable(source)
|
|
37
|
+
if os.environ.get("CUDA_VISIBLE_DEVICES", "") != "":
|
|
38
|
+
source_node.definition.needs_gpu = True
|
|
39
|
+
# currently no need to set True downstream since scheduler prefers no transfer
|
|
40
|
+
|
|
41
|
+
job = JobBuilder().with_node("source", source_node)
|
|
42
|
+
prv = "source"
|
|
43
|
+
for i in range(L):
|
|
44
|
+
cur = f"pow{i}"
|
|
45
|
+
node = TaskBuilder.from_callable(powr)
|
|
46
|
+
job = job.with_node(cur, node).with_edge(prv, cur, 0)
|
|
47
|
+
prv = cur
|
|
48
|
+
|
|
49
|
+
job = job.build().get_or_raise()
|
|
50
|
+
job.ext_outputs = list(job.outputs_of(cur))
|
|
51
|
+
return job
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def execute_locally():
|
|
55
|
+
L = int(os.environ["MATMUL_L"])
|
|
56
|
+
|
|
57
|
+
source, powr = get_funcs()
|
|
58
|
+
|
|
59
|
+
device = "gpu" if os.environ.get("CUDA_VISIBLE_DEVICES", "") != "" else "cpu"
|
|
60
|
+
print(f"device is {device}")
|
|
61
|
+
with jax.default_device(jax.devices(device)[0]):
|
|
62
|
+
m0 = source()
|
|
63
|
+
for _ in range(L):
|
|
64
|
+
m0 = powr(m0)
|
|
65
|
+
|
|
66
|
+
from multiprocessing.shared_memory import SharedMemory
|
|
67
|
+
|
|
68
|
+
mem = SharedMemory("benchmark_tmp", create=True, size=m0.nbytes)
|
|
69
|
+
mem.buf[:] = m0.tobytes()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
if __name__ == "__main__":
|
|
73
|
+
execute_locally()
|
cascade/controller/act.py
CHANGED
cascade/controller/impl.py
CHANGED
|
@@ -43,6 +43,11 @@ def run(
|
|
|
43
43
|
reporter = Reporter(report_address)
|
|
44
44
|
|
|
45
45
|
try:
|
|
46
|
+
total_gpus = sum(worker.gpu for worker in env.workers.values())
|
|
47
|
+
needs_gpus = any(task.definition.needs_gpu for task in job.tasks.values())
|
|
48
|
+
if needs_gpus and total_gpus == 0:
|
|
49
|
+
raise ValueError("environment contains no gpu yet job demands one")
|
|
50
|
+
|
|
46
51
|
while (
|
|
47
52
|
state.has_awaitable()
|
|
48
53
|
or context.has_awaitable()
|
cascade/controller/notify.py
CHANGED
|
@@ -22,6 +22,7 @@ from cascade.low.core import DatasetId, HostId, WorkerId
|
|
|
22
22
|
from cascade.low.execution_context import DatasetStatus, JobExecutionContext
|
|
23
23
|
from cascade.low.func import assert_never
|
|
24
24
|
from cascade.low.tracing import TaskLifecycle, TransmitLifecycle, mark
|
|
25
|
+
from cascade.scheduler.api import gang_check_ready
|
|
25
26
|
from cascade.scheduler.assign import set_worker2task_overhead
|
|
26
27
|
from cascade.scheduler.core import Schedule
|
|
27
28
|
|
|
@@ -67,6 +68,7 @@ def consider_computable(
|
|
|
67
68
|
# NOTE this is a task newly made computable, so we need to calc
|
|
68
69
|
# `overhead` for all hosts/workers assigned to the component
|
|
69
70
|
set_worker2task_overhead(schedule, context, worker, child_task)
|
|
71
|
+
gang_check_ready(child_task, component.gang_preparation)
|
|
70
72
|
|
|
71
73
|
|
|
72
74
|
# TODO refac less explicit mutation of context, use class methods
|
cascade/executor/bridge.py
CHANGED
|
@@ -46,7 +46,7 @@ class Bridge:
|
|
|
46
46
|
self.transmit_idx_counter = 0
|
|
47
47
|
self.sender = ReliableSender(self.mlistener.address, resend_grace_ms)
|
|
48
48
|
registered = 0
|
|
49
|
-
self.environment = Environment(workers={})
|
|
49
|
+
self.environment = Environment(workers={}, host_url_base={})
|
|
50
50
|
logger.debug("about to start receiving registrations")
|
|
51
51
|
registration_grace = time.time_ns() + 3 * 60 * 1_000_000_000
|
|
52
52
|
while registered < expected_executors:
|
|
@@ -69,6 +69,7 @@ class Bridge:
|
|
|
69
69
|
self.environment.workers[worker.worker_id] = Worker(
|
|
70
70
|
cpu=worker.cpu, gpu=worker.gpu, memory_mb=worker.memory_mb
|
|
71
71
|
)
|
|
72
|
+
self.environment.host_url_base[message.host] = message.url_base
|
|
72
73
|
registered += 1
|
|
73
74
|
self.heartbeat_checker[message.host] = GraceWatcher(
|
|
74
75
|
2 * executor_heartbeat_grace_ms
|
cascade/executor/config.py
CHANGED
|
@@ -27,6 +27,7 @@ logging_config = {
|
|
|
27
27
|
"cascade.executor": {"level": "DEBUG"},
|
|
28
28
|
"cascade.scheduler": {"level": "DEBUG"},
|
|
29
29
|
"cascade.gateway": {"level": "DEBUG"},
|
|
30
|
+
"earthkit.workflows": {"level": "DEBUG"},
|
|
30
31
|
"httpcore": {"level": "ERROR"},
|
|
31
32
|
"httpx": {"level": "ERROR"},
|
|
32
33
|
"": {"level": "WARNING", "handlers": ["default"]},
|
cascade/executor/executor.py
CHANGED
|
@@ -69,8 +69,9 @@ class Executor:
|
|
|
69
69
|
workers: int,
|
|
70
70
|
host: HostId,
|
|
71
71
|
portBase: int,
|
|
72
|
-
shm_vol_gb: int | None
|
|
73
|
-
log_base: str | None
|
|
72
|
+
shm_vol_gb: int | None,
|
|
73
|
+
log_base: str | None,
|
|
74
|
+
url_base: str,
|
|
74
75
|
) -> None:
|
|
75
76
|
self.job_instance = job_instance
|
|
76
77
|
self.param_source = param_source(job_instance.edges)
|
|
@@ -138,6 +139,7 @@ class Executor:
|
|
|
138
139
|
)
|
|
139
140
|
for idx, worker_id in enumerate(self.workers.keys())
|
|
140
141
|
],
|
|
142
|
+
url_base=url_base,
|
|
141
143
|
)
|
|
142
144
|
logger.debug("constructed executor")
|
|
143
145
|
|
cascade/executor/msg.py
CHANGED
|
@@ -71,6 +71,7 @@ class TaskSequence:
|
|
|
71
71
|
worker: WorkerId # worker for running those tasks
|
|
72
72
|
tasks: list[TaskId] # to be executed in the given order
|
|
73
73
|
publish: set[DatasetId] # set of outputs to be published
|
|
74
|
+
extra_env: list[tuple[str, str]] # extra env var to set
|
|
74
75
|
|
|
75
76
|
|
|
76
77
|
@dataclass(frozen=True)
|
|
@@ -147,6 +148,7 @@ class ExecutorRegistration:
|
|
|
147
148
|
host: HostId
|
|
148
149
|
maddress: BackboneAddress
|
|
149
150
|
daddress: BackboneAddress
|
|
151
|
+
url_base: str # used for eg dist comms init
|
|
150
152
|
workers: list[Worker]
|
|
151
153
|
|
|
152
154
|
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
import logging
|
|
12
12
|
import logging.config
|
|
13
13
|
import os
|
|
14
|
+
import sys
|
|
14
15
|
from dataclasses import dataclass
|
|
15
16
|
|
|
16
17
|
import zmq
|
|
@@ -67,6 +68,25 @@ class RunnerContext:
|
|
|
67
68
|
)
|
|
68
69
|
|
|
69
70
|
|
|
71
|
+
class Config:
|
|
72
|
+
"""Some parameters to drive behaviour. Currently not exposed externally -- no clear argument
|
|
73
|
+
that they should be. As is, just a means of code experimentation.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
# flushing approach -- when we finish a computation of task sequence, there is a question what
|
|
77
|
+
# to do with the output. We could either publish & drop, or publish and retain in memory. The
|
|
78
|
+
# former is is slower -- if the next task sequence needs this output, it requires a fetch & deser
|
|
79
|
+
# from cashme. But the latter is more risky -- we effectively have the same dataset twice in
|
|
80
|
+
# system memory. The `posttask_flush` below goes the former way, the `pretask_flush` is a careful
|
|
81
|
+
# way of latter -- we drop the output from memory only if the *next* task sequence does not need
|
|
82
|
+
# it, ie, we retain a cache of age 1. We could ultimately have controller decide about this, or
|
|
83
|
+
# decide dynamically based on memory pressure -- but neither is easy.
|
|
84
|
+
posttask_flush = False # after task is done, drop all outputs from memory
|
|
85
|
+
pretask_flush = (
|
|
86
|
+
True # when we receive a task, we drop those in memory that wont be needed
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
70
90
|
def worker_address(workerId: WorkerId) -> BackboneAddress:
|
|
71
91
|
return f"ipc:///tmp/{repr(workerId)}.socket"
|
|
72
92
|
|
|
@@ -79,11 +99,17 @@ def execute_sequence(
|
|
|
79
99
|
) -> None:
|
|
80
100
|
taskId: TaskId | None = None
|
|
81
101
|
try:
|
|
102
|
+
for key, value in taskSequence.extra_env.items():
|
|
103
|
+
os.environ[key] = value
|
|
82
104
|
executionContext = runnerContext.project(taskSequence)
|
|
83
105
|
for taskId in taskSequence.tasks:
|
|
84
106
|
pckg.extend(executionContext.tasks[taskId].definition.environment)
|
|
85
107
|
run(taskId, executionContext, memory)
|
|
86
|
-
|
|
108
|
+
if Config.posttask_flush:
|
|
109
|
+
memory.flush()
|
|
110
|
+
for key in taskSequence.extra_env.keys():
|
|
111
|
+
# NOTE we should in principle restore the previous value, but we dont expect collisions
|
|
112
|
+
del os.environ[key]
|
|
87
113
|
except Exception as e:
|
|
88
114
|
logger.exception("runner failure, about to report")
|
|
89
115
|
callback(
|
|
@@ -107,10 +133,17 @@ def entrypoint(runnerContext: RunnerContext):
|
|
|
107
133
|
PackagesEnv() as pckg,
|
|
108
134
|
):
|
|
109
135
|
label("worker", repr(runnerContext.workerId))
|
|
110
|
-
|
|
111
|
-
os.environ
|
|
112
|
-
|
|
113
|
-
|
|
136
|
+
worker_num = runnerContext.workerId.worker_num()
|
|
137
|
+
gpus = int(os.environ.get("CASCADE_GPU_COUNT", "0"))
|
|
138
|
+
if sys.platform != "darwin":
|
|
139
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = (
|
|
140
|
+
str(worker_num) if worker_num < gpus else ""
|
|
141
|
+
)
|
|
142
|
+
# NOTE check any(task.definition.needs_gpu) anywhere?
|
|
143
|
+
# TODO configure OMP_NUM_THREADS, blas, mkl, etc -- not clear how tho
|
|
144
|
+
else:
|
|
145
|
+
if gpus != 1:
|
|
146
|
+
logger.warning("unexpected absence of gpu on darwin")
|
|
114
147
|
|
|
115
148
|
for serdeTypeEnc, (serdeSer, serdeDes) in runnerContext.job.serdes.items():
|
|
116
149
|
serde.SerdeRegistry.register(type_dec(serdeTypeEnc), serdeSer, serdeDes)
|
|
@@ -151,6 +184,9 @@ def entrypoint(runnerContext: RunnerContext):
|
|
|
151
184
|
for key, _ in runnerContext.job.tasks[task].definition.output_schema
|
|
152
185
|
}
|
|
153
186
|
missing_ds = required - availab_ds
|
|
187
|
+
if Config.pretask_flush:
|
|
188
|
+
extraneous_ds = availab_ds - required
|
|
189
|
+
memory.flush(extraneous_ds)
|
|
154
190
|
if missing_ds:
|
|
155
191
|
waiting_ts = mDes
|
|
156
192
|
for ds in availab_ds.intersection(required):
|
|
@@ -51,7 +51,6 @@ class Memory(AbstractContextManager):
|
|
|
51
51
|
else:
|
|
52
52
|
outputValue = "ok"
|
|
53
53
|
|
|
54
|
-
# TODO how do we purge from here over time?
|
|
55
54
|
self.local[outputId] = outputValue
|
|
56
55
|
|
|
57
56
|
if isPublish:
|
|
@@ -68,6 +67,18 @@ class Memory(AbstractContextManager):
|
|
|
68
67
|
self.callback,
|
|
69
68
|
DatasetPublished(ds=outputId, origin=self.worker, transmit_idx=None),
|
|
70
69
|
)
|
|
70
|
+
else:
|
|
71
|
+
# NOTE even if its not actually published, we send the message to allow for
|
|
72
|
+
# marking the task itself as completed -- its odd, but arguably better than
|
|
73
|
+
# introducing a TaskCompleted message. TODO we should fine-grain host-wide
|
|
74
|
+
# and worker-only publishes at the `controller.notify` level, to not cause
|
|
75
|
+
# incorrect shm.purge calls at worklow end, which log an annoying key error
|
|
76
|
+
logger.debug(f"fake publish of {outputId} for the sake of task completion")
|
|
77
|
+
shmid = ds2shmid(outputId)
|
|
78
|
+
callback(
|
|
79
|
+
self.callback,
|
|
80
|
+
DatasetPublished(ds=outputId, origin=self.worker, transmit_idx=None),
|
|
81
|
+
)
|
|
71
82
|
|
|
72
83
|
def provide(self, inputId: DatasetId, annotation: str) -> Any:
|
|
73
84
|
if inputId not in self.local:
|
|
@@ -85,18 +96,24 @@ class Memory(AbstractContextManager):
|
|
|
85
96
|
|
|
86
97
|
def pop(self, ds: DatasetId) -> None:
|
|
87
98
|
if ds in self.local:
|
|
99
|
+
logger.debug(f"popping local {ds}")
|
|
88
100
|
val = self.local.pop(ds) # noqa: F841
|
|
89
101
|
del val
|
|
90
102
|
if ds in self.bufs:
|
|
103
|
+
logger.debug(f"popping buf {ds}")
|
|
91
104
|
buf = self.bufs.pop(ds)
|
|
92
105
|
buf.close()
|
|
93
106
|
|
|
94
|
-
def flush(self) -> None:
|
|
95
|
-
# NOTE poor man's memory management -- just drop those locals that
|
|
107
|
+
def flush(self, datasets: set[DatasetId] = set()) -> None:
|
|
108
|
+
# NOTE poor man's memory management -- just drop those locals that didn't come from cashme. Called
|
|
96
109
|
# after every taskSequence. In principle, we could purge some locals earlier, and ideally scheduler
|
|
97
110
|
# would invoke some targeted purges to also remove some published ones earlier (eg, they are still
|
|
98
111
|
# needed somewhere but not here)
|
|
99
|
-
purgeable = [
|
|
112
|
+
purgeable = [
|
|
113
|
+
inputId
|
|
114
|
+
for inputId in self.local
|
|
115
|
+
if inputId not in self.bufs and (not datasets or inputId in datasets)
|
|
116
|
+
]
|
|
100
117
|
logger.debug(f"will flush {len(purgeable)} datasets")
|
|
101
118
|
for inputId in purgeable:
|
|
102
119
|
self.local.pop(inputId)
|
|
@@ -115,6 +132,8 @@ class Memory(AbstractContextManager):
|
|
|
115
132
|
free, total = torch.cuda.mem_get_info()
|
|
116
133
|
logger.debug(f"cuda mem avail post cache empty: {free/total:.2%}")
|
|
117
134
|
if free / total < 0.8:
|
|
135
|
+
# NOTE this ofc makes low sense if there is any other application (like browser or ollama)
|
|
136
|
+
# that the user may be running
|
|
118
137
|
logger.warning("cuda mem avail low despite cache empty!")
|
|
119
138
|
logger.debug(torch.cuda.memory_summary())
|
|
120
139
|
except ImportError:
|
cascade/low/core.py
CHANGED
|
@@ -106,15 +106,26 @@ def type_enc(t: Type) -> str:
|
|
|
106
106
|
return b64encode(cloudpickle.dumps(t)).decode("ascii")
|
|
107
107
|
|
|
108
108
|
|
|
109
|
+
class SchedulingConstraint(BaseModel):
|
|
110
|
+
gang: list[TaskId] = Field(
|
|
111
|
+
description="this set of TaskIds must be started at the same time, with ranks and address list as envvar",
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
109
115
|
class JobInstance(BaseModel):
|
|
110
116
|
tasks: dict[TaskId, TaskInstance]
|
|
111
117
|
edges: list[Task2TaskEdge]
|
|
112
118
|
serdes: dict[str, tuple[str, str]] = Field(
|
|
113
|
-
{},
|
|
119
|
+
default_factory=lambda: {},
|
|
114
120
|
description="for each Type with custom serde, add entry here. The string is fully qualified name of the ser/des functions",
|
|
115
121
|
)
|
|
116
122
|
ext_outputs: list[DatasetId] = Field(
|
|
117
|
-
[],
|
|
123
|
+
default_factory=lambda: [],
|
|
124
|
+
description="ids to externally materialize",
|
|
125
|
+
)
|
|
126
|
+
constraints: list[SchedulingConstraint] = Field(
|
|
127
|
+
default_factory=lambda: [],
|
|
128
|
+
description="constraints for the scheduler such as gangs",
|
|
118
129
|
)
|
|
119
130
|
|
|
120
131
|
def outputs_of(self, task_id: TaskId) -> set[DatasetId]:
|
|
@@ -157,6 +168,7 @@ class Worker(BaseModel):
|
|
|
157
168
|
|
|
158
169
|
class Environment(BaseModel):
|
|
159
170
|
workers: dict[WorkerId, Worker]
|
|
171
|
+
host_url_base: dict[HostId, str]
|
|
160
172
|
|
|
161
173
|
|
|
162
174
|
class TaskExecutionRecord(BaseModel):
|
cascade/low/execution_context.py
CHANGED
|
@@ -108,6 +108,12 @@ class JobExecutionContext:
|
|
|
108
108
|
self.idle_workers.add(worker)
|
|
109
109
|
|
|
110
110
|
def dataset_preparing(self, dataset: DatasetId, worker: WorkerId) -> None:
|
|
111
|
+
# NOTE Currently this is invoked during `build_assignment`, as we need
|
|
112
|
+
# some state tranisition to allow fusing opportunities as well as
|
|
113
|
+
# preventing double transmits. This may not be the best idea, eg for long
|
|
114
|
+
# fusing chains -- instead, we may execute this transition at the time
|
|
115
|
+
# it actually happens, granularize the preparing state into
|
|
116
|
+
# (will_appear, is_appearing), etc
|
|
111
117
|
# NOTE Currently, these `if`s are necessary because we issue transmit
|
|
112
118
|
# command when host *has* DS but worker does *not*. This ends up no-op,
|
|
113
119
|
# but we totally dont want host state to reset -- it wouldnt recover
|