spacr 0.3.47__py3-none-any.whl → 0.3.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/chat_bot.py +31 -0
- spacr/gui_elements.py +33 -7
- spacr/ml.py +453 -141
- spacr/plot.py +460 -29
- spacr/sequencing.py +5 -2
- spacr/settings.py +1 -1
- spacr/toxo.py +267 -158
- spacr/utils.py +12 -4
- {spacr-0.3.47.dist-info → spacr-0.3.50.dist-info}/METADATA +2 -1
- {spacr-0.3.47.dist-info → spacr-0.3.50.dist-info}/RECORD +14 -13
- {spacr-0.3.47.dist-info → spacr-0.3.50.dist-info}/LICENSE +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.50.dist-info}/WHEEL +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.50.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.50.dist-info}/top_level.txt +0 -0
spacr/ml.py
CHANGED
@@ -4,15 +4,24 @@ import numpy as np
|
|
4
4
|
from scipy import stats
|
5
5
|
from scipy.stats import shapiro
|
6
6
|
|
7
|
+
from sklearn.linear_model import Lasso, Ridge, LassoCV, RidgeCV
|
8
|
+
from sklearn.metrics import mean_squared_error
|
9
|
+
import numpy as np
|
10
|
+
|
7
11
|
import matplotlib.pyplot as plt
|
8
12
|
from IPython.display import display
|
9
|
-
|
13
|
+
import scipy.stats as st
|
10
14
|
import statsmodels.api as sm
|
11
15
|
import statsmodels.formula.api as smf
|
16
|
+
from statsmodels.tools import add_constant
|
12
17
|
from statsmodels.regression.mixed_linear_model import MixedLM
|
13
18
|
from statsmodels.tools.sm_exceptions import PerfectSeparationError
|
14
19
|
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
15
|
-
|
20
|
+
from statsmodels.genmod.families import Binomial
|
21
|
+
from statsmodels.genmod.families.links import logit
|
22
|
+
from statsmodels.othermod.betareg import BetaModel
|
23
|
+
from scipy.optimize import minimize
|
24
|
+
from scipy.special import gammaln, psi, expit
|
16
25
|
from sklearn.linear_model import Lasso, Ridge
|
17
26
|
from sklearn.preprocessing import FunctionTransformer
|
18
27
|
from patsy import dmatrices
|
@@ -24,17 +33,30 @@ from sklearn.inspection import permutation_importance
|
|
24
33
|
from sklearn.metrics import classification_report, precision_recall_curve
|
25
34
|
from sklearn.preprocessing import StandardScaler
|
26
35
|
from sklearn.preprocessing import MinMaxScaler
|
27
|
-
|
28
36
|
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
29
|
-
|
30
37
|
from xgboost import XGBClassifier
|
31
38
|
|
39
|
+
import numpy as np
|
40
|
+
from scipy.stats import kstest, normaltest
|
41
|
+
import statsmodels.api as sm
|
42
|
+
|
32
43
|
import matplotlib
|
33
44
|
matplotlib.use('Agg')
|
34
45
|
|
35
46
|
import warnings
|
36
47
|
warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
|
37
48
|
|
49
|
+
|
50
|
+
class QuasiBinomial(Binomial):
|
51
|
+
"""Custom Quasi-Binomial family with adjustable variance."""
|
52
|
+
def __init__(self, link=logit(), dispersion=1.0):
|
53
|
+
super().__init__(link=link)
|
54
|
+
self.dispersion = dispersion
|
55
|
+
|
56
|
+
def variance(self, mu):
|
57
|
+
"""Adjust the variance with the dispersion parameter."""
|
58
|
+
return self.dispersion * super().variance(mu)
|
59
|
+
|
38
60
|
def calculate_p_values(X, y, model):
|
39
61
|
# Predict y values
|
40
62
|
y_pred = model.predict(X)
|
@@ -75,46 +97,6 @@ def perform_mixed_model(y, X, groups, alpha=1.0):
|
|
75
97
|
result = model.fit()
|
76
98
|
return result
|
77
99
|
|
78
|
-
def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
|
79
|
-
|
80
|
-
def plot_regression_line(X, y, model):
|
81
|
-
"""Helper to plot regression line for lasso and ridge models."""
|
82
|
-
y_pred = model.predict(X)
|
83
|
-
plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
|
84
|
-
plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
|
85
|
-
plt.xlabel('Features')
|
86
|
-
plt.ylabel('Dependent Variable')
|
87
|
-
plt.legend()
|
88
|
-
plt.show()
|
89
|
-
|
90
|
-
# Define the dictionary with callables (lambdas) to delay evaluation
|
91
|
-
model_map = {
|
92
|
-
'ols': lambda: sm.OLS(y, X).fit(cov_type=cov_type) if cov_type else sm.OLS(y, X).fit(),
|
93
|
-
'gls': lambda: sm.GLS(y, X).fit(),
|
94
|
-
'wls': lambda: sm.WLS(y, X, weights=1 / np.sqrt(X.iloc[:, 1])).fit(),
|
95
|
-
'rlm': lambda: sm.RLM(y, X, M=sm.robust.norms.HuberT()).fit(),
|
96
|
-
'glm': lambda: sm.GLM(y, X, family=sm.families.Gaussian()).fit(),
|
97
|
-
'quantile': lambda: sm.QuantReg(y, X).fit(q=alpha),
|
98
|
-
'logit': lambda: sm.Logit(y, X).fit(),
|
99
|
-
'probit': lambda: sm.Probit(y, X).fit(),
|
100
|
-
'poisson': lambda: sm.Poisson(y, X).fit(),
|
101
|
-
'lasso': lambda: Lasso(alpha=alpha).fit(X, y),
|
102
|
-
'ridge': lambda: Ridge(alpha=alpha).fit(X, y)
|
103
|
-
}
|
104
|
-
|
105
|
-
# Call the appropriate model only when needed
|
106
|
-
if regression_type in model_map:
|
107
|
-
model = model_map[regression_type]()
|
108
|
-
elif regression_type == 'mixed':
|
109
|
-
model = perform_mixed_model(y, X, groups, alpha=alpha)
|
110
|
-
else:
|
111
|
-
raise ValueError(f"Unsupported regression type {regression_type}")
|
112
|
-
|
113
|
-
if regression_type in ['lasso', 'ridge']:
|
114
|
-
plot_regression_line(X, y, model)
|
115
|
-
|
116
|
-
return model
|
117
|
-
|
118
100
|
def create_volcano_filename(csv_path, regression_type, alpha, dst):
|
119
101
|
"""Create and return the volcano plot filename based on regression type and alpha."""
|
120
102
|
volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
|
@@ -173,6 +155,41 @@ def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
|
|
173
155
|
coef_df['condition'] = coef_df.apply(lambda row: 'nc' if nc in row['feature'] else 'pc' if pc in row['feature'] else ('control' if row['grna'] in controls else 'other'),axis=1)
|
174
156
|
return coef_df[~coef_df['feature'].str.contains('row|column')]
|
175
157
|
|
158
|
+
|
159
|
+
|
160
|
+
|
161
|
+
|
162
|
+
def check_distribution(y):
|
163
|
+
"""Check the type of distribution to recommend a model."""
|
164
|
+
if np.all((y == 0) | (y == 1)):
|
165
|
+
print("Detected binary data.")
|
166
|
+
return 'logit'
|
167
|
+
elif (y > 0).all() and (y < 1).all():
|
168
|
+
print("Detected continuous data between 0 and 1 (excluding 0 and 1).")
|
169
|
+
return 'beta'
|
170
|
+
elif (y >= 0).all() and (y <= 1).all():
|
171
|
+
print("Detected continuous data between 0 and 1 (including 0 or 1).")
|
172
|
+
# Consider quasi-binomial regression
|
173
|
+
return 'quasi_binomial'
|
174
|
+
else:
|
175
|
+
print("Using OLS as a fallback.")
|
176
|
+
return 'ols'
|
177
|
+
|
178
|
+
def select_glm_family(y):
|
179
|
+
"""Select the appropriate GLM family based on the data."""
|
180
|
+
if np.all((y == 0) | (y == 1)):
|
181
|
+
print("Using Binomial family (for binary data).")
|
182
|
+
return sm.families.Binomial()
|
183
|
+
elif (y >= 0).all() and (y <= 1).all():
|
184
|
+
print("Using Quasi-Binomial family (for proportion data including 0 and 1).")
|
185
|
+
return QuasiBinomial()
|
186
|
+
elif np.all(y.astype(int) == y) and (y >= 0).all():
|
187
|
+
print("Using Poisson family (for count data).")
|
188
|
+
return sm.families.Poisson()
|
189
|
+
else:
|
190
|
+
print("Using Gaussian family (for continuous data).")
|
191
|
+
return sm.families.Gaussian()
|
192
|
+
|
176
193
|
def prepare_formula(dependent_variable, random_row_column_effects=False):
|
177
194
|
"""Return the regression formula using random effects for plate, row, and column."""
|
178
195
|
if random_row_column_effects:
|
@@ -290,25 +307,345 @@ def check_and_clean_data(df, dependent_variable):
|
|
290
307
|
print("Data is ready for model fitting.")
|
291
308
|
return df_cleaned
|
292
309
|
|
293
|
-
def
|
294
|
-
|
310
|
+
def check_normality(y, variable_name):
|
311
|
+
"""Check if the data is normally distributed using the Shapiro-Wilk test."""
|
312
|
+
from scipy.stats import shapiro
|
295
313
|
|
296
|
-
|
297
|
-
|
314
|
+
stat, p = shapiro(y)
|
315
|
+
alpha = 0.05
|
316
|
+
if p > alpha:
|
317
|
+
print(f"{variable_name} is normally distributed (fail to reject H0)")
|
318
|
+
return True
|
319
|
+
else:
|
320
|
+
print(f"{variable_name} is not normally distributed (reject H0)")
|
321
|
+
return False
|
322
|
+
|
323
|
+
def minimum_cell_simulation(settings, num_repeats=10, sample_size=100, tolerance=0.02, smoothing=10, increment=10):
|
324
|
+
"""
|
325
|
+
Plot the mean absolute difference with standard deviation as shaded area vs. sample size.
|
326
|
+
Detect and mark the elbow point (inflection) with smoothing and tolerance control.
|
327
|
+
"""
|
298
328
|
|
299
|
-
|
300
|
-
is_normal = check_normality(df[dependent_variable], dependent_variable)
|
329
|
+
from spacr.utils import correct_metadata_column_names
|
301
330
|
|
302
|
-
|
303
|
-
|
304
|
-
|
331
|
+
# Load and process data
|
332
|
+
if isinstance(settings['score_data'], str):
|
333
|
+
settings['score_data'] = [settings['score_data']]
|
334
|
+
|
335
|
+
dfs = []
|
336
|
+
for i, score_data in enumerate(settings['score_data']):
|
337
|
+
df = pd.read_csv(score_data)
|
338
|
+
df = correct_metadata_column_names(df)
|
339
|
+
df['plate'] = f'plate{i + 1}'
|
340
|
+
df['prc'] = df['plate'] + '_' + df['row'].astype(str) + '_' + df['column'].astype(str)
|
341
|
+
dfs.append(df)
|
342
|
+
|
343
|
+
df = pd.concat(dfs, axis=0)
|
344
|
+
|
345
|
+
# Compute the number of cells per well and select the top 100 wells by cell count
|
346
|
+
cell_counts = df.groupby('prc').size().reset_index(name='cell_count')
|
347
|
+
top_wells = cell_counts.nlargest(sample_size, 'cell_count')['prc']
|
348
|
+
|
349
|
+
# Filter the data to include only the top 100 wells
|
350
|
+
df = df[df['prc'].isin(top_wells)]
|
351
|
+
|
352
|
+
# Initialize storage for absolute difference data
|
353
|
+
diff_data = []
|
354
|
+
|
355
|
+
# Group by wells and iterate over them
|
356
|
+
for i, (prc, group) in enumerate(df.groupby('prc')):
|
357
|
+
original_mean = group[settings['score_column']].mean() # Original full-well mean
|
358
|
+
max_cells = len(group)
|
359
|
+
sample_sizes = np.arange(2, max_cells + 1, increment) # Sample sizes from 2 to max cells
|
360
|
+
|
361
|
+
# Iterate over sample sizes and compute absolute difference
|
362
|
+
for sample_size in sample_sizes:
|
363
|
+
abs_diffs = []
|
364
|
+
|
365
|
+
# Perform multiple random samples to reduce noise
|
366
|
+
for _ in range(num_repeats):
|
367
|
+
sample = group.sample(n=sample_size, replace=False)
|
368
|
+
sampled_mean = sample[settings['score_column']].mean()
|
369
|
+
abs_diff = abs(sampled_mean - original_mean) # Absolute difference
|
370
|
+
abs_diffs.append(abs_diff)
|
371
|
+
|
372
|
+
# Compute the average absolute difference across all repeats
|
373
|
+
avg_abs_diff = np.mean(abs_diffs)
|
374
|
+
|
375
|
+
# Store the result for plotting
|
376
|
+
diff_data.append((sample_size, avg_abs_diff))
|
377
|
+
|
378
|
+
# Convert absolute difference data to DataFrame for plotting
|
379
|
+
diff_df = pd.DataFrame(diff_data, columns=['sample_size', 'avg_abs_diff'])
|
380
|
+
|
381
|
+
# Group by sample size to calculate mean and standard deviation
|
382
|
+
summary_df = diff_df.groupby('sample_size').agg(
|
383
|
+
mean_abs_diff=('avg_abs_diff', 'mean'),
|
384
|
+
std_abs_diff=('avg_abs_diff', 'std')
|
385
|
+
).reset_index()
|
386
|
+
|
387
|
+
# Apply smoothing using a rolling window
|
388
|
+
summary_df['smoothed_mean_abs_diff'] = summary_df['mean_abs_diff'].rolling(window=smoothing, min_periods=1).mean()
|
389
|
+
|
390
|
+
# Detect the elbow point (where mean_abs_diff < tolerance)
|
391
|
+
elbow_df = summary_df[summary_df['smoothed_mean_abs_diff'] <= tolerance]
|
392
|
+
|
393
|
+
# Select the first occurrence if it exists; otherwise, use the last point
|
394
|
+
if not elbow_df.empty:
|
395
|
+
elbow_point = elbow_df.iloc[0] # First point where the condition is met
|
305
396
|
else:
|
306
|
-
|
307
|
-
|
397
|
+
elbow_point = summary_df.iloc[-1] # Fallback to the last point
|
398
|
+
|
399
|
+
# Plot the mean absolute difference with standard deviation as shaded area
|
400
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
401
|
+
ax.plot(
|
402
|
+
summary_df['sample_size'], summary_df['smoothed_mean_abs_diff'], color='teal', label='Smoothed Mean Absolute Difference'
|
403
|
+
)
|
404
|
+
ax.fill_between(
|
405
|
+
summary_df['sample_size'],
|
406
|
+
summary_df['smoothed_mean_abs_diff'] - summary_df['std_abs_diff'],
|
407
|
+
summary_df['smoothed_mean_abs_diff'] + summary_df['std_abs_diff'],
|
408
|
+
color='teal', alpha=0.3, label='±1 Std. Dev.'
|
409
|
+
)
|
410
|
+
|
411
|
+
# Mark the elbow point (inflection) on the plot
|
412
|
+
ax.axvline(elbow_point['sample_size'], color='black', linestyle='--', label='Elbow Point')
|
413
|
+
|
414
|
+
# Formatting the plot
|
415
|
+
ax.set_xlabel('Sample Size')
|
416
|
+
ax.set_ylabel('Mean Absolute Difference')
|
417
|
+
ax.set_title('Mean Absolute Difference vs. Sample Size with Standard Deviation')
|
418
|
+
ax.legend().remove()
|
419
|
+
|
420
|
+
# Save the plot if a destination is provided
|
421
|
+
dst = os.path.dirname(settings['count_data'][0])
|
422
|
+
if dst is not None:
|
423
|
+
fig_path = os.path.join(dst, 'results')
|
424
|
+
os.makedirs(fig_path, exist_ok=True)
|
425
|
+
fig_file_path = os.path.join(fig_path, 'cell_min_threshold.pdf')
|
426
|
+
fig.savefig(fig_file_path, format='pdf', dpi=600, bbox_inches='tight')
|
427
|
+
print(f"Saved {fig_file_path}")
|
428
|
+
|
429
|
+
plt.show()
|
430
|
+
return elbow_point['sample_size']
|
431
|
+
|
432
|
+
def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
|
433
|
+
"""Return DataFrame of model coefficients, standard errors, and p-values."""
|
434
|
+
|
435
|
+
if regression_type == 'beta':
|
436
|
+
# Extract coefficients and standard errors
|
437
|
+
coefs = model.params
|
438
|
+
std_err = model.bse
|
439
|
+
|
440
|
+
# Compute Wald test (coefficient / standard error)
|
441
|
+
wald_stats = coefs / std_err
|
442
|
+
|
443
|
+
# Calculate two-tailed p-values
|
444
|
+
p_values = 2 * (1 - st.norm.cdf(np.abs(wald_stats)))
|
445
|
+
|
446
|
+
coef_df = pd.DataFrame({
|
447
|
+
'feature': coefs.index,
|
448
|
+
'coefficient': coefs.values,
|
449
|
+
'std_err': std_err.values,
|
450
|
+
'wald_stat': wald_stats.values,
|
451
|
+
'p_value': p_values
|
452
|
+
})
|
453
|
+
|
454
|
+
elif regression_type in ['ols', 'glm', 'logit', 'probit', 'quasi_binomial']:
|
455
|
+
coefs = model.params
|
456
|
+
p_values = model.pvalues
|
457
|
+
|
458
|
+
coef_df = pd.DataFrame({
|
459
|
+
'feature': coefs.index,
|
460
|
+
'coefficient': coefs.values,
|
461
|
+
'p_value': p_values.values
|
462
|
+
})
|
463
|
+
|
464
|
+
elif regression_type in ['ridge', 'lasso']:
|
465
|
+
coefs = model.coef_.flatten()
|
466
|
+
p_values = calculate_p_values(X, y, model)
|
467
|
+
|
468
|
+
coef_df = pd.DataFrame({
|
469
|
+
'feature': X.columns,
|
470
|
+
'coefficient': coefs,
|
471
|
+
'p_value': p_values
|
472
|
+
})
|
473
|
+
|
474
|
+
else:
|
475
|
+
raise ValueError(f"Unsupported regression type: {regression_type}")
|
476
|
+
|
477
|
+
# Additional formatting
|
478
|
+
coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
|
479
|
+
coef_df['grna'] = coef_df['feature'].str.extract(r'\[(.*?)\]')[0]
|
480
|
+
coef_df['condition'] = coef_df.apply(
|
481
|
+
lambda row: 'nc' if nc in row['feature'] else
|
482
|
+
'pc' if pc in row['feature'] else
|
483
|
+
('control' if row['grna'] in controls else 'other'),
|
484
|
+
axis=1
|
485
|
+
)
|
486
|
+
|
487
|
+
return coef_df[~coef_df['feature'].str.contains('row|column')]
|
488
|
+
|
489
|
+
def check_distribution(y, epsilon=1e-6):
|
490
|
+
"""Check the distribution of y and recommend an appropriate model."""
|
491
|
+
|
492
|
+
# Check if the dependent variable is binary (only 0 and 1)
|
493
|
+
if np.all((y == 0) | (y == 1)):
|
494
|
+
print("Detected binary data.")
|
495
|
+
return 'logit'
|
496
|
+
|
497
|
+
# Continuous data between 0 and 1 (excluding exact 0 and 1)
|
498
|
+
elif (y > 0).all() and (y < 1).all():
|
499
|
+
# Check if the data is close to 0 or 1 (boundary issues)
|
500
|
+
if np.any((y < epsilon) | (y > 1 - epsilon)):
|
501
|
+
print("Detected continuous data near 0 or 1. Using quasi-binomial.")
|
502
|
+
return 'quasi_binomial'
|
503
|
+
else:
|
504
|
+
print("Detected continuous data between 0 and 1 (no boundary issues). Using beta regression.")
|
505
|
+
return 'beta'
|
506
|
+
|
507
|
+
# Continuous data between 0 and 1 (including exact 0 or 1)
|
508
|
+
elif (y >= 0).all() and (y <= 1).all():
|
509
|
+
print("Detected continuous data with boundary values (0 or 1). Using quasi-binomial.")
|
510
|
+
return 'quasi_binomial'
|
511
|
+
|
512
|
+
# Check if the data is normally distributed for OLS suitability
|
513
|
+
stat, p_value = stats.normaltest(y) # D’Agostino and Pearson’s test for normality
|
514
|
+
print(f"Normality test p-value: {p_value:.4f}")
|
515
|
+
|
516
|
+
if p_value > 0.05:
|
517
|
+
print("Detected normally distributed data. Using OLS.")
|
518
|
+
return 'ols'
|
519
|
+
|
520
|
+
# Check if the data fits a Beta distribution
|
521
|
+
if stats.kstest(y, 'beta', args=(2, 2)).pvalue > 0.05:
|
522
|
+
# Check if the data is close to 0 or 1 (boundary issues)
|
523
|
+
if np.any((y < epsilon) | (y > 1 - epsilon)):
|
524
|
+
print("Detected continuous data near 0 or 1. Using quasi-binomial.")
|
525
|
+
return 'quasi_binomial'
|
526
|
+
else:
|
527
|
+
print("Detected continuous data between 0 and 1 (no boundary issues). Using beta regression.")
|
528
|
+
return 'beta'
|
529
|
+
|
530
|
+
print("Detected non-normally distributed data. Using GLM.")
|
531
|
+
return 'glm'
|
532
|
+
|
533
|
+
def pick_glm_family_and_link(y):
|
534
|
+
"""Select the appropriate GLM family and link function based on data."""
|
535
|
+
if np.all((y == 0) | (y == 1)):
|
536
|
+
print("Binary data detected. Using Binomial family with Logit link.")
|
537
|
+
return sm.families.Binomial(link=sm.families.links.Logit())
|
538
|
+
|
539
|
+
elif (y > 0).all() and (y < 1).all():
|
540
|
+
print("Data strictly between 0 and 1. Beta regression recommended.")
|
541
|
+
raise ValueError("Use BetaModel for this data; GLM is not applicable.")
|
542
|
+
|
543
|
+
elif (y >= 0).all() and (y <= 1).all():
|
544
|
+
print("Data between 0 and 1 (including boundaries). Using Quasi-Binomial.")
|
545
|
+
return sm.families.Binomial(link=sm.families.links.Logit())
|
546
|
+
|
547
|
+
stat, p_value = normaltest(y)
|
548
|
+
print(f"Normality test p-value: {p_value:.4f}")
|
549
|
+
if p_value > 0.05:
|
550
|
+
print("Normally distributed data detected. Using Gaussian with Identity link.")
|
551
|
+
return sm.families.Gaussian(link=sm.families.links.Identity())
|
552
|
+
|
553
|
+
if (y >= 0).all() and np.all(y.astype(int) == y):
|
554
|
+
print("Count data detected. Using Poisson with Log link.")
|
555
|
+
return sm.families.Poisson(link=sm.families.links.Log())
|
556
|
+
|
557
|
+
if (y > 0).all() and kstest(y, 'invgauss', args=(1,)).pvalue > 0.05:
|
558
|
+
print("Inverse Gaussian distribution detected. Using InverseGaussian with Log link.")
|
559
|
+
return sm.families.InverseGaussian(link=sm.families.links.Log())
|
560
|
+
|
561
|
+
if (y >= 0).all():
|
562
|
+
print("Overdispersed count data detected. Using Negative Binomial with Log link.")
|
563
|
+
return sm.families.NegativeBinomial(link=sm.families.links.Log())
|
564
|
+
|
565
|
+
print("Using default Gaussian family with Identity link.")
|
566
|
+
return sm.families.Gaussian(link=sm.families.links.Identity())
|
567
|
+
|
568
|
+
def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
|
569
|
+
def plot_regression_line(X, y, model):
|
570
|
+
"""Helper to plot regression line for lasso and ridge models."""
|
571
|
+
y_pred = model.predict(X)
|
572
|
+
plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
|
573
|
+
plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
|
574
|
+
plt.xlabel('Features')
|
575
|
+
plt.ylabel('Dependent Variable')
|
576
|
+
plt.legend()
|
577
|
+
plt.show()
|
578
|
+
|
579
|
+
def find_best_alpha(model_cls):
|
580
|
+
"""Find optimal alpha using cross-validation."""
|
581
|
+
alphas = np.logspace(-5, 5, 100) # Search over a range of alphas
|
582
|
+
if model_cls == 'lasso':
|
583
|
+
model_cv = LassoCV(alphas=alphas, cv=5).fit(X, y)
|
584
|
+
elif model_cls == 'ridge':
|
585
|
+
model_cv = RidgeCV(alphas=alphas, cv=5).fit(X, y)
|
586
|
+
print(f"Optimal alpha for {model_cls}: {model_cv.alpha_}")
|
587
|
+
return model_cv
|
588
|
+
|
589
|
+
# Dictionary of models
|
590
|
+
model_map = {
|
591
|
+
'ols': lambda: sm.OLS(y, X).fit(cov_type=cov_type) if cov_type else sm.OLS(y, X).fit(),
|
592
|
+
'glm': lambda: sm.GLM(y, X, family=pick_glm_family_and_link(y)).fit(),
|
593
|
+
'beta': lambda: BetaModel(endog=y, exog=X).fit(),
|
594
|
+
'logit': lambda: sm.Logit(y, X).fit(),
|
595
|
+
'probit': lambda: sm.Probit(y, X).fit(),
|
596
|
+
'lasso': lambda: find_best_alpha('lasso') if alpha in [0, None] else Lasso(alpha=alpha).fit(X, y),
|
597
|
+
'ridge': lambda: find_best_alpha('ridge') if alpha in [0, None] else Ridge(alpha=alpha).fit(X, y)
|
598
|
+
}
|
599
|
+
|
600
|
+
# Select the model based on regression_type
|
601
|
+
if regression_type in model_map:
|
602
|
+
model = model_map[regression_type]()
|
603
|
+
elif regression_type == 'mixed':
|
604
|
+
model = perform_mixed_model(y, X, groups, alpha=alpha)
|
605
|
+
else:
|
606
|
+
raise ValueError(f"Unsupported regression type {regression_type}")
|
607
|
+
|
608
|
+
# Plot regression line for Lasso and Ridge
|
609
|
+
if regression_type in ['lasso', 'ridge']:
|
610
|
+
plot_regression_line(X, y, model)
|
611
|
+
|
612
|
+
# Handle GLM-specific statistics
|
613
|
+
if regression_type == 'glm':
|
614
|
+
llf_model = model.llf # Log-likelihood of the fitted model
|
615
|
+
llf_null = model.null_deviance / -2 # Log-likelihood of the null model
|
616
|
+
mcfadden_r2 = 1 - (llf_model / llf_null)
|
617
|
+
print(f"McFadden's R²: {mcfadden_r2:.4f}")
|
618
|
+
print(model.summary())
|
619
|
+
|
620
|
+
if regression_type in ['lasso', 'ridge']:
|
621
|
+
# Calculate the Mean Squared Error (MSE)
|
622
|
+
mse = mean_squared_error(y, model.predict(X))
|
623
|
+
print(f"{regression_type.capitalize()} Regression MSE: {mse:.4f}")
|
624
|
+
|
625
|
+
# Display coefficients
|
626
|
+
coef_df = pd.DataFrame({
|
627
|
+
'Feature': X.columns,
|
628
|
+
'Coefficient': model.coef_
|
629
|
+
})
|
630
|
+
print(coef_df)
|
631
|
+
|
632
|
+
return model
|
633
|
+
|
634
|
+
def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0,
|
635
|
+
random_row_column_effects=False, nc='233460', pc='220950', controls=[''],
|
636
|
+
dst=None, cov_type=None, plot=False):
|
637
|
+
|
638
|
+
from spacr.plot import volcano_plot, plot_histogram
|
639
|
+
from spacr.ml import create_volcano_filename, check_and_clean_data, prepare_formula, scale_variables
|
640
|
+
|
641
|
+
# Generate the volcano filename
|
642
|
+
volcano_path = create_volcano_filename(csv_path, regression_type, alpha, dst)
|
308
643
|
|
309
644
|
# Determine regression type if not specified
|
310
645
|
if regression_type is None:
|
311
|
-
regression_type =
|
646
|
+
regression_type = check_distribution(df[dependent_variable])
|
647
|
+
|
648
|
+
print(f"Using regression type: {regression_type}")
|
312
649
|
|
313
650
|
df = check_and_clean_data(df, dependent_variable)
|
314
651
|
|
@@ -319,102 +656,51 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
|
|
319
656
|
mixed_model, coef_df = fit_mixed_model(df, formula, dst)
|
320
657
|
model = mixed_model
|
321
658
|
else:
|
322
|
-
#
|
659
|
+
# Prepare the formula
|
323
660
|
formula = prepare_formula(dependent_variable, random_row_column_effects=False)
|
324
661
|
y, X = dmatrices(formula, data=df, return_type='dataframe')
|
325
|
-
|
662
|
+
|
326
663
|
# Plot histogram of the dependent variable
|
327
664
|
plot_histogram(y, dependent_variable, dst=dst)
|
665
|
+
plot_histogram(df, 'fraction', dst=dst)
|
328
666
|
|
329
667
|
# Scale the independent variables and dependent variable
|
330
|
-
|
668
|
+
if regression_type in ['beta', 'quasi_binomial', 'logit']:
|
669
|
+
print('Data will not be scaled')
|
670
|
+
else:
|
671
|
+
X, y = scale_variables(X, y)
|
331
672
|
|
332
673
|
# Perform the regression
|
333
674
|
groups = df['prc'] if regression_type == 'mixed' else None
|
334
|
-
print(f'
|
335
|
-
|
675
|
+
print(f'Performing {regression_type} regression')
|
676
|
+
|
336
677
|
model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, cov_type=cov_type)
|
337
678
|
|
338
679
|
# Process the model coefficients
|
339
680
|
coef_df = process_model_coefficients(model, regression_type, X, y, nc, pc, controls)
|
340
|
-
|
681
|
+
display(coef_df)
|
341
682
|
if plot:
|
342
683
|
volcano_plot(coef_df, volcano_path)
|
684
|
+
|
685
|
+
return model, coef_df, regression_type
|
343
686
|
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
def _line_plot(df, x, y, log_x=False, log_y=False, title=""):
|
351
|
-
fig, ax = plt.subplots(figsize=(10, 6))
|
352
|
-
ax.plot(df[x], df[y], linestyle='-', color=(0, 0.6, 0.6), label=f"{y}")
|
353
|
-
ax.set_xlabel(x)
|
354
|
-
ax.set_ylabel(y)
|
355
|
-
ax.set_title(title)
|
356
|
-
ax.legend()
|
357
|
-
if log_x:
|
358
|
-
ax.set_xscale('log')
|
359
|
-
if log_y:
|
360
|
-
ax.set_yscale('log')
|
361
|
-
plt.show()
|
362
|
-
|
363
|
-
if isinstance(settings['score_data'], str):
|
364
|
-
settings['score_data'] = [settings['score_data']]
|
365
|
-
|
366
|
-
dfs = []
|
367
|
-
for i, score_data in enumerate(settings['score_data']):
|
368
|
-
df = pd.read_csv(score_data)
|
369
|
-
df = correct_metadata_column_names(df)
|
370
|
-
df['plate'] = f'plate{i+1}'
|
371
|
-
df['prc'] = df['plate'] + '_' + df['row'].astype(str) + '_' + df['column'].astype(str)
|
372
|
-
dfs.append(df)
|
373
|
-
|
374
|
-
df = pd.concat(dfs, axis=0)
|
375
|
-
|
376
|
-
# Compute the number of cells (or scores) per well
|
377
|
-
cell_counts = df.groupby('prc').size().reset_index(name='cell_count')
|
378
|
-
|
379
|
-
# Merge the cell counts back into the original DataFrame
|
380
|
-
df = df.merge(cell_counts, on='prc')
|
381
|
-
|
382
|
-
# Generate a range of thresholds
|
383
|
-
thresholds = np.arange(1, df['cell_count'].max() + 1)
|
384
|
-
results = []
|
385
|
-
|
386
|
-
# Iterate over thresholds and compute score mean and variance
|
387
|
-
for threshold in thresholds:
|
388
|
-
filtered_df = df[df['cell_count'] >= threshold]
|
389
|
-
score_mean = filtered_df.groupby('prc')[settings['score_column']].mean().mean()
|
390
|
-
score_variance = filtered_df.groupby('prc')[settings['score_column']].mean().var()
|
391
|
-
results.append((threshold, score_mean, score_variance))
|
392
|
-
|
393
|
-
results_df = pd.DataFrame(results, columns=['cell_count_threshold', 'score_mean', 'score_variance'])
|
394
|
-
|
395
|
-
if results_df.empty:
|
396
|
-
raise ValueError("No valid results were found. Check your data and thresholds.")
|
397
|
-
|
398
|
-
closest_threshold = results_df['score_variance'].diff().abs().argmin()
|
399
|
-
optimal_threshold = results_df.iloc[closest_threshold]
|
400
|
-
|
401
|
-
print(f"Optimal Threshold: {optimal_threshold['cell_count_threshold']}")
|
402
|
-
print(f"Score Mean at Optimal Threshold: {optimal_threshold['score_mean']}")
|
403
|
-
print(f"Score Variance at Optimal Threshold: {optimal_threshold['score_variance']}")
|
404
|
-
|
405
|
-
_line_plot(results_df, x='cell_count_threshold', y='score_mean',
|
406
|
-
title='Mean Well Score vs. Cell Count Threshold')
|
407
|
-
_line_plot(results_df, x='cell_count_threshold', y='score_variance',
|
408
|
-
title='Score Variance vs. Cell Count Threshold')
|
687
|
+
def save_summary_to_file(model, file_path='summary.csv'):
|
688
|
+
"""
|
689
|
+
Save the model's summary output to a CSV or text file.
|
690
|
+
"""
|
691
|
+
# Get the summary as a string
|
692
|
+
summary_str = model.summary().as_text()
|
409
693
|
|
410
|
-
|
694
|
+
# Save it as a plain text file or CSV
|
695
|
+
with open(file_path, 'w') as f:
|
696
|
+
f.write(summary_str)
|
411
697
|
|
412
698
|
def perform_regression(settings):
|
413
699
|
|
414
700
|
from .plot import plot_plates
|
415
701
|
from .utils import merge_regression_res_with_metadata, save_settings
|
416
702
|
from .settings import get_perform_regression_default_settings
|
417
|
-
from .toxo import go_term_enrichment_by_column, custom_volcano_plot
|
703
|
+
from .toxo import go_term_enrichment_by_column, custom_volcano_plot, plot_gene_phenotypes, plot_gene_heatmaps
|
418
704
|
from .sequencing import graph_sequencing_stats
|
419
705
|
|
420
706
|
def _perform_regression_read_data(settings):
|
@@ -460,7 +746,7 @@ def perform_regression(settings):
|
|
460
746
|
if not settings['class_1_threshold'] is None:
|
461
747
|
score_data_df['predictions'] = (score_data_df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
|
462
748
|
|
463
|
-
reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
|
749
|
+
reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge', None]
|
464
750
|
if settings['regression_type'] not in reg_types:
|
465
751
|
print(f'Possible regression types: {reg_types}')
|
466
752
|
raise ValueError(f"Unsupported regression type {settings['regression_type']}")
|
@@ -468,7 +754,7 @@ def perform_regression(settings):
|
|
468
754
|
return count_data_df, score_data_df
|
469
755
|
|
470
756
|
def _perform_regression_set_paths(settings):
|
471
|
-
|
757
|
+
|
472
758
|
if isinstance(settings['score_data'], list):
|
473
759
|
score_data = settings['score_data'][0]
|
474
760
|
else:
|
@@ -484,7 +770,11 @@ def perform_regression(settings):
|
|
484
770
|
csv_path = settings['count_data']
|
485
771
|
|
486
772
|
settings['src'] = src
|
487
|
-
|
773
|
+
|
774
|
+
if settings['regression_type'] is None:
|
775
|
+
res_folder = os.path.join(src, 'results', score_source, 'auto')
|
776
|
+
else:
|
777
|
+
res_folder = os.path.join(src, 'results', score_source, settings['regression_type'])
|
488
778
|
|
489
779
|
if isinstance(settings['count_data'], list):
|
490
780
|
res_folder = os.path.join(res_folder, 'list')
|
@@ -536,7 +826,10 @@ def perform_regression(settings):
|
|
536
826
|
print(f"Dependent variable after clean_controls: {len(score_data_df)}")
|
537
827
|
|
538
828
|
if settings['min_cell_count'] is None:
|
539
|
-
settings['min_cell_count'] =
|
829
|
+
settings['min_cell_count'] = minimum_cell_simulation(settings, tolerance=settings['tolerance'])
|
830
|
+
print(f"Minimum cell count: {settings['min_cell_count']}")
|
831
|
+
|
832
|
+
orig_dv = settings['dependent_variable']
|
540
833
|
|
541
834
|
dependent_df, dependent_variable = process_scores(score_data_df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
|
542
835
|
print(f"Dependent variable after process_scores: {len(dependent_df)}")
|
@@ -551,15 +844,16 @@ def perform_regression(settings):
|
|
551
844
|
|
552
845
|
merged_df = pd.merge(independent_df, dependent_df, on='prc')
|
553
846
|
|
847
|
+
os.makedirs(res_folder, exist_ok=True)
|
554
848
|
data_path = os.path.join(res_folder, 'regression_data.csv')
|
555
849
|
merged_df.to_csv(data_path, index=False)
|
850
|
+
print(f"Saved regression data to {data_path}")
|
556
851
|
|
557
852
|
merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
|
558
853
|
|
559
|
-
|
560
|
-
_ = plot_plates(score_data_df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'], dst = res_folder)
|
854
|
+
_ = plot_plates(merged_df, variable=orig_dv, grouping='mean', min_max='allq', cmap='viridis', min_count=None, dst=res_folder)
|
561
855
|
|
562
|
-
model, coef_df = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
|
856
|
+
model, coef_df, regression_type = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
|
563
857
|
|
564
858
|
coef_df['grna'] = coef_df['feature'].apply(lambda x: re.search(r'grna\[(.*?)\]', x).group(1) if 'grna' in x else None)
|
565
859
|
coef_df['gene'] = coef_df['feature'].apply(lambda x: re.search(r'gene\[(.*?)\]', x).group(1) if 'gene' in x else None)
|
@@ -592,7 +886,7 @@ def perform_regression(settings):
|
|
592
886
|
gene_coef_df.to_csv(results_path_gene, index=False)
|
593
887
|
grna_coef_df.to_csv(results_path_grna, index=False)
|
594
888
|
|
595
|
-
if
|
889
|
+
if regression_type == 'lasso':
|
596
890
|
significant = coef_df[coef_df['coefficient'] > 0]
|
597
891
|
|
598
892
|
else:
|
@@ -604,8 +898,9 @@ def perform_regression(settings):
|
|
604
898
|
significant.sort_values(by='coefficient', ascending=False, inplace=True)
|
605
899
|
significant = significant[~significant['feature'].str.contains('row|column')]
|
606
900
|
|
607
|
-
if
|
901
|
+
if regression_type in ['ols', 'beta']:
|
608
902
|
print(model.summary())
|
903
|
+
save_summary_to_file(model, file_path=f'{res_folder}/mode_summary.csv')
|
609
904
|
|
610
905
|
significant.to_csv(hits_path, index=False)
|
611
906
|
|
@@ -633,15 +928,32 @@ def perform_regression(settings):
|
|
633
928
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
634
929
|
metadata_path = os.path.join(base_dir, 'resources', 'data', 'lopit.csv')
|
635
930
|
|
931
|
+
display(data_path)
|
932
|
+
|
636
933
|
if settings['volcano'] == 'all':
|
637
934
|
print('all')
|
638
|
-
custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold,
|
935
|
+
gene_list = custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
|
936
|
+
display(gene_list)
|
639
937
|
elif settings['volcano'] == 'gene':
|
640
938
|
print('gene')
|
641
|
-
custom_volcano_plot(data_path_gene, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold,
|
939
|
+
gene_list = custom_volcano_plot(data_path_gene, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
|
940
|
+
display(gene_list)
|
642
941
|
elif settings['volcano'] == 'grna':
|
643
942
|
print('grna')
|
644
|
-
custom_volcano_plot(data_path_grna, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold,
|
943
|
+
gene_list = custom_volcano_plot(data_path_grna, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
|
944
|
+
display(gene_list)
|
945
|
+
phenotype_plot = os.path.join(res_folder,'phenotype_plot.pdf')
|
946
|
+
transcription_heatmap = os.path.join(res_folder,'transcription_heatmap.pdf')
|
947
|
+
data_GT1 = pd.read_csv(settings['metadata_files'][1], low_memory=False)
|
948
|
+
data_ME49 = pd.read_csv(settings['metadata_files'][0], low_memory=False)
|
949
|
+
|
950
|
+
columns = ['sense - Tachyzoites', 'sense - Tissue cysts', 'sense - EES1', 'sense - EES2', 'sense - EES3', 'sense - EES4', 'sense - EES5']
|
951
|
+
|
952
|
+
print('Plotting gene phenotypes and heatmaps')
|
953
|
+
print(gene_list)
|
954
|
+
|
955
|
+
plot_gene_phenotypes(data=data_GT1, gene_list=gene_list, save_path=phenotype_plot)
|
956
|
+
plot_gene_heatmaps(data=data_ME49, gene_list=gene_list, columns=columns, x_column='Gene ID', normalize=True, save_path=transcription_heatmap)
|
645
957
|
|
646
958
|
#if len(significant) > 2:
|
647
959
|
# metadata_path = os.path.join(base_dir, 'resources', 'data', 'toxoplasma_metadata.csv')
|