spacr 1.0.9__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/app_classify.py CHANGED
@@ -1,6 +1,16 @@
1
1
  from .gui import MainApp
2
2
 
3
3
  def start_classify_app():
4
+ """
5
+ Launch the spaCR GUI with the Classify application preloaded.
6
+
7
+ This function initializes the main GUI window with "Classify" set as the default active application.
8
+ It then starts the Tkinter main event loop to display the interface.
9
+
10
+ Typical use case:
11
+ Called from the command line or another script to directly launch the Classify module of spaCR.
12
+
13
+ """
4
14
  app = MainApp(default_app="Classify")
5
15
  app.mainloop()
6
16
 
spacr/app_mask.py CHANGED
@@ -1,6 +1,15 @@
1
1
  from .gui import MainApp
2
2
 
3
3
  def start_mask_app():
4
+ """
5
+ Launch the spaCR GUI with the Mask application preloaded.
6
+
7
+ This function initializes the main GUI window with "Mask" selected as the default active module.
8
+ It is intended for users who want to directly start the application in mask generation mode.
9
+
10
+ Typical use case:
11
+ Called from the command line or another script to launch the mask generation workflow of spaCR.
12
+ """
4
13
  app = MainApp(default_app="Mask")
5
14
  app.mainloop()
6
15
 
spacr/app_measure.py CHANGED
@@ -1,6 +1,15 @@
1
1
  from .gui import MainApp
2
2
 
3
3
  def start_measure_app():
4
+ """
5
+ Launch the spaCR GUI with the Measure application preloaded.
6
+
7
+ This function initializes the main GUI window with "Measure" selected as the default active module.
8
+ It is used to directly open the application in object measurement mode.
9
+
10
+ Typical use case:
11
+ Called from the command line or another script to start spaCR in measurement mode.
12
+ """
4
13
  app = MainApp(default_app="Measure")
5
14
  app.mainloop()
6
15
 
spacr/app_sequencing.py CHANGED
@@ -1,6 +1,15 @@
1
1
  from .gui import MainApp
2
2
 
3
3
  def start_seq_app():
4
+ """
5
+ Launch the spaCR GUI with the Measure application preloaded.
6
+
7
+ This function initializes the main GUI window with "Measure" selected as the default active module.
8
+ It is used to directly open the application in object measurement mode.
9
+
10
+ Typical use case:
11
+ Called from the command line or another script to start spaCR in measurement mode.
12
+ """
4
13
  app = MainApp(default_app="Sequencing")
5
14
  app.mainloop()
6
15
 
spacr/core.py CHANGED
@@ -7,6 +7,99 @@ import warnings
7
7
  warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
8
8
 
9
9
  def preprocess_generate_masks(settings):
