redzed 25.12.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redzed/circuit.py ADDED
@@ -0,0 +1,756 @@
1
+ """
2
+ The circuit runner.
3
+ - - - - - -
4
+ Part of the redzed package.
5
+ # Docs: https://redzed.readthedocs.io/en/latest/
6
+ # Project home: https://github.com/xitop/redzed/
7
+ """
8
+ from __future__ import annotations
9
+
10
+ __all__ = ['CircuitState', 'get_circuit', 'reset_circuit', 'run', 'unique_name']
11
+
12
+ import asyncio
13
+ from collections.abc import Coroutine, Iterable, MutableMapping, Sequence
14
+ import contextlib
15
+ import enum
16
+ import itertools
17
+ import logging
18
+ import signal
19
+ import time
20
+ import typing as t
21
+
22
+ from .block import Block, PersistenceFlags
23
+ from .cron_service import Cron
24
+ from .debug import get_debug_level
25
+ from .initializers import AsyncInitializer
26
+ from .formula_trigger import Formula, Trigger
27
+ from .signal_shutdown import TerminatingSignal
28
+ from .undef import UNDEF
29
+ from .utils import check_async_coro, check_identifier, tasks_are_eager, time_period
30
+
31
+ _logger = logging.getLogger(__package__)
32
+ _current_circuit: Circuit|None = None
33
+
34
+ # [attr-defined]: access to rz_xxx attrs is guarded by has_method() checks
35
+ # mypy: disable-error-code="attr-defined"
36
+
37
+
38
+ def get_circuit() -> Circuit:
39
+ """Get the current circuit. Create one if it does not exist."""
40
+ global _current_circuit # pylint: disable=global-statement
41
+
42
+ if _current_circuit is None:
43
+ _current_circuit = Circuit()
44
+ return _current_circuit
45
+
46
+
47
+ def reset_circuit() -> None:
48
+ global _current_circuit # pylint: disable=global-statement
49
+ if _current_circuit is not None:
50
+ if _current_circuit.get_state() not in [
51
+ CircuitState.UNDER_CONSTRUCTION, CircuitState.CLOSED]:
52
+ raise RuntimeError("Cannot reset running circuit")
53
+ # pylint: disable=protected-access
54
+ # break some reference cycles
55
+ _current_circuit._blocks.clear()
56
+ _current_circuit._triggers.clear()
57
+ _current_circuit._errors.clear()
58
+ _current_circuit = None
59
+
60
+
61
+ def unique_name(prefix: str = 'auto') -> str:
62
+ """Add a numeric suffix to make the name unique."""
63
+ return get_circuit().rz_unique_name(prefix)
64
+
65
+
66
+ class CircuitState(enum.IntEnum):
67
+ """
68
+ Circuit state.
69
+
70
+ The integer value may only increase during the circuit's life-cycle.
71
+ """
72
+
73
+ UNDER_CONSTRUCTION = 0 # being built, the runner is not started yet
74
+ INIT_CIRCUIT = 1 # the runner initializes itself
75
+ INIT_BLOCKS = 2 # runner is started, now initializing blocks and triggers
76
+ RUNNING = 3 # the circuit is running
77
+ SHUTDOWN = 4 # shutting down
78
+ CLOSED = 5 # runner has exited
79
+
80
+
81
+ class _TerminateTaskGroup(Exception):
82
+ """Exception raised to terminate a task group."""
83
+
84
+
85
+ @contextlib.contextmanager
86
+ def error_debug(item: Block|Formula|Trigger, suppress_error: bool = False) -> t.Iterator[None]:
87
+ """Add a note to raised exception -or- log and suppress exception."""
88
+ try:
89
+ yield None
90
+ except Exception as err:
91
+ if not suppress_error:
92
+ err.add_note(f"This {type(err).__name__} occurred in {item}")
93
+ raise
94
+ # errors should be suppressed only during the shutdown & cleanup
95
+ _logger.error("[Circuit] %s: Suppressing %s: %s", item, type(err).__name__, err)
96
+
97
+
98
+ class Circuit:
99
+ """
100
+ A container of all blocks.
101
+
102
+ In this implementation, circuit blocks can be added, but not removed.
103
+ """
104
+
105
+ def __init__(self) -> None:
106
+ self._state = CircuitState.UNDER_CONSTRUCTION
107
+ self._state_change: asyncio.Event|None = None # used by .reached_state
108
+ self._blocks: dict[str, Block|Formula] = {}
109
+ # all Blocks and Formulas belonging to this circuit stored by name
110
+ self._triggers: list[Trigger] = [] # all triggers belonging to this circuit
111
+ self._errors: list[Exception] = [] # exceptions occurred in the runner
112
+ self.rz_persistent_dict: MutableMapping[str, t.Any]|None = None
113
+ # persistent state data back-end
114
+ self.rz_persistent_ts: float|None = None
115
+ # timestamp of persistent data (Unix clock)
116
+ self._start_ts: float|None = None
117
+ # runner's start timestamp (monotonic clock)
118
+ self._auto_cancel_tasks: set[asyncio.Task[t.Any]] = set()
119
+ # interval for checkpointing
120
+ self._sync_time = 250.0
121
+
122
+ def log_debug1(self, msg: str, *args: t.Any, **kwargs: t.Any) -> None:
123
+ """Log a message if debugging is enabled."""
124
+ if get_debug_level() >= 1:
125
+ _logger.debug("[Circuit] "+ msg, *args, **kwargs)
126
+
127
+ def log_debug2(self, msg: str, *args: t.Any, **kwargs: t.Any) -> None:
128
+ """Log a message if verbose debugging is enabled."""
129
+ if get_debug_level() >= 2:
130
+ _logger.debug("[Circuit] "+ msg, *args, **kwargs)
131
+
132
+ def log_info(self, msg: str, *args: t.Any, **kwargs: t.Any) -> None:
133
+ """Log a message with _INFO_ priority."""
134
+ _logger.info("[Circuit] "+ msg, *args, **kwargs)
135
+
136
+ def log_warning(self, msg: str, *args: t.Any, **kwargs: t.Any) -> None:
137
+ """Log a message with _WARNING_ priority."""
138
+ _logger.warning("[Circuit] "+ msg, *args, **kwargs)
139
+
140
+ def log_error(self, msg: str, *args: t.Any, **kwargs: t.Any) -> None:
141
+ """Log a message with _ERROR_ priority."""
142
+ _logger.error("[Circuit] "+ msg, *args, **kwargs)
143
+
144
+ def _log_debug2_blocks(
145
+ self,
146
+ msg: str,
147
+ *args: Sequence[Block]|Sequence[Formula]|Sequence[Trigger]
148
+ ) -> None:
149
+ """Log a debug message with a block count and a block name list."""
150
+ if get_debug_level() < 2:
151
+ return
152
+ parts = []
153
+ for ilist in args:
154
+ if (cnt := len(ilist)) == 0:
155
+ continue
156
+ itype = type(ilist[0])
157
+ plural = "" if cnt == 1 else "s"
158
+ if issubclass(itype, Block):
159
+ names = ', '.join(b.name for b in ilist) # type: ignore[union-attr]
160
+ parts.append(f"{cnt} block{plural}: {names}")
161
+ elif itype is Formula:
162
+ # issubclass is not needed, because Formula and Trigger are final classes
163
+ names = ', '.join(f.name for f in ilist) # type: ignore[union-attr]
164
+ parts.append(f"{cnt} formula{plural}: {names}")
165
+ elif itype is Trigger:
166
+ parts.append(f"{cnt} trigger{plural}")
167
+ if parts:
168
+ _logger.debug("[Circuit] %s. Processing %s", msg, "; ".join(parts))
169
+
170
+ # --- circuit components storage ---
171
+
172
+ def rz_add_item(self, item: Block|Formula|Trigger) -> None:
173
+ """Add a circuit item."""
174
+ self._check_not_started()
175
+ if isinstance(item, Trigger):
176
+ self._triggers.append(item)
177
+ return
178
+ if not isinstance(item, (Block, Formula)):
179
+ raise TypeError(
180
+ f"Expected a circuit component (Block/Formula/Trigger), but got {item!r}")
181
+ if item.name in self._blocks:
182
+ raise ValueError(f"Duplicate name '{item.name}'")
183
+ self._blocks[item.name] = item
184
+
185
+ _Block = t.TypeVar("_Block", bound = Block)
186
+ @t.overload
187
+ def get_items(self, btype: type[Formula]) -> Iterable[Formula]: ...
188
+ @t.overload
189
+ def get_items(self, btype: type[Trigger]) -> Iterable[Trigger]: ...
190
+ @t.overload
191
+ def get_items(self, btype: type[_Block]) -> Iterable[_Block]: ...
192
+ def get_items(
193
+ self, btype: type[Formula|Trigger|Block]) -> Iterable[Formula|Trigger|Block]:
194
+ """
195
+ Return an iterable of circuit components of selected type *btype*.
196
+
197
+ The returned iterable might be a generator.
198
+ """
199
+ # no issubclass(), because Trigger is a "final" class
200
+ if btype is Trigger:
201
+ return self._triggers
202
+ if not isinstance(btype, type) or not issubclass(btype, (Block, Formula)):
203
+ raise TypeError(f"Expected a circuit component type, got {btype!r}")
204
+ return (item for item in self._blocks.values() if isinstance(item, btype))
205
+
206
+ @t.overload
207
+ def resolve_name(self, ref: Block) -> Block: ...
208
+ @t.overload
209
+ def resolve_name(self, ref: Formula) -> Formula: ...
210
+ @t.overload
211
+ def resolve_name(self, ref: str) -> Block|Formula: ...
212
+ def resolve_name(self, ref: Block|Formula|str) -> Block|Formula:
213
+ """
214
+ Resolve a reference by name if necessary.
215
+
216
+ If *ref* is a string, return circuit's block or formula with that name.
217
+ Raise a KeyError when not found. Special case during initialization:
218
+ create internal blocks on demand.
219
+
220
+ Return *ref* unchanged if it is already a valid block or formula object.
221
+ """
222
+ if isinstance(ref, (Block, Formula)):
223
+ return ref # name already resolved
224
+ if not isinstance(ref, str):
225
+ raise TypeError(f"Expected a name (string), got {ref!r}")
226
+ try:
227
+ return self._blocks[ref]
228
+ except KeyError:
229
+ if self._state <= CircuitState.INIT_BLOCKS:
230
+ if ref == '_cron_local':
231
+ return Cron(ref, utc=False, comment="Time scheduler (local time)")
232
+ if ref == '_cron_utc':
233
+ return Cron(ref, utc=True, comment="Time scheduler (UTC time)")
234
+ raise KeyError(f"No block or formula named '{ref}' found") from None
235
+
236
+ def rz_unique_name(self, prefix: str) -> str:
237
+ """Add a numeric suffix to make the name unique."""
238
+ check_identifier(prefix, "Block/Formula name prefix")
239
+ delim = '' if prefix.endswith('_') else '_'
240
+ num = sum(1 for n in self._blocks if n.startswith(prefix))
241
+ while True:
242
+ if (name := f"{prefix}{delim}{num}") not in self._blocks:
243
+ return name
244
+ num += 1
245
+ # not reached
246
+
247
+ # --- State management ---
248
+
249
+ def get_state(self) -> CircuitState:
250
+ """Get the circuit's state."""
251
+ return self._state
252
+
253
+ def _set_state(self, newstate: CircuitState) -> None:
254
+ """Set the circuit's state."""
255
+ if newstate <= self._state:
256
+ # This is not allowed, but in the same time, it's not an error. Usually
257
+ # it happens when there was an abort(), but the runner wasn't notified yet.
258
+ return
259
+ self._state = newstate
260
+ self.log_debug2("State: %s", newstate.name)
261
+
262
+ if self._state_change is not None:
263
+ self._state_change.set()
264
+ self._state_change = None
265
+
266
+ async def reached_state(self, state: CircuitState) -> bool:
267
+ """
268
+ Async synchronization tool.
269
+
270
+ Wait until the DESIRED OR HIGHER state is reached.
271
+ """
272
+ if not isinstance(state, CircuitState):
273
+ raise TypeError(f"Expected CircuitState, got {state!r}")
274
+ while self._state < state:
275
+ if self._state_change is None:
276
+ self._state_change = asyncio.Event()
277
+ await self._state_change.wait()
278
+ return self._state == state
279
+
280
+ def _check_not_started(self) -> None:
281
+ """Raise an error if the circuit runner has started already."""
282
+ if self._state == CircuitState.CLOSED:
283
+ # A circuit may be closed before start (see shutdown),
284
+ # let's use this message instead of the one below.
285
+ raise RuntimeError("The circuit was closed")
286
+ # allow adding special blocks in the INIT_CIRCUIT state
287
+ if self._state > CircuitState.INIT_CIRCUIT:
288
+ raise RuntimeError("Not allowed after the start")
289
+
290
+ def after_shutdown(self) -> bool:
291
+ """Test if we are past the shutdown() call."""
292
+ return self._state >= CircuitState.SHUTDOWN
293
+
294
+ async def _checkpointing_service(self, blocks: Sequence[Block]) -> None:
295
+ while self._state <= CircuitState.RUNNING:
296
+ await asyncio.sleep(self._sync_time)
297
+ now = time.time()
298
+ if self._state is CircuitState.RUNNING:
299
+ for blk in blocks:
300
+ self.save_persistent_state(blk, now)
301
+
302
+ def set_persistent_storage(
303
+ self,
304
+ persistent_dict: MutableMapping[str, t.Any]|None,
305
+ *,
306
+ sync_time: float|str|None = None
307
+ ) -> None:
308
+ """Setup the persistent state data storage."""
309
+ self._check_not_started()
310
+ self.rz_persistent_dict = persistent_dict
311
+ if sync_time is not None:
312
+ self._sync_time = time_period(sync_time)
313
+
314
+ def _check_persistent_storage(self) -> None:
315
+ """Check persistent state related settings."""
316
+ storage = self.rz_persistent_dict
317
+ ps_blocks = [blk for blk in self.get_items(Block) if blk.rz_persistence]
318
+ if storage is None:
319
+ if ps_blocks:
320
+ self.log_warning("No data storage, disabling state persistence")
321
+ for blk in ps_blocks:
322
+ blk.rz_persistence = PersistenceFlags(0)
323
+ return
324
+ # clear the unused items
325
+ used_keys = {pblk.rz_key for pblk in ps_blocks}
326
+ for key in list(storage.keys()):
327
+ if key not in used_keys:
328
+ self.log_debug2("Removing unused persistent state for '%s'", key)
329
+ del storage[key]
330
+ # start checkpointing if necessary
331
+ ch_blocks = [
332
+ blk for blk in self.get_items(Block)
333
+ if blk.rz_persistence & PersistenceFlags.INTERVAL]
334
+ if ch_blocks:
335
+ self.create_service(
336
+ self._checkpointing_service(ch_blocks), name="Checkpointing service")
337
+
338
+ # --- init/shutdown helpers ---
339
+
340
+ def _init_block_core(self, blk: Block) -> t.Iterator[AsyncInitializer]:
341
+ """
342
+ Initialize with available initializers.
343
+
344
+ Run sync initializers immediately. Yield async initializers
345
+ for further processing.
346
+
347
+ Persistent state is handled elsewhere.
348
+ """
349
+ for init in blk.rz_initializers:
350
+ if not blk.is_undef():
351
+ return
352
+ if isinstance(init, AsyncInitializer):
353
+ yield init
354
+ else:
355
+ init.apply_to(blk)
356
+ if blk.is_undef() and blk.has_method('rz_init_default'):
357
+ blk.log_debug2("Calling the built-in default initializer")
358
+ blk.rz_init_default()
359
+
360
+ def init_block_sync(self, blk: Block) -> None:
361
+ """Initialize a Block excluding async initializers."""
362
+ for _ in self._init_block_core(blk):
363
+ pass
364
+
365
+ async def init_block_async(self, blk: Block) -> None:
366
+ """Initialize a Block including async initializers."""
367
+ for initializer in self._init_block_core(blk):
368
+ task = asyncio.create_task(
369
+ initializer.async_apply_to(blk),
370
+ name=f"initializer {type(initializer).__name__} for block '{blk.name}'")
371
+ blk.rz_set_inittask(task)
372
+ try:
373
+ await task
374
+ except asyncio.CancelledError:
375
+ # [union-attr]: asyncio.current_task() cannot return None here
376
+ if asyncio.current_task().cancelling() > 0: # type: ignore[union-attr]
377
+ raise
378
+
379
+ async def _init_blocks(self, blocks: Sequence[Block]) -> None:
380
+ """
381
+ Initialize multiple logic blocks.
382
+
383
+ Run async initializations concurrently.
384
+ """
385
+ # Init from value provided by initializers (specified with initial=... or built-in)
386
+ uninitialized = [blk for blk in blocks if blk.is_undef()]
387
+ sync_blocks: list[Block] = []
388
+ async_blocks: list[Block] = []
389
+ for blk in uninitialized:
390
+ async_init = any(isinstance(init, AsyncInitializer) for init in blk.rz_initializers)
391
+ (async_blocks if async_init else sync_blocks).append(blk)
392
+ if sync_blocks:
393
+ self._log_debug2_blocks(
394
+ "Initializing blocks having sync initializers only", sync_blocks)
395
+ for blk in sync_blocks:
396
+ self.init_block_sync(blk)
397
+ if async_blocks:
398
+ self._log_debug2_blocks(
399
+ "Initializing blocks having some async initializers", async_blocks)
400
+ async with asyncio.TaskGroup() as tg:
401
+ for blk in async_blocks:
402
+ tg.create_task(self.init_block_async(blk))
403
+ # final check
404
+ for blk in blocks:
405
+ if blk.is_undef():
406
+ raise RuntimeError(f"Block '{blk.name}' was not initialized")
407
+
408
+ def save_persistent_state(self, blk: Block, now: float|None = None) -> None:
409
+ """
410
+ Save persistent state.
411
+
412
+ It is assumed the block has the persistent state feature
413
+ enabled and the storage is ready.
414
+ """
415
+ assert self.rz_persistent_dict is not None
416
+ if blk.is_undef():
417
+ blk.log_debug2("Not saving undefined state")
418
+ if now is None:
419
+ now = time.time()
420
+ try:
421
+ if (state := blk.rz_export_state()) is UNDEF:
422
+ blk.error("Exported state was <UNDEF>")
423
+ return
424
+ self.rz_persistent_dict[blk.rz_key] = [state, now]
425
+ except Exception as err:
426
+ blk.log_error("Saving state failed with %r", err)
427
+
428
+ async def _shutdown_block_async(self, blk: Block) -> None:
429
+ """Shutdown a Block."""
430
+ with error_debug(blk, suppress_error=True):
431
+ async with asyncio.timeout(blk.rz_stop_timeout):
432
+ await blk.rz_astop()
433
+
434
+ # --- runner ---
435
+
436
+ def runtime(self) -> float:
437
+ """
438
+ Return seconds since runner's start.
439
+
440
+ Return 0.0 if it hasn't started yet.
441
+ """
442
+ return 0.0 if self._start_ts is None else time.monotonic() - self._start_ts
443
+
444
+ async def _runner_init(self) -> None:
445
+ """Run the circuit during the initialization phase."""
446
+ self._set_state(CircuitState.INIT_CIRCUIT)
447
+ await asyncio.sleep(0) # allow reached_state() synchronization
448
+ if self.after_shutdown():
449
+ # It looks like a supporting task has failed immediately after the start
450
+ return
451
+
452
+ pe_blocks = [blk for blk in self.get_items(Block) if blk.has_method('rz_pre_init')]
453
+ pe_formulas = list(self.get_items(Formula))
454
+ pe_triggers = list(self.get_items(Trigger))
455
+ self._log_debug2_blocks("Pre-initializing", pe_blocks, pe_formulas, pe_triggers)
456
+ for pe in itertools.chain(pe_blocks, pe_formulas, pe_triggers):
457
+ assert isinstance(pe, (Block, Formula, Trigger)) # @mypy
458
+ with error_debug(pe):
459
+ # union-attr: checked with .has_method()
460
+ pe.rz_pre_init() # type: ignore[union-attr]
461
+
462
+ self._set_state(CircuitState.INIT_BLOCKS)
463
+ await asyncio.sleep(0)
464
+ await self._init_blocks(list(self.get_items(Block)))
465
+
466
+ start_blocks = [blk for blk in self.get_items(Block) if blk.has_method('rz_start')]
467
+ start_formulas = list(self.get_items(Formula))
468
+ start_triggers = list(self.get_items(Trigger))
469
+ self._log_debug2_blocks("Starting", start_formulas, start_triggers, start_blocks)
470
+ # starting blocks after formulas and triggers
471
+ for start in itertools.chain(start_formulas, start_triggers, start_blocks):
472
+ assert isinstance(start, (Block, Formula, Trigger)) # @mypy
473
+ with error_debug(start):
474
+ # union-attr: checked with .has_method()
475
+ start.rz_start() # type: ignore[union-attr]
476
+
477
+ if self.rz_persistent_dict is not None:
478
+ # initial checkpoints
479
+ ch_blocks = [
480
+ blk for blk in self.get_items(Block)
481
+ if blk.rz_persistence & (PersistenceFlags.EVENT | PersistenceFlags.INTERVAL)]
482
+ if ch_blocks:
483
+ now = time.time()
484
+ for blk in ch_blocks:
485
+ self.save_persistent_state(blk, now)
486
+
487
+ async def _runner_shutdown(self) -> None:
488
+ """Run the circuit during the shutdown."""
489
+ if self.rz_persistent_dict is not None:
490
+ # save the state first, because stop may invalidate the state information
491
+ ps_blocks = [blk for blk in self.get_items(Block) if blk.rz_persistence]
492
+ if ps_blocks:
493
+ self._log_debug2_blocks("Saving persistent state", ps_blocks)
494
+ now = time.time()
495
+ for blk in ps_blocks:
496
+ self.save_persistent_state(blk, now)
497
+
498
+ stop_triggers = list(self.get_items(Trigger))
499
+ stop_blocks = [blk for blk in self.get_items(Block) if blk.has_method('rz_stop')]
500
+ if stop_blocks:
501
+ self._log_debug2_blocks("Stopping (sync)", stop_triggers, stop_blocks)
502
+ for stop in itertools.chain(stop_triggers, stop_blocks):
503
+ assert isinstance(stop, (Block, Trigger)) # @mypy
504
+ with error_debug(stop, suppress_error=True):
505
+ # union-attr: checked with .has_method()
506
+ stop.rz_stop() # type: ignore[union-attr]
507
+
508
+ if self._auto_cancel_tasks:
509
+ self.log_debug2("Cancelling %d service task(s)")
510
+ for task in self._auto_cancel_tasks:
511
+ if not task.done():
512
+ task.cancel()
513
+ await asyncio.sleep(0)
514
+ for task in self._auto_cancel_tasks:
515
+ if not task.done():
516
+ self.log_warning("Canceled service task %s did not terminate", task)
517
+
518
+ stop_blocks = [
519
+ blk for blk in self.get_items(Block)
520
+ if blk.has_method('rz_astop', async_method=True)]
521
+ if stop_blocks:
522
+ self._log_debug2_blocks("Stopping (async)", stop_blocks)
523
+ async with asyncio.TaskGroup() as tg:
524
+ for blk in stop_blocks:
525
+ tg.create_task(self._shutdown_block_async(blk))
526
+
527
+ close_blocks = [blk for blk in self.get_items(Block) if blk.has_method('rz_close')]
528
+ if close_blocks:
529
+ self._log_debug2_blocks("Closing", close_blocks)
530
+ for close in close_blocks:
531
+ with error_debug(close, suppress_error=True):
532
+ close.rz_close()
533
+
534
+ async def _runner_core(self) -> t.NoReturn:
535
+ """
536
+ Run the circuit until shutdown/abort, then clean up.
537
+
538
+ _runner_core() never exits normally without an exception.
539
+ It must be cancelled to switch from running to shutting down.
540
+
541
+ Please note that the cleanup could take some time depending
542
+ on the outputs' .rz_stop_timeout values.
543
+
544
+ When the runner terminates, it cannot be invoked again.
545
+ """
546
+ if self._state == CircuitState.CLOSED:
547
+ raise RuntimeError("Cannot restart a closed circuit.")
548
+ if self._state != CircuitState.UNDER_CONSTRUCTION:
549
+ raise RuntimeError("The circuit is already running.")
550
+ if not self._blocks:
551
+ raise RuntimeError("The circuit is empty")
552
+ if tasks_are_eager():
553
+ self.log_debug2("Eager asyncio tasks detected")
554
+ self._check_persistent_storage()
555
+ self._start_ts = time.monotonic()
556
+ try:
557
+ try:
558
+ await self._runner_init()
559
+ except Exception as err:
560
+ self.abort(err)
561
+ else:
562
+ # There might be errors reported with abort().
563
+ # In such case the state has been set to SHUTDOWN.
564
+ # _set_state(RUNNING) will be silently ignored.
565
+ self._set_state(CircuitState.RUNNING)
566
+ # wait until cancelled from the task group; possible causes:
567
+ # 1. shutdown() or abort()
568
+ # 2. failed supporting task (this includes unexpected termination)
569
+ # 3. cancellation of the task group itself
570
+ await asyncio.Future()
571
+ except asyncio.CancelledError:
572
+ # will be re-raised at the end if there won't be other exceptions
573
+ pass
574
+ # cancellation causes 2 and 3 do not modify the state
575
+ if not self.after_shutdown():
576
+ self._set_state(CircuitState.SHUTDOWN)
577
+ await asyncio.sleep(0)
578
+ try:
579
+ await self._runner_shutdown()
580
+ except Exception as err:
581
+ # If an exception is propagated from _runner_shutdown, it is probably a bug.
582
+ # Calling abort is not necessary when shutting down, but the call will log
583
+ # and register the exception to be included in the final ExceptionGroup.
584
+ self.abort(err)
585
+ self._set_state(CircuitState.CLOSED)
586
+
587
+ if self._errors:
588
+ raise ExceptionGroup("_runner_core() errors", self._errors)
589
+ raise asyncio.CancelledError()
590
+
591
+ # --- abort/shutdown ---
592
+
593
+ def abort(self, err: Exception) -> None:
594
+ """
595
+ Abort the circuit runner due to an error.
596
+
597
+ abort() is necessary only when an exception isn't propagated
598
+ to the runner.
599
+ """
600
+ if not isinstance(err, Exception):
601
+ # one more reason to abort
602
+ err = TypeError(f'abort(): expected an exception, got {err!r}')
603
+ if err in self._errors:
604
+ # the same error may be reported from several places
605
+ return
606
+ self._errors.append(err)
607
+ if self.after_shutdown():
608
+ self.log_error("Unhandled error during shutdown: %r", err)
609
+ else:
610
+ self.log_warning("Aborting due to an exception: %r", err)
611
+ self.shutdown()
612
+
613
+ def shutdown(self) -> None:
614
+ """
615
+ Stop the runner if it was started.
616
+
617
+ Prevent the runner from starting if it wasn't started yet.
618
+ """
619
+ if self.after_shutdown():
620
+ return
621
+ if self._state == CircuitState.UNDER_CONSTRUCTION:
622
+ self._set_state(CircuitState.CLOSED)
623
+ return
624
+ self._set_state(CircuitState.SHUTDOWN)
625
+ # The shutdown monitor will be awakened and exits with an error. The task
626
+ # group will detect it and cancel the runner and its supporting tasks.
627
+
628
+ async def watchdog(
629
+ self,
630
+ coro: Coroutine[t.Any, t.Any, t.Any],
631
+ immediate_start: bool,
632
+ name: str|None
633
+ ) -> None:
634
+ """
635
+ Detect *coro* termination before shutdown and treat it as an error. Add logging.
636
+
637
+ This is a low-level function for create_service.
638
+ """
639
+ this_task = asyncio.current_task()
640
+ assert this_task is not None # @mypy
641
+ if name is None:
642
+ name = this_task.get_name()
643
+ elif this_task.get_name() != name:
644
+ this_task.set_name(name)
645
+ longname = f"Task '{name}' running '{coro.__name__}'"
646
+ if not immediate_start:
647
+ self.log_debug2("%s waiting for RUNNING state", longname)
648
+ if not await self.reached_state(CircuitState.RUNNING):
649
+ self.log_debug1("%s not started", longname)
650
+ coro.close() # won't be awaited; prevent a warning about that
651
+ # Failed start! The return value does not matter now.
652
+ # No abort() here, because this is not an error, it's a consequence.
653
+ return
654
+ self.log_debug1("%s started", longname)
655
+ try:
656
+ await coro # return value of a service is ignored
657
+ except asyncio.CancelledError:
658
+ if self.after_shutdown():
659
+ self.log_debug1("%s was cancelled", longname)
660
+ raise
661
+ err = RuntimeError(f"{longname} was cancelled before shutdown")
662
+ self.abort(err)
663
+ raise err from None
664
+ except Exception as err:
665
+ err.add_note(f"Error occurred in {longname}")
666
+ self.abort(err)
667
+ raise
668
+ if self.after_shutdown():
669
+ self.log_debug1("%s terminated", longname)
670
+ return
671
+ exc = RuntimeError(f"{longname} terminated before shutdown")
672
+ self.abort(exc)
673
+ raise exc
674
+
675
+ def create_service(
676
+ self, coro: Coroutine[t.Any, t.Any, t.Any],
677
+ immediate_start: bool = False,
678
+ auto_cancel: bool = True,
679
+ **task_kwargs
680
+ ) -> asyncio.Task[None]:
681
+ """Create a service task for the circuit."""
682
+ if self.after_shutdown():
683
+ raise RuntimeError("Cannot create a service after shutdown")
684
+ check_async_coro(coro)
685
+ # Python 3.12 and 3.13 only: Eager tasks start to run before their name is set.
686
+ # As a workaround we tell the watchdog wrapper the name.
687
+ task = asyncio.create_task(
688
+ self.watchdog(coro, immediate_start, task_kwargs.get('name')), **task_kwargs)
689
+ if auto_cancel:
690
+ self._auto_cancel_tasks.add(task)
691
+ # mark exceptions as consumed, because they were reported with abort
692
+ task.add_done_callback(lambda t: None if t.cancelled() else t.exception())
693
+ return task
694
+
695
+ async def _shutdown_monitor(self) -> t.NoReturn:
696
+ """
697
+ Helper task: exit with an error when shutdown starts.
698
+
699
+ The failure of the shutdown monitor task cancels the task group.
700
+ (if it wasn't cancelling already). The _TerminateTaskGroup
701
+ error will be filtered out.
702
+ """
703
+ await self.reached_state(CircuitState.SHUTDOWN)
704
+ raise _TerminateTaskGroup()
705
+
706
+ async def rz_runner(self) -> None:
707
+ """Run the circuit."""
708
+ async with asyncio.TaskGroup() as tg:
709
+ tg.create_task(self._shutdown_monitor(), name="Shutdown monitor")
710
+ tg.create_task(self._runner_core(), name="Circuit runner core")
711
+ raise _TerminateTaskGroup()
712
+
713
+
714
+ _ET = t.TypeVar("_ET", bound=BaseException)
715
+
716
+ def leaf_exceptions(group: BaseExceptionGroup[_ET]) -> list[_ET]:
717
+ """
718
+ Return a flat list of all 'leaf' exceptions.
719
+
720
+ Not using techniques from PEP-785, but leaving tracebacks unmodified.
721
+ """
722
+ result = []
723
+ for exc in group.exceptions:
724
+ if isinstance(exc, BaseExceptionGroup):
725
+ result.extend(leaf_exceptions(exc))
726
+ else:
727
+ result.append(exc)
728
+ return result
729
+
730
+
731
+ async def run(*coroutines: Coroutine[t.Any, t.Any, t.Any], catch_sigterm: bool = True) -> None:
732
+ """
733
+ The main entry point(). Run the circuit together with supporting coroutines.
734
+
735
+ If errors occur, raise an exception group with all exceptions.
736
+ """
737
+ with TerminatingSignal(signal.SIGTERM) if catch_sigterm else contextlib.nullcontext():
738
+ circuit = get_circuit()
739
+ try:
740
+ async with asyncio.TaskGroup() as tg:
741
+ tg.create_task(circuit.rz_runner(), name="Circuit runner")
742
+ task_cnt = len(coroutines)
743
+ for i, coro in enumerate(coroutines, start=1):
744
+ name = "Supporting task"
745
+ if task_cnt > 1:
746
+ name += f" {i}/{task_cnt}"
747
+ tg.create_task(
748
+ circuit.watchdog(coro, immediate_start=True, name=name), name=name)
749
+ except ExceptionGroup as eg:
750
+ exceptions: list[Exception] = []
751
+ for exc in leaf_exceptions(eg):
752
+ if not isinstance(exc, _TerminateTaskGroup) and exc not in exceptions:
753
+ exceptions.append(exc)
754
+ if exceptions:
755
+ raise ExceptionGroup("Circuit runner exceptions", exceptions) from None
756
+ circuit.log_debug2("Terminated normally")