spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/sim.py CHANGED
@@ -1,16 +1,19 @@
1
1
 
2
- import os, gc, random, warnings, traceback, itertools, matplotlib, sqlite3
3
- import time as tm
2
+ import os, random, warnings, traceback, sqlite3, shap, math, gc
4
3
  from time import time, sleep
5
4
  from datetime import datetime
6
5
  import numpy as np
7
6
  import pandas as pd
8
7
  import matplotlib.pyplot as plt
8
+ from matplotlib.patches import Patch
9
9
  import seaborn as sns
10
10
  import sklearn.metrics as metrics
11
- from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix, precision_recall_curve
11
+ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
12
+ from sklearn.inspection import PartialDependenceDisplay, permutation_importance
13
+ from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_recall_curve
12
14
  import statsmodels.api as sm
13
- from multiprocessing import cpu_count, Value, Array, Lock, Pool, Manager
15
+ from multiprocessing import cpu_count, Pool, Manager
16
+ from copy import deepcopy
14
17
 
15
18
  from .logger import log_function_call
16
19
 
@@ -35,7 +38,7 @@ def generate_gene_list(number_of_genes, number_of_all_genes):
35
38
 
36
39
  # plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
37
40
  def generate_plate_map(nr_plates):
38
- print('nr_plates',nr_plates)
41
+ #print('nr_plates',nr_plates)
39
42
  """
40
43
  Generate a plate map based on the number of plates.
41
44
 
@@ -241,6 +244,7 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
241
244
  - df_ls (list): A list containing gene counts per well, well counts per gene, Gini coefficients for wells,
242
245
  Gini coefficients for genes, gene weights array, and well weights.
243
246
  """
247
+
244
248
  #generate primary distributions and genes
245
249
  cpw, _ = dist_gen(avg_cells_per_well, sd_cells_per_well, plate_map)
246
250
  gpw, _ = dist_gen(avg_genes_per_well, sd_genes_per_well, plate_map)
@@ -251,14 +255,18 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
251
255
  gene_weights_array = np.array(list(gene_weights.values())) # Convert the values to an array
252
256
 
253
257
  well_weights = generate_power_law_distribution(len(plate_map), well_ineq_coeff)
254
-
255
- gene_to_well_mapping = {}
256
-
257
- for gene in genes:
258
- gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=int(gpw[gene-1]), p=well_weights) # Generate a number of wells for each gene according to well_weights
259
-
260
- gene_to_well_mapping = {gene: wells for gene, wells in gene_to_well_mapping.items() if len(wells) >= 2}
261
258
 
259
+ gene_to_well_mapping = {}
260
+ for gene in range(1, number_of_genes + 1): # ensures gene-1 is within bounds
261
+ if gene-1 < len(gpw):
262
+ max_index = len(plate_map['plate_row_column']) # this should be the number of choices available from plate_map
263
+ num_samples = int(gpw[gene-1])
264
+ if num_samples >= max_index:
265
+ num_samples = max_index - 1 # adjust to maximum possible index
266
+ gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=num_samples, replace=False, p=well_weights)
267
+ else:
268
+ break # break the loop if gene-1 is out of bounds for gpw
269
+
262
270
  cells = []
263
271
  for i in [*range(0,len(plate_map))]:
264
272
  ciw = random.choice(cpw)
@@ -316,9 +324,47 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
316
324
  df_ls = [gene_counts_per_well, well_counts_per_gene, gini_well, gini_gene, gene_weights_array, well_weights]
317
325
  return cell_df, genes_per_well_df, wells_per_gene_df, df_ls
318
326
 
319
- # classifier is a function that takes a cell state (active=1/inactive=0) and produces a score in [0, 1]
320
- # For the input cell, it checks if it is active or inactive, and then samples from an appropriate beta distribution to give a score
321
- def classifier(positive_mean, positive_variance, negative_mean, negative_variance, df):
327
+ def classifier(positive_mean, positive_variance, negative_mean, negative_variance, classifier_accuracy, df):
328
+ """
329
+ Classifies the data in the DataFrame based on the given parameters and a classifier error rate.
330
+
331
+ Args:
332
+ positive_mean (float): The mean of the positive distribution.
333
+ positive_variance (float): The variance of the positive distribution.
334
+ negative_mean (float): The mean of the negative distribution.
335
+ negative_variance (float): The variance of the negative distribution.
336
+ classifier_accuracy (float): The likelihood (0 to 1) that a gene is correctly classified according to its true label.
337
+ df (pandas.DataFrame): The DataFrame containing the data to be classified.
338
+
339
+ Returns:
340
+ pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
341
+ """
342
+ def calc_alpha_beta(mean, variance):
343
+ if mean <= 0 or mean >= 1:
344
+ raise ValueError("Mean must be between 0 and 1 exclusively.")
345
+ max_variance = mean * (1 - mean)
346
+ if variance <= 0 or variance >= max_variance:
347
+ raise ValueError(f"Variance must be positive and less than {max_variance}.")
348
+
349
+ alpha = mean * (mean * (1 - mean) / variance - 1)
350
+ beta = alpha * (1 - mean) / mean
351
+ return alpha, beta
352
+
353
+ # Apply the beta distribution based on 'is_active' status with consideration for classifier error
354
+ def get_score(is_active):
355
+ if np.random.rand() < classifier_accuracy: # With classifier_accuracy probability, choose the correct distribution
356
+ return np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2)
357
+ else: # With 1-classifier_accuracy probability, choose the incorrect distribution
358
+ return np.random.beta(a2, b2) if is_active else np.random.beta(a1, b1)
359
+
360
+ # Calculate alpha and beta for both distributions
361
+ a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
362
+ a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
363
+ df['score'] = df['is_active'].apply(get_score)
364
+
365
+ return df
366
+
367
+ def classifier_v2(positive_mean, positive_variance, negative_mean, negative_variance, df):
322
368
  """
323
369
  Classifies the data in the DataFrame based on the given parameters.
324
370
 
@@ -332,12 +378,22 @@ def classifier(positive_mean, positive_variance, negative_mean, negative_varianc
332
378
  Returns:
333
379
  pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
334
380
  """