10
+
11
+ """
12
+ Preprocess image data and generate Cellpose segmentation masks for cells, nuclei, and pathogens.
13
+
14
+ This function supports preprocessing, metadata conversion, Cellpose-based mask generation, optional
15
+ mask adjustment, result plotting, and intermediate file cleanup. It handles batch operations and
16
+ supports advanced timelapse and channel-specific configurations.
17
+
18
+ Args:
19
+ settings (dict): Dictionary containing the following keys:
20
+
21
+ General settings:
22
+ - src (str or list): Path(s) to input folders. Required.
23
+ - denoise (bool): Apply denoising during preprocessing. Default is False.
24
+ - delete_intermediate (bool): Delete intermediate files after processing. Default is False.
25
+ - preprocess (bool): Perform preprocessing. Default is True.
26
+ - masks (bool): Generate masks using Cellpose. Default is True.
27
+ - save (bool or list of bool): Whether to save outputs per object type. Default is True.
28
+ - consolidate (bool): Consolidate input folder structure. Default is False.
29
+ - batch_size (int): Number of files processed per batch. Default is 50.
30
+ - test_mode (bool): Enable test mode with limited data. Default is False.
31
+ - test_images (int): Number of test images to use. Default is 10.
32
+ - magnification (int): Magnification of input data. Default is 20.
33
+ - custom_regex (str or None): Regex for filename parsing in auto metadata mode.
34
+ - metadata_type (str): Metadata type; "cellvoyager" or "auto". Default is "cellvoyager".
35
+ - n_jobs (int): Number of parallel processes. Default is os.cpu_count() - 4.
36
+ - randomize (bool): Randomize processing order. Default is True.
37
+ - verbose (bool): Print full settings table. Default is True.
38
+
39
+ Channel background correction:
40
+ - remove_background_cell (bool): Remove background from cell channel. Default is False.
41
+ - remove_background_nucleus (bool): Remove background from nucleus channel. Default is False.
42
+ - remove_background_pathogen (bool): Remove background from pathogen channel. Default is False.
43
+
44
+ Channel diameter and index settings:
45
+ - cell_diamiter (float or None): Cell diameter estimate for Cellpose.
46
+ - nucleus_diamiter (float or None): Nucleus diameter estimate for Cellpose.
47
+ - pathogen_diamiter (float or None): Pathogen diameter estimate for Cellpose.
48
+ - cell_channel (int or None): Channel index for cell. Default is None.
49
+ - nucleus_channel (int or None): Channel index for nucleus. Default is None.
50
+ - pathogen_channel (int or None): Channel index for pathogen. Default is None.
51
+ - channels (list): List of channel indices to include. Default is [0, 1, 2, 3].
52
+
53
+ Cellpose parameters:
54
+ - pathogen_background (float): Background intensity for pathogen. Default is 100.
55
+ - pathogen_Signal_to_noise (float): SNR threshold for pathogen. Default is 10.
56
+ - pathogen_CP_prob (float): Cellpose probability threshold for pathogen. Default is 0.
57
+ - cell_background (float): Background intensity for cell. Default is 100.
58
+ - cell_Signal_to_noise (float): SNR threshold for cell. Default is 10.
59
+ - cell_CP_prob (float): Cellpose probability threshold for cell. Default is 0.
60
+ - nucleus_background (float): Background intensity for nucleus. Default is 100.
61
+ - nucleus_Signal_to_noise (float): SNR threshold for nucleus. Default is 10.
62
+ - nucleus_CP_prob (float): Cellpose probability threshold for nucleus. Default is 0.
63
+ - nucleus_FT (float): Intensity scaling factor for nucleus. Default is 1.0.
64
+ - cell_FT (float): Intensity scaling factor for cell. Default is 1.0.
65
+ - pathogen_FT (float): Intensity scaling factor for pathogen. Default is 1.0.
66
+
67
+ Plotting settings:
68
+ - plot (bool): Enable plotting. Default is False.
69
+ - figuresize (int or float): Figure size for plots. Default is 10.
70
+ - cmap (str): Colormap used for plotting. Default is "inferno".
71
+ - normalize (bool): Normalize image intensities before processing. Default is True.
72
+ - normalize_plots (bool): Normalize intensity for plotting. Default is True.
73
+ - examples_to_plot (int): Number of examples to plot. Default is 1.
74
+
75
+ Analysis settings:
76
+ - pathogen_model (str or None): Custom model for pathogen ("toxo_pv_lumen" or "toxo_cyto").
77
+ - merge_pathogens (bool): Whether to merge multiple pathogen types. Default is False.
78
+ - filter (bool): Apply percentile filter. Default is False.
79
+ - lower_percentile (float): Lower percentile for intensity filtering. Default is 2.
80
+
81
+ Timelapse settings:
82
+ - timelapse (bool): Enable timelapse mode. Default is False.
83
+ - fps (int): Frames per second for timelapse export. Default is 2.
84
+ - timelapse_displacement (float or None): Max displacement for object tracking.
85
+ - timelapse_memory (int): Memory for tracking algorithm. Default is 3.
86
+ - timelapse_frame_limits (list): Frame limits for tracking. Default is [5].
87
+ - timelapse_remove_transient (bool): Remove short-lived objects. Default is False.
88
+ - timelapse_mode (str): Tracking algorithm. Default is "trackpy".
89
+ - timelapse_objects (str or None): Object type for tracking.
90
+
91
+ Miscellaneous:
92
+ - all_to_mip (bool): Convert all input to MIP. Default is False.
93
+ - upscale (bool): Upscale images prior to processing. Default is False.
94
+ - upscale_factor (float): Upscaling factor. Default is 2.0.
95
+ - adjust_cells (bool): Adjust cell masks based on nuclei and pathogen. Default is False.
96
+ - use_sam_cell (bool): Use SAM model for cell segmentation. Default is False.
97
+ - use_sam_nucleus (bool): Use SAM model for nucleus segmentation. Default is False.
98
+ - use_sam_pathogen (bool): Use SAM model for pathogen segmentation. Default is False.
99
+
100
+ Returns:
101
+ None: All outputs (masks, merged arrays, plots, databases) are saved to disk under the source folder(s).
102
+ """
10
103
 
