spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/sim.py CHANGED
@@ -1,16 +1,19 @@
1
1
 
2
- import os, gc, random, warnings, traceback, itertools, matplotlib, sqlite3
3
- import time as tm
2
+ import os, random, warnings, traceback, sqlite3, shap, math, gc
4
3
  from time import time, sleep
5
4
  from datetime import datetime
6
5
  import numpy as np
7
6
  import pandas as pd
8
7
  import matplotlib.pyplot as plt
8
+ from matplotlib.patches import Patch
9
9
  import seaborn as sns
10
10
  import sklearn.metrics as metrics
11
- from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix, precision_recall_curve
11
+ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
12
+ from sklearn.inspection import PartialDependenceDisplay, permutation_importance
13
+ from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_recall_curve
12
14
  import statsmodels.api as sm
13
- from multiprocessing import cpu_count, Value, Array, Lock, Pool, Manager
15
+ from multiprocessing import cpu_count, Pool, Manager
16
+ from copy import deepcopy
14
17
 
15
18
  from .logger import log_function_call
16
19
 
@@ -35,6 +38,7 @@ def generate_gene_list(number_of_genes, number_of_all_genes):
35
38
 
36
39
  # plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
37
40
  def generate_plate_map(nr_plates):
41
+ #print('nr_plates',nr_plates)
38
42
  """
39
43
  Generate a plate map based on the number of plates.
40
44
 
@@ -240,6 +244,7 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
240
244
  - df_ls (list): A list containing gene counts per well, well counts per gene, Gini coefficients for wells,
241
245
  Gini coefficients for genes, gene weights array, and well weights.
242
246
  """
247
+
243
248
  #generate primary distributions and genes
244
249
  cpw, _ = dist_gen(avg_cells_per_well, sd_cells_per_well, plate_map)
245
250
  gpw, _ = dist_gen(avg_genes_per_well, sd_genes_per_well, plate_map)
@@ -250,14 +255,18 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
250
255
  gene_weights_array = np.array(list(gene_weights.values())) # Convert the values to an array
251
256
 
252
257
  well_weights = generate_power_law_distribution(len(plate_map), well_ineq_coeff)
253
-
254
- gene_to_well_mapping = {}
255
-
256
- for gene in genes:
257
- gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=int(gpw[gene-1]), p=well_weights) # Generate a number of wells for each gene according to well_weights
258
-
259
- gene_to_well_mapping = {gene: wells for gene, wells in gene_to_well_mapping.items() if len(wells) >= 2}
260
258
 
259
+ gene_to_well_mapping = {}
260
+ for gene in range(1, number_of_genes + 1): # ensures gene-1 is within bounds
261
+ if gene-1 < len(gpw):
262
+ max_index = len(plate_map['plate_row_column']) # this should be the number of choices available from plate_map
263
+ num_samples = int(gpw[gene-1])
264
+ if num_samples >= max_index:
265
+ num_samples = max_index - 1 # adjust to maximum possible index
266
+ gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=num_samples, replace=False, p=well_weights)
267
+ else:
268
+ break # break the loop if gene-1 is out of bounds for gpw
269
+
261
270
  cells = []
262
271
  for i in [*range(0,len(plate_map))]:
263
272
  ciw = random.choice(cpw)
@@ -315,9 +324,47 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
315
324
  df_ls = [gene_counts_per_well, well_counts_per_gene, gini_well, gini_gene, gene_weights_array, well_weights]
316
325
  return cell_df, genes_per_well_df, wells_per_gene_df, df_ls
317
326
 
318
- # classifier is a function that takes a cell state (active=1/inactive=0) and produces a score in [0, 1]
319
- # For the input cell, it checks if it is active or inactive, and then samples from an appropriate beta distribution to give a score
320
- def classifier(positive_mean, positive_variance, negative_mean, negative_variance, df):
327
+ def classifier(positive_mean, positive_variance, negative_mean, negative_variance, classifier_accuracy, df):
328
+ """
329
+ Classifies the data in the DataFrame based on the given parameters and a classifier error rate.
330
+
331
+ Args:
332
+ positive_mean (float): The mean of the positive distribution.
333
+ positive_variance (float): The variance of the positive distribution.
334
+ negative_mean (float): The mean of the negative distribution.
335
+ negative_variance (float): The variance of the negative distribution.
336
+ classifier_accuracy (float): The likelihood (0 to 1) that a gene is correctly classified according to its true label.
337
+ df (pandas.DataFrame): The DataFrame containing the data to be classified.
338
+
339
+ Returns:
340
+ pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
341
+ """
342
+ def calc_alpha_beta(mean, variance):
343
+ if mean <= 0 or mean >= 1:
344
+ raise ValueError("Mean must be between 0 and 1 exclusively.")
345
+ max_variance = mean * (1 - mean)
346
+ if variance <= 0 or variance >= max_variance:
347
+ raise ValueError(f"Variance must be positive and less than {max_variance}.")
348
+
349
+ alpha = mean * (mean * (1 - mean) / variance - 1)
350
+ beta = alpha * (1 - mean) / mean
351
+ return alpha, beta
352
+
353
+ # Apply the beta distribution based on 'is_active' status with consideration for classifier error
354
+ def get_score(is_active):
355
+ if np.random.rand() < classifier_accuracy: # With classifier_accuracy probability, choose the correct distribution
356
+ return np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2)
357
+ else: # With 1-classifier_accuracy probability, choose the incorrect distribution
358
+ return np.random.beta(a2, b2) if is_active else np.random.beta(a1, b1)
359
+
360
+ # Calculate alpha and beta for both distributions
361
+ a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
362
+ a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
363
+ df['score'] = df['is_active'].apply(get_score)
364
+
365
+ return df
366
+
367
+ def classifier_v2(positive_mean, positive_variance, negative_mean, negative_variance, df):
321
368
  """
322
369
  Classifies the data in the DataFrame based on the given parameters.
323
370
 
@@ -331,12 +378,22 @@ def classifier(positive_mean, positive_variance, negative_mean, negative_varianc
331
378
  Returns:
332
379
  pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
333
380
  """
