lsst-pipe-base 30.0.0rc3__py3-none-any.whl → 30.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. lsst/pipe/base/_instrument.py +25 -15
  2. lsst/pipe/base/_quantumContext.py +3 -3
  3. lsst/pipe/base/_status.py +43 -10
  4. lsst/pipe/base/_task_metadata.py +2 -2
  5. lsst/pipe/base/all_dimensions_quantum_graph_builder.py +8 -3
  6. lsst/pipe/base/automatic_connection_constants.py +20 -1
  7. lsst/pipe/base/cli/cmd/__init__.py +18 -2
  8. lsst/pipe/base/cli/cmd/commands.py +149 -4
  9. lsst/pipe/base/connectionTypes.py +72 -160
  10. lsst/pipe/base/connections.py +6 -9
  11. lsst/pipe/base/execution_reports.py +0 -5
  12. lsst/pipe/base/graph/graph.py +11 -10
  13. lsst/pipe/base/graph/quantumNode.py +4 -4
  14. lsst/pipe/base/graph_walker.py +8 -10
  15. lsst/pipe/base/log_capture.py +1 -1
  16. lsst/pipe/base/log_on_close.py +4 -7
  17. lsst/pipe/base/pipeline.py +5 -6
  18. lsst/pipe/base/pipelineIR.py +2 -8
  19. lsst/pipe/base/pipelineTask.py +5 -7
  20. lsst/pipe/base/pipeline_graph/_dataset_types.py +2 -2
  21. lsst/pipe/base/pipeline_graph/_edges.py +32 -22
  22. lsst/pipe/base/pipeline_graph/_mapping_views.py +4 -7
  23. lsst/pipe/base/pipeline_graph/_pipeline_graph.py +14 -7
  24. lsst/pipe/base/pipeline_graph/expressions.py +2 -2
  25. lsst/pipe/base/pipeline_graph/io.py +7 -10
  26. lsst/pipe/base/pipeline_graph/visualization/_dot.py +13 -12
  27. lsst/pipe/base/pipeline_graph/visualization/_layout.py +16 -18
  28. lsst/pipe/base/pipeline_graph/visualization/_merge.py +4 -7
  29. lsst/pipe/base/pipeline_graph/visualization/_printer.py +10 -10
  30. lsst/pipe/base/pipeline_graph/visualization/_status_annotator.py +7 -0
  31. lsst/pipe/base/prerequisite_helpers.py +2 -1
  32. lsst/pipe/base/quantum_graph/_common.py +15 -17
  33. lsst/pipe/base/quantum_graph/_multiblock.py +36 -20
  34. lsst/pipe/base/quantum_graph/_predicted.py +7 -3
  35. lsst/pipe/base/quantum_graph/_provenance.py +501 -61
  36. lsst/pipe/base/quantum_graph/aggregator/__init__.py +0 -1
  37. lsst/pipe/base/quantum_graph/aggregator/_communicators.py +187 -240
  38. lsst/pipe/base/quantum_graph/aggregator/_config.py +87 -9
  39. lsst/pipe/base/quantum_graph/aggregator/_ingester.py +13 -12
  40. lsst/pipe/base/quantum_graph/aggregator/_scanner.py +15 -7
  41. lsst/pipe/base/quantum_graph/aggregator/_structs.py +3 -3
  42. lsst/pipe/base/quantum_graph/aggregator/_supervisor.py +19 -34
  43. lsst/pipe/base/quantum_graph/aggregator/_workers.py +303 -0
  44. lsst/pipe/base/quantum_graph/aggregator/_writer.py +3 -3
  45. lsst/pipe/base/quantum_graph/formatter.py +74 -4
  46. lsst/pipe/base/quantum_graph/ingest_graph.py +413 -0
  47. lsst/pipe/base/quantum_graph/visualization.py +5 -1
  48. lsst/pipe/base/quantum_graph_builder.py +21 -8
  49. lsst/pipe/base/quantum_graph_skeleton.py +31 -29
  50. lsst/pipe/base/quantum_provenance_graph.py +29 -12
  51. lsst/pipe/base/separable_pipeline_executor.py +1 -1
  52. lsst/pipe/base/single_quantum_executor.py +15 -8
  53. lsst/pipe/base/struct.py +4 -0
  54. lsst/pipe/base/testUtils.py +3 -3
  55. lsst/pipe/base/tests/mocks/_storage_class.py +2 -1
  56. lsst/pipe/base/version.py +1 -1
  57. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/METADATA +3 -3
  58. lsst_pipe_base-30.0.1.dist-info/RECORD +129 -0
  59. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/WHEEL +1 -1
  60. lsst_pipe_base-30.0.0rc3.dist-info/RECORD +0 -127
  61. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/entry_points.txt +0 -0
  62. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/licenses/COPYRIGHT +0 -0
  63. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/licenses/LICENSE +0 -0
  64. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/licenses/bsd_license.txt +0 -0
  65. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/licenses/gpl-v3.0.txt +0 -0
  66. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/top_level.txt +0 -0
  67. {lsst_pipe_base-30.0.0rc3.dist-info → lsst_pipe_base-30.0.1.dist-info}/zip-safe +0 -0
