spacr 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/submodules.py CHANGED
@@ -30,7 +30,37 @@ from natsort import natsorted
30
30
  from torch.utils.data import Dataset
31
31
 
32
32
  class CellposeLazyDataset(Dataset):
33
+ """
34
+ A PyTorch Dataset for lazy loading and optional augmentation of images and segmentation masks
35
+ for training Cellpose-based models.
36
+
37
+ Images and labels are loaded from file paths on-the-fly, optionally normalized and resized,
38
+ and can be augmented with basic geometric transformations.
39
+
40
+ Args:
41
+ image_files (list of str): List of file paths to image files.
42
+ label_files (list of str): List of file paths to corresponding label files.
43
+ settings (dict): Dictionary containing dataset settings:
44
+ - 'normalize' (bool): Whether to apply percentile-based intensity normalization.
45
+ - 'percentiles' (list of int): Two-element list specifying lower and upper percentiles (default: [2, 99]).
46
+ - 'target_size' (int): Desired output image size (height and width).
47
+ randomize (bool, optional): Whether to shuffle the dataset order. Defaults to True.
48
+ augment (bool, optional): Whether to apply 8-way geometric data augmentation. Defaults to False.
49
+ """
33
50
  def __init__(self, image_files, label_files, settings, randomize=True, augment=False):
51
+ """
52
+ Initialize the CellposeLazyDataset.
53
+
54
+ Args:
55
+ image_files (list of str): Paths to input image files.
56
+ label_files (list of str): Paths to corresponding label files.
57
+ settings (dict): Configuration dictionary with keys:
58
+ - 'normalize' (bool)
59
+ - 'percentiles' (list of int)
60
+ - 'target_size' (int)
61
+ randomize (bool, optional): Shuffle file order. Defaults to True.
62
+ augment (bool, optional): Enable 8-fold augmentation. Defaults to False.
63
+ """
34
64
  combined = list(zip(image_files, label_files))
35
65
  if randomize:
36
66
  random.shuffle(combined)
@@ -41,9 +71,30 @@ class CellposeLazyDataset(Dataset):
41
71
  self.augment = augment
42
72
 
43
73
  def __len__(self):
74
+ """
75
+ Return the number of samples in the dataset.
76
+
77
+ If augmentation is enabled, each sample contributes 8 variants.
78
+
79
+ Returns:
80
+ int: Total number of samples (augmented if applicable).
81
+ """
44
82
  return len(self.image_files) * (8 if self.augment else 1)
45
83
 
46
84
  def apply_augmentation(self, image, label, aug_idx):
85
+ """
86
+ Apply one of 8 geometric augmentations to an image-label pair.
87
+
88
+ Augmentations include rotations (90°, 180°, 270°) and horizontal/vertical flips.
89
+
90
+ Args:
91
+ image (ndarray): Input image array.
92
+ label (ndarray): Corresponding label array.
93
+ aug_idx (int): Index from 0 to 7 specifying the augmentation to apply.
94
+
95
+ Returns:
96
+ tuple: Augmented (image, label) pair.
97
+ """
47
98
  if aug_idx == 1:
48
99
  return rotate(image, 90, resize=False, preserve_range=True), rotate(label, 90, resize=False, preserve_range=True)
49
100
  elif aug_idx == 2:
@@ -61,6 +112,20 @@ class CellposeLazyDataset(Dataset):
61
112
  return image, label
62
113
 
63
114
  def __getitem__(self, idx):
115
+ """
116
+ Retrieve a sample by index, optionally applying augmentation and preprocessing.
117
+
118
+ Loads the image and label, normalizes intensity if specified, applies augmentation,
119
+ and resizes to the target shape.
120
+
121
+ Args:
122
+ idx (int): Index of the sample to retrieve.
123
+
124
+ Returns:
125
+ tuple:
126
+ - image (np.ndarray): Preprocessed image, shape (target_size, target_size), dtype float32.
127
+ - label (np.ndarray): Resized label mask, shape (target_size, target_size), dtype uint8.
128
+ """
64
129
  base_idx = idx // 8 if self.augment else idx
65
130
  aug_idx = idx % 8 if self.augment else 0
66
131
 
@@ -86,7 +151,32 @@ class CellposeLazyDataset(Dataset):
86
151
  return image, label
87
152
 
88
153
  def train_cellpose(settings):
89
-
154
+ """
155
+ Train a Cellpose model on custom images and masks using specified settings.
156
+
157
+ This function prepares training data from `train/images` and `train/masks` subfolders within
158
+ the provided `settings['src']` directory. It constructs a model name based on training parameters,
159
+ initializes the Cellpose model, and trains it using the specified number of epochs and hyperparameters.
160
+
161
+ The dataset can be augmented up to 8-fold, and training images and masks are matched by filename.
162
+ Training progress is visualized (if possible), and the model is saved in `models/cellpose_model`.
163
+
164
+ Args:
165
+ settings (dict): Dictionary with the following required keys:
166
+ - 'src' (str): Root directory containing `train/images`, `train/masks`, and `models`.
167
+ - 'target_size' (int): Side length to which images/masks are resized.
168
+ - 'model_name' (str): Base name of the model.
169
+ - 'n_epochs' (int): Number of training epochs.
170
+ - 'batch_size' (int): Number of images to train per batch.
171
+ - 'learning_rate' (float): Learning rate for training.
172
+ - 'weight_decay' (float): Weight decay (L2 regularization).
173
+ - 'augment' (bool): Whether to use 8-fold data augmentation.
174
+
175
+ Side Effects:
176
+ - Trains a model and saves it to `models/cellpose_model/`.
177
+ - Writes training settings using `save_settings()`.
178
+ - Optionally visualizes a training batch using `plot_cellpose_batch()`.
179
+ """
90
180
  from .settings import get_train_cellpose_default_settings
91
181
  from .utils import save_settings
92
182
 
@@ -156,7 +246,35 @@ def train_cellpose(settings):
156
246
  print(f"Model saved at: {model_save_path}/{model_name}")
157
247
 
158
248
  def test_cellpose_model(settings):
