earthkit-workflows 0.3.6__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,9 @@ import logging.config
26
26
  import multiprocessing
27
27
  import os
28
28
  import subprocess
29
+ import sys
29
30
  from concurrent.futures import ThreadPoolExecutor
31
+ from socket import getfqdn
30
32
  from time import perf_counter_ns
31
33
 
32
34
  import fire
@@ -41,7 +43,7 @@ from cascade.executor.executor import Executor
41
43
  from cascade.executor.msg import BackboneAddress, ExecutorShutdown
42
44
  from cascade.low.core import JobInstance
43
45
  from cascade.low.func import msum
44
- from cascade.scheduler.graph import precompute
46
+ from cascade.scheduler.precompute import precompute
45
47
  from earthkit.workflows.graph import Graph, deduplicate_nodes
46
48
 
47
49
  logger = logging.getLogger("cascade.benchmarks")
@@ -73,14 +75,28 @@ def get_job(benchmark: str | None, instance_path: str | None) -> JobInstance:
73
75
  import cascade.benchmarks.generators as generators
74
76
 
75
77
  return generators.get_job()
78
+ elif benchmark.startswith("matmul"):
79
+ import cascade.benchmarks.matmul as matmul
80
+
81
+ return matmul.get_job()
82
+ elif benchmark.startswith("dist"):
83
+ import cascade.benchmarks.dist as dist
84
+
85
+ return dist.get_job()
76
86
  else:
77
87
  raise NotImplementedError(benchmark)
78
88
  else:
79
89
  raise TypeError("specified neither benchmark name nor job instance")
80
90
 
81
91
 
82
- def get_gpu_count() -> int:
92
+ def get_cuda_count() -> int:
83
93
  try:
94
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
95
+ # TODO we dont want to just count, we want to actually use literally these ids
96
+ # NOTE this is particularly useful for "" value -- careful when refactoring
97
+ visible = os.environ["CUDA_VISIBLE_DEVICES"]
98
+ visible_count = sum(1 for e in visible if e == ",") + (1 if visible else 0)
99
+ return visible_count
84
100
  gpus = sum(
85
101
  1
86
102
  for l in subprocess.run(
@@ -91,12 +107,22 @@ def get_gpu_count() -> int:
91
107
  if "GPU" in l
92
108
  )
93
109
  except:
94
- # TODO support macos
95
110
  logger.exception("unable to determine available gpus")
96
111
  gpus = 0
97
112
  return gpus
98
113
 
99
114
 
115
+ def get_gpu_count(host_idx: int, worker_count: int) -> int:
116
+ if sys.platform == "darwin":
117
+ # we should inspect some gpu capabilities details to prevent overcommit
118
+ return worker_count
119
+ else:
120
+ if host_idx == 0:
121
+ return get_cuda_count()
122
+ else:
123
+ return 0
124
+
125
+
100
126
  def launch_executor(
101
127
  job_instance: JobInstance,
102
128
  controller_address: BackboneAddress,
@@ -106,6 +132,7 @@ def launch_executor(
106
132
  shm_vol_gb: int | None,
107
133
  gpu_count: int,
108
134
  log_base: str | None,
135
+ url_base: str,
109
136
  ):
110
137
  if log_base is not None:
111
138
  log_base = f"{log_base}.host{i}"
@@ -123,6 +150,7 @@ def launch_executor(
123
150
  portBase,
124
151
  shm_vol_gb,
125
152
  log_base,
153
+ url_base,
126
154
  )
127
155
  executor.register()
128
156
  executor.recv_loop()
@@ -147,14 +175,21 @@ def run_locally(
147
175
  m = f"tcp://localhost:{portBase+1}"
148
176
  ps = []
149
177
  for i, executor in enumerate(range(hosts)):
150
- if i == 0:
151
- gpu_count = get_gpu_count()
152
- else:
153
- gpu_count = 0
178
+ gpu_count = get_gpu_count(i, workers)
154
179
  # NOTE forkserver/spawn seem to forget venv, we need fork
155
180
  p = multiprocessing.get_context("fork").Process(
156
181
  target=launch_executor,
157
- args=(job, c, workers, portBase + 1 + i * 10, i, None, gpu_count, log_base),
182
+ args=(
183
+ job,
184
+ c,
185
+ workers,
186
+ portBase + 1 + i * 10,
187
+ i,
188
+ None,
189
+ gpu_count,
190
+ log_base,
191
+ "tcp://localhost",
192
+ ),
158
193
  )
159
194
  p.start()
160
195
  ps.append(p)
@@ -228,7 +263,7 @@ def main_dist(
228
263
  f"compute took {(end-start)/1e9:.3f}s, including startup {(end-launch)/1e9:.3f}s"
229
264
  )
230
265
  else:
231
- gpu_count = get_gpu_count()
266
+ gpu_count = get_gpu_count(0, workers_per_host)
232
267
  launch_executor(
233
268
  jobInstance,
234
269
  controller_url,
@@ -237,6 +272,7 @@ def main_dist(
237
272
  idx,
238
273
  shm_vol_gb,
239
274
  gpu_count,
275
+ f"tcp://{getfqdn()}",
240
276
  )
241
277
 
242
278
 
@@ -0,0 +1,123 @@
1
+ """Demonstrates gang scheduling capabilities, ie, multiple nodes capable of mutual communication.
2
+
3
+ The job is a source -> (dist group) -> sink, where:
4
+ source just returns an int,
5
+ dist group is L nodes to be scheduled as a single gang
6
+ rank=0 node broadcasts a buffer containing the node's input
7
+ each node returns its input multiplied by broadcasted buffer
8
+ sink returns the sum of all inputs
9
+
10
+ There are multiple implementations of that:
11
+ torch
12
+ jax (actually does a mesh-shard global sum instead of broadcast -- the point is to showcase dist init)
13
+ """
14
+
15
+ import os
16
+
17
+ from cascade.low.builders import JobBuilder, TaskBuilder
18
+ from cascade.low.core import JobInstance, SchedulingConstraint
19
+
20
+
21
+ def source_func() -> int:
22
+ return 42
23
+
24
+
25
+ def dist_func_torch(a: int) -> int:
26
+ import datetime as dt
27
+
28
+ import numpy as np
29
+ import torch.distributed as dist
30
+
31
+ world_size = int(os.environ["CASCADE_GANG_WORLD_SIZE"])
32
+ rank = int(os.environ["CASCADE_GANG_RANK"])
33
+ coordinator = os.environ["CASCADE_GANG_COORDINATOR"]
34
+ print(f"starting with envvars: {rank=}/{world_size=}, {coordinator=}")
35
+ dist.init_process_group(
36
+ backend="gloo",
37
+ init_method=coordinator,
38
+ timeout=dt.timedelta(minutes=1),
39
+ world_size=world_size,
40
+ rank=rank,
41
+ )
42
+ group_ranks = np.arange(world_size, dtype=int)
43
+ group = dist.new_group(group_ranks)
44
+
45
+ if rank == 0:
46
+ buf = [a]
47
+ dist.broadcast_object_list(buf, src=0, group=group)
48
+ print("broadcast ok")
49
+ else:
50
+ buf = np.array([0], dtype=np.uint64)
51
+ dist.broadcast_object_list(buf, src=0, group=group)
52
+ print(f"broadcast recevied {buf}")
53
+
54
+ return a * buf[0]
55
+
56
+
57
+ def dist_func_jax(a: int) -> int:
58
+ world_size = int(os.environ["CASCADE_GANG_WORLD_SIZE"])
59
+ rank = int(os.environ["CASCADE_GANG_RANK"])
60
+ coordinator = os.environ["CASCADE_GANG_COORDINATOR"]
61
+ os.environ["JAX_NUM_CPU_DEVICES"] = "1"
62
+ os.environ["JAX_PLATFORM_NAME"] = "cpu"
63
+ os.environ["JAX_PLATFORMS"] = "cpu"
64
+ import jax
65
+ import jax.numpy as jp
66
+
67
+ jax.config.update("jax_platforms", "cpu")
68
+ jax.config.update("jax_platform_name", "cpu")
69
+ # NOTE neither of the above seems to actually help with an init error message :(
70
+ print(f"starting with envvars: {rank=}/{world_size=}, {coordinator=}")
71
+ if coordinator.startswith("tcp://"):
72
+ coordinator = coordinator[len("tcp://") :]
73
+ jax.distributed.initialize(coordinator, num_processes=world_size, process_id=rank)
74
+ assert jax.device_count() == world_size
75
+
76
+ mesh = jax.make_mesh((world_size,), ("i",))
77
+ global_data = jp.arange(world_size)
78
+ sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("i"))
79
+ global_array = jax.device_put(global_data, sharding)
80
+ result = jp.sum(global_array)
81
+ print(f"worker {rank}# got result {result=}")
82
+ return a + result
83
+
84
+
85
+ def build_dist_func(impl: str):
86
+ if impl == "torch":
87
+ return dist_func_torch
88
+ elif impl == "jax":
89
+ return dist_func_jax
90
+ else:
91
+ raise NotImplementedError(impl)
92
+
93
+
94
+ def sink_func(**kwargs) -> int:
95
+ c = 0
96
+ for _, v in kwargs.items():
97
+ c += v
98
+ print(f"sink accumulated {c}")
99
+ return c
100
+
101
+
102
+ def get_job() -> JobInstance:
103
+ source_node = TaskBuilder.from_callable(source_func)
104
+ sink_node = TaskBuilder.from_callable(sink_func)
105
+ job = JobBuilder().with_node("source", source_node).with_node("sink", sink_node)
106
+ L = int(os.environ["DIST_L"])
107
+ IMPL = os.environ["DIST_IMPL"]
108
+ node = TaskBuilder.from_callable(build_dist_func(IMPL))
109
+
110
+ for i in range(L):
111
+ job = (
112
+ job.with_node(f"proc{i}", node)
113
+ .with_edge("source", f"proc{i}", "a")
114
+ .with_edge(f"proc{i}", "sink", f"v{i}")
115
+ )
116
+ job.nodes["sink"].definition.input_schema[
117
+ f"v{i}"
118
+ ] = "int" # TODO put some allow_kw into TaskDefinition instead to allow this
119
+
120
+ job = job.build().get_or_raise()
121
+ job.ext_outputs = list(job.outputs_of("sink"))
122
+ job.constraints = [SchedulingConstraint(gang=[f"proc{i}" for i in range(L)])]
123
+ return job
@@ -16,10 +16,10 @@ Controlled by env var params: JOB1_{DATA_ROOT, GRID, ...}, see below
16
16
  import os
17
17
 
18
18
  import earthkit.data
19
- from ppcascade.fluent import from_source
20
- from ppcascade.utils.window import Range
21
19
 
22
20
  from earthkit.workflows.fluent import Payload
21
+ from earthkit.workflows.plugins.pproc.fluent import from_source
22
+ from earthkit.workflows.plugins.pproc.utils.window import Range
23
23
 
24
24
  # *** PARAMS ***
25
25
 
@@ -0,0 +1,73 @@
1
+ import os
2
+ from typing import Any
3
+
4
+ import jax
5
+ import jax.numpy as jp
6
+ import jax.random as jr
7
+
8
+ from cascade.low.builders import JobBuilder, TaskBuilder
9
+ from cascade.low.core import JobInstance
10
+
11
+
12
+ def get_funcs():
13
+ K = int(os.environ["MATMUL_K"])
14
+ size = (2**K, 2**K)
15
+ E = int(os.environ["MATMUL_E"])
16
+
17
+ def source() -> Any:
18
+ k0 = jr.key(0)
19
+ m = jr.uniform(key=k0, shape=size)
20
+ return m
21
+
22
+ def powr(m: Any) -> Any:
23
+ print(f"powr device is {m.device}")
24
+ return m**E * jp.percentile(m, 0.7)
25
+
26
+ return source, powr
27
+
28
+
29
+ def get_job() -> JobInstance:
30
+ L = int(os.environ["MATMUL_L"])
31
+ # D = os.environ["MATMUL_D"]
32
+ # it would be tempting to with jax.default_device(jax.devices(D)):
33
+ # alas, it doesn't work because we can't inject this at deser time
34
+
35
+ source, powr = get_funcs()
36
+ source_node = TaskBuilder.from_callable(source)
37
+ if os.environ.get("CUDA_VISIBLE_DEVICES", "") != "":
38
+ source_node.definition.needs_gpu = True
39
+ # currently no need to set True downstream since scheduler prefers no transfer
40
+
41
+ job = JobBuilder().with_node("source", source_node)
42
+ prv = "source"
43
+ for i in range(L):
44
+ cur = f"pow{i}"
45
+ node = TaskBuilder.from_callable(powr)
46
+ job = job.with_node(cur, node).with_edge(prv, cur, 0)
47
+ prv = cur
48
+
49
+ job = job.build().get_or_raise()
50
+ job.ext_outputs = list(job.outputs_of(cur))
51
+ return job
52
+
53
+
54
+ def execute_locally():
55
+ L = int(os.environ["MATMUL_L"])
56
+
57
+ source, powr = get_funcs()
58
+
59
+ device = "gpu" if os.environ.get("CUDA_VISIBLE_DEVICES", "") != "" else "cpu"
60
+ print(f"device is {device}")
61
+ with jax.default_device(jax.devices(device)[0]):
62
+ m0 = source()
63
+ for _ in range(L):
64
+ m0 = powr(m0)
65
+
66
+ from multiprocessing.shared_memory import SharedMemory
67
+
68
+ mem = SharedMemory("benchmark_tmp", create=True, size=m0.nbytes)
69
+ mem.buf[:] = m0.tobytes()
70
+
71
+
72
+ if __name__ == "__main__":
73
+ execute_locally()
cascade/controller/act.py CHANGED
@@ -51,6 +51,7 @@ def act(bridge: Bridge, assignment: Assignment) -> None:
51
51
  worker=assignment.worker,
52
52
  tasks=assignment.tasks,
53
53
  publish=assignment.outputs,
54
+ extra_env=assignment.extra_env,
54
55
  )
55
56
 
56
57
  for task in assignment.tasks:
@@ -43,6 +43,11 @@ def run(
43
43
  reporter = Reporter(report_address)
44
44
 
45
45
  try:
46
+ total_gpus = sum(worker.gpu for worker in env.workers.values())
47
+ needs_gpus = any(task.definition.needs_gpu for task in job.tasks.values())
48
+ if needs_gpus and total_gpus == 0:
49
+ raise ValueError("environment contains no gpu yet job demands one")
50
+
46
51
  while (
47
52
  state.has_awaitable()
48
53
  or context.has_awaitable()
@@ -22,6 +22,7 @@ from cascade.low.core import DatasetId, HostId, WorkerId
22
22
  from cascade.low.execution_context import DatasetStatus, JobExecutionContext
23
23
  from cascade.low.func import assert_never
24
24
  from cascade.low.tracing import TaskLifecycle, TransmitLifecycle, mark
25
+ from cascade.scheduler.api import gang_check_ready
25
26
  from cascade.scheduler.assign import set_worker2task_overhead
26
27
  from cascade.scheduler.core import Schedule
27
28
 
@@ -67,6 +68,7 @@ def consider_computable(
67
68
  # NOTE this is a task newly made computable, so we need to calc
68
69
  # `overhead` for all hosts/workers assigned to the component
69
70
  set_worker2task_overhead(schedule, context, worker, child_task)
71
+ gang_check_ready(child_task, component.gang_preparation)
70
72
 
71
73
 
72
74
  # TODO refac less explicit mutation of context, use class methods
@@ -46,7 +46,7 @@ class Bridge:
46
46
  self.transmit_idx_counter = 0
47
47
  self.sender = ReliableSender(self.mlistener.address, resend_grace_ms)
48
48
  registered = 0
49
- self.environment = Environment(workers={})
49
+ self.environment = Environment(workers={}, host_url_base={})
50
50
  logger.debug("about to start receiving registrations")
51
51
  registration_grace = time.time_ns() + 3 * 60 * 1_000_000_000
52
52
  while registered < expected_executors:
@@ -69,6 +69,7 @@ class Bridge:
69
69
  self.environment.workers[worker.worker_id] = Worker(
70
70
  cpu=worker.cpu, gpu=worker.gpu, memory_mb=worker.memory_mb
71
71
  )
72
+ self.environment.host_url_base[message.host] = message.url_base
72
73
  registered += 1
73
74
  self.heartbeat_checker[message.host] = GraceWatcher(
74
75
  2 * executor_heartbeat_grace_ms
@@ -27,6 +27,7 @@ logging_config = {
27
27
  "cascade.executor": {"level": "DEBUG"},
28
28
  "cascade.scheduler": {"level": "DEBUG"},
29
29
  "cascade.gateway": {"level": "DEBUG"},
30
+ "earthkit.workflows": {"level": "DEBUG"},
30
31
  "httpcore": {"level": "ERROR"},
31
32
  "httpx": {"level": "ERROR"},
32
33
  "": {"level": "WARNING", "handlers": ["default"]},
@@ -69,8 +69,9 @@ class Executor:
69
69
  workers: int,
70
70
  host: HostId,
71
71
  portBase: int,
72
- shm_vol_gb: int | None = None,
73
- log_base: str | None = None,
72
+ shm_vol_gb: int | None,
73
+ log_base: str | None,
74
+ url_base: str,
74
75
  ) -> None:
75
76
  self.job_instance = job_instance
76
77
  self.param_source = param_source(job_instance.edges)
@@ -138,6 +139,7 @@ class Executor:
138
139
  )
139
140
  for idx, worker_id in enumerate(self.workers.keys())
140
141
  ],
142
+ url_base=url_base,
141
143
  )
142
144
  logger.debug("constructed executor")
143
145
 
cascade/executor/msg.py CHANGED
@@ -71,6 +71,7 @@ class TaskSequence:
71
71
  worker: WorkerId # worker for running those tasks
72
72
  tasks: list[TaskId] # to be executed in the given order
73
73
  publish: set[DatasetId] # set of outputs to be published
74
+ extra_env: list[tuple[str, str]] # extra env var to set
74
75
 
75
76
 
76
77
  @dataclass(frozen=True)
@@ -147,6 +148,7 @@ class ExecutorRegistration:
147
148
  host: HostId
148
149
  maddress: BackboneAddress
149
150
  daddress: BackboneAddress
151
+ url_base: str # used for eg dist comms init
150
152
  workers: list[Worker]
151
153
 
152
154
 
@@ -11,6 +11,7 @@
11
11
  import logging
12
12
  import logging.config
13
13
  import os
14
+ import sys
14
15
  from dataclasses import dataclass
15
16
 
16
17
  import zmq
@@ -67,6 +68,25 @@ class RunnerContext:
67
68
  )
68
69
 
69
70
 
71
+ class Config:
72
+ """Some parameters to drive behaviour. Currently not exposed externally -- no clear argument
73
+ that they should be. As is, just a means of code experimentation.
74
+ """
75
+
76
+ # flushing approach -- when we finish a computation of task sequence, there is a question what
77
+ # to do with the output. We could either publish & drop, or publish and retain in memory. The
78
+ # former is is slower -- if the next task sequence needs this output, it requires a fetch & deser
79
+ # from cashme. But the latter is more risky -- we effectively have the same dataset twice in
80
+ # system memory. The `posttask_flush` below goes the former way, the `pretask_flush` is a careful
81
+ # way of latter -- we drop the output from memory only if the *next* task sequence does not need
82
+ # it, ie, we retain a cache of age 1. We could ultimately have controller decide about this, or
83
+ # decide dynamically based on memory pressure -- but neither is easy.
84
+ posttask_flush = False # after task is done, drop all outputs from memory
85
+ pretask_flush = (
86
+ True # when we receive a task, we drop those in memory that wont be needed
87
+ )
88
+
89
+
70
90
  def worker_address(workerId: WorkerId) -> BackboneAddress:
71
91
  return f"ipc:///tmp/{repr(workerId)}.socket"
72
92
 
@@ -79,11 +99,17 @@ def execute_sequence(
79
99
  ) -> None:
80
100
  taskId: TaskId | None = None
81
101
  try:
102
+ for key, value in taskSequence.extra_env.items():
103
+ os.environ[key] = value
82
104
  executionContext = runnerContext.project(taskSequence)
83
105
  for taskId in taskSequence.tasks:
84
106
  pckg.extend(executionContext.tasks[taskId].definition.environment)
85
107
  run(taskId, executionContext, memory)
86
- memory.flush()
108
+ if Config.posttask_flush:
109
+ memory.flush()
110
+ for key in taskSequence.extra_env.keys():
111
+ # NOTE we should in principle restore the previous value, but we dont expect collisions
112
+ del os.environ[key]
87
113
  except Exception as e:
88
114
  logger.exception("runner failure, about to report")
89
115
  callback(
@@ -107,10 +133,17 @@ def entrypoint(runnerContext: RunnerContext):
107
133
  PackagesEnv() as pckg,
108
134
  ):
109
135
  label("worker", repr(runnerContext.workerId))
110
- gpu_id = str(runnerContext.workerId.worker_num())
111
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_id)
112
- # NOTE check any(task.definition.needs_gpu) anywhere?
113
- # TODO configure OMP_NUM_THREADS, blas, mkl, etc -- not clear how tho
136
+ worker_num = runnerContext.workerId.worker_num()
137
+ gpus = int(os.environ.get("CASCADE_GPU_COUNT", "0"))
138
+ if sys.platform != "darwin":
139
+ os.environ["CUDA_VISIBLE_DEVICES"] = (
140
+ str(worker_num) if worker_num < gpus else ""
141
+ )
142
+ # NOTE check any(task.definition.needs_gpu) anywhere?
143
+ # TODO configure OMP_NUM_THREADS, blas, mkl, etc -- not clear how tho
144
+ else:
145
+ if gpus != 1:
146
+ logger.warning("unexpected absence of gpu on darwin")
114
147
 
115
148
  for serdeTypeEnc, (serdeSer, serdeDes) in runnerContext.job.serdes.items():
116
149
  serde.SerdeRegistry.register(type_dec(serdeTypeEnc), serdeSer, serdeDes)
@@ -151,6 +184,9 @@ def entrypoint(runnerContext: RunnerContext):
151
184
  for key, _ in runnerContext.job.tasks[task].definition.output_schema
152
185
  }