335
- # alpha and beta for positive distribution
336
- a1 = positive_mean*(positive_mean*(1-positive_mean)/positive_variance - 1)
337
- b1 = a1*(1-positive_mean)/positive_mean
338
- # alpha and beta for negative distribution
339
- a2 = negative_mean*(negative_mean*(1-negative_mean)/negative_variance - 1)
340
- b2 = a2*(1-negative_mean)/negative_mean
381
+ def calc_alpha_beta(mean, variance):
382
+ if mean <= 0 or mean >= 1:
383
+ raise ValueError("Mean must be between 0 and 1 exclusively.")
384
+ max_variance = mean * (1 - mean)
385
+ if variance <= 0 or variance >= max_variance:
386
+ raise ValueError(f"Variance must be positive and less than {max_variance}.")
387
+
388
+ alpha = mean * (mean * (1 - mean) / variance - 1)
389
+ beta = alpha * (1 - mean) / mean
390
+ return alpha, beta
391
+
392
+ # Calculate alpha and beta for both distributions
393
+ a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
394
+ a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
395
+
396
+ # Apply the beta distribution based on 'is_active' status
341
397
  df['score'] = df['is_active'].apply(lambda is_active: np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2))
342
398
  return df
343
399
 
@@ -446,6 +502,7 @@ def generate_well_score(cell_scores):
446
502
  return well_score
447
503
 
448
504
  def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_per_gene, sequencing_error=0.01):
505
+
449
506
  """
450
507
  Simulates the sequencing of plates and calculates gene fractions and metadata.
451
508
 
@@ -460,6 +517,7 @@ def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_pe
460
517
  gene_fraction_map (pd.DataFrame): DataFrame containing gene fractions for each well.
461
518
  metadata (pd.DataFrame): DataFrame containing metadata for each well.
462
519
  """
520
+
463
521
  reads, _ = dist_gen(avg_reads_per_gene, sd_reads_per_gene, well_score)
464
522
  gene_names = [f'gene_{v}' for v in range(number_of_genes+1)]
465
523
  all_wells = well_score.index
@@ -612,8 +670,12 @@ def plot_histogram(data, x_label, ax, color, title, binwidth=0.01, log=False):
612
670
  Returns:
613
671
  None
614
672
  """
615
- sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
616
- legend=False, fill=True, element='step', palette='dark')
673
+ if not binwidth:
674
+ sns.histplot(data=data, x=x_label, ax=ax, color=color, kde=False, stat='density',
675
+ legend=False, fill=True, element='step', palette='dark')
676
+ else:
677
+ sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
678
+ legend=False, fill=True, element='step', palette='dark')
617
679
  if log:
618
680
  ax.set_yscale('log')
619
681
  ax.set_title(title)
@@ -700,7 +762,7 @@ def run_simulation(settings):
700
762
  plate_map = plate_map[~plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is not in [1,2,3,23,24]
701
763
 
702
764
  cell_level, genes_per_well_df, wells_per_gene_df, dists = run_experiment(plate_map, settings['number_of_genes'], active_gene_list, settings['avg_genes_per_well'], settings['sd_genes_per_well'], settings['avg_cells_per_well'], settings['sd_cells_per_well'], settings['well_ineq_coeff'], settings['gene_ineq_coeff'])
703
- cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], df=cell_level)
765
+ cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], settings['classifier_accuracy'], df=cell_level)
704
766
  cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm = cell_level_roc_auc(cell_scores)
705
767
  well_score = generate_well_score(cell_scores)
706
768
  gene_fraction_map, metadata = sequence_plates(well_score, settings['number_of_genes'], settings['avg_reads_per_gene'], settings['sd_reads_per_gene'], sequencing_error=settings['sequencing_error'])
@@ -720,7 +782,10 @@ def run_simulation(settings):
720
782
  results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats = regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False)
721
783
  #except Exception as e:
722
784
  # print(f"An error occurred while saving data: {e}")
723
- return [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df], dists
785
+ output = [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df]
786
+ del cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df
787
+ gc.collect()
788
+ return output, dists
724
789
 
725
790
  def vis_dists(dists, src, v, i):
726
791
  """
@@ -744,9 +809,14 @@ def vis_dists(dists, src, v, i):
744
809
  for index, dist in enumerate(dists):
745
810
  temp = pd.DataFrame(dist, columns = [f'{names[index]}'])
746
811
  sns.histplot(data=temp, x=f'{names[index]}', kde=False, binwidth=None, stat='count', element="step", ax=ax[n], color='teal', log_scale=False)
747
- #plot_histogram(temp, f'{names[index]}', ax[n], 'slategray', f'{names[index]}', binwidth=None, log=False)
748
812
  n+=1
749
813
  save_plot(fig2, src, 'dists', i)
814
+ plt.close(fig2)
815
+ plt.figure().clear()
816
+ plt.cla()
817
+ plt.clf()
818
+ del dists
819
+
750
820
  return
751
821
 
752
822
  def visualize_all(output):
@@ -773,6 +843,7 @@ def visualize_all(output):
773
843
  Returns:
774
844
  fig (matplotlib.figure.Figure): The generated figure object.
