spacr 0.2.81__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -1
- spacr/core.py +107 -12
- spacr/gui.py +3 -2
- spacr/gui_core.py +8 -4
- spacr/gui_utils.py +4 -1
- spacr/io.py +13 -13
- spacr/measure.py +4 -4
- spacr/mediar.py +364 -0
- spacr/plot.py +5 -2
- spacr/resources/MEDIAR/.git +1 -0
- spacr/resources/MEDIAR/.gitignore +18 -0
- spacr/resources/MEDIAR/LICENSE +21 -0
- spacr/resources/MEDIAR/README.md +189 -0
- spacr/resources/MEDIAR/SetupDict.py +39 -0
- spacr/resources/MEDIAR/config/baseline.json +60 -0
- spacr/resources/MEDIAR/config/mediar_example.json +72 -0
- spacr/resources/MEDIAR/config/pred/pred_mediar.json +17 -0
- spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +55 -0
- spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +58 -0
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +66 -0
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +66 -0
- spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +16 -0
- spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +23 -0
- spacr/resources/MEDIAR/core/BasePredictor.py +120 -0
- spacr/resources/MEDIAR/core/BaseTrainer.py +240 -0
- spacr/resources/MEDIAR/core/Baseline/Predictor.py +59 -0
- spacr/resources/MEDIAR/core/Baseline/Trainer.py +113 -0
- spacr/resources/MEDIAR/core/Baseline/__init__.py +2 -0
- spacr/resources/MEDIAR/core/Baseline/utils.py +80 -0
- spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +105 -0
- spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +234 -0
- spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +172 -0
- spacr/resources/MEDIAR/core/MEDIAR/__init__.py +3 -0
- spacr/resources/MEDIAR/core/MEDIAR/utils.py +429 -0
- spacr/resources/MEDIAR/core/__init__.py +2 -0
- spacr/resources/MEDIAR/core/utils.py +40 -0
- spacr/resources/MEDIAR/evaluate.py +71 -0
- spacr/resources/MEDIAR/generate_mapping.py +121 -0
- spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
- spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
- spacr/resources/MEDIAR/image/failure_cases.png +0 -0
- spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
- spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
- spacr/resources/MEDIAR/image/mediar_results.png +0 -0
- spacr/resources/MEDIAR/main.py +125 -0
- spacr/resources/MEDIAR/predict.py +70 -0
- spacr/resources/MEDIAR/requirements.txt +14 -0
- spacr/resources/MEDIAR/train_tools/__init__.py +3 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +1 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +88 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +161 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +77 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +3 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +208 -0
- spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +148 -0
- spacr/resources/MEDIAR/train_tools/data_utils/utils.py +84 -0
- spacr/resources/MEDIAR/train_tools/measures.py +200 -0
- spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +102 -0
- spacr/resources/MEDIAR/train_tools/models/__init__.py +1 -0
- spacr/resources/MEDIAR/train_tools/utils.py +70 -0
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/plaque.png +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif +0 -0
- spacr/settings.py +3 -1
- spacr/utils.py +15 -13
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/METADATA +9 -1
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/RECORD +75 -16
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/LICENSE +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/WHEEL +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/top_level.txt +0 -0
spacr/mediar.py
ADDED
@@ -0,0 +1,364 @@
|
|
1
|
+
import os, sys, gdown, cv2, torch
|
2
|
+
import numpy as np
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
from monai.inferers import sliding_window_inference
|
5
|
+
import skimage.io as io
|
6
|
+
|
7
|
+
# Path to the MEDIAR directory
|
8
|
+
mediar_path = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR')
|
9
|
+
|
10
|
+
print('mediar path', mediar_path)
|
11
|
+
|
12
|
+
# Temporarily create __init__.py to make MEDIAR a package
|
13
|
+
init_file = os.path.join(mediar_path, '__init__.py')
|
14
|
+
if not os.path.exists(init_file):
|
15
|
+
with open(init_file, 'w'): # Create the __init__.py file
|
16
|
+
pass
|
17
|
+
|
18
|
+
# Add MEDIAR to sys.path
|
19
|
+
sys.path.insert(0, mediar_path)
|
20
|
+
|
21
|
+
try:
|
22
|
+
# Now import the dependencies from MEDIAR
|
23
|
+
from core.MEDIAR import Predictor, EnsemblePredictor
|
24
|
+
from train_tools.models import MEDIARFormer
|
25
|
+
finally:
|
26
|
+
# Remove the temporary __init__.py file after the import
|
27
|
+
if os.path.exists(init_file):
|
28
|
+
os.remove(init_file) # Remove the __init__.py file
|
29
|
+
|
30
|
+
def display_imgs_in_list(lists_of_imgs, cmaps=None):
|
31
|
+
"""
|
32
|
+
Displays images from multiple lists side by side.
|
33
|
+
Each row will display one image from each list (lists_of_imgs[i][j] is the j-th image in the i-th list).
|
34
|
+
|
35
|
+
:param lists_of_imgs: A list of lists, where each inner list contains images.
|
36
|
+
:param cmaps: List of colormaps to use for each list (optional). If not provided, defaults to 'gray' for all lists.
|
37
|
+
"""
|
38
|
+
num_lists = len(lists_of_imgs)
|
39
|
+
num_images = len(lists_of_imgs[0])
|
40
|
+
|
41
|
+
# Ensure that all lists have the same number of images
|
42
|
+
for img_list in lists_of_imgs:
|
43
|
+
assert len(img_list) == num_images, "All inner lists must have the same number of images"
|
44
|
+
|
45
|
+
# Use 'gray' as the default colormap if cmaps are not provided
|
46
|
+
if cmaps is None:
|
47
|
+
cmaps = ['gray'] * num_lists
|
48
|
+
else:
|
49
|
+
assert len(cmaps) == num_lists, "The number of colormaps must match the number of lists"
|
50
|
+
|
51
|
+
plt.figure(figsize=(15, 5 * num_images))
|
52
|
+
|
53
|
+
for j in range(num_images):
|
54
|
+
for i, img_list in enumerate(lists_of_imgs):
|
55
|
+
img = img_list[j]
|
56
|
+
plt.subplot(num_images, num_lists, j * num_lists + i + 1)
|
57
|
+
|
58
|
+
if len(img.shape) == 2: # Grayscale image
|
59
|
+
plt.imshow(img, cmap=cmaps[i])
|
60
|
+
elif len(img.shape) == 3 and img.shape[0] == 3: # 3-channel image (C, H, W)
|
61
|
+
plt.imshow(img.transpose(1, 2, 0)) # Change shape to (H, W, C) for displaying
|
62
|
+
else:
|
63
|
+
plt.imshow(img)
|
64
|
+
|
65
|
+
plt.axis('off')
|
66
|
+
plt.title(f'Image {j+1} from list {i+1}')
|
67
|
+
|
68
|
+
plt.tight_layout()
|
69
|
+
plt.show()
|
70
|
+
|
71
|
+
def get_weights(finetuned_weights=False):
|
72
|
+
if finetuned_weights:
|
73
|
+
model_path1 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'from_phase1.pth')
|
74
|
+
if not os.path.exists(model_path1):
|
75
|
+
print("Downloading finetuned model 1...")
|
76
|
+
gdown.download('https://drive.google.com/uc?id=1JJ2-QKTCk-G7sp5ddkqcifMxgnyOrXjx', model_path1, quiet=False)
|
77
|
+
else:
|
78
|
+
model_path1 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'phase1.pth')
|
79
|
+
if not os.path.exists(model_path1):
|
80
|
+
print("Downloading model 1...")
|
81
|
+
gdown.download('https://drive.google.com/uc?id=1v5tYYJDqiwTn_mV0KyX5UEonlViSNx4i', model_path1, quiet=False)
|
82
|
+
|
83
|
+
if finetuned_weights:
|
84
|
+
model_path2 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'from_phase2.pth')
|
85
|
+
if not os.path.exists(model_path2):
|
86
|
+
print("Downloading finetuned model 2...")
|
87
|
+
gdown.download('https://drive.google.com/uc?id=168MtudjTMLoq9YGTyoD2Rjl_d3Gy6c_L', model_path2, quiet=False)
|
88
|
+
else:
|
89
|
+
model_path2 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'phase2.pth')
|
90
|
+
if not os.path.exists(model_path2):
|
91
|
+
print("Downloading model 2...")
|
92
|
+
gdown.download('https://drive.google.com/uc?id=1NHDaYvsYz3G0OCqzegT-bkNcly2clPGR', model_path2, quiet=False)
|
93
|
+
|
94
|
+
return model_path1, model_path2
|
95
|
+
|
96
|
+
def normalize_image(image, lower_percentile=0.0, upper_percentile=99.5):
|
97
|
+
"""
|
98
|
+
Normalize an image based on the 0.0 and 99.5 percentiles.
|
99
|
+
|
100
|
+
:param image: Input image (numpy array).
|
101
|
+
:param lower_percentile: Lower percentile (default is 0.0).
|
102
|
+
:param upper_percentile: Upper percentile (default is 99.5).
|
103
|
+
:return: Normalized image (numpy array).
|
104
|
+
"""
|
105
|
+
lower_bound = np.percentile(image, lower_percentile)
|
106
|
+
upper_bound = np.percentile(image, upper_percentile)
|
107
|
+
|
108
|
+
# Clip image values to the calculated percentiles
|
109
|
+
image = np.clip(image, lower_bound, upper_bound)
|
110
|
+
|
111
|
+
# Normalize to [0, 1]
|
112
|
+
image = (image - lower_bound) / (upper_bound - lower_bound + 1e-5) # Add small epsilon to avoid division by zero
|
113
|
+
|
114
|
+
return image
|
115
|
+
|
116
|
+
class MEDIARPredictor:
|
117
|
+
def __init__(self, input_path=None, output_path=None, device=None, model="ensemble", roi_size=512, overlap=0.6, finetuned_weights=False, test=False, use_tta=False, normalize=True, quantiles=[0.0, 99.5]):
|
118
|
+
if device is None:
|
119
|
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
120
|
+
self.device = device
|
121
|
+
self.test = test
|
122
|
+
self.model = model
|
123
|
+
self.normalize = normalize
|
124
|
+
self.quantiles = quantiles
|
125
|
+
|
126
|
+
# Paths to model weights
|
127
|
+
self.model1_path, self.model2_path = get_weights(finetuned_weights)
|
128
|
+
|
129
|
+
# Load main models
|
130
|
+
self.model1 = self.load_model(self.model1_path, device=self.device)
|
131
|
+
self.model2 = self.load_model(self.model2_path, device=self.device) if model == "ensemble" or model == "model2" else None
|
132
|
+
if self.test:
|
133
|
+
# Define input and output paths for running test
|
134
|
+
self.input_path = os.path.join(os.path.dirname(__file__), 'resources/images')
|
135
|
+
self.output_path = os.path.join(os.path.dirname(__file__), 'resources/MEDIAR/results')
|
136
|
+
else:
|
137
|
+
self.input_path = input_path
|
138
|
+
self.output_path = output_path
|
139
|
+
|
140
|
+
# If using a single model
|
141
|
+
if self.model == "model1":
|
142
|
+
self.predictor = Predictor(
|
143
|
+
model=self.model1,
|
144
|
+
device=self.device,
|
145
|
+
input_path=self.input_path,
|
146
|
+
output_path=self.output_path,
|
147
|
+
algo_params={"use_tta": use_tta}
|
148
|
+
)
|
149
|
+
|
150
|
+
# If using a single model
|
151
|
+
if self.model == "model2":
|
152
|
+
self.predictor = Predictor(
|
153
|
+
model=self.model2,
|
154
|
+
device=self.device,
|
155
|
+
input_path=self.input_path,
|
156
|
+
output_path=self.output_path,
|
157
|
+
algo_params={"use_tta": use_tta}
|
158
|
+
)
|
159
|
+
|
160
|
+
# If using two models
|
161
|
+
elif self.model == "ensemble":
|
162
|
+
self.predictor = EnsemblePredictor(
|
163
|
+
model=self.model1, # Pass model1 as model
|
164
|
+
model_aux=self.model2, # Pass model2 as model_aux
|
165
|
+
device=self.device,
|
166
|
+
input_path=self.input_path,
|
167
|
+
output_path=self.output_path,
|
168
|
+
algo_params={"use_tta": use_tta}
|
169
|
+
)
|
170
|
+
|
171
|
+
if self.test:
|
172
|
+
self.run_test()
|
173
|
+
|
174
|
+
if not self.model in ["model1", "model2", "ensemble"]:
|
175
|
+
raise ValueError("Invalid model type. Choose from 'model1', 'model2', or 'ensemble'.")
|
176
|
+
|
177
|
+
def load_model(self, model_path, device):
|
178
|
+
model_args = {
|
179
|
+
"classes": 3,
|
180
|
+
"decoder_channels": [1024, 512, 256, 128, 64],
|
181
|
+
"decoder_pab_channels": 256,
|
182
|
+
"encoder_name": 'mit_b5',
|
183
|
+
"in_channels": 3
|
184
|
+
}
|
185
|
+
model = MEDIARFormer(**model_args)
|
186
|
+
weights = torch.load(model_path, map_location=device)
|
187
|
+
model.load_state_dict(weights, strict=False)
|
188
|
+
model.to(device)
|
189
|
+
model.eval()
|
190
|
+
return model
|
191
|
+
|
192
|
+
def display_image_and_mask(self, img, mask):
|
193
|
+
|
194
|
+
from .plot import generate_mask_random_cmap
|
195
|
+
"""
|
196
|
+
Displays the normalized input image and the predicted mask side by side.
|
197
|
+
"""
|
198
|
+
# If img is a tensor, convert it to NumPy for display
|
199
|
+
if isinstance(img, torch.Tensor):
|
200
|
+
img = img.cpu().numpy()
|
201
|
+
|
202
|
+
# If mask is a tensor, convert it to NumPy for display
|
203
|
+
if isinstance(mask, torch.Tensor):
|
204
|
+
mask = mask.cpu().numpy()
|
205
|
+
|
206
|
+
# Transpose the image to have (H, W, C) format for display if needed
|
207
|
+
if len(img.shape) == 3 and img.shape[0] == 3:
|
208
|
+
img = img.transpose(1, 2, 0)
|
209
|
+
|
210
|
+
# Scale the normalized image back to [0, 255] for proper display
|
211
|
+
img_display = (img * 255).astype(np.uint8)
|
212
|
+
|
213
|
+
plt.figure(figsize=(10, 5))
|
214
|
+
|
215
|
+
# Display normalized image
|
216
|
+
plt.subplot(1, 2, 1)
|
217
|
+
plt.imshow(img_display)
|
218
|
+
plt.title("Normalized Image")
|
219
|
+
plt.axis("off")
|
220
|
+
|
221
|
+
r_cmap = generate_mask_random_cmap(mask)
|
222
|
+
|
223
|
+
# Display predicted mask
|
224
|
+
plt.subplot(1, 2, 2)
|
225
|
+
plt.imshow(mask, cmap=r_cmap)
|
226
|
+
plt.title("Predicted Mask")
|
227
|
+
plt.axis("off")
|
228
|
+
|
229
|
+
plt.tight_layout()
|
230
|
+
plt.show()
|
231
|
+
|
232
|
+
def predict_batch(self, imgs):
|
233
|
+
"""
|
234
|
+
Predict masks for a batch of images.
|
235
|
+
|
236
|
+
:param imgs: List of input images as NumPy arrays (each in (H, W, C) format).
|
237
|
+
:return: List of predicted masks as NumPy arrays.
|
238
|
+
"""
|
239
|
+
processed_imgs = []
|
240
|
+
|
241
|
+
# Preprocess and normalize each image
|
242
|
+
for img in imgs:
|
243
|
+
if self.normalize:
|
244
|
+
# Normalize the image using the specified quantiles
|
245
|
+
img_normalized = normalize_image(img, lower_percentile=self.quantiles[0], upper_percentile=self.quantiles[1])
|
246
|
+
else:
|
247
|
+
img_normalized = img
|
248
|
+
|
249
|
+
# Convert image to tensor and send to device
|
250
|
+
img_tensor = torch.tensor(img_normalized.astype(np.float32).transpose(2, 0, 1)).to(self.device) # (C, H, W)
|
251
|
+
processed_imgs.append(img_tensor)
|
252
|
+
|
253
|
+
# Stack all processed images into a batch tensor
|
254
|
+
batch_tensor = torch.stack(processed_imgs)
|
255
|
+
|
256
|
+
# Run inference to get predicted masks
|
257
|
+
pred_masks = self.predictor._inference(batch_tensor)
|
258
|
+
|
259
|
+
# Ensure pred_masks is always treated as a batch
|
260
|
+
if len(pred_masks.shape) == 3: # If single image, add batch dimension
|
261
|
+
pred_masks = pred_masks.unsqueeze(0)
|
262
|
+
|
263
|
+
# Convert predictions to NumPy arrays and post-process each mask
|
264
|
+
predicted_masks = []
|
265
|
+
for pred_mask in pred_masks:
|
266
|
+
pred_mask_np = pred_mask.cpu().numpy()
|
267
|
+
|
268
|
+
# Extract dP and cellprob from pred_mask
|
269
|
+
dP = pred_mask_np[:2] # First two channels as dP (displacement field)
|
270
|
+
cellprob = pred_mask_np[2] # Third channel as cell probability
|
271
|
+
|
272
|
+
# Concatenate dP and cellprob along axis 0 to pass a single array
|
273
|
+
combined_pred_mask = np.concatenate([dP, np.expand_dims(cellprob, axis=0)], axis=0)
|
274
|
+
|
275
|
+
# Post-process the predicted mask
|
276
|
+
mask = self.predictor._post_process(combined_pred_mask)
|
277
|
+
|
278
|
+
# Append the processed mask to the list
|
279
|
+
predicted_masks.append(mask.astype(np.uint16))
|
280
|
+
|
281
|
+
return predicted_masks
|
282
|
+
|
283
|
+
def run_test(self):
|
284
|
+
"""
|
285
|
+
Run the model on test images if the test flag is True.
|
286
|
+
"""
|
287
|
+
# List of input images
|
288
|
+
imgs = []
|
289
|
+
img_names = []
|
290
|
+
|
291
|
+
for img_file in os.listdir(self.input_path):
|
292
|
+
img_path = os.path.join(self.input_path, img_file)
|
293
|
+
img = io.imread(img_path)
|
294
|
+
|
295
|
+
# Check if the image is grayscale (2D) or RGB (3D), and convert grayscale to RGB
|
296
|
+
if len(img.shape) == 2: # Grayscale image (H, W)
|
297
|
+
img = np.repeat(img[:, :, np.newaxis], 3, axis=2) # Convert grayscale to RGB
|
298
|
+
|
299
|
+
# Normalize the image if the normalize flag is True
|
300
|
+
if self.normalize:
|
301
|
+
img_normalized = normalize_image(img, lower_percentile=self.quantiles[0], upper_percentile=self.quantiles[1])
|
302
|
+
else:
|
303
|
+
img_normalized = img
|
304
|
+
|
305
|
+
# Convert image to tensor and send directly to device
|
306
|
+
img_tensor = torch.tensor(img_normalized.astype(np.float32).transpose(2, 0, 1)).to(self.device) # (C, H, W)
|
307
|
+
|
308
|
+
imgs.append(img_tensor)
|
309
|
+
img_names.append(os.path.splitext(img_file)[0])
|
310
|
+
|
311
|
+
# Stack all images into a batch (ensure it's always treated as a batch)
|
312
|
+
batch_tensor = torch.stack(imgs)
|
313
|
+
|
314
|
+
# Predict using the predictor (or ensemble predictor)
|
315
|
+
pred_masks = self.predictor._inference(batch_tensor)
|
316
|
+
|
317
|
+
# Ensure pred_masks is always treated as a batch
|
318
|
+
if len(pred_masks.shape) == 3: # If single image, add batch dimension
|
319
|
+
pred_masks = pred_masks.unsqueeze(0)
|
320
|
+
|
321
|
+
# Convert predictions to NumPy arrays and post-process each mask
|
322
|
+
for i, pred_mask in enumerate(pred_masks):
|
323
|
+
# Ensure the dimensions of pred_mask remain consistent
|
324
|
+
pred_mask_np = pred_mask.cpu().numpy()
|
325
|
+
|
326
|
+
# Extract dP and cellprob from pred_mask
|
327
|
+
dP = pred_mask_np[:2] # First two channels as dP (displacement field)
|
328
|
+
cellprob = pred_mask_np[2] # Third channel as cell probability
|
329
|
+
|
330
|
+
# Concatenate dP and cellprob along axis 0 to pass a single array
|
331
|
+
combined_pred_mask = np.concatenate([dP, np.expand_dims(cellprob, axis=0)], axis=0)
|
332
|
+
|
333
|
+
# Post-process the predicted mask
|
334
|
+
mask = self.predictor._post_process(combined_pred_mask)
|
335
|
+
|
336
|
+
# Convert the mask to 16-bit format (ensure values fit into 16-bit range)
|
337
|
+
mask_to_save = mask.astype(np.uint16)
|
338
|
+
|
339
|
+
# Save the post-processed mask as a .tif file using cv2
|
340
|
+
mask_output_path = os.path.join(self.output_path, f"{img_names[i]}_mask.tiff")
|
341
|
+
cv2.imwrite(mask_output_path, mask_to_save)
|
342
|
+
|
343
|
+
print(f"Predicted mask saved at: {mask_output_path}")
|
344
|
+
|
345
|
+
self.display_image_and_mask(imgs[i].cpu().numpy(), mask)
|
346
|
+
|
347
|
+
print(f"Test predictions saved in {self.output_path}")
|
348
|
+
|
349
|
+
def preprocess_image(self, img):
|
350
|
+
"""
|
351
|
+
Preprocess input image (numpy array) for compatibility with the model.
|
352
|
+
"""
|
353
|
+
if isinstance(img, np.ndarray): # Check if the input is a numpy array
|
354
|
+
if len(img.shape) == 2: # Grayscale image (H, W)
|
355
|
+
img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
|
356
|
+
|
357
|
+
elif img.shape[2] == 1: # Single channel grayscale (H, W, 1)
|
358
|
+
img = np.repeat(img, 3, axis=2) # Convert to 3-channel RGB
|
359
|
+
|
360
|
+
img_tensor = torch.tensor(img.astype(np.float32).transpose(2, 0, 1)) # Change shape to (C, H, W)
|
361
|
+
else:
|
362
|
+
img_tensor = img # If it's already a tensor, assume it's in (C, H, W) format
|
363
|
+
|
364
|
+
return img_tensor.float()
|
spacr/plot.py
CHANGED
@@ -123,7 +123,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
123
123
|
|
124
124
|
fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
|
125
125
|
|
126
|
-
return
|
126
|
+
return fig
|
127
127
|
|
128
128
|
def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
|
129
129
|
"""
|
@@ -274,6 +274,7 @@ def _generate_mask_random_cmap(mask):
|
|
274
274
|
return random_cmap
|
275
275
|
|
276
276
|
def _get_colours_merged(outline_color):
|
277
|
+
|
277
278
|
"""
|
278
279
|
Get the merged outline colors based on the specified outline color format.
|
279
280
|
|
@@ -283,6 +284,7 @@ def _get_colours_merged(outline_color):
|
|
283
284
|
Returns:
|
284
285
|
list: A list of merged outline colors based on the specified format.
|
285
286
|
"""
|
287
|
+
|
286
288
|
if outline_color == 'rgb':
|
287
289
|
outline_colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rgb
|
288
290
|
elif outline_color == 'bgr':
|
@@ -296,6 +298,7 @@ def _get_colours_merged(outline_color):
|
|
296
298
|
return outline_colors
|
297
299
|
|
298
300
|
def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, threshold=1000, extensions=['.npy', '.tif', '.tiff', '.png'], overlay=False, max_nr=None, randomize=True):
|
301
|
+
|
299
302
|
"""
|
300
303
|
Plot images and arrays from the given folders.
|
301
304
|
|
@@ -1055,7 +1058,7 @@ def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figu
|
|
1055
1058
|
plt.show()
|
1056
1059
|
|
1057
1060
|
columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
|
1058
|
-
columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
|
1061
|
+
#columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
|
1059
1062
|
|
1060
1063
|
width = figuresize*2
|
1061
1064
|
columns_per_row = math.ceil(len(columns) / 2)
|
@@ -0,0 +1 @@
|
|
1
|
+
gitdir: ../../../.git/modules/spacr/resources/MEDIAR
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2022 opcrisis
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,189 @@
|
|
1
|
+
|
2
|
+
# **MEDIAR: Harmony of Data-Centric and Model-Centric for Multi-Modality Microscopy**
|
3
|
+

