spacr 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/ml.py CHANGED
@@ -49,16 +49,56 @@ warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0
49
49
 
50
50
 
51
51
  class QuasiBinomial(Binomial):
52
- """Custom Quasi-Binomial family with adjustable variance."""
52
+ """
53
+ Quasi-Binomial family for generalized linear models with adjustable dispersion.
54
+
55
+ Extends the standard Binomial family from `statsmodels` to allow for modeling overdispersion
56
+ in binomial data, where the observed variance exceeds the nominal binomial variance.
57
+
58
+ This is often used in biological and ecological count data where extra-binomial variation is common.
59
+ """
53
60
  def __init__(self, link=logit(), dispersion=1.0):
61
+ """
62
+ Initialize the Quasi-Binomial family with a specified link function and dispersion factor.
63
+
64
+ Args:
65
+ link (statsmodels.genmod.families.links): The link function to use. Defaults to logit().
66
+ dispersion (float): A positive float to scale the variance.
67
+ Use values >1 to model overdispersion. Defaults to 1.0.
68
+ """
54
69
  super().__init__(link=link)
55
70
  self.dispersion = dispersion
56
71
 
57
72
  def variance(self, mu):
58
- """Adjust the variance with the dispersion parameter."""
73
+ """
74
+ Compute the variance function for the quasi-binomial model.
75
+
76
+ Args:
77
+ mu (array-like): The mean response values.
78
+
79
+ Returns:
80
+ array-like: Variance scaled by the dispersion parameter.
81
+ """
59
82
  return self.dispersion * super().variance(mu)
60
83
 
61
84
  def calculate_p_values(X, y, model):
85
+ """
86
+ Calculate p-values for model coefficients using residuals and t-statistics.
87
+
88
+ Args
89
+ ----------
90
+ X : np.ndarray
91
+ The input feature matrix used for fitting the model.
92
+ y : np.ndarray
93
+ The target variable.
94
+ model : object
95
+ A fitted sklearn linear model (must have `.predict` and `.coef_` attributes).
96
+
97
+ Returns
98
+ -------
99
+ np.ndarray
100
+ An array of p-values corresponding to each coefficient in the model.
101
+ """
62
102
  # Predict y values
63
103
  y_pred = model.predict(X)
64
104
  # Calculate residuals
@@ -78,6 +118,25 @@ def calculate_p_values(X, y, model):
78
118
  return np.array(p_values) # Ensure p_values is a 1-dimensional array
79
119
 
80
120
  def perform_mixed_model(y, X, groups, alpha=1.0):
121
+ """
122
+ Perform mixed effects regression and return the fitted model.
123
+
124
+ Args
125
+ ----------
126
+ y : pd.Series or np.ndarray
127
+ Target variable.
128
+ X : pd.DataFrame
129
+ Feature matrix.
130
+ groups : pd.Series or np.ndarray
131
+ Grouping variable for mixed model.
132
+ alpha : float, optional
133
+ Regularization strength for Ridge regression when multicollinearity is detected.
134
+
135
+ Returns
136
+ -------
137
+ statsmodels.regression.mixed_linear_model.MixedLMResults
138
+ Fitted mixed model results.
139
+ """
81
140
  # Ensure groups are defined correctly and check for multicollinearity
82
141
  if groups is None:
83
142
  raise ValueError("Groups must be defined for mixed model regression")
@@ -99,7 +158,25 @@ def perform_mixed_model(y, X, groups, alpha=1.0):
99
158
  return result
100
159
 
101
160
  def create_volcano_filename(csv_path, regression_type, alpha, dst):
102
- """Create and return the volcano plot filename based on regression type and alpha."""
161
+ """
162
+ Construct the file name for saving the volcano plot.
163
+
164
+ Args
165
+ ----------
166
+ csv_path : str
167
+ Path to the CSV file containing model results.
168
+ regression_type : str
169
+ Type of regression performed.
170
+ alpha : float
171
+ Alpha value (used for quantile regression naming).
172
+ dst : str or None
173
+ Destination directory to save the file.
174
+
175
+ Returns
176
+ -------
177
+ str
178
+ Full path for the output volcano plot PDF file.
179
+ """
103
180
  volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
104
181
  volcano_filename = f"{regression_type}_{volcano_filename}" if regression_type != 'quantile' else f"{alpha}_{volcano_filename}"
105
182
 
@@ -108,7 +185,23 @@ def create_volcano_filename(csv_path, regression_type, alpha, dst):
108
185
  return os.path.join(os.path.dirname(csv_path), volcano_filename)
109
186
 
110
187
  def scale_variables(X, y):
111
- """Scale independent (X) and dependent (y) variables using MinMaxScaler."""
188
+ """
189
+ Min-max scale both independent and dependent variables.
190
+
191
+ Args
192
+ ----------
193
+ X : pd.DataFrame
194
+ Feature matrix.
195
+ y : np.ndarray
196
+ Target variable.
197
+
198
+ Returns
199
+ -------
200
+ X_scaled : pd.DataFrame
201
+ Scaled feature matrix.
202
+ y_scaled : np.ndarray
203
+ Scaled target variable.
204
+ """
112
205
  scaler_X = MinMaxScaler()
113
206
  scaler_y = MinMaxScaler()
114
207
 
@@ -117,63 +210,20 @@ def scale_variables(X, y):
117
210
 
118
211
  return X_scaled, y_scaled
119
212
 
