spacr 0.2.81__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -1
- spacr/core.py +106 -11
- spacr/gui.py +3 -2
- spacr/gui_core.py +8 -4
- spacr/gui_utils.py +4 -1
- spacr/io.py +1 -1
- spacr/measure.py +4 -4
- spacr/mediar.py +366 -0
- spacr/plot.py +4 -1
- spacr/resources/MEDIAR/.git +1 -0
- spacr/resources/MEDIAR/.gitignore +18 -0
- spacr/resources/MEDIAR/LICENSE +21 -0
- spacr/resources/MEDIAR/README.md +189 -0
- spacr/resources/MEDIAR/SetupDict.py +39 -0
- spacr/resources/MEDIAR/config/baseline.json +60 -0
- spacr/resources/MEDIAR/config/mediar_example.json +72 -0
- spacr/resources/MEDIAR/config/pred/pred_mediar.json +17 -0
- spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +55 -0
- spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +58 -0
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +66 -0
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +66 -0
- spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +16 -0
- spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +23 -0
- spacr/resources/MEDIAR/core/BasePredictor.py +120 -0
- spacr/resources/MEDIAR/core/BaseTrainer.py +240 -0
- spacr/resources/MEDIAR/core/Baseline/Predictor.py +59 -0
- spacr/resources/MEDIAR/core/Baseline/Trainer.py +113 -0
- spacr/resources/MEDIAR/core/Baseline/__init__.py +2 -0
- spacr/resources/MEDIAR/core/Baseline/utils.py +80 -0
- spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +105 -0
- spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +234 -0
- spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +172 -0
- spacr/resources/MEDIAR/core/MEDIAR/__init__.py +3 -0
- spacr/resources/MEDIAR/core/MEDIAR/utils.py +429 -0
- spacr/resources/MEDIAR/core/__init__.py +2 -0
- spacr/resources/MEDIAR/core/utils.py +40 -0
- spacr/resources/MEDIAR/evaluate.py +71 -0
- spacr/resources/MEDIAR/generate_mapping.py +121 -0
- spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
- spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
- spacr/resources/MEDIAR/image/failure_cases.png +0 -0
- spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
- spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
- spacr/resources/MEDIAR/image/mediar_results.png +0 -0
- spacr/resources/MEDIAR/main.py +125 -0
- spacr/resources/MEDIAR/predict.py +70 -0
- spacr/resources/MEDIAR/requirements.txt +14 -0
- spacr/resources/MEDIAR/train_tools/__init__.py +3 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +1 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +88 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +161 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +77 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +3 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +208 -0
- spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +148 -0
- spacr/resources/MEDIAR/train_tools/data_utils/utils.py +84 -0
- spacr/resources/MEDIAR/train_tools/measures.py +200 -0
- spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +102 -0
- spacr/resources/MEDIAR/train_tools/models/__init__.py +1 -0
- spacr/resources/MEDIAR/train_tools/utils.py +70 -0
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/plaque.png +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif +0 -0
- spacr/settings.py +3 -1
- spacr/utils.py +10 -10
- {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/METADATA +9 -1
- {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/RECORD +75 -16
- {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/LICENSE +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/WHEEL +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.0.dist-info}/top_level.txt +0 -0
spacr/__init__.py
CHANGED
@@ -23,9 +23,9 @@ from . import app_measure
|
|
23
23
|
from . import app_classify
|
24
24
|
from . import app_sequencing
|
25
25
|
from . import app_umap
|
26
|
+
from . import mediar
|
26
27
|
from . import logger
|
27
28
|
|
28
|
-
|
29
29
|
__all__ = [
|
30
30
|
"core",
|
31
31
|
"io",
|
@@ -48,6 +48,7 @@ __all__ = [
|
|
48
48
|
"app_classify",
|
49
49
|
"app_sequencing",
|
50
50
|
"app_umap",
|
51
|
+
"mediar",
|
51
52
|
"logger"
|
52
53
|
]
|
53
54
|
|
spacr/core.py
CHANGED
@@ -50,8 +50,6 @@ import random
|
|
50
50
|
from PIL import Image
|
51
51
|
from torchvision.transforms import ToTensor
|
52
52
|
|
53
|
-
|
54
|
-
|
55
53
|
def analyze_plaques(folder):
|
56
54
|
summary_data = []
|
57
55
|
details_data = []
|
@@ -1674,7 +1672,10 @@ def preprocess_generate_masks(src, settings={}):
|
|
1674
1672
|
time_ls=[]
|
1675
1673
|
if check_mask_folder(src, 'cell_mask_stack'):
|
1676
1674
|
start = time.time()
|
1677
|
-
|
1675
|
+
if settings['segmantation_model'] == 'cellpose':
|
1676
|
+
generate_cellpose_masks(mask_src, settings, 'cell')
|
1677
|
+
elif settings['segmantation_model'] == 'mediar':
|
1678
|
+
generate_mediar_masks(mask_src, settings, 'cell')
|
1678
1679
|
stop = time.time()
|
1679
1680
|
duration = (stop - start)
|
1680
1681
|
time_ls.append(duration)
|
@@ -1685,7 +1686,10 @@ def preprocess_generate_masks(src, settings={}):
|
|
1685
1686
|
time_ls=[]
|
1686
1687
|
if check_mask_folder(src, 'nucleus_mask_stack'):
|
1687
1688
|
start = time.time()
|
1688
|
-
|
1689
|
+
if settings['segmantation_model'] == 'cellpose':
|
1690
|
+
generate_cellpose_masks(mask_src, settings, 'nucleus')
|
1691
|
+
elif settings['segmantation_model'] == 'mediar':
|
1692
|
+
generate_mediar_masks(mask_src, settings, 'nucleus')
|
1689
1693
|
stop = time.time()
|
1690
1694
|
duration = (stop - start)
|
1691
1695
|
time_ls.append(duration)
|
@@ -1696,7 +1700,10 @@ def preprocess_generate_masks(src, settings={}):
|
|
1696
1700
|
time_ls=[]
|
1697
1701
|
if check_mask_folder(src, 'pathogen_mask_stack'):
|
1698
1702
|
start = time.time()
|
1699
|
-
|
1703
|
+
if settings['segmantation_model'] == 'cellpose':
|
1704
|
+
generate_cellpose_masks(mask_src, settings, 'pathogen')
|
1705
|
+
elif settings['segmantation_model'] == 'mediar':
|
1706
|
+
generate_mediar_masks(mask_src, settings, 'pathogen')
|
1700
1707
|
stop = time.time()
|
1701
1708
|
duration = (stop - start)
|
1702
1709
|
time_ls.append(duration)
|
@@ -1898,7 +1905,7 @@ def all_elements_match(list1, list2):
|
|
1898
1905
|
# Check if all elements in list1 are in list2
|
1899
1906
|
return all(element in list2 for element in list1)
|
1900
1907
|
|
1901
|
-
def
|
1908
|
+
def prepare_batch_for_segmentation(batch):
|
1902
1909
|
# Ensure the batch is of dtype float32
|
1903
1910
|
if batch.dtype != np.float32:
|
1904
1911
|
batch = batch.astype(np.float32)
|
@@ -2021,7 +2028,7 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2021
2028
|
if batch.size == 0:
|
2022
2029
|
continue
|
2023
2030
|
|
2024
|
-
batch =
|
2031
|
+
batch = prepare_batch_for_segmentation(batch)
|
2025
2032
|
|
2026
2033
|
if timelapse:
|
2027
2034
|
movie_path = os.path.join(os.path.dirname(src), 'movies')
|
@@ -2180,7 +2187,7 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
|
|
2180
2187
|
image_files = all_image_files[i:i+batch_size]
|
2181
2188
|
|
2182
2189
|
if normalize:
|
2183
|
-
images, _, image_names, _ = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise)
|
2190
|
+
images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise, target_height, target_width)
|
2184
2191
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2185
2192
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2186
2193
|
else:
|
@@ -2300,7 +2307,7 @@ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
|
|
2300
2307
|
from .io import _read_mask
|
2301
2308
|
|
2302
2309
|
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
|
2303
|
-
dirs.sort()
|
2310
|
+
dirs.sort()
|
2304
2311
|
conditions = [os.path.basename(d) for d in dirs]
|
2305
2312
|
|
2306
2313
|
# Get common files in all directories
|
@@ -2316,7 +2323,7 @@ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
|
|
2316
2323
|
|
2317
2324
|
# Filter out None results (from skipped files)
|
2318
2325
|
results = [res for res in results if res is not None]
|
2319
|
-
|
2326
|
+
print(results)
|
2320
2327
|
if verbose:
|
2321
2328
|
for result in results:
|
2322
2329
|
filename = result['filename']
|
@@ -3102,4 +3109,92 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
|
|
3102
3109
|
else:
|
3103
3110
|
plt.show()
|
3104
3111
|
|
3105
|
-
return
|
3112
|
+
return
|
3113
|
+
|
3114
|
+
def generate_mediar_masks(src, settings, object_type):
|
3115
|
+
"""
|
3116
|
+
Generates masks using the MEDIARPredictor.
|
3117
|
+
|
3118
|
+
:param src: Source folder containing images or npz files.
|
3119
|
+
:param settings: Dictionary of settings for generating masks.
|
3120
|
+
:param object_type: Type of object to detect (e.g., 'cell', 'nucleus', etc.).
|
3121
|
+
"""
|
3122
|
+
from .mediar import MEDIARPredictor
|
3123
|
+
from .io import _create_database, _save_object_counts_to_database
|
3124
|
+
from .plot import plot_masks
|
3125
|
+
from .settings import set_default_settings_preprocess_generate_masks, _get_object_settings
|
3126
|
+
|
3127
|
+
# Clear CUDA cache and check if CUDA is available
|
3128
|
+
gc.collect()
|
3129
|
+
if not torch.cuda.is_available():
|
3130
|
+
print(f'Torch CUDA is not available, using CPU')
|
3131
|
+
|
3132
|
+
# Preprocess settings
|
3133
|
+
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
3134
|
+
|
3135
|
+
if settings['verbose']:
|
3136
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
3137
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
3138
|
+
display(settings_df)
|
3139
|
+
|
3140
|
+
figuresize = 10
|
3141
|
+
timelapse = settings['timelapse']
|
3142
|
+
batch_size = settings['batch_size']
|
3143
|
+
|
3144
|
+
# Get object settings and initialize MEDIARPredictor
|
3145
|
+
mediar_predictor = MEDIARPredictor(input_path=None, output_path=None, normalize=settings['normalize'], use_tta=False)
|
3146
|
+
|
3147
|
+
# Paths to input npz files
|
3148
|
+
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
3149
|
+
|
3150
|
+
# Initialize a database for saving measurements
|
3151
|
+
count_loc = os.path.join(os.path.dirname(src), 'measurements', 'measurements.db')
|
3152
|
+
os.makedirs(os.path.dirname(src) + '/measurements', exist_ok=True)
|
3153
|
+
_create_database(count_loc)
|
3154
|
+
|
3155
|
+
for file_index, path in enumerate(paths):
|
3156
|
+
name = os.path.basename(path)
|
3157
|
+
name, ext = os.path.splitext(name)
|
3158
|
+
output_folder = os.path.join(os.path.dirname(path), f'{object_type}_mask_stack')
|
3159
|
+
os.makedirs(output_folder, exist_ok=True)
|
3160
|
+
|
3161
|
+
with np.load(path) as data:
|
3162
|
+
stack = data['data']
|
3163
|
+
filenames = data['filenames']
|
3164
|
+
|
3165
|
+
for i, filename in enumerate(filenames):
|
3166
|
+
output_path = os.path.join(output_folder, filename)
|
3167
|
+
if os.path.exists(output_path):
|
3168
|
+
print(f"File {filename} already exists. Skipping...")
|
3169
|
+
continue
|
3170
|
+
|
3171
|
+
# Process each batch of images in the stack
|
3172
|
+
for i in range(0, stack.shape[0], batch_size):
|
3173
|
+
batch = stack[i: i + batch_size]
|
3174
|
+
batch_filenames = filenames[i: i + batch_size]
|
3175
|
+
|
3176
|
+
# Prepare batch for MEDIARPredictor (optional)
|
3177
|
+
batch = prepare_batch_for_segmentation(batch)
|
3178
|
+
|
3179
|
+
# Predict masks using MEDIARPredictor
|
3180
|
+
predicted_masks = mediar_predictor.predict_batch(batch)
|
3181
|
+
|
3182
|
+
# Save predicted masks
|
3183
|
+
for j, mask in enumerate(predicted_masks):
|
3184
|
+
output_filename = os.path.join(output_folder, batch_filenames[j])
|
3185
|
+
mask = mask.astype(np.uint16)
|
3186
|
+
np.save(output_filename, mask)
|
3187
|
+
|
3188
|
+
# Optional: Plot the masks
|
3189
|
+
if settings['plot']:
|
3190
|
+
for idx, mask in enumerate(predicted_masks):
|
3191
|
+
plot_masks(batch[idx], mask, cmap='inferno', figuresize=figuresize)
|
3192
|
+
|
3193
|
+
# Save object counts to database
|
3194
|
+
_save_object_counts_to_database(predicted_masks, object_type, batch_filenames, count_loc)
|
3195
|
+
|
3196
|
+
# Clear CUDA cache after each file
|
3197
|
+
gc.collect()
|
3198
|
+
torch.cuda.empty_cache()
|
3199
|
+
|
3200
|
+
print("Mask generation completed.")
|
spacr/gui.py
CHANGED
@@ -32,7 +32,7 @@ class MainApp(tk.Tk):
|
|
32
32
|
# Set the window size to the dimensions of the monitor where it is located
|
33
33
|
self.geometry(f"{width}x{height}")
|
34
34
|
self.title("SpaCr GUI Collection")
|
35
|
-
self.configure(bg='#333333')
|
35
|
+
self.configure(bg='#333333')
|
36
36
|
|
37
37
|
style = ttk.Style()
|
38
38
|
self.color_settings = set_dark_style(style, parent_frame=self)
|
@@ -55,7 +55,8 @@ class MainApp(tk.Tk):
|
|
55
55
|
"Cellpose All": (lambda frame: initiate_root(self, 'cellpose_all'), "Run Cellpose on all images."),
|
56
56
|
"Map Barcodes": (lambda frame: initiate_root(self, 'map_barcodes'), "Map barcodes to data."),
|
57
57
|
"Regression": (lambda frame: initiate_root(self, 'regression'), "Perform regression analysis."),
|
58
|
-
"Recruitment": (lambda frame: initiate_root(self, 'recruitment'), "Analyze recruitment data.")
|
58
|
+
"Recruitment": (lambda frame: initiate_root(self, 'recruitment'), "Analyze recruitment data."),
|
59
|
+
"Plaque": (lambda frame: initiate_root(self, 'analyze_plaques'), "Analyze plaque data.")
|
59
60
|
}
|
60
61
|
|
61
62
|
self.selected_app = tk.StringVar()
|
spacr/gui_core.py
CHANGED
@@ -440,6 +440,8 @@ def import_settings(settings_type='mask'):
|
|
440
440
|
settings = set_default_umap_image_settings(settings={})
|
441
441
|
elif settings_type == 'recruitment':
|
442
442
|
settings = get_analyze_recruitment_default_settings(settings={})
|
443
|
+
elif settings_type == 'analyze_plaques':
|
444
|
+
settings = {}
|
443
445
|
else:
|
444
446
|
raise ValueError(f"Invalid settings type: {settings_type}")
|
445
447
|
|
@@ -493,6 +495,8 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
|
|
493
495
|
settings = get_perform_regression_default_settings(settings={})
|
494
496
|
elif settings_type == 'recruitment':
|
495
497
|
settings = get_analyze_recruitment_default_settings(settings={})
|
498
|
+
elif settings_type == 'analyze_plaques':
|
499
|
+
settings = {}
|
496
500
|
else:
|
497
501
|
raise ValueError(f"Invalid settings type: {settings_type}")
|
498
502
|
|
@@ -645,7 +649,7 @@ def setup_usage_panel(horizontal_container, btn_col, uppdate_frequency):
|
|
645
649
|
widgets = [usage_scrollable_frame.scrollable_frame]
|
646
650
|
|
647
651
|
usage_bars = []
|
648
|
-
max_elements_per_column =
|
652
|
+
max_elements_per_column = 5
|
649
653
|
row = 0
|
650
654
|
col = 0
|
651
655
|
|
@@ -773,9 +777,9 @@ def start_process(q=None, fig_queue=None, settings_type='mask'):
|
|
773
777
|
initialize_cuda()
|
774
778
|
|
775
779
|
process_args = (settings_type, settings, q, fig_queue, stop_requested)
|
776
|
-
if settings_type in ['mask', 'umap', 'measure', 'simulation', 'sequencing',
|
777
|
-
'
|
778
|
-
'regression', 'recruitment', '
|
780
|
+
if settings_type in ['mask', 'umap', 'measure', 'simulation', 'sequencing', 'classify', 'analyze_plaques',
|
781
|
+
'cellpose_dataset', 'train_cellpose', 'ml_analyze', 'cellpose_masks', 'cellpose_all', 'map_barcodes',
|
782
|
+
'regression', 'recruitment', 'cellpose_compare', 'vision_scores', 'vision_dataset']:
|
779
783
|
|
780
784
|
# Start the process
|
781
785
|
process = Process(target=run_function_gui, args=process_args)
|
spacr/gui_utils.py
CHANGED
@@ -486,7 +486,7 @@ def function_gui_wrapper(function=None, settings={}, q=None, fig_queue=None, imp
|
|
486
486
|
def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
487
487
|
|
488
488
|
from .gui_utils import process_stdout_stderr
|
489
|
-
from .core import generate_image_umap, preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose,
|
489
|
+
from .core import generate_image_umap, preprocess_generate_masks, generate_ml_scores, identify_masks_finetune, check_cellpose_models, analyze_recruitment, train_cellpose, analyze_plaques, compare_cellpose_masks, generate_dataset, apply_model_to_tar
|
490
490
|
from .io import generate_cellpose_train_test
|
491
491
|
from .measure import measure_crop
|
492
492
|
from .sim import run_multiple_simulations
|
@@ -532,6 +532,9 @@ def run_function_gui(settings_type, settings, q, fig_queue, stop_requested):
|
|
532
532
|
elif settings_type == 'umap':
|
533
533
|
function = generate_image_umap
|
534
534
|
imports = 1
|
535
|
+
elif settings_type == 'analyze_plaques':
|
536
|
+
function = analyze_plaques
|
537
|
+
imports = 1
|
535
538
|
else:
|
536
539
|
raise ValueError(f"Invalid settings type: {settings_type}")
|
537
540
|
try:
|
spacr/io.py
CHANGED
@@ -26,7 +26,7 @@ import atexit
|
|
26
26
|
|
27
27
|
from .logger import log_function_call
|
28
28
|
|
29
|
-
def _load_images_and_labels(image_files, label_files, circular=False, invert=False
|
29
|
+
def _load_images_and_labels(image_files, label_files, circular=False, invert=False):
|
30
30
|
|
31
31
|
from .utils import invert_image, apply_mask
|
32
32
|
|
spacr/measure.py
CHANGED
@@ -998,10 +998,10 @@ def measure_crop(settings):
|
|
998
998
|
src_fldr = os.path.join(src_fldr, 'merged')
|
999
999
|
print(f"Changed source folder to: {src_fldr}")
|
1000
1000
|
|
1001
|
-
if settings['save_measurements']:
|
1002
|
-
source_folder = os.path.dirname(settings['src'])
|
1003
|
-
os.makedirs(source_folder+'/measurements', exist_ok=True)
|
1004
|
-
_create_database(source_folder+'/measurements/measurements.db')
|
1001
|
+
#if settings['save_measurements']:
|
1002
|
+
#source_folder = os.path.dirname(settings['src'])
|
1003
|
+
#os.makedirs(source_folder+'/measurements', exist_ok=True)
|
1004
|
+
#_create_database(source_folder+'/measurements/measurements.db')
|
1005
1005
|
|
1006
1006
|
if settings['cell_mask_dim'] is None:
|
1007
1007
|
settings['include_uninfected'] = True
|
spacr/mediar.py
ADDED
@@ -0,0 +1,366 @@
|
|
1
|
+
import os, sys, gdown, cv2, torch
|
2
|
+
import numpy as np
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
from monai.inferers import sliding_window_inference
|
5
|
+
import skimage.io as io
|
6
|
+
|
7
|
+
# Path to the MEDIAR directory
|
8
|
+
mediar_path = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR')
|
9
|
+
|
10
|
+
print('mediar path', mediar_path)
|
11
|
+
|
12
|
+
# Temporarily create __init__.py to make MEDIAR a package
|
13
|
+
init_file = os.path.join(mediar_path, '__init__.py')
|
14
|
+
if not os.path.exists(init_file):
|
15
|
+
with open(init_file, 'w'): # Create the __init__.py file
|
16
|
+
pass
|
17
|
+
|
18
|
+
# Add MEDIAR to sys.path
|
19
|
+
sys.path.insert(0, mediar_path)
|
20
|
+
|
21
|
+
try:
|
22
|
+
# Now import the dependencies from MEDIAR
|
23
|
+
from core.MEDIAR import Predictor, EnsemblePredictor
|
24
|
+
from train_tools.models import MEDIARFormer
|
25
|
+
|
26
|
+
print("Imports successful.")
|
27
|
+
finally:
|
28
|
+
# Remove the temporary __init__.py file after the import
|
29
|
+
if os.path.exists(init_file):
|
30
|
+
os.remove(init_file) # Remove the __init__.py file
|
31
|
+
|
32
|
+
def display_imgs_in_list(lists_of_imgs, cmaps=None):
|
33
|
+
"""
|
34
|
+
Displays images from multiple lists side by side.
|
35
|
+
Each row will display one image from each list (lists_of_imgs[i][j] is the j-th image in the i-th list).
|
36
|
+
|
37
|
+
:param lists_of_imgs: A list of lists, where each inner list contains images.
|
38
|
+
:param cmaps: List of colormaps to use for each list (optional). If not provided, defaults to 'gray' for all lists.
|
39
|
+
"""
|
40
|
+
num_lists = len(lists_of_imgs)
|
41
|
+
num_images = len(lists_of_imgs[0])
|
42
|
+
|
43
|
+
# Ensure that all lists have the same number of images
|
44
|
+
for img_list in lists_of_imgs:
|
45
|
+
assert len(img_list) == num_images, "All inner lists must have the same number of images"
|
46
|
+
|
47
|
+
# Use 'gray' as the default colormap if cmaps are not provided
|
48
|
+
if cmaps is None:
|
49
|
+
cmaps = ['gray'] * num_lists
|
50
|
+
else:
|
51
|
+
assert len(cmaps) == num_lists, "The number of colormaps must match the number of lists"
|
52
|
+
|
53
|
+
plt.figure(figsize=(15, 5 * num_images))
|
54
|
+
|
55
|
+
for j in range(num_images):
|
56
|
+
for i, img_list in enumerate(lists_of_imgs):
|
57
|
+
img = img_list[j]
|
58
|
+
plt.subplot(num_images, num_lists, j * num_lists + i + 1)
|
59
|
+
|
60
|
+
if len(img.shape) == 2: # Grayscale image
|
61
|
+
plt.imshow(img, cmap=cmaps[i])
|
62
|
+
elif len(img.shape) == 3 and img.shape[0] == 3: # 3-channel image (C, H, W)
|
63
|
+
plt.imshow(img.transpose(1, 2, 0)) # Change shape to (H, W, C) for displaying
|
64
|
+
else:
|
65
|
+
plt.imshow(img)
|
66
|
+
|
67
|
+
plt.axis('off')
|
68
|
+
plt.title(f'Image {j+1} from list {i+1}')
|
69
|
+
|
70
|
+
plt.tight_layout()
|
71
|
+
plt.show()
|
72
|
+
|
73
|
+
def get_weights(finetuned_weights=False):
|
74
|
+
if finetuned_weights:
|
75
|
+
model_path1 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'from_phase1.pth')
|
76
|
+
if not os.path.exists(model_path1):
|
77
|
+
print("Downloading finetuned model 1...")
|
78
|
+
gdown.download('https://drive.google.com/uc?id=1JJ2-QKTCk-G7sp5ddkqcifMxgnyOrXjx', model_path1, quiet=False)
|
79
|
+
else:
|
80
|
+
model_path1 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'phase1.pth')
|
81
|
+
if not os.path.exists(model_path1):
|
82
|
+
print("Downloading model 1...")
|
83
|
+
gdown.download('https://drive.google.com/uc?id=1v5tYYJDqiwTn_mV0KyX5UEonlViSNx4i', model_path1, quiet=False)
|
84
|
+
|
85
|
+
if finetuned_weights:
|
86
|
+
model_path2 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'from_phase2.pth')
|
87
|
+
if not os.path.exists(model_path2):
|
88
|
+
print("Downloading finetuned model 2...")
|
89
|
+
gdown.download('https://drive.google.com/uc?id=168MtudjTMLoq9YGTyoD2Rjl_d3Gy6c_L', model_path2, quiet=False)
|
90
|
+
else:
|
91
|
+
model_path2 = os.path.join(os.path.dirname(__file__), 'resources', 'MEDIAR_weights', 'phase2.pth')
|
92
|
+
if not os.path.exists(model_path2):
|
93
|
+
print("Downloading model 2...")
|
94
|
+
gdown.download('https://drive.google.com/uc?id=1NHDaYvsYz3G0OCqzegT-bkNcly2clPGR', model_path2, quiet=False)
|
95
|
+
|
96
|
+
return model_path1, model_path2
|
97
|
+
|
98
|
+
def normalize_image(image, lower_percentile=0.0, upper_percentile=99.5):
|
99
|
+
"""
|
100
|
+
Normalize an image based on the 0.0 and 99.5 percentiles.
|
101
|
+
|
102
|
+
:param image: Input image (numpy array).
|
103
|
+
:param lower_percentile: Lower percentile (default is 0.0).
|
104
|
+
:param upper_percentile: Upper percentile (default is 99.5).
|
105
|
+
:return: Normalized image (numpy array).
|
106
|
+
"""
|
107
|
+
lower_bound = np.percentile(image, lower_percentile)
|
108
|
+
upper_bound = np.percentile(image, upper_percentile)
|
109
|
+
|
110
|
+
# Clip image values to the calculated percentiles
|
111
|
+
image = np.clip(image, lower_bound, upper_bound)
|
112
|
+
|
113
|
+
# Normalize to [0, 1]
|
114
|
+
image = (image - lower_bound) / (upper_bound - lower_bound + 1e-5) # Add small epsilon to avoid division by zero
|
115
|
+
|
116
|
+
return image
|
117
|
+
|
118
|
+
class MEDIARPredictor:
|
119
|
+
def __init__(self, input_path=None, output_path=None, device=None, model="ensemble", roi_size=512, overlap=0.6, finetuned_weights=False, test=False, use_tta=False, normalize=True, quantiles=[0.0, 99.5]):
|
120
|
+
if device is None:
|
121
|
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
122
|
+
self.device = device
|
123
|
+
self.test = test
|
124
|
+
self.model = model
|
125
|
+
self.normalize = normalize
|
126
|
+
self.quantiles = quantiles
|
127
|
+
|
128
|
+
# Paths to model weights
|
129
|
+
self.model1_path, self.model2_path = get_weights(finetuned_weights)
|
130
|
+
|
131
|
+
# Load main models
|
132
|
+
self.model1 = self.load_model(self.model1_path, device=self.device)
|
133
|
+
self.model2 = self.load_model(self.model2_path, device=self.device) if model == "ensemble" or model == "model2" else None
|
134
|
+
if self.test:
|
135
|
+
# Define input and output paths for running test
|
136
|
+
self.input_path = os.path.join(os.path.dirname(__file__), 'resources/images')
|
137
|
+
self.output_path = os.path.join(os.path.dirname(__file__), 'resources/MEDIAR/results')
|
138
|
+
else:
|
139
|
+
self.input_path = input_path
|
140
|
+
self.output_path = output_path
|
141
|
+
|
142
|
+
# If using a single model
|
143
|
+
if self.model == "model1":
|
144
|
+
self.predictor = Predictor(
|
145
|
+
model=self.model1,
|
146
|
+
device=self.device,
|
147
|
+
input_path=self.input_path,
|
148
|
+
output_path=self.output_path,
|
149
|
+
algo_params={"use_tta": use_tta}
|
150
|
+
)
|
151
|
+
|
152
|
+
# If using a single model
|
153
|
+
if self.model == "model2":
|
154
|
+
self.predictor = Predictor(
|
155
|
+
model=self.model2,
|
156
|
+
device=self.device,
|
157
|
+
input_path=self.input_path,
|
158
|
+
output_path=self.output_path,
|
159
|
+
algo_params={"use_tta": use_tta}
|
160
|
+
)
|
161
|
+
|
162
|
+
# If using two models
|
163
|
+
elif self.model == "ensemble":
|
164
|
+
self.predictor = EnsemblePredictor(
|
165
|
+
model=self.model1, # Pass model1 as model
|
166
|
+
model_aux=self.model2, # Pass model2 as model_aux
|
167
|
+
device=self.device,
|
168
|
+
input_path=self.input_path,
|
169
|
+
output_path=self.output_path,
|
170
|
+
algo_params={"use_tta": use_tta}
|
171
|
+
)
|
172
|
+
|
173
|
+
if self.test:
|
174
|
+
self.run_test()
|
175
|
+
|
176
|
+
if not self.model in ["model1", "model2", "ensemble"]:
|
177
|
+
raise ValueError("Invalid model type. Choose from 'model1', 'model2', or 'ensemble'.")
|
178
|
+
|
179
|
+
def load_model(self, model_path, device):
|
180
|
+
model_args = {
|
181
|
+
"classes": 3,
|
182
|
+
"decoder_channels": [1024, 512, 256, 128, 64],
|
183
|
+
"decoder_pab_channels": 256,
|
184
|
+
"encoder_name": 'mit_b5',
|
185
|
+
"in_channels": 3
|
186
|
+
}
|
187
|
+
model = MEDIARFormer(**model_args)
|
188
|
+
weights = torch.load(model_path, map_location=device)
|
189
|
+
model.load_state_dict(weights, strict=False)
|
190
|
+
model.to(device)
|
191
|
+
model.eval()
|
192
|
+
return model
|
193
|
+
|
194
|
+
def display_image_and_mask(self, img, mask):
|
195
|
+
|
196
|
+
from .plot import generate_mask_random_cmap
|
197
|
+
"""
|
198
|
+
Displays the normalized input image and the predicted mask side by side.
|
199
|
+
"""
|
200
|
+
# If img is a tensor, convert it to NumPy for display
|
201
|
+
if isinstance(img, torch.Tensor):
|
202
|
+
img = img.cpu().numpy()
|
203
|
+
|
204
|
+
# If mask is a tensor, convert it to NumPy for display
|
205
|
+
if isinstance(mask, torch.Tensor):
|
206
|
+
mask = mask.cpu().numpy()
|
207
|
+
|
208
|
+
# Transpose the image to have (H, W, C) format for display if needed
|
209
|
+
if len(img.shape) == 3 and img.shape[0] == 3:
|
210
|
+
img = img.transpose(1, 2, 0)
|
211
|
+
|
212
|
+
# Scale the normalized image back to [0, 255] for proper display
|
213
|
+
img_display = (img * 255).astype(np.uint8)
|
214
|
+
|
215
|
+
plt.figure(figsize=(10, 5))
|
216
|
+
|
217
|
+
# Display normalized image
|
218
|
+
plt.subplot(1, 2, 1)
|
219
|
+
plt.imshow(img_display)
|
220
|
+
plt.title("Normalized Image")
|
221
|
+
plt.axis("off")
|
222
|
+
|
223
|
+
r_cmap = generate_mask_random_cmap(mask)
|
224
|
+
|
225
|
+
# Display predicted mask
|
226
|
+
plt.subplot(1, 2, 2)
|
227
|
+
plt.imshow(mask, cmap=r_cmap)
|
228
|
+
plt.title("Predicted Mask")
|
229
|
+
plt.axis("off")
|
230
|
+
|
231
|
+
plt.tight_layout()
|
232
|
+
plt.show()
|
233
|
+
|
234
|
+
def predict_batch(self, imgs):
|
235
|
+
"""
|
236
|
+
Predict masks for a batch of images.
|
237
|
+
|
238
|
+
:param imgs: List of input images as NumPy arrays (each in (H, W, C) format).
|
239
|
+
:return: List of predicted masks as NumPy arrays.
|
240
|
+
"""
|
241
|
+
processed_imgs = []
|
242
|
+
|
243
|
+
# Preprocess and normalize each image
|
244
|
+
for img in imgs:
|
245
|
+
if self.normalize:
|
246
|
+
# Normalize the image using the specified quantiles
|
247
|
+
img_normalized = normalize_image(img, lower_percentile=self.quantiles[0], upper_percentile=self.quantiles[1])
|
248
|
+
else:
|
249
|
+
img_normalized = img
|
250
|
+
|
251
|
+
# Convert image to tensor and send to device
|
252
|
+
img_tensor = torch.tensor(img_normalized.astype(np.float32).transpose(2, 0, 1)).to(self.device) # (C, H, W)
|
253
|
+
processed_imgs.append(img_tensor)
|
254
|
+
|
255
|
+
# Stack all processed images into a batch tensor
|
256
|
+
batch_tensor = torch.stack(processed_imgs)
|
257
|
+
|
258
|
+
# Run inference to get predicted masks
|
259
|
+
pred_masks = self.predictor._inference(batch_tensor)
|
260
|
+
|
261
|
+
# Ensure pred_masks is always treated as a batch
|
262
|
+
if len(pred_masks.shape) == 3: # If single image, add batch dimension
|
263
|
+
pred_masks = pred_masks.unsqueeze(0)
|
264
|
+
|
265
|
+
# Convert predictions to NumPy arrays and post-process each mask
|
266
|
+
predicted_masks = []
|
267
|
+
for pred_mask in pred_masks:
|
268
|
+
pred_mask_np = pred_mask.cpu().numpy()
|
269
|
+
|
270
|
+
# Extract dP and cellprob from pred_mask
|
271
|
+
dP = pred_mask_np[:2] # First two channels as dP (displacement field)
|
272
|
+
cellprob = pred_mask_np[2] # Third channel as cell probability
|
273
|
+
|
274
|
+
# Concatenate dP and cellprob along axis 0 to pass a single array
|
275
|
+
combined_pred_mask = np.concatenate([dP, np.expand_dims(cellprob, axis=0)], axis=0)
|
276
|
+
|
277
|
+
# Post-process the predicted mask
|
278
|
+
mask = self.predictor._post_process(combined_pred_mask)
|
279
|
+
|
280
|
+
# Append the processed mask to the list
|
281
|
+
predicted_masks.append(mask.astype(np.uint16))
|
282
|
+
|
283
|
+
return predicted_masks
|
284
|
+
|
285
|
+
def run_test(self):
|
286
|
+
"""
|
287
|
+
Run the model on test images if the test flag is True.
|
288
|
+
"""
|
289
|
+
# List of input images
|
290
|
+
imgs = []
|
291
|
+
img_names = []
|
292
|
+
|
293
|
+
for img_file in os.listdir(self.input_path):
|
294
|
+
img_path = os.path.join(self.input_path, img_file)
|
295
|
+
img = io.imread(img_path)
|
296
|
+
|
297
|
+
# Check if the image is grayscale (2D) or RGB (3D), and convert grayscale to RGB
|
298
|
+
if len(img.shape) == 2: # Grayscale image (H, W)
|
299
|
+
img = np.repeat(img[:, :, np.newaxis], 3, axis=2) # Convert grayscale to RGB
|
300
|
+
|
301
|
+
# Normalize the image if the normalize flag is True
|
302
|
+
if self.normalize:
|
303
|
+
img_normalized = normalize_image(img, lower_percentile=self.quantiles[0], upper_percentile=self.quantiles[1])
|
304
|
+
else:
|
305
|
+
img_normalized = img
|
306
|
+
|
307
|
+
# Convert image to tensor and send directly to device
|
308
|
+
img_tensor = torch.tensor(img_normalized.astype(np.float32).transpose(2, 0, 1)).to(self.device) # (C, H, W)
|
309
|
+
|
310
|
+
imgs.append(img_tensor)
|
311
|
+
img_names.append(os.path.splitext(img_file)[0])
|
312
|
+
|
313
|
+
# Stack all images into a batch (ensure it's always treated as a batch)
|
314
|
+
batch_tensor = torch.stack(imgs)
|
315
|
+
|
316
|
+
# Predict using the predictor (or ensemble predictor)
|
317
|
+
pred_masks = self.predictor._inference(batch_tensor)
|
318
|
+
|
319
|
+
# Ensure pred_masks is always treated as a batch
|
320
|
+
if len(pred_masks.shape) == 3: # If single image, add batch dimension
|
321
|
+
pred_masks = pred_masks.unsqueeze(0)
|
322
|
+
|
323
|
+
# Convert predictions to NumPy arrays and post-process each mask
|
324
|
+
for i, pred_mask in enumerate(pred_masks):
|
325
|
+
# Ensure the dimensions of pred_mask remain consistent
|
326
|
+
pred_mask_np = pred_mask.cpu().numpy()
|
327
|
+
|
328
|
+
# Extract dP and cellprob from pred_mask
|
329
|
+
dP = pred_mask_np[:2] # First two channels as dP (displacement field)
|
330
|
+
cellprob = pred_mask_np[2] # Third channel as cell probability
|
331
|
+
|
332
|
+
# Concatenate dP and cellprob along axis 0 to pass a single array
|
333
|
+
combined_pred_mask = np.concatenate([dP, np.expand_dims(cellprob, axis=0)], axis=0)
|
334
|
+
|
335
|
+
# Post-process the predicted mask
|
336
|
+
mask = self.predictor._post_process(combined_pred_mask)
|
337
|
+
|
338
|
+
# Convert the mask to 16-bit format (ensure values fit into 16-bit range)
|
339
|
+
mask_to_save = mask.astype(np.uint16)
|
340
|
+
|
341
|
+
# Save the post-processed mask as a .tif file using cv2
|
342
|
+
mask_output_path = os.path.join(self.output_path, f"{img_names[i]}_mask.tiff")
|
343
|
+
cv2.imwrite(mask_output_path, mask_to_save)
|
344
|
+
|
345
|
+
print(f"Predicted mask saved at: {mask_output_path}")
|
346
|
+
|
347
|
+
self.display_image_and_mask(imgs[i].cpu().numpy(), mask)
|
348
|
+
|
349
|
+
print(f"Test predictions saved in {self.output_path}")
|
350
|
+
|
351
|
+
def preprocess_image(self, img):
|
352
|
+
"""
|
353
|
+
Preprocess input image (numpy array) for compatibility with the model.
|
354
|
+
"""
|
355
|
+
if isinstance(img, np.ndarray): # Check if the input is a numpy array
|
356
|
+
if len(img.shape) == 2: # Grayscale image (H, W)
|
357
|
+
img = np.repeat(img[:, :, np.newaxis], 3, axis=2)
|
358
|
+
|
359
|
+
elif img.shape[2] == 1: # Single channel grayscale (H, W, 1)
|
360
|
+
img = np.repeat(img, 3, axis=2) # Convert to 3-channel RGB
|
361
|
+
|
362
|
+
img_tensor = torch.tensor(img.astype(np.float32).transpose(2, 0, 1)) # Change shape to (C, H, W)
|
363
|
+
else:
|
364
|
+
img_tensor = img # If it's already a tensor, assume it's in (C, H, W) format
|
365
|
+
|
366
|
+
return img_tensor.float()
|