spacr 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/timelapse.py CHANGED
@@ -18,7 +18,6 @@ except ImportError:
18
18
 
19
19
  import matplotlib.pyplot as plt
20
20
 
21
-
22
21
  def _npz_to_movie(arrays, filenames, save_path, fps=10):
23
22
  """
24
23
  Convert a list of numpy arrays to a movie file.
@@ -76,55 +75,55 @@ def _npz_to_movie(arrays, filenames, save_path, fps=10):
76
75
  print(f"Movie saved to {save_path}")
77
76
 
78
77
  def _scmovie(folder_paths):
79
- """
80
- Generate movies from a collection of PNG images in the given folder paths.
78
+ """
79
+ Generate movies from a collection of PNG images in the given folder paths.
81
80
 
82
- Args:
83
- folder_paths (list): List of folder paths containing PNG images.
81
+ Args:
82
+ folder_paths (list): List of folder paths containing PNG images.
84
83
 
85
- Returns:
86
- None
87
- """
88
- folder_paths = list(set(folder_paths))
89
- for folder_path in folder_paths:
90
- movie_path = os.path.join(folder_path, 'movies')
91
- os.makedirs(movie_path, exist_ok=True)
92
- # Regular expression to parse the filename
93
- filename_regex = re.compile(r'(\w+)_(\w+)_(\w+)_(\d+)_(\d+).png')
94
- # Dictionary to hold lists of images by plate, well, field, and object number
95
- grouped_images = defaultdict(list)
96
- # Iterate over all PNG files in the folder
97
- for filename in os.listdir(folder_path):
98
- if filename.endswith('.png'):
99
- match = filename_regex.match(filename)
100
- if match:
101
- plate, well, field, time, object_number = match.groups()
102
- key = (plate, well, field, object_number)
103
- grouped_images[key].append((int(time), os.path.join(folder_path, filename)))
104
- for key, images in grouped_images.items():
105
- # Sort images by time using sorted and lambda function for custom sort key
106
- images = sorted(images, key=lambda x: x[0])
107
- _, image_paths = zip(*images)
108
- # Determine the size to which all images should be padded
109
- max_height = max_width = 0
110
- for image_path in image_paths:
111
- image = cv2.imread(image_path)
112
- h, w, _ = image.shape
113
- max_height, max_width = max(max_height, h), max(max_width, w)
114
- # Initialize VideoWriter
115
- plate, well, field, object_number = key
116
- output_filename = f"{plate}_{well}_{field}_{object_number}.mp4"
117
- output_path = os.path.join(movie_path, output_filename)
118
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
119
- video = cv2.VideoWriter(output_path, fourcc, 10, (max_width, max_height))
120
- # Process each image
121
- for image_path in image_paths:
122
- image = cv2.imread(image_path)
123
- h, w, _ = image.shape
124
- padded_image = np.zeros((max_height, max_width, 3), dtype=np.uint8)
125
- padded_image[:h, :w, :] = image
126
- video.write(padded_image)
127
- video.release()
84
+ Returns:
85
+ None
86
+ """
87
+ folder_paths = list(set(folder_paths))
88
+ for folder_path in folder_paths:
89
+ movie_path = os.path.join(folder_path, 'movies')
90
+ os.makedirs(movie_path, exist_ok=True)
91
+ # Regular expression to parse the filename
92
+ filename_regex = re.compile(r'(\w+)_(\w+)_(\w+)_(\d+)_(\d+).png')
93
+ # Dictionary to hold lists of images by plate, well, field, and object number
94
+ grouped_images = defaultdict(list)
95
+ # Iterate over all PNG files in the folder
96
+ for filename in os.listdir(folder_path):
97
+ if filename.endswith('.png'):
98
+ match = filename_regex.match(filename)
99
+ if match:
100
+ plate, well, field, time, object_number = match.groups()
101
+ key = (plate, well, field, object_number)
102
+ grouped_images[key].append((int(time), os.path.join(folder_path, filename)))
103
+ for key, images in grouped_images.items():
104
+ # Sort images by time using sorted and lambda function for custom sort key
105
+ images = sorted(images, key=lambda x: x[0])
106
+ _, image_paths = zip(*images)
107
+ # Determine the size to which all images should be padded
108
+ max_height = max_width = 0
109
+ for image_path in image_paths:
110
+ image = cv2.imread(image_path)
111
+ h, w, _ = image.shape
112
+ max_height, max_width = max(max_height, h), max(max_width, w)
113
+ # Initialize VideoWriter
114
+ plate, well, field, object_number = key
115
+ output_filename = f"{plate}_{well}_{field}_{object_number}.mp4"
116
+ output_path = os.path.join(movie_path, output_filename)
117
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
118
+ video = cv2.VideoWriter(output_path, fourcc, 10, (max_width, max_height))
119
+ # Process each image
120
+ for image_path in image_paths:
121
+ image = cv2.imread(image_path)
122
+ h, w, _ = image.shape
123
+ padded_image = np.zeros((max_height, max_width, 3), dtype=np.uint8)
124
+ padded_image[:h, :w, :] = image
125
+ video.write(padded_image)
126
+ video.release()
128
127
 
129
128
 
130
129
  def _sort_key(file_path):
@@ -262,6 +261,25 @@ def _relabel_masks_based_on_tracks(masks, tracks, mode='btrack'):
262
261
  return relabeled_masks
263
262
 
264
263
  def _prepare_for_tracking(mask_array):
264
+ """
265
+ Prepares object data from a sequence of labeled masks for object tracking.
266
+
267
+ This function extracts region properties from each timepoint in a 3D mask array,
268
+ reformats the data for tracking, and returns a concatenated DataFrame.
269
+
270
+ Args:
271
+ mask_array (numpy.ndarray): A 3D array of shape (T, H, W), where each 2D slice at time t
272
+ contains labeled objects for that frame.
273
+
274
+ Returns:
275
+ pandas.DataFrame: A DataFrame with one row per object per frame, including:
276
+ - 'frame': Time index.
277
+ - 'y', 'x': Object centroid coordinates.
278
+ - 'mass': Object area.
279
+ - 'original_label': Label ID in the input mask.
280
+ - 'bbox-0', 'bbox-1', 'bbox-2', 'bbox-3': Bounding box coordinates (min_row, min_col, max_row, max_col).
281
+ - 'eccentricity': Shape eccentricity of the object.
282
+ """
265
283
  frames = []
266
284
  for t, frame in enumerate(mask_array):
267
285
  props = regionprops_table(
@@ -282,8 +300,21 @@ def _prepare_for_tracking(mask_array):
282
300
 
283
301
  def _track_by_iou(masks, iou_threshold=0.1):
284
302
  """
