spacr 0.2.81__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. spacr/__init__.py +2 -1
  2. spacr/core.py +107 -12
  3. spacr/gui.py +3 -2
  4. spacr/gui_core.py +8 -4
  5. spacr/gui_utils.py +4 -1
  6. spacr/io.py +13 -13
  7. spacr/measure.py +4 -4
  8. spacr/mediar.py +364 -0
  9. spacr/plot.py +5 -2
  10. spacr/resources/MEDIAR/.git +1 -0
  11. spacr/resources/MEDIAR/.gitignore +18 -0
  12. spacr/resources/MEDIAR/LICENSE +21 -0
  13. spacr/resources/MEDIAR/README.md +189 -0
  14. spacr/resources/MEDIAR/SetupDict.py +39 -0
  15. spacr/resources/MEDIAR/config/baseline.json +60 -0
  16. spacr/resources/MEDIAR/config/mediar_example.json +72 -0
  17. spacr/resources/MEDIAR/config/pred/pred_mediar.json +17 -0
  18. spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +55 -0
  19. spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +58 -0
  20. spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +66 -0
  21. spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +66 -0
  22. spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +16 -0
  23. spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +23 -0
  24. spacr/resources/MEDIAR/core/BasePredictor.py +120 -0
  25. spacr/resources/MEDIAR/core/BaseTrainer.py +240 -0
  26. spacr/resources/MEDIAR/core/Baseline/Predictor.py +59 -0
  27. spacr/resources/MEDIAR/core/Baseline/Trainer.py +113 -0
  28. spacr/resources/MEDIAR/core/Baseline/__init__.py +2 -0
  29. spacr/resources/MEDIAR/core/Baseline/utils.py +80 -0
  30. spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +105 -0
  31. spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +234 -0
  32. spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +172 -0
  33. spacr/resources/MEDIAR/core/MEDIAR/__init__.py +3 -0
  34. spacr/resources/MEDIAR/core/MEDIAR/utils.py +429 -0
  35. spacr/resources/MEDIAR/core/__init__.py +2 -0
  36. spacr/resources/MEDIAR/core/utils.py +40 -0
  37. spacr/resources/MEDIAR/evaluate.py +71 -0
  38. spacr/resources/MEDIAR/generate_mapping.py +121 -0
  39. spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
  40. spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
  41. spacr/resources/MEDIAR/image/failure_cases.png +0 -0
  42. spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
  43. spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
  44. spacr/resources/MEDIAR/image/mediar_results.png +0 -0
  45. spacr/resources/MEDIAR/main.py +125 -0
  46. spacr/resources/MEDIAR/predict.py +70 -0
  47. spacr/resources/MEDIAR/requirements.txt +14 -0
  48. spacr/resources/MEDIAR/train_tools/__init__.py +3 -0
  49. spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +1 -0
  50. spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +88 -0
  51. spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +161 -0
  52. spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +77 -0
  53. spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +3 -0
  54. spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
  55. spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +208 -0
  56. spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +148 -0
  57. spacr/resources/MEDIAR/train_tools/data_utils/utils.py +84 -0
  58. spacr/resources/MEDIAR/train_tools/measures.py +200 -0
  59. spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +102 -0
  60. spacr/resources/MEDIAR/train_tools/models/__init__.py +1 -0
  61. spacr/resources/MEDIAR/train_tools/utils.py +70 -0
  62. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  63. spacr/resources/icons/.DS_Store +0 -0
  64. spacr/resources/icons/plaque.png +0 -0
  65. spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif +0 -0
  66. spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif +0 -0
  67. spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif +0 -0
  68. spacr/settings.py +3 -1
  69. spacr/utils.py +15 -13
  70. {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/METADATA +9 -1
  71. {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/RECORD +75 -16
  72. {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/LICENSE +0 -0
  73. {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/WHEEL +0 -0
  74. {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/entry_points.txt +0 -0
  75. {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/top_level.txt +0 -0
spacr/mediar.py ADDED
@@ -0,0 +1,364 @@
1
+ import os, sys, gdown, cv2, torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from monai.inferers import sliding_window_inference
5
+ import skimage.io as io
6
+
7
+ # Path to the MEDIAR directory
8
+ mediar_path = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR')
9
+
10
+ print('mediar path', mediar_path)
11
+
12
+ # Temporarily create __init__.py to make MEDIAR a package
13
+ init_file = os.path.join(mediar_path, '__init__.py')
14
+ if not os.path.exists(init_file):
15
+ with open(init_file, 'w'): # Create the __init__.py file
16
+ pass
17
+
18
+ # Add MEDIAR to sys.path
19
+ sys.path.insert(0, mediar_path)
20
+
21
+ try:
22
+ # Now import the dependencies from MEDIAR
23
+ from core.MEDIAR import Predictor, EnsemblePredictor
24
+ from train_tools.models import MEDIARFormer
25
+ finally:
26
+ # Remove the temporary __init__.py file after the import
27
+ if os.path.exists(init_file):
28
+ os.remove(init_file) # Remove the __init__.py file
29
+
30
+ def display_imgs_in_list(lists_of_imgs, cmaps=None):
31
+ """
32
+ Displays images from multiple lists side by side.
33
+ Each row will display one image from each list (lists_of_imgs[i][j] is the j-th image in the i-th list).
34
+
35
+ :param lists_of_imgs: A list of lists, where each inner list contains images.
36
+ :param cmaps: List of colormaps to use for each list (optional). If not provided, defaults to 'gray' for all lists.
37
+ """
38
+ num_lists = len(lists_of_imgs)
39
+ num_images = len(lists_of_imgs[0])
40
+
41
+ # Ensure that all lists have the same number of images
42
+ for img_list in lists_of_imgs:
43
+ assert len(img_list) == num_images, "All inner lists must have the same number of images"
44
+
45
+ # Use 'gray' as the default colormap if cmaps are not provided
46
+ if cmaps is None:
47
+ cmaps = ['gray'] * num_lists
48
+ else:
49
+ assert len(cmaps) == num_lists, "The number of colormaps must match the number of lists"
50
+
51
+ plt.figure(figsize=(15, 5 * num_images))
52
+
53
+ for j in range(num_images):
54
+ for i, img_list in enumerate(lists_of_imgs):
55
+ img = img_list[j]
56
+ plt.subplot(num_images, num_lists, j * num_lists + i + 1)
57
+
58
+ if len(img.shape) == 2: # Grayscale image
59
+ plt.imshow(img, cmap=cmaps[i])
60
+ elif len(img.shape) == 3 and img.shape[0] == 3: # 3-channel image (C, H, W)
61
+ plt.imshow(img.transpose(1, 2, 0)) # Change shape to (H, W, C) for displaying
62
+ else:
63
+ plt.imshow(img)
64
+
65
+ plt.axis('off')
66
+ plt.title(f'Image {j+1} from list {i+1}')
67
+
68
+ plt.tight_layout()
69
+ plt.show()
70
+
71
+ def get_weights(finetuned_weights=False):
72
+ if finetuned_weights:
73
+ model_path1 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'from_phase1.pth')
74
+ if not os.path.exists(model_path1):
75
+ print("Downloading finetuned model 1...")
76
+ gdown.download('https://drive.google.com/uc?id=1JJ2-QKTCk-G7sp5ddkqcifMxgnyOrXjx', model_path1, quiet=False)
77
+ else:
78
+ model_path1 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'phase1.pth')
79
+ if not os.path.exists(model_path1):
80
+ print("Downloading model 1...")
81
+ gdown.download('https://drive.google.com/uc?id=1v5tYYJDqiwTn_mV0KyX5UEonlViSNx4i', model_path1, quiet=False)
82
+
83
+ if finetuned_weights:
84
+ model_path2 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'from_phase2.pth')
85
+ if not os.path.exists(model_path2):
86
+ print("Downloading finetuned model 2...")
87
+ gdown.download('https://drive.google.com/uc?id=168MtudjTMLoq9YGTyoD2Rjl_d3Gy6c_L', model_path2, quiet=False)
88
+ else:
89
+ model_path2 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'phase2.pth')
90
+ if not os.path.exists(model_path2):
91
+ print("Downloading model 2...")
92
+ gdown.download('https://drive.google.com/uc?id=1NHDaYvsYz3G0OCqzegT-bkNcly2clPGR', model_path2, quiet=False)
93
+
94
+ return model_path1, model_path2
95
+
96
+ def normalize_image(image, lower_percentile=0.0, upper_percentile=99.5):
97
+ """
98
+ Normalize an image based on the 0.0 and 99.5 percentiles.
99
+
100
+ :param image: Input image (numpy array).
101
+ :param lower_percentile: Lower percentile (default is 0.0).
102
+ :param upper_percentile: Upper percentile (default is 99.5).
103
+ :return: Normalized image (numpy array).
104
+ """
105
+ lower_bound = np.percentile(image, lower_percentile)
106
+ upper_bound = np.percentile(image, upper_percentile)
107
+
108
+ # Clip image values to the calculated percentiles
109
+ image = np.clip(image, lower_bound, upper_bound)
110
+
111
+ # Normalize to [0, 1]
112
+ image = (image - lower_bound) / (upper_bound - lower_bound + 1e-5) # Add small epsilon to avoid division by zero
113
+
114
+ return image
115
+
116
+ class MEDIARPredictor:
117
+ def __init__(self, input_path=None, output_path=None, device=None, model="ensemble", roi_size=512, overlap=0.6, finetuned_weights=False, test=False, use_tta=False, normalize=True, quantiles=[0.0, 99.5]):
118
+ if device is None:
119
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
120
+ self.device = device
121
+ self.test = test
122
+ self.model = model
123
+ self.normalize = normalize
124
+ self.quantiles = quantiles
125
+
126
+ # Paths to model weights
127
+ self.model1_path, self.model2_path = get_weights(finetuned_weights)
128
+
129
+ # Load main models
130
+ self.model1 = self.load_model(self.model1_path, device=self.device)
131
+ self.model2 = self.load_model(self.model2_path, device=self.device) if model == "ensemble" or model == "model2" else None
132
+ if self.test:
133
+ # Define input and output paths for running test
134
+ self.input_path = os.path.join(os.path.dirname(__file__), 'resources/images')
135
+ self.output_path = os.path.join(os.path.dirname(__file__), 'resources/MEDIAR/results')
136
+ else:
137
+ self.input_path = input_path
138
+ self.output_path = output_path
139
+
140
+ # If using a single model
141
+ if self.model == "model1":
142
+ self.predictor = Predictor(
143
+ model=self.model1,
144
+ device=self.device,
145
+ input_path=self.input_path,
146
+ output_path=self.output_path,
147
+ algo_params={"use_tta": use_tta}
148
+ )
149
+
150
+ # If using a single model
151
+ if self.model == "model2":
152
+ self.predictor = Predictor(
153
+ model=self.model2,
154
+ device=self.device,
155
+ input_path=self.input_path,
156
+ output_path=self.output_path,
157
+ algo_params={"use_tta": use_tta}
158
+ )
159
+
160
+ # If using two models
161
+ elif self.model == "ensemble":
162
+ self.predictor = EnsemblePredictor(
163
+ model=self.model1, # Pass model1 as model
164
+ model_aux=self.model2, # Pass model2 as model_aux
165
+ device=self.device,
166
+ input_path=self.input_path,
167
+ output_path=self.output_path,
168
+ algo_params={"use_tta": use_tta}
169
+ )
170
+
171
+ if self.test:
172
+ self.run_test()
173
+
174
+ if not self.model in ["model1", "model2", "ensemble"]:
175
+ raise ValueError("Invalid model type. Choose from 'model1', 'model2', or 'ensemble'.")
176
+
177
+ def load_model(self, model_path, device):
178
+ model_args = {
179
+ "classes": 3,
180
+ "decoder_channels": [1024, 512, 256, 128, 64],
181
+ "decoder_pab_channels": 256,
182
+ "encoder_name": 'mit_b5',
183
+ "in_channels": 3
184
+ }
185
+ model = MEDIARFormer(**model_args)
186
+ weights = torch.load(model_path, map_location=device)
187
+ model.load_state_dict(weights, strict=False)
188
+ model.to(device)
189
+ model.eval()
190
+ return model
191
+
192
+ def display_image_and_mask(self, img, mask):
193
+
194
+ from .plot import generate_mask_random_cmap
195
+ """
196
+ Displays the normalized input image and the predicted mask side by side.
197
+ """
198
+ # If img is a tensor, convert it to NumPy for display
199
+ if isinstance(img, torch.Tensor):
200
+ img = img.cpu().numpy()
201
+
202
+ # If mask is a tensor, convert it to NumPy for display
203
+ if isinstance(mask, torch.Tensor):
204
+ mask = mask.cpu().numpy()
205
+
206
+ # Transpose the image to have (H, W, C) format for display if needed
207
+ if len(img.shape) == 3 and img.shape[0] == 3:
208
+ img = img.transpose(1, 2, 0)
209
+
210
+ # Scale the normalized image back to [0, 255] for proper display
211
+ img_display = (img * 255).astype(np.uint8)
212
+
213
+ plt.figure(figsize=(10, 5))
214
+
215
+ # Display normalized image
216
+ plt.subplot(1, 2, 1)
217
+ plt.imshow(img_display)
218
+ plt.title("Normalized Image")
219
+ plt.axis("off")
220
+
221
+ r_cmap = generate_mask_random_cmap(mask)
222
+
223
+ # Display predicted mask
224
+ plt.subplot(1, 2, 2)
225
+ plt.imshow(mask, cmap=r_cmap)
226
+ plt.title("Predicted Mask")
227
+ plt.axis("off")
228
+
229
+ plt.tight_layout()
230
+ plt.show()
231
+
232
+ def predict_batch(self, imgs):
233
+ """
234
+ Predict masks for a batch of images.
235
+
236
+ :param imgs: List of input images as NumPy arrays (each in (H, W, C) format).
237
+ :return: List of predicted masks as NumPy arrays.
238
+ """
239
+ processed_imgs = []
240
+
241
+ # Preprocess and normalize each image
242
+ for img in imgs:
243
+ if self.normalize:
244
+ # Normalize the image using the specified quantiles
245
+ img_normalized = normalize_image(img, lower_percentile=self.quantiles[0], upper_percentile=self.quantiles[1])
246
+ else:
247
+ img_normalized = img
248
+
249
+ # Convert image to tensor and send to device
250
+ img_tensor = torch.tensor(img_normalized.astype(np.float32).transpose(2, 0, 1)).to(self.device) # (C, H, W)
251
+ processed_imgs.append(img_tensor)
252
+
253
+ # Stack all processed images into a batch tensor
254
+ batch_tensor = torch.stack(processed_imgs)
255
+
256
+ # Run inference to get predicted masks
257
+ pred_masks = self.predictor._inference(batch_tensor)
258
+
259
+ # Ensure pred_masks is always treated as a batch
260
+ if len(pred_masks.shape) == 3: # If single image, add batch dimension
261
+ pred_masks = pred_masks.unsqueeze(0)
262
+
263
+ # Convert predictions to NumPy arrays and post-process each mask
264
+ predicted_masks = []
265
+ for pred_mask in pred_masks:
266
+ pred_mask_np = pred_mask.cpu().numpy()
267
+
268
+ # Extract dP and cellprob from pred_mask
269
+ dP = pred_mask_np[:2] # First two channels as dP (displacement field)
270
+ cellprob = pred_mask_np[2] # Third channel as cell probability
271
+
272
+ # Concatenate dP and cellprob along axis 0 to pass a single array
273
+ combined_pred_mask = np.concatenate([dP, np.expand_dims(cellprob, axis=0)], axis=0)
274
+
275
+ # Post-process the predicted mask
276
+ mask = self.predictor._post_process(combined_pred_mask)
277
+
278
+ # Append the processed mask to the list
279
+ predicted_masks.append(mask.astype(np.uint16))
280
+
281
+ return predicted_masks
282
+
283
+ def run_test(self):
284
+ """
285
+ Run the model on test images if the test flag is True.
286
+ """
287
+ # List of input images
288
+ imgs = []
289
+ img_names = []
290
+
291
+ for img_file in os.listdir(self.input_path):
292
+ img_path = os.path.join(self.input_path, img_file)
293
+ img = io.imread(img_path)
294
+
295
+ # Check if the image is grayscale (2D) or RGB (3D), and convert grayscale to RGB
296
+ if len(img.shape) == 2: # Grayscale image (H, W)
297
+ img = np.repeat(img[:, :, np.newaxis], 3, axis=2) # Convert grayscale to RGB
298
+
299
+ # Normalize the image if the normalize flag is True
300
+ if self.normalize:
301
+ img_normalized = normalize_image(img, lower_percentile=self.quantiles[0], upper_percentile=self.quantiles[1])
302
+ else:
303
+ img_normalized = img
304
+
305
+ # Convert image to tensor and send directly to device
306
+ img_tensor = torch.tensor(img_normalized.astype(np.float32).transpose(2, 0, 1)).to(self.device) # (C, H, W)
307
+
308
+ imgs.append(img_tensor)
309
+ img_names.append(os.path.splitext(img_file)[0])
310
+
311
+ # Stack all images into a batch (ensure it's always treated as a batch)
312
+ batch_tensor = torch.stack(imgs)
313
+
314
+ # Predict using the predictor (or ensemble predictor)
315
+ pred_masks = self.predictor._inference(batch_tensor)
316
+
317
+ # Ensure pred_masks is always treated as a batch
318
+ if len(pred_masks.shape) == 3: # If single image, add batch dimension
319
+ pred_masks = pred_masks.unsqueeze(0)
320
+
321
+ # Convert predictions to NumPy arrays and post-process each mask
322
+ for i, pred_mask in enumerate(pred_masks):
323
+ # Ensure the dimensions of pred_mask remain consistent
324
+ pred_mask_np = pred_mask.cpu().numpy()
325
+
326
+ # Extract dP and cellprob from pred_mask
327
+ dP = pred_mask_np[:2] # First two channels as dP (displacement field)
328
+ cellprob = pred_mask_np[2] # Third channel as cell probability
329
+
330
+ # Concatenate dP and cellprob along axis 0 to pass a single array
331
+ combined_pred_mask = np.concatenate([dP, np.expand_dims(cellprob, axis=0)], axis=0)
332
+
333
+ # Post-process the predicted mask
334
+ mask = self.predictor._post_process(combined_pred_mask)
335
+
336
+ # Convert the mask to 16-bit format (ensure values fit into 16-bit range)
337
+ mask_to_save = mask.astype(np.uint16)
338
+
339
+ # Save the post-processed mask as a .tif file using cv2
340
+ mask_output_path = os.path.join(self.output_path, f"{img_names[i]}_mask.tiff")
341
+ cv2.imwrite(mask_output_path, mask_to_save)
342
+
343
+ print(f"Predicted mask saved at: {mask_output_path}")
344
+
345
+ self.display_image_and_mask(imgs[i].cpu().numpy(), mask)
346
+
347
+ print(f"Test predictions saved in {self.output_path}")
348
+
349
+ def preprocess_image(self, img):
350
+ """
351
+ Preprocess input image (numpy array) for compatibility with the model.
352
+ """
353
+ if isinstance(img, np.ndarray): # Check if the input is a numpy array
354
+ if len(img.shape) == 2: # Grayscale image (H, W)
355
+ img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
356
+
357
+ elif img.shape[2] == 1: # Single channel grayscale (H, W, 1)
358
+ img = np.repeat(img, 3, axis=2) # Convert to 3-channel RGB
359
+
360
+ img_tensor = torch.tensor(img.astype(np.float32).transpose(2, 0, 1)) # Change shape to (C, H, W)
361
+ else:
362
+ img_tensor = img # If it's already a tensor, assume it's in (C, H, W) format
363
+
364
+ return img_tensor.float()
spacr/plot.py CHANGED
@@ -123,7 +123,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
123
123
 
124
124
  fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
125
125
 
126
- return
126
+ return fig
127
127
 
128
128
  def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
129
129
  """
@@ -274,6 +274,7 @@ def _generate_mask_random_cmap(mask):
274
274
  return random_cmap
275
275
 
276
276
  def _get_colours_merged(outline_color):
277
+
277
278
  """
278
279
  Get the merged outline colors based on the specified outline color format.
279
280
 
@@ -283,6 +284,7 @@ def _get_colours_merged(outline_color):
283
284
  Returns:
284
285
  list: A list of merged outline colors based on the specified format.
285
286
  """
287
+
286
288
  if outline_color == 'rgb':
287
289
  outline_colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rgb
288
290
  elif outline_color == 'bgr':
@@ -296,6 +298,7 @@ def _get_colours_merged(outline_color):
296
298
  return outline_colors
297
299
 
298
300
  def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, threshold=1000, extensions=['.npy', '.tif', '.tiff', '.png'], overlay=False, max_nr=None, randomize=True):
301
+
299
302
  """
300
303
  Plot images and arrays from the given folders.
301
304
 
@@ -1055,7 +1058,7 @@ def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figu
1055
1058
  plt.show()
1056
1059
 
1057
1060
  columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
1058
- columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
1061
+ #columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
1059
1062
 
1060
1063
  width = figuresize*2
1061
1064
  columns_per_row = math.ceil(len(columns) / 2)
@@ -0,0 +1 @@
1
+ gitdir: ../../../.git/modules/spacr/resources/MEDIAR
@@ -0,0 +1,18 @@
1
+ config/
2
+ *.log
3
+ *.ipynb
4
+ *.ipynb_checkpoints/
5
+ __pycache__/
6
+ results/
7
+ weights/
8
+ wandb/
9
+ data/
10
+ submissions/
11
+ /.vscode
12
+ *.npy
13
+ *.pth
14
+ *.sh
15
+ *.json
16
+ *.out
17
+ *.zip
18
+ *.tiff
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022 opcrisis
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,189 @@
1
+
2
+ # **MEDIAR: Harmony of Data-Centric and Model-Centric for Multi-Modality Microscopy**
3
+ ![1-2](https://user-images.githubusercontent.com/12638561/207771867-0b1414f2-cf48-4747-9cda-3304e6d86bfd.png)
4
+
5
+
6
+ This repository provides an official implementation of [MEDIAR: MEDIAR: Harmony of Data-Centric and Model-Centric for Multi-Modality Microscopy](https://arxiv.org/abs/2212.03465), which achieved the ***"1st winner"*** in the [NeurIPS-2022 Cell Segmentation Challenge](https://neurips22-cellseg.grand-challenge.org/).
7
+
8
+ To access and try mediar directly, please see links below.
9
+ - <a href="https://colab.research.google.com/drive/1iFnGu6A_p-5s_eATjNtfjb9-MR5L3pLB?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
10
+ - [Huggingface Space](https://huggingface.co/spaces/ghlee94/MEDIAR?logs=build)
11
+ - [Napari Plugin](https://github.com/joonkeekim/mediar-napari)
12
+ - [Docker Image](https://hub.docker.com/repository/docker/joonkeekim/mediar/general)
13
+
14
+ # 1. MEDIAR Overview
15
+ <img src="./image/mediar_framework.png" width="1200"/>
16
+
17
+ MEIDAR is a framework for efficient cell instance segmentation of multi-modality microscopy images. The above figure illustrates an overview of our approach. MEDIAR harmonizes data-centric and model-centric approaches as the learning and inference strategies, achieving a **0.9067** Mean F1-score on the validation datasets. We provide a brief description of methods that combined in the MEDIAR. Please refer to our paper for more information.
18
+ # 2. Methods
19
+
20
+ ## **Data-Centric**
21
+ - **Cell Aware Augmentation** : We apply two novel cell-aware augmentations. *cell-wisely intensity is randomization* (Cell Intensity Diversification) and *cell-wise boundary pixels exclusion* in the label. The boundary exclusion is adopted only in the pre-training phase.
22
+
23
+ - **Two-phase Pretraining and Fine-tuning** : To extract knowledge from large public datasets, we first pretrained our model on public sets, then fine-tune.
24
+ - Pretraining : We use 7,2412 labeled images from four public datasets for pretraining: OmniPose, CellPose, LiveCell and DataScienceBowl-2018. MEDIAR takes two different phases for the pretraining. the MEDIAR-Former model with encoder parameters initialized from ImageNet-1k pretraining.
25
+
26
+ - Fine-tuning : We use two different model for ensemble. First model is fine-tuned 200 epochs using target datasets. Second model is fine-tuned 25 epochs using both target and public datsets.
27
+
28
+ - **Modality Discovery & Amplified Sampling** : To balance towards the latent modalities in the datasets, we conduct K-means clustering and discover 40 modalities. In the training phase, we over-sample the minor cluster samples.
29
+
30
+ - **Cell Memory Replay** : We concatenate the data from the public dataset with a small portion to the batch and train with boundary-excluded labels.
31
+
32
+ ## **Model-Centric**
33
+ - **MEDIAR-Former Architecture** : MEDIAR-Former follows the design paradigm of U-Net, but use SegFormer and MA-Net for the encoder and decoder. The two heads of MEDIAR-Former predicts cell probability and gradieng flow.
34
+
35
+ <img src="./image/mediar_model.PNG" width="1200"/>
36
+
37
+ - **Gradient Flow Tracking** : We utilize gradient flow tracking proposed by [CellPose](https://github.com/MouseLand/cellpose).
38
+
39
+ - **Ensemble with Stochastic TTA**: During the inference, the MEIDAR conduct prediction as sliding-window manner with importance map generated by the gaussian filter. We use two fine-tuned models from phase1 and phase2 pretraining, and ensemble their outputs by summation. For each outputs, test-time augmentation is used.
40
+ # 3. Experiments
41
+
42
+ ### **Dataset**
43
+ - Official Dataset
44
+ - We are provided the target dataset from [Weakly Supervised Cell Segmentation in Multi-modality High-Resolution Microscopy Images](https://neurips22-cellseg.grand-challenge.org/). It consists of 1,000 labeled images, 1,712 unlabeled images and 13 unlabeled whole slide image from various microscopy types, tissue types, and staining types. Validation set is given with 101 images including 1 whole slide image.
45
+
46
+ - Public Dataset
47
+ - [OmniPose](http://www.cellpose.org/dataset_omnipose) : contains mixtures of 14 bacterial species. We only use 611 bacterial cell microscopy images and discard 118 worm images.
48
+ - [CellPose](https://www.cellpose.org/dataset) : includes Cytoplasm, cellular microscopy, fluorescent cells images. We used 551 images by discarding 58 non-microscopy images. We convert all images as gray-scale.
49
+ - [LiveCell](https://github.com/sartorius-research/LIVECell) : is a large-scale dataset with 5,239 images containing 1,686,352 individual cells annotated by trained crowdsources from 8 distinct cell types.
50
+ - [DataScienceBowl 2018](https://www.kaggle.com/competitions/sartorius-cell-instance-segmentation/overview) : 841 images contain 37,333 cells from 22 cell types, 15 image resolutions, and five visually similar groups.
51
+
52
+ ### **Testing steps**
53
+ - **Ensemble Prediction with TTA** : MEDIAR uses sliding-window inference with the overlap size between the adjacent patches as 0.6 and gaussian importance map. To predict the different views on the image, MEDIAR uses Test-Time Augmentation (TTA) for the model prediction and ensemble two models described in **Two-phase Pretraining and Fine-tuning**.
54
+
55
+ - **Inference time** : MEDIAR conducts most images in less than 1sec and it depends on the image size and the number of cells, even with ensemble prediction with TTA. Detailed evaluation-time results are in the paper.
56
+
57
+ ### **Preprocessing & Augmentations**
58
+ | Strategy | Type | Probability |
59
+ |----------|:-------------|------|
60
+ | `Clip` | Pre-processing | . |
61
+ | `Normalization` | Pre-processing | . |
62
+ | `Scale Intensity` | Pre-processing | . |
63
+ | `Zoom` | Spatial Augmentation | 0.5 |
64
+ | `Spatial Crop` | Spatial Augmentation | 1.0 |
65
+ | `Axis Flip` | Spatial Augmentation | 0.5 |
66
+ | `Rotation` | Spatial Augmentation | 0.5 |
67
+ | `Cell-Aware Intensity` | Intensity Augmentation | 0.25 |
68
+ | `Gaussian Noise` | Intensity Augmentation | 0.25 |
69
+ | `Contrast Adjustment` | Intensity Augmentation | 0.25 |
70
+ | `Gaussian Smoothing` | Intensity Augmentation | 0.25 |
71
+ | `Histogram Shift` | Intensity Augmentation | 0.25 |
72
+ | `Gaussian Sharpening` | Intensity Augmentation | 0.25 |
73
+ | `Boundary Exclusion` | Others | . |
74
+
75
+
76
+ | Learning Setups | Pretraining | Fine-tuning |
77
+ |----------------------------------------------------------------------|---------------------------------------------------------|---------------------------------------------------------|
78
+ | Initialization (Encoder) | Imagenet-1k pretrained | from Pretraining |
79
+ | Initialization (Decoder, Head) | He normal initialization | from Pretraining|
80
+ | Batch size | 9 | 9 |
81
+ | Total epochs | 80 (60) | 200 (25) |
82
+ | Optimizer | AdamW | AdamW |
83
+ | Initial learning rate (lr) | 5e-5 | 2e-5 |
84
+ | Lr decay schedule | Cosine scheduler (100 interval) | Cosine scheduler (100 interval) |
85
+ | Loss function | MSE, BCE | MSE, BCE |
86
+
87
+ # 4. Results
88
+ ### **Validation Dataset**
89
+ - Quantitative Evaluation
90
+ - Our MEDIAR achieved **0.9067** validation mean F1-score.
91
+ - Qualitative Evaluation
92
+ <img src="./image/mediar_results.png" width="1200"/>
93
+
94
+ - Failure Cases
95
+ <img src="./image/failure_cases.png" width="1200"/>
96
+
97
+ ### **Test Dataset**
98
+ ![F1_osilab](https://user-images.githubusercontent.com/12638561/207772559-2185b79c-8288-4556-a3b4-9bd1d359fceb.png)
99
+ ![RunningTime_osilab](https://user-images.githubusercontent.com/12638561/207772555-c3b29071-6e03-4985-837a-da7b3dd3b65d.png)
100
+
101
+
102
+ # 5. Reproducing
103
+
104
+ ### **Our Environment**
105
+ | Computing Infrastructure| |
106
+ |-------------------------|----------------------------------------------------------------------|
107
+ | System | Ubuntu 18.04.5 LTS |
108
+ | CPU | AMD EPYC 7543 32-Core Processor CPU@2.26GHz |
109
+ | RAM | 500GB; 3.125MT/s |
110
+ | GPU (number and type) | NVIDIA A5000 (24GB) 2ea |
111
+ | CUDA version | 11.7 |
112
+ | Programming language | Python 3.9 |
113
+ | Deep learning framework | Pytorch (v1.12, with torchvision v0.13.1) |
114
+ | Code dependencies | MONAI (v0.9.0), Segmentation Models (v0.3.0) |
115
+ | Specific dependencies | None |
116
+
117
+ To install requirements:
118
+
119
+ ```
120
+ pip install -r requirements.txt
121
+ wandb off
122
+ ```
123
+
124
+ ## Dataset
125
+ - The datasets directories under the root should the following structure:
126
+
127
+ ```
128
+ Root
129
+ ├── Datasets
130
+ │ ├── images (images can have various extensions: .tif, .tiff, .png, .bmp ...)
131
+ │ │ ├── cell_00001.png
132
+ │ │ ├── cell_00002.tif
133
+ │ │ ├── cell_00003.xxx
134
+ │ │ ├── ...
135
+ │ └── labels (labels must have .tiff extension.)
136
+ │ │ ├── cell_00001_label.tiff
137
+ │ │ ├── cell_00002.label.tiff
138
+ │ │ ├── cell_00003.label.tiff
139
+ │ │ ├── ...
140
+ └── ...
141
+ ```
142
+
143
+ Before execute the codes, run the follwing code to generate path mappting json file:
144
+
145
+ ```python
146
+ python ./generate_mapping.py --root=<path_to_data>
147
+ ```
148
+
149
+ ## Training
150
+
151
+ To train the model(s) in the paper, run the following command:
152
+
153
+ ```python
154
+ python ./main.py --config_path=<path_to_config>
155
+ ```
156
+ Configuration files are in `./config/*`. We provide the pretraining, fine-tuning, and prediction configs. You can refer to the configuration options in the `./config/mediar_example.json`. We also implemented the official challenge baseline code in our framework. You can run the baseline code by running the `./config/baseline.json`.
157
+
158
+ ## Inference
159
+
160
+ To conduct prediction on the testing cases, run the following command:
161
+
162
+ ```python
163
+ python predict.py --config_path=<path_to_config>
164
+ ```
165
+
166
+ ## Evaluation
167
+ If you have the labels run the following command for evaluation:
168
+
169
+ ```python
170
+ python ./evaluate.py --pred_path=<path_to_prediciton_results> --gt_path=<path_to_ground_truth_labels>
171
+ ```
172
+
173
+ The configuration files for `predict.py` is slightly different. Please refer to the config files in `./config/step3_prediction/*`.
174
+ ## Trained Models
175
+
176
+ You can download MEDIAR pretrained and finetuned models here:
177
+
178
+ - [Google Drive Link](https://drive.google.com/drive/folders/1RgMxHIT7WsKNjir3wXSl7BrzlpS05S18?usp=share_link).
179
+
180
+ ## Citation of this Work
181
+ ```
182
+ @article{lee2022mediar,
183
+ title={Mediar: Harmony of data-centric and model-centric for multi-modality microscopy},
184
+ author={Lee, Gihun and Kim, SangMook and Kim, Joonkee and Yun, Se-Young},
185
+ journal={arXiv preprint arXiv:2212.03465},
186
+ year={2022}
187
+ }
188
+ ```
189
+
@@ -0,0 +1,39 @@
1
+ import torch.optim as optim
2
+ import torch.optim.lr_scheduler as lr_scheduler
3
+ import monai
4
+
5
+ import core
6
+ from train_tools import models
7
+ from train_tools.models import *
8
+
9
+ __all__ = ["TRAINER", "OPTIMIZER", "SCHEDULER"]
10
+
11
+ TRAINER = {
12
+ "baseline": core.Baseline.Trainer,
13
+ "mediar": core.MEDIAR.Trainer,
14
+ }
15
+
16
+ PREDICTOR = {
17
+ "baseline": core.Baseline.Predictor,
18
+ "mediar": core.MEDIAR.Predictor,
19
+ "ensemble_mediar": core.MEDIAR.EnsemblePredictor,
20
+ }
21
+
22
+ MODELS = {
23
+ "unet": monai.networks.nets.UNet,
24
+ "unetr": monai.networks.nets.unetr.UNETR,
25
+ "swinunetr": monai.networks.nets.SwinUNETR,
26
+ "mediar-former": models.MEDIARFormer,
27
+ }
28
+
29
+ OPTIMIZER = {
30
+ "sgd": optim.SGD,
31
+ "adam": optim.Adam,
32
+ "adamw": optim.AdamW,
33
+ }
34
+
35
+ SCHEDULER = {
36
+ "step": lr_scheduler.StepLR,
37
+ "multistep": lr_scheduler.MultiStepLR,
38
+ "cosine": lr_scheduler.CosineAnnealingLR,
39
+ }