11
104
  from .io import preprocess_img_data, _load_and_concatenate_arrays, convert_to_yokogawa, convert_separate_files_to_yokogawa
12
105
  from .plot import plot_image_mask_overlay, plot_arrays
@@ -194,7 +287,85 @@ def preprocess_generate_masks(settings):
194
287
  return
195
288
 
196
289
  def generate_cellpose_masks(src, settings, object_type):
197
-
290
+ """
291
+ Generate segmentation masks for a specific object type using Cellpose.
292
+
293
+ This function applies a Cellpose-based segmentation pipeline to images in `.npz` format, using settings
294
+ for batch size, object type (cell, nucleus, pathogen), filtering, plotting, and timelapse options.
295
+ Masks are optionally filtered, saved, tracked (for timelapse), and summarized into a SQLite database.
296
+
297
+ Args:
298
+ src (str): Path to the source folder containing `.npz` files with image stacks.
299
+ settings (dict): Dictionary of settings used to control preprocessing and segmentation. Includes:
300
+
301
+ General settings:
302
+ - src (str): Source directory.
303
+ - denoise (bool): Apply denoising before processing.
304
+ - delete_intermediate (bool): Remove intermediate files after processing.
305
+ - preprocess (bool): Enable preprocessing.
306
+ - masks (bool): Enable mask generation.
307
+ - save (bool): Save mask outputs.
308
+ - consolidate (bool): Consolidate image folders.
309
+ - batch_size (int): Batch size for processing.
310
+ - test_mode (bool): Enable test mode with limited image count.
311
+ - test_images (int): Number of test images to process.
312
+ - magnification (int): Image magnification level.
313
+ - custom_regex (str or None): Regex pattern for file parsing (metadata_type = 'auto').
314
+ - metadata_type (str): One of "cellvoyager" or "auto".
315
+ - n_jobs (int): Number of parallel workers.
316
+ - randomize (bool): Shuffle file order before processing.
317
+ - verbose (bool): Print full settings to console.
318
+
319
+ Channel/background/cellpose settings:
320
+ - remove_background_cell/nucleus/pathogen (bool): Whether to subtract background from channel.
321
+ - cell_diamiter / nucleus_diamiter / pathogen_diamiter (float or None): Estimated diameter.
322
+ - cell_channel / nucleus_channel / pathogen_channel (int or None): Channel index.
323
+ - channels (list): List of channels to include in stack.
324
+ - cell/background/SNR/CP_prob/FT (float): Intensity/cellpose thresholds and scaling.
325
+ - pathogen_model (str or None): Custom model for pathogen segmentation (e.g. "toxo_pv_lumen").
326
+
327
+ Plotting:
328
+ - plot (bool): Plot masks or overlay visualizations.
329
+ - figuresize (int): Matplotlib figure size.
330
+ - cmap (str): Colormap to use (e.g. "inferno").
331
+ - normalize (bool): Normalize input intensities.
332
+ - normalize_plots (bool): Normalize for plots.
333
+ - examples_to_plot (int): How many examples to plot.
334
+
335
+ Filtering and merging:
336
+ - merge_pathogens (bool): Whether to merge pathogen objects.
337
+ - filter (bool): Apply filtering on masks.
338
+ - lower_percentile (float): Intensity filter threshold.
339
+ - merge (bool): Merge adjacent objects if needed.
340
+
341
+ Timelapse:
342
+ - timelapse (bool): Enable object tracking across timepoints.
343
+ - timelapse_displacement (float or None): Max tracking displacement.
344
+ - timelapse_memory (int): Trackpy memory.
345
+ - timelapse_frame_limits (list): Frames to include in timelapse batch.
346
+ - timelapse_remove_transient (bool): Remove transient objects.
347
+ - timelapse_mode (str): One of "trackpy", "btrack", or "iou".
348
+ - timelapse_objects (list or None): Subset of ['cell', 'nucleus', 'pathogen'] to track.
349
+
350
+ Miscellaneous:
351
+ - all_to_mip (bool): Convert Z-stacks to max projections.
352
+ - upscale (bool): Apply upscaling.
353
+ - upscale_factor (float): Upscaling factor.
354
+ - adjust_cells (bool): Refine cell masks with nucleus/pathogen.
355
+ - use_sam_cell/nucleus/pathogen (bool): Use SAM for mask generation.
356
+
357
+ object_type (str): One of 'cell', 'nucleus', or 'pathogen'. Determines which mask to generate.
358
+
359
+ Returns:
360
+ None. Outputs are saved to disk:
361
+ - Generated masks are stored in a `*_mask_stack/` folder.
362
+ - Object counts are written to `measurements/measurements.db`.
363
+ - Optional overlay plots are saved if enabled.
364
+ - Optional timelapse movies are saved in `movies/`.
365
+
366
+ Raises:
367
+ ValueError: If the object_type is missing from the computed channel map, or if invalid tracking settings are provided.
368
+ """
198
369
  from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, all_elements_match, prepare_batch_for_segmentation
