dcnum 0.17.0__py3-none-any.whl → 0.23.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dcnum might be problematic. Click here for more details.

Files changed (49) hide show
  1. dcnum/_version.py +2 -2
  2. dcnum/feat/__init__.py +1 -1
  3. dcnum/feat/event_extractor_manager_thread.py +34 -25
  4. dcnum/feat/feat_background/base.py +22 -26
  5. dcnum/feat/feat_background/bg_copy.py +18 -12
  6. dcnum/feat/feat_background/bg_roll_median.py +20 -10
  7. dcnum/feat/feat_background/bg_sparse_median.py +55 -7
  8. dcnum/feat/feat_brightness/bright_all.py +41 -6
  9. dcnum/feat/feat_contour/__init__.py +4 -0
  10. dcnum/feat/{feat_moments/mt_legacy.py → feat_contour/moments.py} +32 -8
  11. dcnum/feat/feat_contour/volume.py +174 -0
  12. dcnum/feat/feat_texture/tex_all.py +28 -1
  13. dcnum/feat/gate.py +2 -2
  14. dcnum/feat/queue_event_extractor.py +30 -9
  15. dcnum/logic/ctrl.py +199 -49
  16. dcnum/logic/job.py +63 -2
  17. dcnum/logic/json_encoder.py +2 -0
  18. dcnum/meta/ppid.py +17 -3
  19. dcnum/read/__init__.py +1 -0
  20. dcnum/read/cache.py +100 -78
  21. dcnum/read/const.py +6 -4
  22. dcnum/read/hdf5_data.py +146 -23
  23. dcnum/read/mapped.py +87 -0
  24. dcnum/segm/__init__.py +6 -3
  25. dcnum/segm/segm_thresh.py +6 -18
  26. dcnum/segm/segm_torch/__init__.py +19 -0
  27. dcnum/segm/segm_torch/segm_torch_base.py +125 -0
  28. dcnum/segm/segm_torch/segm_torch_mpo.py +71 -0
  29. dcnum/segm/segm_torch/segm_torch_sto.py +88 -0
  30. dcnum/segm/segm_torch/torch_model.py +95 -0
  31. dcnum/segm/segm_torch/torch_postproc.py +93 -0
  32. dcnum/segm/segm_torch/torch_preproc.py +114 -0
  33. dcnum/segm/segmenter.py +181 -80
  34. dcnum/segm/segmenter_manager_thread.py +38 -30
  35. dcnum/segm/{segmenter_cpu.py → segmenter_mpo.py} +116 -44
  36. dcnum/segm/segmenter_sto.py +110 -0
  37. dcnum/write/__init__.py +2 -1
  38. dcnum/write/deque_writer_thread.py +9 -1
  39. dcnum/write/queue_collector_thread.py +8 -14
  40. dcnum/write/writer.py +128 -5
  41. {dcnum-0.17.0.dist-info → dcnum-0.23.1.dist-info}/METADATA +4 -2
  42. dcnum-0.23.1.dist-info/RECORD +55 -0
  43. {dcnum-0.17.0.dist-info → dcnum-0.23.1.dist-info}/WHEEL +1 -1
  44. dcnum/feat/feat_moments/__init__.py +0 -4
  45. dcnum/segm/segmenter_gpu.py +0 -64
  46. dcnum-0.17.0.dist-info/RECORD +0 -46
  47. /dcnum/feat/{feat_moments/ct_opencv.py → feat_contour/contour.py} +0 -0
  48. {dcnum-0.17.0.dist-info → dcnum-0.23.1.dist-info}/LICENSE +0 -0
  49. {dcnum-0.17.0.dist-info → dcnum-0.23.1.dist-info}/top_level.txt +0 -0
