spacr 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +19 -3
- spacr/cellpose.py +311 -0
- spacr/core.py +142 -2495
- spacr/deep_spacr.py +151 -29
- spacr/gui.py +1 -0
- spacr/gui_core.py +74 -63
- spacr/gui_elements.py +110 -5
- spacr/gui_utils.py +346 -6
- spacr/io.py +631 -51
- spacr/logger.py +28 -9
- spacr/measure.py +107 -95
- spacr/mediar.py +0 -5
- spacr/ml.py +964 -0
- spacr/openai.py +37 -0
- spacr/plot.py +281 -16
- spacr/resources/data/lopit.csv +3833 -0
- spacr/resources/data/toxoplasma_metadata.csv +8843 -0
- spacr/resources/icons/convert.png +0 -0
- spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
- spacr/sequencing.py +241 -1311
- spacr/settings.py +129 -43
- spacr/sim.py +0 -2
- spacr/submodules.py +348 -0
- spacr/timelapse.py +0 -2
- spacr/toxo.py +233 -0
- spacr/utils.py +275 -173
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
- spacr/chris.py +0 -50
- spacr/graph_learning.py +0 -340
- spacr/resources/MEDIAR/.git +0 -1
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/sim_app.py +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
spacr/ml.py
ADDED
@@ -0,0 +1,964 @@
|
|
1
|
+
import os, shap
|
2
|
+
import pandas as pd
|
3
|
+
import numpy as np
|
4
|
+
from scipy import stats
|
5
|
+
from scipy.stats import shapiro
|
6
|
+
|
7
|
+
import matplotlib.pyplot as plt
|
8
|
+
from IPython.display import display
|
9
|
+
|
10
|
+
import statsmodels.api as sm
|
11
|
+
import statsmodels.formula.api as smf
|
12
|
+
from statsmodels.regression.mixed_linear_model import MixedLM
|
13
|
+
from statsmodels.tools.sm_exceptions import PerfectSeparationError
|
14
|
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
15
|
+
|
16
|
+
from sklearn.linear_model import Lasso, Ridge
|
17
|
+
from sklearn.preprocessing import FunctionTransformer
|
18
|
+
from patsy import dmatrices
|
19
|
+
|
20
|
+
from sklearn.model_selection import train_test_split
|
21
|
+
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
|
22
|
+
from sklearn.linear_model import LogisticRegression
|
23
|
+
from sklearn.inspection import permutation_importance
|
24
|
+
from sklearn.metrics import classification_report, precision_recall_curve
|
25
|
+
from sklearn.preprocessing import StandardScaler
|
26
|
+
from sklearn.preprocessing import MinMaxScaler
|
27
|
+
|
28
|
+
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
29
|
+
|
30
|
+
from xgboost import XGBClassifier
|
31
|
+
|
32
|
+
import matplotlib
|
33
|
+
matplotlib.use('Agg')
|
34
|
+
|
35
|
+
import warnings
|
36
|
+
warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
|
37
|
+
|
38
|
+
def calculate_p_values(X, y, model):
|
39
|
+
# Predict y values
|
40
|
+
y_pred = model.predict(X)
|
41
|
+
# Calculate residuals
|
42
|
+
residuals = y - y_pred
|
43
|
+
# Calculate the standard error of the residuals
|
44
|
+
dof = X.shape[0] - X.shape[1] - 1
|
45
|
+
residual_std_error = np.sqrt(np.sum(residuals ** 2) / dof)
|
46
|
+
# Calculate the standard error of the coefficients
|
47
|
+
X_design = np.hstack((np.ones((X.shape[0], 1)), X)) # Add intercept
|
48
|
+
# Use pseudoinverse instead of inverse to handle singular matrices
|
49
|
+
coef_var_covar = residual_std_error ** 2 * np.linalg.pinv(X_design.T @ X_design)
|
50
|
+
coef_standard_errors = np.sqrt(np.diag(coef_var_covar))
|
51
|
+
# Calculate t-statistics
|
52
|
+
t_stats = model.coef_ / coef_standard_errors[1:] # Skip intercept error
|
53
|
+
# Calculate p-values
|
54
|
+
p_values = [2 * (1 - stats.t.cdf(np.abs(t), dof)) for t in t_stats]
|
55
|
+
return np.array(p_values) # Ensure p_values is a 1-dimensional array
|
56
|
+
|
57
|
+
def perform_mixed_model(y, X, groups, alpha=1.0):
|
58
|
+
# Ensure groups are defined correctly and check for multicollinearity
|
59
|
+
if groups is None:
|
60
|
+
raise ValueError("Groups must be defined for mixed model regression")
|
61
|
+
|
62
|
+
# Check for multicollinearity by calculating the VIF for each feature
|
63
|
+
X_np = X.values
|
64
|
+
vif = [variance_inflation_factor(X_np, i) for i in range(X_np.shape[1])]
|
65
|
+
print(f"VIF: {vif}")
|
66
|
+
if any(v > 10 for v in vif):
|
67
|
+
print(f"Multicollinearity detected with VIF: {vif}. Applying Ridge regression to the fixed effects.")
|
68
|
+
ridge = Ridge(alpha=alpha)
|
69
|
+
ridge.fit(X, y)
|
70
|
+
X_ridge = ridge.coef_ * X # Adjust X with Ridge coefficients
|
71
|
+
model = MixedLM(y, X_ridge, groups=groups)
|
72
|
+
else:
|
73
|
+
model = MixedLM(y, X, groups=groups)
|
74
|
+
|
75
|
+
result = model.fit()
|
76
|
+
return result
|
77
|
+
|
78
|
+
def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
|
79
|
+
|
80
|
+
def plot_regression_line(X, y, model):
|
81
|
+
"""Helper to plot regression line for lasso and ridge models."""
|
82
|
+
y_pred = model.predict(X)
|
83
|
+
plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
|
84
|
+
plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
|
85
|
+
plt.xlabel('Features')
|
86
|
+
plt.ylabel('Dependent Variable')
|
87
|
+
plt.legend()
|
88
|
+
plt.show()
|
89
|
+
|
90
|
+
# Define the dictionary with callables (lambdas) to delay evaluation
|
91
|
+
model_map = {
|
92
|
+
'ols': lambda: sm.OLS(y, X).fit(cov_type=cov_type) if cov_type else sm.OLS(y, X).fit(),
|
93
|
+
'gls': lambda: sm.GLS(y, X).fit(),
|
94
|
+
'wls': lambda: sm.WLS(y, X, weights=1 / np.sqrt(X.iloc[:, 1])).fit(),
|
95
|
+
'rlm': lambda: sm.RLM(y, X, M=sm.robust.norms.HuberT()).fit(),
|
96
|
+
'glm': lambda: sm.GLM(y, X, family=sm.families.Gaussian()).fit(),
|
97
|
+
'quantile': lambda: sm.QuantReg(y, X).fit(q=alpha),
|
98
|
+
'logit': lambda: sm.Logit(y, X).fit(),
|
99
|
+
'probit': lambda: sm.Probit(y, X).fit(),
|
100
|
+
'poisson': lambda: sm.Poisson(y, X).fit(),
|
101
|
+
'lasso': lambda: Lasso(alpha=alpha).fit(X, y),
|
102
|
+
'ridge': lambda: Ridge(alpha=alpha).fit(X, y)
|
103
|
+
}
|
104
|
+
|
105
|
+
# Call the appropriate model only when needed
|
106
|
+
if regression_type in model_map:
|
107
|
+
model = model_map[regression_type]()
|
108
|
+
elif regression_type == 'mixed':
|
109
|
+
model = perform_mixed_model(y, X, groups, alpha=alpha)
|
110
|
+
else:
|
111
|
+
raise ValueError(f"Unsupported regression type {regression_type}")
|
112
|
+
|
113
|
+
if regression_type in ['lasso', 'ridge']:
|
114
|
+
plot_regression_line(X, y, model)
|
115
|
+
|
116
|
+
return model
|
117
|
+
|
118
|
+
def create_volcano_filename(csv_path, regression_type, alpha, dst):
|
119
|
+
"""Create and return the volcano plot filename based on regression type and alpha."""
|
120
|
+
volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
|
121
|
+
volcano_filename = f"{regression_type}_{volcano_filename}" if regression_type != 'quantile' else f"{alpha}_{volcano_filename}"
|
122
|
+
|
123
|
+
if dst:
|
124
|
+
return os.path.join(dst, volcano_filename)
|
125
|
+
return os.path.join(os.path.dirname(csv_path), volcano_filename)
|
126
|
+
|
127
|
+
def scale_variables(X, y):
|
128
|
+
"""Scale independent (X) and dependent (y) variables using MinMaxScaler."""
|
129
|
+
scaler_X = MinMaxScaler()
|
130
|
+
scaler_y = MinMaxScaler()
|
131
|
+
|
132
|
+
X_scaled = pd.DataFrame(scaler_X.fit_transform(X), columns=X.columns)
|
133
|
+
y_scaled = scaler_y.fit_transform(y)
|
134
|
+
|
135
|
+
return X_scaled, y_scaled
|
136
|
+
|
137
|
+
def process_model_coefficients(model, regression_type, X, y, highlight):
|
138
|
+
"""Return DataFrame of model coefficients and p-values."""
|
139
|
+
if regression_type in ['ols', 'gls', 'wls', 'rlm', 'glm', 'mixed', 'quantile', 'logit', 'probit', 'poisson']:
|
140
|
+
coefs = model.params
|
141
|
+
p_values = model.pvalues
|
142
|
+
|
143
|
+
coef_df = pd.DataFrame({
|
144
|
+
'feature': coefs.index,
|
145
|
+
'coefficient': coefs.values,
|
146
|
+
'p_value': p_values.values
|
147
|
+
})
|
148
|
+
|
149
|
+
elif regression_type in ['ridge', 'lasso']:
|
150
|
+
coefs = model.coef_.flatten()
|
151
|
+
p_values = calculate_p_values(X, y, model)
|
152
|
+
|
153
|
+
coef_df = pd.DataFrame({
|
154
|
+
'feature': X.columns,
|
155
|
+
'coefficient': coefs,
|
156
|
+
'p_value': p_values
|
157
|
+
})
|
158
|
+
|
159
|
+
else:
|
160
|
+
coefs = model.coef_
|
161
|
+
intercept = model.intercept_
|
162
|
+
feature_names = X.design_info.column_names
|
163
|
+
|
164
|
+
coef_df = pd.DataFrame({
|
165
|
+
'feature': feature_names,
|
166
|
+
'coefficient': coefs
|
167
|
+
})
|
168
|
+
coef_df.loc[0, 'coefficient'] += intercept
|
169
|
+
coef_df['p_value'] = np.nan # Placeholder since sklearn doesn't provide p-values
|
170
|
+
|
171
|
+
coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
|
172
|
+
coef_df['highlight'] = coef_df['feature'].apply(lambda x: highlight in x)
|
173
|
+
|
174
|
+
return coef_df[~coef_df['feature'].str.contains('row|column')]
|
175
|
+
|
176
|
+
def prepare_formula(dependent_variable, random_row_column_effects=False):
|
177
|
+
"""Return the regression formula using random effects for plate, row, and column."""
|
178
|
+
if random_row_column_effects:
|
179
|
+
# Random effects for row and column + gene weighted by gene_fraction + grna weighted by fraction
|
180
|
+
return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene'
|
181
|
+
return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene + row + column'
|
182
|
+
|
183
|
+
def fit_mixed_model(df, formula, dst):
|
184
|
+
from .plot import plot_histogram
|
185
|
+
|
186
|
+
"""Fit the mixed model with plate, row, and column as random effects and return results."""
|
187
|
+
# Specify random effects for plate, row, and column
|
188
|
+
model = smf.mixedlm(formula,
|
189
|
+
data=df,
|
190
|
+
groups=df['plate'],
|
191
|
+
re_formula="1 + row + column",
|
192
|
+
vc_formula={"row": "0 + row", "column": "0 + column"})
|
193
|
+
|
194
|
+
mixed_model = model.fit()
|
195
|
+
|
196
|
+
# Plot residuals
|
197
|
+
df['residuals'] = mixed_model.resid
|
198
|
+
plot_histogram(df, 'residuals', dst=dst)
|
199
|
+
|
200
|
+
# Return coefficients and p-values
|
201
|
+
coefs = mixed_model.params
|
202
|
+
p_values = mixed_model.pvalues
|
203
|
+
|
204
|
+
coef_df = pd.DataFrame({
|
205
|
+
'feature': coefs.index,
|
206
|
+
'coefficient': coefs.values,
|
207
|
+
'p_value': p_values.values
|
208
|
+
})
|
209
|
+
|
210
|
+
return mixed_model, coef_df
|
211
|
+
|
212
|
+
def check_and_clean_data(df, dependent_variable):
|
213
|
+
"""Check for collinearity, missing values, or invalid types in relevant columns. Clean data accordingly."""
|
214
|
+
|
215
|
+
def handle_missing_values(df, columns):
|
216
|
+
"""Handle missing values in specified columns."""
|
217
|
+
missing_summary = df[columns].isnull().sum()
|
218
|
+
print("Missing values summary:")
|
219
|
+
print(missing_summary)
|
220
|
+
|
221
|
+
# Drop rows with missing values in these fields
|
222
|
+
df_cleaned = df.dropna(subset=columns)
|
223
|
+
if df_cleaned.shape[0] < df.shape[0]:
|
224
|
+
print(f"Dropped {df.shape[0] - df_cleaned.shape[0]} rows with missing values in {columns}.")
|
225
|
+
return df_cleaned
|
226
|
+
|
227
|
+
def ensure_valid_types(df, columns):
|
228
|
+
"""Ensure that specified columns are categorical."""
|
229
|
+
for col in columns:
|
230
|
+
if not pd.api.types.is_categorical_dtype(df[col]):
|
231
|
+
df[col] = pd.Categorical(df[col])
|
232
|
+
print(f"Converted {col} to categorical type.")
|
233
|
+
return df
|
234
|
+
|
235
|
+
def check_collinearity(df, columns):
|
236
|
+
"""Check for collinearity using VIF (Variance Inflation Factor)."""
|
237
|
+
print("Checking for collinearity...")
|
238
|
+
|
239
|
+
# Only include fraction and the dependent variable for collinearity check
|
240
|
+
df_encoded = df[columns]
|
241
|
+
|
242
|
+
# Ensure all data in df_encoded is numeric
|
243
|
+
df_encoded = df_encoded.apply(pd.to_numeric, errors='coerce')
|
244
|
+
|
245
|
+
# Check for perfect multicollinearity (i.e., rank deficiency)
|
246
|
+
if np.linalg.matrix_rank(df_encoded.values) < df_encoded.shape[1]:
|
247
|
+
print("Warning: Perfect multicollinearity detected! Dropping correlated columns.")
|
248
|
+
df_encoded = df_encoded.loc[:, ~df_encoded.columns.duplicated()]
|
249
|
+
|
250
|
+
# Calculate VIF for each feature
|
251
|
+
vif_data = pd.DataFrame()
|
252
|
+
vif_data["Feature"] = df_encoded.columns
|
253
|
+
try:
|
254
|
+
vif_data["VIF"] = [variance_inflation_factor(df_encoded.values, i) for i in range(df_encoded.shape[1])]
|
255
|
+
except np.linalg.LinAlgError:
|
256
|
+
print("LinAlgError: Unable to compute VIF due to matrix singularity.")
|
257
|
+
return df_encoded
|
258
|
+
|
259
|
+
print("Variance Inflation Factor (VIF) for each feature:")
|
260
|
+
print(vif_data)
|
261
|
+
|
262
|
+
# Drop columns with VIF > 10 (a common threshold to identify multicollinearity)
|
263
|
+
high_vif_columns = vif_data[vif_data["VIF"] > 10]["Feature"].tolist()
|
264
|
+
if high_vif_columns:
|
265
|
+
print(f"Dropping columns with high VIF: {high_vif_columns}")
|
266
|
+
df_encoded.drop(columns=high_vif_columns, inplace=True)
|
267
|
+
|
268
|
+
return df_encoded
|
269
|
+
|
270
|
+
# Step 1: Handle missing values in relevant fields
|
271
|
+
df = handle_missing_values(df, ['fraction', dependent_variable])
|
272
|
+
|
273
|
+
# Step 2: Ensure grna, gene, plate, row, column, and prc are categorical types
|
274
|
+
df = ensure_valid_types(df, ['grna', 'gene', 'plate', 'row', 'column', 'prc'])
|
275
|
+
|
276
|
+
# Step 3: Check for multicollinearity in fraction and the dependent variable
|
277
|
+
df_cleaned = check_collinearity(df, ['fraction', dependent_variable])
|
278
|
+
|
279
|
+
# Ensure that the prc, plate, row, and column columns are still included for random effects
|
280
|
+
df_cleaned['gene'] = df['gene']
|
281
|
+
df_cleaned['grna'] = df['grna']
|
282
|
+
df_cleaned['prc'] = df['prc']
|
283
|
+
df_cleaned['plate'] = df['plate']
|
284
|
+
df_cleaned['row'] = df['row']
|
285
|
+
df_cleaned['column'] = df['column']
|
286
|
+
|
287
|
+
#display(df_cleaned)
|
288
|
+
|
289
|
+
# Create a new column 'gene_fraction' that sums the fractions by gene within the same well
|
290
|
+
df_cleaned['gene_fraction'] = df_cleaned.groupby(['prc', 'gene'])['fraction'].transform('sum')
|
291
|
+
|
292
|
+
print("Data is ready for model fitting.")
|
293
|
+
return df_cleaned
|
294
|
+
|
295
|
+
def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0, random_row_column_effects=False, highlight='220950', dst=None, cov_type=None):
|
296
|
+
from .plot import volcano_plot, plot_histogram
|
297
|
+
|
298
|
+
# Generate the volcano filename
|
299
|
+
volcano_path = create_volcano_filename(csv_path, regression_type, alpha, dst)
|
300
|
+
|
301
|
+
# Check if the data is normally distributed
|
302
|
+
is_normal = check_normality(df[dependent_variable], dependent_variable)
|
303
|
+
|
304
|
+
if is_normal:
|
305
|
+
print(f"To avoid violating assumptions, it is recommended to use a regression model that assumes normality.")
|
306
|
+
print(f"Recommended regression type: ols (Ordinary Least Squares)")
|
307
|
+
else:
|
308
|
+
print(f"To avoid violating assumptions, it is recommended to use a regression model that does not assume normality.")
|
309
|
+
print(f"Recommended regression type: glm (Generalized Linear Model)")
|
310
|
+
|
311
|
+
# Determine regression type if not specified
|
312
|
+
if regression_type is None:
|
313
|
+
regression_type = 'ols' if is_normal else 'glm'
|
314
|
+
|
315
|
+
#display('before check_and_clean_data:',df)
|
316
|
+
df = check_and_clean_data(df, dependent_variable)
|
317
|
+
#display('after check_and_clean_data:',df)
|
318
|
+
|
319
|
+
# Handle mixed effects if row/column effect is treated as random
|
320
|
+
if random_row_column_effects:
|
321
|
+
regression_type = 'mixed'
|
322
|
+
formula = prepare_formula(dependent_variable, random_row_column_effects=True)
|
323
|
+
mixed_model, coef_df = fit_mixed_model(df, formula, dst)
|
324
|
+
model = mixed_model
|
325
|
+
else:
|
326
|
+
# Regular regression models
|
327
|
+
formula = prepare_formula(dependent_variable, random_row_column_effects=False)
|
328
|
+
y, X = dmatrices(formula, data=df, return_type='dataframe')
|
329
|
+
|
330
|
+
# Plot histogram of the dependent variable
|
331
|
+
plot_histogram(y, dependent_variable, dst=dst)
|
332
|
+
|
333
|
+
# Scale the independent variables and dependent variable
|
334
|
+
X, y = scale_variables(X, y)
|
335
|
+
|
336
|
+
# Perform the regression
|
337
|
+
groups = df['prc'] if regression_type == 'mixed' else None
|
338
|
+
print(f'performing {regression_type} regression')
|
339
|
+
|
340
|
+
model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, cov_type=cov_type)
|
341
|
+
|
342
|
+
# Process the model coefficients
|
343
|
+
coef_df = process_model_coefficients(model, regression_type, X, y, highlight)
|
344
|
+
|
345
|
+
# Plot the volcano plot
|
346
|
+
volcano_plot(coef_df, volcano_path)
|
347
|
+
|
348
|
+
return model, coef_df
|
349
|
+
|
350
|
+
def perform_regression(settings):
|
351
|
+
|
352
|
+
from .plot import plot_plates
|
353
|
+
from .utils import merge_regression_res_with_metadata, save_settings
|
354
|
+
from .settings import get_perform_regression_default_settings
|
355
|
+
from .toxo import go_term_enrichment_by_column, custom_volcano_plot
|
356
|
+
|
357
|
+
if isinstance(settings['score_data'], list) and isinstance(settings['count_data'], list):
|
358
|
+
settings['plate'] = None
|
359
|
+
if len(settings['score_data']) == 1:
|
360
|
+
settings['score_data'] = settings['score_data'][0]
|
361
|
+
if len(settings['count_data']) == 1:
|
362
|
+
settings['count_data'] = settings['count_data'][0]
|
363
|
+
else:
|
364
|
+
count_data_df = pd.DataFrame()
|
365
|
+
for i, count_data in enumerate(settings['count_data']):
|
366
|
+
df = pd.read_csv(count_data)
|
367
|
+
df['plate_name'] = f'plate{i+1}'
|
368
|
+
count_data_df = pd.concat([count_data_df, df])
|
369
|
+
print('Count data:', len(count_data_df))
|
370
|
+
|
371
|
+
score_data_df = pd.DataFrame()
|
372
|
+
for i, score_data in enumerate(settings['score_data']):
|
373
|
+
df = pd.read_csv(score_data)
|
374
|
+
df['plate_name'] = f'plate{i+1}'
|
375
|
+
score_data_df = pd.concat([score_data_df, df])
|
376
|
+
print('Score data:', len(score_data_df))
|
377
|
+
else:
|
378
|
+
count_data_df = pd.read_csv(settings['count_data'])
|
379
|
+
score_data_df = pd.read_csv(settings['score_data'])
|
380
|
+
|
381
|
+
reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
|
382
|
+
if settings['regression_type'] not in reg_types:
|
383
|
+
print(f'Possible regression types: {reg_types}')
|
384
|
+
raise ValueError(f"Unsupported regression type {settings['regression_type']}")
|
385
|
+
|
386
|
+
if settings['dependent_variable'] not in score_data_df.columns:
|
387
|
+
print(f'Columns in DataFrame:')
|
388
|
+
for col in score_data_df.columns:
|
389
|
+
print(col)
|
390
|
+
raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
|
391
|
+
|
392
|
+
if isinstance(settings['count_data'], list):
|
393
|
+
src = os.path.dirname(settings['count_data'][0])
|
394
|
+
csv_path = settings['count_data'][0]
|
395
|
+
else:
|
396
|
+
src = os.path.dirname(settings['count_data'])
|
397
|
+
csv_path = settings['count_data']
|
398
|
+
|
399
|
+
settings['src'] = src
|
400
|
+
fldr = 'results_' + settings['regression_type']
|
401
|
+
if isinstance(settings['count_data'], list):
|
402
|
+
fldr = fldr + '_list'
|
403
|
+
|
404
|
+
if settings['regression_type'] == 'quantile':
|
405
|
+
fldr = fldr + '_' + str(settings['alpha'])
|
406
|
+
|
407
|
+
res_folder = os.path.join(src, fldr)
|
408
|
+
os.makedirs(res_folder, exist_ok=True)
|
409
|
+
results_filename = 'results.csv'
|
410
|
+
hits_filename = 'results_significant.csv'
|
411
|
+
results_path=os.path.join(res_folder, results_filename)
|
412
|
+
hits_path=os.path.join(res_folder, hits_filename)
|
413
|
+
|
414
|
+
settings = get_perform_regression_default_settings(settings)
|
415
|
+
save_settings(settings, name='regression', show=True)
|
416
|
+
|
417
|
+
score_data_df = clean_controls(score_data_df, settings['pc'], settings['nc'], settings['other'])
|
418
|
+
|
419
|
+
if 'prediction_probability_class_1' in score_data_df.columns:
|
420
|
+
if not settings['class_1_threshold'] is None:
|
421
|
+
score_data_df['predictions'] = (score_data_df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
|
422
|
+
|
423
|
+
dependent_df, dependent_variable = process_scores(score_data_df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
|
424
|
+
|
425
|
+
independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plate'])
|
426
|
+
|
427
|
+
merged_df = pd.merge(independent_df, dependent_df, on='prc')
|
428
|
+
|
429
|
+
data_path = os.path.join(res_folder, 'regression_data.csv')
|
430
|
+
merged_df.to_csv(data_path, index=False)
|
431
|
+
|
432
|
+
merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
|
433
|
+
|
434
|
+
if settings['transform'] is None:
|
435
|
+
_ = plot_plates(score_data_df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'], dst = res_folder)
|
436
|
+
|
437
|
+
model, coef_df = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], highlight=settings['highlight'], dst=res_folder, cov_type=settings['cov_type'])
|
438
|
+
|
439
|
+
coef_df.to_csv(results_path, index=False)
|
440
|
+
|
441
|
+
if settings['regression_type'] == 'lasso':
|
442
|
+
significant = coef_df[coef_df['coefficient'] > 0]
|
443
|
+
|
444
|
+
else:
|
445
|
+
significant = coef_df[coef_df['p_value']<= 0.05]
|
446
|
+
#significant = significant[significant['coefficient'] > 0.1]
|
447
|
+
significant.sort_values(by='coefficient', ascending=False, inplace=True)
|
448
|
+
significant = significant[~significant['feature'].str.contains('row|column')]
|
449
|
+
|
450
|
+
if settings['regression_type'] == 'ols':
|
451
|
+
print(model.summary())
|
452
|
+
|
453
|
+
significant.to_csv(hits_path, index=False)
|
454
|
+
|
455
|
+
if isinstance(settings['metadata_files'], str):
|
456
|
+
settings['metadata_files'] = [settings['metadata_files']]
|
457
|
+
|
458
|
+
for metadata_file in settings['metadata_files']:
|
459
|
+
file = os.path.basename(metadata_file)
|
460
|
+
filename, _ = os.path.splitext(file)
|
461
|
+
_ = merge_regression_res_with_metadata(hits_path, metadata_file, name=filename)
|
462
|
+
merged_df = merge_regression_res_with_metadata(results_path, metadata_file, name=filename)
|
463
|
+
|
464
|
+
if settings['toxo']:
|
465
|
+
|
466
|
+
data_path = merged_df
|
467
|
+
base_dir = os.path.dirname(os.path.abspath(__file__))
|
468
|
+
metadata_path = os.path.join(base_dir, 'resources', 'data', 'lopit.csv')
|
469
|
+
|
470
|
+
custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', string_list=[settings['highlight']], point_size=50, figsize=20)
|
471
|
+
|
472
|
+
metadata_path = os.path.join(base_dir, 'resources', 'data', 'toxoplasma_metadata.csv')
|
473
|
+
|
474
|
+
go_term_enrichment_by_column(significant, metadata_path)
|
475
|
+
|
476
|
+
print('Significant Genes')
|
477
|
+
display(significant)
|
478
|
+
|
479
|
+
output = {'results':coef_df,
|
480
|
+
'significant':significant}
|
481
|
+
|
482
|
+
return output
|
483
|
+
|
484
|
+
def process_reads(csv_path, fraction_threshold, plate):
|
485
|
+
|
486
|
+
if isinstance(csv_path, pd.DataFrame):
|
487
|
+
csv_df = csv_path
|
488
|
+
else:
|
489
|
+
# Read the CSV file into a DataFrame
|
490
|
+
csv_df = pd.read_csv(csv_path)
|
491
|
+
|
492
|
+
if 'plate_name' in csv_df.columns:
|
493
|
+
csv_df = csv_df.rename(columns={'plate_name': 'plate'})
|
494
|
+
if 'column_name' in csv_df.columns:
|
495
|
+
csv_df = csv_df.rename(columns={'column_name': 'column'})
|
496
|
+
if 'row_name' in csv_df.columns:
|
497
|
+
csv_df = csv_df.rename(columns={'row_name': 'row'})
|
498
|
+
if 'grna_name' in csv_df.columns:
|
499
|
+
csv_df = csv_df.rename(columns={'grna_name': 'grna'})
|
500
|
+
if 'plate_row' in csv_df.columns:
|
501
|
+
csv_df[['plate', 'row']] = csv_df['plate_row'].str.split('_', expand=True)
|
502
|
+
if not 'plate' in csv_df.columns:
|
503
|
+
if not plate is None:
|
504
|
+
csv_df['plate'] = plate
|
505
|
+
else:
|
506
|
+
csv_df['plate'] = 'plate1'
|
507
|
+
|
508
|
+
# Ensure the necessary columns are present
|
509
|
+
if not all(col in csv_df.columns for col in ['row','column','grna','count']):
|
510
|
+
raise ValueError("The CSV file must contain 'grna', 'count', 'row', and 'column' columns.")
|
511
|
+
|
512
|
+
# Create the prc column
|
513
|
+
csv_df['prc'] = csv_df['plate'] + '_' + csv_df['row'] + '_' + csv_df['column']
|
514
|
+
|
515
|
+
# Group by prc and calculate the sum of counts
|
516
|
+
grouped_df = csv_df.groupby('prc')['count'].sum().reset_index()
|
517
|
+
grouped_df = grouped_df.rename(columns={'count': 'total_counts'})
|
518
|
+
merged_df = pd.merge(csv_df, grouped_df, on='prc')
|
519
|
+
merged_df['fraction'] = merged_df['count'] / merged_df['total_counts']
|
520
|
+
|
521
|
+
# Filter rows with fraction under the threshold
|
522
|
+
if fraction_threshold is not None:
|
523
|
+
observations_before = len(merged_df)
|
524
|
+
merged_df = merged_df[merged_df['fraction'] >= fraction_threshold]
|
525
|
+
observations_after = len(merged_df)
|
526
|
+
removed = observations_before - observations_after
|
527
|
+
print(f'Removed {removed} observation below fraction threshold: {fraction_threshold}')
|
528
|
+
|
529
|
+
merged_df = merged_df[['prc', 'grna', 'fraction']]
|
530
|
+
|
531
|
+
if not all(col in merged_df.columns for col in ['grna', 'gene']):
|
532
|
+
try:
|
533
|
+
merged_df[['org', 'gene', 'grna']] = merged_df['grna'].str.split('_', expand=True)
|
534
|
+
merged_df = merged_df.drop(columns=['org'])
|
535
|
+
merged_df['grna'] = merged_df['gene'] + '_' + merged_df['grna']
|
536
|
+
except:
|
537
|
+
print('Error splitting grna into org, gene, grna.')
|
538
|
+
|
539
|
+
return merged_df
|
540
|
+
|
541
|
+
def apply_transformation(X, transform):
|
542
|
+
if transform == 'log':
|
543
|
+
transformer = FunctionTransformer(np.log1p, validate=True)
|
544
|
+
elif transform == 'sqrt':
|
545
|
+
transformer = FunctionTransformer(np.sqrt, validate=True)
|
546
|
+
elif transform == 'square':
|
547
|
+
transformer = FunctionTransformer(np.square, validate=True)
|
548
|
+
else:
|
549
|
+
transformer = None
|
550
|
+
return transformer
|
551
|
+
|
552
|
+
def check_normality(data, variable_name, verbose=False):
|
553
|
+
"""Check if the data is normally distributed using the Shapiro-Wilk test."""
|
554
|
+
stat, p_value = shapiro(data)
|
555
|
+
if verbose:
|
556
|
+
print(f"Shapiro-Wilk Test for {variable_name}:\nStatistic: {stat}, P-value: {p_value}")
|
557
|
+
if p_value > 0.05:
|
558
|
+
if verbose:
|
559
|
+
print(f"Normal distribution: The data for {variable_name} is normally distributed.")
|
560
|
+
return True
|
561
|
+
else:
|
562
|
+
if verbose:
|
563
|
+
print(f"Normal distribution: The data for {variable_name} is not normally distributed.")
|
564
|
+
return False
|
565
|
+
|
566
|
+
def clean_controls(df,pc,nc,other):
|
567
|
+
if 'col' in df.columns:
|
568
|
+
df['column'] = df['col']
|
569
|
+
if nc != None:
|
570
|
+
df = df[~df['column'].isin([nc])]
|
571
|
+
if pc != None:
|
572
|
+
df = df[~df['column'].isin([pc])]
|
573
|
+
if other != None:
|
574
|
+
df = df[~df['column'].isin([other])]
|
575
|
+
print(f'Removed data from {nc, pc, other}')
|
576
|
+
return df
|
577
|
+
|
578
|
+
def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
|
579
|
+
|
580
|
+
if 'plate_name' in df.columns:
|
581
|
+
df.drop(columns=['plate'], inplace=True)
|
582
|
+
df = df.rename(columns={'plate_name': 'plate'})
|
583
|
+
|
584
|
+
if plate is not None:
|
585
|
+
df['plate'] = plate
|
586
|
+
|
587
|
+
if 'col' not in df.columns:
|
588
|
+
df['col'] = df['column']
|
589
|
+
|
590
|
+
df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
|
591
|
+
df = df[['prc', dependent_variable]]
|
592
|
+
|
593
|
+
# Group by prc and calculate the mean and count of the dependent_variable
|
594
|
+
grouped = df.groupby('prc')[dependent_variable]
|
595
|
+
|
596
|
+
if regression_type != 'poisson':
|
597
|
+
|
598
|
+
print(f'Using agg_type: {agg_type}')
|
599
|
+
|
600
|
+
if agg_type == 'median':
|
601
|
+
dependent_df = grouped.median().reset_index()
|
602
|
+
elif agg_type == 'mean':
|
603
|
+
dependent_df = grouped.mean().reset_index()
|
604
|
+
elif agg_type == 'quantile':
|
605
|
+
dependent_df = grouped.quantile(0.75).reset_index()
|
606
|
+
elif agg_type == None:
|
607
|
+
dependent_df = df.reset_index()
|
608
|
+
if 'prcfo' in dependent_df.columns:
|
609
|
+
dependent_df = dependent_df.drop(columns=['prcfo'])
|
610
|
+
else:
|
611
|
+
raise ValueError(f"Unsupported aggregation type {agg_type}")
|
612
|
+
|
613
|
+
if regression_type == 'poisson':
|
614
|
+
agg_type = 'count'
|
615
|
+
print(f'Using agg_type: {agg_type} for poisson regression')
|
616
|
+
dependent_df = grouped.sum().reset_index()
|
617
|
+
|
618
|
+
# Calculate cell_count for all cases
|
619
|
+
cell_count = grouped.size().reset_index(name='cell_count')
|
620
|
+
|
621
|
+
if agg_type is None:
|
622
|
+
dependent_df = pd.merge(dependent_df, cell_count, on='prc')
|
623
|
+
else:
|
624
|
+
dependent_df['cell_count'] = cell_count['cell_count']
|
625
|
+
|
626
|
+
dependent_df = dependent_df[dependent_df['cell_count'] >= min_cell_count]
|
627
|
+
|
628
|
+
is_normal = check_normality(dependent_df[dependent_variable], dependent_variable)
|
629
|
+
|
630
|
+
if not transform is None:
|
631
|
+
transformer = apply_transformation(dependent_df[dependent_variable], transform=transform)
|
632
|
+
transformed_var = f'{transform}_{dependent_variable}'
|
633
|
+
dependent_df[transformed_var] = transformer.fit_transform(dependent_df[[dependent_variable]])
|
634
|
+
dependent_variable = transformed_var
|
635
|
+
is_normal = check_normality(dependent_df[transformed_var], transformed_var)
|
636
|
+
|
637
|
+
if not is_normal:
|
638
|
+
print(f'{dependent_variable} is not normally distributed')
|
639
|
+
else:
|
640
|
+
print(f'{dependent_variable} is normally distributed')
|
641
|
+
|
642
|
+
return dependent_df, dependent_variable
|
643
|
+
|
644
|
+
def generate_ml_scores(settings):
|
645
|
+
|
646
|
+
from .io import _read_and_merge_data
|
647
|
+
from .plot import plot_plates
|
648
|
+
from .utils import get_ml_results_paths
|
649
|
+
from .settings import set_default_analyze_screen
|
650
|
+
|
651
|
+
settings = set_default_analyze_screen(settings)
|
652
|
+
|
653
|
+
src = settings['src']
|
654
|
+
|
655
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
656
|
+
display(settings_df)
|
657
|
+
|
658
|
+
db_loc = [src+'/measurements/measurements.db']
|
659
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
660
|
+
|
661
|
+
nuclei_limit, pathogen_limit, uninfected = settings['nuclei_limit'], settings['pathogen_limit'], settings['uninfected']
|
662
|
+
|
663
|
+
df, _ = _read_and_merge_data(db_loc,
|
664
|
+
tables,
|
665
|
+
settings['verbose'],
|
666
|
+
nuclei_limit,
|
667
|
+
pathogen_limit,
|
668
|
+
uninfected)
|
669
|
+
|
670
|
+
if settings['channel_of_interest'] in [0,1,2,3]:
|
671
|
+
|
672
|
+
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
673
|
+
|
674
|
+
output, figs = ml_analysis(df,
|
675
|
+
settings['channel_of_interest'],
|
676
|
+
settings['location_column'],
|
677
|
+
settings['positive_control'],
|
678
|
+
settings['negative_control'],
|
679
|
+
settings['exclude'],
|
680
|
+
settings['n_repeats'],
|
681
|
+
settings['top_features'],
|
682
|
+
settings['n_estimators'],
|
683
|
+
settings['test_size'],
|
684
|
+
settings['model_type_ml'],
|
685
|
+
settings['n_jobs'],
|
686
|
+
settings['remove_low_variance_features'],
|
687
|
+
settings['remove_highly_correlated_features'],
|
688
|
+
settings['verbose'])
|
689
|
+
|
690
|
+
shap_fig = shap_analysis(output[3], output[4], output[5])
|
691
|
+
|
692
|
+
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
693
|
+
|
694
|
+
if not settings['heatmap_feature'] in features:
|
695
|
+
raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
|
696
|
+
|
697
|
+
plate_heatmap = plot_plates(df=output[0],
|
698
|
+
variable=settings['heatmap_feature'],
|
699
|
+
grouping=settings['grouping'],
|
700
|
+
min_max=settings['min_max'],
|
701
|
+
cmap=settings['cmap'],
|
702
|
+
min_count=settings['minimum_cell_count'],
|
703
|
+
verbose=settings['verbose'])
|
704
|
+
|
705
|
+
data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
|
706
|
+
df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
|
707
|
+
|
708
|
+
settings_df.to_csv(settings_csv, index=False)
|
709
|
+
df.to_csv(data_path, mode='w', encoding='utf-8')
|
710
|
+
permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
|
711
|
+
feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
|
712
|
+
metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
|
713
|
+
|
714
|
+
plate_heatmap.savefig(plate_heatmap_path, format='pdf')
|
715
|
+
figs[0].savefig(permutation_fig_path, format='pdf')
|
716
|
+
figs[1].savefig(feature_importance_fig_path, format='pdf')
|
717
|
+
shap_fig.savefig(shap_fig_path, format='pdf')
|
718
|
+
|
719
|
+
return [output, plate_heatmap]
|
720
|
+
|
721
|
+
def ml_analysis(df, channel_of_interest=3, location_column='col', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=100, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, verbose=False):
|
722
|
+
"""
|
723
|
+
Calculates permutation importance for numerical features in the dataframe,
|
724
|
+
comparing groups based on specified column values and uses the model to predict
|
725
|
+
the class for all other rows in the dataframe.
|
726
|
+
|
727
|
+
Args:
|
728
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
729
|
+
feature_string (str): String to filter features that contain this substring.
|
730
|
+
location_column (str): Column name to use for comparing groups.
|
731
|
+
positive_control, negative_control (str): Values in location_column to create subsets for comparison.
|
732
|
+
exclude (list or str, optional): Columns to exclude from features.
|
733
|
+
n_repeats (int): Number of repeats for permutation importance.
|
734
|
+
top_features (int): Number of top features to plot based on permutation importance.
|
735
|
+
n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
|
736
|
+
test_size (float): Proportion of the dataset to include in the test split.
|
737
|
+
random_state (int): Random seed for reproducibility.
|
738
|
+
model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
|
739
|
+
n_jobs (int): Number of jobs to run in parallel for applicable models.
|
740
|
+
|
741
|
+
Returns:
|
742
|
+
pandas.DataFrame: The original dataframe with added prediction and data usage columns.
|
743
|
+
pandas.DataFrame: DataFrame containing the importances and standard deviations.
|
744
|
+
"""
|
745
|
+
|
746
|
+
from .utils import filter_dataframe_features
|
747
|
+
from .plot import plot_permutation, plot_feature_importance
|
748
|
+
|
749
|
+
random_state = 42
|
750
|
+
|
751
|
+
if 'cells_per_well' in df.columns:
|
752
|
+
df = df.drop(columns=['cells_per_well'])
|
753
|
+
|
754
|
+
|
755
|
+
df_metadata = df[[location_column]].copy()
|
756
|
+
|
757
|
+
df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
|
758
|
+
print('After filtration:', len(df))
|
759
|
+
|
760
|
+
if verbose:
|
761
|
+
print(f'Found {len(features)} numerical features in the dataframe')
|
762
|
+
print(f'Features used in training: {features}')
|
763
|
+
df = pd.concat([df, df_metadata[location_column]], axis=1)
|
764
|
+
|
765
|
+
# Subset the dataframe based on specified column values
|
766
|
+
df1 = df[df[location_column] == negative_control].copy()
|
767
|
+
df2 = df[df[location_column] == positive_control].copy()
|
768
|
+
|
769
|
+
# Create target variable
|
770
|
+
df1['target'] = 0 # Negative control
|
771
|
+
df2['target'] = 1 # Positive control
|
772
|
+
|
773
|
+
# Combine the subsets for analysis
|
774
|
+
combined_df = pd.concat([df1, df2])
|
775
|
+
combined_df = combined_df.drop(columns=[location_column])
|
776
|
+
if verbose:
|
777
|
+
print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
|
778
|
+
|
779
|
+
X = combined_df[features]
|
780
|
+
y = combined_df['target']
|
781
|
+
|
782
|
+
print(X)
|
783
|
+
print(y)
|
784
|
+
|
785
|
+
# Split the data into training and testing sets
|
786
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
787
|
+
|
788
|
+
# Add data usage labels
|
789
|
+
combined_df['data_usage'] = 'train'
|
790
|
+
combined_df.loc[X_test.index, 'data_usage'] = 'test'
|
791
|
+
df['data_usage'] = 'not_used'
|
792
|
+
df.loc[combined_df.index, 'data_usage'] = combined_df['data_usage']
|
793
|
+
|
794
|
+
# Initialize the model based on model_type
|
795
|
+
if model_type == 'random_forest':
|
796
|
+
model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
|
797
|
+
elif model_type == 'logistic_regression':
|
798
|
+
model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
|
799
|
+
elif model_type == 'gradient_boosting':
|
800
|
+
model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
|
801
|
+
elif model_type == 'xgboost':
|
802
|
+
model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
|
803
|
+
else:
|
804
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
805
|
+
|
806
|
+
model.fit(X_train, y_train)
|
807
|
+
|
808
|
+
perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
|
809
|
+
|
810
|
+
# Create a DataFrame for permutation importances
|
811
|
+
permutation_df = pd.DataFrame({
|
812
|
+
'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
|
813
|
+
'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
|
814
|
+
'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
|
815
|
+
}).tail(top_features)
|
816
|
+
|
817
|
+
permutation_fig = plot_permutation(permutation_df)
|
818
|
+
if verbose:
|
819
|
+
permutation_fig.show()
|
820
|
+
|
821
|
+
# Feature importance for models that support it
|
822
|
+
if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
|
823
|
+
feature_importances = model.feature_importances_
|
824
|
+
feature_importance_df = pd.DataFrame({
|
825
|
+
'feature': features,
|
826
|
+
'importance': feature_importances
|
827
|
+
}).sort_values(by='importance', ascending=False).head(top_features)
|
828
|
+
|
829
|
+
feature_importance_fig = plot_feature_importance(feature_importance_df)
|
830
|
+
if verbose:
|
831
|
+
feature_importance_fig.show()
|
832
|
+
|
833
|
+
else:
|
834
|
+
feature_importance_df = pd.DataFrame()
|
835
|
+
|
836
|
+
# Predicting the target variable for the test set
|
837
|
+
predictions_test = model.predict(X_test)
|
838
|
+
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
839
|
+
|
840
|
+
# Get prediction probabilities for the test set
|
841
|
+
prediction_probabilities_test = model.predict_proba(X_test)
|
842
|
+
|
843
|
+
# Find the optimal threshold
|
844
|
+
optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
|
845
|
+
if verbose:
|
846
|
+
print(f'Optimal threshold: {optimal_threshold}')
|
847
|
+
|
848
|
+
# Predicting the target variable for all other rows in the dataframe
|
849
|
+
X_all = df[features]
|
850
|
+
all_predictions = model.predict(X_all)
|
851
|
+
df['predictions'] = all_predictions
|
852
|
+
|
853
|
+
# Get prediction probabilities for all rows in the dataframe
|
854
|
+
prediction_probabilities = model.predict_proba(X_all)
|
855
|
+
for i in range(prediction_probabilities.shape[1]):
|
856
|
+
df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
|
857
|
+
if verbose:
|
858
|
+
print("\nClassification Report:")
|
859
|
+
print(classification_report(y_test, predictions_test))
|
860
|
+
report_dict = classification_report(y_test, predictions_test, output_dict=True)
|
861
|
+
metrics_df = pd.DataFrame(report_dict).transpose()
|
862
|
+
|
863
|
+
df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
|
864
|
+
|
865
|
+
df['prcfo'] = df.index.astype(str)
|
866
|
+
df[['plate', 'row', 'col', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
|
867
|
+
df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
|
868
|
+
|
869
|
+
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
|
870
|
+
|
871
|
+
def shap_analysis(model, X_train, X_test):
|
872
|
+
|
873
|
+
"""
|
874
|
+
Performs SHAP analysis on the given model and data.
|
875
|
+
|
876
|
+
Args:
|
877
|
+
model: The trained model.
|
878
|
+
X_train (pandas.DataFrame): Training feature set.
|
879
|
+
X_test (pandas.DataFrame): Testing feature set.
|
880
|
+
Returns:
|
881
|
+
fig: Matplotlib figure object containing the SHAP summary plot.
|
882
|
+
"""
|
883
|
+
|
884
|
+
explainer = shap.Explainer(model, X_train)
|
885
|
+
shap_values = explainer(X_test)
|
886
|
+
# Create a new figure
|
887
|
+
fig, ax = plt.subplots()
|
888
|
+
# Summary plot
|
889
|
+
shap.summary_plot(shap_values, X_test, show=False)
|
890
|
+
# Save the current figure (the one that SHAP just created)
|
891
|
+
fig = plt.gcf()
|
892
|
+
plt.close(fig) # Close the figure to prevent it from displaying immediately
|
893
|
+
return fig
|
894
|
+
|
895
|
+
def find_optimal_threshold(y_true, y_pred_proba):
|
896
|
+
"""
|
897
|
+
Find the optimal threshold for binary classification based on the F1-score.
|
898
|
+
|
899
|
+
Args:
|
900
|
+
y_true (array-like): True binary labels.
|
901
|
+
y_pred_proba (array-like): Predicted probabilities for the positive class.
|
902
|
+
|
903
|
+
Returns:
|
904
|
+
float: The optimal threshold.
|
905
|
+
"""
|
906
|
+
precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
|
907
|
+
f1_scores = 2 * (precision * recall) / (precision + recall)
|
908
|
+
optimal_idx = np.argmax(f1_scores)
|
909
|
+
optimal_threshold = thresholds[optimal_idx]
|
910
|
+
return optimal_threshold
|
911
|
+
|
912
|
+
def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
913
|
+
"""
|
914
|
+
Calculate similarity scores of each well to the positive and negative controls using various metrics.
|
915
|
+
|
916
|
+
Args:
|
917
|
+
df (pandas.DataFrame): DataFrame containing the data.
|
918
|
+
features (list): List of feature columns to use for similarity calculation.
|
919
|
+
col_to_compare (str): Column name to use for comparing groups.
|
920
|
+
val1, val2 (str): Values in col_to_compare to create subsets for comparison.
|
921
|
+
|
922
|
+
Returns:
|
923
|
+
pandas.DataFrame: DataFrame with similarity scores.
|
924
|
+
"""
|
925
|
+
# Separate positive and negative control wells
|
926
|
+
pos_control = df[df[col_to_compare] == val1][features].mean()
|
927
|
+
neg_control = df[df[col_to_compare] == val2][features].mean()
|
928
|
+
|
929
|
+
# Standardize features for Mahalanobis distance
|
930
|
+
scaler = StandardScaler()
|
931
|
+
scaled_features = scaler.fit_transform(df[features])
|
932
|
+
|
933
|
+
# Regularize the covariance matrix to avoid singularity
|
934
|
+
cov_matrix = np.cov(scaled_features, rowvar=False)
|
935
|
+
inv_cov_matrix = None
|
936
|
+
try:
|
937
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix)
|
938
|
+
except np.linalg.LinAlgError:
|
939
|
+
# Add a small value to the diagonal elements for regularization
|
940
|
+
epsilon = 1e-5
|
941
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
|
942
|
+
|
943
|
+
# Calculate similarity scores
|
944
|
+
df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
|
945
|
+
df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
|
946
|
+
df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
|
947
|
+
df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
|
948
|
+
df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
|
949
|
+
df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
|
950
|
+
df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
|
951
|
+
df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
|
952
|
+
df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
|
953
|
+
df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
|
954
|
+
df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
|
955
|
+
df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
|
956
|
+
df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
|
957
|
+
df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
|
958
|
+
df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
|
959
|
+
df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
|
960
|
+
df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
|
961
|
+
df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
|
962
|
+
|
963
|
+
return df
|
964
|
+
|