dcnum 0.13.2__py3-none-any.whl → 0.23.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dcnum might be problematic. Click here for more details.
- dcnum/_version.py +2 -2
- dcnum/feat/__init__.py +2 -1
- dcnum/feat/event_extractor_manager_thread.py +67 -33
- dcnum/feat/feat_background/__init__.py +3 -12
- dcnum/feat/feat_background/base.py +80 -65
- dcnum/feat/feat_background/bg_copy.py +31 -0
- dcnum/feat/feat_background/bg_roll_median.py +38 -30
- dcnum/feat/feat_background/bg_sparse_median.py +96 -45
- dcnum/feat/feat_brightness/__init__.py +1 -0
- dcnum/feat/feat_brightness/bright_all.py +41 -6
- dcnum/feat/feat_contour/__init__.py +4 -0
- dcnum/feat/{feat_moments/mt_legacy.py → feat_contour/moments.py} +32 -8
- dcnum/feat/feat_contour/volume.py +174 -0
- dcnum/feat/feat_texture/__init__.py +1 -0
- dcnum/feat/feat_texture/tex_all.py +28 -1
- dcnum/feat/gate.py +92 -70
- dcnum/feat/queue_event_extractor.py +139 -70
- dcnum/logic/__init__.py +5 -0
- dcnum/logic/ctrl.py +794 -0
- dcnum/logic/job.py +184 -0
- dcnum/logic/json_encoder.py +19 -0
- dcnum/meta/__init__.py +1 -0
- dcnum/meta/paths.py +30 -0
- dcnum/meta/ppid.py +66 -9
- dcnum/read/__init__.py +1 -0
- dcnum/read/cache.py +109 -77
- dcnum/read/const.py +6 -4
- dcnum/read/hdf5_data.py +190 -31
- dcnum/read/mapped.py +87 -0
- dcnum/segm/__init__.py +6 -15
- dcnum/segm/segm_thresh.py +7 -14
- dcnum/segm/segm_torch/__init__.py +19 -0
- dcnum/segm/segm_torch/segm_torch_base.py +125 -0
- dcnum/segm/segm_torch/segm_torch_mpo.py +71 -0
- dcnum/segm/segm_torch/segm_torch_sto.py +88 -0
- dcnum/segm/segm_torch/torch_model.py +95 -0
- dcnum/segm/segm_torch/torch_postproc.py +93 -0
- dcnum/segm/segm_torch/torch_preproc.py +114 -0
- dcnum/segm/segmenter.py +245 -96
- dcnum/segm/segmenter_manager_thread.py +39 -28
- dcnum/segm/{segmenter_cpu.py → segmenter_mpo.py} +137 -43
- dcnum/segm/segmenter_sto.py +110 -0
- dcnum/write/__init__.py +3 -1
- dcnum/write/deque_writer_thread.py +15 -5
- dcnum/write/queue_collector_thread.py +14 -17
- dcnum/write/writer.py +225 -55
- {dcnum-0.13.2.dist-info → dcnum-0.23.1.dist-info}/METADATA +4 -2
- dcnum-0.23.1.dist-info/RECORD +55 -0
- {dcnum-0.13.2.dist-info → dcnum-0.23.1.dist-info}/WHEEL +1 -1
- dcnum/feat/feat_moments/__init__.py +0 -3
- dcnum/segm/segmenter_gpu.py +0 -45
- dcnum-0.13.2.dist-info/RECORD +0 -40
- /dcnum/feat/{feat_moments/ct_opencv.py → feat_contour/contour.py} +0 -0
- {dcnum-0.13.2.dist-info → dcnum-0.23.1.dist-info}/LICENSE +0 -0
- {dcnum-0.13.2.dist-info → dcnum-0.23.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import pathlib
|
|
3
|
+
import re
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
from ...meta import paths
|
|
7
|
+
|
|
8
|
+
from ..segmenter import Segmenter, SegmenterNotApplicableError
|
|
9
|
+
|
|
10
|
+
from .torch_model import load_model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TorchSegmenterBase(Segmenter):
|
|
14
|
+
"""Torch segmenters that use a pretrained model for segmentation"""
|
|
15
|
+
requires_background_correction = False
|
|
16
|
+
mask_postprocessing = True
|
|
17
|
+
mask_default_kwargs = {
|
|
18
|
+
"clear_border": True,
|
|
19
|
+
"fill_holes": True,
|
|
20
|
+
"closing_disk": 0,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def get_ppid_from_ppkw(cls, kwargs, kwargs_mask=None):
|
|
25
|
+
kwargs_new = kwargs.copy()
|
|
26
|
+
# Make sure that the `model_file` kwarg is actually just a filename
|
|
27
|
+
# so that the pipeline identifier only contains the name, but not
|
|
28
|
+
# the full path.
|
|
29
|
+
if "model_file" in kwargs:
|
|
30
|
+
model_file = kwargs["model_file"]
|
|
31
|
+
mpath = pathlib.Path(model_file)
|
|
32
|
+
if mpath.exists():
|
|
33
|
+
# register the location of the file in the search path
|
|
34
|
+
# registry so other threads/processes will find it.
|
|
35
|
+
paths.register_search_path("torch_model_files", mpath.parent)
|
|
36
|
+
kwargs_new["model_file"] = mpath.name
|
|
37
|
+
return super(TorchSegmenterBase, cls).get_ppid_from_ppkw(kwargs_new,
|
|
38
|
+
kwargs_mask)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def validate_applicability(cls,
|
|
42
|
+
segmenter_kwargs: Dict,
|
|
43
|
+
meta: Dict = None,
|
|
44
|
+
logs: Dict = None):
|
|
45
|
+
"""Validate the applicability of this segmenter for a dataset
|
|
46
|
+
|
|
47
|
+
The applicability is defined by the metadata in the segmentation
|
|
48
|
+
model.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
segmenter_kwargs: dict
|
|
53
|
+
Keyword arguments for the segmenter
|
|
54
|
+
meta: dict
|
|
55
|
+
Dictionary of metadata from an :class:`HDF5Data` instance
|
|
56
|
+
logs: dict
|
|
57
|
+
Dictionary of logs from an :class:`HDF5Data` instance
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
applicable: bool
|
|
62
|
+
True if the segmenter is applicable to the dataset
|
|
63
|
+
|
|
64
|
+
Raises
|
|
65
|
+
------
|
|
66
|
+
SegmenterNotApplicable
|
|
67
|
+
If the segmenter is not applicable to the dataset
|
|
68
|
+
"""
|
|
69
|
+
if "model_file" not in segmenter_kwargs:
|
|
70
|
+
raise ValueError("A `model_file` must be provided in the "
|
|
71
|
+
"`segmenter_kwargs` to validate applicability")
|
|
72
|
+
|
|
73
|
+
model_file = segmenter_kwargs["model_file"]
|
|
74
|
+
_, model_meta = load_model(model_file, device="cpu")
|
|
75
|
+
|
|
76
|
+
reasons_list = []
|
|
77
|
+
validators = {
|
|
78
|
+
"meta": functools.partial(
|
|
79
|
+
cls._validate_applicability_item,
|
|
80
|
+
data_dict=meta,
|
|
81
|
+
reasons_list=reasons_list),
|
|
82
|
+
"logs": functools.partial(
|
|
83
|
+
cls._validate_applicability_item,
|
|
84
|
+
# convert logs to strings
|
|
85
|
+
data_dict={key: "\n".join(val) for key, val in logs.items()},
|
|
86
|
+
reasons_list=reasons_list)
|
|
87
|
+
}
|
|
88
|
+
for item in model_meta.get("validation", []):
|
|
89
|
+
it = item["type"]
|
|
90
|
+
if it in validators:
|
|
91
|
+
validators[it](item)
|
|
92
|
+
else:
|
|
93
|
+
reasons_list.append(
|
|
94
|
+
f"invalid validation type {it} in {model_file}")
|
|
95
|
+
|
|
96
|
+
if reasons_list:
|
|
97
|
+
raise SegmenterNotApplicableError(segmenter_class=cls,
|
|
98
|
+
reasons_list=reasons_list)
|
|
99
|
+
|
|
100
|
+
return True
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _validate_applicability_item(item, data_dict, reasons_list):
|
|
104
|
+
"""Populate `reasons_list` with invalid entries
|
|
105
|
+
|
|
106
|
+
Example `data_dict`::
|
|
107
|
+
|
|
108
|
+
{"type": "meta",
|
|
109
|
+
"key": "setup:region",
|
|
110
|
+
"allow-missing-key": False,
|
|
111
|
+
"regexp": "^channel$",
|
|
112
|
+
"regexp-negate": False,
|
|
113
|
+
"reason": "only channel region supported",
|
|
114
|
+
}
|
|
115
|
+
"""
|
|
116
|
+
key = item["key"]
|
|
117
|
+
if key in data_dict:
|
|
118
|
+
regexp = re.compile(item["regexp"])
|
|
119
|
+
matched = bool(regexp.match(data_dict[key]))
|
|
120
|
+
negate = item.get("regexp-negate", False)
|
|
121
|
+
valid = matched if not negate else not matched
|
|
122
|
+
if not valid:
|
|
123
|
+
reasons_list.append(item.get("reason", "unknown reason"))
|
|
124
|
+
elif not item.get("allow-missing-key", False):
|
|
125
|
+
reasons_list.append(f"Key '{key}' missing in {item['type']}")
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ..segmenter_mpo import MPOSegmenter
|
|
5
|
+
|
|
6
|
+
from .segm_torch_base import TorchSegmenterBase
|
|
7
|
+
from .torch_model import load_model
|
|
8
|
+
from .torch_preproc import preprocess_images
|
|
9
|
+
from .torch_postproc import postprocess_masks
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SegmentTorchMPO(TorchSegmenterBase, MPOSegmenter):
|
|
13
|
+
"""PyTorch segmentation (multiprocessing version)"""
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def segment_algorithm(image, *,
|
|
17
|
+
model_file: str = None):
|
|
18
|
+
"""
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
image: 2d ndarray
|
|
22
|
+
event image
|
|
23
|
+
model_file: str
|
|
24
|
+
path to or name of a dcnum model file (.dcnm); if only a
|
|
25
|
+
name is provided, then the "torch_model_files" directory
|
|
26
|
+
paths are searched for the file name
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
mask: 2d boolean or integer ndarray
|
|
31
|
+
mask or labeling image for the give index
|
|
32
|
+
"""
|
|
33
|
+
if model_file is None:
|
|
34
|
+
raise ValueError("Please specify a .dcnm model file!")
|
|
35
|
+
|
|
36
|
+
# Set number of pytorch threads to 1, because dcnum is doing
|
|
37
|
+
# all the multiprocessing.
|
|
38
|
+
# https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads
|
|
39
|
+
torch.set_num_threads(1)
|
|
40
|
+
device = torch.device("cpu")
|
|
41
|
+
|
|
42
|
+
# Load model and metadata
|
|
43
|
+
model, model_meta = load_model(model_file, device)
|
|
44
|
+
|
|
45
|
+
image_preproc = preprocess_images(image[np.newaxis, :, :],
|
|
46
|
+
**model_meta["preprocessing"])
|
|
47
|
+
|
|
48
|
+
image_ten = torch.from_numpy(image_preproc)
|
|
49
|
+
|
|
50
|
+
# Move image tensors to device
|
|
51
|
+
image_ten_on_device = image_ten.to(device)
|
|
52
|
+
# Model inference
|
|
53
|
+
pred_tensor = model(image_ten_on_device)
|
|
54
|
+
|
|
55
|
+
# Convert cuda-tensor into numpy mask array. The `pred_tensor`
|
|
56
|
+
# array is still of the shape (1, 1, H, W). The `masks`
|
|
57
|
+
# array is of shape (1, H, W). We can optionally label it
|
|
58
|
+
# here (we have to if the shapes don't match) or do it in
|
|
59
|
+
# postprocessing.
|
|
60
|
+
masks = pred_tensor.detach().cpu().numpy()[0] >= 0.5
|
|
61
|
+
|
|
62
|
+
# Perform postprocessing in cases where the image shapes don't match
|
|
63
|
+
assert len(masks[0].shape) == len(image.shape), "sanity check"
|
|
64
|
+
if masks[0].shape != image.shape:
|
|
65
|
+
labels = postprocess_masks(
|
|
66
|
+
masks=masks,
|
|
67
|
+
original_image_shape=image.shape,
|
|
68
|
+
)
|
|
69
|
+
return labels[0]
|
|
70
|
+
else:
|
|
71
|
+
return masks[0]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from dcnum.segm import STOSegmenter
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .segm_torch_base import TorchSegmenterBase
|
|
6
|
+
from .torch_model import load_model
|
|
7
|
+
from .torch_preproc import preprocess_images
|
|
8
|
+
from .torch_postproc import postprocess_masks
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SegmentTorchSTO(TorchSegmenterBase, STOSegmenter):
|
|
12
|
+
"""PyTorch segmentation (GPU version)"""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def _segment_in_batches(imgs_t, model, batch_size, device):
|
|
16
|
+
"""Segment image data in batches"""
|
|
17
|
+
size = len(imgs_t)
|
|
18
|
+
# Create empty array to fill up with segmented batches
|
|
19
|
+
masks = np.empty((len(imgs_t), *imgs_t[0].shape[-2:]),
|
|
20
|
+
dtype=bool)
|
|
21
|
+
|
|
22
|
+
for start_idx in range(0, size, batch_size):
|
|
23
|
+
batch = imgs_t[start_idx:start_idx + batch_size]
|
|
24
|
+
# Move image tensors to cuda
|
|
25
|
+
batch = torch.tensor(batch, device=device)
|
|
26
|
+
# Model inference
|
|
27
|
+
batch_seg = model(batch)
|
|
28
|
+
# Remove extra dim [B, C, H, W] --> [B, H, W]
|
|
29
|
+
batch_seg = batch_seg.squeeze(1)
|
|
30
|
+
# Convert cuda-tensor into numpy arrays
|
|
31
|
+
batch_seg_np = batch_seg.detach().cpu().numpy()
|
|
32
|
+
# Fill empty array with segmented batch
|
|
33
|
+
masks[start_idx:start_idx + batch_size] = batch_seg_np >= 0.5
|
|
34
|
+
|
|
35
|
+
return masks
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def segment_algorithm(images, gpu_id=None, batch_size=50, *,
|
|
39
|
+
model_file: str = None):
|
|
40
|
+
"""
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
images: 3d ndarray
|
|
44
|
+
array of N event images of shape (N, H, W)
|
|
45
|
+
gpu_id: str
|
|
46
|
+
optional argument specifying the GPU to use
|
|
47
|
+
batch_size: int
|
|
48
|
+
number of images to process in one batch
|
|
49
|
+
model_file: str
|
|
50
|
+
path to or name of a dcnum model file (.dcnm); if only a
|
|
51
|
+
name is provided, then the "torch_model_files" directory
|
|
52
|
+
paths are searched for the file name
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
mask: 2d boolean or integer ndarray
|
|
57
|
+
mask or label images of shape (N, H, W)
|
|
58
|
+
"""
|
|
59
|
+
if model_file is None:
|
|
60
|
+
raise ValueError("Please specify a model file!")
|
|
61
|
+
|
|
62
|
+
# Determine device to use
|
|
63
|
+
device = torch.device(gpu_id if gpu_id is not None else "cuda")
|
|
64
|
+
|
|
65
|
+
# Load model and metadata
|
|
66
|
+
model, model_meta = load_model(model_file, device)
|
|
67
|
+
|
|
68
|
+
# Preprocess the images
|
|
69
|
+
image_preproc = preprocess_images(images,
|
|
70
|
+
**model_meta["preprocessing"])
|
|
71
|
+
# Model inference
|
|
72
|
+
# The `masks` array has the shape (len(images), H, W), where
|
|
73
|
+
# H and W may be different from the corresponding axes in `images`.
|
|
74
|
+
masks = SegmentTorchSTO._segment_in_batches(image_preproc,
|
|
75
|
+
model,
|
|
76
|
+
batch_size,
|
|
77
|
+
device
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Perform postprocessing in cases where the image shapes don't match
|
|
81
|
+
assert len(masks.shape[1:]) == len(images.shape[1:]), "sanity check"
|
|
82
|
+
if masks.shape[1:] != images.shape[1:]:
|
|
83
|
+
labels = postprocess_masks(
|
|
84
|
+
masks=masks,
|
|
85
|
+
original_image_shape=images.shape[1:])
|
|
86
|
+
return labels
|
|
87
|
+
else:
|
|
88
|
+
return masks
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import errno
|
|
2
|
+
import functools
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import pathlib
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ...meta import paths
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def check_md5sum(path):
|
|
18
|
+
"""Verify the last five characters of the file stem with its MD5 hash"""
|
|
19
|
+
md5 = hashlib.md5(path.read_bytes()).hexdigest()
|
|
20
|
+
if md5[:5] != path.stem.split("_")[-1]:
|
|
21
|
+
raise ValueError(f"MD5 mismatch for {path} ({md5})! Expected the "
|
|
22
|
+
f"input file to end with '{md5[:5]}{path.suffix}'.")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@functools.cache
|
|
26
|
+
def load_model(path_or_name, device):
|
|
27
|
+
"""Load a PyTorch model + metadata from a TorchScript jit checkpoint
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
path_or_name: str or pathlib.Path
|
|
32
|
+
jit checkpoint file; For dcnum, these files have the suffix .dcnm
|
|
33
|
+
and contain a special `_extra_files["dcnum_meta.json"]` extra
|
|
34
|
+
file that can be loaded via `torch.jit.load` (see below).
|
|
35
|
+
device: str or torch.device
|
|
36
|
+
device on which to run the model
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
model_jit: torch.jit.ScriptModule
|
|
41
|
+
loaded PyTorch model stored as a TorchScript module
|
|
42
|
+
model_meta: dict
|
|
43
|
+
metadata associated with the loaded model
|
|
44
|
+
"""
|
|
45
|
+
model_path = retrieve_model_file(path_or_name)
|
|
46
|
+
# define an extra files mapping dictionary that loads the model's metadata
|
|
47
|
+
extra_files = {"dcnum_meta.json": ""}
|
|
48
|
+
# load model
|
|
49
|
+
model_jit = torch.jit.load(model_path,
|
|
50
|
+
_extra_files=extra_files,
|
|
51
|
+
map_location=device)
|
|
52
|
+
# load model metadata
|
|
53
|
+
model_meta = json.loads(extra_files["dcnum_meta.json"])
|
|
54
|
+
# set model to evaluation mode
|
|
55
|
+
model_jit.eval()
|
|
56
|
+
# optimize for inference on device
|
|
57
|
+
model_jit = torch.jit.optimize_for_inference(model_jit)
|
|
58
|
+
return model_jit, model_meta
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@functools.cache
|
|
62
|
+
def retrieve_model_file(path_or_name):
|
|
63
|
+
"""Retrieve a dcnum torch model file
|
|
64
|
+
|
|
65
|
+
If a path to a model is given, then this path is returned directly.
|
|
66
|
+
If a file name is given, then look for the file with
|
|
67
|
+
:func:`dcnum.meta.paths.find_file` using the "torch_model_file"
|
|
68
|
+
topic.
|
|
69
|
+
"""
|
|
70
|
+
# Did the user already pass a path?
|
|
71
|
+
if isinstance(path_or_name, pathlib.Path):
|
|
72
|
+
if path_or_name.exists():
|
|
73
|
+
path = path_or_name
|
|
74
|
+
else:
|
|
75
|
+
try:
|
|
76
|
+
return retrieve_model_file(path_or_name.name)
|
|
77
|
+
except BaseException:
|
|
78
|
+
raise FileNotFoundError(errno.ENOENT,
|
|
79
|
+
os.strerror(errno.ENOENT),
|
|
80
|
+
str(path_or_name))
|
|
81
|
+
elif isinstance(path_or_name, str):
|
|
82
|
+
name = path_or_name.strip()
|
|
83
|
+
# We now have a string for a filename, and we have to figure out what
|
|
84
|
+
# the path is. There are several options, including cached files.
|
|
85
|
+
if pathlib.Path(name).exists():
|
|
86
|
+
path = pathlib.Path(name)
|
|
87
|
+
else:
|
|
88
|
+
path = paths.find_file("torch_model_files", name)
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Please pass a string or a path, got {type(path_or_name)}!")
|
|
92
|
+
|
|
93
|
+
logger.info(f"Found dcnum model file {path}")
|
|
94
|
+
check_md5sum(path)
|
|
95
|
+
return path
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
from ..segmenter import Segmenter
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy import ndimage as ndi
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def postprocess_masks(masks,
|
|
10
|
+
original_image_shape: Tuple[int, int]):
|
|
11
|
+
"""Postprocess mask images from ML segmenters
|
|
12
|
+
|
|
13
|
+
The transformation includes:
|
|
14
|
+
- Revert the cropping and padding operations done in
|
|
15
|
+
:func:`.preprocess_images` by padding with zeros and cropping.
|
|
16
|
+
- If the original image shape is larger than the mask image shape,
|
|
17
|
+
also clear borders in an intermediate step
|
|
18
|
+
(maks postprocessing using :func:`Segmenter.process_mask`).
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
masks: 3d or 4d ndarray
|
|
23
|
+
Mask data in shape (batch_size, 1, imagex_size, imagey_size)
|
|
24
|
+
or (batch_size, imagex_size, imagey_size).
|
|
25
|
+
original_image_shape: tuple of (int, int)
|
|
26
|
+
The required output mask shape for one event. This required for
|
|
27
|
+
doing the inverse of what is done in :func:`.preprocess_images`.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
labels_proc: np.ndarray
|
|
32
|
+
An integer array with the same dimensions as the original image
|
|
33
|
+
data passed to :func:`.preprocess_images`. The shape of this array
|
|
34
|
+
is (batch_size, original_image_shape[0], original_image_shape[1]).
|
|
35
|
+
"""
|
|
36
|
+
# If output of model is 4d, remove channel axis
|
|
37
|
+
if len(masks.shape) == 4:
|
|
38
|
+
masks = masks[:, 0, :, :]
|
|
39
|
+
|
|
40
|
+
# Label the mask image
|
|
41
|
+
labels = np.empty(masks.shape, dtype=np.uint16)
|
|
42
|
+
label_struct = ndi.generate_binary_structure(2, 2)
|
|
43
|
+
for ii in range(masks.shape[0]):
|
|
44
|
+
ndi.label(
|
|
45
|
+
input=masks[ii],
|
|
46
|
+
output=labels[ii],
|
|
47
|
+
structure=label_struct)
|
|
48
|
+
|
|
49
|
+
batch_size = labels.shape[0]
|
|
50
|
+
|
|
51
|
+
# Revert padding and cropping from preprocessing
|
|
52
|
+
mask_shape_ret = labels.shape[1:]
|
|
53
|
+
# height
|
|
54
|
+
s0diff = original_image_shape[0] - mask_shape_ret[0]
|
|
55
|
+
s0t = abs(s0diff) // 2
|
|
56
|
+
s0b = abs(s0diff) - s0t
|
|
57
|
+
# width
|
|
58
|
+
s1diff = original_image_shape[1] - mask_shape_ret[1]
|
|
59
|
+
s1l = abs(s1diff) // 2
|
|
60
|
+
s1r = abs(s1diff) - s1l
|
|
61
|
+
|
|
62
|
+
if s0diff > 0 or s1diff > 0:
|
|
63
|
+
# The masks that we have must be padded. Before we do that, we have
|
|
64
|
+
# to remove events on the edges, otherwise we will have half-segmented
|
|
65
|
+
# cell events in the output array.
|
|
66
|
+
for ii in range(batch_size):
|
|
67
|
+
labels[ii] = Segmenter.process_mask(labels[ii],
|
|
68
|
+
clear_border=True,
|
|
69
|
+
fill_holes=False,
|
|
70
|
+
closing_disk=0)
|
|
71
|
+
|
|
72
|
+
# Crop first, only then pad.
|
|
73
|
+
if s1diff > 0:
|
|
74
|
+
labels_pad = np.zeros((batch_size,
|
|
75
|
+
labels.shape[1],
|
|
76
|
+
original_image_shape[1]),
|
|
77
|
+
dtype=np.uint16)
|
|
78
|
+
labels_pad[:, :, s1l:-s1r] = labels
|
|
79
|
+
labels = labels_pad
|
|
80
|
+
elif s1diff < 0:
|
|
81
|
+
labels = labels[:, :, s1l:-s1r]
|
|
82
|
+
|
|
83
|
+
if s0diff > 0:
|
|
84
|
+
labels_pad = np.zeros((batch_size,
|
|
85
|
+
original_image_shape[0],
|
|
86
|
+
original_image_shape[1]),
|
|
87
|
+
dtype=np.uint16)
|
|
88
|
+
labels_pad[:, s0t:-s0b, :] = labels
|
|
89
|
+
labels = labels_pad
|
|
90
|
+
elif s0diff < 0:
|
|
91
|
+
labels = labels[:, s0t:-s0b, :]
|
|
92
|
+
|
|
93
|
+
return labels
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def preprocess_images(images: np.ndarray,
|
|
7
|
+
norm_mean: float | None,
|
|
8
|
+
norm_std: float | None,
|
|
9
|
+
image_shape: Tuple[int, int] = None,
|
|
10
|
+
):
|
|
11
|
+
"""Transform image data to something torch models expect
|
|
12
|
+
|
|
13
|
+
The transformation includes:
|
|
14
|
+
- normalization (division by 255, subtraction of mean, division by std)
|
|
15
|
+
- cropping and padding of the input images to `image_shape`. For padding,
|
|
16
|
+
the median of each *individual* image is used.
|
|
17
|
+
- casting the input images to four dimensions
|
|
18
|
+
(batch_size, 1, height, width) where the second axis is "channels"
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
images:
|
|
23
|
+
Input image array (batch_size, height_in, width_in). If this is a
|
|
24
|
+
2D image, it will be reshaped to a 3D image with a batch_size of 1.
|
|
25
|
+
norm_mean:
|
|
26
|
+
Mean value used for standard score data normalization, i.e.
|
|
27
|
+
`normalized = `(images / 255 - norm_mean) / norm_std`; Set
|
|
28
|
+
to None to disable normalization.
|
|
29
|
+
norm_std:
|
|
30
|
+
Standard deviation used for standard score data normalization;
|
|
31
|
+
Set to None to disable normalization (see above).
|
|
32
|
+
image_shape
|
|
33
|
+
Image shape for which the model was created (height, width).
|
|
34
|
+
If the image shape does not match the input image shape, then
|
|
35
|
+
the input images are padded/cropped to fit the image shape of
|
|
36
|
+
the model.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
image_proc:
|
|
41
|
+
3D array with preprocessed image data of shape
|
|
42
|
+
(batch_size, 1, height, width)
|
|
43
|
+
"""
|
|
44
|
+
if len(images.shape) == 2:
|
|
45
|
+
# Insert indexing axis (batch dimension)
|
|
46
|
+
images = images[np.newaxis, :, :]
|
|
47
|
+
|
|
48
|
+
batch_size = images.shape[0]
|
|
49
|
+
|
|
50
|
+
# crop and pad the images based on what the model expects
|
|
51
|
+
image_shape_act = images.shape[1:]
|
|
52
|
+
if image_shape is None:
|
|
53
|
+
# model fits perfectly to input data
|
|
54
|
+
image_shape = image_shape_act
|
|
55
|
+
|
|
56
|
+
# height
|
|
57
|
+
hdiff = image_shape_act[0] - image_shape[0]
|
|
58
|
+
ht = abs(hdiff) // 2
|
|
59
|
+
hb = abs(hdiff) - ht
|
|
60
|
+
# width
|
|
61
|
+
wdiff = image_shape_act[1] - image_shape[1]
|
|
62
|
+
wl = abs(wdiff) // 2
|
|
63
|
+
wr = abs(wdiff) - wl
|
|
64
|
+
# helper variables
|
|
65
|
+
wpad = wdiff < 0
|
|
66
|
+
wcrp = wdiff > 0
|
|
67
|
+
hpad = hdiff < 0
|
|
68
|
+
hcrp = hdiff > 0
|
|
69
|
+
|
|
70
|
+
# The easy part is the cropping
|
|
71
|
+
if hcrp or wcrp:
|
|
72
|
+
# define slices for width and height
|
|
73
|
+
slice_hc = slice(ht, -hb) if hcrp else slice(None, None)
|
|
74
|
+
slice_wc = slice(wl, -wr) if wcrp else slice(None, None)
|
|
75
|
+
img_proc = images[:, slice_hc, slice_wc]
|
|
76
|
+
else:
|
|
77
|
+
img_proc = images
|
|
78
|
+
|
|
79
|
+
# The hard part is the padding
|
|
80
|
+
if hpad or wpad:
|
|
81
|
+
# compute median for each original input image
|
|
82
|
+
img_med = np.median(images, axis=(1, 2))
|
|
83
|
+
# broadcast the median array from 1D to 3D
|
|
84
|
+
img_med = img_med[:, None, None]
|
|
85
|
+
|
|
86
|
+
# define slices for width and height
|
|
87
|
+
slice_hp = slice(ht, -hb) if hpad else slice(None, None)
|
|
88
|
+
slice_wp = slice(wl, -wr) if wpad else slice(None, None)
|
|
89
|
+
|
|
90
|
+
# empty padding image stack with the shape required for the model
|
|
91
|
+
img_pad = np.empty(shape=(batch_size, image_shape[0], image_shape[1]),
|
|
92
|
+
dtype=np.float32)
|
|
93
|
+
# fill in original data
|
|
94
|
+
img_pad[:, slice_hp, slice_wp] = img_proc
|
|
95
|
+
# fill in background data for height
|
|
96
|
+
if hpad:
|
|
97
|
+
img_pad[:, :ht, :] = img_med
|
|
98
|
+
img_pad[:, -hb:, :] = img_med
|
|
99
|
+
# fill in background data for width
|
|
100
|
+
if wpad:
|
|
101
|
+
img_pad[:, :, :wl] = img_med
|
|
102
|
+
img_pad[:, :, -wr:] = img_med
|
|
103
|
+
# Replace img_norm
|
|
104
|
+
img_proc = img_pad
|
|
105
|
+
|
|
106
|
+
if norm_mean is None or norm_std is None:
|
|
107
|
+
# convert to float32
|
|
108
|
+
img_norm = img_proc.astype(np.float32)
|
|
109
|
+
else:
|
|
110
|
+
# normalize images
|
|
111
|
+
img_norm = (img_proc.astype(np.float32) / 255 - norm_mean) / norm_std
|
|
112
|
+
|
|
113
|
+
# Add a "channels" axis for the ML models.
|
|
114
|
+
return img_norm[:, np.newaxis, :, :]
|