159
-
249
+ """
250
+ Evaluate a pretrained Cellpose model on a test dataset and report segmentation performance.
251
+
252
+ This function loads test images and ground-truth masks from a specified directory structure,
253
+ applies the Cellpose model to predict masks, and compares them to the ground truth using
254
+ object-level metrics and Jaccard index. Results are saved and visualized if specified.
255
+
256
+ Args:
257
+ settings (dict): Dictionary with the following keys:
258
+ - 'src' (str): Root directory containing `test/images` and `test/masks`.
259
+ - 'model_path' (str): Path to the pretrained Cellpose model.
260
+ - 'batch_size' (int): Number of images per inference batch.
261
+ - 'FT' (float): Flow threshold for Cellpose segmentation.
262
+ - 'CP_probability' (float): Cell probability threshold for segmentation.
263
+ - 'target_size' (int): Size to which input images are resized.
264
+ - 'save' (bool): Whether to save segmentation visualizations and results.
265
+
266
+ Side Effects:
267
+ - Loads and processes image and mask files from `test/images` and `test/masks`.
268
+ - Evaluates segmentation and computes the following metrics per image:
269
+ * Jaccard index
270
+ * Number of predicted and ground-truth objects
271
+ * Mean object area (true and predicted)
272
+ * True positives, false positives, false negatives
273
+ * Precision, recall, F1-score, accuracy
274
+ - Saves figures showing input, true mask, predicted mask, probability map, and flow.
275
+ - Optionally saves results to `results/test_results.csv`.
276
+ - Prints performance summary to stdout.
277
+ """
160
278
  from .utils import save_settings, print_progress
161
279
  from .settings import get_default_test_cellpose_model_settings
162
280
 
@@ -342,7 +460,32 @@ def test_cellpose_model(settings):
342
460
  df_results.to_csv(os.path.join(results_dir, 'test_results.csv'), index=False)
343
461
 
344
462
  def apply_cellpose_model(settings):
345
-
463
+ """
464
+ Apply a pretrained Cellpose model to a folder of images and extract object-level measurements.
465
+
466
+ This function processes all `.tif` images in the specified source directory using a pretrained
467
+ Cellpose model, generates segmentation masks, and computes region properties (e.g., area) for each
468
+ detected object. It optionally applies circular masking to limit the field of view and saves results
469
+ and figures if enabled.
470
+
471
+ Args:
472
+ settings (dict): Dictionary with the following required keys:
473
+ - 'src' (str): Directory containing the input `.tif` images.
474
+ - 'model_path' (str): Path to the pretrained Cellpose model file.
475
+ - 'batch_size' (int): Number of images processed per batch.
476
+ - 'FT' (float): Flow threshold for Cellpose segmentation.
477
+ - 'CP_probability' (float): Cell probability threshold for segmentation.
478
+ - 'target_size' (int): Resize target for input images.
479
+ - 'save' (bool): If True, saves visualizations and CSV results.
480
+ - 'circularize' (bool, optional): If True, applies a circular mask to each prediction.
481
+
482
+ Side Effects:
483
+ - Generates and saves segmented mask visualizations (if `save=True`) to `results/`.
484
+ - Extracts per-object measurements (area) and saves:
485
+ * `results/measurements.csv`: one row per object.
486
+ * `results/summary.csv`: average object area and count per image.
487
+ - Displays progress and timing for each batch.
488
+ """
346
489
  from .settings import get_default_apply_cellpose_model_settings
347
490
  from .utils import save_settings, print_progress
348
491
 
@@ -462,6 +605,17 @@ def apply_cellpose_model(settings):
462
605
  print("Saved object count and average area to summary.csv")
463
606
 
464
607
  def plot_cellpose_batch(images, labels):
608
+ """
609
+ Display a batch of input images and their corresponding label masks.
610
+
611
+ This function plots two rows of subplots:
612
+ - Top row: grayscale input images.
613
+ - Bottom row: corresponding label masks with randomly colored regions.
614
+
615
+ Args:
616
+ images (list of np.ndarray): List of 2D grayscale input images.
617
+ labels (list of np.ndarray): List of 2D label masks corresponding to the images.
618
+ """
465
619
  from .plot import generate_mask_random_cmap
466
620
 
467
621
  cmap_lbl = generate_mask_random_cmap(labels)
@@ -477,6 +631,31 @@ def plot_cellpose_batch(images, labels):
477
631
  plt.show()
478
632
 
479
633
  def analyze_percent_positive(settings):
634
+ """
635
+ Analyze the fraction of objects above a threshold per well and annotate them accordingly.
636
+
637
+ This function loads merged object-level measurements from a SQLite database, applies optional
638
+ filtering, and annotates each object as 'above' or 'below' a specified threshold on a value column.
639
+ It then summarizes the number and fraction of above-threshold objects per condition and well.
640
+ Results are merged with well metadata and saved as a CSV.
641
+
642
+ Args:
643
+ settings (dict): Dictionary with the following required keys:
644
+ - 'src' (str): Root directory containing `measurements.db` and `rename_log.csv`.
645
+ - 'tables' (list of str): Table names to extract from the SQLite database.
646
+ - 'value_col' (str): Column to apply the threshold to.
647
+ - 'threshold' (float): Threshold value to classify objects.
648
+ - 'filter_1' (tuple or None): Optional filter in the form (column, min_value), or None.
649
+
650
+ Returns:
651
+ pd.DataFrame: Merged DataFrame with annotation summary and metadata, also saved to `result.csv`.
652
+
653
+ Side Effects:
654
+ - Reads measurement data from `measurements/measurements.db`.
655
+ - Reads well metadata from `rename_log.csv`.
656
+ - Writes annotated results to `result.csv` in the same directory.
657
+ - Displays the final DataFrame in notebook or console.
658
+ """
480
659
  from .io import _read_and_merge_data
481
660
  from .utils import save_settings
482
661
  from .settings import default_settings_analyze_percent_positive
@@ -484,6 +663,26 @@ def analyze_percent_positive(settings):
484
663
  settings = default_settings_analyze_percent_positive(settings)
485
664
 
486
665
  def translate_well_in_df(csv_loc):
666
+ """
667
+ Extract and translate well metadata from a CSV file containing renamed TIFF filenames.
668
+
669
+ This function parses the 'Renamed TIFF' column to extract plate and well information,
670
+ generates plate-well identifiers, and translates well positions to row and column IDs.
671
+ It also constructs a unique `prc` (plate_row_column) identifier for each well.
672
+
673
+ Args:
674
+ csv_loc (str): Path to a CSV file containing a 'Renamed TIFF' column.
675
+
676
+ Returns:
677
+ pd.DataFrame: DataFrame with columns:
678
+ - 'plateID': extracted plate identifier
679
+ - 'well': extracted well identifier (e.g., A01)
680
+ - 'plate_well': combined plate and well string
681
+ - 'rowID': converted row index (e.g., 'r1' for row A)
682
+ - 'column_name': converted column index (e.g., 'c1' for column 01)
683
+ - 'fieldID': field identifier (default: 'f1')
684
+ - 'prc': combined identifier in the form 'p<plate>_r<row>_c<column>'
685
+ """
487
686
  # Load and extract metadata
