spacr 0.3.46__py3-none-any.whl → 0.3.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/ml.py CHANGED
@@ -4,15 +4,24 @@ import numpy as np
4
4
  from scipy import stats
5
5
  from scipy.stats import shapiro
6
6
 
7
+ from sklearn.linear_model import Lasso, Ridge, LassoCV, RidgeCV
8
+ from sklearn.metrics import mean_squared_error
9
+ import numpy as np
10
+
7
11
  import matplotlib.pyplot as plt
8
12
  from IPython.display import display
9
-
13
+ import scipy.stats as st
10
14
  import statsmodels.api as sm
11
15
  import statsmodels.formula.api as smf
16
+ from statsmodels.tools import add_constant
12
17
  from statsmodels.regression.mixed_linear_model import MixedLM
13
18
  from statsmodels.tools.sm_exceptions import PerfectSeparationError
14
19
  from statsmodels.stats.outliers_influence import variance_inflation_factor
15
-
20
+ from statsmodels.genmod.families import Binomial
21
+ from statsmodels.genmod.families.links import logit
22
+ from statsmodels.othermod.betareg import BetaModel
23
+ from scipy.optimize import minimize
24
+ from scipy.special import gammaln, psi, expit
16
25
  from sklearn.linear_model import Lasso, Ridge
17
26
  from sklearn.preprocessing import FunctionTransformer
18
27
  from patsy import dmatrices
@@ -24,17 +33,30 @@ from sklearn.inspection import permutation_importance
24
33
  from sklearn.metrics import classification_report, precision_recall_curve
25
34
  from sklearn.preprocessing import StandardScaler
26
35
  from sklearn.preprocessing import MinMaxScaler
27
-
28
36
  from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
29
-
30
37
  from xgboost import XGBClassifier
31
38
 
39
+ import numpy as np
40
+ from scipy.stats import kstest, normaltest
41
+ import statsmodels.api as sm
42
+
32
43
  import matplotlib
33
44
  matplotlib.use('Agg')
34
45
 
35
46
  import warnings
36
47
  warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
37
48
 
49
+
50
+ class QuasiBinomial(Binomial):
51
+ """Custom Quasi-Binomial family with adjustable variance."""
52
+ def __init__(self, link=logit(), dispersion=1.0):
53
+ super().__init__(link=link)
54
+ self.dispersion = dispersion
55
+
56
+ def variance(self, mu):
57
+ """Adjust the variance with the dispersion parameter."""
58
+ return self.dispersion * super().variance(mu)
59
+
38
60
  def calculate_p_values(X, y, model):
39
61
  # Predict y values
40
62
  y_pred = model.predict(X)
@@ -75,46 +97,6 @@ def perform_mixed_model(y, X, groups, alpha=1.0):
75
97
  result = model.fit()
76
98
  return result
77
99
 
78
- def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
79
-
80
- def plot_regression_line(X, y, model):
81
- """Helper to plot regression line for lasso and ridge models."""
82
- y_pred = model.predict(X)
83
- plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
84
- plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
85
- plt.xlabel('Features')
86
- plt.ylabel('Dependent Variable')
87
- plt.legend()
88
- plt.show()
89
-
90
- # Define the dictionary with callables (lambdas) to delay evaluation
91
- model_map = {
92
- 'ols': lambda: sm.OLS(y, X).fit(cov_type=cov_type) if cov_type else sm.OLS(y, X).fit(),
93
- 'gls': lambda: sm.GLS(y, X).fit(),
94
- 'wls': lambda: sm.WLS(y, X, weights=1 / np.sqrt(X.iloc[:, 1])).fit(),
95
- 'rlm': lambda: sm.RLM(y, X, M=sm.robust.norms.HuberT()).fit(),
96
- 'glm': lambda: sm.GLM(y, X, family=sm.families.Gaussian()).fit(),
97
- 'quantile': lambda: sm.QuantReg(y, X).fit(q=alpha),
98
- 'logit': lambda: sm.Logit(y, X).fit(),
99
- 'probit': lambda: sm.Probit(y, X).fit(),
100
- 'poisson': lambda: sm.Poisson(y, X).fit(),
101
- 'lasso': lambda: Lasso(alpha=alpha).fit(X, y),
102
- 'ridge': lambda: Ridge(alpha=alpha).fit(X, y)
103
- }
104
-
105
- # Call the appropriate model only when needed
106
- if regression_type in model_map:
107
- model = model_map[regression_type]()
108
- elif regression_type == 'mixed':
109
- model = perform_mixed_model(y, X, groups, alpha=alpha)
110
- else:
111
- raise ValueError(f"Unsupported regression type {regression_type}")
112
-
113
- if regression_type in ['lasso', 'ridge']:
114
- plot_regression_line(X, y, model)
115
-
116
- return model
117
-
118
100
  def create_volcano_filename(csv_path, regression_type, alpha, dst):
