lsst-pipe-base 29.2025.3900__py3-none-any.whl → 29.2025.4000__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. lsst/pipe/base/dot_tools.py +14 -152
  2. lsst/pipe/base/exec_fixup_data_id.py +17 -44
  3. lsst/pipe/base/execution_graph_fixup.py +49 -18
  4. lsst/pipe/base/graph/graph.py +28 -9
  5. lsst/pipe/base/graph_walker.py +119 -0
  6. lsst/pipe/base/log_capture.py +5 -2
  7. lsst/pipe/base/mermaid_tools.py +11 -64
  8. lsst/pipe/base/mp_graph_executor.py +298 -236
  9. lsst/pipe/base/quantum_graph/__init__.py +32 -0
  10. lsst/pipe/base/quantum_graph/_common.py +610 -0
  11. lsst/pipe/base/quantum_graph/_multiblock.py +737 -0
  12. lsst/pipe/base/quantum_graph/_predicted.py +1874 -0
  13. lsst/pipe/base/quantum_graph/visualization.py +302 -0
  14. lsst/pipe/base/quantum_graph_builder.py +292 -34
  15. lsst/pipe/base/quantum_graph_executor.py +2 -1
  16. lsst/pipe/base/quantum_provenance_graph.py +16 -7
  17. lsst/pipe/base/separable_pipeline_executor.py +126 -15
  18. lsst/pipe/base/simple_pipeline_executor.py +44 -43
  19. lsst/pipe/base/single_quantum_executor.py +1 -40
  20. lsst/pipe/base/tests/mocks/__init__.py +1 -1
  21. lsst/pipe/base/tests/mocks/_pipeline_task.py +16 -1
  22. lsst/pipe/base/tests/mocks/{_in_memory_repo.py → _repo.py} +324 -45
  23. lsst/pipe/base/tests/mocks/_storage_class.py +6 -0
  24. lsst/pipe/base/tests/simpleQGraph.py +11 -5
  25. lsst/pipe/base/version.py +1 -1
  26. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/METADATA +2 -1
  27. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/RECORD +35 -29
  28. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/WHEEL +0 -0
  29. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/entry_points.txt +0 -0
  30. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/licenses/COPYRIGHT +0 -0
  31. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/licenses/LICENSE +0 -0
  32. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/licenses/bsd_license.txt +0 -0
  33. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/licenses/gpl-v3.0.txt +0 -0
  34. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/top_level.txt +0 -0
  35. {lsst_pipe_base-29.2025.3900.dist-info → lsst_pipe_base-29.2025.4000.dist-info}/zip-safe +0 -0