488
687
  df = pd.read_csv(csv_loc)
489
688
  df[['plateID', 'well']] = df['Renamed TIFF'].str.replace('.tif', '', regex=False).str.split('_', expand=True)[[0, 1]]
@@ -564,15 +763,39 @@ def analyze_percent_positive(settings):
564
763
 
565
764
  def analyze_recruitment(settings):
566
765
  """
567
- Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
568
-
569
- Parameters:
570
- settings (dict): settings.
766
+ Analyze host protein recruitment to the pathogen vacuole across experimental conditions.
767
+
768
+ This function loads object-level measurements, applies size and intensity-based filtering, computes
769
+ recruitment scores as the ratio of pathogen-to-cytoplasm intensities, and aggregates results at the
770
+ cell and well level. Recruitment is annotated and visualized, and filtered results are saved.
771
+
772
+ Args:
773
+ settings (dict): Dictionary with the following required keys:
774
+ - 'src' (str): Path to the root folder or `measurements.db` file.
775
+ - 'cell_types', 'pathogen_types', 'treatments' (list of str): Experimental conditions to annotate.
776
+ - 'cell_plate_metadata', 'pathogen_plate_metadata', 'treatment_plate_metadata' (str): Metadata sources.
777
+ - 'channel_of_interest' (int): Channel index for recruitment quantification.
778
+ - 'channel_dims' (list of int): Channels to calculate recruitment for.
779
+ - 'cell_chann_dim', 'nucleus_chann_dim', 'pathogen_chann_dim' (int or None): Mask channel assignments.
780
+ - 'cell_size_range', 'nucleus_size_range', 'pathogen_size_range' (list): Min/max object size filters.
781
+ - 'cell_intensity_range', etc. (list): Intensity-based object filters.
782
+ - 'target_intensity_min' (float): Minimum required target channel intensity in cell.
783
+ - 'nuclei_limit', 'pathogen_limit' (int or None): Max number of nuclei/pathogens to load.
784
+ - 'cells_per_well' (int): Minimum number of cells per well to retain.
785
+ - 'plot' (bool): Whether to plot cell outlines on merged images.
786
+ - 'plot_control' (bool): Whether to plot control comparisons.
787
+ - 'plot_nr' (int): Maximum number of outlines to plot.
788
+ - 'figuresize' (float): Size of recruitment plots in inches.
571
789
 
572
790
  Returns:
573
- None
791
+ list: [cell-level DataFrame, well-level DataFrame] after filtering and recruitment analysis.
792
+
793
+ Side Effects:
794
+ - Reads from and writes to files in the source directory.
795
+ - Plots recruitment distributions and controls.
796
+ - Saves `cell_level_results.csv` and `well_level_results.csv` via `_results_to_csv`.
797
+ - Displays status messages and summary statistics.
574
798
  """
575
-
576
799
  from .io import _read_and_merge_data, _results_to_csv
577
800
  from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
578
801
  from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
@@ -686,7 +909,30 @@ def analyze_recruitment(settings):
686
909
  return [cells,wells]
687
910
 
688
911
  def analyze_plaques(settings):
689
-
912
+ """
913
+ Analyze plaque-like structures from microscopy images using a pretrained Cellpose model.
914
+
915
+ This function applies a custom-trained Cellpose model to identify and segment plaques in `.tif` images.
916
+ It calculates object-level statistics including object count and area, and saves the results into a
917
+ SQLite database with three tables: `summary`, `details`, and `stats`.
918
+
919
+ Args:
920
+ settings (dict): Dictionary with the following keys:
921
+ - 'src' (str): Source folder containing input `.tif` images.
922
+ - 'masks' (bool): If True, (re)generate segmentation masks using Cellpose.
923
+ - 'custom_model' (str, optional): Will be set internally to the plaque model path.
924
+ - Any keys required by `get_analyze_plaque_settings()` or `identify_masks_finetune()`.
925
+
926
+ Side Effects:
927
+ - Downloads pretrained plaque segmentation model if not available.
928
+ - Saves generated masks (if `masks=True`) to `src/masks/`.
929
+ - Analyzes all `.tif` masks in the folder.
930
+ - Creates and saves the following tables to `plaques_analysis.db`:
931
+ * `summary`: file name, number of plaques, average size
932
+ * `details`: file name, individual plaque sizes
933
+ * `stats`: file name, plaque count, average size, and standard deviation
934
+ - Prints progress and completion messages.
935
+ """
690
936
  from .spacr_cellpose import identify_masks_finetune
691
937
  from .settings import get_analyze_plaque_settings
692
938
  from .utils import save_settings, download_models
@@ -751,6 +997,29 @@ def analyze_plaques(settings):
751
997
  print(f"Analysis completed and saved to database '{db_name}'.")
752
998
 
753
999
  def count_phenotypes(settings):
1000
+ """
1001
+ Count and summarize annotated phenotypes in a measurement database.
1002
+
1003
+ This function reads annotation data from the `measurements.db` database and performs the following:
1004
+ - Counts the number of unique phenotype labels in the specified annotation column.
1005
+ - Computes how many unique labels occur per well (defined by plateID, rowID, columnID).
1006
+ - Computes the count of each phenotype value per well.
1007
+ - Outputs a CSV file with the per-well phenotype counts.
1008
+
1009
+ Args:
1010
+ settings (dict): Dictionary with the following keys:
1011
+ - 'src' (str): Path to a measurement database or directory containing it.
1012
+ - 'annotation_column' (str): Column name containing the phenotype annotations.
1013
+
1014
+ Returns:
1015
+ None
1016
+
1017
+ Side Effects:
1018
+ - Displays number of unique phenotype values.
1019
+ - Displays a table of unique value counts per well.
1020
+ - Saves a CSV file `phenotype_counts.csv` summarizing value counts per well
1021
+ (one row per well, one column per phenotype).
1022
+ """
754
1023
  from .io import _read_db
755
1024
 
756
1025
  if not settings['src'].endswith('/measurements/measurements.db'):
@@ -794,8 +1063,63 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
794
1063
  pc_grna='TGGT1_220950_1', nc_grna='TGGT1_233460_4',
795
1064
  y_columns=['class_1_fraction', 'TGGT1_220950_1_fraction', 'nc_fraction'],
796
1065
  column='columnID', value='c3', plate=None, save_paths=None):
