spacr 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +3 -2
- spacr/app_classify.py +10 -0
- spacr/app_mask.py +9 -0
- spacr/app_measure.py +9 -0
- spacr/app_sequencing.py +9 -0
- spacr/core.py +172 -1
- spacr/deep_spacr.py +296 -7
- spacr/gui.py +68 -0
- spacr/gui_core.py +319 -10
- spacr/gui_elements.py +772 -13
- spacr/gui_utils.py +304 -12
- spacr/io.py +887 -71
- spacr/logger.py +36 -0
- spacr/measure.py +206 -28
- spacr/ml.py +606 -142
- spacr/plot.py +797 -131
- spacr/sequencing.py +363 -8
- spacr/settings.py +1158 -38
- spacr/sp_stats.py +80 -12
- spacr/spacr_cellpose.py +115 -2
- spacr/submodules.py +747 -19
- spacr/timelapse.py +237 -53
- spacr/toxo.py +132 -6
- spacr/utils.py +2422 -80
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/METADATA +31 -17
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/RECORD +30 -30
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/LICENSE +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/WHEEL +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/entry_points.txt +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/top_level.txt +0 -0
spacr/submodules.py
CHANGED
@@ -30,7 +30,37 @@ from natsort import natsorted
|
|
30
30
|
from torch.utils.data import Dataset
|
31
31
|
|
32
32
|
class CellposeLazyDataset(Dataset):
|
33
|
+
"""
|
34
|
+
A PyTorch Dataset for lazy loading and optional augmentation of images and segmentation masks
|
35
|
+
for training Cellpose-based models.
|
36
|
+
|
37
|
+
Images and labels are loaded from file paths on-the-fly, optionally normalized and resized,
|
38
|
+
and can be augmented with basic geometric transformations.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
image_files (list of str): List of file paths to image files.
|
42
|
+
label_files (list of str): List of file paths to corresponding label files.
|
43
|
+
settings (dict): Dictionary containing dataset settings:
|
44
|
+
- 'normalize' (bool): Whether to apply percentile-based intensity normalization.
|
45
|
+
- 'percentiles' (list of int): Two-element list specifying lower and upper percentiles (default: [2, 99]).
|
46
|
+
- 'target_size' (int): Desired output image size (height and width).
|
47
|
+
randomize (bool, optional): Whether to shuffle the dataset order. Defaults to True.
|
48
|
+
augment (bool, optional): Whether to apply 8-way geometric data augmentation. Defaults to False.
|
49
|
+
"""
|
33
50
|
def __init__(self, image_files, label_files, settings, randomize=True, augment=False):
|
51
|
+
"""
|
52
|
+
Initialize the CellposeLazyDataset.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
image_files (list of str): Paths to input image files.
|
56
|
+
label_files (list of str): Paths to corresponding label files.
|
57
|
+
settings (dict): Configuration dictionary with keys:
|
58
|
+
- 'normalize' (bool)
|
59
|
+
- 'percentiles' (list of int)
|
60
|
+
- 'target_size' (int)
|
61
|
+
randomize (bool, optional): Shuffle file order. Defaults to True.
|
62
|
+
augment (bool, optional): Enable 8-fold augmentation. Defaults to False.
|
63
|
+
"""
|
34
64
|
combined = list(zip(image_files, label_files))
|
35
65
|
if randomize:
|
36
66
|
random.shuffle(combined)
|
@@ -41,9 +71,30 @@ class CellposeLazyDataset(Dataset):
|
|
41
71
|
self.augment = augment
|
42
72
|
|
43
73
|
def __len__(self):
|
74
|
+
"""
|
75
|
+
Return the number of samples in the dataset.
|
76
|
+
|
77
|
+
If augmentation is enabled, each sample contributes 8 variants.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
int: Total number of samples (augmented if applicable).
|
81
|
+
"""
|
44
82
|
return len(self.image_files) * (8 if self.augment else 1)
|
45
83
|
|
46
84
|
def apply_augmentation(self, image, label, aug_idx):
|
85
|
+
"""
|
86
|
+
Apply one of 8 geometric augmentations to an image-label pair.
|
87
|
+
|
88
|
+
Augmentations include rotations (90°, 180°, 270°) and horizontal/vertical flips.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
image (ndarray): Input image array.
|
92
|
+
label (ndarray): Corresponding label array.
|
93
|
+
aug_idx (int): Index from 0 to 7 specifying the augmentation to apply.
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
tuple: Augmented (image, label) pair.
|
97
|
+
"""
|
47
98
|
if aug_idx == 1:
|
48
99
|
return rotate(image, 90, resize=False, preserve_range=True), rotate(label, 90, resize=False, preserve_range=True)
|
49
100
|
elif aug_idx == 2:
|
@@ -61,6 +112,20 @@ class CellposeLazyDataset(Dataset):
|
|
61
112
|
return image, label
|
62
113
|
|
63
114
|
def __getitem__(self, idx):
|
115
|
+
"""
|
116
|
+
Retrieve a sample by index, optionally applying augmentation and preprocessing.
|
117
|
+
|
118
|
+
Loads the image and label, normalizes intensity if specified, applies augmentation,
|
119
|
+
and resizes to the target shape.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
idx (int): Index of the sample to retrieve.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
tuple:
|
126
|
+
- image (np.ndarray): Preprocessed image, shape (target_size, target_size), dtype float32.
|
127
|
+
- label (np.ndarray): Resized label mask, shape (target_size, target_size), dtype uint8.
|
128
|
+
"""
|
64
129
|
base_idx = idx // 8 if self.augment else idx
|
65
130
|
aug_idx = idx % 8 if self.augment else 0
|
66
131
|
|
@@ -86,7 +151,32 @@ class CellposeLazyDataset(Dataset):
|
|
86
151
|
return image, label
|
87
152
|
|
88
153
|
def train_cellpose(settings):
|
89
|
-
|
154
|
+
"""
|
155
|
+
Train a Cellpose model on custom images and masks using specified settings.
|
156
|
+
|
157
|
+
This function prepares training data from `train/images` and `train/masks` subfolders within
|
158
|
+
the provided `settings['src']` directory. It constructs a model name based on training parameters,
|
159
|
+
initializes the Cellpose model, and trains it using the specified number of epochs and hyperparameters.
|
160
|
+
|
161
|
+
The dataset can be augmented up to 8-fold, and training images and masks are matched by filename.
|
162
|
+
Training progress is visualized (if possible), and the model is saved in `models/cellpose_model`.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
settings (dict): Dictionary with the following required keys:
|
166
|
+
- 'src' (str): Root directory containing `train/images`, `train/masks`, and `models`.
|
167
|
+
- 'target_size' (int): Side length to which images/masks are resized.
|
168
|
+
- 'model_name' (str): Base name of the model.
|
169
|
+
- 'n_epochs' (int): Number of training epochs.
|
170
|
+
- 'batch_size' (int): Number of images to train per batch.
|
171
|
+
- 'learning_rate' (float): Learning rate for training.
|
172
|
+
- 'weight_decay' (float): Weight decay (L2 regularization).
|
173
|
+
- 'augment' (bool): Whether to use 8-fold data augmentation.
|
174
|
+
|
175
|
+
Side Effects:
|
176
|
+
- Trains a model and saves it to `models/cellpose_model/`.
|
177
|
+
- Writes training settings using `save_settings()`.
|
178
|
+
- Optionally visualizes a training batch using `plot_cellpose_batch()`.
|
179
|
+
"""
|
90
180
|
from .settings import get_train_cellpose_default_settings
|
91
181
|
from .utils import save_settings
|
92
182
|
|
@@ -156,7 +246,35 @@ def train_cellpose(settings):
|
|
156
246
|
print(f"Model saved at: {model_save_path}/{model_name}")
|
157
247
|
|
158
248
|
def test_cellpose_model(settings):
|
159
|
-
|
249
|
+
"""
|
250
|
+
Evaluate a pretrained Cellpose model on a test dataset and report segmentation performance.
|
251
|
+
|
252
|
+
This function loads test images and ground-truth masks from a specified directory structure,
|
253
|
+
applies the Cellpose model to predict masks, and compares them to the ground truth using
|
254
|
+
object-level metrics and Jaccard index. Results are saved and visualized if specified.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
settings (dict): Dictionary with the following keys:
|
258
|
+
- 'src' (str): Root directory containing `test/images` and `test/masks`.
|
259
|
+
- 'model_path' (str): Path to the pretrained Cellpose model.
|
260
|
+
- 'batch_size' (int): Number of images per inference batch.
|
261
|
+
- 'FT' (float): Flow threshold for Cellpose segmentation.
|
262
|
+
- 'CP_probability' (float): Cell probability threshold for segmentation.
|
263
|
+
- 'target_size' (int): Size to which input images are resized.
|
264
|
+
- 'save' (bool): Whether to save segmentation visualizations and results.
|
265
|
+
|
266
|
+
Side Effects:
|
267
|
+
- Loads and processes image and mask files from `test/images` and `test/masks`.
|
268
|
+
- Evaluates segmentation and computes the following metrics per image:
|
269
|
+
* Jaccard index
|
270
|
+
* Number of predicted and ground-truth objects
|
271
|
+
* Mean object area (true and predicted)
|
272
|
+
* True positives, false positives, false negatives
|
273
|
+
* Precision, recall, F1-score, accuracy
|
274
|
+
- Saves figures showing input, true mask, predicted mask, probability map, and flow.
|
275
|
+
- Optionally saves results to `results/test_results.csv`.
|
276
|
+
- Prints performance summary to stdout.
|
277
|
+
"""
|
160
278
|
from .utils import save_settings, print_progress
|
161
279
|
from .settings import get_default_test_cellpose_model_settings
|
162
280
|
|
@@ -342,7 +460,32 @@ def test_cellpose_model(settings):
|
|
342
460
|
df_results.to_csv(os.path.join(results_dir, 'test_results.csv'), index=False)
|
343
461
|
|
344
462
|
def apply_cellpose_model(settings):
|
345
|
-
|
463
|
+
"""
|
464
|
+
Apply a pretrained Cellpose model to a folder of images and extract object-level measurements.
|
465
|
+
|
466
|
+
This function processes all `.tif` images in the specified source directory using a pretrained
|
467
|
+
Cellpose model, generates segmentation masks, and computes region properties (e.g., area) for each
|
468
|
+
detected object. It optionally applies circular masking to limit the field of view and saves results
|
469
|
+
and figures if enabled.
|
470
|
+
|
471
|
+
Args:
|
472
|
+
settings (dict): Dictionary with the following required keys:
|
473
|
+
- 'src' (str): Directory containing the input `.tif` images.
|
474
|
+
- 'model_path' (str): Path to the pretrained Cellpose model file.
|
475
|
+
- 'batch_size' (int): Number of images processed per batch.
|
476
|
+
- 'FT' (float): Flow threshold for Cellpose segmentation.
|
477
|
+
- 'CP_probability' (float): Cell probability threshold for segmentation.
|
478
|
+
- 'target_size' (int): Resize target for input images.
|
479
|
+
- 'save' (bool): If True, saves visualizations and CSV results.
|
480
|
+
- 'circularize' (bool, optional): If True, applies a circular mask to each prediction.
|
481
|
+
|
482
|
+
Side Effects:
|
483
|
+
- Generates and saves segmented mask visualizations (if `save=True`) to `results/`.
|
484
|
+
- Extracts per-object measurements (area) and saves:
|
485
|
+
* `results/measurements.csv`: one row per object.
|
486
|
+
* `results/summary.csv`: average object area and count per image.
|
487
|
+
- Displays progress and timing for each batch.
|
488
|
+
"""
|
346
489
|
from .settings import get_default_apply_cellpose_model_settings
|
347
490
|
from .utils import save_settings, print_progress
|
348
491
|
|
@@ -462,6 +605,17 @@ def apply_cellpose_model(settings):
|
|
462
605
|
print("Saved object count and average area to summary.csv")
|
463
606
|
|
464
607
|
def plot_cellpose_batch(images, labels):
|
608
|
+
"""
|
609
|
+
Display a batch of input images and their corresponding label masks.
|
610
|
+
|
611
|
+
This function plots two rows of subplots:
|
612
|
+
- Top row: grayscale input images.
|
613
|
+
- Bottom row: corresponding label masks with randomly colored regions.
|
614
|
+
|
615
|
+
Args:
|
616
|
+
images (list of np.ndarray): List of 2D grayscale input images.
|
617
|
+
labels (list of np.ndarray): List of 2D label masks corresponding to the images.
|
618
|
+
"""
|
465
619
|
from .plot import generate_mask_random_cmap
|
466
620
|
|
467
621
|
cmap_lbl = generate_mask_random_cmap(labels)
|
@@ -477,6 +631,31 @@ def plot_cellpose_batch(images, labels):
|
|
477
631
|
plt.show()
|
478
632
|
|
479
633
|
def analyze_percent_positive(settings):
|
634
|
+
"""
|
635
|
+
Analyze the fraction of objects above a threshold per well and annotate them accordingly.
|
636
|
+
|
637
|
+
This function loads merged object-level measurements from a SQLite database, applies optional
|
638
|
+
filtering, and annotates each object as 'above' or 'below' a specified threshold on a value column.
|
639
|
+
It then summarizes the number and fraction of above-threshold objects per condition and well.
|
640
|
+
Results are merged with well metadata and saved as a CSV.
|
641
|
+
|
642
|
+
Args:
|
643
|
+
settings (dict): Dictionary with the following required keys:
|
644
|
+
- 'src' (str): Root directory containing `measurements.db` and `rename_log.csv`.
|
645
|
+
- 'tables' (list of str): Table names to extract from the SQLite database.
|
646
|
+
- 'value_col' (str): Column to apply the threshold to.
|
647
|
+
- 'threshold' (float): Threshold value to classify objects.
|
648
|
+
- 'filter_1' (tuple or None): Optional filter in the form (column, min_value), or None.
|
649
|
+
|
650
|
+
Returns:
|
651
|
+
pd.DataFrame: Merged DataFrame with annotation summary and metadata, also saved to `result.csv`.
|
652
|
+
|
653
|
+
Side Effects:
|
654
|
+
- Reads measurement data from `measurements/measurements.db`.
|
655
|
+
- Reads well metadata from `rename_log.csv`.
|
656
|
+
- Writes annotated results to `result.csv` in the same directory.
|
657
|
+
- Displays the final DataFrame in notebook or console.
|
658
|
+
"""
|
480
659
|
from .io import _read_and_merge_data
|
481
660
|
from .utils import save_settings
|
482
661
|
from .settings import default_settings_analyze_percent_positive
|
@@ -484,6 +663,26 @@ def analyze_percent_positive(settings):
|
|
484
663
|
settings = default_settings_analyze_percent_positive(settings)
|
485
664
|
|
486
665
|
def translate_well_in_df(csv_loc):
|
666
|
+
"""
|
667
|
+
Extract and translate well metadata from a CSV file containing renamed TIFF filenames.
|
668
|
+
|
669
|
+
This function parses the 'Renamed TIFF' column to extract plate and well information,
|
670
|
+
generates plate-well identifiers, and translates well positions to row and column IDs.
|
671
|
+
It also constructs a unique `prc` (plate_row_column) identifier for each well.
|
672
|
+
|
673
|
+
Args:
|
674
|
+
csv_loc (str): Path to a CSV file containing a 'Renamed TIFF' column.
|
675
|
+
|
676
|
+
Returns:
|
677
|
+
pd.DataFrame: DataFrame with columns:
|
678
|
+
- 'plateID': extracted plate identifier
|
679
|
+
- 'well': extracted well identifier (e.g., A01)
|
680
|
+
- 'plate_well': combined plate and well string
|
681
|
+
- 'rowID': converted row index (e.g., 'r1' for row A)
|
682
|
+
- 'column_name': converted column index (e.g., 'c1' for column 01)
|
683
|
+
- 'fieldID': field identifier (default: 'f1')
|
684
|
+
- 'prc': combined identifier in the form 'p<plate>_r<row>_c<column>'
|
685
|
+
"""
|
487
686
|
# Load and extract metadata
|
488
687
|
df = pd.read_csv(csv_loc)
|
489
688
|
df[['plateID', 'well']] = df['Renamed TIFF'].str.replace('.tif', '', regex=False).str.split('_', expand=True)[[0, 1]]
|
@@ -564,15 +763,39 @@ def analyze_percent_positive(settings):
|
|
564
763
|
|
565
764
|
def analyze_recruitment(settings):
|
566
765
|
"""
|
567
|
-
Analyze
|
568
|
-
|
569
|
-
|
570
|
-
|
766
|
+
Analyze host protein recruitment to the pathogen vacuole across experimental conditions.
|
767
|
+
|
768
|
+
This function loads object-level measurements, applies size and intensity-based filtering, computes
|
769
|
+
recruitment scores as the ratio of pathogen-to-cytoplasm intensities, and aggregates results at the
|
770
|
+
cell and well level. Recruitment is annotated and visualized, and filtered results are saved.
|
771
|
+
|
772
|
+
Args:
|
773
|
+
settings (dict): Dictionary with the following required keys:
|
774
|
+
- 'src' (str): Path to the root folder or `measurements.db` file.
|
775
|
+
- 'cell_types', 'pathogen_types', 'treatments' (list of str): Experimental conditions to annotate.
|
776
|
+
- 'cell_plate_metadata', 'pathogen_plate_metadata', 'treatment_plate_metadata' (str): Metadata sources.
|
777
|
+
- 'channel_of_interest' (int): Channel index for recruitment quantification.
|
778
|
+
- 'channel_dims' (list of int): Channels to calculate recruitment for.
|
779
|
+
- 'cell_chann_dim', 'nucleus_chann_dim', 'pathogen_chann_dim' (int or None): Mask channel assignments.
|
780
|
+
- 'cell_size_range', 'nucleus_size_range', 'pathogen_size_range' (list): Min/max object size filters.
|
781
|
+
- 'cell_intensity_range', etc. (list): Intensity-based object filters.
|
782
|
+
- 'target_intensity_min' (float): Minimum required target channel intensity in cell.
|
783
|
+
- 'nuclei_limit', 'pathogen_limit' (int or None): Max number of nuclei/pathogens to load.
|
784
|
+
- 'cells_per_well' (int): Minimum number of cells per well to retain.
|
785
|
+
- 'plot' (bool): Whether to plot cell outlines on merged images.
|
786
|
+
- 'plot_control' (bool): Whether to plot control comparisons.
|
787
|
+
- 'plot_nr' (int): Maximum number of outlines to plot.
|
788
|
+
- 'figuresize' (float): Size of recruitment plots in inches.
|
571
789
|
|
572
790
|
Returns:
|
573
|
-
|
791
|
+
list: [cell-level DataFrame, well-level DataFrame] after filtering and recruitment analysis.
|
792
|
+
|
793
|
+
Side Effects:
|
794
|
+
- Reads from and writes to files in the source directory.
|
795
|
+
- Plots recruitment distributions and controls.
|
796
|
+
- Saves `cell_level_results.csv` and `well_level_results.csv` via `_results_to_csv`.
|
797
|
+
- Displays status messages and summary statistics.
|
574
798
|
"""
|
575
|
-
|
576
799
|
from .io import _read_and_merge_data, _results_to_csv
|
577
800
|
from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
|
578
801
|
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
|
@@ -686,7 +909,30 @@ def analyze_recruitment(settings):
|
|
686
909
|
return [cells,wells]
|
687
910
|
|
688
911
|
def analyze_plaques(settings):
|
689
|
-
|
912
|
+
"""
|
913
|
+
Analyze plaque-like structures from microscopy images using a pretrained Cellpose model.
|
914
|
+
|
915
|
+
This function applies a custom-trained Cellpose model to identify and segment plaques in `.tif` images.
|
916
|
+
It calculates object-level statistics including object count and area, and saves the results into a
|
917
|
+
SQLite database with three tables: `summary`, `details`, and `stats`.
|
918
|
+
|
919
|
+
Args:
|
920
|
+
settings (dict): Dictionary with the following keys:
|
921
|
+
- 'src' (str): Source folder containing input `.tif` images.
|
922
|
+
- 'masks' (bool): If True, (re)generate segmentation masks using Cellpose.
|
923
|
+
- 'custom_model' (str, optional): Will be set internally to the plaque model path.
|
924
|
+
- Any keys required by `get_analyze_plaque_settings()` or `identify_masks_finetune()`.
|
925
|
+
|
926
|
+
Side Effects:
|
927
|
+
- Downloads pretrained plaque segmentation model if not available.
|
928
|
+
- Saves generated masks (if `masks=True`) to `src/masks/`.
|
929
|
+
- Analyzes all `.tif` masks in the folder.
|
930
|
+
- Creates and saves the following tables to `plaques_analysis.db`:
|
931
|
+
* `summary`: file name, number of plaques, average size
|
932
|
+
* `details`: file name, individual plaque sizes
|
933
|
+
* `stats`: file name, plaque count, average size, and standard deviation
|
934
|
+
- Prints progress and completion messages.
|
935
|
+
"""
|
690
936
|
from .spacr_cellpose import identify_masks_finetune
|
691
937
|
from .settings import get_analyze_plaque_settings
|
692
938
|
from .utils import save_settings, download_models
|
@@ -751,6 +997,29 @@ def analyze_plaques(settings):
|
|
751
997
|
print(f"Analysis completed and saved to database '{db_name}'.")
|
752
998
|
|
753
999
|
def count_phenotypes(settings):
|
1000
|
+
"""
|
1001
|
+
Count and summarize annotated phenotypes in a measurement database.
|
1002
|
+
|
1003
|
+
This function reads annotation data from the `measurements.db` database and performs the following:
|
1004
|
+
- Counts the number of unique phenotype labels in the specified annotation column.
|
1005
|
+
- Computes how many unique labels occur per well (defined by plateID, rowID, columnID).
|
1006
|
+
- Computes the count of each phenotype value per well.
|
1007
|
+
- Outputs a CSV file with the per-well phenotype counts.
|
1008
|
+
|
1009
|
+
Args:
|
1010
|
+
settings (dict): Dictionary with the following keys:
|
1011
|
+
- 'src' (str): Path to a measurement database or directory containing it.
|
1012
|
+
- 'annotation_column' (str): Column name containing the phenotype annotations.
|
1013
|
+
|
1014
|
+
Returns:
|
1015
|
+
None
|
1016
|
+
|
1017
|
+
Side Effects:
|
1018
|
+
- Displays number of unique phenotype values.
|
1019
|
+
- Displays a table of unique value counts per well.
|
1020
|
+
- Saves a CSV file `phenotype_counts.csv` summarizing value counts per well
|
1021
|
+
(one row per well, one column per phenotype).
|
1022
|
+
"""
|
754
1023
|
from .io import _read_db
|
755
1024
|
|
756
1025
|
if not settings['src'].endswith('/measurements/measurements.db'):
|
@@ -794,8 +1063,63 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
794
1063
|
pc_grna='TGGT1_220950_1', nc_grna='TGGT1_233460_4',
|
795
1064
|
y_columns=['class_1_fraction', 'TGGT1_220950_1_fraction', 'nc_fraction'],
|
796
1065
|
column='columnID', value='c3', plate=None, save_paths=None):
|
1066
|
+
"""
|
1067
|
+
Compare Cellpose-based phenotypic classification scores with sequencing read distributions
|
1068
|
+
from positive and negative control gRNAs across plate wells.
|
1069
|
+
|
1070
|
+
This function merges phenotype classification scores and read count data, calculates gRNA
|
1071
|
+
fractions and class predictions per well, and overlays empirical expectations for controls
|
1072
|
+
to assess model calibration. It generates and saves line plots for selected y-axis columns.
|
1073
|
+
|
1074
|
+
Args:
|
1075
|
+
reads_csv (str or list of str): Path(s) to CSV file(s) containing read counts and gRNA names.
|
1076
|
+
scores_csv (str or list of str): Path(s) to CSV file(s) with classification results per object.
|
1077
|
+
empirical_dict (dict): Mapping of rowID to (positive, negative) control counts.
|
1078
|
+
pc_grna (str): Name of the positive control gRNA.
|
1079
|
+
nc_grna (str): Name of the negative control gRNA.
|
1080
|
+
y_columns (list of str): Columns in the merged dataframe to plot on y-axis.
|
1081
|
+
column (str): Column to subset on (e.g., 'columnID').
|
1082
|
+
value (str): Value to filter the column on (e.g., 'c3').
|
1083
|
+
plate (str, optional): Plate name to assign if not present in the CSVs.
|
1084
|
+
save_paths (list of str, optional): List of two paths to save the plots as PDF.
|
797
1085
|
|
1086
|
+
Returns:
|
1087
|
+
list: Two matplotlib Figure objects, one for pc_fraction and one for nc_fraction x-axis plots.
|
1088
|
+
|
1089
|
+
Side Effects:
|
1090
|
+
- Reads and merges input CSVs.
|
1091
|
+
- Computes gRNA read fractions and classification fractions per well.
|
1092
|
+
- Merges empirical expectations based on rowID.
|
1093
|
+
- Displays merged DataFrame.
|
1094
|
+
- Saves two line plots (if `save_paths` is specified).
|
1095
|
+
"""
|
1096
|
+
|
798
1097
|
def calculate_well_score_fractions(df, class_columns='cv_predictions'):
|
1098
|
+
"""
|
1099
|
+
Calculate the fraction of predicted classes (e.g., class_0 and class_1) per well.
|
1100
|
+
|
1101
|
+
This function groups a DataFrame by plate, row, and column IDs to compute the total number
|
1102
|
+
of objects and the number of objects predicted in each class per well. It returns a summary
|
1103
|
+
DataFrame that includes the absolute and relative frequencies of class 0 and class 1 predictions.
|
1104
|
+
|
1105
|
+
Args:
|
1106
|
+
df (pd.DataFrame): Input DataFrame containing at least:
|
1107
|
+
- 'plateID', 'rowID', 'columnID': to define wells
|
1108
|
+
- A column specified by `class_columns` with predicted class labels (e.g., 0, 1)
|
1109
|
+
class_columns (str): Name of the column containing predicted class labels. Defaults to 'cv_predictions'.
|
1110
|
+
|
1111
|
+
Returns:
|
1112
|
+
pd.DataFrame: Summary DataFrame with columns:
|
1113
|
+
- 'plateID', 'rowID', 'columnID', 'prc'
|
1114
|
+
- 'total_rows': total number of entries per well
|
1115
|
+
- 'class_0': number of class 0 predictions
|
1116
|
+
- 'class_1': number of class 1 predictions
|
1117
|
+
- 'class_0_fraction': class_0 / total_rows
|
1118
|
+
- 'class_1_fraction': class_1 / total_rows
|
1119
|
+
|
1120
|
+
Raises:
|
1121
|
+
ValueError: If required columns ('plateID', 'rowID', 'columnID') are missing in `df`.
|
1122
|
+
"""
|
799
1123
|
if all(col in df.columns for col in ['plateID', 'rowID', 'columnID']):
|
800
1124
|
df['prc'] = df['plateID'] + '_' + df['rowID'] + '_' + df['columnID']
|
801
1125
|
else:
|
@@ -814,9 +1138,31 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
814
1138
|
def plot_line(df, x_column, y_columns, group_column=None, xlabel=None, ylabel=None,
|
815
1139
|
title=None, figsize=(10, 6), save_path=None, theme='deep'):
|
816
1140
|
"""
|
817
|
-
Create a line plot
|
818
|
-
|
1141
|
+
Create a customizable line plot for one or more y-columns against a common x-axis.
|
1142
|
+
|
1143
|
+
This function supports grouped or ungrouped line plots and includes options for custom axis
|
1144
|
+
labels, title, figure size, and color themes. When multiple y-columns are specified, each is
|
1145
|
+
plotted as an individual line. Optionally saves the plot as a high-resolution PDF.
|
1146
|
+
|
1147
|
+
Args:
|
1148
|
+
df (pd.DataFrame): Input DataFrame containing the plotting data.
|
1149
|
+
x_column (str): Column name to use for the x-axis.
|
1150
|
+
y_columns (str or list of str): Column(s) to use for the y-axis. If a list, each column is plotted as a separate line.
|
1151
|
+
group_column (str, optional): Column name for grouping lines (only used if `y_columns` is a single string).
|
1152
|
+
xlabel (str, optional): Label for the x-axis. Defaults to `x_column` name.
|
1153
|
+
ylabel (str, optional): Label for the y-axis. Defaults to 'Value'.
|
1154
|
+
title (str, optional): Title of the plot. Defaults to 'Line Plot'.
|
1155
|
+
figsize (tuple, optional): Figure size in inches. Defaults to (10, 6).
|
1156
|
+
save_path (str, optional): If provided, saves the figure as a PDF to this path.
|
1157
|
+
theme (str, optional): Seaborn color palette theme. Defaults to 'deep'.
|
819
1158
|
|
1159
|
+
Returns:
|
1160
|
+
matplotlib.figure.Figure: The generated figure object.
|
1161
|
+
|
1162
|
+
Side Effects:
|
1163
|
+
- Displays the plot.
|
1164
|
+
- Saves the figure as a PDF if `save_path` is specified.
|
1165
|
+
"""
|
820
1166
|
def _set_theme(theme):
|
821
1167
|
"""Set the Seaborn theme and reorder colors if necessary."""
|
822
1168
|
|
@@ -880,6 +1226,30 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
880
1226
|
return fig
|
881
1227
|
|
882
1228
|
def calculate_grna_fraction_ratio(df, grna1='TGGT1_220950_1', grna2='TGGT1_233460_4'):
|
1229
|
+
"""
|
1230
|
+
Calculate the ratio and fractional abundance of two gRNAs per well.
|
1231
|
+
|
1232
|
+
This function filters the input DataFrame for two specified gRNAs, aggregates their read
|
1233
|
+
fractions and counts per well (`prc`), and computes:
|
1234
|
+
- the fraction ratio (grna1 / grna2),
|
1235
|
+
- total read count per well,
|
1236
|
+
- individual gRNA read fractions per well.
|
1237
|
+
|
1238
|
+
Args:
|
1239
|
+
df (pd.DataFrame): Input DataFrame with columns ['prc', 'grna_name', 'fraction', 'count'].
|
1240
|
+
grna1 (str): Name of the positive control or reference gRNA.
|
1241
|
+
grna2 (str): Name of the negative control or comparison gRNA.
|
1242
|
+
|
1243
|
+
Returns:
|
1244
|
+
pd.DataFrame: A DataFrame with columns:
|
1245
|
+
- 'prc': well identifier (plate_row_column)
|
1246
|
+
- '<grna1>_count': total read count of grna1 per well
|
1247
|
+
- '<grna2>_count': total read count of grna2 per well
|
1248
|
+
- 'fraction_ratio': grna1 fraction / grna2 fraction
|
1249
|
+
- 'total_reads': total reads of both gRNAs
|
1250
|
+
- '<grna1>_fraction': grna1 reads / total_reads
|
1251
|
+
- '<grna2>_fraction': grna2 reads / total_reads
|
1252
|
+
"""
|
883
1253
|
# Filter relevant grna_names within each prc and group them
|
884
1254
|
grouped = df[df['grna_name'].isin([grna1, grna2])] \
|
885
1255
|
.groupby(['prc', 'grna_name']) \
|
@@ -901,6 +1271,28 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
901
1271
|
return result
|
902
1272
|
|
903
1273
|
def calculate_well_read_fraction(df, count_column='count'):
|
1274
|
+
"""
|
1275
|
+
Calculate the fractional read count for each gRNA within a well.
|
1276
|
+
|
1277
|
+
This function computes the total number of reads per well (identified by plateID, rowID, columnID),
|
1278
|
+
and then computes the fraction of reads contributed by each gRNA relative to the total reads
|
1279
|
+
in that well.
|
1280
|
+
|
1281
|
+
Args:
|
1282
|
+
df (pd.DataFrame): DataFrame containing at least the columns:
|
1283
|
+
- 'plateID', 'rowID', 'columnID' (to define wells)
|
1284
|
+
- `count_column` (default: 'count'), the read count per gRNA instance
|
1285
|
+
count_column (str): Name of the column holding raw read counts. Defaults to 'count'.
|
1286
|
+
|
1287
|
+
Returns:
|
1288
|
+
pd.DataFrame: Original DataFrame with added columns:
|
1289
|
+
- 'prc': combined plate_row_column identifier
|
1290
|
+
- 'total_counts': total reads per well
|
1291
|
+
- 'fraction': count / total_counts for each row
|
1292
|
+
|
1293
|
+
Raises:
|
1294
|
+
ValueError: If any of 'plateID', 'rowID', or 'columnID' are missing in the input DataFrame.
|
1295
|
+
"""
|
904
1296
|
if all(col in df.columns for col in ['plateID', 'rowID', 'columnID']):
|
905
1297
|
df['prc'] = df['plateID'] + '_' + df['rowID'] + '_' + df['columnID']
|
906
1298
|
else:
|
@@ -969,7 +1361,49 @@ def compare_reads_to_scores(reads_csv, scores_csv, empirical_dict={'r1':(90,10),
|
|
969
1361
|
return [fig_1, fig_2]
|
970
1362
|
|
971
1363
|
def interperate_vision_model(settings={}):
|
972
|
-
|
1364
|
+
"""
|
1365
|
+
Interpret a vision-based machine learning model using multiple feature attribution strategies.
|
1366
|
+
|
1367
|
+
This function reads merged object-level data and model scores, engineers relative features, and applies:
|
1368
|
+
- Random Forest feature importance
|
1369
|
+
- Permutation importance
|
1370
|
+
- SHAP (SHapley Additive exPlanations) values
|
1371
|
+
|
1372
|
+
The model interpretation is performed across cellular compartments and imaging channels, and
|
1373
|
+
results can be visualized or saved for downstream analysis.
|
1374
|
+
|
1375
|
+
Args:
|
1376
|
+
settings (dict): Dictionary of configuration parameters including:
|
1377
|
+
- 'src' (str): Path to the data directory.
|
1378
|
+
- 'scores' (str): Path to the score CSV file.
|
1379
|
+
- 'tables' (list of str): Feature group keywords for compartmental grouping.
|
1380
|
+
- 'channels' (list of str): Feature group keywords for channel grouping.
|
1381
|
+
- 'score_column' (str): Column name for predicted class scores.
|
1382
|
+
- 'nuclei_limit' (int): Optional filter on number of nuclei to load.
|
1383
|
+
- 'pathogen_limit' (int): Optional filter on number of pathogens to load.
|
1384
|
+
- 'feature_importance' (bool): Whether to compute Random Forest feature importances.
|
1385
|
+
- 'permutation_importance' (bool): Whether to compute permutation-based importances.
|
1386
|
+
- 'shap' (bool): Whether to compute SHAP values.
|
1387
|
+
- 'shap_sample' (bool): Whether to subsample SHAP input for speed.
|
1388
|
+
- 'top_features' (int): Number of top features to plot.
|
1389
|
+
- 'include_all' (bool): Whether to include a row summing all feature groups.
|
1390
|
+
- 'n_jobs' (int): Number of parallel jobs to use for model training or SHAP.
|
1391
|
+
- 'save' (bool): Whether to save output CSV files.
|
1392
|
+
|
1393
|
+
Returns:
|
1394
|
+
dict: A dictionary of interpretation outputs with the following possible keys:
|
1395
|
+
- 'feature_importance': Raw feature importances (Random Forest)
|
1396
|
+
- 'feature_importance_compartment': Grouped by compartment
|
1397
|
+
- 'feature_importance_channel': Grouped by imaging channel
|
1398
|
+
- 'permutation_importance': Permutation importance values
|
1399
|
+
- 'shap': SHAP values for each object-feature pair
|
1400
|
+
|
1401
|
+
Side Effects:
|
1402
|
+
- Displays importance plots.
|
1403
|
+
- Optionally saves CSVs to `<src>/results/`.
|
1404
|
+
- Merges and preprocesses input data, including relative feature construction.
|
1405
|
+
"""
|
1406
|
+
|
973
1407
|
from .io import _read_and_merge_data
|
974
1408
|
|
975
1409
|
def generate_comparison_columns(df, compartments=['cell', 'nucleus', 'pathogen', 'cytoplasm']):
|
@@ -1016,7 +1450,26 @@ def interperate_vision_model(settings={}):
|
|
1016
1450
|
return df, comparison_dict
|
1017
1451
|
|
1018
1452
|
def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pathogen'], name='compartment', include_all=False):
|
1453
|
+
"""
|
1454
|
+
Group features by compartment or channel and sum their importance scores.
|
1455
|
+
|
1456
|
+
This function identifies the compartment or channel associated with each feature based on
|
1457
|
+
string matching, and computes the total importance of each group. Optionally, it includes
|
1458
|
+
a row summing the total importance across all groups.
|
1019
1459
|
|
1460
|
+
Args:
|
1461
|
+
df (pd.DataFrame): DataFrame with at least the columns:
|
1462
|
+
- 'feature': feature names to classify
|
1463
|
+
- 'importance': importance values associated with each feature
|
1464
|
+
feature_groups (list of str): Keywords used to assign features to groups (e.g., compartments or channels).
|
1465
|
+
name (str): Column name to assign group classification (e.g., 'compartment' or 'channel').
|
1466
|
+
include_all (bool): If True, include a summary row for total importance across all groups.
|
1467
|
+
|
1468
|
+
Returns:
|
1469
|
+
pd.DataFrame: DataFrame with columns:
|
1470
|
+
- `name` (e.g., 'compartment', 'channel')
|
1471
|
+
- `<name>_importance_sum`: total importance per group
|
1472
|
+
"""
|
1020
1473
|
# Function to determine compartment based on multiple matches
|
1021
1474
|
def find_feature_class(feature, compartments):
|
1022
1475
|
matches = [compartment for compartment in compartments if re.search(compartment, feature)]
|
@@ -1049,6 +1502,23 @@ def interperate_vision_model(settings={}):
|
|
1049
1502
|
|
1050
1503
|
# Function to create radar plot for individual and combined values
|
1051
1504
|
def create_extended_radar_plot(values, labels, title):
|
1505
|
+
"""
|
1506
|
+
Create a radar (spider) plot to visualize multivariate values on a circular axis.
|
1507
|
+
|
1508
|
+
This function plots a closed polygon where each axis represents a different category
|
1509
|
+
(e.g., feature group), allowing for intuitive comparison of relative magnitudes.
|
1510
|
+
|
1511
|
+
Args:
|
1512
|
+
values (list or array-like): List of values to plot. Should match the length of `labels`.
|
1513
|
+
labels (list of str): Labels for each axis in the radar chart.
|
1514
|
+
title (str): Title of the plot.
|
1515
|
+
|
1516
|
+
Returns:
|
1517
|
+
None
|
1518
|
+
|
1519
|
+
Side Effects:
|
1520
|
+
- Displays the radar chart using matplotlib.
|
1521
|
+
"""
|
1052
1522
|
values = list(values) + [values[0]] # Close the loop for radar chart
|
1053
1523
|
angles = [n / float(len(labels)) * 2 * pi for n in range(len(labels))]
|
1054
1524
|
angles += angles[:1]
|
@@ -1063,6 +1533,22 @@ def interperate_vision_model(settings={}):
|
|
1063
1533
|
plt.show()
|
1064
1534
|
|
1065
1535
|
def extract_compartment_channel(feature_name):
|
1536
|
+
"""
|
1537
|
+
Extract the compartment and imaging channel from a feature name.
|
1538
|
+
|
1539
|
+
This function parses a feature name string to identify:
|
1540
|
+
- The compartment (first substring before an underscore, with 'cells' mapped to 'cell')
|
1541
|
+
- The imaging channel(s), based on the presence of substrings like 'channel_0', etc.
|
1542
|
+
|
1543
|
+
Args:
|
1544
|
+
feature_name (str): The name of the feature to parse.
|
1545
|
+
|
1546
|
+
Returns:
|
1547
|
+
tuple:
|
1548
|
+
- compartment (str): The identified compartment (e.g., 'cell', 'nucleus').
|
1549
|
+
- channel (str): Channel label or 'morphology' if no channel found.
|
1550
|
+
If multiple channels are present, they are joined by ' + '.
|
1551
|
+
"""
|
1066
1552
|
# Identify compartment as the first part before an underscore
|
1067
1553
|
compartment = feature_name.split('_')[0]
|
1068
1554
|
|
@@ -1089,7 +1575,34 @@ def interperate_vision_model(settings={}):
|
|
1089
1575
|
return (compartment, channel)
|
1090
1576
|
|
1091
1577
|
def read_and_preprocess_data(settings):
|
1578
|
+
"""
|
1579
|
+
Load, preprocess, and merge measurement data and classification scores.
|
1580
|
+
|
1581
|
+
This function:
|
1582
|
+
- Reads merged object-level measurements from a SQLite database.
|
1583
|
+
- Adds relative comparison features between compartments (e.g., nucleus/cytoplasm).
|
1584
|
+
- Loads classification scores and aligns them with the measurement data.
|
1585
|
+
- Ensures consistent formatting of join keys (plateID, rowID, etc.).
|
1586
|
+
- Merges the two datasets on object identifiers and returns the feature matrix and labels.
|
1587
|
+
|
1588
|
+
Args:
|
1589
|
+
settings (dict): Dictionary with required keys:
|
1590
|
+
- 'src' (str): Base path to the dataset directory.
|
1591
|
+
- 'tables' (list of str): Tables to load from the database.
|
1592
|
+
- 'nuclei_limit' (int or None): Optional limit on number of nuclei to load.
|
1593
|
+
- 'pathogen_limit' (int or None): Optional limit on number of pathogens to load.
|
1594
|
+
- 'scores' (str): Path to the CSV file with prediction/classification scores.
|
1595
|
+
- 'score_column' (str): Column name in scores CSV used as prediction target.
|
1596
|
+
|
1597
|
+
Returns:
|
1598
|
+
tuple:
|
1599
|
+
- X (pd.DataFrame): Numerical feature matrix (excluding the score column).
|
1600
|
+
- y (pd.Series): Target values corresponding to the score column.
|
1601
|
+
- merged_df (pd.DataFrame): Full merged DataFrame including both features and scores.
|
1092
1602
|
|
1603
|
+
Side Effects:
|
1604
|
+
- Prints the number of columns in the final feature-expanded DataFrame.
|
1605
|
+
"""
|
1093
1606
|
df, _ = _read_and_merge_data(
|
1094
1607
|
locs=[settings['src']+'/measurements/measurements.db'],
|
1095
1608
|
tables=settings['tables'],
|
@@ -1267,6 +1780,38 @@ def interperate_vision_model(settings={}):
|
|
1267
1780
|
return output
|
1268
1781
|
|
1269
1782
|
def _plot_proportion_stacked_bars(settings, df, group_column, bin_column, prc_column='prc', level='object'):
|
1783
|
+
"""
|
1784
|
+
Generate a stacked bar plot of bin proportions by group and perform a chi-squared test for association.
|
1785
|
+
|
1786
|
+
Depending on the `level` parameter, this function can:
|
1787
|
+
- Plot mean ± SD of binned proportions across wells (`level='well'`).
|
1788
|
+
- Plot object-level proportions per group (`level='object'`).
|
1789
|
+
|
1790
|
+
The function also calculates and prints the chi-squared test statistic and p-value
|
1791
|
+
for independence between group and bin category distributions.
|
1792
|
+
|
1793
|
+
Args:
|
1794
|
+
settings (dict): Dictionary with settings, must include:
|
1795
|
+
- 'um_per_px' (float or None): Micron-per-pixel scaling factor. If None, volume is in px³.
|
1796
|
+
df (pd.DataFrame): DataFrame containing binned volume information.
|
1797
|
+
group_column (str): Column name indicating group assignment (e.g., treatment, genotype).
|
1798
|
+
bin_column (str): Column name for the binned volume categories.
|
1799
|
+
prc_column (str): Column name specifying well ID. Only used if level='well'. Default is 'prc'.
|
1800
|
+
level (str): Either 'object' (default) or 'well', determines plot granularity.
|
1801
|
+
|
1802
|
+
Returns:
|
1803
|
+
tuple:
|
1804
|
+
- chi2 (float): Chi-squared test statistic from raw data.
|
1805
|
+
- p (float): p-value from the chi-squared test.
|
1806
|
+
- dof (int): Degrees of freedom from the test.
|
1807
|
+
- expected (ndarray): Expected frequencies under independence.
|
1808
|
+
- raw_counts (pd.DataFrame): Contingency table used for chi-squared test.
|
1809
|
+
- fig (matplotlib.figure.Figure): The figure object containing the plot.
|
1810
|
+
|
1811
|
+
Side Effects:
|
1812
|
+
- Displays a stacked bar plot using matplotlib.
|
1813
|
+
- Prints the chi-squared test result.
|
1814
|
+
"""
|
1270
1815
|
# Always calculate chi-squared on raw data
|
1271
1816
|
raw_counts = df.groupby([group_column, bin_column]).size().unstack(fill_value=0)
|
1272
1817
|
chi2, p, dof, expected = chi2_contingency(raw_counts)
|
@@ -1316,7 +1861,42 @@ def _plot_proportion_stacked_bars(settings, df, group_column, bin_column, prc_co
|
|
1316
1861
|
return chi2, p, dof, expected, raw_counts, fig
|
1317
1862
|
|
1318
1863
|
def analyze_endodyogeny(settings):
|
1319
|
-
|
1864
|
+
"""
|
1865
|
+
Analyze intracellular pathogen replication (endodyogeny) based on compartment volume measurements.
|
1866
|
+
|
1867
|
+
This function bins single-cell or single-object volume estimates (e.g., pathogen area^1.5) into
|
1868
|
+
discrete categories, annotates experimental conditions, and visualizes the volume distribution
|
1869
|
+
across groups as a stacked bar plot. It also performs a chi-squared test to assess whether the
|
1870
|
+
volume distribution is significantly different between groups.
|
1871
|
+
|
1872
|
+
Args:
|
1873
|
+
settings (dict): Configuration dictionary with required keys:
|
1874
|
+
- 'src' (str or list): Path(s) to directories containing measurements.db.
|
1875
|
+
- 'tables' (list of str): Tables to extract from the SQLite database.
|
1876
|
+
- 'compartment' (str): Compartment to analyze (e.g., 'pathogen').
|
1877
|
+
- 'min_area_bin' (float): Minimum 2D area threshold to consider (in px² or µm²).
|
1878
|
+
- 'max_bins' (int or None): Maximum number of volume bins to display.
|
1879
|
+
- 'group_column' (str): Column defining experimental groups (e.g., genotype or treatment).
|
1880
|
+
- 'cell_types', 'pathogen_types', 'treatments' (list): Experimental groups for annotation.
|
1881
|
+
- 'cell_plate_metadata', 'pathogen_plate_metadata', 'treatment_plate_metadata' (str): Metadata source paths.
|
1882
|
+
- 'nuclei_limit', 'pathogen_limit' (int or None): Max number of objects to load per type.
|
1883
|
+
- 'change_plate' (bool): Whether to modify plate names for uniqueness.
|
1884
|
+
- 'verbose' (bool): Whether to print debug output.
|
1885
|
+
- 'um_per_px' (float or None): Pixel-to-micrometer conversion (if None, volume remains in px³).
|
1886
|
+
- 'level' (str): Either 'object' or 'well' for aggregation level.
|
1887
|
+
- 'cmap' (str): Colormap to use in the plot.
|
1888
|
+
- 'save' (bool): Whether to save results and figures.
|
1889
|
+
|
1890
|
+
Returns:
|
1891
|
+
dict: A dictionary with:
|
1892
|
+
- 'data': Annotated and binned DataFrame.
|
1893
|
+
- 'chi_squared': Chi-squared test summary DataFrame.
|
1894
|
+
|
1895
|
+
Side Effects:
|
1896
|
+
- Displays a stacked bar plot showing volume bin proportions by group.
|
1897
|
+
- Performs chi-squared test and prints p-value.
|
1898
|
+
- Saves data and plots to `results/analyze_endodyogeny/` if `save=True`.
|
1899
|
+
"""
|
1320
1900
|
from .utils import annotate_conditions, save_settings
|
1321
1901
|
from .io import _read_and_merge_data
|
1322
1902
|
from .settings import set_analyze_endodyogeny_defaults
|
@@ -1449,7 +2029,42 @@ def analyze_endodyogeny(settings):
|
|
1449
2029
|
return output
|
1450
2030
|
|
1451
2031
|
def analyze_class_proportion(settings):
|
1452
|
-
|
2032
|
+
"""
|
2033
|
+
Analyze class frequency distributions across experimental groups and perform statistical tests.
|
2034
|
+
|
2035
|
+
This function compares the proportion of classification outcomes (e.g., phenotypic classes)
|
2036
|
+
across specified experimental groups. It generates stacked bar plots, performs chi-squared
|
2037
|
+
and parametric/non-parametric statistical tests, and saves annotated data and results.
|
2038
|
+
|
2039
|
+
Args:
|
2040
|
+
settings (dict): Dictionary of parameters, including:
|
2041
|
+
- 'src' (str or list): Path(s) to measurement directories.
|
2042
|
+
- 'tables' (list): Tables to load from measurements.db.
|
2043
|
+
- 'class_column' (str): Column indicating class or category.
|
2044
|
+
- 'group_column' (str): Column indicating grouping variable (e.g., genotype or treatment).
|
2045
|
+
- 'cell_types', 'pathogen_types', 'treatments' (list): Experimental group values.
|
2046
|
+
- 'cell_plate_metadata', 'pathogen_plate_metadata', 'treatment_plate_metadata' (str): Metadata paths.
|
2047
|
+
- 'nuclei_limit', 'pathogen_limit' (int or None): Optional load limits.
|
2048
|
+
- 'verbose' (bool): Whether to print debug information.
|
2049
|
+
- 'level' (str): Aggregation level for chi-squared test ('object' or 'well').
|
2050
|
+
- 'save' (bool): Whether to save results and figures.
|
2051
|
+
|
2052
|
+
Returns:
|
2053
|
+
dict: A dictionary containing:
|
2054
|
+
- 'data': Annotated DataFrame.
|
2055
|
+
- 'chi_squared': Chi-squared summary results.
|
2056
|
+
|
2057
|
+
Side Effects:
|
2058
|
+
- Annotates DataFrame with experimental conditions.
|
2059
|
+
- Performs chi-squared test on class distribution and plots stacked bars.
|
2060
|
+
- Plots class heatmap across plates.
|
2061
|
+
- Performs and prints:
|
2062
|
+
* Shapiro-Wilk normality test
|
2063
|
+
* Levene’s test for equal variances
|
2064
|
+
* ANOVA or Kruskal-Wallis test
|
2065
|
+
* Post hoc pairwise comparisons
|
2066
|
+
- Optionally saves all results and figures in a `results/analyze_class_proportion/` subfolder.
|
2067
|
+
"""
|
1453
2068
|
from .utils import annotate_conditions, save_settings
|
1454
2069
|
from .io import _read_and_merge_data
|
1455
2070
|
from .settings import set_analyze_class_proportion_defaults
|
@@ -1553,9 +2168,56 @@ def analyze_class_proportion(settings):
|
|
1553
2168
|
return output
|
1554
2169
|
|
1555
2170
|
def generate_score_heatmap(settings):
|
1556
|
-
|
2171
|
+
"""
|
2172
|
+
Generate a heatmap comparing predicted classification scores from multiple models
|
2173
|
+
with empirical gRNA-based fraction measurements across wells.
|
2174
|
+
|
2175
|
+
This function:
|
2176
|
+
- Aggregates prediction scores from subfolders.
|
2177
|
+
- Calculates the gRNA-based empirical fraction from sequencing data.
|
2178
|
+
- Merges predictions with empirical scores and computes MAE (mean absolute error).
|
2179
|
+
- Displays and optionally saves a multi-channel heatmap and the corresponding MAE results.
|
2180
|
+
|
2181
|
+
Args:
|
2182
|
+
settings (dict): Configuration dictionary with the following keys:
|
2183
|
+
- 'folders' (list or str): Folder(s) containing model output subfolders with CSVs.
|
2184
|
+
- 'csv_name' (str): Name of the CSV file to search for in each model output subfolder.
|
2185
|
+
- 'data_column' (str): Column in CSVs containing the prediction scores.
|
2186
|
+
- 'cv_csv' (str): CSV with cross-validated model scores.
|
2187
|
+
- 'data_column_cv' (str): Column in `cv_csv` with scores to merge.
|
2188
|
+
- 'csv' (str): CSV with gRNA counts used for computing empirical fractions.
|
2189
|
+
- 'control_sgrnas' (list): List of two control gRNA names used to compute the fraction.
|
2190
|
+
- 'fraction_grna' (str): One of the control gRNAs to extract the empirical fraction from.
|
2191
|
+
- 'plateID' (int or str): Plate ID used for labeling and filtering.
|
2192
|
+
- 'columnID' (str): Column ID to filter on (e.g., 'c3').
|
2193
|
+
- 'cmap' (str): Colormap to use for the heatmap.
|
2194
|
+
- 'dst' (str or None): If provided, save outputs to this directory.
|
2195
|
+
|
2196
|
+
Returns:
|
2197
|
+
pd.DataFrame: Merged DataFrame containing all prediction and empirical score columns.
|
2198
|
+
|
2199
|
+
Side Effects:
|
2200
|
+
- Displays a multi-channel heatmap of prediction scores.
|
2201
|
+
- Computes and prints MAE per well per channel.
|
2202
|
+
- Saves merged data, MAE scores, and the heatmap as PDF/CSV if `dst` is set.
|
2203
|
+
"""
|
1557
2204
|
def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
|
1558
|
-
|
2205
|
+
"""
|
2206
|
+
Group cross-validation prediction scores by plate, row, and column.
|
2207
|
+
|
2208
|
+
This function reads a CSV file containing predictions, filters by a specific column (e.g., 'c3'),
|
2209
|
+
assigns a plate ID, and computes the mean prediction score per well.
|
2210
|
+
|
2211
|
+
Args:
|
2212
|
+
csv (str): Path to the CSV file containing predictions.
|
2213
|
+
plate (int or str, optional): Plate identifier to assign. Defaults to 1.
|
2214
|
+
column (str, optional): Column ID to filter by (e.g., 'c3'). Defaults to 'c3'.
|
2215
|
+
data_column (str, optional): Name of the column with prediction scores. Defaults to 'pred'.
|
2216
|
+
|
2217
|
+
Returns:
|
2218
|
+
pd.DataFrame: Aggregated DataFrame with mean scores and a unique 'prc' well identifier.
|
2219
|
+
Columns: ['plateID', 'rowID', 'columnID', data_column, 'prc']
|
2220
|
+
"""
|
1559
2221
|
df = pd.read_csv(csv)
|
1560
2222
|
if 'columnID' in df.columns:
|
1561
2223
|
df = df[df['columnID']==column]
|
@@ -1569,6 +2231,25 @@ def generate_score_heatmap(settings):
|
|
1569
2231
|
return grouped_df
|
1570
2232
|
|
1571
2233
|
def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
|
2234
|
+
"""
|
2235
|
+
Calculate the fraction of control sgRNAs in each well for a mixed condition.
|
2236
|
+
|
2237
|
+
This function filters a CSV containing sgRNA counts for a specified column and control sgRNAs,
|
2238
|
+
computes the total count per well, and calculates the fraction of each sgRNA within that well.
|
2239
|
+
A unique identifier ('prc') is added for each well.
|
2240
|
+
|
2241
|
+
Args:
|
2242
|
+
csv (str): Path to the input CSV file containing sgRNA data.
|
2243
|
+
plate (int or str, optional): Plate identifier to assign if not present. Defaults to 1.
|
2244
|
+
column (str, optional): Column name to filter by (e.g., 'c3'). Defaults to 'c3'.
|
2245
|
+
control_sgrnas (list of str, optional): List of two control sgRNA names to include in the calculation.
|
2246
|
+
Defaults to ['TGGT1_220950_1', 'TGGT1_233460_4'].
|
2247
|
+
|
2248
|
+
Returns:
|
2249
|
+
pd.DataFrame: DataFrame with fraction per sgRNA in each well and a unique 'prc' well identifier.
|
2250
|
+
Includes columns: ['plateID', 'rowID', 'column_name', 'grna_name', 'count',
|
2251
|
+
'total_count', 'fraction', 'prc']
|
2252
|
+
"""
|
1572
2253
|
df = pd.read_csv(csv)
|
1573
2254
|
df = df[df['column_name']==column]
|
1574
2255
|
if plate not in df.columns:
|
@@ -1631,6 +2312,24 @@ def generate_score_heatmap(settings):
|
|
1631
2312
|
|
1632
2313
|
|
1633
2314
|
def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
|
2315
|
+
"""
|
2316
|
+
Combine classification scores from multiple CSV files across folders.
|
2317
|
+
|
2318
|
+
This function traverses given directories, locates specified CSV files within subfolders,
|
2319
|
+
filters by a target column, aggregates prediction scores by well position, and merges
|
2320
|
+
the results into a single DataFrame with each file contributing a new column of scores.
|
2321
|
+
|
2322
|
+
Args:
|
2323
|
+
folders (str or list of str): Path(s) to parent folders containing subfolders with CSV files.
|
2324
|
+
csv_name (str): Name of the CSV file to locate in each subfolder.
|
2325
|
+
data_column (str): Name of the column in each CSV containing prediction scores to aggregate.
|
2326
|
+
plate (int or str, optional): Identifier for the plate to add as 'plateID'. Defaults to 1.
|
2327
|
+
column (str, optional): Column name to filter data by (e.g., 'c3'). Defaults to 'c3'.
|
2328
|
+
|
2329
|
+
Returns:
|
2330
|
+
pd.DataFrame: Combined DataFrame with mean prediction scores per well from all found CSV files.
|
2331
|
+
Includes a 'prc' column for well-level identifiers and one score column per source.
|
2332
|
+
"""
|
1634
2333
|
# Ensure `folders` is a list
|
1635
2334
|
if isinstance(folders, str):
|
1636
2335
|
folders = [folders]
|
@@ -1678,7 +2377,20 @@ def generate_score_heatmap(settings):
|
|
1678
2377
|
|
1679
2378
|
def calculate_mae(df):
|
1680
2379
|
"""
|
1681
|
-
Calculate
|
2380
|
+
Calculate mean absolute error (MAE) between prediction scores and reference fractions.
|
2381
|
+
|
2382
|
+
This function computes the MAE between each numeric prediction column and the
|
2383
|
+
reference 'fraction' column on a per-row basis. The results are returned as a
|
2384
|
+
long-form DataFrame with one row per (channel, row) pair.
|
2385
|
+
|
2386
|
+
Args:
|
2387
|
+
df (pd.DataFrame): Input DataFrame containing a 'fraction' column as ground truth,
|
2388
|
+
a 'prc' column as row identifier, and one or more numeric
|
2389
|
+
prediction columns.
|
2390
|
+
|
2391
|
+
Returns:
|
2392
|
+
pd.DataFrame: DataFrame with columns ['Channel', 'MAE', 'Row'], reporting the
|
2393
|
+
MAE between each prediction channel and the fraction per row.
|
1682
2394
|
"""
|
1683
2395
|
# Extract numeric columns excluding 'fraction' and 'prc'
|
1684
2396
|
channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
|
@@ -1721,7 +2433,23 @@ def generate_score_heatmap(settings):
|
|
1721
2433
|
return merged_df
|
1722
2434
|
|
1723
2435
|
def post_regression_analysis(csv_file, grna_dict, grna_list, save=False):
|
1724
|
-
|
2436
|
+
"""
|
2437
|
+
Perform post-regression analysis to assess gRNA correlation and infer relative effect sizes.
|
2438
|
+
|
2439
|
+
This function loads a CSV file containing gRNA abundance data (with columns including
|
2440
|
+
'grna', 'fraction', and 'prc'), filters it to the specified gRNAs of interest, computes
|
2441
|
+
a correlation matrix, visualizes it, and estimates inferred effect sizes for all gRNAs
|
2442
|
+
based on a subset of known fixed effect sizes.
|
2443
|
+
|
2444
|
+
Args:
|
2445
|
+
csv_file (str): Path to the CSV file containing gRNA data with columns ['grna', 'fraction', 'prc'].
|
2446
|
+
grna_dict (dict): Dictionary mapping selected gRNAs to their known effect sizes.
|
2447
|
+
grna_list (list): List of gRNAs to include in correlation and effect size inference.
|
2448
|
+
save (bool, optional): Whether to save outputs including figures and CSVs. Defaults to False.
|
2449
|
+
|
2450
|
+
Returns:
|
2451
|
+
None. Saves plots and results if `save=True`.
|
2452
|
+
"""
|
1725
2453
|
def _analyze_and_visualize_grna_correlation(df, grna_list, save_folder, save=False):
|
1726
2454
|
"""
|
1727
2455
|
Analyze and visualize the correlation matrix of gRNAs based on their fractions and overlap.
|