285
- Build a track table by linking masks frame→frame via IoU.
286
- Returns a DataFrame with columns [frame, original_label, track_id].
303
+ Assigns track IDs to labeled objects across frames by linking objects using Intersection-over-Union (IoU).
304
+
305
+ Each unique object label in frame 0 is assigned a unique track ID. Objects in subsequent frames are linked
306
+ to previous objects if their IoU exceeds the given threshold. Unmatched objects start new tracks.
307
+
308
+ Args:
309
+ masks (np.ndarray): A 3D array of shape (T, H, W) with integer labels in each 2D frame.
310
+ Background must be labeled as 0.
311
+ iou_threshold (float): Minimum IoU required to link objects between consecutive frames. Defaults to 0.1.
312
+
313
+ Returns:
314
+ pd.DataFrame: A DataFrame with columns:
315
+ - 'frame': Frame index.
316
+ - 'original_label': Label value in the input mask.
317
+ - 'track_id': Assigned unique track ID.
287
318
  """
288
319
  n_frames = masks.shape[0]
289
320
  # 1) initialize: every label in frame 0 starts its own track
@@ -317,6 +348,17 @@ def _track_by_iou(masks, iou_threshold=0.1):
317
348
  return pd.DataFrame(records)
318
349
 
319
350
  def link_by_iou(mask_prev, mask_next, iou_threshold=0.1):
351
+ """
352
+ Matches labeled objects between two frames based on Intersection-over-Union (IoU).
353
+
354
+ Args:
355
+ mask_prev (np.ndarray): 2D array of integer labels from the previous frame.
356
+ mask_next (np.ndarray): 2D array of integer labels from the next frame.
357
+ iou_threshold (float, optional): Minimum IoU required to consider a match. Defaults to 0.1.
358
+
359
+ Returns:
360
+ list of tuple: List of (label_prev, label_next) pairs for matched objects.
361
+ """
320
362
  # Get labels
321
363
  labels_prev = np.unique(mask_prev)[1:]
322
364
  labels_next = np.unique(mask_next)[1:]
@@ -393,8 +435,17 @@ def _remove_objects_from_first_frame(masks, percentage=10):
393
435
 
394
436
  def _track_by_iou(masks, iou_threshold=0.1):
395
437
  """