334
- # alpha and beta for positive distribution
335
- a1 = positive_mean*(positive_mean*(1-positive_mean)/positive_variance - 1)
336
- b1 = a1*(1-positive_mean)/positive_mean
337
- # alpha and beta for negative distribution
338
- a2 = negative_mean*(negative_mean*(1-negative_mean)/negative_variance - 1)
339
- b2 = a2*(1-negative_mean)/negative_mean
381
+ def calc_alpha_beta(mean, variance):
382
+ if mean <= 0 or mean >= 1:
383
+ raise ValueError("Mean must be between 0 and 1 exclusively.")
384
+ max_variance = mean * (1 - mean)
385
+ if variance <= 0 or variance >= max_variance:
386
+ raise ValueError(f"Variance must be positive and less than {max_variance}.")
387
+
388
+ alpha = mean * (mean * (1 - mean) / variance - 1)
389
+ beta = alpha * (1 - mean) / mean
390
+ return alpha, beta
391
+
392
+ # Calculate alpha and beta for both distributions
393
+ a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
394
+ a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
395
+
396
+ # Apply the beta distribution based on 'is_active' status
340
397
  df['score'] = df['is_active'].apply(lambda is_active: np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2))
341
398
  return df
342
399
 
@@ -445,6 +502,7 @@ def generate_well_score(cell_scores):
445
502
  return well_score
446
503
 
447
504
  def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_per_gene, sequencing_error=0.01):
505
+
448
506
  """
449
507
  Simulates the sequencing of plates and calculates gene fractions and metadata.
450
508
 
@@ -459,6 +517,7 @@ def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_pe
459
517
  gene_fraction_map (pd.DataFrame): DataFrame containing gene fractions for each well.
460
518
  metadata (pd.DataFrame): DataFrame containing metadata for each well.
461
519
  """
520
+
462
521
  reads, _ = dist_gen(avg_reads_per_gene, sd_reads_per_gene, well_score)
463
522
  gene_names = [f'gene_{v}' for v in range(number_of_genes+1)]
464
523
  all_wells = well_score.index
@@ -611,8 +670,12 @@ def plot_histogram(data, x_label, ax, color, title, binwidth=0.01, log=False):
611
670
  Returns:
612
671
  None
613
672
  """
614
- sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
615
- legend=False, fill=True, element='step', palette='dark')
673
+ if not binwidth:
674
+ sns.histplot(data=data, x=x_label, ax=ax, color=color, kde=False, stat='density',
675
+ legend=False, fill=True, element='step', palette='dark')
676
+ else:
677
+ sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
678
+ legend=False, fill=True, element='step', palette='dark')
616
679
  if log:
617
680
  ax.set_yscale('log')
618
681
  ax.set_title(title)
@@ -699,7 +762,7 @@ def run_simulation(settings):
699
762
  plate_map = plate_map[~plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is not in [1,2,3,23,24]
700
763
 
701
764
  cell_level, genes_per_well_df, wells_per_gene_df, dists = run_experiment(plate_map, settings['number_of_genes'], active_gene_list, settings['avg_genes_per_well'], settings['sd_genes_per_well'], settings['avg_cells_per_well'], settings['sd_cells_per_well'], settings['well_ineq_coeff'], settings['gene_ineq_coeff'])
702
- cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], df=cell_level)
765
+ cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], settings['classifier_accuracy'], df=cell_level)
703
766
  cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm = cell_level_roc_auc(cell_scores)
704
767
  well_score = generate_well_score(cell_scores)
705
768
  gene_fraction_map, metadata = sequence_plates(well_score, settings['number_of_genes'], settings['avg_reads_per_gene'], settings['sd_reads_per_gene'], sequencing_error=settings['sequencing_error'])
@@ -719,7 +782,10 @@ def run_simulation(settings):
719
782
  results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats = regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False)
720
783
  #except Exception as e:
721
784
  # print(f"An error occurred while saving data: {e}")
722
- return [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df], dists
785
+ output = [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df]
786
+ del cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df
787
+ gc.collect()
788
+ return output, dists
723
789
 
724
790
  def vis_dists(dists, src, v, i):
725
791
  """
@@ -743,9 +809,14 @@ def vis_dists(dists, src, v, i):
743
809
  for index, dist in enumerate(dists):
744
810
  temp = pd.DataFrame(dist, columns = [f'{names[index]}'])
745
811
  sns.histplot(data=temp, x=f'{names[index]}', kde=False, binwidth=None, stat='count', element="step", ax=ax[n], color='teal', log_scale=False)
746
- #plot_histogram(temp, f'{names[index]}', ax[n], 'slategray', f'{names[index]}', binwidth=None, log=False)
747
812
  n+=1
748
813
  save_plot(fig2, src, 'dists', i)
814
+ plt.close(fig2)
815
+ plt.figure().clear()
816
+ plt.cla()
817
+ plt.clf()
818
+ del dists
819
+
749
820
  return
750
821
 
751
822
  def visualize_all(output):
@@ -772,6 +843,7 @@ def visualize_all(output):
772
843
  Returns:
773
844
  fig (matplotlib.figure.Figure): The generated figure object.