@@ -29,6 +29,8 @@ from __future__ import annotations
29
29
 
30
30
  __all__ = ("AggregatorConfig",)
31
31
 
32
+ import sys
33
+ from typing import TYPE_CHECKING, Any
32
34
 
33
35
  import pydantic
34
36
 
@@ -60,11 +62,13 @@ class AggregatorConfig(pydantic.BaseModel):
60
62
  n_processes: int = 1
61
63
  """Number of processes the scanner should use."""
62
64
 
63
- assume_complete: bool = True
64
- """If `True`, the aggregator can assume all quanta have run to completion
65
- (including any automatic retries). If `False`, only successes can be
66
- considered final, and quanta that appear to have failed or to have not been
67
- executed are ignored.
65
+ incomplete: bool = False
66
+ """If `True`, do not expect the graph to have been executed to completion
67
+ yet, and only ingest the outputs of successful quanta.
68
+
69
+ This disables writing the provenance quantum graph, since this is likely to
70
+ be wasted effort that just complicates a follow-up run with
71
+ ``incomplete=False`` later.
68
72
  """
69
73
 
70
74
  defensive_ingest: bool = False
@@ -95,11 +99,10 @@ class AggregatorConfig(pydantic.BaseModel):
95
99
  """
96
100
 
97
101
  dry_run: bool = False
98
- """If `True`, do not actually perform any deletions or central butler
99
- ingests.
102
+ """If `True`, do not actually perform any central butler ingests.
100
103
 
101
- Most log messages concerning deletions and ingests will still be emitted in
102
- order to provide a better emulation of a real run.
104
+ Most log messages concerning ingests will still be emitted in order to
105
+ provide a better emulation of a real run.
103
106
  """
104
107
 
105
108
  interactive_status: bool = False
@@ -137,3 +140,78 @@ class AggregatorConfig(pydantic.BaseModel):
137
140
  """Enable support for storage classes by created by the
138
141
  lsst.pipe.base.tests.mocks package.
139
142
  """
143
+
144
+ promise_ingest_graph: bool = False
145
+ """If `True`, the aggregator will assume that `~.ingest_graph.ingest_graph`
146
+ will be run later to ingest metadata/log/config datasets, and will not
147
+ ingest them itself. This means that if `~.ingest_graph.ingest_graph` is
148
+ not run, those files will be abandoned in the butler storage root without
149
+ being present in the butler database, but it will speed up both processes.
150
+
151
+ It is *usually* safe to build a quantum graph for downstream processing
152
+ before or while running `~.ingest_graph.ingest_graph`, because
153
+ metadata/log/config datasets are rarely used as inputs. To check, use
154
+ ``pipetask build ... --show inputs`` to show the overall-inputs to the
155
+ graph and scan for these dataset types.
156
+ """
157
+
158
+ worker_check_timeout: float = 5.0
159
+ """Time to wait (s) for reports from subprocesses before running
160
+ process-alive checks.
161
+
162
+ These checks are designed to kill the main aggregator process when a
163
+ subprocess has been unexpectedly killed (e.g. for for using too much
164
+ memory).
165
+ """
166
+
167
+ @property
168
+ def is_writing_provenance(self) -> bool:
169
+ """Whether the aggregator is configured to write the provenance quantum
170
+ graph.
171
+ """
172
+ return self.output_path is not None and not self.incomplete
173
+
174
+ # Work around the fact that Sphinx chokes on Pydantic docstring formatting,
175
+ # when we inherit those docstrings in our public classes.
176
+ if "sphinx" in sys.modules and not TYPE_CHECKING:
177
+
178
+ def copy(self, *args: Any, **kwargs: Any) -> Any:
179
+ """See `pydantic.BaseModel.copy`."""
180
+ return super().copy(*args, **kwargs)
181
+
182
+ def model_dump(self, *args: Any, **kwargs: Any) -> Any:
183
+ """See `pydantic.BaseModel.model_dump`."""
184
+ return super().model_dump(*args, **kwargs)
185
+
186
+ def model_dump_json(self, *args: Any, **kwargs: Any) -> Any:
187
+ """See `pydantic.BaseModel.model_dump_json`."""
188
+ return super().model_dump(*args, **kwargs)
189
+
190
+ def model_copy(self, *args: Any, **kwargs: Any) -> Any:
191
+ """See `pydantic.BaseModel.model_copy`."""
192
+ return super().model_copy(*args, **kwargs)
193
+
194
+ @classmethod
195
+ def model_construct(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc, override]
196
+ """See `pydantic.BaseModel.model_construct`."""
197
+ return super().model_construct(*args, **kwargs)
198
+
199
+ @classmethod
200
+ def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any:
201
+ """See `pydantic.BaseModel.model_json_schema`."""
202
+ return super().model_json_schema(*args, **kwargs)
203
+
204
+ @classmethod
205
+ def model_validate(cls, *args: Any, **kwargs: Any) -> Any:
206
+ """See `pydantic.BaseModel.model_validate`."""
207
+ return super().model_validate(*args, **kwargs)
208
+
209
+ @classmethod
210
+ def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any:
211
+ """See `pydantic.BaseModel.model_validate_json`."""
212
+ return super().model_validate_json(*args, **kwargs)
213
+
214
+ @classmethod
215
+ def model_validate_strings(cls, *args: Any, **kwargs: Any) -> Any:
216
+ """See `pydantic.BaseModel.model_validate_strings`."""
217
+ return super().model_validate_strings(*args, **kwargs)
@@ -43,7 +43,7 @@ from lsst.daf.butler.registry import ConflictingDefinitionError
43
43
 
44
44
  from ...pipeline_graph import TaskImportMode
45
45
  from .._common import DatastoreName
46
- from .._predicted import PredictedDatasetModel, PredictedQuantumGraphComponents, PredictedQuantumGraphReader
46
+ from .._predicted import PredictedQuantumGraphComponents, PredictedQuantumGraphReader
47
47
  from ._communicators import IngesterCommunicator
48
48
 
49
49
 
@@ -140,7 +140,7 @@ class Ingester(AbstractContextManager):
140
140
  Notes
141
141
  -----
142
142
  This method is designed to run as the ``target`` in
143
- `WorkerContext.make_worker`.
143
+ `WorkerFactory.make_worker`.
144
144
  """
145
145
  with comms, Ingester(predicted_path, butler_path, comms) as ingester:
146
146
  ingester.loop()
@@ -170,7 +170,7 @@ class Ingester(AbstractContextManager):
170
170
  for ingest_request in self.comms.poll():
171
171
  self.n_producers_pending += 1
172
172
  self.comms.log.debug(f"Got ingest request for producer {ingest_request.producer_id}.")
173
- self.update_pending(ingest_request.datasets, ingest_request.records)
173
+ self.update_outputs_pending(refs=ingest_request.refs, records=ingest_request.records)
174
174
  if self.n_datasets_pending > self.comms.config.ingest_batch_size:
175
175
  self.ingest()
176
176
  self.comms.log.info("All ingest requests received.")
@@ -266,31 +266,32 @@ class Ingester(AbstractContextManager):
266
266
  else:
267
267
  del self.records_pending[datastore_name]
268
268
 
269
- def update_pending(
270
- self, datasets: list[PredictedDatasetModel], records: dict[DatastoreName, DatastoreRecordData]
269
+ def update_outputs_pending(
270
+ self,
271
+ refs: list[DatasetRef],
272
+ records: dict[DatastoreName, DatastoreRecordData],
271
273
  ) -> None:
272
274
  """Add an ingest request to the pending-ingest data structures.
273
275
 
274
276
  Parameters
275
277
  ----------
276
- datasets : `list` [ `PredictedDatasetModel` ]
277
- Registry information about the datasets.
278
+ refs : `list` [ `lsst.daf.butler.DatasetRef` ]
279
+ Registry information about regular quantum-output datasets.
278
280
  records : `dict` [ `str`, \
279
281
  `lsst.daf.butler.datastore.record_data.DatastoreRecordData` ]
280
282
  Datastore information about the datasets.
281
283
  """
282
- n_given = len(datasets)
284
+ n_given = len(refs)
283
285
  if self.already_ingested is not None:
284
- datasets = [d for d in datasets if d.dataset_id not in self.already_ingested]
285
- kept = {d.dataset_id for d in datasets}
286
+ refs = [ref for ref in refs if ref.id not in self.already_ingested]
287
+ kept = {ref.id for ref in refs}
286
288
  self.n_datasets_skipped += n_given - len(kept)
287
289
  records = {
288
290
  datastore_name: filtered_records
289
291
  for datastore_name, original_records in records.items()
290
292
  if (filtered_records := original_records.subset(kept)) is not None
291
293
  }
292
- for dataset in datasets:
293
- ref = self.predicted.make_dataset_ref(dataset)
294
+ for ref in refs:
294
295
  self.refs_pending[ref.datasetType.dimensions].append(ref)
295
296
  for datastore_name, datastore_records in records.items():
296
297
  if (existing_records := self.records_pending.get(datastore_name)) is not None:
@@ -161,7 +161,7 @@ class Scanner(AbstractContextManager):
161
161
  Notes
162
162
  -----
163
163
  This method is designed to run as the ``target`` in
164
- `WorkerContext.make_worker`.
164
+ `WorkerFactory.make_worker`.
165
165
  """
166
166
  with comms, Scanner(predicted_path, butler_path, comms) as scanner:
167
167
  scanner.loop()
@@ -223,7 +223,7 @@ class Scanner(AbstractContextManager):
223
223
  logs = self._read_log(predicted_quantum)
224
224
  metadata = self._read_metadata(predicted_quantum)
225
225
  result = ProvenanceQuantumScanModels.from_metadata_and_logs(
226
- predicted_quantum, metadata, logs, assume_complete=self.comms.config.assume_complete
226
+ predicted_quantum, metadata, logs, incomplete=self.comms.config.incomplete
227
227
  )
228
228
  if result.status is ProvenanceQuantumScanStatus.ABANDONED:
229
229
  self.comms.log.debug("Abandoning scan for failed quantum %s.", quantum_id)
@@ -233,7 +233,7 @@ class Scanner(AbstractContextManager):
233
233
  if predicted_output.dataset_id not in result.output_existence:
234
234
  result.output_existence[predicted_output.dataset_id] = self.scan_dataset(predicted_output)
235
235
  to_ingest = self._make_ingest_request(predicted_quantum, result)
236
- if self.comms.config.output_path is not None:
236
+ if self.comms.config.is_writing_provenance:
237
237
  to_write = result.to_scan_data(predicted_quantum, compressor=self.compressor)
238
238
  self.comms.request_write(to_write)
239
239
  self.comms.request_ingest(to_ingest)
@@ -261,15 +261,23 @@ class Scanner(AbstractContextManager):
261
261
  predicted_outputs_by_id = {
262
262
  d.dataset_id: d for d in itertools.chain.from_iterable(predicted_quantum.outputs.values())
263
263
  }
264
- to_ingest_predicted: list[PredictedDatasetModel] = []
265
264
  to_ingest_refs: list[DatasetRef] = []
265
+ to_ignore: set[uuid.UUID] = set()
266
+ if self.comms.config.promise_ingest_graph:
267
+ if result.status is ProvenanceQuantumScanStatus.INIT:
268
+ if predicted_quantum.task_label: # i.e. not the 'packages' producer
269
+ to_ignore.add(
270
+ predicted_quantum.outputs[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME][0].dataset_id
271
+ )
272
+ else:
273
+ to_ignore.add(predicted_quantum.outputs[acc.METADATA_OUTPUT_CONNECTION_NAME][0].dataset_id)
274
+ to_ignore.add(predicted_quantum.outputs[acc.LOG_OUTPUT_CONNECTION_NAME][0].dataset_id)
266
275
  for dataset_id, was_produced in result.output_existence.items():
267
- if was_produced:
276
+ if was_produced and dataset_id not in to_ignore:
268
277
  predicted_output = predicted_outputs_by_id[dataset_id]
269
- to_ingest_predicted.append(predicted_output)
270
278
  to_ingest_refs.append(self.reader.components.make_dataset_ref(predicted_output))
271
279
  to_ingest_records = self.qbb._datastore.export_predicted_records(to_ingest_refs)
272
- return IngestRequest(result.quantum_id, to_ingest_predicted, to_ingest_records)
280
+ return IngestRequest(result.quantum_id, to_ingest_refs, to_ingest_records)
273
281
 
274
282
  def _read_metadata(self, predicted_quantum: PredictedQuantumDatasetsModel) -> TaskMetadata | None:
275
283
  """Attempt to read the metadata dataset for a quantum.
@@ -32,10 +32,10 @@ __all__ = ("IngestRequest", "ScanReport")
32
32
  import dataclasses
33
33
  import uuid
34
34
 
35
+ from lsst.daf.butler import DatasetRef
35
36
  from lsst.daf.butler.datastore.record_data import DatastoreRecordData
36
37
 
37
38
  from .._common import DatastoreName
38
- from .._predicted import PredictedDatasetModel
39
39
  from .._provenance import ProvenanceQuantumScanStatus
40
40
 
41
41
 
@@ -57,11 +57,11 @@ class IngestRequest:
57
57
  producer_id: uuid.UUID
58
58
  """ID of the quantum that produced these datasets."""
59
59
 
60
- datasets: list[PredictedDatasetModel]
60
+ refs: list[DatasetRef]
61
61
  """Registry information about the datasets."""
62
62
 
63
63
  records: dict[DatastoreName, DatastoreRecordData]
64
64
  """Datastore information about the datasets."""
65
65
 
66
66
  def __bool__(self) -> bool:
67
- return bool(self.datasets or self.records)
67
+ return bool(self.refs or self.records)
@@ -46,16 +46,14 @@ from .._provenance import ProvenanceQuantumScanData, ProvenanceQuantumScanStatus
46
46
  from ._communicators import (
47
47
  IngesterCommunicator,
48
48
  ScannerCommunicator,
49
- SpawnProcessContext,
50
49
  SupervisorCommunicator,
51
- ThreadingContext,
52
- Worker,
53
50
  WriterCommunicator,
54
51
  )
55
52
  from ._config import AggregatorConfig
56
53
  from ._ingester import Ingester
57
54
  from ._scanner import Scanner
58
55
  from ._structs import ScanReport
56
+ from ._workers import SpawnWorkerFactory, ThreadWorkerFactory
59
57
  from ._writer import Writer
60
58
 
61
59
 
@@ -117,6 +115,17 @@ class Supervisor:
117
115
  self.comms.request_scan(ready_set.pop())
118
116
  for scan_return in self.comms.poll():
119
117
  self.handle_report(scan_return)
118
+ if self.comms.config.incomplete:
119
+ quantum_or_quanta = "quanta" if self.n_abandoned != 1 else "quantum"
120
+ self.comms.progress.log.info(
121
+ "%d %s incomplete/failed abandoned; re-run with incomplete=False to finish.",
122
+ self.n_abandoned,
123
+ quantum_or_quanta,
124
+ )
125
+ self.comms.progress.log.info(
126
+ "Scanning complete after %0.1fs; waiting for workers to finish.",
127
+ self.comms.progress.elapsed_time,
128
+ )
120
129
 
121
130
  def handle_report(self, scan_report: ScanReport) -> None:
122
131
  """Handle a report from a scanner.
@@ -134,7 +143,7 @@ class Supervisor:
134
143
  self.comms.log.debug("Scan complete for %s: quantum failed.", scan_report.quantum_id)
135
144
  blocked_quanta = self.walker.fail(scan_report.quantum_id)
136
145
  for blocked_quantum_id in blocked_quanta:
137
- if self.comms.config.output_path is not None:
146
+ if self.comms.config.is_writing_provenance:
138
147
  self.comms.request_write(
139
148
  ProvenanceQuantumScanData(
140
149
  blocked_quantum_id, status=ProvenanceQuantumScanStatus.BLOCKED
@@ -166,55 +175,31 @@ def aggregate_graph(predicted_path: str, butler_path: str, config: AggregatorCon
166
175
  Configuration for the aggregator.
167
176
  """
168
177
  log = getLogger("lsst.pipe.base.quantum_graph.aggregator")
169
- ctx = ThreadingContext() if config.n_processes == 1 else SpawnProcessContext()
170
- scanners: list[Worker] = []
171
- ingester: Worker
172
- writer: Worker | None = None
173
- with SupervisorCommunicator(log, config.n_processes, ctx, config) as comms:
178
+ worker_factory = ThreadWorkerFactory() if config.n_processes == 1 else SpawnWorkerFactory()
179
+ with SupervisorCommunicator(log, config.n_processes, worker_factory, config) as comms:
174
180
  comms.progress.log.verbose("Starting workers.")
175
- if config.output_path is not None:
181
+ if config.is_writing_provenance:
176
182
  writer_comms = WriterCommunicator(comms)
177
- writer = ctx.make_worker(
183
+ comms.workers[writer_comms.name] = worker_factory.make_worker(
178
184
  target=Writer.run,
179
185
  args=(predicted_path, writer_comms),
180
186
  name=writer_comms.name,
181
187
  )
182
- writer.start()
183
188
  for scanner_id in range(config.n_processes):
184
189
  scanner_comms = ScannerCommunicator(comms, scanner_id)
185
- worker = ctx.make_worker(
190
+ comms.workers[scanner_comms.name] = worker_factory.make_worker(
186
191
  target=Scanner.run,
187
192
  args=(predicted_path, butler_path, scanner_comms),
188
193
  name=scanner_comms.name,
189
194
  )
190
- worker.start()
191
- scanners.append(worker)
192
195
  ingester_comms = IngesterCommunicator(comms)
193
- ingester = ctx.make_worker(
196
+ comms.workers[ingester_comms.name] = worker_factory.make_worker(
194
197
  target=Ingester.run,
195
198
  args=(predicted_path, butler_path, ingester_comms),
196
199
  name=ingester_comms.name,
197
200
  )
198
- ingester.start()
199
201
  supervisor = Supervisor(predicted_path, comms)
200
202
  supervisor.loop()
201
- log.info(
202
- "Scanning complete after %0.1fs; waiting for workers to finish.",
203
- comms.progress.elapsed_time,
204
- )
205
- comms.wait_for_workers_to_finish()
206
- if supervisor.n_abandoned:
207
- raise RuntimeError(
208
- f"{supervisor.n_abandoned} {'quanta' if supervisor.n_abandoned > 1 else 'quantum'} "
209
- "abandoned because they did not succeed. Re-run with assume_complete=True after all retry "
210
- "attempts have been exhausted."
211
- )
212
- for w in scanners:
213
- w.join()
214
- ingester.join()
215
- if writer is not None and writer.is_alive():
216
- log.info("Waiting for writer process to close (garbage collecting can be very slow).")
217
- writer.join()
218
203
  # We can't get memory usage for children until they've joined.
219
204
  parent_mem, child_mem = get_peak_mem_usage()
220
205
  # This is actually an upper bound on the peak (since the peaks could be
@@ -0,0 +1,303 @@
1
+ # This file is part of pipe_base.
2
+ #
3
+ # Developed for the LSST Data Management System.
4
+ # This product includes software developed by the LSST Project
5
+ # (http://www.lsst.org).
6
+ # See the COPYRIGHT file at the top-level directory of this distribution
7
+ # for details of code ownership.
8
+ #
9
+ # This software is dual licensed under the GNU General Public License and also
10
+ # under a 3-clause BSD license. Recipients may choose which of these licenses
11
+ # to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12
+ # respectively. If you choose the GPL option then the following text applies
13
+ # (but note that there is still no warranty even if you opt for BSD instead):
14
+ #
15
+ # This program is free software: you can redistribute it and/or modify
16
+ # it under the terms of the GNU General Public License as published by
17
+ # the Free Software Foundation, either version 3 of the License, or
18
+ # (at your option) any later version.
19
+ #
20
+ # This program is distributed in the hope that it will be useful,
21
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
22
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23
+ # GNU General Public License for more details.
24
+ #
25
+ # You should have received a copy of the GNU General Public License
26
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
27
+
28
+ from __future__ import annotations
29
+
30
+ __all__ = ("Event", "Queue", "SpawnWorkerFactory", "ThreadWorkerFactory", "Worker", "WorkerFactory")
31
+
32
+ import multiprocessing.context
33
+ import multiprocessing.synchronize
34
+ import queue
35
+ import threading
36
+ from abc import ABC, abstractmethod
37
+ from collections.abc import Callable
38
+ from typing import Any, Literal, overload
39
+
40
+ _TINY_TIMEOUT = 0.01
41
+
42
+ type Event = threading.Event | multiprocessing.synchronize.Event
43
+
44
+
45
+ class Worker(ABC):
46
+ """A thin abstraction over `threading.Thread` and `multiprocessing.Process`
47
+ that also provides a variable to track whether it reported successful
48
+ completion.
49
+ """
50
+
51
+ def __init__(self) -> None:
52
+ self.successful = False
53
+
54
+ @property
55
+ @abstractmethod
56
+ def name(self) -> str:
57
+ """Name of the worker, as assigned at creation."""
58
+ raise NotImplementedError()
59
+
60
+ @abstractmethod
61
+ def join(self, timeout: float | None = None) -> None:
62
+ """Wait for the worker to finish.
63
+
64
+ Parameters
65
+ ----------
66
+ timeout : `float`, optional
67
+ How long to wait in seconds. If the timeout is exceeded,
68
+ `is_alive` can be used to see whether the worker finished or not.
69
+ """
70
+ raise NotImplementedError()
71
+
72
+ @abstractmethod
73
+ def is_alive(self) -> bool:
74
+ """Return whether the worker is still running."""
75
+ raise NotImplementedError()
76
+
77
+ def kill(self) -> None:
78
+ """Kill the worker, if possible."""
79
+
80
+
81
+ class Queue[T](ABC):
82
+ """A thin abstraction over `queue.Queue` and `multiprocessing.Queue` that
83
+ provides better control over disorderly shutdowns.
84
+ """
85
+
86
+ @overload
87
+ def get(self, *, block: Literal[True]) -> T: ...
88
+
89
+ @overload
90
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ...
91
+
92
+ @abstractmethod
93
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
94
+ """Get an object or return `None` if the queue is empty.
95
+
96
+ Parameters
97
+ ----------
98
+ timeout : `float` or `None`, optional
99
+ Maximum number of seconds to wait while blocking.
100
+ block : `bool`, optional
101
+ Whether to block until an object is available.
102
+
103
+ Returns
104
+ -------
105
+ obj : `object` or `None`
106
+ Object from the queue, or `None` if it was empty. Note that this
107
+ is different from the behavior of the built-in Python queues,
108
+ which raise `queue.Empty` instead.
109
+ """
110
+ raise NotImplementedError()
111
+
112
+ @abstractmethod
113
+ def put(self, item: T) -> None:
114
+ """Add an object to the queue.
115
+
116
+ Parameters
117
+ ----------
118
+ item : `object`
119
+ Item to add.
120
+ """
121
+ raise NotImplementedError()
122
+
123
+ def clear(self) -> bool:
124
+ """Clear out all objects currently on the queue.
125
+
126
+ This does not guarantee that more objects will not be added later.
127
+ """
128
+ found_anything: bool = False
129
+ while self.get() is not None:
130
+ found_anything = True
131
+ return found_anything
132
+
133
+ def kill(self) -> None:
134
+ """Prepare a queue for a disorderly shutdown, without assuming that
135
+ any other workers using it are still alive and functioning.
136
+ """
137
+
138
+
139
+ class WorkerFactory(ABC):
140
+ """A simple abstract interface that can be implemented by both threading
141
+ and multiprocessing.
142
+ """
143
+
144
+ @abstractmethod
145
+ def make_queue(self) -> Queue[Any]:
146
+ """Make an empty queue that can be used to pass objects between
147
+ workers created by this factory.
148
+ """
149
+ raise NotImplementedError()
150
+
151
+ @abstractmethod
152
+ def make_event(self) -> Event:
153
+ """Make an event that can be used to communicate a boolean state change
154
+ to workers created by this factory.
155
+ """
156
+ raise NotImplementedError()
157
+
158
+ @abstractmethod
159
+ def make_worker(
160
+ self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
161
+ ) -> Worker:
162
+ """Make a worker that runs the given callable.
163
+
164
+ Parameters
165
+ ----------
166
+ target : `~collections.abc.Callable`
167
+ A callable to invoke on the worker.
168
+ args : `tuple`
169
+ Positional arguments to pass to the callable.
170
+ name : `str`, optional
171
+ Human-readable name for the worker.
172
+
173
+ Returns
174
+ -------
175
+ worker : `Worker`
176
+ Process or thread that is already running the given callable.
177
+ """
178
+ raise NotImplementedError()
179
+
180
+
181
+ class _ThreadWorker(Worker):
182
+ """An implementation of `Worker` backed by the `threading` module."""
183
+
184
+ def __init__(self, thread: threading.Thread):
185
+ super().__init__()
186
+ self._thread = thread
187
+
188
+ @property
189
+ def name(self) -> str:
190
+ return self._thread.name
191
+
192
+ def join(self, timeout: float | None = None) -> None:
193
+ self._thread.join(timeout=timeout)
194
+
195
+ def is_alive(self) -> bool:
196
+ return self._thread.is_alive()
197
+
198
+
199
+ class _ThreadQueue[T](Queue[T]):
200
+ def __init__(self) -> None:
201
+ self._impl = queue.Queue[T]()
202
+
203
+ @overload
204
+ def get(self, *, block: Literal[True]) -> T: ...
205
+
206
+ @overload
207
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ...
208
+
209
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
210
+ try:
211
+ return self._impl.get(block=block, timeout=timeout)
212
+ except queue.Empty:
213
+ return None
214
+
215
+ def put(self, item: T) -> None:
216
+ self._impl.put(item, block=False)
217
+
218
+
219
+ class ThreadWorkerFactory(WorkerFactory):
220
+ """An implementation of `WorkerFactory` backed by the `threading`
221
+ module.
222
+ """
223
+
224
+ def make_queue(self) -> Queue[Any]:
225
+ return _ThreadQueue()
226
+
227
+ def make_event(self) -> Event:
228
+ return threading.Event()
229
+
230
+ def make_worker(
231
+ self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
232
+ ) -> Worker:
233
+ thread = threading.Thread(target=target, args=args, name=name)
234
+ thread.start()
235
+ return _ThreadWorker(thread)
236
+
237
+
238
+ class _ProcessWorker(Worker):
239
+ """An implementation of `Worker` backed by the `multiprocessing` module."""
240
+
241
+ def __init__(self, process: multiprocessing.context.SpawnProcess):
242
+ super().__init__()
243
+ self._process = process
244
+
245
+ @property
246
+ def name(self) -> str:
247
+ return self._process.name
248
+
249
+ def join(self, timeout: float | None = None) -> None:
250
+ self._process.join(timeout=timeout)
251
+
252
+ def is_alive(self) -> bool:
253
+ return self._process.is_alive()
254
+
255
+ def kill(self) -> None:
256
+ """Kill the worker, if possible."""
257
+ self._process.kill()
258
+
259
+
260
+ class _ProcessQueue[T](Queue[T]):
261
+ def __init__(self, impl: multiprocessing.Queue):
262
+ self._impl = impl
263
+
264
+ @overload
265
+ def get(self, *, block: Literal[True]) -> T: ...
266
+
267
+ @overload
268
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None: ...
269
+
270
+ def get(self, *, timeout: float | None = None, block: bool = False) -> T | None:
271
+ try:
272
+ return self._impl.get(block=block, timeout=timeout)
273
+ except queue.Empty:
274
+ return None
275
+
276
+ def put(self, item: T) -> None:
277
+ self._impl.put(item, block=False)
278
+
279
+ def kill(self) -> None:
280
+ self._impl.cancel_join_thread()
281
+ self._impl.close()
282
+
283
+
284
+ class SpawnWorkerFactory(WorkerFactory):
285
+ """An implementation of `WorkerFactory` backed by the `multiprocessing`
286
+ module, with new processes started by spawning.
287
+ """
288
+
289
+ def __init__(self) -> None:
290
+ self._ctx = multiprocessing.get_context("spawn")
291
+
292
+ def make_queue(self) -> Queue[Any]:
293
+ return _ProcessQueue(self._ctx.Queue())
294
+
295
+ def make_event(self) -> Event:
296
+ return self._ctx.Event()
297
+
298
+ def make_worker(
299
+ self, target: Callable[..., None], args: tuple[Any, ...], name: str | None = None
300
+ ) -> Worker:
301
+ process = self._ctx.Process(target=target, args=args, name=name)
302
+ process.start()
303
+ return _ProcessWorker(process)