dcnum/read/mapped.py ADDED
@@ -0,0 +1,87 @@
1
+ import functools
2
+
3
+ import numbers
4
+
5
+ import h5py
6
+ import numpy as np
7
+
8
+
9
+ class MappedHDF5Dataset:
10
+ def __init__(self,
11
+ h5ds: h5py.Dataset,
12
+ mapping_indices: np.ndarray):
13
+ """An index-mapped object for accessing an HDF5 dataset
14
+
15
+ Parameters
16
+ ----------
17
+ h5ds: h5py.Dataset
18
+ HDF5 dataset from which to map data
19
+ mapping_indices: np.ndarray
20
+ numpy indexing array containing integer indices
21
+ """
22
+ self.h5ds = h5ds
23
+ self.mapping_indices = mapping_indices
24
+ self.shape = (mapping_indices.size,) + h5ds.shape[1:]
25
+
26
+ def __getitem__(self, idx):
27
+ if isinstance(idx, numbers.Integral):
28
+ return self.h5ds[self.mapping_indices[idx]]
29
+ else:
30
+ idx_mapped = self.mapping_indices[idx]
31
+ return self.h5ds[idx_mapped]
32
+
33
+
34
+ def get_mapping_indices(
35
+ index_mapping: numbers.Integral | slice | list | np.ndarray
36
+ ):
37
+ """Return integer numpy array with mapping indices for a range
38
+
39
+ Parameters
40
+ ----------
41
+ index_mapping: numbers.Integral | slice | list | np.ndarray
42
+ Several options you have here:
43
+ - integer: results in np.arrange(integer)
44
+ - slice: results in np.arrange(slice.start, slice.stop, slice.step)
45
+ - list or np.ndarray: returns the input as unit32 array
46
+ """
47
+ if isinstance(index_mapping, numbers.Integral):
48
+ return _get_mapping_indices_cached(index_mapping)
49
+ elif isinstance(index_mapping, slice):
50
+ return _get_mapping_indices_cached(
51
+ (index_mapping.start, index_mapping.stop, index_mapping.step))
52
+ elif isinstance(index_mapping, (np.ndarray, list)):
53
+ return np.array(index_mapping, dtype=np.uint32)
54
+ else:
55
+ raise ValueError(f"Invalid type for `index_mapping`: "
56
+ f"{type(index_mapping)} ({index_mapping})")
57
+
58
+
59
+ @functools.lru_cache(maxsize=100)
60
+ def _get_mapping_indices_cached(
61
+ index_mapping: numbers.Integral | tuple
62
+ ):
63
+ if isinstance(index_mapping, numbers.Integral):
64
+ return np.arange(index_mapping)
65
+ elif isinstance(index_mapping, tuple):
66
+ im_slice = slice(*index_mapping)
67
+ if im_slice.stop is None or im_slice.start is None:
68
+ raise NotImplementedError(
69
+ "Slices must have start and stop defined")
70
+ return np.arange(im_slice.start, im_slice.stop, im_slice.step)
71
+ elif isinstance(index_mapping, list):
72
+ return np.array(index_mapping, dtype=np.uint32)
73
+ else:
74
+ raise ValueError(f"Invalid type for cached `index_mapping`: "
75
+ f"{type(index_mapping)} ({index_mapping})")
76
+
77
+
78
+ def get_mapped_object(obj, index_mapping=None):
79
+ if index_mapping is None:
80
+ return obj
81
+ elif isinstance(obj, h5py.Dataset):
82
+ return MappedHDF5Dataset(
83
+ obj,
84
+ mapping_indices=get_mapping_indices(index_mapping))
85
+ else:
86
+ raise ValueError(f"No recipe to convert object of type {type(obj)} "
87
+ f"({obj}) to an index-mapped object")
dcnum/segm/__init__.py CHANGED
@@ -1,6 +1,9 @@
1
1
  # flake8: noqa: F401
2
- from .segmenter import Segmenter, get_available_segmenters
3
- from .segmenter_cpu import CPUSegmenter
4
- from .segmenter_gpu import GPUSegmenter
2
+ from .segmenter import (
3
+ Segmenter, SegmenterNotApplicableError, get_available_segmenters
4
+ )
5
+ from .segmenter_mpo import MPOSegmenter
6
+ from .segmenter_sto import STOSegmenter
5
7
  from .segmenter_manager_thread import SegmenterManagerThread
6
8
  from . import segm_thresh
9
+ from . import segm_torch
dcnum/segm/segm_thresh.py CHANGED
@@ -1,7 +1,7 @@
1
- from .segmenter_cpu import CPUSegmenter
1
+ from .segmenter_mpo import MPOSegmenter
2
2
 
