furu 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,271 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ import uuid
5
+ from collections.abc import Mapping
6
+ from dataclasses import dataclass
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+
10
+ from ..adapters import SubmititAdapter
11
+ from ..adapters.submitit import SubmititJob
12
+ from ..config import FURU_CONFIG
13
+ from ..core import Furu
14
+ from ..errors import FuruExecutionError
15
+ from ..storage.state import StateManager, _FuruState, _StateResultFailed
16
+ from .plan import DependencyPlan, build_plan, topo_order_todo
17
+ from .slurm_spec import SlurmSpec, SlurmSpecExtraValue
18
+ from .submitit_factory import make_executor_for_spec
19
+
20
+
21
+ @dataclass
22
+ class SlurmDagSubmission:
23
+ plan: DependencyPlan
24
+ job_id_by_hash: dict[str, str]
25
+ root_job_ids: dict[str, str]
26
+ run_id: str
27
+
28
+
29
+ def _job_id_from_state(obj: Furu, directory: Path | None = None) -> str | None:
30
+ state = obj.get_state(directory)
31
+ attempt = state.attempt
32
+ if attempt is None:
33
+ return None
34
+ job_id = attempt.scheduler.get("job_id")
35
+ if job_id is None:
36
+ return None
37
+ return str(job_id)
38
+
39
+
40
+ def _attempt_is_terminal(obj: Furu, directory: Path | None = None) -> bool:
41
+ state = obj.get_state(directory)
42
+ attempt = state.attempt
43
+ if attempt is None:
44
+ return False
45
+ return attempt.status in StateManager.TERMINAL_STATUSES
46
+
47
+
48
+ def _set_submitit_job_id(directory: Path, job_id: str) -> bool:
49
+ updated = False
50
+
51
+ def mutate(state: _FuruState) -> None:
52
+ nonlocal updated
53
+ attempt = state.attempt
54
+ if attempt is None:
55
+ return
56
+ if attempt.backend != "submitit":
57
+ return
58
+ if (
59
+ attempt.status not in {"queued", "running"}
60
+ and attempt.status not in StateManager.TERMINAL_STATUSES
61
+ ):
62
+ return
63
+ existing = attempt.scheduler.get("job_id")
64
+ if existing == job_id:
65
+ updated = True
66
+ return
67
+ attempt.scheduler["job_id"] = job_id
68
+ updated = True
69
+
70
+ StateManager.update_state(directory, mutate)
71
+ return updated
72
+
73
+
74
+ def _wait_for_job_id(
75
+ obj: Furu,
76
+ adapter: SubmititAdapter,
77
+ job: SubmititJob | None,
78
+ *,
79
+ timeout_sec: float = 15.0,
80
+ poll_interval_sec: float = 0.25,
81
+ ) -> str:
82
+ deadline = time.time() + timeout_sec
83
+ directory = obj._base_furu_dir()
84
+ last_job_id: str | None = None
85
+
86
+ while True:
87
+ job_id = _job_id_from_state(obj, directory)
88
+ if job_id:
89
+ if job is None:
90
+ return job_id
91
+ adapter_job_id = adapter.get_job_id(job)
92
+ if adapter_job_id is None or str(adapter_job_id) == job_id:
93
+ return job_id
94
+ last_job_id = str(adapter_job_id)
95
+ _set_submitit_job_id(directory, last_job_id)
96
+ state_job_id = _job_id_from_state(obj, directory)
97
+ if state_job_id is not None and state_job_id == last_job_id:
98
+ return state_job_id
99
+ if _attempt_is_terminal(obj, directory):
100
+ return last_job_id
101
+
102
+ if job is None:
103
+ job = adapter.load_job(directory)
104
+
105
+ if job is not None:
106
+ job_id = adapter.get_job_id(job)
107
+ if job_id:
108
+ last_job_id = job_id
109
+ _set_submitit_job_id(directory, job_id)
110
+ state_job_id = _job_id_from_state(obj, directory)
111
+ if state_job_id is not None and state_job_id == job_id:
112
+ return state_job_id
113
+ if _attempt_is_terminal(obj, directory):
114
+ return str(job_id)
115
+
116
+ if time.time() >= deadline:
117
+ suffix = f" Last seen job_id={last_job_id}." if last_job_id else ""
118
+ raise TimeoutError(
119
+ "Timed out waiting for submitit job_id for "
120
+ f"{obj.__class__.__name__} ({obj._furu_hash}).{suffix}"
121
+ )
122
+
123
+ time.sleep(poll_interval_sec)
124
+
125
+
126
+ def _job_id_for_in_progress(obj: Furu) -> str:
127
+ state = obj.get_state()
128
+ attempt = state.attempt
129
+ if attempt is None:
130
+ raise RuntimeError(
131
+ "Cannot wire Slurm DAG dependency for IN_PROGRESS "
132
+ f"{obj.__class__.__name__} ({obj._furu_hash}) without an attempt."
133
+ )
134
+ if attempt.backend != "submitit":
135
+ raise FuruExecutionError(
136
+ "Cannot wire afterok dependencies to non-submitit in-progress nodes. "
137
+ "Use pool mode or wait until completed."
138
+ )
139
+
140
+ # If the dependency has already become terminal and failed (or otherwise did not
141
+ # succeed), wiring `afterok` would permanently block dependents.
142
+ if isinstance(state.result, _StateResultFailed) or (
143
+ attempt.status in StateManager.TERMINAL_STATUSES and attempt.status != "success"
144
+ ):
145
+ raise FuruExecutionError(
146
+ "Cannot wire afterok dependency to a terminal non-success dependency. "
147
+ f"Dependency {obj.__class__.__name__} ({obj._furu_hash}) status={attempt.status}."
148
+ )
149
+
150
+ job_id = attempt.scheduler.get("job_id")
151
+ if job_id:
152
+ resolved = str(job_id)
153
+ else:
154
+ adapter = SubmititAdapter(executor=None)
155
+ resolved = _wait_for_job_id(obj, adapter, None)
156
+
157
+ # Re-check after waiting: the attempt could flip to terminal while we're
158
+ # retrieving job_id. If it ended non-success, fail fast instead of wiring
159
+ # dependents to an `afterok` that will never unblock.
160
+ state2 = obj.get_state()
161
+ attempt2 = state2.attempt
162
+ if attempt2 is not None and attempt2.status in StateManager.TERMINAL_STATUSES:
163
+ if attempt2.status != "success" or isinstance(state2.result, _StateResultFailed):
164
+ raise FuruExecutionError(
165
+ "Cannot wire afterok dependency: dependency became terminal and did not succeed. "
166
+ f"Dependency {obj.__class__.__name__} ({obj._furu_hash}) status={attempt2.status} "
167
+ f"job_id={resolved}."
168
+ )
169
+
170
+ return resolved
171
+
172
+
173
+ def _make_run_id() -> str:
174
+ stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
175
+ token = uuid.uuid4().hex[:6]
176
+ return f"{stamp}-{token}"
177
+
178
+
179
+ def submit_slurm_dag(
180
+ roots: list[Furu],
181
+ *,
182
+ specs: dict[str, SlurmSpec],
183
+ submitit_root: Path | None = None,
184
+ ) -> SlurmDagSubmission:
185
+ if "default" not in specs:
186
+ raise KeyError("Missing slurm spec for key 'default'.")
187
+
188
+ run_id = _make_run_id()
189
+ plan = build_plan(roots)
190
+ failed = [node for node in plan.nodes.values() if node.status == "FAILED"]
191
+ if failed:
192
+ names = ", ".join(
193
+ f"{node.obj.__class__.__name__}({node.obj._furu_hash})" for node in failed
194
+ )
195
+ raise RuntimeError(f"Cannot submit slurm DAG with failed dependencies: {names}")
196
+
197
+ order = topo_order_todo(plan)
198
+ job_id_by_hash: dict[str, str] = {}
199
+ root_job_ids: dict[str, str] = {}
200
+
201
+ root_hashes = {root._furu_hash for root in roots}
202
+
203
+ for digest in order:
204
+ node = plan.nodes[digest]
205
+ dep_job_ids: list[str] = []
206
+ for dep_hash in sorted(node.deps_pending):
207
+ dep_node = plan.nodes[dep_hash]
208
+ if dep_node.status == "IN_PROGRESS":
209
+ dep_job_ids.append(_job_id_for_in_progress(dep_node.obj))
210
+ elif dep_node.status == "TODO":
211
+ dep_job_ids.append(job_id_by_hash[dep_hash])
212
+
213
+ spec_key = node.spec_key
214
+ if spec_key not in specs:
215
+ raise KeyError(
216
+ "Missing slurm spec for key "
217
+ f"'{spec_key}' for node {node.obj.__class__.__name__} ({digest})."
218
+ )
219
+
220
+ spec = specs[spec_key]
221
+ executor = make_executor_for_spec(
222
+ spec_key,
223
+ spec,
224
+ kind="nodes",
225
+ submitit_root=submitit_root,
226
+ run_id=run_id,
227
+ )
228
+ if dep_job_ids:
229
+ dependency = "afterok:" + ":".join(dep_job_ids)
230
+ slurm_params: dict[str, SlurmSpecExtraValue] = {"dependency": dependency}
231
+ if spec.extra:
232
+ extra_params = spec.extra.get("slurm_additional_parameters")
233
+ if extra_params is not None:
234
+ if not isinstance(extra_params, Mapping):
235
+ raise TypeError(
236
+ "slurm_additional_parameters must be a mapping when provided."
237
+ )
238
+ slurm_params = {
239
+ **dict(extra_params),
240
+ "dependency": dependency,
241
+ }
242
+ executor.update_parameters(slurm_additional_parameters=slurm_params)
243
+
244
+ adapter = SubmititAdapter(executor)
245
+ job = node.obj._submit_once(
246
+ adapter,
247
+ directory=node.obj._base_furu_dir(),
248
+ on_job_id=None,
249
+ allow_failed=FURU_CONFIG.retry_failed,
250
+ )
251
+ job_id = _wait_for_job_id(node.obj, adapter, job)
252
+ job_id_by_hash[digest] = job_id
253
+ if digest in root_hashes:
254
+ root_job_ids[digest] = job_id
255
+
256
+ for root in roots:
257
+ digest = root._furu_hash
258
+ if digest in root_job_ids:
259
+ continue
260
+ node = plan.nodes.get(digest)
261
+ if node is None:
262
+ continue
263
+ if node.status == "IN_PROGRESS":
264
+ root_job_ids[digest] = _job_id_for_in_progress(node.obj)
265
+
266
+ return SlurmDagSubmission(
267
+ plan=plan,
268
+ job_id_by_hash=job_id_by_hash,
269
+ root_job_ids=root_job_ids,
270
+ run_id=run_id,
271
+ )