spacr 0.0.20__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/sim.py CHANGED
@@ -1,16 +1,19 @@
1
1
 
2
- import os, gc, random, warnings, traceback, itertools, matplotlib, sqlite3
3
- import time as tm
2
+ import os, random, warnings, traceback, sqlite3, shap, math, gc
4
3
  from time import time, sleep
5
4
  from datetime import datetime
6
5
  import numpy as np
7
6
  import pandas as pd
8
7
  import matplotlib.pyplot as plt
8
+ from matplotlib.patches import Patch
9
9
  import seaborn as sns
10
10
  import sklearn.metrics as metrics
11
- from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix, precision_recall_curve
11
+ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
12
+ from sklearn.inspection import PartialDependenceDisplay, permutation_importance
13
+ from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_recall_curve
12
14
  import statsmodels.api as sm
13
- from multiprocessing import cpu_count, Value, Array, Lock, Pool, Manager
15
+ from multiprocessing import cpu_count, Pool, Manager
16
+ from copy import deepcopy
14
17
 
15
18
  from .logger import log_function_call
16
19
 
@@ -35,7 +38,7 @@ def generate_gene_list(number_of_genes, number_of_all_genes):
35
38
 
36
39
  # plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
37
40
  def generate_plate_map(nr_plates):
38
- print('nr_plates',nr_plates)
41
+ #print('nr_plates',nr_plates)
39
42
  """
40
43
  Generate a plate map based on the number of plates.
41
44
 
@@ -241,6 +244,7 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
241
244
  - df_ls (list): A list containing gene counts per well, well counts per gene, Gini coefficients for wells,
242
245
  Gini coefficients for genes, gene weights array, and well weights.
243
246
  """
247
+
244
248
  #generate primary distributions and genes
245
249
  cpw, _ = dist_gen(avg_cells_per_well, sd_cells_per_well, plate_map)
246
250
  gpw, _ = dist_gen(avg_genes_per_well, sd_genes_per_well, plate_map)
@@ -251,14 +255,18 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
251
255
  gene_weights_array = np.array(list(gene_weights.values())) # Convert the values to an array
252
256
 
253
257
  well_weights = generate_power_law_distribution(len(plate_map), well_ineq_coeff)
254
-
255
- gene_to_well_mapping = {}
256
-
257
- for gene in genes:
258
- gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=int(gpw[gene-1]), p=well_weights) # Generate a number of wells for each gene according to well_weights
259
-
260
- gene_to_well_mapping = {gene: wells for gene, wells in gene_to_well_mapping.items() if len(wells) >= 2}
261
258
 
259
+ gene_to_well_mapping = {}
260
+ for gene in range(1, number_of_genes + 1): # ensures gene-1 is within bounds
261
+ if gene-1 < len(gpw):
262
+ max_index = len(plate_map['plate_row_column']) # this should be the number of choices available from plate_map
263
+ num_samples = int(gpw[gene-1])
264
+ if num_samples >= max_index:
265
+ num_samples = max_index - 1 # adjust to maximum possible index
266
+ gene_to_well_mapping[gene] = np.random.choice(plate_map['plate_row_column'], size=num_samples, replace=False, p=well_weights)
267
+ else:
268
+ break # break the loop if gene-1 is out of bounds for gpw
269
+
262
270
  cells = []
263
271
  for i in [*range(0,len(plate_map))]:
264
272
  ciw = random.choice(cpw)
@@ -316,9 +324,80 @@ def run_experiment(plate_map, number_of_genes, active_gene_list, avg_genes_per_w
316
324
  df_ls = [gene_counts_per_well, well_counts_per_gene, gini_well, gini_gene, gene_weights_array, well_weights]
317
325
  return cell_df, genes_per_well_df, wells_per_gene_df, df_ls
318
326
 
319
- # classifier is a function that takes a cell state (active=1/inactive=0) and produces a score in [0, 1]
320
- # For the input cell, it checks if it is active or inactive, and then samples from an appropriate beta distribution to give a score
321
- def classifier(positive_mean, positive_variance, negative_mean, negative_variance, df):
327
+ def classifier(positive_mean, positive_variance, negative_mean, negative_variance, classifier_accuracy, df):
328
+ """
329
+ Classifies the data in the DataFrame based on the given parameters and a classifier error rate.
330
+
331
+ Args:
332
+ positive_mean (float): The mean of the positive distribution.
333
+ positive_variance (float): The variance of the positive distribution.
334
+ negative_mean (float): The mean of the negative distribution.
335
+ negative_variance (float): The variance of the negative distribution.
336
+ classifier_accuracy (float): The likelihood (0 to 1) that a gene is correctly classified according to its true label.
337
+ df (pandas.DataFrame): The DataFrame containing the data to be classified.
338
+
339
+ Returns:
340
+ pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
341
+ """
342
+ def calc_alpha_beta(mean, variance):
343
+ if mean <= 0 or mean >= 1:
344
+ raise ValueError("Mean must be between 0 and 1 exclusively.")
345
+ max_variance = mean * (1 - mean)
346
+ if variance <= 0 or variance >= max_variance:
347
+ raise ValueError(f"Variance must be positive and less than {max_variance}.")
348
+
349
+ alpha = mean * (mean * (1 - mean) / variance - 1)
350
+ beta = alpha * (1 - mean) / mean
351
+ return alpha, beta
352
+
353
+ # Apply the beta distribution based on 'is_active' status with consideration for classifier error
354
+ def get_score(is_active):
355
+ if np.random.rand() < classifier_accuracy: # With classifier_accuracy probability, choose the correct distribution
356
+ return np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2)
357
+ else: # With 1-classifier_accuracy probability, choose the incorrect distribution
358
+ return np.random.beta(a2, b2) if is_active else np.random.beta(a1, b1)
359
+
360
+ # Calculate alpha and beta for both distributions
361
+ a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
362
+ a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
363
+ df['score'] = df['is_active'].apply(get_score)
364
+
365
+ return df
366
+
367
+ def classifier_v2(positive_mean, positive_variance, negative_mean, negative_variance, df):
368
+ """
369
+ Classifies the data in the DataFrame based on the given parameters.
370
+
371
+ Args:
372
+ positive_mean (float): The mean of the positive distribution.
373
+ positive_variance (float): The variance of the positive distribution.
374
+ negative_mean (float): The mean of the negative distribution.
375
+ negative_variance (float): The variance of the negative distribution.
376
+ df (pandas.DataFrame): The DataFrame containing the data to be classified.
377
+
378
+ Returns:
379
+ pandas.DataFrame: The DataFrame with an additional 'score' column containing the classification scores.
380
+ """
381
+ def calc_alpha_beta(mean, variance):
382
+ if mean <= 0 or mean >= 1:
383
+ raise ValueError("Mean must be between 0 and 1 exclusively.")
384
+ max_variance = mean * (1 - mean)
385
+ if variance <= 0 or variance >= max_variance:
386
+ raise ValueError(f"Variance must be positive and less than {max_variance}.")
387
+
388
+ alpha = mean * (mean * (1 - mean) / variance - 1)
389
+ beta = alpha * (1 - mean) / mean
390
+ return alpha, beta
391
+
392
+ # Calculate alpha and beta for both distributions
393
+ a1, b1 = calc_alpha_beta(positive_mean, positive_variance)
394
+ a2, b2 = calc_alpha_beta(negative_mean, negative_variance)
395
+
396
+ # Apply the beta distribution based on 'is_active' status
397
+ df['score'] = df['is_active'].apply(lambda is_active: np.random.beta(a1, b1) if is_active else np.random.beta(a2, b2))
398
+ return df
399
+
400
+ def classifier_v1(positive_mean, positive_variance, negative_mean, negative_variance, df):
322
401
  """
323
402
  Classifies the data in the DataFrame based on the given parameters.
324
403
 
@@ -446,6 +525,7 @@ def generate_well_score(cell_scores):
446
525
  return well_score
447
526
 
448
527
  def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_per_gene, sequencing_error=0.01):
528
+
449
529
  """