153
186
  missing_ds = required - availab_ds
187
+ if Config.pretask_flush:
188
+ extraneous_ds = availab_ds - required
189
+ memory.flush(extraneous_ds)
154
190
  if missing_ds:
155
191
  waiting_ts = mDes
156
192
  for ds in availab_ds.intersection(required):
@@ -51,7 +51,6 @@ class Memory(AbstractContextManager):
51
51
  else:
52
52
  outputValue = "ok"
53
53
 
54
- # TODO how do we purge from here over time?
55
54
  self.local[outputId] = outputValue
56
55
 
57
56
  if isPublish:
@@ -68,6 +67,18 @@ class Memory(AbstractContextManager):
68
67
  self.callback,
69
68
  DatasetPublished(ds=outputId, origin=self.worker, transmit_idx=None),
70
69
  )
70
+ else:
71
+ # NOTE even if its not actually published, we send the message to allow for
72
+ # marking the task itself as completed -- its odd, but arguably better than
73
+ # introducing a TaskCompleted message. TODO we should fine-grain host-wide
74
+ # and worker-only publishes at the `controller.notify` level, to not cause
75
+ # incorrect shm.purge calls at worklow end, which log an annoying key error
76
+ logger.debug(f"fake publish of {outputId} for the sake of task completion")
77
+ shmid = ds2shmid(outputId)
78
+ callback(
79
+ self.callback,
80
+ DatasetPublished(ds=outputId, origin=self.worker, transmit_idx=None),
81
+ )
71
82
 
