dcnum 0.17.0__py3-none-any.whl → 0.23.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dcnum might be problematic. Click here for more details.
- dcnum/_version.py +2 -2
- dcnum/feat/__init__.py +1 -1
- dcnum/feat/event_extractor_manager_thread.py +34 -25
- dcnum/feat/feat_background/base.py +22 -26
- dcnum/feat/feat_background/bg_copy.py +18 -12
- dcnum/feat/feat_background/bg_roll_median.py +20 -10
- dcnum/feat/feat_background/bg_sparse_median.py +55 -7
- dcnum/feat/feat_brightness/bright_all.py +41 -6
- dcnum/feat/feat_contour/__init__.py +4 -0
- dcnum/feat/{feat_moments/mt_legacy.py → feat_contour/moments.py} +32 -8
- dcnum/feat/feat_contour/volume.py +174 -0
- dcnum/feat/feat_texture/tex_all.py +28 -1
- dcnum/feat/gate.py +2 -2
- dcnum/feat/queue_event_extractor.py +30 -9
- dcnum/logic/ctrl.py +199 -49
- dcnum/logic/job.py +63 -2
- dcnum/logic/json_encoder.py +2 -0
- dcnum/meta/ppid.py +17 -3
- dcnum/read/__init__.py +1 -0
- dcnum/read/cache.py +100 -78
- dcnum/read/const.py +6 -4
- dcnum/read/hdf5_data.py +146 -23
- dcnum/read/mapped.py +87 -0
- dcnum/segm/__init__.py +6 -3
- dcnum/segm/segm_thresh.py +6 -18
- dcnum/segm/segm_torch/__init__.py +19 -0
- dcnum/segm/segm_torch/segm_torch_base.py +125 -0
- dcnum/segm/segm_torch/segm_torch_mpo.py +71 -0
- dcnum/segm/segm_torch/segm_torch_sto.py +88 -0
- dcnum/segm/segm_torch/torch_model.py +95 -0
- dcnum/segm/segm_torch/torch_postproc.py +93 -0
- dcnum/segm/segm_torch/torch_preproc.py +114 -0
- dcnum/segm/segmenter.py +181 -80
- dcnum/segm/segmenter_manager_thread.py +38 -30
- dcnum/segm/{segmenter_cpu.py → segmenter_mpo.py} +116 -44
- dcnum/segm/segmenter_sto.py +110 -0
- dcnum/write/__init__.py +2 -1
- dcnum/write/deque_writer_thread.py +9 -1
- dcnum/write/queue_collector_thread.py +8 -14
- dcnum/write/writer.py +128 -5
- {dcnum-0.17.0.dist-info → dcnum-0.23.1.dist-info}/METADATA +4 -2
- dcnum-0.23.1.dist-info/RECORD +55 -0
- {dcnum-0.17.0.dist-info → dcnum-0.23.1.dist-info}/WHEEL +1 -1
- dcnum/feat/feat_moments/__init__.py +0 -4
- dcnum/segm/segmenter_gpu.py +0 -64
- dcnum-0.17.0.dist-info/RECORD +0 -46
- /dcnum/feat/{feat_moments/ct_opencv.py → feat_contour/contour.py} +0 -0
- {dcnum-0.17.0.dist-info → dcnum-0.23.1.dist-info}/LICENSE +0 -0
- {dcnum-0.17.0.dist-info → dcnum-0.23.1.dist-info}/top_level.txt +0 -0
dcnum/read/mapped.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import numbers
|
|
4
|
+
|
|
5
|
+
import h5py
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MappedHDF5Dataset:
|
|
10
|
+
def __init__(self,
|
|
11
|
+
h5ds: h5py.Dataset,
|
|
12
|
+
mapping_indices: np.ndarray):
|
|
13
|
+
"""An index-mapped object for accessing an HDF5 dataset
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
h5ds: h5py.Dataset
|
|
18
|
+
HDF5 dataset from which to map data
|
|
19
|
+
mapping_indices: np.ndarray
|
|
20
|
+
numpy indexing array containing integer indices
|
|
21
|
+
"""
|
|
22
|
+
self.h5ds = h5ds
|
|
23
|
+
self.mapping_indices = mapping_indices
|
|
24
|
+
self.shape = (mapping_indices.size,) + h5ds.shape[1:]
|
|
25
|
+
|
|
26
|
+
def __getitem__(self, idx):
|
|
27
|
+
if isinstance(idx, numbers.Integral):
|
|
28
|
+
return self.h5ds[self.mapping_indices[idx]]
|
|
29
|
+
else:
|
|
30
|
+
idx_mapped = self.mapping_indices[idx]
|
|
31
|
+
return self.h5ds[idx_mapped]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_mapping_indices(
|
|
35
|
+
index_mapping: numbers.Integral | slice | list | np.ndarray
|
|
36
|
+
):
|
|
37
|
+
"""Return integer numpy array with mapping indices for a range
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
index_mapping: numbers.Integral | slice | list | np.ndarray
|
|
42
|
+
Several options you have here:
|
|
43
|
+
- integer: results in np.arrange(integer)
|
|
44
|
+
- slice: results in np.arrange(slice.start, slice.stop, slice.step)
|
|
45
|
+
- list or np.ndarray: returns the input as unit32 array
|
|
46
|
+
"""
|
|
47
|
+
if isinstance(index_mapping, numbers.Integral):
|
|
48
|
+
return _get_mapping_indices_cached(index_mapping)
|
|
49
|
+
elif isinstance(index_mapping, slice):
|
|
50
|
+
return _get_mapping_indices_cached(
|
|
51
|
+
(index_mapping.start, index_mapping.stop, index_mapping.step))
|
|
52
|
+
elif isinstance(index_mapping, (np.ndarray, list)):
|
|
53
|
+
return np.array(index_mapping, dtype=np.uint32)
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Invalid type for `index_mapping`: "
|
|
56
|
+
f"{type(index_mapping)} ({index_mapping})")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@functools.lru_cache(maxsize=100)
|
|
60
|
+
def _get_mapping_indices_cached(
|
|
61
|
+
index_mapping: numbers.Integral | tuple
|
|
62
|
+
):
|
|
63
|
+
if isinstance(index_mapping, numbers.Integral):
|
|
64
|
+
return np.arange(index_mapping)
|
|
65
|
+
elif isinstance(index_mapping, tuple):
|
|
66
|
+
im_slice = slice(*index_mapping)
|
|
67
|
+
if im_slice.stop is None or im_slice.start is None:
|
|
68
|
+
raise NotImplementedError(
|
|
69
|
+
"Slices must have start and stop defined")
|
|
70
|
+
return np.arange(im_slice.start, im_slice.stop, im_slice.step)
|
|
71
|
+
elif isinstance(index_mapping, list):
|
|
72
|
+
return np.array(index_mapping, dtype=np.uint32)
|
|
73
|
+
else:
|
|
74
|
+
raise ValueError(f"Invalid type for cached `index_mapping`: "
|
|
75
|
+
f"{type(index_mapping)} ({index_mapping})")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def get_mapped_object(obj, index_mapping=None):
|
|
79
|
+
if index_mapping is None:
|
|
80
|
+
return obj
|
|
81
|
+
elif isinstance(obj, h5py.Dataset):
|
|
82
|
+
return MappedHDF5Dataset(
|
|
83
|
+
obj,
|
|
84
|
+
mapping_indices=get_mapping_indices(index_mapping))
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"No recipe to convert object of type {type(obj)} "
|
|
87
|
+
f"({obj}) to an index-mapped object")
|
dcnum/segm/__init__.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
# flake8: noqa: F401
|
|
2
|
-
from .segmenter import
|
|
3
|
-
|
|
4
|
-
|
|
2
|
+
from .segmenter import (
|
|
3
|
+
Segmenter, SegmenterNotApplicableError, get_available_segmenters
|
|
4
|
+
)
|
|
5
|
+
from .segmenter_mpo import MPOSegmenter
|
|
6
|
+
from .segmenter_sto import STOSegmenter
|
|
5
7
|
from .segmenter_manager_thread import SegmenterManagerThread
|
|
6
8
|
from . import segm_thresh
|
|
9
|
+
from . import segm_torch
|
dcnum/segm/segm_thresh.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .segmenter_mpo import MPOSegmenter
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
class SegmentThresh(
|
|
4
|
+
class SegmentThresh(MPOSegmenter):
|
|
5
5
|
mask_postprocessing = True
|
|
6
6
|
mask_default_kwargs = {
|
|
7
7
|
"clear_border": True,
|
|
@@ -10,22 +10,10 @@ class SegmentThresh(CPUSegmenter):
|
|
|
10
10
|
}
|
|
11
11
|
requires_background_correction = True
|
|
12
12
|
|
|
13
|
-
def __init__(self, thresh=-6, *args, **kwargs):
|
|
14
|
-
"""Simple image thresholding segmentation
|
|
15
|
-
|
|
16
|
-
Parameters
|
|
17
|
-
----------
|
|
18
|
-
thresh: int
|
|
19
|
-
grayscale threhold value for creating the mask image;
|
|
20
|
-
For a background-corrected image, pixels with values below
|
|
21
|
-
this value are considered to be part of the mask.
|
|
22
|
-
"""
|
|
23
|
-
super(SegmentThresh, self).__init__(thresh=thresh, *args, **kwargs)
|
|
24
|
-
|
|
25
13
|
@staticmethod
|
|
26
|
-
def
|
|
27
|
-
|
|
28
|
-
"""Mask retrieval
|
|
14
|
+
def segment_algorithm(image, *,
|
|
15
|
+
thresh: float = -6):
|
|
16
|
+
"""Mask retrieval using basic thresholding
|
|
29
17
|
|
|
30
18
|
Parameters
|
|
31
19
|
----------
|
|
@@ -39,7 +27,7 @@ class SegmentThresh(CPUSegmenter):
|
|
|
39
27
|
Returns
|
|
40
28
|
-------
|
|
41
29
|
mask: 2d boolean ndarray
|
|
42
|
-
Mask image for the
|
|
30
|
+
Mask image for the given index
|
|
43
31
|
"""
|
|
44
32
|
assert thresh < 0, "threshold values above zero not supported!"
|
|
45
33
|
return image < thresh
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
torch = importlib.import_module("torch")
|
|
5
|
+
req_maj = 2
|
|
6
|
+
req_min = 3
|
|
7
|
+
ver_tuple = torch.__version__.split(".")
|
|
8
|
+
act_maj = int(ver_tuple[0])
|
|
9
|
+
act_min = int(ver_tuple[1])
|
|
10
|
+
if act_maj < req_maj or (act_maj == req_maj and act_min < req_min):
|
|
11
|
+
raise ValueError(f"Your PyTorch version {act_maj}.{act_min} is not "
|
|
12
|
+
f"supported, please update to at least "
|
|
13
|
+
f"{req_maj}.{req_min}")
|
|
14
|
+
except ImportError:
|
|
15
|
+
pass
|
|
16
|
+
else:
|
|
17
|
+
from .segm_torch_mpo import SegmentTorchMPO # noqa: F401
|
|
18
|
+
if torch.cuda.is_available():
|
|
19
|
+
from .segm_torch_sto import SegmentTorchSTO # noqa: F401
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import pathlib
|
|
3
|
+
import re
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
from ...meta import paths
|
|
7
|
+
|
|
8
|
+
from ..segmenter import Segmenter, SegmenterNotApplicableError
|
|
9
|
+
|
|
10
|
+
from .torch_model import load_model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TorchSegmenterBase(Segmenter):
|
|
14
|
+
"""Torch segmenters that use a pretrained model for segmentation"""
|
|
15
|
+
requires_background_correction = False
|
|
16
|
+
mask_postprocessing = True
|
|
17
|
+
mask_default_kwargs = {
|
|
18
|
+
"clear_border": True,
|
|
19
|
+
"fill_holes": True,
|
|
20
|
+
"closing_disk": 0,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def get_ppid_from_ppkw(cls, kwargs, kwargs_mask=None):
|
|
25
|
+
kwargs_new = kwargs.copy()
|
|
26
|
+
# Make sure that the `model_file` kwarg is actually just a filename
|
|
27
|
+
# so that the pipeline identifier only contains the name, but not
|
|
28
|
+
# the full path.
|
|
29
|
+
if "model_file" in kwargs:
|
|
30
|
+
model_file = kwargs["model_file"]
|
|
31
|
+
mpath = pathlib.Path(model_file)
|
|
32
|
+
if mpath.exists():
|
|
33
|
+
# register the location of the file in the search path
|
|
34
|
+
# registry so other threads/processes will find it.
|
|
35
|
+
paths.register_search_path("torch_model_files", mpath.parent)
|
|
36
|
+
kwargs_new["model_file"] = mpath.name
|
|
37
|
+
return super(TorchSegmenterBase, cls).get_ppid_from_ppkw(kwargs_new,
|
|
38
|
+
kwargs_mask)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def validate_applicability(cls,
|
|
42
|
+
segmenter_kwargs: Dict,
|
|
43
|
+
meta: Dict = None,
|
|
44
|
+
logs: Dict = None):
|
|
45
|
+
"""Validate the applicability of this segmenter for a dataset
|
|
46
|
+
|
|
47
|
+
The applicability is defined by the metadata in the segmentation
|
|
48
|
+
model.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
segmenter_kwargs: dict
|
|
53
|
+
Keyword arguments for the segmenter
|
|
54
|
+
meta: dict
|
|
55
|
+
Dictionary of metadata from an :class:`HDF5Data` instance
|
|
56
|
+
logs: dict
|
|
57
|
+
Dictionary of logs from an :class:`HDF5Data` instance
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
applicable: bool
|
|
62
|
+
True if the segmenter is applicable to the dataset
|
|
63
|
+
|
|
64
|
+
Raises
|
|
65
|
+
------
|
|
66
|
+
SegmenterNotApplicable
|
|
67
|
+
If the segmenter is not applicable to the dataset
|
|
68
|
+
"""
|
|
69
|
+
if "model_file" not in segmenter_kwargs:
|
|
70
|
+
raise ValueError("A `model_file` must be provided in the "
|
|
71
|
+
"`segmenter_kwargs` to validate applicability")
|
|
72
|
+
|
|
73
|
+
model_file = segmenter_kwargs["model_file"]
|
|
74
|
+
_, model_meta = load_model(model_file, device="cpu")
|
|
75
|
+
|
|
76
|
+
reasons_list = []
|
|
77
|
+
validators = {
|
|
78
|
+
"meta": functools.partial(
|
|
79
|
+
cls._validate_applicability_item,
|
|
80
|
+
data_dict=meta,
|
|
81
|
+
reasons_list=reasons_list),
|
|
82
|
+
"logs": functools.partial(
|
|
83
|
+
cls._validate_applicability_item,
|
|
84
|
+
# convert logs to strings
|
|
85
|
+
data_dict={key: "\n".join(val) for key, val in logs.items()},
|
|
86
|
+
reasons_list=reasons_list)
|
|
87
|
+
}
|
|
88
|
+
for item in model_meta.get("validation", []):
|
|
89
|
+
it = item["type"]
|
|
90
|
+
if it in validators:
|
|
91
|
+
validators[it](item)
|
|
92
|
+
else:
|
|
93
|
+
reasons_list.append(
|
|
94
|
+
f"invalid validation type {it} in {model_file}")
|
|
95
|
+
|
|
96
|
+
if reasons_list:
|
|
97
|
+
raise SegmenterNotApplicableError(segmenter_class=cls,
|
|
98
|
+
reasons_list=reasons_list)
|
|
99
|
+
|
|
100
|
+
return True
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _validate_applicability_item(item, data_dict, reasons_list):
|
|
104
|
+
"""Populate `reasons_list` with invalid entries
|
|
105
|
+
|
|
106
|
+
Example `data_dict`::
|
|
107
|
+
|
|
108
|
+
{"type": "meta",
|
|
109
|
+
"key": "setup:region",
|
|
110
|
+
"allow-missing-key": False,
|
|
111
|
+
"regexp": "^channel$",
|
|
112
|
+
"regexp-negate": False,
|
|
113
|
+
"reason": "only channel region supported",
|
|
114
|
+
}
|
|
115
|
+
"""
|
|
116
|
+
key = item["key"]
|
|
117
|
+
if key in data_dict:
|
|
118
|
+
regexp = re.compile(item["regexp"])
|
|
119
|
+
matched = bool(regexp.match(data_dict[key]))
|
|
120
|
+
negate = item.get("regexp-negate", False)
|
|
121
|
+
valid = matched if not negate else not matched
|
|
122
|
+
if not valid:
|
|
123
|
+
reasons_list.append(item.get("reason", "unknown reason"))
|
|
124
|
+
elif not item.get("allow-missing-key", False):
|
|
125
|
+
reasons_list.append(f"Key '{key}' missing in {item['type']}")
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ..segmenter_mpo import MPOSegmenter
|
|
5
|
+
|
|
6
|
+
from .segm_torch_base import TorchSegmenterBase
|
|
7
|
+
from .torch_model import load_model
|
|
8
|
+
from .torch_preproc import preprocess_images
|
|
9
|
+
from .torch_postproc import postprocess_masks
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SegmentTorchMPO(TorchSegmenterBase, MPOSegmenter):
|
|
13
|
+
"""PyTorch segmentation (multiprocessing version)"""
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def segment_algorithm(image, *,
|
|
17
|
+
model_file: str = None):
|
|
18
|
+
"""
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
image: 2d ndarray
|
|
22
|
+
event image
|
|
23
|
+
model_file: str
|
|
24
|
+
path to or name of a dcnum model file (.dcnm); if only a
|
|
25
|
+
name is provided, then the "torch_model_files" directory
|
|
26
|
+
paths are searched for the file name
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
mask: 2d boolean or integer ndarray
|
|
31
|
+
mask or labeling image for the give index
|
|
32
|
+
"""
|
|
33
|
+
if model_file is None:
|
|
34
|
+
raise ValueError("Please specify a .dcnm model file!")
|
|
35
|
+
|
|
36
|
+
# Set number of pytorch threads to 1, because dcnum is doing
|
|
37
|
+
# all the multiprocessing.
|
|
38
|
+
# https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads
|
|
39
|
+
torch.set_num_threads(1)
|
|
40
|
+
device = torch.device("cpu")
|
|
41
|
+
|
|
42
|
+
# Load model and metadata
|
|
43
|
+
model, model_meta = load_model(model_file, device)
|
|
44
|
+
|
|
45
|
+
image_preproc = preprocess_images(image[np.newaxis, :, :],
|
|
46
|
+
**model_meta["preprocessing"])
|
|
47
|
+
|
|
48
|
+
image_ten = torch.from_numpy(image_preproc)
|
|
49
|
+
|
|
50
|
+
# Move image tensors to device
|
|
51
|
+
image_ten_on_device = image_ten.to(device)
|
|
52
|
+
# Model inference
|
|
53
|
+
pred_tensor = model(image_ten_on_device)
|
|
54
|
+
|
|
55
|
+
# Convert cuda-tensor into numpy mask array. The `pred_tensor`
|
|
56
|
+
# array is still of the shape (1, 1, H, W). The `masks`
|
|
57
|
+
# array is of shape (1, H, W). We can optionally label it
|
|
58
|
+
# here (we have to if the shapes don't match) or do it in
|
|
59
|
+
# postprocessing.
|
|
60
|
+
masks = pred_tensor.detach().cpu().numpy()[0] >= 0.5
|
|
61
|
+
|
|
62
|
+
# Perform postprocessing in cases where the image shapes don't match
|
|
63
|
+
assert len(masks[0].shape) == len(image.shape), "sanity check"
|
|
64
|
+
if masks[0].shape != image.shape:
|
|
65
|
+
labels = postprocess_masks(
|
|
66
|
+
masks=masks,
|
|
67
|
+
original_image_shape=image.shape,
|
|
68
|
+
)
|
|
69
|
+
return labels[0]
|
|
70
|
+
else:
|
|
71
|
+
return masks[0]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from dcnum.segm import STOSegmenter
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .segm_torch_base import TorchSegmenterBase
|
|
6
|
+
from .torch_model import load_model
|
|
7
|
+
from .torch_preproc import preprocess_images
|
|
8
|
+
from .torch_postproc import postprocess_masks
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SegmentTorchSTO(TorchSegmenterBase, STOSegmenter):
|
|
12
|
+
"""PyTorch segmentation (GPU version)"""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def _segment_in_batches(imgs_t, model, batch_size, device):
|
|
16
|
+
"""Segment image data in batches"""
|
|
17
|
+
size = len(imgs_t)
|
|
18
|
+
# Create empty array to fill up with segmented batches
|
|
19
|
+
masks = np.empty((len(imgs_t), *imgs_t[0].shape[-2:]),
|
|
20
|
+
dtype=bool)
|
|
21
|
+
|
|
22
|
+
for start_idx in range(0, size, batch_size):
|
|
23
|
+
batch = imgs_t[start_idx:start_idx + batch_size]
|
|
24
|
+
# Move image tensors to cuda
|
|
25
|
+
batch = torch.tensor(batch, device=device)
|
|
26
|
+
# Model inference
|
|
27
|
+
batch_seg = model(batch)
|
|
28
|
+
# Remove extra dim [B, C, H, W] --> [B, H, W]
|
|
29
|
+
batch_seg = batch_seg.squeeze(1)
|
|
30
|
+
# Convert cuda-tensor into numpy arrays
|
|
31
|
+
batch_seg_np = batch_seg.detach().cpu().numpy()
|
|
32
|
+
# Fill empty array with segmented batch
|
|
33
|
+
masks[start_idx:start_idx + batch_size] = batch_seg_np >= 0.5
|
|
34
|
+
|
|
35
|
+
return masks
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def segment_algorithm(images, gpu_id=None, batch_size=50, *,
|
|
39
|
+
model_file: str = None):
|
|
40
|
+
"""
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
images: 3d ndarray
|
|
44
|
+
array of N event images of shape (N, H, W)
|
|
45
|
+
gpu_id: str
|
|
46
|
+
optional argument specifying the GPU to use
|
|
47
|
+
batch_size: int
|
|
48
|
+
number of images to process in one batch
|
|
49
|
+
model_file: str
|
|
50
|
+
path to or name of a dcnum model file (.dcnm); if only a
|
|
51
|
+
name is provided, then the "torch_model_files" directory
|
|
52
|
+
paths are searched for the file name
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
mask: 2d boolean or integer ndarray
|
|
57
|
+
mask or label images of shape (N, H, W)
|
|
58
|
+
"""
|
|
59
|
+
if model_file is None:
|
|
60
|
+
raise ValueError("Please specify a model file!")
|
|
61
|
+
|
|
62
|
+
# Determine device to use
|
|
63
|
+
device = torch.device(gpu_id if gpu_id is not None else "cuda")
|
|
64
|
+
|
|
65
|
+
# Load model and metadata
|
|
66
|
+
model, model_meta = load_model(model_file, device)
|
|
67
|
+
|
|
68
|
+
# Preprocess the images
|
|
69
|
+
image_preproc = preprocess_images(images,
|
|
70
|
+
**model_meta["preprocessing"])
|
|
71
|
+
# Model inference
|
|
72
|
+
# The `masks` array has the shape (len(images), H, W), where
|
|
73
|
+
# H and W may be different from the corresponding axes in `images`.
|
|
74
|
+
masks = SegmentTorchSTO._segment_in_batches(image_preproc,
|
|
75
|
+
model,
|
|
76
|
+
batch_size,
|
|
77
|
+
device
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Perform postprocessing in cases where the image shapes don't match
|
|
81
|
+
assert len(masks.shape[1:]) == len(images.shape[1:]), "sanity check"
|
|
82
|
+
if masks.shape[1:] != images.shape[1:]:
|
|
83
|
+
labels = postprocess_masks(
|
|
84
|
+
masks=masks,
|
|
85
|
+
original_image_shape=images.shape[1:])
|
|
86
|
+
return labels
|
|
87
|
+
else:
|
|
88
|
+
return masks
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import errno
|
|
2
|
+
import functools
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import pathlib
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ...meta import paths
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def check_md5sum(path):
|
|
18
|
+
"""Verify the last five characters of the file stem with its MD5 hash"""
|
|
19
|
+
md5 = hashlib.md5(path.read_bytes()).hexdigest()
|
|
20
|
+
if md5[:5] != path.stem.split("_")[-1]:
|
|
21
|
+
raise ValueError(f"MD5 mismatch for {path} ({md5})! Expected the "
|
|
22
|
+
f"input file to end with '{md5[:5]}{path.suffix}'.")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@functools.cache
|
|
26
|
+
def load_model(path_or_name, device):
|
|
27
|
+
"""Load a PyTorch model + metadata from a TorchScript jit checkpoint
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
path_or_name: str or pathlib.Path
|
|
32
|
+
jit checkpoint file; For dcnum, these files have the suffix .dcnm
|
|
33
|
+
and contain a special `_extra_files["dcnum_meta.json"]` extra
|
|
34
|
+
file that can be loaded via `torch.jit.load` (see below).
|
|
35
|
+
device: str or torch.device
|
|
36
|
+
device on which to run the model
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
model_jit: torch.jit.ScriptModule
|
|
41
|
+
loaded PyTorch model stored as a TorchScript module
|
|
42
|
+
model_meta: dict
|
|
43
|
+
metadata associated with the loaded model
|
|
44
|
+
"""
|
|
45
|
+
model_path = retrieve_model_file(path_or_name)
|
|
46
|
+
# define an extra files mapping dictionary that loads the model's metadata
|
|
47
|
+
extra_files = {"dcnum_meta.json": ""}
|
|
48
|
+
# load model
|
|
49
|
+
model_jit = torch.jit.load(model_path,
|
|
50
|
+
_extra_files=extra_files,
|
|
51
|
+
map_location=device)
|
|
52
|
+
# load model metadata
|
|
53
|
+
model_meta = json.loads(extra_files["dcnum_meta.json"])
|
|
54
|
+
# set model to evaluation mode
|
|
55
|
+
model_jit.eval()
|
|
56
|
+
# optimize for inference on device
|
|
57
|
+
model_jit = torch.jit.optimize_for_inference(model_jit)
|
|
58
|
+
return model_jit, model_meta
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@functools.cache
|
|
62
|
+
def retrieve_model_file(path_or_name):
|
|
63
|
+
"""Retrieve a dcnum torch model file
|
|
64
|
+
|
|
65
|
+
If a path to a model is given, then this path is returned directly.
|
|
66
|
+
If a file name is given, then look for the file with
|
|
67
|
+
:func:`dcnum.meta.paths.find_file` using the "torch_model_file"
|
|
68
|
+
topic.
|
|
69
|
+
"""
|
|
70
|
+
# Did the user already pass a path?
|
|
71
|
+
if isinstance(path_or_name, pathlib.Path):
|
|
72
|
+
if path_or_name.exists():
|
|
73
|
+
path = path_or_name
|
|
74
|
+
else:
|
|
75
|
+
try:
|
|
76
|
+
return retrieve_model_file(path_or_name.name)
|
|
77
|
+
except BaseException:
|
|
78
|
+
raise FileNotFoundError(errno.ENOENT,
|
|
79
|
+
os.strerror(errno.ENOENT),
|
|
80
|
+
str(path_or_name))
|
|
81
|
+
elif isinstance(path_or_name, str):
|
|
82
|
+
name = path_or_name.strip()
|
|
83
|
+
# We now have a string for a filename, and we have to figure out what
|
|
84
|
+
# the path is. There are several options, including cached files.
|
|
85
|
+
if pathlib.Path(name).exists():
|
|
86
|
+
path = pathlib.Path(name)
|
|
87
|
+
else:
|
|
88
|
+
path = paths.find_file("torch_model_files", name)
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Please pass a string or a path, got {type(path_or_name)}!")
|
|
92
|
+
|
|
93
|
+
logger.info(f"Found dcnum model file {path}")
|
|
94
|
+
check_md5sum(path)
|
|
95
|
+
return path
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
from ..segmenter import Segmenter
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy import ndimage as ndi
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def postprocess_masks(masks,
|
|
10
|
+
original_image_shape: Tuple[int, int]):
|
|
11
|
+
"""Postprocess mask images from ML segmenters
|
|
12
|
+
|
|
13
|
+
The transformation includes:
|
|
14
|
+
- Revert the cropping and padding operations done in
|
|
15
|
+
:func:`.preprocess_images` by padding with zeros and cropping.
|
|
16
|
+
- If the original image shape is larger than the mask image shape,
|
|
17
|
+
also clear borders in an intermediate step
|
|
18
|
+
(maks postprocessing using :func:`Segmenter.process_mask`).
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
masks: 3d or 4d ndarray
|
|
23
|
+
Mask data in shape (batch_size, 1, imagex_size, imagey_size)
|
|
24
|
+
or (batch_size, imagex_size, imagey_size).
|
|
25
|
+
original_image_shape: tuple of (int, int)
|
|
26
|
+
The required output mask shape for one event. This required for
|
|
27
|
+
doing the inverse of what is done in :func:`.preprocess_images`.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
labels_proc: np.ndarray
|
|
32
|
+
An integer array with the same dimensions as the original image
|
|
33
|
+
data passed to :func:`.preprocess_images`. The shape of this array
|
|
34
|
+
is (batch_size, original_image_shape[0], original_image_shape[1]).
|
|
35
|
+
"""
|
|
36
|
+
# If output of model is 4d, remove channel axis
|
|
37
|
+
if len(masks.shape) == 4:
|
|
38
|
+
masks = masks[:, 0, :, :]
|
|
39
|
+
|
|
40
|
+
# Label the mask image
|
|
41
|
+
labels = np.empty(masks.shape, dtype=np.uint16)
|
|
42
|
+
label_struct = ndi.generate_binary_structure(2, 2)
|
|
43
|
+
for ii in range(masks.shape[0]):
|
|
44
|
+
ndi.label(
|
|
45
|
+
input=masks[ii],
|
|
46
|
+
output=labels[ii],
|
|
47
|
+
structure=label_struct)
|
|
48
|
+
|
|
49
|
+
batch_size = labels.shape[0]
|
|
50
|
+
|
|
51
|
+
# Revert padding and cropping from preprocessing
|
|
52
|
+
mask_shape_ret = labels.shape[1:]
|
|
53
|
+
# height
|
|
54
|
+
s0diff = original_image_shape[0] - mask_shape_ret[0]
|
|
55
|
+
s0t = abs(s0diff) // 2
|
|
56
|
+
s0b = abs(s0diff) - s0t
|
|
57
|
+
# width
|
|
58
|
+
s1diff = original_image_shape[1] - mask_shape_ret[1]
|
|
59
|
+
s1l = abs(s1diff) // 2
|
|
60
|
+
s1r = abs(s1diff) - s1l
|
|
61
|
+
|
|
62
|
+
if s0diff > 0 or s1diff > 0:
|
|
63
|
+
# The masks that we have must be padded. Before we do that, we have
|
|
64
|
+
# to remove events on the edges, otherwise we will have half-segmented
|
|
65
|
+
# cell events in the output array.
|
|
66
|
+
for ii in range(batch_size):
|
|
67
|
+
labels[ii] = Segmenter.process_mask(labels[ii],
|
|
68
|
+
clear_border=True,
|
|
69
|
+
fill_holes=False,
|
|
70
|
+
closing_disk=0)
|
|
71
|
+
|
|
72
|
+
# Crop first, only then pad.
|
|
73
|
+
if s1diff > 0:
|
|
74
|
+
labels_pad = np.zeros((batch_size,
|
|
75
|
+
labels.shape[1],
|
|
76
|
+
original_image_shape[1]),
|
|
77
|
+
dtype=np.uint16)
|
|
78
|
+
labels_pad[:, :, s1l:-s1r] = labels
|
|
79
|
+
labels = labels_pad
|
|
80
|
+
elif s1diff < 0:
|
|
81
|
+
labels = labels[:, :, s1l:-s1r]
|
|
82
|
+
|
|
83
|
+
if s0diff > 0:
|
|
84
|
+
labels_pad = np.zeros((batch_size,
|
|
85
|
+
original_image_shape[0],
|
|
86
|
+
original_image_shape[1]),
|
|
87
|
+
dtype=np.uint16)
|
|
88
|
+
labels_pad[:, s0t:-s0b, :] = labels
|
|
89
|
+
labels = labels_pad
|
|
90
|
+
elif s0diff < 0:
|
|
91
|
+
labels = labels[:, s0t:-s0b, :]
|
|
92
|
+
|
|
93
|
+
return labels
|