spacr 0.2.81__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. spacr/__init__.py +2 -1
  2. spacr/core.py +106 -11
  3. spacr/gui.py +3 -2
  4. spacr/gui_core.py +8 -4
  5. spacr/gui_utils.py +4 -1
  6. spacr/io.py +1 -1
  7. spacr/measure.py +4 -4
  8. spacr/mediar.py +366 -0
  9. spacr/plot.py +4 -1
  10. spacr/resources/MEDIAR/.git +1 -0
  11. spacr/resources/MEDIAR/.gitignore +18 -0
  12. spacr/resources/MEDIAR/LICENSE +21 -0
  13. spacr/resources/MEDIAR/README.md +189 -0
  14. spacr/resources/MEDIAR/SetupDict.py +39 -0
  15. spacr/resources/MEDIAR/config/baseline.json +60 -0
  16. spacr/resources/MEDIAR/config/mediar_example.json +72 -0
  17. spacr/resources/MEDIAR/config/pred/pred_mediar.json +17 -0
  18. spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +55 -0
  19. spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +58 -0
  20. spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +66 -0
  21. spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +66 -0
  22. spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +16 -0
  23. spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +23 -0
  24. spacr/resources/MEDIAR/core/BasePredictor.py +120 -0
  25. spacr/resources/MEDIAR/core/BaseTrainer.py +240 -0
  26. spacr/resources/MEDIAR/core/Baseline/Predictor.py +59 -0
  27. spacr/resources/MEDIAR/core/Baseline/Trainer.py +113 -0
  28. spacr/resources/MEDIAR/core/Baseline/__init__.py +2 -0
  29. spacr/resources/MEDIAR/core/Baseline/utils.py +80 -0
  30. spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +105 -0
  31. spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +234 -0
  32. spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +172 -0
  33. spacr/resources/MEDIAR/core/MEDIAR/__init__.py +3 -0
  34. spacr/resources/MEDIAR/core/MEDIAR/utils.py +429 -0
  35. spacr/resources/MEDIAR/core/__init__.py +2 -0
  36. spacr/resources/MEDIAR/core/utils.py +40 -0
  37. spacr/resources/MEDIAR/evaluate.py +71 -0
  38. spacr/resources/MEDIAR/generate_mapping.py +121 -0
  39. spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
  40. spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
  41. spacr/resources/MEDIAR/image/failure_cases.png +0 -0
  42. spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
  43. spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
  44. spacr/resources/MEDIAR/image/mediar_results.png +0 -0
  45. spacr/resources/MEDIAR/main.py +125 -0
  46. spacr/resources/MEDIAR/predict.py +70 -0
  47. spacr/resources/MEDIAR/requirements.txt +14 -0
  48. spacr/resources/MEDIAR/train_tools/__init__.py +3 -0
  49. spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +1 -0
  50. spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +88 -0
  51. spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +161 -0
  52. spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +77 -0
  53. spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +3 -0
  54. spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
  55. spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +208 -0
  56. spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +148 -0
  57. spacr/resources/MEDIAR/train_tools/data_utils/utils.py +84 -0
  58. spacr/resources/MEDIAR/train_tools/measures.py +200 -0
  59. spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +102 -0
  60. spacr/resources/MEDIAR/train_tools/models/__init__.py +1 -0
  61. spacr/resources/MEDIAR/train_tools/utils.py +70 -0
  62. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  63. spacr/resources/icons/.DS_Store +0 -0
  64. spacr/resources/icons/plaque.png +0 -0
  65. spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif +0 -0
  66. spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif +0 -0
  67. spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif +0 -0
  68. spacr/settings.py +3 -1
  69. spacr/utils.py +10 -10
  70. {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/METADATA +9 -1
  71. {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/RECORD +75 -16
  72. {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/LICENSE +0 -0
  73. {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/WHEEL +0 -0
  74. {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/entry_points.txt +0 -0
  75. {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/top_level.txt +0 -0
spacr/__init__.py CHANGED
@@ -23,9 +23,9 @@ from . import app_measure
23
23
  from . import app_classify
24
24
  from . import app_sequencing
25
25
  from . import app_umap
26
+ from . import mediar
26
27
  from . import logger
27
28
 
28
-
29
29
  __all__ = [
30
30
  "core",
31
31
  "io",
@@ -48,6 +48,7 @@ __all__ = [
48
48
  "app_classify",
49
49
  "app_sequencing",
50
50
  "app_umap",
51
+ "mediar",
51
52
  "logger"
52
53
  ]
53
54
 
spacr/core.py CHANGED
@@ -50,8 +50,6 @@ import random
50
50
  from PIL import Image
51
51
  from torchvision.transforms import ToTensor
52
52
 
53
-
54
-
55
53
  def analyze_plaques(folder):
56
54
  summary_data = []
57
55
  details_data = []
@@ -1674,7 +1672,10 @@ def preprocess_generate_masks(src, settings={}):
1674
1672
  time_ls=[]
1675
1673
  if check_mask_folder(src, 'cell_mask_stack'):
1676
1674
  start = time.time()
1677
- generate_cellpose_masks(mask_src, settings, 'cell')
1675
+ if settings['segmantation_model'] == 'cellpose':
1676
+ generate_cellpose_masks(mask_src, settings, 'cell')
1677
+ elif settings['segmantation_model'] == 'mediar':
1678
+ generate_mediar_masks(mask_src, settings, 'cell')
1678
1679
  stop = time.time()
1679
1680
  duration = (stop - start)
1680
1681
  time_ls.append(duration)
@@ -1685,7 +1686,10 @@ def preprocess_generate_masks(src, settings={}):
1685
1686
  time_ls=[]
1686
1687
  if check_mask_folder(src, 'nucleus_mask_stack'):
1687
1688
  start = time.time()
1688
- generate_cellpose_masks(mask_src, settings, 'nucleus')
1689
+ if settings['segmantation_model'] == 'cellpose':
1690
+ generate_cellpose_masks(mask_src, settings, 'nucleus')
1691
+ elif settings['segmantation_model'] == 'mediar':
1692
+ generate_mediar_masks(mask_src, settings, 'nucleus')
1689
1693
  stop = time.time()
1690
1694
  duration = (stop - start)
1691
1695
  time_ls.append(duration)
@@ -1696,7 +1700,10 @@ def preprocess_generate_masks(src, settings={}):
1696
1700
  time_ls=[]
1697
1701
  if check_mask_folder(src, 'pathogen_mask_stack'):
1698
1702
  start = time.time()
1699
- generate_cellpose_masks(mask_src, settings, 'pathogen')
1703
+ if settings['segmantation_model'] == 'cellpose':
1704
+ generate_cellpose_masks(mask_src, settings, 'pathogen')
1705
+ elif settings['segmantation_model'] == 'mediar':
1706
+ generate_mediar_masks(mask_src, settings, 'pathogen')
1700
1707
  stop = time.time()
1701
1708
  duration = (stop - start)
1702
1709
  time_ls.append(duration)
@@ -1898,7 +1905,7 @@ def all_elements_match(list1, list2):
1898
1905
  # Check if all elements in list1 are in list2
1899
1906
  return all(element in list2 for element in list1)
1900
1907
 
1901
- def prepare_batch_for_cellpose(batch):
1908
+ def prepare_batch_for_segmentation(batch):
1902
1909
  # Ensure the batch is of dtype float32
1903
1910
  if batch.dtype != np.float32:
1904
1911
  batch = batch.astype(np.float32)
@@ -2021,7 +2028,7 @@ def generate_cellpose_masks(src, settings, object_type):
2021
2028
  if batch.size == 0:
2022
2029
  continue
2023
2030
 
2024
- batch = prepare_batch_for_cellpose(batch)
2031
+ batch = prepare_batch_for_segmentation(batch)
2025
2032
 
2026
2033
  if timelapse:
2027
2034
  movie_path = os.path.join(os.path.dirname(src), 'movies')
@@ -2180,7 +2187,7 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
2180
2187
  image_files = all_image_files[i:i+batch_size]
2181
2188
 
2182
2189
  if normalize:
2183
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise)
2190
+ images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise, target_height, target_width)
2184
2191
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2185
2192
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2186
2193
  else:
@@ -2300,7 +2307,7 @@ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
2300
2307
  from .io import _read_mask
2301
2308
 
2302
2309
  dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
2303
- dirs.sort() # Optional: sort directories if needed
2310
+ dirs.sort()
2304
2311
  conditions = [os.path.basename(d) for d in dirs]
2305
2312
 
2306
2313
  # Get common files in all directories
@@ -2316,7 +2323,7 @@ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
2316
2323
 
2317
2324
  # Filter out None results (from skipped files)
2318
2325
  results = [res for res in results if res is not None]
2319
- #print(results)
2326
+ print(results)
2320
2327
  if verbose:
2321
2328
  for result in results:
2322
2329
  filename = result['filename']
@@ -3102,4 +3109,92 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
3102
3109
  else:
3103
3110
  plt.show()
3104
3111
 
3105
- return
3112
+ return
3113
+
3114
+ def generate_mediar_masks(src, settings, object_type):
3115
+ """
3116
+ Generates masks using the MEDIARPredictor.
3117
+
3118
+ :param src: Source folder containing images or npz files.
3119
+ :param settings: Dictionary of settings for generating masks.
3120
+ :param object_type: Type of object to detect (e.g., 'cell', 'nucleus', etc.).
3121
+ """
3122
+ from .mediar import MEDIARPredictor
3123
+ from .io import _create_database, _save_object_counts_to_database
3124
+ from .plot import plot_masks
3125
+ from .settings import set_default_settings_preprocess_generate_masks, _get_object_settings
3126
+
3127
+ # Clear CUDA cache and check if CUDA is available
3128
+ gc.collect()
3129
+ if not torch.cuda.is_available():
3130
+ print(f'Torch CUDA is not available, using CPU')
3131
+
3132
+ # Preprocess settings
3133
+ settings = set_default_settings_preprocess_generate_masks(src, settings)
3134
+
3135
+ if settings['verbose']:
3136
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
3137
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
3138
+ display(settings_df)
3139
+
3140
+ figuresize = 10
3141
+ timelapse = settings['timelapse']
3142
+ batch_size = settings['batch_size']
3143
+
3144
+ # Get object settings and initialize MEDIARPredictor
3145
+ mediar_predictor = MEDIARPredictor(input_path=None, output_path=None, normalize=settings['normalize'], use_tta=False)
3146
+
3147
+ # Paths to input npz files
3148
+ paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
3149
+
3150
+ # Initialize a database for saving measurements
3151
+ count_loc = os.path.join(os.path.dirname(src), 'measurements', 'measurements.db')
3152
+ os.makedirs(os.path.dirname(src) + '/measurements', exist_ok=True)
3153
+ _create_database(count_loc)
3154
+
3155
+ for file_index, path in enumerate(paths):
3156
+ name = os.path.basename(path)
3157
+ name, ext = os.path.splitext(name)
3158
+ output_folder = os.path.join(os.path.dirname(path), f'{object_type}_mask_stack')
3159
+ os.makedirs(output_folder, exist_ok=True)
3160
+
3161
+ with np.load(path) as data:
3162
+ stack = data['data']
3163
+ filenames = data['filenames']
3164
+
3165
+ for i, filename in enumerate(filenames):
3166
+ output_path = os.path.join(output_folder, filename)
3167
+ if os.path.exists(output_path):
3168
+ print(f"File {filename} already exists. Skipping...")
3169
+ continue
3170
+
3171
+ # Process each batch of images in the stack
3172
+ for i in range(0, stack.shape[0], batch_size):
3173
+ batch = stack[i: i + batch_size]
3174
+ batch_filenames = filenames[i: i + batch_size]
3175
+
3176
+ # Prepare batch for MEDIARPredictor (optional)
3177
+ batch = prepare_batch_for_segmentation(batch)
3178
+
3179
+ # Predict masks using MEDIARPredictor
3180
+ predicted_masks = mediar_predictor.predict_batch(batch)
3181
+
3182
+ # Save predicted masks
3183
+ for j, mask in enumerate(predicted_masks):
3184
+ output_filename = os.path.join(output_folder, batch_filenames[j])
3185
+ mask = mask.astype(np.uint16)
3186
+ np.save(output_filename, mask)
3187
+
3188
+ # Optional: Plot the masks
3189
+ if settings['plot']:
3190
+ for idx, mask in enumerate(predicted_masks):
3191
+ plot_masks(batch[idx], mask, cmap='inferno', figuresize=figuresize)
3192
+
3193
+ # Save object counts to database
3194
+ _save_object_counts_to_database(predicted_masks, object_type, batch_filenames, count_loc)
3195
+
3196
+ # Clear CUDA cache after each file
3197
+ gc.collect()
3198
+ torch.cuda.empty_cache()
3199
+
3200
+ print("Mask generation completed.")
spacr/gui.py CHANGED
@@ -32,7 +32,7 @@ class MainApp(tk.Tk):
32
32
  # Set the window size to the dimensions of the monitor where it is located
33
33
  self.geometry(f"{width}x{height}")
34
34
  self.title("SpaCr GUI Collection")
35
- self.configure(bg='#333333') # Set window background to dark gray
35
+ self.configure(bg='#333333')
36
36
 
37
37
  style = ttk.Style()
38
38
  self.color_settings = set_dark_style(style, parent_frame=self)
@@ -55,7 +55,8 @@ class MainApp(tk.Tk):
55
55
  "Cellpose All": (lambda frame: initiate_root(self, 'cellpose_all'), "Run Cellpose on all images."),
56
56
  "Map Barcodes": (lambda frame: initiate_root(self, 'map_barcodes'), "Map barcodes to data."),
57
57
  "Regression": (lambda frame: initiate_root(self, 'regression'), "Perform regression analysis."),
58
- "Recruitment": (lambda frame: initiate_root(self, 'recruitment'), "Analyze recruitment data.")
58
+ "Recruitment": (lambda frame: initiate_root(self, 'recruitment'), "Analyze recruitment data."),
59
+ "Plaque": (lambda frame: initiate_root(self, 'analyze_plaques'), "Analyze plaque data.")
59
60
  }
60
61
 
61
62
  self.selected_app = tk.StringVar()
spacr/gui_core.py CHANGED
@@ -440,6 +440,8 @@ def import_settings(settings_type='mask'):
440
440
  settings = set_default_umap_image_settings(settings={})
441
441
  elif settings_type == 'recruitment':
442
442
  settings = get_analyze_recruitment_default_settings(settings={})
443
+ elif settings_type == 'analyze_plaques':
444
+ settings = {}
443
445
  else:
444
446
  raise ValueError(f"Invalid settings type: {settings_type}")
445
447
 
@@ -493,6 +495,8 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
493
495
  settings = get_perform_regression_default_settings(settings={})
494
496
  elif settings_type == 'recruitment':
495
497
  settings = get_analyze_recruitment_default_settings(settings={})
498
+ elif settings_type == 'analyze_plaques':
499
+ settings = {}
496
500
  else:
497
501
  raise ValueError(f"Invalid settings type: {settings_type}")
498
502
 
@@ -645,7 +649,7 @@ def setup_usage_panel(horizontal_container, btn_col, uppdate_frequency):
645
649
  widgets = [usage_scrollable_frame.scrollable_frame]
646
650
 
647
651
  usage_bars = []
648
- max_elements_per_column = 6
652
+ max_elements_per_column = 5
649
653
  row = 0
650
654
  col = 0
651
655
 
@@ -773,9 +777,9 @@ def start_process(q=None, fig_queue=None, settings_type='mask'):
773
777
  initialize_cuda()
774
778
 
775
779
  process_args = (settings_type, settings, q, fig_queue, stop_requested)
776
- if settings_type in ['mask', 'umap', 'measure', 'simulation', 'sequencing',
777
- 'classify', 'cellpose_dataset', 'train_cellpose', 'ml_analyze', 'cellpose_masks', 'cellpose_all', 'map_barcodes',
778
- 'regression', 'recruitment', 'plaques', 'cellpose_compare', 'vision_scores', 'vision_dataset']:
780
+ if settings_type in ['mask', 'umap', 'measure', 'simulation', 'sequencing', 'classify', 'analyze_plaques',
781
+ 'cellpose_dataset', 'train_cellpose', 'ml_analyze', 'cellpose_masks', 'cellpose_all', 'map_barcodes',
782
+ 'regression', 'recruitment', 'cellpose_compare', 'vision_scores', 'vision_dataset']:
779
783
 
780
784
  # Start the process
781
785
  process = Process(target=run_function_gui, args=process_args)
spacr/gui_utils.py CHANGED
@@ -486,7 +486,7 @@ def function_gui_wrapper(function=None, settings={}, q=None, fig_queue=None, imp
486
486
  def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
487
487
 
488
488
  from .gui_utils import process_stdout_stderr
489
- from .core import generate_image_umap, preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, compare_cellpose_masks, analyze_plaques, generate_dataset, apply_model_to_tar
489
+ from .core import generate_image_umap, preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, analyze_plaques, compare_cellpose_masks, generate_dataset, apply_model_to_tar
490
490
  from .io import generate_cellpose_train_test
491
491
  from .measure import measure_crop
492
492
  from .sim import run_multiple_simulations
@@ -532,6 +532,9 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
532
532
  elif settings_type == 'umap':
533
533
  function = generate_image_umap
534
534
  imports = 1
535
+ elif settings_type == 'analyze_plaques':
536
+ function = analyze_plaques
537
+ imports = 1
535
538
  else:
536
539
  raise ValueError(f"Invalid settings type: {settings_type}")
537
540
  try:
spacr/io.py CHANGED
@@ -26,7 +26,7 @@ import atexit
26
26
 
27
27
  from .logger import log_function_call
28
28
 
29
- def _load_images_and_labels(image_files, label_files, circular=False, invert=False, image_extension="*.tif", label_extension="*.tif"):
29
+ def _load_images_and_labels(image_files, label_files, circular=False, invert=False):
30
30
 
31
31
  from .utils import invert_image, apply_mask
32
32
 
spacr/measure.py CHANGED
@@ -998,10 +998,10 @@ def measure_crop(settings):
998
998
  src_fldr = os.path.join(src_fldr, 'merged')
999
999
  print(f"Changed source folder to: {src_fldr}")
1000
1000
 
1001
- if settings['save_measurements']:
1002
- source_folder = os.path.dirname(settings['src'])
1003
- os.makedirs(source_folder+'/measurements', exist_ok=True)
1004
- _create_database(source_folder+'/measurements/measurements.db')
1001
+ #if settings['save_measurements']:
1002
+ #source_folder = os.path.dirname(settings['src'])
1003
+ #os.makedirs(source_folder+'/measurements', exist_ok=True)
1004
+ #_create_database(source_folder+'/measurements/measurements.db')
1005
1005
 
1006
1006
  if settings['cell_mask_dim'] is None:
1007
1007
  settings['include_uninfected'] = True
spacr/mediar.py ADDED
@@ -0,0 +1,366 @@
1
+ import os, sys, gdown, cv2, torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from monai.inferers import sliding_window_inference
5
+ import skimage.io as io
6
+
7
+ # Path to the MEDIAR directory
8
+ mediar_path = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR')
9
+
10
+ print('mediar path', mediar_path)
11
+
12
+ # Temporarily create __init__.py to make MEDIAR a package
13
+ init_file = os.path.join(mediar_path, '__init__.py')
14
+ if not os.path.exists(init_file):
15
+ with open(init_file, 'w'): # Create the __init__.py file
16
+ pass
17
+
18
+ # Add MEDIAR to sys.path
19
+ sys.path.insert(0, mediar_path)
20
+
21
+ try:
22
+ # Now import the dependencies from MEDIAR
23
+ from core.MEDIAR import Predictor, EnsemblePredictor
24
+ from train_tools.models import MEDIARFormer
25
+
26
+ print("Imports successful.")
27
+ finally:
28
+ # Remove the temporary __init__.py file after the import
29
+ if os.path.exists(init_file):
30
+ os.remove(init_file) # Remove the __init__.py file
31
+
32
+ def display_imgs_in_list(lists_of_imgs, cmaps=None):
33
+ """
34
+ Displays images from multiple lists side by side.
35
+ Each row will display one image from each list (lists_of_imgs[i][j] is the j-th image in the i-th list).
36
+
37
+ :param lists_of_imgs: A list of lists, where each inner list contains images.
38
+ :param cmaps: List of colormaps to use for each list (optional). If not provided, defaults to 'gray' for all lists.
39
+ """
40
+ num_lists = len(lists_of_imgs)
41
+ num_images = len(lists_of_imgs[0])
42
+
43
+ # Ensure that all lists have the same number of images
44
+ for img_list in lists_of_imgs:
45
+ assert len(img_list) == num_images, "All inner lists must have the same number of images"
46
+
47
+ # Use 'gray' as the default colormap if cmaps are not provided
48
+ if cmaps is None:
49
+ cmaps = ['gray'] * num_lists
50
+ else:
51
+ assert len(cmaps) == num_lists, "The number of colormaps must match the number of lists"
52
+
53
+ plt.figure(figsize=(15, 5 * num_images))
54
+
55
+ for j in range(num_images):
56
+ for i, img_list in enumerate(lists_of_imgs):
57
+ img = img_list[j]
58
+ plt.subplot(num_images, num_lists, j * num_lists + i + 1)
59
+
60
+ if len(img.shape) == 2: # Grayscale image
61
+ plt.imshow(img, cmap=cmaps[i])
62
+ elif len(img.shape) == 3 and img.shape[0] == 3: # 3-channel image (C, H, W)
63
+ plt.imshow(img.transpose(1, 2, 0)) # Change shape to (H, W, C) for displaying
64
+ else:
65
+ plt.imshow(img)
66
+
67
+ plt.axis('off')
68
+ plt.title(f'Image {j+1} from list {i+1}')
69
+
70
+ plt.tight_layout()
71
+ plt.show()
72
+
73
+ def get_weights(finetuned_weights=False):
74
+ if finetuned_weights:
75
+ model_path1 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'from_phase1.pth')
76
+ if not os.path.exists(model_path1):
77
+ print("Downloading finetuned model 1...")
78
+ gdown.download('https://drive.google.com/uc?id=1JJ2-QKTCk-G7sp5ddkqcifMxgnyOrXjx', model_path1, quiet=False)
79
+ else:
80
+ model_path1 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'phase1.pth')
81
+ if not os.path.exists(model_path1):
82
+ print("Downloading model 1...")
83
+ gdown.download('https://drive.google.com/uc?id=1v5tYYJDqiwTn_mV0KyX5UEonlViSNx4i', model_path1, quiet=False)
84
+
85
+ if finetuned_weights:
86
+ model_path2 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'from_phase2.pth')
87
+ if not os.path.exists(model_path2):
88
+ print("Downloading finetuned model 2...")
89
+ gdown.download('https://drive.google.com/uc?id=168MtudjTMLoq9YGTyoD2Rjl_d3Gy6c_L', model_path2, quiet=False)
90
+ else:
91
+ model_path2 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'phase2.pth')
92
+ if not os.path.exists(model_path2):
93
+ print("Downloading model 2...")
94
+ gdown.download('https://drive.google.com/uc?id=1NHDaYvsYz3G0OCqzegT-bkNcly2clPGR', model_path2, quiet=False)
95
+
96
+ return model_path1, model_path2
97
+
98
+ def normalize_image(image, lower_percentile=0.0, upper_percentile=99.5):
99
+ """
100
+ Normalize an image based on the 0.0 and 99.5 percentiles.
101
+
102
+ :param image: Input image (numpy array).
103
+ :param lower_percentile: Lower percentile (default is 0.0).
104
+ :param upper_percentile: Upper percentile (default is 99.5).
105
+ :return: Normalized image (numpy array).
106
+ """
107
+ lower_bound = np.percentile(image, lower_percentile)
108
+ upper_bound = np.percentile(image, upper_percentile)
109
+
110
+ # Clip image values to the calculated percentiles
111
+ image = np.clip(image, lower_bound, upper_bound)
112
+
113
+ # Normalize to [0, 1]
114
+ image = (image - lower_bound) / (upper_bound - lower_bound + 1e-5) # Add small epsilon to avoid division by zero
115
+
116
+ return image
117
+
118
+ class MEDIARPredictor:
119
+ def __init__(self, input_path=None, output_path=None, device=None, model="ensemble", roi_size=512, overlap=0.6, finetuned_weights=False, test=False, use_tta=False, normalize=True, quantiles=[0.0, 99.5]):
120
+ if device is None:
121
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
122
+ self.device = device
123
+ self.test = test
124
+ self.model = model
125
+ self.normalize = normalize
126
+ self.quantiles = quantiles
127
+
128
+ # Paths to model weights
129
+ self.model1_path, self.model2_path = get_weights(finetuned_weights)
130
+
131
+ # Load main models
132
+ self.model1 = self.load_model(self.model1_path, device=self.device)
133
+ self.model2 = self.load_model(self.model2_path, device=self.device) if model == "ensemble" or model == "model2" else None
134
+ if self.test:
135
+ # Define input and output paths for running test
136
+ self.input_path = os.path.join(os.path.dirname(__file__), 'resources/images')
137
+ self.output_path = os.path.join(os.path.dirname(__file__), 'resources/MEDIAR/results')
138
+ else:
139
+ self.input_path = input_path
140
+ self.output_path = output_path
141
+
142
+ # If using a single model
143
+ if self.model == "model1":
144
+ self.predictor = Predictor(
145
+ model=self.model1,
146
+ device=self.device,
147
+ input_path=self.input_path,
148
+ output_path=self.output_path,
149
+ algo_params={"use_tta": use_tta}
150
+ )
151
+
152
+ # If using a single model
153
+ if self.model == "model2":
154
+ self.predictor = Predictor(
155
+ model=self.model2,
156
+ device=self.device,
157
+ input_path=self.input_path,
158
+ output_path=self.output_path,
159
+ algo_params={"use_tta": use_tta}
160
+ )
161
+
162
+ # If using two models
163
+ elif self.model == "ensemble":
164
+ self.predictor = EnsemblePredictor(
165
+ model=self.model1, # Pass model1 as model
166
+ model_aux=self.model2, # Pass model2 as model_aux
167
+ device=self.device,
168
+ input_path=self.input_path,
169
+ output_path=self.output_path,
170
+ algo_params={"use_tta": use_tta}
171
+ )
172
+
173
+ if self.test:
174
+ self.run_test()
175
+
176
+ if not self.model in ["model1", "model2", "ensemble"]:
177
+ raise ValueError("Invalid model type. Choose from 'model1', 'model2', or 'ensemble'.")
178
+
179
+ def load_model(self, model_path, device):
180
+ model_args = {
181
+ "classes": 3,
182
+ "decoder_channels": [1024, 512, 256, 128, 64],
183
+ "decoder_pab_channels": 256,
184
+ "encoder_name": 'mit_b5',
185
+ "in_channels": 3
186
+ }
187
+ model = MEDIARFormer(**model_args)
188
+ weights = torch.load(model_path, map_location=device)
189
+ model.load_state_dict(weights, strict=False)
190
+ model.to(device)
191
+ model.eval()
192
+ return model
193
+
194
+ def display_image_and_mask(self, img, mask):
195
+
196
+ from .plot import generate_mask_random_cmap
197
+ """
198
+ Displays the normalized input image and the predicted mask side by side.
199
+ """
200
+ # If img is a tensor, convert it to NumPy for display
201
+ if isinstance(img, torch.Tensor):
202
+ img = img.cpu().numpy()
203
+
204
+ # If mask is a tensor, convert it to NumPy for display
205
+ if isinstance(mask, torch.Tensor):
206
+ mask = mask.cpu().numpy()
207
+
208
+ # Transpose the image to have (H, W, C) format for display if needed
209
+ if len(img.shape) == 3 and img.shape[0] == 3:
210
+ img = img.transpose(1, 2, 0)
211
+
212
+ # Scale the normalized image back to [0, 255] for proper display
213
+ img_display = (img * 255).astype(np.uint8)
214
+
215
+ plt.figure(figsize=(10, 5))
216
+
217
+ # Display normalized image
218
+ plt.subplot(1, 2, 1)
219
+ plt.imshow(img_display)
220
+ plt.title("Normalized Image")
221
+ plt.axis("off")
222
+
223
+ r_cmap = generate_mask_random_cmap(mask)
224
+
225
+ # Display predicted mask
226
+ plt.subplot(1, 2, 2)
227
+ plt.imshow(mask, cmap=r_cmap)
228
+ plt.title("Predicted Mask")
229
+ plt.axis("off")
230
+
231
+ plt.tight_layout()
232
+ plt.show()
233
+
234
+ def predict_batch(self, imgs):
235
+ """
236
+ Predict masks for a batch of images.
237
+
238
+ :param imgs: List of input images as NumPy arrays (each in (H, W, C) format).
239
+ :return: List of predicted masks as NumPy arrays.
240
+ """
241
+ processed_imgs = []
242
+
243
+ # Preprocess and normalize each image
244
+ for img in imgs:
245
+ if self.normalize:
246
+ # Normalize the image using the specified quantiles
247
+ img_normalized = normalize_image(img, lower_percentile=self.quantiles[0], upper_percentile=self.quantiles[1])
248
+ else:
249
+ img_normalized = img
250
+
251
+ # Convert image to tensor and send to device
252
+ img_tensor = torch.tensor(img_normalized.astype(np.float32).transpose(2, 0, 1)).to(self.device) # (C, H, W)
253
+ processed_imgs.append(img_tensor)
254
+
255
+ # Stack all processed images into a batch tensor
256
+ batch_tensor = torch.stack(processed_imgs)
257
+
258
+ # Run inference to get predicted masks
259
+ pred_masks = self.predictor._inference(batch_tensor)
260
+
261
+ # Ensure pred_masks is always treated as a batch
262
+ if len(pred_masks.shape) == 3: # If single image, add batch dimension
263
+ pred_masks = pred_masks.unsqueeze(0)
264
+
265
+ # Convert predictions to NumPy arrays and post-process each mask
266
+ predicted_masks = []
267
+ for pred_mask in pred_masks:
268
+ pred_mask_np = pred_mask.cpu().numpy()
269
+
270
+ # Extract dP and cellprob from pred_mask
271
+ dP = pred_mask_np[:2] # First two channels as dP (displacement field)
272
+ cellprob = pred_mask_np[2] # Third channel as cell probability
273
+
274
+ # Concatenate dP and cellprob along axis 0 to pass a single array
275
+ combined_pred_mask = np.concatenate([dP, np.expand_dims(cellprob, axis=0)], axis=0)
276
+
277
+ # Post-process the predicted mask
278
+ mask = self.predictor._post_process(combined_pred_mask)
279
+
280
+ # Append the processed mask to the list
281
+ predicted_masks.append(mask.astype(np.uint16))
282
+
283
+ return predicted_masks
284
+
285
+ def run_test(self):
286
+ """
287
+ Run the model on test images if the test flag is True.
288
+ """
289
+ # List of input images
290
+ imgs = []
291
+ img_names = []
292
+
293
+ for img_file in os.listdir(self.input_path):
294
+ img_path = os.path.join(self.input_path, img_file)
295
+ img = io.imread(img_path)
296
+
297
+ # Check if the image is grayscale (2D) or RGB (3D), and convert grayscale to RGB
298
+ if len(img.shape) == 2: # Grayscale image (H, W)
299
+ img = np.repeat(img[:, :, np.newaxis], 3, axis=2) # Convert grayscale to RGB
300
+
301
+ # Normalize the image if the normalize flag is True
302
+ if self.normalize:
303
+ img_normalized = normalize_image(img, lower_percentile=self.quantiles[0], upper_percentile=self.quantiles[1])
304
+ else:
305
+ img_normalized = img
306
+
307
+ # Convert image to tensor and send directly to device
308
+ img_tensor = torch.tensor(img_normalized.astype(np.float32).transpose(2, 0, 1)).to(self.device) # (C, H, W)
309
+
310
+ imgs.append(img_tensor)
311
+ img_names.append(os.path.splitext(img_file)[0])
312
+
313
+ # Stack all images into a batch (ensure it's always treated as a batch)
314
+ batch_tensor = torch.stack(imgs)
315
+
316
+ # Predict using the predictor (or ensemble predictor)
317
+ pred_masks = self.predictor._inference(batch_tensor)
318
+
319
+ # Ensure pred_masks is always treated as a batch
320
+ if len(pred_masks.shape) == 3: # If single image, add batch dimension
321
+ pred_masks = pred_masks.unsqueeze(0)
322
+
323
+ # Convert predictions to NumPy arrays and post-process each mask
324
+ for i, pred_mask in enumerate(pred_masks):
325
+ # Ensure the dimensions of pred_mask remain consistent
326
+ pred_mask_np = pred_mask.cpu().numpy()
327
+
328
+ # Extract dP and cellprob from pred_mask
329
+ dP = pred_mask_np[:2] # First two channels as dP (displacement field)
330
+ cellprob = pred_mask_np[2] # Third channel as cell probability
331
+
332
+ # Concatenate dP and cellprob along axis 0 to pass a single array
333
+ combined_pred_mask = np.concatenate([dP, np.expand_dims(cellprob, axis=0)], axis=0)
334
+
335
+ # Post-process the predicted mask
336
+ mask = self.predictor._post_process(combined_pred_mask)
337
+
338
+ # Convert the mask to 16-bit format (ensure values fit into 16-bit range)
339
+ mask_to_save = mask.astype(np.uint16)
340
+
341
+ # Save the post-processed mask as a .tif file using cv2
342
+ mask_output_path = os.path.join(self.output_path, f"{img_names[i]}_mask.tiff")
343
+ cv2.imwrite(mask_output_path, mask_to_save)
344
+
345
+ print(f"Predicted mask saved at: {mask_output_path}")
346
+
347
+ self.display_image_and_mask(imgs[i].cpu().numpy(), mask)
348
+
349
+ print(f"Test predictions saved in {self.output_path}")
350
+
351
+ def preprocess_image(self, img):
352
+ """
353
+ Preprocess input image (numpy array) for compatibility with the model.
354
+ """
355
+ if isinstance(img, np.ndarray): # Check if the input is a numpy array
356
+ if len(img.shape) == 2: # Grayscale image (H, W)
357
+ img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
358
+
359
+ elif img.shape[2] == 1: # Single channel grayscale (H, W, 1)
360
+ img = np.repeat(img, 3, axis=2) # Convert to 3-channel RGB
361
+
362
+ img_tensor = torch.tensor(img.astype(np.float32).transpose(2, 0, 1)) # Change shape to (C, H, W)
363
+ else:
364
+ img_tensor = img # If it's already a tensor, assume it's in (C, H, W) format
365
+
366
+ return img_tensor.float()