72
83
  def provide(self, inputId: DatasetId, annotation: str) -> Any:
73
84
  if inputId not in self.local:
@@ -85,18 +96,24 @@ class Memory(AbstractContextManager):
85
96
 
86
97
  def pop(self, ds: DatasetId) -> None:
87
98
  if ds in self.local:
99
+ logger.debug(f"popping local {ds}")
88
100
  val = self.local.pop(ds) # noqa: F841
89
101
  del val
90
102
  if ds in self.bufs:
103
+ logger.debug(f"popping buf {ds}")
91
104
  buf = self.bufs.pop(ds)
92
105
  buf.close()
93
106
 
94
- def flush(self) -> None:
95
- # NOTE poor man's memory management -- just drop those locals that weren't published. Called
107
+ def flush(self, datasets: set[DatasetId] = set()) -> None:
108
+ # NOTE poor man's memory management -- just drop those locals that didn't come from cashme. Called
96
109
  # after every taskSequence. In principle, we could purge some locals earlier, and ideally scheduler
97
110
  # would invoke some targeted purges to also remove some published ones earlier (eg, they are still
98
111
  # needed somewhere but not here)
99
- purgeable = [inputId for inputId in self.local if inputId not in self.bufs]
112
+ purgeable = [
113
+ inputId
114
+ for inputId in self.local
115
+ if inputId not in self.bufs and (not datasets or inputId in datasets)
116
+ ]
100
117
  logger.debug(f"will flush {len(purgeable)} datasets")