450
530
  Simulates the sequencing of plates and calculates gene fractions and metadata.
451
531
 
@@ -460,6 +540,7 @@ def sequence_plates(well_score, number_of_genes, avg_reads_per_gene, sd_reads_pe
460
540
  gene_fraction_map (pd.DataFrame): DataFrame containing gene fractions for each well.
461
541
  metadata (pd.DataFrame): DataFrame containing metadata for each well.
462
542
  """
543
+
463
544
  reads, _ = dist_gen(avg_reads_per_gene, sd_reads_per_gene, well_score)
464
545
  gene_names = [f'gene_{v}' for v in range(number_of_genes+1)]
465
546
  all_wells = well_score.index
@@ -612,8 +693,12 @@ def plot_histogram(data, x_label, ax, color, title, binwidth=0.01, log=False):
612
693
  Returns:
613
694
  None
614
695
  """
615
- sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
616
- legend=False, fill=True, element='step', palette='dark')
696
+ if not binwidth:
697
+ sns.histplot(data=data, x=x_label, ax=ax, color=color, kde=False, stat='density',
698
+ legend=False, fill=True, element='step', palette='dark')
699
+ else:
700
+ sns.histplot(data=data, x=x_label, ax=ax, color=color, binwidth=binwidth, kde=False, stat='density',
701
+ legend=False, fill=True, element='step', palette='dark')
617
702
  if log:
618
703
  ax.set_yscale('log')
619
704
  ax.set_title(title)
@@ -700,7 +785,7 @@ def run_simulation(settings):
700
785
  plate_map = plate_map[~plate_map['column_id'].isin(['c1', 'c2', 'c3', 'c23', 'c24'])] # Extract rows where 'column_id' is not in [1,2,3,23,24]
701
786
 
702
787
  cell_level, genes_per_well_df, wells_per_gene_df, dists = run_experiment(plate_map, settings['number_of_genes'], active_gene_list, settings['avg_genes_per_well'], settings['sd_genes_per_well'], settings['avg_cells_per_well'], settings['sd_cells_per_well'], settings['well_ineq_coeff'], settings['gene_ineq_coeff'])
703
- cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], df=cell_level)
788
+ cell_scores = classifier(settings['positive_mean'], settings['positive_variance'], settings['negative_mean'], settings['negative_variance'], settings['classifier_accuracy'], df=cell_level)
704
789
  cell_roc_dict_df, cell_pr_dict_df, cell_scores, cell_cm = cell_level_roc_auc(cell_scores)
705
790
  well_score = generate_well_score(cell_scores)
706
791
  gene_fraction_map, metadata = sequence_plates(well_score, settings['number_of_genes'], settings['avg_reads_per_gene'], settings['sd_reads_per_gene'], sequencing_error=settings['sequencing_error'])
@@ -720,7 +805,10 @@ def run_simulation(settings):
720
805
  results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats = regression_roc_auc(results_df, active_gene_list, control_gene_list, alpha = 0.05, optimal=False)
721
806
  #except Exception as e:
722
807
  # print(f"An error occurred while saving data: {e}")
723
- return [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df], dists
808
+ output = [cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df]
809
+ del cell_scores, cell_roc_dict_df, cell_pr_dict_df, cell_cm, well_score, gene_fraction_map, metadata, results_df, reg_roc_dict_df, reg_pr_dict_df, reg_cm, sim_stats, genes_per_well_df, wells_per_gene_df
810
+ gc.collect()
811
+ return output, dists
724
812
 
725
813
  def vis_dists(dists, src, v, i):
726
814
  """
@@ -744,9 +832,14 @@ def vis_dists(dists, src, v, i):
744
832
  for index, dist in enumerate(dists):
745
833
  temp = pd.DataFrame(dist, columns = [f'{names[index]}'])
746
834
  sns.histplot(data=temp, x=f'{names[index]}', kde=False, binwidth=None, stat='count', element="step", ax=ax[n], color='teal', log_scale=False)
747
- #plot_histogram(temp, f'{names[index]}', ax[n], 'slategray', f'{names[index]}', binwidth=None, log=False)
748
835
  n+=1
749
836
  save_plot(fig2, src, 'dists', i)
837
+ plt.close(fig2)
838
+ plt.figure().clear()
839
+ plt.cla()
840
+ plt.clf()
841
+ del dists
842
+
750
843
  return
751
844
 
752
845
  def visualize_all(output):
@@ -773,6 +866,7 @@ def visualize_all(output):
773
866
  Returns:
774
867
  fig (matplotlib.figure.Figure): The generated figure object.
775
868
  """