120
- def process_model_coefficients_v1(model, regression_type, X, y, nc, pc, controls):
121
- """Return DataFrame of model coefficients and p-values."""
122
- if regression_type in ['ols', 'gls', 'wls', 'rlm', 'glm', 'mixed', 'quantile', 'logit', 'probit', 'poisson']:
123
- coefs = model.params
124
- p_values = model.pvalues
125
-
126
- coef_df = pd.DataFrame({
127
- 'feature': coefs.index,
128
- 'coefficient': coefs.values,
129
- 'p_value': p_values.values
130
- })
131
-
132
- elif regression_type in ['ridge', 'lasso']:
133
- coefs = model.coef_.flatten()
134
- p_values = calculate_p_values(X, y, model)
135
-
136
- coef_df = pd.DataFrame({
137
- 'feature': X.columns,
138
- 'coefficient': coefs,
139
- 'p_value': p_values
140
- })
141
-
142
- else:
143
- coefs = model.coef_
144
- intercept = model.intercept_
145
- feature_names = X.design_info.column_names
213
+ def select_glm_family(y):
214
+ """
215
+ Automatically select a GLM family based on the nature of the dependent variable.
146
216
 
147
- coef_df = pd.DataFrame({
148
- 'feature': feature_names,
149
- 'coefficient': coefs
150
- })
151
- coef_df.loc[0, 'coefficient'] += intercept
152
- coef_df['p_value'] = np.nan # Placeholder since sklearn doesn't provide p-values
217
+ Args
218
+ ----------
219
+ y : np.ndarray
220
+ Target variable.
153
221
 
154
- coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
155
- coef_df['grna'] = coef_df['feature'].str.extract(r'\[(.*?)\]')[0]
156
- coef_df['condition'] = coef_df.apply(lambda row: 'nc' if nc in row['feature'] else 'pc' if pc in row['feature'] else ('control' if row['grna'] in controls else 'other'),axis=1)
157
- return coef_df[~coef_df['feature'].str.contains('row|column')]
158
-
159
- def check_distribution_v1(y):
160
- """Check the type of distribution to recommend a model."""
161
- if np.all((y == 0) | (y == 1)):
162
- print("Detected binary data.")
163
- return 'logit'
164
- elif (y > 0).all() and (y < 1).all():
165
- print("Detected continuous data between 0 and 1 (excluding 0 and 1).")
166
- return 'beta'
167
- elif (y >= 0).all() and (y <= 1).all():
168
- print("Detected continuous data between 0 and 1 (including 0 or 1).")
169
- # Consider quasi-binomial regression
170
- return 'quasi_binomial'
171
- else:
172
- print("Using OLS as a fallback.")
173
- return 'ols'
174
-
175
- def select_glm_family(y):
176
- """Select the appropriate GLM family based on the data."""
222
+ Returns
223
+ -------
224
+ sm.families.Family
225
+ GLM family appropriate for the data.
226
+ """
177
227
  if np.all((y == 0) | (y == 1)):
178
228
  print("Using Binomial family (for binary data).")
179
229
  return sm.families.Binomial()
@@ -187,17 +237,77 @@ def select_glm_family(y):
187
237
  print("Using Gaussian family (for continuous data).")
188
238
  return sm.families.Gaussian()
189
239
 
190
- def prepare_formula(dependent_variable, random_row_column_effects=False):
191
- """Return the regression formula using random effects for plate, row, and column."""
240
+ def prepare_formula(dependent_variable, random_row_column_effects=False, regression_level="gene"):
241
+ """
242
+ Prepare the formula for regression modeling based on model design.
243
+
244
+ Args
245
+ ----------
246
+ dependent_variable : str
247
+ Name of the dependent variable.
248
+ random_row_column_effects : bool, optional
249
+ Whether to include row and column IDs as random effects.
250
+
251
+ Returns
252
+ -------
253
+ str
254
+ Regression formula string.
255
+ """
256
+ #if random_row_column_effects:
257
+ # # Random effects for row and column + gene weighted by gene_fraction + grna weighted by fraction
258
+ # return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene'
259
+ #return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene + rowID + columnID'
260
+
192
261
  if random_row_column_effects:
193
- # Random effects for row and column + gene weighted by gene_fraction + grna weighted by fraction
194
- return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene'
195
- return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene + rowID + columnID'
262
+ if regression_level == "gene":
263
+ formula = f'{dependent_variable} ~ gene_fraction:gene + plateID + rowID + columnID'
264
+
265
+ elif regression_level == "grna":
266
+ formula = f'{dependent_variable} ~ fraction:grna + plateID + rowID + columnID'
267
+
268
+ elif regression_level == "both":
269
+ formula = f'{dependent_variable} ~ fraction:grna + gene_fraction:gene + plateID + rowID + columnID'
270
+
271
+ else:
272
+ print(f"Unknown regression_level: {regression_level}. Defaulting to gene model.")
273
+ formula = f'{dependent_variable} ~ gene_fraction:gene + plateID + rowID + columnID'
274
+ else:
275
+ if regression_level == "gene":
276
+ formula = f'{dependent_variable} ~ gene_fraction:gene'
277
+
278
+ elif regression_level == "grna":
279
+ formula = f'{dependent_variable} ~ fraction:grna'
280
+
281
+ elif regression_level == "both":
282
+ formula = f'{dependent_variable} ~ fraction:grna + gene_fraction:gene'
283
+
284
+ else:
285
+ print(f"Unknown type: {regression_level}. Defaulting to gene model.")
286
+ formula = f'{dependent_variable} ~ gene_fraction:gene'
287
+
288
+ return formula
196
289
 
197
290
  def fit_mixed_model(df, formula, dst):
291
+ """
292
+ Fit a mixed linear model with random effects for plate, row, and column.
293
+
294
+ Args
295
+ ----------
296
+ df : pd.DataFrame
297
+ Input data.
298
+ formula : str
299
+ Regression formula.
300
+ dst : str
301
+ Output folder to save diagnostic plots.
302
+
303
+ Returns
304
+ -------
305
+ mixed_model : statsmodels.regression.mixed_linear_model.MixedLMResults
306
+ Fitted model.
307
+ coef_df : pd.DataFrame
308
+ DataFrame of coefficients and p-values.
309
+ """
198
310
  from .plot import plot_histogram
199
-
200
- """Fit the mixed model with plate, row_name, and columnID as random effects and return results."""
201
311
  # Specify random effects for plate, row, and column
