spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +6 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +807 -0
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/cli.py +25 -187
- spacr/core.py +1611 -389
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +320 -0
- spacr/graph_learning_lap.py +84 -0
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +187 -0
- spacr/gui_mask_app.py +149 -174
- spacr/gui_measure_app.py +116 -109
- spacr/gui_sim_app.py +0 -0
- spacr/gui_utils.py +679 -139
- spacr/io.py +620 -469
- spacr/mask_app.py +116 -9
- spacr/measure.py +178 -84
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +255 -1
- spacr/plot.py +263 -100
- spacr/sequencing.py +1130 -0
- spacr/sim.py +634 -122
- spacr/timelapse.py +343 -53
- spacr/train.py +195 -22
- spacr/umap.py +0 -689
- spacr/utils.py +1530 -188
- spacr-0.0.6.dist-info/METADATA +118 -0
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.1.dist-info/METADATA +0 -64
- spacr-0.0.1.dist-info/RECORD +0 -26
- spacr-0.0.1.dist-info/entry_points.txt +0 -5
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/sim.py
CHANGED
@@ -1,16 +1,19 @@
|
|
1
1
|
|
2
|
-
import os,
|
3
|
-
import time as tm
|
2
|
+
import os, random, warnings, traceback, sqlite3, shap, math, gc
|
4
3
|
from time import time, sleep
|
5
4
|
from datetime import datetime
|
6
5
|
import numpy as np
|
7
6
|
import pandas as pd
|
8
7
|
import matplotlib.pyplot as plt
|
8
|
+
from matplotlib.patches import Patch
|
9
9
|
import seaborn as sns
|
10
10
|
import sklearn.metrics as metrics
|
11
|
-
from sklearn.
|
11
|
+
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
12
|
+
from sklearn.inspection import PartialDependenceDisplay, permutation_importance
|
13
|
+
from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_recall_curve
|
12
14
|
import statsmodels.api as sm
|
13
|
-
from multiprocessing import cpu_count,
|
15
|
+
from multiprocessing import cpu_count, Pool, Manager
|
16
|
+
from copy import deepcopy
|
14
17
|
|
15
18
|
from .logger import log_function_call
|
16
19
|
|
@@ -35,6 +38,7 @@ def generate_gene_list(number_of_genes, number_of_all_genes):
|
|
35
38
|
|
36
39
|
# plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
|
37
40
|
def generate_plate_map(nr_plates):
|
41
|
+
#print('nr_plates',nr_plates)
|
38
42
|
"""
|
39
43
|
Generate a plate map based on the number of plates.
|
40
44
|
|
@@ -240,6 +244,7 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
240
244
|
- df_ls (list): A list containing gene counts per well, well counts per gene, Gini coefficients for wells,
|
241
245
|
Gini coefficients for genes, gene weights array, and well weights.
|
242
246
|
"""
|
247
|
+
|
243
248
|
#generate primary distributions and genes
|
244
249
|
cpw, _ = dist_gen(avg_cells_per_well, sd_cells_per_well, plate_map)
|
245
250
|
gpw, _ = dist_gen(avg_genes_per_well, sd_genes_per_well, plate_map)
|
@@ -250,14 +255,18 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
250
255
|
gene_weights_array = np.array(list(gene_weights.values())) # Convert the values to an array
|
251
256
|
|
252
257
|
well_weights = generate_power_law_distribution(len(plate_map), well_ineq_coeff)
|
253
|
-
|
254
|
-
gene_to_well_mapping = {}
|
255
|
-
|
256
|
-
for gene in genes:
|
257
|
-
gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=int(gpw[gene-1]), p=well_weights) # Generate a number of wells for each gene according to well_weights
|
258
|
-
|
259
|
-
gene_to_well_mapping = {gene: wells for gene, wells in gene_to_well_mapping.items() if len(wells) >= 2}
|
260
258
|
|
259
|
+
gene_to_well_mapping = {}
|
260
|
+
for gene in range(1, number_of_genes + 1): # ensures gene-1 is within bounds
|
261
|
+
if gene-1 < len(gpw):
|
262
|
+
max_index = len(plate_map['plate_row_column']) # this should be the number of choices available from plate_map
|
263
|
+
num_samples = int(gpw[gene-1])
|
264
|
+
if num_samples >= max_index:
|
265
|
+
num_samples = max_index - 1 # adjust to maximum possible index
|
266
|
+
gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=num_samples, replace=False, p=well_weights)
|
267
|
+
else:
|
268
|
+
break # break the loop if gene-1 is out of bounds for gpw
|
269
|
+
|
261
270
|
cells = []
|
262
271
|
for i in [*range(0,len(plate_map))]:
|
263
272
|
ciw = random.choice(cpw)
|
@@ -315,9 +324,47 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
|
|
315
324
|
df_ls = [gene_counts_per_well, well_counts_per_gene, gini_well, gini_gene, gene_weights_array, well_weights]
|
316
325
|
return cell_df, genes_per_well_df, wells_per_gene_df, df_ls
|
317
326
|
|
318
|
-
|
319
|
-
|
320
|
-
|
327
|
+
def classifier(positive_mean, positive_variance, negative_mean, negative_variance, classifier_accuracy, df):
|
328
|
+
"""
|
329
|
+
Classifies the data in the DataFrame based on the given parameters and a classifier error rate.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
positive_mean (float): The mean of the positive distribution.
|
333
|
+
positive_variance (float): The variance of the positive distribution.
|
334
|
+
negative_mean (float): The mean of the negative distribution.
|
335
|
+
negative_variance (float): The variance of the negative distribution.
|
336
|
+
classifier_accuracy (float): The likelihood (0 to 1) that a gene is correctly classified according to its true label.
|
337
|
+
df (pandas.DataFrame): The DataFrame containing the data to be classified.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
|
341
|
+
"""
|
342
|
+
def calc_alpha_beta(mean, variance):
|
343
|
+
if mean <= 0 or mean >= 1:
|
344
|
+
raise ValueError("Mean must be between 0 and 1 exclusively.")
|
345
|
+
max_variance = mean * (1 - mean)
|
346
|
+
if variance <= 0 or variance >= max_variance:
|
347
|
+
raise ValueError(f"Variance must be positive and less than {max_variance}.")
|
348
|
+
|
349
|
+
alpha = mean * (mean * (1 - mean) / variance - 1)
|
350
|
+
beta = alpha * (1 - mean) / mean
|
351
|
+
return alpha, beta
|
352
|
+
|
353
|
+
# Apply the beta distribution based on 'is_active' status with consideration for classifier error
|
354
|
+
def get_score(is_active):
|
355
|
+
if np.random.rand() < classifier_accuracy: # With classifier_accuracy probability, choose the correct distribution
|
356
|
+
return np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2)
|
357
|
+
else: # With 1-classifier_accuracy probability, choose the incorrect distribution
|
358
|
+
return np.random.beta(a2, b2) if is_active else np.random.beta(a1, b1)
|
359
|
+
|
360
|
+
# Calculate alpha and beta for both distributions
|
361
|
+
a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
|
362
|
+
a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
|
363
|
+
df['score'] = df['is_active'].apply(get_score)
|
364
|
+
|
365
|
+
return df
|
366
|
+
|
367
|
+
def classifier_v2(positive_mean, positive_variance, negative_mean, negative_variance, df):
|
321
368
|
"""
|
322
369
|
Classifies the data in the DataFrame based on the given parameters.
|
323
370
|
|
@@ -331,12 +378,22 @@ def classifier(positive_mean, positive_variance, negative_mean, negative_varianc
|
|
331
378
|
Returns:
|
332
379
|
pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
|
333
380
|
"""
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
381
|
+
def calc_alpha_beta(mean, variance):
|
382
|
+
if mean <= 0 or mean >= 1:
|
383
|
+
raise ValueError("Mean must be between 0 and 1 exclusively.")
|
384
|
+
max_variance = mean * (1 - mean)
|
385
|
+
if variance <= 0 or variance >= max_variance:
|
386
|
+
raise ValueError(f"Variance must be positive and less than {max_variance}.")
|
387
|
+
|
388
|
+
alpha = mean * (mean * (1 - mean) / variance - 1)
|
389
|
+
beta = alpha * (1 - mean) / mean
|
390
|
+
return alpha, beta
|
391
|
+
|
392
|
+
# Calculate alpha and beta for both distributions
|
393
|
+
a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
|
394
|
+
a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
|
395
|
+
|
396
|
+
# Apply the beta distribution based on 'is_active' status
|
340
397
|
df['score'] = df['is_active'].apply(lambda is_active: np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2))
|
341
398
|
return df
|
342
399
|
|
@@ -445,6 +502,7 @@ def generate_well_score(cell_scores):
|
|
445
502
|
return well_score
|
446
503
|
|
447
504
|
def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_per_gene, sequencing_error=0.01):
|
505
|
+
|
448
506
|
"""
|
449
507
|
Simulates the sequencing of plates and calculates gene fractions and metadata.
|
450
508
|
|
@@ -459,6 +517,7 @@ def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_pe
|
|
459
517
|
gene_fraction_map (pd.DataFrame): DataFrame containing gene fractions for each well.
|
460
518
|
metadata (pd.DataFrame): DataFrame containing metadata for each well.
|
461
519
|
"""
|
520
|
+
|
462
521
|
reads, _ = dist_gen(avg_reads_per_gene, sd_reads_per_gene, well_score)
|
463
522
|
gene_names = [f'gene_{v}' for v in range(number_of_genes+1)]
|
464
523
|
all_wells = well_score.index
|
@@ -611,8 +670,12 @@ def plot_histogram(data, x_label, ax, color, title, binwidth=0.01, log=False):
|
|
611
670
|
Returns:
|
612
671
|
None
|
613
672
|
"""
|
614
|
-
|
615
|
-
|
673
|
+
if not binwidth:
|
674
|
+
sns.histplot(data=data, x=x_label, ax=ax, color=color, kde=False, stat='density',
|
675
|
+
legend=False, fill=True, element='step', palette='dark')
|
676
|
+
else:
|
677
|
+
sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
|
678
|
+
legend=False, fill=True, element='step', palette='dark')
|
616
679
|
if log:
|
617
680
|
ax.set_yscale('log')
|
618
681
|
ax.set_title(title)
|
@@ -699,7 +762,7 @@ def run_simulation(settings):
|
|
699
762
|
plate_map = plate_map[~plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is not in [1,2,3,23,24]
|
700
763
|
|
701
764
|
cell_level, genes_per_well_df, wells_per_gene_df, dists = run_experiment(plate_map, settings['number_of_genes'], active_gene_list, settings['avg_genes_per_well'], settings['sd_genes_per_well'], settings['avg_cells_per_well'], settings['sd_cells_per_well'], settings['well_ineq_coeff'], settings['gene_ineq_coeff'])
|
702
|
-
cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], df=cell_level)
|
765
|
+
cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], settings['classifier_accuracy'], df=cell_level)
|
703
766
|
cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm = cell_level_roc_auc(cell_scores)
|
704
767
|
well_score = generate_well_score(cell_scores)
|
705
768
|
gene_fraction_map, metadata = sequence_plates(well_score, settings['number_of_genes'], settings['avg_reads_per_gene'], settings['sd_reads_per_gene'], sequencing_error=settings['sequencing_error'])
|
@@ -719,7 +782,10 @@ def run_simulation(settings):
|
|
719
782
|
results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats = regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False)
|
720
783
|
#except Exception as e:
|
721
784
|
# print(f"An error occurred while saving data: {e}")
|
722
|
-
|
785
|
+
output = [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df]
|
786
|
+
del cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df
|
787
|
+
gc.collect()
|
788
|
+
return output, dists
|
723
789
|
|
724
790
|
def vis_dists(dists, src, v, i):
|
725
791
|
"""
|
@@ -743,9 +809,14 @@ def vis_dists(dists, src, v, i):
|
|
743
809
|
for index, dist in enumerate(dists):
|
744
810
|
temp = pd.DataFrame(dist, columns = [f'{names[index]}'])
|
745
811
|
sns.histplot(data=temp, x=f'{names[index]}', kde=False, binwidth=None, stat='count', element="step", ax=ax[n], color='teal', log_scale=False)
|
746
|
-
#plot_histogram(temp, f'{names[index]}', ax[n], 'slategray', f'{names[index]}', binwidth=None, log=False)
|
747
812
|
n+=1
|
748
813
|
save_plot(fig2, src, 'dists', i)
|
814
|
+
plt.close(fig2)
|
815
|
+
plt.figure().clear()
|
816
|
+
plt.cla()
|
817
|
+
plt.clf()
|
818
|
+
del dists
|
819
|
+
|
749
820
|
return
|
750
821
|
|
751
822
|
def visualize_all(output):
|
@@ -772,6 +843,7 @@ def visualize_all(output):
|
|
772
843
|
Returns:
|
773
844
|
fig (matplotlib.figure.Figure): The generated figure object.
|
774
845
|
"""
|
846
|
+
|
775
847
|
cell_scores = output[0]
|
776
848
|
cell_roc_dict_df = output[1]
|
777
849
|
cell_pr_dict_df = output[2]
|
@@ -809,24 +881,35 @@ def visualize_all(output):
|
|
809
881
|
#plot cell classification score by inactive and active
|
810
882
|
active_distribution = cell_scores[cell_scores['is_active'] == 1]
|
811
883
|
inactive_distribution = cell_scores[cell_scores['is_active'] == 0]
|
812
|
-
plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', binwidth=0.01, log=False)
|
813
|
-
plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', binwidth=0.01, log=False)
|
884
|
+
plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', log=False)#, binwidth=0.01, log=False)
|
885
|
+
plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', log=False)#, binwidth=0.01, log=False)
|
886
|
+
|
887
|
+
legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
|
888
|
+
Patch(facecolor='teal', edgecolor='teal', label='Active')]
|
889
|
+
|
890
|
+
ax[n].legend(handles=legend_elements, loc='upper right')
|
891
|
+
|
892
|
+
|
814
893
|
ax[n].set_xlim([0, 1])
|
815
894
|
n+=1
|
816
895
|
|
817
896
|
#plot classifier cell predictions by inactive and active well average
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
897
|
+
inactive_distribution_well = inactive_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
898
|
+
active_distribution_well = active_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
899
|
+
mixed_distribution_well = cell_scores.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
|
900
|
+
|
901
|
+
plot_histogram(inactive_distribution_well, "score", ax[n], 'slategray', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
902
|
+
plot_histogram(active_distribution_well, "score", ax[n], 'teal', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
903
|
+
plot_histogram(mixed_distribution_well, "score", ax[n], 'red', 'Well scores', log=False)#, binwidth=0.01, log=False)
|
823
904
|
|
824
|
-
|
825
|
-
|
905
|
+
legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
|
906
|
+
Patch(facecolor='teal', edgecolor='teal', label='Active'),
|
907
|
+
Patch(facecolor='red', edgecolor='red', label='Mixed')]
|
826
908
|
|
827
|
-
|
828
|
-
|
909
|
+
ax[n].legend(handles=legend_elements, loc='upper right')
|
910
|
+
|
829
911
|
ax[n].set_xlim([0, 1])
|
912
|
+
#ax[n].legend()
|
830
913
|
n+=1
|
831
914
|
|
832
915
|
#plot ROC (cell classification)
|
@@ -845,8 +928,8 @@ def visualize_all(output):
|
|
845
928
|
n+=1
|
846
929
|
|
847
930
|
#plot well score
|
848
|
-
plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=
|
849
|
-
ax[n].set_xlim([0, 1])
|
931
|
+
plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=True)
|
932
|
+
#ax[n].set_xlim([0, 1])
|
850
933
|
n+=1
|
851
934
|
|
852
935
|
control_df = results_df[results_df['color'] == 'control']
|
@@ -910,6 +993,7 @@ def visualize_all(output):
|
|
910
993
|
|
911
994
|
plt.tight_layout()
|
912
995
|
plt.show()
|
996
|
+
gc.collect()
|
913
997
|
return fig
|
914
998
|
|
915
999
|
def create_database(db_path):
|
@@ -991,6 +1075,7 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
|
|
991
1075
|
df_concat[f'variable_{variable}_sim_nr'] = i
|
992
1076
|
|
993
1077
|
append_database(src, df_concat, 'simulations')
|
1078
|
+
del gini_genes_per_well, gini_wells_per_gene, df_concat
|
994
1079
|
|
995
1080
|
if save_all:
|
996
1081
|
for i, df in enumerate(output):
|
@@ -1000,8 +1085,12 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
|
|
1000
1085
|
if not isinstance(df, pd.DataFrame):
|
1001
1086
|
df = pd.DataFrame(df)
|
1002
1087
|
append_database(src, df, table_names[i])
|
1088
|
+
del df
|
1003
1089
|
except Exception as e:
|
1004
1090
|
print(f"An error occurred while saving data: {e}")
|
1091
|
+
print(traceback.format_exc())
|
1092
|
+
|
1093
|
+
del output, settings_df
|
1005
1094
|
return
|
1006
1095
|
|
1007
1096
|
def save_plot(fig, src, variable, i):
|
@@ -1023,6 +1112,7 @@ def save_plot(fig, src, variable, i):
|
|
1023
1112
|
return
|
1024
1113
|
|
1025
1114
|
def run_and_save(i, settings, time_ls, total_sims):
|
1115
|
+
|
1026
1116
|
"""
|
1027
1117
|
Run the simulation and save the results.
|
1028
1118
|
|
@@ -1035,23 +1125,26 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1035
1125
|
Returns:
|
1036
1126
|
tuple: A tuple containing the simulation index, simulation time, and None.
|
1037
1127
|
"""
|
1128
|
+
#print(f'Runnings simulation with the following paramiters')
|
1129
|
+
#print(settings)
|
1130
|
+
settings['random_seed'] = False
|
1038
1131
|
if settings['random_seed']:
|
1039
1132
|
random.seed(42) # sims will be too similar with random seed
|
1040
1133
|
src = settings['src']
|
1041
1134
|
plot = settings['plot']
|
1042
1135
|
v = settings['variable']
|
1043
1136
|
start_time = time() # Start time of the simulation
|
1044
|
-
now = datetime.now() # get current date
|
1045
|
-
date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1137
|
+
#now = datetime.now() # get current date
|
1138
|
+
#date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1139
|
+
date_string = settings['start_time']
|
1046
1140
|
#try:
|
1047
1141
|
output, dists = run_simulation(settings)
|
1048
1142
|
sim_time = time() - start_time # Elapsed time for the simulation
|
1049
1143
|
settings['sim_time'] = sim_time
|
1050
1144
|
src = os.path.join(f'{src}/{date_string}',settings['name'])
|
1051
1145
|
save_data(src, output, settings, save_all=False, i=i, variable=v)
|
1052
|
-
if vis_dists:
|
1053
|
-
vis_dists(dists,src, v, i)
|
1054
1146
|
if plot:
|
1147
|
+
vis_dists(dists,src, v, i)
|
1055
1148
|
fig = visualize_all(output)
|
1056
1149
|
save_plot(fig, src, v, i)
|
1057
1150
|
plt.close(fig)
|
@@ -1060,6 +1153,7 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1060
1153
|
plt.clf()
|
1061
1154
|
del fig
|
1062
1155
|
del output, dists
|
1156
|
+
gc.collect()
|
1063
1157
|
#except Exception as e:
|
1064
1158
|
# print(e, end='\r', flush=True)
|
1065
1159
|
# sim_time = time() - start_time
|
@@ -1067,98 +1161,92 @@ def run_and_save(i, settings, time_ls, total_sims):
|
|
1067
1161
|
time_ls.append(sim_time)
|
1068
1162
|
return i, sim_time, None
|
1069
1163
|
|
1070
|
-
def
|
1164
|
+
def validate_and_adjust_beta_params(sim_params):
|
1071
1165
|
"""
|
1072
|
-
|
1073
|
-
|
1166
|
+
Validates and adjusts Beta distribution parameters in simulation settings to ensure they are possible.
|
1167
|
+
|
1074
1168
|
Args:
|
1075
|
-
|
1076
|
-
|
1169
|
+
sim_params (list of dict): List of dictionaries, each containing the simulation parameters.
|
1170
|
+
|
1077
1171
|
Returns:
|
1078
|
-
|
1172
|
+
list of dict: The adjusted list of simulation parameter sets.
|
1079
1173
|
"""
|
1080
|
-
|
1081
|
-
for
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
sett['well_ineq_coeff'] = well_ineq_coeff
|
1101
|
-
for gene_ineq_coeff in settings['gene_ineq_coeff']:
|
1102
|
-
sett['gene_ineq_coeff'] = gene_ineq_coeff
|
1103
|
-
for nr_plates in settings['nr_plates']:
|
1104
|
-
sett['nr_plates'] = nr_plates
|
1105
|
-
for number_of_genes in settings['number_of_genes']:
|
1106
|
-
sett['number_of_genes'] = number_of_genes
|
1107
|
-
for number_of_active_genes in settings['number_of_active_genes']:
|
1108
|
-
sett['number_of_active_genes'] = number_of_active_genes
|
1109
|
-
for i in [*range(1,replicates+1)]:
|
1110
|
-
sim_ls.append(sett)
|
1111
|
-
#print(sett)
|
1112
|
-
#print('Number of simulations:',len(sim_ls))
|
1113
|
-
return sim_ls
|
1174
|
+
adjusted_params = []
|
1175
|
+
for params in sim_params:
|
1176
|
+
max_pos_variance = params['positive_mean'] * (1 - params['positive_mean'])
|
1177
|
+
max_neg_variance = params['negative_mean'] * (1 - params['negative_mean'])
|
1178
|
+
|
1179
|
+
# Adjust positive variance
|
1180
|
+
if params['positive_variance'] >= max_pos_variance:
|
1181
|
+
print(f'changed positive variance from {params["positive_variance"]} to {max_pos_variance * 0.99}')
|
1182
|
+
params['positive_variance'] = max_pos_variance * 0.99 # Adjust to 99% of the maximum allowed variance
|
1183
|
+
|
1184
|
+
# Adjust negative variance
|
1185
|
+
if params['negative_variance'] >= max_neg_variance:
|
1186
|
+
print(f'changed negative variance from {params["negative_variance"]} to {max_neg_variance * 0.99}')
|
1187
|
+
params['negative_variance'] = max_neg_variance * 0.99 # Adjust to 99% of the maximum allowed variance
|
1188
|
+
|
1189
|
+
adjusted_params.append(params)
|
1190
|
+
|
1191
|
+
return adjusted_params
|
1192
|
+
|
1193
|
+
def generate_paramiters(settings):
|
1114
1194
|
|
1115
|
-
#altered for one set of settings see negative_mean and variance
|
1116
|
-
def generate_paramiters_single(settings):
|
1117
1195
|
"""
|
1118
|
-
Generate a list of parameter sets for
|
1196
|
+
Generate a list of parameter sets for simulation based on the given settings.
|
1119
1197
|
|
1120
1198
|
Args:
|
1121
1199
|
settings (dict): A dictionary containing the simulation settings.
|
1122
1200
|
|
1123
1201
|
Returns:
|
1124
|
-
list: A list of parameter sets for
|
1202
|
+
list: A list of parameter sets for simulation.
|
1125
1203
|
"""
|
1204
|
+
|
1205
|
+
settings['positive_mean'] = [0.8]
|
1206
|
+
|
1126
1207
|
sim_ls = []
|
1127
1208
|
for avg_genes_per_well in settings['avg_genes_per_well']:
|
1128
1209
|
replicates = settings['replicates']
|
1129
|
-
sett = settings.copy()
|
1130
|
-
sett['avg_genes_per_well'] = avg_genes_per_well
|
1131
|
-
sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
|
1132
1210
|
for avg_cells_per_well in settings['avg_cells_per_well']:
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1211
|
+
for classifier_accuracy in settings['classifier_accuracy']:
|
1212
|
+
for positive_mean in settings['positive_mean']:
|
1213
|
+
for avg_reads_per_gene in settings['avg_reads_per_gene']:
|
1214
|
+
for sequencing_error in settings['sequencing_error']:
|
1215
|
+
for well_ineq_coeff in settings['well_ineq_coeff']:
|
1216
|
+
for gene_ineq_coeff in settings['gene_ineq_coeff']:
|
1217
|
+
for nr_plates in settings['nr_plates']:
|
1218
|
+
for number_of_genes in settings['number_of_genes']:
|
1219
|
+
for number_of_active_genes in settings['number_of_active_genes']:
|
1220
|
+
for i in range(1, replicates+1):
|
1221
|
+
sett = deepcopy(settings)
|
1222
|
+
sett['avg_genes_per_well'] = avg_genes_per_well
|
1223
|
+
sett['sd_genes_per_well'] = avg_genes_per_well / 2
|
1224
|
+
sett['avg_cells_per_well'] = avg_cells_per_well
|
1225
|
+
sett['sd_cells_per_well'] = avg_cells_per_well / 2
|
1226
|
+
sett['classifier_accuracy'] = classifier_accuracy
|
1227
|
+
sett['positive_mean'] = positive_mean
|
1228
|
+
sett['negative_mean'] = 1-positive_mean
|
1229
|
+
sett['positive_variance'] = (1-positive_mean)/2
|
1230
|
+
sett['negative_variance'] = (1-positive_mean)/2
|
1231
|
+
sett['avg_reads_per_gene'] = avg_reads_per_gene
|
1232
|
+
sett['sd_reads_per_gene'] = avg_reads_per_gene / 2
|
1233
|
+
sett['sequencing_error'] = sequencing_error
|
1234
|
+
sett['well_ineq_coeff'] = well_ineq_coeff
|
1235
|
+
sett['gene_ineq_coeff'] = gene_ineq_coeff
|
1236
|
+
sett['nr_plates'] = nr_plates
|
1237
|
+
sett['number_of_genes'] = number_of_genes
|
1238
|
+
sett['number_of_active_genes'] = number_of_active_genes
|
1239
|
+
sim_ls.append(sett)
|
1240
|
+
|
1241
|
+
random.shuffle(sim_ls)
|
1242
|
+
sim_ls = validate_and_adjust_beta_params(sim_ls)
|
1243
|
+
print(f'Running {len(sim_ls)} simulations.')
|
1244
|
+
#for x in sim_ls:
|
1245
|
+
# print(x['positive_mean'])
|
1159
1246
|
return sim_ls
|
1160
1247
|
|
1161
1248
|
def run_multiple_simulations(settings):
|
1249
|
+
|
1162
1250
|
"""
|
1163
1251
|
Run multiple simulations in parallel using the provided settings.
|
1164
1252
|
|
@@ -1169,8 +1257,12 @@ def run_multiple_simulations(settings):
|
|
1169
1257
|
None
|
1170
1258
|
"""
|
1171
1259
|
|
1260
|
+
now = datetime.now() # get current date
|
1261
|
+
start_time = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
|
1262
|
+
settings['start_time'] = start_time
|
1263
|
+
|
1172
1264
|
sim_ls = generate_paramiters(settings)
|
1173
|
-
print(f'Running {len(sim_ls)} simulations.
|
1265
|
+
#print(f'Running {len(sim_ls)} simulations.')
|
1174
1266
|
|
1175
1267
|
max_workers = settings['max_workers'] or cpu_count() - 4
|
1176
1268
|
with Manager() as manager:
|
@@ -1179,9 +1271,429 @@ def run_multiple_simulations(settings):
|
|
1179
1271
|
with Pool(max_workers) as pool:
|
1180
1272
|
result = pool.starmap_async(run_and_save, [(index, settings, time_ls, total_sims) for index, settings in enumerate(sim_ls)])
|
1181
1273
|
while not result.ready():
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1274
|
+
try:
|
1275
|
+
sleep(0.01)
|
1276
|
+
sims_processed = len(time_ls)
|
1277
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
1278
|
+
time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
|
1279
|
+
print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
|
1280
|
+
gc.collect()
|
1281
|
+
except Exception as e:
|
1282
|
+
print(e)
|
1283
|
+
print(traceback.format_exc())
|
1284
|
+
try:
|
1285
|
+
result.get()
|
1286
|
+
except Exception as e:
|
1287
|
+
print(e)
|
1288
|
+
print(traceback.format_exc())
|
1289
|
+
|
1290
|
+
def generate_integers(start, stop, step):
|
1291
|
+
return list(range(start, stop + 1, step))
|
1292
|
+
|
1293
|
+
def generate_floats(start, stop, step):
|
1294
|
+
# Determine the number of decimal places in 'step'
|
1295
|
+
num_decimals = str(step)[::-1].find('.')
|
1296
|
+
|
1297
|
+
current = start
|
1298
|
+
floats_list = []
|
1299
|
+
while current <= stop:
|
1300
|
+
# Round each float to the appropriate number of decimal places
|
1301
|
+
floats_list.append(round(current, num_decimals))
|
1302
|
+
current += step
|
1303
|
+
|
1304
|
+
return floats_list
|
1305
|
+
|
1306
|
+
def remove_columns_with_single_value(df):
|
1307
|
+
"""
|
1308
|
+
Removes columns from the DataFrame that have the same value in all rows.
|
1309
|
+
|
1310
|
+
Args:
|
1311
|
+
df (pandas.DataFrame): The original DataFrame.
|
1312
|
+
|
1313
|
+
Returns:
|
1314
|
+
pandas.DataFrame: A DataFrame with the columns removed that contained only one unique value.
|
1315
|
+
"""
|
1316
|
+
to_drop = [column for column in df.columns if df[column].nunique() == 1]
|
1317
|
+
return df.drop(to_drop, axis=1)
|
1318
|
+
|
1319
|
+
def read_simulations_table(db_path):
|
1320
|
+
"""
|
1321
|
+
Reads the 'simulations' table from an SQLite database into a pandas DataFrame.
|
1322
|
+
|
1323
|
+
Args:
|
1324
|
+
db_path (str): The file path to the SQLite database.
|
1325
|
+
|
1326
|
+
Returns:
|
1327
|
+
pandas.DataFrame: DataFrame containing the 'simulations' table data.
|
1328
|
+
"""
|
1329
|
+
# Create a connection object using the connect function
|
1330
|
+
conn = sqlite3.connect(db_path)
|
1331
|
+
|
1332
|
+
# Read the 'simulations' table into a pandas DataFrame
|
1333
|
+
try:
|
1334
|
+
df = pd.read_sql_query("SELECT * FROM simulations", conn)
|
1335
|
+
except Exception as e:
|
1336
|
+
print(f"An error occurred: {e}")
|
1337
|
+
return None
|
1338
|
+
finally:
|
1339
|
+
# Close the connection to SQLite database
|
1340
|
+
conn.close()
|
1341
|
+
|
1342
|
+
return df
|
1343
|
+
|
1344
|
+
def plot_simulations(df, variable, x_rotation=None, legend=False, grid=False, clean=True, verbose=False):
|
1345
|
+
|
1346
|
+
"""
|
1347
|
+
Creates separate line plots for 'prauc' against a specified 'variable',
|
1348
|
+
for each unique combination of conditions defined by 'grouping_vars', displayed on a grid.
|
1349
|
+
|
1350
|
+
Args:
|
1351
|
+
df (pandas.DataFrame): DataFrame containing the necessary columns.
|
1352
|
+
variable (str): Name of the column to use as the x-axis for grouping and plotting.
|
1353
|
+
x_rotation (int, optional): Degrees to rotate the x-axis labels.
|
1354
|
+
legend (bool, optional): Whether to display a legend.
|
1355
|
+
grid (bool, optional): Whether to display grid lines.
|
1356
|
+
verbose (bool, optional): Whether to print the filter conditions.
|
1357
|
+
|
1358
|
+
Returns:
|
1359
|
+
None
|
1360
|
+
"""
|
1361
|
+
|
1362
|
+
grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1363
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1364
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1365
|
+
|
1366
|
+
if clean:
|
1367
|
+
relevant_data = remove_columns_with_single_value(relevant_data)
|
1368
|
+
|
1369
|
+
grouping_vars = [col for col in grouping_vars if col != variable]
|
1370
|
+
|
1371
|
+
# Check if the necessary columns are present in the DataFrame
|
1372
|
+
required_columns = {variable, 'prauc'} | set(grouping_vars)
|
1373
|
+
if not required_columns.issubset(df.columns):
|
1374
|
+
missing_cols = required_columns - set(df.columns)
|
1375
|
+
raise ValueError(f"DataFrame must contain {missing_cols} columns")
|
1376
|
+
|
1377
|
+
#if not dependent is None:
|
1378
|
+
|
1379
|
+
# Get unique combinations of conditions from grouping_vars
|
1380
|
+
unique_combinations = df[grouping_vars].drop_duplicates()
|
1381
|
+
num_combinations = len(unique_combinations)
|
1382
|
+
|
1383
|
+
# Determine the layout of the subplots
|
1384
|
+
num_rows = math.ceil(np.sqrt(num_combinations))
|
1385
|
+
num_cols = math.ceil(num_combinations / num_rows)
|
1386
|
+
|
1387
|
+
fig, axes = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
|
1388
|
+
if num_rows * num_cols > 1:
|
1389
|
+
axes = axes.flatten()
|
1390
|
+
else:
|
1391
|
+
axes = [axes]
|
1392
|
+
|
1393
|
+
for idx, (ax, (_, row)) in enumerate(zip(axes, unique_combinations.iterrows())):
|
1394
|
+
|
1395
|
+
# Filter the DataFrame for the current combination of variables
|
1396
|
+
condition = {var: row[var] for var in grouping_vars}
|
1397
|
+
subset_df = df[df[grouping_vars].eq(row).all(axis=1)]
|
1398
|
+
|
1399
|
+
# Group by 'variable' and calculate mean and std dev of 'prauc'
|
1400
|
+
grouped = subset_df.groupby(variable)['prauc'].agg(['mean', 'std'])
|
1401
|
+
grouped = grouped.sort_index() # Sort by the variable for orderly plots
|
1402
|
+
|
1403
|
+
# Plotting the mean of 'prauc' with std deviation as shaded area
|
1404
|
+
ax.plot(grouped.index, grouped['mean'], marker='o', linestyle='-', color='b', label='Mean PRAUC')
|
1405
|
+
ax.fill_between(grouped.index, grouped['mean'] - grouped['std'], grouped['mean'] + grouped['std'], color='gray', alpha=0.5, label='Std Dev')
|
1406
|
+
|
1407
|
+
# Setting plot labels and title
|
1408
|
+
title_details = ', '.join([f"{var}={row[var]}" for var in grouping_vars])
|
1409
|
+
ax.set_xlabel(variable)
|
1410
|
+
ax.set_ylabel('Precision-Recall AUC (PRAUC)')
|
1411
|
+
#ax.set_title(f'PRAUC vs. {variable} | {title_details}')
|
1412
|
+
ax.grid(grid)
|
1413
|
+
|
1414
|
+
if legend:
|
1415
|
+
ax.legend()
|
1416
|
+
|
1417
|
+
# Set x-ticks and rotate them as specified
|
1418
|
+
ax.set_xticks(grouped.index)
|
1419
|
+
ax.set_xticklabels(grouped.index, rotation=x_rotation if x_rotation is not None else 45)
|
1420
|
+
|
1421
|
+
if verbose:
|
1422
|
+
verbose_text = '\n'.join([f"{var}: {val}" for var, val in condition.items()])
|
1423
|
+
ax.text(0.95, 0.05, verbose_text, transform=ax.transAxes, fontsize=9, verticalalignment='bottom', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))
|
1424
|
+
|
1425
|
+
# Hide any unused axes if there are any
|
1426
|
+
for ax in axes[idx+1:]:
|
1427
|
+
ax.set_visible(False)
|
1428
|
+
|
1429
|
+
plt.tight_layout()
|
1430
|
+
plt.show()
|
1431
|
+
return fig
|
1432
|
+
|
1433
|
+
def plot_correlation_matrix(df, annot=False, cmap='inferno', clean=True):
|
1434
|
+
"""
|
1435
|
+
Plots a correlation matrix for the specified variables and the target variable.
|
1436
|
+
|
1437
|
+
Args:
|
1438
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1439
|
+
variables (list): List of column names to include in the correlation matrix.
|
1440
|
+
target_variable (str): The target variable column name.
|
1441
|
+
|
1442
|
+
Returns:
|
1443
|
+
None
|
1444
|
+
"""
|
1445
|
+
cmap = sns.diverging_palette(240, 10, as_cmap=True)
|
1446
|
+
grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1447
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1448
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1449
|
+
|
1450
|
+
grouping_vars = grouping_vars + ['optimal_threshold', 'accuracy', 'prauc', 'roc_auc','genes_per_well_gini', 'wells_per_gene_gini']
|
1451
|
+
# 'inactive_mean', 'inactive_std', 'inactive_var', 'active_mean', 'active_std', 'inactive_var', 'cutoff', 'TP', 'FP', 'TN', 'FN',
|
1452
|
+
|
1453
|
+
if clean:
|
1454
|
+
df = remove_constant_columns(df)
|
1455
|
+
grouping_vars = [feature for feature in grouping_vars if feature in df.columns]
|
1456
|
+
|
1457
|
+
# Subsetting the DataFrame to include only the relevant variables
|
1458
|
+
relevant_data = df[grouping_vars]
|
1459
|
+
|
1460
|
+
if clean:
|
1461
|
+
relevant_data = remove_columns_with_single_value(relevant_data)
|
1462
|
+
|
1463
|
+
# Calculating the correlation matrix
|
1464
|
+
corr_matrix = relevant_data.corr()
|
1465
|
+
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
|
1466
|
+
|
1467
|
+
# Plotting the correlation matrix
|
1468
|
+
fig = plt.figure(figsize=(12, 8))
|
1469
|
+
sns.heatmap(corr_matrix, mask=mask, annot=annot, cmap=cmap, fmt=".2f", linewidths=.5, robust=True)
|
1470
|
+
#plt.title('Correlation Matrix with Heatmap')
|
1471
|
+
|
1472
|
+
plt.tight_layout()
|
1473
|
+
plt.show()
|
1474
|
+
save_plot(fig, src='figures', variable='correlation_matrix', i=1)
|
1475
|
+
return fig
|
1476
|
+
|
1477
|
+
def plot_feature_importance(df, target='prauc', exclude=None, clean=True):
|
1478
|
+
"""
|
1479
|
+
Trains a RandomForestRegressor to determine the importance of each feature in predicting the target.
|
1480
|
+
|
1481
|
+
Args:
|
1482
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1483
|
+
target (str): The target variable column name.
|
1484
|
+
exclude (list or str, optional): Column names to exclude from features.
|
1485
|
+
|
1486
|
+
Returns:
|
1487
|
+
matplotlib.figure.Figure: The figure object containing the feature importance plot.
|
1488
|
+
"""
|
1489
|
+
|
1490
|
+
# Define the features for the model
|
1491
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1492
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1493
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1494
|
+
|
1495
|
+
if clean:
|
1496
|
+
df = remove_columns_with_single_value(df)
|
1497
|
+
features = [feature for feature in features if feature in df.columns]
|
1498
|
+
|
1499
|
+
# Remove excluded features if specified
|
1500
|
+
if isinstance(exclude, list):
|
1501
|
+
features = [feature for feature in features if feature not in exclude]
|
1502
|
+
elif exclude is not None:
|
1503
|
+
features = [feature for feature in features if feature != exclude]
|
1504
|
+
|
1505
|
+
# Train the model
|
1506
|
+
model = RandomForestRegressor(n_estimators=1000, random_state=42)
|
1507
|
+
model.fit(df[features], df[target])
|
1508
|
+
|
1509
|
+
# Get feature importances
|
1510
|
+
importances = model.feature_importances_
|
1511
|
+
indices = np.argsort(importances)[::-1]
|
1512
|
+
|
1513
|
+
# Plot horizontal bar chart
|
1514
|
+
fig = plt.figure(figsize=(12, 6))
|
1515
|
+
plt.barh(range(len(indices)), importances[indices], color="teal", align="center", alpha=0.6)
|
1516
|
+
plt.yticks(range(len(indices)), [features[i] for i in indices[::-1]]) # Invert y-axis to match the order
|
1517
|
+
plt.gca().invert_yaxis() # Invert the axis to have the highest importance at the top
|
1518
|
+
plt.xlabel('Feature Importance')
|
1519
|
+
plt.title('Feature Importances')
|
1520
|
+
plt.tight_layout()
|
1521
|
+
plt.show()
|
1522
|
+
save_plot(fig, src='figures', variable='feature_importance', i=1)
|
1523
|
+
return fig
|
1524
|
+
|
1525
|
+
def calculate_permutation_importance(df, target='prauc', exclude=None, n_repeats=10, clean=True):
|
1526
|
+
"""
|
1527
|
+
Calculates permutation importance for the given features in the dataframe.
|
1528
|
+
|
1529
|
+
Args:
|
1530
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1531
|
+
features (list): List of column names to include as features.
|
1532
|
+
target (str): The name of the target variable column.
|
1533
|
+
|
1534
|
+
Returns:
|
1535
|
+
dict: Dictionary containing the importances and standard deviations.
|
1536
|
+
"""
|
1537
|
+
|
1538
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1539
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1540
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1541
|
+
|
1542
|
+
if clean:
|
1543
|
+
df = remove_columns_with_single_value(df)
|
1544
|
+
features = [feature for feature in features if feature in df.columns]
|
1545
|
+
|
1546
|
+
if isinstance(exclude, list):
|
1547
|
+
for ex in exclude:
|
1548
|
+
features.remove(ex)
|
1549
|
+
if not exclude is None:
|
1550
|
+
features.remove(exclude)
|
1551
|
+
|
1552
|
+
X = df[features]
|
1553
|
+
y = df[target]
|
1554
|
+
|
1555
|
+
# Initialize a model (you could pass it as an argument if you'd like to use a different one)
|
1556
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
1557
|
+
model.fit(X, y)
|
1558
|
+
|
1559
|
+
perm_importance = permutation_importance(model, X, y, n_repeats=n_repeats, random_state=42)
|
1560
|
+
|
1561
|
+
# Plotting
|
1562
|
+
sorted_idx = perm_importance.importances_mean.argsort()
|
1563
|
+
|
1564
|
+
# Create a figure and a set of subplots
|
1565
|
+
fig, ax = plt.subplots()
|
1566
|
+
ax.barh(range(len(sorted_idx)), perm_importance.importances_mean[sorted_idx], color="teal", align="center", alpha=0.6)
|
1567
|
+
ax.set_yticks(range(len(sorted_idx)))
|
1568
|
+
ax.set_yticklabels([df.columns[i] for i in sorted_idx])
|
1569
|
+
ax.set_xlabel('Permutation Importance')
|
1570
|
+
plt.tight_layout()
|
1571
|
+
plt.show()
|
1572
|
+
save_plot(fig, src='figures', variable='permutation_importance', i=1)
|
1573
|
+
return fig
|
1574
|
+
|
1575
|
+
def plot_partial_dependences(df, target='prauc', clean=True):
|
1576
|
+
|
1577
|
+
"""
|
1578
|
+
Creates partial dependence plots for the specified features, with improved layout to avoid text overlap.
|
1579
|
+
|
1580
|
+
Args:
|
1581
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1582
|
+
target (str): The target variable.
|
1583
|
+
|
1584
|
+
Returns:
|
1585
|
+
None
|
1586
|
+
"""
|
1587
|
+
|
1588
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1589
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1590
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1591
|
+
|
1592
|
+
if clean:
|
1593
|
+
df = remove_columns_with_single_value(df)
|
1594
|
+
features = [feature for feature in features if feature in df.columns]
|
1595
|
+
|
1596
|
+
X = df[features]
|
1597
|
+
y = df[target]
|
1598
|
+
|
1599
|
+
# Train a model
|
1600
|
+
model = GradientBoostingRegressor()
|
1601
|
+
model.fit(X, y)
|
1602
|
+
|
1603
|
+
# Determine the number of rows and columns for subplots
|
1604
|
+
n_cols = 4 # Number of columns in subplot grid
|
1605
|
+
n_rows = (len(features) + n_cols - 1) // n_cols # Calculate rows needed
|
1606
|
+
|
1607
|
+
# Plot partial dependence
|
1608
|
+
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(5 * n_cols, 5 * n_rows))
|
1609
|
+
fig.suptitle('Partial Dependence Plots', fontsize=20, y=1.03)
|
1610
|
+
|
1611
|
+
# Flatten the array of axes if it's multidimensional
|
1612
|
+
axs = axs.flatten() if n_rows > 1 else [axs]
|
1613
|
+
|
1614
|
+
for i, feature in enumerate(features):
|
1615
|
+
ax = axs[i]
|
1616
|
+
disp = PartialDependenceDisplay.from_estimator(model, X, features=[feature], ax=ax)
|
1617
|
+
ax.set_title(feature) # Set title to the name of the feature
|
1618
|
+
|
1619
|
+
# Hide unused axes if any
|
1620
|
+
for ax in axs[len(features):]:
|
1621
|
+
ax.set_visible(False)
|
1622
|
+
|
1623
|
+
plt.tight_layout()
|
1624
|
+
plt.show()
|
1625
|
+
save_plot(fig, src='figures', variable='partial_dependences', i=1)
|
1626
|
+
return fig
|
1627
|
+
|
1628
|
+
def save_shap_plot(fig, src, variable, i):
|
1629
|
+
import os
|
1630
|
+
os.makedirs(f'{src}/{variable}', exist_ok=True)
|
1631
|
+
filename_fig = f'{src}/{variable}/{str(i)}_figure.pdf'
|
1632
|
+
fig.savefig(filename_fig, dpi=600, format='pdf', bbox_inches='tight')
|
1633
|
+
print(f"Saved figure as {filename_fig}")
|
1634
|
+
|
1635
|
+
def generate_shap_summary_plot(df,target='prauc', clean=True):
|
1636
|
+
"""
|
1637
|
+
Generates a SHAP summary plot for the given features in the dataframe.
|
1638
|
+
|
1639
|
+
Args:
|
1640
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
1641
|
+
features (list): List of column names to include as features.
|
1642
|
+
target (str): The name of the target variable column.
|
1643
|
+
|
1644
|
+
Returns:
|
1645
|
+
None
|
1646
|
+
"""
|
1647
|
+
|
1648
|
+
features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
|
1649
|
+
'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
|
1650
|
+
'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
|
1651
|
+
|
1652
|
+
if clean:
|
1653
|
+
df = remove_columns_with_single_value(df)
|
1654
|
+
features = [feature for feature in features if feature in df.columns]
|
1655
|
+
|
1656
|
+
X = df[features]
|
1657
|
+
y = df[target]
|
1658
|
+
|
1659
|
+
# Initialize a model (you could pass it as an argument if you'd like to use a different one)
|
1660
|
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
1661
|
+
model.fit(X, y)
|
1662
|
+
|
1663
|
+
# Calculate SHAP values
|
1664
|
+
explainer = shap.TreeExplainer(model)
|
1665
|
+
shap_values = explainer.shap_values(X)
|
1666
|
+
|
1667
|
+
# Summary plot
|
1668
|
+
shap.summary_plot(shap_values, X)
|
1669
|
+
save_shap_plot(plt.gcf(), src='figures', variable='shap', i=1)
|
1670
|
+
#save_shap_plot(fig, src, variable, i)
|
1671
|
+
return plt.gcf()
|
1672
|
+
|
1673
|
+
|
1674
|
+
def remove_constant_columns(df):
|
1675
|
+
"""
|
1676
|
+
Removes columns in the DataFrame where all entries have the same value.
|
1677
|
+
|
1678
|
+
Parameters:
|
1679
|
+
df (pd.DataFrame): The input DataFrame from which to remove constant columns.
|
1680
|
+
|
1681
|
+
Returns:
|
1682
|
+
pd.DataFrame: A DataFrame with the constant columns removed.
|
1683
|
+
"""
|
1684
|
+
return df.loc[:, df.nunique() > 1]
|
1685
|
+
|
1686
|
+
|
1687
|
+
# to justify using beta for sim classifier
|
1688
|
+
|
1689
|
+
# Fit a Beta distribution to these outputs
|
1690
|
+
#a, b, loc, scale = beta.fit(predicted_probs, floc=0, fscale=1) # Fix location and scale to match the support of the sigmoid
|
1691
|
+
|
1692
|
+
# Sample from this fitted Beta distribution
|
1693
|
+
#simulated_probs = beta.rvs(a, b, size=1000)
|
1694
|
+
|
1695
|
+
# Plot the empirical vs simulated distribution
|
1696
|
+
#plt.hist(predicted_probs, bins=30, alpha=0.5, label='Empirical')
|
1697
|
+
#plt.hist(simulated_probs, bins=30, alpha=0.5, label='Simulated from Beta')
|
1698
|
+
#plt.legend()
|
1699
|
+
#plt.show()
|