spacr 0.3.46__py3-none-any.whl → 0.3.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/chat_bot.py +31 -0
- spacr/gui_elements.py +33 -7
- spacr/ml.py +478 -76
- spacr/plot.py +488 -47
- spacr/sequencing.py +122 -1
- spacr/settings.py +2 -1
- spacr/toxo.py +266 -147
- spacr/utils.py +27 -4
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/METADATA +2 -1
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/RECORD +14 -13
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/LICENSE +0 -0
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/WHEEL +0 -0
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/top_level.txt +0 -0
spacr/ml.py
CHANGED
@@ -4,15 +4,24 @@ import numpy as np
|
|
4
4
|
from scipy import stats
|
5
5
|
from scipy.stats import shapiro
|
6
6
|
|
7
|
+
from sklearn.linear_model import Lasso, Ridge, LassoCV, RidgeCV
|
8
|
+
from sklearn.metrics import mean_squared_error
|
9
|
+
import numpy as np
|
10
|
+
|
7
11
|
import matplotlib.pyplot as plt
|
8
12
|
from IPython.display import display
|
9
|
-
|
13
|
+
import scipy.stats as st
|
10
14
|
import statsmodels.api as sm
|
11
15
|
import statsmodels.formula.api as smf
|
16
|
+
from statsmodels.tools import add_constant
|
12
17
|
from statsmodels.regression.mixed_linear_model import MixedLM
|
13
18
|
from statsmodels.tools.sm_exceptions import PerfectSeparationError
|
14
19
|
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
15
|
-
|
20
|
+
from statsmodels.genmod.families import Binomial
|
21
|
+
from statsmodels.genmod.families.links import logit
|
22
|
+
from statsmodels.othermod.betareg import BetaModel
|
23
|
+
from scipy.optimize import minimize
|
24
|
+
from scipy.special import gammaln, psi, expit
|
16
25
|
from sklearn.linear_model import Lasso, Ridge
|
17
26
|
from sklearn.preprocessing import FunctionTransformer
|
18
27
|
from patsy import dmatrices
|
@@ -24,17 +33,30 @@ from sklearn.inspection import permutation_importance
|
|
24
33
|
from sklearn.metrics import classification_report, precision_recall_curve
|
25
34
|
from sklearn.preprocessing import StandardScaler
|
26
35
|
from sklearn.preprocessing import MinMaxScaler
|
27
|
-
|
28
36
|
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
29
|
-
|
30
37
|
from xgboost import XGBClassifier
|
31
38
|
|
39
|
+
import numpy as np
|
40
|
+
from scipy.stats import kstest, normaltest
|
41
|
+
import statsmodels.api as sm
|
42
|
+
|
32
43
|
import matplotlib
|
33
44
|
matplotlib.use('Agg')
|
34
45
|
|
35
46
|
import warnings
|
36
47
|
warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
|
37
48
|
|
49
|
+
|
50
|
+
class QuasiBinomial(Binomial):
|
51
|
+
"""Custom Quasi-Binomial family with adjustable variance."""
|
52
|
+
def __init__(self, link=logit(), dispersion=1.0):
|
53
|
+
super().__init__(link=link)
|
54
|
+
self.dispersion = dispersion
|
55
|
+
|
56
|
+
def variance(self, mu):
|
57
|
+
"""Adjust the variance with the dispersion parameter."""
|
58
|
+
return self.dispersion * super().variance(mu)
|
59
|
+
|
38
60
|
def calculate_p_values(X, y, model):
|
39
61
|
# Predict y values
|
40
62
|
y_pred = model.predict(X)
|
@@ -75,46 +97,6 @@ def perform_mixed_model(y, X, groups, alpha=1.0):
|
|
75
97
|
result = model.fit()
|
76
98
|
return result
|
77
99
|
|
78
|
-
def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
|
79
|
-
|
80
|
-
def plot_regression_line(X, y, model):
|
81
|
-
"""Helper to plot regression line for lasso and ridge models."""
|
82
|
-
y_pred = model.predict(X)
|
83
|
-
plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
|
84
|
-
plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
|
85
|
-
plt.xlabel('Features')
|
86
|
-
plt.ylabel('Dependent Variable')
|
87
|
-
plt.legend()
|
88
|
-
plt.show()
|
89
|
-
|
90
|
-
# Define the dictionary with callables (lambdas) to delay evaluation
|
91
|
-
model_map = {
|
92
|
-
'ols': lambda: sm.OLS(y, X).fit(cov_type=cov_type) if cov_type else sm.OLS(y, X).fit(),
|
93
|
-
'gls': lambda: sm.GLS(y, X).fit(),
|
94
|
-
'wls': lambda: sm.WLS(y, X, weights=1 / np.sqrt(X.iloc[:, 1])).fit(),
|
95
|
-
'rlm': lambda: sm.RLM(y, X, M=sm.robust.norms.HuberT()).fit(),
|
96
|
-
'glm': lambda: sm.GLM(y, X, family=sm.families.Gaussian()).fit(),
|
97
|
-
'quantile': lambda: sm.QuantReg(y, X).fit(q=alpha),
|
98
|
-
'logit': lambda: sm.Logit(y, X).fit(),
|
99
|
-
'probit': lambda: sm.Probit(y, X).fit(),
|
100
|
-
'poisson': lambda: sm.Poisson(y, X).fit(),
|
101
|
-
'lasso': lambda: Lasso(alpha=alpha).fit(X, y),
|
102
|
-
'ridge': lambda: Ridge(alpha=alpha).fit(X, y)
|
103
|
-
}
|
104
|
-
|
105
|
-
# Call the appropriate model only when needed
|
106
|
-
if regression_type in model_map:
|
107
|
-
model = model_map[regression_type]()
|
108
|
-
elif regression_type == 'mixed':
|
109
|
-
model = perform_mixed_model(y, X, groups, alpha=alpha)
|
110
|
-
else:
|
111
|
-
raise ValueError(f"Unsupported regression type {regression_type}")
|
112
|
-
|
113
|
-
if regression_type in ['lasso', 'ridge']:
|
114
|
-
plot_regression_line(X, y, model)
|
115
|
-
|
116
|
-
return model
|
117
|
-
|
118
100
|
def create_volcano_filename(csv_path, regression_type, alpha, dst):
|
119
101
|
"""Create and return the volcano plot filename based on regression type and alpha."""
|
120
102
|
volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
|
@@ -173,6 +155,41 @@ def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
|
|
173
155
|
coef_df['condition'] = coef_df.apply(lambda row: 'nc' if nc in row['feature'] else 'pc' if pc in row['feature'] else ('control' if row['grna'] in controls else 'other'),axis=1)
|
174
156
|
return coef_df[~coef_df['feature'].str.contains('row|column')]
|
175
157
|
|
158
|
+
|
159
|
+
|
160
|
+
|
161
|
+
|
162
|
+
def check_distribution(y):
|
163
|
+
"""Check the type of distribution to recommend a model."""
|
164
|
+
if np.all((y == 0) | (y == 1)):
|
165
|
+
print("Detected binary data.")
|
166
|
+
return 'logit'
|
167
|
+
elif (y > 0).all() and (y < 1).all():
|
168
|
+
print("Detected continuous data between 0 and 1 (excluding 0 and 1).")
|
169
|
+
return 'beta'
|
170
|
+
elif (y >= 0).all() and (y <= 1).all():
|
171
|
+
print("Detected continuous data between 0 and 1 (including 0 or 1).")
|
172
|
+
# Consider quasi-binomial regression
|
173
|
+
return 'quasi_binomial'
|
174
|
+
else:
|
175
|
+
print("Using OLS as a fallback.")
|
176
|
+
return 'ols'
|
177
|
+
|
178
|
+
def select_glm_family(y):
|
179
|
+
"""Select the appropriate GLM family based on the data."""
|
180
|
+
if np.all((y == 0) | (y == 1)):
|
181
|
+
print("Using Binomial family (for binary data).")
|
182
|
+
return sm.families.Binomial()
|
183
|
+
elif (y >= 0).all() and (y <= 1).all():
|
184
|
+
print("Using Quasi-Binomial family (for proportion data including 0 and 1).")
|
185
|
+
return QuasiBinomial()
|
186
|
+
elif np.all(y.astype(int) == y) and (y >= 0).all():
|
187
|
+
print("Using Poisson family (for count data).")
|
188
|
+
return sm.families.Poisson()
|
189
|
+
else:
|
190
|
+
print("Using Gaussian family (for continuous data).")
|
191
|
+
return sm.families.Gaussian()
|
192
|
+
|
176
193
|
def prepare_formula(dependent_variable, random_row_column_effects=False):
|
177
194
|
"""Return the regression formula using random effects for plate, row, and column."""
|
178
195
|
if random_row_column_effects:
|
@@ -290,25 +307,345 @@ def check_and_clean_data(df, dependent_variable):
|
|
290
307
|
print("Data is ready for model fitting.")
|
291
308
|
return df_cleaned
|
292
309
|
|
293
|
-
def
|
294
|
-
|
310
|
+
def check_normality(y, variable_name):
|
311
|
+
"""Check if the data is normally distributed using the Shapiro-Wilk test."""
|
312
|
+
from scipy.stats import shapiro
|
313
|
+
|
314
|
+
stat, p = shapiro(y)
|
315
|
+
alpha = 0.05
|
316
|
+
if p > alpha:
|
317
|
+
print(f"{variable_name} is normally distributed (fail to reject H0)")
|
318
|
+
return True
|
319
|
+
else:
|
320
|
+
print(f"{variable_name} is not normally distributed (reject H0)")
|
321
|
+
return False
|
295
322
|
|
296
|
-
|
297
|
-
|
323
|
+
def minimum_cell_simulation(settings, num_repeats=10, sample_size=100, tolerance=0.02, smoothing=10, increment=10):
|
324
|
+
"""
|
325
|
+
Plot the mean absolute difference with standard deviation as shaded area vs. sample size.
|
326
|
+
Detect and mark the elbow point (inflection) with smoothing and tolerance control.
|
327
|
+
"""
|
328
|
+
|
329
|
+
from spacr.utils import correct_metadata_column_names
|
330
|
+
|
331
|
+
# Load and process data
|
332
|
+
if isinstance(settings['score_data'], str):
|
333
|
+
settings['score_data'] = [settings['score_data']]
|
334
|
+
|
335
|
+
dfs = []
|
336
|
+
for i, score_data in enumerate(settings['score_data']):
|
337
|
+
df = pd.read_csv(score_data)
|
338
|
+
df = correct_metadata_column_names(df)
|
339
|
+
df['plate'] = f'plate{i + 1}'
|
340
|
+
df['prc'] = df['plate'] + '_' + df['row'].astype(str) + '_' + df['column'].astype(str)
|
341
|
+
dfs.append(df)
|
342
|
+
|
343
|
+
df = pd.concat(dfs, axis=0)
|
344
|
+
|
345
|
+
# Compute the number of cells per well and select the top 100 wells by cell count
|
346
|
+
cell_counts = df.groupby('prc').size().reset_index(name='cell_count')
|
347
|
+
top_wells = cell_counts.nlargest(sample_size, 'cell_count')['prc']
|
348
|
+
|
349
|
+
# Filter the data to include only the top 100 wells
|
350
|
+
df = df[df['prc'].isin(top_wells)]
|
351
|
+
|
352
|
+
# Initialize storage for absolute difference data
|
353
|
+
diff_data = []
|
354
|
+
|
355
|
+
# Group by wells and iterate over them
|
356
|
+
for i, (prc, group) in enumerate(df.groupby('prc')):
|
357
|
+
original_mean = group[settings['score_column']].mean() # Original full-well mean
|
358
|
+
max_cells = len(group)
|
359
|
+
sample_sizes = np.arange(2, max_cells + 1, increment) # Sample sizes from 2 to max cells
|
360
|
+
|
361
|
+
# Iterate over sample sizes and compute absolute difference
|
362
|
+
for sample_size in sample_sizes:
|
363
|
+
abs_diffs = []
|
364
|
+
|
365
|
+
# Perform multiple random samples to reduce noise
|
366
|
+
for _ in range(num_repeats):
|
367
|
+
sample = group.sample(n=sample_size, replace=False)
|
368
|
+
sampled_mean = sample[settings['score_column']].mean()
|
369
|
+
abs_diff = abs(sampled_mean - original_mean) # Absolute difference
|
370
|
+
abs_diffs.append(abs_diff)
|
371
|
+
|
372
|
+
# Compute the average absolute difference across all repeats
|
373
|
+
avg_abs_diff = np.mean(abs_diffs)
|
374
|
+
|
375
|
+
# Store the result for plotting
|
376
|
+
diff_data.append((sample_size, avg_abs_diff))
|
377
|
+
|
378
|
+
# Convert absolute difference data to DataFrame for plotting
|
379
|
+
diff_df = pd.DataFrame(diff_data, columns=['sample_size', 'avg_abs_diff'])
|
380
|
+
|
381
|
+
# Group by sample size to calculate mean and standard deviation
|
382
|
+
summary_df = diff_df.groupby('sample_size').agg(
|
383
|
+
mean_abs_diff=('avg_abs_diff', 'mean'),
|
384
|
+
std_abs_diff=('avg_abs_diff', 'std')
|
385
|
+
).reset_index()
|
386
|
+
|
387
|
+
# Apply smoothing using a rolling window
|
388
|
+
summary_df['smoothed_mean_abs_diff'] = summary_df['mean_abs_diff'].rolling(window=smoothing, min_periods=1).mean()
|
389
|
+
|
390
|
+
# Detect the elbow point (where mean_abs_diff < tolerance)
|
391
|
+
elbow_df = summary_df[summary_df['smoothed_mean_abs_diff'] <= tolerance]
|
392
|
+
|
393
|
+
# Select the first occurrence if it exists; otherwise, use the last point
|
394
|
+
if not elbow_df.empty:
|
395
|
+
elbow_point = elbow_df.iloc[0] # First point where the condition is met
|
396
|
+
else:
|
397
|
+
elbow_point = summary_df.iloc[-1] # Fallback to the last point
|
398
|
+
|
399
|
+
# Plot the mean absolute difference with standard deviation as shaded area
|
400
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
401
|
+
ax.plot(
|
402
|
+
summary_df['sample_size'], summary_df['smoothed_mean_abs_diff'], color='teal', label='Smoothed Mean Absolute Difference'
|
403
|
+
)
|
404
|
+
ax.fill_between(
|
405
|
+
summary_df['sample_size'],
|
406
|
+
summary_df['smoothed_mean_abs_diff'] - summary_df['std_abs_diff'],
|
407
|
+
summary_df['smoothed_mean_abs_diff'] + summary_df['std_abs_diff'],
|
408
|
+
color='teal', alpha=0.3, label='±1 Std. Dev.'
|
409
|
+
)
|
410
|
+
|
411
|
+
# Mark the elbow point (inflection) on the plot
|
412
|
+
ax.axvline(elbow_point['sample_size'], color='black', linestyle='--', label='Elbow Point')
|
413
|
+
|
414
|
+
# Formatting the plot
|
415
|
+
ax.set_xlabel('Sample Size')
|
416
|
+
ax.set_ylabel('Mean Absolute Difference')
|
417
|
+
ax.set_title('Mean Absolute Difference vs. Sample Size with Standard Deviation')
|
418
|
+
ax.legend().remove()
|
419
|
+
|
420
|
+
# Save the plot if a destination is provided
|
421
|
+
dst = os.path.dirname(settings['count_data'][0])
|
422
|
+
if dst is not None:
|
423
|
+
fig_path = os.path.join(dst, 'results')
|
424
|
+
os.makedirs(fig_path, exist_ok=True)
|
425
|
+
fig_file_path = os.path.join(fig_path, 'cell_min_threshold.pdf')
|
426
|
+
fig.savefig(fig_file_path, format='pdf', dpi=600, bbox_inches='tight')
|
427
|
+
print(f"Saved {fig_file_path}")
|
428
|
+
|
429
|
+
plt.show()
|
430
|
+
return elbow_point['sample_size']
|
431
|
+
|
432
|
+
def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
|
433
|
+
"""Return DataFrame of model coefficients, standard errors, and p-values."""
|
434
|
+
|
435
|
+
if regression_type == 'beta':
|
436
|
+
# Extract coefficients and standard errors
|
437
|
+
coefs = model.params
|
438
|
+
std_err = model.bse
|
439
|
+
|
440
|
+
# Compute Wald test (coefficient / standard error)
|
441
|
+
wald_stats = coefs / std_err
|
442
|
+
|
443
|
+
# Calculate two-tailed p-values
|
444
|
+
p_values = 2 * (1 - st.norm.cdf(np.abs(wald_stats)))
|
445
|
+
|
446
|
+
coef_df = pd.DataFrame({
|
447
|
+
'feature': coefs.index,
|
448
|
+
'coefficient': coefs.values,
|
449
|
+
'std_err': std_err.values,
|
450
|
+
'wald_stat': wald_stats.values,
|
451
|
+
'p_value': p_values
|
452
|
+
})
|
453
|
+
|
454
|
+
elif regression_type in ['ols', 'glm', 'logit', 'probit', 'quasi_binomial']:
|
455
|
+
coefs = model.params
|
456
|
+
p_values = model.pvalues
|
457
|
+
|
458
|
+
coef_df = pd.DataFrame({
|
459
|
+
'feature': coefs.index,
|
460
|
+
'coefficient': coefs.values,
|
461
|
+
'p_value': p_values.values
|
462
|
+
})
|
463
|
+
|
464
|
+
elif regression_type in ['ridge', 'lasso']:
|
465
|
+
coefs = model.coef_.flatten()
|
466
|
+
p_values = calculate_p_values(X, y, model)
|
467
|
+
|
468
|
+
coef_df = pd.DataFrame({
|
469
|
+
'feature': X.columns,
|
470
|
+
'coefficient': coefs,
|
471
|
+
'p_value': p_values
|
472
|
+
})
|
473
|
+
|
474
|
+
else:
|
475
|
+
raise ValueError(f"Unsupported regression type: {regression_type}")
|
476
|
+
|
477
|
+
# Additional formatting
|
478
|
+
coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
|
479
|
+
coef_df['grna'] = coef_df['feature'].str.extract(r'\[(.*?)\]')[0]
|
480
|
+
coef_df['condition'] = coef_df.apply(
|
481
|
+
lambda row: 'nc' if nc in row['feature'] else
|
482
|
+
'pc' if pc in row['feature'] else
|
483
|
+
('control' if row['grna'] in controls else 'other'),
|
484
|
+
axis=1
|
485
|
+
)
|
298
486
|
|
299
|
-
|
300
|
-
|
487
|
+
return coef_df[~coef_df['feature'].str.contains('row|column')]
|
488
|
+
|
489
|
+
def check_distribution(y, epsilon=1e-6):
|
490
|
+
"""Check the distribution of y and recommend an appropriate model."""
|
491
|
+
|
492
|
+
# Check if the dependent variable is binary (only 0 and 1)
|
493
|
+
if np.all((y == 0) | (y == 1)):
|
494
|
+
print("Detected binary data.")
|
495
|
+
return 'logit'
|
496
|
+
|
497
|
+
# Continuous data between 0 and 1 (excluding exact 0 and 1)
|
498
|
+
elif (y > 0).all() and (y < 1).all():
|
499
|
+
# Check if the data is close to 0 or 1 (boundary issues)
|
500
|
+
if np.any((y < epsilon) | (y > 1 - epsilon)):
|
501
|
+
print("Detected continuous data near 0 or 1. Using quasi-binomial.")
|
502
|
+
return 'quasi_binomial'
|
503
|
+
else:
|
504
|
+
print("Detected continuous data between 0 and 1 (no boundary issues). Using beta regression.")
|
505
|
+
return 'beta'
|
506
|
+
|
507
|
+
# Continuous data between 0 and 1 (including exact 0 or 1)
|
508
|
+
elif (y >= 0).all() and (y <= 1).all():
|
509
|
+
print("Detected continuous data with boundary values (0 or 1). Using quasi-binomial.")
|
510
|
+
return 'quasi_binomial'
|
511
|
+
|
512
|
+
# Check if the data is normally distributed for OLS suitability
|
513
|
+
stat, p_value = stats.normaltest(y) # D’Agostino and Pearson’s test for normality
|
514
|
+
print(f"Normality test p-value: {p_value:.4f}")
|
515
|
+
|
516
|
+
if p_value > 0.05:
|
517
|
+
print("Detected normally distributed data. Using OLS.")
|
518
|
+
return 'ols'
|
519
|
+
|
520
|
+
# Check if the data fits a Beta distribution
|
521
|
+
if stats.kstest(y, 'beta', args=(2, 2)).pvalue > 0.05:
|
522
|
+
# Check if the data is close to 0 or 1 (boundary issues)
|
523
|
+
if np.any((y < epsilon) | (y > 1 - epsilon)):
|
524
|
+
print("Detected continuous data near 0 or 1. Using quasi-binomial.")
|
525
|
+
return 'quasi_binomial'
|
526
|
+
else:
|
527
|
+
print("Detected continuous data between 0 and 1 (no boundary issues). Using beta regression.")
|
528
|
+
return 'beta'
|
529
|
+
|
530
|
+
print("Detected non-normally distributed data. Using GLM.")
|
531
|
+
return 'glm'
|
532
|
+
|
533
|
+
def pick_glm_family_and_link(y):
|
534
|
+
"""Select the appropriate GLM family and link function based on data."""
|
535
|
+
if np.all((y == 0) | (y == 1)):
|
536
|
+
print("Binary data detected. Using Binomial family with Logit link.")
|
537
|
+
return sm.families.Binomial(link=sm.families.links.Logit())
|
538
|
+
|
539
|
+
elif (y > 0).all() and (y < 1).all():
|
540
|
+
print("Data strictly between 0 and 1. Beta regression recommended.")
|
541
|
+
raise ValueError("Use BetaModel for this data; GLM is not applicable.")
|
542
|
+
|
543
|
+
elif (y >= 0).all() and (y <= 1).all():
|
544
|
+
print("Data between 0 and 1 (including boundaries). Using Quasi-Binomial.")
|
545
|
+
return sm.families.Binomial(link=sm.families.links.Logit())
|
546
|
+
|
547
|
+
stat, p_value = normaltest(y)
|
548
|
+
print(f"Normality test p-value: {p_value:.4f}")
|
549
|
+
if p_value > 0.05:
|
550
|
+
print("Normally distributed data detected. Using Gaussian with Identity link.")
|
551
|
+
return sm.families.Gaussian(link=sm.families.links.Identity())
|
552
|
+
|
553
|
+
if (y >= 0).all() and np.all(y.astype(int) == y):
|
554
|
+
print("Count data detected. Using Poisson with Log link.")
|
555
|
+
return sm.families.Poisson(link=sm.families.links.Log())
|
556
|
+
|
557
|
+
if (y > 0).all() and kstest(y, 'invgauss', args=(1,)).pvalue > 0.05:
|
558
|
+
print("Inverse Gaussian distribution detected. Using InverseGaussian with Log link.")
|
559
|
+
return sm.families.InverseGaussian(link=sm.families.links.Log())
|
560
|
+
|
561
|
+
if (y >= 0).all():
|
562
|
+
print("Overdispersed count data detected. Using Negative Binomial with Log link.")
|
563
|
+
return sm.families.NegativeBinomial(link=sm.families.links.Log())
|
564
|
+
|
565
|
+
print("Using default Gaussian family with Identity link.")
|
566
|
+
return sm.families.Gaussian(link=sm.families.links.Identity())
|
567
|
+
|
568
|
+
def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
|
569
|
+
def plot_regression_line(X, y, model):
|
570
|
+
"""Helper to plot regression line for lasso and ridge models."""
|
571
|
+
y_pred = model.predict(X)
|
572
|
+
plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
|
573
|
+
plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
|
574
|
+
plt.xlabel('Features')
|
575
|
+
plt.ylabel('Dependent Variable')
|
576
|
+
plt.legend()
|
577
|
+
plt.show()
|
578
|
+
|
579
|
+
def find_best_alpha(model_cls):
|
580
|
+
"""Find optimal alpha using cross-validation."""
|
581
|
+
alphas = np.logspace(-5, 5, 100) # Search over a range of alphas
|
582
|
+
if model_cls == 'lasso':
|
583
|
+
model_cv = LassoCV(alphas=alphas, cv=5).fit(X, y)
|
584
|
+
elif model_cls == 'ridge':
|
585
|
+
model_cv = RidgeCV(alphas=alphas, cv=5).fit(X, y)
|
586
|
+
print(f"Optimal alpha for {model_cls}: {model_cv.alpha_}")
|
587
|
+
return model_cv
|
588
|
+
|
589
|
+
# Dictionary of models
|
590
|
+
model_map = {
|
591
|
+
'ols': lambda: sm.OLS(y, X).fit(cov_type=cov_type) if cov_type else sm.OLS(y, X).fit(),
|
592
|
+
'glm': lambda: sm.GLM(y, X, family=pick_glm_family_and_link(y)).fit(),
|
593
|
+
'beta': lambda: BetaModel(endog=y, exog=X).fit(),
|
594
|
+
'logit': lambda: sm.Logit(y, X).fit(),
|
595
|
+
'probit': lambda: sm.Probit(y, X).fit(),
|
596
|
+
'lasso': lambda: find_best_alpha('lasso') if alpha in [0, None] else Lasso(alpha=alpha).fit(X, y),
|
597
|
+
'ridge': lambda: find_best_alpha('ridge') if alpha in [0, None] else Ridge(alpha=alpha).fit(X, y)
|
598
|
+
}
|
301
599
|
|
302
|
-
|
303
|
-
|
304
|
-
|
600
|
+
# Select the model based on regression_type
|
601
|
+
if regression_type in model_map:
|
602
|
+
model = model_map[regression_type]()
|
603
|
+
elif regression_type == 'mixed':
|
604
|
+
model = perform_mixed_model(y, X, groups, alpha=alpha)
|
305
605
|
else:
|
306
|
-
|
307
|
-
|
606
|
+
raise ValueError(f"Unsupported regression type {regression_type}")
|
607
|
+
|
608
|
+
# Plot regression line for Lasso and Ridge
|
609
|
+
if regression_type in ['lasso', 'ridge']:
|
610
|
+
plot_regression_line(X, y, model)
|
611
|
+
|
612
|
+
# Handle GLM-specific statistics
|
613
|
+
if regression_type == 'glm':
|
614
|
+
llf_model = model.llf # Log-likelihood of the fitted model
|
615
|
+
llf_null = model.null_deviance / -2 # Log-likelihood of the null model
|
616
|
+
mcfadden_r2 = 1 - (llf_model / llf_null)
|
617
|
+
print(f"McFadden's R²: {mcfadden_r2:.4f}")
|
618
|
+
print(model.summary())
|
619
|
+
|
620
|
+
if regression_type in ['lasso', 'ridge']:
|
621
|
+
# Calculate the Mean Squared Error (MSE)
|
622
|
+
mse = mean_squared_error(y, model.predict(X))
|
623
|
+
print(f"{regression_type.capitalize()} Regression MSE: {mse:.4f}")
|
624
|
+
|
625
|
+
# Display coefficients
|
626
|
+
coef_df = pd.DataFrame({
|
627
|
+
'Feature': X.columns,
|
628
|
+
'Coefficient': model.coef_
|
629
|
+
})
|
630
|
+
print(coef_df)
|
631
|
+
|
632
|
+
return model
|
633
|
+
|
634
|
+
def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0,
|
635
|
+
random_row_column_effects=False, nc='233460', pc='220950', controls=[''],
|
636
|
+
dst=None, cov_type=None, plot=False):
|
637
|
+
|
638
|
+
from spacr.plot import volcano_plot, plot_histogram
|
639
|
+
from spacr.ml import create_volcano_filename, check_and_clean_data, prepare_formula, scale_variables
|
640
|
+
|
641
|
+
# Generate the volcano filename
|
642
|
+
volcano_path = create_volcano_filename(csv_path, regression_type, alpha, dst)
|
308
643
|
|
309
644
|
# Determine regression type if not specified
|
310
645
|
if regression_type is None:
|
311
|
-
regression_type =
|
646
|
+
regression_type = check_distribution(df[dependent_variable])
|
647
|
+
|
648
|
+
print(f"Using regression type: {regression_type}")
|
312
649
|
|
313
650
|
df = check_and_clean_data(df, dependent_variable)
|
314
651
|
|
@@ -319,36 +656,52 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
|
|
319
656
|
mixed_model, coef_df = fit_mixed_model(df, formula, dst)
|
320
657
|
model = mixed_model
|
321
658
|
else:
|
322
|
-
#
|
659
|
+
# Prepare the formula
|
323
660
|
formula = prepare_formula(dependent_variable, random_row_column_effects=False)
|
324
661
|
y, X = dmatrices(formula, data=df, return_type='dataframe')
|
325
|
-
|
662
|
+
|
326
663
|
# Plot histogram of the dependent variable
|
327
664
|
plot_histogram(y, dependent_variable, dst=dst)
|
665
|
+
plot_histogram(df, 'fraction', dst=dst)
|
328
666
|
|
329
667
|
# Scale the independent variables and dependent variable
|
330
|
-
|
668
|
+
if regression_type in ['beta', 'quasi_binomial', 'logit']:
|
669
|
+
print('Data will not be scaled')
|
670
|
+
else:
|
671
|
+
X, y = scale_variables(X, y)
|
331
672
|
|
332
673
|
# Perform the regression
|
333
674
|
groups = df['prc'] if regression_type == 'mixed' else None
|
334
|
-
print(f'
|
335
|
-
|
675
|
+
print(f'Performing {regression_type} regression')
|
676
|
+
|
336
677
|
model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, cov_type=cov_type)
|
337
678
|
|
338
679
|
# Process the model coefficients
|
339
680
|
coef_df = process_model_coefficients(model, regression_type, X, y, nc, pc, controls)
|
340
|
-
|
681
|
+
display(coef_df)
|
341
682
|
if plot:
|
342
683
|
volcano_plot(coef_df, volcano_path)
|
684
|
+
|
685
|
+
return model, coef_df, regression_type
|
343
686
|
|
344
|
-
|
687
|
+
def save_summary_to_file(model, file_path='summary.csv'):
|
688
|
+
"""
|
689
|
+
Save the model's summary output to a CSV or text file.
|
690
|
+
"""
|
691
|
+
# Get the summary as a string
|
692
|
+
summary_str = model.summary().as_text()
|
345
693
|
|
346
|
-
|
694
|
+
# Save it as a plain text file or CSV
|
695
|
+
with open(file_path, 'w') as f:
|
696
|
+
f.write(summary_str)
|
347
697
|
|
698
|
+
def perform_regression(settings):
|
699
|
+
|
348
700
|
from .plot import plot_plates
|
349
701
|
from .utils import merge_regression_res_with_metadata, save_settings
|
350
702
|
from .settings import get_perform_regression_default_settings
|
351
|
-
from .toxo import go_term_enrichment_by_column, custom_volcano_plot
|
703
|
+
from .toxo import go_term_enrichment_by_column, custom_volcano_plot, plot_gene_phenotypes, plot_gene_heatmaps
|
704
|
+
from .sequencing import graph_sequencing_stats
|
352
705
|
|
353
706
|
def _perform_regression_read_data(settings):
|
354
707
|
|
@@ -393,7 +746,7 @@ def perform_regression(settings):
|
|
393
746
|
if not settings['class_1_threshold'] is None:
|
394
747
|
score_data_df['predictions'] = (score_data_df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
|
395
748
|
|
396
|
-
reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
|
749
|
+
reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge', None]
|
397
750
|
if settings['regression_type'] not in reg_types:
|
398
751
|
print(f'Possible regression types: {reg_types}')
|
399
752
|
raise ValueError(f"Unsupported regression type {settings['regression_type']}")
|
@@ -401,7 +754,7 @@ def perform_regression(settings):
|
|
401
754
|
return count_data_df, score_data_df
|
402
755
|
|
403
756
|
def _perform_regression_set_paths(settings):
|
404
|
-
|
757
|
+
|
405
758
|
if isinstance(settings['score_data'], list):
|
406
759
|
score_data = settings['score_data'][0]
|
407
760
|
else:
|
@@ -417,7 +770,11 @@ def perform_regression(settings):
|
|
417
770
|
csv_path = settings['count_data']
|
418
771
|
|
419
772
|
settings['src'] = src
|
420
|
-
|
773
|
+
|
774
|
+
if settings['regression_type'] is None:
|
775
|
+
res_folder = os.path.join(src, 'results', score_source, 'auto')
|
776
|
+
else:
|
777
|
+
res_folder = os.path.join(src, 'results', score_source, settings['regression_type'])
|
421
778
|
|
422
779
|
if isinstance(settings['count_data'], list):
|
423
780
|
res_folder = os.path.join(res_folder, 'list')
|
@@ -468,9 +825,18 @@ def perform_regression(settings):
|
|
468
825
|
score_data_df = clean_controls(score_data_df, settings['filter_value'], settings['filter_column'])
|
469
826
|
print(f"Dependent variable after clean_controls: {len(score_data_df)}")
|
470
827
|
|
828
|
+
if settings['min_cell_count'] is None:
|
829
|
+
settings['min_cell_count'] = minimum_cell_simulation(settings, tolerance=settings['tolerance'])
|
830
|
+
print(f"Minimum cell count: {settings['min_cell_count']}")
|
831
|
+
|
832
|
+
orig_dv = settings['dependent_variable']
|
833
|
+
|
471
834
|
dependent_df, dependent_variable = process_scores(score_data_df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
|
472
835
|
print(f"Dependent variable after process_scores: {len(dependent_df)}")
|
473
836
|
|
837
|
+
if settings['fraction_threshold'] is None:
|
838
|
+
settings['fraction_threshold'] = graph_sequencing_stats(settings)
|
839
|
+
|
474
840
|
independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plate'], filter_column=filter_column, filter_value=filter_value)
|
475
841
|
independent_df, n_grna, n_gene = _count_variable_instances(independent_df, column_1='grna', column_2='gene')
|
476
842
|
|
@@ -478,15 +844,16 @@ def perform_regression(settings):
|
|
478
844
|
|
479
845
|
merged_df = pd.merge(independent_df, dependent_df, on='prc')
|
480
846
|
|
847
|
+
os.makedirs(res_folder, exist_ok=True)
|
481
848
|
data_path = os.path.join(res_folder, 'regression_data.csv')
|
482
849
|
merged_df.to_csv(data_path, index=False)
|
850
|
+
print(f"Saved regression data to {data_path}")
|
483
851
|
|
484
852
|
merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
|
485
853
|
|
486
|
-
|
487
|
-
_ = plot_plates(score_data_df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'], dst = res_folder)
|
854
|
+
_ = plot_plates(merged_df, variable=orig_dv, grouping='mean', min_max='allq', cmap='viridis', min_count=None, dst=res_folder)
|
488
855
|
|
489
|
-
model, coef_df = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
|
856
|
+
model, coef_df, regression_type = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
|
490
857
|
|
491
858
|
coef_df['grna'] = coef_df['feature'].apply(lambda x: re.search(r'grna\[(.*?)\]', x).group(1) if 'grna' in x else None)
|
492
859
|
coef_df['gene'] = coef_df['feature'].apply(lambda x: re.search(r'gene\[(.*?)\]', x).group(1) if 'gene' in x else None)
|
@@ -499,8 +866,12 @@ def perform_regression(settings):
|
|
499
866
|
grna_coef_df = grna_coef_df.dropna(subset=['n_grna'])
|
500
867
|
|
501
868
|
if settings['controls'] is not None:
|
869
|
+
|
502
870
|
control_coef_df = grna_coef_df[grna_coef_df['grna'].isin(settings['controls'])]
|
503
871
|
mean_coef = control_coef_df['coefficient'].mean()
|
872
|
+
significant_c = control_coef_df[control_coef_df['p_value']<= 0.05]
|
873
|
+
mean_coef_c = significant_c['coefficient'].mean()
|
874
|
+
print(mean_coef, mean_coef_c)
|
504
875
|
|
505
876
|
if settings['threshold_method'] in ['var','variance']:
|
506
877
|
coef_mes = control_coef_df['coefficient'].var()
|
@@ -508,13 +879,14 @@ def perform_regression(settings):
|
|
508
879
|
coef_mes = control_coef_df['coefficient'].std()
|
509
880
|
else:
|
510
881
|
raise ValueError(f"Unsupported threshold method {settings['threshold_method']}. Supported methods: ['var','variance','std','standard_deveation']")
|
882
|
+
|
511
883
|
reg_threshold = mean_coef + (settings['threshold_multiplier'] * coef_mes)
|
512
884
|
|
513
885
|
coef_df.to_csv(results_path, index=False)
|
514
886
|
gene_coef_df.to_csv(results_path_gene, index=False)
|
515
887
|
grna_coef_df.to_csv(results_path_grna, index=False)
|
516
888
|
|
517
|
-
if
|
889
|
+
if regression_type == 'lasso':
|
518
890
|
significant = coef_df[coef_df['coefficient'] > 0]
|
519
891
|
|
520
892
|
else:
|
@@ -526,11 +898,18 @@ def perform_regression(settings):
|
|
526
898
|
significant.sort_values(by='coefficient', ascending=False, inplace=True)
|
527
899
|
significant = significant[~significant['feature'].str.contains('row|column')]
|
528
900
|
|
529
|
-
if
|
901
|
+
if regression_type in ['ols', 'beta']:
|
530
902
|
print(model.summary())
|
903
|
+
save_summary_to_file(model, file_path=f'{res_folder}/mode_summary.csv')
|
531
904
|
|
532
905
|
significant.to_csv(hits_path, index=False)
|
533
906
|
|
907
|
+
significant_grna_filtered = significant[significant['n_grna'] > settings['min_n']]
|
908
|
+
significant_gene_filtered = significant[significant['n_gene'] > settings['min_n']]
|
909
|
+
significant_filtered = pd.concat([significant_grna_filtered, significant_gene_filtered])
|
910
|
+
filtered_hit_path = os.path.join(os.path.dirname(hits_path), 'results_significant_filtered.csv')
|
911
|
+
significant_filtered.to_csv(filtered_hit_path, index=False)
|
912
|
+
|
534
913
|
if isinstance(settings['metadata_files'], str):
|
535
914
|
settings['metadata_files'] = [settings['metadata_files']]
|
536
915
|
|
@@ -549,9 +928,32 @@ def perform_regression(settings):
|
|
549
928
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
550
929
|
metadata_path = os.path.join(base_dir, 'resources', 'data', 'lopit.csv')
|
551
930
|
|
552
|
-
|
553
|
-
|
554
|
-
|
931
|
+
display(data_path)
|
932
|
+
|
933
|
+
if settings['volcano'] == 'all':
|
934
|
+
print('all')
|
935
|
+
gene_list = custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
|
936
|
+
display(gene_list)
|
937
|
+
elif settings['volcano'] == 'gene':
|
938
|
+
print('gene')
|
939
|
+
gene_list = custom_volcano_plot(data_path_gene, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
|
940
|
+
display(gene_list)
|
941
|
+
elif settings['volcano'] == 'grna':
|
942
|
+
print('grna')
|
943
|
+
gene_list = custom_volcano_plot(data_path_grna, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
|
944
|
+
display(gene_list)
|
945
|
+
phenotype_plot = os.path.join(res_folder,'phenotype_plot.pdf')
|
946
|
+
transcription_heatmap = os.path.join(res_folder,'transcription_heatmap.pdf')
|
947
|
+
data_GT1 = pd.read_csv(settings['metadata_files'][1], low_memory=False)
|
948
|
+
data_ME49 = pd.read_csv(settings['metadata_files'][0], low_memory=False)
|
949
|
+
|
950
|
+
columns = ['sense - Tachyzoites', 'sense - Tissue cysts', 'sense - EES1', 'sense - EES2', 'sense - EES3', 'sense - EES4', 'sense - EES5']
|
951
|
+
|
952
|
+
print('Plotting gene phenotypes and heatmaps')
|
953
|
+
print(gene_list)
|
954
|
+
|
955
|
+
plot_gene_phenotypes(data=data_GT1, gene_list=gene_list, save_path=phenotype_plot)
|
956
|
+
plot_gene_heatmaps(data=data_ME49, gene_list=gene_list, columns=columns, x_column='Gene ID', normalize=True, save_path=transcription_heatmap)
|
555
957
|
|
556
958
|
#if len(significant) > 2:
|
557
959
|
# metadata_path = os.path.join(base_dir, 'resources', 'data', 'toxoplasma_metadata.csv')
|