396
- Build a track table by linking masks frame→frame via IoU.
397
- Returns a DataFrame with columns [frame, original_label, track_id].
438
+ Assigns consistent track IDs to labeled objects across time-lapse mask frames based on IoU matching.
439
+
440
+ Args:
441
+ masks (np.ndarray): 3D array of shape (T, H, W), where T is the number of time frames,
442
+ and each frame contains labeled segmentation masks.
443
+ iou_threshold (float, optional): The minimum Intersection-over-Union required to consider
444
+ two labels as matching. Defaults to 0.1.
445
+
446
+ Returns:
447
+ pd.DataFrame: A DataFrame with columns ['frame', 'original_label', 'track_id'], mapping
448
+ each original label in every frame to a track ID.
398
449
  """
399
450
  n_frames = masks.shape[0]
400
451
  # 1) initialize: every label in frame 0 starts its own track
@@ -427,7 +478,6 @@ def _track_by_iou(masks, iou_threshold=0.1):
427
478
  records.append({'frame': frame, 'original_label': label, 'track_id': tid})
428
479
  return pd.DataFrame(records)
429
480
 
430
-
431
481
  def _facilitate_trackin_with_adaptive_removal(masks, search_range=None, max_attempts=5, memory=3, min_mass=50, track_by_iou=False):
432
482
  """
433
483
  Facilitates object tracking with deterministic initial filtering and
@@ -631,9 +681,36 @@ def _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, mas
631
681
  return mask_stack
632
682
 
633
683
  def exponential_decay(x, a, b, c):
684
+ """
685
+ Computes the exponential decay function.
686
+
687
+ Args:
688
+ x (float or np.ndarray): Input value(s).
689
+ a (float): Initial amplitude.
690
+ b (float): Decay rate.
691
+ c (float): Offset.
692
+
693
+ Returns:
694
+ float or np.ndarray: The result of a * exp(-b * x) + c.
695
+ """
634
696
  return a * np.exp(-b * x) + c
635
697
 
636
698
  def preprocess_pathogen_data(pathogen_df):
699
+ """
700
+ Preprocesses pathogen data by aggregating measurements per host cell and counting parasites.
701
+
702
+ Args:
703
+ pathogen_df (pd.DataFrame): Input DataFrame containing pathogen-level measurements.
704
+ Must include 'plateID', 'rowID', 'column_name', 'fieldID', 'timeid', and 'pathogen_cell_id' columns.
705
+
706
+ Returns:
707
+ pd.DataFrame: Aggregated DataFrame with one row per pathogen-hosting cell, including:
708
+ - Mean of numeric features,
709
+ - First value of non-numeric features,
710
+ - Count of parasites per cell,
711
+ - Renamed 'pathogen_cell_id' to 'object_label',
712
+ - Removal of 'object_label' if originally present.
713
+ """
637
714
  # Group by identifiers and count the number of parasites
638
715
  parasite_counts = pathogen_df.groupby(['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id']).size().reset_index(name='parasite_count')
639
716
 
@@ -654,9 +731,41 @@ def preprocess_pathogen_data(pathogen_df):
654
731
  return pathogen_agg
655
732
 
656
733
  def plot_data(measurement, group, ax, label, marker='o', linestyle='-'):
734
+ """
735
+ Plots a time series of delta measurements on a given Axes object.
736
+
737
+ Args:
738
+ measurement (str): The name of the measurement to plot (e.g., 'intensity').
739
+ group (pd.DataFrame): A DataFrame containing time series data with a 'time' column and
740
+ a 'delta_<measurement>' column.
741
+ ax (matplotlib.axes.Axes): The Axes object to plot on.
742
+ label (str): The label for the plotted line (used in the legend).
743
+ marker (str, optional): Marker style for the plot. Defaults to 'o'.
744
+ linestyle (str, optional): Line style for the plot. Defaults to '-'.
745
+
746
+ Returns:
747
+ None
748
+ """
657
749
  ax.plot(group['time'], group['delta_' + measurement], marker=marker, linestyle=linestyle, label=label)