1066
+ """
1067
+ Compare Cellpose-based phenotypic classification scores with sequencing read distributions
1068
+ from positive and negative control gRNAs across plate wells.
1069
+
1070
+ This function merges phenotype classification scores and read count data, calculates gRNA
1071
+ fractions and class predictions per well, and overlays empirical expectations for controls
1072
+ to assess model calibration. It generates and saves line plots for selected y-axis columns.
1073
+
1074
+ Args:
1075
+ reads_csv (str or list of str): Path(s) to CSV file(s) containing read counts and gRNA names.
1076
+ scores_csv (str or list of str): Path(s) to CSV file(s) with classification results per object.
1077
+ empirical_dict (dict): Mapping of rowID to (positive, negative) control counts.
1078
+ pc_grna (str): Name of the positive control gRNA.
1079
+ nc_grna (str): Name of the negative control gRNA.
1080
+ y_columns (list of str): Columns in the merged dataframe to plot on y-axis.
1081
+ column (str): Column to subset on (e.g., 'columnID').
1082
+ value (str): Value to filter the column on (e.g., 'c3').
1083
+ plate (str, optional): Plate name to assign if not present in the CSVs.
1084
+ save_paths (list of str, optional): List of two paths to save the plots as PDF.
797
1085
 
1086
+ Returns:
1087
+ list: Two matplotlib Figure objects, one for pc_fraction and one for nc_fraction x-axis plots.
1088
+
1089
+ Side Effects:
1090
+ - Reads and merges input CSVs.
1091
+ - Computes gRNA read fractions and classification fractions per well.
1092
+ - Merges empirical expectations based on rowID.
1093
+ - Displays merged DataFrame.
1094
+ - Saves two line plots (if `save_paths` is specified).
1095
+ """
1096
+
798
1097
  def calculate_well_score_fractions(df, class_columns='cv_predictions'):
1098
+ """
1099
+ Calculate the fraction of predicted classes (e.g., class_0 and class_1) per well.
1100
+
1101
+ This function groups a DataFrame by plate, row, and column IDs to compute the total number
1102
+ of objects and the number of objects predicted in each class per well. It returns a summary
1103
+ DataFrame that includes the absolute and relative frequencies of class 0 and class 1 predictions.
1104
+
1105
+ Args:
1106
+ df (pd.DataFrame): Input DataFrame containing at least:
1107
+ - 'plateID', 'rowID', 'columnID': to define wells
1108
+ - A column specified by `class_columns` with predicted class labels (e.g., 0, 1)
1109
+ class_columns (str): Name of the column containing predicted class labels. Defaults to 'cv_predictions'.
1110
+
1111
+ Returns:
1112
+ pd.DataFrame: Summary DataFrame with columns:
1113
+ - 'plateID', 'rowID', 'columnID', 'prc'
1114
+ - 'total_rows': total number of entries per well
1115
+ - 'class_0': number of class 0 predictions
1116
+ - 'class_1': number of class 1 predictions
1117
+ - 'class_0_fraction': class_0 / total_rows
1118
+ - 'class_1_fraction': class_1 / total_rows
1119
+
1120
+ Raises:
1121
+ ValueError: If required columns ('plateID', 'rowID', 'columnID') are missing in `df`.
1122
+ """
799
1123
  if all(col in df.columns for col in ['plateID', 'rowID', 'columnID']):
800
1124
  df['prc'] = df['plateID'] + '_' + df['rowID'] + '_' + df['columnID']
801
1125
  else:
@@ -814,9 +1138,31 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
814
1138
  def plot_line(df, x_column, y_columns, group_column=None, xlabel=None, ylabel=None,
815
1139
  title=None, figsize=(10, 6), save_path=None, theme='deep'):
816
1140
  """
817
- Create a line plot that can handle multiple y-columns, each becoming a separate line.
818
- """
1141
+ Create a customizable line plot for one or more y-columns against a common x-axis.
1142
+
1143
+ This function supports grouped or ungrouped line plots and includes options for custom axis
1144
+ labels, title, figure size, and color themes. When multiple y-columns are specified, each is
1145
+ plotted as an individual line. Optionally saves the plot as a high-resolution PDF.
1146
+
1147
+ Args:
1148
+ df (pd.DataFrame): Input DataFrame containing the plotting data.
1149
+ x_column (str): Column name to use for the x-axis.
1150
+ y_columns (str or list of str): Column(s) to use for the y-axis. If a list, each column is plotted as a separate line.
1151
+ group_column (str, optional): Column name for grouping lines (only used if `y_columns` is a single string).
1152
+ xlabel (str, optional): Label for the x-axis. Defaults to `x_column` name.
1153
+ ylabel (str, optional): Label for the y-axis. Defaults to 'Value'.
1154
+ title (str, optional): Title of the plot. Defaults to 'Line Plot'.
1155
+ figsize (tuple, optional): Figure size in inches. Defaults to (10, 6).
1156
+ save_path (str, optional): If provided, saves the figure as a PDF to this path.
1157
+ theme (str, optional): Seaborn color palette theme. Defaults to 'deep'.
819
1158
 