775
845
  """
846
+
776
847
  cell_scores = output[0]
777
848
  cell_roc_dict_df = output[1]
778
849
  cell_pr_dict_df = output[2]
@@ -810,24 +881,35 @@ def visualize_all(output):
810
881
  #plot cell classification score by inactive and active
811
882
  active_distribution = cell_scores[cell_scores['is_active'] == 1]
812
883
  inactive_distribution = cell_scores[cell_scores['is_active'] == 0]
813
- plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', binwidth=0.01, log=False)
814
- plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', binwidth=0.01, log=False)
884
+ plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', log=False)#, binwidth=0.01, log=False)
885
+ plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', log=False)#, binwidth=0.01, log=False)
886
+
887
+ legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
888
+ Patch(facecolor='teal', edgecolor='teal', label='Active')]
889
+
890
+ ax[n].legend(handles=legend_elements, loc='upper right')
891
+
892
+
815
893
  ax[n].set_xlim([0, 1])
816
894
  n+=1
817
895
 
818
896
  #plot classifier cell predictions by inactive and active well average
819
- ##inactive_distribution_well['score'] = pd.to_numeric(inactive_distribution['score'], errors='coerce')
820
- ##inactive_distribution_well = inactive_distribution_well.groupby('plate_row_column')['score'].mean()
821
-
822
- ##active_distribution_well['score'] = pd.to_numeric(active_distribution['score'], errors='coerce')
823
- ##active_distribution_well = active_distribution_well.groupby('plate_row_column')['score'].mean()
897
+ inactive_distribution_well = inactive_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
898
+ active_distribution_well = active_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
899
+ mixed_distribution_well = cell_scores.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
900
+
901
+ plot_histogram(inactive_distribution_well, "score", ax[n], 'slategray', 'Well scores', log=False)#, binwidth=0.01, log=False)
902
+ plot_histogram(active_distribution_well, "score", ax[n], 'teal', 'Well scores', log=False)#, binwidth=0.01, log=False)
903
+ plot_histogram(mixed_distribution_well, "score", ax[n], 'red', 'Well scores', log=False)#, binwidth=0.01, log=False)
824
904
 
825
- #inactive_distribution_well = inactive_distribution.groupby(['plate_row_column']).mean()
826
- #active_distribution_well = active_distribution.groupby(['plate_row_column']).mean()
905
+ legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
906
+ Patch(facecolor='teal', edgecolor='teal', label='Active'),
907
+ Patch(facecolor='red', edgecolor='red', label='Mixed')]
827
908
 
828
- plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Well scores', binwidth=0.01, log=False)
829
- plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Well scores', binwidth=0.01, log=False)
909
+ ax[n].legend(handles=legend_elements, loc='upper right')
910
+
830
911
  ax[n].set_xlim([0, 1])
912
+ #ax[n].legend()
831
913
  n+=1
832
914
 
833
915
  #plot ROC (cell classification)
@@ -846,8 +928,8 @@ def visualize_all(output):
846
928
  n+=1
847
929
 
848
930
  #plot well score
849
- plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=False)
850
- ax[n].set_xlim([0, 1])
931
+ plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=True)
932
+ #ax[n].set_xlim([0, 1])
851
933
  n+=1
852
934
 
853
935
  control_df = results_df[results_df['color'] == 'control']
@@ -911,6 +993,7 @@ def visualize_all(output):
911
993
 
912
994
  plt.tight_layout()
913
995
  plt.show()
996
+ gc.collect()
914
997
  return fig
915
998
 
916
999
  def create_database(db_path):
@@ -992,6 +1075,7 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
992
1075
  df_concat[f'variable_{variable}_sim_nr'] = i
993
1076
 
994
1077
  append_database(src, df_concat, 'simulations')
1078
+ del gini_genes_per_well, gini_wells_per_gene, df_concat
995
1079
 
996
1080
  if save_all:
997
1081
  for i, df in enumerate(output):
@@ -1001,8 +1085,12 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
1001
1085
  if not isinstance(df, pd.DataFrame):
1002
1086
  df = pd.DataFrame(df)
1003
1087
  append_database(src, df, table_names[i])
1088
+ del df
1004
1089
  except Exception as e:
1005
1090
  print(f"An error occurred while saving data: {e}")
1091
+ print(traceback.format_exc())
1092
+
1093
+ del output, settings_df
1006
1094
  return
1007
1095
 
1008
1096
  def save_plot(fig, src, variable, i):
@@ -1025,7 +1113,6 @@ def save_plot(fig, src, variable, i):
1025
1113
 
1026
1114
  def run_and_save(i, settings, time_ls, total_sims):
1027
1115
 
1028
-
1029
1116
  """
1030
1117
  Run the simulation and save the results.
1031
1118
 
@@ -1040,24 +1127,24 @@ def run_and_save(i, settings, time_ls, total_sims):
1040
1127
  """
1041
1128
  #print(f'Runnings simulation with the following paramiters')
1042
1129
  #print(settings)
1043
-
1130
+ settings['random_seed'] = False
1044
1131
  if settings['random_seed']:
1045
1132
  random.seed(42) # sims will be too similar with random seed
1046
1133
  src = settings['src']
1047
1134
  plot = settings['plot']
1048
1135
  v = settings['variable']
1049
1136
  start_time = time() # Start time of the simulation
1050
- now = datetime.now() # get current date
1051
- date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1137
+ #now = datetime.now() # get current date
1138
+ #date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1139
+ date_string = settings['start_time']
1052
1140
  #try:
1053
1141
  output, dists = run_simulation(settings)
1054
1142
  sim_time = time() - start_time # Elapsed time for the simulation
1055
1143
  settings['sim_time'] = sim_time
1056
1144
  src = os.path.join(f'{src}/{date_string}',settings['name'])
1057
1145
  save_data(src, output, settings, save_all=False, i=i, variable=v)
1058
- if vis_dists:
1059
- vis_dists(dists,src, v, i)
1060
1146
  if plot:
1147
+ vis_dists(dists,src, v, i)
1061
1148
  fig = visualize_all(output)
1062
1149
  save_plot(fig, src, v, i)
1063
1150
  plt.close(fig)
@@ -1066,6 +1153,7 @@ def run_and_save(i, settings, time_ls, total_sims):
1066
1153
  plt.clf()
1067
1154
  del fig
1068
1155
  del output, dists
1156
+ gc.collect()
1069
1157
  #except Exception as e:
1070
1158
  # print(e, end='\r', flush=True)
1071
1159
  # sim_time = time() - start_time
@@ -1073,98 +1161,92 @@ def run_and_save(i, settings, time_ls, total_sims):
1073
1161
  time_ls.append(sim_time)
1074
1162
  return i, sim_time, None
1075
1163
 
1076
- def generate_paramiters(settings):
1164
+ def validate_and_adjust_beta_params(sim_params):
1077
1165
  """