658
750
 
659
751
  def infected_vs_noninfected(result_df, measurement):
752
+ """
753
+ Plots time series data comparing infected and uninfected cells based on a specified measurement.
754
+
755
+ Args:
756
+ result_df (pd.DataFrame): DataFrame containing cell tracking data including
757
+ 'parasite_count', 'plate_row_column_field_object', and 'time' columns.
758
+ measurement (str): Name of the measurement to be visualized. Assumes a
759
+ 'delta_<measurement>' column exists in the DataFrame.
760
+
761
+ Returns:
762
+ None
763
+
764
+ Notes:
765
+ - Cells are considered infected if their maximum parasite count across all time points is > 0.
766
+ - Two subplots are generated: one for infected cells and one for uninfected cells.
767
+ - Time series are plotted for each unique cell using `plot_data()`.
768
+ """
660
769
  # Separate the merged dataframe into two groups based on pathogen_count
661
770
  infected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]
662
771
  uninfected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]
@@ -688,6 +797,17 @@ def infected_vs_noninfected(result_df, measurement):
688
797
  plt.show()
689
798
 
690
799
  def save_figure(fig, src, figure_number):
800
+ """
801
+ Saves a figure as a PDF in the 'results' directory relative to the given source path.
802
+
803
+ Args:
804
+ fig (matplotlib.figure.Figure): The figure object to be saved.
805
+ src (str): Path to a file or directory used to determine the save location.
806
+ figure_number (int or str): Identifier number for the figure file name.
807
+
808
+ Returns:
809
+ None
810
+ """
691
811
  source = os.path.dirname(src)
692
812
  results_fldr = os.path.join(source,'results')
693
813
  os.makedirs(results_fldr, exist_ok=True)
@@ -696,6 +816,17 @@ def save_figure(fig, src, figure_number):
696
816
  print(f'Saved figure:{fig_loc}')
697
817
 
698
818
  def save_results_dataframe(df, src, results_name):
819
+ """
820
+ Saves a DataFrame as a CSV file in the 'results' directory relative to the given source path.
821
+
822
+ Args:
823
+ df (pandas.DataFrame): The DataFrame to save.
824
+ src (str): Path used to determine the base directory for saving results.
825
+ results_name (str): Name of the CSV file (without extension).
826
+
827
+ Returns:
828
+ None
829
+ """
699
830
  source = os.path.dirname(src)
700
831
  results_fldr = os.path.join(source,'results')
701
832
  os.makedirs(results_fldr, exist_ok=True)
@@ -704,6 +835,22 @@ def save_results_dataframe(df, src, results_name):
704
835
  print(f'Saved results:{csv_loc}')
705
836
 
706
837
  def summarize_per_well(peak_details_df):
838
+ """
839
+ Summarizes peak data per well by extracting metadata, filtering valid entries, and computing statistics.
840
+
841
+ Args:
842
+ peak_details_df (pandas.DataFrame): DataFrame containing a column 'ID' and numeric measurement columns,
843
+ including 'amplitude'.
844
+
845
+ Returns:
846
+ pandas.DataFrame: Summary DataFrame with the following columns:
847
+ - 'well_ID': Unique well identifier ('rowID_columnID').
848
+ - 'peaks_per_well': Total number of peaks per well.
849
+ - 'unique_IDs_with_amplitude': Number of unique object IDs with valid amplitude.
850
+ - 'cells_per_well': Number of unique objects per well.
851
+ - 'peaks_per_cell': Ratio of peaks to cells per well.
852
+ - Mean values of all numeric columns (excluding ID components) aggregated per well.
853
+ """
707
854
  # Step 1: Split the 'ID' column
708
855
  split_columns = peak_details_df['ID'].str.split('_', expand=True)
709
856
  peak_details_df[['plateID', 'rowID', 'columnID', 'fieldID', 'object_number']] = split_columns