3
3
 
4
- class SegmentThresh(CPUSegmenter):
4
+ class SegmentThresh(MPOSegmenter):
5
5
  mask_postprocessing = True
6
6
  mask_default_kwargs = {
7
7
  "clear_border": True,
@@ -10,22 +10,10 @@ class SegmentThresh(CPUSegmenter):
10
10
  }
11
11
  requires_background_correction = True
12
12
 
13
- def __init__(self, thresh=-6, *args, **kwargs):
14
- """Simple image thresholding segmentation
15
-
16
- Parameters
17
- ----------
18
- thresh: int
19
- grayscale threhold value for creating the mask image;
20
- For a background-corrected image, pixels with values below
21
- this value are considered to be part of the mask.
22
- """
23
- super(SegmentThresh, self).__init__(thresh=thresh, *args, **kwargs)
24
-
25
13
  @staticmethod
26
- def segment_approach(image, *,
27
- thresh: float = -6):
28
- """Mask retrieval as it is done in Shape-In
14
+ def segment_algorithm(image, *,
15
+ thresh: float = -6):
16
+ """Mask retrieval using basic thresholding
29
17
 
30
18
  Parameters
31
19
  ----------
@@ -39,7 +27,7 @@ class SegmentThresh(CPUSegmenter):
39
27
  Returns
40
28
  -------
41
29
  mask: 2d boolean ndarray
