spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +803 -14
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/core.py +1544 -533
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +297 -253
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +70 -80
- spacr/gui_mask_app.py +114 -91
- spacr/gui_measure_app.py +109 -88
- spacr/gui_utils.py +376 -32
- spacr/io.py +441 -438
- spacr/mask_app.py +116 -9
- spacr/measure.py +169 -69
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +70 -2
- spacr/plot.py +173 -17
- spacr/sequencing.py +1130 -0
- spacr/sim.py +630 -125
- spacr/timelapse.py +139 -10
- spacr/train.py +188 -21
- spacr/umap.py +0 -689
- spacr/utils.py +1360 -119
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/METADATA +17 -29
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.2.dist-info/RECORD +0 -31
- spacr-0.0.2.dist-info/entry_points.txt +0 -7
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.2.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/sim.py
CHANGED
@@ -1,16 +1,19 @@
|
|
1
1
|
|
2
|
-
import os,
|
3
|
-
import time as tm
|
2
|
+
import os, random, warnings, traceback, sqlite3, shap, math, gc
|
4
3
|
from time import time, sleep
|
5
4
|
from datetime import datetime
|
6
5
|
import numpy as np
|
7
6
|
import pandas as pd
|
8
7
|
import matplotlib.pyplot as plt
|
8
|
+
from matplotlib.patches import Patch
|
9
9
|
import seaborn as sns
|
10
10
|
import sklearn.metrics as metrics
|
11
|
-
from sklearn.
|
11
|
+
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
12
|
+
from sklearn.inspection import PartialDependenceDisplay, permutation_importance
|
13
|
+
from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_recall_curve
|
12
14
|
import statsmodels.api as sm
|
13
|
-
from multiprocessing import cpu_count,
|
15
|
+
from multiprocessing import cpu_count, Pool, Manager
|
16
|
+
from copy import deepcopy
|
14
17
|
|
15
18
|
from .logger import log_function_call
|
16
19
|
|
@@ -35,7 +38,7 @@ def generate_gene_list(number_of_genes, number_of_all_genes):
|
|
35
38
|
|
36
39
|
# plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
|
37
40
|
def generate_plate_map(nr_plates):
|
38
|
-
print('nr_plates',nr_plates)
|
41
|
+
#print('nr_plates',nr_plates)
|
39
42
|
"""
|
40
43
|
Generate a plate map based on the number of plates.
|
41
44
|
|
@@ -241,6 +244,7 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
241
244
|
- df_ls (list): A list containing gene counts per well, well counts per gene, Gini coefficients for wells,
|
242
245
|
Gini coefficients for genes, gene weights array, and well weights.
|
243
246
|
"""
|
247
|
+
|
244
248
|
#generate primary distributions and genes
|
245
249
|
cpw, _ = dist_gen(avg_cells_per_well, sd_cells_per_well, plate_map)
|
246
250
|
gpw, _ = dist_gen(avg_genes_per_well, sd_genes_per_well, plate_map)
|
@@ -251,14 +255,18 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
251
255
|
gene_weights_array = np.array(list(gene_weights.values())) # Convert the values to an array
|
252
256
|
|
253
257
|
well_weights = generate_power_law_distribution(len(plate_map), well_ineq_coeff)
|
254
|
-
|
255
|
-
gene_to_well_mapping = {}
|
256
|
-
|
257
|
-
for gene in genes:
|
258
|
-
gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=int(gpw[gene-1]), p=well_weights) # Generate a number of wells for each gene according to well_weights
|
259
|
-
|
260
|
-
gene_to_well_mapping = {gene: wells for gene, wells in gene_to_well_mapping.items() if len(wells) >= 2}
|
261
258
|
|
259
|
+
gene_to_well_mapping = {}
|
260
|
+
for gene in range(1, number_of_genes + 1): # ensures gene-1 is within bounds
|
261
|
+
if gene-1 < len(gpw):
|
262
|
+
max_index = len(plate_map['plate_row_column']) # this should be the number of choices available from plate_map
|
263
|
+
num_samples = int(gpw[gene-1])
|
264
|
+
if num_samples >= max_index:
|
265
|
+
num_samples = max_index - 1 # adjust to maximum possible index
|
266
|
+
gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=num_samples, replace=False, p=well_weights)
|
267
|
+
else:
|
268
|
+
break # break the loop if gene-1 is out of bounds for gpw
|
269
|
+
|
262
270
|
cells = []
|
263
271
|
for i in [*range(0,len(plate_map))]:
|
264
272
|
ciw = random.choice(cpw)
|
@@ -316,9 +324,47 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
316
324
|
df_ls = [gene_counts_per_well, well_counts_per_gene, gini_well, gini_gene, gene_weights_array, well_weights]
|
317
325
|
return cell_df, genes_per_well_df, wells_per_gene_df, df_ls
|
318
326
|
|
319
|
-
|
320
|
-
|
321
|
-
|
327
|
+
def classifier(positive_mean, positive_variance, negative_mean, negative_variance, classifier_accuracy, df):
|
328
|
+
"""
|
329
|
+
Classifies the data in the DataFrame based on the given parameters and a classifier error rate.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
positive_mean (float): The mean of the positive distribution.
|
333
|
+
positive_variance (float): The variance of the positive distribution.
|
334
|
+
negative_mean (float): The mean of the negative distribution.
|
335
|
+
negative_variance (float): The variance of the negative distribution.
|
336
|
+
classifier_accuracy (float): The likelihood (0 to 1) that a gene is correctly classified according to its true label.
|
337
|
+
df (pandas.DataFrame): The DataFrame containing the data to be classified.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
|
341
|
+
"""
|
342
|
+
def calc_alpha_beta(mean, variance):
|
343
|
+
if mean <= 0 or mean >= 1:
|
344
|
+
raise ValueError("Mean must be between 0 and 1 exclusively.")
|
345
|
+
max_variance = mean * (1 - mean)
|
346
|
+
if variance <= 0 or variance >= max_variance:
|
347
|
+
raise ValueError(f"Variance must be positive and less than {max_variance}.")
|
348
|
+
|
349
|
+
alpha = mean * (mean * (1 - mean) / variance - 1)
|
350
|
+
beta = alpha * (1 - mean) / mean
|
351
|
+
return alpha, beta
|
352
|
+
|
353
|
+
# Apply the beta distribution based on 'is_active' status with consideration for classifier error
|
354
|
+
def get_score(is_active):
|
355
|
+
if np.random.rand() < classifier_accuracy: # With classifier_accuracy probability, choose the correct distribution
|
356
|
+
return np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2)
|
357
|
+
else: # With 1-classifier_accuracy probability, choose the incorrect distribution
|
358
|
+
return np.random.beta(a2, b2) if is_active else np.random.beta(a1, b1)
|
359
|
+
|
360
|
+
# Calculate alpha and beta for both distributions
|
361
|
+
a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
|
362
|
+
a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
|
363
|
+
df['score'] = df['is_active'].apply(get_score)
|
364
|
+
|
365
|
+
return df
|
366
|
+
|
367
|
+
def classifier_v2(positive_mean, positive_variance, negative_mean, negative_variance, df):
|
322
368
|
"""
|
323
369
|
Classifies the data in the DataFrame based on the given parameters.
|
324
370
|
|
@@ -332,12 +378,22 @@ def classifier(positive_mean, positive_variance, negative_mean, negative_varianc
|
|
332
378
|
Returns:
|
333
379
|
pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
|
334
380
|
"""
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
381
|
+
def calc_alpha_beta(mean, variance):
|
382
|
+
if mean <= 0 or mean >= 1:
|
383
|
+
raise ValueError("Mean must be between 0 and 1 exclusively.")
|
384
|
+
max_variance = mean * (1 - mean)
|
385
|
+
if variance <= 0 or variance >= max_variance:
|
386
|
+
raise ValueError(f"Variance must be positive and less than {max_variance}.")
|
387
|
+
|
388
|
+
alpha = mean * (mean * (1 - mean) / variance - 1)
|
389
|
+
beta = alpha * (1 - mean) / mean
|
390
|
+
return alpha, beta
|
391
|
+
|
392
|
+
# Calculate alpha and beta for both distributions
|
393
|
+
a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
|
394
|
+
a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
|
395
|
+
|
396
|
+
# Apply the beta distribution based on 'is_active' status
|
341
397
|
df['score'] = df['is_active'].apply(lambda is_active: np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2))
|
342
398
|
return df
|
343
399
|
|
@@ -446,6 +502,7 @@ def generate_well_score(cell_scores):
|
|
446
502
|
return well_score
|
447
503
|
|
448
504
|
def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_per_gene, sequencing_error=0.01):
|
505
|
+
|
449
506
|
"""
|
450
507
|
Simulates the sequencing of plates and calculates gene fractions and metadata.
|
451
508
|
|
@@ -460,6 +517,7 @@ def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_pe
|
|
460
517
|
gene_fraction_map (pd.DataFrame): DataFrame containing gene fractions for each well.
|
461
518
|
metadata (pd.DataFrame): DataFrame containing metadata for each well.
|
462
519
|
"""
|
520
|
+
|
463
521
|
reads, _ = dist_gen(avg_reads_per_gene, sd_reads_per_gene, well_score)
|
464
522
|
gene_names = [f'gene_{v}' for v in range(number_of_genes+1)]
|
465
523
|
all_wells = well_score.index
|
@@ -612,8 +670,12 @@ def plot_histogram(data, x_label, ax, color, title, binwidth=0.01, log=False):
|
|
612
670
|
Returns:
|
613
671
|
None
|
614
672
|
"""
|
615
|
-
|
616
|
-
|
673
|
+
if not binwidth:
|
674
|
+
sns.histplot(data=data, x=x_label, ax=ax, color=color, kde=False, stat='density',
|
675
|
+
legend=False, fill=True, element='step', palette='dark')
|
676
|
+
else:
|
677
|
+
sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
|
678
|
+
legend=False, fill=True, element='step', palette='dark')
|
617
679
|
if log:
|
618
680
|
ax.set_yscale('log')
|
619
681
|
ax.set_title(title)
|
@@ -700,7 +762,7 @@ def run_simulation(settings):
|
|
700
762
|
plate_map = plate_map[~plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is not in [1,2,3,23,24]
|
701
763
|
|
702
764
|
cell_level, genes_per_well_df, wells_per_gene_df, dists = run_experiment(plate_map, settings['number_of_genes'], active_gene_list, settings['avg_genes_per_well'], settings['sd_genes_per_well'], settings['avg_cells_per_well'], settings['sd_cells_per_well'], settings['well_ineq_coeff'], settings['gene_ineq_coeff'])
|
703
|
-
cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], df=cell_level)
|
765
|
+
cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], settings['classifier_accuracy'], df=cell_level)
|
704
766
|
cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm = cell_level_roc_auc(cell_scores)
|
705
767
|
well_score = generate_well_score(cell_scores)
|
706
768
|
gene_fraction_map, metadata = sequence_plates(well_score, settings['number_of_genes'], settings['avg_reads_per_gene'], settings['sd_reads_per_gene'], sequencing_error=settings['sequencing_error'])
|
@@ -720,7 +782,10 @@ def run_simulation(settings):
|
|
720
782
|
results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats = regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False)
|
721
783
|
#except Exception as e:
|
722
784
|
# print(f"An error occurred while saving data: {e}")
|
723
|
-
|
785
|
+
output = [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df]
|
786
|
+
del cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df
|
787
|
+
gc.collect()
|
788
|
+
return output, dists
|
724
789
|
|
725
790
|
def vis_dists(dists, src, v, i):
|
726
791
|
"""
|
@@ -744,9 +809,14 @@ def vis_dists(dists, src, v, i):
|
|
744
809
|
for index, dist in enumerate(dists):
|
745
810
|
temp = pd.DataFrame(dist, columns = [f'{names[index]}'])
|
746
811
|
sns.histplot(data=temp, x=f'{names[index]}', kde=False, binwidth=None, stat='count', element="step", ax=ax[n], color='teal', log_scale=False)
|
747
|
-
#plot_histogram(temp, f'{names[index]}', ax[n], 'slategray', f'{names[index]}', binwidth=None, log=False)
|
748
812
|
n+=1
|
749
813
|
save_plot(fig2, src, 'dists', i)
|
814
|
+
plt.close(fig2)
|
815
|
+
plt.figure().clear()
|
816
|
+
plt.cla()
|
817
|
+
plt.clf()
|
818
|
+
del dists
|
819
|
+
|
750
820
|
return
|
751
821
|
|
752
822
|
def visualize_all(output):
|
@@ -773,6 +843,7 @@ def visualize_all(output):
|
|
773
843
|
Returns:
|
774
844
|
fig (matplotlib.figure.Figure): The generated figure object.
|
775
845
|
"""
|
846
|
+
|
776
847
|
cell_scores = output[0]
|
777
848
|
cell_roc_dict_df = output[1]
|
778
849
|
cell_pr_dict_df = output[2]
|
@@ -810,24 +881,35 @@ def visualize_all(output):
|
|
810
881
|
#plot cell classification score by inactive and active
|
811
882
|
active_distribution = cell_scores[cell_scores['is_active'] == 1]
|
812
883
|
inactive_distribution = cell_scores[cell_scores['is_active'] == 0]
|
813
|
-
plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', binwidth=0.01, log=False)
|
814
|
-
plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', binwidth=0.01, log=False)
|
884
|
+
plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', log=False)#, binwidth=0.01, log=False)
|
885
|
+
plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', log=False)#, binwidth=0.01, log=False)
|
886
|
+
|
887
|
+
legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
|
888
|
+
Patch(facecolor='teal', edgecolor='teal', label='Active')]
|
889
|
+
|
890
|
+
ax[n].legend(handles=legend_elements, loc='upper right')
|
891
|
+
|
892
|
+
|
815
893
|
ax[n].set_xlim([0, 1])
|
816
894
|
n+=1
|
817
895
|
|
818
896
|
#plot classifier cell predictions by inactive and active well average
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
897
|
+
inactive_distribution_well = inactive_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
898
|
+
active_distribution_well = active_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
899
|
+
mixed_distribution_well = cell_scores.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
900
|
+
|
901
|
+
plot_histogram(inactive_distribution_well, "score", ax[n], 'slategray', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
902
|
+
plot_histogram(active_distribution_well, "score", ax[n], 'teal', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
903
|
+
plot_histogram(mixed_distribution_well, "score", ax[n], 'red', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
824
904
|
|
825
|
-
|
826
|
-
|
905
|
+
legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
|
906
|
+
Patch(facecolor='teal', edgecolor='teal', label='Active'),
|
907
|
+
Patch(facecolor='red', edgecolor='red', label='Mixed')]
|
827
908
|
|
828
|
-
|
829
|
-
|
909
|
+
ax[n].legend(handles=legend_elements, loc='upper right')
|
910
|
+
|
830
911
|
ax[n].set_xlim([0, 1])
|
912
|
+
#ax[n].legend()
|
831
913
|
n+=1
|
832
914
|
|
833
915
|
#plot ROC (cell classification)
|
@@ -846,8 +928,8 @@ def visualize_all(output):
|
|
846
928
|
n+=1
|
847
929
|
|
848
930
|
#plot well score
|
849
|
-
plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=
|
850
|
-
ax[n].set_xlim([0, 1])
|
931
|
+
plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=True)
|
932
|
+
#ax[n].set_xlim([0, 1])
|
851
933
|
n+=1
|
852
934
|
|
853
935
|
control_df = results_df[results_df['color'] == 'control']
|
@@ -911,6 +993,7 @@ def visualize_all(output):
|
|
911
993
|
|
912
994
|
plt.tight_layout()
|
913
995
|
plt.show()
|
996
|
+
gc.collect()
|
914
997
|
return fig
|
915
998
|
|
916
999
|
def create_database(db_path):
|
@@ -992,6 +1075,7 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
|
|
992
1075
|
df_concat[f'variable_{variable}_sim_nr'] = i
|
993
1076
|
|
994
1077
|
append_database(src, df_concat, 'simulations')
|
1078
|
+
del gini_genes_per_well, gini_wells_per_gene, df_concat
|
995
1079
|
|
996
1080
|
if save_all:
|
997
1081
|
for i, df in enumerate(output):
|
@@ -1001,8 +1085,12 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
|
|
1001
1085
|
if not isinstance(df, pd.DataFrame):
|
1002
1086
|
df = pd.DataFrame(df)
|
1003
1087
|
append_database(src, df, table_names[i])
|
1088
|
+
del df
|
1004
1089
|
except Exception as e:
|
1005
1090
|
print(f"An error occurred while saving data: {e}")
|
1091
|
+
print(traceback.format_exc())
|
1092
|
+
|
1093
|
+
del output, settings_df
|
1006
1094
|
return
|
1007
1095
|
|
1008
1096
|
def save_plot(fig, src, variable, i):
|
@@ -1025,7 +1113,6 @@ def save_plot(fig, src, variable, i):
|
|
1025
1113
|
|
1026
1114
|
def run_and_save(i, settings, time_ls, total_sims):
|
1027
1115
|
|
1028
|
-
|
1029
1116
|
"""
|
1030
1117
|
Run the simulation and save the results.
|
1031
1118
|
|
@@ -1040,24 +1127,24 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1040
1127
|
"""
|
1041
1128
|
#print(f'Runnings simulation with the following paramiters')
|
1042
1129
|
#print(settings)
|
1043
|
-
|
1130
|
+
settings['random_seed'] = False
|
1044
1131
|
if settings['random_seed']:
|
1045
1132
|
random.seed(42) # sims will be too similar with random seed
|
1046
1133
|
src = settings['src']
|
1047
1134
|
plot = settings['plot']
|
1048
1135
|
v = settings['variable']
|
1049
1136
|
start_time = time() # Start time of the simulation
|
1050
|
-
now = datetime.now() # get current date
|
1051
|
-
date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1137
|
+
#now = datetime.now() # get current date
|
1138
|
+
#date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1139
|
+
date_string = settings['start_time']
|
1052
1140
|
#try:
|
1053
1141
|
output, dists = run_simulation(settings)
|
1054
1142
|
sim_time = time() - start_time # Elapsed time for the simulation
|
1055
1143
|
settings['sim_time'] = sim_time
|
1056
1144
|
src = os.path.join(f'{src}/{date_string}',settings['name'])
|
1057
1145
|
save_data(src, output, settings, save_all=False, i=i, variable=v)
|
1058
|
-
if vis_dists:
|
1059
|
-
vis_dists(dists,src, v, i)
|
1060
1146
|
if plot:
|
1147
|
+
vis_dists(dists,src, v, i)
|
1061
1148
|
fig = visualize_all(output)
|
1062
1149
|
save_plot(fig, src, v, i)
|
1063
1150
|
plt.close(fig)
|
@@ -1066,6 +1153,7 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1066
1153
|
plt.clf()
|
1067
1154
|
del fig
|
1068
1155
|
del output, dists
|
1156
|
+
gc.collect()
|
1069
1157
|
#except Exception as e:
|
1070
1158
|
# print(e, end='\r', flush=True)
|
1071
1159
|
# sim_time = time() - start_time
|
@@ -1073,98 +1161,92 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1073
1161
|
time_ls.append(sim_time)
|
1074
1162
|
return i, sim_time, None
|
1075
1163
|
|
1076
|
-
def
|
1164
|
+
def validate_and_adjust_beta_params(sim_params):
|
1077
1165
|
"""
|
1078
|
-
|
1079
|
-
|
1166
|
+
Validates and adjusts Beta distribution parameters in simulation settings to ensure they are possible.
|
1167
|
+
|
1080
1168
|
Args:
|
1081
|
-
|
1082
|
-
|
1169
|
+
sim_params (list of dict): List of dictionaries, each containing the simulation parameters.
|
1170
|
+
|
1083
1171
|
Returns:
|
1084
|
-
|
1172
|
+
list of dict: The adjusted list of simulation parameter sets.
|
1085
1173
|
"""
|
1086
|
-
|
1087
|
-
for
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
sett['well_ineq_coeff'] = well_ineq_coeff
|
1107
|
-
for gene_ineq_coeff in settings['gene_ineq_coeff']:
|
1108
|
-
sett['gene_ineq_coeff'] = gene_ineq_coeff
|
1109
|
-
for nr_plates in settings['nr_plates']:
|
1110
|
-
sett['nr_plates'] = nr_plates
|
1111
|
-
for number_of_genes in settings['number_of_genes']:
|
1112
|
-
sett['number_of_genes'] = number_of_genes
|
1113
|
-
for number_of_active_genes in settings['number_of_active_genes']:
|
1114
|
-
sett['number_of_active_genes'] = number_of_active_genes
|
1115
|
-
for i in [*range(1,replicates+1)]:
|
1116
|
-
sim_ls.append(sett)
|
1117
|
-
#print(sett)
|
1118
|
-
#print('Number of simulations:',len(sim_ls))
|
1119
|
-
return sim_ls
|
1174
|
+
adjusted_params = []
|
1175
|
+
for params in sim_params:
|
1176
|
+
max_pos_variance = params['positive_mean'] * (1 - params['positive_mean'])
|
1177
|
+
max_neg_variance = params['negative_mean'] * (1 - params['negative_mean'])
|
1178
|
+
|
1179
|
+
# Adjust positive variance
|
1180
|
+
if params['positive_variance'] >= max_pos_variance:
|
1181
|
+
print(f'changed positive variance from {params["positive_variance"]} to {max_pos_variance * 0.99}')
|
1182
|
+
params['positive_variance'] = max_pos_variance * 0.99 # Adjust to 99% of the maximum allowed variance
|
1183
|
+
|
1184
|
+
# Adjust negative variance
|
1185
|
+
if params['negative_variance'] >= max_neg_variance:
|
1186
|
+
print(f'changed negative variance from {params["negative_variance"]} to {max_neg_variance * 0.99}')
|
1187
|
+
params['negative_variance'] = max_neg_variance * 0.99 # Adjust to 99% of the maximum allowed variance
|
1188
|
+
|
1189
|
+
adjusted_params.append(params)
|
1190
|
+
|
1191
|
+
return adjusted_params
|
1192
|
+
|
1193
|
+
def generate_paramiters(settings):
|
1120
1194
|
|
1121
|
-
#altered for one set of settings see negative_mean and variance
|
1122
|
-
def generate_paramiters_single(settings):
|
1123
1195
|
"""
|
1124
|
-
Generate a list of parameter sets for
|
1196
|
+
Generate a list of parameter sets for simulation based on the given settings.
|
1125
1197
|
|
1126
1198
|
Args:
|
1127
1199
|
settings (dict): A dictionary containing the simulation settings.
|
1128
1200
|
|
1129
1201
|
Returns:
|
1130
|
-
list: A list of parameter sets for
|
1202
|
+
list: A list of parameter sets for simulation.
|
1131
1203
|
"""
|
1204
|
+
|
1205
|
+
settings['positive_mean'] = [0.8]
|
1206
|
+
|
1132
1207
|
sim_ls = []
|
1133
1208
|
for avg_genes_per_well in settings['avg_genes_per_well']:
|
1134
1209
|
replicates = settings['replicates']
|
1135
|
-
sett = settings.copy()
|
1136
|
-
sett['avg_genes_per_well'] = avg_genes_per_well
|
1137
|
-
sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
|
1138
1210
|
for avg_cells_per_well in settings['avg_cells_per_well']:
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1211
|
+
for classifier_accuracy in settings['classifier_accuracy']:
|
1212
|
+
for positive_mean in settings['positive_mean']:
|
1213
|
+
for avg_reads_per_gene in settings['avg_reads_per_gene']:
|
1214
|
+
for sequencing_error in settings['sequencing_error']:
|
1215
|
+
for well_ineq_coeff in settings['well_ineq_coeff']:
|
1216
|
+
for gene_ineq_coeff in settings['gene_ineq_coeff']:
|
1217
|
+
for nr_plates in settings['nr_plates']:
|
1218
|
+
for number_of_genes in settings['number_of_genes']:
|
1219
|
+
for number_of_active_genes in settings['number_of_active_genes']:
|
1220
|
+
for i in range(1, replicates+1):
|
1221
|
+
sett = deepcopy(settings)
|
1222
|
+
sett['avg_genes_per_well'] = avg_genes_per_well
|
1223
|
+
sett['sd_genes_per_well'] = avg_genes_per_well / 2
|
1224
|
+
sett['avg_cells_per_well'] = avg_cells_per_well
|
1225
|
+
sett['sd_cells_per_well'] = avg_cells_per_well / 2
|
1226
|
+
sett['classifier_accuracy'] = classifier_accuracy
|
1227
|
+
sett['positive_mean'] = positive_mean
|
1228
|
+
sett['negative_mean'] = 1-positive_mean
|
1229
|
+
sett['positive_variance'] = (1-positive_mean)/2
|
1230
|
+
sett['negative_variance'] = (1-positive_mean)/2
|
1231
|
+
sett['avg_reads_per_gene'] = avg_reads_per_gene
|
1232
|
+
sett['sd_reads_per_gene'] = avg_reads_per_gene / 2
|
1233
|
+
sett['sequencing_error'] = sequencing_error
|
1234
|
+
sett['well_ineq_coeff'] = well_ineq_coeff
|
1235
|
+
sett['gene_ineq_coeff'] = gene_ineq_coeff
|
1236
|
+
sett['nr_plates'] = nr_plates
|
1237
|
+
sett['number_of_genes'] = number_of_genes
|
1238
|
+
sett['number_of_active_genes'] = number_of_active_genes
|
1239
|
+
sim_ls.append(sett)
|
1240
|
+
|
1241
|
+
random.shuffle(sim_ls)
|
1242
|
+
sim_ls = validate_and_adjust_beta_params(sim_ls)
|
1243
|
+
print(f'Running {len(sim_ls)} simulations.')
|
1244
|
+
#for x in sim_ls:
|
1245
|
+
# print(x['positive_mean'])
|
1165
1246
|
return sim_ls
|
1166
1247
|
|
1167
1248
|
def run_multiple_simulations(settings):
|
1249
|
+
|
1168
1250
|
"""
|
1169
1251
|
Run multiple simulations in parallel using the provided settings.
|
1170
1252
|
|
@@ -1175,8 +1257,12 @@ def run_multiple_simulations(settings):
|
|
1175
1257
|
None
|
1176
1258
|
"""
|
1177
1259
|
|
1260
|
+
now = datetime.now() # get current date
|
1261
|
+
start_time = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1262
|
+
settings['start_time'] = start_time
|
1263
|
+
|
1178
1264
|
sim_ls = generate_paramiters(settings)
|
1179
|
-
print(f'Running {len(sim_ls)} simulations.
|
1265
|
+
#print(f'Running {len(sim_ls)} simulations.')
|
1180
1266
|
|
1181
1267
|
max_workers = settings['max_workers'] or cpu_count() - 4
|
1182
1268
|
with Manager() as manager:
|
@@ -1185,10 +1271,429 @@ def run_multiple_simulations(settings):
|
|
1185
1271
|
with Pool(max_workers) as pool:
|
1186
1272
|
result = pool.starmap_async(run_and_save, [(index, settings, time_ls, total_sims) for index, settings in enumerate(sim_ls)])
|
1187
1273
|
while not result.ready():
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1274
|
+
try:
|
1275
|
+
sleep(0.01)
|
1276
|
+
sims_processed = len(time_ls)
|
1277
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
1278
|
+
time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
|
1279
|
+
print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
|
1280
|
+
gc.collect()
|
1281
|
+
except Exception as e:
|
1282
|
+
print(e)
|
1283
|
+
print(traceback.format_exc())
|
1284
|
+
try:
|
1285
|
+
result.get()
|
1286
|
+
except Exception as e:
|
1287
|
+
print(e)
|
1288
|
+
print(traceback.format_exc())
|
1194
1289
|
|
1290
|
+
def generate_integers(start, stop, step):
|
1291
|
+
return list(range(start, stop + 1, step))
|
1292
|
+
|
1293
|
+
def generate_floats(start, stop, step):
|
1294
|
+
# Determine the number of decimal places in 'step'
|
1295
|
+
num_decimals = str(step)[::-1].find('.')
|
1296
|
+
|
1297
|
+
current = start
|
1298
|
+
floats_list = []
|
1299
|
+
while current <= stop:
|
1300
|
+
# Round each float to the appropriate number of decimal places
|
1301
|
+
floats_list.append(round(current, num_decimals))
|
1302
|
+
current += step
|
1303
|
+
|
1304
|
+
return floats_list
|
1305
|
+
|
1306
|
+
def remove_columns_with_single_value(df):
|
1307
|
+
"""
|
1308
|
+
Removes columns from the DataFrame that have the same value in all rows.
|
1309
|
+
|
1310
|
+
Args:
|
1311
|
+
df (pandas.DataFrame): The original DataFrame.
|
1312
|
+
|
1313
|
+
Returns:
|
1314
|
+
pandas.DataFrame: A DataFrame with the columns removed that contained only one unique value.
|
1315
|
+
"""
|
1316
|
+
to_drop = [column for column in df.columns if df[column].nunique() == 1]
|
1317
|
+
return df.drop(to_drop, axis=1)
|
1318
|
+
|
1319
|
+
def read_simulations_table(db_path):
|
1320
|
+
"""
|
1321
|
+
Reads the 'simulations' table from an SQLite database into a pandas DataFrame.
|
1322
|
+
|
1323
|
+
Args:
|
1324
|
+
db_path (str): The file path to the SQLite database.
|
1325
|
+
|
1326
|
+
Returns:
|
1327
|
+
pandas.DataFrame: DataFrame containing the 'simulations' table data.
|
1328
|
+
"""
|
1329
|
+
# Create a connection object using the connect function
|
1330
|
+
conn = sqlite3.connect(db_path)
|
1331
|
+
|
1332
|
+
# Read the 'simulations' table into a pandas DataFrame
|
1333
|
+
try:
|
1334
|
+
df = pd.read_sql_query("SELECT * FROM simulations", conn)
|
1335
|
+
except Exception as e:
|
1336
|
+
print(f"An error occurred: {e}")
|
1337
|
+
return None
|
1338
|
+
finally:
|
1339
|
+
# Close the connection to SQLite database
|
1340
|
+
conn.close()
|
1341
|
+
|
1342
|
+
return df
|
1343
|
+
|
1344
|
+
def plot_simulations(df, variable, x_rotation=None, legend=False, grid=False, clean=True, verbose=False):
|
1345
|
+
|
1346
|
+
"""
|
1347
|
+
Creates separate line plots for 'prauc' against a specified 'variable',
|
1348
|
+
for each unique combination of conditions defined by 'grouping_vars', displayed on a grid.
|
1349
|
+
|
1350
|
+
Args:
|
1351
|
+
df (pandas.DataFrame): DataFrame containing the necessary columns.
|
1352
|
+
variable (str): Name of the column to use as the x-axis for grouping and plotting.
|
1353
|
+
x_rotation (int, optional): Degrees to rotate the x-axis labels.
|
1354
|
+
legend (bool, optional): Whether to display a legend.
|
1355
|
+
grid (bool, optional): Whether to display grid lines.
|
1356
|
+
verbose (bool, optional): Whether to print the filter conditions.
|
1357
|
+
|
1358
|
+
Returns:
|
1359
|
+
None
|
1360
|
+
"""
|
1361
|
+
|
1362
|
+
grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1363
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1364
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1365
|
+
|
1366
|
+
if clean:
|
1367
|
+
relevant_data = remove_columns_with_single_value(relevant_data)
|
1368
|
+
|
1369
|
+
grouping_vars = [col for col in grouping_vars if col != variable]
|
1370
|
+
|
1371
|
+
# Check if the necessary columns are present in the DataFrame
|
1372
|
+
required_columns = {variable, 'prauc'} | set(grouping_vars)
|
1373
|
+
if not required_columns.issubset(df.columns):
|
1374
|
+
missing_cols = required_columns - set(df.columns)
|
1375
|
+
raise ValueError(f"DataFrame must contain {missing_cols} columns")
|
1376
|
+
|
1377
|
+
#if not dependent is None:
|
1378
|
+
|
1379
|
+
# Get unique combinations of conditions from grouping_vars
|
1380
|
+
unique_combinations = df[grouping_vars].drop_duplicates()
|
1381
|
+
num_combinations = len(unique_combinations)
|
1382
|
+
|
1383
|
+
# Determine the layout of the subplots
|
1384
|
+
num_rows = math.ceil(np.sqrt(num_combinations))
|
1385
|
+
num_cols = math.ceil(num_combinations / num_rows)
|
1386
|
+
|
1387
|
+
fig, axes = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
|
1388
|
+
if num_rows * num_cols > 1:
|
1389
|
+
axes = axes.flatten()
|
1390
|
+
else:
|
1391
|
+
axes = [axes]
|
1392
|
+
|
1393
|
+
for idx, (ax, (_, row)) in enumerate(zip(axes, unique_combinations.iterrows())):
|
1394
|
+
|
1395
|
+
# Filter the DataFrame for the current combination of variables
|
1396
|
+
condition = {var: row[var] for var in grouping_vars}
|
1397
|
+
subset_df = df[df[grouping_vars].eq(row).all(axis=1)]
|
1398
|
+
|
1399
|
+
# Group by 'variable' and calculate mean and std dev of 'prauc'
|
1400
|
+
grouped = subset_df.groupby(variable)['prauc'].agg(['mean', 'std'])
|
1401
|
+
grouped = grouped.sort_index() # Sort by the variable for orderly plots
|
1402
|
+
|
1403
|
+
# Plotting the mean of 'prauc' with std deviation as shaded area
|
1404
|
+
ax.plot(grouped.index, grouped['mean'], marker='o', linestyle='-', color='b', label='Mean PRAUC')
|
1405
|
+
ax.fill_between(grouped.index, grouped['mean'] - grouped['std'], grouped['mean'] + grouped['std'], color='gray', alpha=0.5, label='Std Dev')
|
1406
|
+
|
1407
|
+
# Setting plot labels and title
|
1408
|
+
title_details = ', '.join([f"{var}={row[var]}" for var in grouping_vars])
|
1409
|
+
ax.set_xlabel(variable)
|
1410
|
+
ax.set_ylabel('Precision-Recall AUC (PRAUC)')
|
1411
|
+
#ax.set_title(f'PRAUC vs. {variable} | {title_details}')
|
1412
|
+
ax.grid(grid)
|
1413
|
+
|
1414
|
+
if legend:
|
1415
|
+
ax.legend()
|
1416
|
+
|
1417
|
+
# Set x-ticks and rotate them as specified
|
1418
|
+
ax.set_xticks(grouped.index)
|
1419
|
+
ax.set_xticklabels(grouped.index, rotation=x_rotation if x_rotation is not None else 45)
|
1420
|
+
|
1421
|
+
if verbose:
|
1422
|
+
verbose_text = '\n'.join([f"{var}: {val}" for var, val in condition.items()])
|
1423
|
+
ax.text(0.95, 0.05, verbose_text, transform=ax.transAxes, fontsize=9, verticalalignment='bottom', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))
|
1424
|
+
|
1425
|
+
# Hide any unused axes if there are any
|
1426
|
+
for ax in axes[idx+1:]:
|
1427
|
+
ax.set_visible(False)
|
1428
|
+
|
1429
|
+
plt.tight_layout()
|
1430
|
+
plt.show()
|
1431
|
+
return fig
|
1432
|
+
|
1433
|
+
def plot_correlation_matrix(df, annot=False, cmap='inferno', clean=True):
|
1434
|
+
"""
|
1435
|
+
Plots a correlation matrix for the specified variables and the target variable.
|
1436
|
+
|
1437
|
+
Args:
|
1438
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1439
|
+
variables (list): List of column names to include in the correlation matrix.
|
1440
|
+
target_variable (str): The target variable column name.
|
1441
|
+
|
1442
|
+
Returns:
|
1443
|
+
None
|
1444
|
+
"""
|
1445
|
+
cmap = sns.diverging_palette(240, 10, as_cmap=True)
|
1446
|
+
grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1447
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1448
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1449
|
+
|
1450
|
+
grouping_vars = grouping_vars + ['optimal_threshold', 'accuracy', 'prauc', 'roc_auc','genes_per_well_gini', 'wells_per_gene_gini']
|
1451
|
+
# 'inactive_mean', 'inactive_std', 'inactive_var', 'active_mean', 'active_std', 'inactive_var', 'cutoff', 'TP', 'FP', 'TN', 'FN',
|
1452
|
+
|
1453
|
+
if clean:
|
1454
|
+
df = remove_constant_columns(df)
|
1455
|
+
grouping_vars = [feature for feature in grouping_vars if feature in df.columns]
|
1456
|
+
|
1457
|
+
# Subsetting the DataFrame to include only the relevant variables
|
1458
|
+
relevant_data = df[grouping_vars]
|
1459
|
+
|
1460
|
+
if clean:
|
1461
|
+
relevant_data = remove_columns_with_single_value(relevant_data)
|
1462
|
+
|
1463
|
+
# Calculating the correlation matrix
|
1464
|
+
corr_matrix = relevant_data.corr()
|
1465
|
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
|
1466
|
+
|
1467
|
+
# Plotting the correlation matrix
|
1468
|
+
fig = plt.figure(figsize=(12, 8))
|
1469
|
+
sns.heatmap(corr_matrix, mask=mask, annot=annot, cmap=cmap, fmt=".2f", linewidths=.5, robust=True)
|
1470
|
+
#plt.title('Correlation Matrix with Heatmap')
|
1471
|
+
|
1472
|
+
plt.tight_layout()
|
1473
|
+
plt.show()
|
1474
|
+
save_plot(fig, src='figures', variable='correlation_matrix', i=1)
|
1475
|
+
return fig
|
1476
|
+
|
1477
|
+
def plot_feature_importance(df, target='prauc', exclude=None, clean=True):
|
1478
|
+
"""
|
1479
|
+
Trains a RandomForestRegressor to determine the importance of each feature in predicting the target.
|
1480
|
+
|
1481
|
+
Args:
|
1482
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1483
|
+
target (str): The target variable column name.
|
1484
|
+
exclude (list or str, optional): Column names to exclude from features.
|
1485
|
+
|
1486
|
+
Returns:
|
1487
|
+
matplotlib.figure.Figure: The figure object containing the feature importance plot.
|
1488
|
+
"""
|
1489
|
+
|
1490
|
+
# Define the features for the model
|
1491
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1492
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1493
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1494
|
+
|
1495
|
+
if clean:
|
1496
|
+
df = remove_columns_with_single_value(df)
|
1497
|
+
features = [feature for feature in features if feature in df.columns]
|
1498
|
+
|
1499
|
+
# Remove excluded features if specified
|
1500
|
+
if isinstance(exclude, list):
|
1501
|
+
features = [feature for feature in features if feature not in exclude]
|
1502
|
+
elif exclude is not None:
|
1503
|
+
features = [feature for feature in features if feature != exclude]
|
1504
|
+
|
1505
|
+
# Train the model
|
1506
|
+
model = RandomForestRegressor(n_estimators=1000, random_state=42)
|
1507
|
+
model.fit(df[features], df[target])
|
1508
|
+
|
1509
|
+
# Get feature importances
|
1510
|
+
importances = model.feature_importances_
|
1511
|
+
indices = np.argsort(importances)[::-1]
|
1512
|
+
|
1513
|
+
# Plot horizontal bar chart
|
1514
|
+
fig = plt.figure(figsize=(12, 6))
|
1515
|
+
plt.barh(range(len(indices)), importances[indices], color="teal", align="center", alpha=0.6)
|
1516
|
+
plt.yticks(range(len(indices)), [features[i] for i in indices[::-1]]) # Invert y-axis to match the order
|
1517
|
+
plt.gca().invert_yaxis() # Invert the axis to have the highest importance at the top
|
1518
|
+
plt.xlabel('Feature Importance')
|
1519
|
+
plt.title('Feature Importances')
|
1520
|
+
plt.tight_layout()
|
1521
|
+
plt.show()
|
1522
|
+
save_plot(fig, src='figures', variable='feature_importance', i=1)
|
1523
|
+
return fig
|
1524
|
+
|
1525
|
+
def calculate_permutation_importance(df, target='prauc', exclude=None, n_repeats=10, clean=True):
|
1526
|
+
"""
|
1527
|
+
Calculates permutation importance for the given features in the dataframe.
|
1528
|
+
|
1529
|
+
Args:
|
1530
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1531
|
+
features (list): List of column names to include as features.
|
1532
|
+
target (str): The name of the target variable column.
|
1533
|
+
|
1534
|
+
Returns:
|
1535
|
+
dict: Dictionary containing the importances and standard deviations.
|
1536
|
+
"""
|
1537
|
+
|
1538
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1539
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1540
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1541
|
+
|
1542
|
+
if clean:
|
1543
|
+
df = remove_columns_with_single_value(df)
|
1544
|
+
features = [feature for feature in features if feature in df.columns]
|
1545
|
+
|
1546
|
+
if isinstance(exclude, list):
|
1547
|
+
for ex in exclude:
|
1548
|
+
features.remove(ex)
|
1549
|
+
if not exclude is None:
|
1550
|
+
features.remove(exclude)
|
1551
|
+
|
1552
|
+
X = df[features]
|
1553
|
+
y = df[target]
|
1554
|
+
|
1555
|
+
# Initialize a model (you could pass it as an argument if you'd like to use a different one)
|
1556
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
1557
|
+
model.fit(X, y)
|
1558
|
+
|
1559
|
+
perm_importance = permutation_importance(model, X, y, n_repeats=n_repeats, random_state=42)
|
1560
|
+
|
1561
|
+
# Plotting
|
1562
|
+
sorted_idx = perm_importance.importances_mean.argsort()
|
1563
|
+
|
1564
|
+
# Create a figure and a set of subplots
|
1565
|
+
fig, ax = plt.subplots()
|
1566
|
+
ax.barh(range(len(sorted_idx)), perm_importance.importances_mean[sorted_idx], color="teal", align="center", alpha=0.6)
|
1567
|
+
ax.set_yticks(range(len(sorted_idx)))
|
1568
|
+
ax.set_yticklabels([df.columns[i] for i in sorted_idx])
|
1569
|
+
ax.set_xlabel('Permutation Importance')
|
1570
|
+
plt.tight_layout()
|
1571
|
+
plt.show()
|
1572
|
+
save_plot(fig, src='figures', variable='permutation_importance', i=1)
|
1573
|
+
return fig
|
1574
|
+
|
1575
|
+
def plot_partial_dependences(df, target='prauc', clean=True):
|
1576
|
+
|
1577
|
+
"""
|
1578
|
+
Creates partial dependence plots for the specified features, with improved layout to avoid text overlap.
|
1579
|
+
|
1580
|
+
Args:
|
1581
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1582
|
+
target (str): The target variable.
|
1583
|
+
|
1584
|
+
Returns:
|
1585
|
+
None
|
1586
|
+
"""
|
1587
|
+
|
1588
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1589
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1590
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1591
|
+
|
1592
|
+
if clean:
|
1593
|
+
df = remove_columns_with_single_value(df)
|
1594
|
+
features = [feature for feature in features if feature in df.columns]
|
1595
|
+
|
1596
|
+
X = df[features]
|
1597
|
+
y = df[target]
|
1598
|
+
|
1599
|
+
# Train a model
|
1600
|
+
model = GradientBoostingRegressor()
|
1601
|
+
model.fit(X, y)
|
1602
|
+
|
1603
|
+
# Determine the number of rows and columns for subplots
|
1604
|
+
n_cols = 4 # Number of columns in subplot grid
|
1605
|
+
n_rows = (len(features) + n_cols - 1) // n_cols # Calculate rows needed
|
1606
|
+
|
1607
|
+
# Plot partial dependence
|
1608
|
+
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(5 * n_cols, 5 * n_rows))
|
1609
|
+
fig.suptitle('Partial Dependence Plots', fontsize=20, y=1.03)
|
1610
|
+
|
1611
|
+
# Flatten the array of axes if it's multidimensional
|
1612
|
+
axs = axs.flatten() if n_rows > 1 else [axs]
|
1613
|
+
|
1614
|
+
for i, feature in enumerate(features):
|
1615
|
+
ax = axs[i]
|
1616
|
+
disp = PartialDependenceDisplay.from_estimator(model, X, features=[feature], ax=ax)
|
1617
|
+
ax.set_title(feature) # Set title to the name of the feature
|
1618
|
+
|
1619
|
+
# Hide unused axes if any
|
1620
|
+
for ax in axs[len(features):]:
|
1621
|
+
ax.set_visible(False)
|
1622
|
+
|
1623
|
+
plt.tight_layout()
|
1624
|
+
plt.show()
|
1625
|
+
save_plot(fig, src='figures', variable='partial_dependences', i=1)
|
1626
|
+
return fig
|
1627
|
+
|
1628
|
+
def save_shap_plot(fig, src, variable, i):
|
1629
|
+
import os
|
1630
|
+
os.makedirs(f'{src}/{variable}', exist_ok=True)
|
1631
|
+
filename_fig = f'{src}/{variable}/{str(i)}_figure.pdf'
|
1632
|
+
fig.savefig(filename_fig, dpi=600, format='pdf', bbox_inches='tight')
|
1633
|
+
print(f"Saved figure as {filename_fig}")
|
1634
|
+
|
1635
|
+
def generate_shap_summary_plot(df,target='prauc', clean=True):
|
1636
|
+
"""
|
1637
|
+
Generates a SHAP summary plot for the given features in the dataframe.
|
1638
|
+
|
1639
|
+
Args:
|
1640
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1641
|
+
features (list): List of column names to include as features.
|
1642
|
+
target (str): The name of the target variable column.
|
1643
|
+
|
1644
|
+
Returns:
|
1645
|
+
None
|
1646
|
+
"""
|
1647
|
+
|
1648
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1649
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1650
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1651
|
+
|
1652
|
+
if clean:
|
1653
|
+
df = remove_columns_with_single_value(df)
|
1654
|
+
features = [feature for feature in features if feature in df.columns]
|
1655
|
+
|
1656
|
+
X = df[features]
|
1657
|
+
y = df[target]
|
1658
|
+
|
1659
|
+
# Initialize a model (you could pass it as an argument if you'd like to use a different one)
|
1660
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
1661
|
+
model.fit(X, y)
|
1662
|
+
|
1663
|
+
# Calculate SHAP values
|
1664
|
+
explainer = shap.TreeExplainer(model)
|
1665
|
+
shap_values = explainer.shap_values(X)
|
1666
|
+
|
1667
|
+
# Summary plot
|
1668
|
+
shap.summary_plot(shap_values, X)
|
1669
|
+
save_shap_plot(plt.gcf(), src='figures', variable='shap', i=1)
|
1670
|
+
#save_shap_plot(fig, src, variable, i)
|
1671
|
+
return plt.gcf()
|
1672
|
+
|
1673
|
+
|
1674
|
+
def remove_constant_columns(df):
|
1675
|
+
"""
|
1676
|
+
Removes columns in the DataFrame where all entries have the same value.
|
1677
|
+
|
1678
|
+
Parameters:
|
1679
|
+
df (pd.DataFrame): The input DataFrame from which to remove constant columns.
|
1680
|
+
|
1681
|
+
Returns:
|
1682
|
+
pd.DataFrame: A DataFrame with the constant columns removed.
|
1683
|
+
"""
|
1684
|
+
return df.loc[:, df.nunique() > 1]
|
1685
|
+
|
1686
|
+
|
1687
|
+
# to justify using beta for sim classifier
|
1688
|
+
|
1689
|
+
# Fit a Beta distribution to these outputs
|
1690
|
+
#a, b, loc, scale = beta.fit(predicted_probs, floc=0, fscale=1) # Fix location and scale to match the support of the sigmoid
|
1691
|
+
|
1692
|
+
# Sample from this fitted Beta distribution
|
1693
|
+
#simulated_probs = beta.rvs(a, b, size=1000)
|
1694
|
+
|
1695
|
+
# Plot the empirical vs simulated distribution
|
1696
|
+
#plt.hist(predicted_probs, bins=30, alpha=0.5, label='Empirical')
|
1697
|
+
#plt.hist(simulated_probs, bins=30, alpha=0.5, label='Simulated from Beta')
|
1698
|
+
#plt.legend()
|
1699
|
+
#plt.show()
|