|
4
|
+
|
5
|
+
|
6
|
+
This repository provides an official implementation of [MEDIAR: MEDIAR: Harmony of Data-Centric and Model-Centric for Multi-Modality Microscopy](https://arxiv.org/abs/2212.03465), which achieved the ***"1st winner"*** in the [NeurIPS-2022 Cell Segmentation Challenge](https://neurips22-cellseg.grand-challenge.org/).
|
7
|
+
|
8
|
+
To access and try mediar directly, please see links below.
|
9
|
+
- <a href="https://colab.research.google.com/drive/1iFnGu6A_p-5s_eATjNtfjb9-MR5L3pLB?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
10
|
+
- [Huggingface Space](https://huggingface.co/spaces/ghlee94/MEDIAR?logs=build)
|
11
|
+
- [Napari Plugin](https://github.com/joonkeekim/mediar-napari)
|
12
|
+
- [Docker Image](https://hub.docker.com/repository/docker/joonkeekim/mediar/general)
|
13
|
+
|
14
|
+
# 1. MEDIAR Overview
|
15
|
+
<img src="./image/mediar_framework.png" width="1200"/>
|
16
|
+
|
17
|
+
MEIDAR is a framework for efficient cell instance segmentation of multi-modality microscopy images. The above figure illustrates an overview of our approach. MEDIAR harmonizes data-centric and model-centric approaches as the learning and inference strategies, achieving a **0.9067** Mean F1-score on the validation datasets. We provide a brief description of methods that combined in the MEDIAR. Please refer to our paper for more information.
|
18
|
+
# 2. Methods
|
19
|
+
|
20
|
+
## **Data-Centric**
|
21
|
+
- **Cell Aware Augmentation** : We apply two novel cell-aware augmentations. *cell-wisely intensity is randomization* (Cell Intensity Diversification) and *cell-wise boundary pixels exclusion* in the label. The boundary exclusion is adopted only in the pre-training phase.
|
22
|
+
|
23
|
+
- **Two-phase Pretraining and Fine-tuning** : To extract knowledge from large public datasets, we first pretrained our model on public sets, then fine-tune.
|
24
|
+
- Pretraining : We use 7,2412 labeled images from four public datasets for pretraining: OmniPose, CellPose, LiveCell and DataScienceBowl-2018. MEDIAR takes two different phases for the pretraining. the MEDIAR-Former model with encoder parameters initialized from ImageNet-1k pretraining.
|
25
|
+
|
26
|
+
- Fine-tuning : We use two different model for ensemble. First model is fine-tuned 200 epochs using target datasets. Second model is fine-tuned 25 epochs using both target and public datsets.
|
27
|
+
|
28
|
+
- **Modality Discovery & Amplified Sampling** : To balance towards the latent modalities in the datasets, we conduct K-means clustering and discover 40 modalities. In the training phase, we over-sample the minor cluster samples.
|
29
|
+
|
30
|
+
- **Cell Memory Replay** : We concatenate the data from the public dataset with a small portion to the batch and train with boundary-excluded labels.
|
31
|
+
|
32
|
+
## **Model-Centric**
|
33
|
+
- **MEDIAR-Former Architecture** : MEDIAR-Former follows the design paradigm of U-Net, but use SegFormer and MA-Net for the encoder and decoder. The two heads of MEDIAR-Former predicts cell probability and gradieng flow.
|
34
|
+
|
35
|
+
<img src="./image/mediar_model.PNG" width="1200"/>
|
36
|
+
|
37
|
+
- **Gradient Flow Tracking** : We utilize gradient flow tracking proposed by [CellPose](https://github.com/MouseLand/cellpose).
|
38
|
+
|
39
|
+
- **Ensemble with Stochastic TTA**: During the inference, the MEIDAR conduct prediction as sliding-window manner with importance map generated by the gaussian filter. We use two fine-tuned models from phase1 and phase2 pretraining, and ensemble their outputs by summation. For each outputs, test-time augmentation is used.
|
40
|
+
# 3. Experiments
|
41
|
+
|
42
|
+
### **Dataset**
|
43
|
+
- Official Dataset
|
44
|
+
- We are provided the target dataset from [Weakly Supervised Cell Segmentation in Multi-modality High-Resolution Microscopy Images](https://neurips22-cellseg.grand-challenge.org/). It consists of 1,000 labeled images, 1,712 unlabeled images and 13 unlabeled whole slide image from various microscopy types, tissue types, and staining types. Validation set is given with 101 images including 1 whole slide image.
|
45
|
+
|
46
|
+
- Public Dataset
|
47
|
+
- [OmniPose](http://www.cellpose.org/dataset_omnipose) : contains mixtures of 14 bacterial species. We only use 611 bacterial cell microscopy images and discard 118 worm images.
|
48
|
+
- [CellPose](https://www.cellpose.org/dataset) : includes Cytoplasm, cellular microscopy, fluorescent cells images. We used 551 images by discarding 58 non-microscopy images. We convert all images as gray-scale.
|
49
|
+
- [LiveCell](https://github.com/sartorius-research/LIVECell) : is a large-scale dataset with 5,239 images containing 1,686,352 individual cells annotated by trained crowdsources from 8 distinct cell types.
|
50
|
+
- [DataScienceBowl 2018](https://www.kaggle.com/competitions/sartorius-cell-instance-segmentation/overview) : 841 images contain 37,333 cells from 22 cell types, 15 image resolutions, and five visually similar groups.
|
51
|
+
|
52
|
+
### **Testing steps**
|
53
|
+
- **Ensemble Prediction with TTA** : MEDIAR uses sliding-window inference with the overlap size between the adjacent patches as 0.6 and gaussian importance map. To predict the different views on the image, MEDIAR uses Test-Time Augmentation (TTA) for the model prediction and ensemble two models described in **Two-phase Pretraining and Fine-tuning**.
|
54
|
+
|
55
|
+
- **Inference time** : MEDIAR conducts most images in less than 1sec and it depends on the image size and the number of cells, even with ensemble prediction with TTA. Detailed evaluation-time results are in the paper.
|
56
|
+
|
57
|
+
### **Preprocessing & Augmentations**
|
58
|
+
| Strategy | Type | Probability |
|
59
|
+
|----------|:-------------|------|
|
60
|
+
| `Clip` | Pre-processing | . |
|
61
|
+
| `Normalization` | Pre-processing | . |
|
62
|
+
| `Scale Intensity` | Pre-processing | . |
|
63
|
+
| `Zoom` | Spatial Augmentation | 0.5 |
|
64
|
+
| `Spatial Crop` | Spatial Augmentation | 1.0 |
|
65
|
+
| `Axis Flip` | Spatial Augmentation | 0.5 |
|
66
|
+
| `Rotation` | Spatial Augmentation | 0.5 |
|
67
|
+
| `Cell-Aware Intensity` | Intensity Augmentation | 0.25 |
|
68
|
+
| `Gaussian Noise` | Intensity Augmentation | 0.25 |
|
69
|
+
| `Contrast Adjustment` | Intensity Augmentation | 0.25 |
|
70
|
+
| `Gaussian Smoothing` | Intensity Augmentation | 0.25 |
|
71
|
+
| `Histogram Shift` | Intensity Augmentation | 0.25 |
|
72
|
+
| `Gaussian Sharpening` | Intensity Augmentation | 0.25 |
|
73
|
+
| `Boundary Exclusion` | Others | . |
|
74
|
+
|
75
|
+
|
76
|
+
| Learning Setups | Pretraining | Fine-tuning |
|
77
|
+
|----------------------------------------------------------------------|---------------------------------------------------------|---------------------------------------------------------|
|
78
|
+
| Initialization (Encoder) | Imagenet-1k pretrained | from Pretraining |
|
79
|
+
| Initialization (Decoder, Head) | He normal initialization | from Pretraining|
|
80
|
+
| Batch size | 9 | 9 |
|
81
|
+
| Total epochs | 80 (60) | 200 (25) |
|
82
|
+
| Optimizer | AdamW | AdamW |
|
83
|
+
| Initial learning rate (lr) | 5e-5 | 2e-5 |
|
84
|
+
| Lr decay schedule | Cosine scheduler (100 interval) | Cosine scheduler (100 interval) |
|
85
|
+
| Loss function | MSE, BCE | MSE, BCE |
|
86
|
+
|
87
|
+
# 4. Results
|
88
|
+
### **Validation Dataset**
|
89
|
+
- Quantitative Evaluation
|
90
|
+
- Our MEDIAR achieved **0.9067** validation mean F1-score.
|
91
|
+
- Qualitative Evaluation
|
92
|
+
<img src="./image/mediar_results.png" width="1200"/>
|
93
|
+
|
94
|
+
- Failure Cases
|
95
|
+
<img src="./image/failure_cases.png" width="1200"/>
|
96
|
+
|
97
|
+
### **Test Dataset**
|
98
|
+

|
99
|
+

|
100
|
+
|
101
|
+
|
102
|
+
# 5. Reproducing
|
103
|
+
|
104
|
+
### **Our Environment**
|
105
|
+
| Computing Infrastructure| |
|
106
|
+
|-------------------------|----------------------------------------------------------------------|
|
107
|
+
| System | Ubuntu 18.04.5 LTS |
|
108
|
+
| CPU | AMD EPYC 7543 32-Core Processor CPU@2.26GHz |
|
109
|
+
| RAM | 500GB; 3.125MT/s |
|
110
|
+
| GPU (number and type) | NVIDIA A5000 (24GB) 2ea |
|
111
|
+
| CUDA version | 11.7 |
|
112
|
+
| Programming language | Python 3.9 |
|
113
|
+
| Deep learning framework | Pytorch (v1.12, with torchvision v0.13.1) |
|
114
|
+
| Code dependencies | MONAI (v0.9.0), Segmentation Models (v0.3.0) |
|
115
|
+
| Specific dependencies | None |
|
116
|
+
|
117
|
+
To install requirements:
|
118
|
+
|
119
|
+
```
|
120
|
+
pip install -r requirements.txt
|
121
|
+
wandb off
|
122
|
+
```
|
123
|
+
|
124
|
+
## Dataset
|
125
|
+
- The datasets directories under the root should the following structure:
|
126
|
+
|
127
|
+
```
|
128
|
+
Root
|
129
|
+
├── Datasets
|
130
|
+
│ ├── images (images can have various extensions: .tif, .tiff, .png, .bmp ...)
|
131
|
+
│ │ ├── cell_00001.png
|
132
|
+
│ │ ├── cell_00002.tif
|
133
|
+
│ │ ├── cell_00003.xxx
|
134
|
+
│ │ ├── ...
|
135
|
+
│ └── labels (labels must have .tiff extension.)
|
136
|
+
│ │ ├── cell_00001_label.tiff
|
137
|
+
│ │ ├── cell_00002.label.tiff
|
138
|
+
│ │ ├── cell_00003.label.tiff
|
139
|
+
│ │ ├── ...
|
140
|
+
└── ...
|
141
|
+
```
|
142
|
+
|
143
|
+
Before execute the codes, run the follwing code to generate path mappting json file:
|
144
|
+
|
145
|
+
```python
|
146
|
+
python ./generate_mapping.py --root=<path_to_data>
|
147
|
+
```
|
148
|
+
|
149
|
+
## Training
|
150
|
+
|
151
|
+
To train the model(s) in the paper, run the following command:
|
152
|
+
|
153
|
+
```python
|
154
|
+
python ./main.py --config_path=<path_to_config>
|
155
|
+
```
|
156
|
+
Configuration files are in `./config/*`. We provide the pretraining, fine-tuning, and prediction configs. You can refer to the configuration options in the `./config/mediar_example.json`. We also implemented the official challenge baseline code in our framework. You can run the baseline code by running the `./config/baseline.json`.
|
157
|
+
|
158
|
+
## Inference
|
159
|
+
|
160
|
+
To conduct prediction on the testing cases, run the following command:
|
161
|
+
|
162
|
+
```python
|
163
|
+
python predict.py --config_path=<path_to_config>
|
164
|
+
```
|
165
|
+
|
166
|
+
## Evaluation
|
167
|
+
If you have the labels run the following command for evaluation:
|
168
|
+
|
169
|
+
```python
|
170
|
+
python ./evaluate.py --pred_path=<path_to_prediciton_results> --gt_path=<path_to_ground_truth_labels>
|
171
|
+
```
|
172
|
+
|
173
|
+
The configuration files for `predict.py` is slightly different. Please refer to the config files in `./config/step3_prediction/*`.
|
174
|
+
## Trained Models
|
175
|
+
|
176
|
+
You can download MEDIAR pretrained and finetuned models here:
|
177
|
+
|
178
|
+
- [Google Drive Link](https://drive.google.com/drive/folders/1RgMxHIT7WsKNjir3wXSl7BrzlpS05S18?usp=share_link).
|
179
|
+
|
180
|
+
## Citation of this Work
|
181
|
+
```
|
182
|
+
@article{lee2022mediar,
|
183
|
+
title={Mediar: Harmony of data-centric and model-centric for multi-modality microscopy},
|
184
|
+
author={Lee, Gihun and Kim, SangMook and Kim, Joonkee and Yun, Se-Young},
|
185
|
+
journal={arXiv preprint arXiv:2212.03465},
|
186
|
+
year={2022}
|
187
|
+
}
|
188
|
+
```
|
189
|
+
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import torch.optim as optim
|
2
|
+
import torch.optim.lr_scheduler as lr_scheduler
|
3
|
+
import monai
|
4
|
+
|
5
|
+
import core
|
6
|
+
from train_tools import models
|
7
|
+
from train_tools.models import *
|
8
|
+
|
9
|
+
__all__ = ["TRAINER", "OPTIMIZER", "SCHEDULER"]
|
10
|
+
|
11
|
+
TRAINER = {
|
12
|
+
"baseline": core.Baseline.Trainer,
|
13
|
+
"mediar": core.MEDIAR.Trainer,
|
14
|
+
}
|
15
|
+
|
16
|
+
PREDICTOR = {
|
17
|
+
"baseline": core.Baseline.Predictor,
|
18
|
+
"mediar": core.MEDIAR.Predictor,
|
19
|
+
"ensemble_mediar": core.MEDIAR.EnsemblePredictor,
|
20
|
+
}
|
21
|
+
|
22
|
+
MODELS = {
|
23
|
+
"unet": monai.networks.nets.UNet,
|
24
|
+
"unetr": monai.networks.nets.unetr.UNETR,
|
25
|
+
"swinunetr": monai.networks.nets.SwinUNETR,
|
26
|
+
"mediar-former": models.MEDIARFormer,
|
27
|
+
}
|
28
|
+
|
29
|
+
OPTIMIZER = {
|
30
|
+
"sgd": optim.SGD,
|
31
|
+
"adam": optim.Adam,
|
32
|
+
"adamw": optim.AdamW,
|
33
|
+
}
|
34
|
+
|
35
|
+
SCHEDULER = {
|
36
|
+
"step": lr_scheduler.StepLR,
|
37
|
+
"multistep": lr_scheduler.MultiStepLR,
|
38
|
+
"cosine": lr_scheduler.CosineAnnealingLR,
|
39
|
+
}
|