869
+
776
870
  cell_scores = output[0]
777
871
  cell_roc_dict_df = output[1]
778
872
  cell_pr_dict_df = output[2]
@@ -810,24 +904,35 @@ def visualize_all(output):
810
904
  #plot cell classification score by inactive and active
811
905
  active_distribution = cell_scores[cell_scores['is_active'] == 1]
812
906
  inactive_distribution = cell_scores[cell_scores['is_active'] == 0]
813
- plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', binwidth=0.01, log=False)
814
- plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', binwidth=0.01, log=False)
907
+ plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Cell scores', log=False)#, binwidth=0.01, log=False)
908
+ plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Cell scores', log=False)#, binwidth=0.01, log=False)
909
+
910
+ legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
911
+ Patch(facecolor='teal', edgecolor='teal', label='Active')]
912
+
913
+ ax[n].legend(handles=legend_elements, loc='upper right')
914
+
915
+
815
916
  ax[n].set_xlim([0, 1])
816
917
  n+=1
817
918
 
818
919
  #plot classifier cell predictions by inactive and active well average
819
- ##inactive_distribution_well['score'] = pd.to_numeric(inactive_distribution['score'], errors='coerce')
820
- ##inactive_distribution_well = inactive_distribution_well.groupby('plate_row_column')['score'].mean()
821
-
822
- ##active_distribution_well['score'] = pd.to_numeric(active_distribution['score'], errors='coerce')
823
- ##active_distribution_well = active_distribution_well.groupby('plate_row_column')['score'].mean()
920
+ inactive_distribution_well = inactive_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
921
+ active_distribution_well = active_distribution.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
922
+ mixed_distribution_well = cell_scores.groupby(['plate_id', 'row_id', 'column_id'])['score'].mean().reset_index(name='score')
923
+
924
+ plot_histogram(inactive_distribution_well, "score", ax[n], 'slategray', 'Well scores', log=False)#, binwidth=0.01, log=False)
925
+ plot_histogram(active_distribution_well, "score", ax[n], 'teal', 'Well scores', log=False)#, binwidth=0.01, log=False)
926
+ plot_histogram(mixed_distribution_well, "score", ax[n], 'red', 'Well scores', log=False)#, binwidth=0.01, log=False)
824
927
 
825
- #inactive_distribution_well = inactive_distribution.groupby(['plate_row_column']).mean()
826
- #active_distribution_well = active_distribution.groupby(['plate_row_column']).mean()
928
+ legend_elements = [Patch(facecolor='slategray', edgecolor='slategray', label='Inactive'),
929
+ Patch(facecolor='teal', edgecolor='teal', label='Active'),
930
+ Patch(facecolor='red', edgecolor='red', label='Mixed')]
827
931
 
828
- plot_histogram(active_distribution, "score", ax[n], 'slategray', 'Well scores', binwidth=0.01, log=False)
829
- plot_histogram(inactive_distribution, "score", ax[n], 'teal', 'Well scores', binwidth=0.01, log=False)
932
+ ax[n].legend(handles=legend_elements, loc='upper right')
933
+
830
934
  ax[n].set_xlim([0, 1])
935
+ #ax[n].legend()
831
936
  n+=1
832
937
 
833
938
  #plot ROC (cell classification)
@@ -846,8 +951,8 @@ def visualize_all(output):
846
951
  n+=1
847
952
 
848
953
  #plot well score
849
- plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=False)
850
- ax[n].set_xlim([0, 1])
954
+ plot_histogram(well_score, "score", ax[n], 'teal', 'Well score', binwidth=0.005, log=True)
955
+ #ax[n].set_xlim([0, 1])
851
956
  n+=1
852
957
 
853
958
  control_df = results_df[results_df['color'] == 'control']
@@ -911,6 +1016,7 @@ def visualize_all(output):
911
1016
 
912
1017
  plt.tight_layout()
913
1018
  plt.show()
1019
+ gc.collect()
914
1020
  return fig
915
1021
 
916
1022
  def create_database(db_path):
@@ -992,6 +1098,7 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
992
1098
  df_concat[f'variable_{variable}_sim_nr'] = i
993
1099
 
994
1100
  append_database(src, df_concat, 'simulations')
1101
+ del gini_genes_per_well, gini_wells_per_gene, df_concat
995
1102
 
996
1103
  if save_all:
997
1104
  for i, df in enumerate(output):
@@ -1001,8 +1108,12 @@ def save_data(src, output, settings, save_all=False, i=0, variable='all'):
1001
1108
  if not isinstance(df, pd.DataFrame):
1002
1109
  df = pd.DataFrame(df)
1003
1110
  append_database(src, df, table_names[i])
1111
+ del df
1004
1112
  except Exception as e:
1005
1113
  print(f"An error occurred while saving data: {e}")
1114
+ print(traceback.format_exc())
1115
+
1116
+ del output, settings_df
1006
1117
  return
1007
1118
 
1008
1119
  def save_plot(fig, src, variable, i):
@@ -1025,7 +1136,6 @@ def save_plot(fig, src, variable, i):
1025
1136
 
1026
1137
  def run_and_save(i, settings, time_ls, total_sims):
1027
1138
 
1028
-
1029
1139
  """
1030
1140
  Run the simulation and save the results.
1031
1141
 
@@ -1040,24 +1150,24 @@ def run_and_save(i, settings, time_ls, total_sims):
1040
1150
  """
1041
1151
  #print(f'Runnings simulation with the following paramiters')
1042
1152
  #print(settings)
1043
-
1153
+ settings['random_seed'] = False
1044
1154
  if settings['random_seed']:
1045
1155
  random.seed(42) # sims will be too similar with random seed
1046
1156
  src = settings['src']
1047
1157
  plot = settings['plot']
