spacr 1.0.9__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/app_classify.py +10 -0
- spacr/app_mask.py +9 -0
- spacr/app_measure.py +9 -0
- spacr/app_sequencing.py +9 -0
- spacr/core.py +172 -1
- spacr/deep_spacr.py +296 -7
- spacr/gui.py +68 -0
- spacr/gui_core.py +319 -10
- spacr/gui_elements.py +772 -13
- spacr/gui_utils.py +304 -12
- spacr/io.py +887 -71
- spacr/logger.py +36 -0
- spacr/measure.py +206 -28
- spacr/ml.py +606 -142
- spacr/plot.py +797 -131
- spacr/sequencing.py +363 -8
- spacr/settings.py +1158 -38
- spacr/sp_stats.py +80 -12
- spacr/spacr_cellpose.py +115 -2
- spacr/submodules.py +747 -19
- spacr/timelapse.py +237 -53
- spacr/toxo.py +132 -6
- spacr/utils.py +2422 -80
- {spacr-1.0.9.dist-info → spacr-1.1.0.dist-info}/METADATA +31 -17
- {spacr-1.0.9.dist-info → spacr-1.1.0.dist-info}/RECORD +29 -29
- {spacr-1.0.9.dist-info → spacr-1.1.0.dist-info}/LICENSE +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.0.dist-info}/WHEEL +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.0.dist-info}/entry_points.txt +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.0.dist-info}/top_level.txt +0 -0
spacr/app_classify.py
CHANGED
@@ -1,6 +1,16 @@
|
|
1
1
|
from .gui import MainApp
|
2
2
|
|
3
3
|
def start_classify_app():
|
4
|
+
"""
|
5
|
+
Launch the spaCR GUI with the Classify application preloaded.
|
6
|
+
|
7
|
+
This function initializes the main GUI window with "Classify" set as the default active application.
|
8
|
+
It then starts the Tkinter main event loop to display the interface.
|
9
|
+
|
10
|
+
Typical use case:
|
11
|
+
Called from the command line or another script to directly launch the Classify module of spaCR.
|
12
|
+
|
13
|
+
"""
|
4
14
|
app = MainApp(default_app="Classify")
|
5
15
|
app.mainloop()
|
6
16
|
|
spacr/app_mask.py
CHANGED
@@ -1,6 +1,15 @@
|
|
1
1
|
from .gui import MainApp
|
2
2
|
|
3
3
|
def start_mask_app():
|
4
|
+
"""
|
5
|
+
Launch the spaCR GUI with the Mask application preloaded.
|
6
|
+
|
7
|
+
This function initializes the main GUI window with "Mask" selected as the default active module.
|
8
|
+
It is intended for users who want to directly start the application in mask generation mode.
|
9
|
+
|
10
|
+
Typical use case:
|
11
|
+
Called from the command line or another script to launch the mask generation workflow of spaCR.
|
12
|
+
"""
|
4
13
|
app = MainApp(default_app="Mask")
|
5
14
|
app.mainloop()
|
6
15
|
|
spacr/app_measure.py
CHANGED
@@ -1,6 +1,15 @@
|
|
1
1
|
from .gui import MainApp
|
2
2
|
|
3
3
|
def start_measure_app():
|
4
|
+
"""
|
5
|
+
Launch the spaCR GUI with the Measure application preloaded.
|
6
|
+
|
7
|
+
This function initializes the main GUI window with "Measure" selected as the default active module.
|
8
|
+
It is used to directly open the application in object measurement mode.
|
9
|
+
|
10
|
+
Typical use case:
|
11
|
+
Called from the command line or another script to start spaCR in measurement mode.
|
12
|
+
"""
|
4
13
|
app = MainApp(default_app="Measure")
|
5
14
|
app.mainloop()
|
6
15
|
|
spacr/app_sequencing.py
CHANGED
@@ -1,6 +1,15 @@
|
|
1
1
|
from .gui import MainApp
|
2
2
|
|
3
3
|
def start_seq_app():
|
4
|
+
"""
|
5
|
+
Launch the spaCR GUI with the Measure application preloaded.
|
6
|
+
|
7
|
+
This function initializes the main GUI window with "Measure" selected as the default active module.
|
8
|
+
It is used to directly open the application in object measurement mode.
|
9
|
+
|
10
|
+
Typical use case:
|
11
|
+
Called from the command line or another script to start spaCR in measurement mode.
|
12
|
+
"""
|
4
13
|
app = MainApp(default_app="Sequencing")
|
5
14
|
app.mainloop()
|
6
15
|
|
spacr/core.py
CHANGED
@@ -7,6 +7,99 @@ import warnings
|
|
7
7
|
warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
|
8
8
|
|
9
9
|
def preprocess_generate_masks(settings):
|
10
|
+
|
11
|
+
"""
|
12
|
+
Preprocess image data and generate Cellpose segmentation masks for cells, nuclei, and pathogens.
|
13
|
+
|
14
|
+
This function supports preprocessing, metadata conversion, Cellpose-based mask generation, optional
|
15
|
+
mask adjustment, result plotting, and intermediate file cleanup. It handles batch operations and
|
16
|
+
supports advanced timelapse and channel-specific configurations.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
settings (dict): Dictionary containing the following keys:
|
20
|
+
|
21
|
+
General settings:
|
22
|
+
- src (str or list): Path(s) to input folders. Required.
|
23
|
+
- denoise (bool): Apply denoising during preprocessing. Default is False.
|
24
|
+
- delete_intermediate (bool): Delete intermediate files after processing. Default is False.
|
25
|
+
- preprocess (bool): Perform preprocessing. Default is True.
|
26
|
+
- masks (bool): Generate masks using Cellpose. Default is True.
|
27
|
+
- save (bool or list of bool): Whether to save outputs per object type. Default is True.
|
28
|
+
- consolidate (bool): Consolidate input folder structure. Default is False.
|
29
|
+
- batch_size (int): Number of files processed per batch. Default is 50.
|
30
|
+
- test_mode (bool): Enable test mode with limited data. Default is False.
|
31
|
+
- test_images (int): Number of test images to use. Default is 10.
|
32
|
+
- magnification (int): Magnification of input data. Default is 20.
|
33
|
+
- custom_regex (str or None): Regex for filename parsing in auto metadata mode.
|
34
|
+
- metadata_type (str): Metadata type; "cellvoyager" or "auto". Default is "cellvoyager".
|
35
|
+
- n_jobs (int): Number of parallel processes. Default is os.cpu_count() - 4.
|
36
|
+
- randomize (bool): Randomize processing order. Default is True.
|
37
|
+
- verbose (bool): Print full settings table. Default is True.
|
38
|
+
|
39
|
+
Channel background correction:
|
40
|
+
- remove_background_cell (bool): Remove background from cell channel. Default is False.
|
41
|
+
- remove_background_nucleus (bool): Remove background from nucleus channel. Default is False.
|
42
|
+
- remove_background_pathogen (bool): Remove background from pathogen channel. Default is False.
|
43
|
+
|
44
|
+
Channel diameter and index settings:
|
45
|
+
- cell_diamiter (float or None): Cell diameter estimate for Cellpose.
|
46
|
+
- nucleus_diamiter (float or None): Nucleus diameter estimate for Cellpose.
|
47
|
+
- pathogen_diamiter (float or None): Pathogen diameter estimate for Cellpose.
|
48
|
+
- cell_channel (int or None): Channel index for cell. Default is None.
|
49
|
+
- nucleus_channel (int or None): Channel index for nucleus. Default is None.
|
50
|
+
- pathogen_channel (int or None): Channel index for pathogen. Default is None.
|
51
|
+
- channels (list): List of channel indices to include. Default is [0, 1, 2, 3].
|
52
|
+
|
53
|
+
Cellpose parameters:
|
54
|
+
- pathogen_background (float): Background intensity for pathogen. Default is 100.
|
55
|
+
- pathogen_Signal_to_noise (float): SNR threshold for pathogen. Default is 10.
|
56
|
+
- pathogen_CP_prob (float): Cellpose probability threshold for pathogen. Default is 0.
|
57
|
+
- cell_background (float): Background intensity for cell. Default is 100.
|
58
|
+
- cell_Signal_to_noise (float): SNR threshold for cell. Default is 10.
|
59
|
+
- cell_CP_prob (float): Cellpose probability threshold for cell. Default is 0.
|
60
|
+
- nucleus_background (float): Background intensity for nucleus. Default is 100.
|
61
|
+
- nucleus_Signal_to_noise (float): SNR threshold for nucleus. Default is 10.
|
62
|
+
- nucleus_CP_prob (float): Cellpose probability threshold for nucleus. Default is 0.
|
63
|
+
- nucleus_FT (float): Intensity scaling factor for nucleus. Default is 1.0.
|
64
|
+
- cell_FT (float): Intensity scaling factor for cell. Default is 1.0.
|
65
|
+
- pathogen_FT (float): Intensity scaling factor for pathogen. Default is 1.0.
|
66
|
+
|
67
|
+
Plotting settings:
|
68
|
+
- plot (bool): Enable plotting. Default is False.
|
69
|
+
- figuresize (int or float): Figure size for plots. Default is 10.
|
70
|
+
- cmap (str): Colormap used for plotting. Default is "inferno".
|
71
|
+
- normalize (bool): Normalize image intensities before processing. Default is True.
|
72
|
+
- normalize_plots (bool): Normalize intensity for plotting. Default is True.
|
73
|
+
- examples_to_plot (int): Number of examples to plot. Default is 1.
|
74
|
+
|
75
|
+
Analysis settings:
|
76
|
+
- pathogen_model (str or None): Custom model for pathogen ("toxo_pv_lumen" or "toxo_cyto").
|
77
|
+
- merge_pathogens (bool): Whether to merge multiple pathogen types. Default is False.
|
78
|
+
- filter (bool): Apply percentile filter. Default is False.
|
79
|
+
- lower_percentile (float): Lower percentile for intensity filtering. Default is 2.
|
80
|
+
|
81
|
+
Timelapse settings:
|
82
|
+
- timelapse (bool): Enable timelapse mode. Default is False.
|
83
|
+
- fps (int): Frames per second for timelapse export. Default is 2.
|
84
|
+
- timelapse_displacement (float or None): Max displacement for object tracking.
|
85
|
+
- timelapse_memory (int): Memory for tracking algorithm. Default is 3.
|
86
|
+
- timelapse_frame_limits (list): Frame limits for tracking. Default is [5].
|
87
|
+
- timelapse_remove_transient (bool): Remove short-lived objects. Default is False.
|
88
|
+
- timelapse_mode (str): Tracking algorithm. Default is "trackpy".
|
89
|
+
- timelapse_objects (str or None): Object type for tracking.
|
90
|
+
|
91
|
+
Miscellaneous:
|
92
|
+
- all_to_mip (bool): Convert all input to MIP. Default is False.
|
93
|
+
- upscale (bool): Upscale images prior to processing. Default is False.
|
94
|
+
- upscale_factor (float): Upscaling factor. Default is 2.0.
|
95
|
+
- adjust_cells (bool): Adjust cell masks based on nuclei and pathogen. Default is False.
|
96
|
+
- use_sam_cell (bool): Use SAM model for cell segmentation. Default is False.
|
97
|
+
- use_sam_nucleus (bool): Use SAM model for nucleus segmentation. Default is False.
|
98
|
+
- use_sam_pathogen (bool): Use SAM model for pathogen segmentation. Default is False.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
None: All outputs (masks, merged arrays, plots, databases) are saved to disk under the source folder(s).
|
102
|
+
"""
|
10
103
|
|
11
104
|
from .io import preprocess_img_data, _load_and_concatenate_arrays, convert_to_yokogawa, convert_separate_files_to_yokogawa
|
12
105
|
from .plot import plot_image_mask_overlay, plot_arrays
|
@@ -194,7 +287,85 @@ def preprocess_generate_masks(settings):
|
|
194
287
|
return
|
195
288
|
|
196
289
|
def generate_cellpose_masks(src, settings, object_type):
|
197
|
-
|
290
|
+
"""
|
291
|
+
Generate segmentation masks for a specific object type using Cellpose.
|
292
|
+
|
293
|
+
This function applies a Cellpose-based segmentation pipeline to images in `.npz` format, using settings
|
294
|
+
for batch size, object type (cell, nucleus, pathogen), filtering, plotting, and timelapse options.
|
295
|
+
Masks are optionally filtered, saved, tracked (for timelapse), and summarized into a SQLite database.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
src (str): Path to the source folder containing `.npz` files with image stacks.
|
299
|
+
settings (dict): Dictionary of settings used to control preprocessing and segmentation. Includes:
|
300
|
+
|
301
|
+
General settings:
|
302
|
+
- src (str): Source directory.
|
303
|
+
- denoise (bool): Apply denoising before processing.
|
304
|
+
- delete_intermediate (bool): Remove intermediate files after processing.
|
305
|
+
- preprocess (bool): Enable preprocessing.
|
306
|
+
- masks (bool): Enable mask generation.
|
307
|
+
- save (bool): Save mask outputs.
|
308
|
+
- consolidate (bool): Consolidate image folders.
|
309
|
+
- batch_size (int): Batch size for processing.
|
310
|
+
- test_mode (bool): Enable test mode with limited image count.
|
311
|
+
- test_images (int): Number of test images to process.
|
312
|
+
- magnification (int): Image magnification level.
|
313
|
+
- custom_regex (str or None): Regex pattern for file parsing (metadata_type = 'auto').
|
314
|
+
- metadata_type (str): One of "cellvoyager" or "auto".
|
315
|
+
- n_jobs (int): Number of parallel workers.
|
316
|
+
- randomize (bool): Shuffle file order before processing.
|
317
|
+
- verbose (bool): Print full settings to console.
|
318
|
+
|
319
|
+
Channel/background/cellpose settings:
|
320
|
+
- remove_background_cell/nucleus/pathogen (bool): Whether to subtract background from channel.
|
321
|
+
- cell_diamiter / nucleus_diamiter / pathogen_diamiter (float or None): Estimated diameter.
|
322
|
+
- cell_channel / nucleus_channel / pathogen_channel (int or None): Channel index.
|
323
|
+
- channels (list): List of channels to include in stack.
|
324
|
+
- cell/background/SNR/CP_prob/FT (float): Intensity/cellpose thresholds and scaling.
|
325
|
+
- pathogen_model (str or None): Custom model for pathogen segmentation (e.g. "toxo_pv_lumen").
|
326
|
+
|
327
|
+
Plotting:
|
328
|
+
- plot (bool): Plot masks or overlay visualizations.
|
329
|
+
- figuresize (int): Matplotlib figure size.
|
330
|
+
- cmap (str): Colormap to use (e.g. "inferno").
|
331
|
+
- normalize (bool): Normalize input intensities.
|
332
|
+
- normalize_plots (bool): Normalize for plots.
|
333
|
+
- examples_to_plot (int): How many examples to plot.
|
334
|
+
|
335
|
+
Filtering and merging:
|
336
|
+
- merge_pathogens (bool): Whether to merge pathogen objects.
|
337
|
+
- filter (bool): Apply filtering on masks.
|
338
|
+
- lower_percentile (float): Intensity filter threshold.
|
339
|
+
- merge (bool): Merge adjacent objects if needed.
|
340
|
+
|
341
|
+
Timelapse:
|
342
|
+
- timelapse (bool): Enable object tracking across timepoints.
|
343
|
+
- timelapse_displacement (float or None): Max tracking displacement.
|
344
|
+
- timelapse_memory (int): Trackpy memory.
|
345
|
+
- timelapse_frame_limits (list): Frames to include in timelapse batch.
|
346
|
+
- timelapse_remove_transient (bool): Remove transient objects.
|
347
|
+
- timelapse_mode (str): One of "trackpy", "btrack", or "iou".
|
348
|
+
- timelapse_objects (list or None): Subset of ['cell', 'nucleus', 'pathogen'] to track.
|
349
|
+
|
350
|
+
Miscellaneous:
|
351
|
+
- all_to_mip (bool): Convert Z-stacks to max projections.
|
352
|
+
- upscale (bool): Apply upscaling.
|
353
|
+
- upscale_factor (float): Upscaling factor.
|
354
|
+
- adjust_cells (bool): Refine cell masks with nucleus/pathogen.
|
355
|
+
- use_sam_cell/nucleus/pathogen (bool): Use SAM for mask generation.
|
356
|
+
|
357
|
+
object_type (str): One of 'cell', 'nucleus', or 'pathogen'. Determines which mask to generate.
|
358
|
+
|
359
|
+
Returns:
|
360
|
+
None. Outputs are saved to disk:
|
361
|
+
- Generated masks are stored in a `*_mask_stack/` folder.
|
362
|
+
- Object counts are written to `measurements/measurements.db`.
|
363
|
+
- Optional overlay plots are saved if enabled.
|
364
|
+
- Optional timelapse movies are saved in `movies/`.
|
365
|
+
|
366
|
+
Raises:
|
367
|
+
ValueError: If the object_type is missing from the computed channel map, or if invalid tracking settings are provided.
|
368
|
+
"""
|
198
369
|
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, all_elements_match, prepare_batch_for_segmentation
|
199
370
|
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
200
371
|
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
spacr/deep_spacr.py
CHANGED
@@ -16,7 +16,34 @@ from torchvision import transforms
|
|
16
16
|
from torch.utils.data import DataLoader
|
17
17
|
|
18
18
|
def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
|
19
|
-
|
19
|
+
"""
|
20
|
+
Apply a trained binary classification model to a folder of images.
|
21
|
+
|
22
|
+
Loads a PyTorch model and applies it to images in the specified folder using batch inference.
|
23
|
+
Supports optional normalization and GPU acceleration. Outputs prediction probabilities and
|
24
|
+
saves results as a CSV file alongside the model.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
src (str): Path to a folder containing input images (e.g., PNG, JPG).
|
28
|
+
model_path (str): Path to a trained PyTorch model file (.pt or .pth).
|
29
|
+
image_size (int, optional): Size to center-crop input images to. Default is 224.
|
30
|
+
batch_size (int, optional): Number of images to process per batch. Default is 64.
|
31
|
+
normalize (bool, optional): If True, normalize images to [-1, 1] using ImageNet-style transform. Default is True.
|
32
|
+
n_jobs (int, optional): Number of subprocesses to use for data loading. Default is 10.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
pandas.DataFrame: A DataFrame with two columns:
|
36
|
+
- "path": Filenames of processed images.
|
37
|
+
- "pred": Model output probabilities (sigmoid of logits).
|
38
|
+
|
39
|
+
Saves:
|
40
|
+
A CSV file named like <model_path><YYMMDD>_<ext>_test_result.csv, containing the prediction results.
|
41
|
+
|
42
|
+
Notes:
|
43
|
+
- Uses GPU if available, otherwise runs on CPU.
|
44
|
+
- Assumes model outputs raw logits for binary classification (sigmoid is applied).
|
45
|
+
- The input folder must contain only images readable by `PIL.Image.open`.
|
46
|
+
"""
|
20
47
|
from .io import NoClassDataset
|
21
48
|
from .utils import print_progress
|
22
49
|
|
@@ -71,7 +98,31 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
|
|
71
98
|
return df
|
72
99
|
|
73
100
|
def apply_model_to_tar(settings={}):
|
74
|
-
|
101
|
+
"""
|
102
|
+
Apply a trained model to images stored inside a tar archive.
|
103
|
+
|
104
|
+
Loads a model and applies it to images within a `.tar` archive using batch inference. Results are
|
105
|
+
filtered by a probability threshold and saved to a CSV. Supports GPU acceleration and normalization.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
settings (dict): Dictionary with the following keys:
|
109
|
+
- tar_path (str): Path to the tar archive with input images.
|
110
|
+
- model_path (str): Path to the trained PyTorch model (.pt/.pth).
|
111
|
+
- image_size (int): Center crop size for input images. Default is 224.
|
112
|
+
- batch_size (int): Batch size for DataLoader. Default is 64.
|
113
|
+
- normalize (bool): Apply normalization to [-1, 1]. Default is True.
|
114
|
+
- n_jobs (int): Number of workers for data loading. Default is system CPU count - 4.
|
115
|
+
- verbose (bool): If True, print progress and model details.
|
116
|
+
- score_threshold (float): Probability threshold for positive classification (used in result filtering).
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
pandas.DataFrame: DataFrame with:
|
120
|
+
- "path": Filenames inside the tar archive.
|
121
|
+
- "pred": Model prediction scores (sigmoid output).
|
122
|
+
|
123
|
+
Saves:
|
124
|
+
A CSV file with prediction results to the same directory as the tar file.
|
125
|
+
"""
|
75
126
|
from .io import TarImageDataset
|
76
127
|
from .utils import process_vision_results, print_progress
|
77
128
|
|
@@ -172,7 +223,7 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
|
|
172
223
|
"""
|
173
224
|
Calculate classification metrics for binary classification.
|
174
225
|
|
175
|
-
|
226
|
+
Args:
|
176
227
|
- all_labels (list): List of true labels.
|
177
228
|
- prediction_pos_probs (list): List of predicted positive probabilities.
|
178
229
|
- loader_name (str): Name of the data loader.
|
@@ -256,7 +307,27 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
|
|
256
307
|
return data_dict, [prediction_pos_probs, all_labels]
|
257
308
|
|
258
309
|
def test_model_core(model, loader, loader_name, epoch, loss_type):
|
259
|
-
|
310
|
+
"""
|
311
|
+
Evaluate a trained model on a test DataLoader and return performance metrics and predictions.
|
312
|
+
|
313
|
+
This function evaluates a binary classification model using a specified loss function, computes
|
314
|
+
classification metrics, and logs predictions, targets, and file-level results.
|
315
|
+
|
316
|
+
Args:
|
317
|
+
model (torch.nn.Module): The trained PyTorch model to evaluate.
|
318
|
+
loader (torch.utils.data.DataLoader): DataLoader providing test data and labels.
|
319
|
+
loader_name (str): Identifier name for the loader (used for logging/debugging).
|
320
|
+
epoch (int): Current epoch number (used for metric tracking).
|
321
|
+
loss_type (str): Type of loss function to use for reporting (e.g., 'bce', 'focal').
|
322
|
+
|
323
|
+
Returns:
|
324
|
+
tuple:
|
325
|
+
- data_df (pd.DataFrame): DataFrame containing classification metrics for the test set.
|
326
|
+
- prediction_pos_probs (list): List of predicted probabilities for the positive class.
|
327
|
+
- all_labels (list): Ground truth binary labels.
|
328
|
+
- results_df (pd.DataFrame): Per-sample results, including filename, true label, predicted label,
|
329
|
+
and probability for class 1.
|
330
|
+
"""
|
260
331
|
from .utils import calculate_loss, classification_metrics
|
261
332
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
262
333
|
model.eval()
|
@@ -612,7 +683,50 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
|
|
612
683
|
return model, model_path
|
613
684
|
|
614
685
|
def generate_activation_map(settings):
|
615
|
-
|
686
|
+
"""
|
687
|
+
Generate activation maps (Grad-CAM or saliency) from a trained model applied to a dataset stored in a tar archive.
|
688
|
+
|
689
|
+
This function loads a model, computes class activation maps or saliency maps for each input image, and saves the
|
690
|
+
results as images. Optionally, it plots batch-wise grids of maps and stores correlation results and image metadata
|
691
|
+
into an SQL database.
|
692
|
+
|
693
|
+
Args:
|
694
|
+
settings (dict): Dictionary of parameters controlling activation map generation. Key fields include:
|
695
|
+
|
696
|
+
Required paths:
|
697
|
+
- dataset (str): Path to the `.tar` archive containing images.
|
698
|
+
- model_path (str): Path to the trained PyTorch model (.pt or .pth).
|
699
|
+
|
700
|
+
Model and method:
|
701
|
+
- model_type (str): Model architecture used (e.g., 'maxvit').
|
702
|
+
- cam_type (str): One of ['gradcam', 'gradcam_pp', 'saliency_image', 'saliency_channel'].
|
703
|
+
- target_layer (str or None): Name of the target layer for Grad-CAM (optional, required for Grad-CAM variants).
|
704
|
+
|
705
|
+
Input transforms:
|
706
|
+
- image_size (int): Size to center-crop images to (e.g., 224).
|
707
|
+
- normalize_input (bool): Whether to normalize images to [-1, 1] range.
|
708
|
+
- channels (list): Channel indices to select from input data (e.g., [0,1,2]).
|
709
|
+
|
710
|
+
Inference:
|
711
|
+
- batch_size (int): Number of images per inference batch.
|
712
|
+
- shuffle (bool): Whether to shuffle image order in DataLoader.
|
713
|
+
- n_jobs (int): Number of parallel DataLoader workers (default is CPU count - 4).
|
714
|
+
|
715
|
+
Output control:
|
716
|
+
- save (bool): If True, saves individual activation maps to disk.
|
717
|
+
- plot (bool): If True, generates and saves batch-wise PDF grid plots.
|
718
|
+
- overlay (bool): If True, overlays activation maps on input images.
|
719
|
+
- correlation (bool): If True, computes activation correlation features (e.g., Manders').
|
720
|
+
|
721
|
+
Correlation-specific:
|
722
|
+
- manders_thresholds (list or float): Threshold(s) for calculating Manders' coefficients.
|
723
|
+
|
724
|
+
Returns:
|
725
|
+
None. The following outputs are saved:
|
726
|
+
- PNG or JPEG activation maps organized by predicted class and well.
|
727
|
+
- PDF files with batch-wise overlay plots if `plot=True`.
|
728
|
+
- Activation image metadata and correlations saved to SQL database if `save=True`.
|
729
|
+
"""
|
616
730
|
from .utils import SaliencyMapGenerator, GradCAMGenerator, SelectChannels, activation_maps_to_database, activation_correlations_to_database
|
617
731
|
from .utils import print_progress, save_settings, calculate_activation_correlations
|
618
732
|
from .io import TarImageDataset
|
@@ -771,7 +885,18 @@ def generate_activation_map(settings):
|
|
771
885
|
print("Activation map generation complete.")
|
772
886
|
|
773
887
|
def visualize_classes(model, dtype, class_names, **kwargs):
|
888
|
+
"""
|
889
|
+
Visualize synthetic input images that maximize class activation.
|
774
890
|
|
891
|
+
Args:
|
892
|
+
model (torch.nn.Module): The trained classification model.
|
893
|
+
dtype (str): Data type or domain tag used for visualization.
|
894
|
+
class_names (list): List of class names (length 2 assumed for binary classification).
|
895
|
+
**kwargs: Additional keyword arguments passed to `class_visualization()`.
|
896
|
+
|
897
|
+
Returns:
|
898
|
+
None. Displays matplotlib plots of class visualizations.
|
899
|
+
"""
|
775
900
|
from .utils import class_visualization
|
776
901
|
|
777
902
|
for target_y in range(2): # Assuming binary classification
|
@@ -783,7 +908,22 @@ def visualize_classes(model, dtype, class_names, **kwargs):
|
|
783
908
|
plt.show()
|
784
909
|
|
785
910
|
def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_size=224, channels=[1,2,3], normalize=True, save_integrated_grads=False, save_dir='integrated_grads'):
|
911
|
+
"""
|
912
|
+
Visualize integrated gradients for PNG images in a directory.
|
913
|
+
|
914
|
+
Args:
|
915
|
+
src (str): Directory containing `.png` images.
|
916
|
+
model_path (str): Path to the trained PyTorch model.
|
917
|
+
target_label_idx (int): Index of the target class label.
|
918
|
+
image_size (int): Image size after preprocessing (center crop).
|
919
|
+
channels (list): List of channels to extract (1-indexed).
|
920
|
+
normalize (bool): Whether to normalize image input to [-1, 1].
|
921
|
+
save_integrated_grads (bool): Whether to save integrated gradient maps.
|
922
|
+
save_dir (str): Directory to save integrated gradient outputs.
|
786
923
|
|
924
|
+
Returns:
|
925
|
+
None. Displays overlays and optionally saves saliency maps.
|
926
|
+
"""
|
787
927
|
from .utils import IntegratedGradients, preprocess_image
|
788
928
|
|
789
929
|
use_cuda = torch.cuda.is_available()
|
@@ -832,6 +972,15 @@ def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_si
|
|
832
972
|
integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
|
833
973
|
|
834
974
|
class SmoothGrad:
|
975
|
+
"""
|
976
|
+
Compute SmoothGrad saliency maps from a trained model.
|
977
|
+
|
978
|
+
Args:
|
979
|
+
model (torch.nn.Module): Trained classification model.
|
980
|
+
n_samples (int): Number of noise samples to average.
|
981
|
+
stdev_spread (float): Standard deviation of noise relative to input range.
|
982
|
+
"""
|
983
|
+
|
835
984
|
def __init__(self, model, n_samples=50, stdev_spread=0.15):
|
836
985
|
self.model = model
|
837
986
|
self.n_samples = n_samples
|
@@ -855,7 +1004,22 @@ class SmoothGrad:
|
|
855
1004
|
return avg_gradients.abs()
|
856
1005
|
|
857
1006
|
def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, channels=[1,2,3], normalize=True, save_smooth_grad=False, save_dir='smooth_grad'):
|
1007
|
+
"""
|
1008
|
+
Visualize SmoothGrad maps for PNG images in a folder.
|
1009
|
+
|
1010
|
+
Args:
|
1011
|
+
src (str): Path to directory containing `.png` images.
|
1012
|
+
model_path (str): Path to trained PyTorch model file.
|
1013
|
+
target_label_idx (int): Index of the class to explain.
|
1014
|
+
image_size (int): Size for center cropping during preprocessing.
|
1015
|
+
channels (list): Channel indices to extract from images.
|
1016
|
+
normalize (bool): Whether to normalize inputs to [-1, 1].
|
1017
|
+
save_smooth_grad (bool): If True, saves saliency maps to disk.
|
1018
|
+
save_dir (str): Folder where smooth grad maps are saved.
|
858
1019
|
|
1020
|
+
Returns:
|
1021
|
+
None. Displays overlay figures and optionally saves maps to disk.
|
1022
|
+
"""
|
859
1023
|
from .utils import preprocess_image
|
860
1024
|
|
861
1025
|
use_cuda = torch.cuda.is_available()
|
@@ -904,6 +1068,78 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
|
|
904
1068
|
smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
|
905
1069
|
|
906
1070
|
def deep_spacr(settings={}):
|
1071
|
+
"""
|
1072
|
+
Run deep learning-based classification workflow on microscopy data using SpaCr.
|
1073
|
+
|
1074
|
+
This function handles dataset generation, model training, and inference using a trained model on tar-archived image datasets.
|
1075
|
+
Settings are filled using `deep_spacr_defaults`.
|
1076
|
+
|
1077
|
+
Args:
|
1078
|
+
settings (dict): Dictionary of settings with the following keys:
|
1079
|
+
|
1080
|
+
General:
|
1081
|
+
- src (str): Path to the input dataset.
|
1082
|
+
- dataset (str): Path to a dataset archive.
|
1083
|
+
- dataset_mode (str): Dataset generation mode. Typically 'metadata'.
|
1084
|
+
- file_type (str): Type of input files (e.g., 'cell_png').
|
1085
|
+
- file_metadata (str or None): Path to file-level metadata, if available.
|
1086
|
+
- sample (int or None): Limit to N random samples for development/testing.
|
1087
|
+
- experiment (str): Experiment name prefix. Default is 'exp.'.
|
1088
|
+
|
1089
|
+
Annotation and class mapping:
|
1090
|
+
- annotation_column (str): Metadata column containing class annotations.
|
1091
|
+
- annotated_classes (list): List of class IDs used for training (e.g., [1, 2]).
|
1092
|
+
- classes (list): Class labels (e.g., ['nc', 'pc']).
|
1093
|
+
- class_metadata (list of lists): Mapping of classes to metadata terms (e.g., [['c1'], ['c2']]).
|
1094
|
+
- metadata_type_by (str): How to interpret metadata structure. Typically 'columnID'.
|
1095
|
+
|
1096
|
+
Image processing:
|
1097
|
+
- channel_of_interest (int): Channel index to use for classification.
|
1098
|
+
- png_type (str): Type of image format (e.g., 'cell_png').
|
1099
|
+
- image_size (int): Input size (e.g., 224 for 224x224 crop).
|
1100
|
+
- train_channels (list): Channels to use for training (e.g., ['r', 'g', 'b']).
|
1101
|
+
- normalize (bool): Whether to normalize input images. Default is True.
|
1102
|
+
- augment (bool): Whether to apply data augmentation.
|
1103
|
+
|
1104
|
+
Model and training:
|
1105
|
+
- model_type (str): Model architecture (e.g., 'maxvit_t').
|
1106
|
+
- optimizer_type (str): Optimizer (e.g., 'adamw').
|
1107
|
+
- schedule (str): Learning rate scheduler ('reduce_lr_on_plateau' or 'step_lr').
|
1108
|
+
- loss_type (str): Loss function ('focal_loss' or 'binary_cross_entropy_with_logits').
|
1109
|
+
- dropout_rate (float): Dropout probability.
|
1110
|
+
- init_weights (bool): Initialize model with pretrained weights.
|
1111
|
+
- amsgrad (bool): Use AMSGrad variant of AdamW optimizer.
|
1112
|
+
- use_checkpoint (bool): Enable checkpointing.
|
1113
|
+
- intermedeate_save (bool): Save intermediate models during training.
|
1114
|
+
|
1115
|
+
Training control:
|
1116
|
+
- train (bool): Enable training phase.
|
1117
|
+
- test (bool): Enable evaluation on test set.
|
1118
|
+
- train_DL_model (bool): Enable deep learning model training.
|
1119
|
+
- generate_training_dataset (bool): Enable generation of train/test splits.
|
1120
|
+
- test_split (float): Proportion of data used for testing.
|
1121
|
+
- val_split (float): Fraction of training set used for validation.
|
1122
|
+
- epochs (int): Number of training epochs.
|
1123
|
+
- batch_size (int): Batch size for training and inference.
|
1124
|
+
- learning_rate (float): Learning rate.
|
1125
|
+
- weight_decay (float): L2 regularization strength.
|
1126
|
+
- gradient_accumulation (bool): Accumulate gradients over multiple steps.
|
1127
|
+
- gradient_accumulation_steps (int): Number of steps per gradient update.
|
1128
|
+
|
1129
|
+
Inference:
|
1130
|
+
- apply_model_to_dataset (bool): Run prediction on tar dataset.
|
1131
|
+
- tar_path (str): Path to tar file for inference input.
|
1132
|
+
- model_path (str): Path to trained model file.
|
1133
|
+
- score_threshold (float): Probability threshold for binary classification.
|
1134
|
+
|
1135
|
+
Execution:
|
1136
|
+
- n_jobs (int): Number of parallel workers.
|
1137
|
+
- pin_memory (bool): Whether to use pinned memory in DataLoader.
|
1138
|
+
- verbose (bool): Print training and evaluation progress.
|
1139
|
+
|
1140
|
+
Returns:
|
1141
|
+
None. All outputs (trained models, predictions, settings) are saved to disk.
|
1142
|
+
"""
|
907
1143
|
from .settings import deep_spacr_defaults
|
908
1144
|
from .io import generate_training_dataset, generate_dataset
|
909
1145
|
from .utils import save_settings
|
@@ -937,7 +1173,26 @@ def deep_spacr(settings={}):
|
|
937
1173
|
apply_model_to_tar(settings)
|
938
1174
|
|
939
1175
|
def model_knowledge_transfer(teacher_paths, student_save_path, data_loader, device='cpu', student_model_name='maxvit_t', pretrained=True, dropout_rate=None, use_checkpoint=False, alpha=0.5, temperature=2.0, lr=1e-4, epochs=10):
|
1176
|
+
"""
|
1177
|
+
Perform knowledge distillation from one or more teacher models to a student model.
|
940
1178
|
|
1179
|
+
Args:
|
1180
|
+
teacher_paths (list of str): Paths to pretrained teacher model files (.pth).
|
1181
|
+
student_save_path (str): Output path for the saved student model.
|
1182
|
+
data_loader (torch.utils.data.DataLoader): DataLoader for training data.
|
1183
|
+
device (str): Device to use ('cpu' or 'cuda').
|
1184
|
+
student_model_name (str): Name of the student model architecture (e.g., 'maxvit_t').
|
1185
|
+
pretrained (bool): Whether to initialize the student model with ImageNet weights.
|
1186
|
+
dropout_rate (float or None): Dropout rate for the student model.
|
1187
|
+
use_checkpoint (bool): Whether to use gradient checkpointing.
|
1188
|
+
alpha (float): Weighting factor between cross-entropy and distillation loss.
|
1189
|
+
temperature (float): Temperature scaling for soft targets.
|
1190
|
+
lr (float): Learning rate for optimizer.
|
1191
|
+
epochs (int): Number of training epochs.
|
1192
|
+
|
1193
|
+
Returns:
|
1194
|
+
TorchModel: The trained student model after knowledge distillation.
|
1195
|
+
"""
|
941
1196
|
from .utils import TorchModel
|
942
1197
|
|
943
1198
|
# Adjust filename to reflect knowledge-distillation if desired
|
@@ -1041,7 +1296,22 @@ def model_knowledge_transfer(teacher_paths, student_save_path, data_loader, devi
|
|
1041
1296
|
return student_model
|
1042
1297
|
|
1043
1298
|
def model_fusion(model_paths,save_path,device='cpu',model_name='maxvit_t',pretrained=True,dropout_rate=None,use_checkpoint=False,aggregator='mean'):
|
1299
|
+
"""
|
1300
|
+
Fuse multiple trained models by combining their parameters using a specified aggregation method.
|
1044
1301
|
|
1302
|
+
Args:
|
1303
|
+
model_paths (list of str): List of paths to model checkpoints to be fused.
|
1304
|
+
save_path (str): Path where the fused model will be saved.
|
1305
|
+
device (str): Device to use ('cpu' or 'cuda').
|
1306
|
+
model_name (str): Model architecture to use when initializing.
|
1307
|
+
pretrained (bool): Whether to initialize with pretrained weights.
|
1308
|
+
dropout_rate (float or None): Dropout rate to apply to the model.
|
1309
|
+
use_checkpoint (bool): Whether to use gradient checkpointing.
|
1310
|
+
aggregator (str): Aggregation strategy to combine weights. One of {'mean', 'geomean', 'median', 'sum', 'max', 'min'}.
|
1311
|
+
|
1312
|
+
Returns:
|
1313
|
+
TorchModel: The fused model with combined weights.
|
1314
|
+
"""
|
1045
1315
|
from .utils import TorchModel
|
1046
1316
|
|
1047
1317
|
if save_path.endswith('.pth'):
|
@@ -1141,14 +1411,33 @@ def model_fusion(model_paths,save_path,device='cpu',model_name='maxvit_t',pretra
|
|
1141
1411
|
return fused_model
|
1142
1412
|
|
1143
1413
|
def annotate_filter_vision(settings):
|
1144
|
-
|
1414
|
+
"""
|
1415
|
+
Annotate and filter a CSV file with experimental metadata and optionally remove training samples.
|
1416
|
+
|
1417
|
+
Args:
|
1418
|
+
settings (dict): Configuration dictionary with keys:
|
1419
|
+
- 'src' (str or list): Paths to CSV annotation files.
|
1420
|
+
- 'cells' (dict): Mapping of cell types to annotation labels.
|
1421
|
+
- 'cell_loc' (str): Column name for cell type annotations.
|
1422
|
+
- 'pathogens' (dict): Mapping of pathogens to annotation labels.
|
1423
|
+
- 'pathogen_loc' (str): Column name for pathogen annotations.
|
1424
|
+
- 'treatments' (dict): Mapping of treatments to annotation labels.
|
1425
|
+
- 'treatment_loc' (str): Column name for treatment annotations.
|
1426
|
+
- 'filter_column' (str or None): Column to filter on.
|
1427
|
+
- 'upper_threshold' (float): Upper bound for filtering.
|
1428
|
+
- 'lower_threshold' (float): Lower bound for filtering.
|
1429
|
+
- 'remove_train' (bool): If True, removes rows present in training folders.
|
1430
|
+
|
1431
|
+
Returns:
|
1432
|
+
None. Saves filtered and annotated CSVs to disk.
|
1433
|
+
"""
|
1145
1434
|
from .utils import annotate_conditions, correct_metadata
|
1146
1435
|
|
1147
1436
|
def filter_csv_by_png(csv_file):
|
1148
1437
|
"""
|
1149
1438
|
Filters a DataFrame by removing rows that match PNG filenames in a folder.
|
1150
1439
|
|
1151
|
-
|
1440
|
+
Args:
|
1152
1441
|
csv_file (str): Path to the CSV file.
|
1153
1442
|
|
1154
1443
|
Returns:
|