101
118
  for inputId in purgeable:
102
119
  self.local.pop(inputId)
@@ -115,6 +132,8 @@ class Memory(AbstractContextManager):
115
132
  free, total = torch.cuda.mem_get_info()
116
133
  logger.debug(f"cuda mem avail post cache empty: {free/total:.2%}")
117
134
  if free / total < 0.8:
135
+ # NOTE this ofc makes low sense if there is any other application (like browser or ollama)
136
+ # that the user may be running
118
137
  logger.warning("cuda mem avail low despite cache empty!")
119
138
  logger.debug(torch.cuda.memory_summary())
120
139
  except ImportError:
cascade/low/core.py CHANGED
@@ -106,15 +106,26 @@ def type_enc(t: Type) -> str:
106
106
  return b64encode(cloudpickle.dumps(t)).decode("ascii")
107
107
 
108
108
 
109
+ class SchedulingConstraint(BaseModel):
110
+ gang: list[TaskId] = Field(
111
+ description="this set of TaskIds must be started at the same time, with ranks and address list as envvar",
112
+ )
113
+
114
+
109
115
  class JobInstance(BaseModel):
110
116
  tasks: dict[TaskId, TaskInstance]
111
117
  edges: list[Task2TaskEdge]
112
118
  serdes: dict[str, tuple[str, str]] = Field(
113
- {},
119
+ default_factory=lambda: {},
114
120
  description="for each Type with custom serde, add entry here. The string is fully qualified name of the ser/des functions",
115
121
  )