1048
1158
  v = settings['variable']
1049
1159
  start_time = time() # Start time of the simulation
1050
- now = datetime.now() # get current date
1051
- date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1160
+ #now = datetime.now() # get current date
1161
+ #date_string = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1162
+ date_string = settings['start_time']
1052
1163
  #try:
1053
1164
  output, dists = run_simulation(settings)
1054
1165
  sim_time = time() - start_time # Elapsed time for the simulation
1055
1166
  settings['sim_time'] = sim_time
1056
1167
  src = os.path.join(f'{src}/{date_string}',settings['name'])
1057
1168
  save_data(src, output, settings, save_all=False, i=i, variable=v)
1058
- if vis_dists:
1059
- vis_dists(dists,src, v, i)
1060
1169
  if plot:
1170
+ vis_dists(dists,src, v, i)
1061
1171
  fig = visualize_all(output)
1062
1172
  save_plot(fig, src, v, i)
1063
1173
  plt.close(fig)
@@ -1066,6 +1176,7 @@ def run_and_save(i, settings, time_ls, total_sims):
1066
1176
  plt.clf()
1067
1177
  del fig
1068
1178
  del output, dists
1179
+ gc.collect()
1069
1180
  #except Exception as e:
1070
1181
  # print(e, end='\r', flush=True)
1071
1182
  # sim_time = time() - start_time
@@ -1073,98 +1184,92 @@ def run_and_save(i, settings, time_ls, total_sims):
1073
1184
  time_ls.append(sim_time)
1074
1185
  return i, sim_time, None
1075
1186
 
1076
- def generate_paramiters(settings):
1187
+ def validate_and_adjust_beta_params(sim_params):
1077
1188
  """
1078
- Generate a list of parameter sets for simulation based on the given settings.
1079
-
1189
+ Validates and adjusts Beta distribution parameters in simulation settings to ensure they are possible.
1190
+
1080
1191
  Args:
1081
- settings (dict): A dictionary containing the simulation settings.
1082
-
1192
+ sim_params (list of dict): List of dictionaries, each containing the simulation parameters.
1193
+
1083
1194
  Returns:
1084
- list: A list of parameter sets for simulation.
1195
+ list of dict: The adjusted list of simulation parameter sets.
1085
1196
  """
1086
- sim_ls = []
1087
- for avg_genes_per_well in settings['avg_genes_per_well']:
1088
- replicates = settings['replicates']
1089
- sett = settings.copy()
1090
- sett['avg_genes_per_well'] = avg_genes_per_well
1091
- sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
1092
- for avg_cells_per_well in settings['avg_cells_per_well']:
1093
- sett['avg_cells_per_well'] = avg_cells_per_well
1094
- sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
1095
- for positive_mean in settings['positive_mean']:
1096
- sett['positive_mean'] = positive_mean
1097
- sett['negative_mean'] = 1-positive_mean
1098
- sett['positive_variance'] = (1-positive_mean)/2
1099
- sett['negative_variance'] = (1-positive_mean)/2
1100
- for avg_reads_per_gene in settings['avg_reads_per_gene']:
1101
- sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
1102
- sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
1103
- for sequencing_error in settings['sequencing_error']:
1104
- sett['sequencing_error'] = sequencing_error
1105
- for well_ineq_coeff in settings['well_ineq_coeff']:
1106
- sett['well_ineq_coeff'] = well_ineq_coeff
1107
- for gene_ineq_coeff in settings['gene_ineq_coeff']:
1108
- sett['gene_ineq_coeff'] = gene_ineq_coeff
1109
- for nr_plates in settings['nr_plates']:
1110
- sett['nr_plates'] = nr_plates
1111
- for number_of_genes in settings['number_of_genes']:
1112
- sett['number_of_genes'] = number_of_genes
1113
- for number_of_active_genes in settings['number_of_active_genes']:
1114
- sett['number_of_active_genes'] = number_of_active_genes
1115
- for i in [*range(1,replicates+1)]:
1116
- sim_ls.append(sett)
1117
- #print(sett)
1118
- #print('Number of simulations:',len(sim_ls))
1119
- return sim_ls
1197
+ adjusted_params = []
1198
+ for params in sim_params:
1199
+ max_pos_variance = params['positive_mean'] * (1 - params['positive_mean'])
1200
+ max_neg_variance = params['negative_mean'] * (1 - params['negative_mean'])
1201
+
1202
+ # Adjust positive variance
1203
+ if params['positive_variance'] >= max_pos_variance:
1204
+ print(f'changed positive variance from {params["positive_variance"]} to {max_pos_variance * 0.99}')
1205
+ params['positive_variance'] = max_pos_variance * 0.99 # Adjust to 99% of the maximum allowed variance
1206
+
1207
+ # Adjust negative variance
1208
+ if params['negative_variance'] >= max_neg_variance:
1209
+ print(f'changed negative variance from {params["negative_variance"]} to {max_neg_variance * 0.99}')
1210
+ params['negative_variance'] = max_neg_variance * 0.99 # Adjust to 99% of the maximum allowed variance
1211
+
1212
+ adjusted_params.append(params)
1213
+
1214
+ return adjusted_params
1215
+
1216
+ def generate_paramiters(settings):
1120
1217
 
1121
- #altered for one set of settings see negative_mean and variance
1122
- def generate_paramiters_single(settings):
1123
1218
  """
1124
- Generate a list of parameter sets for single simulations based on the given settings.
1219
+ Generate a list of parameter sets for simulation based on the given settings.
1125
1220
 
1126
1221
  Args:
1127
1222
  settings (dict): A dictionary containing the simulation settings.
1128
1223
 
1129
1224
  Returns:
1130
- list: A list of parameter sets for single simulations.
1225
+ list: A list of parameter sets for simulation.
1131
1226
  """
1227
+
1228
+ settings['positive_mean'] = [0.8]
1229
+
1132
1230
  sim_ls = []
1133
1231
  for avg_genes_per_well in settings['avg_genes_per_well']:
1134
1232
  replicates = settings['replicates']