1159
+ Returns:
1160
+ matplotlib.figure.Figure: The generated figure object.
1161
+
1162
+ Side Effects:
1163
+ - Displays the plot.
1164
+ - Saves the figure as a PDF if `save_path` is specified.
1165
+ """
820
1166
  def _set_theme(theme):
821
1167
  """Set the Seaborn theme and reorder colors if necessary."""
822
1168
 
@@ -880,6 +1226,30 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
880
1226
  return fig
881
1227
 
882
1228
  def calculate_grna_fraction_ratio(df, grna1='TGGT1_220950_1', grna2='TGGT1_233460_4'):
1229
+ """
1230
+ Calculate the ratio and fractional abundance of two gRNAs per well.
1231
+
1232
+ This function filters the input DataFrame for two specified gRNAs, aggregates their read
1233
+ fractions and counts per well (`prc`), and computes:
1234
+ - the fraction ratio (grna1 / grna2),
1235
+ - total read count per well,
1236
+ - individual gRNA read fractions per well.
1237
+
1238
+ Args:
1239
+ df (pd.DataFrame): Input DataFrame with columns ['prc', 'grna_name', 'fraction', 'count'].
1240
+ grna1 (str): Name of the positive control or reference gRNA.
1241
+ grna2 (str): Name of the negative control or comparison gRNA.
1242
+
1243
+ Returns:
1244
+ pd.DataFrame: A DataFrame with columns:
1245
+ - 'prc': well identifier (plate_row_column)
1246
+ - '<grna1>_count': total read count of grna1 per well
1247
+ - '<grna2>_count': total read count of grna2 per well
1248
+ - 'fraction_ratio': grna1 fraction / grna2 fraction
1249
+ - 'total_reads': total reads of both gRNAs
1250
+ - '<grna1>_fraction': grna1 reads / total_reads
1251
+ - '<grna2>_fraction': grna2 reads / total_reads
1252
+ """
883
1253
  # Filter relevant grna_names within each prc and group them
884
1254
  grouped = df[df['grna_name'].isin([grna1, grna2])] \
885
1255
  .groupby(['prc', 'grna_name']) \
@@ -901,6 +1271,28 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
901
1271
  return result
902
1272
 
903
1273
  def calculate_well_read_fraction(df, count_column='count'):
1274
+ """
1275
+ Calculate the fractional read count for each gRNA within a well.
1276
+
1277
+ This function computes the total number of reads per well (identified by plateID, rowID, columnID),
1278
+ and then computes the fraction of reads contributed by each gRNA relative to the total reads
1279
+ in that well.
1280
+
1281
+ Args:
1282
+ df (pd.DataFrame): DataFrame containing at least the columns:
1283
+ - 'plateID', 'rowID', 'columnID' (to define wells)
1284
+ - `count_column` (default: 'count'), the read count per gRNA instance
1285
+ count_column (str): Name of the column holding raw read counts. Defaults to 'count'.
1286
+
1287
+ Returns:
1288
+ pd.DataFrame: Original DataFrame with added columns:
1289
+ - 'prc': combined plate_row_column identifier
1290
+ - 'total_counts': total reads per well
1291
+ - 'fraction': count / total_counts for each row
1292
+
1293
+ Raises:
1294
+ ValueError: If any of 'plateID', 'rowID', or 'columnID' are missing in the input DataFrame.
1295
+ """
904
1296
  if all(col in df.columns for col in ['plateID', 'rowID', 'columnID']):
905
1297
  df['prc'] = df['plateID'] + '_' + df['rowID'] + '_' + df['columnID']
906
1298
  else:
@@ -969,7 +1361,49 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
969
1361
  return [fig_1, fig_2]
970
1362
 
971
1363
  def interperate_vision_model(settings={}):
972
-
1364
+ """
1365
+ Interpret a vision-based machine learning model using multiple feature attribution strategies.
1366
+
1367
+ This function reads merged object-level data and model scores, engineers relative features, and applies:
1368
+ - Random Forest feature importance
1369
+ - Permutation importance
1370
+ - SHAP (SHapley Additive exPlanations) values
1371
+
1372
+ The model interpretation is performed across cellular compartments and imaging channels, and
1373
+ results can be visualized or saved for downstream analysis.
1374
+
1375
+ Args:
1376
+ settings (dict): Dictionary of configuration parameters including:
1377
+ - 'src' (str): Path to the data directory.
1378
+ - 'scores' (str): Path to the score CSV file.
1379
+ - 'tables' (list of str): Feature group keywords for compartmental grouping.
1380
+ - 'channels' (list of str): Feature group keywords for channel grouping.
1381
+ - 'score_column' (str): Column name for predicted class scores.
1382
+ - 'nuclei_limit' (int): Optional filter on number of nuclei to load.
1383
+ - 'pathogen_limit' (int): Optional filter on number of pathogens to load.
1384
+ - 'feature_importance' (bool): Whether to compute Random Forest feature importances.
1385
+ - 'permutation_importance' (bool): Whether to compute permutation-based importances.
1386
+ - 'shap' (bool): Whether to compute SHAP values.
1387
+ - 'shap_sample' (bool): Whether to subsample SHAP input for speed.
1388
+ - 'top_features' (int): Number of top features to plot.
1389
+ - 'include_all' (bool): Whether to include a row summing all feature groups.
1390
+ - 'n_jobs' (int): Number of parallel jobs to use for model training or SHAP.
1391
+ - 'save' (bool): Whether to save output CSV files.
1392
+
1393
+ Returns:
1394
+ dict: A dictionary of interpretation outputs with the following possible keys:
1395
+ - 'feature_importance': Raw feature importances (Random Forest)
1396
+ - 'feature_importance_compartment': Grouped by compartment
1397
+ - 'feature_importance_channel': Grouped by imaging channel
1398
+ - 'permutation_importance': Permutation importance values
1399
+ - 'shap': SHAP values for each object-feature pair
1400
+
1401
+ Side Effects:
1402
+ - Displays importance plots.
1403
+ - Optionally saves CSVs to `<src>/results/`.
1404
+ - Merges and preprocesses input data, including relative feature construction.
1405
+ """
1406
+
973
1407
  from .io import _read_and_merge_data
974
1408
 
975
1409
  def generate_comparison_columns(df, compartments=['cell', 'nucleus', 'pathogen', 'cytoplasm']):
@@ -1016,7 +1450,26 @@ def interperate_vision_model(settings={}):
1016
1450
  return df, comparison_dict
1017
1451
 
1018
1452
  def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pathogen'], name='compartment', include_all=False):
1453
+ """
1454
+ Group features by compartment or channel and sum their importance scores.
1455
+
1456
+ This function identifies the compartment or channel associated with each feature based on
1457
+ string matching, and computes the total importance of each group. Optionally, it includes
1458
+ a row summing the total importance across all groups.
1019
1459
 