@@ -735,6 +882,21 @@ def summarize_per_well(peak_details_df):
735
882
  return summary_df
736
883
 
737
884
  def summarize_per_well_inf_non_inf(peak_details_df):
885
+ """
886
+ Summarizes peak data per well stratified by infection status.
887
+
888
+ Args:
889
+ peak_details_df (pandas.DataFrame): DataFrame containing peak data with the following required columns:
890
+ - 'ID': Composite string identifier ('plateID_rowID_columnID_fieldID_object_number').
891
+ - 'infected': Integer or boolean indicator (values > 0 denote infection).
892
+
893
+ Returns:
894
+ pandas.DataFrame: Summary statistics grouped by 'well_ID' and 'infected_status', including:
895
+ - 'cells_per_well': Unique cell count per group.
896
+ - 'peaks_per_well': Total number of peaks per group.
897
+ - 'peaks_per_cell': Ratio of peaks to cells.
898
+ - Mean of all numeric columns aggregated per group.
899
+ """
738
900
  # Step 1: Split the 'ID' column
739
901
  split_columns = peak_details_df['ID'].str.split('_', expand=True)
740
902
  peak_details_df[['plateID', 'rowID', 'columnID', 'fieldID', 'object_number']] = split_columns
@@ -762,6 +924,28 @@ def summarize_per_well_inf_non_inf(peak_details_df):
762
924
  return summary_df
763
925
 
764
926
  def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
927
+ """
928
+ Analyze calcium oscillations in single-cell time-lapse data by detecting peaks in fluorescence intensity changes.
929
+
930
+ Args:
931
+ db_loc (str): Path to the SQLite database containing 'cell', optionally 'pathogen' and/or 'cytoplasm' tables.
932
+ measurement (str): Column name in the 'cell' table to analyze (default: 'cell_channel_1_mean_intensity').
933
+ size_filter (str): Column used to filter out unstable cells by size fluctuation (default: 'cell_area').
934
+ fluctuation_threshold (float): Relative threshold for cell size fluctuation (std/mean) to retain cells (default: 0.25).
935
+ num_lines (int or None): Number of single-cell traces to show in the output plot. If None, plot all (default: None).
936
+ peak_height (float): Minimum peak height in delta-fluorescence for peak detection (default: 0.01).
937
+ pathogen (bool or None): If True, merge with 'pathogen' table and use parasite counts to annotate infection (default: None).
938
+ cytoplasm (bool or None): If True, merge with 'cytoplasm' table (default: None).
939
+ remove_transient (bool): Whether to discard cells tracked in less than a fraction of total timepoints (default: True).
940
+ verbose (bool): Print detailed progress and filtering info (default: False).
941
+ transience_threshold (float): Minimum fraction of timepoints a cell must be tracked (default: 0.9).
942
+
943
+ Returns:
944
+ tuple: (result_df, peak_details_df, fig)
945
+ - result_df (DataFrame): All analyzed per-frame data with delta fluorescence.
946
+ - peak_details_df (DataFrame): Detected peaks and derived features per peak.
947
+ - fig (Matplotlib figure): Line plot of normalized fluorescence deltas over time.
948
+ """
765
949
  # Load data
766
950
  conn = sqlite3.connect(db_loc)
767
951
  # Load cell table
spacr/toxo.py CHANGED
@@ -9,6 +9,44 @@ from sklearn.metrics import mean_absolute_error
9
9
  from matplotlib.gridspec import GridSpec
10
10
 
11
11
  def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',point_size=50, figsize=20, threshold=0,save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 20]]):