1135
- sett = settings.copy()
1136
- sett['avg_genes_per_well'] = avg_genes_per_well
1137
- sett['sd_genes_per_well'] = int(avg_genes_per_well / 2)
1138
1233
  for avg_cells_per_well in settings['avg_cells_per_well']:
1139
- sett['avg_cells_per_well'] = avg_cells_per_well
1140
- sett['sd_cells_per_well'] = int(avg_cells_per_well / 2)
1141
- for positive_mean in settings['positive_mean']:
1142
- sett['positive_mean'] = positive_mean
1143
- sett['negative_mean'] = 0.2
1144
- sett['positive_variance'] = 0.13
1145
- sett['negative_variance'] = 0.13
1146
- for avg_reads_per_gene in settings['avg_reads_per_gene']:
1147
- sett['avg_reads_per_gene'] = int(avg_reads_per_gene)
1148
- sett['sd_reads_per_gene'] = int(avg_reads_per_gene/2)
1149
- for sequencing_error in settings['sequencing_error']:
1150
- sett['sequencing_error'] = sequencing_error
1151
- for well_ineq_coeff in settings['well_ineq_coeff']:
1152
- sett['well_ineq_coeff'] = well_ineq_coeff
1153
- for gene_ineq_coeff in settings['gene_ineq_coeff']:
1154
- sett['gene_ineq_coeff'] = gene_ineq_coeff
1155
- for nr_plates in settings['nr_plates']:
1156
- sett['nr_plates'] = nr_plates
1157
- for number_of_genes in settings['number_of_genes']:
1158
- sett['number_of_genes'] = number_of_genes
1159
- for number_of_active_genes in settings['number_of_active_genes']:
1160
- sett['number_of_active_genes'] = number_of_active_genes
1161
- for i in [*range(1,replicates+1)]:
1162
- sim_ls.append(sett)
1163
- #print(sett)
1164
- #print('Number of simulations:',len(sim_ls))
1234
+ for classifier_accuracy in settings['classifier_accuracy']:
1235
+ for positive_mean in settings['positive_mean']:
1236
+ for avg_reads_per_gene in settings['avg_reads_per_gene']:
1237
+ for sequencing_error in settings['sequencing_error']:
1238
+ for well_ineq_coeff in settings['well_ineq_coeff']:
1239
+ for gene_ineq_coeff in settings['gene_ineq_coeff']:
1240
+ for nr_plates in settings['nr_plates']:
1241
+ for number_of_genes in settings['number_of_genes']:
1242
+ for number_of_active_genes in settings['number_of_active_genes']:
1243
+ for i in range(1, replicates+1):
1244
+ sett = deepcopy(settings)
1245
+ sett['avg_genes_per_well'] = avg_genes_per_well
1246
+ sett['sd_genes_per_well'] = avg_genes_per_well / 2
1247
+ sett['avg_cells_per_well'] = avg_cells_per_well
1248
+ sett['sd_cells_per_well'] = avg_cells_per_well / 2
1249
+ sett['classifier_accuracy'] = classifier_accuracy
1250
+ sett['positive_mean'] = positive_mean
1251
+ sett['negative_mean'] = 1-positive_mean
1252
+ sett['positive_variance'] = (1-positive_mean)/2
1253
+ sett['negative_variance'] = (1-positive_mean)/2
1254
+ sett['avg_reads_per_gene'] = avg_reads_per_gene
1255
+ sett['sd_reads_per_gene'] = avg_reads_per_gene / 2
1256
+ sett['sequencing_error'] = sequencing_error
1257
+ sett['well_ineq_coeff'] = well_ineq_coeff
1258
+ sett['gene_ineq_coeff'] = gene_ineq_coeff
1259
+ sett['nr_plates'] = nr_plates
1260
+ sett['number_of_genes'] = number_of_genes
1261
+ sett['number_of_active_genes'] = number_of_active_genes
1262
+ sim_ls.append(sett)
1263
+
1264
+ random.shuffle(sim_ls)
1265
+ sim_ls = validate_and_adjust_beta_params(sim_ls)
1266
+ print(f'Running {len(sim_ls)} simulations.')
1267
+ #for x in sim_ls:
1268
+ # print(x['positive_mean'])
1165
1269
  return sim_ls
1166
1270
 
1167
1271
  def run_multiple_simulations(settings):
1272
+
1168
1273
  """
1169
1274
  Run multiple simulations in parallel using the provided settings.
1170
1275
 
@@ -1175,8 +1280,12 @@ def run_multiple_simulations(settings):
1175
1280
  None
1176
1281
  """
1177
1282
 
1283
+ now = datetime.now() # get current date
1284
+ start_time = now.strftime("%y%m%d") # format as a string in 'ddmmyy' format
1285
+ settings['start_time'] = start_time
1286
+
1178
1287
  sim_ls = generate_paramiters(settings)
1179
- print(f'Running {len(sim_ls)} simulations. Standard deviations for each variable are variable / 2')
1288
+ #print(f'Running {len(sim_ls)} simulations.')
1180
1289
 
1181
1290
  max_workers = settings['max_workers'] or cpu_count() - 4
1182
1291
  with Manager() as manager:
@@ -1185,10 +1294,448 @@ def run_multiple_simulations(settings):
1185
1294
  with Pool(max_workers) as pool:
1186
1295
  result = pool.starmap_async(run_and_save, [(index, settings, time_ls, total_sims) for index, settings in enumerate(sim_ls)])
1187
1296
  while not result.ready():
1188
- sleep(0.01)
1189
- sims_processed = len(time_ls)
1190
- average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1191
- time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
1192
- print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
1193
- result.get()
1297
+ try:
1298
+ sleep(0.01)
1299
+ sims_processed = len(time_ls)
1300
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1301
+ time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
1302
+ print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
1303
+ gc.collect()
1304
+ except Exception as e:
1305
+ print(e)
1306
+ print(traceback.format_exc())
1307
+ try:
1308
+ result.get()
1309
+ except Exception as e:
1310
+ print(e)
1311
+ print(traceback.format_exc())
1194
1312
 