1460
+ Args:
1461
+ df (pd.DataFrame): DataFrame with at least the columns:
1462
+ - 'feature': feature names to classify
1463
+ - 'importance': importance values associated with each feature
1464
+ feature_groups (list of str): Keywords used to assign features to groups (e.g., compartments or channels).
1465
+ name (str): Column name to assign group classification (e.g., 'compartment' or 'channel').
1466
+ include_all (bool): If True, include a summary row for total importance across all groups.
1467
+
1468
+ Returns:
1469
+ pd.DataFrame: DataFrame with columns:
1470
+ - `name` (e.g., 'compartment', 'channel')
1471
+ - `<name>_importance_sum`: total importance per group
1472
+ """
1020
1473
  # Function to determine compartment based on multiple matches
1021
1474
  def find_feature_class(feature, compartments):
1022
1475
  matches = [compartment for compartment in compartments if re.search(compartment, feature)]
@@ -1049,6 +1502,23 @@ def interperate_vision_model(settings={}):
1049
1502
 
1050
1503
  # Function to create radar plot for individual and combined values
1051
1504
  def create_extended_radar_plot(values, labels, title):
1505
+ """
1506
+ Create a radar (spider) plot to visualize multivariate values on a circular axis.
1507
+
1508
+ This function plots a closed polygon where each axis represents a different category
1509
+ (e.g., feature group), allowing for intuitive comparison of relative magnitudes.
1510
+
1511
+ Args:
1512
+ values (list or array-like): List of values to plot. Should match the length of `labels`.
1513
+ labels (list of str): Labels for each axis in the radar chart.
1514
+ title (str): Title of the plot.
1515
+
1516
+ Returns:
1517
+ None
1518
+
1519
+ Side Effects:
1520
+ - Displays the radar chart using matplotlib.
1521
+ """
1052
1522
  values = list(values) + [values[0]] # Close the loop for radar chart
1053
1523
  angles = [n / float(len(labels)) * 2 * pi for n in range(len(labels))]
1054
1524
  angles += angles[:1]
@@ -1063,6 +1533,22 @@ def interperate_vision_model(settings={}):
1063
1533
  plt.show()
1064
1534
 
1065
1535
  def extract_compartment_channel(feature_name):
1536
+ """
1537
+ Extract the compartment and imaging channel from a feature name.
1538
+
1539
+ This function parses a feature name string to identify:
1540
+ - The compartment (first substring before an underscore, with 'cells' mapped to 'cell')
1541
+ - The imaging channel(s), based on the presence of substrings like 'channel_0', etc.
1542
+
1543
+ Args:
1544
+ feature_name (str): The name of the feature to parse.
1545
+
1546
+ Returns:
1547
+ tuple:
1548
+ - compartment (str): The identified compartment (e.g., 'cell', 'nucleus').
1549
+ - channel (str): Channel label or 'morphology' if no channel found.
1550
+ If multiple channels are present, they are joined by ' + '.
1551
+ """
1066
1552
  # Identify compartment as the first part before an underscore
1067
1553
  compartment = feature_name.split('_')[0]
1068
1554
 
@@ -1089,7 +1575,34 @@ def interperate_vision_model(settings={}):
1089
1575
  return (compartment, channel)
1090
1576
 
1091
1577
  def read_and_preprocess_data(settings):
1578
+ """
1579
+ Load, preprocess, and merge measurement data and classification scores.
1580
+
1581
+ This function:
1582
+ - Reads merged object-level measurements from a SQLite database.
1583
+ - Adds relative comparison features between compartments (e.g., nucleus/cytoplasm).
1584
+ - Loads classification scores and aligns them with the measurement data.
1585
+ - Ensures consistent formatting of join keys (plateID, rowID, etc.).
1586
+ - Merges the two datasets on object identifiers and returns the feature matrix and labels.
1587
+
1588
+ Args:
1589
+ settings (dict): Dictionary with required keys:
1590
+ - 'src' (str): Base path to the dataset directory.
1591
+ - 'tables' (list of str): Tables to load from the database.
1592
+ - 'nuclei_limit' (int or None): Optional limit on number of nuclei to load.
1593
+ - 'pathogen_limit' (int or None): Optional limit on number of pathogens to load.
1594
+ - 'scores' (str): Path to the CSV file with prediction/classification scores.
1595
+ - 'score_column' (str): Column name in scores CSV used as prediction target.
1596
+
1597
+ Returns:
1598
+ tuple:
1599
+ - X (pd.DataFrame): Numerical feature matrix (excluding the score column).
1600
+ - y (pd.Series): Target values corresponding to the score column.
1601
+ - merged_df (pd.DataFrame): Full merged DataFrame including both features and scores.
1092
1602
 