12
+ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=0, save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 20]]):
13
+ """
14
+ Creates a custom volcano plot with two subplots (upper and lower) sharing the x-axis,
15
+ where points are colored based on metadata categories.
16
+ Parameters:
17
+ -----------
18
+ data_path : str or pandas.DataFrame
19
+ Path to the CSV file containing the data or a pandas DataFrame. The data should
20
+ include columns 'feature', 'p_value', and 'coefficient'.
21
+ metadata_path : str or pandas.DataFrame
22
+ Path to the CSV file containing metadata or a pandas DataFrame. The metadata
23
+ should include a column matching `metadata_column` and 'gene_nr'.
24
+ metadata_column : str, optional
25
+ The column in the metadata file used to categorize and color points (default is 'tagm_location').
26
+ point_size : int, optional
27
+ Size of the scatter plot points (default is 50).
28
+ figsize : int, optional
29
+ Size of the figure (default is 20).
30
+ threshold : float, optional
31
+ Threshold for the absolute value of the coefficient to consider a point significant (default is 0).
32
+ save_path : str, optional
33
+ Path to save the plot as a PDF. If None, the plot is not saved (default is None).
34
+ x_lim : list of float, optional
35
+ Limits for the x-axis (default is [-0.5, 0.5]).
36
+ y_lims : list of lists, optional
37
+ Limits for the y-axis for the lower and upper subplots (default is [[0, 6], [9, 20]]).
38
+ Returns:
39
+ --------
40
+ hit_list : list of str
41
+ List of variables (hits) that meet the significance thresholds for p-value and coefficient.
42
+ Notes:
43
+ ------
44
+ - The function uses a color dictionary to assign colors to points based on the `metadata_column` values.
45
+ - Points with p-value <= 0.05 and absolute coefficient >= `threshold` are considered significant.
46
+ - Significant points are annotated with their variable names, and overlapping annotations are adjusted.
47
+ - A legend is optionally added to indicate the mapping of metadata categories to colors.
48
+ - The plot is divided into two subplots to handle a wide range of p-values effectively.
49
+ """
12
50
 
13
51
  # Dictionary mapping compartment to color
14
52
 
@@ -326,12 +364,20 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
326
364
 
327
365
  def plot_gene_phenotypes(data, gene_list, x_column='Gene ID', data_column='T.gondii GT1 CRISPR Phenotype - Mean Phenotype',error_column='T.gondii GT1 CRISPR Phenotype - Standard Error', save_path=None):
328
366
  """
329
- Plot a line graph for the mean phenotype with standard error shading and highlighted genes.
330
-
367
+ Plot gene phenotype means with error bars, highlighting selected genes.
368
+
331
369
  Args:
332
- data (pd.DataFrame): The input DataFrame containing gene data.
333
- gene_list (list): A list of gene names to highlight on the plot.
370
+ data (pd.DataFrame): DataFrame with phenotype data.
371
+ gene_list (list of str): List of gene identifiers to highlight.
372
+ x_column (str, optional): Column with gene identifiers. Default: 'Gene ID'.
373
+ data_column (str, optional): Column with mean phenotype values.
374
+ error_column (str, optional): Column with standard error values.
375
+ save_path (str or None, optional): Path to save PDF. If None, plot is not saved.
376
+
377
+ Returns:
378
+ None
334
379
  """
380
+
335
381
  # Ensure x_column is properly processed
336
382
  def extract_gene_id(gene):
337
383
  if isinstance(gene, str) and '_' in gene:
@@ -471,9 +517,48 @@ def plot_gene_heatmaps(data, gene_list, columns, x_column='Gene ID', normalize=F
471
517
  plt.show()
472
518
 
473
519
  def generate_score_heatmap(settings):
520
+ """
521
+ Generates a score comparison heatmap and calculates MAE between classification scores and a reference fraction.
474
522
 
523
+ This function:
524
+ - Combines classification scores from multiple folders.
525
+ - Calculates the control gRNA fractions for a specific gRNA.
526
+ - Merges the fraction data with classification and cross-validation scores.
527
+ - Plots a heatmap of scores.
528
+ - Calculates Mean Absolute Error (MAE) between the prediction scores and the true fraction.
529
+ - Optionally saves the resulting data and figure to disk.
530
+
531
+ Args:
532
+ settings (dict): A dictionary containing:
533
+ - 'folders': List of folders to search for classification CSVs.
534
+ - 'csv_name': Name of the classification CSV file in each folder.
535
+ - 'data_column': Column with predicted values to be compared.
536
+ - 'plateID': Plate identifier (e.g., 1).
537
+ - 'columnID': Column identifier (e.g., 'c3').
538
+ - 'csv': Path to the control gRNA CSV file.
539
+ - 'control_sgrnas': List of two control gRNA names.
540
+ - 'fraction_grna': gRNA name for which the true fraction will be extracted.
541
+ - 'cv_csv': Path to the CSV file with cross-validation predictions.
542
+ - 'data_column_cv': Column name for cross-validation predictions.
543
+ - 'dst': Output directory path (or None to disable saving).
544
+
545
+ Returns:
546
+ pandas.DataFrame: The merged DataFrame with 'fraction', predicted scores, and cross-validation scores.
547
+ """
548
+
475
549
  def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
476
-
550
+ """
551
+ Group and average cross-validation scores (or other predictions) for each well.
552
+
553
+ Args:
554
+ csv (str): Path to the input CSV file.
555
+ plate (int, optional): Plate number to assign if not present in the file. Defaults to 1.
556
+ column (str, optional): Column (well) to filter on, e.g., 'c3'. Defaults to 'c3'.
557
+ data_column (str, optional): Column to average, e.g., prediction or score column. Defaults to 'pred'.
558
+
559
+ Returns:
560
+ pandas.DataFrame: DataFrame with grouped average values and a 'prc' identifier (plate_row_column).
561
+ """
477
562
  df = pd.read_csv(csv)
478
563
  if 'column_name' in df.columns:
479
564
  df = df[df['column_name']==column]
@@ -487,6 +572,21 @@ def generate_score_heatmap(settings):
487
572
  return grouped_df
488
573
 
489
574
  def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
575
+ """
576
+ Calculate the fraction of control_sgrnas in a specified column and plate.
577
+
578
+ Args:
579
+ csv (str): Path to the input CSV file.
580
+ plate (int, optional): Plate number to filter on. Defaults to 1.
581
+ column (str, optional): Column name (e.g., 'c3') to filter on. Defaults to 'c3'.
582
+ control_sgrnas (list of str, optional): List of two gRNA identifiers to include in the fraction calculation.
583
+ Defaults to ['TGGT1_220950_1', 'TGGT1_233460_4'].
584
+
585
+ Returns:
586
+ pandas.DataFrame: DataFrame with columns including:
587
+ - 'fraction': Fraction of each control_sgrna in its well.
588
+ - 'prc': Plate-row-column identifier.
589
+ """
490
590
  df = pd.read_csv(csv)
491
591
  df = df[df['column_name']==column]
492
592
  if plate not in df.columns:
@@ -549,6 +649,22 @@ def generate_score_heatmap(settings):
549
649
 
550
650
 
551
651
  def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
652
+ """
653
+ Combines classification scores from multiple CSV files across subfolders.
654
+
655
+ Args:
656
+ folders (str or list of str): A folder or list of folders containing subdirectories
657
+ with CSV files named `csv_name`.
658
+ csv_name (str): The name of the CSV file to search for in each subdirectory.
659
+ data_column (str): The column name in the CSV to extract and average.
660
+ plate (int, optional): Plate number to assign to all rows. Defaults to 1.
661
+ column (str, optional): Column name to filter on, e.g., 'c3'. Defaults to 'c3'.
662
+
663
+ Returns:
664
+ pandas.DataFrame: Combined DataFrame with averaged `data_column` values from each source file.
665
+ Each source contributes a uniquely named column, and the result includes a
666
+ 'prc' column (plate_row_column identifier).
667
+ """
552
668
  # Ensure `folders` is a list
553
669
  if isinstance(folders, str):
554
670
  folders = [folders]
@@ -596,7 +712,17 @@ def generate_score_heatmap(settings):
596
712
 
597
713
  def calculate_mae(df):
598
714
  """
599
- Calculate the MAE between each channel's predictions and the fraction column for all rows.
715
+ Calculates the Mean Absolute Error (MAE) between the 'fraction' column and all other numeric columns
716
+ (excluding 'fraction' and 'prc') for each row in the DataFrame.
717
+
718
+ Args:
719
+ df (pandas.DataFrame): Input DataFrame containing a 'fraction' column and prediction columns to compare.
720
+
721
+ Returns:
722
+ pandas.DataFrame: A DataFrame with columns ['Channel', 'MAE', 'Row'], where:
723
+ - 'Channel' indicates the compared prediction column,
724
+ - 'MAE' is the mean absolute error between that column and 'fraction',
725
+ - 'Row' contains the 'prc' identifier for the row.
600
726
  """
601
727
  # Extract numeric columns excluding 'fraction' and 'prc'
602
728
  channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])