116
122
  ext_outputs: list[DatasetId] = Field(
117
- [], description="ids to externally materialize"
123
+ default_factory=lambda: [],
124
+ description="ids to externally materialize",
125
+ )
126
+ constraints: list[SchedulingConstraint] = Field(
127
+ default_factory=lambda: [],
128
+ description="constraints for the scheduler such as gangs",
118
129
  )
119
130
 
120
131
  def outputs_of(self, task_id: TaskId) -> set[DatasetId]:
@@ -157,6 +168,7 @@ class Worker(BaseModel):
157
168
 
158
169
  class Environment(BaseModel):
159
170
  workers: dict[WorkerId, Worker]
171
+ host_url_base: dict[HostId, str]
160
172
 
161
173
 
162
174
  class TaskExecutionRecord(BaseModel):
@@ -108,6 +108,12 @@ class JobExecutionContext:
108
108
  self.idle_workers.add(worker)
109
109
 
110
110
  def dataset_preparing(self, dataset: DatasetId, worker: WorkerId) -> None:
111
+ # NOTE Currently this is invoked during `build_assignment`, as we need
112
+ # some state tranisition to allow fusing opportunities as well as
113
+ # preventing double transmits. This may not be the best idea, eg for long
114
+ # fusing chains -- instead, we may execute this transition at the time
115
+ # it actually happens, granularize the preparing state into
116
+ # (will_appear, is_appearing), etc
111
117
  # NOTE Currently, these `if`s are necessary because we issue transmit
112
118
  # command when host *has* DS but worker does *not*. This ends up no-op,
113
119
  # but we totally dont want host state to reset -- it wouldnt recover