1078
- Generate a list of parameter sets for simulation based on the given settings.
1079
-
1166
+ Validates and adjusts Beta distribution parameters in simulation settings to ensure they are possible.
1167
+
1080
1168
  Args:
1081
- settings (dict): A dictionary containing the simulation settings.
1082
-
1169
+ sim_params (list of dict): List of dictionaries, each containing the simulation parameters.
1170
+
1083
1171
  Returns:
1084
- list: A list of parameter sets for simulation.
1172
+ list of dict: The adjusted list of simulation parameter sets.
1085
1173
  """
1086
- sim_ls = []
1087
- for avg_genes_per_well in settings['avg_genes_per_well']:
1088
- replicates = settings['replicates']
1089
- sett = settings.copy()
1090
- sett['avg_genes_per_well'] = avg_genes_per_well
1091
- sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
1092
- for avg_cells_per_well in settings['avg_cells_per_well']:
1093
- sett['avg_cells_per_well'] = avg_cells_per_well
1094
- sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
1095
- for positive_mean in settings['positive_mean']:
1096
- sett['positive_mean'] = positive_mean
1097
- sett['negative_mean'] = 1-positive_mean
1098
- sett['positive_variance'] = (1-positive_mean)/2
1099
- sett['negative_variance'] = (1-positive_mean)/2
1100
- for avg_reads_per_gene in settings['avg_reads_per_gene']:
1101
- sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
1102
- sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
1103
- for sequencing_error in settings['sequencing_error']:
1104
- sett['sequencing_error'] = sequencing_error
1105
- for well_ineq_coeff in settings['well_ineq_coeff']:
1106
- sett['well_ineq_coeff'] = well_ineq_coeff
1107
- for gene_ineq_coeff in settings['gene_ineq_coeff']:
1108
- sett['gene_ineq_coeff'] = gene_ineq_coeff
1109
- for nr_plates in settings['nr_plates']:
1110
- sett['nr_plates'] = nr_plates
1111
- for number_of_genes in settings['number_of_genes']:
1112
- sett['number_of_genes'] = number_of_genes
1113
- for number_of_active_genes in settings['number_of_active_genes']:
1114
- sett['number_of_active_genes'] = number_of_active_genes
1115
- for i in [*range(1,replicates+1)]:
1116
- sim_ls.append(sett)
1117
- #print(sett)
1118
- #print('Number of simulations:',len(sim_ls))
1119
- return sim_ls
1174
+ adjusted_params = []
1175
+ for params in sim_params:
1176
+ max_pos_variance = params['positive_mean'] * (1 - params['positive_mean'])
1177
+ max_neg_variance = params['negative_mean'] * (1 - params['negative_mean'])
1178
+
1179
+ # Adjust positive variance
1180
+ if params['positive_variance'] >= max_pos_variance:
1181
+ print(f'changed positive variance from {params["positive_variance"]} to {max_pos_variance * 0.99}')
1182
+ params['positive_variance'] = max_pos_variance * 0.99 # Adjust to 99% of the maximum allowed variance
1183
+
1184
+ # Adjust negative variance
1185
+ if params['negative_variance'] >= max_neg_variance:
1186
+ print(f'changed negative variance from {params["negative_variance"]} to {max_neg_variance * 0.99}')
1187
+ params['negative_variance'] = max_neg_variance * 0.99 # Adjust to 99% of the maximum allowed variance
1188
+
1189
+ adjusted_params.append(params)
1190
+
1191
+ return adjusted_params
1192
+
1193
+ def generate_paramiters(settings):
1120
1194
 
1121
- #altered for one set of settings see negative_mean and variance
1122
- def generate_paramiters_single(settings):
1123
1195
  """
1124
- Generate a list of parameter sets for single simulations based on the given settings.
1196
+ Generate a list of parameter sets for simulation based on the given settings.
1125
1197
 
1126
1198
  Args:
1127
1199
  settings (dict): A dictionary containing the simulation settings.
1128
1200
 
1129
1201
  Returns:
1130
- list: A list of parameter sets for single simulations.
1202
+ list: A list of parameter sets for simulation.
1131
1203
  """
1204
+
1205
+ settings['positive_mean'] = [0.8]
1206
+
1132
1207
  sim_ls = []
1133
1208
  for avg_genes_per_well in settings['avg_genes_per_well']:
1134
1209
  replicates = settings['replicates']
1135
- sett = settings.copy()
1136
- sett['avg_genes_per_well'] = avg_genes_per_well
1137
- sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
1138
1210
  for avg_cells_per_well in settings['avg_cells_per_well']:
1139
- sett['avg_cells_per_well'] = avg_cells_per_well
1140
- sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
1141
- for positive_mean in settings['positive_mean']:
1142
- sett['positive_mean'] = positive_mean
1143
- sett['negative_mean'] = 0.2
1144
- sett['positive_variance'] = 0.13
1145
- sett['negative_variance'] = 0.13
1146
- for avg_reads_per_gene in settings['avg_reads_per_gene']:
1147
- sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
1148
- sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
1149
- for sequencing_error in settings['sequencing_error']:
1150
- sett['sequencing_error'] = sequencing_error
1151
- for well_ineq_coeff in settings['well_ineq_coeff']:
1152
- sett['well_ineq_coeff'] = well_ineq_coeff
1153
- for gene_ineq_coeff in settings['gene_ineq_coeff']:
1154
- sett['gene_ineq_coeff'] = gene_ineq_coeff
1155
- for nr_plates in settings['nr_plates']:
1156
- sett['nr_plates'] = nr_plates
1157
- for number_of_genes in settings['number_of_genes']:
1158
- sett['number_of_genes'] = number_of_genes
1159
- for number_of_active_genes in settings['number_of_active_genes']:
1160
- sett['number_of_active_genes'] = number_of_active_genes
1161
- for i in [*range(1,replicates+1)]:
1162
- sim_ls.append(sett)
1163
- #print(sett)
1164
- #print('Number of simulations:',len(sim_ls))
1211
+ for classifier_accuracy in settings['classifier_accuracy']:
1212
+ for positive_mean in settings['positive_mean']:
1213
+ for avg_reads_per_gene in settings['avg_reads_per_gene']:
1214
+ for sequencing_error in settings['sequencing_error']:
1215
+ for well_ineq_coeff in settings['well_ineq_coeff']:
1216
+ for gene_ineq_coeff in settings['gene_ineq_coeff']:
1217
+ for nr_plates in settings['nr_plates']:
1218
+ for number_of_genes in settings['number_of_genes']:
1219
+ for number_of_active_genes in settings['number_of_active_genes']:
1220
+ for i in range(1, replicates+1):
1221
+ sett = deepcopy(settings)
1222
+ sett['avg_genes_per_well'] = avg_genes_per_well
1223
+ sett['sd_genes_per_well'] = avg_genes_per_well / 2
1224
+ sett['avg_cells_per_well'] = avg_cells_per_well
1225
+ sett['sd_cells_per_well'] = avg_cells_per_well / 2
1226
+ sett['classifier_accuracy'] = classifier_accuracy
1227
+ sett['positive_mean'] = positive_mean
1228
+ sett['negative_mean'] = 1-positive_mean
1229
+ sett['positive_variance'] = (1-positive_mean)/2
1230
+ sett['negative_variance'] = (1-positive_mean)/2
1231
+ sett['avg_reads_per_gene'] = avg_reads_per_gene
1232
+ sett['sd_reads_per_gene'] = avg_reads_per_gene / 2
1233
+ sett['sequencing_error'] = sequencing_error
1234
+ sett['well_ineq_coeff'] = well_ineq_coeff
1235
+ sett['gene_ineq_coeff'] = gene_ineq_coeff
1236
+ sett['nr_plates'] = nr_plates
1237
+ sett['number_of_genes'] = number_of_genes
1238
+ sett['number_of_active_genes'] = number_of_active_genes
1239
+ sim_ls.append(sett)
1240
+
1241
+ random.shuffle(sim_ls)
1242
+ sim_ls = validate_and_adjust_beta_params(sim_ls)
1243
+ print(f'Running {len(sim_ls)} simulations.')
1244
+ #for x in sim_ls:
1245
+ # print(x['positive_mean'])
1165
1246
  return sim_ls
1166
1247
 
1167
1248
  def run_multiple_simulations(settings):
1249
+
1168
1250
  """
1169
1251
  Run multiple simulations in parallel using the provided settings.
1170
1252
 
@@ -1175,8 +1257,12 @@ def run_multiple_simulations(settings):
1175
1257
  None
