spacr 0.5.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. spacr/__init__.py +0 -2
  2. spacr/__main__.py +3 -3
  3. spacr/core.py +13 -106
  4. spacr/gui_core.py +2 -2
  5. spacr/gui_utils.py +1 -13
  6. spacr/io.py +24 -25
  7. spacr/mediar.py +12 -8
  8. spacr/plot.py +50 -13
  9. spacr/settings.py +45 -6
  10. spacr/submodules.py +11 -1
  11. spacr/timelapse.py +21 -3
  12. spacr/utils.py +154 -15
  13. {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/METADATA +62 -62
  14. spacr-0.9.0.dist-info/RECORD +109 -0
  15. {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/WHEEL +1 -1
  16. spacr/resources/MEDIAR/.gitignore +0 -18
  17. spacr/resources/MEDIAR/LICENSE +0 -21
  18. spacr/resources/MEDIAR/README.md +0 -189
  19. spacr/resources/MEDIAR/SetupDict.py +0 -39
  20. spacr/resources/MEDIAR/__pycache__/SetupDict.cpython-39.pyc +0 -0
  21. spacr/resources/MEDIAR/__pycache__/evaluate.cpython-39.pyc +0 -0
  22. spacr/resources/MEDIAR/__pycache__/generate_mapping.cpython-39.pyc +0 -0
  23. spacr/resources/MEDIAR/__pycache__/main.cpython-39.pyc +0 -0
  24. spacr/resources/MEDIAR/config/baseline.json +0 -60
  25. spacr/resources/MEDIAR/config/mediar_example.json +0 -72
  26. spacr/resources/MEDIAR/config/pred/pred_mediar.json +0 -17
  27. spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +0 -55
  28. spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +0 -58
  29. spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +0 -66
  30. spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +0 -66
  31. spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +0 -16
  32. spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +0 -23
  33. spacr/resources/MEDIAR/core/BasePredictor.py +0 -120
  34. spacr/resources/MEDIAR/core/BaseTrainer.py +0 -240
  35. spacr/resources/MEDIAR/core/Baseline/Predictor.py +0 -59
  36. spacr/resources/MEDIAR/core/Baseline/Trainer.py +0 -113
  37. spacr/resources/MEDIAR/core/Baseline/__init__.py +0 -2
  38. spacr/resources/MEDIAR/core/Baseline/__pycache__/Predictor.cpython-39.pyc +0 -0
  39. spacr/resources/MEDIAR/core/Baseline/__pycache__/Trainer.cpython-39.pyc +0 -0
  40. spacr/resources/MEDIAR/core/Baseline/__pycache__/__init__.cpython-39.pyc +0 -0
  41. spacr/resources/MEDIAR/core/Baseline/__pycache__/utils.cpython-39.pyc +0 -0
  42. spacr/resources/MEDIAR/core/Baseline/utils.py +0 -80
  43. spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +0 -105
  44. spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +0 -234
  45. spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +0 -172
  46. spacr/resources/MEDIAR/core/MEDIAR/__init__.py +0 -3
  47. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/EnsemblePredictor.cpython-39.pyc +0 -0
  48. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Predictor.cpython-39.pyc +0 -0
  49. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Trainer.cpython-39.pyc +0 -0
  50. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/__init__.cpython-39.pyc +0 -0
  51. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/utils.cpython-39.pyc +0 -0
  52. spacr/resources/MEDIAR/core/MEDIAR/utils.py +0 -429
  53. spacr/resources/MEDIAR/core/__init__.py +0 -2
  54. spacr/resources/MEDIAR/core/__pycache__/BasePredictor.cpython-39.pyc +0 -0
  55. spacr/resources/MEDIAR/core/__pycache__/BaseTrainer.cpython-39.pyc +0 -0
  56. spacr/resources/MEDIAR/core/__pycache__/__init__.cpython-39.pyc +0 -0
  57. spacr/resources/MEDIAR/core/__pycache__/utils.cpython-39.pyc +0 -0
  58. spacr/resources/MEDIAR/core/utils.py +0 -40
  59. spacr/resources/MEDIAR/evaluate.py +0 -71
  60. spacr/resources/MEDIAR/generate_mapping.py +0 -121
  61. spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
  62. spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
  63. spacr/resources/MEDIAR/image/failure_cases.png +0 -0
  64. spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
  65. spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
  66. spacr/resources/MEDIAR/image/mediar_results.png +0 -0
  67. spacr/resources/MEDIAR/main.py +0 -125
  68. spacr/resources/MEDIAR/predict.py +0 -70
  69. spacr/resources/MEDIAR/requirements.txt +0 -14
  70. spacr/resources/MEDIAR/train_tools/__init__.py +0 -3
  71. spacr/resources/MEDIAR/train_tools/__pycache__/__init__.cpython-39.pyc +0 -0
  72. spacr/resources/MEDIAR/train_tools/__pycache__/measures.cpython-39.pyc +0 -0
  73. spacr/resources/MEDIAR/train_tools/__pycache__/utils.cpython-39.pyc +0 -0
  74. spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +0 -1
  75. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  76. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/datasetter.cpython-39.pyc +0 -0
  77. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/transforms.cpython-39.pyc +0 -0
  78. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/utils.cpython-39.pyc +0 -0
  79. spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +0 -88
  80. spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +0 -161
  81. spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +0 -77
  82. spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +0 -3
  83. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/CellAware.cpython-39.pyc +0 -0
  84. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/LoadImage.cpython-39.pyc +0 -0
  85. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/NormalizeImage.cpython-39.pyc +0 -0
  86. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/__init__.cpython-39.pyc +0 -0
  87. spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
  88. spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +0 -208
  89. spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +0 -148
  90. spacr/resources/MEDIAR/train_tools/data_utils/utils.py +0 -84
  91. spacr/resources/MEDIAR/train_tools/measures.py +0 -200
  92. spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +0 -102
  93. spacr/resources/MEDIAR/train_tools/models/__init__.py +0 -1
  94. spacr/resources/MEDIAR/train_tools/models/__pycache__/MEDIARFormer.cpython-39.pyc +0 -0
  95. spacr/resources/MEDIAR/train_tools/models/__pycache__/__init__.cpython-39.pyc +0 -0
  96. spacr/resources/MEDIAR/train_tools/utils.py +0 -70
  97. spacr-0.5.0.dist-info/RECORD +0 -190
  98. {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/LICENSE +0 -0
  99. {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/entry_points.txt +0 -0
  100. {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,208 +0,0 @@
1
- from torch.utils.data import DataLoader
2
- from monai.data import Dataset
3
- import pickle
4
-
5
- from .transforms import (
6
- train_transforms,
7
- public_transforms,
8
- valid_transforms,
9
- tuning_transforms,
10
- unlabeled_transforms,
11
- )
12
- from .utils import split_train_valid, path_decoder
13
-
14
- DATA_LABEL_DICT_PICKLE_FILE = "./train_tools/data_utils/custom/modalities.pkl"
15
-
16
- __all__ = [
17
- "get_dataloaders_labeled",
18
- "get_dataloaders_public",
19
- "get_dataloaders_unlabeled",
20
- ]
21
-
22
-
23
- def get_dataloaders_labeled(
24
- root,
25
- mapping_file,
26
- mapping_file_tuning,
27
- join_mapping_file=None,
28
- valid_portion=0.0,
29
- batch_size=8,
30
- amplified=False,
31
- relabel=False,
32
- ):
33
- """Set DataLoaders for labeled datasets.
34
-
35
- Args:
36
- root (str): root directory
37
- mapping_file (str): json file for mapping dataset
38
- valid_portion (float, optional): portion of valid datasets. Defaults to 0.1.
39
- batch_size (int, optional): batch size. Defaults to 8.
40
- shuffle (bool, optional): shuffles dataloader. Defaults to True.
41
- num_workers (int, optional): number of workers for each datalaoder. Defaults to 5.
42
-
43
- Returns:
44
- dict: dictionary of data loaders.
45
- """
46
-
47
- # Get list of data dictionaries from decoded paths
48
- data_dicts = path_decoder(root, mapping_file)
49
- tuning_dicts = path_decoder(root, mapping_file_tuning, no_label=True)
50
-
51
- if amplified:
52
- with open(DATA_LABEL_DICT_PICKLE_FILE, "rb") as f:
53
- data_label_dict = pickle.load(f)
54
-
55
- data_point_dict = {}
56
-
57
- for label, data_lst in data_label_dict.items():
58
- data_point_dict[label] = []
59
-
60
- for d_idx in data_lst:
61
- try:
62
- data_point_dict[label].append(data_dicts[d_idx])
63
- except:
64
- print(label, d_idx)
65
-
66
- data_dicts = []
67
-
68
- for label, data_points in data_point_dict.items():
69
- len_data_points = len(data_points)
70
-
71
- if len_data_points >= 50:
72
- data_dicts += data_points
73
- else:
74
- for i in range(50):
75
- data_dicts.append(data_points[i % len_data_points])
76
-
77
- data_transforms = train_transforms
78
-
79
- if join_mapping_file is not None:
80
- data_dicts += path_decoder(root, join_mapping_file)
81
- data_transforms = public_transforms
82
-
83
- if relabel:
84
- for elem in data_dicts:
85
- cell_idx = int(elem["label"].split("_label.tiff")[0].split("_")[-1])
86
- if cell_idx in range(340, 499):
87
- new_label = elem["label"].replace(
88
- "/data/CellSeg/Official/Train_Labeled/labels/",
89
- "/CellSeg/pretrained_train_ext/",
90
- )
91
- elem["label"] = new_label
92
-
93
- # Split datasets as Train/Valid
94
- train_dicts, valid_dicts = split_train_valid(
95
- data_dicts, valid_portion=valid_portion
96
- )
97
-
98
- # Obtain datasets with transforms
99
- trainset = Dataset(train_dicts, transform=data_transforms)
100
- validset = Dataset(valid_dicts, transform=valid_transforms)
101
- tuningset = Dataset(tuning_dicts, transform=tuning_transforms)
102
-
103
- # Set dataloader for Trainset
104
- train_loader = DataLoader(
105
- trainset, batch_size=batch_size, shuffle=True, num_workers=5
106
- )
107
-
108
- # Set dataloader for Validset (Batch size is fixed as 1)
109
- valid_loader = DataLoader(validset, batch_size=1, shuffle=False,)
110
-
111
- # Set dataloader for Tuningset (Batch size is fixed as 1)
112
- tuning_loader = DataLoader(tuningset, batch_size=1, shuffle=False)
113
-
114
- # Form dataloaders as dictionary
115
- dataloaders = {
116
- "train": train_loader,
117
- "valid": valid_loader,
118
- "tuning": tuning_loader,
119
- }
120
-
121
- return dataloaders
122
-
123
-
124
- def get_dataloaders_public(
125
- root, mapping_file, valid_portion=0.0, batch_size=8,
126
- ):
127
- """Set DataLoaders for labeled datasets.
128
-
129
- Args:
130
- root (str): root directory
131
- mapping_file (str): json file for mapping dataset
132
- valid_portion (float, optional): portion of valid datasets. Defaults to 0.1.
133
- batch_size (int, optional): batch size. Defaults to 8.
134
- shuffle (bool, optional): shuffles dataloader. Defaults to True.
135
-
136
- Returns:
137
- dict: dictionary of data loaders.
138
- """
139
-
140
- # Get list of data dictionaries from decoded paths
141
- data_dicts = path_decoder(root, mapping_file)
142
-
143
- # Split datasets as Train/Valid
144
- train_dicts, _ = split_train_valid(data_dicts, valid_portion=valid_portion)
145
-
146
- trainset = Dataset(train_dicts, transform=public_transforms)
147
- # Set dataloader for Trainset
148
- train_loader = DataLoader(
149
- trainset, batch_size=batch_size, shuffle=True, num_workers=5
150
- )
151
-
152
- # Form dataloaders as dictionary
153
- dataloaders = {
154
- "public": train_loader,
155
- }
156
-
157
- return dataloaders
158
-
159
-
160
- def get_dataloaders_unlabeled(
161
- root, mapping_file, batch_size=8, shuffle=True, num_workers=5,
162
- ):
163
- """Set dataloaders for unlabeled dataset."""
164
- # Get list of data dictionaries from decoded paths
165
- unlabeled_dicts = path_decoder(root, mapping_file, no_label=True, unlabeled=True)
166
-
167
- # Obtain datasets with transforms
168
- unlabeled_dicts, _ = split_train_valid(unlabeled_dicts, valid_portion=0)
169
- unlabeled_set = Dataset(unlabeled_dicts, transform=unlabeled_transforms)
170
-
171
- # Set dataloader for Unlabeled dataset
172
- unlabeled_loader = DataLoader(
173
- unlabeled_set, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
174
- )
175
-
176
- dataloaders = {
177
- "unlabeled": unlabeled_loader,
178
- }
179
-
180
- return dataloaders
181
-
182
-
183
- def get_dataloaders_unlabeled_psuedo(
184
- root, mapping_file, batch_size=8, shuffle=True, num_workers=5,
185
- ):
186
-
187
- # Get list of data dictionaries from decoded paths
188
- unlabeled_psuedo_dicts = path_decoder(
189
- root, mapping_file, no_label=False, unlabeled=True
190
- )
191
-
192
- # Obtain datasets with transforms
193
- unlabeled_psuedo_dicts, _ = split_train_valid(
194
- unlabeled_psuedo_dicts, valid_portion=0
195
- )
196
- unlabeled_psuedo_set = Dataset(unlabeled_psuedo_dicts, transform=train_transforms)
197
-
198
- # Set dataloader for Unlabeled dataset
199
- unlabeled_psuedo_loader = DataLoader(
200
- unlabeled_psuedo_set,
201
- batch_size=batch_size,
202
- shuffle=shuffle,
203
- num_workers=num_workers,
204
- )
205
-
206
- dataloaders = {"unlabeled": unlabeled_psuedo_loader}
207
-
208
- return dataloaders
@@ -1,148 +0,0 @@
1
- from .custom import *
2
-
3
- from monai.transforms import *
4
-
5
- __all__ = [
6
- "train_transforms",
7
- "public_transforms",
8
- "valid_transforms",
9
- "tuning_transforms",
10
- "unlabeled_transforms",
11
- ]
12
-
13
- train_transforms = Compose(
14
- [
15
- # >>> Load and refine data --- img: (H, W, 3); label: (H, W)
16
- CustomLoadImaged(keys=["img", "label"], image_only=True),
17
- CustomNormalizeImaged(
18
- keys=["img"],
19
- allow_missing_keys=True,
20
- channel_wise=False,
21
- percentiles=[0.0, 99.5],
22
- ),
23
- EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
24
- RemoveRepeatedChanneld(keys=["label"], repeats=3), # label: (H, W)
25
- ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
26
- # >>> Spatial transforms
27
- RandZoomd(
28
- keys=["img", "label"],
29
- prob=0.5,
30
- min_zoom=0.25,
31
- max_zoom=1.5,
32
- mode=["area", "nearest"],
33
- keep_size=False,
34
- ),
35
- SpatialPadd(keys=["img", "label"], spatial_size=512),
36
- RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
37
- RandAxisFlipd(keys=["img", "label"], prob=0.5),
38
- RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
39
- IntensityDiversification(keys=["img", "label"], allow_missing_keys=True),
40
- # # >>> Intensity transforms
41
- RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
42
- RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
43
- RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
44
- RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
45
- RandGaussianSharpend(keys=["img"], prob=0.25),
46
- EnsureTyped(keys=["img", "label"]),
47
- ]
48
- )
49
-
50
-
51
- public_transforms = Compose(
52
- [
53
- CustomLoadImaged(keys=["img", "label"], image_only=True),
54
- BoundaryExclusion(keys=["label"]),
55
- CustomNormalizeImaged(
56
- keys=["img"],
57
- allow_missing_keys=True,
58
- channel_wise=False,
59
- percentiles=[0.0, 99.5],
60
- ),
61
- EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
62
- RemoveRepeatedChanneld(keys=["label"], repeats=3), # label: (H, W)
63
- ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
64
- # >>> Spatial transforms
65
- SpatialPadd(keys=["img", "label"], spatial_size=512),
66
- RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
67
- RandAxisFlipd(keys=["img", "label"], prob=0.5),
68
- RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
69
- Rotate90d(k=1, keys=["label"], spatial_axes=(0, 1)),
70
- Flipd(keys=["label"], spatial_axis=0),
71
- EnsureTyped(keys=["img", "label"]),
72
- ]
73
- )
74
-
75
-
76
- valid_transforms = Compose(
77
- [
78
- CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
79
- CustomNormalizeImaged(
80
- keys=["img"],
81
- allow_missing_keys=True,
82
- channel_wise=False,
83
- percentiles=[0.0, 99.5],
84
- ),
85
- EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
86
- RemoveRepeatedChanneld(keys=["label"], repeats=3),
87
- ScaleIntensityd(keys=["img"], allow_missing_keys=True),
88
- EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
89
- ]
90
- )
91
-
92
- tuning_transforms = Compose(
93
- [
94
- CustomLoadImaged(keys=["img"], image_only=True),
95
- CustomNormalizeImaged(
96
- keys=["img"],
97
- allow_missing_keys=True,
98
- channel_wise=False,
99
- percentiles=[0.0, 99.5],
100
- ),
101
- EnsureChannelFirstd(keys=["img"], channel_dim=-1),
102
- ScaleIntensityd(keys=["img"]),
103
- EnsureTyped(keys=["img"]),
104
- ]
105
- )
106
-
107
- unlabeled_transforms = Compose(
108
- [
109
- # >>> Load and refine data --- img: (H, W, 3); label: (H, W)
110
- CustomLoadImaged(keys=["img"], image_only=True),
111
- CustomNormalizeImaged(
112
- keys=["img"],
113
- allow_missing_keys=True,
114
- channel_wise=False,
115
- percentiles=[0.0, 99.5],
116
- ),
117
- EnsureChannelFirstd(keys=["img"], channel_dim=-1),
118
- RandZoomd(
119
- keys=["img"],
120
- prob=0.5,
121
- min_zoom=0.25,
122
- max_zoom=1.25,
123
- mode=["area"],
124
- keep_size=False,
125
- ),
126
- ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
127
- # >>> Spatial transforms
128
- SpatialPadd(keys=["img"], spatial_size=512),
129
- RandSpatialCropd(keys=["img"], roi_size=512, random_size=False),
130
- EnsureTyped(keys=["img"]),
131
- ]
132
- )
133
-
134
-
135
- def get_pred_transforms():
136
- """Prediction preprocessing"""
137
- pred_transforms = Compose(
138
- [
139
- # >>> Load and refine data
140
- CustomLoadImage(image_only=True),
141
- CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]),
142
- EnsureChannelFirst(channel_dim=-1), # image: (3, H, W)
143
- ScaleIntensity(),
144
- EnsureType(data_type="tensor"),
145
- ]
146
- )
147
-
148
- return pred_transforms
@@ -1,84 +0,0 @@
1
- import os
2
- import json
3
- import numpy as np
4
-
5
- __all__ = ["split_train_valid", "path_decoder"]
6
-
7
-
8
- def split_train_valid(data_dicts, valid_portion=0.1):
9
- """Split train/validata data according to the given proportion"""
10
-
11
- train_dicts, valid_dicts = data_dicts, []
12
- if valid_portion > 0:
13
-
14
- # Obtain & shuffle data indices
15
- num_data_dicts = len(data_dicts)
16
- indices = np.arange(num_data_dicts)
17
- np.random.shuffle(indices)
18
-
19
- # Divide train/valid indices by the proportion
20
- valid_size = int(num_data_dicts * valid_portion)
21
- train_indices = indices[valid_size:]
22
- valid_indices = indices[:valid_size]
23
-
24
- # Assign data dicts by split indices
25
- train_dicts = [data_dicts[idx] for idx in train_indices]
26
- valid_dicts = [data_dicts[idx] for idx in valid_indices]
27
-
28
- print(
29
- "\n(DataLoaded) Training data size: %d, Validation data size: %d\n"
30
- % (len(train_dicts), len(valid_dicts))
31
- )
32
-
33
- return train_dicts, valid_dicts
34
-
35
-
36
- def path_decoder(root, mapping_file, no_label=False, unlabeled=False):
37
- """Decode img/label file paths from root & mapping directory.
38
-
39
- Args:
40
- root (str):
41
- mapping_file (str): json file containing image & label file paths.
42
- no_label (bool, optional): whether to include "label" key. Defaults to False.
43
-
44
- Returns:
45
- list: list of dictionary. (ex. [{"img": img_path, "label": label_path}, ...])
46
- """
47
-
48
- data_dicts = []
49
-
50
- with open(mapping_file, "r") as file:
51
- data = json.load(file)
52
-
53
- for map_key in data.keys():
54
-
55
- # If no_label, assign "img" key only
56
- if no_label:
57
- data_dict_item = [
58
- {"img": os.path.join(root, elem["img"]),} for elem in data[map_key]
59
- ]
60
-
61
- # If label exists, assign both "img" and "label" keys
62
- else:
63
- data_dict_item = [
64
- {
65
- "img": os.path.join(root, elem["img"]),
66
- "label": os.path.join(root, elem["label"]),
67
- }
68
- for elem in data[map_key]
69
- ]
70
-
71
- # Add refined datasets to be returned
72
- data_dicts += data_dict_item
73
-
74
- if unlabeled:
75
- refined_data_dicts = []
76
-
77
- # Exclude the corrupted image to prevent errror
78
- for data_dict in data_dicts:
79
- if "00504" not in data_dict["img"]:
80
- refined_data_dicts.append(data_dict)
81
-
82
- data_dicts = refined_data_dicts
83
-
84
- return data_dicts
@@ -1,200 +0,0 @@
1
- """
2
- Adapted from the following references:
3
- [1] https://github.com/JunMa11/NeurIPS-CellSeg/blob/main/baseline/compute_metric.py
4
- [2] https://github.com/stardist/stardist/blob/master/stardist/matching.py
5
-
6
- """
7
-
8
- import numpy as np
9
- from skimage import segmentation
10
- from scipy.optimize import linear_sum_assignment
11
- from numba import jit
12
-
13
- __all__ = ["evaluate_f1_score_cellseg", "evaluate_f1_score"]
14
-
15
-
16
- def evaluate_f1_score_cellseg(masks_true, masks_pred, threshold=0.5):
17
- """
18
- Get confusion elements for cell segmentation results.
19
- Boundary pixels are not considered during evaluation.
20
- """
21
-
22
- if np.prod(masks_true.shape) < (5000 * 5000):
23
- masks_true = _remove_boundary_cells(masks_true.astype(np.int32))
24
- masks_pred = _remove_boundary_cells(masks_pred.astype(np.int32))
25
-
26
- tp, fp, fn = get_confusion(masks_true, masks_pred, threshold)
27
-
28
- # Compute by Patch-based way for large images
29
- else:
30
- H, W = masks_true.shape
31
- roi_size = 2000
32
-
33
- # Get patch grid by roi_size
34
- if H % roi_size != 0:
35
- n_H = H // roi_size + 1
36
- new_H = roi_size * n_H
37
- else:
38
- n_H = H // roi_size
39
- new_H = H
40
-
41
- if W % roi_size != 0:
42
- n_W = W // roi_size + 1
43
- new_W = roi_size * n_W
44
- else:
45
- n_W = W // roi_size
46
- new_W = W
47
-
48
- # Allocate values on the grid
49
- gt_pad = np.zeros((new_H, new_W), dtype=masks_true.dtype)
50
- pred_pad = np.zeros((new_H, new_W), dtype=masks_true.dtype)
51
- gt_pad[:H, :W] = masks_true
52
- pred_pad[:H, :W] = masks_pred
53
-
54
- tp, fp, fn = 0, 0, 0
55
-
56
- # Calculate confusion elements for each patch
57
- for i in range(n_H):
58
- for j in range(n_W):
59
- gt_roi = _remove_boundary_cells(
60
- gt_pad[
61
- roi_size * i : roi_size * (i + 1),
62
- roi_size * j : roi_size * (j + 1),
63
- ]
64
- )
65
- pred_roi = _remove_boundary_cells(
66
- pred_pad[
67
- roi_size * i : roi_size * (i + 1),
68
- roi_size * j : roi_size * (j + 1),
69
- ]
70
- )
71
- tp_i, fp_i, fn_i = get_confusion(gt_roi, pred_roi, threshold)
72
- tp += tp_i
73
- fp += fp_i
74
- fn += fn_i
75
-
76
- # Calculate f1 score
77
- precision, recall, f1_score = evaluate_f1_score(tp, fp, fn)
78
-
79
- return precision, recall, f1_score
80
-
81
-
82
- def evaluate_f1_score(tp, fp, fn):
83
- """Evaluate F1-score for the given confusion elements"""
84
-
85
- # Do not Compute on trivial results
86
- if tp == 0:
87
- precision, recall, f1_score = 0, 0, 0
88
-
89
- else:
90
- precision = tp / (tp + fp)
91
- recall = tp / (tp + fn)
92
- f1_score = 2 * (precision * recall) / (precision + recall)
93
-
94
- return precision, recall, f1_score
95
-
96
-
97
- def _remove_boundary_cells(mask):
98
- """Remove cells on the boundary from the mask"""
99
-
100
- # Identify boundary cells
101
- W, H = mask.shape
102
- bd = np.ones((W, H))
103
- bd[2 : W - 2, 2 : H - 2] = 0
104
- bd_cells = np.unique(mask * bd)
105
-
106
- # Remove cells on the boundary
107
- for i in bd_cells[1:]:
108
- mask[mask == i] = 0
109
-
110
- # Allocate labels as sequential manner
111
- new_label, _, _ = segmentation.relabel_sequential(mask)
112
-
113
- return new_label
114
-
115
-
116
- def get_confusion(masks_true, masks_pred, threshold=0.5):
117
- """Calculate confusion matrix elements: (TP, FP, FN)"""
118
- num_gt_instances = np.max(masks_true)
119
- num_pred_instances = np.max(masks_pred)
120
-
121
- if num_pred_instances == 0:
122
- print("No segmentation results!")
123
- tp, fp, fn = 0, 0, 0
124
-
125
- else:
126
- # Calculate IoU and exclude background label (0)
127
- iou = _get_iou(masks_true, masks_pred)
128
- iou = iou[1:, 1:]
129
-
130
- # Calculate true positives
131
- tp = _get_true_positive(iou, threshold)
132
- fp = num_pred_instances - tp
133
- fn = num_gt_instances - tp
134
-
135
- return tp, fp, fn
136
-
137
-
138
- def _get_true_positive(iou, threshold=0.5):
139
- """Get true positive (TP) pixels at the given threshold"""
140
-
141
- # Number of instances to be matched
142
- num_matched = min(iou.shape[0], iou.shape[1])
143
-
144
- # Find optimal matching by using IoU as tie-breaker
145
- costs = -(iou >= threshold).astype(np.float) - iou / (2 * num_matched)
146
- matched_gt_label, matched_pred_label = linear_sum_assignment(costs)
147
-
148
- # Consider as the same instance only if the IoU is above the threshold
149
- match_ok = iou[matched_gt_label, matched_pred_label] >= threshold
150
- tp = match_ok.sum()
151
-
152
- return tp
153
-
154
-
155
- def _get_iou(masks_true, masks_pred):
156
- """Get the iou between masks_true and masks_pred"""
157
-
158
- # Get overlap matrix (GT Instances Num, Pred Instance Num)
159
- overlap = _label_overlap(masks_true, masks_pred)
160
-
161
- # Predicted instance pixels
162
- n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
163
-
164
- # GT instance pixels
165
- n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
166
-
167
- # Calculate intersection of union (IoU)
168
- union = n_pixels_pred + n_pixels_true - overlap
169
- iou = overlap / union
170
-
171
- # Ensure numerical values
172
- iou[np.isnan(iou)] = 0.0
173
-
174
- return iou
175
-
176
-
177
- @jit(nopython=True)
178
- def _label_overlap(x, y):
179
- """Get pixel overlaps between two masks
180
-
181
- Parameters
182
- ------------
183
- x, y (np array; dtype int): 0=NO masks; 1,2... are mask labels
184
-
185
- Returns
186
- ------------
187
- overlap (np array; dtype int): Overlaps of size [x.max()+1, y.max()+1]
188
- """
189
-
190
- # Make as 1D array
191
- x, y = x.ravel(), y.ravel()
192
-
193
- # Preallocate a Contact Map matrix
194
- overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint)
195
-
196
- # Calculate the number of shared pixels for each label
197
- for i in range(len(x)):
198
- overlap[x[i], y[i]] += 1
199
-
200
- return overlap