spacr 1.0.7__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/app_classify.py +10 -0
- spacr/app_mask.py +9 -0
- spacr/app_measure.py +9 -0
- spacr/app_sequencing.py +9 -0
- spacr/core.py +172 -1
- spacr/deep_spacr.py +296 -7
- spacr/gui.py +68 -0
- spacr/gui_core.py +319 -10
- spacr/gui_elements.py +772 -13
- spacr/gui_utils.py +301 -151
- spacr/io.py +887 -71
- spacr/logger.py +36 -0
- spacr/measure.py +206 -28
- spacr/ml.py +606 -142
- spacr/plot.py +797 -131
- spacr/sequencing.py +363 -8
- spacr/settings.py +1158 -38
- spacr/sp_stats.py +80 -12
- spacr/spacr_cellpose.py +115 -2
- spacr/submodules.py +747 -19
- spacr/timelapse.py +237 -53
- spacr/toxo.py +132 -6
- spacr/utils.py +2422 -80
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/METADATA +31 -17
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/RECORD +29 -29
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/LICENSE +0 -0
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/WHEEL +0 -0
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/entry_points.txt +0 -0
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/top_level.txt +0 -0
spacr/timelapse.py
CHANGED
@@ -18,7 +18,6 @@ except ImportError:
|
|
18
18
|
|
19
19
|
import matplotlib.pyplot as plt
|
20
20
|
|
21
|
-
|
22
21
|
def _npz_to_movie(arrays, filenames, save_path, fps=10):
|
23
22
|
"""
|
24
23
|
Convert a list of numpy arrays to a movie file.
|
@@ -76,55 +75,55 @@ def _npz_to_movie(arrays, filenames, save_path, fps=10):
|
|
76
75
|
print(f"Movie saved to {save_path}")
|
77
76
|
|
78
77
|
def _scmovie(folder_paths):
|
79
|
-
|
80
|
-
|
78
|
+
"""
|
79
|
+
Generate movies from a collection of PNG images in the given folder paths.
|
81
80
|
|
82
|
-
|
83
|
-
|
81
|
+
Args:
|
82
|
+
folder_paths (list): List of folder paths containing PNG images.
|
84
83
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
84
|
+
Returns:
|
85
|
+
None
|
86
|
+
"""
|
87
|
+
folder_paths = list(set(folder_paths))
|
88
|
+
for folder_path in folder_paths:
|
89
|
+
movie_path = os.path.join(folder_path, 'movies')
|
90
|
+
os.makedirs(movie_path, exist_ok=True)
|
91
|
+
# Regular expression to parse the filename
|
92
|
+
filename_regex = re.compile(r'(\w+)_(\w+)_(\w+)_(\d+)_(\d+).png')
|
93
|
+
# Dictionary to hold lists of images by plate, well, field, and object number
|
94
|
+
grouped_images = defaultdict(list)
|
95
|
+
# Iterate over all PNG files in the folder
|
96
|
+
for filename in os.listdir(folder_path):
|
97
|
+
if filename.endswith('.png'):
|
98
|
+
match = filename_regex.match(filename)
|
99
|
+
if match:
|
100
|
+
plate, well, field, time, object_number = match.groups()
|
101
|
+
key = (plate, well, field, object_number)
|
102
|
+
grouped_images[key].append((int(time), os.path.join(folder_path, filename)))
|
103
|
+
for key, images in grouped_images.items():
|
104
|
+
# Sort images by time using sorted and lambda function for custom sort key
|
105
|
+
images = sorted(images, key=lambda x: x[0])
|
106
|
+
_, image_paths = zip(*images)
|
107
|
+
# Determine the size to which all images should be padded
|
108
|
+
max_height = max_width = 0
|
109
|
+
for image_path in image_paths:
|
110
|
+
image = cv2.imread(image_path)
|
111
|
+
h, w, _ = image.shape
|
112
|
+
max_height, max_width = max(max_height, h), max(max_width, w)
|
113
|
+
# Initialize VideoWriter
|
114
|
+
plate, well, field, object_number = key
|
115
|
+
output_filename = f"{plate}_{well}_{field}_{object_number}.mp4"
|
116
|
+
output_path = os.path.join(movie_path, output_filename)
|
117
|
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
118
|
+
video = cv2.VideoWriter(output_path, fourcc, 10, (max_width, max_height))
|
119
|
+
# Process each image
|
120
|
+
for image_path in image_paths:
|
121
|
+
image = cv2.imread(image_path)
|
122
|
+
h, w, _ = image.shape
|
123
|
+
padded_image = np.zeros((max_height, max_width, 3), dtype=np.uint8)
|
124
|
+
padded_image[:h, :w, :] = image
|
125
|
+
video.write(padded_image)
|
126
|
+
video.release()
|
128
127
|
|
129
128
|
|
130
129
|
def _sort_key(file_path):
|
@@ -262,6 +261,25 @@ def _relabel_masks_based_on_tracks(masks, tracks, mode='btrack'):
|
|
262
261
|
return relabeled_masks
|
263
262
|
|
264
263
|
def _prepare_for_tracking(mask_array):
|
264
|
+
"""
|
265
|
+
Prepares object data from a sequence of labeled masks for object tracking.
|
266
|
+
|
267
|
+
This function extracts region properties from each timepoint in a 3D mask array,
|
268
|
+
reformats the data for tracking, and returns a concatenated DataFrame.
|
269
|
+
|
270
|
+
Args:
|
271
|
+
mask_array (numpy.ndarray): A 3D array of shape (T, H, W), where each 2D slice at time t
|
272
|
+
contains labeled objects for that frame.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
pandas.DataFrame: A DataFrame with one row per object per frame, including:
|
276
|
+
- 'frame': Time index.
|
277
|
+
- 'y', 'x': Object centroid coordinates.
|
278
|
+
- 'mass': Object area.
|
279
|
+
- 'original_label': Label ID in the input mask.
|
280
|
+
- 'bbox-0', 'bbox-1', 'bbox-2', 'bbox-3': Bounding box coordinates (min_row, min_col, max_row, max_col).
|
281
|
+
- 'eccentricity': Shape eccentricity of the object.
|
282
|
+
"""
|
265
283
|
frames = []
|
266
284
|
for t, frame in enumerate(mask_array):
|
267
285
|
props = regionprops_table(
|
@@ -282,8 +300,21 @@ def _prepare_for_tracking(mask_array):
|
|
282
300
|
|
283
301
|
def _track_by_iou(masks, iou_threshold=0.1):
|
284
302
|
"""
|
285
|
-
|
286
|
-
|
303
|
+
Assigns track IDs to labeled objects across frames by linking objects using Intersection-over-Union (IoU).
|
304
|
+
|
305
|
+
Each unique object label in frame 0 is assigned a unique track ID. Objects in subsequent frames are linked
|
306
|
+
to previous objects if their IoU exceeds the given threshold. Unmatched objects start new tracks.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
masks (np.ndarray): A 3D array of shape (T, H, W) with integer labels in each 2D frame.
|
310
|
+
Background must be labeled as 0.
|
311
|
+
iou_threshold (float): Minimum IoU required to link objects between consecutive frames. Defaults to 0.1.
|
312
|
+
|
313
|
+
Returns:
|
314
|
+
pd.DataFrame: A DataFrame with columns:
|
315
|
+
- 'frame': Frame index.
|
316
|
+
- 'original_label': Label value in the input mask.
|
317
|
+
- 'track_id': Assigned unique track ID.
|
287
318
|
"""
|
288
319
|
n_frames = masks.shape[0]
|
289
320
|
# 1) initialize: every label in frame 0 starts its own track
|
@@ -317,6 +348,17 @@ def _track_by_iou(masks, iou_threshold=0.1):
|
|
317
348
|
return pd.DataFrame(records)
|
318
349
|
|
319
350
|
def link_by_iou(mask_prev, mask_next, iou_threshold=0.1):
|
351
|
+
"""
|
352
|
+
Matches labeled objects between two frames based on Intersection-over-Union (IoU).
|
353
|
+
|
354
|
+
Args:
|
355
|
+
mask_prev (np.ndarray): 2D array of integer labels from the previous frame.
|
356
|
+
mask_next (np.ndarray): 2D array of integer labels from the next frame.
|
357
|
+
iou_threshold (float, optional): Minimum IoU required to consider a match. Defaults to 0.1.
|
358
|
+
|
359
|
+
Returns:
|
360
|
+
list of tuple: List of (label_prev, label_next) pairs for matched objects.
|
361
|
+
"""
|
320
362
|
# Get labels
|
321
363
|
labels_prev = np.unique(mask_prev)[1:]
|
322
364
|
labels_next = np.unique(mask_next)[1:]
|
@@ -393,8 +435,17 @@ def _remove_objects_from_first_frame(masks, percentage=10):
|
|
393
435
|
|
394
436
|
def _track_by_iou(masks, iou_threshold=0.1):
|
395
437
|
"""
|
396
|
-
|
397
|
-
|
438
|
+
Assigns consistent track IDs to labeled objects across time-lapse mask frames based on IoU matching.
|
439
|
+
|
440
|
+
Args:
|
441
|
+
masks (np.ndarray): 3D array of shape (T, H, W), where T is the number of time frames,
|
442
|
+
and each frame contains labeled segmentation masks.
|
443
|
+
iou_threshold (float, optional): The minimum Intersection-over-Union required to consider
|
444
|
+
two labels as matching. Defaults to 0.1.
|
445
|
+
|
446
|
+
Returns:
|
447
|
+
pd.DataFrame: A DataFrame with columns ['frame', 'original_label', 'track_id'], mapping
|
448
|
+
each original label in every frame to a track ID.
|
398
449
|
"""
|
399
450
|
n_frames = masks.shape[0]
|
400
451
|
# 1) initialize: every label in frame 0 starts its own track
|
@@ -427,7 +478,6 @@ def _track_by_iou(masks, iou_threshold=0.1):
|
|
427
478
|
records.append({'frame': frame, 'original_label': label, 'track_id': tid})
|
428
479
|
return pd.DataFrame(records)
|
429
480
|
|
430
|
-
|
431
481
|
def _facilitate_trackin_with_adaptive_removal(masks, search_range=None, max_attempts=5, memory=3, min_mass=50, track_by_iou=False):
|
432
482
|
"""
|
433
483
|
Facilitates object tracking with deterministic initial filtering and
|
@@ -631,9 +681,36 @@ def _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, mas
|
|
631
681
|
return mask_stack
|
632
682
|
|
633
683
|
def exponential_decay(x, a, b, c):
|
684
|
+
"""
|
685
|
+
Computes the exponential decay function.
|
686
|
+
|
687
|
+
Args:
|
688
|
+
x (float or np.ndarray): Input value(s).
|
689
|
+
a (float): Initial amplitude.
|
690
|
+
b (float): Decay rate.
|
691
|
+
c (float): Offset.
|
692
|
+
|
693
|
+
Returns:
|
694
|
+
float or np.ndarray: The result of a * exp(-b * x) + c.
|
695
|
+
"""
|
634
696
|
return a * np.exp(-b * x) + c
|
635
697
|
|
636
698
|
def preprocess_pathogen_data(pathogen_df):
|
699
|
+
"""
|
700
|
+
Preprocesses pathogen data by aggregating measurements per host cell and counting parasites.
|
701
|
+
|
702
|
+
Args:
|
703
|
+
pathogen_df (pd.DataFrame): Input DataFrame containing pathogen-level measurements.
|
704
|
+
Must include 'plateID', 'rowID', 'column_name', 'fieldID', 'timeid', and 'pathogen_cell_id' columns.
|
705
|
+
|
706
|
+
Returns:
|
707
|
+
pd.DataFrame: Aggregated DataFrame with one row per pathogen-hosting cell, including:
|
708
|
+
- Mean of numeric features,
|
709
|
+
- First value of non-numeric features,
|
710
|
+
- Count of parasites per cell,
|
711
|
+
- Renamed 'pathogen_cell_id' to 'object_label',
|
712
|
+
- Removal of 'object_label' if originally present.
|
713
|
+
"""
|
637
714
|
# Group by identifiers and count the number of parasites
|
638
715
|
parasite_counts = pathogen_df.groupby(['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id']).size().reset_index(name='parasite_count')
|
639
716
|
|
@@ -654,9 +731,41 @@ def preprocess_pathogen_data(pathogen_df):
|
|
654
731
|
return pathogen_agg
|
655
732
|
|
656
733
|
def plot_data(measurement, group, ax, label, marker='o', linestyle='-'):
|
734
|
+
"""
|
735
|
+
Plots a time series of delta measurements on a given Axes object.
|
736
|
+
|
737
|
+
Args:
|
738
|
+
measurement (str): The name of the measurement to plot (e.g., 'intensity').
|
739
|
+
group (pd.DataFrame): A DataFrame containing time series data with a 'time' column and
|
740
|
+
a 'delta_<measurement>' column.
|
741
|
+
ax (matplotlib.axes.Axes): The Axes object to plot on.
|
742
|
+
label (str): The label for the plotted line (used in the legend).
|
743
|
+
marker (str, optional): Marker style for the plot. Defaults to 'o'.
|
744
|
+
linestyle (str, optional): Line style for the plot. Defaults to '-'.
|
745
|
+
|
746
|
+
Returns:
|
747
|
+
None
|
748
|
+
"""
|
657
749
|
ax.plot(group['time'], group['delta_' + measurement], marker=marker, linestyle=linestyle, label=label)
|
658
750
|
|
659
751
|
def infected_vs_noninfected(result_df, measurement):
|
752
|
+
"""
|
753
|
+
Plots time series data comparing infected and uninfected cells based on a specified measurement.
|
754
|
+
|
755
|
+
Args:
|
756
|
+
result_df (pd.DataFrame): DataFrame containing cell tracking data including
|
757
|
+
'parasite_count', 'plate_row_column_field_object', and 'time' columns.
|
758
|
+
measurement (str): Name of the measurement to be visualized. Assumes a
|
759
|
+
'delta_<measurement>' column exists in the DataFrame.
|
760
|
+
|
761
|
+
Returns:
|
762
|
+
None
|
763
|
+
|
764
|
+
Notes:
|
765
|
+
- Cells are considered infected if their maximum parasite count across all time points is > 0.
|
766
|
+
- Two subplots are generated: one for infected cells and one for uninfected cells.
|
767
|
+
- Time series are plotted for each unique cell using `plot_data()`.
|
768
|
+
"""
|
660
769
|
# Separate the merged dataframe into two groups based on pathogen_count
|
661
770
|
infected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]
|
662
771
|
uninfected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]
|
@@ -688,6 +797,17 @@ def infected_vs_noninfected(result_df, measurement):
|
|
688
797
|
plt.show()
|
689
798
|
|
690
799
|
def save_figure(fig, src, figure_number):
|
800
|
+
"""
|
801
|
+
Saves a figure as a PDF in the 'results' directory relative to the given source path.
|
802
|
+
|
803
|
+
Args:
|
804
|
+
fig (matplotlib.figure.Figure): The figure object to be saved.
|
805
|
+
src (str): Path to a file or directory used to determine the save location.
|
806
|
+
figure_number (int or str): Identifier number for the figure file name.
|
807
|
+
|
808
|
+
Returns:
|
809
|
+
None
|
810
|
+
"""
|
691
811
|
source = os.path.dirname(src)
|
692
812
|
results_fldr = os.path.join(source,'results')
|
693
813
|
os.makedirs(results_fldr, exist_ok=True)
|
@@ -696,6 +816,17 @@ def save_figure(fig, src, figure_number):
|
|
696
816
|
print(f'Saved figure:{fig_loc}')
|
697
817
|
|
698
818
|
def save_results_dataframe(df, src, results_name):
|
819
|
+
"""
|
820
|
+
Saves a DataFrame as a CSV file in the 'results' directory relative to the given source path.
|
821
|
+
|
822
|
+
Args:
|
823
|
+
df (pandas.DataFrame): The DataFrame to save.
|
824
|
+
src (str): Path used to determine the base directory for saving results.
|
825
|
+
results_name (str): Name of the CSV file (without extension).
|
826
|
+
|
827
|
+
Returns:
|
828
|
+
None
|
829
|
+
"""
|
699
830
|
source = os.path.dirname(src)
|
700
831
|
results_fldr = os.path.join(source,'results')
|
701
832
|
os.makedirs(results_fldr, exist_ok=True)
|
@@ -704,6 +835,22 @@ def save_results_dataframe(df, src, results_name):
|
|
704
835
|
print(f'Saved results:{csv_loc}')
|
705
836
|
|
706
837
|
def summarize_per_well(peak_details_df):
|
838
|
+
"""
|
839
|
+
Summarizes peak data per well by extracting metadata, filtering valid entries, and computing statistics.
|
840
|
+
|
841
|
+
Args:
|
842
|
+
peak_details_df (pandas.DataFrame): DataFrame containing a column 'ID' and numeric measurement columns,
|
843
|
+
including 'amplitude'.
|
844
|
+
|
845
|
+
Returns:
|
846
|
+
pandas.DataFrame: Summary DataFrame with the following columns:
|
847
|
+
- 'well_ID': Unique well identifier ('rowID_columnID').
|
848
|
+
- 'peaks_per_well': Total number of peaks per well.
|
849
|
+
- 'unique_IDs_with_amplitude': Number of unique object IDs with valid amplitude.
|
850
|
+
- 'cells_per_well': Number of unique objects per well.
|
851
|
+
- 'peaks_per_cell': Ratio of peaks to cells per well.
|
852
|
+
- Mean values of all numeric columns (excluding ID components) aggregated per well.
|
853
|
+
"""
|
707
854
|
# Step 1: Split the 'ID' column
|
708
855
|
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
709
856
|
peak_details_df[['plateID', 'rowID', 'columnID', 'fieldID', 'object_number']] = split_columns
|
@@ -735,6 +882,21 @@ def summarize_per_well(peak_details_df):
|
|
735
882
|
return summary_df
|
736
883
|
|
737
884
|
def summarize_per_well_inf_non_inf(peak_details_df):
|
885
|
+
"""
|
886
|
+
Summarizes peak data per well stratified by infection status.
|
887
|
+
|
888
|
+
Args:
|
889
|
+
peak_details_df (pandas.DataFrame): DataFrame containing peak data with the following required columns:
|
890
|
+
- 'ID': Composite string identifier ('plateID_rowID_columnID_fieldID_object_number').
|
891
|
+
- 'infected': Integer or boolean indicator (values > 0 denote infection).
|
892
|
+
|
893
|
+
Returns:
|
894
|
+
pandas.DataFrame: Summary statistics grouped by 'well_ID' and 'infected_status', including:
|
895
|
+
- 'cells_per_well': Unique cell count per group.
|
896
|
+
- 'peaks_per_well': Total number of peaks per group.
|
897
|
+
- 'peaks_per_cell': Ratio of peaks to cells.
|
898
|
+
- Mean of all numeric columns aggregated per group.
|
899
|
+
"""
|
738
900
|
# Step 1: Split the 'ID' column
|
739
901
|
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
740
902
|
peak_details_df[['plateID', 'rowID', 'columnID', 'fieldID', 'object_number']] = split_columns
|
@@ -762,6 +924,28 @@ def summarize_per_well_inf_non_inf(peak_details_df):
|
|
762
924
|
return summary_df
|
763
925
|
|
764
926
|
def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
|
927
|
+
"""
|
928
|
+
Analyze calcium oscillations in single-cell time-lapse data by detecting peaks in fluorescence intensity changes.
|
929
|
+
|
930
|
+
Args:
|
931
|
+
db_loc (str): Path to the SQLite database containing 'cell', optionally 'pathogen' and/or 'cytoplasm' tables.
|
932
|
+
measurement (str): Column name in the 'cell' table to analyze (default: 'cell_channel_1_mean_intensity').
|
933
|
+
size_filter (str): Column used to filter out unstable cells by size fluctuation (default: 'cell_area').
|
934
|
+
fluctuation_threshold (float): Relative threshold for cell size fluctuation (std/mean) to retain cells (default: 0.25).
|
935
|
+
num_lines (int or None): Number of single-cell traces to show in the output plot. If None, plot all (default: None).
|
936
|
+
peak_height (float): Minimum peak height in delta-fluorescence for peak detection (default: 0.01).
|
937
|
+
pathogen (bool or None): If True, merge with 'pathogen' table and use parasite counts to annotate infection (default: None).
|
938
|
+
cytoplasm (bool or None): If True, merge with 'cytoplasm' table (default: None).
|
939
|
+
remove_transient (bool): Whether to discard cells tracked in less than a fraction of total timepoints (default: True).
|
940
|
+
verbose (bool): Print detailed progress and filtering info (default: False).
|
941
|
+
transience_threshold (float): Minimum fraction of timepoints a cell must be tracked (default: 0.9).
|
942
|
+
|
943
|
+
Returns:
|
944
|
+
tuple: (result_df, peak_details_df, fig)
|
945
|
+
- result_df (DataFrame): All analyzed per-frame data with delta fluorescence.
|
946
|
+
- peak_details_df (DataFrame): Detected peaks and derived features per peak.
|
947
|
+
- fig (Matplotlib figure): Line plot of normalized fluorescence deltas over time.
|
948
|
+
"""
|
765
949
|
# Load data
|
766
950
|
conn = sqlite3.connect(db_loc)
|
767
951
|
# Load cell table
|
spacr/toxo.py
CHANGED
@@ -9,6 +9,44 @@ from sklearn.metrics import mean_absolute_error
|
|
9
9
|
from matplotlib.gridspec import GridSpec
|
10
10
|
|
11
11
|
def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',point_size=50, figsize=20, threshold=0,save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 20]]):
|
12
|
+
def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=0, save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 20]]):
|
13
|
+
"""
|
14
|
+
Creates a custom volcano plot with two subplots (upper and lower) sharing the x-axis,
|
15
|
+
where points are colored based on metadata categories.
|
16
|
+
Parameters:
|
17
|
+
-----------
|
18
|
+
data_path : str or pandas.DataFrame
|
19
|
+
Path to the CSV file containing the data or a pandas DataFrame. The data should
|
20
|
+
include columns 'feature', 'p_value', and 'coefficient'.
|
21
|
+
metadata_path : str or pandas.DataFrame
|
22
|
+
Path to the CSV file containing metadata or a pandas DataFrame. The metadata
|
23
|
+
should include a column matching `metadata_column` and 'gene_nr'.
|
24
|
+
metadata_column : str, optional
|
25
|
+
The column in the metadata file used to categorize and color points (default is 'tagm_location').
|
26
|
+
point_size : int, optional
|
27
|
+
Size of the scatter plot points (default is 50).
|
28
|
+
figsize : int, optional
|
29
|
+
Size of the figure (default is 20).
|
30
|
+
threshold : float, optional
|
31
|
+
Threshold for the absolute value of the coefficient to consider a point significant (default is 0).
|
32
|
+
save_path : str, optional
|
33
|
+
Path to save the plot as a PDF. If None, the plot is not saved (default is None).
|
34
|
+
x_lim : list of float, optional
|
35
|
+
Limits for the x-axis (default is [-0.5, 0.5]).
|
36
|
+
y_lims : list of lists, optional
|
37
|
+
Limits for the y-axis for the lower and upper subplots (default is [[0, 6], [9, 20]]).
|
38
|
+
Returns:
|
39
|
+
--------
|
40
|
+
hit_list : list of str
|
41
|
+
List of variables (hits) that meet the significance thresholds for p-value and coefficient.
|
42
|
+
Notes:
|
43
|
+
------
|
44
|
+
- The function uses a color dictionary to assign colors to points based on the `metadata_column` values.
|
45
|
+
- Points with p-value <= 0.05 and absolute coefficient >= `threshold` are considered significant.
|
46
|
+
- Significant points are annotated with their variable names, and overlapping annotations are adjusted.
|
47
|
+
- A legend is optionally added to indicate the mapping of metadata categories to colors.
|
48
|
+
- The plot is divided into two subplots to handle a wide range of p-values effectively.
|
49
|
+
"""
|
12
50
|
|
13
51
|
# Dictionary mapping compartment to color
|
14
52
|
|
@@ -326,12 +364,20 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
|
|
326
364
|
|
327
365
|
def plot_gene_phenotypes(data, gene_list, x_column='Gene ID', data_column='T.gondii GT1 CRISPR Phenotype - Mean Phenotype',error_column='T.gondii GT1 CRISPR Phenotype - Standard Error', save_path=None):
|
328
366
|
"""
|
329
|
-
Plot
|
330
|
-
|
367
|
+
Plot gene phenotype means with error bars, highlighting selected genes.
|
368
|
+
|
331
369
|
Args:
|
332
|
-
data (pd.DataFrame):
|
333
|
-
gene_list (list):
|
370
|
+
data (pd.DataFrame): DataFrame with phenotype data.
|
371
|
+
gene_list (list of str): List of gene identifiers to highlight.
|
372
|
+
x_column (str, optional): Column with gene identifiers. Default: 'Gene ID'.
|
373
|
+
data_column (str, optional): Column with mean phenotype values.
|
374
|
+
error_column (str, optional): Column with standard error values.
|
375
|
+
save_path (str or None, optional): Path to save PDF. If None, plot is not saved.
|
376
|
+
|
377
|
+
Returns:
|
378
|
+
None
|
334
379
|
"""
|
380
|
+
|
335
381
|
# Ensure x_column is properly processed
|
336
382
|
def extract_gene_id(gene):
|
337
383
|
if isinstance(gene, str) and '_' in gene:
|
@@ -471,9 +517,48 @@ def plot_gene_heatmaps(data, gene_list, columns, x_column='Gene ID', normalize=F
|
|
471
517
|
plt.show()
|
472
518
|
|
473
519
|
def generate_score_heatmap(settings):
|
520
|
+
"""
|
521
|
+
Generates a score comparison heatmap and calculates MAE between classification scores and a reference fraction.
|
474
522
|
|
523
|
+
This function:
|
524
|
+
- Combines classification scores from multiple folders.
|
525
|
+
- Calculates the control gRNA fractions for a specific gRNA.
|
526
|
+
- Merges the fraction data with classification and cross-validation scores.
|
527
|
+
- Plots a heatmap of scores.
|
528
|
+
- Calculates Mean Absolute Error (MAE) between the prediction scores and the true fraction.
|
529
|
+
- Optionally saves the resulting data and figure to disk.
|
530
|
+
|
531
|
+
Args:
|
532
|
+
settings (dict): A dictionary containing:
|
533
|
+
- 'folders': List of folders to search for classification CSVs.
|
534
|
+
- 'csv_name': Name of the classification CSV file in each folder.
|
535
|
+
- 'data_column': Column with predicted values to be compared.
|
536
|
+
- 'plateID': Plate identifier (e.g., 1).
|
537
|
+
- 'columnID': Column identifier (e.g., 'c3').
|
538
|
+
- 'csv': Path to the control gRNA CSV file.
|
539
|
+
- 'control_sgrnas': List of two control gRNA names.
|
540
|
+
- 'fraction_grna': gRNA name for which the true fraction will be extracted.
|
541
|
+
- 'cv_csv': Path to the CSV file with cross-validation predictions.
|
542
|
+
- 'data_column_cv': Column name for cross-validation predictions.
|
543
|
+
- 'dst': Output directory path (or None to disable saving).
|
544
|
+
|
545
|
+
Returns:
|
546
|
+
pandas.DataFrame: The merged DataFrame with 'fraction', predicted scores, and cross-validation scores.
|
547
|
+
"""
|
548
|
+
|
475
549
|
def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
|
476
|
-
|
550
|
+
"""
|
551
|
+
Group and average cross-validation scores (or other predictions) for each well.
|
552
|
+
|
553
|
+
Args:
|
554
|
+
csv (str): Path to the input CSV file.
|
555
|
+
plate (int, optional): Plate number to assign if not present in the file. Defaults to 1.
|
556
|
+
column (str, optional): Column (well) to filter on, e.g., 'c3'. Defaults to 'c3'.
|
557
|
+
data_column (str, optional): Column to average, e.g., prediction or score column. Defaults to 'pred'.
|
558
|
+
|
559
|
+
Returns:
|
560
|
+
pandas.DataFrame: DataFrame with grouped average values and a 'prc' identifier (plate_row_column).
|
561
|
+
"""
|
477
562
|
df = pd.read_csv(csv)
|
478
563
|
if 'column_name' in df.columns:
|
479
564
|
df = df[df['column_name']==column]
|
@@ -487,6 +572,21 @@ def generate_score_heatmap(settings):
|
|
487
572
|
return grouped_df
|
488
573
|
|
489
574
|
def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
|
575
|
+
"""
|
576
|
+
Calculate the fraction of control_sgrnas in a specified column and plate.
|
577
|
+
|
578
|
+
Args:
|
579
|
+
csv (str): Path to the input CSV file.
|
580
|
+
plate (int, optional): Plate number to filter on. Defaults to 1.
|
581
|
+
column (str, optional): Column name (e.g., 'c3') to filter on. Defaults to 'c3'.
|
582
|
+
control_sgrnas (list of str, optional): List of two gRNA identifiers to include in the fraction calculation.
|
583
|
+
Defaults to ['TGGT1_220950_1', 'TGGT1_233460_4'].
|
584
|
+
|
585
|
+
Returns:
|
586
|
+
pandas.DataFrame: DataFrame with columns including:
|
587
|
+
- 'fraction': Fraction of each control_sgrna in its well.
|
588
|
+
- 'prc': Plate-row-column identifier.
|
589
|
+
"""
|
490
590
|
df = pd.read_csv(csv)
|
491
591
|
df = df[df['column_name']==column]
|
492
592
|
if plate not in df.columns:
|
@@ -549,6 +649,22 @@ def generate_score_heatmap(settings):
|
|
549
649
|
|
550
650
|
|
551
651
|
def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
|
652
|
+
"""
|
653
|
+
Combines classification scores from multiple CSV files across subfolders.
|
654
|
+
|
655
|
+
Args:
|
656
|
+
folders (str or list of str): A folder or list of folders containing subdirectories
|
657
|
+
with CSV files named `csv_name`.
|
658
|
+
csv_name (str): The name of the CSV file to search for in each subdirectory.
|
659
|
+
data_column (str): The column name in the CSV to extract and average.
|
660
|
+
plate (int, optional): Plate number to assign to all rows. Defaults to 1.
|
661
|
+
column (str, optional): Column name to filter on, e.g., 'c3'. Defaults to 'c3'.
|
662
|
+
|
663
|
+
Returns:
|
664
|
+
pandas.DataFrame: Combined DataFrame with averaged `data_column` values from each source file.
|
665
|
+
Each source contributes a uniquely named column, and the result includes a
|
666
|
+
'prc' column (plate_row_column identifier).
|
667
|
+
"""
|
552
668
|
# Ensure `folders` is a list
|
553
669
|
if isinstance(folders, str):
|
554
670
|
folders = [folders]
|
@@ -596,7 +712,17 @@ def generate_score_heatmap(settings):
|
|
596
712
|
|
597
713
|
def calculate_mae(df):
|
598
714
|
"""
|
599
|
-
|
715
|
+
Calculates the Mean Absolute Error (MAE) between the 'fraction' column and all other numeric columns
|
716
|
+
(excluding 'fraction' and 'prc') for each row in the DataFrame.
|
717
|
+
|
718
|
+
Args:
|
719
|
+
df (pandas.DataFrame): Input DataFrame containing a 'fraction' column and prediction columns to compare.
|
720
|
+
|
721
|
+
Returns:
|
722
|
+
pandas.DataFrame: A DataFrame with columns ['Channel', 'MAE', 'Row'], where:
|
723
|
+
- 'Channel' indicates the compared prediction column,
|
724
|
+
- 'MAE' is the mean absolute error between that column and 'fraction',
|
725
|
+
- 'Row' contains the 'prc' identifier for the row.
|
600
726
|
"""
|
601
727
|
# Extract numeric columns excluding 'fraction' and 'prc'
|
602
728
|
channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
|