119
101
  """Create and return the volcano plot filename based on regression type and alpha."""
120
102
  volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
@@ -173,6 +155,41 @@ def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
173
155
  coef_df['condition'] = coef_df.apply(lambda row: 'nc' if nc in row['feature'] else 'pc' if pc in row['feature'] else ('control' if row['grna'] in controls else 'other'),axis=1)
174
156
  return coef_df[~coef_df['feature'].str.contains('row|column')]
175
157
 
158
+
159
+
160
+
161
+
162
+ def check_distribution(y):
163
+ """Check the type of distribution to recommend a model."""
164
+ if np.all((y == 0) | (y == 1)):
165
+ print("Detected binary data.")
166
+ return 'logit'
167
+ elif (y > 0).all() and (y < 1).all():
168
+ print("Detected continuous data between 0 and 1 (excluding 0 and 1).")
169
+ return 'beta'
170
+ elif (y >= 0).all() and (y <= 1).all():
171
+ print("Detected continuous data between 0 and 1 (including 0 or 1).")
172
+ # Consider quasi-binomial regression
173
+ return 'quasi_binomial'
174
+ else:
175
+ print("Using OLS as a fallback.")
176
+ return 'ols'
177
+
178
+ def select_glm_family(y):
179
+ """Select the appropriate GLM family based on the data."""
180
+ if np.all((y == 0) | (y == 1)):
181
+ print("Using Binomial family (for binary data).")
182
+ return sm.families.Binomial()
183
+ elif (y >= 0).all() and (y <= 1).all():
184
+ print("Using Quasi-Binomial family (for proportion data including 0 and 1).")
185
+ return QuasiBinomial()
186
+ elif np.all(y.astype(int) == y) and (y >= 0).all():
187
+ print("Using Poisson family (for count data).")
188
+ return sm.families.Poisson()
189
+ else:
190
+ print("Using Gaussian family (for continuous data).")
191
+ return sm.families.Gaussian()
192
+
176
193
  def prepare_formula(dependent_variable, random_row_column_effects=False):
177
194
  """Return the regression formula using random effects for plate, row, and column."""
178
195
  if random_row_column_effects:
@@ -290,25 +307,345 @@ def check_and_clean_data(df, dependent_variable):
290
307
  print("Data is ready for model fitting.")
291
308
  return df_cleaned
292
309
 
293
- def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0, random_row_column_effects=False, nc='233460', pc='220950', controls=[''], dst=None, cov_type=None, plot=False):
294
- from .plot import volcano_plot, plot_histogram
310
+ def check_normality(y, variable_name):
311
+ """Check if the data is normally distributed using the Shapiro-Wilk test."""
312
+ from scipy.stats import shapiro
313
+
314
+ stat, p = shapiro(y)
315
+ alpha = 0.05
316
+ if p > alpha:
317
+ print(f"{variable_name} is normally distributed (fail to reject H0)")
318
+ return True
319
+ else:
320
+ print(f"{variable_name} is not normally distributed (reject H0)")
321
+ return False
295
322
 