202
312
  model = smf.mixedlm(formula,
203
313
  data=df,
@@ -224,8 +334,21 @@ def fit_mixed_model(df, formula, dst):
224
334
  return mixed_model, coef_df
225
335
 
226
336
  def check_and_clean_data(df, dependent_variable):
227
- """Check for collinearity, missing values, or invalid types in relevant columns. Clean data accordingly."""
228
-
337
+ """
338
+ Validate and preprocess a DataFrame before model fitting.
339
+
340
+ Args
341
+ ----------
342
+ df : pd.DataFrame
343
+ Input data.
344
+ dependent_variable : str
345
+ Name of the dependent variable column.
346
+
347
+ Returns
348
+ -------
349
+ pd.DataFrame
350
+ Cleaned and validated DataFrame, ready for model fitting.
351
+ """
229
352
  def handle_missing_values(df, columns):
230
353
  """Handle missing values in specified columns."""
231
354
  missing_summary = df[columns].isnull().sum()
@@ -304,25 +427,31 @@ def check_and_clean_data(df, dependent_variable):
304
427
  print("Data is ready for model fitting.")
305
428
  return df_cleaned
306
429
 
307
- def check_normality_v1(y, variable_name):
308
- """Check if the data is normally distributed using the Shapiro-Wilk test."""
309
- from scipy.stats import shapiro
310
-
311
- stat, p = shapiro(y)
312
- alpha = 0.05
313
- if p > alpha:
314
- print(f"{variable_name} is normally distributed (fail to reject H0)")
315
- return True
316
- else:
317
- print(f"{variable_name} is not normally distributed (reject H0)")
318
- return False
319
-
320
430
  def minimum_cell_simulation(settings, num_repeats=10, sample_size=100, tolerance=0.02, smoothing=10, increment=10):
321
431
  """
322
- Plot the mean absolute difference with standard deviation as shaded area vs. sample size.
323
- Detect and mark the elbow point (inflection) with smoothing and tolerance control.
324
- """
432
+ Estimate the minimum number of cells required per well to stabilize phenotype measurements.
433
+
434
+ This function simulates phenotype score stability by repeatedly subsampling increasing numbers
435
+ of cells per well, calculating the mean absolute difference from the full-well mean. It identifies
436
+ the minimal sample size (elbow point) at which the average difference drops below a user-defined
437
+ tolerance, smoothed across wells.
325
438
 
439
+ Args:
440
+ settings (dict): A dictionary with keys:
441
+ - 'score_data' (list or str): CSV path(s) with cell-level data.
442
+ - 'score_column' (str): Column name of the phenotype score.
443
+ - 'tolerance' (float or int): Allowed deviation from true mean (e.g. 0.02 for 2%).
444
+ - 'min_cell_count' (int or None): Optional fixed value for annotation.
445
+ - 'count_data' (list): Used to define the output folder for figure saving.
446
+ num_repeats (int, optional): Number of times to resample for each sample size. Default is 10.
447
+ sample_size (int, optional): Number of top wells to simulate (by cell count). Default is 100.
448
+ tolerance (float or int, optional): Tolerance threshold to define the elbow point. Default is 0.02.
449
+ smoothing (int, optional): Window size for smoothing the mean absolute difference curve. Default is 10.
450
+ increment (int, optional): Step size between tested sample sizes. Default is 10.
451
+
452
+ Returns:
453
+ int: Estimated minimal required number of cells (elbow point).
454
+ """
326
455
  from .utils import correct_metadata_column_names
327
456
 
328
457
  # Load and process data
@@ -451,7 +580,24 @@ def minimum_cell_simulation(settings, num_repeats=10, sample_size=100, tolerance
451
580
  return elbow_point['sample_size']
452
581
 
453
582
  def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
454
- """Return DataFrame of model coefficients, standard errors, and p-values."""
583
+ """
584
+ Extract model coefficients, standard errors, and p-values into a DataFrame with annotations.
585
+
586
+ Supports various regression types including beta, OLS, GLM, ridge, lasso, and quasi-binomial.
587
+ Adds classification labels (nc, pc, control, other) and computes -log10(p-values).
588
+
589
+ Args:
590
+ model: Fitted regression model object.
591
+ regression_type (str): Type of regression (e.g., 'beta', 'ols', 'glm', 'ridge', 'lasso').
592
+ X (pd.DataFrame): Feature matrix used for fitting the model.
593
+ y (np.ndarray): Target values used in model fitting.
594
+ nc (str): Identifier for negative control features.
595
+ pc (str): Identifier for positive control features.
596
+ controls (list): List of gRNAs used as general controls.
597
+
598
+ Returns:
599
+ pd.DataFrame: Table of coefficients, p-values, and annotations.
600
+ """
455
601
 
456
602
  if regression_type == 'beta':
457
603
  # Extract coefficients and standard errors
@@ -508,7 +654,24 @@ def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
508
654
  return coef_df[~coef_df['feature'].str.contains('row|column')]
509
655
 
510
656
  def check_distribution(y, epsilon=1e-6):
511
- """Check the distribution of y and recommend an appropriate model."""
657
+ """
658
+ Analyze distribution of the target variable and recommend appropriate regression model.
659
+
660
+ Checks for:
661
+ - Binary data (logit)
662
+ - Continuous [0, 1) data (beta)
663
+ - Continuous [0, 1] data (quasi-binomial)
664
+ - Normal distribution (OLS)
665
+ - Beta distribution fit
666
+ - Default to GLM otherwise
667
+
668
+ Args:
669
+ y (np.ndarray): Dependent variable values.
670
+ epsilon (float, optional): Threshold for boundary proximity detection. Default is 1e-6.
671
+
672
+ Returns:
673
+ str: Suggested regression model name (e.g., 'logit', 'beta', 'ols', 'glm').
674
+ """
512
675
 
513
676
  # Check if the dependent variable is binary (only 0 and 1)
514
677
  if np.all((y == 0) | (y == 1)):
@@ -552,7 +715,25 @@ def check_distribution(y, epsilon=1e-6):
552
715
  return 'glm'
553
716
 
554
717
  def pick_glm_family_and_link(y):
555
- """Select the appropriate GLM family and link function based on data."""
718
+ """
719
+ Select the appropriate GLM family and link function based on distribution of y.
720
+
721
+ Inspects binary, count, proportion, normal, and overdispersed data and maps to:
722
+ - Binomial with Logit link
723
+ - Gaussian with Identity link
724
+ - Poisson with Log link
725
+ - Inverse Gaussian with Log link
726
+ - Negative Binomial with Log link
727
+
728
+ Args:
729
+ y (np.ndarray): Response variable to inspect.
730
+
731
+ Returns:
732
+ sm.families.Family: GLM family object with appropriate link function.
733
+
734
+ Raises:
735
+ ValueError: If data is suitable for Beta regression, which GLM cannot handle.
736
+ """
556
737
  if np.all((y == 0) | (y == 1)):
557
738
  print("Binary data detected. Using Binomial family with Logit link.")
558
739
  return sm.families.Binomial(link=sm.families.links.Logit())
@@ -587,6 +768,24 @@ def pick_glm_family_and_link(y):
587
768
  return sm.families.Gaussian(link=sm.families.links.Identity())
588
769
 
589
770
  def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
771
+ """
772
+ Fit a regression model of the specified type to the data.
773
+
774
+ Supports multiple regression types: OLS, GLM, beta, logit, probit, ridge, lasso, and mixed models.
775
+ Automatically performs hyperparameter tuning (alpha) for lasso and ridge if alpha is 0 or None.
776
+ For GLMs, prints McFadden's pseudo R². For regularized models, prints MSE and coefficient table.
777
+
778
+ Args:
779
+ X (pd.DataFrame): Design matrix of independent variables.
780
+ y (pd.Series or np.ndarray): Dependent variable.
781
+ regression_type (str): Type of regression ('ols', 'glm', 'beta', 'logit', 'probit', 'lasso', 'ridge', 'mixed').
782
+ groups (array-like, optional): Grouping variable for mixed models.
783
+ alpha (float): Regularization strength for lasso and ridge.
784
+ cov_type (str, optional): Covariance estimator type for OLS (e.g., 'HC3').
785
+
786
+ Returns:
787
+ model: Fitted model object (statsmodels or sklearn estimator).
788
+ """
590
789
  def plot_regression_line(X, y, model):
591
790
  """Helper to plot regression line for lasso and ridge models."""
592
791
  y_pred = model.predict(X)
@@ -652,10 +851,32 @@ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_ty
652
851
 
653
852
  return model
654
853
 
655
- def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0,
854
+ def regression(df, csv_path, dependent_variable='predictions', regression_type=None, regression_level = "gRNA", alpha=1.0,
656
855
  random_row_column_effects=False, nc='233460', pc='220950', controls=[''],
657
856
  dst=None, cov_type=None, plot=False):
658
-
857
+ """
858
+ Perform regression analysis on a DataFrame with optional plotting and mixed effects support.
859
+
860
+ Handles data cleaning, formula preparation, scaling, model fitting, and coefficient extraction.
861
+ Automatically determines appropriate regression type if not specified.
862
+
863
+ Args:
864
+ df (pd.DataFrame): Input data.
865
+ csv_path (str): Path to input CSV used for labeling plots/files.
866
+ dependent_variable (str): Name of column to regress against.
867
+ regression_type (str, optional): Type of regression (auto-detected if None).
868
+ alpha (float): Regularization parameter for lasso/ridge.
869
+ random_row_column_effects (bool): Whether to model row/column as random effects.
870
+ nc (str): Identifier for negative controls.
871
+ pc (str): Identifier for positive controls.
872
+ controls (list): List of general control gRNAs.
873
+ dst (str, optional): Destination folder for output files.
874
+ cov_type (str, optional): Covariance estimator for OLS.
875
+ plot (bool): Whether to generate volcano plot.
876
+
877
+ Returns:
878
+ tuple: (fitted model, coefficients DataFrame, regression type used)
879
+ """
659
880
  from .plot import volcano_plot, plot_histogram
660
881
  #from .ml import create_volcano_filename, check_and_clean_data, prepare_formula, scale_variables
661
882
 
@@ -673,12 +894,12 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
673
894
  # Handle mixed effects if row/column effect is treated as random
674
895
  if random_row_column_effects:
675
896
  regression_type = 'mixed'
676
- formula = prepare_formula(dependent_variable, random_row_column_effects=True)
897
+ formula = prepare_formula(dependent_variable, random_row_column_effects=True, regression_level=regression_level)
677
898
  mixed_model, coef_df = fit_mixed_model(df, formula, dst)
678
899
  model = mixed_model
679
900
  else:
680
901
  # Prepare the formula
681
- formula = prepare_formula(dependent_variable, random_row_column_effects=False)
902
+ formula = prepare_formula(dependent_variable, random_row_column_effects=False, regression_level=regression_level)
682
903
  y, X = dmatrices(formula, data=df, return_type='dataframe')
683
904
 
684
905
  # Plot histogram of the dependent variable
@@ -707,7 +928,14 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
707
928
 
708
929
  def save_summary_to_file(model, file_path='summary.csv'):
709
930
  """
710
- Save the model's summary output to a CSV or text file.
931
+ Save the model's summary output to a text or CSV file.
932
+
933
+ Args:
934
+ model: Fitted statsmodels model with a `.summary()` method.
935
+ file_path (str): Output path to save the summary (e.g., 'summary.txt').
936
+
937
+ Returns:
938
+ None
711
939
  """
712
940
  # Get the summary as a string
713
941
  summary_str = model.summary().as_text()
@@ -717,7 +945,58 @@ def save_summary_to_file(model, file_path='summary.csv'):
717
945
  f.write(summary_str)
718
946
 
719
947
  def perform_regression(settings):
720
-
948
+ """
949
+ Perform full regression analysis pipeline for pooled CRISPR-Cas9 screens.
950
+
951
+ This function integrates data loading, filtering, transformation, regression modeling,
952
+ statistical significance assessment, visualization, and optional downstream Toxoplasma-specific analysis.
953
+ It handles both grna-level and gene-level aggregation and outputs results with optional metadata integration.
954
+
955
+ Args:
956
+ settings (dict): Dictionary containing regression settings. Recommended to start with
957
+ `get_perform_regression_default_settings()` to initialize.
958
+
959
+ Required Keys in `settings`:
960
+ - score_data (str or list): Path(s) to CSV(s) containing phenotypic scores.
961
+ - count_data (str or list): Path(s) to CSV(s) containing barcode abundance counts.
962
+ - dependent_variable (str): Column name in `score_data` to use as dependent variable.
963
+ - regression_type (str or None): Type of regression model to use ('ols', 'glm', 'lasso', etc.).
964
+ - filter_value (list): Values to exclude from `filter_column`.
965
+ - filter_column (str): Column in data to apply exclusion filter.
966
+ - plateID (str): Identifier to use when reconstructing `prc` labels.
967
+ - alpha (float): Regularization strength for lasso/ridge regressions.
968
+ - random_row_column_effects (bool): If True, treat row/column as random effects.
969
+ - negative_control (str): Label for negative controls.
970
+ - positive_control (str): Label for positive controls.
971
+ - controls (list): List of control gRNAs for estimating coefficient threshold.
972
+ - volcano (str): Type of volcano plot to generate ('all', 'gene', 'grna').
973
+ - transform (bool): Whether to transform dependent variable.
974
+ - agg_type (str): Aggregation method for per-well phenotypes ('mean', 'median', etc.).
975
+ - metadata_files (str or list): Metadata CSV(s) to merge with regression results.
976
+ - toxo (bool): If True, generate Toxoplasma-specific downstream plots.
977
+ - threshold_method (str): Method to calculate effect size threshold ('std' or 'var').
978
+ - threshold_multiplier (float): Multiplier for threshold method.
979
+ - outlier_detection (bool): Whether to remove grnas with outlier well counts.
980
+ - cov_type (str or None): Covariance type for OLS (e.g., 'HC3').
981
+ - verbose (bool): Whether to print verbose output.
982
+ - x_lim, y_lims: Optional axes limits for volcano plots.
983
+ - tolerance (float): Tolerance for determining minimal cell count via simulation.
984
+ - min_cell_count (int or None): Minimum number of cells per well for inclusion.
985
+ - min_n (int): Minimum number of replicates for inclusion in filtered results.
986
+
987
+ Returns:
988
+ dict:
989
+ 'results': Full coefficient table (pandas DataFrame).
990
+ 'significant': Filtered table of statistically significant hits (pandas DataFrame).
991
+
992
+ Side Effects:
993
+ - Saves regression results, metadata-merged results, volcano plots, and phenotype plots to disk.
994
+ - Optionally generates publication-ready plots for Toxoplasma gene expression and phenotype scores.
995
+ - Writes cell count, well-grna, and well-gene plots to disk.
996
+
997
+ Raises:
998
+ ValueError: If required keys are missing or regression type is invalid.
999
+ """
721
1000
  from .plot import plot_plates, plot_data_from_csv
722
1001
  from .utils import merge_regression_res_with_metadata, save_settings, calculate_shortest_distance, correct_metadata
723
1002
  from .settings import get_perform_regression_default_settings
@@ -859,7 +1138,7 @@ def perform_regression(settings):
859
1138
  Detect outliers in 'outlier_col' of 'df' using the 1.5 × IQR rule,
860
1139
  and return values from 'return_col' that correspond to those outliers.
861
1140
 
862
- Parameters:
1141
+ Args:
863
1142
  -----------
864
1143
  df : pd.DataFrame
865
1144
  Input DataFrame.
@@ -946,7 +1225,9 @@ def perform_regression(settings):
946
1225
  if settings['fraction_threshold'] is None:
947
1226
  settings['fraction_threshold'] = graph_sequencing_stats(settings)
948
1227
 
949
- independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plateID'], filter_column=filter_column, filter_value=filter_value)
1228
+ independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plateID'], filter_column=filter_column, filter_value=filter_value, pc=settings['positive_control'], nc=settings['negative_control'], remove_pc_nc=settings['remove_pc_nc'])
1229
+
1230
+
950
1231
 
951
1232
  independent_df, n_grna, n_gene = _count_variable_instances(independent_df, column_1='grna', column_2='gene')
952
1233
 
@@ -985,9 +1266,10 @@ def perform_regression(settings):
985
1266
  _, _ = plot_data_from_csv(settings=cell_settings)
986
1267
 
987
1268
  final_grna_df, prc_gene_count_df = grna_metricks(merged_df)
988
-
1269
+ print("outliers",settings['outlier_detection'])
989
1270
  if settings['outlier_detection']:
990
1271
  outliers_grna = get_outlier_reference_values(final_grna_df,outlier_col='grna_well_count',return_col='grna')
1272
+ print("outliers", outliers_grna)
991
1273
  if len (outliers_grna) > 0:
992
1274
  merged_df = merged_df[~merged_df['grna'].isin(outliers_grna)]
993
1275
  final_grna_df, prc_gene_count_df = grna_metricks(merged_df)
@@ -1035,9 +1317,9 @@ def perform_regression(settings):
1035
1317
  except Exception as e:
1036
1318
  print(e)
1037
1319
 
1038
- _ = plot_plates(merged_df, variable=orig_dv, grouping='mean', min_max='allq', cmap='viridis', min_count=None, dst=res_folder)
1039
-
1040
- model, coef_df, regression_type = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
1320
+ _ = plot_plates(merged_df, variable=orig_dv, grouping='mean', min_max='allq', cmap='viridis', min_count=None, dst=res_folder)
1321
+
1322
+ model, coef_df, regression_type = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['regression_level'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
1041
1323
 
1042
1324
  coef_df['grna'] = coef_df['feature'].apply(lambda x: re.search(r'grna\[(.*?)\]', x).group(1) if 'grna' in x else None)
1043
1325
  coef_df['gene'] = coef_df['feature'].apply(lambda x: re.search(r'gene\[(.*?)\]', x).group(1) if 'gene' in x else None)
@@ -1068,6 +1350,7 @@ def perform_regression(settings):
1068
1350
  raise ValueError(f"Unsupported threshold method {settings['threshold_method']}. Supported methods: ['var','variance','std','standard_deveation']")
1069
1351
 
1070
1352
  reg_threshold = mean_coef + (settings['threshold_multiplier'] * coef_mes)
1353
+ print("reg_threshold", reg_threshold)
1071
1354
 
1072
1355
  coef_df.to_csv(results_path, index=False)
1073
1356
  gene_coef_df.to_csv(results_path_gene, index=False)
@@ -1156,8 +1439,28 @@ def perform_regression(settings):
1156
1439
 
1157
1440
  return output
1158
1441
 
1159
- def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filter_value=None):
1160
-
1442
+ def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filter_value=None, pc=None, nc=None, remove_pc_nc=False):
1443
+ """
1444
+ Process barcode count data and compute fractional abundance of each gRNA per well.
1445
+
1446
+ This function loads a count table, standardizes metadata columns, computes the
1447
+ relative abundance of each gRNA in a well (`fraction`), and filters based on a
1448
+ minimum threshold. It also parses `grna` strings into gene-level information if
1449
+ applicable.
1450
+
1451
+ Args:
1452
+ csv_path (str or pd.DataFrame): Path to CSV file or preloaded DataFrame with count data.
1453
+ fraction_threshold (float): Minimum fraction of reads per well for inclusion.
1454
+ plate (str): Plate identifier to use if not present in metadata.
1455
+ filter_column (str or list, optional): Columns to filter values from.
1456
+ filter_value (str or list, optional): Values to exclude from filter_column(s).
1457
+
1458
+ Returns:
1459
+ pd.DataFrame: Filtered DataFrame with columns ['prc', 'grna', 'fraction'].
1460
+
1461
+ Raises:
1462
+ ValueError: If required metadata columns are missing.
1463
+ """
1161
1464
  from .utils import correct_metadata
1162
1465
 
1163
1466
  if isinstance(csv_path, pd.DataFrame):
@@ -1201,6 +1504,16 @@ def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filte
1201
1504
 
1202
1505
  # Create the prc column
1203
1506
  csv_df['prc'] = csv_df['plateID'] + '_' + csv_df['rowID'] + '_' + csv_df['columnID']
1507
+
1508
+ display("Before",csv_df)
1509
+
1510
+ if remove_pc_nc:
1511
+ if pc is not None:
1512
+ csv_df = csv_df[csv_df['grna'] != pc]
1513
+ if nc is not None:
1514
+ csv_df = csv_df[csv_df['grna'] != nc]
1515
+
1516
+ display("After",csv_df)
1204
1517
 
1205
1518
  # Group by prc and calculate the sum of counts
1206
1519
  grouped_df = csv_df.groupby('prc')['count'].sum().reset_index()
@@ -1229,6 +1542,18 @@ def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filte
1229
1542
  return merged_df
1230
1543
 
1231
1544
  def apply_transformation(X, transform):
1545
+ """
1546
+ Apply a mathematical transformation to a variable.
1547
+
1548
+ Supported transformations include logarithm, square root, and square.
1549
+
1550
+ Args:
1551
+ X (array-like): Input variable or DataFrame column.
1552
+ transform (str): Type of transformation ('log', 'sqrt', or 'square').
1553
+
1554
+ Returns:
1555
+ FunctionTransformer or None: Transformer object, or None if transform is unrecognized.
1556
+ """
1232
1557
  if transform == 'log':
1233
1558
  transformer = FunctionTransformer(np.log1p, validate=True)
1234
1559
  elif transform == 'sqrt':
@@ -1240,7 +1565,17 @@ def apply_transformation(X, transform):
1240
1565
  return transformer
1241
1566
 
1242
1567
  def check_normality(data, variable_name, verbose=False):
1243
- """Check if the data is normally distributed using the Shapiro-Wilk test."""
1568
+ """
1569
+ Check if data follows a normal distribution using the Shapiro-Wilk test.
1570
+
1571
+ Args:
1572
+ data (array-like): Data to test for normality.
1573
+ variable_name (str): Name of the variable (used in print statements).
1574
+ verbose (bool): Whether to print test results and interpretation.
1575
+
1576
+ Returns:
1577
+ bool: True if data is normally distributed (p > 0.05), False otherwise.
1578
+ """
1244
1579
  stat, p_value = shapiro(data)
1245
1580
  if verbose:
1246
1581
  print(f"Shapiro-Wilk Test for {variable_name}:\nStatistic: {stat}, P-value: {p_value}")
@@ -1254,6 +1589,17 @@ def check_normality(data, variable_name, verbose=False):
1254
1589
  return False
1255
1590
 
1256
1591
  def clean_controls(df,values, column):
1592
+ """
1593
+ Remove rows from a DataFrame based on specified control values in a given column.
1594
+
1595
+ Args:
1596
+ df (pd.DataFrame): Input DataFrame.
1597
+ values (list): List of control values to remove.
1598
+ column (str): Column from which to remove control values.
1599
+
1600
+ Returns:
1601
+ pd.DataFrame: Filtered DataFrame with control rows removed.
1602
+ """
1257
1603
  if column in df.columns:
1258
1604
  if isinstance(values, list):
1259
1605
  for value in values:
@@ -1262,6 +1608,30 @@ def clean_controls(df,values, column):
1262
1608
  return df
1263
1609
 
1264
1610
  def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
1611
+ """
1612
+ Aggregate and transform single-cell phenotype scores for regression input.
1613
+
1614
+ Groups phenotypic measurements by well (`prc`) and applies optional aggregation
1615
+ (mean, median, quantile). Filters out wells with low cell counts. Optionally applies
1616
+ mathematical transformation and checks for normality of the dependent variable.
1617
+
1618
+ Args:
1619
+ df (pd.DataFrame): Input DataFrame with single-cell phenotype scores.
1620
+ dependent_variable (str): Name of the phenotype column to model.
1621
+ plate (str): Plate ID to assign if not present in metadata.
1622
+ min_cell_count (int): Minimum number of cells per well to retain.
1623
+ agg_type (str): Aggregation method ('mean', 'median', 'quantile', None).
1624
+ transform (str or None): Optional transformation ('log', 'sqrt', or 'square').
1625
+ regression_type (str): Type of regression (affects aggregation logic for 'poisson').
1626
+
1627
+ Returns:
1628
+ tuple: (aggregated_df, dependent_variable_name)
1629
+ - aggregated_df: Aggregated and filtered phenotype DataFrame.
1630
+ - dependent_variable_name: Possibly transformed column name used for modeling.
1631
+
1632
+ Raises:
1633
+ ValueError: If required metadata columns are missing or aggregation type is invalid.
1634
+ """
1265
1635
  from .utils import calculate_shortest_distance, correct_metadata
1266
1636
  df = df.reset_index(drop=True)
1267
1637
  if 'prcfo' in df.columns:
@@ -1339,7 +1709,51 @@ def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='m
1339
1709
  return dependent_df, dependent_variable
1340
1710
 
1341
1711
  def generate_ml_scores(settings):
1342
-
1712
+ """
1713
+ Perform machine learning analysis on single-cell measurement data from one or more screen sources.
1714
+
1715
+ This function loads and merges measurement databases, computes additional features (e.g., recruitment
1716
+ scores, shortest distance between compartments), handles annotations (for supervised classification),
1717
+ trains a machine learning model (e.g., XGBoost), computes SHAP values, and generates output files
1718
+ including predictions, importance scores, SHAP plots, and well-level heatmaps.
1719
+
1720
+ Args:
1721
+ settings (dict): Dictionary of analysis parameters and options. Required keys include:
1722
+ - 'src' (str or list): Path(s) to screen folders with measurement databases.
1723
+ - 'channel_of_interest' (int): Channel index used for computing recruitment scores.
1724
+ - 'location_column' (str): Column used to group wells for training (e.g., 'columnID').
1725
+ - 'positive_control' (str): Value of `location_column` representing positive control wells.
1726
+ - 'negative_control' (str): Value of `location_column` representing negative control wells.
1727
+ - 'exclude' (list): List of feature names to exclude.
1728
+ - 'n_repeats' (int): Number of repetitions for permutation importance.
1729
+ - 'top_features' (int): Number of top features to keep and plot.
1730
+ - 'reg_alpha', 'reg_lambda' (float): Regularization strength for XGBoost.
1731
+ - 'learning_rate' (float): Learning rate for XGBoost.
1732
+ - 'n_estimators' (int): Number of boosting rounds for the model.
1733
+ - 'test_size' (float): Fraction of labeled data to use for testing.
1734
+ - 'model_type_ml' (str): Type of model to use ('xgboost', 'random_forest', etc.).
1735
+ - 'n_jobs' (int): Number of parallel jobs.
1736
+ - 'remove_low_variance_features' (bool): Whether to drop near-constant features.
1737
+ - 'remove_highly_correlated_features' (bool): Whether to drop redundant features.
1738
+ - 'prune_features' (bool): Whether to perform SelectKBest-based pruning.
1739
+ - 'cross_validation' (bool): Whether to use 5-fold cross-validation.
1740
+ - 'heatmap_feature' (str): Feature to plot as a plate heatmap.
1741
+ - 'grouping' (str): Column used to group wells in the heatmap (e.g., 'prc').
1742
+ - 'min_max' (tuple): Min/max range for heatmap normalization.
1743
+ - 'minimum_cell_count' (int): Minimum number of cells per well for heatmap inclusion.
1744
+ - 'cmap' (str): Colormap for the heatmap.
1745
+ - 'save_to_db' (bool): If True, update the database with prediction column.
1746
+ - 'verbose' (bool): If True, print detailed logs.
1747
+ - 'annotation_column' (str or None): If provided, annotate data from the `png_list` table.
1748
+
1749
+ Returns:
1750
+ list:
1751
+ - output (list): Includes DataFrames and model objects from the `ml_analysis` function.
1752
+ - plate_heatmap (matplotlib.Figure): Plate-level heatmap of the selected feature.
1753
+
1754
+ Raises:
1755
+ ValueError: If required columns are missing or specified features are not found.
1756
+ """
1343
1757
  from .io import _read_and_merge_data, _read_db
1344
1758
  from .plot import plot_plates
1345
1759
  from .utils import get_ml_results_paths, add_column_to_database, calculate_shortest_distance
@@ -1475,31 +1889,53 @@ def generate_ml_scores(settings):
1475
1889
  return [output, plate_heatmap]
1476
1890
 
1477
1891
  def ml_analysis(df, channel_of_interest=3, location_column='columnID', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, reg_alpha=0.1, reg_lambda=1.0, learning_rate=0.00001, n_estimators=1000, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, prune_features=False, cross_validation=False, verbose=False):
1478
-
1479
1892
  """
1480
- Calculates permutation importance for numerical features in the dataframe,
1481
- comparing groups based on specified column values and uses the model to predict
1482
- the class for all other rows in the dataframe.
1893
+ Train a machine learning classifier to distinguish between positive and negative control wells,
1894
+ compute feature importance via permutation and model-based methods, and assign predictions to all rows.
1895
+
1896
+ This function supports several classifier types (XGBoost, random forest, logistic regression, etc.),
1897
+ with options for feature selection, cross-validation, and SHAP-ready output. It returns predictions,
1898
+ probability estimates, and feature importances, and computes classification metrics.
1483
1899
 
1484
1900
  Args:
1485
- df (pandas.DataFrame): The DataFrame containing the data.
1486
- feature_string (str): String to filter features that contain this substring.
1487
- location_column (str): Column name to use for comparing groups.
1488
- positive_control, negative_control (str): Values in location_column to create subsets for comparison.
1489
- exclude (list or str, optional): Columns to exclude from features.
1490
- n_repeats (int): Number of repeats for permutation importance.
1491
- top_features (int): Number of top features to plot based on permutation importance.
1492
- n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
1493
- test_size (float): Proportion of the dataset to include in the test split.
1494
- random_state (int): Random seed for reproducibility.
1495
- model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
1496
- n_jobs (int): Number of jobs to run in parallel for applicable models.
1901
+ df (pd.DataFrame): Input DataFrame with features and metadata.
1902
+ channel_of_interest (int): Channel index used for filtering relevant features.
1903
+ location_column (str): Column that defines grouping (e.g., 'columnID') to select control wells.
1904
+ positive_control (str or list): Identifier(s) in `location_column` for positive control group.
1905
+ negative_control (str or list): Identifier(s) in `location_column` for negative control group.
1906
+ exclude (list or str, optional): Feature names or substrings to exclude from analysis.
1907
+ n_repeats (int): Number of repetitions for permutation importance calculation.
1908
+ top_features (int): Number of top features to retain and visualize.
1909
+ reg_alpha (float): L1 regularization term (only for XGBoost).
1910
+ reg_lambda (float): L2 regularization term (only for XGBoost).
1911
+ learning_rate (float): Learning rate for gradient-based models (e.g., XGBoost).
1912
+ n_estimators (int): Number of trees or boosting iterations.
1913
+ test_size (float): Fraction of labeled data used for testing.
1914
+ model_type (str): Model type to train: 'xgboost', 'random_forest', 'logistic_regression', or 'gradient_boosting'.
1915
+ n_jobs (int): Number of CPU cores to use (where supported).
1916
+ remove_low_variance_features (bool): If True, drop near-constant features.
1917
+ remove_highly_correlated_features (bool): If True, drop features with high pairwise correlation.
1918
+ prune_features (bool): If True, apply SelectKBest to reduce features before training.
1919
+ cross_validation (bool): If True, use 5-fold stratified cross-validation instead of single split.
1920
+ verbose (bool): If True, print detailed logs and show diagnostic plots.
1497
1921
 
1498
1922
  Returns:
1499
- pandas.DataFrame: The original dataframe with added prediction and data usage columns.
1500
- pandas.DataFrame: DataFrame containing the importances and standard deviations.
1923
+ list:
1924
+ - [0] df (pd.DataFrame): Original DataFrame with added predictions, probabilities, and metadata.
1925
+ - [1] permutation_df (pd.DataFrame): Permutation importance for top features.
1926
+ - [2] feature_importance_df (pd.DataFrame): Model-based feature importance scores (if supported).
1927
+ - [3] model (sklearn-compatible): Trained model object.
1928
+ - [4] X_train (pd.DataFrame): Training feature matrix.
1929
+ - [5] X_test (pd.DataFrame): Test feature matrix.
1930
+ - [6] y_train (pd.Series): Training target labels.
1931
+ - [7] y_test (pd.Series): Test target labels.
1932
+ - [8] metrics_df (pd.DataFrame): Summary classification metrics (precision, recall, f1).
1933
+ - [9] features (list): Final list of features used in training.
1934
+
1935
+ list:
1936
+ - [0] permutation_fig (matplotlib.Figure): Bar plot of permutation importances.
1937
+ - [1] feature_importance_fig (matplotlib.Figure): Bar plot of model importances.
1501
1938
  """
1502
-
1503
1939
  from .utils import filter_dataframe_features
1504
1940
  from .plot import plot_permutation, plot_feature_importance
1505
1941
 
@@ -1709,18 +2145,17 @@ def ml_analysis(df, channel_of_interest=3, location_column='columnID', positive_
1709
2145
  return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df, features], [permutation_fig, feature_importance_fig]
1710
2146
 
1711
2147
  def shap_analysis(model, X_train, X_test):
1712
-
1713
2148
  """
1714
- Performs SHAP analysis on the given model and data.
2149
+ Compute and return a SHAP summary plot figure for a trained model.
1715
2150
 
1716
2151
  Args:
1717
- model: The trained model.
1718
- X_train (pandas.DataFrame): Training feature set.
1719
- X_test (pandas.DataFrame): Testing feature set.
2152
+ model: A trained machine learning model compatible with SHAP.
2153
+ X_train (pd.DataFrame): Training features used to initialize the SHAP explainer.
2154
+ X_test (pd.DataFrame): Test features for which SHAP values will be computed.
2155
+
1720
2156
  Returns:
1721
- fig: Matplotlib figure object containing the SHAP summary plot.
2157
+ matplotlib.figure.Figure: SHAP summary plot figure.
1722
2158
  """
1723
-
1724
2159
  explainer = shap.Explainer(model, X_train)
1725
2160
  shap_values = explainer(X_test)
1726
2161
  # Create a new figure
@@ -1734,14 +2169,15 @@ def shap_analysis(model, X_train, X_test):
1734
2169
 
1735
2170
  def find_optimal_threshold(y_true, y_pred_proba):
1736
2171
  """
1737
- Find the optimal threshold for binary classification based on the F1-score.
2172
+ Determine the optimal decision threshold that maximizes the F1-score
2173
+ based on predicted probabilities.
1738
2174
 
1739
2175
  Args:
1740
- y_true (array-like): True binary labels.
1741
- y_pred_proba (array-like): Predicted probabilities for the positive class.
2176
+ y_true (array-like): Ground-truth binary labels.
2177
+ y_pred_proba (array-like): Predicted class probabilities (positive class).
1742
2178
 
1743
2179
  Returns:
1744
- float: The optimal threshold.
2180
+ float: Threshold that yields the highest F1-score.
1745
2181
  """
1746
2182
  precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
1747
2183
  f1_scores = 2 * (precision * recall) / (precision + recall)
@@ -1751,17 +2187,19 @@ def find_optimal_threshold(y_true, y_pred_proba):
1751
2187
 
1752
2188
  def _calculate_similarity(df, features, col_to_compare, val1, val2):
1753
2189
  """
1754
- Calculate similarity scores of each well to the positive and negative controls using various metrics.
1755
-
2190
+ Calculate similarity scores between each sample and the mean profile of two control groups
2191
+ using multiple distance metrics (Euclidean, Cosine, Mahalanobis, Manhattan, Minkowski, Chebyshev, Bray-Curtis).
2192
+
1756
2193
  Args:
1757
- df (pandas.DataFrame): DataFrame containing the data.
1758
- features (list): List of feature columns to use for similarity calculation.
1759
- col_to_compare (str): Column name to use for comparing groups.
1760
- val1, val2 (str): Values in col_to_compare to create subsets for comparison.
2194
+ df (pd.DataFrame): DataFrame containing feature data and control labels.
2195
+ features (list of str): Feature columns to use in similarity calculations.
2196
+ col_to_compare (str): Column name used to identify control group membership.
2197
+ val1 (str or list): Value(s) indicating the positive control group.
2198
+ val2 (str or list): Value(s) indicating the negative control group.
1761
2199
 
1762
2200
  Returns:
1763
- pandas.DataFrame: DataFrame with similarity scores.
1764
- """
2201
+ pd.DataFrame: Input DataFrame with appended similarity score columns.
2202
+ """
1765
2203
  # Separate positive and negative control wells
1766
2204
  if isinstance(val1, str):
1767
2205
  pos_control = df[df[col_to_compare] == val1][features].mean()
@@ -1814,7 +2252,33 @@ def _calculate_similarity(df, features, col_to_compare, val1, val2):
1814
2252
  return df
1815
2253
 
1816
2254
  def interperate_vision_model(settings={}):
1817
-
2255
+ """
2256
+ Perform feature interpretation on vision model predictions using feature importance,
2257
+ permutation importance, and SHAP analysis.
2258
+
2259
+ Steps:
2260
+ 1. Loads and merges measurement and score data.
2261
+ 2. Computes Random Forest feature importances.
2262
+ 3. Computes permutation-based feature importances.
2263
+ 4. Runs SHAP analysis on selected features.
2264
+ 5. Aggregates SHAP values by compartment and channel and visualizes them using radar plots.
2265
+
2266
+ Args:
2267
+ settings (dict): Dictionary containing configuration options, including:
2268
+ - src (str): Path to measurement database.
2269
+ - tables (list): List of tables to include.
2270
+ - score_column (str): Column with predicted or measured scores.
2271
+ - top_features (int): Number of features to retain and plot.
2272
+ - shap (bool): Whether to perform SHAP analysis.
2273
+ - shap_sample (bool): Whether to subsample data before SHAP.
2274
+ - feature_importance (bool): Whether to compute Random Forest importance.
2275
+ - permutation_importance (bool): Whether to compute permutation importance.
2276
+ - n_jobs (int): Number of parallel jobs for model training and permutation.
2277
+ - save (bool): Whether to save result CSVs.
2278
+
2279
+ Returns:
2280
+ pd.DataFrame: Merged and scored dataset used for interpretation.
2281
+ """
1818
2282
  from .io import _read_and_merge_data, _results_to_csv
1819
2283
  from .settings import set_interperate_vision_model_defaults
1820
2284
  from .utils import save_settings