sibi-flux 2025.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sibi_dst/__init__.py +44 -0
  2. sibi_flux/__init__.py +49 -0
  3. sibi_flux/artifacts/__init__.py +7 -0
  4. sibi_flux/artifacts/base.py +166 -0
  5. sibi_flux/artifacts/parquet.py +360 -0
  6. sibi_flux/artifacts/parquet_engine/__init__.py +5 -0
  7. sibi_flux/artifacts/parquet_engine/executor.py +204 -0
  8. sibi_flux/artifacts/parquet_engine/manifest.py +101 -0
  9. sibi_flux/artifacts/parquet_engine/planner.py +544 -0
  10. sibi_flux/conf/settings.py +131 -0
  11. sibi_flux/core/__init__.py +5 -0
  12. sibi_flux/core/managed_resource/__init__.py +3 -0
  13. sibi_flux/core/managed_resource/_managed_resource.py +733 -0
  14. sibi_flux/core/type_maps/__init__.py +100 -0
  15. sibi_flux/dask_cluster/__init__.py +47 -0
  16. sibi_flux/dask_cluster/async_core.py +27 -0
  17. sibi_flux/dask_cluster/client_manager.py +549 -0
  18. sibi_flux/dask_cluster/core.py +322 -0
  19. sibi_flux/dask_cluster/exceptions.py +34 -0
  20. sibi_flux/dask_cluster/utils.py +49 -0
  21. sibi_flux/datacube/__init__.py +3 -0
  22. sibi_flux/datacube/_data_cube.py +332 -0
  23. sibi_flux/datacube/config_engine.py +152 -0
  24. sibi_flux/datacube/field_factory.py +48 -0
  25. sibi_flux/datacube/field_registry.py +122 -0
  26. sibi_flux/datacube/generator.py +677 -0
  27. sibi_flux/datacube/orchestrator.py +171 -0
  28. sibi_flux/dataset/__init__.py +3 -0
  29. sibi_flux/dataset/_dataset.py +162 -0
  30. sibi_flux/df_enricher/__init__.py +56 -0
  31. sibi_flux/df_enricher/async_enricher.py +201 -0
  32. sibi_flux/df_enricher/merger.py +253 -0
  33. sibi_flux/df_enricher/specs.py +45 -0
  34. sibi_flux/df_enricher/types.py +12 -0
  35. sibi_flux/df_helper/__init__.py +5 -0
  36. sibi_flux/df_helper/_df_helper.py +450 -0
  37. sibi_flux/df_helper/backends/__init__.py +34 -0
  38. sibi_flux/df_helper/backends/_params.py +173 -0
  39. sibi_flux/df_helper/backends/_strategies.py +295 -0
  40. sibi_flux/df_helper/backends/http/__init__.py +5 -0
  41. sibi_flux/df_helper/backends/http/_http_config.py +122 -0
  42. sibi_flux/df_helper/backends/parquet/__init__.py +7 -0
  43. sibi_flux/df_helper/backends/parquet/_parquet_options.py +268 -0
  44. sibi_flux/df_helper/backends/sqlalchemy/__init__.py +9 -0
  45. sibi_flux/df_helper/backends/sqlalchemy/_db_connection.py +256 -0
  46. sibi_flux/df_helper/backends/sqlalchemy/_db_gatekeeper.py +15 -0
  47. sibi_flux/df_helper/backends/sqlalchemy/_io_dask.py +386 -0
  48. sibi_flux/df_helper/backends/sqlalchemy/_load_from_db.py +134 -0
  49. sibi_flux/df_helper/backends/sqlalchemy/_model_registry.py +239 -0
  50. sibi_flux/df_helper/backends/sqlalchemy/_sql_model_builder.py +42 -0
  51. sibi_flux/df_helper/backends/utils.py +32 -0
  52. sibi_flux/df_helper/core/__init__.py +15 -0
  53. sibi_flux/df_helper/core/_defaults.py +104 -0
  54. sibi_flux/df_helper/core/_filter_handler.py +617 -0
  55. sibi_flux/df_helper/core/_params_config.py +185 -0
  56. sibi_flux/df_helper/core/_query_config.py +17 -0
  57. sibi_flux/df_validator/__init__.py +3 -0
  58. sibi_flux/df_validator/_df_validator.py +222 -0
  59. sibi_flux/logger/__init__.py +1 -0
  60. sibi_flux/logger/_logger.py +480 -0
  61. sibi_flux/mcp/__init__.py +26 -0
  62. sibi_flux/mcp/client.py +150 -0
  63. sibi_flux/mcp/router.py +126 -0
  64. sibi_flux/orchestration/__init__.py +9 -0
  65. sibi_flux/orchestration/_artifact_orchestrator.py +346 -0
  66. sibi_flux/orchestration/_pipeline_executor.py +212 -0
  67. sibi_flux/osmnx_helper/__init__.py +22 -0
  68. sibi_flux/osmnx_helper/_pbf_handler.py +384 -0
  69. sibi_flux/osmnx_helper/graph_loader.py +225 -0
  70. sibi_flux/osmnx_helper/utils.py +100 -0
  71. sibi_flux/pipelines/__init__.py +3 -0
  72. sibi_flux/pipelines/base.py +218 -0
  73. sibi_flux/py.typed +0 -0
  74. sibi_flux/readers/__init__.py +3 -0
  75. sibi_flux/readers/base.py +82 -0
  76. sibi_flux/readers/parquet.py +106 -0
  77. sibi_flux/utils/__init__.py +53 -0
  78. sibi_flux/utils/boilerplate/__init__.py +19 -0
  79. sibi_flux/utils/boilerplate/base_attacher.py +45 -0
  80. sibi_flux/utils/boilerplate/base_cube_router.py +283 -0
  81. sibi_flux/utils/boilerplate/base_data_cube.py +132 -0
  82. sibi_flux/utils/boilerplate/base_pipeline_template.py +54 -0
  83. sibi_flux/utils/boilerplate/hybrid_data_loader.py +193 -0
  84. sibi_flux/utils/clickhouse_writer/__init__.py +6 -0
  85. sibi_flux/utils/clickhouse_writer/_clickhouse_writer.py +225 -0
  86. sibi_flux/utils/common.py +7 -0
  87. sibi_flux/utils/credentials/__init__.py +3 -0
  88. sibi_flux/utils/credentials/_config_manager.py +155 -0
  89. sibi_flux/utils/dask_utils.py +14 -0
  90. sibi_flux/utils/data_utils/__init__.py +3 -0
  91. sibi_flux/utils/data_utils/_data_utils.py +389 -0
  92. sibi_flux/utils/dataframe_utils.py +52 -0
  93. sibi_flux/utils/date_utils/__init__.py +10 -0
  94. sibi_flux/utils/date_utils/_business_days.py +220 -0
  95. sibi_flux/utils/date_utils/_date_utils.py +311 -0
  96. sibi_flux/utils/date_utils/_file_age_checker.py +319 -0
  97. sibi_flux/utils/file_utils.py +48 -0
  98. sibi_flux/utils/filepath_generator/__init__.py +5 -0
  99. sibi_flux/utils/filepath_generator/_filepath_generator.py +185 -0
  100. sibi_flux/utils/parquet_saver/__init__.py +6 -0
  101. sibi_flux/utils/parquet_saver/_parquet_saver.py +436 -0
  102. sibi_flux/utils/parquet_saver/_write_gatekeeper.py +33 -0
  103. sibi_flux/utils/retry.py +46 -0
  104. sibi_flux/utils/storage/__init__.py +7 -0
  105. sibi_flux/utils/storage/_fs_registry.py +112 -0
  106. sibi_flux/utils/storage/_storage_manager.py +257 -0
  107. sibi_flux/utils/storage/factory.py +33 -0
  108. sibi_flux-2025.12.0.dist-info/METADATA +283 -0
  109. sibi_flux-2025.12.0.dist-info/RECORD +110 -0
  110. sibi_flux-2025.12.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,549 @@