1176
1258
  """
1177
1259
 
1260
+ now = datetime.now() # get current date
1261
+ start_time = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1262
+ settings['start_time'] = start_time
1263
+
1178
1264
  sim_ls = generate_paramiters(settings)
1179
- print(f'Running {len(sim_ls)} simulations. Standard deviations for each variable are variable / 2')
1265
+ #print(f'Running {len(sim_ls)} simulations.')
1180
1266
 
1181
1267
  max_workers = settings['max_workers'] or cpu_count() - 4
1182
1268
  with Manager() as manager:
@@ -1185,10 +1271,429 @@ def run_multiple_simulations(settings):
1185
1271
  with Pool(max_workers) as pool:
1186
1272
  result = pool.starmap_async(run_and_save, [(index, settings, time_ls, total_sims) for index, settings in enumerate(sim_ls)])
1187
1273
  while not result.ready():
1188
- sleep(0.01)
1189
- sims_processed = len(time_ls)
1190
- average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1191
- time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
1192
- print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
1193
- result.get()
1274
+ try:
1275
+ sleep(0.01)
1276
+ sims_processed = len(time_ls)
1277
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1278
+ time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
1279
+ print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
1280
+ gc.collect()
1281
+ except Exception as e:
1282
+ print(e)
1283
+ print(traceback.format_exc())
1284
+ try:
1285
+ result.get()
1286
+ except Exception as e:
1287
+ print(e)
1288
+ print(traceback.format_exc())
1194
1289
 
1290
+ def generate_integers(start, stop, step):
1291
+ return list(range(start, stop + 1, step))
1292
+
1293
+ def generate_floats(start, stop, step):
1294
+ # Determine the number of decimal places in 'step'
1295
+ num_decimals = str(step)[::-1].find('.')
1296
+
1297
+ current = start
1298
+ floats_list = []
1299
+ while current <= stop:
1300
+ # Round each float to the appropriate number of decimal places
1301
+ floats_list.append(round(current, num_decimals))
1302
+ current += step
1303
+
1304
+ return floats_list
1305
+
1306
+ def remove_columns_with_single_value(df):
1307
+ """
1308
+ Removes columns from the DataFrame that have the same value in all rows.
1309
+
1310
+ Args:
1311
+ df (pandas.DataFrame): The original DataFrame.
1312
+
1313
+ Returns:
1314
+ pandas.DataFrame: A DataFrame with the columns removed that contained only one unique value.
1315
+ """
1316
+ to_drop = [column for column in df.columns if df[column].nunique() == 1]
1317
+ return df.drop(to_drop, axis=1)
1318
+
1319
+ def read_simulations_table(db_path):
1320
+ """
1321
+ Reads the 'simulations' table from an SQLite database into a pandas DataFrame.
1322
+
1323
+ Args:
1324
+ db_path (str): The file path to the SQLite database.
1325
+
1326
+ Returns:
1327
+ pandas.DataFrame: DataFrame containing the 'simulations' table data.
1328
+ """
1329
+ # Create a connection object using the connect function
1330
+ conn = sqlite3.connect(db_path)
1331
+
1332
+ # Read the 'simulations' table into a pandas DataFrame
1333
+ try:
1334
+ df = pd.read_sql_query("SELECT * FROM simulations", conn)
1335
+ except Exception as e:
1336
+ print(f"An error occurred: {e}")
1337
+ return None
1338
+ finally:
1339
+ # Close the connection to SQLite database
1340
+ conn.close()
1341
+
1342
+ return df
1343
+
1344
+ def plot_simulations(df, variable, x_rotation=None, legend=False, grid=False, clean=True, verbose=False):
1345
+
1346
+ """
1347
+ Creates separate line plots for 'prauc' against a specified 'variable',
1348
+ for each unique combination of conditions defined by 'grouping_vars', displayed on a grid.
1349
+
1350
+ Args:
1351
+ df (pandas.DataFrame): DataFrame containing the necessary columns.
1352
+ variable (str): Name of the column to use as the x-axis for grouping and plotting.
1353
+ x_rotation (int, optional): Degrees to rotate the x-axis labels.
1354
+ legend (bool, optional): Whether to display a legend.
1355
+ grid (bool, optional): Whether to display grid lines.
1356
+ verbose (bool, optional): Whether to print the filter conditions.
1357
+
1358
+ Returns:
1359
+ None
1360
+ """
1361
+
1362
+ grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1363
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1364
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1365
+
1366
+ if clean:
1367
+ relevant_data = remove_columns_with_single_value(relevant_data)
1368
+
1369
+ grouping_vars = [col for col in grouping_vars if col != variable]
1370
+
1371
+ # Check if the necessary columns are present in the DataFrame
1372
+ required_columns = {variable, 'prauc'} | set(grouping_vars)
1373
+ if not required_columns.issubset(df.columns):
1374
+ missing_cols = required_columns - set(df.columns)
1375
+ raise ValueError(f"DataFrame must contain {missing_cols} columns")
1376
+
1377
+ #if not dependent is None:
1378
+
1379
+ # Get unique combinations of conditions from grouping_vars
1380
+ unique_combinations = df[grouping_vars].drop_duplicates()
1381
+ num_combinations = len(unique_combinations)
1382
+
1383
+ # Determine the layout of the subplots
1384
+ num_rows = math.ceil(np.sqrt(num_combinations))
1385
+ num_cols = math.ceil(num_combinations / num_rows)
1386
+
1387
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
1388
+ if num_rows * num_cols > 1:
1389
+ axes = axes.flatten()
1390
+ else:
1391
+ axes = [axes]
1392
+
1393
+ for idx, (ax, (_, row)) in enumerate(zip(axes, unique_combinations.iterrows())):
1394
+
1395
+ # Filter the DataFrame for the current combination of variables
1396
+ condition = {var: row[var] for var in grouping_vars}
1397
+ subset_df = df[df[grouping_vars].eq(row).all(axis=1)]
1398
+
1399
+ # Group by 'variable' and calculate mean and std dev of 'prauc'
1400
+ grouped = subset_df.groupby(variable)['prauc'].agg(['mean', 'std'])
1401
+ grouped = grouped.sort_index() # Sort by the variable for orderly plots
1402
+
1403
+ # Plotting the mean of 'prauc' with std deviation as shaded area
1404
+ ax.plot(grouped.index, grouped['mean'], marker='o', linestyle='-', color='b', label='Mean PRAUC')
1405
+ ax.fill_between(grouped.index, grouped['mean'] - grouped['std'], grouped['mean'] + grouped['std'], color='gray', alpha=0.5, label='Std Dev')
1406
+
1407
+ # Setting plot labels and title
1408
+ title_details = ', '.join([f"{var}={row[var]}" for var in grouping_vars])
1409
+ ax.set_xlabel(variable)
1410
+ ax.set_ylabel('Precision-Recall AUC (PRAUC)')
1411
+ #ax.set_title(f'PRAUC vs. {variable} | {title_details}')
1412
+ ax.grid(grid)
1413
+
1414
+ if legend:
1415
+ ax.legend()
1416
+
1417
+ # Set x-ticks and rotate them as specified
1418
+ ax.set_xticks(grouped.index)
1419
+ ax.set_xticklabels(grouped.index, rotation=x_rotation if x_rotation is not None else 45)
1420
+
1421
+ if verbose:
1422
+ verbose_text = '\n'.join([f"{var}: {val}" for var, val in condition.items()])
1423
+ ax.text(0.95, 0.05, verbose_text, transform=ax.transAxes, fontsize=9, verticalalignment='bottom', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))
1424
+
1425
+ # Hide any unused axes if there are any
1426
+ for ax in axes[idx+1:]:
1427
+ ax.set_visible(False)
1428
+
1429
+ plt.tight_layout()
1430
+ plt.show()
1431
+ return fig
1432
+
1433
+ def plot_correlation_matrix(df, annot=False, cmap='inferno', clean=True):
1434
+ """
1435
+ Plots a correlation matrix for the specified variables and the target variable.
1436
+
1437
+ Args:
1438
+ df (pandas.DataFrame): The DataFrame containing the data.
1439
+ variables (list): List of column names to include in the correlation matrix.
1440
+ target_variable (str): The target variable column name.
1441
+
1442
+ Returns:
1443
+ None
1444
+ """
1445
+ cmap = sns.diverging_palette(240, 10, as_cmap=True)
1446
+ grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1447
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1448
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1449
+
1450
+ grouping_vars = grouping_vars + ['optimal_threshold', 'accuracy', 'prauc', 'roc_auc','genes_per_well_gini', 'wells_per_gene_gini']
1451
+ # 'inactive_mean', 'inactive_std', 'inactive_var', 'active_mean', 'active_std', 'inactive_var', 'cutoff', 'TP', 'FP', 'TN', 'FN',
1452
+
1453
+ if clean:
1454
+ df = remove_constant_columns(df)
1455
+ grouping_vars = [feature for feature in grouping_vars if feature in df.columns]
1456
+
1457
+ # Subsetting the DataFrame to include only the relevant variables
1458
+ relevant_data = df[grouping_vars]
1459
+
1460
+ if clean:
1461
+ relevant_data = remove_columns_with_single_value(relevant_data)
1462
+
1463
+ # Calculating the correlation matrix
1464
+ corr_matrix = relevant_data.corr()
1465
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
1466
+
1467
+ # Plotting the correlation matrix
1468
+ fig = plt.figure(figsize=(12, 8))
1469
+ sns.heatmap(corr_matrix, mask=mask, annot=annot, cmap=cmap, fmt=".2f", linewidths=.5, robust=True)
1470
+ #plt.title('Correlation Matrix with Heatmap')
1471
+
1472
+ plt.tight_layout()
1473
+ plt.show()
1474
+ save_plot(fig, src='figures', variable='correlation_matrix', i=1)
1475
+ return fig
1476
+
1477
+ def plot_feature_importance(df, target='prauc', exclude=None, clean=True):
1478
+ """
1479
+ Trains a RandomForestRegressor to determine the importance of each feature in predicting the target.
1480
+
1481
+ Args:
1482
+ df (pandas.DataFrame): The DataFrame containing the data.
1483
+ target (str): The target variable column name.
1484
+ exclude (list or str, optional): Column names to exclude from features.
1485
+
1486
+ Returns:
1487
+ matplotlib.figure.Figure: The figure object containing the feature importance plot.
1488
+ """
1489
+
1490
+ # Define the features for the model
1491
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1492
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1493
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1494
+
1495
+ if clean:
1496
+ df = remove_columns_with_single_value(df)
1497
+ features = [feature for feature in features if feature in df.columns]
1498
+
1499
+ # Remove excluded features if specified
1500
+ if isinstance(exclude, list):
1501
+ features = [feature for feature in features if feature not in exclude]
1502
+ elif exclude is not None:
1503
+ features = [feature for feature in features if feature != exclude]
1504
+
1505
+ # Train the model
1506
+ model = RandomForestRegressor(n_estimators=1000, random_state=42)
1507
+ model.fit(df[features], df[target])
1508
+
1509
+ # Get feature importances
1510
+ importances = model.feature_importances_
1511
+ indices = np.argsort(importances)[::-1]
1512
+
1513
+ # Plot horizontal bar chart
1514
+ fig = plt.figure(figsize=(12, 6))
1515
+ plt.barh(range(len(indices)), importances[indices], color="teal", align="center", alpha=0.6)
1516
+ plt.yticks(range(len(indices)), [features[i] for i in indices[::-1]]) # Invert y-axis to match the order
1517
+ plt.gca().invert_yaxis() # Invert the axis to have the highest importance at the top
1518
+ plt.xlabel('Feature Importance')
1519
+ plt.title('Feature Importances')
1520
+ plt.tight_layout()
1521
+ plt.show()
1522
+ save_plot(fig, src='figures', variable='feature_importance', i=1)
1523
+ return fig
1524
+
1525
+ def calculate_permutation_importance(df, target='prauc', exclude=None, n_repeats=10, clean=True):
1526
+ """
1527
+ Calculates permutation importance for the given features in the dataframe.
1528
+
1529
+ Args:
1530
+ df (pandas.DataFrame): The DataFrame containing the data.
1531
+ features (list): List of column names to include as features.
1532
+ target (str): The name of the target variable column.
1533
+
1534
+ Returns:
1535
+ dict: Dictionary containing the importances and standard deviations.
1536
+ """
1537
+
1538
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1539
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1540
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1541
+
1542
+ if clean:
1543
+ df = remove_columns_with_single_value(df)
1544
+ features = [feature for feature in features if feature in df.columns]
1545
+
1546
+ if isinstance(exclude, list):
1547
+ for ex in exclude:
1548
+ features.remove(ex)
1549
+ if not exclude is None:
1550
+ features.remove(exclude)
1551
+
1552
+ X = df[features]
1553
+ y = df[target]
1554
+
1555
+ # Initialize a model (you could pass it as an argument if you'd like to use a different one)
1556
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
1557
+ model.fit(X, y)
1558
+
1559
+ perm_importance = permutation_importance(model, X, y, n_repeats=n_repeats, random_state=42)
1560
+
1561
+ # Plotting
1562
+ sorted_idx = perm_importance.importances_mean.argsort()
1563
+
1564
+ # Create a figure and a set of subplots
1565
+ fig, ax = plt.subplots()
1566
+ ax.barh(range(len(sorted_idx)), perm_importance.importances_mean[sorted_idx], color="teal", align="center", alpha=0.6)
1567
+ ax.set_yticks(range(len(sorted_idx)))
1568
+ ax.set_yticklabels([df.columns[i] for i in sorted_idx])
1569
+ ax.set_xlabel('Permutation Importance')
1570
+ plt.tight_layout()
1571
+ plt.show()
1572
+ save_plot(fig, src='figures', variable='permutation_importance', i=1)
1573
+ return fig
1574
+
1575
+ def plot_partial_dependences(df, target='prauc', clean=True):
1576
+
1577
+ """
1578
+ Creates partial dependence plots for the specified features, with improved layout to avoid text overlap.
1579
+
1580
+ Args:
1581
+ df (pandas.DataFrame): The DataFrame containing the data.
1582
+ target (str): The target variable.
1583
+
1584
+ Returns:
1585
+ None
1586
+ """
1587
+
1588
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1589
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1590
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1591
+
1592
+ if clean:
1593
+ df = remove_columns_with_single_value(df)
1594
+ features = [feature for feature in features if feature in df.columns]
1595
+
1596
+ X = df[features]
1597
+ y = df[target]
1598
+
1599
+ # Train a model
1600
+ model = GradientBoostingRegressor()
1601
+ model.fit(X, y)
1602
+
1603
+ # Determine the number of rows and columns for subplots
1604
+ n_cols = 4 # Number of columns in subplot grid
1605
+ n_rows = (len(features) + n_cols - 1) // n_cols # Calculate rows needed
1606
+
1607
+ # Plot partial dependence
1608
+ fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(5 * n_cols, 5 * n_rows))
1609
+ fig.suptitle('Partial Dependence Plots', fontsize=20, y=1.03)
1610
+
1611
+ # Flatten the array of axes if it's multidimensional
1612
+ axs = axs.flatten() if n_rows > 1 else [axs]
1613
+
1614
+ for i, feature in enumerate(features):
1615
+ ax = axs[i]
1616
+ disp = PartialDependenceDisplay.from_estimator(model, X, features=[feature], ax=ax)
1617
+ ax.set_title(feature) # Set title to the name of the feature
1618
+
1619
+ # Hide unused axes if any
1620
+ for ax in axs[len(features):]:
1621
+ ax.set_visible(False)
1622
+
1623
+ plt.tight_layout()
1624
+ plt.show()
1625
+ save_plot(fig, src='figures', variable='partial_dependences', i=1)
1626
+ return fig
1627
+
1628
+ def save_shap_plot(fig, src, variable, i):
1629
+ import os
1630
+ os.makedirs(f'{src}/{variable}', exist_ok=True)
1631
+ filename_fig = f'{src}/{variable}/{str(i)}_figure.pdf'
1632
+ fig.savefig(filename_fig, dpi=600, format='pdf', bbox_inches='tight')
1633
+ print(f"Saved figure as {filename_fig}")
1634
+
1635
+ def generate_shap_summary_plot(df,target='prauc', clean=True):
1636
+ """
1637
+ Generates a SHAP summary plot for the given features in the dataframe.
1638
+
1639
+ Args:
1640
+ df (pandas.DataFrame): The DataFrame containing the data.
1641
+ features (list): List of column names to include as features.
1642
+ target (str): The name of the target variable column.
1643
+
1644
+ Returns:
1645
+ None
1646
+ """
1647
+
1648
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1649
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1650
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1651
+
1652
+ if clean:
1653
+ df = remove_columns_with_single_value(df)
1654
+ features = [feature for feature in features if feature in df.columns]
1655
+
1656
+ X = df[features]
1657
+ y = df[target]
1658
+
1659
+ # Initialize a model (you could pass it as an argument if you'd like to use a different one)
1660
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
1661
+ model.fit(X, y)
1662
+
1663
+ # Calculate SHAP values
1664
+ explainer = shap.TreeExplainer(model)
1665
+ shap_values = explainer.shap_values(X)
1666
+
1667
+ # Summary plot
1668
+ shap.summary_plot(shap_values, X)
1669
+ save_shap_plot(plt.gcf(), src='figures', variable='shap', i=1)
1670
+ #save_shap_plot(fig, src, variable, i)
1671
+ return plt.gcf()
1672
+
1673
+
1674
+ def remove_constant_columns(df):
1675
+ """
1676
+ Removes columns in the DataFrame where all entries have the same value.
1677
+
1678
+ Parameters:
1679
+ df (pd.DataFrame): The input DataFrame from which to remove constant columns.
1680
+
1681
+ Returns:
1682
+ pd.DataFrame: A DataFrame with the constant columns removed.
1683
+ """
1684
+ return df.loc[:, df.nunique() > 1]
1685
+
1686
+
1687
+ # to justify using beta for sim classifier
1688
+
1689
+ # Fit a Beta distribution to these outputs
1690
+ #a, b, loc, scale = beta.fit(predicted_probs, floc=0, fscale=1) # Fix location and scale to match the support of the sigmoid
1691
+
1692
+ # Sample from this fitted Beta distribution
1693
+ #simulated_probs = beta.rvs(a, b, size=1000)
1694
+
1695
+ # Plot the empirical vs simulated distribution
1696
+ #plt.hist(predicted_probs, bins=30, alpha=0.5, label='Empirical')
1697
+ #plt.hist(simulated_probs, bins=30, alpha=0.5, label='Simulated from Beta')
1698
+ #plt.legend()
1699
+ #plt.show()