774
845
  """
846
+
775
847
  cell_scores = output[0]
776
848
  cell_roc_dict_df = output[1]
777
849
  cell_pr_dict_df = output[2]
@@ -809,24 +881,35 @@ def visualize_all(output):
809
881
  #plot cell classification score by inactive and active
810
882
  active_distribution = cell_scores[cell_scores['is_active'] == 1]
811
883
  inactive_distribution = cell_scores[cell_scores['is_active'] == 0]
812
- plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', binwidth=0.01, log=False)
813
- plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', binwidth=0.01, log=False)
884
+ plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', log=False)#, binwidth=0.01, log=False)
885
+ plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', log=False)#, binwidth=0.01, log=False)
886
+
887
+ legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
888
+ Patch(facecolor='teal', edgecolor='teal', label='Active')]
889
+
890
+ ax[n].legend(handles=legend_elements, loc='upper right')
891
+
892
+
814
893
  ax[n].set_xlim([0, 1])
815
894
  n+=1
816
895
 
817
896
  #plot classifier cell predictions by inactive and active well average
818
- ##inactive_distribution_well['score'] = pd.to_numeric(inactive_distribution['score'], errors='coerce')
819
- ##inactive_distribution_well = inactive_distribution_well.groupby('plate_row_column')['score'].mean()
820
-
821
- ##active_distribution_well['score'] = pd.to_numeric(active_distribution['score'], errors='coerce')
822
- ##active_distribution_well = active_distribution_well.groupby('plate_row_column')['score'].mean()
897
+ inactive_distribution_well = inactive_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
898
+ active_distribution_well = active_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
899
+ mixed_distribution_well = cell_scores.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
900
+
901
+ plot_histogram(inactive_distribution_well, "score", ax[n], 'slategray', 'Well scores', log=False)#, binwidth=0.01, log=False)
902
+ plot_histogram(active_distribution_well, "score", ax[n], 'teal', 'Well scores', log=False)#, binwidth=0.01, log=False)
903
+ plot_histogram(mixed_distribution_well, "score", ax[n], 'red', 'Well scores', log=False)#, binwidth=0.01, log=False)
823
904
 
824
- #inactive_distribution_well = inactive_distribution.groupby(['plate_row_column']).mean()
825
- #active_distribution_well = active_distribution.groupby(['plate_row_column']).mean()
905
+ legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
906
+ Patch(facecolor='teal', edgecolor='teal', label='Active'),
907
+ Patch(facecolor='red', edgecolor='red', label='Mixed')]
826
908
 
827
- plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Well scores', binwidth=0.01, log=False)
828
- plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Well scores', binwidth=0.01, log=False)
909
+ ax[n].legend(handles=legend_elements, loc='upper right')
910
+
829
911
  ax[n].set_xlim([0, 1])
912
+ #ax[n].legend()
830
913
  n+=1
831
914
 
832
915
  #plot ROC (cell classification)
@@ -845,8 +928,8 @@ def visualize_all(output):
845
928
  n+=1
846
929
 
847
930
  #plot well score
848
- plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=False)
849
- ax[n].set_xlim([0, 1])
931
+ plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=True)
932
+ #ax[n].set_xlim([0, 1])
850
933
  n+=1
851
934
 
852
935
  control_df = results_df[results_df['color'] == 'control']
@@ -910,6 +993,7 @@ def visualize_all(output):
910
993
 
911
994
  plt.tight_layout()
912
995
  plt.show()
996
+ gc.collect()
913
997
  return fig
914
998
 
915
999
  def create_database(db_path):
@@ -991,6 +1075,7 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
991
1075
  df_concat[f'variable_{variable}_sim_nr'] = i
992
1076
 
993
1077
  append_database(src, df_concat, 'simulations')
1078
+ del gini_genes_per_well, gini_wells_per_gene, df_concat
994
1079
 
995
1080
  if save_all:
996
1081
  for i, df in enumerate(output):
@@ -1000,8 +1085,12 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
1000
1085
  if not isinstance(df, pd.DataFrame):
1001
1086
  df = pd.DataFrame(df)
1002
1087
  append_database(src, df, table_names[i])
1088
+ del df
1003
1089
  except Exception as e:
1004
1090
  print(f"An error occurred while saving data: {e}")
1091
+ print(traceback.format_exc())
1092
+
1093
+ del output, settings_df
1005
1094
  return
1006
1095
 
1007
1096
  def save_plot(fig, src, variable, i):
@@ -1023,6 +1112,7 @@ def save_plot(fig, src, variable, i):
1023
1112
  return
1024
1113
 
1025
1114
  def run_and_save(i, settings, time_ls, total_sims):
1115
+
1026
1116
  """
1027
1117
  Run the simulation and save the results.
1028
1118
 
@@ -1035,23 +1125,26 @@ def run_and_save(i, settings, time_ls, total_sims):
1035
1125
  Returns:
1036
1126
  tuple: A tuple containing the simulation index, simulation time, and None.
