spacr 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +3 -2
- spacr/app_classify.py +10 -0
- spacr/app_mask.py +9 -0
- spacr/app_measure.py +9 -0
- spacr/app_sequencing.py +9 -0
- spacr/core.py +172 -1
- spacr/deep_spacr.py +296 -7
- spacr/gui.py +68 -0
- spacr/gui_core.py +319 -10
- spacr/gui_elements.py +772 -13
- spacr/gui_utils.py +304 -12
- spacr/io.py +887 -71
- spacr/logger.py +36 -0
- spacr/measure.py +206 -28
- spacr/ml.py +606 -142
- spacr/plot.py +797 -131
- spacr/sequencing.py +363 -8
- spacr/settings.py +1158 -38
- spacr/sp_stats.py +80 -12
- spacr/spacr_cellpose.py +115 -2
- spacr/submodules.py +747 -19
- spacr/timelapse.py +237 -53
- spacr/toxo.py +132 -6
- spacr/utils.py +2422 -80
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/METADATA +31 -17
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/RECORD +30 -30
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/LICENSE +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/WHEEL +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/entry_points.txt +0 -0
- {spacr-1.0.9.dist-info → spacr-1.1.1.dist-info}/top_level.txt +0 -0
spacr/ml.py
CHANGED
@@ -49,16 +49,56 @@ warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0
|
|
49
49
|
|
50
50
|
|
51
51
|
class QuasiBinomial(Binomial):
|
52
|
-
"""
|
52
|
+
"""
|
53
|
+
Quasi-Binomial family for generalized linear models with adjustable dispersion.
|
54
|
+
|
55
|
+
Extends the standard Binomial family from `statsmodels` to allow for modeling overdispersion
|
56
|
+
in binomial data, where the observed variance exceeds the nominal binomial variance.
|
57
|
+
|
58
|
+
This is often used in biological and ecological count data where extra-binomial variation is common.
|
59
|
+
"""
|
53
60
|
def __init__(self, link=logit(), dispersion=1.0):
|
61
|
+
"""
|
62
|
+
Initialize the Quasi-Binomial family with a specified link function and dispersion factor.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
link (statsmodels.genmod.families.links): The link function to use. Defaults to logit().
|
66
|
+
dispersion (float): A positive float to scale the variance.
|
67
|
+
Use values >1 to model overdispersion. Defaults to 1.0.
|
68
|
+
"""
|
54
69
|
super().__init__(link=link)
|
55
70
|
self.dispersion = dispersion
|
56
71
|
|
57
72
|
def variance(self, mu):
|
58
|
-
"""
|
73
|
+
"""
|
74
|
+
Compute the variance function for the quasi-binomial model.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
mu (array-like): The mean response values.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
array-like: Variance scaled by the dispersion parameter.
|
81
|
+
"""
|
59
82
|
return self.dispersion * super().variance(mu)
|
60
83
|
|
61
84
|
def calculate_p_values(X, y, model):
|
85
|
+
"""
|
86
|
+
Calculate p-values for model coefficients using residuals and t-statistics.
|
87
|
+
|
88
|
+
Args
|
89
|
+
----------
|
90
|
+
X : np.ndarray
|
91
|
+
The input feature matrix used for fitting the model.
|
92
|
+
y : np.ndarray
|
93
|
+
The target variable.
|
94
|
+
model : object
|
95
|
+
A fitted sklearn linear model (must have `.predict` and `.coef_` attributes).
|
96
|
+
|
97
|
+
Returns
|
98
|
+
-------
|
99
|
+
np.ndarray
|
100
|
+
An array of p-values corresponding to each coefficient in the model.
|
101
|
+
"""
|
62
102
|
# Predict y values
|
63
103
|
y_pred = model.predict(X)
|
64
104
|
# Calculate residuals
|
@@ -78,6 +118,25 @@ def calculate_p_values(X, y, model):
|
|
78
118
|
return np.array(p_values) # Ensure p_values is a 1-dimensional array
|
79
119
|
|
80
120
|
def perform_mixed_model(y, X, groups, alpha=1.0):
|
121
|
+
"""
|
122
|
+
Perform mixed effects regression and return the fitted model.
|
123
|
+
|
124
|
+
Args
|
125
|
+
----------
|
126
|
+
y : pd.Series or np.ndarray
|
127
|
+
Target variable.
|
128
|
+
X : pd.DataFrame
|
129
|
+
Feature matrix.
|
130
|
+
groups : pd.Series or np.ndarray
|
131
|
+
Grouping variable for mixed model.
|
132
|
+
alpha : float, optional
|
133
|
+
Regularization strength for Ridge regression when multicollinearity is detected.
|
134
|
+
|
135
|
+
Returns
|
136
|
+
-------
|
137
|
+
statsmodels.regression.mixed_linear_model.MixedLMResults
|
138
|
+
Fitted mixed model results.
|
139
|
+
"""
|
81
140
|
# Ensure groups are defined correctly and check for multicollinearity
|
82
141
|
if groups is None:
|
83
142
|
raise ValueError("Groups must be defined for mixed model regression")
|
@@ -99,7 +158,25 @@ def perform_mixed_model(y, X, groups, alpha=1.0):
|
|
99
158
|
return result
|
100
159
|
|
101
160
|
def create_volcano_filename(csv_path, regression_type, alpha, dst):
|
102
|
-
"""
|
161
|
+
"""
|
162
|
+
Construct the file name for saving the volcano plot.
|
163
|
+
|
164
|
+
Args
|
165
|
+
----------
|
166
|
+
csv_path : str
|
167
|
+
Path to the CSV file containing model results.
|
168
|
+
regression_type : str
|
169
|
+
Type of regression performed.
|
170
|
+
alpha : float
|
171
|
+
Alpha value (used for quantile regression naming).
|
172
|
+
dst : str or None
|
173
|
+
Destination directory to save the file.
|
174
|
+
|
175
|
+
Returns
|
176
|
+
-------
|
177
|
+
str
|
178
|
+
Full path for the output volcano plot PDF file.
|
179
|
+
"""
|
103
180
|
volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
|
104
181
|
volcano_filename = f"{regression_type}_{volcano_filename}" if regression_type != 'quantile' else f"{alpha}_{volcano_filename}"
|
105
182
|
|
@@ -108,7 +185,23 @@ def create_volcano_filename(csv_path, regression_type, alpha, dst):
|
|
108
185
|
return os.path.join(os.path.dirname(csv_path), volcano_filename)
|
109
186
|
|
110
187
|
def scale_variables(X, y):
|
111
|
-
"""
|
188
|
+
"""
|
189
|
+
Min-max scale both independent and dependent variables.
|
190
|
+
|
191
|
+
Args
|
192
|
+
----------
|
193
|
+
X : pd.DataFrame
|
194
|
+
Feature matrix.
|
195
|
+
y : np.ndarray
|
196
|
+
Target variable.
|
197
|
+
|
198
|
+
Returns
|
199
|
+
-------
|
200
|
+
X_scaled : pd.DataFrame
|
201
|
+
Scaled feature matrix.
|
202
|
+
y_scaled : np.ndarray
|
203
|
+
Scaled target variable.
|
204
|
+
"""
|
112
205
|
scaler_X = MinMaxScaler()
|
113
206
|
scaler_y = MinMaxScaler()
|
114
207
|
|
@@ -117,63 +210,20 @@ def scale_variables(X, y):
|
|
117
210
|
|
118
211
|
return X_scaled, y_scaled
|
119
212
|
|
120
|
-
def
|
121
|
-
"""
|
122
|
-
|
123
|
-
coefs = model.params
|
124
|
-
p_values = model.pvalues
|
125
|
-
|
126
|
-
coef_df = pd.DataFrame({
|
127
|
-
'feature': coefs.index,
|
128
|
-
'coefficient': coefs.values,
|
129
|
-
'p_value': p_values.values
|
130
|
-
})
|
131
|
-
|
132
|
-
elif regression_type in ['ridge', 'lasso']:
|
133
|
-
coefs = model.coef_.flatten()
|
134
|
-
p_values = calculate_p_values(X, y, model)
|
135
|
-
|
136
|
-
coef_df = pd.DataFrame({
|
137
|
-
'feature': X.columns,
|
138
|
-
'coefficient': coefs,
|
139
|
-
'p_value': p_values
|
140
|
-
})
|
141
|
-
|
142
|
-
else:
|
143
|
-
coefs = model.coef_
|
144
|
-
intercept = model.intercept_
|
145
|
-
feature_names = X.design_info.column_names
|
213
|
+
def select_glm_family(y):
|
214
|
+
"""
|
215
|
+
Automatically select a GLM family based on the nature of the dependent variable.
|
146
216
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
coef_df.loc[0, 'coefficient'] += intercept
|
152
|
-
coef_df['p_value'] = np.nan # Placeholder since sklearn doesn't provide p-values
|
217
|
+
Args
|
218
|
+
----------
|
219
|
+
y : np.ndarray
|
220
|
+
Target variable.
|
153
221
|
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
def check_distribution_v1(y):
|
160
|
-
"""Check the type of distribution to recommend a model."""
|
161
|
-
if np.all((y == 0) | (y == 1)):
|
162
|
-
print("Detected binary data.")
|
163
|
-
return 'logit'
|
164
|
-
elif (y > 0).all() and (y < 1).all():
|
165
|
-
print("Detected continuous data between 0 and 1 (excluding 0 and 1).")
|
166
|
-
return 'beta'
|
167
|
-
elif (y >= 0).all() and (y <= 1).all():
|
168
|
-
print("Detected continuous data between 0 and 1 (including 0 or 1).")
|
169
|
-
# Consider quasi-binomial regression
|
170
|
-
return 'quasi_binomial'
|
171
|
-
else:
|
172
|
-
print("Using OLS as a fallback.")
|
173
|
-
return 'ols'
|
174
|
-
|
175
|
-
def select_glm_family(y):
|
176
|
-
"""Select the appropriate GLM family based on the data."""
|
222
|
+
Returns
|
223
|
+
-------
|
224
|
+
sm.families.Family
|
225
|
+
GLM family appropriate for the data.
|
226
|
+
"""
|
177
227
|
if np.all((y == 0) | (y == 1)):
|
178
228
|
print("Using Binomial family (for binary data).")
|
179
229
|
return sm.families.Binomial()
|
@@ -187,17 +237,77 @@ def select_glm_family(y):
|
|
187
237
|
print("Using Gaussian family (for continuous data).")
|
188
238
|
return sm.families.Gaussian()
|
189
239
|
|
190
|
-
def prepare_formula(dependent_variable, random_row_column_effects=False):
|
191
|
-
"""
|
240
|
+
def prepare_formula(dependent_variable, random_row_column_effects=False, regression_level="gene"):
|
241
|
+
"""
|
242
|
+
Prepare the formula for regression modeling based on model design.
|
243
|
+
|
244
|
+
Args
|
245
|
+
----------
|
246
|
+
dependent_variable : str
|
247
|
+
Name of the dependent variable.
|
248
|
+
random_row_column_effects : bool, optional
|
249
|
+
Whether to include row and column IDs as random effects.
|
250
|
+
|
251
|
+
Returns
|
252
|
+
-------
|
253
|
+
str
|
254
|
+
Regression formula string.
|
255
|
+
"""
|
256
|
+
#if random_row_column_effects:
|
257
|
+
# # Random effects for row and column + gene weighted by gene_fraction + grna weighted by fraction
|
258
|
+
# return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene'
|
259
|
+
#return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene + rowID + columnID'
|
260
|
+
|
192
261
|
if random_row_column_effects:
|
193
|
-
|
194
|
-
|
195
|
-
|
262
|
+
if regression_level == "gene":
|
263
|
+
formula = f'{dependent_variable} ~ gene_fraction:gene + plateID + rowID + columnID'
|
264
|
+
|
265
|
+
elif regression_level == "grna":
|
266
|
+
formula = f'{dependent_variable} ~ fraction:grna + plateID + rowID + columnID'
|
267
|
+
|
268
|
+
elif regression_level == "both":
|
269
|
+
formula = f'{dependent_variable} ~ fraction:grna + gene_fraction:gene + plateID + rowID + columnID'
|
270
|
+
|
271
|
+
else:
|
272
|
+
print(f"Unknown regression_level: {regression_level}. Defaulting to gene model.")
|
273
|
+
formula = f'{dependent_variable} ~ gene_fraction:gene + plateID + rowID + columnID'
|
274
|
+
else:
|
275
|
+
if regression_level == "gene":
|
276
|
+
formula = f'{dependent_variable} ~ gene_fraction:gene'
|
277
|
+
|
278
|
+
elif regression_level == "grna":
|
279
|
+
formula = f'{dependent_variable} ~ fraction:grna'
|
280
|
+
|
281
|
+
elif regression_level == "both":
|
282
|
+
formula = f'{dependent_variable} ~ fraction:grna + gene_fraction:gene'
|
283
|
+
|
284
|
+
else:
|
285
|
+
print(f"Unknown type: {regression_level}. Defaulting to gene model.")
|
286
|
+
formula = f'{dependent_variable} ~ gene_fraction:gene'
|
287
|
+
|
288
|
+
return formula
|
196
289
|
|
197
290
|
def fit_mixed_model(df, formula, dst):
|
291
|
+
"""
|
292
|
+
Fit a mixed linear model with random effects for plate, row, and column.
|
293
|
+
|
294
|
+
Args
|
295
|
+
----------
|
296
|
+
df : pd.DataFrame
|
297
|
+
Input data.
|
298
|
+
formula : str
|
299
|
+
Regression formula.
|
300
|
+
dst : str
|
301
|
+
Output folder to save diagnostic plots.
|
302
|
+
|
303
|
+
Returns
|
304
|
+
-------
|
305
|
+
mixed_model : statsmodels.regression.mixed_linear_model.MixedLMResults
|
306
|
+
Fitted model.
|
307
|
+
coef_df : pd.DataFrame
|
308
|
+
DataFrame of coefficients and p-values.
|
309
|
+
"""
|
198
310
|
from .plot import plot_histogram
|
199
|
-
|
200
|
-
"""Fit the mixed model with plate, row_name, and columnID as random effects and return results."""
|
201
311
|
# Specify random effects for plate, row, and column
|
202
312
|
model = smf.mixedlm(formula,
|
203
313
|
data=df,
|
@@ -224,8 +334,21 @@ def fit_mixed_model(df, formula, dst):
|
|
224
334
|
return mixed_model, coef_df
|
225
335
|
|
226
336
|
def check_and_clean_data(df, dependent_variable):
|
227
|
-
"""
|
228
|
-
|
337
|
+
"""
|
338
|
+
Validate and preprocess a DataFrame before model fitting.
|
339
|
+
|
340
|
+
Args
|
341
|
+
----------
|
342
|
+
df : pd.DataFrame
|
343
|
+
Input data.
|
344
|
+
dependent_variable : str
|
345
|
+
Name of the dependent variable column.
|
346
|
+
|
347
|
+
Returns
|
348
|
+
-------
|
349
|
+
pd.DataFrame
|
350
|
+
Cleaned and validated DataFrame, ready for model fitting.
|
351
|
+
"""
|
229
352
|
def handle_missing_values(df, columns):
|
230
353
|
"""Handle missing values in specified columns."""
|
231
354
|
missing_summary = df[columns].isnull().sum()
|
@@ -304,25 +427,31 @@ def check_and_clean_data(df, dependent_variable):
|
|
304
427
|
print("Data is ready for model fitting.")
|
305
428
|
return df_cleaned
|
306
429
|
|
307
|
-
def check_normality_v1(y, variable_name):
|
308
|
-
"""Check if the data is normally distributed using the Shapiro-Wilk test."""
|
309
|
-
from scipy.stats import shapiro
|
310
|
-
|
311
|
-
stat, p = shapiro(y)
|
312
|
-
alpha = 0.05
|
313
|
-
if p > alpha:
|
314
|
-
print(f"{variable_name} is normally distributed (fail to reject H0)")
|
315
|
-
return True
|
316
|
-
else:
|
317
|
-
print(f"{variable_name} is not normally distributed (reject H0)")
|
318
|
-
return False
|
319
|
-
|
320
430
|
def minimum_cell_simulation(settings, num_repeats=10, sample_size=100, tolerance=0.02, smoothing=10, increment=10):
|
321
431
|
"""
|
322
|
-
|
323
|
-
|
324
|
-
|
432
|
+
Estimate the minimum number of cells required per well to stabilize phenotype measurements.
|
433
|
+
|
434
|
+
This function simulates phenotype score stability by repeatedly subsampling increasing numbers
|
435
|
+
of cells per well, calculating the mean absolute difference from the full-well mean. It identifies
|
436
|
+
the minimal sample size (elbow point) at which the average difference drops below a user-defined
|
437
|
+
tolerance, smoothed across wells.
|
325
438
|
|
439
|
+
Args:
|
440
|
+
settings (dict): A dictionary with keys:
|
441
|
+
- 'score_data' (list or str): CSV path(s) with cell-level data.
|
442
|
+
- 'score_column' (str): Column name of the phenotype score.
|
443
|
+
- 'tolerance' (float or int): Allowed deviation from true mean (e.g. 0.02 for 2%).
|
444
|
+
- 'min_cell_count' (int or None): Optional fixed value for annotation.
|
445
|
+
- 'count_data' (list): Used to define the output folder for figure saving.
|
446
|
+
num_repeats (int, optional): Number of times to resample for each sample size. Default is 10.
|
447
|
+
sample_size (int, optional): Number of top wells to simulate (by cell count). Default is 100.
|
448
|
+
tolerance (float or int, optional): Tolerance threshold to define the elbow point. Default is 0.02.
|
449
|
+
smoothing (int, optional): Window size for smoothing the mean absolute difference curve. Default is 10.
|
450
|
+
increment (int, optional): Step size between tested sample sizes. Default is 10.
|
451
|
+
|
452
|
+
Returns:
|
453
|
+
int: Estimated minimal required number of cells (elbow point).
|
454
|
+
"""
|
326
455
|
from .utils import correct_metadata_column_names
|
327
456
|
|
328
457
|
# Load and process data
|
@@ -451,7 +580,24 @@ def minimum_cell_simulation(settings, num_repeats=10, sample_size=100, tolerance
|
|
451
580
|
return elbow_point['sample_size']
|
452
581
|
|
453
582
|
def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
|
454
|
-
"""
|
583
|
+
"""
|
584
|
+
Extract model coefficients, standard errors, and p-values into a DataFrame with annotations.
|
585
|
+
|
586
|
+
Supports various regression types including beta, OLS, GLM, ridge, lasso, and quasi-binomial.
|
587
|
+
Adds classification labels (nc, pc, control, other) and computes -log10(p-values).
|
588
|
+
|
589
|
+
Args:
|
590
|
+
model: Fitted regression model object.
|
591
|
+
regression_type (str): Type of regression (e.g., 'beta', 'ols', 'glm', 'ridge', 'lasso').
|
592
|
+
X (pd.DataFrame): Feature matrix used for fitting the model.
|
593
|
+
y (np.ndarray): Target values used in model fitting.
|
594
|
+
nc (str): Identifier for negative control features.
|
595
|
+
pc (str): Identifier for positive control features.
|
596
|
+
controls (list): List of gRNAs used as general controls.
|
597
|
+
|
598
|
+
Returns:
|
599
|
+
pd.DataFrame: Table of coefficients, p-values, and annotations.
|
600
|
+
"""
|
455
601
|
|
456
602
|
if regression_type == 'beta':
|
457
603
|
# Extract coefficients and standard errors
|
@@ -508,7 +654,24 @@ def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
|
|
508
654
|
return coef_df[~coef_df['feature'].str.contains('row|column')]
|
509
655
|
|
510
656
|
def check_distribution(y, epsilon=1e-6):
|
511
|
-
"""
|
657
|
+
"""
|
658
|
+
Analyze distribution of the target variable and recommend appropriate regression model.
|
659
|
+
|
660
|
+
Checks for:
|
661
|
+
- Binary data (logit)
|
662
|
+
- Continuous [0, 1) data (beta)
|
663
|
+
- Continuous [0, 1] data (quasi-binomial)
|
664
|
+
- Normal distribution (OLS)
|
665
|
+
- Beta distribution fit
|
666
|
+
- Default to GLM otherwise
|
667
|
+
|
668
|
+
Args:
|
669
|
+
y (np.ndarray): Dependent variable values.
|
670
|
+
epsilon (float, optional): Threshold for boundary proximity detection. Default is 1e-6.
|
671
|
+
|
672
|
+
Returns:
|
673
|
+
str: Suggested regression model name (e.g., 'logit', 'beta', 'ols', 'glm').
|
674
|
+
"""
|
512
675
|
|
513
676
|
# Check if the dependent variable is binary (only 0 and 1)
|
514
677
|
if np.all((y == 0) | (y == 1)):
|
@@ -552,7 +715,25 @@ def check_distribution(y, epsilon=1e-6):
|
|
552
715
|
return 'glm'
|
553
716
|
|
554
717
|
def pick_glm_family_and_link(y):
|
555
|
-
"""
|
718
|
+
"""
|
719
|
+
Select the appropriate GLM family and link function based on distribution of y.
|
720
|
+
|
721
|
+
Inspects binary, count, proportion, normal, and overdispersed data and maps to:
|
722
|
+
- Binomial with Logit link
|
723
|
+
- Gaussian with Identity link
|
724
|
+
- Poisson with Log link
|
725
|
+
- Inverse Gaussian with Log link
|
726
|
+
- Negative Binomial with Log link
|
727
|
+
|
728
|
+
Args:
|
729
|
+
y (np.ndarray): Response variable to inspect.
|
730
|
+
|
731
|
+
Returns:
|
732
|
+
sm.families.Family: GLM family object with appropriate link function.
|
733
|
+
|
734
|
+
Raises:
|
735
|
+
ValueError: If data is suitable for Beta regression, which GLM cannot handle.
|
736
|
+
"""
|
556
737
|
if np.all((y == 0) | (y == 1)):
|
557
738
|
print("Binary data detected. Using Binomial family with Logit link.")
|
558
739
|
return sm.families.Binomial(link=sm.families.links.Logit())
|
@@ -587,6 +768,24 @@ def pick_glm_family_and_link(y):
|
|
587
768
|
return sm.families.Gaussian(link=sm.families.links.Identity())
|
588
769
|
|
589
770
|
def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
|
771
|
+
"""
|
772
|
+
Fit a regression model of the specified type to the data.
|
773
|
+
|
774
|
+
Supports multiple regression types: OLS, GLM, beta, logit, probit, ridge, lasso, and mixed models.
|
775
|
+
Automatically performs hyperparameter tuning (alpha) for lasso and ridge if alpha is 0 or None.
|
776
|
+
For GLMs, prints McFadden's pseudo R². For regularized models, prints MSE and coefficient table.
|
777
|
+
|
778
|
+
Args:
|
779
|
+
X (pd.DataFrame): Design matrix of independent variables.
|
780
|
+
y (pd.Series or np.ndarray): Dependent variable.
|
781
|
+
regression_type (str): Type of regression ('ols', 'glm', 'beta', 'logit', 'probit', 'lasso', 'ridge', 'mixed').
|
782
|
+
groups (array-like, optional): Grouping variable for mixed models.
|
783
|
+
alpha (float): Regularization strength for lasso and ridge.
|
784
|
+
cov_type (str, optional): Covariance estimator type for OLS (e.g., 'HC3').
|
785
|
+
|
786
|
+
Returns:
|
787
|
+
model: Fitted model object (statsmodels or sklearn estimator).
|
788
|
+
"""
|
590
789
|
def plot_regression_line(X, y, model):
|
591
790
|
"""Helper to plot regression line for lasso and ridge models."""
|
592
791
|
y_pred = model.predict(X)
|
@@ -652,10 +851,32 @@ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_ty
|
|
652
851
|
|
653
852
|
return model
|
654
853
|
|
655
|
-
def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0,
|
854
|
+
def regression(df, csv_path, dependent_variable='predictions', regression_type=None, regression_level = "gRNA", alpha=1.0,
|
656
855
|
random_row_column_effects=False, nc='233460', pc='220950', controls=[''],
|
657
856
|
dst=None, cov_type=None, plot=False):
|
658
|
-
|
857
|
+
"""
|
858
|
+
Perform regression analysis on a DataFrame with optional plotting and mixed effects support.
|
859
|
+
|
860
|
+
Handles data cleaning, formula preparation, scaling, model fitting, and coefficient extraction.
|
861
|
+
Automatically determines appropriate regression type if not specified.
|
862
|
+
|
863
|
+
Args:
|
864
|
+
df (pd.DataFrame): Input data.
|
865
|
+
csv_path (str): Path to input CSV used for labeling plots/files.
|
866
|
+
dependent_variable (str): Name of column to regress against.
|
867
|
+
regression_type (str, optional): Type of regression (auto-detected if None).
|
868
|
+
alpha (float): Regularization parameter for lasso/ridge.
|
869
|
+
random_row_column_effects (bool): Whether to model row/column as random effects.
|
870
|
+
nc (str): Identifier for negative controls.
|
871
|
+
pc (str): Identifier for positive controls.
|
872
|
+
controls (list): List of general control gRNAs.
|
873
|
+
dst (str, optional): Destination folder for output files.
|
874
|
+
cov_type (str, optional): Covariance estimator for OLS.
|
875
|
+
plot (bool): Whether to generate volcano plot.
|
876
|
+
|
877
|
+
Returns:
|
878
|
+
tuple: (fitted model, coefficients DataFrame, regression type used)
|
879
|
+
"""
|
659
880
|
from .plot import volcano_plot, plot_histogram
|
660
881
|
#from .ml import create_volcano_filename, check_and_clean_data, prepare_formula, scale_variables
|
661
882
|
|
@@ -673,12 +894,12 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
|
|
673
894
|
# Handle mixed effects if row/column effect is treated as random
|
674
895
|
if random_row_column_effects:
|
675
896
|
regression_type = 'mixed'
|
676
|
-
formula = prepare_formula(dependent_variable, random_row_column_effects=True)
|
897
|
+
formula = prepare_formula(dependent_variable, random_row_column_effects=True, regression_level=regression_level)
|
677
898
|
mixed_model, coef_df = fit_mixed_model(df, formula, dst)
|
678
899
|
model = mixed_model
|
679
900
|
else:
|
680
901
|
# Prepare the formula
|
681
|
-
formula = prepare_formula(dependent_variable, random_row_column_effects=False)
|
902
|
+
formula = prepare_formula(dependent_variable, random_row_column_effects=False, regression_level=regression_level)
|
682
903
|
y, X = dmatrices(formula, data=df, return_type='dataframe')
|
683
904
|
|
684
905
|
# Plot histogram of the dependent variable
|
@@ -707,7 +928,14 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
|
|
707
928
|
|
708
929
|
def save_summary_to_file(model, file_path='summary.csv'):
|
709
930
|
"""
|
710
|
-
Save the model's summary output to a
|
931
|
+
Save the model's summary output to a text or CSV file.
|
932
|
+
|
933
|
+
Args:
|
934
|
+
model: Fitted statsmodels model with a `.summary()` method.
|
935
|
+
file_path (str): Output path to save the summary (e.g., 'summary.txt').
|
936
|
+
|
937
|
+
Returns:
|
938
|
+
None
|
711
939
|
"""
|
712
940
|
# Get the summary as a string
|
713
941
|
summary_str = model.summary().as_text()
|
@@ -717,7 +945,58 @@ def save_summary_to_file(model, file_path='summary.csv'):
|
|
717
945
|
f.write(summary_str)
|
718
946
|
|
719
947
|
def perform_regression(settings):
|
720
|
-
|
948
|
+
"""
|
949
|
+
Perform full regression analysis pipeline for pooled CRISPR-Cas9 screens.
|
950
|
+
|
951
|
+
This function integrates data loading, filtering, transformation, regression modeling,
|
952
|
+
statistical significance assessment, visualization, and optional downstream Toxoplasma-specific analysis.
|
953
|
+
It handles both grna-level and gene-level aggregation and outputs results with optional metadata integration.
|
954
|
+
|
955
|
+
Args:
|
956
|
+
settings (dict): Dictionary containing regression settings. Recommended to start with
|
957
|
+
`get_perform_regression_default_settings()` to initialize.
|
958
|
+
|
959
|
+
Required Keys in `settings`:
|
960
|
+
- score_data (str or list): Path(s) to CSV(s) containing phenotypic scores.
|
961
|
+
- count_data (str or list): Path(s) to CSV(s) containing barcode abundance counts.
|
962
|
+
- dependent_variable (str): Column name in `score_data` to use as dependent variable.
|
963
|
+
- regression_type (str or None): Type of regression model to use ('ols', 'glm', 'lasso', etc.).
|
964
|
+
- filter_value (list): Values to exclude from `filter_column`.
|
965
|
+
- filter_column (str): Column in data to apply exclusion filter.
|
966
|
+
- plateID (str): Identifier to use when reconstructing `prc` labels.
|
967
|
+
- alpha (float): Regularization strength for lasso/ridge regressions.
|
968
|
+
- random_row_column_effects (bool): If True, treat row/column as random effects.
|
969
|
+
- negative_control (str): Label for negative controls.
|
970
|
+
- positive_control (str): Label for positive controls.
|
971
|
+
- controls (list): List of control gRNAs for estimating coefficient threshold.
|
972
|
+
- volcano (str): Type of volcano plot to generate ('all', 'gene', 'grna').
|
973
|
+
- transform (bool): Whether to transform dependent variable.
|
974
|
+
- agg_type (str): Aggregation method for per-well phenotypes ('mean', 'median', etc.).
|
975
|
+
- metadata_files (str or list): Metadata CSV(s) to merge with regression results.
|
976
|
+
- toxo (bool): If True, generate Toxoplasma-specific downstream plots.
|
977
|
+
- threshold_method (str): Method to calculate effect size threshold ('std' or 'var').
|
978
|
+
- threshold_multiplier (float): Multiplier for threshold method.
|
979
|
+
- outlier_detection (bool): Whether to remove grnas with outlier well counts.
|
980
|
+
- cov_type (str or None): Covariance type for OLS (e.g., 'HC3').
|
981
|
+
- verbose (bool): Whether to print verbose output.
|
982
|
+
- x_lim, y_lims: Optional axes limits for volcano plots.
|
983
|
+
- tolerance (float): Tolerance for determining minimal cell count via simulation.
|
984
|
+
- min_cell_count (int or None): Minimum number of cells per well for inclusion.
|
985
|
+
- min_n (int): Minimum number of replicates for inclusion in filtered results.
|
986
|
+
|
987
|
+
Returns:
|
988
|
+
dict:
|
989
|
+
'results': Full coefficient table (pandas DataFrame).
|
990
|
+
'significant': Filtered table of statistically significant hits (pandas DataFrame).
|
991
|
+
|
992
|
+
Side Effects:
|
993
|
+
- Saves regression results, metadata-merged results, volcano plots, and phenotype plots to disk.
|
994
|
+
- Optionally generates publication-ready plots for Toxoplasma gene expression and phenotype scores.
|
995
|
+
- Writes cell count, well-grna, and well-gene plots to disk.
|
996
|
+
|
997
|
+
Raises:
|
998
|
+
ValueError: If required keys are missing or regression type is invalid.
|
999
|
+
"""
|
721
1000
|
from .plot import plot_plates, plot_data_from_csv
|
722
1001
|
from .utils import merge_regression_res_with_metadata, save_settings, calculate_shortest_distance, correct_metadata
|
723
1002
|
from .settings import get_perform_regression_default_settings
|
@@ -859,7 +1138,7 @@ def perform_regression(settings):
|
|
859
1138
|
Detect outliers in 'outlier_col' of 'df' using the 1.5 × IQR rule,
|
860
1139
|
and return values from 'return_col' that correspond to those outliers.
|
861
1140
|
|
862
|
-
|
1141
|
+
Args:
|
863
1142
|
-----------
|
864
1143
|
df : pd.DataFrame
|
865
1144
|
Input DataFrame.
|
@@ -946,7 +1225,9 @@ def perform_regression(settings):
|
|
946
1225
|
if settings['fraction_threshold'] is None:
|
947
1226
|
settings['fraction_threshold'] = graph_sequencing_stats(settings)
|
948
1227
|
|
949
|
-
independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plateID'], filter_column=filter_column, filter_value=filter_value)
|
1228
|
+
independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plateID'], filter_column=filter_column, filter_value=filter_value, pc=settings['positive_control'], nc=settings['negative_control'], remove_pc_nc=settings['remove_pc_nc'])
|
1229
|
+
|
1230
|
+
|
950
1231
|
|
951
1232
|
independent_df, n_grna, n_gene = _count_variable_instances(independent_df, column_1='grna', column_2='gene')
|
952
1233
|
|
@@ -985,9 +1266,10 @@ def perform_regression(settings):
|
|
985
1266
|
_, _ = plot_data_from_csv(settings=cell_settings)
|
986
1267
|
|
987
1268
|
final_grna_df, prc_gene_count_df = grna_metricks(merged_df)
|
988
|
-
|
1269
|
+
print("outliers",settings['outlier_detection'])
|
989
1270
|
if settings['outlier_detection']:
|
990
1271
|
outliers_grna = get_outlier_reference_values(final_grna_df,outlier_col='grna_well_count',return_col='grna')
|
1272
|
+
print("outliers", outliers_grna)
|
991
1273
|
if len (outliers_grna) > 0:
|
992
1274
|
merged_df = merged_df[~merged_df['grna'].isin(outliers_grna)]
|
993
1275
|
final_grna_df, prc_gene_count_df = grna_metricks(merged_df)
|
@@ -1035,9 +1317,9 @@ def perform_regression(settings):
|
|
1035
1317
|
except Exception as e:
|
1036
1318
|
print(e)
|
1037
1319
|
|
1038
|
-
_ = plot_plates(merged_df, variable=orig_dv, grouping='mean', min_max='allq', cmap='viridis', min_count=None, dst=res_folder)
|
1039
|
-
|
1040
|
-
model, coef_df, regression_type = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
|
1320
|
+
_ = plot_plates(merged_df, variable=orig_dv, grouping='mean', min_max='allq', cmap='viridis', min_count=None, dst=res_folder)
|
1321
|
+
|
1322
|
+
model, coef_df, regression_type = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['regression_level'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
|
1041
1323
|
|
1042
1324
|
coef_df['grna'] = coef_df['feature'].apply(lambda x: re.search(r'grna\[(.*?)\]', x).group(1) if 'grna' in x else None)
|
1043
1325
|
coef_df['gene'] = coef_df['feature'].apply(lambda x: re.search(r'gene\[(.*?)\]', x).group(1) if 'gene' in x else None)
|
@@ -1068,6 +1350,7 @@ def perform_regression(settings):
|
|
1068
1350
|
raise ValueError(f"Unsupported threshold method {settings['threshold_method']}. Supported methods: ['var','variance','std','standard_deveation']")
|
1069
1351
|
|
1070
1352
|
reg_threshold = mean_coef + (settings['threshold_multiplier'] * coef_mes)
|
1353
|
+
print("reg_threshold", reg_threshold)
|
1071
1354
|
|
1072
1355
|
coef_df.to_csv(results_path, index=False)
|
1073
1356
|
gene_coef_df.to_csv(results_path_gene, index=False)
|
@@ -1156,8 +1439,28 @@ def perform_regression(settings):
|
|
1156
1439
|
|
1157
1440
|
return output
|
1158
1441
|
|
1159
|
-
def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filter_value=None):
|
1160
|
-
|
1442
|
+
def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filter_value=None, pc=None, nc=None, remove_pc_nc=False):
|
1443
|
+
"""
|
1444
|
+
Process barcode count data and compute fractional abundance of each gRNA per well.
|
1445
|
+
|
1446
|
+
This function loads a count table, standardizes metadata columns, computes the
|
1447
|
+
relative abundance of each gRNA in a well (`fraction`), and filters based on a
|
1448
|
+
minimum threshold. It also parses `grna` strings into gene-level information if
|
1449
|
+
applicable.
|
1450
|
+
|
1451
|
+
Args:
|
1452
|
+
csv_path (str or pd.DataFrame): Path to CSV file or preloaded DataFrame with count data.
|
1453
|
+
fraction_threshold (float): Minimum fraction of reads per well for inclusion.
|
1454
|
+
plate (str): Plate identifier to use if not present in metadata.
|
1455
|
+
filter_column (str or list, optional): Columns to filter values from.
|
1456
|
+
filter_value (str or list, optional): Values to exclude from filter_column(s).
|
1457
|
+
|
1458
|
+
Returns:
|
1459
|
+
pd.DataFrame: Filtered DataFrame with columns ['prc', 'grna', 'fraction'].
|
1460
|
+
|
1461
|
+
Raises:
|
1462
|
+
ValueError: If required metadata columns are missing.
|
1463
|
+
"""
|
1161
1464
|
from .utils import correct_metadata
|
1162
1465
|
|
1163
1466
|
if isinstance(csv_path, pd.DataFrame):
|
@@ -1201,6 +1504,16 @@ def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filte
|
|
1201
1504
|
|
1202
1505
|
# Create the prc column
|
1203
1506
|
csv_df['prc'] = csv_df['plateID'] + '_' + csv_df['rowID'] + '_' + csv_df['columnID']
|
1507
|
+
|
1508
|
+
display("Before",csv_df)
|
1509
|
+
|
1510
|
+
if remove_pc_nc:
|
1511
|
+
if pc is not None:
|
1512
|
+
csv_df = csv_df[csv_df['grna'] != pc]
|
1513
|
+
if nc is not None:
|
1514
|
+
csv_df = csv_df[csv_df['grna'] != nc]
|
1515
|
+
|
1516
|
+
display("After",csv_df)
|
1204
1517
|
|
1205
1518
|
# Group by prc and calculate the sum of counts
|
1206
1519
|
grouped_df = csv_df.groupby('prc')['count'].sum().reset_index()
|
@@ -1229,6 +1542,18 @@ def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filte
|
|
1229
1542
|
return merged_df
|
1230
1543
|
|
1231
1544
|
def apply_transformation(X, transform):
|
1545
|
+
"""
|
1546
|
+
Apply a mathematical transformation to a variable.
|
1547
|
+
|
1548
|
+
Supported transformations include logarithm, square root, and square.
|
1549
|
+
|
1550
|
+
Args:
|
1551
|
+
X (array-like): Input variable or DataFrame column.
|
1552
|
+
transform (str): Type of transformation ('log', 'sqrt', or 'square').
|
1553
|
+
|
1554
|
+
Returns:
|
1555
|
+
FunctionTransformer or None: Transformer object, or None if transform is unrecognized.
|
1556
|
+
"""
|
1232
1557
|
if transform == 'log':
|
1233
1558
|
transformer = FunctionTransformer(np.log1p, validate=True)
|
1234
1559
|
elif transform == 'sqrt':
|
@@ -1240,7 +1565,17 @@ def apply_transformation(X, transform):
|
|
1240
1565
|
return transformer
|
1241
1566
|
|
1242
1567
|
def check_normality(data, variable_name, verbose=False):
|
1243
|
-
"""
|
1568
|
+
"""
|
1569
|
+
Check if data follows a normal distribution using the Shapiro-Wilk test.
|
1570
|
+
|
1571
|
+
Args:
|
1572
|
+
data (array-like): Data to test for normality.
|
1573
|
+
variable_name (str): Name of the variable (used in print statements).
|
1574
|
+
verbose (bool): Whether to print test results and interpretation.
|
1575
|
+
|
1576
|
+
Returns:
|
1577
|
+
bool: True if data is normally distributed (p > 0.05), False otherwise.
|
1578
|
+
"""
|
1244
1579
|
stat, p_value = shapiro(data)
|
1245
1580
|
if verbose:
|
1246
1581
|
print(f"Shapiro-Wilk Test for {variable_name}:\nStatistic: {stat}, P-value: {p_value}")
|
@@ -1254,6 +1589,17 @@ def check_normality(data, variable_name, verbose=False):
|
|
1254
1589
|
return False
|
1255
1590
|
|
1256
1591
|
def clean_controls(df,values, column):
|
1592
|
+
"""
|
1593
|
+
Remove rows from a DataFrame based on specified control values in a given column.
|
1594
|
+
|
1595
|
+
Args:
|
1596
|
+
df (pd.DataFrame): Input DataFrame.
|
1597
|
+
values (list): List of control values to remove.
|
1598
|
+
column (str): Column from which to remove control values.
|
1599
|
+
|
1600
|
+
Returns:
|
1601
|
+
pd.DataFrame: Filtered DataFrame with control rows removed.
|
1602
|
+
"""
|
1257
1603
|
if column in df.columns:
|
1258
1604
|
if isinstance(values, list):
|
1259
1605
|
for value in values:
|
@@ -1262,6 +1608,30 @@ def clean_controls(df,values, column):
|
|
1262
1608
|
return df
|
1263
1609
|
|
1264
1610
|
def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
|
1611
|
+
"""
|
1612
|
+
Aggregate and transform single-cell phenotype scores for regression input.
|
1613
|
+
|
1614
|
+
Groups phenotypic measurements by well (`prc`) and applies optional aggregation
|
1615
|
+
(mean, median, quantile). Filters out wells with low cell counts. Optionally applies
|
1616
|
+
mathematical transformation and checks for normality of the dependent variable.
|
1617
|
+
|
1618
|
+
Args:
|
1619
|
+
df (pd.DataFrame): Input DataFrame with single-cell phenotype scores.
|
1620
|
+
dependent_variable (str): Name of the phenotype column to model.
|
1621
|
+
plate (str): Plate ID to assign if not present in metadata.
|
1622
|
+
min_cell_count (int): Minimum number of cells per well to retain.
|
1623
|
+
agg_type (str): Aggregation method ('mean', 'median', 'quantile', None).
|
1624
|
+
transform (str or None): Optional transformation ('log', 'sqrt', or 'square').
|
1625
|
+
regression_type (str): Type of regression (affects aggregation logic for 'poisson').
|
1626
|
+
|
1627
|
+
Returns:
|
1628
|
+
tuple: (aggregated_df, dependent_variable_name)
|
1629
|
+
- aggregated_df: Aggregated and filtered phenotype DataFrame.
|
1630
|
+
- dependent_variable_name: Possibly transformed column name used for modeling.
|
1631
|
+
|
1632
|
+
Raises:
|
1633
|
+
ValueError: If required metadata columns are missing or aggregation type is invalid.
|
1634
|
+
"""
|
1265
1635
|
from .utils import calculate_shortest_distance, correct_metadata
|
1266
1636
|
df = df.reset_index(drop=True)
|
1267
1637
|
if 'prcfo' in df.columns:
|
@@ -1339,7 +1709,51 @@ def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='m
|
|
1339
1709
|
return dependent_df, dependent_variable
|
1340
1710
|
|
1341
1711
|
def generate_ml_scores(settings):
|
1342
|
-
|
1712
|
+
"""
|
1713
|
+
Perform machine learning analysis on single-cell measurement data from one or more screen sources.
|
1714
|
+
|
1715
|
+
This function loads and merges measurement databases, computes additional features (e.g., recruitment
|
1716
|
+
scores, shortest distance between compartments), handles annotations (for supervised classification),
|
1717
|
+
trains a machine learning model (e.g., XGBoost), computes SHAP values, and generates output files
|
1718
|
+
including predictions, importance scores, SHAP plots, and well-level heatmaps.
|
1719
|
+
|
1720
|
+
Args:
|
1721
|
+
settings (dict): Dictionary of analysis parameters and options. Required keys include:
|
1722
|
+
- 'src' (str or list): Path(s) to screen folders with measurement databases.
|
1723
|
+
- 'channel_of_interest' (int): Channel index used for computing recruitment scores.
|
1724
|
+
- 'location_column' (str): Column used to group wells for training (e.g., 'columnID').
|
1725
|
+
- 'positive_control' (str): Value of `location_column` representing positive control wells.
|
1726
|
+
- 'negative_control' (str): Value of `location_column` representing negative control wells.
|
1727
|
+
- 'exclude' (list): List of feature names to exclude.
|
1728
|
+
- 'n_repeats' (int): Number of repetitions for permutation importance.
|
1729
|
+
- 'top_features' (int): Number of top features to keep and plot.
|
1730
|
+
- 'reg_alpha', 'reg_lambda' (float): Regularization strength for XGBoost.
|
1731
|
+
- 'learning_rate' (float): Learning rate for XGBoost.
|
1732
|
+
- 'n_estimators' (int): Number of boosting rounds for the model.
|
1733
|
+
- 'test_size' (float): Fraction of labeled data to use for testing.
|
1734
|
+
- 'model_type_ml' (str): Type of model to use ('xgboost', 'random_forest', etc.).
|
1735
|
+
- 'n_jobs' (int): Number of parallel jobs.
|
1736
|
+
- 'remove_low_variance_features' (bool): Whether to drop near-constant features.
|
1737
|
+
- 'remove_highly_correlated_features' (bool): Whether to drop redundant features.
|
1738
|
+
- 'prune_features' (bool): Whether to perform SelectKBest-based pruning.
|
1739
|
+
- 'cross_validation' (bool): Whether to use 5-fold cross-validation.
|
1740
|
+
- 'heatmap_feature' (str): Feature to plot as a plate heatmap.
|
1741
|
+
- 'grouping' (str): Column used to group wells in the heatmap (e.g., 'prc').
|
1742
|
+
- 'min_max' (tuple): Min/max range for heatmap normalization.
|
1743
|
+
- 'minimum_cell_count' (int): Minimum number of cells per well for heatmap inclusion.
|
1744
|
+
- 'cmap' (str): Colormap for the heatmap.
|
1745
|
+
- 'save_to_db' (bool): If True, update the database with prediction column.
|
1746
|
+
- 'verbose' (bool): If True, print detailed logs.
|
1747
|
+
- 'annotation_column' (str or None): If provided, annotate data from the `png_list` table.
|
1748
|
+
|
1749
|
+
Returns:
|
1750
|
+
list:
|
1751
|
+
- output (list): Includes DataFrames and model objects from the `ml_analysis` function.
|
1752
|
+
- plate_heatmap (matplotlib.Figure): Plate-level heatmap of the selected feature.
|
1753
|
+
|
1754
|
+
Raises:
|
1755
|
+
ValueError: If required columns are missing or specified features are not found.
|
1756
|
+
"""
|
1343
1757
|
from .io import _read_and_merge_data, _read_db
|
1344
1758
|
from .plot import plot_plates
|
1345
1759
|
from .utils import get_ml_results_paths, add_column_to_database, calculate_shortest_distance
|
@@ -1475,31 +1889,53 @@ def generate_ml_scores(settings):
|
|
1475
1889
|
return [output, plate_heatmap]
|
1476
1890
|
|
1477
1891
|
def ml_analysis(df, channel_of_interest=3, location_column='columnID', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, reg_alpha=0.1, reg_lambda=1.0, learning_rate=0.00001, n_estimators=1000, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, prune_features=False, cross_validation=False, verbose=False):
|
1478
|
-
|
1479
1892
|
"""
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1893
|
+
Train a machine learning classifier to distinguish between positive and negative control wells,
|
1894
|
+
compute feature importance via permutation and model-based methods, and assign predictions to all rows.
|
1895
|
+
|
1896
|
+
This function supports several classifier types (XGBoost, random forest, logistic regression, etc.),
|
1897
|
+
with options for feature selection, cross-validation, and SHAP-ready output. It returns predictions,
|
1898
|
+
probability estimates, and feature importances, and computes classification metrics.
|
1483
1899
|
|
1484
1900
|
Args:
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1901
|
+
df (pd.DataFrame): Input DataFrame with features and metadata.
|
1902
|
+
channel_of_interest (int): Channel index used for filtering relevant features.
|
1903
|
+
location_column (str): Column that defines grouping (e.g., 'columnID') to select control wells.
|
1904
|
+
positive_control (str or list): Identifier(s) in `location_column` for positive control group.
|
1905
|
+
negative_control (str or list): Identifier(s) in `location_column` for negative control group.
|
1906
|
+
exclude (list or str, optional): Feature names or substrings to exclude from analysis.
|
1907
|
+
n_repeats (int): Number of repetitions for permutation importance calculation.
|
1908
|
+
top_features (int): Number of top features to retain and visualize.
|
1909
|
+
reg_alpha (float): L1 regularization term (only for XGBoost).
|
1910
|
+
reg_lambda (float): L2 regularization term (only for XGBoost).
|
1911
|
+
learning_rate (float): Learning rate for gradient-based models (e.g., XGBoost).
|
1912
|
+
n_estimators (int): Number of trees or boosting iterations.
|
1913
|
+
test_size (float): Fraction of labeled data used for testing.
|
1914
|
+
model_type (str): Model type to train: 'xgboost', 'random_forest', 'logistic_regression', or 'gradient_boosting'.
|
1915
|
+
n_jobs (int): Number of CPU cores to use (where supported).
|
1916
|
+
remove_low_variance_features (bool): If True, drop near-constant features.
|
1917
|
+
remove_highly_correlated_features (bool): If True, drop features with high pairwise correlation.
|
1918
|
+
prune_features (bool): If True, apply SelectKBest to reduce features before training.
|
1919
|
+
cross_validation (bool): If True, use 5-fold stratified cross-validation instead of single split.
|
1920
|
+
verbose (bool): If True, print detailed logs and show diagnostic plots.
|
1497
1921
|
|
1498
1922
|
Returns:
|
1499
|
-
|
1500
|
-
|
1923
|
+
list:
|
1924
|
+
- [0] df (pd.DataFrame): Original DataFrame with added predictions, probabilities, and metadata.
|
1925
|
+
- [1] permutation_df (pd.DataFrame): Permutation importance for top features.
|
1926
|
+
- [2] feature_importance_df (pd.DataFrame): Model-based feature importance scores (if supported).
|
1927
|
+
- [3] model (sklearn-compatible): Trained model object.
|
1928
|
+
- [4] X_train (pd.DataFrame): Training feature matrix.
|
1929
|
+
- [5] X_test (pd.DataFrame): Test feature matrix.
|
1930
|
+
- [6] y_train (pd.Series): Training target labels.
|
1931
|
+
- [7] y_test (pd.Series): Test target labels.
|
1932
|
+
- [8] metrics_df (pd.DataFrame): Summary classification metrics (precision, recall, f1).
|
1933
|
+
- [9] features (list): Final list of features used in training.
|
1934
|
+
|
1935
|
+
list:
|
1936
|
+
- [0] permutation_fig (matplotlib.Figure): Bar plot of permutation importances.
|
1937
|
+
- [1] feature_importance_fig (matplotlib.Figure): Bar plot of model importances.
|
1501
1938
|
"""
|
1502
|
-
|
1503
1939
|
from .utils import filter_dataframe_features
|
1504
1940
|
from .plot import plot_permutation, plot_feature_importance
|
1505
1941
|
|
@@ -1709,18 +2145,17 @@ def ml_analysis(df, channel_of_interest=3, location_column='columnID', positive_
|
|
1709
2145
|
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df, features], [permutation_fig, feature_importance_fig]
|
1710
2146
|
|
1711
2147
|
def shap_analysis(model, X_train, X_test):
|
1712
|
-
|
1713
2148
|
"""
|
1714
|
-
|
2149
|
+
Compute and return a SHAP summary plot figure for a trained model.
|
1715
2150
|
|
1716
2151
|
Args:
|
1717
|
-
|
1718
|
-
|
1719
|
-
|
2152
|
+
model: A trained machine learning model compatible with SHAP.
|
2153
|
+
X_train (pd.DataFrame): Training features used to initialize the SHAP explainer.
|
2154
|
+
X_test (pd.DataFrame): Test features for which SHAP values will be computed.
|
2155
|
+
|
1720
2156
|
Returns:
|
1721
|
-
|
2157
|
+
matplotlib.figure.Figure: SHAP summary plot figure.
|
1722
2158
|
"""
|
1723
|
-
|
1724
2159
|
explainer = shap.Explainer(model, X_train)
|
1725
2160
|
shap_values = explainer(X_test)
|
1726
2161
|
# Create a new figure
|
@@ -1734,14 +2169,15 @@ def shap_analysis(model, X_train, X_test):
|
|
1734
2169
|
|
1735
2170
|
def find_optimal_threshold(y_true, y_pred_proba):
|
1736
2171
|
"""
|
1737
|
-
|
2172
|
+
Determine the optimal decision threshold that maximizes the F1-score
|
2173
|
+
based on predicted probabilities.
|
1738
2174
|
|
1739
2175
|
Args:
|
1740
|
-
|
1741
|
-
|
2176
|
+
y_true (array-like): Ground-truth binary labels.
|
2177
|
+
y_pred_proba (array-like): Predicted class probabilities (positive class).
|
1742
2178
|
|
1743
2179
|
Returns:
|
1744
|
-
|
2180
|
+
float: Threshold that yields the highest F1-score.
|
1745
2181
|
"""
|
1746
2182
|
precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
|
1747
2183
|
f1_scores = 2 * (precision * recall) / (precision + recall)
|
@@ -1751,17 +2187,19 @@ def find_optimal_threshold(y_true, y_pred_proba):
|
|
1751
2187
|
|
1752
2188
|
def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
1753
2189
|
"""
|
1754
|
-
Calculate similarity scores
|
1755
|
-
|
2190
|
+
Calculate similarity scores between each sample and the mean profile of two control groups
|
2191
|
+
using multiple distance metrics (Euclidean, Cosine, Mahalanobis, Manhattan, Minkowski, Chebyshev, Bray-Curtis).
|
2192
|
+
|
1756
2193
|
Args:
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
2194
|
+
df (pd.DataFrame): DataFrame containing feature data and control labels.
|
2195
|
+
features (list of str): Feature columns to use in similarity calculations.
|
2196
|
+
col_to_compare (str): Column name used to identify control group membership.
|
2197
|
+
val1 (str or list): Value(s) indicating the positive control group.
|
2198
|
+
val2 (str or list): Value(s) indicating the negative control group.
|
1761
2199
|
|
1762
2200
|
Returns:
|
1763
|
-
|
1764
|
-
"""
|
2201
|
+
pd.DataFrame: Input DataFrame with appended similarity score columns.
|
2202
|
+
"""
|
1765
2203
|
# Separate positive and negative control wells
|
1766
2204
|
if isinstance(val1, str):
|
1767
2205
|
pos_control = df[df[col_to_compare] == val1][features].mean()
|
@@ -1814,7 +2252,33 @@ def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
|
1814
2252
|
return df
|
1815
2253
|
|
1816
2254
|
def interperate_vision_model(settings={}):
|
1817
|
-
|
2255
|
+
"""
|
2256
|
+
Perform feature interpretation on vision model predictions using feature importance,
|
2257
|
+
permutation importance, and SHAP analysis.
|
2258
|
+
|
2259
|
+
Steps:
|
2260
|
+
1. Loads and merges measurement and score data.
|
2261
|
+
2. Computes Random Forest feature importances.
|
2262
|
+
3. Computes permutation-based feature importances.
|
2263
|
+
4. Runs SHAP analysis on selected features.
|
2264
|
+
5. Aggregates SHAP values by compartment and channel and visualizes them using radar plots.
|
2265
|
+
|
2266
|
+
Args:
|
2267
|
+
settings (dict): Dictionary containing configuration options, including:
|
2268
|
+
- src (str): Path to measurement database.
|
2269
|
+
- tables (list): List of tables to include.
|
2270
|
+
- score_column (str): Column with predicted or measured scores.
|
2271
|
+
- top_features (int): Number of features to retain and plot.
|
2272
|
+
- shap (bool): Whether to perform SHAP analysis.
|
2273
|
+
- shap_sample (bool): Whether to subsample data before SHAP.
|
2274
|
+
- feature_importance (bool): Whether to compute Random Forest importance.
|
2275
|
+
- permutation_importance (bool): Whether to compute permutation importance.
|
2276
|
+
- n_jobs (int): Number of parallel jobs for model training and permutation.
|
2277
|
+
- save (bool): Whether to save result CSVs.
|
2278
|
+
|
2279
|
+
Returns:
|
2280
|
+
pd.DataFrame: Merged and scored dataset used for interpretation.
|
2281
|
+
"""
|
1818
2282
|
from .io import _read_and_merge_data, _results_to_csv
|
1819
2283
|
from .settings import set_interperate_vision_model_defaults
|
1820
2284
|
from .utils import save_settings
|