spacr 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +140 -2493
  4. spacr/deep_spacr.py +151 -29
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +624 -44
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +964 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +280 -15
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +129 -43
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +348 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +233 -0
  26. spacr/utils.py +271 -171
  27. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
  28. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
spacr/ml.py ADDED
@@ -0,0 +1,964 @@
1
+ import os, shap
2
+ import pandas as pd
3
+ import numpy as np
4
+ from scipy import stats
5
+ from scipy.stats import shapiro
6
+
7
+ import matplotlib.pyplot as plt
8
+ from IPython.display import display
9
+
10
+ import statsmodels.api as sm
11
+ import statsmodels.formula.api as smf
12
+ from statsmodels.regression.mixed_linear_model import MixedLM
13
+ from statsmodels.tools.sm_exceptions import PerfectSeparationError
14
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
15
+
16
+ from sklearn.linear_model import Lasso, Ridge
17
+ from sklearn.preprocessing import FunctionTransformer
18
+ from patsy import dmatrices
19
+
20
+ from sklearn.model_selection import train_test_split
21
+ from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
22
+ from sklearn.linear_model import LogisticRegression
23
+ from sklearn.inspection import permutation_importance
24
+ from sklearn.metrics import classification_report, precision_recall_curve
25
+ from sklearn.preprocessing import StandardScaler
26
+ from sklearn.preprocessing import MinMaxScaler
27
+
28
+ from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
29
+
30
+ from xgboost import XGBClassifier
31
+
32
+ import matplotlib
33
+ matplotlib.use('Agg')
34
+
35
+ import warnings
36
+ warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
37
+
38
+ def calculate_p_values(X, y, model):
39
+ # Predict y values
40
+ y_pred = model.predict(X)
41
+ # Calculate residuals
42
+ residuals = y - y_pred
43
+ # Calculate the standard error of the residuals
44
+ dof = X.shape[0] - X.shape[1] - 1
45
+ residual_std_error = np.sqrt(np.sum(residuals ** 2) / dof)
46
+ # Calculate the standard error of the coefficients
47
+ X_design = np.hstack((np.ones((X.shape[0], 1)), X)) # Add intercept
48
+ # Use pseudoinverse instead of inverse to handle singular matrices
49
+ coef_var_covar = residual_std_error ** 2 * np.linalg.pinv(X_design.T @ X_design)
50
+ coef_standard_errors = np.sqrt(np.diag(coef_var_covar))
51
+ # Calculate t-statistics
52
+ t_stats = model.coef_ / coef_standard_errors[1:] # Skip intercept error
53
+ # Calculate p-values
54
+ p_values = [2 * (1 - stats.t.cdf(np.abs(t), dof)) for t in t_stats]
55
+ return np.array(p_values) # Ensure p_values is a 1-dimensional array
56
+
57
+ def perform_mixed_model(y, X, groups, alpha=1.0):
58
+ # Ensure groups are defined correctly and check for multicollinearity
59
+ if groups is None:
60
+ raise ValueError("Groups must be defined for mixed model regression")
61
+
62
+ # Check for multicollinearity by calculating the VIF for each feature
63
+ X_np = X.values
64
+ vif = [variance_inflation_factor(X_np, i) for i in range(X_np.shape[1])]
65
+ print(f"VIF: {vif}")
66
+ if any(v > 10 for v in vif):
67
+ print(f"Multicollinearity detected with VIF: {vif}. Applying Ridge regression to the fixed effects.")
68
+ ridge = Ridge(alpha=alpha)
69
+ ridge.fit(X, y)
70
+ X_ridge = ridge.coef_ * X # Adjust X with Ridge coefficients
71
+ model = MixedLM(y, X_ridge, groups=groups)
72
+ else:
73
+ model = MixedLM(y, X, groups=groups)
74
+
75
+ result = model.fit()
76
+ return result
77
+
78
+ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
79
+
80
+ def plot_regression_line(X, y, model):
81
+ """Helper to plot regression line for lasso and ridge models."""
82
+ y_pred = model.predict(X)
83
+ plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
84
+ plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
85
+ plt.xlabel('Features')
86
+ plt.ylabel('Dependent Variable')
87
+ plt.legend()
88
+ plt.show()
89
+
90
+ # Define the dictionary with callables (lambdas) to delay evaluation
91
+ model_map = {
92
+ 'ols': lambda: sm.OLS(y, X).fit(cov_type=cov_type) if cov_type else sm.OLS(y, X).fit(),
93
+ 'gls': lambda: sm.GLS(y, X).fit(),
94
+ 'wls': lambda: sm.WLS(y, X, weights=1 / np.sqrt(X.iloc[:, 1])).fit(),
95
+ 'rlm': lambda: sm.RLM(y, X, M=sm.robust.norms.HuberT()).fit(),
96
+ 'glm': lambda: sm.GLM(y, X, family=sm.families.Gaussian()).fit(),
97
+ 'quantile': lambda: sm.QuantReg(y, X).fit(q=alpha),
98
+ 'logit': lambda: sm.Logit(y, X).fit(),
99
+ 'probit': lambda: sm.Probit(y, X).fit(),
100
+ 'poisson': lambda: sm.Poisson(y, X).fit(),
101
+ 'lasso': lambda: Lasso(alpha=alpha).fit(X, y),
102
+ 'ridge': lambda: Ridge(alpha=alpha).fit(X, y)
103
+ }
104
+
105
+ # Call the appropriate model only when needed
106
+ if regression_type in model_map:
107
+ model = model_map[regression_type]()
108
+ elif regression_type == 'mixed':
109
+ model = perform_mixed_model(y, X, groups, alpha=alpha)
110
+ else:
111
+ raise ValueError(f"Unsupported regression type {regression_type}")
112
+
113
+ if regression_type in ['lasso', 'ridge']:
114
+ plot_regression_line(X, y, model)
115
+
116
+ return model
117
+
118
+ def create_volcano_filename(csv_path, regression_type, alpha, dst):
119
+ """Create and return the volcano plot filename based on regression type and alpha."""
120
+ volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
121
+ volcano_filename = f"{regression_type}_{volcano_filename}" if regression_type != 'quantile' else f"{alpha}_{volcano_filename}"
122
+
123
+ if dst:
124
+ return os.path.join(dst, volcano_filename)
125
+ return os.path.join(os.path.dirname(csv_path), volcano_filename)
126
+
127
+ def scale_variables(X, y):
128
+ """Scale independent (X) and dependent (y) variables using MinMaxScaler."""
129
+ scaler_X = MinMaxScaler()
130
+ scaler_y = MinMaxScaler()
131
+
132
+ X_scaled = pd.DataFrame(scaler_X.fit_transform(X), columns=X.columns)
133
+ y_scaled = scaler_y.fit_transform(y)
134
+
135
+ return X_scaled, y_scaled
136
+
137
+ def process_model_coefficients(model, regression_type, X, y, highlight):
138
+ """Return DataFrame of model coefficients and p-values."""
139
+ if regression_type in ['ols', 'gls', 'wls', 'rlm', 'glm', 'mixed', 'quantile', 'logit', 'probit', 'poisson']:
140
+ coefs = model.params
141
+ p_values = model.pvalues
142
+
143
+ coef_df = pd.DataFrame({
144
+ 'feature': coefs.index,
145
+ 'coefficient': coefs.values,
146
+ 'p_value': p_values.values
147
+ })
148
+
149
+ elif regression_type in ['ridge', 'lasso']:
150
+ coefs = model.coef_.flatten()
151
+ p_values = calculate_p_values(X, y, model)
152
+
153
+ coef_df = pd.DataFrame({
154
+ 'feature': X.columns,
155
+ 'coefficient': coefs,
156
+ 'p_value': p_values
157
+ })
158
+
159
+ else:
160
+ coefs = model.coef_
161
+ intercept = model.intercept_
162
+ feature_names = X.design_info.column_names
163
+
164
+ coef_df = pd.DataFrame({
165
+ 'feature': feature_names,
166
+ 'coefficient': coefs
167
+ })
168
+ coef_df.loc[0, 'coefficient'] += intercept
169
+ coef_df['p_value'] = np.nan # Placeholder since sklearn doesn't provide p-values
170
+
171
+ coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
172
+ coef_df['highlight'] = coef_df['feature'].apply(lambda x: highlight in x)
173
+
174
+ return coef_df[~coef_df['feature'].str.contains('row|column')]
175
+
176
+ def prepare_formula(dependent_variable, random_row_column_effects=False):
177
+ """Return the regression formula using random effects for plate, row, and column."""
178
+ if random_row_column_effects:
179
+ # Random effects for row and column + gene weighted by gene_fraction + grna weighted by fraction
180
+ return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene'
181
+ return f'{dependent_variable} ~ fraction:grna + gene_fraction:gene + row + column'
182
+
183
+ def fit_mixed_model(df, formula, dst):
184
+ from .plot import plot_histogram
185
+
186
+ """Fit the mixed model with plate, row, and column as random effects and return results."""
187
+ # Specify random effects for plate, row, and column
188
+ model = smf.mixedlm(formula,
189
+ data=df,
190
+ groups=df['plate'],
191
+ re_formula="1 + row + column",
192
+ vc_formula={"row": "0 + row", "column": "0 + column"})
193
+
194
+ mixed_model = model.fit()
195
+
196
+ # Plot residuals
197
+ df['residuals'] = mixed_model.resid
198
+ plot_histogram(df, 'residuals', dst=dst)
199
+
200
+ # Return coefficients and p-values
201
+ coefs = mixed_model.params
202
+ p_values = mixed_model.pvalues
203
+
204
+ coef_df = pd.DataFrame({
205
+ 'feature': coefs.index,
206
+ 'coefficient': coefs.values,
207
+ 'p_value': p_values.values
208
+ })
209
+
210
+ return mixed_model, coef_df
211
+
212
+ def check_and_clean_data(df, dependent_variable):
213
+ """Check for collinearity, missing values, or invalid types in relevant columns. Clean data accordingly."""
214
+
215
+ def handle_missing_values(df, columns):
216
+ """Handle missing values in specified columns."""
217
+ missing_summary = df[columns].isnull().sum()
218
+ print("Missing values summary:")
219
+ print(missing_summary)
220
+
221
+ # Drop rows with missing values in these fields
222
+ df_cleaned = df.dropna(subset=columns)
223
+ if df_cleaned.shape[0] < df.shape[0]:
224
+ print(f"Dropped {df.shape[0] - df_cleaned.shape[0]} rows with missing values in {columns}.")
225
+ return df_cleaned
226
+
227
+ def ensure_valid_types(df, columns):
228
+ """Ensure that specified columns are categorical."""
229
+ for col in columns:
230
+ if not pd.api.types.is_categorical_dtype(df[col]):
231
+ df[col] = pd.Categorical(df[col])
232
+ print(f"Converted {col} to categorical type.")
233
+ return df
234
+
235
+ def check_collinearity(df, columns):
236
+ """Check for collinearity using VIF (Variance Inflation Factor)."""
237
+ print("Checking for collinearity...")
238
+
239
+ # Only include fraction and the dependent variable for collinearity check
240
+ df_encoded = df[columns]
241
+
242
+ # Ensure all data in df_encoded is numeric
243
+ df_encoded = df_encoded.apply(pd.to_numeric, errors='coerce')
244
+
245
+ # Check for perfect multicollinearity (i.e., rank deficiency)
246
+ if np.linalg.matrix_rank(df_encoded.values) < df_encoded.shape[1]:
247
+ print("Warning: Perfect multicollinearity detected! Dropping correlated columns.")
248
+ df_encoded = df_encoded.loc[:, ~df_encoded.columns.duplicated()]
249
+
250
+ # Calculate VIF for each feature
251
+ vif_data = pd.DataFrame()
252
+ vif_data["Feature"] = df_encoded.columns
253
+ try:
254
+ vif_data["VIF"] = [variance_inflation_factor(df_encoded.values, i) for i in range(df_encoded.shape[1])]
255
+ except np.linalg.LinAlgError:
256
+ print("LinAlgError: Unable to compute VIF due to matrix singularity.")
257
+ return df_encoded
258
+
259
+ print("Variance Inflation Factor (VIF) for each feature:")
260
+ print(vif_data)
261
+
262
+ # Drop columns with VIF > 10 (a common threshold to identify multicollinearity)
263
+ high_vif_columns = vif_data[vif_data["VIF"] > 10]["Feature"].tolist()
264
+ if high_vif_columns:
265
+ print(f"Dropping columns with high VIF: {high_vif_columns}")
266
+ df_encoded.drop(columns=high_vif_columns, inplace=True)
267
+
268
+ return df_encoded
269
+
270
+ # Step 1: Handle missing values in relevant fields
271
+ df = handle_missing_values(df, ['fraction', dependent_variable])
272
+
273
+ # Step 2: Ensure grna, gene, plate, row, column, and prc are categorical types
274
+ df = ensure_valid_types(df, ['grna', 'gene', 'plate', 'row', 'column', 'prc'])
275
+
276
+ # Step 3: Check for multicollinearity in fraction and the dependent variable
277
+ df_cleaned = check_collinearity(df, ['fraction', dependent_variable])
278
+
279
+ # Ensure that the prc, plate, row, and column columns are still included for random effects
280
+ df_cleaned['gene'] = df['gene']
281
+ df_cleaned['grna'] = df['grna']
282
+ df_cleaned['prc'] = df['prc']
283
+ df_cleaned['plate'] = df['plate']
284
+ df_cleaned['row'] = df['row']
285
+ df_cleaned['column'] = df['column']
286
+
287
+ #display(df_cleaned)
288
+
289
+ # Create a new column 'gene_fraction' that sums the fractions by gene within the same well
290
+ df_cleaned['gene_fraction'] = df_cleaned.groupby(['prc', 'gene'])['fraction'].transform('sum')
291
+
292
+ print("Data is ready for model fitting.")
293
+ return df_cleaned
294
+
295
+ def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0, random_row_column_effects=False, highlight='220950', dst=None, cov_type=None):
296
+ from .plot import volcano_plot, plot_histogram
297
+
298
+ # Generate the volcano filename
299
+ volcano_path = create_volcano_filename(csv_path, regression_type, alpha, dst)
300
+
301
+ # Check if the data is normally distributed
302
+ is_normal = check_normality(df[dependent_variable], dependent_variable)
303
+
304
+ if is_normal:
305
+ print(f"To avoid violating assumptions, it is recommended to use a regression model that assumes normality.")
306
+ print(f"Recommended regression type: ols (Ordinary Least Squares)")
307
+ else:
308
+ print(f"To avoid violating assumptions, it is recommended to use a regression model that does not assume normality.")
309
+ print(f"Recommended regression type: glm (Generalized Linear Model)")
310
+
311
+ # Determine regression type if not specified
312
+ if regression_type is None:
313
+ regression_type = 'ols' if is_normal else 'glm'
314
+
315
+ #display('before check_and_clean_data:',df)
316
+ df = check_and_clean_data(df, dependent_variable)
317
+ #display('after check_and_clean_data:',df)
318
+
319
+ # Handle mixed effects if row/column effect is treated as random
320
+ if random_row_column_effects:
321
+ regression_type = 'mixed'
322
+ formula = prepare_formula(dependent_variable, random_row_column_effects=True)
323
+ mixed_model, coef_df = fit_mixed_model(df, formula, dst)
324
+ model = mixed_model
325
+ else:
326
+ # Regular regression models
327
+ formula = prepare_formula(dependent_variable, random_row_column_effects=False)
328
+ y, X = dmatrices(formula, data=df, return_type='dataframe')
329
+
330
+ # Plot histogram of the dependent variable
331
+ plot_histogram(y, dependent_variable, dst=dst)
332
+
333
+ # Scale the independent variables and dependent variable
334
+ X, y = scale_variables(X, y)
335
+
336
+ # Perform the regression
337
+ groups = df['prc'] if regression_type == 'mixed' else None
338
+ print(f'performing {regression_type} regression')
339
+
340
+ model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, cov_type=cov_type)
341
+
342
+ # Process the model coefficients
343
+ coef_df = process_model_coefficients(model, regression_type, X, y, highlight)
344
+
345
+ # Plot the volcano plot
346
+ volcano_plot(coef_df, volcano_path)
347
+
348
+ return model, coef_df
349
+
350
+ def perform_regression(settings):
351
+
352
+ from .plot import plot_plates
353
+ from .utils import merge_regression_res_with_metadata, save_settings
354
+ from .settings import get_perform_regression_default_settings
355
+ from .toxo import go_term_enrichment_by_column, custom_volcano_plot
356
+
357
+ if isinstance(settings['score_data'], list) and isinstance(settings['count_data'], list):
358
+ settings['plate'] = None
359
+ if len(settings['score_data']) == 1:
360
+ settings['score_data'] = settings['score_data'][0]
361
+ if len(settings['count_data']) == 1:
362
+ settings['count_data'] = settings['count_data'][0]
363
+ else:
364
+ count_data_df = pd.DataFrame()
365
+ for i, count_data in enumerate(settings['count_data']):
366
+ df = pd.read_csv(count_data)
367
+ df['plate_name'] = f'plate{i+1}'
368
+ count_data_df = pd.concat([count_data_df, df])
369
+ print('Count data:', len(count_data_df))
370
+
371
+ score_data_df = pd.DataFrame()
372
+ for i, score_data in enumerate(settings['score_data']):
373
+ df = pd.read_csv(score_data)
374
+ df['plate_name'] = f'plate{i+1}'
375
+ score_data_df = pd.concat([score_data_df, df])
376
+ print('Score data:', len(score_data_df))
377
+ else:
378
+ count_data_df = pd.read_csv(settings['count_data'])
379
+ score_data_df = pd.read_csv(settings['score_data'])
380
+
381
+ reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
382
+ if settings['regression_type'] not in reg_types:
383
+ print(f'Possible regression types: {reg_types}')
384
+ raise ValueError(f"Unsupported regression type {settings['regression_type']}")
385
+
386
+ if settings['dependent_variable'] not in score_data_df.columns:
387
+ print(f'Columns in DataFrame:')
388
+ for col in score_data_df.columns:
389
+ print(col)
390
+ raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
391
+
392
+ if isinstance(settings['count_data'], list):
393
+ src = os.path.dirname(settings['count_data'][0])
394
+ csv_path = settings['count_data'][0]
395
+ else:
396
+ src = os.path.dirname(settings['count_data'])
397
+ csv_path = settings['count_data']
398
+
399
+ settings['src'] = src
400
+ fldr = 'results_' + settings['regression_type']
401
+ if isinstance(settings['count_data'], list):
402
+ fldr = fldr + '_list'
403
+
404
+ if settings['regression_type'] == 'quantile':
405
+ fldr = fldr + '_' + str(settings['alpha'])
406
+
407
+ res_folder = os.path.join(src, fldr)
408
+ os.makedirs(res_folder, exist_ok=True)
409
+ results_filename = 'results.csv'
410
+ hits_filename = 'results_significant.csv'
411
+ results_path=os.path.join(res_folder, results_filename)
412
+ hits_path=os.path.join(res_folder, hits_filename)
413
+
414
+ settings = get_perform_regression_default_settings(settings)
415
+ save_settings(settings, name='regression', show=True)
416
+
417
+ score_data_df = clean_controls(score_data_df, settings['pc'], settings['nc'], settings['other'])
418
+
419
+ if 'prediction_probability_class_1' in score_data_df.columns:
420
+ if not settings['class_1_threshold'] is None:
421
+ score_data_df['predictions'] = (score_data_df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
422
+
423
+ dependent_df, dependent_variable = process_scores(score_data_df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
424
+
425
+ independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plate'])
426
+
427
+ merged_df = pd.merge(independent_df, dependent_df, on='prc')
428
+
429
+ data_path = os.path.join(res_folder, 'regression_data.csv')
430
+ merged_df.to_csv(data_path, index=False)
431
+
432
+ merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
433
+
434
+ if settings['transform'] is None:
435
+ _ = plot_plates(score_data_df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'], dst = res_folder)
436
+
437
+ model, coef_df = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], highlight=settings['highlight'], dst=res_folder, cov_type=settings['cov_type'])
438
+
439
+ coef_df.to_csv(results_path, index=False)
440
+
441
+ if settings['regression_type'] == 'lasso':
442
+ significant = coef_df[coef_df['coefficient'] > 0]
443
+
444
+ else:
445
+ significant = coef_df[coef_df['p_value']<= 0.05]
446
+ #significant = significant[significant['coefficient'] > 0.1]
447
+ significant.sort_values(by='coefficient', ascending=False, inplace=True)
448
+ significant = significant[~significant['feature'].str.contains('row|column')]
449
+
450
+ if settings['regression_type'] == 'ols':
451
+ print(model.summary())
452
+
453
+ significant.to_csv(hits_path, index=False)
454
+
455
+ if isinstance(settings['metadata_files'], str):
456
+ settings['metadata_files'] = [settings['metadata_files']]
457
+
458
+ for metadata_file in settings['metadata_files']:
459
+ file = os.path.basename(metadata_file)
460
+ filename, _ = os.path.splitext(file)
461
+ _ = merge_regression_res_with_metadata(hits_path, metadata_file, name=filename)
462
+ merged_df = merge_regression_res_with_metadata(results_path, metadata_file, name=filename)
463
+
464
+ if settings['toxo']:
465
+
466
+ data_path = merged_df
467
+ base_dir = os.path.dirname(os.path.abspath(__file__))
468
+ metadata_path = os.path.join(base_dir, 'resources', 'data', 'lopit.csv')
469
+
470
+ custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', string_list=[settings['highlight']], point_size=50, figsize=20)
471
+
472
+ metadata_path = os.path.join(base_dir, 'resources', 'data', 'toxoplasma_metadata.csv')
473
+
474
+ go_term_enrichment_by_column(significant, metadata_path)
475
+
476
+ print('Significant Genes')
477
+ display(significant)
478
+
479
+ output = {'results':coef_df,
480
+ 'significant':significant}
481
+
482
+ return output
483
+
484
+ def process_reads(csv_path, fraction_threshold, plate):
485
+
486
+ if isinstance(csv_path, pd.DataFrame):
487
+ csv_df = csv_path
488
+ else:
489
+ # Read the CSV file into a DataFrame
490
+ csv_df = pd.read_csv(csv_path)
491
+
492
+ if 'plate_name' in csv_df.columns:
493
+ csv_df = csv_df.rename(columns={'plate_name': 'plate'})
494
+ if 'column_name' in csv_df.columns:
495
+ csv_df = csv_df.rename(columns={'column_name': 'column'})
496
+ if 'row_name' in csv_df.columns:
497
+ csv_df = csv_df.rename(columns={'row_name': 'row'})
498
+ if 'grna_name' in csv_df.columns:
499
+ csv_df = csv_df.rename(columns={'grna_name': 'grna'})
500
+ if 'plate_row' in csv_df.columns:
501
+ csv_df[['plate', 'row']] = csv_df['plate_row'].str.split('_', expand=True)
502
+ if not 'plate' in csv_df.columns:
503
+ if not plate is None:
504
+ csv_df['plate'] = plate
505
+ else:
506
+ csv_df['plate'] = 'plate1'
507
+
508
+ # Ensure the necessary columns are present
509
+ if not all(col in csv_df.columns for col in ['row','column','grna','count']):
510
+ raise ValueError("The CSV file must contain 'grna', 'count', 'row', and 'column' columns.")
511
+
512
+ # Create the prc column
513
+ csv_df['prc'] = csv_df['plate'] + '_' + csv_df['row'] + '_' + csv_df['column']
514
+
515
+ # Group by prc and calculate the sum of counts
516
+ grouped_df = csv_df.groupby('prc')['count'].sum().reset_index()
517
+ grouped_df = grouped_df.rename(columns={'count': 'total_counts'})
518
+ merged_df = pd.merge(csv_df, grouped_df, on='prc')
519
+ merged_df['fraction'] = merged_df['count'] / merged_df['total_counts']
520
+
521
+ # Filter rows with fraction under the threshold
522
+ if fraction_threshold is not None:
523
+ observations_before = len(merged_df)
524
+ merged_df = merged_df[merged_df['fraction'] >= fraction_threshold]
525
+ observations_after = len(merged_df)
526
+ removed = observations_before - observations_after
527
+ print(f'Removed {removed} observation below fraction threshold: {fraction_threshold}')
528
+
529
+ merged_df = merged_df[['prc', 'grna', 'fraction']]
530
+
531
+ if not all(col in merged_df.columns for col in ['grna', 'gene']):
532
+ try:
533
+ merged_df[['org', 'gene', 'grna']] = merged_df['grna'].str.split('_', expand=True)
534
+ merged_df = merged_df.drop(columns=['org'])
535
+ merged_df['grna'] = merged_df['gene'] + '_' + merged_df['grna']
536
+ except:
537
+ print('Error splitting grna into org, gene, grna.')
538
+
539
+ return merged_df
540
+
541
+ def apply_transformation(X, transform):
542
+ if transform == 'log':
543
+ transformer = FunctionTransformer(np.log1p, validate=True)
544
+ elif transform == 'sqrt':
545
+ transformer = FunctionTransformer(np.sqrt, validate=True)
546
+ elif transform == 'square':
547
+ transformer = FunctionTransformer(np.square, validate=True)
548
+ else:
549
+ transformer = None
550
+ return transformer
551
+
552
+ def check_normality(data, variable_name, verbose=False):
553
+ """Check if the data is normally distributed using the Shapiro-Wilk test."""
554
+ stat, p_value = shapiro(data)
555
+ if verbose:
556
+ print(f"Shapiro-Wilk Test for {variable_name}:\nStatistic: {stat}, P-value: {p_value}")
557
+ if p_value > 0.05:
558
+ if verbose:
559
+ print(f"Normal distribution: The data for {variable_name} is normally distributed.")
560
+ return True
561
+ else:
562
+ if verbose:
563
+ print(f"Normal distribution: The data for {variable_name} is not normally distributed.")
564
+ return False
565
+
566
+ def clean_controls(df,pc,nc,other):
567
+ if 'col' in df.columns:
568
+ df['column'] = df['col']
569
+ if nc != None:
570
+ df = df[~df['column'].isin([nc])]
571
+ if pc != None:
572
+ df = df[~df['column'].isin([pc])]
573
+ if other != None:
574
+ df = df[~df['column'].isin([other])]
575
+ print(f'Removed data from {nc, pc, other}')
576
+ return df
577
+
578
+ def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
579
+
580
+ if 'plate_name' in df.columns:
581
+ df.drop(columns=['plate'], inplace=True)
582
+ df = df.rename(columns={'plate_name': 'plate'})
583
+
584
+ if plate is not None:
585
+ df['plate'] = plate
586
+
587
+ if 'col' not in df.columns:
588
+ df['col'] = df['column']
589
+
590
+ df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
591
+ df = df[['prc', dependent_variable]]
592
+
593
+ # Group by prc and calculate the mean and count of the dependent_variable
594
+ grouped = df.groupby('prc')[dependent_variable]
595
+
596
+ if regression_type != 'poisson':
597
+
598
+ print(f'Using agg_type: {agg_type}')
599
+
600
+ if agg_type == 'median':
601
+ dependent_df = grouped.median().reset_index()
602
+ elif agg_type == 'mean':
603
+ dependent_df = grouped.mean().reset_index()
604
+ elif agg_type == 'quantile':
605
+ dependent_df = grouped.quantile(0.75).reset_index()
606
+ elif agg_type == None:
607
+ dependent_df = df.reset_index()
608
+ if 'prcfo' in dependent_df.columns:
609
+ dependent_df = dependent_df.drop(columns=['prcfo'])
610
+ else:
611
+ raise ValueError(f"Unsupported aggregation type {agg_type}")
612
+
613
+ if regression_type == 'poisson':
614
+ agg_type = 'count'
615
+ print(f'Using agg_type: {agg_type} for poisson regression')
616
+ dependent_df = grouped.sum().reset_index()
617
+
618
+ # Calculate cell_count for all cases
619
+ cell_count = grouped.size().reset_index(name='cell_count')
620
+
621
+ if agg_type is None:
622
+ dependent_df = pd.merge(dependent_df, cell_count, on='prc')
623
+ else:
624
+ dependent_df['cell_count'] = cell_count['cell_count']
625
+
626
+ dependent_df = dependent_df[dependent_df['cell_count'] >= min_cell_count]
627
+
628
+ is_normal = check_normality(dependent_df[dependent_variable], dependent_variable)
629
+
630
+ if not transform is None:
631
+ transformer = apply_transformation(dependent_df[dependent_variable], transform=transform)
632
+ transformed_var = f'{transform}_{dependent_variable}'
633
+ dependent_df[transformed_var] = transformer.fit_transform(dependent_df[[dependent_variable]])
634
+ dependent_variable = transformed_var
635
+ is_normal = check_normality(dependent_df[transformed_var], transformed_var)
636
+
637
+ if not is_normal:
638
+ print(f'{dependent_variable} is not normally distributed')
639
+ else:
640
+ print(f'{dependent_variable} is normally distributed')
641
+
642
+ return dependent_df, dependent_variable
643
+
644
+ def generate_ml_scores(settings):
645
+
646
+ from .io import _read_and_merge_data
647
+ from .plot import plot_plates
648
+ from .utils import get_ml_results_paths
649
+ from .settings import set_default_analyze_screen
650
+
651
+ settings = set_default_analyze_screen(settings)
652
+
653
+ src = settings['src']
654
+
655
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
656
+ display(settings_df)
657
+
658
+ db_loc = [src+'/measurements/measurements.db']
659
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
660
+
661
+ nuclei_limit, pathogen_limit, uninfected = settings['nuclei_limit'], settings['pathogen_limit'], settings['uninfected']
662
+
663
+ df, _ = _read_and_merge_data(db_loc,
664
+ tables,
665
+ settings['verbose'],
666
+ nuclei_limit,
667
+ pathogen_limit,
668
+ uninfected)
669
+
670
+ if settings['channel_of_interest'] in [0,1,2,3]:
671
+
672
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
673
+
674
+ output, figs = ml_analysis(df,
675
+ settings['channel_of_interest'],
676
+ settings['location_column'],
677
+ settings['positive_control'],
678
+ settings['negative_control'],
679
+ settings['exclude'],
680
+ settings['n_repeats'],
681
+ settings['top_features'],
682
+ settings['n_estimators'],
683
+ settings['test_size'],
684
+ settings['model_type_ml'],
685
+ settings['n_jobs'],
686
+ settings['remove_low_variance_features'],
687
+ settings['remove_highly_correlated_features'],
688
+ settings['verbose'])
689
+
690
+ shap_fig = shap_analysis(output[3], output[4], output[5])
691
+
692
+ features = output[0].select_dtypes(include=[np.number]).columns.tolist()
693
+
694
+ if not settings['heatmap_feature'] in features:
695
+ raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
696
+
697
+ plate_heatmap = plot_plates(df=output[0],
698
+ variable=settings['heatmap_feature'],
699
+ grouping=settings['grouping'],
700
+ min_max=settings['min_max'],
701
+ cmap=settings['cmap'],
702
+ min_count=settings['minimum_cell_count'],
703
+ verbose=settings['verbose'])
704
+
705
+ data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
706
+ df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
707
+
708
+ settings_df.to_csv(settings_csv, index=False)
709
+ df.to_csv(data_path, mode='w', encoding='utf-8')
710
+ permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
711
+ feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
712
+ metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
713
+
714
+ plate_heatmap.savefig(plate_heatmap_path, format='pdf')
715
+ figs[0].savefig(permutation_fig_path, format='pdf')
716
+ figs[1].savefig(feature_importance_fig_path, format='pdf')
717
+ shap_fig.savefig(shap_fig_path, format='pdf')
718
+
719
+ return [output, plate_heatmap]
720
+
721
+ def ml_analysis(df, channel_of_interest=3, location_column='col', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=100, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, verbose=False):
722
+ """
723
+ Calculates permutation importance for numerical features in the dataframe,
724
+ comparing groups based on specified column values and uses the model to predict
725
+ the class for all other rows in the dataframe.
726
+
727
+ Args:
728
+ df (pandas.DataFrame): The DataFrame containing the data.
729
+ feature_string (str): String to filter features that contain this substring.
730
+ location_column (str): Column name to use for comparing groups.
731
+ positive_control, negative_control (str): Values in location_column to create subsets for comparison.
732
+ exclude (list or str, optional): Columns to exclude from features.
733
+ n_repeats (int): Number of repeats for permutation importance.
734
+ top_features (int): Number of top features to plot based on permutation importance.
735
+ n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
736
+ test_size (float): Proportion of the dataset to include in the test split.
737
+ random_state (int): Random seed for reproducibility.
738
+ model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
739
+ n_jobs (int): Number of jobs to run in parallel for applicable models.
740
+
741
+ Returns:
742
+ pandas.DataFrame: The original dataframe with added prediction and data usage columns.
743
+ pandas.DataFrame: DataFrame containing the importances and standard deviations.
744
+ """
745
+
746
+ from .utils import filter_dataframe_features
747
+ from .plot import plot_permutation, plot_feature_importance
748
+
749
+ random_state = 42
750
+
751
+ if 'cells_per_well' in df.columns:
752
+ df = df.drop(columns=['cells_per_well'])
753
+
754
+
755
+ df_metadata = df[[location_column]].copy()
756
+
757
+ df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
758
+ print('After filtration:', len(df))
759
+
760
+ if verbose:
761
+ print(f'Found {len(features)} numerical features in the dataframe')
762
+ print(f'Features used in training: {features}')
763
+ df = pd.concat([df, df_metadata[location_column]], axis=1)
764
+
765
+ # Subset the dataframe based on specified column values
766
+ df1 = df[df[location_column] == negative_control].copy()
767
+ df2 = df[df[location_column] == positive_control].copy()
768
+
769
+ # Create target variable
770
+ df1['target'] = 0 # Negative control
771
+ df2['target'] = 1 # Positive control
772
+
773
+ # Combine the subsets for analysis
774
+ combined_df = pd.concat([df1, df2])
775
+ combined_df = combined_df.drop(columns=[location_column])
776
+ if verbose:
777
+ print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
778
+
779
+ X = combined_df[features]
780
+ y = combined_df['target']
781
+
782
+ print(X)
783
+ print(y)
784
+
785
+ # Split the data into training and testing sets
786
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
787
+
788
+ # Add data usage labels
789
+ combined_df['data_usage'] = 'train'
790
+ combined_df.loc[X_test.index, 'data_usage'] = 'test'
791
+ df['data_usage'] = 'not_used'
792
+ df.loc[combined_df.index, 'data_usage'] = combined_df['data_usage']
793
+
794
+ # Initialize the model based on model_type
795
+ if model_type == 'random_forest':
796
+ model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
797
+ elif model_type == 'logistic_regression':
798
+ model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
799
+ elif model_type == 'gradient_boosting':
800
+ model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
801
+ elif model_type == 'xgboost':
802
+ model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
803
+ else:
804
+ raise ValueError(f"Unsupported model_type: {model_type}")
805
+
806
+ model.fit(X_train, y_train)
807
+
808
+ perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
809
+
810
+ # Create a DataFrame for permutation importances
811
+ permutation_df = pd.DataFrame({
812
+ 'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
813
+ 'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
814
+ 'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
815
+ }).tail(top_features)
816
+
817
+ permutation_fig = plot_permutation(permutation_df)
818
+ if verbose:
819
+ permutation_fig.show()
820
+
821
+ # Feature importance for models that support it
822
+ if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
823
+ feature_importances = model.feature_importances_
824
+ feature_importance_df = pd.DataFrame({
825
+ 'feature': features,
826
+ 'importance': feature_importances
827
+ }).sort_values(by='importance', ascending=False).head(top_features)
828
+
829
+ feature_importance_fig = plot_feature_importance(feature_importance_df)
830
+ if verbose:
831
+ feature_importance_fig.show()
832
+
833
+ else:
834
+ feature_importance_df = pd.DataFrame()
835
+
836
+ # Predicting the target variable for the test set
837
+ predictions_test = model.predict(X_test)
838
+ combined_df.loc[X_test.index, 'predictions'] = predictions_test
839
+
840
+ # Get prediction probabilities for the test set
841
+ prediction_probabilities_test = model.predict_proba(X_test)
842
+
843
+ # Find the optimal threshold
844
+ optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
845
+ if verbose:
846
+ print(f'Optimal threshold: {optimal_threshold}')
847
+
848
+ # Predicting the target variable for all other rows in the dataframe
849
+ X_all = df[features]
850
+ all_predictions = model.predict(X_all)
851
+ df['predictions'] = all_predictions
852
+
853
+ # Get prediction probabilities for all rows in the dataframe
854
+ prediction_probabilities = model.predict_proba(X_all)
855
+ for i in range(prediction_probabilities.shape[1]):
856
+ df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
857
+ if verbose:
858
+ print("\nClassification Report:")
859
+ print(classification_report(y_test, predictions_test))
860
+ report_dict = classification_report(y_test, predictions_test, output_dict=True)
861
+ metrics_df = pd.DataFrame(report_dict).transpose()
862
+
863
+ df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
864
+
865
+ df['prcfo'] = df.index.astype(str)
866
+ df[['plate', 'row', 'col', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
867
+ df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
868
+
869
+ return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
870
+
871
+ def shap_analysis(model, X_train, X_test):
872
+
873
+ """
874
+ Performs SHAP analysis on the given model and data.
875
+
876
+ Args:
877
+ model: The trained model.
878
+ X_train (pandas.DataFrame): Training feature set.
879
+ X_test (pandas.DataFrame): Testing feature set.
880
+ Returns:
881
+ fig: Matplotlib figure object containing the SHAP summary plot.
882
+ """
883
+
884
+ explainer = shap.Explainer(model, X_train)
885
+ shap_values = explainer(X_test)
886
+ # Create a new figure
887
+ fig, ax = plt.subplots()
888
+ # Summary plot
889
+ shap.summary_plot(shap_values, X_test, show=False)
890
+ # Save the current figure (the one that SHAP just created)
891
+ fig = plt.gcf()
892
+ plt.close(fig) # Close the figure to prevent it from displaying immediately
893
+ return fig
894
+
895
+ def find_optimal_threshold(y_true, y_pred_proba):
896
+ """
897
+ Find the optimal threshold for binary classification based on the F1-score.
898
+
899
+ Args:
900
+ y_true (array-like): True binary labels.
901
+ y_pred_proba (array-like): Predicted probabilities for the positive class.
902
+
903
+ Returns:
904
+ float: The optimal threshold.
905
+ """
906
+ precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
907
+ f1_scores = 2 * (precision * recall) / (precision + recall)
908
+ optimal_idx = np.argmax(f1_scores)
909
+ optimal_threshold = thresholds[optimal_idx]
910
+ return optimal_threshold
911
+
912
+ def _calculate_similarity(df, features, col_to_compare, val1, val2):
913
+ """
914
+ Calculate similarity scores of each well to the positive and negative controls using various metrics.
915
+
916
+ Args:
917
+ df (pandas.DataFrame): DataFrame containing the data.
918
+ features (list): List of feature columns to use for similarity calculation.
919
+ col_to_compare (str): Column name to use for comparing groups.
920
+ val1, val2 (str): Values in col_to_compare to create subsets for comparison.
921
+
922
+ Returns:
923
+ pandas.DataFrame: DataFrame with similarity scores.
924
+ """
925
+ # Separate positive and negative control wells
926
+ pos_control = df[df[col_to_compare] == val1][features].mean()
927
+ neg_control = df[df[col_to_compare] == val2][features].mean()
928
+
929
+ # Standardize features for Mahalanobis distance
930
+ scaler = StandardScaler()
931
+ scaled_features = scaler.fit_transform(df[features])
932
+
933
+ # Regularize the covariance matrix to avoid singularity
934
+ cov_matrix = np.cov(scaled_features, rowvar=False)
935
+ inv_cov_matrix = None
936
+ try:
937
+ inv_cov_matrix = np.linalg.inv(cov_matrix)
938
+ except np.linalg.LinAlgError:
939
+ # Add a small value to the diagonal elements for regularization
940
+ epsilon = 1e-5
941
+ inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
942
+
943
+ # Calculate similarity scores
944
+ df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
945
+ df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
946
+ df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
947
+ df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
948
+ df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
949
+ df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
950
+ df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
951
+ df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
952
+ df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
953
+ df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
954
+ df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
955
+ df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
956
+ df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
957
+ df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
958
+ df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
959
+ df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
960
+ df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
961
+ df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
962
+
963
+ return df
964
+