1
+ """
2
+ Dask Client Manager module.
3
+
4
+ Manages Dask cluster lifecycles, registry paths, and client connections.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import json
11
+ import logging
12
+ import os
13
+ import tempfile
14
+ import time
15
+ import psutil
16
+ import multiprocessing
17
+ import threading
18
+ from pathlib import Path
19
+ from contextlib import asynccontextmanager, contextmanager, suppress
20
+ from typing import Optional, Dict, Any
21
+
22
+ import shutil
23
+
24
+ try:
25
+ from dask.distributed import Client, LocalCluster, get_client
26
+ HAS_DISTRIBUTED = True
27
+ except ImportError:
28
+ Client = object
29
+ LocalCluster = object
30
+ get_client = None
31
+ HAS_DISTRIBUTED = False
32
+
33
+ # Lazy import watchdog to allow remote workers (which might lack it) to import this module
34
+ try:
35
+ from watchdog.observers import Observer
36
+ from watchdog.events import FileSystemEventHandler
37
+
38
+ HAS_WATCHDOG = True
39
+ except ImportError:
40
+ Observer = object
41
+ FileSystemEventHandler = object
42
+ HAS_WATCHDOG = False
43
+
44
+ try:
45
+ from sibi_flux.logger import Logger
46
+
47
+ _default_logger = Logger.default_logger(logger_name="dask_lifecycle_manager")
48
+ except ImportError:
49
+ _default_logger = logging.getLogger(__name__)
50
+ if not _default_logger.handlers:
51
+ _default_logger.addHandler(logging.StreamHandler())
52
+ _default_logger.setLevel(logging.INFO)
53
+
54
+ _cores = multiprocessing.cpu_count()
55
+ DEFAULT_WORKERS = max(2, min(4, _cores))
56
+ DEFAULT_THREADS = 2
57
+
58
+
59
+ class RegistryEventHandler(FileSystemEventHandler):
60
+ """
61
+ Watchdog handler to trigger logic when the registry file changes.
62
+ Used to detect if the shared cluster configuration has been updated by another process.
63
+ """
64
+
65
+ def __init__(self, callback):
66
+ if not HAS_WATCHDOG:
67
+ return
68
+ self.callback = callback
69
+
70
+ def on_modified(self, event):
71
+ if not event.is_directory and event.src_path.endswith(
72
+ "shared_dask_cluster.json"
73
+ ):
74
+ self.callback()
75
+
76
+
77
+ class DaskClientMixin:
78
+ """
79
+ Stabilised Dask lifecycle manager.
80
+
81
+ This class manages a shared Dask LocalCluster instance, enabling:
82
+ - Singleton cluster management across processes via a file-based registry.
83
+ - Automatic healing and reconnection if the cluster becomes unresponsive.
84
+ - Event-driven monitoring using `watchdog` to detect external registry updates.
85
+ """
86
+
87
+ REGISTRY_PATH = Path(tempfile.gettempdir()) / "shared_dask_cluster.json"
88
+ REGISTRY_LOCK_PATH = REGISTRY_PATH.parent / (REGISTRY_PATH.name + ".lock")
89
+ WATCHDOG_INTERVAL = 30
90
+ HEAL_COOLDOWN = 5.0
91
+
92
+ def __init__(self, **kwargs):
93
+ self.dask_client: Optional[Client] = None
94
+ self.own_dask_client = False
95
+ self.logger: Any = kwargs.get("logger") or _default_logger
96
+ self._watchdog_task: Optional[asyncio.Task] = None
97
+ self._watchdog_stop = asyncio.Event()
98
+ self._fs_observer: Optional[Observer] = None
99
+ self._init_params = {}
100
+ self._last_heal_time = 0.0
101
+ self._last_registry_write = 0.0
102
+
103
+ def _deploy_local_code(self, client: Client) -> None:
104
+ """
105
+ Zips the local 'sibi_flux' package and uploads it to the scheduler.
106
+ Dask distributed will propagate this to all workers.
107
+ """
108
+ try:
109
+ import sibi_flux
110
+
111
+ pkg_path = Path(sibi_flux.__file__).parent
112
+ src_root = pkg_path.parent
113
+
114
+ with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp:
115
+ archive_path = Path(tmp.name)
116
+
117
+ shutil.make_archive(
118
+ str(archive_path.with_suffix("")), # trim extension, shutil adds it
119
+ "zip",
120
+ root_dir=str(src_root),
121
+ base_dir="sibi_flux",
122
+ )
123
+
124
+ self.logger.info(f"Deploying local code to cluster: {archive_path}")
125
+ client.upload_file(str(archive_path))
126
+
127
+ # Cleanup
128
+ archive_path.unlink(missing_ok=True)
129
+
130
+ except Exception as e:
131
+ self.logger.warning(f"Failed to deploy local code: {e}")
132
+
133
+ def _verify_connectivity(self, client: Client) -> bool:
134
+ """
135
+ Verifies that the client can actually run tasks on the cluster.
136
+ This catches issues like missing environments, version mismatches,
137
+ or deserialization errors that occur after connection.
138
+ """
139
+ try:
140
+ # 1. Check for workers
141
+ # Giving it a moment to discover workers if just connected
142
+ client.wait_for_workers(1, timeout=2.0)
143
+
144
+ def environment_probe():
145
+ return True
146
+
147
+ f = client.submit(environment_probe)
148
+ # Retry loop for race conditions (workers restart/receive file)
149
+ for attempt in range(3):
150
+ try:
151
+ if f.result(timeout=5.0) is True:
152
+ return True
153
+ except Exception:
154
+ if attempt < 2:
155
+ time.sleep(2.0)
156
+ continue
157
+ return False
158
+ except Exception as e:
159
+ self.logger.warning(
160
+ f"Cluster verification failed (Environment mismatch?): {e}",
161
+ extra={"error": str(e)},
162
+ )
163
+ return False
164
+
165
+ def _remove_registry(self):
166
+ with suppress(FileNotFoundError):
167
+ self.REGISTRY_PATH.unlink()
168
+
169
+ def _cleanup_stale_registry(self):
170
+ """Removes the registry file if the owning process is no longer running."""
171
+ reg = self._read_registry()
172
+ if reg and reg.get("pid"):
173
+ if not psutil.pid_exists(reg["pid"]):
174
+ self.logger.warning(
175
+ "Cleaning zombie registry", extra={"pid": reg["pid"]}
176
+ )
177
+ self._remove_registry()
178
+
179
+ def _retire_all_workers(self):
180
+ """Attempts to gracefully retire all workers."""
181
+ if self.dask_client:
182
+ with suppress(Exception):
183
+ self.dask_client.retire_workers()
184
+
185
+ def _has_inflight(self) -> bool:
186
+ return False
187
+
188
+ @contextmanager
189
+ def _registry_lock(self, timeout: float = 10.0):
190
+ """
191
+ Acquires an exclusive file lock to ensure atomic registry updates.
192
+ """
193
+ start_time = time.time()
194
+ while True:
195
+ try:
196
+ # Atomically create the lock file. Fails if exists.
197
+ self.REGISTRY_LOCK_PATH.touch(exist_ok=False)
198
+ try:
199
+ yield
200
+ finally:
201
+ with suppress(OSError):
202
+ self.REGISTRY_LOCK_PATH.unlink()
203
+ return
204
+ except FileExistsError:
205
+ if time.time() - start_time > timeout:
206
+ raise TimeoutError("Could not acquire registry lock")
207
+ time.sleep(0.05)
208
+
209
+ def _read_registry(self) -> Optional[Dict[str, Any]]:
210
+ """Reads and parses the registry file JSON."""
211
+ if not self.REGISTRY_PATH.exists():
212
+ return None
213
+ try:
214
+ with self.REGISTRY_PATH.open("r") as f:
215
+ return json.load(f)
216
+ except (json.JSONDecodeError, OSError):
217
+ return None
218
+
219
+ def _write_registry(self, data: Dict[str, Any]) -> None:
220
+ """
221
+ Writes data to the registry file atomically via a temporary file.
222
+ Updates the timestamp of the last write.
223
+ """
224
+ self._last_registry_write = time.time()
225
+ # Use a temporary file in the same directory to ensure atomic move support
226
+ tmp = self.REGISTRY_PATH.with_suffix(".tmp")
227
+ with tmp.open("w") as f:
228
+ json.dump(data, f)
229
+ tmp.replace(self.REGISTRY_PATH)
230
+
231
+ def _init_dask_client(self, **kwargs) -> None:
232
+ self._init_params = kwargs
233
+ if not HAS_DISTRIBUTED:
234
+ self.logger.info("Dask Distributed not installed. Skipping cluster initialization.")
235
+ return
236
+
237
+ if kwargs.get("dask_client"):
238
+ self.dask_client = kwargs["dask_client"]
239
+ return
240
+ if kwargs.get("reuse_context_client", True):
241
+ with suppress(ValueError):
242
+ self.dask_client = get_client()
243
+ return
244
+ self._connect_or_create()
245
+ if kwargs.get("watchdog", True):
246
+ self._start_watchdog()
247
+
248
+ def _connect_or_create(self) -> None:
249
+ """
250
+ Connects to an existing Dask cluster or creates a new one.
251
+ Handles locking, registry validation, and fallback to ephemeral clusters.
252
+ """
253
+ p = self._init_params
254
+ try:
255
+ # 1. Explicit argument
256
+ explicit_address = p.get("scheduler_address")
257
+ if explicit_address:
258
+ self.logger.info(
259
+ f"Connecting to EXTERNAL Dask Scheduler (Explicit): {explicit_address}"
260
+ )
261
+ try:
262
+ candidate_client = Client(
263
+ address=explicit_address,
264
+ timeout=p.get("timeout", 5),
265
+ set_as_default=p.get("set_as_default", True),
266
+ )
267
+ if self._verify_connectivity(candidate_client):
268
+ self.dask_client = candidate_client
269
+ self.own_dask_client = False
270
+ return
271
+ else:
272
+ candidate_client.close()
273
+ except Exception as e:
274
+ self.logger.error(f"Failed to connect to specified scheduler: {e}")
275
+
276
+ # 2. Environment Variable
277
+ env_address = os.environ.get("DASK_SCHEDULER_ADDRESS")
278
+ if env_address:
279
+ self.logger.info(
280
+ f"Connecting to EXTERNAL Dask Scheduler (Env): {env_address}"
281
+ )
282
+ try:
283
+ candidate_client = Client(
284
+ address=env_address,
285
+ timeout=p.get("timeout", 5),
286
+ set_as_default=p.get("set_as_default", True),
287
+ )
288
+
289
+ # Optional: Deploy code if requested
290
+ if p.get("deploy_mode", False):
291
+ self._deploy_local_code(candidate_client)
292
+
293
+ if self._verify_connectivity(candidate_client):
294
+ self.dask_client = candidate_client
295
+ self.own_dask_client = False
296
+ return
297
+ else:
298
+ self.logger.warning(
299
+ "External cluster unusable. Falling back to LOCAL."
300
+ )
301
+ candidate_client.close()
302
+ except Exception as e:
303
+ self.logger.error(
304
+ f"Failed to connect to env-specified scheduler: {e}"
305
+ )
306
+
307
+ try:
308
+ lock_context = self._registry_lock()
309
+ except TimeoutError:
310
+ self.logger.warning(
311
+ "Lock acquisition timed out; falling back to ephemeral cluster."
312
+ )
313
+ self._bootstrap_new_cluster(p)
314
+ return
315
+
316
+ try:
317
+ with lock_context:
318
+ self._cleanup_stale_registry()
319
+ reg = self._read_registry()
320
+
321
+ if reg and not reg.get("closing"):
322
+ try:
323
+ self.dask_client = Client(
324
+ address=reg["address"],
325
+ timeout=p.get("timeout", 30),
326
+ set_as_default=p.get("set_as_default", True),
327
+ )
328
+ reg["refcount"] = int(reg.get("refcount", 0)) + 1
329
+ self._write_registry(reg)
330
+ self.own_dask_client = False
331
+ self.logger.info(
332
+ f"Connected to SHARED LOCAL Dask cluster at {reg['address']}\n"
333
+ f"Dashboard: {reg.get('dashboard_link', 'Unknown')}"
334
+ )
335
+ return
336
+ except Exception:
337
+ self.logger.warning("Failed to connect to shared cluster.")
338
+
339
+ self._bootstrap_new_cluster(p)
340
+ except TimeoutError:
341
+ self.logger.warning("Lock context entry timed out; falling back.")
342
+ self._bootstrap_new_cluster(p)
343
+
344
+ except Exception as e:
345
+ self.logger.error("Dask initialization failed", extra={"error": str(e)})
346
+
347
+ def _bootstrap_new_cluster(self, p: Dict[str, Any]) -> None:
348
+ """
349
+ Creates a new LocalCluster instance.
350
+
351
+ Configures the cluster based on input parameters or defaults.
352
+ Logs:
353
+ Silences 'distributed.shuffle' logs to ERROR level to reduce noise.
354
+ """
355
+
356
+ self.logger.info("Starting NEW LOCAL Dask cluster...")
357
+ cluster = LocalCluster(
358
+ n_workers=p.get("n_workers", DEFAULT_WORKERS),
359
+ threads_per_worker=p.get("threads_per_worker", DEFAULT_THREADS),
360
+ processes=p.get("processes", False),
361
+ dashboard_address=p.get("dashboard_address", ":0"),
362
+ )
363
+ self.dask_client = Client(cluster, set_as_default=p.get("set_as_default", True))
364
+ self.own_dask_client = True
365
+ self.logger.info(f"Local Cluster Dashboard: {cluster.dashboard_link}")
366
+ self._write_registry(
367
+ {
368
+ "address": cluster.scheduler_address,
369
+ "dashboard_link": cluster.dashboard_link,
370
+ "refcount": 1,
371
+ "closing": False,
372
+ "pid": os.getpid(),
373
+ "created_at": time.time(),
374
+ }
375
+ )
376
+
377
+ def _close_dask_client(self) -> None:
378
+ """
379
+ Closes the Dask client and cleans up resources.
380
+ Decrements refcount in registry and shuts down cluster if count hits zero.
381
+ """
382
+ if self._fs_observer:
383
+ self._fs_observer.stop()
384
+ self._fs_observer.join(timeout=2.0)
385
+ self._fs_observer = None
386
+
387
+ if self._watchdog_task:
388
+ self._watchdog_stop.set()
389
+
390
+ if not self.dask_client:
391
+ return
392
+
393
+ if not self.own_dask_client:
394
+ self.dask_client = None
395
+ return
396
+
397
+ try:
398
+ with self._registry_lock():
399
+ if self._has_inflight():
400
+ return
401
+
402
+ reg = self._read_registry()
403
+ if reg:
404
+ reg["refcount"] = max(0, int(reg.get("refcount", 1)) - 1)
405
+ if reg["refcount"] > 0:
406
+ self._write_registry(reg)
407
+ self.dask_client.close()
408
+ return
409
+ reg["closing"] = True
410
+ self._write_registry(reg)
411
+ self._retire_all_workers()
412
+ self._remove_registry()
413
+
414
+ cluster = getattr(self.dask_client, "cluster", None)
415
+ self.dask_client.close()
416
+ if cluster:
417
+ cluster.close()
418
+ self.dask_client = None
419
+ except Exception as e:
420
+ self.logger.error("Shutdown error", extra={"error": str(e)})
421
+
422
+ def _start_watchdog(self) -> None:
423
+ if not HAS_WATCHDOG:
424
+ return
425
+
426
+ try:
427
+ loop = asyncio.get_running_loop()
428
+ except RuntimeError:
429
+ return
430
+
431
+ def on_fs_event():
432
+ if time.time() - self._last_registry_write < 1.5:
433
+ return
434
+ loop.call_soon_threadsafe(
435
+ lambda: asyncio.create_task(self._heal_if_disconnected())
436
+ )
437
+
438
+ handler = RegistryEventHandler(on_fs_event)
439
+ self._fs_observer = Observer()
440
+ self._fs_observer.schedule(
441
+ handler, str(self.REGISTRY_PATH.parent), recursive=False
442
+ )
443
+ self._fs_observer.start()
444
+
445
+ async def liveness_checker():
446
+ while not self._watchdog_stop.is_set():
447
+ await asyncio.sleep(self.WATCHDOG_INTERVAL)
448
+ await self._heal_if_disconnected()
449
+
450
+ self._watchdog_task = loop.create_task(liveness_checker())
451
+
452
+ async def _heal_if_disconnected(self) -> None:
453
+ now = time.time()
454
+ if now - self._last_heal_time < self.HEAL_COOLDOWN:
455
+ return
456
+ is_healthy = False
457
+ if self.dask_client:
458
+ try:
459
+ # Handle both sync and async clients
460
+ if getattr(self.dask_client, "asynchronous", False):
461
+ await self.dask_client.scheduler_info()
462
+ else:
463
+ # Sync client: scheduler_info() returns dict immediately
464
+ self.dask_client.scheduler_info()
465
+ is_healthy = True
466
+ except Exception:
467
+ pass
468
+ if not is_healthy:
469
+ self._last_heal_time = now
470
+ self.logger.warning("Dask heartbeat lost. Healing.")
471
+ with suppress(Exception):
472
+ if self.dask_client:
473
+ await self.dask_client.close(timeout=1)
474
+ self.dask_client = None
475
+ self._connect_or_create()
476
+
477
+
478
+ _persistent_mixin: Optional[DaskClientMixin] = None
479
+ _singleton_lock = threading.Lock()
480
+
481
+
482
+ def get_persistent_client(**kwargs) -> Client:
483
+ """
484
+ Retrieves or creates a singleton Dask client instance.
485
+ This client persists until `force_close_persistent_client` is called.
486
+ """
487
+ if not HAS_DISTRIBUTED:
488
+ return None
489
+ global _persistent_mixin
490
+ with _singleton_lock:
491
+ if _persistent_mixin is None:
492
+ _persistent_mixin = DaskClientMixin(logger=kwargs.get("logger"))
493
+
494
+ # If client is None, initialize it
495
+ if _persistent_mixin.dask_client is None:
496
+ _persistent_mixin._init_dask_client(**kwargs)
497
+ else:
498
+ # Check if we need to upgrade from Local to External+Deploy
499
+ current_local = _persistent_mixin.own_dask_client
500
+ want_deploy = kwargs.get("deploy_mode", False)
501
+
502
+ # Update params for future healings
503
+ _persistent_mixin._init_params.update(kwargs)
504
+
505
+ if want_deploy and current_local:
506
+ _persistent_mixin.logger.info(
507
+ "Re-initialising client to apply 'deploy_mode' configuration."
508
+ )
509
+ _persistent_mixin._close_dask_client()
510
+ # _init_dask_client will use the updated params
511
+ _persistent_mixin._init_dask_client(**kwargs)
512
+
513
+ return _persistent_mixin.dask_client # type: ignore
514
+
515
+
516
+ def force_close_persistent_client():
517
+ global _persistent_mixin
518
+ with _singleton_lock:
519
+ if _persistent_mixin:
520
+ _persistent_mixin._close_dask_client()
521
+ _persistent_mixin = None
522
+
523
+
524
+ def shared_dask_session(*, async_mode: bool = True, **kwargs):
525
+ """
526
+ Context manager for a shared Dask session.
527
+ """
528
+ mixin = DaskClientMixin(logger=kwargs.get("logger"))
529
+ mixin._init_dask_client(**kwargs)
530
+ if async_mode:
531
+
532
+ @asynccontextmanager
533
+ async def _async_manager():
534
+ try:
535
+ yield mixin.dask_client
536
+ finally:
537
+ mixin._close_dask_client()
538
+
539
+ return _async_manager()
540
+ else:
541
+
542
+ @contextmanager
543
+ def _sync_manager():
544
+ try:
545
+ yield mixin.dask_client
546
+ finally:
547
+ mixin._close_dask_client()
548
+
549
+ return _sync_manager()