1313
+ def generate_integers(start, stop, step):
1314
+ return list(range(start, stop + 1, step))
1315
+
1316
+ def generate_floats(start, stop, step):
1317
+ # Determine the number of decimal places in 'step'
1318
+ num_decimals = str(step)[::-1].find('.')
1319
+
1320
+ current = start
1321
+ floats_list = []
1322
+ while current <= stop:
1323
+ # Round each float to the appropriate number of decimal places
1324
+ floats_list.append(round(current, num_decimals))
1325
+ current += step
1326
+
1327
+ return floats_list
1328
+
1329
+ def remove_columns_with_single_value_v1(df):
1330
+ """
1331
+ Removes columns from the DataFrame that have the same value in all rows.
1332
+
1333
+ Args:
1334
+ df (pandas.DataFrame): The original DataFrame.
1335
+
1336
+ Returns:
1337
+ pandas.DataFrame: A DataFrame with the columns removed that contained only one unique value.
1338
+ """
1339
+
1340
+ df=df.copy()
1341
+
1342
+ for column in df.columns:
1343
+ if len(df[column].unique()) == 1:
1344
+ df.drop(column, axis=1, inplace=True)
1345
+
1346
+ return df
1347
+
1348
+ def remove_columns_with_single_value(df):
1349
+ """
1350
+ Removes columns from the DataFrame that have the same value in all rows.
1351
+
1352
+ Args:
1353
+ df (pandas.DataFrame): The original DataFrame.
1354
+
1355
+ Returns:
1356
+ pandas.DataFrame: A DataFrame with the columns removed that contained only one unique value.
1357
+ """
1358
+ to_drop = [column for column in df.columns if df[column].nunique() == 1]
1359
+ return df.drop(to_drop, axis=1)
1360
+
1361
+ def read_simulations_table(db_path):
1362
+ """
1363
+ Reads the 'simulations' table from an SQLite database into a pandas DataFrame.
1364
+
1365
+ Args:
1366
+ db_path (str): The file path to the SQLite database.
1367
+
1368
+ Returns:
1369
+ pandas.DataFrame: DataFrame containing the 'simulations' table data.
1370
+ """
1371
+ # Create a connection object using the connect function
1372
+ conn = sqlite3.connect(db_path)
1373
+
1374
+ # Read the 'simulations' table into a pandas DataFrame
1375
+ try:
1376
+ df = pd.read_sql_query("SELECT * FROM simulations", conn)
1377
+ except Exception as e:
1378
+ print(f"An error occurred: {e}")
1379
+ return None
1380
+ finally:
1381
+ # Close the connection to SQLite database
1382
+ conn.close()
1383
+
1384
+ return df
1385
+
1386
+ def plot_simulations(df, variable, x_rotation=None, legend=False, grid=False, clean=True, verbose=False):
1387
+
1388
+ """
1389
+ Creates separate line plots for 'prauc' against a specified 'variable',
1390
+ for each unique combination of conditions defined by 'grouping_vars', displayed on a grid.
1391
+
1392
+ Args:
1393
+ df (pandas.DataFrame): DataFrame containing the necessary columns.
1394
+ variable (str): Name of the column to use as the x-axis for grouping and plotting.
1395
+ x_rotation (int, optional): Degrees to rotate the x-axis labels.
1396
+ legend (bool, optional): Whether to display a legend.
1397
+ grid (bool, optional): Whether to display grid lines.
1398
+ verbose (bool, optional): Whether to print the filter conditions.
1399
+
1400
+ Returns:
1401
+ None
1402
+ """
1403
+
1404
+ grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1405
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1406
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1407
+
1408
+ if clean:
1409
+ relevant_data = remove_columns_with_single_value(relevant_data)
1410
+
1411
+ grouping_vars = [col for col in grouping_vars if col != variable]
1412
+
1413
+ # Check if the necessary columns are present in the DataFrame
1414
+ required_columns = {variable, 'prauc'} | set(grouping_vars)
1415
+ if not required_columns.issubset(df.columns):
1416
+ missing_cols = required_columns - set(df.columns)
1417
+ raise ValueError(f"DataFrame must contain {missing_cols} columns")
1418
+
1419
+ #if not dependent is None:
1420
+
1421
+ # Get unique combinations of conditions from grouping_vars
1422
+ unique_combinations = df[grouping_vars].drop_duplicates()
1423
+ num_combinations = len(unique_combinations)
1424
+
1425
+ # Determine the layout of the subplots
1426
+ num_rows = math.ceil(np.sqrt(num_combinations))
1427
+ num_cols = math.ceil(num_combinations / num_rows)
1428
+
1429
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
1430
+ if num_rows * num_cols > 1:
1431
+ axes = axes.flatten()
1432
+ else:
1433
+ axes = [axes]
1434
+
1435
+ for idx, (ax, (_, row)) in enumerate(zip(axes, unique_combinations.iterrows())):
1436
+
1437
+ # Filter the DataFrame for the current combination of variables
1438
+ condition = {var: row[var] for var in grouping_vars}
1439
+ subset_df = df[df[grouping_vars].eq(row).all(axis=1)]
1440
+
1441
+ # Group by 'variable' and calculate mean and std dev of 'prauc'
1442
+ grouped = subset_df.groupby(variable)['prauc'].agg(['mean', 'std'])
1443
+ grouped = grouped.sort_index() # Sort by the variable for orderly plots
1444
+
1445
+ # Plotting the mean of 'prauc' with std deviation as shaded area
1446
+ ax.plot(grouped.index, grouped['mean'], marker='o', linestyle='-', color='b', label='Mean PRAUC')
1447
+ ax.fill_between(grouped.index, grouped['mean'] - grouped['std'], grouped['mean'] + grouped['std'], color='gray', alpha=0.5, label='Std Dev')
1448
+
1449
+ # Setting plot labels and title
1450
+ title_details = ', '.join([f"{var}={row[var]}" for var in grouping_vars])
1451
+ ax.set_xlabel(variable)
1452
+ ax.set_ylabel('Precision-Recall AUC (PRAUC)')
1453
+ #ax.set_title(f'PRAUC vs. {variable} | {title_details}')
1454
+ ax.grid(grid)
1455
+
1456
+ if legend:
1457
+ ax.legend()
1458
+
1459
+ # Set x-ticks and rotate them as specified
1460
+ ax.set_xticks(grouped.index)
1461
+ ax.set_xticklabels(grouped.index, rotation=x_rotation if x_rotation is not None else 45)
1462
+
1463
+ if verbose:
1464
+ verbose_text = '\n'.join([f"{var}: {val}" for var, val in condition.items()])
1465
+ ax.text(0.95, 0.05, verbose_text, transform=ax.transAxes, fontsize=9, verticalalignment='bottom', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))
1466
+
1467
+ # Hide any unused axes if there are any
1468
+ for ax in axes[idx+1:]:
1469
+ ax.set_visible(False)
1470
+
1471
+ plt.tight_layout()
1472
+ plt.show()
1473
+ return fig
1474
+
1475
+ def plot_correlation_matrix(df, annot=False, cmap='inferno', clean=True):
1476
+ """
1477
+ Plots a correlation matrix for the specified variables and the target variable.
1478
+
1479
+ Args:
1480
+ df (pandas.DataFrame): The DataFrame containing the data.
1481
+ variables (list): List of column names to include in the correlation matrix.
1482
+ target_variable (str): The target variable column name.
1483
+
1484
+ Returns:
1485
+ None
1486
+ """
1487
+ cmap = sns.diverging_palette(240, 10, as_cmap=True)
1488
+ grouping_vars = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1489
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1490
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1491
+
1492
+ grouping_vars = grouping_vars + ['optimal_threshold', 'accuracy', 'prauc', 'roc_auc','genes_per_well_gini', 'wells_per_gene_gini']
1493
+ # 'inactive_mean', 'inactive_std', 'inactive_var', 'active_mean', 'active_std', 'inactive_var', 'cutoff', 'TP', 'FP', 'TN', 'FN',
1494
+
1495
+ if clean:
1496
+ df = remove_constant_columns(df)
1497
+ grouping_vars = [feature for feature in grouping_vars if feature in df.columns]
1498
+
1499
+ # Subsetting the DataFrame to include only the relevant variables
1500
+ relevant_data = df[grouping_vars]
1501
+
1502
+ if clean:
1503
+ relevant_data = remove_columns_with_single_value(relevant_data)
1504
+
1505
+ # Calculating the correlation matrix
1506
+ corr_matrix = relevant_data.corr()
1507
+ mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
1508
+
1509
+ # Plotting the correlation matrix
1510
+ fig = plt.figure(figsize=(12, 8))
1511
+ sns.heatmap(corr_matrix, mask=mask, annot=annot, cmap=cmap, fmt=".2f", linewidths=.5, robust=True)
1512
+ #plt.title('Correlation Matrix with Heatmap')
1513
+
1514
+ plt.tight_layout()
1515
+ plt.show()
1516
+ save_plot(fig, src='figures', variable='correlation_matrix', i=1)
1517
+ return fig
1518
+
1519
+ def plot_feature_importance(df, target='prauc', exclude=None, clean=True):
1520
+ """
1521
+ Trains a RandomForestRegressor to determine the importance of each feature in predicting the target.
1522
+
1523
+ Args:
1524
+ df (pandas.DataFrame): The DataFrame containing the data.
1525
+ target (str): The target variable column name.
1526
+ exclude (list or str, optional): Column names to exclude from features.
1527
+
1528
+ Returns:
1529
+ matplotlib.figure.Figure: The figure object containing the feature importance plot.
1530
+ """
1531
+
1532
+ # Define the features for the model
1533
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1534
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1535
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1536
+
1537
+ if clean:
1538
+ df = remove_columns_with_single_value(df)
1539
+ features = [feature for feature in features if feature in df.columns]
1540
+
1541
+ # Remove excluded features if specified
1542
+ if isinstance(exclude, list):
1543
+ features = [feature for feature in features if feature not in exclude]
1544
+ elif exclude is not None:
1545
+ features = [feature for feature in features if feature != exclude]
1546
+
1547
+ # Train the model
1548
+ model = RandomForestRegressor(n_estimators=1000, random_state=42)
1549
+ model.fit(df[features], df[target])
1550
+
1551
+ # Get feature importances
1552
+ importances = model.feature_importances_
1553
+ indices = np.argsort(importances)[::-1]
1554
+
1555
+ # Plot horizontal bar chart
1556
+ fig = plt.figure(figsize=(12, 6))
1557
+ plt.barh(range(len(indices)), importances[indices], color="teal", align="center", alpha=0.6)
1558
+ plt.yticks(range(len(indices)), [features[i] for i in indices[::-1]]) # Invert y-axis to match the order
1559
+ plt.gca().invert_yaxis() # Invert the axis to have the highest importance at the top
1560
+ plt.xlabel('Feature Importance')
1561
+ plt.title('Feature Importances')
1562
+ plt.tight_layout()
1563
+ plt.show()
1564
+ save_plot(fig, src='figures', variable='feature_importance', i=1)
1565
+ return fig
1566
+
1567
+ def calculate_permutation_importance(df, target='prauc', exclude=None, n_repeats=10, clean=True):
1568
+ """
1569
+ Calculates permutation importance for the given features in the dataframe.
1570
+
1571
+ Args:
1572
+ df (pandas.DataFrame): The DataFrame containing the data.
1573
+ features (list): List of column names to include as features.
1574
+ target (str): The name of the target variable column.
1575
+
1576
+ Returns:
1577
+ dict: Dictionary containing the importances and standard deviations.
1578
+ """
1579
+
1580
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1581
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1582
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1583
+
1584
+ if clean:
1585
+ df = remove_columns_with_single_value(df)
1586
+ features = [feature for feature in features if feature in df.columns]
1587
+
1588
+ if isinstance(exclude, list):
1589
+ for ex in exclude:
1590
+ features.remove(ex)
1591
+ if not exclude is None:
1592
+ features.remove(exclude)
1593
+
1594
+ X = df[features]
1595
+ y = df[target]
1596
+
1597
+ # Initialize a model (you could pass it as an argument if you'd like to use a different one)
1598
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
1599
+ model.fit(X, y)
1600
+
1601
+ perm_importance = permutation_importance(model, X, y, n_repeats=n_repeats, random_state=42)
1602
+
1603
+ # Plotting
1604
+ sorted_idx = perm_importance.importances_mean.argsort()
1605
+
1606
+ # Create a figure and a set of subplots
1607
+ fig, ax = plt.subplots()
1608
+ ax.barh(range(len(sorted_idx)), perm_importance.importances_mean[sorted_idx], color="teal", align="center", alpha=0.6)
1609
+ ax.set_yticks(range(len(sorted_idx)))
1610
+ ax.set_yticklabels([df.columns[i] for i in sorted_idx])
1611
+ ax.set_xlabel('Permutation Importance')
1612
+ plt.tight_layout()
1613
+ plt.show()
1614
+ save_plot(fig, src='figures', variable='permutation_importance', i=1)
1615
+ return fig
1616
+
1617
+ def plot_partial_dependences(df, target='prauc', clean=True):
1618
+
1619
+ """
1620
+ Creates partial dependence plots for the specified features, with improved layout to avoid text overlap.
1621
+
1622
+ Args:
1623
+ df (pandas.DataFrame): The DataFrame containing the data.
1624
+ target (str): The target variable.
1625
+
1626
+ Returns:
1627
+ None
1628
+ """
1629
+
1630
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1631
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1632
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1633
+
1634
+ if clean:
1635
+ df = remove_columns_with_single_value(df)
1636
+ features = [feature for feature in features if feature in df.columns]
1637
+
1638
+ X = df[features]
1639
+ y = df[target]
1640
+
1641
+ # Train a model
1642
+ model = GradientBoostingRegressor()
1643
+ model.fit(X, y)
1644
+
1645
+ # Determine the number of rows and columns for subplots
1646
+ n_cols = 4 # Number of columns in subplot grid
1647
+ n_rows = (len(features) + n_cols - 1) // n_cols # Calculate rows needed
1648
+
1649
+ # Plot partial dependence
1650
+ fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(5 * n_cols, 5 * n_rows))
1651
+ fig.suptitle('Partial Dependence Plots', fontsize=20, y=1.03)
1652
+
1653
+ # Flatten the array of axes if it's multidimensional
1654
+ axs = axs.flatten() if n_rows > 1 else [axs]
1655
+
1656
+ for i, feature in enumerate(features):
1657
+ ax = axs[i]
1658
+ disp = PartialDependenceDisplay.from_estimator(model, X, features=[feature], ax=ax)
1659
+ ax.set_title(feature) # Set title to the name of the feature
1660
+
1661
+ # Hide unused axes if any
1662
+ for ax in axs[len(features):]:
1663
+ ax.set_visible(False)
1664
+
1665
+ plt.tight_layout()
1666
+ plt.show()
1667
+ save_plot(fig, src='figures', variable='partial_dependences', i=1)
1668
+ return fig
1669
+
1670
+ def save_shap_plot(fig, src, variable, i):
1671
+ import os
1672
+ os.makedirs(f'{src}/{variable}', exist_ok=True)
1673
+ filename_fig = f'{src}/{variable}/{str(i)}_figure.pdf'
1674
+ fig.savefig(filename_fig, dpi=600, format='pdf', bbox_inches='tight')
1675
+ print(f"Saved figure as {filename_fig}")
1676
+
1677
+ def generate_shap_summary_plot(df,target='prauc', clean=True):
1678
+ """
1679
+ Generates a SHAP summary plot for the given features in the dataframe.
1680
+
1681
+ Args:
1682
+ df (pandas.DataFrame): The DataFrame containing the data.
1683
+ features (list): List of column names to include as features.
1684
+ target (str): The name of the target variable column.
1685
+
1686
+ Returns:
1687
+ None
1688
+ """
1689
+
1690
+ features = ['number_of_active_genes', 'number_of_control_genes', 'avg_reads_per_gene',
1691
+ 'classifier_accuracy', 'nr_plates', 'number_of_genes', 'avg_genes_per_well',
1692
+ 'avg_cells_per_well', 'sequencing_error', 'well_ineq_coeff', 'gene_ineq_coeff']
1693
+
1694
+ if clean:
1695
+ df = remove_columns_with_single_value(df)
1696
+ features = [feature for feature in features if feature in df.columns]
1697
+
1698
+ X = df[features]
1699
+ y = df[target]
1700
+
1701
+ # Initialize a model (you could pass it as an argument if you'd like to use a different one)
1702
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
1703
+ model.fit(X, y)
1704
+
1705
+ # Calculate SHAP values
1706
+ explainer = shap.TreeExplainer(model)
1707
+ shap_values = explainer.shap_values(X)
1708
+
1709
+ # Summary plot
1710
+ shap.summary_plot(shap_values, X)
1711
+ save_shap_plot(plt.gcf(), src='figures', variable='shap', i=1)
1712
+ #save_shap_plot(fig, src, variable, i)
1713
+ return plt.gcf()
1714
+
1715
+
1716
+ def remove_constant_columns(df):
1717
+ """
1718
+ Removes columns in the DataFrame where all entries have the same value.
1719
+
1720
+ Parameters:
1721
+ df (pd.DataFrame): The input DataFrame from which to remove constant columns.
1722
+
1723
+ Returns:
1724
+ pd.DataFrame: A DataFrame with the constant columns removed.
1725
+ """
1726
+ return df.loc[:, df.nunique() > 1]
1727
+
1728
+
1729
+ # to justify using beta for sim classifier
1730
+
1731
+ # Fit a Beta distribution to these outputs
1732
+ #a, b, loc, scale = beta.fit(predicted_probs, floc=0, fscale=1) # Fix location and scale to match the support of the sigmoid
1733
+
1734
+ # Sample from this fitted Beta distribution
1735
+ #simulated_probs = beta.rvs(a, b, size=1000)
1736
+
1737
+ # Plot the empirical vs simulated distribution
1738
+ #plt.hist(predicted_probs, bins=30, alpha=0.5, label='Empirical')
1739
+ #plt.hist(simulated_probs, bins=30, alpha=0.5, label='Simulated from Beta')
1740
+ #plt.legend()
1741
+ #plt.show()