42
- Mask image for the give index
30
+ Mask image for the given index
43
31
  """
44
32
  assert thresh < 0, "threshold values above zero not supported!"
45
33
  return image < thresh
@@ -0,0 +1,19 @@
1
+ import importlib
2
+
3
+ try:
4
+ torch = importlib.import_module("torch")
5
+ req_maj = 2
6
+ req_min = 3
7
+ ver_tuple = torch.__version__.split(".")
8
+ act_maj = int(ver_tuple[0])
9
+ act_min = int(ver_tuple[1])
10
+ if act_maj < req_maj or (act_maj == req_maj and act_min < req_min):
11
+ raise ValueError(f"Your PyTorch version {act_maj}.{act_min} is not "
12
+ f"supported, please update to at least "
13
+ f"{req_maj}.{req_min}")
14
+ except ImportError:
15
+ pass
16
+ else:
17
+ from .segm_torch_mpo import SegmentTorchMPO # noqa: F401
18
+ if torch.cuda.is_available():
19
+ from .segm_torch_sto import SegmentTorchSTO # noqa: F401
@@ -0,0 +1,125 @@
1
+ import functools
2
+ import pathlib
3
+ import re
4
+ from typing import Dict
5
+
6
+ from ...meta import paths
7
+
8
+ from ..segmenter import Segmenter, SegmenterNotApplicableError
9
+
10
+ from .torch_model import load_model
11
+
12
+
13
+ class TorchSegmenterBase(Segmenter):
14
+ """Torch segmenters that use a pretrained model for segmentation"""
15
+ requires_background_correction = False
16
+ mask_postprocessing = True
17
+ mask_default_kwargs = {
18
+ "clear_border": True,
19
+ "fill_holes": True,
20
+ "closing_disk": 0,
21
+ }
22
+
23
+ @classmethod
24
+ def get_ppid_from_ppkw(cls, kwargs, kwargs_mask=None):
25
+ kwargs_new = kwargs.copy()
26
+ # Make sure that the `model_file` kwarg is actually just a filename
27
+ # so that the pipeline identifier only contains the name, but not
28
+ # the full path.
29
+ if "model_file" in kwargs:
30
+ model_file = kwargs["model_file"]
31
+ mpath = pathlib.Path(model_file)
32
+ if mpath.exists():
33
+ # register the location of the file in the search path
34
+ # registry so other threads/processes will find it.
35
+ paths.register_search_path("torch_model_files", mpath.parent)
36
+ kwargs_new["model_file"] = mpath.name
37
+ return super(TorchSegmenterBase, cls).get_ppid_from_ppkw(kwargs_new,
38
+ kwargs_mask)
39
+
40
+ @classmethod
41
+ def validate_applicability(cls,
42
+ segmenter_kwargs: Dict,
43
+ meta: Dict = None,
44
+ logs: Dict = None):
45
+ """Validate the applicability of this segmenter for a dataset
46
+
47
+ The applicability is defined by the metadata in the segmentation
48
+ model.
49
+
50
+ Parameters
51
+ ----------
52
+ segmenter_kwargs: dict
53
+ Keyword arguments for the segmenter
54
+ meta: dict
55
+ Dictionary of metadata from an :class:`HDF5Data` instance
56
+ logs: dict
57
+ Dictionary of logs from an :class:`HDF5Data` instance
58
+
59
+ Returns
60
+ -------
61
+ applicable: bool
62
+ True if the segmenter is applicable to the dataset
63
+
64
+ Raises
65
+ ------
66
+ SegmenterNotApplicable
67
+ If the segmenter is not applicable to the dataset
68
+ """
69
+ if "model_file" not in segmenter_kwargs:
70
+ raise ValueError("A `model_file` must be provided in the "
71
+ "`segmenter_kwargs` to validate applicability")
72
+
73
+ model_file = segmenter_kwargs["model_file"]
74
+ _, model_meta = load_model(model_file, device="cpu")
75
+
76
+ reasons_list = []
77
+ validators = {
78
+ "meta": functools.partial(
79
+ cls._validate_applicability_item,
80
+ data_dict=meta,
81
+ reasons_list=reasons_list),
82
+ "logs": functools.partial(
83
+ cls._validate_applicability_item,
84
+ # convert logs to strings
85
+ data_dict={key: "\n".join(val) for key, val in logs.items()},
86
+ reasons_list=reasons_list)
87
+ }
88
+ for item in model_meta.get("validation", []):
89
+ it = item["type"]
90
+ if it in validators:
91
+ validators[it](item)
92
+ else:
93
+ reasons_list.append(
94
+ f"invalid validation type {it} in {model_file}")
95
+
96
+ if reasons_list:
97
+ raise SegmenterNotApplicableError(segmenter_class=cls,
98
+ reasons_list=reasons_list)
99
+
100
+ return True
101
+
102
+ @staticmethod
103
+ def _validate_applicability_item(item, data_dict, reasons_list):
104
+ """Populate `reasons_list` with invalid entries
105
+
106
+ Example `data_dict`::
107
+
108
+ {"type": "meta",
109
+ "key": "setup:region",
110
+ "allow-missing-key": False,
111
+ "regexp": "^channel$",
112
+ "regexp-negate": False,
113
+ "reason": "only channel region supported",
114
+ }
115
+ """
116
+ key = item["key"]
117
+ if key in data_dict:
118
+ regexp = re.compile(item["regexp"])
119
+ matched = bool(regexp.match(data_dict[key]))
120
+ negate = item.get("regexp-negate", False)
121
+ valid = matched if not negate else not matched
122
+ if not valid:
123
+ reasons_list.append(item.get("reason", "unknown reason"))
124
+ elif not item.get("allow-missing-key", False):
125
+ reasons_list.append(f"Key '{key}' missing in {item['type']}")
@@ -0,0 +1,71 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from ..segmenter_mpo import MPOSegmenter
5
+
6
+ from .segm_torch_base import TorchSegmenterBase
7
+ from .torch_model import load_model
8
+ from .torch_preproc import preprocess_images
9
+ from .torch_postproc import postprocess_masks
10
+
11
+
12
+ class SegmentTorchMPO(TorchSegmenterBase, MPOSegmenter):
13
+ """PyTorch segmentation (multiprocessing version)"""
14
+
15
+ @staticmethod
16
+ def segment_algorithm(image, *,
17
+ model_file: str = None):
18
+ """
19
+ Parameters
20
+ ----------
21
+ image: 2d ndarray
22
+ event image
23
+ model_file: str
24
+ path to or name of a dcnum model file (.dcnm); if only a
25
+ name is provided, then the "torch_model_files" directory
26
+ paths are searched for the file name
27
+
28
+ Returns
29
+ -------
30
+ mask: 2d boolean or integer ndarray
31
+ mask or labeling image for the give index
32
+ """
33
+ if model_file is None:
34
+ raise ValueError("Please specify a .dcnm model file!")
35
+
36
+ # Set number of pytorch threads to 1, because dcnum is doing
37
+ # all the multiprocessing.
38
+ # https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads
39
+ torch.set_num_threads(1)
40
+ device = torch.device("cpu")
41
+
42
+ # Load model and metadata
43
+ model, model_meta = load_model(model_file, device)
44
+
45
+ image_preproc = preprocess_images(image[np.newaxis, :, :],
46
+ **model_meta["preprocessing"])
47
+
48
+ image_ten = torch.from_numpy(image_preproc)
49
+
50
+ # Move image tensors to device
51
+ image_ten_on_device = image_ten.to(device)
52
+ # Model inference
53
+ pred_tensor = model(image_ten_on_device)
54
+
55
+ # Convert cuda-tensor into numpy mask array. The `pred_tensor`
56
+ # array is still of the shape (1, 1, H, W). The `masks`
57
+ # array is of shape (1, H, W). We can optionally label it
58
+ # here (we have to if the shapes don't match) or do it in
59
+ # postprocessing.
60
+ masks = pred_tensor.detach().cpu().numpy()[0] >= 0.5
61
+
62
+ # Perform postprocessing in cases where the image shapes don't match
63
+ assert len(masks[0].shape) == len(image.shape), "sanity check"
64
+ if masks[0].shape != image.shape:
65
+ labels = postprocess_masks(
66
+ masks=masks,
67
+ original_image_shape=image.shape,
68
+ )
69
+ return labels[0]
70
+ else:
71
+ return masks[0]
@@ -0,0 +1,88 @@
1
+ from dcnum.segm import STOSegmenter
2
+ import numpy as np
3
+ import torch
4
+
5
+ from .segm_torch_base import TorchSegmenterBase
6
+ from .torch_model import load_model
7
+ from .torch_preproc import preprocess_images
8
+ from .torch_postproc import postprocess_masks
9
+
10
+
11
+ class SegmentTorchSTO(TorchSegmenterBase, STOSegmenter):
12
+ """PyTorch segmentation (GPU version)"""
13
+
14
+ @staticmethod
15
+ def _segment_in_batches(imgs_t, model, batch_size, device):
16
+ """Segment image data in batches"""
17
+ size = len(imgs_t)
18
+ # Create empty array to fill up with segmented batches
19
+ masks = np.empty((len(imgs_t), *imgs_t[0].shape[-2:]),
20
+ dtype=bool)
21
+
22
+ for start_idx in range(0, size, batch_size):
23
+ batch = imgs_t[start_idx:start_idx + batch_size]
24
+ # Move image tensors to cuda
25
+ batch = torch.tensor(batch, device=device)
26
+ # Model inference
27
+ batch_seg = model(batch)
28
+ # Remove extra dim [B, C, H, W] --> [B, H, W]
29
+ batch_seg = batch_seg.squeeze(1)
30
+ # Convert cuda-tensor into numpy arrays
31
+ batch_seg_np = batch_seg.detach().cpu().numpy()
32
+ # Fill empty array with segmented batch
33
+ masks[start_idx:start_idx + batch_size] = batch_seg_np >= 0.5
34
+
35
+ return masks
36
+
37
+ @staticmethod
38
+ def segment_algorithm(images, gpu_id=None, batch_size=50, *,
39
+ model_file: str = None):
40
+ """
41
+ Parameters
42
+ ----------
43
+ images: 3d ndarray
44
+ array of N event images of shape (N, H, W)
45
+ gpu_id: str
46
+ optional argument specifying the GPU to use
47
+ batch_size: int
48
+ number of images to process in one batch
49
+ model_file: str
50
+ path to or name of a dcnum model file (.dcnm); if only a
51
+ name is provided, then the "torch_model_files" directory
52
+ paths are searched for the file name
53
+
54
+ Returns
55
+ -------
56
+ mask: 2d boolean or integer ndarray
57
+ mask or label images of shape (N, H, W)
58
+ """
59
+ if model_file is None:
60
+ raise ValueError("Please specify a model file!")
61
+
62
+ # Determine device to use
63
+ device = torch.device(gpu_id if gpu_id is not None else "cuda")
64
+
65
+ # Load model and metadata
66
+ model, model_meta = load_model(model_file, device)
67
+
68
+ # Preprocess the images
69
+ image_preproc = preprocess_images(images,
70
+ **model_meta["preprocessing"])
71
+ # Model inference
72
+ # The `masks` array has the shape (len(images), H, W), where
73
+ # H and W may be different from the corresponding axes in `images`.
74
+ masks = SegmentTorchSTO._segment_in_batches(image_preproc,
75
+ model,
76
+ batch_size,
77
+ device
78
+ )
79
+
80
+ # Perform postprocessing in cases where the image shapes don't match
81
+ assert len(masks.shape[1:]) == len(images.shape[1:]), "sanity check"
82
+ if masks.shape[1:] != images.shape[1:]:
83
+ labels = postprocess_masks(
84
+ masks=masks,
85
+ original_image_shape=images.shape[1:])
86
+ return labels
87
+ else:
88
+ return masks
@@ -0,0 +1,95 @@
1
+ import errno
2
+ import functools
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import pathlib
8
+
9
+ import torch
10
+
11
+ from ...meta import paths
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def check_md5sum(path):
18
+ """Verify the last five characters of the file stem with its MD5 hash"""
19
+ md5 = hashlib.md5(path.read_bytes()).hexdigest()
20
+ if md5[:5] != path.stem.split("_")[-1]:
21
+ raise ValueError(f"MD5 mismatch for {path} ({md5})! Expected the "
22
+ f"input file to end with '{md5[:5]}{path.suffix}'.")
23
+
24
+
25
+ @functools.cache
26
+ def load_model(path_or_name, device):
27
+ """Load a PyTorch model + metadata from a TorchScript jit checkpoint
28
+
29
+ Parameters
30
+ ----------
31
+ path_or_name: str or pathlib.Path
32
+ jit checkpoint file; For dcnum, these files have the suffix .dcnm
33
+ and contain a special `_extra_files["dcnum_meta.json"]` extra
34
+ file that can be loaded via `torch.jit.load` (see below).
35
+ device: str or torch.device
36
+ device on which to run the model
37
+
38
+ Returns
39
+ -------
40
+ model_jit: torch.jit.ScriptModule
41
+ loaded PyTorch model stored as a TorchScript module
42
+ model_meta: dict
43
+ metadata associated with the loaded model
44
+ """
45
+ model_path = retrieve_model_file(path_or_name)
46
+ # define an extra files mapping dictionary that loads the model's metadata
47
+ extra_files = {"dcnum_meta.json": ""}
48
+ # load model
49
+ model_jit = torch.jit.load(model_path,
50
+ _extra_files=extra_files,
51
+ map_location=device)
52
+ # load model metadata
53
+ model_meta = json.loads(extra_files["dcnum_meta.json"])
54
+ # set model to evaluation mode
55
+ model_jit.eval()
56
+ # optimize for inference on device
57
+ model_jit = torch.jit.optimize_for_inference(model_jit)
58
+ return model_jit, model_meta
59
+
60
+
61
+ @functools.cache
62
+ def retrieve_model_file(path_or_name):
63
+ """Retrieve a dcnum torch model file
64
+
65
+ If a path to a model is given, then this path is returned directly.
66
+ If a file name is given, then look for the file with
67
+ :func:`dcnum.meta.paths.find_file` using the "torch_model_file"
68
+ topic.
69
+ """
70
+ # Did the user already pass a path?
71
+ if isinstance(path_or_name, pathlib.Path):
72
+ if path_or_name.exists():
73
+ path = path_or_name
74
+ else:
75
+ try:
76
+ return retrieve_model_file(path_or_name.name)
77
+ except BaseException:
78
+ raise FileNotFoundError(errno.ENOENT,
79
+ os.strerror(errno.ENOENT),
80
+ str(path_or_name))
81
+ elif isinstance(path_or_name, str):
82
+ name = path_or_name.strip()
83
+ # We now have a string for a filename, and we have to figure out what
84
+ # the path is. There are several options, including cached files.
85
+ if pathlib.Path(name).exists():
86
+ path = pathlib.Path(name)
87
+ else:
88
+ path = paths.find_file("torch_model_files", name)
89
+ else:
90
+ raise ValueError(
91
+ f"Please pass a string or a path, got {type(path_or_name)}!")
92
+
93
+ logger.info(f"Found dcnum model file {path}")
94
+ check_md5sum(path)
95
+ return path
@@ -0,0 +1,93 @@
1
+ from typing import Tuple
2
+
3
+ from ..segmenter import Segmenter
4
+
5
+ import numpy as np
6
+ from scipy import ndimage as ndi
7
+
8
+
9
+ def postprocess_masks(masks,
10
+ original_image_shape: Tuple[int, int]):
11
+ """Postprocess mask images from ML segmenters
12
+
13
+ The transformation includes:
14
+ - Revert the cropping and padding operations done in
15
+ :func:`.preprocess_images` by padding with zeros and cropping.
16
+ - If the original image shape is larger than the mask image shape,
17
+ also clear borders in an intermediate step
18
+ (maks postprocessing using :func:`Segmenter.process_mask`).
19
+
20
+ Parameters
21
+ ----------
22
+ masks: 3d or 4d ndarray
23
+ Mask data in shape (batch_size, 1, imagex_size, imagey_size)
24
+ or (batch_size, imagex_size, imagey_size).
25
+ original_image_shape: tuple of (int, int)
26
+ The required output mask shape for one event. This required for
27
+ doing the inverse of what is done in :func:`.preprocess_images`.
28
+
29
+ Returns
30
+ -------
31
+ labels_proc: np.ndarray
32
+ An integer array with the same dimensions as the original image
33
+ data passed to :func:`.preprocess_images`. The shape of this array
34
+ is (batch_size, original_image_shape[0], original_image_shape[1]).
35
+ """
36
+ # If output of model is 4d, remove channel axis
37
+ if len(masks.shape) == 4:
38
+ masks = masks[:, 0, :, :]
39
+
40
+ # Label the mask image
41
+ labels = np.empty(masks.shape, dtype=np.uint16)
42
+ label_struct = ndi.generate_binary_structure(2, 2)
43
+ for ii in range(masks.shape[0]):
44
+ ndi.label(
45
+ input=masks[ii],
46
+ output=labels[ii],
47
+ structure=label_struct)
48
+
49
+ batch_size = labels.shape[0]
50
+
51
+ # Revert padding and cropping from preprocessing
52
+ mask_shape_ret = labels.shape[1:]
53
+ # height
54
+ s0diff = original_image_shape[0] - mask_shape_ret[0]
55
+ s0t = abs(s0diff) // 2
56
+ s0b = abs(s0diff) - s0t
57
+ # width
58
+ s1diff = original_image_shape[1] - mask_shape_ret[1]
59
+ s1l = abs(s1diff) // 2
60
+ s1r = abs(s1diff) - s1l
61
+
62
+ if s0diff > 0 or s1diff > 0:
63
+ # The masks that we have must be padded. Before we do that, we have
64
+ # to remove events on the edges, otherwise we will have half-segmented
65
+ # cell events in the output array.
66
+ for ii in range(batch_size):
67
+ labels[ii] = Segmenter.process_mask(labels[ii],
68
+ clear_border=True,
69
+ fill_holes=False,
70
+ closing_disk=0)
71
+
72
+ # Crop first, only then pad.
73
+ if s1diff > 0:
74
+ labels_pad = np.zeros((batch_size,
75
+ labels.shape[1],
76
+ original_image_shape[1]),
77
+ dtype=np.uint16)
78
+ labels_pad[:, :, s1l:-s1r] = labels
79
+ labels = labels_pad
80
+ elif s1diff < 0:
81
+ labels = labels[:, :, s1l:-s1r]
82
+
83
+ if s0diff > 0:
84
+ labels_pad = np.zeros((batch_size,
85
+ original_image_shape[0],
86
+ original_image_shape[1]),
87
+ dtype=np.uint16)
88
+ labels_pad[:, s0t:-s0b, :] = labels
89
+ labels = labels_pad
90
+ elif s0diff < 0:
91
+ labels = labels[:, s0t:-s0b, :]
92
+
93
+ return labels