1037
1127
  """
1128
+ #print(f'Runnings simulation with the following paramiters')
1129
+ #print(settings)
1130
+ settings['random_seed'] = False
1038
1131
  if settings['random_seed']:
1039
1132
  random.seed(42) # sims will be too similar with random seed
1040
1133
  src = settings['src']
1041
1134
  plot = settings['plot']
1042
1135
  v = settings['variable']
1043
1136
  start_time = time() # Start time of the simulation
1044
- now = datetime.now() # get current date
1045
- date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1137
+ #now = datetime.now() # get current date
1138
+ #date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1139
+ date_string = settings['start_time']
1046
1140
  #try:
1047
1141
  output, dists = run_simulation(settings)
1048
1142
  sim_time = time() - start_time # Elapsed time for the simulation
1049
1143
  settings['sim_time'] = sim_time
1050
1144
  src = os.path.join(f'{src}/{date_string}',settings['name'])
1051
1145
  save_data(src, output, settings, save_all=False, i=i, variable=v)
1052
- if vis_dists:
1053
- vis_dists(dists,src, v, i)
1054
1146
  if plot:
1147
+ vis_dists(dists,src, v, i)
1055
1148
  fig = visualize_all(output)
1056
1149
  save_plot(fig, src, v, i)
1057
1150
  plt.close(fig)
@@ -1060,6 +1153,7 @@ def run_and_save(i, settings, time_ls, total_sims):
1060
1153
  plt.clf()
1061
1154
  del fig
1062
1155
  del output, dists
1156
+ gc.collect()
1063
1157
  #except Exception as e:
1064
1158
  # print(e, end='\r', flush=True)
1065
1159
  # sim_time = time() - start_time
@@ -1067,98 +1161,92 @@ def run_and_save(i, settings, time_ls, total_sims):
1067
1161
  time_ls.append(sim_time)
1068
1162
  return i, sim_time, None
1069
1163
 
1070
- def generate_paramiters(settings):
1164
+ def validate_and_adjust_beta_params(sim_params):
1071
1165
  """
1072
- Generate a list of parameter sets for simulation based on the given settings.
1073
-
1166
+ Validates and adjusts Beta distribution parameters in simulation settings to ensure they are possible.
1167
+
1074
1168
  Args:
1075
- settings (dict): A dictionary containing the simulation settings.
1076
-
1169
+ sim_params (list of dict): List of dictionaries, each containing the simulation parameters.
1170
+
1077
1171
  Returns:
1078
- list: A list of parameter sets for simulation.
1172
+ list of dict: The adjusted list of simulation parameter sets.
1079
1173
  """
1080
- sim_ls = []
1081
- for avg_genes_per_well in settings['avg_genes_per_well']:
1082
- replicates = settings['replicates']
1083
- sett = settings.copy()
1084
- sett['avg_genes_per_well'] = avg_genes_per_well
1085
- sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
1086
- for avg_cells_per_well in settings['avg_cells_per_well']:
1087
- sett['avg_cells_per_well'] = avg_cells_per_well
1088
- sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
1089
- for positive_mean in settings['positive_mean']:
1090
- sett['positive_mean'] = positive_mean
1091
- sett['negative_mean'] = 1-positive_mean
1092
- sett['positive_variance'] = (1-positive_mean)/2
1093
- sett['negative_variance'] = (1-positive_mean)/2
1094
- for avg_reads_per_gene in settings['avg_reads_per_gene']:
1095
- sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
1096
- sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
1097
- for sequencing_error in settings['sequencing_error']:
1098
- sett['sequencing_error'] = sequencing_error
1099
- for well_ineq_coeff in settings['well_ineq_coeff']:
1100
- sett['well_ineq_coeff'] = well_ineq_coeff
1101
- for gene_ineq_coeff in settings['gene_ineq_coeff']:
1102
- sett['gene_ineq_coeff'] = gene_ineq_coeff
1103
- for nr_plates in settings['nr_plates']:
1104
- sett['nr_plates'] = nr_plates
1105
- for number_of_genes in settings['number_of_genes']:
1106
- sett['number_of_genes'] = number_of_genes
1107
- for number_of_active_genes in settings['number_of_active_genes']:
1108
- sett['number_of_active_genes'] = number_of_active_genes
1109
- for i in [*range(1,replicates+1)]:
1110
- sim_ls.append(sett)
1111
- #print(sett)
1112
- #print('Number of simulations:',len(sim_ls))
1113
- return sim_ls
1174
+ adjusted_params = []
1175
+ for params in sim_params:
1176
+ max_pos_variance = params['positive_mean'] * (1 - params['positive_mean'])
1177
+ max_neg_variance = params['negative_mean'] * (1 - params['negative_mean'])
1178
+
1179
+ # Adjust positive variance
1180
+ if params['positive_variance'] >= max_pos_variance:
1181
+ print(f'changed positive variance from {params["positive_variance"]} to {max_pos_variance * 0.99}')
1182
+ params['positive_variance'] = max_pos_variance * 0.99 # Adjust to 99% of the maximum allowed variance
1183
+
1184
+ # Adjust negative variance
1185
+ if params['negative_variance'] >= max_neg_variance:
1186
+ print(f'changed negative variance from {params["negative_variance"]} to {max_neg_variance * 0.99}')
1187
+ params['negative_variance'] = max_neg_variance * 0.99 # Adjust to 99% of the maximum allowed variance
1188
+
1189
+ adjusted_params.append(params)
1190
+
1191
+ return adjusted_params
1192
+
1193
+ def generate_paramiters(settings):
1114
1194
 
1115
- #altered for one set of settings see negative_mean and variance
1116
- def generate_paramiters_single(settings):
1117
1195
  """
1118
- Generate a list of parameter sets for single simulations based on the given settings.
1196
+ Generate a list of parameter sets for simulation based on the given settings.
1119
1197
 
1120
1198
  Args:
1121
1199
  settings (dict): A dictionary containing the simulation settings.
1122
1200
 
1123
1201
  Returns:
1124
- list: A list of parameter sets for single simulations.
1202
+ list: A list of parameter sets for simulation.
1125
1203
  """
1204
+
1205
+ settings['positive_mean'] = [0.8]
1206
+
1126
1207
  sim_ls = []
1127
1208
  for avg_genes_per_well in settings['avg_genes_per_well']:
1128
1209
  replicates = settings['replicates']
1129
- sett = settings.copy()
1130
- sett['avg_genes_per_well'] = avg_genes_per_well
1131
- sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
1132
1210
  for avg_cells_per_well in settings['avg_cells_per_well']:
1133
- sett['avg_cells_per_well'] = avg_cells_per_well
1134
- sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
1135
- for positive_mean in settings['positive_mean']:
1136
- sett['positive_mean'] = positive_mean
1137
- sett['negative_mean'] = 0.2
1138
- sett['positive_variance'] = 0.13
1139
- sett['negative_variance'] = 0.13
1140
- for avg_reads_per_gene in settings['avg_reads_per_gene']:
1141
- sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
1142
- sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
1143
- for sequencing_error in settings['sequencing_error']:
1144
- sett['sequencing_error'] = sequencing_error
1145
- for well_ineq_coeff in settings['well_ineq_coeff']:
1146
- sett['well_ineq_coeff'] = well_ineq_coeff
1147
- for gene_ineq_coeff in settings['gene_ineq_coeff']:
1148
- sett['gene_ineq_coeff'] = gene_ineq_coeff
1149
- for nr_plates in settings['nr_plates']:
1150
- sett['nr_plates'] = nr_plates
1151
- for number_of_genes in settings['number_of_genes']:
1152
- sett['number_of_genes'] = number_of_genes
1153
- for number_of_active_genes in settings['number_of_active_genes']:
1154
- sett['number_of_active_genes'] = number_of_active_genes
1155
- for i in [*range(1,replicates+1)]:
1156
- sim_ls.append(sett)
1157
- #print(sett)
1158
- #print('Number of simulations:',len(sim_ls))
1211
+ for classifier_accuracy in settings['classifier_accuracy']:
1212
+ for positive_mean in settings['positive_mean']:
1213
+ for avg_reads_per_gene in settings['avg_reads_per_gene']:
1214
+ for sequencing_error in settings['sequencing_error']:
1215
+ for well_ineq_coeff in settings['well_ineq_coeff']:
1216
+ for gene_ineq_coeff in settings['gene_ineq_coeff']:
1217
+ for nr_plates in settings['nr_plates']:
1218
+ for number_of_genes in settings['number_of_genes']:
1219
+ for number_of_active_genes in settings['number_of_active_genes']:
1220
+ for i in range(1, replicates+1):
1221
+ sett = deepcopy(settings)
1222
+ sett['avg_genes_per_well'] = avg_genes_per_well
1223
+ sett['sd_genes_per_well'] = avg_genes_per_well / 2
1224
+ sett['avg_cells_per_well'] = avg_cells_per_well
1225
+ sett['sd_cells_per_well'] = avg_cells_per_well / 2
1226
+ sett['classifier_accuracy'] = classifier_accuracy
1227
+ sett['positive_mean'] = positive_mean
1228
+ sett['negative_mean'] = 1-positive_mean
1229
+ sett['positive_variance'] = (1-positive_mean)/2
1230
+ sett['negative_variance'] = (1-positive_mean)/2
1231
+ sett['avg_reads_per_gene'] = avg_reads_per_gene
1232
+ sett['sd_reads_per_gene'] = avg_reads_per_gene / 2
1233
+ sett['sequencing_error'] = sequencing_error
1234
+ sett['well_ineq_coeff'] = well_ineq_coeff
1235
+ sett['gene_ineq_coeff'] = gene_ineq_coeff
1236
+ sett['nr_plates'] = nr_plates
1237
+ sett['number_of_genes'] = number_of_genes
1238
+ sett['number_of_active_genes'] = number_of_active_genes
1239
+ sim_ls.append(sett)
1240
+
1241
+ random.shuffle(sim_ls)
1242
+ sim_ls = validate_and_adjust_beta_params(sim_ls)
1243
+ print(f'Running {len(sim_ls)} simulations.')
1244
+ #for x in sim_ls:
1245
+ # print(x['positive_mean'])
1159
1246
  return sim_ls
1160
1247
 
1161
1248
  def run_multiple_simulations(settings):
1249
+
1162
1250
  """
1163
1251
  Run multiple simulations in parallel using the provided settings.
1164
1252
 
@@ -1169,8 +1257,12 @@ def run_multiple_simulations(settings):
1169
1257
  None
