lsst-pipe-base 30.0.0rc2__py3-none-any.whl → 30.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. lsst/pipe/base/_instrument.py +31 -20
  2. lsst/pipe/base/_quantumContext.py +3 -3
  3. lsst/pipe/base/_status.py +43 -10
  4. lsst/pipe/base/_task_metadata.py +2 -2
  5. lsst/pipe/base/all_dimensions_quantum_graph_builder.py +8 -3
  6. lsst/pipe/base/automatic_connection_constants.py +20 -1
  7. lsst/pipe/base/cli/cmd/__init__.py +18 -2
  8. lsst/pipe/base/cli/cmd/commands.py +149 -4
  9. lsst/pipe/base/connectionTypes.py +72 -160
  10. lsst/pipe/base/connections.py +6 -9
  11. lsst/pipe/base/execution_reports.py +0 -5
  12. lsst/pipe/base/graph/graph.py +11 -10
  13. lsst/pipe/base/graph/quantumNode.py +4 -4
  14. lsst/pipe/base/graph_walker.py +8 -10
  15. lsst/pipe/base/log_capture.py +40 -80
  16. lsst/pipe/base/log_on_close.py +76 -0
  17. lsst/pipe/base/mp_graph_executor.py +51 -15
  18. lsst/pipe/base/pipeline.py +5 -6
  19. lsst/pipe/base/pipelineIR.py +2 -8
  20. lsst/pipe/base/pipelineTask.py +5 -7
  21. lsst/pipe/base/pipeline_graph/_dataset_types.py +2 -2
  22. lsst/pipe/base/pipeline_graph/_edges.py +32 -22
  23. lsst/pipe/base/pipeline_graph/_mapping_views.py +4 -7
  24. lsst/pipe/base/pipeline_graph/_pipeline_graph.py +14 -7
  25. lsst/pipe/base/pipeline_graph/expressions.py +2 -2
  26. lsst/pipe/base/pipeline_graph/io.py +7 -10
  27. lsst/pipe/base/pipeline_graph/visualization/_dot.py +13 -12
  28. lsst/pipe/base/pipeline_graph/visualization/_layout.py +16 -18
  29. lsst/pipe/base/pipeline_graph/visualization/_merge.py +4 -7
  30. lsst/pipe/base/pipeline_graph/visualization/_printer.py +10 -10
  31. lsst/pipe/base/pipeline_graph/visualization/_status_annotator.py +7 -0
  32. lsst/pipe/base/prerequisite_helpers.py +2 -1
  33. lsst/pipe/base/quantum_graph/_common.py +19 -20
  34. lsst/pipe/base/quantum_graph/_multiblock.py +37 -31
  35. lsst/pipe/base/quantum_graph/_predicted.py +113 -15
  36. lsst/pipe/base/quantum_graph/_provenance.py +1136 -45
  37. lsst/pipe/base/quantum_graph/aggregator/__init__.py +0 -1
  38. lsst/pipe/base/quantum_graph/aggregator/_communicators.py +204 -289
  39. lsst/pipe/base/quantum_graph/aggregator/_config.py +87 -9
  40. lsst/pipe/base/quantum_graph/aggregator/_ingester.py +13 -12
  41. lsst/pipe/base/quantum_graph/aggregator/_scanner.py +49 -235
  42. lsst/pipe/base/quantum_graph/aggregator/_structs.py +6 -116
  43. lsst/pipe/base/quantum_graph/aggregator/_supervisor.py +29 -39
  44. lsst/pipe/base/quantum_graph/aggregator/_workers.py +303 -0
  45. lsst/pipe/base/quantum_graph/aggregator/_writer.py +34 -351
  46. lsst/pipe/base/quantum_graph/formatter.py +171 -0
  47. lsst/pipe/base/quantum_graph/ingest_graph.py +413 -0
  48. lsst/pipe/base/quantum_graph/visualization.py +5 -1
  49. lsst/pipe/base/quantum_graph_builder.py +33 -9
  50. lsst/pipe/base/quantum_graph_executor.py +116 -13
  51. lsst/pipe/base/quantum_graph_skeleton.py +31 -35
  52. lsst/pipe/base/quantum_provenance_graph.py +29 -12
  53. lsst/pipe/base/separable_pipeline_executor.py +19 -3
  54. lsst/pipe/base/single_quantum_executor.py +67 -42
  55. lsst/pipe/base/struct.py +4 -0
  56. lsst/pipe/base/testUtils.py +3 -3
  57. lsst/pipe/base/tests/mocks/_storage_class.py +2 -1
  58. lsst/pipe/base/version.py +1 -1
  59. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/METADATA +3 -3
  60. lsst_pipe_base-30.0.1.dist-info/RECORD +129 -0
  61. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/WHEEL +1 -1
  62. lsst_pipe_base-30.0.0rc2.dist-info/RECORD +0 -125
  63. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/entry_points.txt +0 -0
  64. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/licenses/COPYRIGHT +0 -0
  65. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/licenses/LICENSE +0 -0
  66. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/licenses/bsd_license.txt +0 -0
  67. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/licenses/gpl-v3.0.txt +0 -0
  68. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/top_level.txt +0 -0
  69. {lsst_pipe_base-30.0.0rc2.dist-info → lsst_pipe_base-30.0.1.dist-info}/zip-safe +0 -0
