spacr 0.0.20__py3-none-any.whl → 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +7 -5
- spacr/chris.py +50 -0
- spacr/core.py +1301 -426
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/gui.py +144 -0
- spacr/gui_classify_app.py +65 -74
- spacr/gui_mask_app.py +110 -87
- spacr/gui_measure_app.py +104 -81
- spacr/gui_utils.py +276 -31
- spacr/io.py +261 -102
- spacr/mask_app.py +6 -3
- spacr/measure.py +150 -64
- spacr/plot.py +151 -12
- spacr/sim.py +666 -119
- spacr/timelapse.py +139 -9
- spacr/train.py +18 -10
- spacr/utils.py +43 -49
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/METADATA +5 -2
- spacr-0.0.35.dist-info/RECORD +35 -0
- spacr-0.0.35.dist-info/entry_points.txt +8 -0
- spacr-0.0.20.dist-info/RECORD +0 -31
- spacr-0.0.20.dist-info/entry_points.txt +0 -7
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/LICENSE +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/WHEEL +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/top_level.txt +0 -0
spacr/sim.py
CHANGED
@@ -1,16 +1,19 @@
|
|
1
1
|
|
2
|
-
import os,
|
3
|
-
import time as tm
|
2
|
+
import os, random, warnings, traceback, sqlite3, shap, math, gc
|
4
3
|
from time import time, sleep
|
5
4
|
from datetime import datetime
|
6
5
|
import numpy as np
|
7
6
|
import pandas as pd
|
8
7
|
import matplotlib.pyplot as plt
|
8
|
+
from matplotlib.patches import Patch
|
9
9
|
import seaborn as sns
|
10
10
|
import sklearn.metrics as metrics
|
11
|
-
from sklearn.
|
11
|
+
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
12
|
+
from sklearn.inspection import PartialDependenceDisplay, permutation_importance
|
13
|
+
from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_recall_curve
|
12
14
|
import statsmodels.api as sm
|
13
|
-
from multiprocessing import cpu_count,
|
15
|
+
from multiprocessing import cpu_count, Pool, Manager
|
16
|
+
from copy import deepcopy
|
14
17
|
|
15
18
|
from .logger import log_function_call
|
16
19
|
|
@@ -35,7 +38,7 @@ def generate_gene_list(number_of_genes, number_of_all_genes):
|
|
35
38
|
|
36
39
|
# plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
|
37
40
|
def generate_plate_map(nr_plates):
|
38
|
-
print('nr_plates',nr_plates)
|
41
|
+
#print('nr_plates',nr_plates)
|
39
42
|
"""
|
40
43
|
Generate a plate map based on the number of plates.
|
41
44
|
|
@@ -241,6 +244,7 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
241
244
|
- df_ls (list): A list containing gene counts per well, well counts per gene, Gini coefficients for wells,
|
242
245
|
Gini coefficients for genes, gene weights array, and well weights.
|
243
246
|
"""
|
247
|
+
|
244
248
|
#generate primary distributions and genes
|
245
249
|
cpw, _ = dist_gen(avg_cells_per_well, sd_cells_per_well, plate_map)
|
246
250
|
gpw, _ = dist_gen(avg_genes_per_well, sd_genes_per_well, plate_map)
|
@@ -251,14 +255,18 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
251
255
|
gene_weights_array = np.array(list(gene_weights.values())) # Convert the values to an array
|
252
256
|
|
253
257
|
well_weights = generate_power_law_distribution(len(plate_map), well_ineq_coeff)
|
254
|
-
|
255
|
-
gene_to_well_mapping = {}
|
256
|
-
|
257
|
-
for gene in genes:
|
258
|
-
gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=int(gpw[gene-1]), p=well_weights) # Generate a number of wells for each gene according to well_weights
|
259
|
-
|
260
|
-
gene_to_well_mapping = {gene: wells for gene, wells in gene_to_well_mapping.items() if len(wells) >= 2}
|
261
258
|
|
259
|
+
gene_to_well_mapping = {}
|
260
|
+
for gene in range(1, number_of_genes + 1): # ensures gene-1 is within bounds
|
261
|
+
if gene-1 < len(gpw):
|
262
|
+
max_index = len(plate_map['plate_row_column']) # this should be the number of choices available from plate_map
|
263
|
+
num_samples = int(gpw[gene-1])
|
264
|
+
if num_samples >= max_index:
|
265
|
+
num_samples = max_index - 1 # adjust to maximum possible index
|
266
|
+
gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=num_samples, replace=False, p=well_weights)
|
267
|
+
else:
|
268
|
+
break # break the loop if gene-1 is out of bounds for gpw
|
269
|
+
|
262
270
|
cells = []
|
263
271
|
for i in [*range(0,len(plate_map))]:
|
264
272
|
ciw = random.choice(cpw)
|
@@ -316,9 +324,80 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
316
324
|
df_ls = [gene_counts_per_well, well_counts_per_gene, gini_well, gini_gene, gene_weights_array, well_weights]
|
317
325
|
return cell_df, genes_per_well_df, wells_per_gene_df, df_ls
|
318
326
|
|
319
|
-
|
320
|
-
|
321
|
-
|
327
|
+
def classifier(positive_mean, positive_variance, negative_mean, negative_variance, classifier_accuracy, df):
|
328
|
+
"""
|
329
|
+
Classifies the data in the DataFrame based on the given parameters and a classifier error rate.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
positive_mean (float): The mean of the positive distribution.
|
333
|
+
positive_variance (float): The variance of the positive distribution.
|
334
|
+
negative_mean (float): The mean of the negative distribution.
|
335
|
+
negative_variance (float): The variance of the negative distribution.
|
336
|
+
classifier_accuracy (float): The likelihood (0 to 1) that a gene is correctly classified according to its true label.
|
337
|
+
df (pandas.DataFrame): The DataFrame containing the data to be classified.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
|
341
|
+
"""
|
342
|
+
def calc_alpha_beta(mean, variance):
|
343
|
+
if mean <= 0 or mean >= 1:
|
344
|
+
raise ValueError("Mean must be between 0 and 1 exclusively.")
|
345
|
+
max_variance = mean * (1 - mean)
|
346
|
+
if variance <= 0 or variance >= max_variance:
|
347
|
+
raise ValueError(f"Variance must be positive and less than {max_variance}.")
|
348
|
+
|
349
|
+
alpha = mean * (mean * (1 - mean) / variance - 1)
|
350
|
+
beta = alpha * (1 - mean) / mean
|
351
|
+
return alpha, beta
|
352
|
+
|
353
|
+
# Apply the beta distribution based on 'is_active' status with consideration for classifier error
|
354
|
+
def get_score(is_active):
|
355
|
+
if np.random.rand() < classifier_accuracy: # With classifier_accuracy probability, choose the correct distribution
|
356
|
+
return np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2)
|
357
|
+
else: # With 1-classifier_accuracy probability, choose the incorrect distribution
|
358
|
+
return np.random.beta(a2, b2) if is_active else np.random.beta(a1, b1)
|
359
|
+
|
360
|
+
# Calculate alpha and beta for both distributions
|
361
|
+
a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
|
362
|
+
a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
|
363
|
+
df['score'] = df['is_active'].apply(get_score)
|
364
|
+
|
365
|
+
return df
|
366
|
+
|
367
|
+
def classifier_v2(positive_mean, positive_variance, negative_mean, negative_variance, df):
|
368
|
+
"""
|
369
|
+
Classifies the data in the DataFrame based on the given parameters.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
positive_mean (float): The mean of the positive distribution.
|
373
|
+
positive_variance (float): The variance of the positive distribution.
|
374
|
+
negative_mean (float): The mean of the negative distribution.
|
375
|
+
negative_variance (float): The variance of the negative distribution.
|
376
|
+
df (pandas.DataFrame): The DataFrame containing the data to be classified.
|
377
|
+
|
378
|
+
Returns:
|
379
|
+
pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
|
380
|
+
"""
|
381
|
+
def calc_alpha_beta(mean, variance):
|
382
|
+
if mean <= 0 or mean >= 1:
|
383
|
+
raise ValueError("Mean must be between 0 and 1 exclusively.")
|
384
|
+
max_variance = mean * (1 - mean)
|
385
|
+
if variance <= 0 or variance >= max_variance:
|
386
|
+
raise ValueError(f"Variance must be positive and less than {max_variance}.")
|
387
|
+
|
388
|
+
alpha = mean * (mean * (1 - mean) / variance - 1)
|
389
|
+
beta = alpha * (1 - mean) / mean
|
390
|
+
return alpha, beta
|
391
|
+
|
392
|
+
# Calculate alpha and beta for both distributions
|
393
|
+
a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
|
394
|
+
a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
|
395
|
+
|
396
|
+
# Apply the beta distribution based on 'is_active' status
|
397
|
+
df['score'] = df['is_active'].apply(lambda is_active: np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2))
|
398
|
+
return df
|
399
|
+
|
400
|
+
def classifier_v1(positive_mean, positive_variance, negative_mean, negative_variance, df):
|
322
401
|
"""
|
323
402
|
Classifies the data in the DataFrame based on the given parameters.
|
324
403
|
|
@@ -446,6 +525,7 @@ def generate_well_score(cell_scores):
|
|
446
525
|
return well_score
|
447
526
|
|
448
527
|
def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_per_gene, sequencing_error=0.01):
|
528
|
+
|
449
529
|
"""
|
450
530
|
Simulates the sequencing of plates and calculates gene fractions and metadata.
|
451
531
|
|
@@ -460,6 +540,7 @@ def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_pe
|
|
460
540
|
gene_fraction_map (pd.DataFrame): DataFrame containing gene fractions for each well.
|
461
541
|
metadata (pd.DataFrame): DataFrame containing metadata for each well.
|
462
542
|
"""
|
543
|
+
|
463
544
|
reads, _ = dist_gen(avg_reads_per_gene, sd_reads_per_gene, well_score)
|
464
545
|
gene_names = [f'gene_{v}' for v in range(number_of_genes+1)]
|
465
546
|
all_wells = well_score.index
|
@@ -612,8 +693,12 @@ def plot_histogram(data, x_label, ax, color, title, binwidth=0.01, log=False):
|
|
612
693
|
Returns:
|
613
694
|
None
|
614
695
|
"""
|
615
|
-
|
616
|
-
|
696
|
+
if not binwidth:
|
697
|
+
sns.histplot(data=data, x=x_label, ax=ax, color=color, kde=False, stat='density',
|
698
|
+
legend=False, fill=True, element='step', palette='dark')
|
699
|
+
else:
|
700
|
+
sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
|
701
|
+
legend=False, fill=True, element='step', palette='dark')
|
617
702
|
if log:
|
618
703
|
ax.set_yscale('log')
|
619
704
|
ax.set_title(title)
|
@@ -700,7 +785,7 @@ def run_simulation(settings):
|
|
700
785
|
plate_map = plate_map[~plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is not in [1,2,3,23,24]
|
701
786
|
|
702
787
|
cell_level, genes_per_well_df, wells_per_gene_df, dists = run_experiment(plate_map, settings['number_of_genes'], active_gene_list, settings['avg_genes_per_well'], settings['sd_genes_per_well'], settings['avg_cells_per_well'], settings['sd_cells_per_well'], settings['well_ineq_coeff'], settings['gene_ineq_coeff'])
|
703
|
-
cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], df=cell_level)
|
788
|
+
cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], settings['classifier_accuracy'], df=cell_level)
|
704
789
|
cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm = cell_level_roc_auc(cell_scores)
|
705
790
|
well_score = generate_well_score(cell_scores)
|
706
791
|
gene_fraction_map, metadata = sequence_plates(well_score, settings['number_of_genes'], settings['avg_reads_per_gene'], settings['sd_reads_per_gene'], sequencing_error=settings['sequencing_error'])
|
@@ -720,7 +805,10 @@ def run_simulation(settings):
|
|
720
805
|
results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats = regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False)
|
721
806
|
#except Exception as e:
|
722
807
|
# print(f"An error occurred while saving data: {e}")
|
723
|
-
|
808
|
+
output = [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df]
|
809
|
+
del cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df
|
810
|
+
gc.collect()
|
811
|
+
return output, dists
|
724
812
|
|
725
813
|
def vis_dists(dists, src, v, i):
|
726
814
|
"""
|
@@ -744,9 +832,14 @@ def vis_dists(dists, src, v, i):
|
|
744
832
|
for index, dist in enumerate(dists):
|
745
833
|
temp = pd.DataFrame(dist, columns = [f'{names[index]}'])
|
746
834
|
sns.histplot(data=temp, x=f'{names[index]}', kde=False, binwidth=None, stat='count', element="step", ax=ax[n], color='teal', log_scale=False)
|
747
|
-
#plot_histogram(temp, f'{names[index]}', ax[n], 'slategray', f'{names[index]}', binwidth=None, log=False)
|
748
835
|
n+=1
|
749
836
|
save_plot(fig2, src, 'dists', i)
|
837
|
+
plt.close(fig2)
|
838
|
+
plt.figure().clear()
|
839
|
+
plt.cla()
|
840
|
+
plt.clf()
|
841
|
+
del dists
|
842
|
+
|
750
843
|
return
|
751
844
|
|
752
845
|
def visualize_all(output):
|
@@ -773,6 +866,7 @@ def visualize_all(output):
|
|
773
866
|
Returns:
|
774
867
|
fig (matplotlib.figure.Figure): The generated figure object.
|
775
868
|
"""
|
869
|
+
|
776
870
|
cell_scores = output[0]
|
777
871
|
cell_roc_dict_df = output[1]
|
778
872
|
cell_pr_dict_df = output[2]
|
@@ -810,24 +904,35 @@ def visualize_all(output):
|
|
810
904
|
#plot cell classification score by inactive and active
|
811
905
|
active_distribution = cell_scores[cell_scores['is_active'] == 1]
|
812
906
|
inactive_distribution = cell_scores[cell_scores['is_active'] == 0]
|
813
|
-
plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', binwidth=0.01, log=False)
|
814
|
-
plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', binwidth=0.01, log=False)
|
907
|
+
plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', log=False)#, binwidth=0.01, log=False)
|
908
|
+
plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', log=False)#, binwidth=0.01, log=False)
|
909
|
+
|
910
|
+
legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
|
911
|
+
Patch(facecolor='teal', edgecolor='teal', label='Active')]
|
912
|
+
|
913
|
+
ax[n].legend(handles=legend_elements, loc='upper right')
|
914
|
+
|
915
|
+
|
815
916
|
ax[n].set_xlim([0, 1])
|
816
917
|
n+=1
|
817
918
|
|
818
919
|
#plot classifier cell predictions by inactive and active well average
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
920
|
+
inactive_distribution_well = inactive_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
921
|
+
active_distribution_well = active_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
922
|
+
mixed_distribution_well = cell_scores.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
923
|
+
|
924
|
+
plot_histogram(inactive_distribution_well, "score", ax[n], 'slategray', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
925
|
+
plot_histogram(active_distribution_well, "score", ax[n], 'teal', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
926
|
+
plot_histogram(mixed_distribution_well, "score", ax[n], 'red', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
824
927
|
|
825
|
-
|
826
|
-
|
928
|
+
legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
|
929
|
+
Patch(facecolor='teal', edgecolor='teal', label='Active'),
|
930
|
+
Patch(facecolor='red', edgecolor='red', label='Mixed')]
|
827
931
|
|
828
|
-
|
829
|
-
|
932
|
+
ax[n].legend(handles=legend_elements, loc='upper right')
|
933
|
+
|
830
934
|
ax[n].set_xlim([0, 1])
|
935
|
+
#ax[n].legend()
|
831
936
|
n+=1
|
832
937
|
|
833
938
|
#plot ROC (cell classification)
|
@@ -846,8 +951,8 @@ def visualize_all(output):
|
|
846
951
|
n+=1
|
847
952
|
|
848
953
|
#plot well score
|
849
|
-
plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=
|
850
|
-
ax[n].set_xlim([0, 1])
|
954
|
+
plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=True)
|
955
|
+
#ax[n].set_xlim([0, 1])
|
851
956
|
n+=1
|
852
957
|
|
853
958
|
control_df = results_df[results_df['color'] == 'control']
|
@@ -911,6 +1016,7 @@ def visualize_all(output):
|
|
911
1016
|
|
912
1017
|
plt.tight_layout()
|
913
1018
|
plt.show()
|
1019
|
+
gc.collect()
|
914
1020
|
return fig
|
915
1021
|
|
916
1022
|
def create_database(db_path):
|
@@ -992,6 +1098,7 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
|
|
992
1098
|
df_concat[f'variable_{variable}_sim_nr'] = i
|
993
1099
|
|
994
1100
|
append_database(src, df_concat, 'simulations')
|
1101
|
+
del gini_genes_per_well, gini_wells_per_gene, df_concat
|
995
1102
|
|
996
1103
|
if save_all:
|
997
1104
|
for i, df in enumerate(output):
|
@@ -1001,8 +1108,12 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
|
|
1001
1108
|
if not isinstance(df, pd.DataFrame):
|
1002
1109
|
df = pd.DataFrame(df)
|
1003
1110
|
append_database(src, df, table_names[i])
|
1111
|
+
del df
|
1004
1112
|
except Exception as e:
|
1005
1113
|
print(f"An error occurred while saving data: {e}")
|
1114
|
+
print(traceback.format_exc())
|
1115
|
+
|
1116
|
+
del output, settings_df
|
1006
1117
|
return
|
1007
1118
|
|
1008
1119
|
def save_plot(fig, src, variable, i):
|
@@ -1025,7 +1136,6 @@ def save_plot(fig, src, variable, i):
|
|
1025
1136
|
|
1026
1137
|
def run_and_save(i, settings, time_ls, total_sims):
|
1027
1138
|
|
1028
|
-
|
1029
1139
|
"""
|
1030
1140
|
Run the simulation and save the results.
|
1031
1141
|
|
@@ -1040,24 +1150,24 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1040
1150
|
"""
|
1041
1151
|
#print(f'Runnings simulation with the following paramiters')
|
1042
1152
|
#print(settings)
|
1043
|
-
|
1153
|
+
settings['random_seed'] = False
|
1044
1154
|
if settings['random_seed']:
|
1045
1155
|
random.seed(42) # sims will be too similar with random seed
|
1046
1156
|
src = settings['src']
|
1047
1157
|
plot = settings['plot']
|
1048
1158
|
v = settings['variable']
|
1049
1159
|
start_time = time() # Start time of the simulation
|
1050
|
-
now = datetime.now() # get current date
|
1051
|
-
date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1160
|
+
#now = datetime.now() # get current date
|
1161
|
+
#date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1162
|
+
date_string = settings['start_time']
|
1052
1163
|
#try:
|
1053
1164
|
output, dists = run_simulation(settings)
|
1054
1165
|
sim_time = time() - start_time # Elapsed time for the simulation
|
1055
1166
|
settings['sim_time'] = sim_time
|
1056
1167
|
src = os.path.join(f'{src}/{date_string}',settings['name'])
|
1057
1168
|
save_data(src, output, settings, save_all=False, i=i, variable=v)
|
1058
|
-
if vis_dists:
|
1059
|
-
vis_dists(dists,src, v, i)
|
1060
1169
|
if plot:
|
1170
|
+
vis_dists(dists,src, v, i)
|
1061
1171
|
fig = visualize_all(output)
|
1062
1172
|
save_plot(fig, src, v, i)
|
1063
1173
|
plt.close(fig)
|
@@ -1066,6 +1176,7 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1066
1176
|
plt.clf()
|
1067
1177
|
del fig
|
1068
1178
|
del output, dists
|
1179
|
+
gc.collect()
|
1069
1180
|
#except Exception as e:
|
1070
1181
|
# print(e, end='\r', flush=True)
|
1071
1182
|
# sim_time = time() - start_time
|
@@ -1073,98 +1184,92 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1073
1184
|
time_ls.append(sim_time)
|
1074
1185
|
return i, sim_time, None
|
1075
1186
|
|
1076
|
-
def
|
1187
|
+
def validate_and_adjust_beta_params(sim_params):
|
1077
1188
|
"""
|
1078
|
-
|
1079
|
-
|
1189
|
+
Validates and adjusts Beta distribution parameters in simulation settings to ensure they are possible.
|
1190
|
+
|
1080
1191
|
Args:
|
1081
|
-
|
1082
|
-
|
1192
|
+
sim_params (list of dict): List of dictionaries, each containing the simulation parameters.
|
1193
|
+
|
1083
1194
|
Returns:
|
1084
|
-
|
1195
|
+
list of dict: The adjusted list of simulation parameter sets.
|
1085
1196
|
"""
|
1086
|
-
|
1087
|
-
for
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
sett['well_ineq_coeff'] = well_ineq_coeff
|
1107
|
-
for gene_ineq_coeff in settings['gene_ineq_coeff']:
|
1108
|
-
sett['gene_ineq_coeff'] = gene_ineq_coeff
|
1109
|
-
for nr_plates in settings['nr_plates']:
|
1110
|
-
sett['nr_plates'] = nr_plates
|
1111
|
-
for number_of_genes in settings['number_of_genes']:
|
1112
|
-
sett['number_of_genes'] = number_of_genes
|
1113
|
-
for number_of_active_genes in settings['number_of_active_genes']:
|
1114
|
-
sett['number_of_active_genes'] = number_of_active_genes
|
1115
|
-
for i in [*range(1,replicates+1)]:
|
1116
|
-
sim_ls.append(sett)
|
1117
|
-
#print(sett)
|
1118
|
-
#print('Number of simulations:',len(sim_ls))
|
1119
|
-
return sim_ls
|
1197
|
+
adjusted_params = []
|
1198
|
+
for params in sim_params:
|
1199
|
+
max_pos_variance = params['positive_mean'] * (1 - params['positive_mean'])
|
1200
|
+
max_neg_variance = params['negative_mean'] * (1 - params['negative_mean'])
|
1201
|
+
|
1202
|
+
# Adjust positive variance
|
1203
|
+
if params['positive_variance'] >= max_pos_variance:
|
1204
|
+
print(f'changed positive variance from {params["positive_variance"]} to {max_pos_variance * 0.99}')
|
1205
|
+
params['positive_variance'] = max_pos_variance * 0.99 # Adjust to 99% of the maximum allowed variance
|
1206
|
+
|
1207
|
+
# Adjust negative variance
|
1208
|
+
if params['negative_variance'] >= max_neg_variance:
|
1209
|
+
print(f'changed negative variance from {params["negative_variance"]} to {max_neg_variance * 0.99}')
|
1210
|
+
params['negative_variance'] = max_neg_variance * 0.99 # Adjust to 99% of the maximum allowed variance
|
1211
|
+
|
1212
|
+
adjusted_params.append(params)
|
1213
|
+
|
1214
|
+
return adjusted_params
|
1215
|
+
|
1216
|
+
def generate_paramiters(settings):
|
1120
1217
|
|
1121
|
-
#altered for one set of settings see negative_mean and variance
|
1122
|
-
def generate_paramiters_single(settings):
|
1123
1218
|
"""
|
1124
|
-
Generate a list of parameter sets for
|
1219
|
+
Generate a list of parameter sets for simulation based on the given settings.
|
1125
1220
|
|
1126
1221
|
Args:
|
1127
1222
|
settings (dict): A dictionary containing the simulation settings.
|
1128
1223
|
|
1129
1224
|
Returns:
|
1130
|
-
list: A list of parameter sets for
|
1225
|
+
list: A list of parameter sets for simulation.
|
1131
1226
|
"""
|
1227
|
+
|
1228
|
+
settings['positive_mean'] = [0.8]
|
1229
|
+
|
1132
1230
|
sim_ls = []
|
1133
1231
|
for avg_genes_per_well in settings['avg_genes_per_well']:
|
1134
1232
|
replicates = settings['replicates']
|
1135
|
-
sett = settings.copy()
|
1136
|
-
sett['avg_genes_per_well'] = avg_genes_per_well
|
1137
|
-
sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
|
1138
1233
|
for avg_cells_per_well in settings['avg_cells_per_well']:
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1234
|
+
for classifier_accuracy in settings['classifier_accuracy']:
|
1235
|
+
for positive_mean in settings['positive_mean']:
|
1236
|
+
for avg_reads_per_gene in settings['avg_reads_per_gene']:
|
1237
|
+
for sequencing_error in settings['sequencing_error']:
|
1238
|
+
for well_ineq_coeff in settings['well_ineq_coeff']:
|
1239
|
+
for gene_ineq_coeff in settings['gene_ineq_coeff']:
|
1240
|
+
for nr_plates in settings['nr_plates']:
|
1241
|
+
for number_of_genes in settings['number_of_genes']:
|
1242
|
+
for number_of_active_genes in settings['number_of_active_genes']:
|
1243
|
+
for i in range(1, replicates+1):
|
1244
|
+
sett = deepcopy(settings)
|
1245
|
+
sett['avg_genes_per_well'] = avg_genes_per_well
|
1246
|
+
sett['sd_genes_per_well'] = avg_genes_per_well / 2
|
1247
|
+
sett['avg_cells_per_well'] = avg_cells_per_well
|
1248
|
+
sett['sd_cells_per_well'] = avg_cells_per_well / 2
|
1249
|
+
sett['classifier_accuracy'] = classifier_accuracy
|
1250
|
+
sett['positive_mean'] = positive_mean
|
1251
|
+
sett['negative_mean'] = 1-positive_mean
|
1252
|
+
sett['positive_variance'] = (1-positive_mean)/2
|
1253
|
+
sett['negative_variance'] = (1-positive_mean)/2
|
1254
|
+
sett['avg_reads_per_gene'] = avg_reads_per_gene
|
1255
|
+
sett['sd_reads_per_gene'] = avg_reads_per_gene / 2
|
1256
|
+
sett['sequencing_error'] = sequencing_error
|
1257
|
+
sett['well_ineq_coeff'] = well_ineq_coeff
|
1258
|
+
sett['gene_ineq_coeff'] = gene_ineq_coeff
|
1259
|
+
sett['nr_plates'] = nr_plates
|
1260
|
+
sett['number_of_genes'] = number_of_genes
|
1261
|
+
sett['number_of_active_genes'] = number_of_active_genes
|
1262
|
+
sim_ls.append(sett)
|
1263
|
+
|
1264
|
+
random.shuffle(sim_ls)
|
1265
|
+
sim_ls = validate_and_adjust_beta_params(sim_ls)
|
1266
|
+
print(f'Running {len(sim_ls)} simulations.')
|
1267
|
+
#for x in sim_ls:
|
1268
|
+
# print(x['positive_mean'])
|
1165
1269
|
return sim_ls
|
1166
1270
|
|
1167
1271
|
def run_multiple_simulations(settings):
|
1272
|
+
|
1168
1273
|
"""
|
1169
1274
|
Run multiple simulations in parallel using the provided settings.
|
1170
1275
|
|
@@ -1175,8 +1280,12 @@ def run_multiple_simulations(settings):
|
|
1175
1280
|
None
|
1176
1281
|
"""
|
1177
1282
|
|
1283
|
+
now = datetime.now() # get current date
|
1284
|
+
start_time = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1285
|
+
settings['start_time'] = start_time
|
1286
|
+
|
1178
1287
|
sim_ls = generate_paramiters(settings)
|
1179
|
-
print(f'Running {len(sim_ls)} simulations.
|
1288
|
+
#print(f'Running {len(sim_ls)} simulations.')
|
1180
1289
|
|
1181
1290
|
max_workers = settings['max_workers'] or cpu_count() - 4
|
1182
1291
|
with Manager() as manager:
|
@@ -1185,10 +1294,448 @@ def run_multiple_simulations(settings):
|
|
1185
1294
|
with Pool(max_workers) as pool:
|
1186
1295
|
result = pool.starmap_async(run_and_save, [(index, settings, time_ls, total_sims) for index, settings in enumerate(sim_ls)])
|
1187
1296
|
while not result.ready():
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1297
|
+
try:
|
1298
|
+
sleep(0.01)
|
1299
|
+
sims_processed = len(time_ls)
|
1300
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
1301
|
+
time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
|
1302
|
+
print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
|
1303
|
+
gc.collect()
|
1304
|
+
except Exception as e:
|
1305
|
+
print(e)
|
1306
|
+
print(traceback.format_exc())
|
1307
|
+
try:
|
1308
|
+
result.get()
|
1309
|
+
except Exception as e:
|
1310
|
+
print(e)
|
1311
|
+
print(traceback.format_exc())
|
1194
1312
|
|
1313
|
+
def generate_integers(start, stop, step):
|
1314
|
+
return list(range(start, stop + 1, step))
|
1315
|
+
|
1316
|
+
def generate_floats(start, stop, step):
|
1317
|
+
# Determine the number of decimal places in 'step'
|
1318
|
+
num_decimals = str(step)[::-1].find('.')
|
1319
|
+
|
1320
|
+
current = start
|
1321
|
+
floats_list = []
|
1322
|
+
while current <= stop:
|
1323
|
+
# Round each float to the appropriate number of decimal places
|
1324
|
+
floats_list.append(round(current, num_decimals))
|
1325
|
+
current += step
|
1326
|
+
|
1327
|
+
return floats_list
|
1328
|
+
|
1329
|
+
def remove_columns_with_single_value_v1(df):
|
1330
|
+
"""
|
1331
|
+
Removes columns from the DataFrame that have the same value in all rows.
|
1332
|
+
|
1333
|
+
Args:
|
1334
|
+
df (pandas.DataFrame): The original DataFrame.
|
1335
|
+
|
1336
|
+
Returns:
|
1337
|
+
pandas.DataFrame: A DataFrame with the columns removed that contained only one unique value.
|
1338
|
+
"""
|
1339
|
+
|
1340
|
+
df=df.copy()
|
1341
|
+
|
1342
|
+
for column in df.columns:
|
1343
|
+
if len(df[column].unique()) == 1:
|
1344
|
+
df.drop(column, axis=1, inplace=True)
|
1345
|
+
|
1346
|
+
return df
|
1347
|
+
|
1348
|
+
def remove_columns_with_single_value(df):
|
1349
|
+
"""
|
1350
|
+
Removes columns from the DataFrame that have the same value in all rows.
|
1351
|
+
|
1352
|
+
Args:
|
1353
|
+
df (pandas.DataFrame): The original DataFrame.
|
1354
|
+
|
1355
|
+
Returns:
|
1356
|
+
pandas.DataFrame: A DataFrame with the columns removed that contained only one unique value.
|
1357
|
+
"""
|
1358
|
+
to_drop = [column for column in df.columns if df[column].nunique() == 1]
|
1359
|
+
return df.drop(to_drop, axis=1)
|
1360
|
+
|
1361
|
+
def read_simulations_table(db_path):
|
1362
|
+
"""
|
1363
|
+
Reads the 'simulations' table from an SQLite database into a pandas DataFrame.
|
1364
|
+
|
1365
|
+
Args:
|
1366
|
+
db_path (str): The file path to the SQLite database.
|
1367
|
+
|
1368
|
+
Returns:
|
1369
|
+
pandas.DataFrame: DataFrame containing the 'simulations' table data.
|
1370
|
+
"""
|
1371
|
+
# Create a connection object using the connect function
|
1372
|
+
conn = sqlite3.connect(db_path)
|
1373
|
+
|
1374
|
+
# Read the 'simulations' table into a pandas DataFrame
|
1375
|
+
try:
|
1376
|
+
df = pd.read_sql_query("SELECT * FROM simulations", conn)
|
1377
|
+
except Exception as e:
|
1378
|
+
print(f"An error occurred: {e}")
|
1379
|
+
return None
|
1380
|
+
finally:
|
1381
|
+
# Close the connection to SQLite database
|
1382
|
+
conn.close()
|
1383
|
+
|
1384
|
+
return df
|
1385
|
+
|
1386
|
+
def plot_simulations(df, variable, x_rotation=None, legend=False, grid=False, clean=True, verbose=False):
|
1387
|
+
|
1388
|
+
"""
|
1389
|
+
Creates separate line plots for 'prauc' against a specified 'variable',
|
1390
|
+
for each unique combination of conditions defined by 'grouping_vars', displayed on a grid.
|
1391
|
+
|
1392
|
+
Args:
|
1393
|
+
df (pandas.DataFrame): DataFrame containing the necessary columns.
|
1394
|
+
variable (str): Name of the column to use as the x-axis for grouping and plotting.
|
1395
|
+
x_rotation (int, optional): Degrees to rotate the x-axis labels.
|
1396
|
+
legend (bool, optional): Whether to display a legend.
|
1397
|
+
grid (bool, optional): Whether to display grid lines.
|
1398
|
+
verbose (bool, optional): Whether to print the filter conditions.
|
1399
|
+
|
1400
|
+
Returns:
|
1401
|
+
None
|
1402
|
+
"""
|
1403
|
+
|
1404
|
+
grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1405
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1406
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1407
|
+
|
1408
|
+
if clean:
|
1409
|
+
relevant_data = remove_columns_with_single_value(relevant_data)
|
1410
|
+
|
1411
|
+
grouping_vars = [col for col in grouping_vars if col != variable]
|
1412
|
+
|
1413
|
+
# Check if the necessary columns are present in the DataFrame
|
1414
|
+
required_columns = {variable, 'prauc'} | set(grouping_vars)
|
1415
|
+
if not required_columns.issubset(df.columns):
|
1416
|
+
missing_cols = required_columns - set(df.columns)
|
1417
|
+
raise ValueError(f"DataFrame must contain {missing_cols} columns")
|
1418
|
+
|
1419
|
+
#if not dependent is None:
|
1420
|
+
|
1421
|
+
# Get unique combinations of conditions from grouping_vars
|
1422
|
+
unique_combinations = df[grouping_vars].drop_duplicates()
|
1423
|
+
num_combinations = len(unique_combinations)
|
1424
|
+
|
1425
|
+
# Determine the layout of the subplots
|
1426
|
+
num_rows = math.ceil(np.sqrt(num_combinations))
|
1427
|
+
num_cols = math.ceil(num_combinations / num_rows)
|
1428
|
+
|
1429
|
+
fig, axes = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
|
1430
|
+
if num_rows * num_cols > 1:
|
1431
|
+
axes = axes.flatten()
|
1432
|
+
else:
|
1433
|
+
axes = [axes]
|
1434
|
+
|
1435
|
+
for idx, (ax, (_, row)) in enumerate(zip(axes, unique_combinations.iterrows())):
|
1436
|
+
|
1437
|
+
# Filter the DataFrame for the current combination of variables
|
1438
|
+
condition = {var: row[var] for var in grouping_vars}
|
1439
|
+
subset_df = df[df[grouping_vars].eq(row).all(axis=1)]
|
1440
|
+
|
1441
|
+
# Group by 'variable' and calculate mean and std dev of 'prauc'
|
1442
|
+
grouped = subset_df.groupby(variable)['prauc'].agg(['mean', 'std'])
|
1443
|
+
grouped = grouped.sort_index() # Sort by the variable for orderly plots
|
1444
|
+
|
1445
|
+
# Plotting the mean of 'prauc' with std deviation as shaded area
|
1446
|
+
ax.plot(grouped.index, grouped['mean'], marker='o', linestyle='-', color='b', label='Mean PRAUC')
|
1447
|
+
ax.fill_between(grouped.index, grouped['mean'] - grouped['std'], grouped['mean'] + grouped['std'], color='gray', alpha=0.5, label='Std Dev')
|
1448
|
+
|
1449
|
+
# Setting plot labels and title
|
1450
|
+
title_details = ', '.join([f"{var}={row[var]}" for var in grouping_vars])
|
1451
|
+
ax.set_xlabel(variable)
|
1452
|
+
ax.set_ylabel('Precision-Recall AUC (PRAUC)')
|
1453
|
+
#ax.set_title(f'PRAUC vs. {variable} | {title_details}')
|
1454
|
+
ax.grid(grid)
|
1455
|
+
|
1456
|
+
if legend:
|
1457
|
+
ax.legend()
|
1458
|
+
|
1459
|
+
# Set x-ticks and rotate them as specified
|
1460
|
+
ax.set_xticks(grouped.index)
|
1461
|
+
ax.set_xticklabels(grouped.index, rotation=x_rotation if x_rotation is not None else 45)
|
1462
|
+
|
1463
|
+
if verbose:
|
1464
|
+
verbose_text = '\n'.join([f"{var}: {val}" for var, val in condition.items()])
|
1465
|
+
ax.text(0.95, 0.05, verbose_text, transform=ax.transAxes, fontsize=9, verticalalignment='bottom', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))
|
1466
|
+
|
1467
|
+
# Hide any unused axes if there are any
|
1468
|
+
for ax in axes[idx+1:]:
|
1469
|
+
ax.set_visible(False)
|
1470
|
+
|
1471
|
+
plt.tight_layout()
|
1472
|
+
plt.show()
|
1473
|
+
return fig
|
1474
|
+
|
1475
|
+
def plot_correlation_matrix(df, annot=False, cmap='inferno', clean=True):
|
1476
|
+
"""
|
1477
|
+
Plots a correlation matrix for the specified variables and the target variable.
|
1478
|
+
|
1479
|
+
Args:
|
1480
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1481
|
+
variables (list): List of column names to include in the correlation matrix.
|
1482
|
+
target_variable (str): The target variable column name.
|
1483
|
+
|
1484
|
+
Returns:
|
1485
|
+
None
|
1486
|
+
"""
|
1487
|
+
cmap = sns.diverging_palette(240, 10, as_cmap=True)
|
1488
|
+
grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1489
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1490
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1491
|
+
|
1492
|
+
grouping_vars = grouping_vars + ['optimal_threshold', 'accuracy', 'prauc', 'roc_auc','genes_per_well_gini', 'wells_per_gene_gini']
|
1493
|
+
# 'inactive_mean', 'inactive_std', 'inactive_var', 'active_mean', 'active_std', 'inactive_var', 'cutoff', 'TP', 'FP', 'TN', 'FN',
|
1494
|
+
|
1495
|
+
if clean:
|
1496
|
+
df = remove_constant_columns(df)
|
1497
|
+
grouping_vars = [feature for feature in grouping_vars if feature in df.columns]
|
1498
|
+
|
1499
|
+
# Subsetting the DataFrame to include only the relevant variables
|
1500
|
+
relevant_data = df[grouping_vars]
|
1501
|
+
|
1502
|
+
if clean:
|
1503
|
+
relevant_data = remove_columns_with_single_value(relevant_data)
|
1504
|
+
|
1505
|
+
# Calculating the correlation matrix
|
1506
|
+
corr_matrix = relevant_data.corr()
|
1507
|
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
|
1508
|
+
|
1509
|
+
# Plotting the correlation matrix
|
1510
|
+
fig = plt.figure(figsize=(12, 8))
|
1511
|
+
sns.heatmap(corr_matrix, mask=mask, annot=annot, cmap=cmap, fmt=".2f", linewidths=.5, robust=True)
|
1512
|
+
#plt.title('Correlation Matrix with Heatmap')
|
1513
|
+
|
1514
|
+
plt.tight_layout()
|
1515
|
+
plt.show()
|
1516
|
+
save_plot(fig, src='figures', variable='correlation_matrix', i=1)
|
1517
|
+
return fig
|
1518
|
+
|
1519
|
+
def plot_feature_importance(df, target='prauc', exclude=None, clean=True):
|
1520
|
+
"""
|
1521
|
+
Trains a RandomForestRegressor to determine the importance of each feature in predicting the target.
|
1522
|
+
|
1523
|
+
Args:
|
1524
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1525
|
+
target (str): The target variable column name.
|
1526
|
+
exclude (list or str, optional): Column names to exclude from features.
|
1527
|
+
|
1528
|
+
Returns:
|
1529
|
+
matplotlib.figure.Figure: The figure object containing the feature importance plot.
|
1530
|
+
"""
|
1531
|
+
|
1532
|
+
# Define the features for the model
|
1533
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1534
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1535
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1536
|
+
|
1537
|
+
if clean:
|
1538
|
+
df = remove_columns_with_single_value(df)
|
1539
|
+
features = [feature for feature in features if feature in df.columns]
|
1540
|
+
|
1541
|
+
# Remove excluded features if specified
|
1542
|
+
if isinstance(exclude, list):
|
1543
|
+
features = [feature for feature in features if feature not in exclude]
|
1544
|
+
elif exclude is not None:
|
1545
|
+
features = [feature for feature in features if feature != exclude]
|
1546
|
+
|
1547
|
+
# Train the model
|
1548
|
+
model = RandomForestRegressor(n_estimators=1000, random_state=42)
|
1549
|
+
model.fit(df[features], df[target])
|
1550
|
+
|
1551
|
+
# Get feature importances
|
1552
|
+
importances = model.feature_importances_
|
1553
|
+
indices = np.argsort(importances)[::-1]
|
1554
|
+
|
1555
|
+
# Plot horizontal bar chart
|
1556
|
+
fig = plt.figure(figsize=(12, 6))
|
1557
|
+
plt.barh(range(len(indices)), importances[indices], color="teal", align="center", alpha=0.6)
|
1558
|
+
plt.yticks(range(len(indices)), [features[i] for i in indices[::-1]]) # Invert y-axis to match the order
|
1559
|
+
plt.gca().invert_yaxis() # Invert the axis to have the highest importance at the top
|
1560
|
+
plt.xlabel('Feature Importance')
|
1561
|
+
plt.title('Feature Importances')
|
1562
|
+
plt.tight_layout()
|
1563
|
+
plt.show()
|
1564
|
+
save_plot(fig, src='figures', variable='feature_importance', i=1)
|
1565
|
+
return fig
|
1566
|
+
|
1567
|
+
def calculate_permutation_importance(df, target='prauc', exclude=None, n_repeats=10, clean=True):
|
1568
|
+
"""
|
1569
|
+
Calculates permutation importance for the given features in the dataframe.
|
1570
|
+
|
1571
|
+
Args:
|
1572
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1573
|
+
features (list): List of column names to include as features.
|
1574
|
+
target (str): The name of the target variable column.
|
1575
|
+
|
1576
|
+
Returns:
|
1577
|
+
dict: Dictionary containing the importances and standard deviations.
|
1578
|
+
"""
|
1579
|
+
|
1580
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1581
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1582
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1583
|
+
|
1584
|
+
if clean:
|
1585
|
+
df = remove_columns_with_single_value(df)
|
1586
|
+
features = [feature for feature in features if feature in df.columns]
|
1587
|
+
|
1588
|
+
if isinstance(exclude, list):
|
1589
|
+
for ex in exclude:
|
1590
|
+
features.remove(ex)
|
1591
|
+
if not exclude is None:
|
1592
|
+
features.remove(exclude)
|
1593
|
+
|
1594
|
+
X = df[features]
|
1595
|
+
y = df[target]
|
1596
|
+
|
1597
|
+
# Initialize a model (you could pass it as an argument if you'd like to use a different one)
|
1598
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
1599
|
+
model.fit(X, y)
|
1600
|
+
|
1601
|
+
perm_importance = permutation_importance(model, X, y, n_repeats=n_repeats, random_state=42)
|
1602
|
+
|
1603
|
+
# Plotting
|
1604
|
+
sorted_idx = perm_importance.importances_mean.argsort()
|
1605
|
+
|
1606
|
+
# Create a figure and a set of subplots
|
1607
|
+
fig, ax = plt.subplots()
|
1608
|
+
ax.barh(range(len(sorted_idx)), perm_importance.importances_mean[sorted_idx], color="teal", align="center", alpha=0.6)
|
1609
|
+
ax.set_yticks(range(len(sorted_idx)))
|
1610
|
+
ax.set_yticklabels([df.columns[i] for i in sorted_idx])
|
1611
|
+
ax.set_xlabel('Permutation Importance')
|
1612
|
+
plt.tight_layout()
|
1613
|
+
plt.show()
|
1614
|
+
save_plot(fig, src='figures', variable='permutation_importance', i=1)
|
1615
|
+
return fig
|
1616
|
+
|
1617
|
+
def plot_partial_dependences(df, target='prauc', clean=True):
|
1618
|
+
|
1619
|
+
"""
|
1620
|
+
Creates partial dependence plots for the specified features, with improved layout to avoid text overlap.
|
1621
|
+
|
1622
|
+
Args:
|
1623
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1624
|
+
target (str): The target variable.
|
1625
|
+
|
1626
|
+
Returns:
|
1627
|
+
None
|
1628
|
+
"""
|
1629
|
+
|
1630
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1631
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1632
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1633
|
+
|
1634
|
+
if clean:
|
1635
|
+
df = remove_columns_with_single_value(df)
|
1636
|
+
features = [feature for feature in features if feature in df.columns]
|
1637
|
+
|
1638
|
+
X = df[features]
|
1639
|
+
y = df[target]
|
1640
|
+
|
1641
|
+
# Train a model
|
1642
|
+
model = GradientBoostingRegressor()
|
1643
|
+
model.fit(X, y)
|
1644
|
+
|
1645
|
+
# Determine the number of rows and columns for subplots
|
1646
|
+
n_cols = 4 # Number of columns in subplot grid
|
1647
|
+
n_rows = (len(features) + n_cols - 1) // n_cols # Calculate rows needed
|
1648
|
+
|
1649
|
+
# Plot partial dependence
|
1650
|
+
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(5 * n_cols, 5 * n_rows))
|
1651
|
+
fig.suptitle('Partial Dependence Plots', fontsize=20, y=1.03)
|
1652
|
+
|
1653
|
+
# Flatten the array of axes if it's multidimensional
|
1654
|
+
axs = axs.flatten() if n_rows > 1 else [axs]
|
1655
|
+
|
1656
|
+
for i, feature in enumerate(features):
|
1657
|
+
ax = axs[i]
|
1658
|
+
disp = PartialDependenceDisplay.from_estimator(model, X, features=[feature], ax=ax)
|
1659
|
+
ax.set_title(feature) # Set title to the name of the feature
|
1660
|
+
|
1661
|
+
# Hide unused axes if any
|
1662
|
+
for ax in axs[len(features):]:
|
1663
|
+
ax.set_visible(False)
|
1664
|
+
|
1665
|
+
plt.tight_layout()
|
1666
|
+
plt.show()
|
1667
|
+
save_plot(fig, src='figures', variable='partial_dependences', i=1)
|
1668
|
+
return fig
|
1669
|
+
|
1670
|
+
def save_shap_plot(fig, src, variable, i):
|
1671
|
+
import os
|
1672
|
+
os.makedirs(f'{src}/{variable}', exist_ok=True)
|
1673
|
+
filename_fig = f'{src}/{variable}/{str(i)}_figure.pdf'
|
1674
|
+
fig.savefig(filename_fig, dpi=600, format='pdf', bbox_inches='tight')
|
1675
|
+
print(f"Saved figure as {filename_fig}")
|
1676
|
+
|
1677
|
+
def generate_shap_summary_plot(df,target='prauc', clean=True):
|
1678
|
+
"""
|
1679
|
+
Generates a SHAP summary plot for the given features in the dataframe.
|
1680
|
+
|
1681
|
+
Args:
|
1682
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1683
|
+
features (list): List of column names to include as features.
|
1684
|
+
target (str): The name of the target variable column.
|
1685
|
+
|
1686
|
+
Returns:
|
1687
|
+
None
|
1688
|
+
"""
|
1689
|
+
|
1690
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1691
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1692
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1693
|
+
|
1694
|
+
if clean:
|
1695
|
+
df = remove_columns_with_single_value(df)
|
1696
|
+
features = [feature for feature in features if feature in df.columns]
|
1697
|
+
|
1698
|
+
X = df[features]
|
1699
|
+
y = df[target]
|
1700
|
+
|
1701
|
+
# Initialize a model (you could pass it as an argument if you'd like to use a different one)
|
1702
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
1703
|
+
model.fit(X, y)
|
1704
|
+
|
1705
|
+
# Calculate SHAP values
|
1706
|
+
explainer = shap.TreeExplainer(model)
|
1707
|
+
shap_values = explainer.shap_values(X)
|
1708
|
+
|
1709
|
+
# Summary plot
|
1710
|
+
shap.summary_plot(shap_values, X)
|
1711
|
+
save_shap_plot(plt.gcf(), src='figures', variable='shap', i=1)
|
1712
|
+
#save_shap_plot(fig, src, variable, i)
|
1713
|
+
return plt.gcf()
|
1714
|
+
|
1715
|
+
|
1716
|
+
def remove_constant_columns(df):
|
1717
|
+
"""
|
1718
|
+
Removes columns in the DataFrame where all entries have the same value.
|
1719
|
+
|
1720
|
+
Parameters:
|
1721
|
+
df (pd.DataFrame): The input DataFrame from which to remove constant columns.
|
1722
|
+
|
1723
|
+
Returns:
|
1724
|
+
pd.DataFrame: A DataFrame with the constant columns removed.
|
1725
|
+
"""
|
1726
|
+
return df.loc[:, df.nunique() > 1]
|
1727
|
+
|
1728
|
+
|
1729
|
+
# to justify using beta for sim classifier
|
1730
|
+
|
1731
|
+
# Fit a Beta distribution to these outputs
|
1732
|
+
#a, b, loc, scale = beta.fit(predicted_probs, floc=0, fscale=1) # Fix location and scale to match the support of the sigmoid
|
1733
|
+
|
1734
|
+
# Sample from this fitted Beta distribution
|
1735
|
+
#simulated_probs = beta.rvs(a, b, size=1000)
|
1736
|
+
|
1737
|
+
# Plot the empirical vs simulated distribution
|
1738
|
+
#plt.hist(predicted_probs, bins=30, alpha=0.5, label='Empirical')
|
1739
|
+
#plt.hist(simulated_probs, bins=30, alpha=0.5, label='Simulated from Beta')
|
1740
|
+
#plt.legend()
|
1741
|
+
#plt.show()
|