1603
+ Side Effects:
1604
+ - Prints the number of columns in the final feature-expanded DataFrame.
1605
+ """
1093
1606
  df, _ = _read_and_merge_data(
1094
1607
  locs=[settings['src']+'/measurements/measurements.db'],
1095
1608
  tables=settings['tables'],
@@ -1267,6 +1780,38 @@ def interperate_vision_model(settings={}):
1267
1780
  return output
1268
1781
 
1269
1782
  def _plot_proportion_stacked_bars(settings, df, group_column, bin_column, prc_column='prc', level='object'):
1783
+ """
1784
+ Generate a stacked bar plot of bin proportions by group and perform a chi-squared test for association.
1785
+
1786
+ Depending on the `level` parameter, this function can:
1787
+ - Plot mean ± SD of binned proportions across wells (`level='well'`).
1788
+ - Plot object-level proportions per group (`level='object'`).
1789
+
1790
+ The function also calculates and prints the chi-squared test statistic and p-value
1791
+ for independence between group and bin category distributions.
1792
+
1793
+ Args:
1794
+ settings (dict): Dictionary with settings, must include:
1795
+ - 'um_per_px' (float or None): Micron-per-pixel scaling factor. If None, volume is in px³.
1796
+ df (pd.DataFrame): DataFrame containing binned volume information.
1797
+ group_column (str): Column name indicating group assignment (e.g., treatment, genotype).
1798
+ bin_column (str): Column name for the binned volume categories.
1799
+ prc_column (str): Column name specifying well ID. Only used if level='well'. Default is 'prc'.
1800
+ level (str): Either 'object' (default) or 'well', determines plot granularity.
1801
+
1802
+ Returns:
1803
+ tuple:
1804
+ - chi2 (float): Chi-squared test statistic from raw data.
1805
+ - p (float): p-value from the chi-squared test.
1806
+ - dof (int): Degrees of freedom from the test.
1807
+ - expected (ndarray): Expected frequencies under independence.
1808
+ - raw_counts (pd.DataFrame): Contingency table used for chi-squared test.
1809
+ - fig (matplotlib.figure.Figure): The figure object containing the plot.
1810
+
1811
+ Side Effects:
1812
+ - Displays a stacked bar plot using matplotlib.
1813
+ - Prints the chi-squared test result.
1814
+ """
1270
1815
  # Always calculate chi-squared on raw data
1271
1816
  raw_counts = df.groupby([group_column, bin_column]).size().unstack(fill_value=0)
1272
1817
  chi2, p, dof, expected = chi2_contingency(raw_counts)
@@ -1316,7 +1861,42 @@ def _plot_proportion_stacked_bars(settings, df, group_column, bin_column, prc_co
1316
1861
  return chi2, p, dof, expected, raw_counts, fig
1317
1862
 
1318
1863
  def analyze_endodyogeny(settings):
1319
-
1864
+ """
1865
+ Analyze intracellular pathogen replication (endodyogeny) based on compartment volume measurements.
1866
+
1867
+ This function bins single-cell or single-object volume estimates (e.g., pathogen area^1.5) into
1868
+ discrete categories, annotates experimental conditions, and visualizes the volume distribution
1869
+ across groups as a stacked bar plot. It also performs a chi-squared test to assess whether the
1870
+ volume distribution is significantly different between groups.
1871
+
1872
+ Args:
1873
+ settings (dict): Configuration dictionary with required keys:
1874
+ - 'src' (str or list): Path(s) to directories containing measurements.db.
1875
+ - 'tables' (list of str): Tables to extract from the SQLite database.
1876
+ - 'compartment' (str): Compartment to analyze (e.g., 'pathogen').
1877
+ - 'min_area_bin' (float): Minimum 2D area threshold to consider (in px² or µm²).
1878
+ - 'max_bins' (int or None): Maximum number of volume bins to display.
1879
+ - 'group_column' (str): Column defining experimental groups (e.g., genotype or treatment).
1880
+ - 'cell_types', 'pathogen_types', 'treatments' (list): Experimental groups for annotation.
1881
+ - 'cell_plate_metadata', 'pathogen_plate_metadata', 'treatment_plate_metadata' (str): Metadata source paths.
1882
+ - 'nuclei_limit', 'pathogen_limit' (int or None): Max number of objects to load per type.
1883
+ - 'change_plate' (bool): Whether to modify plate names for uniqueness.
1884
+ - 'verbose' (bool): Whether to print debug output.
1885
+ - 'um_per_px' (float or None): Pixel-to-micrometer conversion (if None, volume remains in px³).
1886
+ - 'level' (str): Either 'object' or 'well' for aggregation level.
1887
+ - 'cmap' (str): Colormap to use in the plot.
1888
+ - 'save' (bool): Whether to save results and figures.
1889
+
1890
+ Returns:
1891
+ dict: A dictionary with:
1892
+ - 'data': Annotated and binned DataFrame.
1893
+ - 'chi_squared': Chi-squared test summary DataFrame.
1894
+
1895
+ Side Effects:
1896
+ - Displays a stacked bar plot showing volume bin proportions by group.
1897
+ - Performs chi-squared test and prints p-value.
1898
+ - Saves data and plots to `results/analyze_endodyogeny/` if `save=True`.
1899
+ """
1320
1900
  from .utils import annotate_conditions, save_settings
1321
1901
  from .io import _read_and_merge_data
1322
1902
  from .settings import set_analyze_endodyogeny_defaults
@@ -1449,7 +2029,42 @@ def analyze_endodyogeny(settings):
1449
2029
  return output
1450
2030
 
1451
2031
  def analyze_class_proportion(settings):
1452
-
2032
+ """
2033
+ Analyze class frequency distributions across experimental groups and perform statistical tests.
2034
+
2035
+ This function compares the proportion of classification outcomes (e.g., phenotypic classes)
2036
+ across specified experimental groups. It generates stacked bar plots, performs chi-squared
2037
+ and parametric/non-parametric statistical tests, and saves annotated data and results.
2038
+
2039
+ Args:
2040
+ settings (dict): Dictionary of parameters, including:
2041
+ - 'src' (str or list): Path(s) to measurement directories.
2042
+ - 'tables' (list): Tables to load from measurements.db.
2043
+ - 'class_column' (str): Column indicating class or category.
2044
+ - 'group_column' (str): Column indicating grouping variable (e.g., genotype or treatment).
2045
+ - 'cell_types', 'pathogen_types', 'treatments' (list): Experimental group values.
2046
+ - 'cell_plate_metadata', 'pathogen_plate_metadata', 'treatment_plate_metadata' (str): Metadata paths.
2047
+ - 'nuclei_limit', 'pathogen_limit' (int or None): Optional load limits.
2048
+ - 'verbose' (bool): Whether to print debug information.
2049
+ - 'level' (str): Aggregation level for chi-squared test ('object' or 'well').
2050
+ - 'save' (bool): Whether to save results and figures.
2051
+
2052
+ Returns:
2053
+ dict: A dictionary containing:
2054
+ - 'data': Annotated DataFrame.
2055
+ - 'chi_squared': Chi-squared summary results.
2056
+
2057
+ Side Effects:
2058
+ - Annotates DataFrame with experimental conditions.
2059
+ - Performs chi-squared test on class distribution and plots stacked bars.
2060
+ - Plots class heatmap across plates.
2061
+ - Performs and prints:
2062
+ * Shapiro-Wilk normality test
2063
+ * Levene’s test for equal variances
2064
+ * ANOVA or Kruskal-Wallis test
2065
+ * Post hoc pairwise comparisons
2066
+ - Optionally saves all results and figures in a `results/analyze_class_proportion/` subfolder.
2067
+ """
1453
2068
  from .utils import annotate_conditions, save_settings
1454
2069
  from .io import _read_and_merge_data
1455
2070
  from .settings import set_analyze_class_proportion_defaults
@@ -1553,9 +2168,56 @@ def analyze_class_proportion(settings):
1553
2168
  return output
1554
2169
 
1555
2170
  def generate_score_heatmap(settings):
1556
-
2171
+ """
2172
+ Generate a heatmap comparing predicted classification scores from multiple models
2173
+ with empirical gRNA-based fraction measurements across wells.
2174
+
2175
+ This function:
2176
+ - Aggregates prediction scores from subfolders.
2177
+ - Calculates the gRNA-based empirical fraction from sequencing data.
2178
+ - Merges predictions with empirical scores and computes MAE (mean absolute error).
2179
+ - Displays and optionally saves a multi-channel heatmap and the corresponding MAE results.
2180
+
2181
+ Args:
2182
+ settings (dict): Configuration dictionary with the following keys:
2183
+ - 'folders' (list or str): Folder(s) containing model output subfolders with CSVs.
2184
+ - 'csv_name' (str): Name of the CSV file to search for in each model output subfolder.
2185
+ - 'data_column' (str): Column in CSVs containing the prediction scores.
2186
+ - 'cv_csv' (str): CSV with cross-validated model scores.
2187
+ - 'data_column_cv' (str): Column in `cv_csv` with scores to merge.
2188
+ - 'csv' (str): CSV with gRNA counts used for computing empirical fractions.
2189
+ - 'control_sgrnas' (list): List of two control gRNA names used to compute the fraction.
2190
+ - 'fraction_grna' (str): One of the control gRNAs to extract the empirical fraction from.
2191
+ - 'plateID' (int or str): Plate ID used for labeling and filtering.
2192
+ - 'columnID' (str): Column ID to filter on (e.g., 'c3').
2193
+ - 'cmap' (str): Colormap to use for the heatmap.
2194
+ - 'dst' (str or None): If provided, save outputs to this directory.
2195
+
2196
+ Returns:
2197
+ pd.DataFrame: Merged DataFrame containing all prediction and empirical score columns.
2198
+
2199
+ Side Effects:
2200
+ - Displays a multi-channel heatmap of prediction scores.
2201
+ - Computes and prints MAE per well per channel.
2202
+ - Saves merged data, MAE scores, and the heatmap as PDF/CSV if `dst` is set.
2203
+ """
1557
2204
  def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
1558
-
2205
+ """
2206
+ Group cross-validation prediction scores by plate, row, and column.
2207
+
2208
+ This function reads a CSV file containing predictions, filters by a specific column (e.g., 'c3'),
2209
+ assigns a plate ID, and computes the mean prediction score per well.
2210
+
2211
+ Args:
2212
+ csv (str): Path to the CSV file containing predictions.
2213
+ plate (int or str, optional): Plate identifier to assign. Defaults to 1.
2214
+ column (str, optional): Column ID to filter by (e.g., 'c3'). Defaults to 'c3'.
2215
+ data_column (str, optional): Name of the column with prediction scores. Defaults to 'pred'.
2216
+
2217
+ Returns:
2218
+ pd.DataFrame: Aggregated DataFrame with mean scores and a unique 'prc' well identifier.
2219
+ Columns: ['plateID', 'rowID', 'columnID', data_column, 'prc']
2220
+ """
1559
2221
  df = pd.read_csv(csv)
1560
2222
  if 'columnID' in df.columns:
1561
2223
  df = df[df['columnID']==column]
@@ -1569,6 +2231,25 @@ def generate_score_heatmap(settings):
1569
2231
  return grouped_df
1570
2232
 
1571
2233
  def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
2234
+ """
2235
+ Calculate the fraction of control sgRNAs in each well for a mixed condition.
2236
+
2237
+ This function filters a CSV containing sgRNA counts for a specified column and control sgRNAs,
2238
+ computes the total count per well, and calculates the fraction of each sgRNA within that well.
2239
+ A unique identifier ('prc') is added for each well.
2240
+
2241
+ Args:
2242
+ csv (str): Path to the input CSV file containing sgRNA data.
2243
+ plate (int or str, optional): Plate identifier to assign if not present. Defaults to 1.
2244
+ column (str, optional): Column name to filter by (e.g., 'c3'). Defaults to 'c3'.
2245
+ control_sgrnas (list of str, optional): List of two control sgRNA names to include in the calculation.
2246
+ Defaults to ['TGGT1_220950_1', 'TGGT1_233460_4'].
2247
+
2248
+ Returns:
2249
+ pd.DataFrame: DataFrame with fraction per sgRNA in each well and a unique 'prc' well identifier.
2250
+ Includes columns: ['plateID', 'rowID', 'column_name', 'grna_name', 'count',
2251
+ 'total_count', 'fraction', 'prc']
2252
+ """
1572
2253
  df = pd.read_csv(csv)
1573
2254
  df = df[df['column_name']==column]
1574
2255
  if plate not in df.columns:
@@ -1631,6 +2312,24 @@ def generate_score_heatmap(settings):
1631
2312
 
1632
2313
 
1633
2314
  def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
2315
+ """
2316
+ Combine classification scores from multiple CSV files across folders.
2317
+
2318
+ This function traverses given directories, locates specified CSV files within subfolders,
2319
+ filters by a target column, aggregates prediction scores by well position, and merges
2320
+ the results into a single DataFrame with each file contributing a new column of scores.
2321
+
2322
+ Args:
2323
+ folders (str or list of str): Path(s) to parent folders containing subfolders with CSV files.
2324
+ csv_name (str): Name of the CSV file to locate in each subfolder.
2325
+ data_column (str): Name of the column in each CSV containing prediction scores to aggregate.
2326
+ plate (int or str, optional): Identifier for the plate to add as 'plateID'. Defaults to 1.
2327
+ column (str, optional): Column name to filter data by (e.g., 'c3'). Defaults to 'c3'.
2328
+
2329
+ Returns:
2330
+ pd.DataFrame: Combined DataFrame with mean prediction scores per well from all found CSV files.
2331
+ Includes a 'prc' column for well-level identifiers and one score column per source.
2332
+ """
1634
2333
  # Ensure `folders` is a list
1635
2334
  if isinstance(folders, str):
1636
2335
  folders = [folders]
@@ -1678,7 +2377,20 @@ def generate_score_heatmap(settings):
1678
2377
 
1679
2378
  def calculate_mae(df):
1680
2379
  """