@@ -27,68 +27,16 @@
27
27
 
28
28
  from __future__ import annotations
29
29
 
30
- __all__ = (
31
- "InProgressScan",
32
- "IngestRequest",
33
- "ScanReport",
34
- "ScanStatus",
35
- "WriteRequest",
36
- )
30
+ __all__ = ("IngestRequest", "ScanReport")
37
31
 
38
32
  import dataclasses
39
- import enum
40
33
  import uuid
41
34
 
35
+ from lsst.daf.butler import DatasetRef
42
36
  from lsst.daf.butler.datastore.record_data import DatastoreRecordData
43
37
 
44
38
  from .._common import DatastoreName
45
- from .._predicted import PredictedDatasetModel
46
- from .._provenance import (
47
- ProvenanceLogRecordsModel,
48
- ProvenanceQuantumAttemptModel,
49
- ProvenanceTaskMetadataModel,
50
- )
51
-
52
-
53
- class ScanStatus(enum.Enum):
54
- """Status enum for quantum scanning.
55
-
56
- Note that this records the status for the *scanning* which is distinct
57
- from the status of the quantum's execution.
58
- """
59
-
60
- INCOMPLETE = enum.auto()
61
- """The quantum is not necessarily done running, and cannot be scanned
62
- conclusively yet.
63
- """
64
-
65
- ABANDONED = enum.auto()
66
- """The quantum's execution appears to have failed but we cannot rule out
67
- the possibility that it could be recovered, but we've also waited long
68
- enough (according to `ScannerTimeConfigDict.retry_timeout`) that it's time
69
- to stop trying for now.
70
-
71
- This state means a later run with `ScannerConfig.assume_complete` is
72
- required.
73
- """
74
-
75
- SUCCESSFUL = enum.auto()
76
- """The quantum was conclusively scanned and was executed successfully,
77
- unblocking scans for downstream quanta.
78
- """
79
-
80
- FAILED = enum.auto()
81
- """The quantum was conclusively scanned and failed execution, blocking
82
- scans for downstream quanta.
83
- """
84
-
85
- BLOCKED = enum.auto()
86
- """A quantum upstream of this one failed."""
87
-
88
- INIT = enum.auto()
89
- """Init quanta need special handling, because they don't have logs and
90
- metadata.
91
- """
39
+ from .._provenance import ProvenanceQuantumScanStatus
92
40
 
93
41
 
94
42
  @dataclasses.dataclass
@@ -98,7 +46,7 @@ class ScanReport:
98
46
  quantum_id: uuid.UUID
99
47
  """Unique ID of the quantum."""
100
48
 
101
- status: ScanStatus
49
+ status: ProvenanceQuantumScanStatus
102
50
  """Combined status of the scan and the execution of the quantum."""
103
51
 
104
52
 
@@ -109,69 +57,11 @@ class IngestRequest:
109
57
  producer_id: uuid.UUID
110
58
  """ID of the quantum that produced these datasets."""
111
59
 
112
- datasets: list[PredictedDatasetModel]
60
+ refs: list[DatasetRef]
113
61
  """Registry information about the datasets."""
114
62
 
115
63
  records: dict[DatastoreName, DatastoreRecordData]
116
64
  """Datastore information about the datasets."""
117
65
 
118
66
  def __bool__(self) -> bool:
119
- return bool(self.datasets or self.records)
120
-
121
-
122
- @dataclasses.dataclass
123
- class InProgressScan:
124
- """A struct that represents a quantum that is being scanned."""
125
-
126
- quantum_id: uuid.UUID
127
- """Unique ID for the quantum."""
128
-
129
- status: ScanStatus
130
- """Combined status for the scan and the execution of the quantum."""
131
-
132
- attempts: list[ProvenanceQuantumAttemptModel] = dataclasses.field(default_factory=list)
133
- """Provenance information about each attempt to run the quantum."""
134
-
135
- outputs: dict[uuid.UUID, bool] = dataclasses.field(default_factory=dict)
136
- """Unique IDs of the output datasets mapped to whether they were actually
137
- produced.
138
- """
139
-
140
- metadata: ProvenanceTaskMetadataModel = dataclasses.field(default_factory=ProvenanceTaskMetadataModel)
141
- """Task metadata information for each attempt.
142
- """
143
-
144
- logs: ProvenanceLogRecordsModel = dataclasses.field(default_factory=ProvenanceLogRecordsModel)
145
- """Log records for each attempt.
146
- """
147
-
148
-
149
- @dataclasses.dataclass
150
- class WriteRequest:
151
- """A struct that represents a request to write provenance for a quantum."""
152
-
153
- quantum_id: uuid.UUID
154
- """Unique ID for the quantum."""
155
-
156
- status: ScanStatus
157
- """Combined status for the scan and the execution of the quantum."""
158
-
159
- existing_outputs: set[uuid.UUID] = dataclasses.field(default_factory=set)
160
- """Unique IDs of the output datasets that were actually written."""
161
-
162
- quantum: bytes = b""
163
- """Serialized quantum provenance model.
164
-
165
- This may be empty for quanta that had no attempts.
166
- """
167
-
168
- metadata: bytes = b""
169
- """Serialized task metadata."""
170
-
171
- logs: bytes = b""
172
- """Serialized logs."""
173
-
174
- is_compressed: bool = False
175
- """Whether the `quantum`, `metadata`, and `log` attributes are
176
- compressed.
177
- """
67
+ return bool(self.refs or self.records)
@@ -42,19 +42,18 @@ from lsst.utils.usage import get_peak_mem_usage
42
42
  from ...graph_walker import GraphWalker
43
43
  from ...pipeline_graph import TaskImportMode
44
44
  from .._predicted import PredictedQuantumGraphComponents, PredictedQuantumGraphReader
45
+ from .._provenance import ProvenanceQuantumScanData, ProvenanceQuantumScanStatus
45
46
  from ._communicators import (
46
47
  IngesterCommunicator,
47
48
  ScannerCommunicator,
48
- SpawnProcessContext,
49
49
  SupervisorCommunicator,
50
- ThreadingContext,
51
- Worker,
52
50
  WriterCommunicator,
53
51
  )
54
52
  from ._config import AggregatorConfig
55
53
  from ._ingester import Ingester
56
54
  from ._scanner import Scanner
57
- from ._structs import ScanReport, ScanStatus, WriteRequest
55
+ from ._structs import ScanReport
56
+ from ._workers import SpawnWorkerFactory, ThreadWorkerFactory
58
57
  from ._writer import Writer
59
58
 
60
59
 
@@ -116,6 +115,17 @@ class Supervisor:
116
115
  self.comms.request_scan(ready_set.pop())
117
116
  for scan_return in self.comms.poll():
118
117
  self.handle_report(scan_return)
118
+ if self.comms.config.incomplete:
119
+ quantum_or_quanta = "quanta" if self.n_abandoned != 1 else "quantum"
120
+ self.comms.progress.log.info(
121
+ "%d %s incomplete/failed abandoned; re-run with incomplete=False to finish.",
122
+ self.n_abandoned,
123
+ quantum_or_quanta,
124
+ )
125
+ self.comms.progress.log.info(
126
+ "Scanning complete after %0.1fs; waiting for workers to finish.",
127
+ self.comms.progress.elapsed_time,
128
+ )
119
129
 
120
130
  def handle_report(self, scan_report: ScanReport) -> None:
121
131
  """Handle a report from a scanner.
@@ -126,18 +136,22 @@ class Supervisor:
126
136
  Information about the scan.
127
137
  """
128
138
  match scan_report.status:
129
- case ScanStatus.SUCCESSFUL | ScanStatus.INIT:
139
+ case ProvenanceQuantumScanStatus.SUCCESSFUL | ProvenanceQuantumScanStatus.INIT:
130
140
  self.comms.log.debug("Scan complete for %s: quantum succeeded.", scan_report.quantum_id)
131
141
  self.walker.finish(scan_report.quantum_id)
132
- case ScanStatus.FAILED:
142
+ case ProvenanceQuantumScanStatus.FAILED:
133
143
  self.comms.log.debug("Scan complete for %s: quantum failed.", scan_report.quantum_id)
134
144
  blocked_quanta = self.walker.fail(scan_report.quantum_id)
135
145
  for blocked_quantum_id in blocked_quanta:
136
- if self.comms.config.output_path is not None:
137
- self.comms.request_write(WriteRequest(blocked_quantum_id, status=ScanStatus.BLOCKED))
146
+ if self.comms.config.is_writing_provenance:
147
+ self.comms.request_write(
148
+ ProvenanceQuantumScanData(
149
+ blocked_quantum_id, status=ProvenanceQuantumScanStatus.BLOCKED
150
+ )
151
+ )
138
152
  self.comms.progress.scans.update(1)
139
153
  self.comms.progress.quantum_ingests.update(len(blocked_quanta))
140
- case ScanStatus.ABANDONED:
154
+ case ProvenanceQuantumScanStatus.ABANDONED:
141
155
  self.comms.log.debug("Abandoning scan for %s: quantum has not succeeded (yet).")
142
156
  self.walker.fail(scan_report.quantum_id)
143
157
  self.n_abandoned += 1
@@ -161,55 +175,31 @@ def aggregate_graph(predicted_path: str, butler_path: str, config: AggregatorCon
161
175
  Configuration for the aggregator.
162
176
  """
163
177
  log = getLogger("lsst.pipe.base.quantum_graph.aggregator")
164
- ctx = ThreadingContext() if config.n_processes == 1 else SpawnProcessContext()
165
- scanners: list[Worker] = []
166
- ingester: Worker
167
- writer: Worker | None = None
168
- with SupervisorCommunicator(log, config.n_processes, ctx, config) as comms:
178
+ worker_factory = ThreadWorkerFactory() if config.n_processes == 1 else SpawnWorkerFactory()
179
+ with SupervisorCommunicator(log, config.n_processes, worker_factory, config) as comms:
169
180
  comms.progress.log.verbose("Starting workers.")
170
- if config.output_path is not None:
181
+ if config.is_writing_provenance:
171
182
  writer_comms = WriterCommunicator(comms)
172
- writer = ctx.make_worker(
183
+ comms.workers[writer_comms.name] = worker_factory.make_worker(
173
184
  target=Writer.run,
174
185
  args=(predicted_path, writer_comms),
175
186
  name=writer_comms.name,
176
187
  )
177
- writer.start()
178
188
  for scanner_id in range(config.n_processes):
179
189
  scanner_comms = ScannerCommunicator(comms, scanner_id)
180
- worker = ctx.make_worker(
190
+ comms.workers[scanner_comms.name] = worker_factory.make_worker(
181
191
  target=Scanner.run,
182
192
  args=(predicted_path, butler_path, scanner_comms),
183
193
  name=scanner_comms.name,
184
194
  )
185
- worker.start()
186
- scanners.append(worker)
187
195
  ingester_comms = IngesterCommunicator(comms)
188
- ingester = ctx.make_worker(
196
+ comms.workers[ingester_comms.name] = worker_factory.make_worker(
189
197
  target=Ingester.run,
190
198
  args=(predicted_path, butler_path, ingester_comms),
191
199
  name=ingester_comms.name,
192
200
  )
193
- ingester.start()
194
201
  supervisor = Supervisor(predicted_path, comms)
195
202
  supervisor.loop()
196
- log.info(
197
- "Scanning complete after %0.1fs; waiting for workers to finish.",
198
- comms.progress.elapsed_time,
199
- )
200
- comms.wait_for_workers_to_finish()
201
- if supervisor.n_abandoned:
202
- raise RuntimeError(
203
- f"{supervisor.n_abandoned} {'quanta' if supervisor.n_abandoned > 1 else 'quantum'} "
204
- "abandoned because they did not succeed. Re-run with assume_complete=True after all retry "
205
- "attempts have been exhausted."
206
- )
207
- for w in scanners:
208
- w.join()
209
- ingester.join()
210
- if writer is not None and writer.is_alive():
211
- log.info("Waiting for writer process to close (garbage collecting can be very slow).")
212
- writer.join()
213
203
  # We can't get memory usage for children until they've joined.
214
204
  parent_mem, child_mem = get_peak_mem_usage()
215
205
  # This is actually an upper bound on the peak (since the peaks could be
@@ -0,0 +1,303 @@
1
+ # This file is part of pipe_base.
2
+ #
3
+ # Developed for the LSST Data Management System.
4
+ # This product includes software developed by the LSST Project
5
+ # (http://www.lsst.org).
6
+ # See the COPYRIGHT file at the top-level directory of this distribution
7
+ # for details of code ownership.
8
+ #
9
+ # This software is dual licensed under the GNU General Public License and also
10
+ # under a 3-clause BSD license. Recipients may choose which of these licenses
11
+ # to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12
+ # respectively. If you choose the GPL option then the following text applies
13
+ # (but note that there is still no warranty even if you opt for BSD instead):
14
+ #
15
+ # This program is free software: you can redistribute it and/or modify
16
+ # it under the terms of the GNU General Public License as published by
17
+ # the Free Software Foundation, either version 3 of the License, or
18
+ # (at your option) any later version.
19
+ #
20
+ # This program is distributed in the hope that it will be useful,
21
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
22
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23
+ # GNU General Public License for more details.
24
+ #
25
+ # You should have received a copy of the GNU General Public License
26
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
27
+
28
+ from __future__ import annotations
29
+
30
+ __all__ = ("Event", "Queue", "SpawnWorkerFactory", "ThreadWorkerFactory", "Worker", "WorkerFactory")
31
+
32
+ import multiprocessing.context
33
+ import multiprocessing.synchronize
34
+ import queue
35
+ import threading
36
+ from abc import ABC, abstractmethod
37
+ from collections.abc import Callable
38
+ from typing import Any, Literal, overload
39
+
40
+ _TINY_TIMEOUT = 0.01
41
+
42
+ type Event = threading.Event | multiprocessing.synchronize.Event
43
+
44
+
45
+ class Worker(ABC):
46
+ """A thin abstraction over `threading.Thread` and `multiprocessing.Process`
47
+ that also provides a variable to track whether it reported successful
48
+ completion.
49
+ """
50
+
51
+ def __init__(self) -> None:
52
+ self.successful = False
53
+
54
+ @property
55
+ @abstractmethod
56
+ def name(self) -> str:
57
+ """Name of the worker, as assigned at creation."""
58
+ raise NotImplementedError()
59
+
60
+ @abstractmethod
61
+ def join(self, timeout: float | None = None) -> None:
62
+ """Wait for the worker to finish.
63
+
64
+ Parameters
65
+ ----------
66
+ timeout : `float`, optional
67
+ How long to wait in seconds. If the timeout is exceeded,
68
+ `is_alive` can be used to see whether the worker finished or not.
69
+ """
70
+ raise NotImplementedError()
71
+
72
+ @abstractmethod
73
+ def is_alive(self) -> bool:
74
+ """Return whether the worker is still running."""
75
+ raise NotImplementedError()
76
+
77
+ def kill(self) -> None:
78
+ """Kill the worker, if possible."""
79
+
80
+
81
+ class Queue[T](ABC):
82
+ """A thin abstraction over `queue.Queue` and `multiprocessing.Queue` that
83
+ provides better control over disorderly shutdowns.
84
+ """
85
+
86
+ @overload
87
+ def get(self, *, block: Literal[True]) -> T: ...
88
+
89
+ @overload
90
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ...
91
+
92
+ @abstractmethod
93
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
94
+ """Get an object or return `None` if the queue is empty.
95
+
96
+ Parameters
97
+ ----------
98
+ timeout : `float` or `None`, optional
99
+ Maximum number of seconds to wait while blocking.
100
+ block : `bool`, optional
101
+ Whether to block until an object is available.
102
+
103
+ Returns
104
+ -------
105
+ obj : `object` or `None`
106
+ Object from the queue, or `None` if it was empty. Note that this
107
+ is different from the behavior of the built-in Python queues,
108
+ which raise `queue.Empty` instead.
109
+ """
110
+ raise NotImplementedError()
111
+
112
+ @abstractmethod
113
+ def put(self, item: T) -> None:
114
+ """Add an object to the queue.
115
+
116
+ Parameters
117
+ ----------
118
+ item : `object`
119
+ Item to add.
120
+ """
121
+ raise NotImplementedError()
122
+
123
+ def clear(self) -> bool:
124
+ """Clear out all objects currently on the queue.
125
+
126
+ This does not guarantee that more objects will not be added later.
127
+ """
128
+ found_anything: bool = False
129
+ while self.get() is not None:
130
+ found_anything = True
131
+ return found_anything
132
+
133
+ def kill(self) -> None:
134
+ """Prepare a queue for a disorderly shutdown, without assuming that
135
+ any other workers using it are still alive and functioning.
136
+ """
137
+
138
+
139
+ class WorkerFactory(ABC):
140
+ """A simple abstract interface that can be implemented by both threading
141
+ and multiprocessing.
142
+ """
143
+
144
+ @abstractmethod
145
+ def make_queue(self) -> Queue[Any]:
146
+ """Make an empty queue that can be used to pass objects between
147
+ workers created by this factory.
148
+ """
149
+ raise NotImplementedError()
150
+
151
+ @abstractmethod
152
+ def make_event(self) -> Event:
153
+ """Make an event that can be used to communicate a boolean state change
154
+ to workers created by this factory.
155
+ """
156
+ raise NotImplementedError()
157
+
158
+ @abstractmethod
159
+ def make_worker(
160
+ self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
161
+ ) -> Worker:
162
+ """Make a worker that runs the given callable.
163
+
164
+ Parameters
165
+ ----------
166
+ target : `~collections.abc.Callable`
167
+ A callable to invoke on the worker.
168
+ args : `tuple`
169
+ Positional arguments to pass to the callable.
170
+ name : `str`, optional
171
+ Human-readable name for the worker.
172
+
173
+ Returns
174
+ -------
175
+ worker : `Worker`
176
+ Process or thread that is already running the given callable.
177
+ """
178
+ raise NotImplementedError()
179
+
180
+
181
+ class _ThreadWorker(Worker):
182
+ """An implementation of `Worker` backed by the `threading` module."""
183
+
184
+ def __init__(self, thread: threading.Thread):
185
+ super().__init__()
186
+ self._thread = thread
187
+
188
+ @property
189
+ def name(self) -> str:
190
+ return self._thread.name
191
+
192
+ def join(self, timeout: float | None = None) -> None:
193
+ self._thread.join(timeout=timeout)
194
+
195
+ def is_alive(self) -> bool:
196
+ return self._thread.is_alive()
197
+
198
+
199
+ class _ThreadQueue[T](Queue[T]):
200
+ def __init__(self) -> None:
201
+ self._impl = queue.Queue[T]()
202
+
203
+ @overload
204
+ def get(self, *, block: Literal[True]) -> T: ...
205
+
206
+ @overload
207
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ...
208
+
209
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
210
+ try:
211
+ return self._impl.get(block=block, timeout=timeout)
212
+ except queue.Empty:
213
+ return None
214
+
215
+ def put(self, item: T) -> None:
216
+ self._impl.put(item, block=False)
217
+
218
+
219
+ class ThreadWorkerFactory(WorkerFactory):
220
+ """An implementation of `WorkerFactory` backed by the `threading`
221
+ module.
222
+ """
223
+
224
+ def make_queue(self) -> Queue[Any]:
225
+ return _ThreadQueue()
226
+
227
+ def make_event(self) -> Event:
228
+ return threading.Event()
229
+
230
+ def make_worker(
231
+ self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
232
+ ) -> Worker:
233
+ thread = threading.Thread(target=target, args=args, name=name)
234
+ thread.start()
235
+ return _ThreadWorker(thread)
236
+
237
+
238
+ class _ProcessWorker(Worker):
239
+ """An implementation of `Worker` backed by the `multiprocessing` module."""
240
+
241
+ def __init__(self, process: multiprocessing.context.SpawnProcess):
242
+ super().__init__()
243
+ self._process = process
244
+
245
+ @property
246
+ def name(self) -> str:
247
+ return self._process.name
248
+
249
+ def join(self, timeout: float | None = None) -> None:
250
+ self._process.join(timeout=timeout)
251
+
252
+ def is_alive(self) -> bool:
253
+ return self._process.is_alive()
254
+
255
+ def kill(self) -> None:
256
+ """Kill the worker, if possible."""
257
+ self._process.kill()
258
+
259
+
260
+ class _ProcessQueue[T](Queue[T]):
261
+ def __init__(self, impl: multiprocessing.Queue):
262
+ self._impl = impl
263
+
264
+ @overload
265
+ def get(self, *, block: Literal[True]) -> T: ...
266
+
267
+ @overload
268
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ...
269
+
270
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
271
+ try:
272
+ return self._impl.get(block=block, timeout=timeout)
273
+ except queue.Empty:
274
+ return None
275
+
276
+ def put(self, item: T) -> None:
277
+ self._impl.put(item, block=False)
278
+
279
+ def kill(self) -> None:
280
+ self._impl.cancel_join_thread()
281
+ self._impl.close()
282
+
283
+
284
+ class SpawnWorkerFactory(WorkerFactory):
285
+ """An implementation of `WorkerFactory` backed by the `multiprocessing`
286
+ module, with new processes started by spawning.
287
+ """
288
+
289
+ def __init__(self) -> None:
290
+ self._ctx = multiprocessing.get_context("spawn")
291
+
292
+ def make_queue(self) -> Queue[Any]:
293
+ return _ProcessQueue(self._ctx.Queue())
294
+
295
+ def make_event(self) -> Event:
296
+ return self._ctx.Event()
297
+
298
+ def make_worker(
299
+ self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
300
+ ) -> Worker:
301
+ process = self._ctx.Process(target=target, args=args, name=name)
302
+ process.start()
303
+ return _ProcessWorker(process)