micress-micpy 0.3.2b2__py3-none-any.whl → 0.4.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
micpy/bin copy.py ADDED
@@ -0,0 +1,865 @@
1
+ """The `micpy.bin` module provides methods to read and write binary files."""
2
+
3
+ from collections.abc import Callable as ABCCallable
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from time import time
7
+ from typing import Callable, Generator, IO, List, Optional, Tuple, Union, overload
8
+ import os
9
+ import sys
10
+
11
+ import gzip
12
+ import rapidgzip
13
+
14
+ import numpy as np
15
+
16
+ from micpy import geo
17
+ from micpy import utils
18
+ from micpy.matplotlib import matplotlib, pyplot
19
+
20
+
21
+ __all__ = ["File", "Field", "Series", "plot", "PlotArgs"]
22
+
23
+
24
+ class Footer:
25
+ """A field footer."""
26
+
27
+ TYPE = [("length", np.int32)]
28
+ SIZE = np.dtype(TYPE).itemsize
29
+
30
+ def __init__(self, length: int = 0):
31
+ data = np.array((length,), dtype=self.TYPE)
32
+ self.body_length = data["length"]
33
+
34
+ def to_bytes(self):
35
+ """Convert the footer to bytes."""
36
+ return np.array((self.body_length,), dtype=self.TYPE).tobytes()
37
+
38
+
39
+ class Header:
40
+ """A field header."""
41
+
42
+ TYPE = [("size", np.int32), ("time", np.float32), ("length", np.int32)]
43
+ SIZE = np.dtype(TYPE).itemsize
44
+
45
+ def __init__(self, size: int, time: float, length: int):
46
+ self.size = size
47
+ self.time = round(float(time), 7)
48
+ self.body_length = length
49
+
50
+ self.field_size = Header.SIZE + 4 * self.body_length + Footer.SIZE
51
+
52
+ if not self.size == self.field_size - 8:
53
+ raise ValueError("Invalid header")
54
+
55
+ @staticmethod
56
+ def from_bytes(data: bytes):
57
+ """Create a new header from bytes."""
58
+ kwargs = np.frombuffer(data[: Header.SIZE], dtype=Header.TYPE)
59
+ return Header(*kwargs[0].item())
60
+
61
+ def to_bytes(self):
62
+ """Convert the header to bytes."""
63
+ return np.array(
64
+ (self.size, self.time, self.body_length), dtype=Header.TYPE
65
+ ).tobytes()
66
+
67
+ @staticmethod
68
+ def read(file: IO[bytes]) -> "Header":
69
+ """Read the header of a binary file."""
70
+ file.seek(0)
71
+ data = file.read(Header.SIZE)
72
+ return Header.from_bytes(data)
73
+
74
+ def get_field_count(self, file_size: int) -> int:
75
+ """Get the number of fields in the file."""
76
+ return file_size // self.field_size
77
+
78
+
79
+ class Field(np.ndarray):
80
+ """A field."""
81
+
82
+ def __new__(cls, data, time: float, spacing: Tuple[float, float, float]):
83
+ obj = np.asarray(data).view(cls)
84
+ obj.time = time
85
+ obj.spacing = spacing
86
+ return obj
87
+
88
+ def __array_finalize__(self, obj):
89
+ if obj is None:
90
+ return
91
+
92
+ # pylint: disable=attribute-defined-outside-init
93
+ self.time = getattr(obj, "time", None)
94
+ self.spacing = getattr(obj, "spacing", None)
95
+
96
+ @staticmethod
97
+ def from_bytes(data: bytes, shape=None, spacing=None):
98
+ """Create a new time step from bytes."""
99
+
100
+ header = Header.from_bytes(data)
101
+
102
+ start, end, count = (
103
+ header.SIZE,
104
+ header.field_size,
105
+ header.body_length,
106
+ )
107
+
108
+ data = data[start:end]
109
+ data = np.frombuffer(data, count=count, dtype="float32")
110
+ if np.all(np.isclose(data, data.astype("int32"))):
111
+ data = data.astype("int32")
112
+
113
+ if shape is not None:
114
+ data = data.reshape(shape)
115
+
116
+ return Field(data, time=header.time, spacing=spacing)
117
+
118
+ @staticmethod
119
+ def from_bytes2(data: bytes, shape=None, spacing=None):
120
+ """Create a new time step from bytes."""
121
+
122
+ header = Header.from_bytes(data)
123
+
124
+ start, end, count = (
125
+ header.SIZE,
126
+ header.field_size,
127
+ header.body_length,
128
+ )
129
+
130
+ data = data[start:end]
131
+ data = np.frombuffer(data, count=count, dtype="float32")
132
+ # if np.all(np.isclose(data, data.astype("int32"))):
133
+ # data = data.astype("int32")
134
+
135
+ if shape is not None:
136
+ data = data.reshape(shape)
137
+
138
+ return Field(data, time=header.time, spacing=spacing)
139
+
140
+ def to_bytes(self):
141
+ """Convert the field to bytes."""
142
+ header = Header(
143
+ size=self.size * self.itemsize + 8,
144
+ time=self.time,
145
+ length=self.size,
146
+ )
147
+ footer = Footer(length=self.size)
148
+
149
+ return header.to_bytes() + self.tobytes() + footer.to_bytes()
150
+
151
+ @staticmethod
152
+ def dimensions(shape: Tuple[int, int, int]) -> int:
153
+ """Get the number of dimensions of a shape."""
154
+ x, y, _ = shape
155
+
156
+ if y == 1:
157
+ if x == 1:
158
+ return 1
159
+ return 2
160
+ return 3
161
+
162
+ @staticmethod
163
+ def read(
164
+ file: IO[bytes],
165
+ field_id: int,
166
+ field_size: int,
167
+ field_count: int,
168
+ shape=None,
169
+ spacing=None,
170
+ ) -> "Field":
171
+ """Read a field from a binary file."""
172
+
173
+ if field_id < 0:
174
+ field_id += field_count
175
+
176
+ if not 0 <= field_id < field_count:
177
+ raise IndexError("field_id out of range")
178
+
179
+ offset = field_size * field_id
180
+ file.seek(offset)
181
+
182
+ field_data = file.read(field_size)
183
+ if len(field_data) < field_size:
184
+ raise EOFError("Unexpected end of file")
185
+
186
+ return Field.from_bytes(field_data, shape=shape, spacing=spacing)
187
+
188
+ @staticmethod
189
+ def read2(
190
+ file: IO[bytes],
191
+ field_id: int,
192
+ field_size: int,
193
+ shape=None,
194
+ spacing=None,
195
+ ) -> "Field":
196
+ """Read a field from a binary file."""
197
+
198
+ if field_id < 0:
199
+ start = time()
200
+ info("Indexing entire file...")
201
+ file.seek(0, 2)
202
+ file_size = file.tell()
203
+ field_count = file_size // field_size
204
+ field_id += field_count
205
+ info(f"Indexing completed in {time() - start:.2f} seconds.")
206
+
207
+ start = time()
208
+ info(f"Seeking to field ID {field_id}...")
209
+ offset = field_size * field_id
210
+ file.seek(offset)
211
+ info(f"Seek completed in {time() - start:.2f} seconds.")
212
+
213
+ start = time()
214
+ info(f"Reading data for field ID {field_id}...")
215
+ field_data = file.read(field_size)
216
+ if len(field_data) < field_size:
217
+ raise EOFError("Unexpected end of file")
218
+ info(f"Read completed in {time() - start:.2f} seconds.")
219
+
220
+ start = time()
221
+ info(f"Creating Field object for field ID {field_id}...")
222
+ field = Field.from_bytes2(field_data, shape=shape, spacing=spacing)
223
+ info(f"Field object created in {time() - start:.2f} seconds.")
224
+ return field
225
+
226
+ def to_file(self, file: IO[bytes], geometry: bool = True):
227
+ """Write the field to a binary file.
228
+
229
+ Args:
230
+ file (IO[bytes]): Binary file.
231
+ geometry (bool, optional): `True` if geometry should be written, `False`
232
+ otherwise. Defaults to `True`.
233
+ """
234
+
235
+ file.write(self.to_bytes())
236
+
237
+ if geometry:
238
+ geo_filename, geo_data = geo.build(file.name, self.shape, self.spacing)
239
+ geo.write(geo_filename, geo_data, geo.Type.BASIC)
240
+
241
+ def write(
242
+ self,
243
+ filename: str,
244
+ compressed: bool = True,
245
+ geometry: bool = True,
246
+ ):
247
+ """Write the field to a binary file.
248
+
249
+ Args:
250
+ filename (str): Filename of the binary file.
251
+ compressed (bool, optional): `True` if file should be compressed, `False`
252
+ otherwise. Defaults to `True`.
253
+ geometry (bool, optional): `True` if geometry should be written, `False`
254
+ otherwise. Defaults to `True`.
255
+ """
256
+
257
+ file_open = gzip.open if compressed else open
258
+ with file_open(filename, "wb") as file:
259
+ self.to_file(file, geometry)
260
+
261
+
262
+ class Series(np.ndarray):
263
+ def __new__(cls, fields: List[Field]):
264
+ obj = np.asarray(fields).view(cls)
265
+ obj.times = [field.time for field in fields]
266
+ obj.spacings = [field.spacing for field in fields]
267
+ return obj
268
+
269
+ def __array_finalize__(self, obj):
270
+ if obj is None:
271
+ return
272
+
273
+ # pylint: disable=attribute-defined-outside-init
274
+ self.times = getattr(obj, "times", None)
275
+ self.spacings = getattr(obj, "spacings", None)
276
+
277
+ def iterate_fields(self):
278
+ """Iterate over fields in the series.
279
+
280
+ Yields:
281
+ Field.
282
+ """
283
+ for item, time, spacing in zip(self, self.times, self.spacings):
284
+ yield Field(item, time, spacing)
285
+
286
+ def get_field(self, index: int) -> Field:
287
+ """Get a field from the series.
288
+
289
+ Args:
290
+ index (int): Index of the field.
291
+
292
+ Returns:
293
+ Field.
294
+ """
295
+ return Field(self[index], self.times[index], self.spacings[index])
296
+
297
+ def get_series(self, key: Union[int, slice, list]) -> "Series":
298
+ """Get a series of fields.
299
+
300
+ Args:
301
+ key (Union[int, slice, list]): Key to list of field IDs, a slice object, or a
302
+ list of field IDs.
303
+
304
+ Returns:
305
+ Series of fields.
306
+ """
307
+ if isinstance(key, int):
308
+ return Series([self.get_field(key)])
309
+ if isinstance(key, slice):
310
+ return Series([self.get_field(i) for i in range(*key.indices(len(self)))])
311
+ if isinstance(key, list):
312
+ return Series([self.get_field(i) for i in key])
313
+ raise TypeError("Invalid argument type")
314
+
315
+ def write(self, filename: str, compressed: bool = True, geometry: bool = True):
316
+ """Write the series to a binary file.
317
+
318
+ Args:
319
+ filename (str): Filename of the binary file.
320
+ compressed (bool, optional): `True` if file should be compressed, `False`
321
+ otherwise. Defaults to `True`.
322
+ geometry (bool, optional): `True` if geometry should be written, `False`
323
+ otherwise. Defaults to `True`.
324
+ """
325
+
326
+ file_open = gzip.open if compressed else open
327
+ with file_open(filename, "wb") as file:
328
+ for field in self.iterate_fields():
329
+ field.to_file(file, geometry)
330
+ geometry = False
331
+
332
+
333
+ verbose = False
334
+
335
+
336
+ def info(*args):
337
+ if verbose:
338
+ print(*args)
339
+
340
+
341
+ def warn(*args):
342
+ if verbose:
343
+ print(*args) # , file=sys.stderr)
344
+
345
+
346
+ def index_filename(filename: str) -> Path:
347
+ return Path(f"{filename}.idx")
348
+
349
+
350
+ def import_index(file, index: str):
351
+ path = Path(index)
352
+ if not path.is_file():
353
+ warn("Index file not found.")
354
+ return
355
+ if not os.access(path, os.R_OK):
356
+ warn("Index file is not readable.")
357
+ return
358
+ try:
359
+ file.import_index(str(path))
360
+ info("Index file imported.")
361
+ except ValueError:
362
+ warn("Index file is invalid.")
363
+ return
364
+ except (PermissionError, FileNotFoundError, IsADirectoryError, OSError):
365
+ warn("Failed to import index file.")
366
+ return
367
+
368
+
369
+ def export_index(file, index: str):
370
+ path = Path(index)
371
+
372
+ if path.exists() and not path.is_file():
373
+ warn("Index path is not a file.")
374
+ return
375
+ if not os.access(path.parent, os.W_OK):
376
+ warn("Index file is not writable.")
377
+ return
378
+ try:
379
+ file.export_index(str(path))
380
+ info("Index file exported.")
381
+ except (PermissionError, IsADirectoryError, FileNotFoundError, OSError):
382
+ warn("Failed to export index file.")
383
+ return
384
+
385
+
386
+ def read_field(
387
+ filename: str,
388
+ field_id: int,
389
+ shape: tuple = None,
390
+ spacing: tuple = None,
391
+ threads: int = None,
392
+ ) -> Field:
393
+
394
+ if not os.path.isfile(filename):
395
+ raise FileNotFoundError(f"File not found: {filename}")
396
+
397
+ threads = threads if threads is not None else min(4, os.cpu_count())
398
+
399
+ if shape is None or spacing is None:
400
+ try:
401
+ geometry = geo.read(geo.find(filename), type=geo.Type.BASIC)
402
+ shape = shape or geometry["shape"][::-1]
403
+ spacing = spacing or geometry["spacing"][::-1]
404
+ except (geo.GeometryFileNotFoundError, geo.MultipleGeometryFilesError):
405
+ warn("Caution: A geometry file was not found.")
406
+
407
+ compressed = utils.is_compressed(filename)
408
+
409
+ if compressed:
410
+ file = rapidgzip.open(filename, threads)
411
+ else:
412
+ file = open(filename, "rb")
413
+
414
+ if compressed:
415
+ import_index(file, index_filename(filename))
416
+ info(f"Imported file size: {file.size()} bytes")
417
+
418
+ field = Field.read2(
419
+ file=file,
420
+ field_id=field_id,
421
+ field_size=Header.read(file).field_size,
422
+ shape=shape,
423
+ spacing=spacing,
424
+ )
425
+
426
+ if compressed:
427
+ export_index(file, index_filename(filename))
428
+ info(f"Exported file size: {file.size()} bytes")
429
+
430
+ return field
431
+
432
+
433
+ class File:
434
+ """A binary file."""
435
+
436
+ def __init__(
437
+ self,
438
+ filename: str,
439
+ parallelization: int = None,
440
+ verbose: bool = True,
441
+ ):
442
+ """Initialize a binary file.
443
+
444
+ Args:
445
+ filename (str): File name.
446
+ parallelization (int, optional): Number of threads for reading compressed
447
+ files. Defaults to `None` (auto).
448
+ verbose (bool, optional): Verbose output. Defaults to `True`.
449
+
450
+ Raises:
451
+ `FileNotFoundError`: If file is not found.
452
+ """
453
+ if not os.path.isfile(filename):
454
+ raise FileNotFoundError(f"File not found: {filename}")
455
+
456
+ self._filename: str = filename
457
+ self._parallelization: int = (
458
+ parallelization if parallelization is not None else min(4, os.cpu_count())
459
+ )
460
+ self._verbose: bool = verbose
461
+
462
+ self._file: IO[bytes] = None
463
+ self._compressed: bool = None
464
+ self._header: Header = None
465
+ self._file_size: int = None
466
+
467
+ self.shape: np.ndarray[(3,), np.int32] = None
468
+ self.spacing: np.ndarray[(3,), np.float32] = None
469
+
470
+ try:
471
+ self.find_geometry()
472
+ except (geo.GeometryFileNotFoundError, geo.MultipleGeometryFilesError):
473
+ self._warn("Caution: A geometry file was not found.")
474
+
475
+ def __getitem__(self, key: Union[int, slice, list, Callable[[Field], bool]]):
476
+ return self.read(key)
477
+
478
+ def __enter__(self):
479
+ return self.open()
480
+
481
+ def __exit__(self, exc_type, exc_value, traceback):
482
+ self.close()
483
+
484
+ def __iter__(self):
485
+ if not self._file:
486
+ self._open_file()
487
+ return self.iterate()
488
+
489
+ def _info(self, *args):
490
+ if self._verbose:
491
+ print(*args)
492
+
493
+ def _warn(self, *args):
494
+ if self._verbose:
495
+ print(*args, file=sys.stderr)
496
+
497
+ def _open_file(self):
498
+ self._compressed = utils.is_compressed(self._filename)
499
+
500
+ if self._compressed:
501
+ self._file = rapidgzip.open(self._filename, self._parallelization)
502
+ else:
503
+ self._file = open(self._filename, "rb")
504
+
505
+ self.import_index()
506
+ self.export_index()
507
+
508
+ self._header = Header.read(self._file)
509
+
510
+ def _close_file(self):
511
+ if self._file:
512
+ self._file.close()
513
+ self._reset()
514
+
515
+ def _reset(self):
516
+ self._file = None
517
+ self._compressed = None
518
+ self._header = None
519
+ self._file_size = None
520
+
521
+ def open(self):
522
+ """Open the file."""
523
+ if not self._file:
524
+ self._open_file()
525
+ return self
526
+
527
+ def close(self):
528
+ """Close the file."""
529
+ self._close_file()
530
+
531
+ def times(self) -> List[float]:
532
+ """Get the times of the fields in the file.
533
+
534
+ Returns:
535
+ List of times.
536
+ """
537
+ times = []
538
+ self._file.seek(0)
539
+ while True:
540
+ data = self._file.read(Header.SIZE)
541
+ if len(data) < Header.SIZE:
542
+ break
543
+ header = Header.from_bytes(data)
544
+ times.append(header.time)
545
+ self._file.seek(header.field_size - Header.SIZE, 1)
546
+ return times
547
+
548
+ def _index_path(self) -> Path:
549
+ return Path(f"{self._filename}.idx")
550
+
551
+ def import_index(self):
552
+ """Import index from a file."""
553
+ if not self._compressed:
554
+ return
555
+
556
+ path = self._index_path()
557
+
558
+ if not path.is_file():
559
+ self._warn("Index file not found.")
560
+ return
561
+ if not os.access(path, os.R_OK):
562
+ self._warn("Index file is not readable.")
563
+ return
564
+ try:
565
+ self._file.import_index(str(path))
566
+ self._info("Index file imported.")
567
+ except ValueError:
568
+ self._warn("Index file is invalid.")
569
+ return
570
+ except (PermissionError, FileNotFoundError, IsADirectoryError, OSError):
571
+ self._warn("Failed to import index file.")
572
+ return
573
+ finally:
574
+ self._info("Creating index...")
575
+ start = time()
576
+ self._file.seek(0, 2)
577
+ self._file_size = self._file.tell()
578
+ self._info(f"Index created in {time() - start:.2f} seconds.")
579
+
580
+ def export_index(self):
581
+ """Export index to a file."""
582
+ if not self._compressed:
583
+ return
584
+
585
+ path = self._index_path()
586
+
587
+ if path.exists() and not path.is_file():
588
+ self._warn("Index path is not a file.")
589
+ return
590
+ if not os.access(path.parent, os.W_OK):
591
+ self._warn("Index file is not writable.")
592
+ return
593
+ try:
594
+ self._file.export_index(str(path))
595
+ self._info("Index file exported.")
596
+ except (PermissionError, IsADirectoryError, FileNotFoundError, OSError):
597
+ self._warn("Failed to export index file.")
598
+ return
599
+
600
+ def set_geometry(
601
+ self, shape: Tuple[int, int, int], spacing: Tuple[float, float, float]
602
+ ):
603
+ """Set the geometry.
604
+
605
+ Args:
606
+ shape (Tuple[int, int, int]): Shape of the geometry (z, y, x).
607
+ spacing (Tuple[float, float, float]): Spacing of the geometry (dz, dy, dx) in μm.
608
+ """
609
+
610
+ self.shape = np.array(shape)
611
+ self.spacing = np.array(spacing)
612
+
613
+ self.print_geometry()
614
+
615
+ def read_geometry(self, filename: str, compressed: Optional[bool] = None):
616
+ """Read geometry from a file.
617
+
618
+ Args:
619
+ filename (str): Filename of a geometry file.
620
+ compressed (bool, optional): `True` if file is compressed, `False`
621
+ otherwise. Defaults to `None` (auto).
622
+ """
623
+ if compressed is None:
624
+ compressed = utils.is_compressed(filename)
625
+
626
+ geometry = geo.read(filename, type=geo.Type.BASIC, compressed=compressed)
627
+
628
+ shape = geometry["shape"][::-1]
629
+ spacing = geometry["spacing"][::-1]
630
+
631
+ self.set_geometry(shape, spacing)
632
+
633
+ def find_geometry(self, compressed: Optional[bool] = None):
634
+ """Find geometry file and read it.
635
+
636
+ Args:
637
+ compressed (bool, optional): True if file is compressed, False otherwise.
638
+ Defaults to `None` (auto).
639
+
640
+ Raises:
641
+ `GeometryFileNotFoundError`: If no geometry file is found.
642
+ `MultipleGeometryFilesError`: If multiple geometry files are found.
643
+ """
644
+ filename = geo.find(self._filename)
645
+
646
+ self.read_geometry(filename, compressed=compressed)
647
+
648
+ def print_geometry(self):
649
+ """Get a string representation of the geometry."""
650
+
651
+ if self.shape is None or self.spacing is None:
652
+ self._info("Geometry: None")
653
+ return
654
+
655
+ dimensions = Field.dimensions(self.shape)
656
+ cells = self.shape
657
+ spacing = 1e4 * np.round(self.spacing.astype(float), 7)
658
+ size = cells * spacing
659
+
660
+ self._info(f"Geometry: {dimensions}-Dimensional Grid")
661
+ self._info(f"Grid Size [μm]: {tuple(size)}")
662
+ self._info(f"Grid Shape (Cell Count): {tuple(cells)}")
663
+ self._info(f"Grid Spacing (Cell Size) [μm]: {tuple(spacing)}")
664
+
665
+ def read_field(self, field_id: int) -> Field:
666
+ """Read a field from the file.
667
+
668
+ Args:
669
+ field_id (int): Field ID.
670
+
671
+ Returns:
672
+ Field.
673
+ """
674
+
675
+ return Field.read(
676
+ file=self._file,
677
+ field_id=field_id,
678
+ field_size=self._header.field_size,
679
+ field_count=self._header.get_field_count(self._file_size),
680
+ shape=self.shape,
681
+ spacing=self.spacing,
682
+ )
683
+
684
+ def iterate(self) -> Generator[Field, None, None]:
685
+ """Iterate over fields in the file.
686
+
687
+ Returns:
688
+ A generator of fields.
689
+ """
690
+
691
+ field_id = 0
692
+ while True:
693
+ try:
694
+ yield self.read_field(field_id)
695
+ field_id += 1
696
+ except (IndexError, EOFError):
697
+ break
698
+
699
+ @overload
700
+ def read(self, key: None = None) -> Series: ...
701
+ @overload
702
+ def read(self, key: int) -> Series: ...
703
+ @overload
704
+ def read(self, key: slice) -> Series: ...
705
+ @overload
706
+ def read(self, key: list[int]) -> Series: ...
707
+ @overload
708
+ def read(self, key: Callable[[Field], bool]) -> Series: ...
709
+
710
+ def read(
711
+ self, key: Optional[Union[int, slice, list, Callable[[Field], bool]]] = None
712
+ ) -> Series:
713
+ """Read a series of fields from the file.
714
+
715
+ Args:
716
+ key (Union[int, slice, list, Callable[[Field], bool]]): Key to list of
717
+ field IDs, a slice object, a list of field IDs, or a function that filters
718
+ fields. Defaults to `None`.
719
+
720
+ Returns:
721
+ Series of fields.
722
+ """
723
+
724
+ def iterable(iterable):
725
+ if self._verbose:
726
+ return utils.progress_indicator(
727
+ iterable, description="Reading", unit="Field"
728
+ )
729
+ return iterable
730
+
731
+ if key is None:
732
+ fields = list(field for field in iterable(self.iterate()))
733
+ elif isinstance(key, int):
734
+ fields = [self.read_field(i) for i in iterable([key])]
735
+ elif isinstance(key, slice):
736
+ indices = range(*key.indices(self._header.get_field_count(self._file_size)))
737
+ fields = [self.read_field(i) for i in iterable(indices)]
738
+ elif isinstance(key, list):
739
+ fields = [self.read_field(i) for i in iterable(key)]
740
+ elif isinstance(key, ABCCallable):
741
+ fields = [field for field in iterable(self.iterate()) if key(field)]
742
+ else:
743
+ raise TypeError("Invalid argument type")
744
+ return Series(fields)
745
+
746
+
747
+ @dataclass
748
+ class PlotArgs:
749
+ """Arguments for plotting a field.
750
+
751
+ Args:
752
+ title (str, optional): Title of the plot. Defaults to `None`.
753
+ xlabel (str, optional): Label of the x-axis. Defaults to `None`.
754
+ ylabel (str, optional): Label of the y-axis. Defaults to `None`.
755
+ figsize (Tuple[float, float], optional): Figure size. Defaults to `None`.
756
+ dpi (int, optional): Figure DPI. Defaults to `None`.
757
+ aspect (str, optional): Aspect ratio. Defaults to `equal`.
758
+ ax (matplotlib.Axes, optional): Axes of the plot. Defaults to `None`.
759
+ cax (matplotlib.Axes, optional): Axes of the color bar. Defaults to `None`.
760
+ vmin (float, optional): Minimum value of the color bar. Defaults to `None`.
761
+ vmax (float, optional): Maximum value of the color bar. Defaults to `None`.
762
+ cmap (str, optional): Colormap. Defaults to `micpy`.
763
+ alpha (float, optional): Transparency of the plot. Defaults to `1.0`.
764
+ interpolation (str, optional): Interpolation method. Defaults to `none`.
765
+ extent (Tuple[float, float, float, float], optional): Extent of the plot.
766
+ """
767
+
768
+ title: Optional[str] = None
769
+ xlabel: Optional[str] = None
770
+ ylabel: Optional[str] = None
771
+ figsize: Optional[Tuple[float, float]] = None
772
+ dpi: Optional[int] = None
773
+ aspect: str = "equal"
774
+ ax: "matplotlib.Axes" = None
775
+ cax: "matplotlib.Axes" = None
776
+ vmin: Optional[float] = None
777
+ vmax: Optional[float] = None
778
+ cmap: str = "micpy"
779
+ alpha: float = 1.0
780
+ interpolation: str = "none"
781
+ extent: Optional[Tuple[float, float, float, float]] = None
782
+
783
+
784
+ def plot(
785
+ field: Field,
786
+ axis: str = "y",
787
+ index: int = 0,
788
+ args: Optional[PlotArgs] = None,
789
+ ) -> Tuple["matplotlib.Figure", "matplotlib.Axes"]:
790
+ """Plot a slice of the field.
791
+
792
+ Args:
793
+ field (Field): Field to plot.
794
+ axis (str, optional): Axis to plot. Possible values are `x`, `y`, and `z`.
795
+ Defaults to `y`.
796
+ index (int, optional): Index of the slice. Defaults to `0`.
797
+ args (PlotArgs, optional): Arguments for plotting. Defaults to `None`.
798
+
799
+ Returns:
800
+ Matplotlib figure, axes, and color bar.
801
+ """
802
+
803
+ if matplotlib is None:
804
+ raise ImportError("matplotlib is not installed")
805
+
806
+ if field.ndim != 3:
807
+ raise ValueError("Invalid field shape")
808
+
809
+ if axis == "z":
810
+ x, y = "x", "y"
811
+ slice_2d = field[index, :, :]
812
+ elif axis == "y":
813
+ x, y = "x", "z"
814
+ slice_2d = field[:, index, :]
815
+ elif axis == "x":
816
+ x, y = "y", "z"
817
+ slice_2d = field[:, :, index]
818
+ else:
819
+ raise ValueError("Invalid axis")
820
+
821
+ if args is None:
822
+ args = PlotArgs()
823
+
824
+ fig, ax = (
825
+ pyplot.subplots(figsize=args.figsize, dpi=args.dpi)
826
+ if args.ax is None
827
+ else (args.ax.get_figure(), args.ax)
828
+ )
829
+
830
+ if args.title is not None:
831
+ ax.set_title(args.title)
832
+ else:
833
+ if isinstance(field, Field):
834
+ ax.set_title(f"t={np.round(field.time, 7)}s")
835
+ if args.xlabel is not None:
836
+ ax.set_xlabel(args.xlabel)
837
+ else:
838
+ ax.set_xlabel(x)
839
+ if args.ylabel is not None:
840
+ ax.set_ylabel(args.ylabel)
841
+ else:
842
+ ax.set_ylabel(y)
843
+ if args.aspect is not None:
844
+ ax.set_aspect(args.aspect)
845
+ ax.set_frame_on(False)
846
+
847
+ image = ax.imshow(
848
+ slice_2d,
849
+ cmap=args.cmap,
850
+ vmin=args.vmin,
851
+ vmax=args.vmax,
852
+ interpolation=args.interpolation,
853
+ alpha=args.alpha,
854
+ extent=args.extent,
855
+ origin="lower",
856
+ )
857
+
858
+ bar = pyplot.colorbar(image, ax=ax, cax=args.cax)
859
+ bar.locator = matplotlib.ticker.MaxNLocator(
860
+ integer=np.issubdtype(slice_2d.dtype, np.integer)
861
+ )
862
+ bar.outline.set_visible(False)
863
+ bar.update_ticks()
864
+
865
+ return fig, ax, bar