296
- # Generate the volcano filename
297
- volcano_path = create_volcano_filename(csv_path, regression_type, alpha, dst)
323
+ def minimum_cell_simulation(settings, num_repeats=10, sample_size=100, tolerance=0.02, smoothing=10, increment=10):
324
+ """
325
+ Plot the mean absolute difference with standard deviation as shaded area vs. sample size.
326
+ Detect and mark the elbow point (inflection) with smoothing and tolerance control.
327
+ """
328
+
329
+ from spacr.utils import correct_metadata_column_names
330
+
331
+ # Load and process data
332
+ if isinstance(settings['score_data'], str):
333
+ settings['score_data'] = [settings['score_data']]
334
+
335
+ dfs = []
336
+ for i, score_data in enumerate(settings['score_data']):
337
+ df = pd.read_csv(score_data)
338
+ df = correct_metadata_column_names(df)
339
+ df['plate'] = f'plate{i + 1}'
340
+ df['prc'] = df['plate'] + '_' + df['row'].astype(str) + '_' + df['column'].astype(str)
341
+ dfs.append(df)
342
+
343
+ df = pd.concat(dfs, axis=0)
344
+
345
+ # Compute the number of cells per well and select the top 100 wells by cell count
346
+ cell_counts = df.groupby('prc').size().reset_index(name='cell_count')
347
+ top_wells = cell_counts.nlargest(sample_size, 'cell_count')['prc']
348
+
349
+ # Filter the data to include only the top 100 wells
350
+ df = df[df['prc'].isin(top_wells)]
351
+
352
+ # Initialize storage for absolute difference data
353
+ diff_data = []
354
+
355
+ # Group by wells and iterate over them
356
+ for i, (prc, group) in enumerate(df.groupby('prc')):
357
+ original_mean = group[settings['score_column']].mean() # Original full-well mean
358
+ max_cells = len(group)
359
+ sample_sizes = np.arange(2, max_cells + 1, increment) # Sample sizes from 2 to max cells
360
+
361
+ # Iterate over sample sizes and compute absolute difference
362
+ for sample_size in sample_sizes:
363
+ abs_diffs = []
364
+
365
+ # Perform multiple random samples to reduce noise
366
+ for _ in range(num_repeats):
367
+ sample = group.sample(n=sample_size, replace=False)
368
+ sampled_mean = sample[settings['score_column']].mean()
369
+ abs_diff = abs(sampled_mean - original_mean) # Absolute difference
370
+ abs_diffs.append(abs_diff)
371
+
372
+ # Compute the average absolute difference across all repeats
373
+ avg_abs_diff = np.mean(abs_diffs)
374
+
375
+ # Store the result for plotting
376
+ diff_data.append((sample_size, avg_abs_diff))
377
+
378
+ # Convert absolute difference data to DataFrame for plotting
379
+ diff_df = pd.DataFrame(diff_data, columns=['sample_size', 'avg_abs_diff'])
380
+
381
+ # Group by sample size to calculate mean and standard deviation
382
+ summary_df = diff_df.groupby('sample_size').agg(
383
+ mean_abs_diff=('avg_abs_diff', 'mean'),
384
+ std_abs_diff=('avg_abs_diff', 'std')
385
+ ).reset_index()
386
+
387
+ # Apply smoothing using a rolling window
388
+ summary_df['smoothed_mean_abs_diff'] = summary_df['mean_abs_diff'].rolling(window=smoothing, min_periods=1).mean()
389
+
390
+ # Detect the elbow point (where mean_abs_diff < tolerance)
391
+ elbow_df = summary_df[summary_df['smoothed_mean_abs_diff'] <= tolerance]
392
+
393
+ # Select the first occurrence if it exists; otherwise, use the last point
394
+ if not elbow_df.empty:
395
+ elbow_point = elbow_df.iloc[0] # First point where the condition is met
396
+ else:
397
+ elbow_point = summary_df.iloc[-1] # Fallback to the last point
398
+
399
+ # Plot the mean absolute difference with standard deviation as shaded area
400
+ fig, ax = plt.subplots(figsize=(10, 10))
401
+ ax.plot(
402
+ summary_df['sample_size'], summary_df['smoothed_mean_abs_diff'], color='teal', label='Smoothed Mean Absolute Difference'
403
+ )
404
+ ax.fill_between(
405
+ summary_df['sample_size'],
406
+ summary_df['smoothed_mean_abs_diff'] - summary_df['std_abs_diff'],
407
+ summary_df['smoothed_mean_abs_diff'] + summary_df['std_abs_diff'],
408
+ color='teal', alpha=0.3, label='±1 Std. Dev.'
409
+ )
410
+
411
+ # Mark the elbow point (inflection) on the plot
412
+ ax.axvline(elbow_point['sample_size'], color='black', linestyle='--', label='Elbow Point')
413
+
414
+ # Formatting the plot
415
+ ax.set_xlabel('Sample Size')
416
+ ax.set_ylabel('Mean Absolute Difference')
417
+ ax.set_title('Mean Absolute Difference vs. Sample Size with Standard Deviation')
418
+ ax.legend().remove()
419
+
420
+ # Save the plot if a destination is provided
421
+ dst = os.path.dirname(settings['count_data'][0])
422
+ if dst is not None:
423
+ fig_path = os.path.join(dst, 'results')
424
+ os.makedirs(fig_path, exist_ok=True)
425
+ fig_file_path = os.path.join(fig_path, 'cell_min_threshold.pdf')
426
+ fig.savefig(fig_file_path, format='pdf', dpi=600, bbox_inches='tight')
427
+ print(f"Saved {fig_file_path}")
428
+
429
+ plt.show()
430
+ return elbow_point['sample_size']
431
+
432
+ def process_model_coefficients(model, regression_type, X, y, nc, pc, controls):
433
+ """Return DataFrame of model coefficients, standard errors, and p-values."""
434
+
435
+ if regression_type == 'beta':
436
+ # Extract coefficients and standard errors
437
+ coefs = model.params
438
+ std_err = model.bse
439
+
440
+ # Compute Wald test (coefficient / standard error)
441
+ wald_stats = coefs / std_err
442
+
443
+ # Calculate two-tailed p-values
444
+ p_values = 2 * (1 - st.norm.cdf(np.abs(wald_stats)))
445
+
446
+ coef_df = pd.DataFrame({
447
+ 'feature': coefs.index,
448
+ 'coefficient': coefs.values,
449
+ 'std_err': std_err.values,
450
+ 'wald_stat': wald_stats.values,
451
+ 'p_value': p_values
452
+ })
453
+
454
+ elif regression_type in ['ols', 'glm', 'logit', 'probit', 'quasi_binomial']:
455
+ coefs = model.params
456
+ p_values = model.pvalues
457
+
458
+ coef_df = pd.DataFrame({
459
+ 'feature': coefs.index,
460
+ 'coefficient': coefs.values,
461
+ 'p_value': p_values.values
462
+ })
463
+
464
+ elif regression_type in ['ridge', 'lasso']:
465
+ coefs = model.coef_.flatten()
466
+ p_values = calculate_p_values(X, y, model)
467
+
468
+ coef_df = pd.DataFrame({
469
+ 'feature': X.columns,
470
+ 'coefficient': coefs,
471
+ 'p_value': p_values
472
+ })
473
+
474
+ else:
475
+ raise ValueError(f"Unsupported regression type: {regression_type}")
476
+
477
+ # Additional formatting
478
+ coef_df['-log10(p_value)'] = -np.log10(coef_df['p_value'])
479
+ coef_df['grna'] = coef_df['feature'].str.extract(r'\[(.*?)\]')[0]
480
+ coef_df['condition'] = coef_df.apply(
481
+ lambda row: 'nc' if nc in row['feature'] else
482
+ 'pc' if pc in row['feature'] else
483
+ ('control' if row['grna'] in controls else 'other'),
484
+ axis=1
485
+ )
298
486
 
299
- # Check if the data is normally distributed
300
- is_normal = check_normality(df[dependent_variable], dependent_variable)
487
+ return coef_df[~coef_df['feature'].str.contains('row|column')]
488
+
489
+ def check_distribution(y, epsilon=1e-6):
490
+ """Check the distribution of y and recommend an appropriate model."""
491
+
492
+ # Check if the dependent variable is binary (only 0 and 1)
493
+ if np.all((y == 0) | (y == 1)):
494
+ print("Detected binary data.")
495
+ return 'logit'
496
+
497
+ # Continuous data between 0 and 1 (excluding exact 0 and 1)
498
+ elif (y > 0).all() and (y < 1).all():
499
+ # Check if the data is close to 0 or 1 (boundary issues)
500
+ if np.any((y < epsilon) | (y > 1 - epsilon)):
501
+ print("Detected continuous data near 0 or 1. Using quasi-binomial.")
502
+ return 'quasi_binomial'
503
+ else:
504
+ print("Detected continuous data between 0 and 1 (no boundary issues). Using beta regression.")
505
+ return 'beta'
506
+
507
+ # Continuous data between 0 and 1 (including exact 0 or 1)
508
+ elif (y >= 0).all() and (y <= 1).all():
509
+ print("Detected continuous data with boundary values (0 or 1). Using quasi-binomial.")
510
+ return 'quasi_binomial'
511
+
512
+ # Check if the data is normally distributed for OLS suitability
513
+ stat, p_value = stats.normaltest(y) # D’Agostino and Pearson’s test for normality
514
+ print(f"Normality test p-value: {p_value:.4f}")
515
+
516
+ if p_value > 0.05:
517
+ print("Detected normally distributed data. Using OLS.")
518
+ return 'ols'
519
+
520
+ # Check if the data fits a Beta distribution
521
+ if stats.kstest(y, 'beta', args=(2, 2)).pvalue > 0.05:
522
+ # Check if the data is close to 0 or 1 (boundary issues)
523
+ if np.any((y < epsilon) | (y > 1 - epsilon)):
524
+ print("Detected continuous data near 0 or 1. Using quasi-binomial.")
525
+ return 'quasi_binomial'
526
+ else:
527
+ print("Detected continuous data between 0 and 1 (no boundary issues). Using beta regression.")
528
+ return 'beta'
529
+
530
+ print("Detected non-normally distributed data. Using GLM.")
531
+ return 'glm'
532
+
533
+ def pick_glm_family_and_link(y):
534
+ """Select the appropriate GLM family and link function based on data."""
535
+ if np.all((y == 0) | (y == 1)):
536
+ print("Binary data detected. Using Binomial family with Logit link.")
537
+ return sm.families.Binomial(link=sm.families.links.Logit())
538
+
539
+ elif (y > 0).all() and (y < 1).all():
540
+ print("Data strictly between 0 and 1. Beta regression recommended.")
541
+ raise ValueError("Use BetaModel for this data; GLM is not applicable.")
542
+
543
+ elif (y >= 0).all() and (y <= 1).all():
544
+ print("Data between 0 and 1 (including boundaries). Using Quasi-Binomial.")
545
+ return sm.families.Binomial(link=sm.families.links.Logit())
546
+
547
+ stat, p_value = normaltest(y)
548
+ print(f"Normality test p-value: {p_value:.4f}")
549
+ if p_value > 0.05:
550
+ print("Normally distributed data detected. Using Gaussian with Identity link.")
551
+ return sm.families.Gaussian(link=sm.families.links.Identity())
552
+
553
+ if (y >= 0).all() and np.all(y.astype(int) == y):
554
+ print("Count data detected. Using Poisson with Log link.")
555
+ return sm.families.Poisson(link=sm.families.links.Log())
556
+
557
+ if (y > 0).all() and kstest(y, 'invgauss', args=(1,)).pvalue > 0.05:
558
+ print("Inverse Gaussian distribution detected. Using InverseGaussian with Log link.")
559
+ return sm.families.InverseGaussian(link=sm.families.links.Log())
560
+
561
+ if (y >= 0).all():
562
+ print("Overdispersed count data detected. Using Negative Binomial with Log link.")
563
+ return sm.families.NegativeBinomial(link=sm.families.links.Log())
564
+
565
+ print("Using default Gaussian family with Identity link.")
566
+ return sm.families.Gaussian(link=sm.families.links.Identity())
567
+
568
+ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, cov_type=None):
569
+ def plot_regression_line(X, y, model):
570
+ """Helper to plot regression line for lasso and ridge models."""
571
+ y_pred = model.predict(X)
572
+ plt.scatter(X.iloc[:, 1], y, color='blue', label='Data')
573
+ plt.plot(X.iloc[:, 1], y_pred, color='red', label='Regression line')
574
+ plt.xlabel('Features')
575
+ plt.ylabel('Dependent Variable')
576
+ plt.legend()
577
+ plt.show()
578
+
579
+ def find_best_alpha(model_cls):
580
+ """Find optimal alpha using cross-validation."""
581
+ alphas = np.logspace(-5, 5, 100) # Search over a range of alphas
582
+ if model_cls == 'lasso':
583
+ model_cv = LassoCV(alphas=alphas, cv=5).fit(X, y)
584
+ elif model_cls == 'ridge':
585
+ model_cv = RidgeCV(alphas=alphas, cv=5).fit(X, y)
586
+ print(f"Optimal alpha for {model_cls}: {model_cv.alpha_}")
587
+ return model_cv
588
+
589
+ # Dictionary of models
590
+ model_map = {
591
+ 'ols': lambda: sm.OLS(y, X).fit(cov_type=cov_type) if cov_type else sm.OLS(y, X).fit(),
592
+ 'glm': lambda: sm.GLM(y, X, family=pick_glm_family_and_link(y)).fit(),
593
+ 'beta': lambda: BetaModel(endog=y, exog=X).fit(),
594
+ 'logit': lambda: sm.Logit(y, X).fit(),
595
+ 'probit': lambda: sm.Probit(y, X).fit(),
596
+ 'lasso': lambda: find_best_alpha('lasso') if alpha in [0, None] else Lasso(alpha=alpha).fit(X, y),
597
+ 'ridge': lambda: find_best_alpha('ridge') if alpha in [0, None] else Ridge(alpha=alpha).fit(X, y)
598
+ }
301
599
 
302
- if is_normal:
303
- print(f"To avoid violating assumptions, it is recommended to use a regression model that assumes normality.")
304
- print(f"Recommended regression type: ols (Ordinary Least Squares)")
600
+ # Select the model based on regression_type
601
+ if regression_type in model_map:
602
+ model = model_map[regression_type]()
603
+ elif regression_type == 'mixed':
604
+ model = perform_mixed_model(y, X, groups, alpha=alpha)
305
605
  else:
306
- print(f"To avoid violating assumptions, it is recommended to use a regression model that does not assume normality.")
307
- print(f"Recommended regression type: glm (Generalized Linear Model)")
606
+ raise ValueError(f"Unsupported regression type {regression_type}")
607
+
608
+ # Plot regression line for Lasso and Ridge
609
+ if regression_type in ['lasso', 'ridge']:
610
+ plot_regression_line(X, y, model)
611
+
612
+ # Handle GLM-specific statistics
613
+ if regression_type == 'glm':
614
+ llf_model = model.llf # Log-likelihood of the fitted model
615
+ llf_null = model.null_deviance / -2 # Log-likelihood of the null model
616
+ mcfadden_r2 = 1 - (llf_model / llf_null)
617
+ print(f"McFadden's R²: {mcfadden_r2:.4f}")
618
+ print(model.summary())
619
+
620
+ if regression_type in ['lasso', 'ridge']:
621
+ # Calculate the Mean Squared Error (MSE)
622
+ mse = mean_squared_error(y, model.predict(X))
623
+ print(f"{regression_type.capitalize()} Regression MSE: {mse:.4f}")
624
+
625
+ # Display coefficients
626
+ coef_df = pd.DataFrame({
627
+ 'Feature': X.columns,
628
+ 'Coefficient': model.coef_
629
+ })
630
+ print(coef_df)
631
+
632
+ return model
633
+
634
+ def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0,
635
+ random_row_column_effects=False, nc='233460', pc='220950', controls=[''],
636
+ dst=None, cov_type=None, plot=False):
637
+
638
+ from spacr.plot import volcano_plot, plot_histogram
639
+ from spacr.ml import create_volcano_filename, check_and_clean_data, prepare_formula, scale_variables
640
+
641
+ # Generate the volcano filename
642
+ volcano_path = create_volcano_filename(csv_path, regression_type, alpha, dst)
308
643
 
309
644
  # Determine regression type if not specified
310
645
  if regression_type is None:
311
- regression_type = 'ols' if is_normal else 'glm'
646
+ regression_type = check_distribution(df[dependent_variable])
647
+
648
+ print(f"Using regression type: {regression_type}")
312
649
 
313
650
  df = check_and_clean_data(df, dependent_variable)
314
651
 
@@ -319,36 +656,52 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
319
656
  mixed_model, coef_df = fit_mixed_model(df, formula, dst)
320
657
  model = mixed_model
321
658
  else:
322
- # Regular regression models
659
+ # Prepare the formula
323
660
  formula = prepare_formula(dependent_variable, random_row_column_effects=False)
324
661
  y, X = dmatrices(formula, data=df, return_type='dataframe')
325
-
662
+
326
663
  # Plot histogram of the dependent variable
327
664
  plot_histogram(y, dependent_variable, dst=dst)
665
+ plot_histogram(df, 'fraction', dst=dst)
328
666
 
329
667
  # Scale the independent variables and dependent variable
330
- X, y = scale_variables(X, y)
668
+ if regression_type in ['beta', 'quasi_binomial', 'logit']:
669
+ print('Data will not be scaled')
670
+ else:
671
+ X, y = scale_variables(X, y)
331
672
 
332
673
  # Perform the regression
333
674
  groups = df['prc'] if regression_type == 'mixed' else None
334
- print(f'performing {regression_type} regression')
335
-
675
+ print(f'Performing {regression_type} regression')
676
+
336
677
  model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, cov_type=cov_type)
337
678
 
338
679
  # Process the model coefficients
339
680
  coef_df = process_model_coefficients(model, regression_type, X, y, nc, pc, controls)
340
-
681
+ display(coef_df)
341
682
  if plot:
342
683
  volcano_plot(coef_df, volcano_path)
684
+
685
+ return model, coef_df, regression_type
343
686
 
344
- return model, coef_df
687
+ def save_summary_to_file(model, file_path='summary.csv'):
688
+ """
689
+ Save the model's summary output to a CSV or text file.
690
+ """
691
+ # Get the summary as a string
692
+ summary_str = model.summary().as_text()
345
693
 
346
- def perform_regression(settings):
694
+ # Save it as a plain text file or CSV
695
+ with open(file_path, 'w') as f:
696
+ f.write(summary_str)
347
697
 
698
+ def perform_regression(settings):
699
+
348
700
  from .plot import plot_plates
349
701
  from .utils import merge_regression_res_with_metadata, save_settings
350
702
  from .settings import get_perform_regression_default_settings
351
- from .toxo import go_term_enrichment_by_column, custom_volcano_plot
703
+ from .toxo import go_term_enrichment_by_column, custom_volcano_plot, plot_gene_phenotypes, plot_gene_heatmaps
704
+ from .sequencing import graph_sequencing_stats
352
705
 
353
706
  def _perform_regression_read_data(settings):
354
707
 
@@ -393,7 +746,7 @@ def perform_regression(settings):
393
746
  if not settings['class_1_threshold'] is None:
394
747
  score_data_df['predictions'] = (score_data_df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
395
748
 
396
- reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
749
+ reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge', None]
397
750
  if settings['regression_type'] not in reg_types:
398
751
  print(f'Possible regression types: {reg_types}')
399
752
  raise ValueError(f"Unsupported regression type {settings['regression_type']}")
@@ -401,7 +754,7 @@ def perform_regression(settings):
401
754
  return count_data_df, score_data_df
402
755
 
403
756
  def _perform_regression_set_paths(settings):
404
-
757
+
405
758
  if isinstance(settings['score_data'], list):
406
759
  score_data = settings['score_data'][0]
407
760
  else:
@@ -417,7 +770,11 @@ def perform_regression(settings):
417
770
  csv_path = settings['count_data']
418
771
 
419
772
  settings['src'] = src
420
- res_folder = os.path.join(src, 'results', score_source, settings['regression_type'])
773
+
774
+ if settings['regression_type'] is None:
775
+ res_folder = os.path.join(src, 'results', score_source, 'auto')
776
+ else:
777
+ res_folder = os.path.join(src, 'results', score_source, settings['regression_type'])
421
778
 
422
779
  if isinstance(settings['count_data'], list):
423
780
  res_folder = os.path.join(res_folder, 'list')
@@ -468,9 +825,18 @@ def perform_regression(settings):
468
825
  score_data_df = clean_controls(score_data_df, settings['filter_value'], settings['filter_column'])
469
826
  print(f"Dependent variable after clean_controls: {len(score_data_df)}")
470
827
 
828
+ if settings['min_cell_count'] is None:
829
+ settings['min_cell_count'] = minimum_cell_simulation(settings, tolerance=settings['tolerance'])
830
+ print(f"Minimum cell count: {settings['min_cell_count']}")
831
+
832
+ orig_dv = settings['dependent_variable']
833
+
471
834
  dependent_df, dependent_variable = process_scores(score_data_df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
472
835
  print(f"Dependent variable after process_scores: {len(dependent_df)}")
473
836
 
837
+ if settings['fraction_threshold'] is None:
838
+ settings['fraction_threshold'] = graph_sequencing_stats(settings)
839
+
474
840
  independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plate'], filter_column=filter_column, filter_value=filter_value)
475
841
  independent_df, n_grna, n_gene = _count_variable_instances(independent_df, column_1='grna', column_2='gene')
476
842
 
@@ -478,15 +844,16 @@ def perform_regression(settings):
478
844
 
479
845
  merged_df = pd.merge(independent_df, dependent_df, on='prc')
480
846
 
847
+ os.makedirs(res_folder, exist_ok=True)
481
848
  data_path = os.path.join(res_folder, 'regression_data.csv')
482
849
  merged_df.to_csv(data_path, index=False)
850
+ print(f"Saved regression data to {data_path}")
483
851
 
484
852
  merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
485
853
 
486
- if settings['transform'] is None:
487
- _ = plot_plates(score_data_df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'], dst = res_folder)
854
+ _ = plot_plates(merged_df, variable=orig_dv, grouping='mean', min_max='allq', cmap='viridis', min_count=None, dst=res_folder)
488
855
 
489
- model, coef_df = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
856
+ model, coef_df, regression_type = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
490
857
 
491
858
  coef_df['grna'] = coef_df['feature'].apply(lambda x: re.search(r'grna\[(.*?)\]', x).group(1) if 'grna' in x else None)
492
859
  coef_df['gene'] = coef_df['feature'].apply(lambda x: re.search(r'gene\[(.*?)\]', x).group(1) if 'gene' in x else None)
@@ -499,8 +866,12 @@ def perform_regression(settings):
499
866
  grna_coef_df = grna_coef_df.dropna(subset=['n_grna'])
500
867
 
501
868
  if settings['controls'] is not None:
869
+
502
870
  control_coef_df = grna_coef_df[grna_coef_df['grna'].isin(settings['controls'])]
503
871
  mean_coef = control_coef_df['coefficient'].mean()
872
+ significant_c = control_coef_df[control_coef_df['p_value']<= 0.05]
873
+ mean_coef_c = significant_c['coefficient'].mean()
874
+ print(mean_coef, mean_coef_c)
504
875
 
505
876
  if settings['threshold_method'] in ['var','variance']:
506
877
  coef_mes = control_coef_df['coefficient'].var()
@@ -508,13 +879,14 @@ def perform_regression(settings):
508
879
  coef_mes = control_coef_df['coefficient'].std()
509
880
  else:
510
881
  raise ValueError(f"Unsupported threshold method {settings['threshold_method']}. Supported methods: ['var','variance','std','standard_deveation']")
882
+
511
883
  reg_threshold = mean_coef + (settings['threshold_multiplier'] * coef_mes)
512
884
 
513
885
  coef_df.to_csv(results_path, index=False)
514
886
  gene_coef_df.to_csv(results_path_gene, index=False)
515
887
  grna_coef_df.to_csv(results_path_grna, index=False)
516
888
 
517
- if settings['regression_type'] == 'lasso':
889
+ if regression_type == 'lasso':
518
890
  significant = coef_df[coef_df['coefficient'] > 0]
519
891
 
520
892
  else:
@@ -526,11 +898,18 @@ def perform_regression(settings):
526
898
  significant.sort_values(by='coefficient', ascending=False, inplace=True)
527
899
  significant = significant[~significant['feature'].str.contains('row|column')]
528
900
 
529
- if settings['regression_type'] == 'ols':
901
+ if regression_type in ['ols', 'beta']:
530
902
  print(model.summary())
903
+ save_summary_to_file(model, file_path=f'{res_folder}/mode_summary.csv')
531
904
 
532
905
  significant.to_csv(hits_path, index=False)
533
906
 
907
+ significant_grna_filtered = significant[significant['n_grna'] > settings['min_n']]
908
+ significant_gene_filtered = significant[significant['n_gene'] > settings['min_n']]
909
+ significant_filtered = pd.concat([significant_grna_filtered, significant_gene_filtered])
910
+ filtered_hit_path = os.path.join(os.path.dirname(hits_path), 'results_significant_filtered.csv')
911
+ significant_filtered.to_csv(filtered_hit_path, index=False)
912
+
534
913
  if isinstance(settings['metadata_files'], str):
535
914
  settings['metadata_files'] = [settings['metadata_files']]
536
915
 
@@ -549,9 +928,32 @@ def perform_regression(settings):
549
928
  base_dir = os.path.dirname(os.path.abspath(__file__))
550
929
  metadata_path = os.path.join(base_dir, 'resources', 'data', 'lopit.csv')
551
930
 
552
- custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=200, figsize=20, threshold=reg_threshold, split_axis_lims=settings['split_axis_lims'], save_path=volcano_path)
553
- #custom_volcano_plot(data_path_gene, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=reg_threshold)
554
- #custom_volcano_plot(data_path_grna, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=reg_threshold)
931
+ display(data_path)
932
+
933
+ if settings['volcano'] == 'all':
934
+ print('all')
935
+ gene_list = custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
936
+ display(gene_list)
937
+ elif settings['volcano'] == 'gene':
938
+ print('gene')
939
+ gene_list = custom_volcano_plot(data_path_gene, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
940
+ display(gene_list)
941
+ elif settings['volcano'] == 'grna':
942
+ print('grna')
943
+ gene_list = custom_volcano_plot(data_path_grna, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
944
+ display(gene_list)
945
+ phenotype_plot = os.path.join(res_folder,'phenotype_plot.pdf')
946
+ transcription_heatmap = os.path.join(res_folder,'transcription_heatmap.pdf')
947
+ data_GT1 = pd.read_csv(settings['metadata_files'][1], low_memory=False)
948
+ data_ME49 = pd.read_csv(settings['metadata_files'][0], low_memory=False)
949
+
950
+ columns = ['sense - Tachyzoites', 'sense - Tissue cysts', 'sense - EES1', 'sense - EES2', 'sense - EES3', 'sense - EES4', 'sense - EES5']
951
+
952
+ print('Plotting gene phenotypes and heatmaps')
953
+ print(gene_list)
954
+
955
+ plot_gene_phenotypes(data=data_GT1, gene_list=gene_list, save_path=phenotype_plot)
956
+ plot_gene_heatmaps(data=data_ME49, gene_list=gene_list, columns=columns, x_column='Gene ID', normalize=True, save_path=transcription_heatmap)
555
957
 
556
958
  #if len(significant) > 2:
557
959
  # metadata_path = os.path.join(base_dir, 'resources', 'data', 'toxoplasma_metadata.csv')