1681
- Calculate the MAE between each channel's predictions and the fraction column for all rows.
2380
+ Calculate mean absolute error (MAE) between prediction scores and reference fractions.
2381
+
2382
+ This function computes the MAE between each numeric prediction column and the
2383
+ reference 'fraction' column on a per-row basis. The results are returned as a
2384
+ long-form DataFrame with one row per (channel, row) pair.
2385
+
2386
+ Args:
2387
+ df (pd.DataFrame): Input DataFrame containing a 'fraction' column as ground truth,
2388
+ a 'prc' column as row identifier, and one or more numeric
2389
+ prediction columns.
2390
+
2391
+ Returns:
2392
+ pd.DataFrame: DataFrame with columns ['Channel', 'MAE', 'Row'], reporting the
2393
+ MAE between each prediction channel and the fraction per row.
1682
2394
  """
1683
2395
  # Extract numeric columns excluding 'fraction' and 'prc'
1684
2396
  channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
@@ -1721,7 +2433,23 @@ def generate_score_heatmap(settings):
1721
2433
  return merged_df
1722
2434
 
1723
2435
  def post_regression_analysis(csv_file, grna_dict, grna_list, save=False):
1724
-
2436
+ """
2437
+ Perform post-regression analysis to assess gRNA correlation and infer relative effect sizes.
2438
+
2439
+ This function loads a CSV file containing gRNA abundance data (with columns including
2440
+ 'grna', 'fraction', and 'prc'), filters it to the specified gRNAs of interest, computes
2441
+ a correlation matrix, visualizes it, and estimates inferred effect sizes for all gRNAs
2442
+ based on a subset of known fixed effect sizes.
2443
+
2444
+ Args:
2445
+ csv_file (str): Path to the CSV file containing gRNA data with columns ['grna', 'fraction', 'prc'].
2446
+ grna_dict (dict): Dictionary mapping selected gRNAs to their known effect sizes.
2447
+ grna_list (list): List of gRNAs to include in correlation and effect size inference.
2448
+ save (bool, optional): Whether to save outputs including figures and CSVs. Defaults to False.
2449
+
2450
+ Returns:
2451
+ None. Saves plots and results if `save=True`.
2452
+ """
1725
2453
  def _analyze_and_visualize_grna_correlation(df, grna_list, save_folder, save=False):
1726
2454
  """
1727
2455
  Analyze and visualize the correlation matrix of gRNAs based on their fractions and overlap.