199
370
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
200
371
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
spacr/deep_spacr.py CHANGED
@@ -16,7 +16,34 @@ from torchvision import transforms
16
16
  from torch.utils.data import DataLoader
17
17
 
18
18
  def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
19
-
19
+ """
20
+ Apply a trained binary classification model to a folder of images.
21
+
22
+ Loads a PyTorch model and applies it to images in the specified folder using batch inference.
23
+ Supports optional normalization and GPU acceleration. Outputs prediction probabilities and
24
+ saves results as a CSV file alongside the model.
25
+
26
+ Args:
27
+ src (str): Path to a folder containing input images (e.g., PNG, JPG).
28
+ model_path (str): Path to a trained PyTorch model file (.pt or .pth).
29
+ image_size (int, optional): Size to center-crop input images to. Default is 224.
30
+ batch_size (int, optional): Number of images to process per batch. Default is 64.
31
+ normalize (bool, optional): If True, normalize images to [-1, 1] using ImageNet-style transform. Default is True.
32
+ n_jobs (int, optional): Number of subprocesses to use for data loading. Default is 10.
33
+
34
+ Returns:
35
+ pandas.DataFrame: A DataFrame with two columns:
36
+ - "path": Filenames of processed images.
37
+ - "pred": Model output probabilities (sigmoid of logits).
38
+
39
+ Saves:
40
+ A CSV file named like <model_path><YYMMDD>_<ext>_test_result.csv, containing the prediction results.
41
+
42
+ Notes:
43
+ - Uses GPU if available, otherwise runs on CPU.
44
+ - Assumes model outputs raw logits for binary classification (sigmoid is applied).
45
+ - The input folder must contain only images readable by `PIL.Image.open`.
46
+ """
20
47
  from .io import NoClassDataset
21
48
  from .utils import print_progress
22
49
 
@@ -71,7 +98,31 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
71
98
  return df
72
99
 
73
100
  def apply_model_to_tar(settings={}):
74
-
101
+ """
102
+ Apply a trained model to images stored inside a tar archive.
103
+
104
+ Loads a model and applies it to images within a `.tar` archive using batch inference. Results are
105
+ filtered by a probability threshold and saved to a CSV. Supports GPU acceleration and normalization.
106
+
107
+ Args:
108
+ settings (dict): Dictionary with the following keys:
109
+ - tar_path (str): Path to the tar archive with input images.
110
+ - model_path (str): Path to the trained PyTorch model (.pt/.pth).
111
+ - image_size (int): Center crop size for input images. Default is 224.
112
+ - batch_size (int): Batch size for DataLoader. Default is 64.
113
+ - normalize (bool): Apply normalization to [-1, 1]. Default is True.
114
+ - n_jobs (int): Number of workers for data loading. Default is system CPU count - 4.
115
+ - verbose (bool): If True, print progress and model details.
116
+ - score_threshold (float): Probability threshold for positive classification (used in result filtering).
117
+
118
+ Returns:
119
+ pandas.DataFrame: DataFrame with:
120
+ - "path": Filenames inside the tar archive.
121
+ - "pred": Model prediction scores (sigmoid output).
122
+
123
+ Saves:
124
+ A CSV file with prediction results to the same directory as the tar file.
125
+ """
75
126
  from .io import TarImageDataset
76
127
  from .utils import process_vision_results, print_progress
77
128
 