1170
1258
  """
1171
1259
 
1260
+ now = datetime.now() # get current date
1261
+ start_time = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1262
+ settings['start_time'] = start_time
1263
+
1172
1264
  sim_ls = generate_paramiters(settings)
1173
- print(f'Running {len(sim_ls)} simulations. Standard deviations for each variable are variable / 2')
1265
+ #print(f'Running {len(sim_ls)} simulations.')
1174
1266
 
1175
1267
  max_workers = settings['max_workers'] or cpu_count() - 4
1176
1268
  with Manager() as manager:
@@ -1179,9 +1271,429 @@ def run_multiple_simulations(settings):
1179
1271
  with Pool(max_workers) as pool:
1180
1272
  result = pool.starmap_async(run_and_save, [(index, settings, time_ls, total_sims) for index, settings in enumerate(sim_ls)])
1181
1273
  while not result.ready():
1182
- sleep(0.01)
1183
- sims_processed = len(time_ls)
1184
- average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1185
- time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
1186
- print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
1187
- result.get()
1274
+ try:
1275
+ sleep(0.01)
1276
+ sims_processed = len(time_ls)
1277
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1278
+ time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
1279
+ print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
1280
+ gc.collect()
1281
+ except Exception as e:
1282
+ print(e)
1283
+ print(traceback.format_exc())
1284
+ try:
1285
+ result.get()
1286
+ except Exception as e:
1287
+ print(e)
1288
+ print(traceback.format_exc())
1289
+
1290
+ def generate_integers(start, stop, step):
1291
+ return list(range(start, stop + 1, step))
1292
+
1293
+ def generate_floats(start, stop, step):
1294
+ # Determine the number of decimal places in 'step'
1295
+ num_decimals = str(step)[::-1].find('.')
1296
+
1297
+ current = start
1298
+ floats_list = []
1299
+ while current <= stop:
1300
+ # Round each float to the appropriate number of decimal places
1301
+ floats_list.append(round(current, num_decimals))
1302
+ current += step
1303
+
1304
+ return floats_list
1305
+
1306
+ def remove_columns_with_single_value(df):
1307
+ """
1308
+ Removes columns from the DataFrame that have the same value in all rows.
1309
+
1310
+ Args:
1311
+ df (pandas.DataFrame): The original DataFrame.
1312
+
1313
+ Returns:
1314
+ pandas.DataFrame: A DataFrame with the columns removed that contained only one unique value.
1315
+ """
1316
+ to_drop = [column for column in df.columns if df[column].nunique() == 1]
1317
+ return df.drop(to_drop, axis=1)
1318
+
1319
+ def read_simulations_table(db_path):
1320
+ """
1321
+ Reads the 'simulations' table from an SQLite database into a pandas DataFrame.
1322
+
1323
+ Args:
1324
+ db_path (str): The file path to the SQLite database.
1325
+
1326
+ Returns:
1327
+ pandas.DataFrame: DataFrame containing the 'simulations' table data.
1328
+ """
1329
+ # Create a connection object using the connect function
1330
+ conn = sqlite3.connect(db_path)
1331
+
1332
+ # Read the 'simulations' table into a pandas DataFrame
1333
+ try:
1334
+ df = pd.read_sql_query("SELECT * FROM simulations", conn)
1335
+ except Exception as e:
1336
+ print(f"An error occurred: {e}")
1337
+ return None
1338
+ finally:
1339
+ # Close the connection to SQLite database
1340
+ conn.close()
1341
+
1342
+ return df
1343
+
1344
+ def plot_simulations(df, variable, x_rotation=None, legend=False, grid=False, clean=True, verbose=False):
1345
+
1346
+ """
1347
+ Creates separate line plots for 'prauc' against a specified 'variable',
1348
+ for each unique combination of conditions defined by 'grouping_vars', displayed on a grid.
1349
+
1350
+ Args:
1351
+ df (pandas.DataFrame): DataFrame containing the necessary columns.
1352
+ variable (str): Name of the column to use as the x-axis for grouping and plotting.
1353
+ x_rotation (int, optional): Degrees to rotate the x-axis labels.
1354
+ legend (bool, optional): Whether to display a legend.
1355
+ grid (bool, optional): Whether to display grid lines.
1356
+ verbose (bool, optional): Whether to print the filter conditions.
1357
+
1358
+ Returns:
1359
+ None
1360
+ """
1361
+
1362
+ grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1363
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1364
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1365
+
1366
+ if clean:
1367
+ relevant_data = remove_columns_with_single_value(relevant_data)
1368
+
1369
+ grouping_vars = [col for col in grouping_vars if col != variable]
1370
+
1371
+ # Check if the necessary columns are present in the DataFrame
1372
+ required_columns = {variable, 'prauc'} | set(grouping_vars)
1373
+ if not required_columns.issubset(df.columns):
1374
+ missing_cols = required_columns - set(df.columns)
1375
+ raise ValueError(f"DataFrame must contain {missing_cols} columns")
1376
+
1377
+ #if not dependent is None:
1378
+
1379
+ # Get unique combinations of conditions from grouping_vars
1380
+ unique_combinations = df[grouping_vars].drop_duplicates()
1381
+ num_combinations = len(unique_combinations)
1382
+
1383
+ # Determine the layout of the subplots
1384
+ num_rows = math.ceil(np.sqrt(num_combinations))
1385
+ num_cols = math.ceil(num_combinations / num_rows)
1386
+
1387
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
1388
+ if num_rows * num_cols > 1:
1389
+ axes = axes.flatten()
1390
+ else:
1391
+ axes = [axes]
1392
+
1393
+ for idx, (ax, (_, row)) in enumerate(zip(axes, unique_combinations.iterrows())):
1394
+
1395
+ # Filter the DataFrame for the current combination of variables
1396
+ condition = {var: row[var] for var in grouping_vars}
1397
+ subset_df = df[df[grouping_vars].eq(row).all(axis=1)]
1398
+
1399
+ # Group by 'variable' and calculate mean and std dev of 'prauc'
1400
+ grouped = subset_df.groupby(variable)['prauc'].agg(['mean', 'std'])
1401
+ grouped = grouped.sort_index() # Sort by the variable for orderly plots
1402
+
1403
+ # Plotting the mean of 'prauc' with std deviation as shaded area
1404
+ ax.plot(grouped.index, grouped['mean'], marker='o', linestyle='-', color='b', label='Mean PRAUC')
1405
+ ax.fill_between(grouped.index, grouped['mean'] - grouped['std'], grouped['mean'] + grouped['std'], color='gray', alpha=0.5, label='Std Dev')
1406
+
1407
+ # Setting plot labels and title
1408
+ title_details = ', '.join([f"{var}={row[var]}" for var in grouping_vars])
1409
+ ax.set_xlabel(variable)
1410
+ ax.set_ylabel('Precision-Recall AUC (PRAUC)')
1411
+ #ax.set_title(f'PRAUC vs. {variable} | {title_details}')
1412
+ ax.grid(grid)
1413
+
1414
+ if legend:
1415
+ ax.legend()
1416
+
1417
+ # Set x-ticks and rotate them as specified
1418
+ ax.set_xticks(grouped.index)
1419
+ ax.set_xticklabels(grouped.index, rotation=x_rotation if x_rotation is not None else 45)
1420
+
1421
+ if verbose:
1422
+ verbose_text = '\n'.join([f"{var}: {val}" for var, val in condition.items()])
1423
+ ax.text(0.95, 0.05, verbose_text, transform=ax.transAxes, fontsize=9, verticalalignment='bottom', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))
1424
+
1425
+ # Hide any unused axes if there are any
1426
+ for ax in axes[idx+1:]:
1427
+ ax.set_visible(False)
1428
+
1429
+ plt.tight_layout()
1430
+ plt.show()
1431
+ return fig
1432
+
1433
+ def plot_correlation_matrix(df, annot=False, cmap='inferno', clean=True):
1434
+ """
1435
+ Plots a correlation matrix for the specified variables and the target variable.
1436
+
1437
+ Args:
1438
+ df (pandas.DataFrame): The DataFrame containing the data.
1439
+ variables (list): List of column names to include in the correlation matrix.
1440
+ target_variable (str): The target variable column name.
1441
+
1442
+ Returns:
1443
+ None
1444
+ """
1445
+ cmap = sns.diverging_palette(240, 10, as_cmap=True)
1446
+ grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1447
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1448
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1449
+
1450
+ grouping_vars = grouping_vars + ['optimal_threshold', 'accuracy', 'prauc', 'roc_auc','genes_per_well_gini', 'wells_per_gene_gini']
1451
+ # 'inactive_mean', 'inactive_std', 'inactive_var', 'active_mean', 'active_std', 'inactive_var', 'cutoff', 'TP', 'FP', 'TN', 'FN',
1452
+
1453
+ if clean:
1454
+ df = remove_constant_columns(df)
1455
+ grouping_vars = [feature for feature in grouping_vars if feature in df.columns]
1456
+
1457
+ # Subsetting the DataFrame to include only the relevant variables
1458
+ relevant_data = df[grouping_vars]
1459
+
1460
+ if clean:
1461
+ relevant_data = remove_columns_with_single_value(relevant_data)
1462
+
1463
+ # Calculating the correlation matrix
1464
+ corr_matrix = relevant_data.corr()
1465
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
1466
+
1467
+ # Plotting the correlation matrix
1468
+ fig = plt.figure(figsize=(12, 8))
1469
+ sns.heatmap(corr_matrix, mask=mask, annot=annot, cmap=cmap, fmt=".2f", linewidths=.5, robust=True)
1470
+ #plt.title('Correlation Matrix with Heatmap')
1471
+
1472
+ plt.tight_layout()
1473
+ plt.show()
1474
+ save_plot(fig, src='figures', variable='correlation_matrix', i=1)
1475
+ return fig
1476
+
1477
+ def plot_feature_importance(df, target='prauc', exclude=None, clean=True):
1478
+ """
1479
+ Trains a RandomForestRegressor to determine the importance of each feature in predicting the target.
1480
+
1481
+ Args:
1482
+ df (pandas.DataFrame): The DataFrame containing the data.
1483
+ target (str): The target variable column name.
1484
+ exclude (list or str, optional): Column names to exclude from features.
1485
+
1486
+ Returns:
1487
+ matplotlib.figure.Figure: The figure object containing the feature importance plot.
1488
+ """
1489
+
1490
+ # Define the features for the model
1491
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1492
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1493
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1494
+
1495
+ if clean:
1496
+ df = remove_columns_with_single_value(df)
1497
+ features = [feature for feature in features if feature in df.columns]
1498
+
1499
+ # Remove excluded features if specified
1500
+ if isinstance(exclude, list):
1501
+ features = [feature for feature in features if feature not in exclude]
1502
+ elif exclude is not None:
1503
+ features = [feature for feature in features if feature != exclude]
1504
+
1505
+ # Train the model
1506
+ model = RandomForestRegressor(n_estimators=1000, random_state=42)
1507
+ model.fit(df[features], df[target])
1508
+
1509
+ # Get feature importances
1510
+ importances = model.feature_importances_
1511
+ indices = np.argsort(importances)[::-1]
1512
+
1513
+ # Plot horizontal bar chart
1514
+ fig = plt.figure(figsize=(12, 6))
1515
+ plt.barh(range(len(indices)), importances[indices], color="teal", align="center", alpha=0.6)
1516
+ plt.yticks(range(len(indices)), [features[i] for i in indices[::-1]]) # Invert y-axis to match the order
1517
+ plt.gca().invert_yaxis() # Invert the axis to have the highest importance at the top
1518
+ plt.xlabel('Feature Importance')
1519
+ plt.title('Feature Importances')
1520
+ plt.tight_layout()
1521
+ plt.show()
1522
+ save_plot(fig, src='figures', variable='feature_importance', i=1)
1523
+ return fig
1524
+
1525
+ def calculate_permutation_importance(df, target='prauc', exclude=None, n_repeats=10, clean=True):
1526
+ """
1527
+ Calculates permutation importance for the given features in the dataframe.
1528
+
1529
+ Args:
1530
+ df (pandas.DataFrame): The DataFrame containing the data.
1531
+ features (list): List of column names to include as features.
1532
+ target (str): The name of the target variable column.
1533
+
1534
+ Returns:
1535
+ dict: Dictionary containing the importances and standard deviations.
1536
+ """
1537
+
1538
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1539
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1540
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1541
+
1542
+ if clean:
1543
+ df = remove_columns_with_single_value(df)
1544
+ features = [feature for feature in features if feature in df.columns]
1545
+
1546
+ if isinstance(exclude, list):
1547
+ for ex in exclude:
1548
+ features.remove(ex)
1549
+ if not exclude is None:
1550
+ features.remove(exclude)
1551
+
1552
+ X = df[features]
1553
+ y = df[target]
1554
+
1555
+ # Initialize a model (you could pass it as an argument if you'd like to use a different one)
1556
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
1557
+ model.fit(X, y)
1558
+
1559
+ perm_importance = permutation_importance(model, X, y, n_repeats=n_repeats, random_state=42)
1560
+
1561
+ # Plotting
1562
+ sorted_idx = perm_importance.importances_mean.argsort()
1563
+
1564
+ # Create a figure and a set of subplots
1565
+ fig, ax = plt.subplots()
1566
+ ax.barh(range(len(sorted_idx)), perm_importance.importances_mean[sorted_idx], color="teal", align="center", alpha=0.6)
1567
+ ax.set_yticks(range(len(sorted_idx)))
1568
+ ax.set_yticklabels([df.columns[i] for i in sorted_idx])
1569
+ ax.set_xlabel('Permutation Importance')
1570
+ plt.tight_layout()
1571
+ plt.show()
1572
+ save_plot(fig, src='figures', variable='permutation_importance', i=1)
1573
+ return fig
1574
+
1575
+ def plot_partial_dependences(df, target='prauc', clean=True):
1576
+
1577
+ """
1578
+ Creates partial dependence plots for the specified features, with improved layout to avoid text overlap.
1579
+
1580
+ Args:
1581
+ df (pandas.DataFrame): The DataFrame containing the data.
1582
+ target (str): The target variable.
1583
+
1584
+ Returns:
1585
+ None
1586
+ """
1587
+
1588
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1589
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1590
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1591
+
1592
+ if clean:
1593
+ df = remove_columns_with_single_value(df)
1594
+ features = [feature for feature in features if feature in df.columns]
1595
+
1596
+ X = df[features]
1597
+ y = df[target]
1598
+
1599
+ # Train a model
1600
+ model = GradientBoostingRegressor()
1601
+ model.fit(X, y)
1602
+
1603
+ # Determine the number of rows and columns for subplots
1604
+ n_cols = 4 # Number of columns in subplot grid
1605
+ n_rows = (len(features) + n_cols - 1) // n_cols # Calculate rows needed
1606
+
1607
+ # Plot partial dependence
1608
+ fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(5 * n_cols, 5 * n_rows))
1609
+ fig.suptitle('Partial Dependence Plots', fontsize=20, y=1.03)
1610
+
1611
+ # Flatten the array of axes if it's multidimensional
1612
+ axs = axs.flatten() if n_rows > 1 else [axs]
1613
+
1614
+ for i, feature in enumerate(features):
1615
+ ax = axs[i]
1616
+ disp = PartialDependenceDisplay.from_estimator(model, X, features=[feature], ax=ax)
1617
+ ax.set_title(feature) # Set title to the name of the feature
1618
+
1619
+ # Hide unused axes if any
1620
+ for ax in axs[len(features):]:
1621
+ ax.set_visible(False)
1622
+
1623
+ plt.tight_layout()
1624
+ plt.show()
1625
+ save_plot(fig, src='figures', variable='partial_dependences', i=1)
1626
+ return fig
1627
+
1628
+ def save_shap_plot(fig, src, variable, i):
1629
+ import os
1630
+ os.makedirs(f'{src}/{variable}', exist_ok=True)
1631
+ filename_fig = f'{src}/{variable}/{str(i)}_figure.pdf'
1632
+ fig.savefig(filename_fig, dpi=600, format='pdf', bbox_inches='tight')
1633
+ print(f"Saved figure as {filename_fig}")
1634
+
1635
+ def generate_shap_summary_plot(df,target='prauc', clean=True):
1636
+ """
1637
+ Generates a SHAP summary plot for the given features in the dataframe.
1638
+
1639
+ Args:
1640
+ df (pandas.DataFrame): The DataFrame containing the data.
1641
+ features (list): List of column names to include as features.
1642
+ target (str): The name of the target variable column.
1643
+
1644
+ Returns:
1645
+ None
1646
+ """
1647
+
1648
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1649
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1650
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1651
+
1652
+ if clean:
1653
+ df = remove_columns_with_single_value(df)
1654
+ features = [feature for feature in features if feature in df.columns]
1655
+
1656
+ X = df[features]
1657
+ y = df[target]
1658
+
1659
+ # Initialize a model (you could pass it as an argument if you'd like to use a different one)
1660
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
1661
+ model.fit(X, y)
1662
+
1663
+ # Calculate SHAP values
1664
+ explainer = shap.TreeExplainer(model)
1665
+ shap_values = explainer.shap_values(X)
1666
+
1667
+ # Summary plot
1668
+ shap.summary_plot(shap_values, X)
1669
+ save_shap_plot(plt.gcf(), src='figures', variable='shap', i=1)
1670
+ #save_shap_plot(fig, src, variable, i)
1671
+ return plt.gcf()
1672
+
1673
+
1674
+ def remove_constant_columns(df):
1675
+ """
1676
+ Removes columns in the DataFrame where all entries have the same value.
1677
+
1678
+ Parameters:
1679
+ df (pd.DataFrame): The input DataFrame from which to remove constant columns.
1680
+
1681
+ Returns:
1682
+ pd.DataFrame: A DataFrame with the constant columns removed.
1683
+ """
1684
+ return df.loc[:, df.nunique() > 1]
1685
+
1686
+
1687
+ # to justify using beta for sim classifier
1688
+
1689
+ # Fit a Beta distribution to these outputs
1690
+ #a, b, loc, scale = beta.fit(predicted_probs, floc=0, fscale=1) # Fix location and scale to match the support of the sigmoid
1691
+
1692
+ # Sample from this fitted Beta distribution
1693
+ #simulated_probs = beta.rvs(a, b, size=1000)
1694
+
1695
+ # Plot the empirical vs simulated distribution
1696
+ #plt.hist(predicted_probs, bins=30, alpha=0.5, label='Empirical')
1697
+ #plt.hist(simulated_probs, bins=30, alpha=0.5, label='Simulated from Beta')
1698
+ #plt.legend()
1699
+ #plt.show()