@@ -0,0 +1,737 @@
1
+ # This file is part of pipe_base.
2
+ #
3
+ # Developed for the LSST Data Management System.
4
+ # This product includes software developed by the LSST Project
5
+ # (http://www.lsst.org).
6
+ # See the COPYRIGHT file at the top-level directory of this distribution
7
+ # for details of code ownership.
8
+ #
9
+ # This software is dual licensed under the GNU General Public License and also
10
+ # under a 3-clause BSD license. Recipients may choose which of these licenses
11
+ # to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12
+ # respectively. If you choose the GPL option then the following text applies
13
+ # (but note that there is still no warranty even if you opt for BSD instead):
14
+ #
15
+ # This program is free software: you can redistribute it and/or modify
16
+ # it under the terms of the GNU General Public License as published by
17
+ # the Free Software Foundation, either version 3 of the License, or
18
+ # (at your option) any later version.
19
+ #
20
+ # This program is distributed in the hope that it will be useful,
21
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
22
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23
+ # GNU General Public License for more details.
24
+ #
25
+ # You should have received a copy of the GNU General Public License
26
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
27
+
28
+ from __future__ import annotations
29
+
30
+ __all__ = (
31
+ "Address",
32
+ "AddressReader",
33
+ "AddressRow",
34
+ "AddressWriter",
35
+ "Compressor",
36
+ "Decompressor",
37
+ "InvalidQuantumGraphFileError",
38
+ "MultiblockReader",
39
+ "MultiblockWriter",
40
+ )
41
+
42
+ import dataclasses
43
+ import itertools
44
+ import logging
45
+ import uuid
46
+ from collections.abc import Iterator
47
+ from contextlib import contextmanager
48
+ from io import BufferedReader, BytesIO
49
+ from operator import attrgetter
50
+ from typing import IO, TYPE_CHECKING, ClassVar, Protocol, TypeVar
51
+
52
+ import pydantic
53
+
54
+ if TYPE_CHECKING:
55
+ import zipfile
56
+
57
+
58
+ _LOG = logging.getLogger(__name__)
59
+
60
+
61
+ _T = TypeVar("_T", bound=pydantic.BaseModel)
62
+
63
+
64
+ DEFAULT_PAGE_SIZE: int = 5_000_000
65
+ """Default page size for reading chunks of quantum graph files.
66
+
67
+ This is intended to be large enough to avoid any possibility of individual
68
+ reads suffering from per-seek overheads, especially in network file access,
69
+ while still being small enough to only minimally slow down tiny reads of
70
+ individual quanta (especially for execution).
71
+ """
72
+
73
+
74
+ class Compressor(Protocol):
75
+ """A protocol for objects with a `compress` method that takes and returns
76
+ `bytes`.
77
+ """
78
+
79
+ def compress(self, data: bytes) -> bytes:
80
+ """Compress the given data.
81
+
82
+ Parameters
83
+ ----------
84
+ data : `bytes`
85
+ Uncompressed data.
86
+
87
+ Returns
88
+ -------
89
+ compressed : `bytes`
90
+ Compressed data.
91
+ """
92
+ ...
93
+
94
+
95
+ class Decompressor(Protocol):
96
+ """A protocol for objects with a `decompress` method that takes and returns
97
+ `bytes`.
98
+ """
99
+
100
+ def decompress(self, data: bytes) -> bytes:
101
+ """Decompress the given data.
102
+
103
+ Parameters
104
+ ----------
105
+ data : `bytes`
106
+ Compressed data.
107
+
108
+ Returns
109
+ -------
110
+ decompressed : `bytes`
111
+ Uncompressed data.
112
+ """
113
+ ...
114
+
115
+
116
+ class InvalidQuantumGraphFileError(RuntimeError):
117
+ """An exception raised when a quantum graph file has internal
118
+ inconsistencies or does not actually appear to be a quantum graph file.
119
+ """
120
+
121
+
122
+ @dataclasses.dataclass(slots=True)
123
+ class Address:
124
+ """Struct that holds an address into a multi-block file."""
125
+
126
+ offset: int = 0
127
+ """Byte offset for the block."""
128
+
129
+ size: int = 0
130
+ """Size of the block.
131
+
132
+ This always includes the size of the tiny header that records the block
133
+ size. That header does not include the size of the header, so these sizes
134
+ differ by the ``int_size`` used to write the multi-block file.
135
+
136
+ A size of zero is used (with, by convention, an offset of zero) to indicate
137
+ an absent block.
138
+ """
139
+
140
+ def __str__(self) -> str:
141
+ return f"{self.offset:06}[{self.size:06}]"
142
+
143
+
144
+ @dataclasses.dataclass(slots=True)
145
+ class AddressRow:
146
+ """The in-memory representation of a single row in an address file."""
147
+
148
+ key: uuid.UUID
149
+ """Universally unique identifier for this row."""
150
+
151
+ index: int
152
+ """Monotonically increasing integer ID; unique within this file only."""
153
+
154
+ addresses: list[Address] = dataclasses.field(default_factory=list)
155
+ """Offsets and sizes into multi-block files."""
156
+
157
+ def write(self, stream: IO[bytes], int_size: int) -> None:
158
+ """Write this address row to a file-like object.
159
+
160
+ Parameters
161
+ ----------
162
+ stream : `typing.IO` [ `bytes` ]
163
+ Binary file-like object.
164
+ int_size : `int`
165
+ Number of bytes to use for all integers.
166
+ """
167
+ stream.write(self.key.bytes)
168
+ stream.write(self.index.to_bytes(int_size))
169
+ for address in self.addresses:
170
+ stream.write(address.offset.to_bytes(int_size))
171
+ stream.write(address.size.to_bytes(int_size))
172
+
173
+ @classmethod
174
+ def read(cls, stream: IO[bytes], n_addresses: int, int_size: int) -> AddressRow:
175
+ """Read this address row from a file-like object.
176
+
177
+ Parameters
178
+ ----------
179
+ stream : `typing.IO` [ `bytes` ]
180
+ Binary file-like object.
181
+ n_addresses : `int`
182
+ Number of addresses included in each row.
183
+ int_size : `int`
184
+ Number of bytes to use for all integers.
185
+ """
186
+ key = uuid.UUID(int=int.from_bytes(stream.read(16)))
187
+ index = int.from_bytes(stream.read(int_size))
188
+ row = AddressRow(key, index)
189
+ for _ in range(n_addresses):
190
+ offset = int.from_bytes(stream.read(int_size))
191
+ size = int.from_bytes(stream.read(int_size))
192
+ row.addresses.append(Address(offset, size))
193
+ return row
194
+
195
+ def __str__(self) -> str:
196
+ return f"{self.key} {self.index:06} {' '.join(str(a) for a in self.addresses)}"
197
+
198
+
199
+ @dataclasses.dataclass
200
+ class AddressWriter:
201
+ """A helper object for writing address files for multi-block files."""
202
+
203
+ indices: dict[uuid.UUID, int] = dataclasses.field(default_factory=dict)
204
+ """Mapping from UUID to internal integer ID.
205
+
206
+ The internal integer ID must always correspond to the index into the
207
+ sorted list of all UUIDs, but this `dict` need not be sorted itself.
208
+ """
209
+
210
+ addresses: list[dict[uuid.UUID, Address]] = dataclasses.field(default_factory=list)
211
+ """Addresses to store with each UUID.
212
+
213
+ Every key in one of these dictionaries must have an entry in `indices`.
214
+ The converse is not true.
215
+ """
216
+
217
+ def write(self, stream: IO[bytes], int_size: int) -> None:
218
+ """Write all addresses to a file-like object.
219
+
220
+ Parameters
221
+ ----------
222
+ stream : `typing.IO` [ `bytes` ]
223
+ Binary file-like object.
224
+ int_size : `int`
225
+ Number of bytes to use for all integers.
226
+ """
227
+ for n, address_map in enumerate(self.addresses):
228
+ if not self.indices.keys() >= address_map.keys():
229
+ raise AssertionError(
230
+ f"Logic bug in quantum graph I/O: address map {n} of {len(self.addresses)} has IDs "
231
+ f"{address_map.keys() - self.indices.keys()} not in the index map."
232
+ )
233
+ stream.write(int_size.to_bytes(1))
234
+ stream.write(len(self.indices).to_bytes(int_size))
235
+ stream.write(len(self.addresses).to_bytes(int_size))
236
+ empty_address = Address()
237
+ for key in sorted(self.indices.keys(), key=attrgetter("int")):
238
+ row = AddressRow(key, self.indices[key], [m.get(key, empty_address) for m in self.addresses])
239
+ _LOG.debug("Wrote address %s.", row)
240
+ row.write(stream, int_size)
241
+
242
+ def write_to_zip(self, zf: zipfile.ZipFile, name: str, int_size: int) -> None:
243
+ """Write all addresses to a file in a zip archive.
244
+
245
+ Parameters
246
+ ----------
247
+ zf : `zipfile.ZipFile`
248
+ Zip archive to add the file to.
249
+ name : `str`
250
+ Base name for the address file; an extension will be added.
251
+ int_size : `int`
252
+ Number of bytes to use for all integers.
253
+ """
254
+ with zf.open(f"{name}.addr", mode="w") as stream:
255
+ self.write(stream, int_size=int_size)
256
+
257
+
258
+ @dataclasses.dataclass
259
+ class AddressReader:
260
+ """A helper object for reading address files for multi-block files."""
261
+
262
+ MAX_UUID_INT: ClassVar[int] = 2**128
263
+ """The maximum value of a UUID's integer form."""
264
+
265
+ stream: IO[bytes]
266
+ """Stream to read from."""
267
+
268
+ int_size: int
269
+ """Size of each integer in bytes."""
270
+
271
+ n_rows: int
272
+ """Number of address rows in the file (also the number of UUIDs)."""
273
+
274
+ n_addresses: int
275
+ """Number of addresses in each row."""
276
+
277
+ start_index: int
278
+ """Index of the first row."""
279
+
280
+ rows: dict[uuid.UUID, AddressRow]
281
+ """Rows that have already been read."""
282
+
283
+ rows_per_page: int
284
+ """Minimum number of rows to read at once."""
285
+
286
+ unread_pages: dict[int, int]
287
+ """Pages that have not yet been read, as a mapping from page index to the
288
+ number of rows in that page.
289
+
290
+ Values are always `rows_per_page` with the possible exception of the last
291
+ page.
292
+ """
293
+
294
+ @classmethod
295
+ def from_stream(cls, stream: IO[bytes], page_size: int, start_index: int = 0) -> AddressReader:
296
+ """Construct from a stream by reading the header.
297
+
298
+ Parameters
299
+ ----------
300
+ stream : `typing.IO` [ `bytes` ]
301
+ File-like object to read from.
302
+ page_size : `int`
303
+ Approximate number of bytes to read at a time when searching for an
304
+ address.
305
+ start_index : `int`, optional
306
+ Value of the first index in the file.
307
+ """
308
+ int_size = int.from_bytes(stream.read(1))
309
+ n_rows = int.from_bytes(stream.read(int_size))
310
+ n_addresses = int.from_bytes(stream.read(int_size))
311
+ rows_per_page = max(page_size // cls.compute_row_size(int_size, n_addresses), 1)
312
+ n_full_pages, last_rows_per_page = divmod(n_rows, rows_per_page)
313
+ unread_pages = dict.fromkeys(range(n_full_pages), rows_per_page)
314
+ if last_rows_per_page := n_rows % rows_per_page:
315
+ unread_pages[n_full_pages] = last_rows_per_page
316
+ return cls(
317
+ stream,
318
+ int_size=int_size,
319
+ n_rows=n_rows,
320
+ n_addresses=n_addresses,
321
+ start_index=start_index,
322
+ rows={},
323
+ rows_per_page=rows_per_page,
324
+ unread_pages=unread_pages,
325
+ )
326
+
327
+ @classmethod
328
+ @contextmanager
329
+ def open_in_zip(
330
+ cls,
331
+ zf: zipfile.ZipFile,
332
+ name: str,
333
+ page_size: int,
334
+ int_size: int | None = None,
335
+ start_index: int = 0,
336
+ ) -> Iterator[AddressReader]:
337
+ """Make a reader for an address file in a zip archive.
338
+
339
+ Parameters
340
+ ----------
341
+ zf : `zipfile.ZipFile`
342
+ Zip archive to read the file from.
343
+ name : `str`
344
+ Base name for the address file; an extension will be added.
345
+ page_size : `int`
346
+ Approximate number of bytes to read at a time when searching for an
347
+ address.
348
+ int_size : `int`, optional
349
+ Number of bytes to use for all integers. This is checked against
350
+ the size embedded in the file.
351
+ start_index : `int`, optional
352
+ Value of the first index in the file.
353
+
354
+ Returns
355
+ -------
356
+ reader : `contextlib.AbstractContextManager` [ `AddressReader` ]
357
+ Context manager that returns a reader when entered.
358
+ """
359
+ with zf.open(f"{name}.addr", mode="r") as stream:
360
+ result = cls.from_stream(stream, page_size=page_size, start_index=start_index)
361
+ if int_size is not None and result.int_size != int_size:
362
+ raise InvalidQuantumGraphFileError(
363
+ "int size in address file does not match int size in header."
364
+ )
365
+ yield result
366
+
367
+ @staticmethod
368
+ def compute_header_size(int_size: int) -> int:
369
+ """Return the size (in bytes) of the header of an address file.
370
+
371
+ Parameters
372
+ ----------
373
+ int_size : `int`
374
+ Size of each integer in bytes.
375
+
376
+ Returns
377
+ -------
378
+ size : `int`
379
+ Size of the header in bytes.
380
+ """
381
+ return (
382
+ 1 # int_size
383
+ + int_size # number of rows
384
+ + int_size # number of addresses in each row
385
+ )
386
+
387
+ @staticmethod
388
+ def compute_row_size(int_size: int, n_addresses: int) -> int:
389
+ """Return the size (in bytes) of each row of an address file.
390
+
391
+ Parameters
392
+ ----------
393
+ int_size : `int`
394
+ Size of each integer in bytes.
395
+ n_addresses : `int`
396
+ Number of addresses in each row.
397
+
398
+ Returns
399
+ -------
400
+ size : `int`
401
+ Size of each row in bytes.
402
+ """
403
+ return (
404
+ 16 # uuid
405
+ + int_size
406
+ * (
407
+ 1 # index
408
+ + 2 * n_addresses
409
+ )
410
+ )
411
+
412
+ @property
413
+ def header_size(self) -> int:
414
+ """The size (in bytes) of the header of this address file."""
415
+ return self.compute_header_size(self.int_size)
416
+
417
+ @property
418
+ def row_size(self) -> int:
419
+ """The size (in bytes) of each row of this address file."""
420
+ return self.compute_row_size(self.int_size, self.n_addresses)
421
+
422
+ def read_all(self) -> dict[uuid.UUID, AddressRow]:
423
+ """Read all addresses in the file.
424
+
425
+ Returns
426
+ -------
427
+ rows : `dict` [ `uuid.UUID`, `AddressRow` ]
428
+ Mapping of loaded address rows, keyed by UUID.
429
+ """
430
+ # Shortcut out if we've already read everything, but don't bother
431
+ # optimizing previous partial reads.
432
+ if self.unread_pages:
433
+ self.stream.seek(self.header_size)
434
+ data = self.stream.read()
435
+ buffer = BytesIO(data)
436
+ _LOG.debug("Reading all %d address rows.", self.n_rows)
437
+ for _ in range(self.n_rows):
438
+ self._read_row(buffer)
439
+ self.unread_pages.clear()
440
+ return self.rows
441
+
442
+ def find(self, key: uuid.UUID) -> AddressRow:
443
+ """Read the row for the given UUID.
444
+
445
+ Parameters
446
+ ----------
447
+ key : `uuid.UUID`
448
+ UUID to find.
449
+
450
+ Returns
451
+ -------
452
+ row : `AddressRow`
453
+ Addresses for the given UUID.
454
+ """
455
+ if (row := self.rows.get(key)) is not None:
456
+ return row
457
+ guess_index_float = (key.int / self.MAX_UUID_INT) * self.n_rows + self.start_index
458
+ guess_page_float = (guess_index_float - self.start_index) / self.rows_per_page
459
+ guess_page = int(guess_page_float)
460
+ _LOG.debug(
461
+ "Searching for %s, starting at index %s of %s (%s rows per page).",
462
+ key,
463
+ guess_index_float,
464
+ self.n_rows,
465
+ self.rows_per_page,
466
+ )
467
+ for page in self._page_search_path(guess_page):
468
+ if page in self.unread_pages:
469
+ self._read_page(page)
470
+ if (row := self.rows.get(key)) is not None:
471
+ return row
472
+ elif not self.unread_pages:
473
+ raise LookupError(f"Address for UUID {key} not found.")
474
+ raise AssertionError("Logic error in page tracking.")
475
+
476
+ def _read_page(self, page_index: int) -> None:
477
+ rows_in_page = self.unread_pages[page_index]
478
+ _LOG.debug(
479
+ "Reading page %s (rows %s:%s).",
480
+ page_index,
481
+ page_index * self.rows_per_page,
482
+ page_index * self.rows_per_page + rows_in_page,
483
+ )
484
+ self.stream.seek(page_index * self.rows_per_page * self.row_size + self.header_size)
485
+ data = self.stream.read(self.row_size * rows_in_page)
486
+ page_stream = BytesIO(data)
487
+ for _ in range(rows_in_page):
488
+ self._read_row(page_stream)
489
+ del self.unread_pages[page_index]
490
+
491
+ def _read_row(self, page_stream: BytesIO) -> AddressRow:
492
+ row = AddressRow.read(page_stream, self.n_addresses, self.int_size)
493
+ self.rows[row.key] = row
494
+ _LOG.debug("Read address row %s.", row)
495
+ return row
496
+
497
+ def _page_search_path(self, mid: int) -> Iterator[int]:
498
+ yield mid
499
+ for abs_offset in itertools.count(1):
500
+ yield mid + abs_offset
501
+ yield mid - abs_offset
502
+
503
+
504
+ @dataclasses.dataclass
505
+ class MultiblockWriter:
506
+ """A helper object for writing multi-block files."""
507
+
508
+ stream: IO[bytes]
509
+ """A binary file-like object to write to."""
510
+
511
+ int_size: int
512
+ """Number of bytes to use for all integers."""
513
+
514
+ file_size: int = 0
515
+ """Running size of the full file."""
516
+
517
+ addresses: dict[uuid.UUID, Address] = dataclasses.field(default_factory=dict)
518
+ """Running map of all addresses added to the file so far.
519
+
520
+ When the multi-block file is fully written, this is appended to the
521
+ `AddressWriter.addresses` to write the corresponding address file.
522
+ """
523
+
524
+ @classmethod
525
+ @contextmanager
526
+ def open_in_zip(cls, zf: zipfile.ZipFile, name: str, int_size: int) -> Iterator[MultiblockWriter]:
527
+ """Open a writer for a file in a zip archive.
528
+
529
+ Parameters
530
+ ----------
531
+ zf : `zipfile.ZipFile`
532
+ Zip archive to add the file to.
533
+ name : `str`
534
+ Base name for the multi-block file; an extension will be added.
535
+ int_size : `int`
536
+ Number of bytes to use for all integers.
537
+
538
+ Returns
539
+ -------
540
+ writer : `contextlib.AbstractContextManager` [ `MultiblockWriter` ]
541
+ Context manager that returns a writer when entered.
542
+ """
543
+ with zf.open(f"{name}.mb", mode="w", force_zip64=True) as stream:
544
+ yield MultiblockWriter(stream, int_size)
545
+
546
+ def write_bytes(self, id: uuid.UUID, data: bytes) -> Address:
547
+ """Write raw bytes to the multi-block file.
548
+
549
+ Parameters
550
+ ----------
551
+ id : `uuid.UUID`
552
+ Unique ID of the object described by this block.
553
+ data : `bytes`
554
+ Data to store directly.
555
+
556
+ Returns
557
+ -------
558
+ address : `Address`
559
+ Address of the bytes just written.
560
+ """
561
+ self.stream.write(len(data).to_bytes(self.int_size))
562
+ self.stream.write(data)
563
+ block_size = len(data) + self.int_size
564
+ address = Address(offset=self.file_size, size=block_size)
565
+ self.file_size += block_size
566
+ self.addresses[id] = address
567
+ return address
568
+
569
+ def write_model(self, id: uuid.UUID, model: pydantic.BaseModel, compressor: Compressor) -> Address:
570
+ """Write raw bytes to the multi-block file.
571
+
572
+ Parameters
573
+ ----------
574
+ id : `uuid.UUID`
575
+ Unique ID of the object described by this block.
576
+ model : `pydantic.BaseModel`
577
+ Model to convert to JSON and compress.
578
+ compressor : `Compressor`
579
+ Object with a `compress` method that takes and returns `bytes`.
580
+
581
+ Returns
582
+ -------
583
+ address : `Address`
584
+ Address of the bytes just written.
585
+ """
586
+ json_data = model.model_dump_json().encode()
587
+ compressed_data = compressor.compress(json_data)
588
+ return self.write_bytes(id, compressed_data)
589
+
590
+
591
+ @dataclasses.dataclass
592
+ class MultiblockReader:
593
+ """A helper object for reader multi-block files."""
594
+
595
+ stream: IO[bytes]
596
+ """A binary file-like object to read from."""
597
+
598
+ int_size: int
599
+ """Number of bytes to use for all integers."""
600
+
601
+ @classmethod
602
+ @contextmanager
603
+ def open_in_zip(cls, zf: zipfile.ZipFile, name: str, *, int_size: int) -> Iterator[MultiblockReader]:
604
+ """Open a reader for a file in a zip archive.
605
+
606
+ Parameters
607
+ ----------
608
+ zf : `zipfile.ZipFile`
609
+ Zip archive to read the file from.
610
+ name : `str`
611
+ Base name for the multi-block file; an extension will be added.
612
+ int_size : `int`
613
+ Number of bytes to use for all integers.
614
+
615
+ Returns
616
+ -------
617
+ reader : `contextlib.AbstractContextManager` [ `MultiblockReader` ]
618
+ Context manager that returns a reader when entered.
619
+ """
620
+ with zf.open(f"{name}.mb", mode="r") as stream:
621
+ yield MultiblockReader(stream, int_size)
622
+
623
+ @classmethod
624
+ def read_all_bytes_in_zip(
625
+ cls, zf: zipfile.ZipFile, name: str, *, int_size: int, page_size: int
626
+ ) -> Iterator[bytes]:
627
+ """Iterate over all of the byte blocks in a file in a zip archive.
628
+
629
+ Parameters
630
+ ----------
631
+ zf : `zipfile.ZipFile`
632
+ Zip archive to read the file from.
633
+ name : `str`
634
+ Base name for the multi-block file; an extension will be added.
635
+ int_size : `int`
636
+ Number of bytes to use for all integers.
637
+ page_size : `int`
638
+ Approximate number of bytes to read at a time.
639
+
640
+ Returns
641
+ -------
642
+ byte_iter : `~collections.abc.Iterator` [ `bytes` ]
643
+ Iterator over blocks.
644
+ """
645
+ with zf.open(f"{name}.mb", mode="r") as zf_stream:
646
+ # The standard library typing of IO[bytes] tiers isn't consistent.
647
+ buffered_stream = BufferedReader(zf_stream) # type: ignore[type-var]
648
+ size_data = buffered_stream.read(int_size)
649
+ while size_data:
650
+ internal_size = int.from_bytes(size_data)
651
+ yield buffered_stream.read(internal_size)
652
+ size_data = buffered_stream.read(int_size)
653
+
654
+ @classmethod
655
+ def read_all_models_in_zip(
656
+ cls,
657
+ zf: zipfile.ZipFile,
658
+ name: str,
659
+ model_type: type[_T],
660
+ decompressor: Decompressor,
661
+ *,
662
+ int_size: int,
663
+ page_size: int,
664
+ ) -> Iterator[_T]:
665
+ """Iterate over all of the models in a file in a zip archive.
666
+
667
+ Parameters
668
+ ----------
669
+ zf : `zipfile.ZipFile`
670
+ Zip archive to read the file from.
671
+ name : `str`
672
+ Base name for the multi-block file; an extension will be added.
673
+ model_type : `type` [ `pydantic.BaseModel` ]
674
+ Pydantic model to validate JSON with.
675
+ decompressor : `Decompressor`
676
+ Object with a `decompress` method that takes and returns `bytes`.
677
+ int_size : `int`
678
+ Number of bytes to use for all integers.
679
+ page_size : `int`
680
+ Approximate number of bytes to read at a time.
681
+
682
+ Returns
683
+ -------
684
+ model_iter : `~collections.abc.Iterator` [ `pydantic.BaseModel` ]
685
+ Iterator over model instances.
686
+ """
687
+ for compressed_data in cls.read_all_bytes_in_zip(zf, name, int_size=int_size, page_size=page_size):
688
+ json_data = decompressor.decompress(compressed_data)
689
+ yield model_type.model_validate_json(json_data)
690
+
691
+ def read_bytes(self, address: Address) -> bytes | None:
692
+ """Read raw bytes from the multi-block file.
693
+
694
+ Parameters
695
+ ----------
696
+ address : `Address`
697
+ Offset and size of the data to read.
698
+
699
+ Returns
700
+ -------
701
+ data : `bytes` or `None`
702
+ Data read directly, or `None` if the address has zero size.
703
+ """
704
+ if not address.size:
705
+ return None
706
+ self.stream.seek(address.offset)
707
+ data = self.stream.read(address.size)
708
+ internal_size = int.from_bytes(data[: self.int_size])
709
+ data = data[self.int_size :]
710
+ if len(data) != internal_size:
711
+ raise InvalidQuantumGraphFileError(
712
+ f"Internal size {internal_size} does not match loaded data size {len(data)}."
713
+ )
714
+ return data
715
+
716
+ def read_model(self, address: Address, model_type: type[_T], decompressor: Decompressor) -> _T | None:
717
+ """Read a single compressed JSON block.
718
+
719
+ Parameters
720
+ ----------
721
+ address : `Address`
722
+ Size and offset of the block.
723
+ model_type : `type` [ `pydantic.BaseModel` ]
724
+ Pydantic model to validate JSON with.
725
+ decompressor : `Decompressor`
726
+ Object with a `decompress` method that takes and returns `bytes`.
727
+
728
+ Returns
729
+ -------
730
+ model : `pydantic.BaseModel`
731
+ Validated model.
732
+ """
733
+ compressed_data = self.read_bytes(address)
734
+ if compressed_data is None:
735
+ return None
736
+ json_data = decompressor.decompress(compressed_data)
737
+ return model_type.model_validate_json(json_data)