@@ -172,7 +223,7 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
172
223
  """
173
224
  Calculate classification metrics for binary classification.
174
225
 
175
- Parameters:
226
+ Args:
176
227
  - all_labels (list): List of true labels.
177
228
  - prediction_pos_probs (list): List of predicted positive probabilities.
178
229
  - loader_name (str): Name of the data loader.
@@ -256,7 +307,27 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
256
307
  return data_dict, [prediction_pos_probs, all_labels]
257
308
 
258
309
  def test_model_core(model, loader, loader_name, epoch, loss_type):
259
-
310
+ """
311
+ Evaluate a trained model on a test DataLoader and return performance metrics and predictions.
312
+
313
+ This function evaluates a binary classification model using a specified loss function, computes
314
+ classification metrics, and logs predictions, targets, and file-level results.
315
+
316
+ Args:
317
+ model (torch.nn.Module): The trained PyTorch model to evaluate.
318
+ loader (torch.utils.data.DataLoader): DataLoader providing test data and labels.
319
+ loader_name (str): Identifier name for the loader (used for logging/debugging).
320
+ epoch (int): Current epoch number (used for metric tracking).
321
+ loss_type (str): Type of loss function to use for reporting (e.g., 'bce', 'focal').
322
+
323
+ Returns:
324
+ tuple:
325
+ - data_df (pd.DataFrame): DataFrame containing classification metrics for the test set.
326
+ - prediction_pos_probs (list): List of predicted probabilities for the positive class.
327
+ - all_labels (list): Ground truth binary labels.
328
+ - results_df (pd.DataFrame): Per-sample results, including filename, true label, predicted label,
329
+ and probability for class 1.
330
+ """
260
331
  from .utils import calculate_loss, classification_metrics
261
332
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
262
333
  model.eval()
@@ -612,7 +683,50 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
612
683
  return model, model_path
613
684
 
614
685
  def generate_activation_map(settings):
615
-
686
+ """
687
+ Generate activation maps (Grad-CAM or saliency) from a trained model applied to a dataset stored in a tar archive.
688
+
689
+ This function loads a model, computes class activation maps or saliency maps for each input image, and saves the
690
+ results as images. Optionally, it plots batch-wise grids of maps and stores correlation results and image metadata
691
+ into an SQL database.
692
+
693
+ Args:
694
+ settings (dict): Dictionary of parameters controlling activation map generation. Key fields include:
695
+
696
+ Required paths:
697
+ - dataset (str): Path to the `.tar` archive containing images.
698
+ - model_path (str): Path to the trained PyTorch model (.pt or .pth).
699
+
700
+ Model and method:
701
+ - model_type (str): Model architecture used (e.g., 'maxvit').
702
+ - cam_type (str): One of ['gradcam', 'gradcam_pp', 'saliency_image', 'saliency_channel'].
703
+ - target_layer (str or None): Name of the target layer for Grad-CAM (optional, required for Grad-CAM variants).
704
+
705
+ Input transforms:
706
+ - image_size (int): Size to center-crop images to (e.g., 224).
707
+ - normalize_input (bool): Whether to normalize images to [-1, 1] range.
708
+ - channels (list): Channel indices to select from input data (e.g., [0,1,2]).
709
+
710
+ Inference:
711
+ - batch_size (int): Number of images per inference batch.
712
+ - shuffle (bool): Whether to shuffle image order in DataLoader.
713
+ - n_jobs (int): Number of parallel DataLoader workers (default is CPU count - 4).
714
+
715
+ Output control:
716
+ - save (bool): If True, saves individual activation maps to disk.
717
+ - plot (bool): If True, generates and saves batch-wise PDF grid plots.
718
+ - overlay (bool): If True, overlays activation maps on input images.
719
+ - correlation (bool): If True, computes activation correlation features (e.g., Manders').
720
+
721
+ Correlation-specific:
722
+ - manders_thresholds (list or float): Threshold(s) for calculating Manders' coefficients.
723
+
724
+ Returns:
725
+ None. The following outputs are saved:
726
+ - PNG or JPEG activation maps organized by predicted class and well.
727
+ - PDF files with batch-wise overlay plots if `plot=True`.
728
+ - Activation image metadata and correlations saved to SQL database if `save=True`.
729
+ """
616
730
  from .utils import SaliencyMapGenerator, GradCAMGenerator, SelectChannels, activation_maps_to_database, activation_correlations_to_database
617
731
  from .utils import print_progress, save_settings, calculate_activation_correlations
618
732
  from .io import TarImageDataset
@@ -771,7 +885,18 @@ def generate_activation_map(settings):
771
885
  print("Activation map generation complete.")
772
886
 
773
887
  def visualize_classes(model, dtype, class_names, **kwargs):
888
+ """
889
+ Visualize synthetic input images that maximize class activation.
774
890
 
891
+ Args:
892
+ model (torch.nn.Module): The trained classification model.
893
+ dtype (str): Data type or domain tag used for visualization.
894
+ class_names (list): List of class names (length 2 assumed for binary classification).
895
+ **kwargs: Additional keyword arguments passed to `class_visualization()`.
896
+
897
+ Returns:
898
+ None. Displays matplotlib plots of class visualizations.
899
+ """
775
900
  from .utils import class_visualization
776
901
 
777
902
  for target_y in range(2): # Assuming binary classification
@@ -783,7 +908,22 @@ def visualize_classes(model, dtype, class_names, **kwargs):
783
908
  plt.show()
784
909
 
785
910
  def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_size=224, channels=[1,2,3], normalize=True, save_integrated_grads=False, save_dir='integrated_grads'):
911
+ """
912
+ Visualize integrated gradients for PNG images in a directory.
913
+
914
+ Args:
915
+ src (str): Directory containing `.png` images.
916
+ model_path (str): Path to the trained PyTorch model.
917
+ target_label_idx (int): Index of the target class label.
918
+ image_size (int): Image size after preprocessing (center crop).
919
+ channels (list): List of channels to extract (1-indexed).
920
+ normalize (bool): Whether to normalize image input to [-1, 1].
921
+ save_integrated_grads (bool): Whether to save integrated gradient maps.
922
+ save_dir (str): Directory to save integrated gradient outputs.
786
923
 
924
+ Returns:
925
+ None. Displays overlays and optionally saves saliency maps.
926
+ """
787
927
  from .utils import IntegratedGradients, preprocess_image
788
928
 
789
929
  use_cuda = torch.cuda.is_available()
@@ -832,6 +972,15 @@ def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_si
832
972
  integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
833
973
 
834
974
  class SmoothGrad:
975
+ """
976
+ Compute SmoothGrad saliency maps from a trained model.
977
+
978
+ Args:
979
+ model (torch.nn.Module): Trained classification model.
980
+ n_samples (int): Number of noise samples to average.
981
+ stdev_spread (float): Standard deviation of noise relative to input range.
982
+ """
983
+
835
984
  def __init__(self, model, n_samples=50, stdev_spread=0.15):
836
985
  self.model = model
837
986
  self.n_samples = n_samples
@@ -855,7 +1004,22 @@ class SmoothGrad:
855
1004
  return avg_gradients.abs()
856
1005
 
857
1006
  def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, channels=[1,2,3], normalize=True, save_smooth_grad=False, save_dir='smooth_grad'):
1007
+ """
1008
+ Visualize SmoothGrad maps for PNG images in a folder.
1009
+
1010
+ Args:
1011
+ src (str): Path to directory containing `.png` images.
1012
+ model_path (str): Path to trained PyTorch model file.
1013
+ target_label_idx (int): Index of the class to explain.
1014
+ image_size (int): Size for center cropping during preprocessing.
1015
+ channels (list): Channel indices to extract from images.
1016
+ normalize (bool): Whether to normalize inputs to [-1, 1].
1017
+ save_smooth_grad (bool): If True, saves saliency maps to disk.
1018
+ save_dir (str): Folder where smooth grad maps are saved.
858
1019
 
1020
+ Returns:
1021
+ None. Displays overlay figures and optionally saves maps to disk.
1022
+ """
859
1023
  from .utils import preprocess_image
860
1024
 
861
1025
  use_cuda = torch.cuda.is_available()
@@ -904,6 +1068,78 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
904
1068
  smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
905
1069
 
906
1070
  def deep_spacr(settings={}):
1071
+ """
1072
+ Run deep learning-based classification workflow on microscopy data using SpaCr.
1073
+
1074
+ This function handles dataset generation, model training, and inference using a trained model on tar-archived image datasets.
1075
+ Settings are filled using `deep_spacr_defaults`.
1076
+
1077
+ Args:
1078
+ settings (dict): Dictionary of settings with the following keys:
1079
+
1080
+ General:
1081
+ - src (str): Path to the input dataset.
1082
+ - dataset (str): Path to a dataset archive.
1083
+ - dataset_mode (str): Dataset generation mode. Typically 'metadata'.
1084
+ - file_type (str): Type of input files (e.g., 'cell_png').
1085
+ - file_metadata (str or None): Path to file-level metadata, if available.
1086
+ - sample (int or None): Limit to N random samples for development/testing.
1087
+ - experiment (str): Experiment name prefix. Default is 'exp.'.
1088
+
1089
+ Annotation and class mapping:
1090
+ - annotation_column (str): Metadata column containing class annotations.
1091
+ - annotated_classes (list): List of class IDs used for training (e.g., [1, 2]).
1092
+ - classes (list): Class labels (e.g., ['nc', 'pc']).
1093
+ - class_metadata (list of lists): Mapping of classes to metadata terms (e.g., [['c1'], ['c2']]).
1094
+ - metadata_type_by (str): How to interpret metadata structure. Typically 'columnID'.
1095
+
1096
+ Image processing:
1097
+ - channel_of_interest (int): Channel index to use for classification.
1098
+ - png_type (str): Type of image format (e.g., 'cell_png').
1099
+ - image_size (int): Input size (e.g., 224 for 224x224 crop).
1100
+ - train_channels (list): Channels to use for training (e.g., ['r', 'g', 'b']).
1101
+ - normalize (bool): Whether to normalize input images. Default is True.
1102
+ - augment (bool): Whether to apply data augmentation.
1103
+
1104
+ Model and training:
1105
+ - model_type (str): Model architecture (e.g., 'maxvit_t').
1106
+ - optimizer_type (str): Optimizer (e.g., 'adamw').
1107
+ - schedule (str): Learning rate scheduler ('reduce_lr_on_plateau' or 'step_lr').
1108
+ - loss_type (str): Loss function ('focal_loss' or 'binary_cross_entropy_with_logits').
1109
+ - dropout_rate (float): Dropout probability.
1110
+ - init_weights (bool): Initialize model with pretrained weights.
1111
+ - amsgrad (bool): Use AMSGrad variant of AdamW optimizer.
1112
+ - use_checkpoint (bool): Enable checkpointing.
1113
+ - intermedeate_save (bool): Save intermediate models during training.
1114
+
1115
+ Training control:
1116
+ - train (bool): Enable training phase.
1117
+ - test (bool): Enable evaluation on test set.
1118
+ - train_DL_model (bool): Enable deep learning model training.
1119
+ - generate_training_dataset (bool): Enable generation of train/test splits.
1120
+ - test_split (float): Proportion of data used for testing.
1121
+ - val_split (float): Fraction of training set used for validation.
1122
+ - epochs (int): Number of training epochs.
1123
+ - batch_size (int): Batch size for training and inference.
1124
+ - learning_rate (float): Learning rate.
1125
+ - weight_decay (float): L2 regularization strength.
1126
+ - gradient_accumulation (bool): Accumulate gradients over multiple steps.
1127
+ - gradient_accumulation_steps (int): Number of steps per gradient update.
1128
+
1129
+ Inference:
1130
+ - apply_model_to_dataset (bool): Run prediction on tar dataset.
1131
+ - tar_path (str): Path to tar file for inference input.
1132
+ - model_path (str): Path to trained model file.
1133
+ - score_threshold (float): Probability threshold for binary classification.
1134
+
1135
+ Execution:
1136
+ - n_jobs (int): Number of parallel workers.
1137
+ - pin_memory (bool): Whether to use pinned memory in DataLoader.
1138
+ - verbose (bool): Print training and evaluation progress.
1139
+
1140
+ Returns:
1141
+ None. All outputs (trained models, predictions, settings) are saved to disk.
1142
+ """
907
1143
  from .settings import deep_spacr_defaults
908
1144
  from .io import generate_training_dataset, generate_dataset
909
1145
  from .utils import save_settings
@@ -937,7 +1173,26 @@ def deep_spacr(settings={}):
937
1173
  apply_model_to_tar(settings)
938
1174
 
939
1175
  def model_knowledge_transfer(teacher_paths, student_save_path, data_loader, device='cpu', student_model_name='maxvit_t', pretrained=True, dropout_rate=None, use_checkpoint=False, alpha=0.5, temperature=2.0, lr=1e-4, epochs=10):
1176
+ """
1177
+ Perform knowledge distillation from one or more teacher models to a student model.
940
1178
 
1179
+ Args:
1180
+ teacher_paths (list of str): Paths to pretrained teacher model files (.pth).
1181
+ student_save_path (str): Output path for the saved student model.
1182
+ data_loader (torch.utils.data.DataLoader): DataLoader for training data.
1183
+ device (str): Device to use ('cpu' or 'cuda').
1184
+ student_model_name (str): Name of the student model architecture (e.g., 'maxvit_t').
1185
+ pretrained (bool): Whether to initialize the student model with ImageNet weights.
1186
+ dropout_rate (float or None): Dropout rate for the student model.
1187
+ use_checkpoint (bool): Whether to use gradient checkpointing.
1188
+ alpha (float): Weighting factor between cross-entropy and distillation loss.
1189
+ temperature (float): Temperature scaling for soft targets.
1190
+ lr (float): Learning rate for optimizer.
1191
+ epochs (int): Number of training epochs.
1192
+
1193
+ Returns:
1194
+ TorchModel: The trained student model after knowledge distillation.
1195
+ """
941
1196
  from .utils import TorchModel
942
1197
 
943
1198
  # Adjust filename to reflect knowledge-distillation if desired
@@ -1041,7 +1296,22 @@ def model_knowledge_transfer(teacher_paths, student_save_path, data_loader, devi
1041
1296
  return student_model
1042
1297
 
1043
1298
  def model_fusion(model_paths,save_path,device='cpu',model_name='maxvit_t',pretrained=True,dropout_rate=None,use_checkpoint=False,aggregator='mean'):
1299
+ """
1300
+ Fuse multiple trained models by combining their parameters using a specified aggregation method.
1044
1301
 
1302
+ Args:
1303
+ model_paths (list of str): List of paths to model checkpoints to be fused.
1304
+ save_path (str): Path where the fused model will be saved.
1305
+ device (str): Device to use ('cpu' or 'cuda').
1306
+ model_name (str): Model architecture to use when initializing.
1307
+ pretrained (bool): Whether to initialize with pretrained weights.
1308
+ dropout_rate (float or None): Dropout rate to apply to the model.
1309
+ use_checkpoint (bool): Whether to use gradient checkpointing.
1310
+ aggregator (str): Aggregation strategy to combine weights. One of {'mean', 'geomean', 'median', 'sum', 'max', 'min'}.
1311
+
1312
+ Returns:
1313
+ TorchModel: The fused model with combined weights.
1314
+ """
1045
1315
  from .utils import TorchModel
1046
1316
 
1047
1317
  if save_path.endswith('.pth'):
@@ -1141,14 +1411,33 @@ def model_fusion(model_paths,save_path,device='cpu',model_name='maxvit_t',pretra
1141
1411
  return fused_model
1142
1412
 
1143
1413
  def annotate_filter_vision(settings):
1144
-
1414
+ """
1415
+ Annotate and filter a CSV file with experimental metadata and optionally remove training samples.
1416
+
1417
+ Args:
1418
+ settings (dict): Configuration dictionary with keys:
1419
+ - 'src' (str or list): Paths to CSV annotation files.
1420
+ - 'cells' (dict): Mapping of cell types to annotation labels.
1421
+ - 'cell_loc' (str): Column name for cell type annotations.
1422
+ - 'pathogens' (dict): Mapping of pathogens to annotation labels.
1423
+ - 'pathogen_loc' (str): Column name for pathogen annotations.
1424
+ - 'treatments' (dict): Mapping of treatments to annotation labels.
1425
+ - 'treatment_loc' (str): Column name for treatment annotations.
1426
+ - 'filter_column' (str or None): Column to filter on.
1427
+ - 'upper_threshold' (float): Upper bound for filtering.
1428
+ - 'lower_threshold' (float): Lower bound for filtering.
1429
+ - 'remove_train' (bool): If True, removes rows present in training folders.
1430
+
1431
+ Returns:
1432
+ None. Saves filtered and annotated CSVs to disk.
1433
+ """
1145
1434
  from .utils import annotate_conditions, correct_metadata
1146
1435
 
1147
1436
  def filter_csv_by_png(csv_file):
1148
1437
  """
1149
1438
  Filters a DataFrame by removing rows that match PNG filenames in a folder.
1150
1439
 
1151
- Parameters:
1440
+ Args:
1152
1441
  csv_file (str): Path to the CSV file.
1153
1442
 
1154
1443
  Returns: