spacr 0.5.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +0 -2
- spacr/__main__.py +3 -3
- spacr/core.py +13 -106
- spacr/gui_core.py +2 -77
- spacr/gui_utils.py +1 -13
- spacr/io.py +24 -25
- spacr/mediar.py +12 -8
- spacr/plot.py +50 -135
- spacr/settings.py +42 -30
- spacr/submodules.py +11 -1
- spacr/timelapse.py +7 -79
- spacr/utils.py +152 -61
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/METADATA +62 -62
- spacr-0.9.1.dist-info/RECORD +109 -0
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/WHEEL +1 -1
- spacr/resources/MEDIAR/.gitignore +0 -18
- spacr/resources/MEDIAR/LICENSE +0 -21
- spacr/resources/MEDIAR/README.md +0 -189
- spacr/resources/MEDIAR/SetupDict.py +0 -39
- spacr/resources/MEDIAR/__pycache__/SetupDict.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/evaluate.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/generate_mapping.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/main.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/config/baseline.json +0 -60
- spacr/resources/MEDIAR/config/mediar_example.json +0 -72
- spacr/resources/MEDIAR/config/pred/pred_mediar.json +0 -17
- spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +0 -55
- spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +0 -58
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +0 -66
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +0 -66
- spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +0 -16
- spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +0 -23
- spacr/resources/MEDIAR/core/BasePredictor.py +0 -120
- spacr/resources/MEDIAR/core/BaseTrainer.py +0 -240
- spacr/resources/MEDIAR/core/Baseline/Predictor.py +0 -59
- spacr/resources/MEDIAR/core/Baseline/Trainer.py +0 -113
- spacr/resources/MEDIAR/core/Baseline/__init__.py +0 -2
- spacr/resources/MEDIAR/core/Baseline/__pycache__/Predictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/Trainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/utils.py +0 -80
- spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +0 -105
- spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +0 -234
- spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +0 -172
- spacr/resources/MEDIAR/core/MEDIAR/__init__.py +0 -3
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/EnsemblePredictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Predictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Trainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/utils.py +0 -429
- spacr/resources/MEDIAR/core/__init__.py +0 -2
- spacr/resources/MEDIAR/core/__pycache__/BasePredictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/BaseTrainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/utils.py +0 -40
- spacr/resources/MEDIAR/evaluate.py +0 -71
- spacr/resources/MEDIAR/generate_mapping.py +0 -121
- spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
- spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
- spacr/resources/MEDIAR/image/failure_cases.png +0 -0
- spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
- spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
- spacr/resources/MEDIAR/image/mediar_results.png +0 -0
- spacr/resources/MEDIAR/main.py +0 -125
- spacr/resources/MEDIAR/predict.py +0 -70
- spacr/resources/MEDIAR/requirements.txt +0 -14
- spacr/resources/MEDIAR/train_tools/__init__.py +0 -3
- spacr/resources/MEDIAR/train_tools/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/measures.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +0 -1
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/datasetter.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/transforms.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +0 -88
- spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +0 -161
- spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +0 -77
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +0 -3
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/CellAware.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/LoadImage.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/NormalizeImage.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +0 -208
- spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +0 -148
- spacr/resources/MEDIAR/train_tools/data_utils/utils.py +0 -84
- spacr/resources/MEDIAR/train_tools/measures.py +0 -200
- spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +0 -102
- spacr/resources/MEDIAR/train_tools/models/__init__.py +0 -1
- spacr/resources/MEDIAR/train_tools/models/__pycache__/MEDIARFormer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/models/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/utils.py +0 -70
- spacr-0.5.0.dist-info/RECORD +0 -190
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/LICENSE +0 -0
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/entry_points.txt +0 -0
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,208 +0,0 @@
|
|
1
|
-
from torch.utils.data import DataLoader
|
2
|
-
from monai.data import Dataset
|
3
|
-
import pickle
|
4
|
-
|
5
|
-
from .transforms import (
|
6
|
-
train_transforms,
|
7
|
-
public_transforms,
|
8
|
-
valid_transforms,
|
9
|
-
tuning_transforms,
|
10
|
-
unlabeled_transforms,
|
11
|
-
)
|
12
|
-
from .utils import split_train_valid, path_decoder
|
13
|
-
|
14
|
-
DATA_LABEL_DICT_PICKLE_FILE = "./train_tools/data_utils/custom/modalities.pkl"
|
15
|
-
|
16
|
-
__all__ = [
|
17
|
-
"get_dataloaders_labeled",
|
18
|
-
"get_dataloaders_public",
|
19
|
-
"get_dataloaders_unlabeled",
|
20
|
-
]
|
21
|
-
|
22
|
-
|
23
|
-
def get_dataloaders_labeled(
|
24
|
-
root,
|
25
|
-
mapping_file,
|
26
|
-
mapping_file_tuning,
|
27
|
-
join_mapping_file=None,
|
28
|
-
valid_portion=0.0,
|
29
|
-
batch_size=8,
|
30
|
-
amplified=False,
|
31
|
-
relabel=False,
|
32
|
-
):
|
33
|
-
"""Set DataLoaders for labeled datasets.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
root (str): root directory
|
37
|
-
mapping_file (str): json file for mapping dataset
|
38
|
-
valid_portion (float, optional): portion of valid datasets. Defaults to 0.1.
|
39
|
-
batch_size (int, optional): batch size. Defaults to 8.
|
40
|
-
shuffle (bool, optional): shuffles dataloader. Defaults to True.
|
41
|
-
num_workers (int, optional): number of workers for each datalaoder. Defaults to 5.
|
42
|
-
|
43
|
-
Returns:
|
44
|
-
dict: dictionary of data loaders.
|
45
|
-
"""
|
46
|
-
|
47
|
-
# Get list of data dictionaries from decoded paths
|
48
|
-
data_dicts = path_decoder(root, mapping_file)
|
49
|
-
tuning_dicts = path_decoder(root, mapping_file_tuning, no_label=True)
|
50
|
-
|
51
|
-
if amplified:
|
52
|
-
with open(DATA_LABEL_DICT_PICKLE_FILE, "rb") as f:
|
53
|
-
data_label_dict = pickle.load(f)
|
54
|
-
|
55
|
-
data_point_dict = {}
|
56
|
-
|
57
|
-
for label, data_lst in data_label_dict.items():
|
58
|
-
data_point_dict[label] = []
|
59
|
-
|
60
|
-
for d_idx in data_lst:
|
61
|
-
try:
|
62
|
-
data_point_dict[label].append(data_dicts[d_idx])
|
63
|
-
except:
|
64
|
-
print(label, d_idx)
|
65
|
-
|
66
|
-
data_dicts = []
|
67
|
-
|
68
|
-
for label, data_points in data_point_dict.items():
|
69
|
-
len_data_points = len(data_points)
|
70
|
-
|
71
|
-
if len_data_points >= 50:
|
72
|
-
data_dicts += data_points
|
73
|
-
else:
|
74
|
-
for i in range(50):
|
75
|
-
data_dicts.append(data_points[i % len_data_points])
|
76
|
-
|
77
|
-
data_transforms = train_transforms
|
78
|
-
|
79
|
-
if join_mapping_file is not None:
|
80
|
-
data_dicts += path_decoder(root, join_mapping_file)
|
81
|
-
data_transforms = public_transforms
|
82
|
-
|
83
|
-
if relabel:
|
84
|
-
for elem in data_dicts:
|
85
|
-
cell_idx = int(elem["label"].split("_label.tiff")[0].split("_")[-1])
|
86
|
-
if cell_idx in range(340, 499):
|
87
|
-
new_label = elem["label"].replace(
|
88
|
-
"/data/CellSeg/Official/Train_Labeled/labels/",
|
89
|
-
"/CellSeg/pretrained_train_ext/",
|
90
|
-
)
|
91
|
-
elem["label"] = new_label
|
92
|
-
|
93
|
-
# Split datasets as Train/Valid
|
94
|
-
train_dicts, valid_dicts = split_train_valid(
|
95
|
-
data_dicts, valid_portion=valid_portion
|
96
|
-
)
|
97
|
-
|
98
|
-
# Obtain datasets with transforms
|
99
|
-
trainset = Dataset(train_dicts, transform=data_transforms)
|
100
|
-
validset = Dataset(valid_dicts, transform=valid_transforms)
|
101
|
-
tuningset = Dataset(tuning_dicts, transform=tuning_transforms)
|
102
|
-
|
103
|
-
# Set dataloader for Trainset
|
104
|
-
train_loader = DataLoader(
|
105
|
-
trainset, batch_size=batch_size, shuffle=True, num_workers=5
|
106
|
-
)
|
107
|
-
|
108
|
-
# Set dataloader for Validset (Batch size is fixed as 1)
|
109
|
-
valid_loader = DataLoader(validset, batch_size=1, shuffle=False,)
|
110
|
-
|
111
|
-
# Set dataloader for Tuningset (Batch size is fixed as 1)
|
112
|
-
tuning_loader = DataLoader(tuningset, batch_size=1, shuffle=False)
|
113
|
-
|
114
|
-
# Form dataloaders as dictionary
|
115
|
-
dataloaders = {
|
116
|
-
"train": train_loader,
|
117
|
-
"valid": valid_loader,
|
118
|
-
"tuning": tuning_loader,
|
119
|
-
}
|
120
|
-
|
121
|
-
return dataloaders
|
122
|
-
|
123
|
-
|
124
|
-
def get_dataloaders_public(
|
125
|
-
root, mapping_file, valid_portion=0.0, batch_size=8,
|
126
|
-
):
|
127
|
-
"""Set DataLoaders for labeled datasets.
|
128
|
-
|
129
|
-
Args:
|
130
|
-
root (str): root directory
|
131
|
-
mapping_file (str): json file for mapping dataset
|
132
|
-
valid_portion (float, optional): portion of valid datasets. Defaults to 0.1.
|
133
|
-
batch_size (int, optional): batch size. Defaults to 8.
|
134
|
-
shuffle (bool, optional): shuffles dataloader. Defaults to True.
|
135
|
-
|
136
|
-
Returns:
|
137
|
-
dict: dictionary of data loaders.
|
138
|
-
"""
|
139
|
-
|
140
|
-
# Get list of data dictionaries from decoded paths
|
141
|
-
data_dicts = path_decoder(root, mapping_file)
|
142
|
-
|
143
|
-
# Split datasets as Train/Valid
|
144
|
-
train_dicts, _ = split_train_valid(data_dicts, valid_portion=valid_portion)
|
145
|
-
|
146
|
-
trainset = Dataset(train_dicts, transform=public_transforms)
|
147
|
-
# Set dataloader for Trainset
|
148
|
-
train_loader = DataLoader(
|
149
|
-
trainset, batch_size=batch_size, shuffle=True, num_workers=5
|
150
|
-
)
|
151
|
-
|
152
|
-
# Form dataloaders as dictionary
|
153
|
-
dataloaders = {
|
154
|
-
"public": train_loader,
|
155
|
-
}
|
156
|
-
|
157
|
-
return dataloaders
|
158
|
-
|
159
|
-
|
160
|
-
def get_dataloaders_unlabeled(
|
161
|
-
root, mapping_file, batch_size=8, shuffle=True, num_workers=5,
|
162
|
-
):
|
163
|
-
"""Set dataloaders for unlabeled dataset."""
|
164
|
-
# Get list of data dictionaries from decoded paths
|
165
|
-
unlabeled_dicts = path_decoder(root, mapping_file, no_label=True, unlabeled=True)
|
166
|
-
|
167
|
-
# Obtain datasets with transforms
|
168
|
-
unlabeled_dicts, _ = split_train_valid(unlabeled_dicts, valid_portion=0)
|
169
|
-
unlabeled_set = Dataset(unlabeled_dicts, transform=unlabeled_transforms)
|
170
|
-
|
171
|
-
# Set dataloader for Unlabeled dataset
|
172
|
-
unlabeled_loader = DataLoader(
|
173
|
-
unlabeled_set, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
|
174
|
-
)
|
175
|
-
|
176
|
-
dataloaders = {
|
177
|
-
"unlabeled": unlabeled_loader,
|
178
|
-
}
|
179
|
-
|
180
|
-
return dataloaders
|
181
|
-
|
182
|
-
|
183
|
-
def get_dataloaders_unlabeled_psuedo(
|
184
|
-
root, mapping_file, batch_size=8, shuffle=True, num_workers=5,
|
185
|
-
):
|
186
|
-
|
187
|
-
# Get list of data dictionaries from decoded paths
|
188
|
-
unlabeled_psuedo_dicts = path_decoder(
|
189
|
-
root, mapping_file, no_label=False, unlabeled=True
|
190
|
-
)
|
191
|
-
|
192
|
-
# Obtain datasets with transforms
|
193
|
-
unlabeled_psuedo_dicts, _ = split_train_valid(
|
194
|
-
unlabeled_psuedo_dicts, valid_portion=0
|
195
|
-
)
|
196
|
-
unlabeled_psuedo_set = Dataset(unlabeled_psuedo_dicts, transform=train_transforms)
|
197
|
-
|
198
|
-
# Set dataloader for Unlabeled dataset
|
199
|
-
unlabeled_psuedo_loader = DataLoader(
|
200
|
-
unlabeled_psuedo_set,
|
201
|
-
batch_size=batch_size,
|
202
|
-
shuffle=shuffle,
|
203
|
-
num_workers=num_workers,
|
204
|
-
)
|
205
|
-
|
206
|
-
dataloaders = {"unlabeled": unlabeled_psuedo_loader}
|
207
|
-
|
208
|
-
return dataloaders
|
@@ -1,148 +0,0 @@
|
|
1
|
-
from .custom import *
|
2
|
-
|
3
|
-
from monai.transforms import *
|
4
|
-
|
5
|
-
__all__ = [
|
6
|
-
"train_transforms",
|
7
|
-
"public_transforms",
|
8
|
-
"valid_transforms",
|
9
|
-
"tuning_transforms",
|
10
|
-
"unlabeled_transforms",
|
11
|
-
]
|
12
|
-
|
13
|
-
train_transforms = Compose(
|
14
|
-
[
|
15
|
-
# >>> Load and refine data --- img: (H, W, 3); label: (H, W)
|
16
|
-
CustomLoadImaged(keys=["img", "label"], image_only=True),
|
17
|
-
CustomNormalizeImaged(
|
18
|
-
keys=["img"],
|
19
|
-
allow_missing_keys=True,
|
20
|
-
channel_wise=False,
|
21
|
-
percentiles=[0.0, 99.5],
|
22
|
-
),
|
23
|
-
EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
|
24
|
-
RemoveRepeatedChanneld(keys=["label"], repeats=3), # label: (H, W)
|
25
|
-
ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
|
26
|
-
# >>> Spatial transforms
|
27
|
-
RandZoomd(
|
28
|
-
keys=["img", "label"],
|
29
|
-
prob=0.5,
|
30
|
-
min_zoom=0.25,
|
31
|
-
max_zoom=1.5,
|
32
|
-
mode=["area", "nearest"],
|
33
|
-
keep_size=False,
|
34
|
-
),
|
35
|
-
SpatialPadd(keys=["img", "label"], spatial_size=512),
|
36
|
-
RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
|
37
|
-
RandAxisFlipd(keys=["img", "label"], prob=0.5),
|
38
|
-
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
|
39
|
-
IntensityDiversification(keys=["img", "label"], allow_missing_keys=True),
|
40
|
-
# # >>> Intensity transforms
|
41
|
-
RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
|
42
|
-
RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
|
43
|
-
RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
|
44
|
-
RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
|
45
|
-
RandGaussianSharpend(keys=["img"], prob=0.25),
|
46
|
-
EnsureTyped(keys=["img", "label"]),
|
47
|
-
]
|
48
|
-
)
|
49
|
-
|
50
|
-
|
51
|
-
public_transforms = Compose(
|
52
|
-
[
|
53
|
-
CustomLoadImaged(keys=["img", "label"], image_only=True),
|
54
|
-
BoundaryExclusion(keys=["label"]),
|
55
|
-
CustomNormalizeImaged(
|
56
|
-
keys=["img"],
|
57
|
-
allow_missing_keys=True,
|
58
|
-
channel_wise=False,
|
59
|
-
percentiles=[0.0, 99.5],
|
60
|
-
),
|
61
|
-
EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
|
62
|
-
RemoveRepeatedChanneld(keys=["label"], repeats=3), # label: (H, W)
|
63
|
-
ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
|
64
|
-
# >>> Spatial transforms
|
65
|
-
SpatialPadd(keys=["img", "label"], spatial_size=512),
|
66
|
-
RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
|
67
|
-
RandAxisFlipd(keys=["img", "label"], prob=0.5),
|
68
|
-
RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
|
69
|
-
Rotate90d(k=1, keys=["label"], spatial_axes=(0, 1)),
|
70
|
-
Flipd(keys=["label"], spatial_axis=0),
|
71
|
-
EnsureTyped(keys=["img", "label"]),
|
72
|
-
]
|
73
|
-
)
|
74
|
-
|
75
|
-
|
76
|
-
valid_transforms = Compose(
|
77
|
-
[
|
78
|
-
CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
|
79
|
-
CustomNormalizeImaged(
|
80
|
-
keys=["img"],
|
81
|
-
allow_missing_keys=True,
|
82
|
-
channel_wise=False,
|
83
|
-
percentiles=[0.0, 99.5],
|
84
|
-
),
|
85
|
-
EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
|
86
|
-
RemoveRepeatedChanneld(keys=["label"], repeats=3),
|
87
|
-
ScaleIntensityd(keys=["img"], allow_missing_keys=True),
|
88
|
-
EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
|
89
|
-
]
|
90
|
-
)
|
91
|
-
|
92
|
-
tuning_transforms = Compose(
|
93
|
-
[
|
94
|
-
CustomLoadImaged(keys=["img"], image_only=True),
|
95
|
-
CustomNormalizeImaged(
|
96
|
-
keys=["img"],
|
97
|
-
allow_missing_keys=True,
|
98
|
-
channel_wise=False,
|
99
|
-
percentiles=[0.0, 99.5],
|
100
|
-
),
|
101
|
-
EnsureChannelFirstd(keys=["img"], channel_dim=-1),
|
102
|
-
ScaleIntensityd(keys=["img"]),
|
103
|
-
EnsureTyped(keys=["img"]),
|
104
|
-
]
|
105
|
-
)
|
106
|
-
|
107
|
-
unlabeled_transforms = Compose(
|
108
|
-
[
|
109
|
-
# >>> Load and refine data --- img: (H, W, 3); label: (H, W)
|
110
|
-
CustomLoadImaged(keys=["img"], image_only=True),
|
111
|
-
CustomNormalizeImaged(
|
112
|
-
keys=["img"],
|
113
|
-
allow_missing_keys=True,
|
114
|
-
channel_wise=False,
|
115
|
-
percentiles=[0.0, 99.5],
|
116
|
-
),
|
117
|
-
EnsureChannelFirstd(keys=["img"], channel_dim=-1),
|
118
|
-
RandZoomd(
|
119
|
-
keys=["img"],
|
120
|
-
prob=0.5,
|
121
|
-
min_zoom=0.25,
|
122
|
-
max_zoom=1.25,
|
123
|
-
mode=["area"],
|
124
|
-
keep_size=False,
|
125
|
-
),
|
126
|
-
ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
|
127
|
-
# >>> Spatial transforms
|
128
|
-
SpatialPadd(keys=["img"], spatial_size=512),
|
129
|
-
RandSpatialCropd(keys=["img"], roi_size=512, random_size=False),
|
130
|
-
EnsureTyped(keys=["img"]),
|
131
|
-
]
|
132
|
-
)
|
133
|
-
|
134
|
-
|
135
|
-
def get_pred_transforms():
|
136
|
-
"""Prediction preprocessing"""
|
137
|
-
pred_transforms = Compose(
|
138
|
-
[
|
139
|
-
# >>> Load and refine data
|
140
|
-
CustomLoadImage(image_only=True),
|
141
|
-
CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]),
|
142
|
-
EnsureChannelFirst(channel_dim=-1), # image: (3, H, W)
|
143
|
-
ScaleIntensity(),
|
144
|
-
EnsureType(data_type="tensor"),
|
145
|
-
]
|
146
|
-
)
|
147
|
-
|
148
|
-
return pred_transforms
|
@@ -1,84 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import json
|
3
|
-
import numpy as np
|
4
|
-
|
5
|
-
__all__ = ["split_train_valid", "path_decoder"]
|
6
|
-
|
7
|
-
|
8
|
-
def split_train_valid(data_dicts, valid_portion=0.1):
|
9
|
-
"""Split train/validata data according to the given proportion"""
|
10
|
-
|
11
|
-
train_dicts, valid_dicts = data_dicts, []
|
12
|
-
if valid_portion > 0:
|
13
|
-
|
14
|
-
# Obtain & shuffle data indices
|
15
|
-
num_data_dicts = len(data_dicts)
|
16
|
-
indices = np.arange(num_data_dicts)
|
17
|
-
np.random.shuffle(indices)
|
18
|
-
|
19
|
-
# Divide train/valid indices by the proportion
|
20
|
-
valid_size = int(num_data_dicts * valid_portion)
|
21
|
-
train_indices = indices[valid_size:]
|
22
|
-
valid_indices = indices[:valid_size]
|
23
|
-
|
24
|
-
# Assign data dicts by split indices
|
25
|
-
train_dicts = [data_dicts[idx] for idx in train_indices]
|
26
|
-
valid_dicts = [data_dicts[idx] for idx in valid_indices]
|
27
|
-
|
28
|
-
print(
|
29
|
-
"\n(DataLoaded) Training data size: %d, Validation data size: %d\n"
|
30
|
-
% (len(train_dicts), len(valid_dicts))
|
31
|
-
)
|
32
|
-
|
33
|
-
return train_dicts, valid_dicts
|
34
|
-
|
35
|
-
|
36
|
-
def path_decoder(root, mapping_file, no_label=False, unlabeled=False):
|
37
|
-
"""Decode img/label file paths from root & mapping directory.
|
38
|
-
|
39
|
-
Args:
|
40
|
-
root (str):
|
41
|
-
mapping_file (str): json file containing image & label file paths.
|
42
|
-
no_label (bool, optional): whether to include "label" key. Defaults to False.
|
43
|
-
|
44
|
-
Returns:
|
45
|
-
list: list of dictionary. (ex. [{"img": img_path, "label": label_path}, ...])
|
46
|
-
"""
|
47
|
-
|
48
|
-
data_dicts = []
|
49
|
-
|
50
|
-
with open(mapping_file, "r") as file:
|
51
|
-
data = json.load(file)
|
52
|
-
|
53
|
-
for map_key in data.keys():
|
54
|
-
|
55
|
-
# If no_label, assign "img" key only
|
56
|
-
if no_label:
|
57
|
-
data_dict_item = [
|
58
|
-
{"img": os.path.join(root, elem["img"]),} for elem in data[map_key]
|
59
|
-
]
|
60
|
-
|
61
|
-
# If label exists, assign both "img" and "label" keys
|
62
|
-
else:
|
63
|
-
data_dict_item = [
|
64
|
-
{
|
65
|
-
"img": os.path.join(root, elem["img"]),
|
66
|
-
"label": os.path.join(root, elem["label"]),
|
67
|
-
}
|
68
|
-
for elem in data[map_key]
|
69
|
-
]
|
70
|
-
|
71
|
-
# Add refined datasets to be returned
|
72
|
-
data_dicts += data_dict_item
|
73
|
-
|
74
|
-
if unlabeled:
|
75
|
-
refined_data_dicts = []
|
76
|
-
|
77
|
-
# Exclude the corrupted image to prevent errror
|
78
|
-
for data_dict in data_dicts:
|
79
|
-
if "00504" not in data_dict["img"]:
|
80
|
-
refined_data_dicts.append(data_dict)
|
81
|
-
|
82
|
-
data_dicts = refined_data_dicts
|
83
|
-
|
84
|
-
return data_dicts
|
@@ -1,200 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Adapted from the following references:
|
3
|
-
[1] https://github.com/JunMa11/NeurIPS-CellSeg/blob/main/baseline/compute_metric.py
|
4
|
-
[2] https://github.com/stardist/stardist/blob/master/stardist/matching.py
|
5
|
-
|
6
|
-
"""
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
from skimage import segmentation
|
10
|
-
from scipy.optimize import linear_sum_assignment
|
11
|
-
from numba import jit
|
12
|
-
|
13
|
-
__all__ = ["evaluate_f1_score_cellseg", "evaluate_f1_score"]
|
14
|
-
|
15
|
-
|
16
|
-
def evaluate_f1_score_cellseg(masks_true, masks_pred, threshold=0.5):
|
17
|
-
"""
|
18
|
-
Get confusion elements for cell segmentation results.
|
19
|
-
Boundary pixels are not considered during evaluation.
|
20
|
-
"""
|
21
|
-
|
22
|
-
if np.prod(masks_true.shape) < (5000 * 5000):
|
23
|
-
masks_true = _remove_boundary_cells(masks_true.astype(np.int32))
|
24
|
-
masks_pred = _remove_boundary_cells(masks_pred.astype(np.int32))
|
25
|
-
|
26
|
-
tp, fp, fn = get_confusion(masks_true, masks_pred, threshold)
|
27
|
-
|
28
|
-
# Compute by Patch-based way for large images
|
29
|
-
else:
|
30
|
-
H, W = masks_true.shape
|
31
|
-
roi_size = 2000
|
32
|
-
|
33
|
-
# Get patch grid by roi_size
|
34
|
-
if H % roi_size != 0:
|
35
|
-
n_H = H // roi_size + 1
|
36
|
-
new_H = roi_size * n_H
|
37
|
-
else:
|
38
|
-
n_H = H // roi_size
|
39
|
-
new_H = H
|
40
|
-
|
41
|
-
if W % roi_size != 0:
|
42
|
-
n_W = W // roi_size + 1
|
43
|
-
new_W = roi_size * n_W
|
44
|
-
else:
|
45
|
-
n_W = W // roi_size
|
46
|
-
new_W = W
|
47
|
-
|
48
|
-
# Allocate values on the grid
|
49
|
-
gt_pad = np.zeros((new_H, new_W), dtype=masks_true.dtype)
|
50
|
-
pred_pad = np.zeros((new_H, new_W), dtype=masks_true.dtype)
|
51
|
-
gt_pad[:H, :W] = masks_true
|
52
|
-
pred_pad[:H, :W] = masks_pred
|
53
|
-
|
54
|
-
tp, fp, fn = 0, 0, 0
|
55
|
-
|
56
|
-
# Calculate confusion elements for each patch
|
57
|
-
for i in range(n_H):
|
58
|
-
for j in range(n_W):
|
59
|
-
gt_roi = _remove_boundary_cells(
|
60
|
-
gt_pad[
|
61
|
-
roi_size * i : roi_size * (i + 1),
|
62
|
-
roi_size * j : roi_size * (j + 1),
|
63
|
-
]
|
64
|
-
)
|
65
|
-
pred_roi = _remove_boundary_cells(
|
66
|
-
pred_pad[
|
67
|
-
roi_size * i : roi_size * (i + 1),
|
68
|
-
roi_size * j : roi_size * (j + 1),
|
69
|
-
]
|
70
|
-
)
|
71
|
-
tp_i, fp_i, fn_i = get_confusion(gt_roi, pred_roi, threshold)
|
72
|
-
tp += tp_i
|
73
|
-
fp += fp_i
|
74
|
-
fn += fn_i
|
75
|
-
|
76
|
-
# Calculate f1 score
|
77
|
-
precision, recall, f1_score = evaluate_f1_score(tp, fp, fn)
|
78
|
-
|
79
|
-
return precision, recall, f1_score
|
80
|
-
|
81
|
-
|
82
|
-
def evaluate_f1_score(tp, fp, fn):
|
83
|
-
"""Evaluate F1-score for the given confusion elements"""
|
84
|
-
|
85
|
-
# Do not Compute on trivial results
|
86
|
-
if tp == 0:
|
87
|
-
precision, recall, f1_score = 0, 0, 0
|
88
|
-
|
89
|
-
else:
|
90
|
-
precision = tp / (tp + fp)
|
91
|
-
recall = tp / (tp + fn)
|
92
|
-
f1_score = 2 * (precision * recall) / (precision + recall)
|
93
|
-
|
94
|
-
return precision, recall, f1_score
|
95
|
-
|
96
|
-
|
97
|
-
def _remove_boundary_cells(mask):
|
98
|
-
"""Remove cells on the boundary from the mask"""
|
99
|
-
|
100
|
-
# Identify boundary cells
|
101
|
-
W, H = mask.shape
|
102
|
-
bd = np.ones((W, H))
|
103
|
-
bd[2 : W - 2, 2 : H - 2] = 0
|
104
|
-
bd_cells = np.unique(mask * bd)
|
105
|
-
|
106
|
-
# Remove cells on the boundary
|
107
|
-
for i in bd_cells[1:]:
|
108
|
-
mask[mask == i] = 0
|
109
|
-
|
110
|
-
# Allocate labels as sequential manner
|
111
|
-
new_label, _, _ = segmentation.relabel_sequential(mask)
|
112
|
-
|
113
|
-
return new_label
|
114
|
-
|
115
|
-
|
116
|
-
def get_confusion(masks_true, masks_pred, threshold=0.5):
|
117
|
-
"""Calculate confusion matrix elements: (TP, FP, FN)"""
|
118
|
-
num_gt_instances = np.max(masks_true)
|
119
|
-
num_pred_instances = np.max(masks_pred)
|
120
|
-
|
121
|
-
if num_pred_instances == 0:
|
122
|
-
print("No segmentation results!")
|
123
|
-
tp, fp, fn = 0, 0, 0
|
124
|
-
|
125
|
-
else:
|
126
|
-
# Calculate IoU and exclude background label (0)
|
127
|
-
iou = _get_iou(masks_true, masks_pred)
|
128
|
-
iou = iou[1:, 1:]
|
129
|
-
|
130
|
-
# Calculate true positives
|
131
|
-
tp = _get_true_positive(iou, threshold)
|
132
|
-
fp = num_pred_instances - tp
|
133
|
-
fn = num_gt_instances - tp
|
134
|
-
|
135
|
-
return tp, fp, fn
|
136
|
-
|
137
|
-
|
138
|
-
def _get_true_positive(iou, threshold=0.5):
|
139
|
-
"""Get true positive (TP) pixels at the given threshold"""
|
140
|
-
|
141
|
-
# Number of instances to be matched
|
142
|
-
num_matched = min(iou.shape[0], iou.shape[1])
|
143
|
-
|
144
|
-
# Find optimal matching by using IoU as tie-breaker
|
145
|
-
costs = -(iou >= threshold).astype(np.float) - iou / (2 * num_matched)
|
146
|
-
matched_gt_label, matched_pred_label = linear_sum_assignment(costs)
|
147
|
-
|
148
|
-
# Consider as the same instance only if the IoU is above the threshold
|
149
|
-
match_ok = iou[matched_gt_label, matched_pred_label] >= threshold
|
150
|
-
tp = match_ok.sum()
|
151
|
-
|
152
|
-
return tp
|
153
|
-
|
154
|
-
|
155
|
-
def _get_iou(masks_true, masks_pred):
|
156
|
-
"""Get the iou between masks_true and masks_pred"""
|
157
|
-
|
158
|
-
# Get overlap matrix (GT Instances Num, Pred Instance Num)
|
159
|
-
overlap = _label_overlap(masks_true, masks_pred)
|
160
|
-
|
161
|
-
# Predicted instance pixels
|
162
|
-
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
|
163
|
-
|
164
|
-
# GT instance pixels
|
165
|
-
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
|
166
|
-
|
167
|
-
# Calculate intersection of union (IoU)
|
168
|
-
union = n_pixels_pred + n_pixels_true - overlap
|
169
|
-
iou = overlap / union
|
170
|
-
|
171
|
-
# Ensure numerical values
|
172
|
-
iou[np.isnan(iou)] = 0.0
|
173
|
-
|
174
|
-
return iou
|
175
|
-
|
176
|
-
|
177
|
-
@jit(nopython=True)
|
178
|
-
def _label_overlap(x, y):
|
179
|
-
"""Get pixel overlaps between two masks
|
180
|
-
|
181
|
-
Parameters
|
182
|
-
------------
|
183
|
-
x, y (np array; dtype int): 0=NO masks; 1,2... are mask labels
|
184
|
-
|
185
|
-
Returns
|
186
|
-
------------
|
187
|
-
overlap (np array; dtype int): Overlaps of size [x.max()+1, y.max()+1]
|
188
|
-
"""
|
189
|
-
|
190
|
-
# Make as 1D array
|
191
|
-
x, y = x.ravel(), y.ravel()
|
192
|
-
|
193
|
-
# Preallocate a Contact Map matrix
|
194
|
-
overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint)
|
195
|
-
|
196
|
-
# Calculate the number of shared pixels for each label
|
197
|
-
for i in range(len(x)):
|
198
|
-
overlap[x[i], y[i]] += 1
|
199
|
-
|
200
|
-
return overlap
|